Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
5a3df0da
Unverified
Commit
5a3df0da
authored
Apr 02, 2021
by
msbaines
Committed by
GitHub
Apr 02, 2021
Browse files
[test] modify MOE tests to use NCCL (#570)
NCCL all_to_all is now supported in PyTorch (since v1.8.0) Fixes: #548
parent
60694da1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
72 deletions
+48
-72
.circleci/config.yml
.circleci/config.yml
+6
-34
requirements-test.txt
requirements-test.txt
+0
-2
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_moe_layer.py
+42
-36
No files found.
.circleci/config.yml
View file @
5a3df0da
...
@@ -66,17 +66,12 @@ install_dep_160: &install_dep_160
...
@@ -66,17 +66,12 @@ install_dep_160: &install_dep_160
-
run
:
-
run
:
name
:
Install Dependencies with torch 1.6.0
name
:
Install Dependencies with torch 1.6.0
command
:
|
command
:
|
# make sure that apt-get retries if needed
sudo sh -c "echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80-retries"
sudo apt-get update -y
sudo apt-get install -y libopenmpi-dev
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.6 && exit 0; fi
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.6 && exit 0; fi
# start installing
# start installing
pip install --progress-bar off torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmarks.txt
pip install --progress-bar off -r requirements-benchmarks.txt
pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "6"], "wrong torch version"'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "6"], "wrong torch version"'
python -m torch.utils.collect_env
python -m torch.utils.collect_env
...
@@ -86,17 +81,12 @@ install_dep_171: &install_dep_171
...
@@ -86,17 +81,12 @@ install_dep_171: &install_dep_171
-
run
:
-
run
:
name
:
Install Dependencies with torch 1.7.1
name
:
Install Dependencies with torch 1.7.1
command
:
|
command
:
|
# make sure that apt-get retries if needed
sudo sh -c "echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80-retries"
sudo apt-get update -y
sudo apt-get install -y libopenmpi-dev
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.7 && exit 0; fi
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.7 && exit 0; fi
# start installing
# start installing
pip install --progress-bar off torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmarks.txt
pip install --progress-bar off -r requirements-benchmarks.txt
pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "7"], "wrong torch version"'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "7"], "wrong torch version"'
python -m torch.utils.collect_env
python -m torch.utils.collect_env
...
@@ -106,10 +96,6 @@ install_dep_181: &install_dep_181
...
@@ -106,10 +96,6 @@ install_dep_181: &install_dep_181
-
run
:
-
run
:
name
:
Install Dependencies with torch 1.8.1
name
:
Install Dependencies with torch 1.8.1
command
:
|
command
:
|
# make sure that apt-get retries if needed
sudo sh -c "echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80-retries"
sudo apt-get update -y
sudo apt-get install -y libopenmpi-dev
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi
# start installing
# start installing
...
@@ -125,10 +111,6 @@ install_dep_190: &install_dep_190
...
@@ -125,10 +111,6 @@ install_dep_190: &install_dep_190
-
run
:
-
run
:
name
:
Install Dependencies with torch 1.9.0
name
:
Install Dependencies with torch 1.9.0
command
:
|
command
:
|
# make sure that apt-get retries if needed
sudo sh -c "echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80-retries"
sudo apt-get update -y
sudo apt-get install -y libopenmpi-dev
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi
# start installing
# start installing
...
@@ -184,13 +166,6 @@ upload_coverage: &upload_coverage
...
@@ -184,13 +166,6 @@ upload_coverage: &upload_coverage
file
:
'
coverage.xml'
file
:
'
coverage.xml'
token
:
$CODECOV_TOKEN
token
:
$CODECOV_TOKEN
run_mpi_unittests
:
&run_mpi_unittests
-
run
:
name
:
Run MPI Unit Tests
command
:
|
mpirun -n 4 python -m pytest -p torch_pg.pytest --only-mpi --junitxml=test-results/junit.xml --verbose tests/nn/moe
run_pipe_benchmark
:
&run_pipe_benchmark
run_pipe_benchmark
:
&run_pipe_benchmark
-
run
:
-
run
:
name
:
Run Pipe Benchmark
name
:
Run Pipe Benchmark
...
@@ -276,14 +251,14 @@ jobs:
...
@@ -276,14 +251,14 @@ jobs:
# Cache the venv directory that contains dependencies
# Cache the venv directory that contains dependencies
-
restore_cache
:
-
restore_cache
:
keys
:
keys
:
-
cache-key-cpu-py37-18
0-
1-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
cache-key-cpu-py37-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
<<
:
*install_dep_1
7
1
-
<<
:
*install_dep_1
8
1
-
save_cache
:
-
save_cache
:
paths
:
paths
:
-
~/venv
-
~/venv
key
:
cache-key-cpu-py37-18
0-
1-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
key
:
cache-key-cpu-py37-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
<<
:
*install_repo
-
<<
:
*install_repo
...
@@ -292,7 +267,6 @@ jobs:
...
@@ -292,7 +267,6 @@ jobs:
-
<<
:
*run_mypy
-
<<
:
*run_mypy
-
<<
:
*run_flake8
-
<<
:
*run_flake8
-
<<
:
*run_unittests
-
<<
:
*run_unittests
-
<<
:
*run_mpi_unittests
-
<<
:
*run_doc_build
-
<<
:
*run_doc_build
-
store_test_results
:
-
store_test_results
:
...
@@ -311,13 +285,13 @@ jobs:
...
@@ -311,13 +285,13 @@ jobs:
# Cache the venv directory that contains dependencies
# Cache the venv directory that contains dependencies
-
restore_cache
:
-
restore_cache
:
keys
:
keys
:
-
cache-key-cpu-py38-18
0-
1-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
cache-key-cpu-py38-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
<<
:
*install_dep_1
7
1
-
<<
:
*install_dep_1
8
1
-
save_cache
:
-
save_cache
:
paths
:
paths
:
-
~/venv
-
~/venv
key
:
cache-key-cpu-py38-18
0-
1-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
key
:
cache-key-cpu-py38-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
<<
:
*install_repo
-
<<
:
*install_repo
...
@@ -326,7 +300,6 @@ jobs:
...
@@ -326,7 +300,6 @@ jobs:
-
<<
:
*run_mypy
-
<<
:
*run_mypy
-
<<
:
*run_flake8
-
<<
:
*run_flake8
-
<<
:
*run_unittests
-
<<
:
*run_unittests
-
<<
:
*run_mpi_unittests
-
<<
:
*run_doc_build
-
<<
:
*run_doc_build
-
store_test_results
:
-
store_test_results
:
...
@@ -361,7 +334,6 @@ jobs:
...
@@ -361,7 +334,6 @@ jobs:
-
<<
:
*run_mypy
-
<<
:
*run_mypy
-
<<
:
*run_flake8
-
<<
:
*run_flake8
-
<<
:
*run_unittests
-
<<
:
*run_unittests
# TODO(msb) - <<: *run_mpi_unittests
-
<<
:
*run_doc_build
-
<<
:
*run_doc_build
-
store_test_results
:
-
store_test_results
:
...
...
requirements-test.txt
View file @
5a3df0da
...
@@ -10,8 +10,6 @@ mypy == 0.790
...
@@ -10,8 +10,6 @@ mypy == 0.790
# Tools for unit tests & coverage.
# Tools for unit tests & coverage.
pytest == 5.4.1
pytest == 5.4.1
pytest-cov == 2.10.0
pytest-cov == 2.10.0
pytest-mpi == 0.4
pytest-timeout == 1.4.2
pytest-timeout == 1.4.2
mpi4py == 3.0.3
remote-pdb >= 2.1.0
remote-pdb >= 2.1.0
parameterized >= 0.8.1
parameterized >= 0.8.1
tests/nn/moe/test_moe_layer.py
View file @
5a3df0da
...
@@ -3,44 +3,49 @@
...
@@ -3,44 +3,49 @@
# This source code is licensed under the BSD license found in the
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
o
s
import
functool
s
import
tempfile
import
tempfile
import
pytest
import
pytest
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
fairscale.nn
import
MOELayer
,
Top2Gate
from
fairscale.nn
import
MOELayer
,
Top2Gate
from
fairscale.utils.testing
import
torch_version
skip_if_no_cuda
=
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
pytestmark
=
pytest
.
mark
.
skipif
(
not
(
torch
.
cuda
.
is_available
()
and
torch_version
()
>=
(
1
,
8
,
0
)),
reason
=
"cuda and torch>=1.8.0 required"
)
BACKEND
=
dist
.
Backend
.
NCCL
if
torch
.
cuda
.
is_available
()
else
dist
.
Backend
.
GLOO
# type: ignore
devices
=
[
"cuda"
]
if
torch
.
cuda
.
is_available
():
devices
=
[
"cpu"
,
"cuda"
]
else
:
devices
=
[
"cpu"
]
URL
=
"file://"
+
tempfile
.
mkstemp
()[
1
]
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
def
pg_worker
(
rank
,
world_size
,
init_file
,
func
,
*
args
):
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
# torch 1.5 compatibility
init_url
=
"file://"
+
init_file
dist
.
init_process_group
(
backend
=
dist
.
Backend
.
NCCL
,
rank
=
rank
,
world_size
=
world_size
,
init_method
=
init_url
)
torch
.
cuda
.
set_device
(
rank
)
dist
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
func
(
*
args
)
dist
.
destroy_process_group
()
if
"OMPI_COMM_WORLD_SIZE"
in
os
.
environ
:
dist
.
init_process_group
(
backend
=
dist
.
Backend
.
MPI
,
init_method
=
URL
)
def
pg_test
(
world_size
=
torch
.
cuda
.
device_count
()):
def
decorator
(
func
):
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
tempfile_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
pg_worker
,
args
=
(
world_size
,
tempfile_name
,
func
,
*
kwargs
.
values
()),
nprocs
=
world_size
)
def
setup_module
(
module
):
globals
()[
"test_"
+
func
.
__name__
]
=
wrapper
if
"OMPI_COMM_WORLD_SIZE"
not
in
os
.
environ
:
return
func
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
0
,
world_size
=
1
,
init_method
=
URL
)
return
decorator
def
teardown_module
(
module
):
if
"OMPI_COMM_WORLD_SIZE"
not
in
os
.
environ
:
torch
.
distributed
.
destroy_process_group
()
@
pg_test
(
world_size
=
1
)
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
test_
create
(
device
):
def
create
(
device
):
model_dim
=
8
model_dim
=
8
num_experts
=
4
num_experts
=
4
gate
=
Top2Gate
(
model_dim
,
num_experts
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
...
@@ -48,8 +53,9 @@ def test_create(device):
...
@@ -48,8 +53,9 @@ def test_create(device):
moe
=
MOELayer
(
gate
,
expert
).
to
(
device
)
moe
=
MOELayer
(
gate
,
expert
).
to
(
device
)
@
pg_test
(
world_size
=
1
)
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
test_
expert_params
(
device
):
def
expert_params
(
device
):
model_dim
=
8
model_dim
=
8
num_experts
=
4
num_experts
=
4
gate
=
Top2Gate
(
model_dim
,
num_experts
)
gate
=
Top2Gate
(
model_dim
,
num_experts
)
...
@@ -59,9 +65,9 @@ def test_expert_params(device):
...
@@ -59,9 +65,9 @@ def test_expert_params(device):
assert
p
.
expert
is
True
assert
p
.
expert
is
True
@
p
y
test
.
mark
.
mpi
@
p
g_
test
()
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
]
)
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
test_
forward
(
device
):
def
forward
(
device
):
model_dim
=
8
model_dim
=
8
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
input
=
torch
.
randn
(
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
4
,
16
,
model_dim
).
to
(
device
)
...
@@ -76,9 +82,9 @@ def test_forward(device):
...
@@ -76,9 +82,9 @@ def test_forward(device):
assert
torch
.
allclose
(
input
,
output
)
assert
torch
.
allclose
(
input
,
output
)
@
p
y
test
.
mark
.
mpi
@
p
g_
test
()
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
]
)
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
test_
forward_multi
(
device
):
def
forward_multi
(
device
):
torch
.
set_printoptions
(
threshold
=
5000
)
torch
.
set_printoptions
(
threshold
=
5000
)
num_local_experts
=
4
num_local_experts
=
4
model_dim
=
4
model_dim
=
4
...
@@ -117,9 +123,9 @@ class RoundRobinGate(torch.nn.Module):
...
@@ -117,9 +123,9 @@ class RoundRobinGate(torch.nn.Module):
return
0.0
,
output
,
output
.
bool
()
return
0.0
,
output
,
output
.
bool
()
@
p
y
test
.
mark
.
mpi
@
p
g_
test
()
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
]
)
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
test_
forward_routing
(
device
):
def
forward_routing
(
device
):
model_dim
=
8
model_dim
=
8
num_experts
=
dist
.
get_world_size
()
num_experts
=
dist
.
get_world_size
()
input
=
torch
.
randn
(
4
,
16
,
model_dim
).
to
(
device
)
input
=
torch
.
randn
(
4
,
16
,
model_dim
).
to
(
device
)
...
@@ -138,9 +144,9 @@ def test_forward_routing(device):
...
@@ -138,9 +144,9 @@ def test_forward_routing(device):
assert
torch
.
allclose
(
input
[:,
i
]
*
(
expert
+
1
),
output
[:,
i
])
assert
torch
.
allclose
(
input
[:,
i
]
*
(
expert
+
1
),
output
[:,
i
])
@
p
y
test
.
mark
.
mpi
@
p
g_
test
()
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
]
)
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
test_
forward_routing_multi
(
device
):
def
forward_routing_multi
(
device
):
model_dim
=
8
model_dim
=
8
num_local_experts
=
4
num_local_experts
=
4
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
*
num_local_experts
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
*
num_local_experts
...
@@ -163,9 +169,9 @@ def test_forward_routing_multi(device):
...
@@ -163,9 +169,9 @@ def test_forward_routing_multi(device):
assert
torch
.
allclose
(
input
[:,
i
]
*
(
expert
+
1
),
output
[:,
i
])
assert
torch
.
allclose
(
input
[:,
i
]
*
(
expert
+
1
),
output
[:,
i
])
@
p
y
test
.
mark
.
mpi
@
p
g_
test
()
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
]
)
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
test_
backward
(
device
):
def
backward
(
device
):
loss
=
torch
.
nn
.
MSELoss
()
loss
=
torch
.
nn
.
MSELoss
()
model_dim
=
8
model_dim
=
8
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment