Commit a762d33c authored by Rick Ho's avatar Rick Ho
Browse files

capacity numerical fix

parent 6cb550fd
...@@ -34,7 +34,8 @@ class GShardGate(NaiveGate): ...@@ -34,7 +34,8 @@ class GShardGate(NaiveGate):
self.set_loss(loss) self.set_loss(loss)
cap_rate = self.capacity[0 if self.training else 1] cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * x.shape[0]) // self.world_size capacity = math.ceil(cap_rate * x.shape[0])
capacity = capacity * self.top_k // (self.world_size * self.num_expert)
capacity = torch.ones(self.num_expert * self.world_size, capacity = torch.ones(self.num_expert * self.world_size,
dtype=torch.int32, device=topk_idx.device) * capacity dtype=torch.int32, device=topk_idx.device) * capacity
topk_idx = fmoe_native.prune_gate_by_capacity(topk_idx, capacity, topk_idx = fmoe_native.prune_gate_by_capacity(topk_idx, capacity,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment