"docs/vscode:/vscode.git/clone" did not exist on "aa3c46d99acfaa145bdf620f821de9b409c2e6c6"
Commit 68997976 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fix pylint

parent 90e394bc
...@@ -396,11 +396,12 @@ def get_fmoe_checkpoint_name(checkpoints_path, iteration, ...@@ -396,11 +396,12 @@ def get_fmoe_checkpoint_name(checkpoints_path, iteration,
), ),
'model_optim_rng.pt') 'model_optim_rng.pt')
def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='none'): def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint with expert parallel """ """Save a model checkpoint with expert parallel """
# TODO: update patch # TODO: update patch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
expert_dp_comm = 'none'
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
# at dp rank 0, we still follows the native load_checkpoint by megatron # at dp rank 0, we still follows the native load_checkpoint by megatron
...@@ -520,8 +521,10 @@ def merge_state_dict(state_dict_rank0, state_dict_local, fp16): ...@@ -520,8 +521,10 @@ def merge_state_dict(state_dict_rank0, state_dict_local, fp16):
before.sum={:7f}, after.sum={:7f}".format(k, before, after)) before.sum={:7f}, after.sum={:7f}".format(k, before, after))
merge_model(state_dict_rank0['model'], state_dict_local['model']) merge_model(state_dict_rank0['model'], state_dict_local['model'])
optimizer_rank0 = state_dict_rank0['optimizer']['optimizer'] if fp16 else state_dict_rank0['optimizer'] optimizer_rank0 = state_dict_rank0['optimizer']['optimizer'] \
optimizer_local = state_dict_local['optimizer']['optimizer'] if fp16 else state_dict_local['optimizer'] if fp16 else state_dict_rank0['optimizer']
optimizer_local = state_dict_local['optimizer']['optimizer'] \
if fp16 else state_dict_local['optimizer']
for k, v in optimizer_local['state'].items(): for k, v in optimizer_local['state'].items():
before = {kk: vv.sum().item() \ before = {kk: vv.sum().item() \
...@@ -547,7 +550,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'): ...@@ -547,7 +550,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron import print_rank_last from megatron import print_rank_last
from megatron.checkpointing import get_checkpoint_tracker_filename, set_checkpoint_version, check_checkpoint_args, update_num_microbatches from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import set_checkpoint_version
from megatron.checkpointing import check_checkpoint_args
from megatron.checkpointing import update_num_microbatches
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
# at dp rank 0, we still follow the native load_checkpoint by megatron # at dp rank 0, we still follow the native load_checkpoint by megatron
from megatron.checkpointing import load_checkpoint as load_checkpoint_native from megatron.checkpointing import load_checkpoint as load_checkpoint_native
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment