Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
6d802f5a
Unverified
Commit
6d802f5a
authored
Oct 13, 2020
by
msbaines
Committed by
GitHub
Oct 13, 2020
Browse files
[feat] moe: add all_to_all support (#134)
parent
177151e0
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
69 additions
and
32 deletions
+69
-32
.circleci/config.yml
.circleci/config.yml
+8
-0
fairscale/nn/moe/moelayer.py
fairscale/nn/moe/moelayer.py
+18
-9
requirements-test.txt
requirements-test.txt
+2
-0
stubs/torch/distributed/__init__.pyi
stubs/torch/distributed/__init__.pyi
+6
-1
tests/nn/moe/test_moelayer.py
tests/nn/moe/test_moelayer.py
+31
-22
tests/optim/test_oss.py
tests/optim/test_oss.py
+4
-0
No files found.
.circleci/config.yml
View file @
6d802f5a
...
@@ -42,6 +42,7 @@ install_dep_15: &install_dep_15
...
@@ -42,6 +42,7 @@ install_dep_15: &install_dep_15
-
run
:
-
run
:
name
:
Install Dependencies
name
:
Install Dependencies
command
:
|
command
:
|
sudo apt-get install -y mpi-default-dev
pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off torch==1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; print("Torch version:", torch.__version__)'
...
@@ -51,6 +52,7 @@ install_dep_16: &install_dep_16
...
@@ -51,6 +52,7 @@ install_dep_16: &install_dep_16
-
run
:
-
run
:
name
:
Install Dependencies
name
:
Install Dependencies
command
:
|
command
:
|
sudo apt-get install -y mpi-default-dev
pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off torch==1.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-test.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; print("Torch version:", torch.__version__)'
...
@@ -84,6 +86,12 @@ run_unittests: &run_unittests
...
@@ -84,6 +86,12 @@ run_unittests: &run_unittests
command
:
|
command
:
|
pytest --junitxml=test-results/junit.xml --verbose
pytest --junitxml=test-results/junit.xml --verbose
run_mpi_unittests
:
&run_mpi_unittests
-
run
:
name
:
Run MPI Unit Tests
command
:
|
mpirun -n4 python -m pytest -only-mpi --junitxml=test-results/junit.xml --verbose
run_flake8
:
&run_flake8
run_flake8
:
&run_flake8
-
run
:
-
run
:
name
:
Run Linter (flake8)
name
:
Run Linter (flake8)
...
...
fairscale/nn/moe/moelayer.py
View file @
6d802f5a
...
@@ -3,10 +3,11 @@
...
@@ -3,10 +3,11 @@
# This source code is licensed under the BSD license found in the
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
import
torch.distributed
as
dist
from
torch.nn
import
Module
from
torch.nn
import
Module
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -24,7 +25,8 @@ class MOELayer(Base):
...
@@ -24,7 +25,8 @@ class MOELayer(Base):
gate = Top2Gate(model_dim, num_experts)
gate = Top2Gate(model_dim, num_experts)
moe = MOELayer(gate, expert)
moe = MOELayer(gate, expert)
l_aux, combine_weights, dispatch_mask = moe(input)
output = moe(input)
l_aux = moe.l_aux
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
...
@@ -35,24 +37,31 @@ class MOELayer(Base):
...
@@ -35,24 +37,31 @@ class MOELayer(Base):
expert network
expert network
"""
"""
def
__init__
(
self
,
gate
:
Module
,
expert
:
Module
)
->
None
:
def
__init__
(
self
,
gate
:
Module
,
expert
:
Module
,
group
:
Optional
[
Any
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate
=
gate
self
.
gate
=
gate
self
.
expert
=
expert
self
.
expert
=
expert
self
.
group
=
group
if
group
is
not
None
else
dist
.
group
.
WORLD
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
def
all_to_all_dispatch
(
self
,
dispatch_mask
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
dispatched_input
=
torch
.
einsum
(
"gsec,gsm->egcm"
,
dispatch_mask
.
float
(),
input
)
dispatched_input
=
torch
.
einsum
(
"gsec,gsm->egcm"
,
dispatch_mask
.
float
(),
input
)
# TODO(msb) all-to-all
dispatched_input
=
dispatched_input
.
contiguous
()
dispatched_input
=
torch
.
squeeze
(
dispatched_input
,
0
)
# drop E dimension
chunks
=
list
(
dispatched_input
.
chunk
(
self
.
world_size
))
dist
.
all_to_all
(
chunks
,
chunks
,
self
.
group
)
return
dispatched_input
return
dispatched_input
def
all_to_all_combine
(
self
,
combine_weights
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
def
all_to_all_combine
(
self
,
combine_weights
:
Tensor
,
input
:
Tensor
)
->
Tensor
:
# TODO(msb) all-to-all
expert_output
=
input
.
contiguous
()
expert_output
=
torch
.
unsqueeze
(
input
,
1
)
# add E dimension
chunks
=
list
(
expert_output
.
chunk
(
self
.
world_size
))
output
=
torch
.
einsum
(
"gsec,gecm->gsm"
,
combine_weights
,
expert_output
)
dist
.
all_to_all
(
chunks
,
chunks
,
self
.
group
)
output
=
torch
.
einsum
(
"gsec,egcm->gsm"
,
combine_weights
,
expert_output
)
return
output
return
output
def
forward
(
self
,
*
input
:
Any
,
**
kwargs
:
Any
)
->
Tensor
:
def
forward
(
self
,
*
input
:
Tensor
,
**
kwargs
:
Any
)
->
Tensor
:
assert
len
(
input
)
==
1
,
"only single input Tensor supported"
assert
len
(
input
[
0
].
shape
)
==
4
,
"input Tensor must have dimensions: (g)roup, (s)equence, (t)oken, (m)odel"
# Implement Algorithm 2 from GShard paper.
# Implement Algorithm 2 from GShard paper.
shape
=
input
[
0
].
shape
shape
=
input
[
0
].
shape
# Reshape into S tokens per group.
# Reshape into S tokens per group.
...
...
requirements-test.txt
View file @
6d802f5a
black == 19.10b0
black == 19.10b0
flake8 == 3.7.9
flake8 == 3.7.9
isort == 4.3.21
isort == 4.3.21
mpi4py == 3.0.3
mypy == 0.770
mypy == 0.770
pytest == 5.4.1
pytest == 5.4.1
pytest-cov == 2.10.0
pytest-cov == 2.10.0
pytest-mpi == 0.4
torchtext == 0.6.0
torchtext == 0.6.0
torch >= 1.5.1
torch >= 1.5.1
torchvision >= 0.6.0
torchvision >= 0.6.0
...
...
stubs/torch/distributed/__init__.pyi
View file @
6d802f5a
...
@@ -6,7 +6,10 @@ import datetime
...
@@ -6,7 +6,10 @@ import datetime
from . import rpc as rpc
from . import rpc as rpc
class Backend: ...
class Backend:
GLOO: str
MPI: str
NCCL: str
class ProcessGroup:
class ProcessGroup:
def size(self) -> int: ...
def size(self) -> int: ...
...
@@ -29,8 +32,10 @@ def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
...
@@ -29,8 +32,10 @@ def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
def is_initialized() -> bool: ...
def is_initialized() -> bool: ...
def init_process_group(backend: Union[str, Backend], timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ...
def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ...
def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ...
def all_to_all(output: List[Tensor], intput: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
...
...
tests/nn/moe/test_moelayer.py
View file @
6d802f5a
...
@@ -3,34 +3,53 @@
...
@@ -3,34 +3,53 @@
# This source code is licensed under the BSD license found in the
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
os
import
pytest
import
pytest
import
torch
import
torch
import
torch.distributed
as
dist
from
fairscale.nn
import
MOELayer
,
Top2Gate
from
fairscale.nn
import
MOELayer
,
Top2Gate
skip_if_no_cuda
=
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
skip_if_no_cuda
=
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
BACKEND
=
dist
.
Backend
.
NCCL
if
torch
.
cuda
.
is_available
()
else
dist
.
Backend
.
GLOO
# type: ignore
def
test_create
():
if
torch
.
cuda
.
is_available
():
model_dim
=
8
devices
=
[
"cpu"
,
"cuda"
]
num_experts
=
4
else
:
gate
=
Top2Gate
(
model_dim
,
num_experts
)
devices
=
[
"cpu"
]
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
)
moe
=
MOELayer
(
gate
,
expert
)
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
if
"OMPI_COMM_WORLD_SIZE"
in
os
.
environ
:
dist
.
init_process_group
(
backend
=
dist
.
Backend
.
MPI
)
def
setup_module
(
module
):
if
"OMPI_COMM_WORLD_SIZE"
not
in
os
.
environ
:
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
0
,
world_size
=
1
)
@
skip_if_no_cuda
def
teardown_module
(
module
):
def
test_create_cuda
():
if
"OMPI_COMM_WORLD_SIZE"
not
in
os
.
environ
:
torch
.
distributed
.
destroy_process_group
()
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
test_create
(
device
):
model_dim
=
8
model_dim
=
8
num_experts
=
4
num_experts
=
4
gate
=
Top2Gate
(
model_dim
,
num_experts
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
)
moe
=
MOELayer
(
gate
,
expert
).
cuda
(
)
moe
=
MOELayer
(
gate
,
expert
).
to
(
device
)
def
do_test_forward
(
device
):
@
pytest
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
])
def
test_forward
(
device
):
model_dim
=
8
model_dim
=
8
num_experts
=
1
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
input
=
torch
.
randn
(
3
,
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
3
,
4
,
16
,
model_dim
).
to
(
device
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
expert
=
torch
.
nn
.
Linear
(
model_dim
,
model_dim
,
bias
=
False
)
...
@@ -38,16 +57,6 @@ def do_test_forward(device):
...
@@ -38,16 +57,6 @@ def do_test_forward(device):
expert
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
eye
(
model_dim
))
expert
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
eye
(
model_dim
))
moe
=
MOELayer
(
gate
,
expert
).
to
(
device
)
moe
=
MOELayer
(
gate
,
expert
).
to
(
device
)
output
=
moe
(
input
)
output
=
moe
(
input
)
assert
moe
.
l_aux
.
item
()
==
1.0
assert
output
.
shape
==
input
.
shape
assert
output
.
shape
==
input
.
shape
# Re-assembled output should match input due to identity expert.
# Re-assembled output should match input due to identity expert.
assert
torch
.
equal
(
input
,
output
)
assert
torch
.
allclose
(
input
,
output
)
def
test_forward_cpu
():
do_test_forward
(
"cpu"
)
@
skip_if_no_cuda
def
test_forward_cuda
():
do_test_forward
(
"cuda"
)
tests/optim/test_oss.py
View file @
6d802f5a
...
@@ -29,6 +29,10 @@ def setup_module(module):
...
@@ -29,6 +29,10 @@ def setup_module(module):
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
0
,
world_size
=
1
)
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
0
,
world_size
=
1
)
def
teardown_module
(
module
):
torch
.
distributed
.
destroy_process_group
()
def
dist_init
(
rank
,
world_size
):
def
dist_init
(
rank
,
world_size
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment