Commit a372eb66 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fix pylint

parent d8124b80
......@@ -238,11 +238,12 @@ def get_fmoe_checkpoint_name(checkpoints_path, iteration,
),
'model_optim_rng.pt')
def save_checkpoint(iteration, model, optimizer, lr_scheduler, expert_dp_comm='none'):
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint with expert parallel """
# TODO: update patch
from megatron import get_args
from megatron import mpu
expert_dp_comm = 'none'
if mpu.get_data_parallel_rank() == 0:
# at dp rank 0, we still follows the native load_checkpoint by megatron
......@@ -362,8 +363,10 @@ def merge_state_dict(state_dict_rank0, state_dict_local, fp16):
before.sum={:7f}, after.sum={:7f}".format(k, before, after))
merge_model(state_dict_rank0['model'], state_dict_local['model'])
optimizer_rank0 = state_dict_rank0['optimizer']['optimizer'] if fp16 else state_dict_rank0['optimizer']
optimizer_local = state_dict_local['optimizer']['optimizer'] if fp16 else state_dict_local['optimizer']
optimizer_rank0 = state_dict_rank0['optimizer']['optimizer'] \
if fp16 else state_dict_rank0['optimizer']
optimizer_local = state_dict_local['optimizer']['optimizer'] \
if fp16 else state_dict_local['optimizer']
for k, v in optimizer_local['state'].items():
before = {kk: vv.sum().item() \
......@@ -389,7 +392,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
from megatron import get_args
from megatron import mpu
from megatron import print_rank_last
from megatron.checkpointing import get_checkpoint_tracker_filename, set_checkpoint_version, check_checkpoint_args, update_num_microbatches
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import set_checkpoint_version
from megatron.checkpointing import check_checkpoint_args
from megatron.checkpointing import update_num_microbatches
if mpu.get_data_parallel_rank() == 0:
# at dp rank 0, we still follow the native load_checkpoint by megatron
from megatron.checkpointing import load_checkpoint as load_checkpoint_native
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment