Unverified Commit b52041d9 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] MPI init for unit tests (#316)

* using a global variable to share the init filename across processes
parent ce2f64f9
...@@ -53,6 +53,8 @@ skip_if_single_gpu = pytest.mark.skipif( ...@@ -53,6 +53,8 @@ skip_if_single_gpu = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="multiple GPUs required" not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="multiple GPUs required"
) )
_, filename_mpi = tempfile.mkstemp()
class IdentityLayer(torch.nn.Module): class IdentityLayer(torch.nn.Module):
def __init__(self, size: int, scale: float = 1.0) -> None: def __init__(self, size: int, scale: float = 1.0) -> None:
...@@ -241,10 +243,12 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable: ...@@ -241,10 +243,12 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable:
error_queue = multiprocessing.get_context("spawn").SimpleQueue() error_queue = multiprocessing.get_context("spawn").SimpleQueue()
if "OMPI_COMM_WORLD_RANK" in os.environ: if "OMPI_COMM_WORLD_RANK" in os.environ:
global filename_mpi
os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"]
_, filename = tempfile.mkstemp() torch.distributed.init_process_group("mpi", init_method=f"file://{filename_mpi}")
torch.distributed.init_process_group("mpi", init_method=f"file://{filename}")
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
destroy_model_parallel() destroy_model_parallel()
initialize_model_parallel(1, world_size) initialize_model_parallel(1, world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment