Commit f224b69b authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

remove unnecessary states when dp rank>0

parent 87c1e6bb
...@@ -262,9 +262,6 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n ...@@ -262,9 +262,6 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} state_dict = {}
state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint( state_dict['model'] = model.state_dict_for_save_checkpoint(
keep_vars=(mpu.get_data_parallel_rank() > 0)) keep_vars=(mpu.get_data_parallel_rank() > 0))
...@@ -298,18 +295,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n ...@@ -298,18 +295,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
# since it has been saved by data parallel rank 0 # since it has been saved by data parallel rank 0
state_dict['optimizer']['state'].pop(index) state_dict['optimizer']['state'].pop(index)
index += 1 index += 1
state_dict['optimizer'].pop('param_groups')
if lr_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict()
# RNG states.
if not args.no_save_rng:
state_dict['random_rng_state'] = random.getstate()
state_dict['np_rng_state'] = np.random.get_state()
state_dict['torch_rng_state'] = torch.get_rng_state()
state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
state_dict['rng_tracker_states'] \
= mpu.get_cuda_rng_tracker().get_states()
# Save. # Save.
checkpoint_name = get_fmoe_checkpoint_name(args.save, iteration) checkpoint_name = get_fmoe_checkpoint_name(args.save, iteration)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment