Unverified Commit b640cab5 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore] Move all unit tests dist init to being file based (#272)

* file based dist init
* nicer handling of broken world sizes vs. number of available GPUs, do not break but warn out
parent 290afecd
......@@ -31,6 +31,7 @@ import logging
import multiprocessing
import os
import random
import tempfile
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy
......@@ -63,7 +64,7 @@ def set_random_seed(seed: int) -> None:
def torch_version() -> Tuple[int, ...]:
numbering = torch.__version__.split(".")
numbering = torch.__version__.split("+")[0].split(".")
assert len(numbering) == 3
......@@ -81,35 +82,41 @@ def torch_version() -> Tuple[int, ...]:
return tuple(int(n) for n in numbering)
def dist_init(rank: int, world_size: int, hostname: Optional[str] = None) -> None:
if hostname is None:
hostname = "localhost"
print(f"dist init r={rank}, world={world_size}, host={hostname}")
os.environ["MASTER_ADDR"] = hostname
os.environ["MASTER_PORT"] = "10638"
def dist_init(rank: int, world_size: int, filename: str) -> None:
"""
Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated
tests to be run concurrently.
.. warning: This limits the usecase to all ranks being on the same node
"""
print(f"dist init r={rank}, world={world_size}")
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank)
url = "file://" + filename
if torch_version() >= (1, 6, 0):
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
backend = "nccl" if torch.cuda.is_available() else "gloo"
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=init_method)
os.environ["MASTER_ADDR"] = hostname
os.environ["MASTER_PORT"] = "10639"
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=url)
# New file for RPC init
filename_rpc = filename + "_rpc"
open(filename_rpc, "w")
url = "file://" + filename_rpc
rpc.init_rpc(
f"Test{rank}",
rank=rank,
world_size=world_size,
backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=init_method),
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=url),
)
else:
if world_size > 1:
rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size)
else:
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, init_method=url)
if torch.cuda.is_available() and torch.cuda.device_count():
torch.cuda.set_device(rank % torch.cuda.device_count())
......@@ -125,14 +132,20 @@ def get_world_sizes() -> List[int]:
def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_world_sizes(), args: Any = []) -> None:
for world_size in world_sizes:
mp.spawn(test_func, args=(world_size, *args), nprocs=world_size, join=True) # type: ignore
if torch.cuda.is_available() and torch.cuda.device_count() < world_size:
logging.warning("Requested world size cannot be reached on this machine, not enough GPUs")
continue
filename = tempfile.mkstemp()[1]
mp.spawn(test_func, args=(world_size, filename, *args), nprocs=world_size, join=True) # type: ignore
def worker_process(rank: int, world_size: int, func: Callable, args: Any, error_queue: Any) -> None:
def worker_process(rank: int, world_size: int, filename: str, func: Callable, args: Any, error_queue: Any) -> None:
"""Main function for unit tests launced with torch_spawn"""
dist_init(rank, world_size)
dist_init(rank, world_size, filename)
kwargs = {}
if "OMPI_COMM_WORLD_RANK" not in os.environ:
kwargs["pipeline_backend"] = "gloo"
......@@ -145,8 +158,22 @@ def worker_process(rank: int, world_size: int, func: Callable, args: Any, error_
if e.__class__.__name__ == "Skipped":
error_queue.put(str(e))
return
# Make sure that the group is properly destroyed, even for tests which check for exceptions being raised
teardown()
raise e
teardown()
def teardown() -> None:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
try:
torch.distributed.rpc.shutdown()
except Exception:
pass
def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable:
if world_sizes is None:
......@@ -187,7 +214,9 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable:
if world_size in world_sizes:
try:
func(*args)
teardown()
except BaseException as e:
teardown()
print(f"got exception {e} from test")
import traceback
......
......@@ -19,6 +19,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
......@@ -49,8 +50,8 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed):
return loss, identity.weight.grad
def run_test_cross_entropy(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
def run_test_cross_entropy(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size, filename)
if torch.distributed.get_rank() == 0:
print("> testing cross entropy with model parallel size {} ...".format(model_parallel_size))
......
......@@ -19,14 +19,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from fairscale.nn.model_parallel import initialize as mpu
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
def run_test_initialize_model_parallel(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
def run_test_initialize_model_parallel(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size, filename)
if torch.distributed.get_rank() == 0:
print("> testing initialize_model_parallel with size {} ...".format(model_parallel_size))
......@@ -62,8 +63,8 @@ def run_test_initialize_model_parallel(rank, model_parallel_size):
print(">> passed the test :-)")
def run_test_get_model_parallel_src_rank(rank, model_parallel_size_):
dist_init(rank, model_parallel_size_)
def run_test_get_model_parallel_src_rank(rank, model_parallel_size_, filename):
dist_init(rank, model_parallel_size_, filename)
if torch.distributed.get_rank() == 0:
print("> testing get_model_parallel_src_rank with size {} ...".format(model_parallel_size_))
......
......@@ -20,6 +20,7 @@
# limitations under the License.
import os
import tempfile
import pytest
import torch
......@@ -34,8 +35,8 @@ from fairscale.nn.pipe import Pipe
from fairscale.utils.testing import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes, torch_spawn
def run_test_parallel_embedding(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
def run_test_parallel_embedding(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size, filename)
if torch.distributed.get_rank() == 0:
print("> testing parallel embedding with model parallel size {} ...".format(model_parallel_size))
......@@ -104,8 +105,8 @@ def run_test_parallel_embedding(rank, model_parallel_size):
print(">> passed the test :-)")
def run_test_initialize_affine_weight(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
def run_test_initialize_affine_weight(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size, filename)
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
......@@ -180,8 +181,8 @@ class IdentityLayer2D(torch.nn.Module):
return self.weight
def run_test_column_parallel_linear(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
def run_test_column_parallel_linear(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size, filename)
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
......@@ -241,8 +242,8 @@ def run_test_column_parallel_linear(rank, model_parallel_size):
print(" >> passed the test :-)")
def run_test_row_parallel_linear(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
def run_test_row_parallel_linear(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size, filename)
mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0:
......@@ -301,14 +302,14 @@ def run_test_row_parallel_linear(rank, model_parallel_size):
print(" >> passed the test :-)")
def run_test_pipe(rank, world_size, skip_dist_init=False):
def run_test_pipe(rank, world_size, filename, skip_dist_init=False):
pipe_world_size = 2
if world_size == 1:
return
if not skip_dist_init:
dist_init(rank, world_size)
dist_init(rank, world_size, filename)
else:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29502"
......@@ -566,7 +567,8 @@ def test_row_parallel():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def mpi_pipe():
mpu.destroy_model_parallel()
run_test_pipe(torch.distributed.get_rank(), torch.distributed.get_world_size(), skip_dist_init=True)
tempfile_init = tempfile.mkstemp()[1]
run_test_pipe(torch.distributed.get_rank(), torch.distributed.get_world_size(), tempfile_init, skip_dist_init=True)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
......
......@@ -27,8 +27,8 @@ from fairscale.nn.model_parallel.random import get_cuda_rng_tracker, model_paral
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
def run_test_set_cuda_rng_state(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
def run_test_set_cuda_rng_state(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size, filename)
if torch.distributed.get_rank() == 0:
print("> testing set_rng_state with size {} ...".format(model_parallel_size))
......@@ -96,8 +96,8 @@ def run_test_set_cuda_rng_state(rank, model_parallel_size):
print(">> passed the test :-)")
def run_test_cuda_rng_tracker(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
def run_test_cuda_rng_tracker(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size, filename)
if torch.distributed.get_rank() == 0:
print("> testing cuda rng tracker with size {} ...".format(model_parallel_size))
......@@ -172,8 +172,8 @@ def run_test_cuda_rng_tracker(rank, model_parallel_size):
print(">> passed the test :-)")
def run_test_model_parallel_cuda_manual_seed(rank, model_parallel_size):
dist_init(rank, model_parallel_size)
def run_test_model_parallel_cuda_manual_seed(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size, filename)
if torch.distributed.get_rank() == 0:
print("> testing model parallel cuda manual seed with size {} ...".format(model_parallel_size))
......
......@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
import os
import tempfile
import pytest
import torch
......@@ -20,15 +21,15 @@ if torch.cuda.is_available():
else:
devices = ["cpu"]
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
URL = "file://" + tempfile.mkstemp()[1]
if "OMPI_COMM_WORLD_SIZE" in os.environ:
dist.init_process_group(backend=dist.Backend.MPI)
dist.init_process_group(backend=dist.Backend.MPI, init_method=URL)
def setup_module(module):
if "OMPI_COMM_WORLD_SIZE" not in os.environ:
dist.init_process_group(backend=BACKEND, rank=0, world_size=1)
dist.init_process_group(backend=BACKEND, rank=0, world_size=1, init_method=URL)
def teardown_module(module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment