Commit 269f3fd4 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fix pylint issues

parent 1a6073b5
......@@ -103,7 +103,7 @@ class DistributedGroupedDataParallel(nn.Module):
synced = _unflatten_dense_tensors(coalesced, datas)
for d, s in zip(datas, synced):
d.copy_(s)
def forward(self, *args, **kwargs):
r'''
Directly call the module's forward function.
......
r'''
Layers that FMoE provides to users
'''
import math
import torch
import torch.nn as nn
import numpy as np
import math
from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear
......@@ -34,17 +34,18 @@ class FMoELinear(nn.Module):
'''
rng = np.random.default_rng(np.random.randint(2048) + self.rank)
# copied from https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_uniform_
# copied from torch.nn.init.kaiming_uniform_
fan = nn.init._calculate_correct_fan(self.weight[0], 'fan_in')
gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5))
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
bound = math.sqrt(3.0) * std
device = self.weight.device
dtype = self.weight.dtype
for i in range(self.num_expert):
weight = rng.uniform(-bound, bound, size=tuple(self.weight[i].size()))
self.weight.data[i] = torch.tensor(weight, dtype=dtype, device=device)
weight = rng.uniform(-bound, bound,
size=tuple(self.weight[i].size()))
self.weight.data[i] = torch.tensor(weight,
dtype=dtype, device=device)
def forward(self, inp, fwd_expert_count):
r'''
......
......@@ -52,7 +52,8 @@ class FMoETransformerMLP(FMoE):
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group,
expert_fn=expert_fn)
self.experts = _Expert(num_expert, d_model, d_hidden, activation, self.mp_rank)
self.experts = _Expert(num_expert, d_model, d_hidden, activation,
rank=self.mp_rank)
self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model)
self.mark_parallel_comm()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment