Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
6d802f5a
Unverified
Commit
6d802f5a
authored
Oct 13, 2020
by
msbaines
Committed by
GitHub
Oct 13, 2020
Browse files
[feat] moe: add all_to_all support (#134)
parent
177151e0
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
69 additions
and
32 deletions
+69
-32
.circleci/config.yml
.circleci/config.yml
+8
-0
fairscale/nn/moe/moelayer.py
fairscale/nn/moe/moelayer.py
+18
-9
requirements-test.txt
requirements-test.txt
+2
-0
stubs/torch/distributed/__init__.pyi
stubs/torch/distributed/__init__.pyi
+6
-1
tests/nn/moe/test_moelayer.py
tests/nn/moe/test_moelayer.py
+31
-22
tests/optim/test_oss.py
tests/optim/test_oss.py
+4
-0
No files found.
.circleci/config.yml
View file @
6d802f5a
...
...
@@ -42,6 +42,7 @@ install_dep_15: &install_dep_15
-
run
:
name
:
Install Dependencies
command
:
|
sudo apt-get install -y mpi-default-dev
pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
...
...
@@ -51,6 +52,7 @@ install_dep_16: &install_dep_16
-
run
:
name
:
Install Dependencies
command
:
|
sudo apt-get install -y mpi-default-dev
pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
...
...
@@ -84,6 +86,12 @@ run_unittests: &run_unittests
command
:
|
pytest --junitxml=test-results/junit.xml --verbose
run_mpi_unittests
:
&run_mpi_unittests
-
run
:
name
:
Run MPI Unit Tests
command
:
|
mpirun -n4 python -m pytest -only-mpi --junitxml=test-results/junit.xml --verbose
run_flake8
:
&run_flake8
-
run
:
name
:
Run Linter (flake8)
...
...
fairscale/nn/moe/moelayer.py
View file @
6d802f5a
...
...
@@ -3,10 +3,11 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
from
torch
import
Tensor
import
torch.distributed
as
dist
from
torch.nn
import
Module
if
TYPE_CHECKING
:
...
...
@@ -24,7 +25,8 @@ class MOELayer(Base):
gate = Top2Gate(model_dim, num_experts)
moe = MOELayer(gate, expert)
l_aux, combine_weights, dispatch_mask = moe(input)
output = moe(input)
l_aux = moe.l_aux
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
...
...
@@ -35,24 +37,31 @@ class MOELayer(Base):
expert network
"""
def
__init__
(
self
,
gate
:
Module
,
expert
:
Module
)
->
None
:
def
__init__
(
self
,
gate
:
Module
,
expert
:
Module
,
group
:
Optional
[
Any
]
=
None
)
->
None
:
super
().
__init__
()
self
.
gate
=
gate
self
.
expert
=
expert
self
.
group
=
group
if
group
is
not
None
else
dist
.
group
.
WORLD
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
dispatched_input
=
torch
.
einsum
(
"gsec,gsm->egcm"
,
dispatch_mask
.
float
(),
input
)
# TODO(msb) all-to-all
dispatched_input
=
torch
.
squeeze
(
dispatched_input
,
0
)
# drop E dimension
dispatched_input
=
dispatched_input
.
contiguous
()
chunks
=
list
(
dispatched_input
.
chunk
(
self
.
world_size
))
dist
.
all_to_all
(
chunks
,
chunks
,
self
.
group
)
return
dispatched_input
def
all_to_all_combine
(
self
,
combine_weights
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
# TODO(msb) all-to-all
expert_output
=
torch
.
unsqueeze
(
input
,
1
)
# add E dimension
output
=
torch
.
einsum
(
"gsec,gecm->gsm"
,
combine_weights
,
expert_output
)
expert_output
=
input
.
contiguous
()
chunks
=
list
(
expert_output
.
chunk
(
self
.
world_size
))
dist
.
all_to_all
(
chunks
,
chunks
,
self
.
group
)
output
=
torch
.
einsum
(
"gsec,egcm->gsm"
,
combine_weights
,
expert_output
)
return
output
def
forward
(
self
,
*
input
:
Any
,
**
kwargs
:
Any
)
->
Tensor
:
def
forward
(
self
,
*
input
:
Tensor
,
**
kwargs
:
Any
)
->
Tensor
:
assert
len
(
input
)
==
1
,
"only single input Tensor supported"
assert
len
(
input
[
0
].
shape
)
==
4
,
"input Tensor must have dimensions: (g)roup, (s)equence, (t)oken, (m)odel"
# Implement Algorithm 2 from GShard paper.
shape
=
input
[
0
].
shape
# Reshape into S tokens per group.
...
...
requirements-test.txt
View file @
6d802f5a
black == 19.10b0
flake8 == 3.7.9
isort == 4.3.21
mpi4py == 3.0.3
mypy == 0.770
pytest == 5.4.1
pytest-cov == 2.10.0
pytest-mpi == 0.4
torchtext == 0.6.0
torch >= 1.5.1
torchvision >= 0.6.0
...
...
stubs/torch/distributed/__init__.pyi
View file @
6d802f5a
...
...
@@ -6,7 +6,10 @@ import datetime
from . import rpc as rpc
class Backend: ...
class Backend:
GLOO: str
MPI: str
NCCL: str
class ProcessGroup:
def size(self) -> int: ...
...
...
@@ -29,8 +32,10 @@ def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
def is_initialized() -> bool: ...
def init_process_group(backend: Union[str, Backend], timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ...
def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ...
def all_to_all(output: List[Tensor], intput: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
...
...
tests/nn/moe/test_moelayer.py
View file @
6d802f5a
...
...
@@ -3,34 +3,53 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
os
import
pytest
import
torch
import
torch.distributed
as
dist
from
fairscale.nn
import
MOELayer
,
Top2Gate
skip_if_no_cuda
=
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
BACKEND
=
dist
.
Backend
.
NCCL
if
torch
.
cuda
.
is_available
()
else
dist
.
Backend
.
GLOO
# type: ignore
def
test_create
():
model_dim
=
8
num_experts
=
4
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
)
moe
=
MOELayer
(
gate
,
expert
)
if
torch
.
cuda
.
is_available
():
devices
=
[
"cpu"
,
"cuda"
]
else
:
devices
=
[
"cpu"
]
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
if
"OMPI_COMM_WORLD_SIZE"
in
os
.
environ
:
dist
.
init_process_group
(
backend
=
dist
.
Backend
.
MPI
)
def
setup_module
(
module
):
if
"OMPI_COMM_WORLD_SIZE"
not
in
os
.
environ
:
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
0
,
world_size
=
1
)
@
skip_if_no_cuda
def
test_create_cuda
():
def
teardown_module
(
module
):
if
"OMPI_COMM_WORLD_SIZE"
not
in
os
.
environ
:
torch
.
distributed
.
destroy_process_group
()
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
test_create
(
device
):
model_dim
=
8
num_experts
=
4
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
)
moe
=
MOELayer
(
gate
,
expert
).
cuda
(
)
moe
=
MOELayer
(
gate
,
expert
).
to
(
device
)
def
do_test_forward
(
device
):
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
def
test_forward
(
device
):
model_dim
=
8
num_experts
=
1
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
input
=
torch
.
randn
(
3
,
4
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
...
...
@@ -38,16 +57,6 @@ def do_test_forward(device):
expert
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
eye
(
model_dim
))
moe
=
MOELayer
(
gate
,
expert
).
to
(
device
)
output
=
moe
(
input
)
assert
moe
.
l_aux
.
item
()
==
1.0
assert
output
.
shape
==
input
.
shape
# Re-assembled output should match input due to identity expert.
assert
torch
.
equal
(
input
,
output
)
def
test_forward_cpu
():
do_test_forward
(
"cpu"
)
@
skip_if_no_cuda
def
test_forward_cuda
():
do_test_forward
(
"cuda"
)
assert
torch
.
allclose
(
input
,
output
)
tests/optim/test_oss.py
View file @
6d802f5a
...
...
@@ -29,6 +29,10 @@ def setup_module(module):
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
0
,
world_size
=
1
)
def
teardown_module
(
module
):
torch
.
distributed
.
destroy_process_group
()
def
dist_init
(
rank
,
world_size
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment