Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
339cf060
Unverified
Commit
339cf060
authored
Oct 23, 2020
by
msbaines
Committed by
GitHub
Oct 23, 2020
Browse files
[feat] moe: add support for multiple experts per device (#161)
parent
95ddbc19
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
10 deletions
+44
-10
fairscale/nn/moe/moe_layer.py
fairscale/nn/moe/moe_layer.py
+17
-7
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_moe_layer.py
+27
-3
No files found.
fairscale/nn/moe/moe_layer.py
View file @
339cf060
...
...
@@ -3,12 +3,12 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Tuple
,
Union
,
cast
import
torch
from
torch
import
Tensor
import
torch.distributed
as
dist
from
torch.nn
import
Module
from
torch.nn
import
Module
,
ModuleList
if
TYPE_CHECKING
:
Base
=
Module
[
Tensor
]
...
...
@@ -55,13 +55,17 @@ class MOELayer(Base):
expert network
"""
def
__init__
(
self
,
gate
:
Module
,
expert
:
Module
,
group
:
Optional
[
Any
]
=
None
)
->
None
:
def
__init__
(
self
,
gate
:
Module
,
expert
s
:
Union
[
Module
,
ModuleList
]
,
group
:
Optional
[
Any
]
=
None
)
->
None
:
super
().
__init__
()
self
.
gate
=
gate
self
.
expert
=
expert
if
type
(
experts
)
==
ModuleList
:
self
.
experts
=
cast
(
ModuleList
,
experts
)
else
:
self
.
experts
=
ModuleList
([
experts
])
self
.
group
=
group
if
group
is
not
None
else
dist
.
group
.
WORLD
for
p
in
expert
.
parameters
():
p
.
expert
=
True
# type: ignore
for
expert
in
self
.
experts
:
for
p
in
experts
.
parameters
():
p
.
expert
=
True
# type: ignore
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
dispatched_input
=
torch
.
einsum
(
"gsec,gsm->egcm"
,
dispatch_mask
.
float
(),
input
)
...
...
@@ -74,6 +78,7 @@ class MOELayer(Base):
def
forward
(
self
,
*
input
:
Tensor
,
**
kwargs
:
Any
)
->
Tensor
:
assert
len
(
input
)
==
1
,
"only single input Tensor supported"
assert
len
(
input
[
0
].
shape
)
==
4
,
"input Tensor must have dimensions: (g)roup, (s)equence, (t)oken, (m)odel"
assert
input
[
0
].
shape
[
0
]
==
len
(
self
.
experts
),
"group dimension size must be equal to number of local experts"
# Implement Algorithm 2 from GShard paper.
shape
=
input
[
0
].
shape
...
...
@@ -81,6 +86,11 @@ class MOELayer(Base):
reshaped_input
=
input
[
0
].
reshape
(
shape
[
0
],
-
1
,
shape
[
3
])
self
.
l_aux
,
combine_weights
,
dispatching_mask
=
self
.
gate
(
reshaped_input
)
dispatched_input
=
self
.
all_to_all_dispatch
(
dispatching_mask
,
reshaped_input
)
expert_output
=
self
.
expert
(
dispatched_input
)
assert
dispatched_input
.
shape
[
1
]
==
len
(
self
.
experts
)
chunks
=
dispatched_input
.
chunk
(
len
(
self
.
experts
),
dim
=
1
)
expert_outputs
=
[]
for
chunk
,
expert
in
zip
(
chunks
,
self
.
experts
):
expert_outputs
+=
[
expert
(
chunk
)]
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
1
)
combined_output
=
self
.
all_to_all_combine
(
combine_weights
,
expert_output
)
return
combined_output
.
reshape
(
shape
)
tests/nn/moe/test_moe_layer.py
View file @
339cf060
...
...
@@ -61,7 +61,7 @@ def test_expert_params(device):
def
test_forward
(
device
):
model_dim
=
8
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
input
=
torch
.
randn
(
3
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
1
,
4
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use identity matrix
...
...
@@ -73,6 +73,30 @@ def test_forward(device):
assert
torch
.
allclose
(
input
,
output
)
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
def
test_forward_multi
(
device
):
torch
.
set_printoptions
(
threshold
=
5000
)
num_local_experts
=
4
model_dim
=
4
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
*
num_local_experts
input
=
torch
.
randn
(
num_local_experts
,
4
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
experts
=
[]
for
i
in
range
(
num_local_experts
):
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use identity matrix
expert
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
eye
(
model_dim
))
experts
+=
[
expert
]
moe
=
MOELayer
(
gate
,
torch
.
nn
.
ModuleList
(
experts
)).
to
(
device
)
output
=
moe
(
input
)
assert
output
.
shape
==
input
.
shape
# 90% of the input should have gone to an expert
assert
len
(
output
.
nonzero
(
as_tuple
=
False
))
/
output
.
numel
()
>
0.90
# Except for zeros, re-assembled output should match input due to identity expert.
assert
torch
.
allclose
(
input
,
torch
.
where
(
output
>
0
,
output
,
input
))
# Test Gate which round-robin routes tokens to experts
class
RoundRobinGate
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_dim
,
num_experts
):
...
...
@@ -95,7 +119,7 @@ class RoundRobinGate(torch.nn.Module):
def
test_forward_routing
(
device
):
model_dim
=
8
num_experts
=
dist
.
get_world_size
()
input
=
torch
.
randn
(
3
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
1
,
4
,
16
,
model_dim
).
to
(
device
)
gate
=
RoundRobinGate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use scaling matrix (each rank has a different scale)
...
...
@@ -117,7 +141,7 @@ def test_backward(device):
loss
=
torch
.
nn
.
MSELoss
()
model_dim
=
8
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
input
=
torch
.
randn
(
3
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
1
,
4
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use identity matrix
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment