Unverified Commit 7abaa2be authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[hotfix] Catching properly a given test failing if not enough gpus (#274)

* catching properly a given test failing if not enough gpus
parent 60c8de4a
...@@ -42,7 +42,7 @@ from torch.distributed import rpc ...@@ -42,7 +42,7 @@ from torch.distributed import rpc
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
...@@ -82,11 +82,13 @@ def torch_version() -> Tuple[int, ...]: ...@@ -82,11 +82,13 @@ def torch_version() -> Tuple[int, ...]:
return tuple(int(n) for n in numbering) return tuple(int(n) for n in numbering)
def dist_init(rank: int, world_size: int, filename: str) -> None: def dist_init(rank: int, world_size: int, filename: str) -> bool:
""" """
Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated
tests to be run concurrently. tests to be run concurrently.
Return false if not enough GPUs present in the system.
.. warning: This limits the usecase to all ranks being on the same node .. warning: This limits the usecase to all ranks being on the same node
""" """
...@@ -97,6 +99,11 @@ def dist_init(rank: int, world_size: int, filename: str) -> None: ...@@ -97,6 +99,11 @@ def dist_init(rank: int, world_size: int, filename: str) -> None:
if torch_version() >= (1, 6, 0): if torch_version() >= (1, 6, 0):
backend = "nccl" if torch.cuda.is_available() else "gloo" backend = "nccl" if torch.cuda.is_available() else "gloo"
if backend == "nccl" and torch.cuda.device_count() < world_size:
logging.warning("Requested world size cannot be reached on this machine, not enough GPUs")
return False
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=url) torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=url)
# New file for RPC init # New file for RPC init
...@@ -121,6 +128,8 @@ def dist_init(rank: int, world_size: int, filename: str) -> None: ...@@ -121,6 +128,8 @@ def dist_init(rank: int, world_size: int, filename: str) -> None:
if torch.cuda.is_available() and torch.cuda.device_count(): if torch.cuda.is_available() and torch.cuda.device_count():
torch.cuda.set_device(rank % torch.cuda.device_count()) torch.cuda.set_device(rank % torch.cuda.device_count())
return True
def get_worker_map() -> Dict[Any, Any]: def get_worker_map() -> Dict[Any, Any]:
return {rank: f"Test{rank}" for rank in range(dist.get_world_size())} return {rank: f"Test{rank}" for rank in range(dist.get_world_size())}
...@@ -134,10 +143,6 @@ def get_world_sizes() -> List[int]: ...@@ -134,10 +143,6 @@ def get_world_sizes() -> List[int]:
def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_world_sizes(), args: Any = []) -> None: def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_world_sizes(), args: Any = []) -> None:
for world_size in world_sizes: for world_size in world_sizes:
if torch.cuda.is_available() and torch.cuda.device_count() < world_size:
logging.warning("Requested world size cannot be reached on this machine, not enough GPUs")
continue
filename = tempfile.mkstemp()[1] filename = tempfile.mkstemp()[1]
mp.spawn(test_func, args=(world_size, filename, *args), nprocs=world_size, join=True) # type: ignore mp.spawn(test_func, args=(world_size, filename, *args), nprocs=world_size, join=True) # type: ignore
...@@ -145,7 +150,9 @@ def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_ ...@@ -145,7 +150,9 @@ def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_
def worker_process(rank: int, world_size: int, filename: str, func: Callable, args: Any, error_queue: Any) -> None: def worker_process(rank: int, world_size: int, filename: str, func: Callable, args: Any, error_queue: Any) -> None:
"""Main function for unit tests launced with torch_spawn""" """Main function for unit tests launced with torch_spawn"""
dist_init(rank, world_size, filename) if not dist_init(rank, world_size, filename):
return
kwargs = {} kwargs = {}
if "OMPI_COMM_WORLD_RANK" not in os.environ: if "OMPI_COMM_WORLD_RANK" not in os.environ:
kwargs["pipeline_backend"] = "gloo" kwargs["pipeline_backend"] = "gloo"
...@@ -167,6 +174,7 @@ def worker_process(rank: int, world_size: int, filename: str, func: Callable, ar ...@@ -167,6 +174,7 @@ def worker_process(rank: int, world_size: int, filename: str, func: Callable, ar
def teardown() -> None: def teardown() -> None:
destroy_model_parallel()
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
try: try:
......
...@@ -449,6 +449,7 @@ def exception(pipeline_style): ...@@ -449,6 +449,7 @@ def exception(pipeline_style):
# FIXME(tom) should probably signal to all hosts in group to stop # FIXME(tom) should probably signal to all hosts in group to stop
@torch_spawn([4]) @torch_spawn([4])
@pytest.mark.skipif(torch.cuda.is_available() and torch.cuda.device_count() < 4, reason="Not enough GPUs")
@pytest.mark.xfail(strict=True) @pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
def exception_early_stop_asap(pipeline_style): def exception_early_stop_asap(pipeline_style):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment