Unverified Commit 79365ee6 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] Flaky tests (#283)

* adding the pytest timeout plugin to properly root out hanging tests
* removing redundant code, slightly more reasonable timeout, works on single cuda
* finding the root bug for some of the cpu hangs, rpc init
* propagating all the rpc init test changes to the pipe and model parallel tests
parent 7cc8b34a
...@@ -103,7 +103,7 @@ run_coverage: &run_coverage ...@@ -103,7 +103,7 @@ run_coverage: &run_coverage
- run: - run:
name: Run Unit Tests With Coverage name: Run Unit Tests With Coverage
command: | command: |
pytest --cov-report=xml --cov=./ pytest --cov-report=xml --cov=./ --timeout 60
#Uploading test coverage for Python code #Uploading test coverage for Python code
bash <(curl -s https://codecov.io/bash) -f coverage.xml -cF Python bash <(curl -s https://codecov.io/bash) -f coverage.xml -cF Python
...@@ -111,7 +111,7 @@ run_unittests: &run_unittests ...@@ -111,7 +111,7 @@ run_unittests: &run_unittests
- run: - run:
name: Run Unit Tests name: Run Unit Tests
command: | command: |
pytest --junitxml=test-results/junit.xml --verbose pytest --junitxml=test-results/junit.xml --verbose --timeout 60
run_mpi_unittests: &run_mpi_unittests run_mpi_unittests: &run_mpi_unittests
- run: - run:
......
...@@ -45,6 +45,14 @@ import torch.nn as nn ...@@ -45,6 +45,14 @@ import torch.nn as nn
from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
skip_if_no_cuda = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 1, reason="CUDA required"
)
skip_if_single_gpu = pytest.mark.skipif(
not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason="multiple GPUs required"
)
class IdentityLayer(torch.nn.Module): class IdentityLayer(torch.nn.Module):
def __init__(self, size: int, scale: float = 1.0) -> None: def __init__(self, size: int, scale: float = 1.0) -> None:
...@@ -82,7 +90,7 @@ def torch_version() -> Tuple[int, ...]: ...@@ -82,7 +90,7 @@ def torch_version() -> Tuple[int, ...]:
return tuple(int(n) for n in numbering) return tuple(int(n) for n in numbering)
def dist_init(rank: int, world_size: int, filename: str) -> bool: def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "") -> bool:
""" """
Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated
tests to be run concurrently. tests to be run concurrently.
...@@ -106,24 +114,22 @@ def dist_init(rank: int, world_size: int, filename: str) -> bool: ...@@ -106,24 +114,22 @@ def dist_init(rank: int, world_size: int, filename: str) -> bool:
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=url) torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=url)
# New file for RPC init url_rpc = "file://" + filename_rpc
filename_rpc = filename + "_rpc"
open(filename_rpc, "w")
url = "file://" + filename_rpc
rpc.init_rpc( rpc.init_rpc(
f"Test{rank}", f"Test{rank}",
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
backend=rpc.BackendType.TENSORPIPE, backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=url), rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=url_rpc),
) )
else: else:
if world_size > 1: if world_size > 1:
rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size) rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size)
else: elif torch.cuda.is_available():
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, init_method=url) torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, init_method=url)
else:
return False
if torch.cuda.is_available() and torch.cuda.device_count(): if torch.cuda.is_available() and torch.cuda.device_count():
torch.cuda.set_device(rank % torch.cuda.device_count()) torch.cuda.set_device(rank % torch.cuda.device_count())
...@@ -143,25 +149,34 @@ def get_world_sizes() -> List[int]: ...@@ -143,25 +149,34 @@ def get_world_sizes() -> List[int]:
def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_world_sizes(), args: Any = []) -> None: def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_world_sizes(), args: Any = []) -> None:
for world_size in world_sizes: for world_size in world_sizes:
filename = tempfile.mkstemp()[1] _, filename = tempfile.mkstemp()
context = mp.spawn(test_func, args=(world_size, filename, *args), nprocs=world_size, join=False) # type: ignore _, filename_rpc = tempfile.mkstemp()
context.join(timeout=60.0)
# (lefaudeux) Let mp handle the process joining, join=False and handling context has been unstable in the past
mp.spawn(test_func, args=(world_size, filename, filename_rpc, *args), nprocs=world_size, join=True) # type: ignore
def worker_process(rank: int, world_size: int, filename: str, func: Callable, args: Any, error_queue: Any) -> None: def worker_process(
rank: int, world_size: int, filename: str, filename_rpc: str, func: Callable, args: Any, error_queue: Any
) -> None:
"""Main function for unit tests launced with torch_spawn""" """Main function for unit tests launced with torch_spawn"""
if not dist_init(rank, world_size, filename): if not dist_init(rank, world_size, filename, filename_rpc):
logging.warning("failed initializing torch distributed")
return return
kwargs = {} kwargs = {}
if "OMPI_COMM_WORLD_RANK" not in os.environ: if "OMPI_COMM_WORLD_RANK" not in os.environ:
kwargs["pipeline_backend"] = "gloo" kwargs["pipeline_backend"] = "gloo"
initialize_model_parallel(1, world_size, **kwargs) initialize_model_parallel(1, world_size, **kwargs)
try: try:
func(*args) func(*args)
teardown() teardown()
except BaseException as e: except BaseException as e:
logging.warning(f" Rank {rank}: {e}")
# Make sure that the group is properly destroyed, even for tests which check for exceptions being raised # Make sure that the group is properly destroyed, even for tests which check for exceptions being raised
teardown() teardown()
...@@ -176,6 +191,7 @@ def worker_process(rank: int, world_size: int, filename: str, func: Callable, ar ...@@ -176,6 +191,7 @@ def worker_process(rank: int, world_size: int, filename: str, func: Callable, ar
def teardown() -> None: def teardown() -> None:
destroy_model_parallel() destroy_model_parallel()
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
try: try:
...@@ -226,13 +242,12 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable: ...@@ -226,13 +242,12 @@ def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable:
teardown() teardown()
except BaseException as e: except BaseException as e:
teardown() teardown()
print(f"got exception {e} from test")
import traceback import traceback
print(f"{traceback.format_exc()}") print(f"{traceback.format_exc()}")
raise e raise e
else: else:
pytest.skip(f"requested world size doesn't match current world size") pytest.skip("Requested world size doesn't match current world size")
else: else:
spawn_for_all_world_sizes(worker_process, world_sizes, (func, args, error_queue)) spawn_for_all_world_sizes(worker_process, world_sizes, (func, args, error_queue))
...@@ -274,6 +289,10 @@ class _Block(nn.Module): ...@@ -274,6 +289,10 @@ class _Block(nn.Module):
class GPT2(nn.Module): class GPT2(nn.Module):
"""
GPT2 pytorch implementation, for testing purposes in the image-GPT context
Credits: https://github.com/teddykoker/image-gpt"""
def __init__( def __init__(
self, embed_dim: int, num_heads: int, num_layers: int, num_positions: int, num_vocab: int, num_classes: int self, embed_dim: int, num_heads: int, num_layers: int, num_positions: int, num_vocab: int, num_classes: int
) -> None: ) -> None:
......
...@@ -6,6 +6,7 @@ mypy == 0.770 ...@@ -6,6 +6,7 @@ mypy == 0.770
pytest == 5.4.1 pytest == 5.4.1
pytest-cov == 2.10.0 pytest-cov == 2.10.0
pytest-mpi == 0.4 pytest-mpi == 0.4
pytest-timeout == 1.4.2
torchtext == 0.6.0 torchtext == 0.6.0
torch >= 1.5.1 torch >= 1.5.1
torchvision >= 0.6.0 torchvision >= 0.6.0
......
...@@ -13,7 +13,6 @@ import tempfile ...@@ -13,7 +13,6 @@ import tempfile
from typing import List from typing import List
import numpy as np import numpy as np
import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -22,10 +21,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -22,10 +21,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.utils.testing import GPT2 from fairscale.utils.testing import GPT2, skip_if_no_cuda, skip_if_single_gpu
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")
def run_one_step(rank, world_size, backend, device, temp_file_name): def run_one_step(rank, world_size, backend, device, temp_file_name):
......
...@@ -50,8 +50,8 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed): ...@@ -50,8 +50,8 @@ def mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed):
return loss, identity.weight.grad return loss, identity.weight.grad
def run_test_cross_entropy(rank, model_parallel_size, filename): def run_test_cross_entropy(rank, model_parallel_size, filename, filename_rpc):
dist_init(rank, model_parallel_size, filename) dist_init(rank, model_parallel_size, filename, filename_rpc)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("> testing cross entropy with model parallel size {} ...".format(model_parallel_size)) print("> testing cross entropy with model parallel size {} ...".format(model_parallel_size))
......
...@@ -26,8 +26,8 @@ from fairscale.nn.model_parallel import initialize as mpu ...@@ -26,8 +26,8 @@ from fairscale.nn.model_parallel import initialize as mpu
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
def run_test_initialize_model_parallel(rank, model_parallel_size, filename): def run_test_initialize_model_parallel(rank, model_parallel_size, filename, filename_rpc):
dist_init(rank, model_parallel_size, filename) dist_init(rank, model_parallel_size, filename, filename_rpc)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("> testing initialize_model_parallel with size {} ...".format(model_parallel_size)) print("> testing initialize_model_parallel with size {} ...".format(model_parallel_size))
...@@ -63,8 +63,8 @@ def run_test_initialize_model_parallel(rank, model_parallel_size, filename): ...@@ -63,8 +63,8 @@ def run_test_initialize_model_parallel(rank, model_parallel_size, filename):
print(">> passed the test :-)") print(">> passed the test :-)")
def run_test_get_model_parallel_src_rank(rank, model_parallel_size_, filename): def run_test_get_model_parallel_src_rank(rank, model_parallel_size_, filename, filename_rpc):
dist_init(rank, model_parallel_size_, filename) dist_init(rank, model_parallel_size_, filename, filename_rpc)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("> testing get_model_parallel_src_rank with size {} ...".format(model_parallel_size_)) print("> testing get_model_parallel_src_rank with size {} ...".format(model_parallel_size_))
......
...@@ -35,8 +35,8 @@ from fairscale.nn.pipe import Pipe ...@@ -35,8 +35,8 @@ from fairscale.nn.pipe import Pipe
from fairscale.utils.testing import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes, torch_spawn from fairscale.utils.testing import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes, torch_spawn
def run_test_parallel_embedding(rank, model_parallel_size, filename): def run_test_parallel_embedding(rank, model_parallel_size, filename, filename_rpc):
dist_init(rank, model_parallel_size, filename) dist_init(rank, model_parallel_size, filename, filename_rpc)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("> testing parallel embedding with model parallel size {} ...".format(model_parallel_size)) print("> testing parallel embedding with model parallel size {} ...".format(model_parallel_size))
...@@ -105,8 +105,8 @@ def run_test_parallel_embedding(rank, model_parallel_size, filename): ...@@ -105,8 +105,8 @@ def run_test_parallel_embedding(rank, model_parallel_size, filename):
print(">> passed the test :-)") print(">> passed the test :-)")
def run_test_initialize_affine_weight(rank, model_parallel_size, filename): def run_test_initialize_affine_weight(rank, model_parallel_size, filename, filename_rpc):
dist_init(rank, model_parallel_size, filename) dist_init(rank, model_parallel_size, filename, filename_rpc)
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -181,8 +181,8 @@ class IdentityLayer2D(torch.nn.Module): ...@@ -181,8 +181,8 @@ class IdentityLayer2D(torch.nn.Module):
return self.weight return self.weight
def run_test_column_parallel_linear(rank, model_parallel_size, filename): def run_test_column_parallel_linear(rank, model_parallel_size, filename, filename_rpc):
dist_init(rank, model_parallel_size, filename) dist_init(rank, model_parallel_size, filename, filename_rpc)
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -242,8 +242,8 @@ def run_test_column_parallel_linear(rank, model_parallel_size, filename): ...@@ -242,8 +242,8 @@ def run_test_column_parallel_linear(rank, model_parallel_size, filename):
print(" >> passed the test :-)") print(" >> passed the test :-)")
def run_test_row_parallel_linear(rank, model_parallel_size, filename): def run_test_row_parallel_linear(rank, model_parallel_size, filename, filename_rpc):
dist_init(rank, model_parallel_size, filename) dist_init(rank, model_parallel_size, filename, filename_rpc)
mpu.initialize_model_parallel(model_parallel_size) mpu.initialize_model_parallel(model_parallel_size)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
...@@ -302,14 +302,14 @@ def run_test_row_parallel_linear(rank, model_parallel_size, filename): ...@@ -302,14 +302,14 @@ def run_test_row_parallel_linear(rank, model_parallel_size, filename):
print(" >> passed the test :-)") print(" >> passed the test :-)")
def run_test_pipe(rank, world_size, filename, skip_dist_init=False): def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False):
pipe_world_size = 2 pipe_world_size = 2
if world_size == 1: if world_size == 1:
return return
if not skip_dist_init: if not skip_dist_init:
dist_init(rank, world_size, filename) dist_init(rank, world_size, filename, filename_rpc)
else: else:
os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29502" os.environ["MASTER_PORT"] = "29502"
...@@ -567,8 +567,16 @@ def test_row_parallel(): ...@@ -567,8 +567,16 @@ def test_row_parallel():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
def mpi_pipe(): def mpi_pipe():
mpu.destroy_model_parallel() mpu.destroy_model_parallel()
tempfile_init = tempfile.mkstemp()[1] _, tempfile_init = tempfile.mkstemp()
run_test_pipe(torch.distributed.get_rank(), torch.distributed.get_world_size(), tempfile_init, skip_dist_init=True) _, tempfile_rpc_init = tempfile.mkstemp()
run_test_pipe(
torch.distributed.get_rank(),
torch.distributed.get_world_size(),
tempfile_init,
tempfile_rpc_init,
skip_dist_init=True,
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
......
...@@ -27,8 +27,8 @@ from fairscale.nn.model_parallel.random import get_cuda_rng_tracker, model_paral ...@@ -27,8 +27,8 @@ from fairscale.nn.model_parallel.random import get_cuda_rng_tracker, model_paral
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
def run_test_set_cuda_rng_state(rank, model_parallel_size, filename): def run_test_set_cuda_rng_state(rank, model_parallel_size, filename, filename_rpc):
dist_init(rank, model_parallel_size, filename) dist_init(rank, model_parallel_size, filename, filename_rpc)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("> testing set_rng_state with size {} ...".format(model_parallel_size)) print("> testing set_rng_state with size {} ...".format(model_parallel_size))
...@@ -96,8 +96,8 @@ def run_test_set_cuda_rng_state(rank, model_parallel_size, filename): ...@@ -96,8 +96,8 @@ def run_test_set_cuda_rng_state(rank, model_parallel_size, filename):
print(">> passed the test :-)") print(">> passed the test :-)")
def run_test_cuda_rng_tracker(rank, model_parallel_size, filename): def run_test_cuda_rng_tracker(rank, model_parallel_size, filename, filename_rpc):
dist_init(rank, model_parallel_size, filename) dist_init(rank, model_parallel_size, filename, filename_rpc)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("> testing cuda rng tracker with size {} ...".format(model_parallel_size)) print("> testing cuda rng tracker with size {} ...".format(model_parallel_size))
...@@ -172,8 +172,8 @@ def run_test_cuda_rng_tracker(rank, model_parallel_size, filename): ...@@ -172,8 +172,8 @@ def run_test_cuda_rng_tracker(rank, model_parallel_size, filename):
print(">> passed the test :-)") print(">> passed the test :-)")
def run_test_model_parallel_cuda_manual_seed(rank, model_parallel_size, filename): def run_test_model_parallel_cuda_manual_seed(rank, model_parallel_size, filename, filename_rpc):
dist_init(rank, model_parallel_size, filename) dist_init(rank, model_parallel_size, filename, filename_rpc)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("> testing model parallel cuda manual seed with size {} ...".format(model_parallel_size)) print("> testing model parallel cuda manual seed with size {} ...".format(model_parallel_size))
......
...@@ -24,9 +24,10 @@ from torch import nn ...@@ -24,9 +24,10 @@ from torch import nn
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import Pipe
from fairscale.nn.pipe.skip import pop, skippable, stash from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange
from fairscale.utils.testing import skip_if_single_gpu
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @skip_if_single_gpu
@pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"]) @pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"])
@pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"])
def test_1to3(balance, checkpoint): def test_1to3(balance, checkpoint):
......
...@@ -23,6 +23,7 @@ from torch import nn ...@@ -23,6 +23,7 @@ from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import Pipe
from fairscale.utils.testing import skip_if_single_gpu
def test_python_autograd_function(): def test_python_autograd_function():
...@@ -81,7 +82,7 @@ def test_exception_no_hang(): ...@@ -81,7 +82,7 @@ def test_exception_no_hang():
model(torch.rand(3)) model(torch.rand(3))
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required") @skip_if_single_gpu
def test_tuple_wait(cuda_sleep): def test_tuple_wait(cuda_sleep):
# In v0.0.3, Wait is applied to only the first tensor on a micro-batch. # In v0.0.3, Wait is applied to only the first tensor on a micro-batch.
# Under this behavior, if checkpointing was disabled, there's a possibility # Under this behavior, if checkpointing was disabled, there's a possibility
......
...@@ -22,6 +22,7 @@ from torch.utils.checkpoint import checkpoint as torch_checkpoint ...@@ -22,6 +22,7 @@ from torch.utils.checkpoint import checkpoint as torch_checkpoint
from fairscale.nn.pipe.checkpoint import Checkpointing, Function, TensorOrTensors from fairscale.nn.pipe.checkpoint import Checkpointing, Function, TensorOrTensors
from fairscale.nn.pipe.microbatch import Batch from fairscale.nn.pipe.microbatch import Batch
from fairscale.utils.testing import skip_if_no_cuda, skip_if_single_gpu
# This test is mainly for checking pytorch & checkpointing behavior. pipe's checkpointing # This test is mainly for checking pytorch & checkpointing behavior. pipe's checkpointing
# code is tested already in another file. Therefore, we can run this test less frequently. # code is tested already in another file. Therefore, we can run this test less frequently.
...@@ -30,8 +31,6 @@ run_test = False ...@@ -30,8 +31,6 @@ run_test = False
if os.getpid() % 100 == 42: if os.getpid() % 100 == 42:
run_test = True run_test = True
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")
skip_if_not_needed = pytest.mark.skipif(not run_test, reason="Skipping due to test frequency") skip_if_not_needed = pytest.mark.skipif(not run_test, reason="Skipping due to test frequency")
......
...@@ -22,8 +22,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -22,8 +22,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD from torch.optim import SGD
from fairscale.optim import AdaScale from fairscale.optim import AdaScale
from fairscale.utils.testing import skip_if_single_gpu
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs are required")
def _dist_init(rank, world_size, tempfile_name, backend): def _dist_init(rank, world_size, tempfile_name, backend):
......
...@@ -21,8 +21,7 @@ import torch.multiprocessing as mp ...@@ -21,8 +21,7 @@ import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim import fairscale.optim as optim
from fairscale.utils.testing import skip_if_no_cuda, skip_if_single_gpu
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu") DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
...@@ -225,7 +224,7 @@ def run_test_add_param_group(rank, world_size, tempfile_name): ...@@ -225,7 +224,7 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
def test_add_param_group(): def test_add_param_group():
world_size = 3 world_size = 3
if torch.cuda.device_count() < world_size: if not torch.cuda.is_available() or torch.cuda.device_count() < world_size:
pytest.skip("Not enough GPUs for NCCL-based test") pytest.skip("Not enough GPUs for NCCL-based test")
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_add_param_group, args=(world_size, temp_file_name), nprocs=world_size, join=True) mp.spawn(run_test_add_param_group, args=(world_size, temp_file_name), nprocs=world_size, join=True)
...@@ -273,9 +272,9 @@ def run_test_step(rank, world_size, tempfile_name): ...@@ -273,9 +272,9 @@ def run_test_step(rank, world_size, tempfile_name):
dist.destroy_process_group() dist.destroy_process_group()
@skip_if_no_cuda @skip_if_single_gpu
def test_step(): def test_step():
world_size = min(2, torch.cuda.device_count()) world_size = 2
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_step, args=(world_size, temp_file_name), nprocs=world_size, join=True) mp.spawn(run_test_step, args=(world_size, temp_file_name), nprocs=world_size, join=True)
...@@ -347,7 +346,7 @@ def run_test_sharding(rank, world_size, tempfile_name): ...@@ -347,7 +346,7 @@ def run_test_sharding(rank, world_size, tempfile_name):
def test_sharding(): def test_sharding():
world_size = 3 world_size = 3
if torch.cuda.device_count() < world_size: if not torch.cuda.is_available() or torch.cuda.device_count() < world_size:
pytest.skip("Not enough GPUs for NCCL-based test") pytest.skip("Not enough GPUs for NCCL-based test")
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
......
...@@ -21,8 +21,7 @@ from torch.optim import SGD ...@@ -21,8 +21,7 @@ from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR from torch.optim.lr_scheduler import LambdaLR
from fairscale.optim import AdaScale from fairscale.optim import AdaScale
from fairscale.utils.testing import skip_if_no_cuda
skip_if_no_gpu = pytest.mark.skipif(torch.cuda.device_count() < 1, reason="1 GPU is required")
def test_basic_cpu(): def test_basic_cpu():
...@@ -114,15 +113,15 @@ def test_grad_accum(test_case, cpu): ...@@ -114,15 +113,15 @@ def test_grad_accum(test_case, cpu):
optim.zero_grad() optim.zero_grad()
@skip_if_no_gpu @skip_if_no_cuda
def test_state_checkpointing(): def test_state_checkpointing():
""" Test state checkpointing on GPU since that's the common case. """Test state checkpointing on GPU since that's the common case.
Note, we don't support checkpointing in the middle of gradient accumulation Note, we don't support checkpointing in the middle of gradient accumulation
step. Therefore, it is not tested here. step. Therefore, it is not tested here.
AdaScale doesn't have distributed state. Otherwise, it will need AdaScale doesn't have distributed state. Otherwise, it will need
a unit test for checkpointing with DDP. a unit test for checkpointing with DDP.
""" """
# Constants. # Constants.
accum_steps = 3 accum_steps = 3
...@@ -207,7 +206,7 @@ def test_lr_scheduler(): ...@@ -207,7 +206,7 @@ def test_lr_scheduler():
assert np.allclose(optim.param_groups[0]["lr"], 0.1 / 10 ** (epoch + 1)), optim.param_groups[0]["lr"] assert np.allclose(optim.param_groups[0]["lr"], 0.1 / 10 ** (epoch + 1)), optim.param_groups[0]["lr"]
@skip_if_no_gpu @skip_if_no_cuda
@pytest.mark.parametrize("debias_ewma", [True, False]) @pytest.mark.parametrize("debias_ewma", [True, False])
def test_add_param_group(debias_ewma): def test_add_param_group(debias_ewma):
"""Test AdaScale supports add_param_group() API.""" """Test AdaScale supports add_param_group() API."""
...@@ -376,7 +375,7 @@ def test_scale_not_equal_default(test_case): ...@@ -376,7 +375,7 @@ def test_scale_not_equal_default(test_case):
assert np.allclose(optim.gain(), exp_gain), optim.gain() assert np.allclose(optim.gain(), exp_gain), optim.gain()
@skip_if_no_gpu @skip_if_no_cuda
def test_unhook(): def test_unhook():
"""Test unhook that frees the tensor from CUDA memory.""" """Test unhook that frees the tensor from CUDA memory."""
model = Linear(123, 456, bias=False).cuda() # unique shape so that it can be found model = Linear(123, 456, bias=False).cuda() # unique shape so that it can be found
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment