Unverified Commit 40f5fd36 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #130 from Fragile-azalea/master

Fix GshardGate top1_idx
parents f2e0d1d5 5f8e3999
...@@ -21,8 +21,7 @@ class GShardGate(NaiveGate): ...@@ -21,8 +21,7 @@ class GShardGate(NaiveGate):
topk_idx, topk_val, gate_score = naive_outs topk_idx, topk_val, gate_score = naive_outs
S = gate_score.shape[0] S = gate_score.shape[0]
top_k = topk_idx.shape[0] // gate_score.shape[0] top1_idx = topk_idx.view((-1, self.top_k))[:, 0]
top1_idx = topk_idx.view((-1, top_k))[:, 0]
c_e = torch.scatter_add( c_e = torch.scatter_add(
torch.zeros(self.tot_expert, device=top1_idx.device), torch.zeros(self.tot_expert, device=top1_idx.device),
0, 0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment