Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
5a3df0da
Unverified
Commit
5a3df0da
authored
Apr 02, 2021
by
msbaines
Committed by
GitHub
Apr 02, 2021
Browse files
[test] modify MOE tests to use NCCL (#570)
NCCL all_to_all is now supported in PyTorch (since v1.8.0) Fixes: #548
parent
60694da1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
72 deletions
+48
-72
.circleci/config.yml
.circleci/config.yml
+6
-34
requirements-test.txt
requirements-test.txt
+0
-2
tests/nn/moe/test_moe_layer.py
tests/nn/moe/test_moe_layer.py
+42
-36
No files found.
.circleci/config.yml
View file @
5a3df0da
...
...
@@ -66,17 +66,12 @@ install_dep_160: &install_dep_160
-
run
:
name
:
Install Dependencies with torch 1.6.0
command
:
|
# make sure that apt-get retries if needed
sudo sh -c "echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80-retries"
sudo apt-get update -y
sudo apt-get install -y libopenmpi-dev
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.6 && exit 0; fi
# start installing
pip install --progress-bar off torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmarks.txt
pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "6"], "wrong torch version"'
python -m torch.utils.collect_env
...
...
@@ -86,17 +81,12 @@ install_dep_171: &install_dep_171
-
run
:
name
:
Install Dependencies with torch 1.7.1
command
:
|
# make sure that apt-get retries if needed
sudo sh -c "echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80-retries"
sudo apt-get update -y
sudo apt-get install -y libopenmpi-dev
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.7 && exit 0; fi
# start installing
pip install --progress-bar off torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmarks.txt
pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "7"], "wrong torch version"'
python -m torch.utils.collect_env
...
...
@@ -106,10 +96,6 @@ install_dep_181: &install_dep_181
-
run
:
name
:
Install Dependencies with torch 1.8.1
command
:
|
# make sure that apt-get retries if needed
sudo sh -c "echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80-retries"
sudo apt-get update -y
sudo apt-get install -y libopenmpi-dev
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi
# start installing
...
...
@@ -125,10 +111,6 @@ install_dep_190: &install_dep_190
-
run
:
name
:
Install Dependencies with torch 1.9.0
command
:
|
# make sure that apt-get retries if needed
sudo sh -c "echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80-retries"
sudo apt-get update -y
sudo apt-get install -y libopenmpi-dev
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi
# start installing
...
...
@@ -184,13 +166,6 @@ upload_coverage: &upload_coverage
file
:
'
coverage.xml'
token
:
$CODECOV_TOKEN
run_mpi_unittests
:
&run_mpi_unittests
-
run
:
name
:
Run MPI Unit Tests
command
:
|
mpirun -n 4 python -m pytest -p torch_pg.pytest --only-mpi --junitxml=test-results/junit.xml --verbose tests/nn/moe
run_pipe_benchmark
:
&run_pipe_benchmark
-
run
:
name
:
Run Pipe Benchmark
...
...
@@ -276,14 +251,14 @@ jobs:
# Cache the venv directory that contains dependencies
-
restore_cache
:
keys
:
-
cache-key-cpu-py37-18
0-
1-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
cache-key-cpu-py37-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
<<
:
*install_dep_1
7
1
-
<<
:
*install_dep_1
8
1
-
save_cache
:
paths
:
-
~/venv
key
:
cache-key-cpu-py37-18
0-
1-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
key
:
cache-key-cpu-py37-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
<<
:
*install_repo
...
...
@@ -292,7 +267,6 @@ jobs:
-
<<
:
*run_mypy
-
<<
:
*run_flake8
-
<<
:
*run_unittests
-
<<
:
*run_mpi_unittests
-
<<
:
*run_doc_build
-
store_test_results
:
...
...
@@ -311,13 +285,13 @@ jobs:
# Cache the venv directory that contains dependencies
-
restore_cache
:
keys
:
-
cache-key-cpu-py38-18
0-
1-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
<<
:
*install_dep_1
7
1
-
cache-key-cpu-py38-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
<<
:
*install_dep_1
8
1
-
save_cache
:
paths
:
-
~/venv
key
:
cache-key-cpu-py38-18
0-
1-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
key
:
cache-key-cpu-py38-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
-
<<
:
*install_repo
...
...
@@ -326,7 +300,6 @@ jobs:
-
<<
:
*run_mypy
-
<<
:
*run_flake8
-
<<
:
*run_unittests
-
<<
:
*run_mpi_unittests
-
<<
:
*run_doc_build
-
store_test_results
:
...
...
@@ -361,7 +334,6 @@ jobs:
-
<<
:
*run_mypy
-
<<
:
*run_flake8
-
<<
:
*run_unittests
# TODO(msb) - <<: *run_mpi_unittests
-
<<
:
*run_doc_build
-
store_test_results
:
...
...
requirements-test.txt
View file @
5a3df0da
...
...
@@ -10,8 +10,6 @@ mypy == 0.790
# Tools for unit tests & coverage.
pytest == 5.4.1
pytest-cov == 2.10.0
pytest-mpi == 0.4
pytest-timeout == 1.4.2
mpi4py == 3.0.3
remote-pdb >= 2.1.0
parameterized >= 0.8.1
tests/nn/moe/test_moe_layer.py
View file @
5a3df0da
...
...
@@ -3,44 +3,49 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
o
s
import
functool
s
import
tempfile
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
fairscale.nn
import
MOELayer
,
Top2Gate
from
fairscale.utils.testing
import
torch_version
skip_if_no_cuda
=
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
pytestmark
=
pytest
.
mark
.
skipif
(
not
(
torch
.
cuda
.
is_available
()
and
torch_version
()
>=
(
1
,
8
,
0
)),
reason
=
"cuda and torch>=1.8.0 required"
)
BACKEND
=
dist
.
Backend
.
NCCL
if
torch
.
cuda
.
is_available
()
else
dist
.
Backend
.
GLOO
# type: ignore
devices
=
[
"cuda"
]
if
torch
.
cuda
.
is_available
():
devices
=
[
"cpu"
,
"cuda"
]
else
:
devices
=
[
"cpu"
]
URL
=
"file://"
+
tempfile
.
mkstemp
()[
1
]
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
# torch 1.5 compatibility
def
pg_worker
(
rank
,
world_size
,
init_file
,
func
,
*
args
):
init_url
=
"file://"
+
init_file
dist
.
init_process_group
(
backend
=
dist
.
Backend
.
NCCL
,
rank
=
rank
,
world_size
=
world_size
,
init_method
=
init_url
)
torch
.
cuda
.
set_device
(
rank
)
dist
.
all_reduce
(
torch
.
zeros
(
1
).
cuda
())
func
(
*
args
)
dist
.
destroy_process_group
()
if
"OMPI_COMM_WORLD_SIZE"
in
os
.
environ
:
dist
.
init_process_group
(
backend
=
dist
.
Backend
.
MPI
,
init_method
=
URL
)
def
pg_test
(
world_size
=
torch
.
cuda
.
device_count
()):
def
decorator
(
func
):
@
functools
.
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
tempfile_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
pg_worker
,
args
=
(
world_size
,
tempfile_name
,
func
,
*
kwargs
.
values
()),
nprocs
=
world_size
)
def
setup_module
(
module
):
if
"OMPI_COMM_WORLD_SIZE"
not
in
os
.
environ
:
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
0
,
world_size
=
1
,
init_method
=
URL
)
globals
()[
"test_"
+
func
.
__name__
]
=
wrapper
return
func
def
teardown_module
(
module
):
if
"OMPI_COMM_WORLD_SIZE"
not
in
os
.
environ
:
torch
.
distributed
.
destroy_process_group
()
return
decorator
@
pg_test
(
world_size
=
1
)
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
test_
create
(
device
):
def
create
(
device
):
model_dim
=
8
num_experts
=
4
gate
=
Top2Gate
(
model_dim
,
num_experts
)
...
...
@@ -48,8 +53,9 @@ def test_create(device):
moe
=
MOELayer
(
gate
,
expert
).
to
(
device
)
@
pg_test
(
world_size
=
1
)
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
test_
expert_params
(
device
):
def
expert_params
(
device
):
model_dim
=
8
num_experts
=
4
gate
=
Top2Gate
(
model_dim
,
num_experts
)
...
...
@@ -59,9 +65,9 @@ def test_expert_params(device):
assert
p
.
expert
is
True
@
p
y
test
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
]
)
def
test_
forward
(
device
):
@
p
g_
test
()
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
forward
(
device
):
model_dim
=
8
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
input
=
torch
.
randn
(
4
,
16
,
model_dim
).
to
(
device
)
...
...
@@ -76,9 +82,9 @@ def test_forward(device):
assert
torch
.
allclose
(
input
,
output
)
@
p
y
test
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
]
)
def
test_
forward_multi
(
device
):
@
p
g_
test
()
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
forward_multi
(
device
):
torch
.
set_printoptions
(
threshold
=
5000
)
num_local_experts
=
4
model_dim
=
4
...
...
@@ -117,9 +123,9 @@ class RoundRobinGate(torch.nn.Module):
return
0.0
,
output
,
output
.
bool
()
@
p
y
test
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
]
)
def
test_
forward_routing
(
device
):
@
p
g_
test
()
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
forward_routing
(
device
):
model_dim
=
8
num_experts
=
dist
.
get_world_size
()
input
=
torch
.
randn
(
4
,
16
,
model_dim
).
to
(
device
)
...
...
@@ -138,9 +144,9 @@ def test_forward_routing(device):
assert
torch
.
allclose
(
input
[:,
i
]
*
(
expert
+
1
),
output
[:,
i
])
@
p
y
test
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
]
)
def
test_
forward_routing_multi
(
device
):
@
p
g_
test
()
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
forward_routing_multi
(
device
):
model_dim
=
8
num_local_experts
=
4
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
*
num_local_experts
...
...
@@ -163,9 +169,9 @@ def test_forward_routing_multi(device):
assert
torch
.
allclose
(
input
[:,
i
]
*
(
expert
+
1
),
output
[:,
i
])
@
p
y
test
.
mark
.
mpi
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
]
)
def
test_
backward
(
device
):
@
p
g_
test
()
@
pytest
.
mark
.
parametrize
(
"device"
,
devices
)
def
backward
(
device
):
loss
=
torch
.
nn
.
MSELoss
()
model_dim
=
8
num_experts
=
dist
.
get_world_size
(
dist
.
group
.
WORLD
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment