Unverified Commit 438aa017 authored by Shaden Smith's avatar Shaden Smith Committed by GitHub
Browse files

Enables NCCL backend in @distributed_test (#13)

* Enables NCCL backend in @distributed_test

* Adds pytest-forked to avoid CUDA re-initialization issue.

* paste typo

* transcription typo
parent 188f7d4e
......@@ -14,13 +14,14 @@ model convergence tests are found in `tests/model/`.
### Unit Tests
[PyTest](https://docs.pytest.org/en/latest/) is used to execute tests. PyTest can be
installed from PyPI via `pip install pytest`. Simply invoke `pytest` to run the unit
tests:
installed from PyPI via `pip install pytest`. Simply invoke `pytest --forked` to run the
unit tests:
pytest tests/unit/
pytest --forked tests/unit/
You can also provide the `-v` flag to `pytest` to see additional information about the
tests.
tests. Note that [pytest-forked](https://github.com/pytest-dev/pytest-forked) and the
`--forked` flag are required to test CUDA functionality in distributed tests.
### Model Tests
To execute model tests, first [install DeepSpeed](#installation). The
......
......@@ -41,8 +41,7 @@ jobs:
displayName: 'Code linter'
- script: |
pip install --user pytest
pytest --verbose tests/unit/
pytest --forked --verbose tests/unit/
displayName: 'Unit tests'
- script: |
......
......@@ -7,11 +7,11 @@ from torch.multiprocessing import Process
import pytest
# Worker timeout _after_ the first worker has completed.
DEEPSPEED_UNIT_WORKER_TIMEOUT = 5
# Worker timeout *after* the first worker has completed.
DEEPSPEED_UNIT_WORKER_TIMEOUT = 10
def distributed_test(world_size=2, backend='gloo'):
def distributed_test(world_size=2, backend='nccl'):
"""A decorator for executing a function (e.g., a unit test) in a distributed manner.
This decorator manages the spawning and joining of processes, initialization of
torch.distributed, and catching of errors.
......@@ -38,9 +38,8 @@ def distributed_test(world_size=2, backend='gloo'):
rank=local_rank,
world_size=num_procs)
# XXX temporarily disabled due to CUDA runtime error?
#if torch.cuda.is_available():
# torch.cuda.set_device(local_rank)
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
run_func(*func_args, **func_kwargs)
......
......@@ -29,9 +29,10 @@ def test_dist_args(number, color):
_test_dist_args_helper(number, color=color)
@distributed_test(world_size=2)
@distributed_test(world_size=[1, 2, 4])
def test_dist_allreduce():
x = torch.ones(1, 3) * (dist.get_rank() + 1)
result = torch.ones(1, 3) * 3
x = torch.ones(1, 3).cuda() * (dist.get_rank() + 1)
sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2
result = torch.ones(1, 3).cuda() * sum_of_ranks
dist.all_reduce(x)
assert torch.all(x == result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment