Commit af2b136f authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

optimizer saves list(group), not list(param).

parent 37ca7859
......@@ -402,17 +402,17 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
sys.exit()
# set checkpoint version
set_checkpoint_version(state_dict.get('checkpoint_version', 0))
set_checkpoint_version(model_state_dict.get('checkpoint_version', 0))
# Set iteration.
if args.finetune or release:
iteration = 0
else:
try:
iteration = state_dict['iteration']
iteration = model_state_dict['iteration']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters']
iteration = model_state_dict['total_iters']
except KeyError:
print_rank_0('A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'.format(
......@@ -422,8 +422,8 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Check arguments.
assert args.consumed_train_samples == 0
assert args.consumed_valid_samples == 0
if 'args' in state_dict:
checkpoint_args = state_dict['args']
if 'args' in model_state_dict:
checkpoint_args = model_state_dict['args']
check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(checkpoint_args,
'consumed_train_samples', 0)
......@@ -435,11 +435,11 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# Model.
if len(model) == 1:
model[0].load_state_dict(state_dict['model'], strict=strict)
model[0].load_state_dict(model_state_dict['model'], strict=strict)
else:
for i in range(len(model)):
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
model[i].load_state_dict(model_state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering if needed
checkpoint_version = get_checkpoint_version()
......@@ -450,12 +450,12 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer'])
optimizer.load_state_dict(optim_state_dict['optimizer'])
if opt_param_scheduler is not None:
if 'lr_scheduler' in state_dict: # backward compatbility
opt_param_scheduler.load_state_dict(state_dict['lr_scheduler'])
opt_param_scheduler.load_state_dict(optim_state_dict['lr_scheduler'])
else:
opt_param_scheduler.load_state_dict(state_dict['opt_param_scheduler'])
opt_param_scheduler.load_state_dict(optim_state_dict['opt_param_scheduler'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
......
......@@ -306,8 +306,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
state_dict['optimizer'] = self.optimizer.state_dict()
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
state_dict['params'] = \
[ p for g in self.optimizer.param_groups for p in g["params"] ]
# state_dict['params'] = \
# [ p for g in self.optimizer.param_groups for p in g["params"] ]
state_dict['groups'] = [g["params"] for g in self.optimizer.param_groups]
# pax(0, { # ... only called on model rank 0
# # "optimizer" : self.optimizer,
# "state_dict" : state_dict,
......@@ -329,10 +330,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
'an old checkpoint ...')
self.optimizer.load_state_dict(state_dict[optimizer_key])
pax(0, {
"state_dict" : state_dict,
"params" : state_dict["params"],
})
# pax(0, {
# "state_dict" : state_dict,
# "params" : state_dict["params"],
# })
# Grad scaler.
if 'grad_scaler' not in state_dict:
print_rank_0('***WARNING*** found an old checkpoint, will not '
......@@ -346,11 +347,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
'Skipping loading grad scaler ...')
# Copy data for the main params.
current_groups = [ g["params"] for g in self.optimizer.param_groups ]
params_key = 'params'
assert params_key in state_dict, "key 'params' not in state_dict."
# pax(0, {
# "state_dict" : state_dict,
# "current_groups" : current_groups,
# "saved_groups" : state_dict[params_key],
# })
for current_group, saved_group in zip(
self.fp32_from_float16_groups,
state_dict[fp32_from_float16_params_key]):
current_groups,
state_dict[params_key]):
pax(0, {
"current_group" : current_group,
"saved_group" : saved_group,
})
for current_param, saved_param in zip(current_group, saved_group):
current_param.data.copy_(saved_param.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment