Unverified Commit b640cab5 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore] Move all unit tests dist init to being file based (#272)

* file based dist init
* nicer handling of broken world sizes vs. number of available GPUs, do not break but warn out
parent 290afecd
...@@ -31,6 +31,7 @@ import logging ...@@ -31,6 +31,7 @@ import logging
import multiprocessing import multiprocessing
import os import os
import random import random
import tempfile
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy import numpy
...@@ -63,7 +64,7 @@ def set_random_seed(seed: int) -> None: ...@@ -63,7 +64,7 @@ def set_random_seed(seed: int) -> None:
def torch_version() -> Tuple[int, ...]: def torch_version() -> Tuple[int, ...]:
numbering = torch.__version__.split(".") numbering = torch.__version__.split("+")[0].split(".")
assert len(numbering) == 3 assert len(numbering) == 3
...@@ -81,35 +82,41 @@ def torch_version() -> Tuple[int, ...]: ...@@ -81,35 +82,41 @@ def torch_version() -> Tuple[int, ...]:
return tuple(int(n) for n in numbering) return tuple(int(n) for n in numbering)
def dist_init(rank: int, world_size: int, hostname: Optional[str] = None) -> None: def dist_init(rank: int, world_size: int, filename: str) -> None:
if hostname is None: """
hostname = "localhost" Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated
print(f"dist init r={rank}, world={world_size}, host={hostname}") tests to be run concurrently.
os.environ["MASTER_ADDR"] = hostname
os.environ["MASTER_PORT"] = "10638" .. warning: This limits the usecase to all ranks being on the same node
"""
print(f"dist init r={rank}, world={world_size}")
os.environ["WORLD_SIZE"] = str(world_size) os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank) os.environ["RANK"] = str(rank)
url = "file://" + filename
if torch_version() >= (1, 6, 0): if torch_version() >= (1, 6, 0):
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
backend = "nccl" if torch.cuda.is_available() else "gloo" backend = "nccl" if torch.cuda.is_available() else "gloo"
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=init_method) torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=url)
os.environ["MASTER_ADDR"] = hostname
os.environ["MASTER_PORT"] = "10639" # New file for RPC init
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" filename_rpc = filename + "_rpc"
open(filename_rpc, "w")
url = "file://" + filename_rpc
rpc.init_rpc( rpc.init_rpc(
f"Test{rank}", f"Test{rank}",
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
backend=rpc.BackendType.TENSORPIPE, backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=init_method), rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=url),
) )
else: else:
if world_size > 1: if world_size > 1:
rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size) rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size)
else: else:
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size) torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, init_method=url)
if torch.cuda.is_available() and torch.cuda.device_count(): if torch.cuda.is_available() and torch.cuda.device_count():
torch.cuda.set_device(rank % torch.cuda.device_count()) torch.cuda.set_device(rank % torch.cuda.device_count())
...@@ -125,14 +132,20 @@ def get_world_sizes() -> List[int]: ...@@ -125,14 +132,20 @@ def get_world_sizes() -> List[int]:
def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_world_sizes(), args: Any = []) -> None: def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_world_sizes(), args: Any = []) -> None:
for world_size in world_sizes: for world_size in world_sizes:
mp.spawn(test_func, args=(world_size, *args), nprocs=world_size, join=True) # type: ignore if torch.cuda.is_available() and torch.cuda.device_count() < world_size:
logging.warning("Requested world size cannot be reached on this machine, not enough GPUs")
continue
filename = tempfile.mkstemp()[1]
mp.spawn(test_func, args=(world_size, filename, *args), nprocs=world_size, join=True) # type: ignore
def worker_process(rank: int, world_size: int, func: Callable, args: Any, error_queue: Any) -> None: def worker_process(rank: int, world_size: int, filename: str, func: Callable, args: Any, error_queue: Any) -> None:
"""Main function for unit tests launced with torch_spawn""" """Main function for unit tests launced with torch_spawn"""
dist_init(rank, world_size) dist_init(rank, world_size, filename)
kwargs = {} kwargs = {}
if "OMPI_COMM_WORLD_RANK" not in os.environ: if "OMPI_COMM_WORLD_RANK" not in os.environ:
kwargs["pipeline_backend"] = "gloo" kwargs["pipeline_backend"] = "gloo"
...@@ -145,8 +158,22 @@ def worker_process(rank: int, world_size: int, func: Callable, args: Any, error_ ...@@ -145,8 +158,22 @@ def worker_process(rank: int, world_size: int, func: Callable, args: Any, error_
if e.__class__.__name__ == "Skipped": if e.__class__.__name__ == "Skipped":
error_queue.put(str(e)) error_queue.put(str(e))
return return
# Make sure that the group is properly destroyed, even for tests which check for exceptions being raised
teardown()
raise e raise e
teardown()
def teardown() -> None:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
try:
torch.distributed.rpc.shutdown()
except Exception:
pass
def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable: def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable:
if world_sizes is None: if world_sizes is None:
...@@ -187,7 +214,9 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable: ...@@ -187,7 +214,9 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable:
if world_size in world_sizes: if world_size in world_sizes:
try: try:
func(*args) func(*args)
teardown()
except BaseException as e: except BaseException as e:
teardown()
print(f"got exception {e} from test") print(f"got exception {e} from test")
import traceback import traceback
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -49,8 +50,8 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed): ...@@ -49,8 +50,8 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed):
return loss, identity.weight.grad return loss, identity.weight.grad
def run_test_cross_entropy(rank, model_parallel_size): def run_test_cross_entropy(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size) dist_init(rank, model_parallel_size, filename)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("> testing cross entropy with model parallel size {} ...".format(model_parallel_size)) print("> testing cross entropy with model parallel size {} ...".format(model_parallel_size))
......
...@@ -19,14 +19,15 @@ ...@@ -19,14 +19,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch import torch
from fairscale.nn.model_parallel import initialize as mpu from fairscale.nn.model_parallel import initialize as mpu
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
def run_test_initialize_model_parallel(rank, model_parallel_size): def run_test_initialize_model_parallel(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size) dist_init(rank, model_parallel_size, filename)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("> testing initialize_model_parallel with size {} ...".format(model_parallel_size)) print("> testing initialize_model_parallel with size {} ...".format(model_parallel_size))
...@@ -62,8 +63,8 @@ def run_test_initialize_model_parallel(rank, model_parallel_size): ...@@ -62,8 +63,8 @@ def run_test_initialize_model_parallel(rank, model_parallel_size):
print(">> passed the test :-)") print(">> passed the test :-)")
def run_test_get_model_parallel_src_rank(rank, model_parallel_size_): def run_test_get_model_parallel_src_rank(rank, model_parallel_size_, filename):
dist_init(rank, model_parallel_size_) dist_init(rank, model_parallel_size_, filename)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("> testing get_model_parallel_src_rank with size {} ...".format(model_parallel_size_)) print("> testing get_model_parallel_src_rank with size {} ...".format(model_parallel_size_))
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
# limitations under the License. # limitations under the License.
import os import os
import tempfile
import pytest import pytest
import torch import torch
...@@ -34,8 +35,8 @@ from fairscale.nn.pipe import Pipe ...@@ -34,8 +35,8 @@ from fairscale.nn.pipe import Pipe
from fairscale.utils.testing import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes, torch_spawn from fairscale.utils.testing import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes, torch_spawn
def run_test_parallel_embedding(rank, model_parallel_size): def run_test_parallel_embedding(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size) dist_init(rank, model_parallel_size, filename)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("> testing parallel embedding with model parallel size {} ...".format(model_parallel_size)) print("> testing parallel embedding with model parallel size {} ...".format(model_parallel_size))
...@@ -104,8 +105,8 @@ def run_test_parallel_embedding(rank, model_parallel_size): ...@@ -104,8 +105,8 @@ def run_test_parallel_embedding(rank, model_parallel_size):
print(">> passed the test :-)") print(">> passed the test :-)")
def run_test_initialize_affine_weight(rank, model_parallel_size): def run_test_initialize_affine_weight(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size) dist_init(rank, model_parallel_size, filename)
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -180,8 +181,8 @@ class IdentityLayer2D(torch.nn.Module): ...@@ -180,8 +181,8 @@ class IdentityLayer2D(torch.nn.Module):
return self.weight return self.weight
def run_test_column_parallel_linear(rank, model_parallel_size): def run_test_column_parallel_linear(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size) dist_init(rank, model_parallel_size, filename)
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -241,8 +242,8 @@ def run_test_column_parallel_linear(rank, model_parallel_size): ...@@ -241,8 +242,8 @@ def run_test_column_parallel_linear(rank, model_parallel_size):
print(" >> passed the test :-)") print(" >> passed the test :-)")
def run_test_row_parallel_linear(rank, model_parallel_size): def run_test_row_parallel_linear(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size) dist_init(rank, model_parallel_size, filename)
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -301,14 +302,14 @@ def run_test_row_parallel_linear(rank, model_parallel_size): ...@@ -301,14 +302,14 @@ def run_test_row_parallel_linear(rank, model_parallel_size):
print(" >> passed the test :-)") print(" >> passed the test :-)")
def run_test_pipe(rank, world_size, skip_dist_init=False): def run_test_pipe(rank, world_size, filename, skip_dist_init=False):
pipe_world_size = 2 pipe_world_size = 2
if world_size == 1: if world_size == 1:
return return
if not skip_dist_init: if not skip_dist_init:
dist_init(rank, world_size) dist_init(rank, world_size, filename)
else: else:
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29502" os.environ["MASTER_PORT"] = "29502"
...@@ -566,7 +567,8 @@ def test_row_parallel(): ...@@ -566,7 +567,8 @@ def test_row_parallel():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def mpi_pipe(): def mpi_pipe():
mpu.destroy_model_parallel() mpu.destroy_model_parallel()
run_test_pipe(torch.distributed.get_rank(), torch.distributed.get_world_size(), skip_dist_init=True) tempfile_init = tempfile.mkstemp()[1]
run_test_pipe(torch.distributed.get_rank(), torch.distributed.get_world_size(), tempfile_init, skip_dist_init=True)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
......
...@@ -27,8 +27,8 @@ from fairscale.nn.model_parallel.random import get_cuda_rng_tracker, model_paral ...@@ -27,8 +27,8 @@ from fairscale.nn.model_parallel.random import get_cuda_rng_tracker, model_paral
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
def run_test_set_cuda_rng_state(rank, model_parallel_size): def run_test_set_cuda_rng_state(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size) dist_init(rank, model_parallel_size, filename)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("> testing set_rng_state with size {} ...".format(model_parallel_size)) print("> testing set_rng_state with size {} ...".format(model_parallel_size))
...@@ -96,8 +96,8 @@ def run_test_set_cuda_rng_state(rank, model_parallel_size): ...@@ -96,8 +96,8 @@ def run_test_set_cuda_rng_state(rank, model_parallel_size):
print(">> passed the test :-)") print(">> passed the test :-)")
def run_test_cuda_rng_tracker(rank, model_parallel_size): def run_test_cuda_rng_tracker(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size) dist_init(rank, model_parallel_size, filename)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("> testing cuda rng tracker with size {} ...".format(model_parallel_size)) print("> testing cuda rng tracker with size {} ...".format(model_parallel_size))
...@@ -172,8 +172,8 @@ def run_test_cuda_rng_tracker(rank, model_parallel_size): ...@@ -172,8 +172,8 @@ def run_test_cuda_rng_tracker(rank, model_parallel_size):
print(">> passed the test :-)") print(">> passed the test :-)")
def run_test_model_parallel_cuda_manual_seed(rank, model_parallel_size): def run_test_model_parallel_cuda_manual_seed(rank, model_parallel_size, filename):
dist_init(rank, model_parallel_size) dist_init(rank, model_parallel_size, filename)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("> testing model parallel cuda manual seed with size {} ...".format(model_parallel_size)) print("> testing model parallel cuda manual seed with size {} ...".format(model_parallel_size))
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os import os
import tempfile
import pytest import pytest
import torch import torch
...@@ -20,15 +21,15 @@ if torch.cuda.is_available(): ...@@ -20,15 +21,15 @@ if torch.cuda.is_available():
else: else:
devices = ["cpu"] devices = ["cpu"]
os.environ["MASTER_ADDR"] = "localhost" URL = "file://" + tempfile.mkstemp()[1]
os.environ["MASTER_PORT"] = "29501"
if "OMPI_COMM_WORLD_SIZE" in os.environ: if "OMPI_COMM_WORLD_SIZE" in os.environ:
dist.init_process_group(backend=dist.Backend.MPI) dist.init_process_group(backend=dist.Backend.MPI, init_method=URL)
def setup_module(module): def setup_module(module):
if "OMPI_COMM_WORLD_SIZE" not in os.environ: if "OMPI_COMM_WORLD_SIZE" not in os.environ:
dist.init_process_group(backend=BACKEND, rank=0, world_size=1) dist.init_process_group(backend=BACKEND, rank=0, world_size=1, init_method=URL)
def teardown_module(module): def teardown_module(module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment