Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
d99c445a
Unverified
Commit
d99c445a
authored
Oct 16, 2020
by
msbaines
Committed by
GitHub
Oct 16, 2020
Browse files
[feat] moe: add all_to_all backward support (#137)
parent
1e6c547a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
13 deletions
+43
-13
.circleci/config.yml
.circleci/config.yml
+2
-2
fairscale/nn/moe/moelayer.py
fairscale/nn/moe/moelayer.py
+22
-11
tests/nn/moe/test_moelayer.py
tests/nn/moe/test_moelayer.py
+19
-0
No files found.
.circleci/config.yml
View file @
d99c445a
...
@@ -42,7 +42,7 @@ install_dep_15: &install_dep_15
...
@@ -42,7 +42,7 @@ install_dep_15: &install_dep_15
-
run
:
-
run
:
name
:
Install Dependencies
name
:
Install Dependencies
command
:
|
command
:
|
sudo apt-get install -y
mpi-default
-dev
sudo apt-get install -y
libopenmpi
-dev
pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; print("Torch version:", torch.__version__)'
...
@@ -52,7 +52,7 @@ install_dep_16: &install_dep_16
...
@@ -52,7 +52,7 @@ install_dep_16: &install_dep_16
-
run
:
-
run
:
name
:
Install Dependencies
name
:
Install Dependencies
command
:
|
command
:
|
sudo apt-get install -y
mpi-default
-dev
sudo apt-get install -y
libopenmpi
-dev
pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; print("Torch version:", torch.__version__)'
...
...
fairscale/nn/moe/moelayer.py
View file @
d99c445a
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
# This source code is licensed under the BSD license found in the
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
...
@@ -19,6 +19,24 @@ else:
...
@@ -19,6 +19,24 @@ else:
# See https://arxiv.org/pdf/2006.16668.pdf for details.
# See https://arxiv.org/pdf/2006.16668.pdf for details.
# Based on https://github.com/pytorch/pytorch/pull/40762
class
_AllToAll
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
:
Any
,
group
:
dist
.
ProcessGroup
,
input
:
Tensor
)
->
Tensor
:
# type: ignore
ctx
.
group
=
group
world_size
=
dist
.
get_world_size
(
group
)
input
=
input
.
contiguous
()
output
=
torch
.
empty_like
(
input
)
input_chunks
=
list
(
input
.
chunk
(
world_size
))
output_chunks
=
list
(
output
.
chunk
(
world_size
))
dist
.
all_to_all
(
output_chunks
,
input_chunks
,
group
=
group
)
return
output
@
staticmethod
def
backward
(
ctx
:
Any
,
*
grad_output
:
Tensor
)
->
Tuple
[
None
,
Tensor
]:
return
(
None
,
_AllToAll
.
apply
(
ctx
.
group
,
*
grad_output
))
class
MOELayer
(
Base
):
class
MOELayer
(
Base
):
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
::
::
...
@@ -42,21 +60,14 @@ class MOELayer(Base):
...
@@ -42,21 +60,14 @@ class MOELayer(Base):
self
.
gate
=
gate
self
.
gate
=
gate
self
.
expert
=
expert
self
.
expert
=
expert
self
.
group
=
group
if
group
is
not
None
else
dist
.
group
.
WORLD
self
.
group
=
group
if
group
is
not
None
else
dist
.
group
.
WORLD
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
dispatched_input
=
torch
.
einsum
(
"gsec,gsm->egcm"
,
dispatch_mask
.
float
(),
input
)
dispatched_input
=
torch
.
einsum
(
"gsec,gsm->egcm"
,
dispatch_mask
.
float
(),
input
)
dispatched_input
=
dispatched_input
.
contiguous
()
return
_AllToAll
.
apply
(
self
.
group
,
dispatched_input
)
chunks
=
list
(
dispatched_input
.
chunk
(
self
.
world_size
))
dist
.
all_to_all
(
chunks
,
chunks
,
self
.
group
)
return
dispatched_input
def
all_to_all_combine
(
self
,
combine_weights
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
def
all_to_all_combine
(
self
,
combine_weights
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
expert_output
=
input
.
contiguous
()
expert_output
=
_AllToAll
.
apply
(
self
.
group
,
input
)
chunks
=
list
(
expert_output
.
chunk
(
self
.
world_size
))
return
torch
.
einsum
(
"gsec,egcm->gsm"
,
combine_weights
,
expert_output
)
dist
.
all_to_all
(
chunks
,
chunks
,
self
.
group
)
output
=
torch
.
einsum
(
"gsec,egcm->gsm"
,
combine_weights
,
expert_output
)
return
output
def
forward
(
self
,
*
input
:
Tensor
,
**
kwargs
:
Any
)
->
Tensor
:
def
forward
(
self
,
*
input
:
Tensor
,
**
kwargs
:
Any
)
->
Tensor
:
assert
len
(
input
)
==
1
,
"only single input Tensor supported"
assert
len
(
input
)
==
1
,
"only single input Tensor supported"
...
...
tests/nn/moe/test_moelayer.py
View file @
d99c445a
...
@@ -60,3 +60,22 @@ def test_forward(device):
...
@@ -60,3 +60,22 @@ def test_forward(device):
assert
output
.
shape
==
input
.
shape
assert
output
.
shape
==
input
.
shape
# Re-assembled output should match input due to identity expert.
# Re-assembled output should match input due to identity expert.
assert
torch
.
allclose
(
input
,
output
)
assert
torch
.
allclose
(
input
,
output
)
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
def
test_backward
(
device
):
loss
=
torch
.
nn
.
MSELoss
()
model_dim
=
8
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
input
=
torch
.
randn
(
3
,
4
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
# Use identity matrix
expert
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
eye
(
model_dim
))
moe
=
MOELayer
(
gate
,
expert
).
to
(
device
)
output
=
moe
(
input
)
assert
output
.
shape
==
input
.
shape
output
=
loss
(
output
,
input
)
output
.
backward
()
assert
torch
.
allclose
(
expert
.
weight
.
grad
,
torch
.
zeros_like
(
expert
.
weight
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment