Commit e58e7b3b authored by Rich Ho's avatar Rich Ho
Browse files

add random routing in gshard gate

parent ba2b7aa9
...@@ -9,9 +9,11 @@ from .utils import limit_by_capacity ...@@ -9,9 +9,11 @@ from .utils import limit_by_capacity
class GShardGate(NaiveGate): class GShardGate(NaiveGate):
def __init__(self, d_model, num_expert, world_size, capacity=(1.2, 2.4)): def __init__(self, d_model, num_expert, world_size,
capacity=(1.2, 2.4), random_routing=True):
super().__init__(d_model, num_expert, world_size, top_k=2) super().__init__(d_model, num_expert, world_size, top_k=2)
self.capacity = capacity self.capacity = capacity
self.random_routing = True
def forward(self, x): def forward(self, x):
naive_outs = super().forward(x, return_all_scores=True) naive_outs = super().forward(x, return_all_scores=True)
...@@ -34,4 +36,9 @@ class GShardGate(NaiveGate): ...@@ -34,4 +36,9 @@ class GShardGate(NaiveGate):
capacity = math.ceil(cap_rate * x.shape[0]) capacity = math.ceil(cap_rate * x.shape[0])
limit_by_capacity(topk_idx, self.num_expert, self.world_size, capacity) limit_by_capacity(topk_idx, self.num_expert, self.world_size, capacity)
if self.random_routing:
rand_routing_prob = torch.rand(gate_score.size(0), device=x.device)
mask = (2 * topk_val[:, 1] < rand_routing_prob)
topk_idx[:, 1].masked_fill_(mask, -1)
return topk_idx, topk_val return topk_idx, topk_val
...@@ -39,7 +39,6 @@ class SwitchGate(NaiveGate): ...@@ -39,7 +39,6 @@ class SwitchGate(NaiveGate):
score, k=1, dim=-1, largest=True score, k=1, dim=-1, largest=True
) # [.. x top_k] ) # [.. x top_k]
top1_score = top1_score.to(dtype=inp.dtype) top1_score = top1_score.to(dtype=inp.dtype)
top1_score = top1_score.to(dtype=inp.dtype)
cap_rate = self.capacity[0 if self.training else 1] cap_rate = self.capacity[0 if self.training else 1]
capacity = math.ceil(cap_rate * inp.shape[0]) capacity = math.ceil(cap_rate * inp.shape[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment