Unverified Commit 662667d0 authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[fix] moe: fix Top2Gate to work on GPU (#124)

parent 7815f6f3
...@@ -7,13 +7,23 @@ ...@@ -7,13 +7,23 @@
# Code is inspired by Top2GatingOnLogits from lingvo: # Code is inspired by Top2GatingOnLogits from lingvo:
# https://github.com/tensorflow/lingvo/blob/21b8106c5f1d30a196c98eedc441d4fd70833b11/lingvo/core/moe_layers.py#L477 # https://github.com/tensorflow/lingvo/blob/21b8106c5f1d30a196c98eedc441d4fd70833b11/lingvo/core/moe_layers.py#L477
from typing import Tuple from typing import Callable, Dict, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
import torch.nn.functional as F import torch.nn.functional as F
gumbel = torch.distributions.gumbel.Gumbel(0, 1) # type: ignore gumbel_map: Dict[torch.device, Callable] = {}
def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
gumbel = gumbel_map.get(device)
if gumbel is None:
one = torch.tensor(1.0, device=device)
zero = torch.tensor(0.0, device=device)
gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample # type: ignore
gumbel_map[device] = gumbel
return gumbel(shape)
def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
...@@ -34,7 +44,7 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: ...@@ -34,7 +44,7 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
# Create a mask for 2nd's expert per token using Gumbel-max trick # Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/ # https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise = logits + gumbel.rsample(logits.shape) logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device)
# Replace top-expert with min value # Replace top-expert with min value
mins = torch.full_like(logits, min_logit) mins = torch.full_like(logits, min_logit)
logits_except1 = torch.where(mask1.bool(), mins, logits_w_noise) logits_except1 = torch.where(mask1.bool(), mins, logits_w_noise)
...@@ -57,8 +67,8 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: ...@@ -57,8 +67,8 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
mask2 *= torch.lt(locations2, capacity) mask2 *= torch.lt(locations2, capacity)
# Store the capacity location for each token # Store the capacity location for each token
locations1_gs = torch.einsum("gse,gse->gs", locations1, mask1) locations1_gs = torch.sum(locations1 * mask1, dim=2)
locations2_gs = torch.einsum("gse,gse->gs", locations2, mask2) locations2_gs = torch.sum(locations2 * mask2, dim=2)
# Normalize gate probabilities # Normalize gate probabilities
mask1_float = mask1.float() mask1_float = mask1.float()
......
...@@ -9,15 +9,22 @@ import torch ...@@ -9,15 +9,22 @@ import torch
from fairscale.nn import Top2Gate from fairscale.nn import Top2Gate
from fairscale.nn.moe.top2gate import top2gating from fairscale.nn.moe.top2gate import top2gating
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def test_create(): def test_create():
gate = Top2Gate(4, 8) gate = Top2Gate(4, 8)
def test_forward(): @skip_if_no_cuda
def test_create_cuda():
gate = Top2Gate(4, 8).cuda()
def do_test_forward(device):
torch.manual_seed(3) torch.manual_seed(3)
input = torch.randn(3, 12, 4) input = torch.randn(3, 12, 4).to(device)
gate = Top2Gate(4, 6) gate = Top2Gate(4, 6).to(device)
capacity = 2 * 12 // 6 capacity = 2 * 12 // 6
l_aux, combine_weights, dispatch_mask = gate(input) l_aux, combine_weights, dispatch_mask = gate(input)
assert pytest.approx(l_aux.item(), 0.0283) assert pytest.approx(l_aux.item(), 0.0283)
...@@ -33,6 +40,15 @@ def test_forward(): ...@@ -33,6 +40,15 @@ def test_forward():
assert weights_sum == pytest.approx(36.0) assert weights_sum == pytest.approx(36.0)
def test_forward_cpu():
do_test_forward("cpu")
@skip_if_no_cuda
def test_forward_cuda():
do_test_forward("cuda")
# Verify that top gate is allocated capacity as per Algorithm 1 in GShard paper. # Verify that top gate is allocated capacity as per Algorithm 1 in GShard paper.
def test_top1s(): def test_top1s():
num_tokens = 8 num_tokens = 8
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment