Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
317c0945
Unverified
Commit
317c0945
authored
Nov 11, 2020
by
msbaines
Committed by
GitHub
Nov 11, 2020
Browse files
[fix] moe: fix bug for multiple experts per-gpu case (#184)
parent
89176e34
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
6 deletions
+36
-6
fairscale/nn/moe/moe_layer.py
fairscale/nn/moe/moe_layer.py
+11
-6
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_moe_layer.py
+25
-0
No files found.
fairscale/nn/moe/moe_layer.py
View file @
317c0945
...
...
@@ -24,7 +24,6 @@ class _AllToAll(torch.autograd.Function):
@
staticmethod
def
forward
(
ctx
:
Any
,
group
:
dist
.
ProcessGroup
,
input
:
Tensor
)
->
Tensor
:
# type: ignore
ctx
.
group
=
group
world_size
=
dist
.
get_world_size
(
group
)
input
=
input
.
contiguous
()
output
=
torch
.
empty_like
(
input
)
dist
.
all_to_all_single
(
output
,
input
,
group
=
group
)
...
...
@@ -64,6 +63,8 @@ class MOELayer(Base):
for
expert
in
self
.
experts
:
for
p
in
experts
.
parameters
():
p
.
expert
=
True
# type: ignore
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
self
.
num_local_experts
=
len
(
self
.
experts
)
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
dispatched_input
=
torch
.
einsum
(
"sec,sm->ecm"
,
dispatch_mask
.
float
(),
input
)
...
...
@@ -79,15 +80,19 @@ class MOELayer(Base):
assert
input
[
0
].
shape
[
0
]
%
len
(
self
.
experts
)
==
0
,
"num tokens must be order of number of local experts"
# Implement Algorithm 2 from GShard paper.
shape
=
input
[
0
].
shape
d_model
=
input
[
0
].
shape
[
2
]
# Reshape into S tokens by dropping sequence dimension.
reshaped_input
=
input
[
0
].
reshape
(
-
1
,
shape
[
2
]
)
reshaped_input
=
input
[
0
].
reshape
(
-
1
,
d_model
)
self
.
l_aux
,
combine_weights
,
dispatching_mask
=
self
.
gate
(
reshaped_input
)
dispatched_input
=
self
.
all_to_all_dispatch
(
dispatching_mask
,
reshaped_input
)
chunks
=
dispatched_input
.
chunk
(
len
(
self
.
experts
),
dim
=
0
)
# Re-shape after all-to-all: ecm -> gecm
dispatched_input
=
dispatched_input
.
reshape
(
self
.
world_size
,
self
.
num_local_experts
,
-
1
,
d_model
)
chunks
=
dispatched_input
.
chunk
(
self
.
num_local_experts
,
dim
=
1
)
expert_outputs
=
[]
for
chunk
,
expert
in
zip
(
chunks
,
self
.
experts
):
expert_outputs
+=
[
expert
(
chunk
)]
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
0
)
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
1
)
# Re-shape back: gecm -> ecm
expert_output
=
expert_output
.
reshape
(
self
.
world_size
*
self
.
num_local_experts
,
-
1
,
d_model
)
combined_output
=
self
.
all_to_all_combine
(
combine_weights
,
expert_output
)
return
combined_output
.
reshape
(
shape
)
return
combined_output
.
reshape
(
input
[
0
].
shape
)
tests/nn/moe/test_moe_layer.py
View file @
317c0945
...
...
@@ -135,6 +135,31 @@ def test_forward_routing(device):
assert
torch
.
allclose
(
input
[:,
i
]
*
(
expert
+
1
),
output
[:,
i
])
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
def
test_forward_routing_multi
(
device
):
model_dim
=
8
num_local_experts
=
4
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
*
num_local_experts
input
=
torch
.
randn
(
4
*
num_local_experts
,
16
,
model_dim
).
to
(
device
)
gate
=
RoundRobinGate
(
model_dim
,
num_experts
)
experts
=
[]
for
i
in
range
(
num_local_experts
):
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use scaling matrix (each rank has a different scale)
scale
=
dist
.
get_rank
()
*
num_local_experts
+
i
+
1
expert
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
eye
(
model_dim
)
*
scale
)
experts
+=
[
expert
]
moe
=
MOELayer
(
gate
,
torch
.
nn
.
ModuleList
(
experts
)).
to
(
device
)
output
=
moe
(
input
)
assert
output
.
shape
==
input
.
shape
# Verify that each token was sent to the correct expert by checking its scale.
t
=
input
.
shape
[
1
]
for
i
in
range
(
t
):
expert
=
i
%
num_experts
assert
torch
.
allclose
(
input
[:,
i
]
*
(
expert
+
1
),
output
[:,
i
])
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
def
test_backward
(
device
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment