Unverified Commit 406955e7 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #4 from laekov/init_expert

merge init expert
parents 02ead75d 269f3fd4
...@@ -103,7 +103,7 @@ class DistributedGroupedDataParallel(nn.Module): ...@@ -103,7 +103,7 @@ class DistributedGroupedDataParallel(nn.Module):
synced = _unflatten_dense_tensors(coalesced, datas) synced = _unflatten_dense_tensors(coalesced, datas)
for d, s in zip(datas, synced): for d, s in zip(datas, synced):
d.copy_(s) d.copy_(s)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
r''' r'''
Directly call the module's forward function. Directly call the module's forward function.
......
r''' r'''
Layers that FMoE provides to users Layers that FMoE provides to users
''' '''
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import numpy as np
from .functions import moe_prepare_forward from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear from .functions import MOEScatter, MOEGather, MOELinear
...@@ -17,11 +19,12 @@ class FMoELinear(nn.Module): ...@@ -17,11 +19,12 @@ class FMoELinear(nn.Module):
performed in parallel to increase the performance. performed in parallel to increase the performance.
The FMoELinear module provides such function. The FMoELinear module provides such function.
''' '''
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024): def __init__(self, num_expert=32, in_feat=1024, out_feat=1024, rank=0):
super().__init__() super().__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat)) self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
self.reset_parameters() self.reset_parameters()
...@@ -29,10 +32,20 @@ class FMoELinear(nn.Module): ...@@ -29,10 +32,20 @@ class FMoELinear(nn.Module):
r''' r'''
Initialize the weight as linear layers Initialize the weight as linear layers
''' '''
rng = np.random.default_rng(np.random.randint(2048) + self.rank)
# copied from torch.nn.init.kaiming_uniform_
fan = nn.init._calculate_correct_fan(self.weight[0], 'fan_in')
gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5))
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std
device = self.weight.device
dtype = self.weight.dtype
for i in range(self.num_expert): for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, weight = rng.uniform(-bound, bound,
out_features=self.out_feat) size=tuple(self.weight[i].size()))
self.weight.data[i] = linear.weight.data self.weight.data[i] = torch.tensor(weight,
dtype=dtype, device=device)
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
r''' r'''
......
...@@ -12,10 +12,10 @@ class _Expert(nn.Module): ...@@ -12,10 +12,10 @@ class _Expert(nn.Module):
An expert using 2 FMoELinear modules to speed up the computation of experts An expert using 2 FMoELinear modules to speed up the computation of experts
within one worker. within one worker.
''' '''
def __init__(self, num_expert, d_model, d_hidden, activation): def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
super().__init__() super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden) self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model) self.h4toh = FMoELinear(num_expert, d_hidden, d_model, rank)
self.activation = activation self.activation = activation
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
...@@ -52,7 +52,8 @@ class FMoETransformerMLP(FMoE): ...@@ -52,7 +52,8 @@ class FMoETransformerMLP(FMoE):
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate, super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group, top_k=top_k, world_size=world_size, mp_group=mp_group,
expert_fn=expert_fn) expert_fn=expert_fn)
self.experts = _Expert(num_expert, d_model, d_hidden, activation) self.experts = _Expert(num_expert, d_model, d_hidden, activation,
rank=self.mp_rank)
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = nn.LayerNorm(d_model)
self.mark_parallel_comm() self.mark_parallel_comm()
......
...@@ -10,7 +10,6 @@ cxx_flags = [ ...@@ -10,7 +10,6 @@ cxx_flags = [
ext_libs = [] ext_libs = []
if os.environ.get('USE_NCCL', '0') == '1': if os.environ.get('USE_NCCL', '0') == '1':
cxx_flags.append('-DMOE_USE_NCCL') cxx_flags.append('-DMOE_USE_NCCL')
ext_libs.append('nccl')
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment