Commit bcfeaf3b authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fp16/fp32 optimizer has different data structure

parent 12b23ae3
......@@ -451,8 +451,15 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
# this parameter is not an expert parameter
# thus there is no need to save its state in current rank
# since it has been saved by data parallel rank 0
if args.fp16:
# fp16 optimizer may have empty state due to overflow
state_dict['optimizer']['optimizer']['state'].pop(index, None)
else:
state_dict['optimizer']['state'].pop(index)
index += 1
if args.fp16:
state_dict['optimizer']['optimizer'].pop('param_groups')
else:
state_dict['optimizer'].pop('param_groups')
# Save.
......@@ -476,7 +483,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
torch.distributed.barrier()
def merge_state_dict(state_dict_rank0, state_dict_local):
def merge_state_dict(state_dict_rank0, state_dict_local, fp16):
"""merge two state dicts, one from data parallel rank 0,
another only contains expert states"""
from megatron import print_rank_last
......@@ -494,12 +501,15 @@ def merge_state_dict(state_dict_rank0, state_dict_local):
before.sum={:7f}, after.sum={:7f}".format(k, before, after))
merge_model(state_dict_rank0['model'], state_dict_local['model'])
for k, v in state_dict_local['optimizer']['state'].items():
optimizer_rank0 = state_dict_rank0['optimizer']['optimizer'] if fp16 else state_dict_rank0['optimizer']
optimizer_local = state_dict_local['optimizer']['optimizer'] if fp16 else state_dict_local['optimizer']
for k, v in optimizer_local['state'].items():
before = {kk: vv.sum().item() \
for kk, vv in state_dict_rank0['optimizer']['state'][k].items()}
state_dict_rank0['optimizer']['state'][k] = v
for kk, vv in optimizer_rank0['state'][k].items()}
optimizer_rank0['state'][k] = v
after = {kk: vv.sum().item() \
for kk, vv in state_dict_rank0['optimizer']['state'][k].items()}
for kk, vv in optimizer_rank0['state'][k].items()}
print_rank_last("[merge optimizer] copy {}, \
before.sum={}, after.sum={}".format(k, str(before), str(after)))
return state_dict_rank0
......@@ -581,7 +591,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
state_dict_rank0 = load_state_dict(checkpoint_name_rank0)
state_dict_local = load_state_dict(checkpoint_name_local)
state_dict = merge_state_dict(state_dict_rank0, state_dict_local)
state_dict = merge_state_dict(state_dict_rank0, state_dict_local, args.fp16)
# set checkpoint version
set_checkpoint_version(state_dict.get('checkpoint_version', 0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment