Unverified Commit 662667d0 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[fix] moe: fix Top2Gate to work on GPU (#124)

parent 7815f6f3
......@@ -7,13 +7,23 @@
# Code is inspired by Top2GatingOnLogits from lingvo:
# https://github.com/tensorflow/lingvo/blob/21b8106c5f1d30a196c98eedc441d4fd70833b11/lingvo/core/moe_layers.py#L477
from typing import Tuple
from typing import Callable, Dict, Tuple
import torch
from torch import Tensor
import torch.nn.functional as F
gumbel = torch.distributions.gumbel.Gumbel(0, 1) # type: ignore
gumbel_map: Dict[torch.device, Callable] = {}
def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
gumbel = gumbel_map.get(device)
if gumbel is None:
one = torch.tensor(1.0, device=device)
zero = torch.tensor(0.0, device=device)
gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
gumbel_map[device] = gumbel
return gumbel(shape)
def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
......@@ -34,7 +44,7 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel.rsample(logits.shape)
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# Replace top-expert with min value
mins = torch.full_like(logits, min_logit)
logits_except1 = torch.where(mask1.bool(), mins, logits_w_noise)
......@@ -57,8 +67,8 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
mask2 *= torch.lt(locations2, capacity)
# Store the capacity location for each token
locations1_gs = torch.einsum("gse,gse->gs", locations1, mask1)
locations2_gs = torch.einsum("gse,gse->gs", locations2, mask2)
locations1_gs = torch.sum(locations1 * mask1, dim=2)
locations2_gs = torch.sum(locations2 * mask2, dim=2)
# Normalize gate probabilities
mask1_float = mask1.float()
......
......@@ -9,15 +9,22 @@ import torch
from fairscale.nn import Top2Gate
from fairscale.nn.moe.top2gate import top2gating
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_create():
gate = Top2Gate(4, 8)
def test_forward():
@skip_if_no_cuda
def test_create_cuda():
gate = Top2Gate(4, 8).cuda()
def do_test_forward(device):
torch.manual_seed(3)
input = torch.randn(3, 12, 4)
gate = Top2Gate(4, 6)
input = torch.randn(3, 12, 4).to(device)
gate = Top2Gate(4, 6).to(device)
capacity = 2 * 12 // 6
l_aux, combine_weights, dispatch_mask = gate(input)
assert pytest.approx(l_aux.item(), 0.0283)
......@@ -33,6 +40,15 @@ def test_forward():
assert weights_sum == pytest.approx(36.0)
def test_forward_cpu():
do_test_forward("cpu")
@skip_if_no_cuda
def test_forward_cuda():
do_test_forward("cuda")
# Verify that top gate is allocated capacity as per Algorithm 1 in GShard paper.
def test_top1s():
num_tokens = 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment