"tests/vscode:/vscode.git/clone" did not exist on "186ef59283891fdaaf377c97da2f1872afbf5b7f"
Commit 1c69da9c authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

avoid saveing duplicated optim state in dp rank>0

parent a54546a4
...@@ -232,7 +232,7 @@ def get_checkpoint_name(checkpoints_path, iteration, ...@@ -232,7 +232,7 @@ def get_checkpoint_name(checkpoints_path, iteration,
), ),
'model_optim_rng.pt') 'model_optim_rng.pt')
def save_checkpoint(iteration, model, optimizer, lr_scheduler): def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='none'):
"""Save a model checkpoint with expert parallel """ """Save a model checkpoint with expert parallel """
# TODO: update patch # TODO: update patch
from megatron import get_args from megatron import get_args
...@@ -240,6 +240,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -240,6 +240,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args = get_args() args = get_args()
# Only rank zero of the data parallel writes to the disk. # Only rank zero of the data parallel writes to the disk.
if isinstance(model, DistributedDataParallel): if isinstance(model, DistributedDataParallel):
model = model.module model = model.module
...@@ -256,7 +257,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -256,7 +257,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict['model'] = model.state_dict_for_save_checkpoint( state_dict['model'] = model.state_dict_for_save_checkpoint(
keep_vars=(mpu.get_data_parallel_rank() > 0)) keep_vars=(mpu.get_data_parallel_rank() > 0))
if mpu.get_data_parallel_rank() != 0: if mpu.get_data_parallel_rank() > 0:
def extract_expert_param(state_dict, expert_dp_comm='none'): def extract_expert_param(state_dict, expert_dp_comm='none'):
state_dict_new = state_dict.__class__() state_dict_new = state_dict.__class__()
...@@ -272,12 +273,24 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -272,12 +273,24 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict['model'] = extract_expert_param( state_dict['model'] = extract_expert_param(
state_dict['model'], state_dict['model'],
expert_dp_comm='none') expert_dp_comm)
# Optimizer stuff. # Optimizer stuff.
if not args.no_save_optim: if not args.no_save_optim:
if optimizer is not None: if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict() state_dict['optimizer'] = optimizer.state_dict()
if mpu.get_data_parallel_rank() > 0:
index = 0
for param_group in optimizer.optimizer.param_groups:
for param in param_group['params']:
if not (hasattr(param, 'dp_comm') and \
param.dp_comm == expert_dp_comm):
# this parameter is not an expert parameter
# thus there is no need to save its state in current rank
# since it has been saved by data parallel rank 0
state_dict['optimizer']['state'].pop(index)
index += 1
if lr_scheduler is not None: if lr_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict() state_dict['lr_scheduler'] = lr_scheduler.state_dict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment