Commit ee38e7f9 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi
Browse files

fixed deserializing issue with old checkpoint

parent 9993ea25
...@@ -338,7 +338,19 @@ def load_checkpoint(model, optimizer, lr_scheduler, args): ...@@ -338,7 +338,19 @@ def load_checkpoint(model, optimizer, lr_scheduler, args):
torch.distributed.get_rank(), checkpoint_name)) torch.distributed.get_rank(), checkpoint_name))
# Load the checkpoint. # Load the checkpoint.
sd = torch.load(checkpoint_name, map_location='cpu') try:
sd = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
# For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...')
import sys
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16.loss_scaler']
sd = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
except:
print_rank_0('could not load the checkpoint')
exit()
# Iterations. # Iterations.
if args.finetune or release: if args.finetune or release:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment