Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
317c0945
Unverified
Commit
317c0945
authored
Nov 11, 2020
by
msbaines
Committed by
GitHub
Nov 11, 2020
Browse files
[fix] moe: fix bug for multiple experts per-gpu case (#184)
parent
89176e34
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
6 deletions
+36
-6
fairscale/nn/moe/moe_layer.py
fairscale/nn/moe/moe_layer.py
+11
-6
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_moe_layer.py
+25
-0
No files found.
fairscale/nn/moe/moe_layer.py
View file @
317c0945
...
@@ -24,7 +24,6 @@ class _AllToAll(torch.autograd.Function):
...
@@ -24,7 +24,6 @@ class _AllToAll(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
:
Any
,
group
:
dist
.
ProcessGroup
,
input
:
Tensor
)
->
Tensor
:
# type: ignore
def
forward
(
ctx
:
Any
,
group
:
dist
.
ProcessGroup
,
input
:
Tensor
)
->
Tensor
:
# type: ignore
ctx
.
group
=
group
ctx
.
group
=
group
world_size
=
dist
.
get_world_size
(
group
)
input
=
input
.
contiguous
()
input
=
input
.
contiguous
()
output
=
torch
.
empty_like
(
input
)
output
=
torch
.
empty_like
(
input
)
dist
.
all_to_all_single
(
output
,
input
,
group
=
group
)
dist
.
all_to_all_single
(
output
,
input
,
group
=
group
)
...
@@ -64,6 +63,8 @@ class MOELayer(Base):
...
@@ -64,6 +63,8 @@ class MOELayer(Base):
for
expert
in
self
.
experts
:
for
expert
in
self
.
experts
:
for
p
in
experts
.
parameters
():
for
p
in
experts
.
parameters
():
p
.
expert
=
True
# type: ignore
p
.
expert
=
True
# type: ignore
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
self
.
num_local_experts
=
len
(
self
.
experts
)
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
dispatched_input
=
torch
.
einsum
(
"sec,sm->ecm"
,
dispatch_mask
.
float
(),
input
)
dispatched_input
=
torch
.
einsum
(
"sec,sm->ecm"
,
dispatch_mask
.
float
(),
input
)
...
@@ -79,15 +80,19 @@ class MOELayer(Base):
...
@@ -79,15 +80,19 @@ class MOELayer(Base):
assert
input
[
0
].
shape
[
0
]
%
len
(
self
.
experts
)
==
0
,
"num tokens must be order of number of local experts"
assert
input
[
0
].
shape
[
0
]
%
len
(
self
.
experts
)
==
0
,
"num tokens must be order of number of local experts"
# Implement Algorithm 2 from GShard paper.
# Implement Algorithm 2 from GShard paper.
shape
=
input
[
0
].
shape
d_model
=
input
[
0
].
shape
[
2
]
# Reshape into S tokens by dropping sequence dimension.
# Reshape into S tokens by dropping sequence dimension.
reshaped_input
=
input
[
0
].
reshape
(
-
1
,
shape
[
2
]
)
reshaped_input
=
input
[
0
].
reshape
(
-
1
,
d_model
)
self
.
l_aux
,
combine_weights
,
dispatching_mask
=
self
.
gate
(
reshaped_input
)
self
.
l_aux
,
combine_weights
,
dispatching_mask
=
self
.
gate
(
reshaped_input
)
dispatched_input
=
self
.
all_to_all_dispatch
(
dispatching_mask
,
reshaped_input
)
dispatched_input
=
self
.
all_to_all_dispatch
(
dispatching_mask
,
reshaped_input
)
chunks
=
dispatched_input
.
chunk
(
len
(
self
.
experts
),
dim
=
0
)
# Re-shape after all-to-all: ecm -> gecm
dispatched_input
=
dispatched_input
.
reshape
(
self
.
world_size
,
self
.
num_local_experts
,
-
1
,
d_model
)
chunks
=
dispatched_input
.
chunk
(
self
.
num_local_experts
,
dim
=
1
)
expert_outputs
=
[]
expert_outputs
=
[]
for
chunk
,
expert
in
zip
(
chunks
,
self
.
experts
):
for
chunk
,
expert
in
zip
(
chunks
,
self
.
experts
):
expert_outputs
+=
[
expert
(
chunk
)]
expert_outputs
+=
[
expert
(
chunk
)]
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
0
)
expert_output
=
torch
.
cat
(
expert_outputs
,
dim
=
1
)
# Re-shape back: gecm -> ecm
expert_output
=
expert_output
.
reshape
(
self
.
world_size
*
self
.
num_local_experts
,
-
1
,
d_model
)
combined_output
=
self
.
all_to_all_combine
(
combine_weights
,
expert_output
)
combined_output
=
self
.
all_to_all_combine
(
combine_weights
,
expert_output
)
return
combined_output
.
reshape
(
shape
)
return
combined_output
.
reshape
(
input
[
0
].
shape
)
tests/nn/moe/test_moe_layer.py
View file @
317c0945
...
@@ -135,6 +135,31 @@ def test_forward_routing(device):
...
@@ -135,6 +135,31 @@ def test_forward_routing(device):
assert
torch
.
allclose
(
input
[:,
i
]
*
(
expert
+
1
),
output
[:,
i
])
assert
torch
.
allclose
(
input
[:,
i
]
*
(
expert
+
1
),
output
[:,
i
])
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
def
test_forward_routing_multi
(
device
):
model_dim
=
8
num_local_experts
=
4
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
*
num_local_experts
input
=
torch
.
randn
(
4
*
num_local_experts
,
16
,
model_dim
).
to
(
device
)
gate
=
RoundRobinGate
(
model_dim
,
num_experts
)
experts
=
[]
for
i
in
range
(
num_local_experts
):
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use scaling matrix (each rank has a different scale)
scale
=
dist
.
get_rank
()
*
num_local_experts
+
i
+
1
expert
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
eye
(
model_dim
)
*
scale
)
experts
+=
[
expert
]
moe
=
MOELayer
(
gate
,
torch
.
nn
.
ModuleList
(
experts
)).
to
(
device
)
output
=
moe
(
input
)
assert
output
.
shape
==
input
.
shape
# Verify that each token was sent to the correct expert by checking its scale.
t
=
input
.
shape
[
1
]
for
i
in
range
(
t
):
expert
=
i
%
num_experts
assert
torch
.
allclose
(
input
[:,
i
]
*
(
expert
+
1
),
output
[:,
i
])
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
def
test_backward
(
device
):
def
test_backward
(
device
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment