Unverified Commit bc1e60e0 authored by Pavel Belevich's avatar Pavel Belevich Committed by GitHub
Browse files

Fix pytorch version check (#716)

parent 00ec9ff1
...@@ -11,6 +11,7 @@ from torch import Tensor, nn ...@@ -11,6 +11,7 @@ from torch import Tensor, nn
from torch.distributed import rpc from torch.distributed import rpc
from fairscale.nn.pipe import microbatch from fairscale.nn.pipe import microbatch
from fairscale.utils import torch_version
from .data import DataConsumer from .data import DataConsumer
from .graph import Node, PipelineModulesGraph from .graph import Node, PipelineModulesGraph
...@@ -20,7 +21,7 @@ Device = Union[torch.device, int, str] ...@@ -20,7 +21,7 @@ Device = Union[torch.device, int, str]
def check_pytorch_version() -> None: def check_pytorch_version() -> None:
if list(map(int, torch.__version__.split("+")[0].split(".")[:2])) < [1, 9]: if torch_version() < (1, 9, 0):
raise Exception("DistributedPipeline requires PyTorch version 1.9 or higher") raise Exception("DistributedPipeline requires PyTorch version 1.9 or higher")
......
...@@ -11,6 +11,7 @@ import torch.distributed as dist ...@@ -11,6 +11,7 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from fairscale.nn.checkpoint import is_checkpointing, is_recomputing from fairscale.nn.checkpoint import is_checkpointing, is_recomputing
from fairscale.utils import torch_version
def _forward(input: Tensor, affine: bool, mean: Tensor, invstd: Tensor, weight: Tensor, bias: Tensor) -> Tensor: def _forward(input: Tensor, affine: bool, mean: Tensor, invstd: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
...@@ -45,7 +46,7 @@ def _calculate_stats(input: Tensor, eps: float, process_group: ProcessGroup) -> ...@@ -45,7 +46,7 @@ def _calculate_stats(input: Tensor, eps: float, process_group: ProcessGroup) ->
return mean, var, invstd, total_count return mean, var, invstd, total_count
if torch.__version__.split(".")[:2] >= ["1", "7"]: if torch_version()[:2] >= (1, 7):
_forward = torch.jit.script(_forward) # type: ignore _forward = torch.jit.script(_forward) # type: ignore
_track_running_stats = torch.jit.script(_track_running_stats) # type: ignore _track_running_stats = torch.jit.script(_track_running_stats) # type: ignore
......
...@@ -27,6 +27,8 @@ from torch import Tensor, nn ...@@ -27,6 +27,8 @@ from torch import Tensor, nn
import torch.autograd import torch.autograd
import torch.cuda import torch.cuda
from fairscale.utils import torch_version
from . import microbatch from . import microbatch
from .batchnorm import DeferredBatchNorm from .batchnorm import DeferredBatchNorm
from .pipeline import Pipeline from .pipeline import Pipeline
...@@ -256,7 +258,7 @@ class Pipe(Module): ...@@ -256,7 +258,7 @@ class Pipe(Module):
) -> None: ) -> None:
super().__init__() super().__init__()
if torch.__version__.split(".")[:2] >= ["1", "8"]: if torch_version()[:2] >= (1, 8):
warnings.warn( warnings.warn(
"fairscale.nn.Pipe has been upstreamed to PyTorch as torch.distributed.pipeline.sync.Pipe. " "fairscale.nn.Pipe has been upstreamed to PyTorch as torch.distributed.pipeline.sync.Pipe. "
"It is now deprecated and will be removed in a future version of fairscale. " "It is now deprecated and will be removed in a future version of fairscale. "
......
...@@ -3,6 +3,4 @@ ...@@ -3,6 +3,4 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import List from .version import *
__all__: List[str] = []
...@@ -51,6 +51,7 @@ import torch.nn as nn ...@@ -51,6 +51,7 @@ import torch.nn as nn
from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel from fairscale.nn.model_parallel import destroy_model_parallel, initialize_model_parallel
from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed
from fairscale.utils import torch_version
if TYPE_CHECKING: if TYPE_CHECKING:
Base = nn.Module[Tensor] Base = nn.Module[Tensor]
...@@ -105,23 +106,6 @@ def set_random_seed(seed: int) -> None: ...@@ -105,23 +106,6 @@ def set_random_seed(seed: int) -> None:
model_parallel_cuda_manual_seed(seed) model_parallel_cuda_manual_seed(seed)
def torch_version() -> Tuple[int, ...]:
numbering = torch.__version__.split("+")[0].split(".")[:3]
# Catch torch version if run against internal pre-releases, like `1.8.0a0fb`,
if not numbering[2].isnumeric():
# Two options here:
# - either skip this version (minor number check is not relevant)
# - or check that our codebase is not broken by this ongoing development.
# Assuming that we're interested in the second usecase more than the first,
# return the pre-release or dev numbering
logging.warning(f"Pytorch pre-release version {torch.__version__} - assuming intent to test it")
numbering[2] = "0"
return tuple(int(n) for n in numbering)
# Global variable to cache the results from the first nvidia-smi execution. # Global variable to cache the results from the first nvidia-smi execution.
_smi_ver: Optional[str] = None _smi_ver: Optional[str] = None
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import re
from typing import List, Tuple
import torch
__all__: List[str] = ["torch_version"]
def torch_version(version: str = torch.__version__) -> Tuple[int, ...]:
numbering = re.search(r"^(\d+).(\d+).(\d+)([^\+]*)(\+\S*)?$", version)
if not numbering:
return tuple()
# Catch torch version if run against internal pre-releases, like `1.8.0a0fb`,
if numbering.group(4):
# Two options here:
# - either skip this version (minor number check is not relevant)
# - or check that our codebase is not broken by this ongoing development.
# Assuming that we're interested in the second use-case more than the first,
# return the pre-release or dev numbering
logging.warning(f"Pytorch pre-release version {version} - assuming intent to test it")
return tuple(int(numbering.group(n)) for n in range(1, 4))
...@@ -7,6 +7,7 @@ tests/utils/test_reduce_scatter_bucketer.py ...@@ -7,6 +7,7 @@ tests/utils/test_reduce_scatter_bucketer.py
tests/utils/test_containers.py tests/utils/test_containers.py
tests/utils/test_parallel.py tests/utils/test_parallel.py
tests/utils/test_state_dict.py tests/utils/test_state_dict.py
tests/utils/test_version.py
tests/nn/checkpoint/test_checkpoint_activations.py tests/nn/checkpoint/test_checkpoint_activations.py
tests/nn/checkpoint/test_checkpoint_activations_norm.py tests/nn/checkpoint/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py tests/nn/misc/test_grad_bucket.py
......
...@@ -14,7 +14,7 @@ import torch ...@@ -14,7 +14,7 @@ import torch
import torch.nn import torch.nn
import torch.nn as nn import torch.nn as nn
from fairscale.utils.testing import torch_version from fairscale.utils import torch_version
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
......
...@@ -21,7 +21,7 @@ import torch.multiprocessing as mp ...@@ -21,7 +21,7 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from fairscale.experimental.nn.distributed_pipeline import DistributedLoss, DistributedPipeline, PipelineModulesGraph from fairscale.experimental.nn.distributed_pipeline import DistributedLoss, DistributedPipeline, PipelineModulesGraph
from fairscale.utils.testing import torch_version from fairscale.utils import torch_version
CPU_DEVICES = ["worker0/cpu", "worker1/cpu"] CPU_DEVICES = ["worker0/cpu", "worker1/cpu"]
GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"] GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"]
......
...@@ -15,7 +15,8 @@ import pytest ...@@ -15,7 +15,8 @@ import pytest
import torch import torch
from fairscale.experimental.nn.offload import OffloadModel from fairscale.experimental.nn.offload import OffloadModel
from fairscale.utils.testing import skip_if_no_cuda, torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import skip_if_no_cuda
if torch_version() >= (1, 8, 0): if torch_version() >= (1, 8, 0):
from fairscale.experimental.nn.auto_shard import shard_model from fairscale.experimental.nn.auto_shard import shard_model
......
...@@ -12,7 +12,8 @@ from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper ...@@ -12,7 +12,8 @@ from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.misc import checkpoint_wrapper as deprecated_checkpoint_wrapper from fairscale.nn.misc import checkpoint_wrapper as deprecated_checkpoint_wrapper
from fairscale.utils.testing import skip_if_no_cuda, torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import skip_if_no_cuda
def get_cuda_mem_allocated(): def get_cuda_mem_allocated():
......
...@@ -15,7 +15,8 @@ from torch.nn import BatchNorm2d, LayerNorm, Linear, Sequential ...@@ -15,7 +15,8 @@ from torch.nn import BatchNorm2d, LayerNorm, Linear, Sequential
from torch.optim import SGD from torch.optim import SGD
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.utils.testing import objects_are_equal, torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import objects_are_equal
NORM_TYPES = [LayerNorm, BatchNorm2d] NORM_TYPES = [LayerNorm, BatchNorm2d]
MP_TYPES = ["fp32", "fp16", "call_half"] MP_TYPES = ["fp32", "fp16", "call_half"]
......
...@@ -19,6 +19,7 @@ import torch.distributed ...@@ -19,6 +19,7 @@ import torch.distributed
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState from fairscale.nn.data_parallel import FullyShardedDataParallel, TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import ( from fairscale.utils.testing import (
DeviceAndTypeCheckModule, DeviceAndTypeCheckModule,
DummyProcessGroup, DummyProcessGroup,
...@@ -26,7 +27,6 @@ from fairscale.utils.testing import ( ...@@ -26,7 +27,6 @@ from fairscale.utils.testing import (
get_cycles_per_ms, get_cycles_per_ms,
objects_are_equal, objects_are_equal,
spawn_for_all_world_sizes, spawn_for_all_world_sizes,
torch_version,
) )
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 # How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
......
...@@ -18,7 +18,8 @@ from torch.optim import SGD ...@@ -18,7 +18,8 @@ from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState from fairscale.nn.data_parallel import TrainingState
from fairscale.utils.testing import dist_init, rmf, skip_if_no_cuda, teardown, torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, rmf, skip_if_no_cuda, teardown
# A fixture to get tempfiles and ensure they are cleaned up. # A fixture to get tempfiles and ensure they are cleaned up.
......
...@@ -21,15 +21,9 @@ import torch.optim as optim ...@@ -21,15 +21,9 @@ import torch.optim as optim
from fairscale.nn import checkpoint_wrapper from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.utils import torch_version
from fairscale.utils.parallel import get_process_group_cached from fairscale.utils.parallel import get_process_group_cached
from fairscale.utils.testing import ( from fairscale.utils.testing import dist_init, dump_all_tensors, skip_if_single_gpu, teardown, temp_files_ctx
dist_init,
dump_all_tensors,
skip_if_single_gpu,
teardown,
temp_files_ctx,
torch_version,
)
def to_fsdp(module, fsdp_config): def to_fsdp(module, fsdp_config):
......
...@@ -19,7 +19,8 @@ from torch.optim import SGD ...@@ -19,7 +19,8 @@ from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState from fairscale.nn.data_parallel import TrainingState
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
......
...@@ -24,7 +24,8 @@ from fairscale.nn import checkpoint_wrapper ...@@ -24,7 +24,8 @@ from fairscale.nn import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import auto_wrap_bn from fairscale.nn.data_parallel import auto_wrap_bn
from fairscale.nn.wrap import enable_wrap, wrap from fairscale.nn.wrap import enable_wrap, wrap
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx, torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_single_gpu, teardown, temp_files_ctx
class Model(nn.Module): class Model(nn.Module):
......
...@@ -19,7 +19,8 @@ from torch.optim import SGD ...@@ -19,7 +19,8 @@ from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState from fairscale.nn.data_parallel import TrainingState
from fairscale.utils.testing import dist_init, skip_if_no_cuda, teardown, torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, skip_if_no_cuda, teardown
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): def _test_func(rank, world_size, fsdp_config, tempfile_name, unused):
......
...@@ -21,14 +21,8 @@ import torch.nn as nn ...@@ -21,14 +21,8 @@ import torch.nn as nn
from fairscale.nn import enable_wrap, wrap from fairscale.nn import enable_wrap, wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import ( from fairscale.utils import torch_version
dist_init, from fairscale.utils.testing import dist_init, get_cycles_per_ms, skip_if_single_gpu, teardown, temp_files_ctx
get_cycles_per_ms,
skip_if_single_gpu,
teardown,
temp_files_ctx,
torch_version,
)
class Layer(nn.Module): class Layer(nn.Module):
......
...@@ -36,6 +36,7 @@ from torch.optim import SGD ...@@ -36,6 +36,7 @@ from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
from fairscale.optim.grad_scaler import ShardedGradScaler from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils import torch_version
from fairscale.utils.testing import ( from fairscale.utils.testing import (
dist_init, dist_init,
objects_are_equal, objects_are_equal,
...@@ -44,7 +45,6 @@ from fairscale.utils.testing import ( ...@@ -44,7 +45,6 @@ from fairscale.utils.testing import (
state_dict_norm, state_dict_norm,
teardown, teardown,
torch_cuda_version, torch_cuda_version,
torch_version,
) )
# Const test params. # Const test params.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment