Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
662667d0
Unverified
Commit
662667d0
authored
Oct 05, 2020
by
msbaines
Committed by
GitHub
Oct 05, 2020
Browse files
[fix] moe: fix Top2Gate to work on GPU (#124)
parent
7815f6f3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
8 deletions
+34
-8
fairscale/nn/moe/top2gate.py
fairscale/nn/moe/top2gate.py
+15
-5
tests/nn/moe/test_top2gating.py
tests/nn/moe/test_top2gating.py
+19
-3
No files found.
fairscale/nn/moe/top2gate.py
View file @
662667d0
...
...
@@ -7,13 +7,23 @@
# Code is inspired by Top2GatingOnLogits from lingvo:
# https://github.com/tensorflow/lingvo/blob/21b8106c5f1d30a196c98eedc441d4fd70833b11/lingvo/core/moe_layers.py#L477
from
typing
import
Tuple
from
typing
import
Callable
,
Dict
,
Tuple
import
torch
from
torch
import
Tensor
import
torch.nn.functional
as
F
gumbel
=
torch
.
distributions
.
gumbel
.
Gumbel
(
0
,
1
)
# type: ignore
gumbel_map
:
Dict
[
torch
.
device
,
Callable
]
=
{}
def
gumbel_rsample
(
shape
:
Tuple
,
device
:
torch
.
device
)
->
Tensor
:
gumbel
=
gumbel_map
.
get
(
device
)
if
gumbel
is
None
:
one
=
torch
.
tensor
(
1.0
,
device
=
device
)
zero
=
torch
.
tensor
(
0.0
,
device
=
device
)
gumbel
=
torch
.
distributions
.
gumbel
.
Gumbel
(
zero
,
one
).
rsample
# type: ignore
gumbel_map
[
device
]
=
gumbel
return
gumbel
(
shape
)
def
top2gating
(
logits
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
...
...
@@ -34,7 +44,7 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise
=
logits
+
gumbel
.
rsample
(
logits
.
shape
)
logits_w_noise
=
logits
+
gumbel
_
rsample
(
logits
.
shape
,
device
=
logits
.
device
)
# Replace top-expert with min value
mins
=
torch
.
full_like
(
logits
,
min_logit
)
logits_except1
=
torch
.
where
(
mask1
.
bool
(),
mins
,
logits_w_noise
)
...
...
@@ -57,8 +67,8 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
mask2
*=
torch
.
lt
(
locations2
,
capacity
)
# Store the capacity location for each token
locations1_gs
=
torch
.
ein
sum
(
"gse,gse->gs"
,
locations1
,
mask1
)
locations2_gs
=
torch
.
ein
sum
(
"gse,gse->gs"
,
locations2
,
mask2
)
locations1_gs
=
torch
.
sum
(
locations1
*
mask1
,
dim
=
2
)
locations2_gs
=
torch
.
sum
(
locations2
*
mask2
,
dim
=
2
)
# Normalize gate probabilities
mask1_float
=
mask1
.
float
()
...
...
tests/nn/moe/test_top2gating.py
View file @
662667d0
...
...
@@ -9,15 +9,22 @@ import torch
from
fairscale.nn
import
Top2Gate
from
fairscale.nn.moe.top2gate
import
top2gating
skip_if_no_cuda
=
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
def
test_create
():
gate
=
Top2Gate
(
4
,
8
)
def
test_forward
():
@
skip_if_no_cuda
def
test_create_cuda
():
gate
=
Top2Gate
(
4
,
8
).
cuda
()
def
do_test_forward
(
device
):
torch
.
manual_seed
(
3
)
input
=
torch
.
randn
(
3
,
12
,
4
)
gate
=
Top2Gate
(
4
,
6
)
input
=
torch
.
randn
(
3
,
12
,
4
)
.
to
(
device
)
gate
=
Top2Gate
(
4
,
6
)
.
to
(
device
)
capacity
=
2
*
12
//
6
l_aux
,
combine_weights
,
dispatch_mask
=
gate
(
input
)
assert
pytest
.
approx
(
l_aux
.
item
(),
0.0283
)
...
...
@@ -33,6 +40,15 @@ def test_forward():
assert
weights_sum
==
pytest
.
approx
(
36.0
)
def
test_forward_cpu
():
do_test_forward
(
"cpu"
)
@
skip_if_no_cuda
def
test_forward_cuda
():
do_test_forward
(
"cuda"
)
# Verify that top gate is allocated capacity as per Algorithm 1 in GShard paper.
def
test_top1s
():
num_tokens
=
8
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment