"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "4774fe3afc61b40a56244e9411a7c3e64ae8147f"
Commit f804a121 authored by Rick Ho's avatar Rick Ho
Browse files

update test

parent 980cf4b6
...@@ -28,11 +28,12 @@ class BruteForceMoELinear(nn.Module): ...@@ -28,11 +28,12 @@ class BruteForceMoELinear(nn.Module):
self.top_k = top_k self.top_k = top_k
def forward(self, inp, gate_idx, gate_score): def forward(self, inp, gate_idx, gate_score):
gate_long = gate_idx.long() inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
gate_long = gate_idx.long().view(-1)
batch_size = inp.size(0) batch_size = inp.size(0)
o = torch.empty(batch_size, self.d_model, dtype=inp.dtype, device=inp.device) o = torch.empty(batch_size, self.d_model, dtype=inp.dtype, device=inp.device)
for i in range(self.weight_htoh4.shape[0]): for i in range(self.weight_htoh4.shape[0]):
idx = gate_idx == i idx = gate_long == i
x = inp[idx] x = inp[idx]
x = x @ self.weight_htoh4[i].t() x = x @ self.weight_htoh4[i].t()
x = x + self.bias_htoh4[i] x = x + self.bias_htoh4[i]
...@@ -56,7 +57,8 @@ class BruteForceMoE(nn.Module): ...@@ -56,7 +57,8 @@ class BruteForceMoE(nn.Module):
self.experts = [expert(d_model) for _ in range(num_expert * world_size)] self.experts = [expert(d_model) for _ in range(num_expert * world_size)]
def forward(self, inp, gate_idx, gate_score): def forward(self, inp, gate_idx, gate_score):
gate_long = gate_idx.long() inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
gate_long = gate_idx.long().view(-1)
batch_size = inp.size(0) batch_size = inp.size(0)
x = inp.new_zeros((batch_size, self.d_model)) x = inp.new_zeros((batch_size, self.d_model))
for i in range(batch_size): for i in range(batch_size):
......
...@@ -58,5 +58,5 @@ def test_switch_gate(d_model, batch_size, n_expert, cap): ...@@ -58,5 +58,5 @@ def test_switch_gate(d_model, batch_size, n_expert, cap):
if __name__ == '__main__': if __name__ == '__main__':
_ensure_initialized() _ensure_initialized()
# test_gshard_gate(4096, 1024, 4, .2) test_gshard_gate(4096, 1024, 4, .2)
test_switch_gate(4096, 1024, 4, .2) # test_switch_gate(4096, 1024, 4, .2)
...@@ -39,9 +39,8 @@ def _perform_forward( ...@@ -39,9 +39,8 @@ def _perform_forward(
inp_raw.requires_grad = True inp_raw.requires_grad = True
gate_idx, gate_score = moe.gate(inp_raw) gate_idx, gate_score = moe.gate(inp_raw)
inp_repeated = inp_raw.repeat_interleave(repeats=top_k, dim=0)
moe_out = moe(inp) moe_out = moe(inp)
raw_out = moe_raw(inp_repeated, gate_idx, gate_score) raw_out = moe_raw(inp_raw, gate_idx, gate_score)
raw_out.mean().backward() raw_out.mean().backward()
moe_out.mean().backward() moe_out.mean().backward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment