Unverified Commit 7abaa2be authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[hotfix] Catching properly a given test failing if not enough gpus (#274)

* catching properly a given test failing if not enough gpus
parent 60c8de4a
......@@ -42,7 +42,7 @@ from torch.distributed import rpc
import torch.multiprocessing as mp
import torch.nn as nn
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
......@@ -82,11 +82,13 @@ def torch_version() -> Tuple[int, ...]:
return tuple(int(n) for n in numbering)
def dist_init(rank: int, world_size: int, filename: str) -> None:
def dist_init(rank: int, world_size: int, filename: str) -> bool:
"""
Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated
tests to be run concurrently.
Return false if not enough GPUs present in the system.
.. warning: This limits the usecase to all ranks being on the same node
"""
......@@ -97,6 +99,11 @@ def dist_init(rank: int, world_size: int, filename: str) -> None:
if torch_version() >= (1, 6, 0):
backend = "nccl" if torch.cuda.is_available() else "gloo"
if backend == "nccl" and torch.cuda.device_count() < world_size:
logging.warning("Requested world size cannot be reached on this machine, not enough GPUs")
return False
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=url)
# New file for RPC init
......@@ -121,6 +128,8 @@ def dist_init(rank: int, world_size: int, filename: str) -> None:
if torch.cuda.is_available() and torch.cuda.device_count():
torch.cuda.set_device(rank % torch.cuda.device_count())
return True
def get_worker_map() -> Dict[Any, Any]:
return {rank: f"Test{rank}" for rank in range(dist.get_world_size())}
......@@ -134,10 +143,6 @@ def get_world_sizes() -> List[int]:
def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_world_sizes(), args: Any = []) -> None:
for world_size in world_sizes:
if torch.cuda.is_available() and torch.cuda.device_count() < world_size:
logging.warning("Requested world size cannot be reached on this machine, not enough GPUs")
continue
filename = tempfile.mkstemp()[1]
mp.spawn(test_func, args=(world_size, filename, *args), nprocs=world_size, join=True) # type: ignore
......@@ -145,7 +150,9 @@ def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_
def worker_process(rank: int, world_size: int, filename: str, func: Callable, args: Any, error_queue: Any) -> None:
"""Main function for unit tests launced with torch_spawn"""
dist_init(rank, world_size, filename)
if not dist_init(rank, world_size, filename):
return
kwargs = {}
if "OMPI_COMM_WORLD_RANK" not in os.environ:
kwargs["pipeline_backend"] = "gloo"
......@@ -167,6 +174,7 @@ def worker_process(rank: int, world_size: int, filename: str, func: Callable, ar
def teardown() -> None:
destroy_model_parallel()
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
try:
......
......@@ -449,6 +449,7 @@ def exception(pipeline_style):
# FIXME(tom) should probably signal to all hosts in group to stop
@torch_spawn([4])
@pytest.mark.skipif(torch.cuda.is_available() and torch.cuda.device_count() < 4, reason="Not enough GPUs")
@pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def exception_early_stop_asap(pipeline_style):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment