Commit a88d1124 authored by Jiezhong Qiu's avatar Jiezhong Qiu
Browse files

test bias, dp still not passed

parent 3c24222c
...@@ -60,11 +60,19 @@ class FMoELinear(nn.Module): ...@@ -60,11 +60,19 @@ class FMoELinear(nn.Module):
Call MOE function Call MOE function
''' '''
x = MOELinear.apply(inp, self.weight, fwd_expert_count) x = MOELinear.apply(inp, self.weight, fwd_expert_count)
if self.bias: if self.bias is not None:
bias = torch.repeat_interleave(self.bias, fwd_expert_count, dim=0) bias = torch.repeat_interleave(self.bias,
fwd_expert_count.to(self.bias.device), dim=0)
x = x + bias x = x + bias
return x return x
def extra_repr(self) -> str:
return 'num_expert={}, in_features={}, \
out_features={}, bias={}, rank={}'.format(
self.num_expert, self.in_feat,
self.out_feat, self.bias is not None, self.rank
)
def mark_module_parallel_comm(module, comm): def mark_module_parallel_comm(module, comm):
r''' r'''
......
...@@ -14,8 +14,10 @@ class _Expert(nn.Module): ...@@ -14,8 +14,10 @@ class _Expert(nn.Module):
''' '''
def __init__(self, num_expert, d_model, d_hidden, activation, rank=0): def __init__(self, num_expert, d_model, d_hidden, activation, rank=0):
super().__init__() super().__init__()
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, rank) self.htoh4 = FMoELinear(num_expert, d_model, d_hidden,
self.h4toh = FMoELinear(num_expert, d_hidden, d_model, rank) bias=True, rank=rank)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model,
bias=True, rank=rank)
self.activation = activation self.activation = activation
def forward(self, inp, fwd_expert_count): def forward(self, inp, fwd_expert_count):
......
...@@ -20,9 +20,15 @@ class BruteForceMoELinear(nn.Module): ...@@ -20,9 +20,15 @@ class BruteForceMoELinear(nn.Module):
self.weight_htoh4 = nn.Parameter( self.weight_htoh4 = nn.Parameter(
torch.Tensor(num_expert * world_size, d_hidden, d_model) torch.Tensor(num_expert * world_size, d_hidden, d_model)
) )
self.bias_htoh4 = nn.Parameter(
torch.Tensor(num_expert * world_size, d_hidden)
)
self.weight_h4toh = nn.Parameter( self.weight_h4toh = nn.Parameter(
torch.Tensor(num_expert * world_size, d_model, d_hidden) torch.Tensor(num_expert * world_size, d_model, d_hidden)
) )
self.bias_h4toh = nn.Parameter(
torch.Tensor(num_expert * world_size, d_model)
)
self.top_k = top_k self.top_k = top_k
def forward(self, inp, gate_idx, gate_score): def forward(self, inp, gate_idx, gate_score):
...@@ -34,8 +40,10 @@ class BruteForceMoELinear(nn.Module): ...@@ -34,8 +40,10 @@ class BruteForceMoELinear(nn.Module):
idx = (gate_idx == i) idx = (gate_idx == i)
x = inp[idx] x = inp[idx]
x = x @ self.weight_htoh4[i].t() x = x @ self.weight_htoh4[i].t()
x = x + self.bias_htoh4[i]
x = self.activation(x) x = self.activation(x)
x = x @ self.weight_h4toh[i].t() x = x @ self.weight_h4toh[i].t()
x = x + self.bias_h4toh[i]
o[idx] = x o[idx] = x
x = torch.bmm(gate_score, o.view(-1, self.top_k, x = torch.bmm(gate_score, o.view(-1, self.top_k,
self.d_model)).reshape(-1, self.d_model) self.d_model)).reshape(-1, self.d_model)
......
...@@ -100,19 +100,31 @@ def test_fmoe_linear( ...@@ -100,19 +100,31 @@ def test_fmoe_linear(
if world_size == 1: if world_size == 1:
moe_raw.weight_htoh4.data = experts.htoh4.weight.data.clone() moe_raw.weight_htoh4.data = experts.htoh4.weight.data.clone()
moe_raw.bias_htoh4.data = experts.htoh4.bias.data.clone()
moe_raw.weight_h4toh.data = experts.h4toh.weight.data.clone() moe_raw.weight_h4toh.data = experts.h4toh.weight.data.clone()
moe_raw.bias_h4toh.data = experts.h4toh.bias.data.clone()
else: else:
weight_htoh4_array = [ weight_htoh4_array = [
torch.empty_like(experts.htoh4.weight.data) for _ in range(world_size) torch.empty_like(experts.htoh4.weight.data) for _ in range(world_size)
] ]
bias_htoh4_array = [
torch.empty_like(experts.htoh4.bias.data) for _ in range(world_size)
]
torch.distributed.all_gather(weight_htoh4_array, experts.htoh4.weight.data) torch.distributed.all_gather(weight_htoh4_array, experts.htoh4.weight.data)
torch.distributed.all_gather(bias_htoh4_array, experts.htoh4.bias.data)
moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0) moe_raw.weight_htoh4.data = torch.cat(weight_htoh4_array, dim=0)
moe_raw.bias_htoh4.data = torch.cat(bias_htoh4_array, dim=0)
weight_h4toh_array = [ weight_h4toh_array = [
torch.empty_like(experts.h4toh.weight.data) for _ in range(world_size) torch.empty_like(experts.h4toh.weight.data) for _ in range(world_size)
] ]
bias_h4toh_array = [
torch.empty_like(experts.h4toh.bias.data) for _ in range(world_size)
]
torch.distributed.all_gather(weight_h4toh_array, experts.h4toh.weight.data) torch.distributed.all_gather(weight_h4toh_array, experts.h4toh.weight.data)
torch.distributed.all_gather(bias_h4toh_array, experts.h4toh.bias.data)
moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0) moe_raw.weight_h4toh.data = torch.cat(weight_h4toh_array, dim=0)
moe_raw.bias_h4toh.data = torch.cat(bias_h4toh_array, dim=0)
moe_out, raw_out = _perform_forward( moe_out, raw_out = _perform_forward(
moe, moe_raw, batch_size, d_model, top_k, rank, mp_group moe, moe_raw, batch_size, d_model, top_k, rank, mp_group
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment