Commit 450549a4 authored by Rick Ho's avatar Rick Ho
Browse files

fix zero gate and constant gate

parent 956d3aef
......@@ -27,7 +27,4 @@ class ZeroGate(BaseGate):
gate_score = (
torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k
)
gate_score_all = torch.zeros(inp.shape[0], self.num_expert, device=inp.device)
gate_score_all[:, 0] = 1
return idx, gate_score.reshape(-1, 1, self.top_k), gate_score_all
return idx, gate_score.reshape(-1, 1, self.top_k)
......@@ -12,7 +12,7 @@ class ConstantGate(torch.nn.Module):
idx = torch.zeros((inp.shape[0] * self.top_k,), dtype=torch.int64,
device=inp.device)
score = torch.ones((inp.shape[0], 1, self.top_k), device=inp.device) / 2
return idx, score, None
return idx, score
def test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment