"git@developer.sourcefind.cn:OpenDAS/torchani.git" did not exist on "168b05939ee36f68a8c2a24d72f07dab8782657c"
Commit 269f3fd4 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

fix pylint issues

parent 1a6073b5
...@@ -103,7 +103,7 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -103,7 +103,7 @@ class DistributedGroupedDataParallel(nn.Module):
synced = _unflatten_dense_tensors(coalesced, datas) synced = _unflatten_dense_tensors(coalesced, datas)
for d, s in zip(datas, synced): for d, s in zip(datas, synced):
d.copy_(s) d.copy_(s)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
r''' r'''
Directly call the module's forward function. Directly call the module's forward function.
......
r''' r'''
Layers that FMoE provides to users Layers that FMoE provides to users
''' '''
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np import numpy as np
import math
from .functions import moe_prepare_forward from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear from .functions import MOEScatter, MOEGather, MOELinear
...@@ -34,17 +34,18 @@ class FMoELinear(nn.Module): ...@@ -34,17 +34,18 @@ class FMoELinear(nn.Module):
''' '''
rng = np.random.default_rng(np.random.randint(2048) + self.rank) rng = np.random.default_rng(np.random.randint(2048) + self.rank)
# copied from https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_uniform_ # copied from torch.nn.init.kaiming_uniform_
fan = nn.init._calculate_correct_fan(self.weight[0], 'fan_in') fan = nn.init._calculate_correct_fan(self.weight[0], 'fan_in')
gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5)) gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5))
std = gain / math.sqrt(fan) std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation bound = math.sqrt(3.0) * std
device = self.weight.device device = self.weight.device
dtype = self.weight.dtype dtype = self.weight.dtype
for i in range(self.num_expert): for i in range(self.num_expert):
weight = rng.uniform(-bound, bound, size=tuple(self.weight[i].size())) weight = rng.uniform(-bound, bound,
self.weight.data[i] = torch.tensor(weight, dtype=dtype, device=device) size=tuple(self.weight[i].size()))
self.weight.data[i] = torch.tensor(weight,
dtype=dtype, device=device)
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
r''' r'''
......
...@@ -52,7 +52,8 @@ class FMoETransformerMLP(FMoE): ...@@ -52,7 +52,8 @@ class FMoETransformerMLP(FMoE):
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate, super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group, top_k=top_k, world_size=world_size, mp_group=mp_group,
expert_fn=expert_fn) expert_fn=expert_fn)
self.experts = _Expert(num_expert, d_model, d_hidden, activation, self.mp_rank) self.experts = _Expert(num_expert, d_model, d_hidden, activation,
rank=self.mp_rank)
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = nn.LayerNorm(d_model)
self.mark_parallel_comm() self.mark_parallel_comm()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment