Commit f804a121 authored by Rick Ho's avatar Rick Ho
Browse files

update test

parent 980cf4b6
......@@ -28,11 +28,12 @@ class BruteForceMoELinear(nn.Module):
self.top_k = top_k
def forward(self, inp, gate_idx, gate_score):
gate_long = gate_idx.long()
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
gate_long = gate_idx.long().view(-1)
batch_size = inp.size(0)
o = torch.empty(batch_size, self.d_model, dtype=inp.dtype, device=inp.device)
for i in range(self.weight_htoh4.shape[0]):
idx = gate_idx == i
idx = gate_long == i
x = inp[idx]
x = x @ self.weight_htoh4[i].t()
x = x + self.bias_htoh4[i]
......@@ -56,7 +57,8 @@ class BruteForceMoE(nn.Module):
self.experts = [expert(d_model) for _ in range(num_expert * world_size)]
def forward(self, inp, gate_idx, gate_score):
gate_long = gate_idx.long()
inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
gate_long = gate_idx.long().view(-1)
batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.d_model))
for i in range(batch_size):
......
......@@ -58,5 +58,5 @@ def test_switch_gate(d_model, batch_size, n_expert, cap):
if __name__ == '__main__':
_ensure_initialized()
# test_gshard_gate(4096, 1024, 4, .2)
test_switch_gate(4096, 1024, 4, .2)
test_gshard_gate(4096, 1024, 4, .2)
# test_switch_gate(4096, 1024, 4, .2)
......@@ -39,9 +39,8 @@ def _perform_forward(
inp_raw.requires_grad = True
gate_idx, gate_score = moe.gate(inp_raw)
inp_repeated = inp_raw.repeat_interleave(repeats=top_k, dim=0)
moe_out = moe(inp)
raw_out = moe_raw(inp_repeated, gate_idx, gate_score)
raw_out = moe_raw(inp_raw, gate_idx, gate_score)
raw_out.mean().backward()
moe_out.mean().backward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment