Commit c844413b authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fix pylint

parent 49a4678c
...@@ -5,9 +5,9 @@ See `examples/megatron` for usage instructions. ...@@ -5,9 +5,9 @@ See `examples/megatron` for usage instructions.
""" """
import os import os
import math import math
import numpy as np
import random import random
from collections import OrderedDict from collections import OrderedDict
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -392,6 +392,7 @@ def get_checkpoint_name(checkpoints_path, iteration, ...@@ -392,6 +392,7 @@ def get_checkpoint_name(checkpoints_path, iteration,
def save_checkpoint(iteration, model, optimizer, lr_scheduler): def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint with expert parallel """ """Save a model checkpoint with expert parallel """
# TODO: update patch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
...@@ -405,15 +406,13 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -405,15 +406,13 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print('saving checkpoint at iteration {:7d} to {}'.format( print('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save), flush=True) iteration, args.save), flush=True)
data_parallel_rank = mpu.get_data_parallel_rank()
# Arguments, iteration, and model. # Arguments, iteration, and model.
state_dict = {} state_dict = {}
state_dict['args'] = args state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0 state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration state_dict['iteration'] = iteration
keep_vars = False if mpu.get_data_parallel_rank() == 0 else True state_dict['model'] = model.state_dict_for_save_checkpoint(
state_dict['model'] = model.state_dict_for_save_checkpoint(keep_vars=keep_vars) keep_vars=(mpu.get_data_parallel_rank() > 0))
if mpu.get_data_parallel_rank() != 0: if mpu.get_data_parallel_rank() != 0:
...@@ -421,15 +420,17 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -421,15 +420,17 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict_new = state_dict.__class__() state_dict_new = state_dict.__class__()
for k, v in state_dict.items(): for k, v in state_dict.items():
# megatron uses both dict and OrderedDict in its state_dict # megatron uses both dict and OrderedDict in its state_dict
if isinstance(v, OrderedDict) or isinstance(v, dict): if isinstance(v, (OrderedDict, dict)):
v_new = extract_expert_param(v, expert_dp_comm) v_new = extract_expert_param(v, expert_dp_comm)
if len(v_new): if len(v_new) > 0:
state_dict_new[k] = v_new state_dict_new[k] = v_new
elif hasattr(v, 'dp_comm') and v.dp_comm == expert_dp_comm: elif hasattr(v, 'dp_comm') and v.dp_comm == expert_dp_comm:
state_dict_new[k] = v.detach() state_dict_new[k] = v.detach()
return state_dict_new return state_dict_new
state_dict['model'] = extract_expert_param(state_dict['model'], 'none') state_dict['model'] = extract_expert_param(
state_dict['model'],
expert_dp_comm='none')
# Optimizer stuff. # Optimizer stuff.
if not args.no_save_optim: if not args.no_save_optim:
......
...@@ -15,10 +15,8 @@ class _Expert(nn.Module): ...@@ -15,10 +15,8 @@ class _Expert(nn.Module):
def __init__(self, num_expert, d_model, d_hidden, activation, rank=0): def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
super().__init__() super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
rank=rank) self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True,
rank=rank)
self.activation = activation self.activation = activation
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment