Commit 74ba3797 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fixed load_checkpoint().

parent b178e6fc
......@@ -367,7 +367,10 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Load the checkpoint.
try:
model_state_dict = torch.load(model_checkpoint_name, map_location='cpu')
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
if args.use_distributed_optimizer:
optim_state_dict = torch.load(optim_checkpoint_name, map_location='cpu')
else:
optim_state_dict = model_state_dict
except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment