Commit 3c24222c authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

add and initialize bias term in FMoELinear

parent 406955e7
...@@ -19,13 +19,18 @@ class FMoELinear(nn.Module): ...@@ -19,13 +19,18 @@ class FMoELinear(nn.Module):
performed in parallel to increase the performance. performed in parallel to increase the performance.
The FMoELinear module provides such function. The FMoELinear module provides such function.
''' '''
def __init__(self, num_expert=32, in_feat=1024, out_feat=1024, rank=0): def __init__(self, num_expert: int, in_feat: int, out_feat: int,
bias: bool = True, rank: int = 0):
super().__init__() super().__init__()
self.num_expert = num_expert self.num_expert = num_expert
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
self.rank = rank self.rank = rank
self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat)) self.weight = nn.Parameter(torch.Tensor(num_expert, out_feat, in_feat))
if bias:
self.bias = nn.Parameter(torch.Tensor(num_expert, out_feat))
else:
self.register_parameter('bias', None)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -41,17 +46,24 @@ class FMoELinear(nn.Module): ...@@ -41,17 +46,24 @@ class FMoELinear(nn.Module):
bound = math.sqrt(3.0) * std bound = math.sqrt(3.0) * std
device = self.weight.device device = self.weight.device
dtype = self.weight.dtype dtype = self.weight.dtype
for i in range(self.num_expert): weight = rng.uniform(-bound, bound, size=tuple(self.weight.size()))
weight = rng.uniform(-bound, bound, self.weight.data = torch.tensor(weight, dtype=dtype, device=device)
size=tuple(self.weight[i].size()))
self.weight.data[i] = torch.tensor(weight, if self.bias is not None:
dtype=dtype, device=device) fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
bound = 1 / math.sqrt(fan_in)
bias = rng.uniform(-bound, bound, size=tuple(self.bias.size()))
self.bias.data = torch.tensor(bias, dtype=dtype, device=device)
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
r''' r'''
Call MOE function Call MOE function
''' '''
return MOELinear.apply(inp, self.weight, fwd_expert_count) x = MOELinear.apply(inp, self.weight, fwd_expert_count)
if self.bias:
bias = torch.repeat_interleave(self.bias, fwd_expert_count, dim=0)
x = x + bias
return x
def mark_module_parallel_comm(module, comm): def mark_module_parallel_comm(module, comm):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment