Unverified Commit bc1e60e0 authored by Pavel Belevich's avatar Pavel Belevich Committed by GitHub
Browse files

Fix pytorch version check (#716)

parent 00ec9ff1
......@@ -11,6 +11,7 @@ from torch import Tensor, nn
from torch.distributed import rpc
from fairscale.nn.pipe import microbatch
from fairscale.utils import torch_version
from .data import DataConsumer
from .graph import Node, PipelineModulesGraph
......@@ -20,7 +21,7 @@ Device = Union[torch.device, int, str]
def check_pytorch_version() -> None:
if list(map(int, torch.__version__.split("+")[0].split(".")[:2])) < [1, 9]:
if torch_version() < (1, 9, 0):
raise Exception("DistributedPipeline requires PyTorch version 1.9 or higher")
......
......@@ -11,6 +11,7 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup
from fairscale.nn.checkpoint import is_checkpointing, is_recomputing
from fairscale.utils import torch_version
def _forward(input: Tensor, affine: bool, mean: Tensor, invstd: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
......@@ -45,7 +46,7 @@ def _calculate_stats(input: Tensor, eps: float, process_group: ProcessGroup) ->
return mean, var, invstd, total_count
if torch.__version__.split(".")[:2] >= ["1", "7"]:
if torch_version()[:2] >= (1, 7):
_forward = torch.jit.script(_forward) # type: ignore
_track_running_stats = torch.jit.script(_track_running_stats) # type: ignore
......
......@@ -27,6 +27,8 @@ from torch import Tensor, nn
import torch.autograd
import torch.cuda
from fairscale.utils import torch_version
from . import microbatch
from .batchnorm import DeferredBatchNorm
from .pipeline import Pipeline
......@@ -256,7 +258,7 @@ class Pipe(Module):
) -> None:
super().__init__()
if torch.__version__.split(".")[:2] >= ["1", "8"]:
if torch_version()[:2] >= (1, 8):
warnings.warn(
"fairscale.nn.Pipe has been upstreamed to PyTorch as torch.distributed.pipeline.sync.Pipe. "
"It is now deprecated and will be removed in a future version of fairscale. "
......
......@@ -3,6 +3,4 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
__all__: List[str] = []
from .version import *
......@@ -51,6 +51,7 @@ import torch.nn as nn
from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
from fairscale.utils import torch_version
if TYPE_CHECKING:
Base = nn.Module[Tensor]
......@@ -105,23 +106,6 @@ def set_random_seed(seed: int) -> None:
model_parallel_cuda_manual_seed(seed)
def torch_version() -> Tuple[int, ...]:
numbering = torch.__version__.split("+")[0].split(".")[:3]
# Catch torch version if run against internal pre-releases, like `1.8.0a0fb`,
if not numbering[2].isnumeric():
# Two options here:
# - either skip this version (minor number check is not relevant)
# - or check that our codebase is not broken by this ongoing development.
# Assuming that we're interested in the second usecase more than the first,
# return the pre-release or dev numbering
logging.warning(f"Pytorch pre-release version {torch.__version__} - assuming intent to test it")
numbering[2] = "0"
return tuple(int(n) for n in numbering)
# Global variable to cache the results from the first nvidia-smi execution.
_smi_ver: Optional[str] = None
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import re
from typing import List, Tuple
import torch
__all__: List[str] = ["torch_version"]
def torch_version(version: str = torch.__version__) -> Tuple[int, ...]:
numbering = re.search(r"^(\d+).(\d+).(\d+)([^\+]*)(\+\S*)?$", version)
if not numbering:
return tuple()
# Catch torch version if run against internal pre-releases, like `1.8.0a0fb`,
if numbering.group(4):
# Two options here:
# - either skip this version (minor number check is not relevant)
# - or check that our codebase is not broken by this ongoing development.
# Assuming that we're interested in the second use-case more than the first,
# return the pre-release or dev numbering
logging.warning(f"Pytorch pre-release version {version} - assuming intent to test it")
return tuple(int(numbering.group(n)) for n in range(1, 4))
......@@ -7,6 +7,7 @@ tests/utils/test_reduce_scatter_bucketer.py
tests/utils/test_containers.py
tests/utils/test_parallel.py
tests/utils/test_state_dict.py
tests/utils/test_version.py
tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py
......
......@@ -14,7 +14,7 @@ import torch
import torch.nn
import torch.nn as nn
from fairscale.utils.testing import torch_version
from fairscale.utils import torch_version
class PositionalEncoding(nn.Module):
......
......@@ -21,7 +21,7 @@ import torch.multiprocessing as mp
import torch.nn as nn
from fairscale.experimental.nn.distributed_pipeline import DistributedLoss, DistributedPipeline, PipelineModulesGraph
from fairscale.utils.testing import torch_version
from fairscale.utils import torch_version
CPU_DEVICES = ["worker0/cpu", "worker1/cpu"]
GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"]
......
......@@ -15,7 +15,8 @@ import pytest
import torch
from fairscale.experimental.nn.offload import OffloadModel
from fairscale.utils.testing import skip_if_no_cuda, torch_version
from fairscale.utils import torch_version
from fairscale.utils.testing import skip_if_no_cuda
if torch_version() >= (1, 8, 0):
from fairscale.experimental.nn.auto_shard import shard_model
......
......@@ -12,7 +12,8 @@ from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.misc import checkpoint_wrapper as deprecated_checkpoint_wrapper
from fairscale.utils.testing import skip_if_no_cuda, torch_version
from fairscale.utils import torch_version
from fairscale.utils.testing import skip_if_no_cuda
def get_cuda_mem_allocated():
......
......@@ -15,7 +15,8 @@ from torch.nn import BatchNorm2d, LayerNorm, Linear, Sequential
from torch.optim import SGD
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.utils.testing import objects_are_equal, torch_version
from fairscale.utils import torch_version
from fairscale.utils.testing import objects_are_equal
NORM_TYPES = [LayerNorm, BatchNorm2d]
MP_TYPES = ["fp32", "fp16", "call_half"]
......
......@@ -19,6 +19,7 @@ import torch.distributed
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import (
DeviceAndTypeCheckModule,
DummyProcessGroup,
......@@ -26,7 +27,6 @@ from fairscale.utils.testing import (
get_cycles_per_ms,
objects_are_equal,
spawn_for_all_world_sizes,
torch_version,
)
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
......
......@@ -18,7 +18,8 @@ from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState
from fairscale.utils.testing import dist_init, rmf, skip_if_no_cuda, teardown, torch_version
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, rmf, skip_if_no_cuda, teardown
# A fixture to get tempfiles and ensure they are cleaned up.
......
......@@ -21,15 +21,9 @@ import torch.optim as optim
from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.utils import torch_version
from fairscale.utils.parallel import get_process_group_cached
from fairscale.utils.testing import (
dist_init,
dump_all_tensors,
skip_if_single_gpu,
teardown,
temp_files_ctx,
torch_version,
)
from fairscale.utils.testing import dist_init, dump_all_tensors, skip_if_single_gpu, teardown, temp_files_ctx
def to_fsdp(module, fsdp_config):
......
......@@ -19,7 +19,8 @@ from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
......
......@@ -24,7 +24,8 @@ from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.nn.wrap import enable_wrap, wrap
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx, torch_version
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx
class Model(nn.Module):
......
......@@ -19,7 +19,8 @@ from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState
from fairscale.utils.testing import dist_init, skip_if_no_cuda, teardown, torch_version
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_no_cuda, teardown
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
......
......@@ -21,14 +21,8 @@ import torch.nn as nn
from fairscale.nn import enable_wrap, wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import (
dist_init,
get_cycles_per_ms,
skip_if_single_gpu,
teardown,
temp_files_ctx,
torch_version,
)
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, get_cycles_per_ms, skip_if_single_gpu, teardown, temp_files_ctx
class Layer(nn.Module):
......
......@@ -36,6 +36,7 @@ from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils import torch_version
from fairscale.utils.testing import (
dist_init,
objects_are_equal,
......@@ -44,7 +45,6 @@ from fairscale.utils.testing import (
state_dict_norm,
teardown,
torch_cuda_version,
torch_version,
)
# Const test params.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment