Unverified Commit bc1e60e0 authored by Pavel Belevich's avatar Pavel Belevich Committed by GitHub
Browse files

Fix pytorch version check (#716)

parent 00ec9ff1
...@@ -20,7 +20,8 @@ from torch.optim import SGD ...@@ -20,7 +20,8 @@ from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel.fully_sharded_data_parallel import TrainingState from fairscale.nn.data_parallel.fully_sharded_data_parallel import TrainingState
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case): def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case):
......
...@@ -13,7 +13,8 @@ import torch.multiprocessing as mp ...@@ -13,7 +13,8 @@ import torch.multiprocessing as mp
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel from fairscale.nn.data_parallel import FullyShardedDataParallel
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx, torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx
@skip_if_single_gpu @skip_if_single_gpu
......
...@@ -22,13 +22,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -22,13 +22,8 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import ( from fairscale.utils import torch_version
check_same_model_params, from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx
skip_if_no_cuda,
skip_if_single_gpu,
temp_files_ctx,
torch_version,
)
""" """
Check that ShardedDDP gets the same results as DDP in a variety of scenarii Check that ShardedDDP gets the same results as DDP in a variety of scenarii
......
...@@ -12,7 +12,7 @@ import torch.distributed as dist ...@@ -12,7 +12,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
from fairscale.nn import MOELayer, Top2Gate from fairscale.nn import MOELayer, Top2Gate
from fairscale.utils.testing import torch_version from fairscale.utils import torch_version
pytestmark = pytest.mark.skipif( pytestmark = pytest.mark.skipif(
not (torch.cuda.is_available() and torch_version() >= (1, 8, 0)), reason="cuda and torch>=1.8.0 required" not (torch.cuda.is_available() and torch_version() >= (1, 8, 0)), reason="cuda and torch>=1.8.0 required"
......
...@@ -29,7 +29,8 @@ from torch import nn ...@@ -29,7 +29,8 @@ from torch import nn
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import AsyncPipe from fairscale.nn.pipe import AsyncPipe
from fairscale.nn.pipe.types import LazyModule from fairscale.nn.pipe.types import LazyModule
from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import get_worker_map, torch_spawn
@torch_spawn([2]) @torch_spawn([2])
......
...@@ -8,6 +8,7 @@ from torch.distributed import rpc ...@@ -8,6 +8,7 @@ from torch.distributed import rpc
from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group
from fairscale.nn.pipe import PipeRPCWrapper from fairscale.nn.pipe import PipeRPCWrapper
from fairscale.utils import torch_version
from fairscale.utils.testing import get_worker_map, torch_spawn from fairscale.utils.testing import get_worker_map, torch_spawn
...@@ -242,7 +243,7 @@ def rpc_multiple_tensors(): ...@@ -242,7 +243,7 @@ def rpc_multiple_tensors():
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="no mpi") @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="no mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
# TODO(msb) Fix this # TODO(msb) Fix this
@pytest.mark.skipif(torch.__version__.split("+")[0].split(".") >= ["1", "8", "0"], reason="disabled for torch 1.8.0") @pytest.mark.skipif(torch_version() >= (1, 8, 0), reason="disabled for torch 1.8.0")
def construct_only_rank_zero(): def construct_only_rank_zero():
model = [nn.Linear(10, 10), nn.ReLU()] model = [nn.Linear(10, 10), nn.ReLU()]
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
......
...@@ -23,13 +23,13 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -23,13 +23,13 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import fairscale.optim as optim import fairscale.optim as optim
import fairscale.utils as utils import fairscale.utils as utils
from fairscale.utils import torch_version
from fairscale.utils.testing import ( from fairscale.utils.testing import (
check_same_model_params, check_same_model_params,
check_same_models_across_ranks, check_same_models_across_ranks,
skip_if_no_cuda, skip_if_no_cuda,
skip_if_py39_no_cuda, skip_if_py39_no_cuda,
skip_if_single_gpu, skip_if_single_gpu,
torch_version,
) )
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO # type: ignore
......
...@@ -12,6 +12,7 @@ from unittest import mock ...@@ -12,6 +12,7 @@ from unittest import mock
from parameterized import parameterized from parameterized import parameterized
import torch import torch
from fairscale.utils import torch_version
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes
...@@ -28,8 +29,7 @@ CONFIG_OPTIONS = [ ...@@ -28,8 +29,7 @@ CONFIG_OPTIONS = [
class TestReduceScatterBucketer(unittest.TestCase): class TestReduceScatterBucketer(unittest.TestCase):
# TODO(sshleifer): check if possible to reuse `DistributedTest, spawn_and_init`. # TODO(sshleifer): check if possible to reuse `DistributedTest, spawn_and_init`.
def setUp(self): def setUp(self):
major, minor = torch.__version__.split(".")[:2] major, minor, _ = torch_version()
major, minor = int(major), int(minor)
if major < 1 or (major == 1 and minor < 6): if major < 1 or (major == 1 and minor < 6):
raise unittest.SkipTest("Need pytorch version >= 1.6 due to reduce_scatter") raise unittest.SkipTest("Need pytorch version >= 1.6 due to reduce_scatter")
if not torch.cuda.is_available(): if not torch.cuda.is_available():
......
from fairscale.utils import torch_version
def test_torch_version():
assert torch_version("") == tuple()
assert torch_version("bad format") == tuple()
assert torch_version("1.9.0") == (1, 9, 0)
assert torch_version("1.10.0a0+gitbc6fc3e") == (1, 10, 0)
assert torch_version("1.7.0+cu102") == (1, 7, 0)
assert torch_version("1.10.0a0+fb") == (1, 10, 0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment