Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
662667d0
Unverified
Commit
662667d0
authored
Oct 05, 2020
by
msbaines
Committed by
GitHub
Oct 05, 2020
Browse files
[fix] moe: fix Top2Gate to work on GPU (#124)
parent
7815f6f3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
8 deletions
+34
-8
fairscale/nn/moe/top2gate.py
fairscale/nn/moe/top2gate.py
+15
-5
tests/nn/moe/test_top2gating.py
tests/nn/moe/test_top2gating.py
+19
-3
No files found.
fairscale/nn/moe/top2gate.py
View file @
662667d0
...
@@ -7,13 +7,23 @@
...
@@ -7,13 +7,23 @@
# Code is inspired by Top2GatingOnLogits from lingvo:
# Code is inspired by Top2GatingOnLogits from lingvo:
# https://github.com/tensorflow/lingvo/blob/21b8106c5f1d30a196c98eedc441d4fd70833b11/lingvo/core/moe_layers.py#L477
# https://github.com/tensorflow/lingvo/blob/21b8106c5f1d30a196c98eedc441d4fd70833b11/lingvo/core/moe_layers.py#L477
from
typing
import
Tuple
from
typing
import
Callable
,
Dict
,
Tuple
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
gumbel
=
torch
.
distributions
.
gumbel
.
Gumbel
(
0
,
1
)
# type: ignore
gumbel_map
:
Dict
[
torch
.
device
,
Callable
]
=
{}
def
gumbel_rsample
(
shape
:
Tuple
,
device
:
torch
.
device
)
->
Tensor
:
gumbel
=
gumbel_map
.
get
(
device
)
if
gumbel
is
None
:
one
=
torch
.
tensor
(
1.0
,
device
=
device
)
zero
=
torch
.
tensor
(
0.0
,
device
=
device
)
gumbel
=
torch
.
distributions
.
gumbel
.
Gumbel
(
zero
,
one
).
rsample
# type: ignore
gumbel_map
[
device
]
=
gumbel
return
gumbel
(
shape
)
def
top2gating
(
logits
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
def
top2gating
(
logits
:
torch
.
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
...
@@ -34,7 +44,7 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
...
@@ -34,7 +44,7 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
# Create a mask for 2nd's expert per token using Gumbel-max trick
# Create a mask for 2nd's expert per token using Gumbel-max trick
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
# https://timvieira.github.io/blog/post/2014/07/31/gumbel-max-trick/
logits_w_noise
=
logits
+
gumbel
.
rsample
(
logits
.
shape
)
logits_w_noise
=
logits
+
gumbel
_
rsample
(
logits
.
shape
,
device
=
logits
.
device
)
# Replace top-expert with min value
# Replace top-expert with min value
mins
=
torch
.
full_like
(
logits
,
min_logit
)
mins
=
torch
.
full_like
(
logits
,
min_logit
)
logits_except1
=
torch
.
where
(
mask1
.
bool
(),
mins
,
logits_w_noise
)
logits_except1
=
torch
.
where
(
mask1
.
bool
(),
mins
,
logits_w_noise
)
...
@@ -57,8 +67,8 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
...
@@ -57,8 +67,8 @@ def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]:
mask2
*=
torch
.
lt
(
locations2
,
capacity
)
mask2
*=
torch
.
lt
(
locations2
,
capacity
)
# Store the capacity location for each token
# Store the capacity location for each token
locations1_gs
=
torch
.
ein
sum
(
"gse,gse->gs"
,
locations1
,
mask1
)
locations1_gs
=
torch
.
sum
(
locations1
*
mask1
,
dim
=
2
)
locations2_gs
=
torch
.
ein
sum
(
"gse,gse->gs"
,
locations2
,
mask2
)
locations2_gs
=
torch
.
sum
(
locations2
*
mask2
,
dim
=
2
)
# Normalize gate probabilities
# Normalize gate probabilities
mask1_float
=
mask1
.
float
()
mask1_float
=
mask1
.
float
()
...
...
tests/nn/moe/test_top2gating.py
View file @
662667d0
...
@@ -9,15 +9,22 @@ import torch
...
@@ -9,15 +9,22 @@ import torch
from
fairscale.nn
import
Top2Gate
from
fairscale.nn
import
Top2Gate
from
fairscale.nn.moe.top2gate
import
top2gating
from
fairscale.nn.moe.top2gate
import
top2gating
skip_if_no_cuda
=
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
def
test_create
():
def
test_create
():
gate
=
Top2Gate
(
4
,
8
)
gate
=
Top2Gate
(
4
,
8
)
def
test_forward
():
@
skip_if_no_cuda
def
test_create_cuda
():
gate
=
Top2Gate
(
4
,
8
).
cuda
()
def
do_test_forward
(
device
):
torch
.
manual_seed
(
3
)
torch
.
manual_seed
(
3
)
input
=
torch
.
randn
(
3
,
12
,
4
)
input
=
torch
.
randn
(
3
,
12
,
4
)
.
to
(
device
)
gate
=
Top2Gate
(
4
,
6
)
gate
=
Top2Gate
(
4
,
6
)
.
to
(
device
)
capacity
=
2
*
12
//
6
capacity
=
2
*
12
//
6
l_aux
,
combine_weights
,
dispatch_mask
=
gate
(
input
)
l_aux
,
combine_weights
,
dispatch_mask
=
gate
(
input
)
assert
pytest
.
approx
(
l_aux
.
item
(),
0.0283
)
assert
pytest
.
approx
(
l_aux
.
item
(),
0.0283
)
...
@@ -33,6 +40,15 @@ def test_forward():
...
@@ -33,6 +40,15 @@ def test_forward():
assert
weights_sum
==
pytest
.
approx
(
36.0
)
assert
weights_sum
==
pytest
.
approx
(
36.0
)
def
test_forward_cpu
():
do_test_forward
(
"cpu"
)
@
skip_if_no_cuda
def
test_forward_cuda
():
do_test_forward
(
"cuda"
)
# Verify that top gate is allocated capacity as per Algorithm 1 in GShard paper.
# Verify that top gate is allocated capacity as per Algorithm 1 in GShard paper.
def
test_top1s
():
def
test_top1s
():
num_tokens
=
8
num_tokens
=
8
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment