"test/git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "1cb25232bdefcaad8ad88c540442981e8d8cab0e"
Unverified Commit 99b30a04 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[perf] nn.moe: workaround inefficiency in PyTorch's one_hot (#666)

Workaround for https://github.com/pytorch/pytorch/issues/55579

Co-authored-by: @shruti-bh, @myleott
parent 6db68518
......@@ -26,6 +26,14 @@ def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
return gumbel(shape)
def one_hot(tensor: torch.Tensor, num_classes: int) -> Tensor:
"""Workaround for https://github.com/pytorch/pytorch/issues/55579"""
assert num_classes > 0, "num_classes must be a positive integer"
ret = torch.zeros(tensor.shape + (num_classes,), device=tensor.device, dtype=tensor.dtype)
ret.scatter_(-1, tensor.unsqueeze(-1), 1)
return ret
def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""Implements Top2Gating on logits."""
gates = F.softmax(logits, dim=1)
......@@ -39,7 +47,7 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
# Create a mask for 1st's expert per token
indices1_s = torch.argmax(gates, dim=1)
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
mask1 = one_hot(indices1_s, num_classes=num_experts)
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
......@@ -47,7 +55,7 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
# Replace top-expert with min value
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf"))
indices2_s = torch.argmax(logits_except1, dim=1)
mask2 = F.one_hot(indices2_s, num_classes=num_experts)
mask2 = one_hot(indices2_s, num_classes=num_experts)
# Compute locations in capacity buffer
locations1 = torch.cumsum(mask1, dim=0) - 1
......@@ -82,8 +90,8 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
# Calculate combine_weights and dispatch_mask
gates1 = torch.einsum("s,se->se", gates1_s, mask1_float)
gates2 = torch.einsum("s,se->se", gates2_s, mask2_float)
locations1_sc = F.one_hot(locations1_s, num_classes=capacity)
locations2_sc = F.one_hot(locations2_s, num_classes=capacity)
locations1_sc = one_hot(locations1_s, num_classes=capacity)
locations2_sc = one_hot(locations2_s, num_classes=capacity)
combine1_sec = torch.einsum("se,sc->sec", gates1, locations1_sc)
combine2_sec = torch.einsum("se,sc->sec", gates2, locations2_sc)
combine_weights = combine1_sec + combine2_sec
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment