Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
339cf060
Unverified
Commit
339cf060
authored
Oct 23, 2020
by
msbaines
Committed by
GitHub
Oct 23, 2020
Browse files
[feat] moe: add support for multiple experts per device (#161)
parent
95ddbc19
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
10 deletions
+44
-10
fairscale/nn/moe/moe_layer.py
fairscale/nn/moe/moe_layer.py
+17
-7
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_moe_layer.py
+27
-3
No files found.
fairscale/nn/moe/moe_layer.py
View file @
339cf060
...
@@ -3,12 +3,12 @@
...
@@ -3,12 +3,12 @@
# This source code is licensed under the BSD license found in the
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Tuple
,
Union
,
cast
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.nn
import
Module
from
torch.nn
import
Module
,
ModuleList
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
Base
=
Module
[
Tensor
]
Base
=
Module
[
Tensor
]
...
@@ -55,13 +55,17 @@ class MOELayer(Base):
...
@@ -55,13 +55,17 @@ class MOELayer(Base):
expert network
expert network
"""
"""
def
__init__
(
self
,
gate
:
Module
,
expert
:
Module
,
group
:
Optional
[
Any
]
=
None
)
->
None
:
def
__init__
(
self
,
gate
:
Module
,
expert
s
:
Union
[
Module
,
ModuleList
]
,
group
:
Optional
[
Any
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate
=
gate
self
.
gate
=
gate
self
.
expert
=
expert
if
type
(
experts
)
==
ModuleList
:
self
.
experts
=
cast
(
ModuleList
,
experts
)
else
:
self
.
experts
=
ModuleList
([
experts
])
self
.
group
=
group
if
group
is
not
None
else
dist
.
group
.
WORLD
self
.
group
=
group
if
group
is
not
None
else
dist
.
group
.
WORLD
for
p
in
expert
.
parameters
():
for
expert
in
self
.
experts
:
p
.
expert
=
True
# type: ignore
for
p
in
experts
.
parameters
():
p
.
expert
=
True
# type: ignore
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
dispatched_input
=
torch
.
einsum
(
"gsec,gsm->egcm"
,
dispatch_mask
.
float
(),
input
)
dispatched_input
=
torch
.
einsum
(
"gsec,gsm->egcm"
,
dispatch_mask
.
float
(),
input
)
...
@@ -74,6 +78,7 @@ class MOELayer(Base):
...
@@ -74,6 +78,7 @@ class MOELayer(Base):
def
forward
(
self
,
*
input
:
Tensor
,
**
kwargs
:
Any
)
->
Tensor
:
def
forward
(
self
,
*
input
:
Tensor
,
**
kwargs
:
Any
)
->
Tensor
:
assert
len
(
input
)
==
1
,
"only single input Tensor supported"
assert
len
(
input
)
==
1
,
"only single input Tensor supported"
assert
len
(
input
[
0
].
shape
)
==
4
,
"input Tensor must have dimensions: (g)roup, (s)equence, (t)oken, (m)odel"
assert
len
(
input
[
0
].
shape
)
==
4
,
"input Tensor must have dimensions: (g)roup, (s)equence, (t)oken, (m)odel"
assert
input
[
0
].
shape
[
0
]
==
len
(
self
.
experts
),
"group dimension size must be equal to number of local experts"
# Implement Algorithm 2 from GShard paper.
# Implement Algorithm 2 from GShard paper.
shape
=
input
[
0
].
shape
shape
=
input
[
0
].
shape
...
@@ -81,6 +86,11 @@ class MOELayer(Base):
...
@@ -81,6 +86,11 @@ class MOELayer(Base):
reshaped_input
=
input
[
0
].
reshape
(
shape
[
0
],
-
1
,
shape
[
3
])
reshaped_input
=
input
[
0
].
reshape
(
shape
[
0
],
-
1
,
shape
[
3
])
self
.
l_aux
,
combine_weights
,
dispatching_mask
=
self
.
gate
(
reshaped_input
)
self
.
l_aux
,
combine_weights
,
dispatching_mask
=
self
.
gate
(
reshaped_input
)
dispatched_input
=
self
.
all_to_all_dispatch
(
dispatching_mask
,
reshaped_input
)
dispatched_input
=
self
.
all_to_all_dispatch
(
dispatching_mask
,
reshaped_input
)
expert_output
=
self
.
expert
(
dispatched_input
)
assert
dispatched_input
.
shape
[
1
]
==
len
(
self
.
experts
)
chunks
=
dispatched_input
.
chunk
(
len
(
self
.
experts
),
dim
=
1
)
expert_outputs
=
[]
for
chunk
,
expert
in
zip
(
chunks
,
self
.
experts
):
expert_outputs
+=
[
expert
(
chunk
)]
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
1
)
combined_output
=
self
.
all_to_all_combine
(
combine_weights
,
expert_output
)
combined_output
=
self
.
all_to_all_combine
(
combine_weights
,
expert_output
)
return
combined_output
.
reshape
(
shape
)
return
combined_output
.
reshape
(
shape
)
tests/nn/moe/test_moe_layer.py
View file @
339cf060
...
@@ -61,7 +61,7 @@ def test_expert_params(device):
...
@@ -61,7 +61,7 @@ def test_expert_params(device):
def
test_forward
(
device
):
def
test_forward
(
device
):
model_dim
=
8
model_dim
=
8
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
input
=
torch
.
randn
(
3
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
1
,
4
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use identity matrix
# Use identity matrix
...
@@ -73,6 +73,30 @@ def test_forward(device):
...
@@ -73,6 +73,30 @@ def test_forward(device):
assert
torch
.
allclose
(
input
,
output
)
assert
torch
.
allclose
(
input
,
output
)
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
def
test_forward_multi
(
device
):
torch
.
set_printoptions
(
threshold
=
5000
)
num_local_experts
=
4
model_dim
=
4
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
*
num_local_experts
input
=
torch
.
randn
(
num_local_experts
,
4
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
experts
=
[]
for
i
in
range
(
num_local_experts
):
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use identity matrix
expert
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
eye
(
model_dim
))
experts
+=
[
expert
]
moe
=
MOELayer
(
gate
,
torch
.
nn
.
ModuleList
(
experts
)).
to
(
device
)
output
=
moe
(
input
)
assert
output
.
shape
==
input
.
shape
# 90% of the input should have gone to an expert
assert
len
(
output
.
nonzero
(
as_tuple
=
False
))
/
output
.
numel
()
>
0.90
# Except for zeros, re-assembled output should match input due to identity expert.
assert
torch
.
allclose
(
input
,
torch
.
where
(
output
>
0
,
output
,
input
))
# Test Gate which round-robin routes tokens to experts
# Test Gate which round-robin routes tokens to experts
class
RoundRobinGate
(
torch
.
nn
.
Module
):
class
RoundRobinGate
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_dim
,
num_experts
):
def
__init__
(
self
,
model_dim
,
num_experts
):
...
@@ -95,7 +119,7 @@ class RoundRobinGate(torch.nn.Module):
...
@@ -95,7 +119,7 @@ class RoundRobinGate(torch.nn.Module):
def
test_forward_routing
(
device
):
def
test_forward_routing
(
device
):
model_dim
=
8
model_dim
=
8
num_experts
=
dist
.
get_world_size
()
num_experts
=
dist
.
get_world_size
()
input
=
torch
.
randn
(
3
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
1
,
4
,
16
,
model_dim
).
to
(
device
)
gate
=
RoundRobinGate
(
model_dim
,
num_experts
)
gate
=
RoundRobinGate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use scaling matrix (each rank has a different scale)
# Use scaling matrix (each rank has a different scale)
...
@@ -117,7 +141,7 @@ def test_backward(device):
...
@@ -117,7 +141,7 @@ def test_backward(device):
loss
=
torch
.
nn
.
MSELoss
()
loss
=
torch
.
nn
.
MSELoss
()
model_dim
=
8
model_dim
=
8
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
input
=
torch
.
randn
(
3
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
1
,
4
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use identity matrix
# Use identity matrix
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment