Commit 59f10bf3 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fp16/fp32 optimizer has different data structure

parent f224b69b
...@@ -293,8 +293,15 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n ...@@ -293,8 +293,15 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
# this parameter is not an expert parameter # this parameter is not an expert parameter
# thus there is no need to save its state in current rank # thus there is no need to save its state in current rank
# since it has been saved by data parallel rank 0 # since it has been saved by data parallel rank 0
if args.fp16:
# fp16 optimizer may have empty state due to overflow
state_dict['optimizer']['optimizer']['state'].pop(index, None)
else:
state_dict['optimizer']['state'].pop(index) state_dict['optimizer']['state'].pop(index)
index += 1 index += 1
if args.fp16:
state_dict['optimizer']['optimizer'].pop('param_groups')
else:
state_dict['optimizer'].pop('param_groups') state_dict['optimizer'].pop('param_groups')
# Save. # Save.
...@@ -318,7 +325,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n ...@@ -318,7 +325,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='n
torch.distributed.barrier() torch.distributed.barrier()
def merge_state_dict(state_dict_rank0, state_dict_local): def merge_state_dict(state_dict_rank0, state_dict_local, fp16):
"""merge two state dicts, one from data parallel rank 0, """merge two state dicts, one from data parallel rank 0,
another only contains expert states""" another only contains expert states"""
from megatron import print_rank_last from megatron import print_rank_last
...@@ -336,12 +343,15 @@ def merge_state_dict(state_dict_rank0, state_dict_local): ...@@ -336,12 +343,15 @@ def merge_state_dict(state_dict_rank0, state_dict_local):
before.sum={:7f}, after.sum={:7f}".format(k, before, after)) before.sum={:7f}, after.sum={:7f}".format(k, before, after))
merge_model(state_dict_rank0['model'], state_dict_local['model']) merge_model(state_dict_rank0['model'], state_dict_local['model'])
for k, v in state_dict_local['optimizer']['state'].items(): optimizer_rank0 = state_dict_rank0['optimizer']['optimizer'] if fp16 else state_dict_rank0['optimizer']
optimizer_local = state_dict_local['optimizer']['optimizer'] if fp16 else state_dict_local['optimizer']
for k, v in optimizer_local['state'].items():
before = {kk: vv.sum().item() \ before = {kk: vv.sum().item() \
for kk, vv in state_dict_rank0['optimizer']['state'][k].items()} for kk, vv in optimizer_rank0['state'][k].items()}
state_dict_rank0['optimizer']['state'][k] = v optimizer_rank0['state'][k] = v
after = {kk: vv.sum().item() \ after = {kk: vv.sum().item() \
for kk, vv in state_dict_rank0['optimizer']['state'][k].items()} for kk, vv in optimizer_rank0['state'][k].items()}
print_rank_last("[merge optimizer] copy {}, \ print_rank_last("[merge optimizer] copy {}, \
before.sum={}, after.sum={}".format(k, str(before), str(after))) before.sum={}, after.sum={}".format(k, str(before), str(after)))
return state_dict_rank0 return state_dict_rank0
...@@ -423,7 +433,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -423,7 +433,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
state_dict_rank0 = load_state_dict(checkpoint_name_rank0) state_dict_rank0 = load_state_dict(checkpoint_name_rank0)
state_dict_local = load_state_dict(checkpoint_name_local) state_dict_local = load_state_dict(checkpoint_name_local)
state_dict = merge_state_dict(state_dict_rank0, state_dict_local) state_dict = merge_state_dict(state_dict_rank0, state_dict_local, args.fp16)
# set checkpoint version # set checkpoint version
set_checkpoint_version(state_dict.get('checkpoint_version', 0)) set_checkpoint_version(state_dict.get('checkpoint_version', 0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment