Commit b178e6fc authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

error fixes & tested.

parent 977efdfb
...@@ -194,47 +194,57 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): ...@@ -194,47 +194,57 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler):
model_checkpoint_name, optim_checkpoint_name = \ model_checkpoint_name, optim_checkpoint_name = \
get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer) get_checkpoint_names(args.save, iteration, args.use_distributed_optimizer)
# Save args, model, RNG. # Collect args, model, RNG.
model_state_dict = {}
if not torch.distributed.is_initialized() \ if not torch.distributed.is_initialized() \
or mpu.get_data_parallel_rank() == 0: or mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} model_state_dict['args'] = args
state_dict['args'] = args model_state_dict['checkpoint_version'] = 3.0
state_dict['checkpoint_version'] = 3.0 model_state_dict['iteration'] = iteration
state_dict['iteration'] = iteration
if len(model) == 1: if len(model) == 1:
state_dict['model'] = model[0].state_dict_for_save_checkpoint() model_state_dict['model'] = model[0].state_dict_for_save_checkpoint()
else: else:
for i in range(len(model)): for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i) mpu.set_virtual_pipeline_model_parallel_rank(i)
state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint() model_state_dict['model%d' % i] = \
model[i].state_dict_for_save_checkpoint()
# RNG states. # RNG states.
if not args.no_save_rng: if not args.no_save_rng:
state_dict["rng_state"] = rng_state model_state_dict["rng_state"] = rng_state
# Save. # Collect optimizer state. (Optimizer is saved separately from the model, due
ensure_directory_exists(model_checkpoint_name)
torch.save(state_dict, model_checkpoint_name)
# Save optimizer state. (Optimizer is saved separately from the model, due
# to the conflicting data pattern when using the distributed optimizer.) # to the conflicting data pattern when using the distributed optimizer.)
optim_state_dict = {}
if not args.no_save_optim \ if not args.no_save_optim \
and (not torch.distributed.is_initialized() and (not torch.distributed.is_initialized()
or mpu.get_data_parallel_rank() == 0 or mpu.get_data_parallel_rank() == 0
or args.use_distributed_optimizer): or args.use_distributed_optimizer):
# Optimizer stuff. # Optimizer stuff.
state_dict = {}
if optimizer is not None: if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict() optim_state_dict['optimizer'] = optimizer.state_dict()
if opt_param_scheduler is not None: if opt_param_scheduler is not None:
state_dict['opt_param_scheduler'] = opt_param_scheduler.state_dict() optim_state_dict['opt_param_scheduler'] = \
opt_param_scheduler.state_dict()
# Save. # Save.
if args.use_distributed_optimizer:
# Save model separate from optimizer.
if model_state_dict:
ensure_directory_exists(model_checkpoint_name)
torch.save(model_state_dict, model_checkpoint_name)
if optim_state_dict:
ensure_directory_exists(optim_checkpoint_name) ensure_directory_exists(optim_checkpoint_name)
torch.save(state_dict, optim_checkpoint_name) torch.save(optim_state_dict, optim_checkpoint_name)
else:
# Save model and optimizer together.
state_dict = {**model_state_dict, **optim_state_dict}
if state_dict: # only saves if populated (i.e., inherits conditions above)
ensure_directory_exists(model_checkpoint_name)
torch.save(state_dict, model_checkpoint_name)
# Wait so everyone is done (necessary) # Wait so everyone is done (necessary)
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment