Unverified Commit e83da060 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore] Refactor unit testing, shared utils (#218)

parent 1db8bbda
...@@ -23,7 +23,7 @@ from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_ ...@@ -23,7 +23,7 @@ from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_
from fairscale.nn.pipe import LazyModule, pipe from fairscale.nn.pipe import LazyModule, pipe
from fairscale.optim import GradScaler from fairscale.optim import GradScaler
from fairscale.optim.oss import OSS from fairscale.optim.oss import OSS
from tests.nn.model_parallel.commons import dist_init, get_worker_map from fairscale.utils.testing import dist_init, get_worker_map
try: try:
from fairscale.optim import Adam # type: ignore from fairscale.optim import Adam # type: ignore
......
# coding=utf-8
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# #
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
...@@ -11,7 +9,7 @@ ...@@ -11,7 +9,7 @@
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -19,14 +17,23 @@ ...@@ -19,14 +17,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# We're not responsible for pytest decorators
# mypy: disallow_untyped_decorators = False
"""
Collection of some testing utilities for the Fairscale library. Please complement as you see fit, but refrain from ad-hoc test utils
within the different feature sets and relative imports.
"""
import functools import functools
import inspect import inspect
import logging
import multiprocessing import multiprocessing
import os import os
import random import random
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy import numpy
from packaging import version
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -38,11 +45,11 @@ from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed ...@@ -38,11 +45,11 @@ from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
class IdentityLayer(torch.nn.Module): class IdentityLayer(torch.nn.Module):
def __init__(self, size, scale=1.0): def __init__(self, size: int, scale: float = 1.0) -> None:
super(IdentityLayer, self).__init__() super(IdentityLayer, self).__init__()
self.weight = torch.nn.Parameter(scale * torch.randn(size)) self.weight = torch.nn.Parameter(scale * torch.randn(size))
def forward(self): def forward(self, *_: Any, **__: Any) -> Any:
return self.weight return self.weight
...@@ -54,7 +61,26 @@ def set_random_seed(seed: int) -> None: ...@@ -54,7 +61,26 @@ def set_random_seed(seed: int) -> None:
model_parallel_cuda_manual_seed(seed) model_parallel_cuda_manual_seed(seed)
def dist_init(rank, world_size, hostname=None): def torch_version() -> Tuple[int, ...]:
numbering = torch.__version__.split(".")
assert len(numbering) == 3
# Catch torch version if run against internal pre-releases, like `1.8.0a0fb`,
if not numbering[2].isnumeric():
# Two options here:
# - either skip this version (minor number check is not relevant)
# - or check that our codebase is not broken by this ongoing development.
# Assuming that we're interested in the second usecase more than the first,
# return the pre-release or dev numbering
logging.warning(f"Pytorch pre-relase version {torch.__version__} - assuming intent to test it")
numbering[2] = "0"
return tuple(int(n) for n in numbering)
def dist_init(rank: int, world_size: int, hostname: Optional[str] = None) -> None:
if hostname is None: if hostname is None:
hostname = "localhost" hostname = "localhost"
print(f"dist init r={rank}, world={world_size}, host={hostname}") print(f"dist init r={rank}, world={world_size}, host={hostname}")
...@@ -63,7 +89,7 @@ def dist_init(rank, world_size, hostname=None): ...@@ -63,7 +89,7 @@ def dist_init(rank, world_size, hostname=None):
os.environ["WORLD_SIZE"] = str(world_size) os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank) os.environ["RANK"] = str(rank)
if version.parse(torch.__version__).release >= (1, 6, 0): if torch_version() >= (1, 6, 0):
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
backend = "nccl" if torch.cuda.is_available() else "gloo" backend = "nccl" if torch.cuda.is_available() else "gloo"
torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=init_method) torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=init_method)
...@@ -77,6 +103,7 @@ def dist_init(rank, world_size, hostname=None): ...@@ -77,6 +103,7 @@ def dist_init(rank, world_size, hostname=None):
backend=rpc.BackendType.TENSORPIPE, backend=rpc.BackendType.TENSORPIPE,
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=init_method), rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=init_method),
) )
else: else:
if world_size > 1: if world_size > 1:
rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size) rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size)
...@@ -87,21 +114,21 @@ def dist_init(rank, world_size, hostname=None): ...@@ -87,21 +114,21 @@ def dist_init(rank, world_size, hostname=None):
torch.cuda.set_device(rank % torch.cuda.device_count()) torch.cuda.set_device(rank % torch.cuda.device_count())
def get_worker_map(): def get_worker_map() -> Dict[Any, Any]:
return {rank: f"Test{rank}" for rank in range(dist.get_world_size())} return {rank: f"Test{rank}" for rank in range(dist.get_world_size())}
def get_world_sizes(): def get_world_sizes() -> List[int]:
limit = torch.cuda.device_count() limit = torch.cuda.device_count()
return [x for x in [1, 2, 4, 8] if x <= limit] return [x for x in [1, 2, 4, 8] if x <= limit]
def spawn_for_all_world_sizes(test_func, world_sizes=get_world_sizes(), args=[]): def spawn_for_all_world_sizes(test_func: Callable, world_sizes: List[int] = get_world_sizes(), args: Any = []) -> None:
for world_size in world_sizes: for world_size in world_sizes:
mp.spawn(test_func, args=(world_size, *args), nprocs=world_size, join=True) mp.spawn(test_func, args=(world_size, *args), nprocs=world_size, join=True) # type: ignore
def worker_process(rank, world_size, func, args, error_queue): def worker_process(rank: int, world_size: int, func: Callable, args: Any, error_queue: Any) -> None:
"""Main function for unit tests launced with torch_spawn""" """Main function for unit tests launced with torch_spawn"""
dist_init(rank, world_size) dist_init(rank, world_size)
...@@ -120,11 +147,11 @@ def worker_process(rank, world_size, func, args, error_queue): ...@@ -120,11 +147,11 @@ def worker_process(rank, world_size, func, args, error_queue):
raise e raise e
def torch_spawn(world_sizes=None): def torch_spawn(world_sizes: Optional[List[int]] = None) -> Callable:
if world_sizes is None: if world_sizes is None:
world_sizes = get_world_sizes() world_sizes = get_world_sizes()
def prepare_test(func): def prepare_test(func: Callable) -> Callable:
"""Function called with the test function as the argument. Generates a """Function called with the test function as the argument. Generates a
replacement which serves as the actual test function.""" replacement which serves as the actual test function."""
...@@ -138,8 +165,10 @@ def torch_spawn(world_sizes=None): ...@@ -138,8 +165,10 @@ def torch_spawn(world_sizes=None):
) )
@functools.wraps(func) @functools.wraps(func)
def replacement(*args, **kwargs): def replacement(*args: Any, **kwargs: Any) -> None:
assert args == tuple() assert args == tuple()
assert world_sizes is not None # mypy crutch
args = tuple( args = tuple(
kwargs[p] for p in parameters if p != "rank" kwargs[p] for p in parameters if p != "rank"
) # converting named parameters to positional parameters to pass to `spawn` ) # converting named parameters to positional parameters to pass to `spawn`
...@@ -174,7 +203,9 @@ def torch_spawn(world_sizes=None): ...@@ -174,7 +203,9 @@ def torch_spawn(world_sizes=None):
# Register a function with the same name, prefixed with "test_" in the # Register a function with the same name, prefixed with "test_" in the
# calling module, so it will be picked up by pytest # calling module, so it will be picked up by pytest
caller_module = inspect.getmodule(inspect.currentframe().f_back) current_frame = inspect.currentframe()
assert current_frame is not None
caller_module = inspect.getmodule(current_frame.f_back)
setattr(caller_module, f"test_{name}", replacement) setattr(caller_module, f"test_{name}", replacement)
return func return func
......
...@@ -28,4 +28,4 @@ use_parentheses = true ...@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"] skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from". # Don't split "import" and "from".
force_sort_within_sections = true force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "dataclasses", "numpy", "packaging", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"] known_third_party = ["benchmark_dataset", "dataclasses", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
...@@ -7,6 +7,7 @@ from .. import device as _device ...@@ -7,6 +7,7 @@ from .. import device as _device
def is_available() -> bool: ... def is_available() -> bool: ...
def init() -> None: ... def init() -> None: ...
def _lazy_call(callable) -> None: ... def _lazy_call(callable) -> None: ...
def _sleep(_:int) -> None : ...
class cudaStatus: class cudaStatus:
SUCCESS: int SUCCESS: int
...@@ -64,6 +65,12 @@ class Stream: ...@@ -64,6 +65,12 @@ class Stream:
def synchronize(self) -> None: ... def synchronize(self) -> None: ...
def wait_stream(self, stream: Stream) -> None: ... def wait_stream(self, stream: Stream) -> None: ...
class Event:
def __new__(cls, enable_timing: bool = False, blocking:bool = False, interprocess: bool = False) -> "Event": ...
def record(self, stream: Optional[Stream] = None) -> None: ...
def synchronize(self) -> None: ...
def elapsed_time(self, end_event: Event) -> int: ...
class stream: class stream:
def __init__(self, stream: Optional[Stream] = ...) -> None: ... def __init__(self, stream: Optional[Stream] = ...) -> None: ...
def __enter__(self) -> None: ... def __enter__(self) -> None: ...
......
...@@ -35,7 +35,7 @@ def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optio ...@@ -35,7 +35,7 @@ def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optio
def is_initialized() -> bool: ... def is_initialized() -> bool: ...
def init_process_group(backend: Union[str, Backend], timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ... def init_process_group(backend: Union[str, Backend], init_method: Optional[str] = None, timeout: datetime.timedelta = datetime.timedelta(0, 1800), rank: Optional[int] = None, world_size: Optional[int] = None): ...
def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ... def new_group(ranks: List[int], timeout: datetime.timedelta = datetime.timedelta(0, 1800), backend: Union[None, str, Backend] = None): ...
def all_to_all(output: List[Tensor], input: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_to_all(output: List[Tensor], input: List[Tensor], group:Optional[ProcessGroup] = None, async_op: bool = False): ...
...@@ -43,6 +43,8 @@ def all_to_all_single(output: Tensor, input: Tensor, output_split_size: Optional ...@@ -43,6 +43,8 @@ def all_to_all_single(output: Tensor, input: Tensor, output_split_size: Optional
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ... def all_gather(tensor_list: List[Tensor], tensor: Tensor, group:Optional[ProcessGroup] = None, async_op: bool = False): ...
def destroy_process_group() -> None: ...
def send(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ... def send(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ...
def isend(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ... def isend(tensor: Tensor, dst: int, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> None: ...
def recv(tensor: Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> int: ... def recv(tensor: Tensor, src: Optional[int] = None, group: Optional[ProcessGroup] = None, tag: Optional[int] = None) -> int: ...
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Union, Callable, Optional from typing import Union, Callable, Optional, Any
from torch.futures import Future from torch.futures import Future
...@@ -11,6 +11,11 @@ class RRef: ...@@ -11,6 +11,11 @@ class RRef:
class WorkerInfo: class WorkerInfo:
... ...
class BackendType:
TENSORPIPE: Any
PROCESS_GROUP: Any
def TensorPipeRpcBackendOptions(init_method: str) -> Any : ...
def rpc_async( def rpc_async(
to: Union[str, WorkerInfo], to: Union[str, WorkerInfo],
...@@ -30,3 +35,8 @@ def rpc_sync( ...@@ -30,3 +35,8 @@ def rpc_sync(
timeout=-1.0, timeout=-1.0,
) -> None: ) -> None:
... ...
def init_rpc(name: str, backend: Optional[Any] = None, rank:int = -1, world_size: Optional[int] = None, rpc_backend_options: Optional[Any] = None) -> None: ...
def shutdown() -> None: ...
...@@ -25,7 +25,7 @@ import torch.nn.functional as F ...@@ -25,7 +25,7 @@ import torch.nn.functional as F
from fairscale.nn.model_parallel import initialize as mpu from fairscale.nn.model_parallel import initialize as mpu
from fairscale.nn.model_parallel.cross_entropy import vocab_parallel_cross_entropy from fairscale.nn.model_parallel.cross_entropy import vocab_parallel_cross_entropy
from fairscale.nn.model_parallel.mappings import scatter_to_model_parallel_region from fairscale.nn.model_parallel.mappings import scatter_to_model_parallel_region
from tests.nn.model_parallel.commons import IdentityLayer, dist_init, set_random_seed, spawn_for_all_world_sizes from fairscale.utils.testing import IdentityLayer, dist_init, set_random_seed, spawn_for_all_world_sizes
def torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed): def torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed):
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
import torch import torch
from fairscale.nn.model_parallel import initialize as mpu from fairscale.nn.model_parallel import initialize as mpu
from tests.nn.model_parallel.commons import dist_init, spawn_for_all_world_sizes from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
def run_test_initialize_model_parallel(rank, model_parallel_size): def run_test_initialize_model_parallel(rank, model_parallel_size):
......
...@@ -31,13 +31,7 @@ from torch.nn.parameter import Parameter ...@@ -31,13 +31,7 @@ from torch.nn.parameter import Parameter
from fairscale.nn.model_parallel import initialize as mpu from fairscale.nn.model_parallel import initialize as mpu
from fairscale.nn.model_parallel import layers from fairscale.nn.model_parallel import layers
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import Pipe
from tests.nn.model_parallel.commons import ( from fairscale.utils.testing import dist_init, get_world_sizes, set_random_seed, spawn_for_all_world_sizes, torch_spawn
dist_init,
get_world_sizes,
set_random_seed,
spawn_for_all_world_sizes,
torch_spawn,
)
def run_test_parallel_embedding(rank, model_parallel_size): def run_test_parallel_embedding(rank, model_parallel_size):
......
...@@ -24,7 +24,7 @@ import torch ...@@ -24,7 +24,7 @@ import torch
from fairscale.nn.model_parallel import initialize as mpu from fairscale.nn.model_parallel import initialize as mpu
from fairscale.nn.model_parallel import random from fairscale.nn.model_parallel import random
from fairscale.nn.model_parallel.random import get_cuda_rng_tracker, model_parallel_cuda_manual_seed from fairscale.nn.model_parallel.random import get_cuda_rng_tracker, model_parallel_cuda_manual_seed
from tests.nn.model_parallel.commons import dist_init, spawn_for_all_world_sizes from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
def run_test_set_cuda_rng_state(rank, model_parallel_size): def run_test_set_cuda_rng_state(rank, model_parallel_size):
......
...@@ -17,17 +17,27 @@ ...@@ -17,17 +17,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools
import os
from typing import Any, Callable
import pytest import pytest
import torch import torch
from fairscale.nn.model_parallel import destroy_model_parallel
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def manual_seed_zero(): def manual_seed_zero() -> None:
torch.manual_seed(0) torch.manual_seed(0)
def cuda_sleep_impl(seconds, cycles_per_ms):
torch.cuda._sleep(int(seconds * cycles_per_ms * 1000))
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def cuda_sleep(): def cuda_sleep() -> Callable:
# Warm-up CUDA. # Warm-up CUDA.
torch.empty(1, device="cuda") torch.empty(1, device="cuda")
...@@ -40,11 +50,23 @@ def cuda_sleep(): ...@@ -40,11 +50,23 @@ def cuda_sleep():
end.synchronize() end.synchronize()
cycles_per_ms = 1000000 / start.elapsed_time(end) cycles_per_ms = 1000000 / start.elapsed_time(end)
def cuda_sleep(seconds): return functools.partial(cuda_sleep_impl, cycles_per_ms=cycles_per_ms)
torch.cuda._sleep(int(seconds * cycles_per_ms * 1000))
return cuda_sleep
def pytest_report_header(): def pytest_report_header() -> str:
return f"torch: {torch.__version__}" return f"torch: {torch.__version__}"
def pytest_runtest_setup(item: Any) -> None:
print(f"setup mpi function called")
def pytest_runtest_teardown(item: Any) -> None:
if "OMPI_COMM_WORLD_RANK" in os.environ:
destroy_model_parallel()
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
try:
torch.distributed.rpc.shutdown()
except Exception:
pass
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import functools import functools
import os import os
from typing import Any, Callable
import pytest import pytest
import torch import torch
...@@ -27,7 +28,7 @@ from fairscale.nn.model_parallel import destroy_model_parallel ...@@ -27,7 +28,7 @@ from fairscale.nn.model_parallel import destroy_model_parallel
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def manual_seed_zero(): def manual_seed_zero() -> None:
torch.manual_seed(0) torch.manual_seed(0)
...@@ -36,7 +37,7 @@ def cuda_sleep_impl(seconds, cycles_per_ms): ...@@ -36,7 +37,7 @@ def cuda_sleep_impl(seconds, cycles_per_ms):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def cuda_sleep(): def cuda_sleep() -> Callable:
# Warm-up CUDA. # Warm-up CUDA.
torch.empty(1, device="cuda") torch.empty(1, device="cuda")
...@@ -52,15 +53,15 @@ def cuda_sleep(): ...@@ -52,15 +53,15 @@ def cuda_sleep():
return functools.partial(cuda_sleep_impl, cycles_per_ms=cycles_per_ms) return functools.partial(cuda_sleep_impl, cycles_per_ms=cycles_per_ms)
def pytest_report_header(): def pytest_report_header() -> str:
return f"torch: {torch.__version__}" return f"torch: {torch.__version__}"
def pytest_runtest_setup(item): def pytest_runtest_setup(item: Any) -> None:
print(f"setup mpi function called") print(f"setup mpi function called")
def pytest_runtest_teardown(item): def pytest_runtest_teardown(item: Any) -> None:
if "OMPI_COMM_WORLD_RANK" in os.environ: if "OMPI_COMM_WORLD_RANK" in os.environ:
destroy_model_parallel() destroy_model_parallel()
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
......
...@@ -26,7 +26,7 @@ from torch import nn ...@@ -26,7 +26,7 @@ from torch import nn
from fairscale.nn.pipe import LazyModule, Pipe from fairscale.nn.pipe import LazyModule, Pipe
from fairscale.nn.pipe.skip import pop, skippable, stash from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange
from tests.nn.model_parallel.commons import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([3]) @torch_spawn([3])
......
...@@ -26,7 +26,7 @@ from torch import nn ...@@ -26,7 +26,7 @@ from torch import nn
from fairscale.nn.pipe import Pipe, is_checkpointing, is_recomputing from fairscale.nn.pipe import Pipe, is_checkpointing, is_recomputing
from fairscale.nn.pipe.skip import pop, skippable, stash from fairscale.nn.pipe.skip import pop, skippable, stash
from fairscale.nn.pipe.skip.tracker import current_skip_tracker from fairscale.nn.pipe.skip.tracker import current_skip_tracker
from tests.nn.model_parallel.commons import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
@skippable(stash=["skip"]) @skippable(stash=["skip"])
......
...@@ -23,7 +23,7 @@ from torch import nn ...@@ -23,7 +23,7 @@ from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import Pipe
from tests.nn.model_parallel.commons import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
......
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
from torch import nn from torch import nn
from fairscale.nn.pipe import Pipe from fairscale.nn.pipe import Pipe
from tests.nn.model_parallel.commons import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
......
...@@ -21,9 +21,7 @@ from collections import OrderedDict ...@@ -21,9 +21,7 @@ from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
import os import os
import time import time
from typing import Tuple
from packaging import version
import pytest import pytest
import torch import torch
from torch import nn from torch import nn
...@@ -34,7 +32,7 @@ from fairscale.nn.model_parallel.initialize import ( ...@@ -34,7 +32,7 @@ from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel, initialize_model_parallel,
) )
from fairscale.nn.pipe import LazyModule, Pipe from fairscale.nn.pipe import LazyModule, Pipe
from tests.nn.model_parallel.commons import get_worker_map, set_random_seed, torch_spawn from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn, torch_version
@torch_spawn([2]) @torch_spawn([2])
...@@ -373,24 +371,6 @@ def checkpoint_eval(pipeline_style): ...@@ -373,24 +371,6 @@ def checkpoint_eval(pipeline_style):
assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward") assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward")
def torch_version() -> Tuple[int, ...]:
result = version.parse(torch.__version__).release
# Catch torch version if run against internal pre-releases, like `1.8.0a0fb`,
# for which version.parse().release will return None (version becomes of LegacyVersion type)
if result is None:
# Two options here:
# - either skip this version,
# - or check that Pipe is not broken by this ongoing development.
# Assuming that we're interested in the second usecase more than the first,
# return the pre-release or dev numbering
numbering = torch.__version__.split(".")
result = (int(numbering[0]), int(numbering[1]), 0)
assert result
return result
@torch_spawn([2]) @torch_spawn([2])
@pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True) @pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True)
@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule])
......
...@@ -8,7 +8,7 @@ from torch.distributed import rpc ...@@ -8,7 +8,7 @@ from torch.distributed import rpc
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import PipeRPCWrapper from fairscale.nn.pipe import PipeRPCWrapper
from tests.nn.model_parallel.commons import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
def init_rpc(): def init_rpc():
......
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
from torch import nn from torch import nn
from fairscale.nn import Pipe from fairscale.nn import Pipe
from tests.nn.model_parallel.commons import get_worker_map, set_random_seed, torch_spawn from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment