Commit 90e394bc authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

load/save fp32 main weight when fp16 training

parent bcfeaf3b
......@@ -443,7 +443,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
if not args.no_save_optim:
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
index = 0
param_global_idx = 0
for param_group in optimizer.optimizer.param_groups:
for param in param_group['params']:
if not (hasattr(param, 'dp_comm') and \
......@@ -453,12 +453,31 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
# since it has been saved by data parallel rank 0
if args.fp16:
# fp16 optimizer may have empty state due to overflow
state_dict['optimizer']['optimizer']['state'].pop(index, None)
state_dict['optimizer']['optimizer']['state'].pop(
param_global_idx, None)
else:
state_dict['optimizer']['state'].pop(index)
index += 1
state_dict['optimizer']['state'].pop(
param_global_idx)
param_global_idx += 1
if args.fp16:
state_dict['optimizer']['optimizer'].pop('param_groups')
# fp32_from_fp16_params in state_dict is not a copy
# but a reference to optimizer.fp32_from_fp16_params,
# changing it in state_dict will change
# optimizer.fp32_from_fp16_params as well
# thus we create an empty fp32_from_fp16_params in state_dict
# and only insert expert parameters.
fp32_from_fp16_params = \
state_dict['optimizer']['fp32_from_fp16_params']
state_dict['optimizer']['fp32_from_fp16_params'] = []
for param_group in fp32_from_fp16_params:
param_group_copy = []
for param in param_group:
param_copy = param if hasattr(param, 'dp_comm') \
and param.dp_comm == expert_dp_comm else None
param_group_copy.append(param_copy)
state_dict['optimizer']['fp32_from_fp16_params'].append(
param_group_copy)
else:
state_dict['optimizer'].pop('param_groups')
......@@ -512,6 +531,14 @@ def merge_state_dict(state_dict_rank0, state_dict_local, fp16):
for kk, vv in optimizer_rank0['state'][k].items()}
print_rank_last("[merge optimizer] copy {}, \
before.sum={}, after.sum={}".format(k, str(before), str(after)))
if fp16:
for group_idx, param_group in enumerate(state_dict_local['optimizer']['fp32_from_fp16_params']):
for param_in_group_idx, param in enumerate(param_group):
if param is not None:
state_dict_rank0['optimizer']['fp32_from_fp16_params'][group_idx][param_in_group_idx] = param
print_rank_last("[merge fp32_from_fp16_params] copy parameter ({:d}, {:d})".format(group_idx, param_in_group_idx))
return state_dict_rank0
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment