Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
c6d9be79
Unverified
Commit
c6d9be79
authored
Oct 20, 2020
by
msbaines
Committed by
GitHub
Oct 20, 2020
Browse files
[test] moe: add a more thorough MOELayer routing test (#151)
parent
49a3d9bc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
0 deletions
+38
-0
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_moe_layer.py
+38
-0
No files found.
tests/nn/moe/test_moe_layer.py
View file @
c6d9be79
...
@@ -73,6 +73,44 @@ def test_forward(device):
...
@@ -73,6 +73,44 @@ def test_forward(device):
assert
torch
.
allclose
(
input
,
output
)
assert
torch
.
allclose
(
input
,
output
)
# Test Gate which round-robin routes tokens to experts
class
RoundRobinGate
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_dim
,
num_experts
):
super
().
__init__
()
self
.
model_dim
=
model_dim
self
.
num_experts
=
num_experts
def
forward
(
self
,
input
):
g
,
s
,
_
=
input
.
shape
assert
s
%
self
.
num_experts
==
0
capacity
=
2
*
s
//
self
.
num_experts
output
=
torch
.
zeros
(
g
,
s
,
self
.
num_experts
,
capacity
,
dtype
=
input
.
dtype
,
device
=
input
.
device
)
for
i
in
range
(
s
):
output
[:,
i
,
i
%
self
.
num_experts
,
i
//
self
.
num_experts
]
=
1.0
return
0.0
,
output
,
output
.
bool
()
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
def
test_forward_routing
(
device
):
model_dim
=
8
num_experts
=
dist
.
get_world_size
()
input
=
torch
.
randn
(
3
,
4
,
16
,
model_dim
).
to
(
device
)
gate
=
RoundRobinGate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use scaling matrix (each rank has a different scale)
scale
=
dist
.
get_rank
()
+
1
expert
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
eye
(
model_dim
)
*
scale
)
moe
=
MOELayer
(
gate
,
expert
).
to
(
device
)
output
=
moe
(
input
)
assert
output
.
shape
==
input
.
shape
# Verify that each token was sent to the correct expert by checking its scale.
t
=
input
.
shape
[
2
]
for
i
in
range
(
t
):
expert
=
i
%
num_experts
assert
torch
.
allclose
(
input
[:,
:,
i
]
*
(
expert
+
1
),
output
[:,
:,
i
])
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
def
test_backward
(
device
):
def
test_backward
(
device
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment