Unverified Commit 29d81c43 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[perf] nn.moe: replace einsum with faster equivalent code (#667)

Co-authored-by: @myleott
parent a9156260
......@@ -77,10 +77,8 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
locations2_s = torch.sum(locations2 * mask2, dim=1)
# Normalize gate probabilities
mask1_float = mask1.float()
mask2_float = mask2.float()
gates1_s = torch.einsum("se,se->s", gates, mask1_float)
gates2_s = torch.einsum("se,se->s", gates, mask2_float)
gates1_s = (gates * mask1).sum(dim=1) # einsum("se,se->s")
gates2_s = (gates * mask2).sum(dim=1) # einsum("se,se->s")
denom_s = gates1_s + gates2_s
# Avoid divide-by-zero
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
......@@ -88,12 +86,12 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
gates2_s /= denom_s
# Calculate combine_weights and dispatch_mask
gates1 = torch.einsum("s,se->se", gates1_s, mask1_float)
gates2 = torch.einsum("s,se->se", gates2_s, mask2_float)
gates1 = gates1_s.unsqueeze(-1) * mask1 # einsum("s,se->se")
gates2 = gates2_s.unsqueeze(-1) * mask2 # einsum("s,se->se")
locations1_sc = one_hot(locations1_s, num_classes=capacity)
locations2_sc = one_hot(locations2_s, num_classes=capacity)
combine1_sec = torch.einsum("se,sc->sec", gates1, locations1_sc)
combine2_sec = torch.einsum("se,sc->sec", gates2, locations2_sc)
combine1_sec = gates1.unsqueeze(2) * locations1_sc.unsqueeze(1) # einsum("se,sc->sec")
combine2_sec = gates2.unsqueeze(2) * locations2_sc.unsqueeze(1) # einsum("se,sc->sec")
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment