Unverified Commit b61a2217 authored by Shaden Smith's avatar Shaden Smith Committed by GitHub
Browse files

Distributed testing (#6)

* Adds distributed_test decorator and some unit tests.

* Setting NCCL backend.

* Parametrizes test.

* rank -> local_rank

* Temporarily disable CUDA initialization.
parent 97f36c11
import os
import time
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import pytest
# Worker timeout _after_ the first worker has completed.
DEEPSPEED_UNIT_WORKER_TIMEOUT = 5
def distributed_test(world_size=2):
"""A decorator for executing a function (e.g., a unit test) in a distributed manner.
This decorator manages the spawning and joining of processes, initialization of
torch.distributed, and catching of errors.
Usage example:
@distributed_test(worker_size=[2,3])
def my_test():
rank = dist.get_rank()
world_size = dist.get_world_size()
assert(rank < world_size)
Arguments:
world_size (int or list): number of ranks to spawn. Can be a list to spawn
multiple tests.
"""
def dist_wrap(run_func):
"""Second-level decorator for dist_test. This actually wraps the function. """
def dist_init(local_rank, num_procs, *func_args, **func_kwargs):
"""Initialize torch.distributed and execute the user function. """
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29500'
dist.init_process_group(backend='nccl',
init_method='env://',
rank=local_rank,
world_size=num_procs)
# XXX temporarily disabled due to CUDA runtime error?
#if torch.cuda.is_available():
# torch.cuda.set_device(local_rank)
run_func(*func_args, **func_kwargs)
def dist_launcher(num_procs, *func_args, **func_kwargs):
"""Launch processes and gracefully handle failures. """
# Spawn all workers on subprocesses.
processes = []
for local_rank in range(num_procs):
p = Process(target=dist_init,
args=(local_rank,
num_procs,
*func_args),
kwargs=func_kwargs)
p.start()
processes.append(p)
# Now loop and wait for a test to complete. The spin-wait here isn't a big
# deal because the number of processes will be O(#GPUs) << O(#CPUs).
any_done = False
while not any_done:
for p in processes:
if not p.is_alive():
any_done = True
break
# Wait for all other processes to complete
for p in processes:
p.join(DEEPSPEED_UNIT_WORKER_TIMEOUT)
failed = [(rank, p) for rank, p in enumerate(processes) if p.exitcode != 0]
for rank, p in failed:
# If it still hasn't terminated, kill it because it hung.
if p.exitcode is None:
p.terminate()
pytest.fail(f'Worker {rank} hung.', pytrace=False)
if p.exitcode < 0:
pytest.fail(f'Worker {rank} killed by signal {-p.exitcode}',
pytrace=False)
if p.exitcode > 0:
pytest.fail(f'Worker {rank} exited with code {p.exitcode}',
pytrace=False)
def run_func_decorator(*func_args, **func_kwargs):
"""Entry point for @distributed_test(). """
if isinstance(world_size, int):
dist_launcher(world_size, *func_args, **func_kwargs)
elif isinstance(world_size, list):
for procs in world_size:
dist_launcher(procs, *func_args, **func_kwargs)
time.sleep(0.5)
else:
raise TypeError(f'world_size must be an integer or a list of integers.')
return run_func_decorator
return dist_wrap
import torch.distributed as dist
from common import distributed_test
import pytest
@distributed_test(world_size=3)
def test_init():
assert dist.is_initialized()
assert dist.get_world_size() == 3
assert dist.get_rank() < 3
# Demonstration of pytest's paramaterization
@pytest.mark.parametrize('number,color', [(1138, 'purple')])
def test_dist_args(number, color):
"""Outer test function with inputs from pytest.mark.parametrize(). Uses a distributed
helper function.
"""
@distributed_test(world_size=2)
def _test_dist_args_helper(x, color='red'):
assert dist.get_world_size() == 2
assert x == 1138
assert color == 'purple'
"""Ensure that we can parse args to distributed_test decorated functions. """
_test_dist_args_helper(number, color=color)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment