Commit cfd30236 authored by zms1999's avatar zms1999
Browse files

Fix GShardGate random_routing

parent 4edeccd9
......@@ -14,7 +14,7 @@ class GShardGate(NaiveGate):
assert topk == 2, 'topk should be 2 in gshard'
super().__init__(d_model, num_expert, world_size, top_k=2)
self.capacity = capacity
self.random_routing = True
self.random_routing = random_routing
def forward(self, x):
naive_outs = super().forward(x, return_all_scores=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment