Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
c6d9be79
You need to sign in or sign up before continuing.
Unverified
Commit
c6d9be79
authored
Oct 20, 2020
by
msbaines
Committed by
GitHub
Oct 20, 2020
Browse files
[test] moe: add a more thorough MOELayer routing test (#151)
parent
49a3d9bc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
0 deletions
+38
-0
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_moe_layer.py
+38
-0
No files found.
tests/nn/moe/test_moe_layer.py
View file @
c6d9be79
...
@@ -73,6 +73,44 @@ def test_forward(device):
...
@@ -73,6 +73,44 @@ def test_forward(device):
assert
torch
.
allclose
(
input
,
output
)
assert
torch
.
allclose
(
input
,
output
)
# Test Gate which round-robin routes tokens to experts
class
RoundRobinGate
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_dim
,
num_experts
):
super
().
__init__
()
self
.
model_dim
=
model_dim
self
.
num_experts
=
num_experts
def
forward
(
self
,
input
):
g
,
s
,
_
=
input
.
shape
assert
s
%
self
.
num_experts
==
0
capacity
=
2
*
s
//
self
.
num_experts
output
=
torch
.
zeros
(
g
,
s
,
self
.
num_experts
,
capacity
,
dtype
=
input
.
dtype
,
device
=
input
.
device
)
for
i
in
range
(
s
):
output
[:,
i
,
i
%
self
.
num_experts
,
i
//
self
.
num_experts
]
=
1.0
return
0.0
,
output
,
output
.
bool
()
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
def
test_forward_routing
(
device
):
model_dim
=
8
num_experts
=
dist
.
get_world_size
()
input
=
torch
.
randn
(
3
,
4
,
16
,
model_dim
).
to
(
device
)
gate
=
RoundRobinGate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use scaling matrix (each rank has a different scale)
scale
=
dist
.
get_rank
()
+
1
expert
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
eye
(
model_dim
)
*
scale
)
moe
=
MOELayer
(
gate
,
expert
).
to
(
device
)
output
=
moe
(
input
)
assert
output
.
shape
==
input
.
shape
# Verify that each token was sent to the correct expert by checking its scale.
t
=
input
.
shape
[
2
]
for
i
in
range
(
t
):
expert
=
i
%
num_experts
assert
torch
.
allclose
(
input
[:,
:,
i
]
*
(
expert
+
1
),
output
[:,
:,
i
])
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
def
test_backward
(
device
):
def
test_backward
(
device
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment