Unverified Commit bc1e60e0 authored by Pavel Belevich's avatar Pavel Belevich Committed by GitHub
Browse files

Fix pytorch version check (#716)

parent 00ec9ff1
......@@ -20,7 +20,8 @@ from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel.fully_sharded_data_parallel import TrainingState
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case):
......
......@@ -13,7 +13,8 @@ import torch.multiprocessing as mp
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx, torch_version
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx
@skip_if_single_gpu
......
......@@ -22,13 +22,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import (
check_same_model_params,
skip_if_no_cuda,
skip_if_single_gpu,
temp_files_ctx,
torch_version,
)
from fairscale.utils import torch_version
from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx
"""
Check that ShardedDDP gets the same results as DDP in a variety of scenarii
......
......@@ -12,7 +12,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp
from fairscale.nn import MOELayer, Top2Gate
from fairscale.utils.testing import torch_version
from fairscale.utils import torch_version
pytestmark = pytest.mark.skipif(
not (torch.cuda.is_available() and torch_version() >= (1, 8, 0)), reason="cuda and torch>=1.8.0 required"
......
......@@ -29,7 +29,8 @@ from torch import nn
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import AsyncPipe
from fairscale.nn.pipe.types import LazyModule
from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version
from fairscale.utils import torch_version
from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2])
......
......@@ -8,6 +8,7 @@ from torch.distributed import rpc
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import PipeRPCWrapper
from fairscale.utils import torch_version
from fairscale.utils.testing import get_worker_map, torch_spawn
......@@ -242,7 +243,7 @@ def rpc_multiple_tensors():
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="no mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
# TODO(msb) Fix this
@pytest.mark.skipif(torch.__version__.split("+")[0].split(".") >= ["1", "8", "0"], reason="disabled for torch 1.8.0")
@pytest.mark.skipif(torch_version() >= (1, 8, 0), reason="disabled for torch 1.8.0")
def construct_only_rank_zero():
model = [nn.Linear(10, 10), nn.ReLU()]
if torch.distributed.get_rank() == 0:
......
......@@ -23,13 +23,13 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim
import fairscale.utils as utils
from fairscale.utils import torch_version
from fairscale.utils.testing import (
check_same_model_params,
check_same_models_across_ranks,
skip_if_no_cuda,
skip_if_py39_no_cuda,
skip_if_single_gpu,
torch_version,
)
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
......
......@@ -12,6 +12,7 @@ from unittest import mock
from parameterized import parameterized
import torch
from fairscale.utils import torch_version
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
......@@ -28,8 +29,7 @@ CONFIG_OPTIONS = [
class TestReduceScatterBucketer(unittest.TestCase):
# TODO(sshleifer): check if possible to reuse `DistributedTest, spawn_and_init`.
def setUp(self):
major, minor = torch.__version__.split(".")[:2]
major, minor = int(major), int(minor)
major, minor, _ = torch_version()
if major < 1 or (major == 1 and minor < 6):
raise unittest.SkipTest("Need pytorch version >= 1.6 due to reduce_scatter")
if not torch.cuda.is_available():
......
from fairscale.utils import torch_version
def test_torch_version():
assert torch_version("") == tuple()
assert torch_version("bad format") == tuple()
assert torch_version("1.9.0") == (1, 9, 0)
assert torch_version("1.10.0a0+gitbc6fc3e") == (1, 10, 0)
assert torch_version("1.7.0+cu102") == (1, 7, 0)
assert torch_version("1.10.0a0+fb") == (1, 10, 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment