"git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "f903fda39e4c3f712fa6e161be89e60fd9470952"
Commit 450549a4 authored by Rick Ho's avatar Rick Ho
Browse files

fix zero gate and constant gate

parent 956d3aef
...@@ -27,7 +27,4 @@ class ZeroGate(BaseGate): ...@@ -27,7 +27,4 @@ class ZeroGate(BaseGate):
gate_score = ( gate_score = (
torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k torch.ones(inp.shape[0] * self.top_k, device=inp.device) / self.top_k
) )
gate_score_all = torch.zeros(inp.shape[0], self.num_expert, device=inp.device) return idx, gate_score.reshape(-1, 1, self.top_k)
gate_score_all[:, 0] = 1
return idx, gate_score.reshape(-1, 1, self.top_k), gate_score_all
...@@ -12,7 +12,7 @@ class ConstantGate(torch.nn.Module): ...@@ -12,7 +12,7 @@ class ConstantGate(torch.nn.Module):
idx = torch.zeros((inp.shape[0] * self.top_k,), dtype=torch.int64, idx = torch.zeros((inp.shape[0] * self.top_k,), dtype=torch.int64,
device=inp.device) device=inp.device)
score = torch.ones((inp.shape[0], 1, self.top_k), device=inp.device) / 2 score = torch.ones((inp.shape[0], 1, self.top_k), device=inp.device) / 2
return idx, score, None return idx, score
def test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1): def test_zero_fwd(num_expert=2, batch_size=4, d_hidden=8, world_size=1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment