Commit f5a5d31a authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

avoid saveing duplicated optim state in dp rank>0

parent c844413b
......@@ -390,7 +390,7 @@ def get_checkpoint_name(checkpoints_path, iteration,
),
'model_optim_rng.pt')
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='none'):
"""Save a model checkpoint with expert parallel """
# TODO: update patch
from megatron import get_args
......@@ -398,6 +398,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args = get_args()
# Only rank zero of the data parallel writes to the disk.
if isinstance(model, DistributedDataParallel):
model = model.module
......@@ -414,7 +415,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict['model'] = model.state_dict_for_save_checkpoint(
keep_vars=(mpu.get_data_parallel_rank() > 0))
if mpu.get_data_parallel_rank() != 0:
if mpu.get_data_parallel_rank() > 0:
def extract_expert_param(state_dict, expert_dp_comm='none'):
state_dict_new = state_dict.__class__()
......@@ -430,12 +431,24 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict['model'] = extract_expert_param(
state_dict['model'],
expert_dp_comm='none')
expert_dp_comm)
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if mpu.get_data_parallel_rank() > 0:
index = 0
for param_group in optimizer.optimizer.param_groups:
for param in param_group['params']:
if not (hasattr(param, 'dp_comm') and \
param.dp_comm == expert_dp_comm):
# this parameter is not an expert parameter
# thus there is no need to save its state in current rank
# since it has been saved by data parallel rank 0
state_dict['optimizer']['state'].pop(index)
index += 1
if lr_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment