Unverified Commit ef48a473 authored by Haodong Duan's avatar Haodong Duan Committed by GitHub
Browse files

[Improvement] Improve digit_version & use it for version_checking (#1185)

* improve digit_version & use it for version_checking

* more testing for digit_version

* setuptools >= 50 is needed

* fix CI

* add debuging log

* >= to ==

* fix lint

* remove

* add failure case

* replace

* fix

* consider TORCH_VERSION == 'parrots'

* add unittest

* digit_version do not deal with the case if 'parrots' in version name.
parent c06be0d5
...@@ -41,6 +41,8 @@ jobs: ...@@ -41,6 +41,8 @@ jobs:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install system dependencies - name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg
- name: Upgrade Setuptools
run: pip install setuptools==52
- name: Build and install - name: Build and install
run: rm -rf .eggs && pip install -e . run: rm -rf .eggs && pip install -e .
- name: Validate the installation - name: Validate the installation
...@@ -75,6 +77,8 @@ jobs: ...@@ -75,6 +77,8 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg
- name: Install PyTorch - name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Upgrade Setuptools
run: pip install setuptools==52
- name: Build and install - name: Build and install
run: rm -rf .eggs && pip install -e . run: rm -rf .eggs && pip install -e .
- name: Validate the installation - name: Validate the installation
...@@ -118,6 +122,8 @@ jobs: ...@@ -118,6 +122,8 @@ jobs:
if: ${{matrix.torchvision == '0.4.2'}} if: ${{matrix.torchvision == '0.4.2'}}
- name: Install PyTorch - name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Upgrade Setuptools
run: pip install setuptools==52
- name: Build and install - name: Build and install
run: rm -rf .eggs && pip install -e . run: rm -rf .eggs && pip install -e .
- name: Validate the installation - name: Validate the installation
...@@ -188,6 +194,8 @@ jobs: ...@@ -188,6 +194,8 @@ jobs:
run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
- name: Install system dependencies - name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg
- name: Upgrade Setuptools
run: pip install setuptools==52
- name: Build and install - name: Build and install
run: rm -rf .eggs && pip install -e . run: rm -rf .eggs && pip install -e .
- name: Validate the installation - name: Validate the installation
...@@ -258,6 +266,8 @@ jobs: ...@@ -258,6 +266,8 @@ jobs:
run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
- name: Install system dependencies - name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg
- name: Upgrade Setuptools
run: pip install setuptools==52
- name: Build and install - name: Build and install
run: rm -rf .eggs && pip install -e . run: rm -rf .eggs && pip install -e .
- name: Validate the installation - name: Validate the installation
...@@ -310,6 +320,8 @@ jobs: ...@@ -310,6 +320,8 @@ jobs:
if: ${{matrix.torchvision == '0.4.2'}} if: ${{matrix.torchvision == '0.4.2'}}
- name: Install PyTorch - name: Install PyTorch
run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} --no-cache-dir run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} --no-cache-dir
- name: Upgrade Setuptools
run: pip install setuptools==52
- name: Build and install - name: Build and install
run: | run: |
rm -rf .eggs rm -rf .eggs
......
from distutils.version import LooseVersion
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.utils import TORCH_VERSION, build_from_cfg from mmcv.utils import TORCH_VERSION, build_from_cfg, digit_version
from .registry import ACTIVATION_LAYERS from .registry import ACTIVATION_LAYERS
for module in [ for module in [
...@@ -73,7 +71,7 @@ class GELU(nn.Module): ...@@ -73,7 +71,7 @@ class GELU(nn.Module):
if (TORCH_VERSION == 'parrots' if (TORCH_VERSION == 'parrots'
or LooseVersion(TORCH_VERSION) < LooseVersion('1.4')): or digit_version(TORCH_VERSION) < digit_version('1.4')):
ACTIVATION_LAYERS.register_module(module=GELU) ACTIVATION_LAYERS.register_module(module=GELU)
else: else:
ACTIVATION_LAYERS.register_module(module=nn.GELU) ACTIVATION_LAYERS.register_module(module=nn.GELU)
......
from distutils.version import LooseVersion
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init from mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
from mmcv.ops.deform_conv import deform_conv2d from mmcv.ops.deform_conv import deform_conv2d
from mmcv.utils import TORCH_VERSION from mmcv.utils import TORCH_VERSION, digit_version
@CONV_LAYERS.register_module(name='SAC') @CONV_LAYERS.register_module(name='SAC')
...@@ -108,10 +106,10 @@ class SAConv2d(ConvAWS2d): ...@@ -108,10 +106,10 @@ class SAConv2d(ConvAWS2d):
out_s = deform_conv2d(x, offset, weight, self.stride, self.padding, out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1) self.dilation, self.groups, 1)
else: else:
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0') if (TORCH_VERSION == 'parrots'
or TORCH_VERSION == 'parrots'): or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
out_s = super().conv2d_forward(x, weight) out_s = super().conv2d_forward(x, weight)
elif LooseVersion(TORCH_VERSION) >= LooseVersion('1.8.0'): elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
# bias is a required argument of _conv_forward in torch 1.8.0 # bias is a required argument of _conv_forward in torch 1.8.0
out_s = super()._conv_forward(x, weight, zero_bias) out_s = super()._conv_forward(x, weight, zero_bias)
else: else:
...@@ -126,10 +124,10 @@ class SAConv2d(ConvAWS2d): ...@@ -126,10 +124,10 @@ class SAConv2d(ConvAWS2d):
out_l = deform_conv2d(x, offset, weight, self.stride, self.padding, out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1) self.dilation, self.groups, 1)
else: else:
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0') if (TORCH_VERSION == 'parrots'
or TORCH_VERSION == 'parrots'): or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
out_l = super().conv2d_forward(x, weight) out_l = super().conv2d_forward(x, weight)
elif LooseVersion(TORCH_VERSION) >= LooseVersion('1.8.0'): elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
# bias is a required argument of _conv_forward in torch 1.8.0 # bias is a required argument of _conv_forward in torch 1.8.0
out_l = super()._conv_forward(x, weight, zero_bias) out_l = super()._conv_forward(x, weight, zero_bias)
else: else:
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from distutils.version import LooseVersion
import torch import torch
from torch.nn.parallel.distributed import (DistributedDataParallel, from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors) _find_tensors)
from mmcv import print_log from mmcv import print_log
from mmcv.utils import TORCH_VERSION from mmcv.utils import TORCH_VERSION, digit_version
from .scatter_gather import scatter_kwargs from .scatter_gather import scatter_kwargs
...@@ -39,8 +37,9 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -39,8 +37,9 @@ class MMDistributedDataParallel(DistributedDataParallel):
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward. # end of backward to the beginning of forward.
if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots' if ('parrots' not in TORCH_VERSION
not in TORCH_VERSION) and self.reducer._rebuild_buckets(): and digit_version(TORCH_VERSION) >= digit_version('1.7')
and self.reducer._rebuild_buckets()):
print_log( print_log(
'Reducer buckets have been rebuilt in this iteration.', 'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv') logger='mmcv')
...@@ -65,7 +64,8 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -65,7 +64,8 @@ class MMDistributedDataParallel(DistributedDataParallel):
else: else:
self.reducer.prepare_for_backward([]) self.reducer.prepare_for_backward([])
else: else:
if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'): if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) > digit_version('1.2')):
self.require_forward_param_sync = False self.require_forward_param_sync = False
return output return output
...@@ -79,8 +79,9 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -79,8 +79,9 @@ class MMDistributedDataParallel(DistributedDataParallel):
""" """
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward. # end of backward to the beginning of forward.
if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots' if ('parrots' not in TORCH_VERSION
not in TORCH_VERSION) and self.reducer._rebuild_buckets(): and digit_version(TORCH_VERSION) >= digit_version('1.7')
and self.reducer._rebuild_buckets()):
print_log( print_log(
'Reducer buckets have been rebuilt in this iteration.', 'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv') logger='mmcv')
...@@ -105,6 +106,7 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -105,6 +106,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
else: else:
self.reducer.prepare_for_backward([]) self.reducer.prepare_for_backward([])
else: else:
if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'): if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) > digit_version('1.2')):
self.require_forward_param_sync = False self.require_forward_param_sync = False
return output return output
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from distutils.version import LooseVersion
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch._utils import (_flatten_dense_tensors, _take_tensors, from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors) _unflatten_dense_tensors)
from mmcv.utils import TORCH_VERSION from mmcv.utils import TORCH_VERSION, digit_version
from .registry import MODULE_WRAPPERS from .registry import MODULE_WRAPPERS
from .scatter_gather import scatter_kwargs from .scatter_gather import scatter_kwargs
...@@ -42,7 +40,8 @@ class MMDistributedDataParallel(nn.Module): ...@@ -42,7 +40,8 @@ class MMDistributedDataParallel(nn.Module):
self._dist_broadcast_coalesced(module_states, self._dist_broadcast_coalesced(module_states,
self.broadcast_bucket_size) self.broadcast_bucket_size)
if self.broadcast_buffers: if self.broadcast_buffers:
if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'): if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) < digit_version('1.0')):
buffers = [b.data for b in self.module._all_buffers()] buffers = [b.data for b in self.module._all_buffers()]
else: else:
buffers = [b.data for b in self.module.buffers()] buffers = [b.data for b in self.module.buffers()]
......
...@@ -3,7 +3,6 @@ import functools ...@@ -3,7 +3,6 @@ import functools
import os import os
import subprocess import subprocess
from collections import OrderedDict from collections import OrderedDict
from distutils.version import LooseVersion
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -11,7 +10,7 @@ from torch import distributed as dist ...@@ -11,7 +10,7 @@ from torch import distributed as dist
from torch._utils import (_flatten_dense_tensors, _take_tensors, from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors) _unflatten_dense_tensors)
from mmcv.utils import TORCH_VERSION from mmcv.utils import TORCH_VERSION, digit_version
def init_dist(launcher, backend='nccl', **kwargs): def init_dist(launcher, backend='nccl', **kwargs):
...@@ -79,7 +78,8 @@ def _init_dist_slurm(backend, port=None): ...@@ -79,7 +78,8 @@ def _init_dist_slurm(backend, port=None):
def get_dist_info(): def get_dist_info():
if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'): if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) < digit_version('1.0')):
initialized = dist._initialized initialized = dist._initialized
else: else:
if dist.is_available(): if dist.is_available():
......
import functools import functools
import warnings import warnings
from collections import abc from collections import abc
from distutils.version import LooseVersion
from inspect import getfullargspec from inspect import getfullargspec
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmcv.utils import TORCH_VERSION from mmcv.utils import TORCH_VERSION, digit_version
from .dist_utils import allreduce_grads as _allreduce_grads from .dist_utils import allreduce_grads as _allreduce_grads
try: try:
...@@ -122,8 +121,8 @@ def auto_fp16(apply_to=None, out_fp32=False): ...@@ -122,8 +121,8 @@ def auto_fp16(apply_to=None, out_fp32=False):
else: else:
new_kwargs[arg_name] = arg_value new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method # apply converted arguments to the decorated method
if (TORCH_VERSION != 'parrots' if (TORCH_VERSION != 'parrots' and
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')): digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True): with autocast(enabled=True):
output = old_func(*new_args, **new_kwargs) output = old_func(*new_args, **new_kwargs)
else: else:
...@@ -208,8 +207,8 @@ def force_fp32(apply_to=None, out_fp16=False): ...@@ -208,8 +207,8 @@ def force_fp32(apply_to=None, out_fp16=False):
else: else:
new_kwargs[arg_name] = arg_value new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method # apply converted arguments to the decorated method
if (TORCH_VERSION != 'parrots' if (TORCH_VERSION != 'parrots' and
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')): digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=False): with autocast(enabled=False):
output = old_func(*new_args, **new_kwargs) output = old_func(*new_args, **new_kwargs)
else: else:
...@@ -249,7 +248,7 @@ def wrap_fp16_model(model): ...@@ -249,7 +248,7 @@ def wrap_fp16_model(model):
model (nn.Module): Model in FP32. model (nn.Module): Model in FP32.
""" """
if (TORCH_VERSION == 'parrots' if (TORCH_VERSION == 'parrots'
or LooseVersion(TORCH_VERSION) < LooseVersion('1.6.0')): or digit_version(TORCH_VERSION) < digit_version('1.6.0')):
# convert model to fp16 # convert model to fp16
model.half() model.half()
# patch the normalization layers to make it work in fp32 mode # patch the normalization layers to make it work in fp32 mode
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp import os.path as osp
from distutils.version import LooseVersion
from mmcv.utils import TORCH_VERSION from mmcv.utils import TORCH_VERSION, digit_version
from ...dist_utils import master_only from ...dist_utils import master_only
from ..hook import HOOKS from ..hook import HOOKS
from .base import LoggerHook from .base import LoggerHook
...@@ -24,8 +23,8 @@ class TensorboardLoggerHook(LoggerHook): ...@@ -24,8 +23,8 @@ class TensorboardLoggerHook(LoggerHook):
@master_only @master_only
def before_run(self, runner): def before_run(self, runner):
super(TensorboardLoggerHook, self).before_run(runner) super(TensorboardLoggerHook, self).before_run(runner)
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.1') if (TORCH_VERSION == 'parrots'
or TORCH_VERSION == 'parrots'): or digit_version(TORCH_VERSION) < digit_version('1.1')):
try: try:
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
except ImportError: except ImportError:
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import copy import copy
from collections import defaultdict from collections import defaultdict
from distutils.version import LooseVersion
from itertools import chain from itertools import chain
from torch.nn.utils import clip_grad from torch.nn.utils import clip_grad
from mmcv.utils import TORCH_VERSION from mmcv.utils import TORCH_VERSION, digit_version
from ..dist_utils import allreduce_grads from ..dist_utils import allreduce_grads
from ..fp16_utils import LossScaler, wrap_fp16_model from ..fp16_utils import LossScaler, wrap_fp16_model
from .hook import HOOKS, Hook from .hook import HOOKS, Hook
...@@ -44,7 +43,7 @@ class OptimizerHook(Hook): ...@@ -44,7 +43,7 @@ class OptimizerHook(Hook):
if (TORCH_VERSION != 'parrots' if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')): and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
@HOOKS.register_module() @HOOKS.register_module()
class Fp16OptimizerHook(OptimizerHook): class Fp16OptimizerHook(OptimizerHook):
......
from functools import partial from functools import partial
from pkg_resources import parse_version
import torch import torch
from mmcv.utils import digit_version
TORCH_VERSION = torch.__version__ TORCH_VERSION = torch.__version__
is_rocm_pytorch = False is_rocm_pytorch = False
if parse_version(TORCH_VERSION) >= parse_version('1.5'): if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.5')):
from torch.utils.cpp_extension import ROCM_HOME from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and is_rocm_pytorch = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False (ROCM_HOME is not None)) else False
......
import warnings import warnings
from distutils.version import LooseVersion
import torch import torch
from mmcv.utils import digit_version
def is_jit_tracing() -> bool: def is_jit_tracing() -> bool:
if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'): if (torch.__version__ != 'parrots'
and digit_version(torch.__version__) >= digit_version('1.6.0')):
on_trace = torch.jit.is_tracing() on_trace = torch.jit.is_tracing()
# In PyTorch 1.6, torch.jit.is_tracing has a bug. # In PyTorch 1.6, torch.jit.is_tracing has a bug.
# Refers to https://github.com/pytorch/pytorch/issues/42448 # Refers to https://github.com/pytorch/pytorch/issues/42448
......
import os import os
import subprocess import subprocess
import warnings
from pkg_resources import parse_version
def digit_version(version_str): def digit_version(version_str: str, length: int = 4):
"""Convert a version string into a tuple of integers. """Convert a version string into a tuple of integers.
This method is usually used for comparing two versions. This method is usually used for comparing two versions. For pre-release
versions: alpha < beta < rc.
Args: Args:
version_str (str): The version string. version_str (str): The version string.
length (int): The maximum number of version levels. Default: 4.
Returns: Returns:
tuple[int]: The version info in digits (integers). tuple[int]: The version info in digits (integers).
""" """
digit_version = [] assert 'parrots' not in version_str
for x in version_str.split('.'): version = parse_version(version_str)
if x.isdigit(): assert version.release, f'failed to parse version {version_str}'
digit_version.append(int(x)) release = list(version.release)
elif x.find('rc') != -1: release = release[:length]
patch_version = x.split('rc') if len(release) < length:
digit_version.append(int(patch_version[0]) - 1) release = release + [0] * (length - len(release))
digit_version.append(int(patch_version[1])) if version.is_prerelease:
return tuple(digit_version) mapping = {'a': -3, 'b': -2, 'rc': -1}
val = -4
# version.pre can be None
if version.pre:
if version.pre[0] not in mapping:
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
'version checking may go wrong')
else:
val = mapping[version.pre[0]]
release.extend([val, version.pre[-1]])
else:
release.extend([val, 0])
elif version.is_postrelease:
release.extend([1, version.post])
else:
release.extend([0, 0])
return tuple(release)
def _minimal_ext_cmd(cmd): def _minimal_ext_cmd(cmd):
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from pkg_resources import parse_version
__version__ = '1.3.9' __version__ = '1.3.9'
def parse_version_info(version_str: str) -> tuple: def parse_version_info(version_str: str, length: int = 4) -> tuple:
"""Parse a version string into a tuple. """Parse a version string into a tuple.
Args: Args:
version_str (str): The version string. version_str (str): The version string.
length (int): The maximum number of version levels. Default: 4.
Returns: Returns:
tuple[int | str]: The version info, e.g., "1.3.0" is parsed into tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
(1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). (1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into
(2, 0, 0, 0, 'rc', 1) (when length is set to 4).
""" """
version_info = [] version = parse_version(version_str)
for x in version_str.split('.'): assert version.release, f'failed to parse version {version_str}'
if x.isdigit(): release = list(version.release)
version_info.append(int(x)) release = release[:length]
elif x.find('rc') != -1: if len(release) < length:
patch_version = x.split('rc') release = release + [0] * (length - len(release))
version_info.append(int(patch_version[0])) if version.is_prerelease:
version_info.append(f'rc{patch_version[1]}') release.extend(list(version.pre))
return tuple(version_info) elif version.is_postrelease:
release.extend(list(version.post))
else:
release.extend([0, 0])
return tuple(release)
version_info = parse_version_info(__version__) version_info = parse_version_info(__version__)
......
...@@ -3,4 +3,5 @@ numpy ...@@ -3,4 +3,5 @@ numpy
Pillow Pillow
pyyaml pyyaml
regex;sys_platform=='win32' regex;sys_platform=='win32'
setuptools>=52
yapf yapf
...@@ -6,4 +6,5 @@ onnxruntime==1.4.0 ...@@ -6,4 +6,5 @@ onnxruntime==1.4.0
pytest pytest
PyTurboJPEG PyTurboJPEG
scipy scipy
setuptools>=52
tiffile tiffile
from distutils.version import LooseVersion
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from mmcv.utils import TORCH_VERSION from mmcv.utils import TORCH_VERSION, digit_version
try: try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
...@@ -144,7 +142,7 @@ class TestDeformconv(object): ...@@ -144,7 +142,7 @@ class TestDeformconv(object):
# test amp when torch version >= '1.6.0', the type of # test amp when torch version >= '1.6.0', the type of
# input data for deformconv might be torch.float or torch.half # input data for deformconv might be torch.float or torch.half
if (TORCH_VERSION != 'parrots' if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')): and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True): with autocast(enabled=True):
self._test_amp_deformconv(torch.float, 1e-1) self._test_amp_deformconv(torch.float, 1e-1)
self._test_amp_deformconv(torch.half, 1e-1) self._test_amp_deformconv(torch.half, 1e-1)
import os import os
from distutils.version import LooseVersion
import numpy import numpy
import torch import torch
from mmcv.utils import TORCH_VERSION from mmcv.utils import TORCH_VERSION, digit_version
try: try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast # If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
...@@ -114,7 +113,7 @@ class TestMdconv(object): ...@@ -114,7 +113,7 @@ class TestMdconv(object):
# test amp when torch version >= '1.6.0', the type of # test amp when torch version >= '1.6.0', the type of
# input data for mdconv might be torch.float or torch.half # input data for mdconv might be torch.float or torch.half
if (TORCH_VERSION != 'parrots' if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')): and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True): with autocast(enabled=True):
self._test_amp_mdconv(torch.float) self._test_amp_mdconv(torch.float)
self._test_amp_mdconv(torch.half) self._test_amp_mdconv(torch.half)
import os import os
import random import random
from pkg_resources import parse_version
import numpy as np import numpy as np
import torch import torch
from mmcv.runner import set_random_seed from mmcv.runner import set_random_seed
from mmcv.utils import TORCH_VERSION from mmcv.utils import TORCH_VERSION, digit_version
is_rocm_pytorch = False is_rocm_pytorch = False
if parse_version(TORCH_VERSION) >= parse_version('1.5'): if digit_version(TORCH_VERSION) >= digit_version('1.5'):
from torch.utils.cpp_extension import ROCM_HOME from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and is_rocm_pytorch = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False (ROCM_HOME is not None)) else False
......
from distutils.version import LooseVersion
import pytest import pytest
import torch import torch
from mmcv.utils import is_jit_tracing from mmcv.utils import digit_version, is_jit_tracing
@pytest.mark.skipif( @pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion('1.6.0'), digit_version(torch.__version__) < digit_version('1.6.0'),
reason='torch.jit.is_tracing is not available before 1.6.0') reason='torch.jit.is_tracing is not available before 1.6.0')
def test_is_jit_tracing(): def test_is_jit_tracing():
......
from unittest.mock import patch from unittest.mock import patch
from mmcv import digit_version, get_git_hash, parse_version_info import pytest
from mmcv import get_git_hash, parse_version_info
from mmcv.utils import digit_version
def test_digit_version(): def test_digit_version():
assert digit_version('0.2.16') == (0, 2, 16) assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0)
assert digit_version('1.2.3') == (1, 2, 3) assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0)
assert digit_version('1.2.3rc0') == (1, 2, 2, 0) assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0)
assert digit_version('1.2.3rc1') == (1, 2, 2, 1) assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1)
assert digit_version('1.0rc0') == (1, -1, 0) assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0)
assert digit_version('1.0') == digit_version('1.0.0')
assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5')
assert digit_version('1.0.0dev') < digit_version('1.0.0a')
assert digit_version('1.0.0a') < digit_version('1.0.0a1')
assert digit_version('1.0.0a') < digit_version('1.0.0b')
assert digit_version('1.0.0b') < digit_version('1.0.0rc')
assert digit_version('1.0.0rc1') < digit_version('1.0.0')
assert digit_version('1.0.0') < digit_version('1.0.0post')
assert digit_version('1.0.0post') < digit_version('1.0.0post1')
assert digit_version('v1') == (1, 0, 0, 0, 0, 0)
assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)
with pytest.raises(AssertionError):
digit_version('a')
with pytest.raises(AssertionError):
digit_version('1x')
with pytest.raises(AssertionError):
digit_version('1.x')
def test_parse_version_info(): def test_parse_version_info():
assert parse_version_info('0.2.16') == (0, 2, 16) assert parse_version_info('0.2.16') == (0, 2, 16, 0, 0, 0)
assert parse_version_info('1.2.3') == (1, 2, 3) assert parse_version_info('1.2.3') == (1, 2, 3, 0, 0, 0)
assert parse_version_info('1.2.3rc0') == (1, 2, 3, 'rc0') assert parse_version_info('1.2.3rc0') == (1, 2, 3, 0, 'rc', 0)
assert parse_version_info('1.2.3rc1') == (1, 2, 3, 'rc1') assert parse_version_info('1.2.3rc1') == (1, 2, 3, 0, 'rc', 1)
assert parse_version_info('1.0rc0') == (1, 0, 'rc0') assert parse_version_info('1.0rc0') == (1, 0, 0, 0, 'rc', 0)
def _mock_cmd_success(cmd): def _mock_cmd_success(cmd):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment