Commit c844413b authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fix pylint

parent 49a4678c
......@@ -5,9 +5,9 @@ See `examples/megatron` for usage instructions.
"""
import os
import math
import numpy as np
import random
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -392,6 +392,7 @@ def get_checkpoint_name(checkpoints_path, iteration,
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint with expert parallel """
# TODO: update patch
from megatron import get_args
from megatron import mpu
......@@ -405,15 +406,13 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True)
data_parallel_rank = mpu.get_data_parallel_rank()
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration
keep_vars = False if mpu.get_data_parallel_rank() == 0 else True
state_dict['model'] = model.state_dict_for_save_checkpoint(keep_vars=keep_vars)
state_dict['model'] = model.state_dict_for_save_checkpoint(
keep_vars=(mpu.get_data_parallel_rank() > 0))
if mpu.get_data_parallel_rank() != 0:
......@@ -421,15 +420,17 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict_new = state_dict.__class__()
for k, v in state_dict.items():
# megatron uses both dict and OrderedDict in its state_dict
if isinstance(v, OrderedDict) or isinstance(v, dict):
if isinstance(v, (OrderedDict, dict)):
v_new = extract_expert_param(v, expert_dp_comm)
if len(v_new):
if len(v_new) > 0:
state_dict_new[k] = v_new
elif hasattr(v, 'dp_comm') and v.dp_comm == expert_dp_comm:
state_dict_new[k] = v.detach()
return state_dict_new
state_dict['model'] = extract_expert_param(state_dict['model'], 'none')
state_dict['model'] = extract_expert_param(
state_dict['model'],
expert_dp_comm='none')
# Optimizer stuff.
if not args.no_save_optim:
......
......@@ -15,10 +15,8 @@ class _Expert(nn.Module):
def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True,
rank=rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True,
rank=rank)
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
self.activation = activation
def forward(self, inp, fwd_expert_count):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment