Commit 5f5ccd47 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

init FMoELinear param using mp_ank and numpy rng

parent da11cb76
...@@ -17,11 +17,12 @@ class FMoELinear(nn.Module): ...@@ -17,11 +17,12 @@ class FMoELinear(nn.Module):
performed in parallel to increase the performance. performed in parallel to increase the performance.
The FMoELinear module provides such function. The FMoELinear module provides such function.
''' '''
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024): def __init__(self, num_expert=32, in_feat=1024, out_feat=1024, rank=0):
super().__init__() super().__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat)) self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
self.reset_parameters() self.reset_parameters()
...@@ -29,10 +30,14 @@ class FMoELinear(nn.Module): ...@@ -29,10 +30,14 @@ class FMoELinear(nn.Module):
r''' r'''
Initialize the weight as linear layers Initialize the weight as linear layers
''' '''
rng = np.random.default_rng(np.random.randint(2048) + self.rank)
fan = nn.init._calculate_correct_fan(self.weight[0], 'fan_in')
gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5))
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
for i in range(self.num_expert): for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, weight = rng.uniform(-bound, bound, size=tuple(self.weight[i].size()))
out_features=self.out_feat) self.weight.data[i] = torch.from_numpy(weight)
self.weight.data[i] = linear.weight.data
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
r''' r'''
......
...@@ -12,10 +12,10 @@ class _Expert(nn.Module): ...@@ -12,10 +12,10 @@ class _Expert(nn.Module):
An expert using 2 FMoELinear modules to speed up the computation of experts An expert using 2 FMoELinear modules to speed up the computation of experts
within one worker. within one worker.
''' '''
def __init__(self, num_expert, d_model, d_hidden, activation): def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
super().__init__() super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden) self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model) self.h4toh = FMoELinear(num_expert, d_hidden, d_model, rank)
self.activation = activation self.activation = activation
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
...@@ -52,7 +52,7 @@ class FMoETransformerMLP(FMoE): ...@@ -52,7 +52,7 @@ class FMoETransformerMLP(FMoE):
super().__init__(num_expert=num_expert, d_model=d_model, gate=gate, super().__init__(num_expert=num_expert, d_model=d_model, gate=gate,
top_k=top_k, world_size=world_size, mp_group=mp_group, top_k=top_k, world_size=world_size, mp_group=mp_group,
expert_fn=expert_fn) expert_fn=expert_fn)
self.experts = _Expert(num_expert, d_model, d_hidden, activation) self.experts = _Expert(num_expert, d_model, d_hidden, activation, self.mp_rank)
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
self.layer_norm = nn.LayerNorm(d_model) self.layer_norm = nn.LayerNorm(d_model)
self.mark_parallel_comm() self.mark_parallel_comm()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment