Unverified Commit 5a3df0da authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[test] modify MOE tests to use NCCL (#570)

NCCL all_to_all is now supported in PyTorch (since v1.8.0)

Fixes: #548
parent 60694da1
...@@ -66,17 +66,12 @@ install_dep_160: &install_dep_160 ...@@ -66,17 +66,12 @@ install_dep_160: &install_dep_160
- run: - run:
name: Install Dependencies with torch 1.6.0 name: Install Dependencies with torch 1.6.0
command: | command: |
# make sure that apt-get retries if needed
sudo sh -c "echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80-retries"
sudo apt-get update -y
sudo apt-get install -y libopenmpi-dev
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip # check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.6 && exit 0; fi if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.6 && exit 0; fi
# start installing # start installing
pip install --progress-bar off torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html pip install --progress-bar off torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmarks.txt pip install --progress-bar off -r requirements-benchmarks.txt
pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "6"], "wrong torch version"' python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "6"], "wrong torch version"'
python -m torch.utils.collect_env python -m torch.utils.collect_env
...@@ -86,17 +81,12 @@ install_dep_171: &install_dep_171 ...@@ -86,17 +81,12 @@ install_dep_171: &install_dep_171
- run: - run:
name: Install Dependencies with torch 1.7.1 name: Install Dependencies with torch 1.7.1
command: | command: |
# make sure that apt-get retries if needed
sudo sh -c "echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80-retries"
sudo apt-get update -y
sudo apt-get install -y libopenmpi-dev
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip # check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.7 && exit 0; fi if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.7 && exit 0; fi
# start installing # start installing
pip install --progress-bar off torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html pip install --progress-bar off torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html
pip install --progress-bar off -r requirements-test.txt pip install --progress-bar off -r requirements-test.txt
pip install --progress-bar off -r requirements-benchmarks.txt pip install --progress-bar off -r requirements-benchmarks.txt
pip install --progress-bar off git+https://github.com/msbaines/torch_pg.git@c85c96f#egg=torch-pg
python -c 'import torch; print("Torch version:", torch.__version__)' python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "7"], "wrong torch version"' python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "7"], "wrong torch version"'
python -m torch.utils.collect_env python -m torch.utils.collect_env
...@@ -106,10 +96,6 @@ install_dep_181: &install_dep_181 ...@@ -106,10 +96,6 @@ install_dep_181: &install_dep_181
- run: - run:
name: Install Dependencies with torch 1.8.1 name: Install Dependencies with torch 1.8.1
command: | command: |
# make sure that apt-get retries if needed
sudo sh -c "echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80-retries"
sudo apt-get update -y
sudo apt-get install -y libopenmpi-dev
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip # check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi
# start installing # start installing
...@@ -125,10 +111,6 @@ install_dep_190: &install_dep_190 ...@@ -125,10 +111,6 @@ install_dep_190: &install_dep_190
- run: - run:
name: Install Dependencies with torch 1.9.0 name: Install Dependencies with torch 1.9.0
command: | command: |
# make sure that apt-get retries if needed
sudo sh -c "echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80-retries"
sudo apt-get update -y
sudo apt-get install -y libopenmpi-dev
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip # check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.8 && exit 0; fi
# start installing # start installing
...@@ -184,13 +166,6 @@ upload_coverage: &upload_coverage ...@@ -184,13 +166,6 @@ upload_coverage: &upload_coverage
file: 'coverage.xml' file: 'coverage.xml'
token: $CODECOV_TOKEN token: $CODECOV_TOKEN
run_mpi_unittests: &run_mpi_unittests
- run:
name: Run MPI Unit Tests
command: |
mpirun -n 4 python -m pytest -p torch_pg.pytest --only-mpi --junitxml=test-results/junit.xml --verbose tests/nn/moe
run_pipe_benchmark: &run_pipe_benchmark run_pipe_benchmark: &run_pipe_benchmark
- run: - run:
name: Run Pipe Benchmark name: Run Pipe Benchmark
...@@ -276,14 +251,14 @@ jobs: ...@@ -276,14 +251,14 @@ jobs:
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-cpu-py37-180-1-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-cpu-py37-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_171 - <<: *install_dep_181
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-cpu-py37-180-1-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-cpu-py37-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
...@@ -292,7 +267,6 @@ jobs: ...@@ -292,7 +267,6 @@ jobs:
- <<: *run_mypy - <<: *run_mypy
- <<: *run_flake8 - <<: *run_flake8
- <<: *run_unittests - <<: *run_unittests
- <<: *run_mpi_unittests
- <<: *run_doc_build - <<: *run_doc_build
- store_test_results: - store_test_results:
...@@ -311,13 +285,13 @@ jobs: ...@@ -311,13 +285,13 @@ jobs:
# Cache the venv directory that contains dependencies # Cache the venv directory that contains dependencies
- restore_cache: - restore_cache:
keys: keys:
- cache-key-cpu-py38-180-1-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} - cache-key-cpu-py38-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_dep_171 - <<: *install_dep_181
- save_cache: - save_cache:
paths: paths:
- ~/venv - ~/venv
key: cache-key-cpu-py38-180-1-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}} key: cache-key-cpu-py38-181-{{ checksum "setup.py"}}-{{ checksum "requirements-test.txt"}}
- <<: *install_repo - <<: *install_repo
...@@ -326,7 +300,6 @@ jobs: ...@@ -326,7 +300,6 @@ jobs:
- <<: *run_mypy - <<: *run_mypy
- <<: *run_flake8 - <<: *run_flake8
- <<: *run_unittests - <<: *run_unittests
- <<: *run_mpi_unittests
- <<: *run_doc_build - <<: *run_doc_build
- store_test_results: - store_test_results:
...@@ -361,7 +334,6 @@ jobs: ...@@ -361,7 +334,6 @@ jobs:
- <<: *run_mypy - <<: *run_mypy
- <<: *run_flake8 - <<: *run_flake8
- <<: *run_unittests - <<: *run_unittests
# TODO(msb) - <<: *run_mpi_unittests
- <<: *run_doc_build - <<: *run_doc_build
- store_test_results: - store_test_results:
......
...@@ -10,8 +10,6 @@ mypy == 0.790 ...@@ -10,8 +10,6 @@ mypy == 0.790
# Tools for unit tests & coverage. # Tools for unit tests & coverage.
pytest == 5.4.1 pytest == 5.4.1
pytest-cov == 2.10.0 pytest-cov == 2.10.0
pytest-mpi == 0.4
pytest-timeout == 1.4.2 pytest-timeout == 1.4.2
mpi4py == 3.0.3
remote-pdb >= 2.1.0 remote-pdb >= 2.1.0
parameterized >= 0.8.1 parameterized >= 0.8.1
...@@ -3,44 +3,49 @@ ...@@ -3,44 +3,49 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os import functools
import tempfile import tempfile
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp
from fairscale.nn import MOELayer, Top2Gate from fairscale.nn import MOELayer, Top2Gate
from fairscale.utils.testing import torch_version
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") pytestmark = pytest.mark.skipif(
not (torch.cuda.is_available() and torch_version() >= (1, 8, 0)), reason="cuda and torch>=1.8.0 required"
)
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore devices = ["cuda"]
if torch.cuda.is_available():
devices = ["cpu", "cuda"]
else:
devices = ["cpu"]
URL = "file://" + tempfile.mkstemp()[1]
os.environ["MASTER_ADDR"] = "localhost" def pg_worker(rank, world_size, init_file, func, *args):
os.environ["MASTER_PORT"] = "29501" # torch 1.5 compatibility init_url = "file://" + init_file
dist.init_process_group(backend=dist.Backend.NCCL, rank=rank, world_size=world_size, init_method=init_url)
torch.cuda.set_device(rank)
dist.all_reduce(torch.zeros(1).cuda())
func(*args)
dist.destroy_process_group()
if "OMPI_COMM_WORLD_SIZE" in os.environ:
dist.init_process_group(backend=dist.Backend.MPI, init_method=URL)
def pg_test(world_size=torch.cuda.device_count()):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
tempfile_name = tempfile.mkstemp()[1]
mp.spawn(pg_worker, args=(world_size, tempfile_name, func, *kwargs.values()), nprocs=world_size)
def setup_module(module): globals()["test_" + func.__name__] = wrapper
if "OMPI_COMM_WORLD_SIZE" not in os.environ: return func
dist.init_process_group(backend=BACKEND, rank=0, world_size=1, init_method=URL)
return decorator
def teardown_module(module):
if "OMPI_COMM_WORLD_SIZE" not in os.environ:
torch.distributed.destroy_process_group()
@pg_test(world_size=1)
@pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("device", devices)
def test_create(device): def create(device):
model_dim = 8 model_dim = 8
num_experts = 4 num_experts = 4
gate = Top2Gate(model_dim, num_experts) gate = Top2Gate(model_dim, num_experts)
...@@ -48,8 +53,9 @@ def test_create(device): ...@@ -48,8 +53,9 @@ def test_create(device):
moe = MOELayer(gate, expert).to(device) moe = MOELayer(gate, expert).to(device)
@pg_test(world_size=1)
@pytest.mark.parametrize("device", devices) @pytest.mark.parametrize("device", devices)
def test_expert_params(device): def expert_params(device):
model_dim = 8 model_dim = 8
num_experts = 4 num_experts = 4
gate = Top2Gate(model_dim, num_experts) gate = Top2Gate(model_dim, num_experts)
...@@ -59,9 +65,9 @@ def test_expert_params(device): ...@@ -59,9 +65,9 @@ def test_expert_params(device):
assert p.expert is True assert p.expert is True
@pytest.mark.mpi @pg_test()
@pytest.mark.parametrize("device", ["cpu"]) @pytest.mark.parametrize("device", devices)
def test_forward(device): def forward(device):
model_dim = 8 model_dim = 8
num_experts = dist.get_world_size(dist.group.WORLD) num_experts = dist.get_world_size(dist.group.WORLD)
input = torch.randn(4, 16, model_dim).to(device) input = torch.randn(4, 16, model_dim).to(device)
...@@ -76,9 +82,9 @@ def test_forward(device): ...@@ -76,9 +82,9 @@ def test_forward(device):
assert torch.allclose(input, output) assert torch.allclose(input, output)
@pytest.mark.mpi @pg_test()
@pytest.mark.parametrize("device", ["cpu"]) @pytest.mark.parametrize("device", devices)
def test_forward_multi(device): def forward_multi(device):
torch.set_printoptions(threshold=5000) torch.set_printoptions(threshold=5000)
num_local_experts = 4 num_local_experts = 4
model_dim = 4 model_dim = 4
...@@ -117,9 +123,9 @@ class RoundRobinGate(torch.nn.Module): ...@@ -117,9 +123,9 @@ class RoundRobinGate(torch.nn.Module):
return 0.0, output, output.bool() return 0.0, output, output.bool()
@pytest.mark.mpi @pg_test()
@pytest.mark.parametrize("device", ["cpu"]) @pytest.mark.parametrize("device", devices)
def test_forward_routing(device): def forward_routing(device):
model_dim = 8 model_dim = 8
num_experts = dist.get_world_size() num_experts = dist.get_world_size()
input = torch.randn(4, 16, model_dim).to(device) input = torch.randn(4, 16, model_dim).to(device)
...@@ -138,9 +144,9 @@ def test_forward_routing(device): ...@@ -138,9 +144,9 @@ def test_forward_routing(device):
assert torch.allclose(input[:, i] * (expert + 1), output[:, i]) assert torch.allclose(input[:, i] * (expert + 1), output[:, i])
@pytest.mark.mpi @pg_test()
@pytest.mark.parametrize("device", ["cpu"]) @pytest.mark.parametrize("device", devices)
def test_forward_routing_multi(device): def forward_routing_multi(device):
model_dim = 8 model_dim = 8
num_local_experts = 4 num_local_experts = 4
num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts num_experts = dist.get_world_size(dist.group.WORLD) * num_local_experts
...@@ -163,9 +169,9 @@ def test_forward_routing_multi(device): ...@@ -163,9 +169,9 @@ def test_forward_routing_multi(device):
assert torch.allclose(input[:, i] * (expert + 1), output[:, i]) assert torch.allclose(input[:, i] * (expert + 1), output[:, i])
@pytest.mark.mpi @pg_test()
@pytest.mark.parametrize("device", ["cpu"]) @pytest.mark.parametrize("device", devices)
def test_backward(device): def backward(device):
loss = torch.nn.MSELoss() loss = torch.nn.MSELoss()
model_dim = 8 model_dim = 8
num_experts = dist.get_world_size(dist.group.WORLD) num_experts = dist.get_world_size(dist.group.WORLD)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment