Commit 12b23ae3 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

remove unnecessary states when dp rank>0

parent bc655118
......@@ -420,9 +420,6 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint(
keep_vars=(mpu.get_data_parallel_rank() > 0))
......@@ -456,18 +453,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
# since it has been saved by data parallel rank 0
state_dict['optimizer']['state'].pop(index)
index += 1
if lr_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict()
# RNG states.
if not args.no_save_rng:
state_dict['random_rng_state'] = random.getstate()
state_dict['np_rng_state'] = np.random.get_state()
state_dict['torch_rng_state'] = torch.get_rng_state()
state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
state_dict['rng_tracker_states'] \
= mpu.get_cuda_rng_tracker().get_states()
state_dict['optimizer'].pop('param_groups')
# Save.
checkpoint_name = get_fmoe_checkpoint_name(args.save, iteration)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment