Commit af2b136f authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

optimizer saves list(group), not list(param).

parent 37ca7859
...@@ -402,17 +402,17 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -402,17 +402,17 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
sys.exit() sys.exit()
# set checkpoint version # set checkpoint version
set_checkpoint_version(state_dict.get('checkpoint_version', 0)) set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
# Set iteration. # Set iteration.
if args.finetune or release: if args.finetune or release:
iteration = 0 iteration = 0
else: else:
try: try:
iteration = state_dict['iteration'] iteration = model_state_dict['iteration']
except KeyError: except KeyError:
try: # Backward compatible with older checkpoints try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters'] iteration = model_state_dict['total_iters']
except KeyError: except KeyError:
print_rank_0('A metadata file exists but unable to load ' print_rank_0('A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'.format( 'iteration from checkpoint {}, exiting'.format(
...@@ -422,8 +422,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -422,8 +422,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check arguments. # Check arguments.
assert args.consumed_train_samples == 0 assert args.consumed_train_samples == 0
assert args.consumed_valid_samples == 0 assert args.consumed_valid_samples == 0
if 'args' in state_dict: if 'args' in model_state_dict:
checkpoint_args = state_dict['args'] checkpoint_args = model_state_dict['args']
check_checkpoint_args(checkpoint_args) check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(checkpoint_args, args.consumed_train_samples = getattr(checkpoint_args,
'consumed_train_samples', 0) 'consumed_train_samples', 0)
...@@ -435,11 +435,11 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -435,11 +435,11 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Model. # Model.
if len(model) == 1: if len(model) == 1:
model[0].load_state_dict(state_dict['model'], strict=strict) model[0].load_state_dict(model_state_dict['model'], strict=strict)
else: else:
for i in range(len(model)): for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i) mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict['model%d' % i], strict=strict) model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering if needed # Fix up query/key/value matrix ordering if needed
checkpoint_version = get_checkpoint_version() checkpoint_version = get_checkpoint_version()
...@@ -450,12 +450,12 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri ...@@ -450,12 +450,12 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
if not release and not args.finetune and not args.no_load_optim: if not release and not args.finetune and not args.no_load_optim:
try: try:
if optimizer is not None: if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer']) optimizer.load_state_dict(optim_state_dict['optimizer'])
if opt_param_scheduler is not None: if opt_param_scheduler is not None:
if 'lr_scheduler' in state_dict: # backward compatbility if 'lr_scheduler' in state_dict: # backward compatbility
opt_param_scheduler.load_state_dict(state_dict['lr_scheduler']) opt_param_scheduler.load_state_dict(optim_state_dict['lr_scheduler'])
else: else:
opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler']) opt_param_scheduler.load_state_dict(optim_state_dict['opt_param_scheduler'])
except KeyError: except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. ' print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent ' 'Specify --no-load-optim or --finetune to prevent '
......
...@@ -306,8 +306,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -306,8 +306,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
state_dict['optimizer'] = self.optimizer.state_dict() state_dict['optimizer'] = self.optimizer.state_dict()
if self.grad_scaler: if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict() state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['params'] = \ # state_dict['params'] = \
[ p for g in self.optimizer.param_groups for p in g["params"] ] # [ p for g in self.optimizer.param_groups for p in g["params"] ]
state_dict['groups'] = [g["params"] for g in self.optimizer.param_groups]
# pax(0, { # ... only called on model rank 0 # pax(0, { # ... only called on model rank 0
# # "optimizer" : self.optimizer, # # "optimizer" : self.optimizer,
# "state_dict" : state_dict, # "state_dict" : state_dict,
...@@ -329,10 +330,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -329,10 +330,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
'an old checkpoint ...') 'an old checkpoint ...')
self.optimizer.load_state_dict(state_dict[optimizer_key]) self.optimizer.load_state_dict(state_dict[optimizer_key])
pax(0, { # pax(0, {
"state_dict" : state_dict, # "state_dict" : state_dict,
"params" : state_dict["params"], # "params" : state_dict["params"],
}) # })
# Grad scaler. # Grad scaler.
if 'grad_scaler' not in state_dict: if 'grad_scaler' not in state_dict:
print_rank_0('***WARNING*** found an old checkpoint, will not ' print_rank_0('***WARNING*** found an old checkpoint, will not '
...@@ -346,11 +347,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -346,11 +347,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
'Skipping loading grad scaler ...') 'Skipping loading grad scaler ...')
# Copy data for the main params. # Copy data for the main params.
current_groups = [ g["params"] for g in self.optimizer.param_groups ]
params_key = 'params' params_key = 'params'
assert params_key in state_dict, "key 'params' not in state_dict." assert params_key in state_dict, "key 'params' not in state_dict."
# pax(0, {
# "state_dict" : state_dict,
# "current_groups" : current_groups,
# "saved_groups" : state_dict[params_key],
# })
for current_group, saved_group in zip( for current_group, saved_group in zip(
self.fp32_from_float16_groups, current_groups,
state_dict[fp32_from_float16_params_key]): state_dict[params_key]):
pax(0, {
"current_group" : current_group,
"saved_group" : saved_group,
})
for current_param, saved_param in zip(current_group, saved_group): for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data) current_param.data.copy_(saved_param.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment