Unverified Commit 6e7ad798 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[refactor] moe: simplify logic removing top expert (#125)

parent 662667d0
...@@ -29,7 +29,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: ...@@ -29,7 +29,6 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits.""" """Implements Top2Gating on logits."""
gates = F.softmax(logits, dim=2) gates = F.softmax(logits, dim=2)
min_logit = torch.finfo(logits.dtype).min # type: ignore
# gates has shape of GSE # gates has shape of GSE
num_tokens = gates.shape[1] num_tokens = gates.shape[1]
...@@ -46,8 +45,7 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: ...@@ -46,8 +45,7 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# Replace top-expert with min value # Replace top-expert with min value
mins = torch.full_like(logits, min_logit) logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
logits_except1 = torch.where(mask1.bool(), mins, logits_w_noise)
indices2_gs = torch.argmax(logits_except1, dim=2) indices2_gs = torch.argmax(logits_except1, dim=2)
mask2 = F.one_hot(indices2_gs, num_classes=num_experts) mask2 = F.one_hot(indices2_gs, num_classes=num_experts)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment