Unverified Commit ef48a473 authored by Haodong Duan's avatar Haodong Duan Committed by GitHub
Browse files

[Improvement] Improve digit_version & use it for version_checking (#1185)

* improve digit_version & use it for version_checking

* more testing for digit_version

* setuptools >= 50 is needed

* fix CI

* add debuging log

* >= to ==

* fix lint

* remove

* add failure case

* replace

* fix

* consider TORCH_VERSION == 'parrots'

* add unittest

* digit_version do not deal with the case if 'parrots' in version name.
parent c06be0d5
......@@ -41,6 +41,8 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg
- name: Upgrade Setuptools
run: pip install setuptools==52
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Validate the installation
......@@ -75,6 +77,8 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Upgrade Setuptools
run: pip install setuptools==52
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Validate the installation
......@@ -118,6 +122,8 @@ jobs:
if: ${{matrix.torchvision == '0.4.2'}}
- name: Install PyTorch
run: pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Upgrade Setuptools
run: pip install setuptools==52
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Validate the installation
......@@ -188,6 +194,8 @@ jobs:
run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg
- name: Upgrade Setuptools
run: pip install setuptools==52
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Validate the installation
......@@ -258,6 +266,8 @@ jobs:
run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} -f https://download.pytorch.org/whl/torch_stable.html
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg libturbojpeg
- name: Upgrade Setuptools
run: pip install setuptools==52
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Validate the installation
......@@ -310,6 +320,8 @@ jobs:
if: ${{matrix.torchvision == '0.4.2'}}
- name: Install PyTorch
run: pip install torch==${{matrix.torch}} torchvision==${{matrix.torchvision}} --no-cache-dir
- name: Upgrade Setuptools
run: pip install setuptools==52
- name: Build and install
run: |
rm -rf .eggs
......
from distutils.version import LooseVersion
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.utils import TORCH_VERSION, build_from_cfg
from mmcv.utils import TORCH_VERSION, build_from_cfg, digit_version
from .registry import ACTIVATION_LAYERS
for module in [
......@@ -73,7 +71,7 @@ class GELU(nn.Module):
if (TORCH_VERSION == 'parrots'
or LooseVersion(TORCH_VERSION) < LooseVersion('1.4')):
or digit_version(TORCH_VERSION) < digit_version('1.4')):
ACTIVATION_LAYERS.register_module(module=GELU)
else:
ACTIVATION_LAYERS.register_module(module=nn.GELU)
......
from distutils.version import LooseVersion
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
from mmcv.ops.deform_conv import deform_conv2d
from mmcv.utils import TORCH_VERSION
from mmcv.utils import TORCH_VERSION, digit_version
@CONV_LAYERS.register_module(name='SAC')
......@@ -108,10 +106,10 @@ class SAConv2d(ConvAWS2d):
out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1)
else:
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
or TORCH_VERSION == 'parrots'):
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
out_s = super().conv2d_forward(x, weight)
elif LooseVersion(TORCH_VERSION) >= LooseVersion('1.8.0'):
elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
# bias is a required argument of _conv_forward in torch 1.8.0
out_s = super()._conv_forward(x, weight, zero_bias)
else:
......@@ -126,10 +124,10 @@ class SAConv2d(ConvAWS2d):
out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1)
else:
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
or TORCH_VERSION == 'parrots'):
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.5.0')):
out_l = super().conv2d_forward(x, weight)
elif LooseVersion(TORCH_VERSION) >= LooseVersion('1.8.0'):
elif digit_version(TORCH_VERSION) >= digit_version('1.8.0'):
# bias is a required argument of _conv_forward in torch 1.8.0
out_l = super()._conv_forward(x, weight, zero_bias)
else:
......
# Copyright (c) Open-MMLab. All rights reserved.
from distutils.version import LooseVersion
import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors)
from mmcv import print_log
from mmcv.utils import TORCH_VERSION
from mmcv.utils import TORCH_VERSION, digit_version
from .scatter_gather import scatter_kwargs
......@@ -39,8 +37,9 @@ class MMDistributedDataParallel(DistributedDataParallel):
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets():
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.7')
and self.reducer._rebuild_buckets()):
print_log(
'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')
......@@ -65,7 +64,8 @@ class MMDistributedDataParallel(DistributedDataParallel):
else:
self.reducer.prepare_for_backward([])
else:
if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'):
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) > digit_version('1.2')):
self.require_forward_param_sync = False
return output
......@@ -79,8 +79,9 @@ class MMDistributedDataParallel(DistributedDataParallel):
"""
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets():
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.7')
and self.reducer._rebuild_buckets()):
print_log(
'Reducer buckets have been rebuilt in this iteration.',
logger='mmcv')
......@@ -105,6 +106,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
else:
self.reducer.prepare_for_backward([])
else:
if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'):
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) > digit_version('1.2')):
self.require_forward_param_sync = False
return output
# Copyright (c) Open-MMLab. All rights reserved.
from distutils.version import LooseVersion
import torch
import torch.distributed as dist
import torch.nn as nn
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from mmcv.utils import TORCH_VERSION
from mmcv.utils import TORCH_VERSION, digit_version
from .registry import MODULE_WRAPPERS
from .scatter_gather import scatter_kwargs
......@@ -42,7 +40,8 @@ class MMDistributedDataParallel(nn.Module):
self._dist_broadcast_coalesced(module_states,
self.broadcast_bucket_size)
if self.broadcast_buffers:
if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'):
if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) < digit_version('1.0')):
buffers = [b.data for b in self.module._all_buffers()]
else:
buffers = [b.data for b in self.module.buffers()]
......
......@@ -3,7 +3,6 @@ import functools
import os
import subprocess
from collections import OrderedDict
from distutils.version import LooseVersion
import torch
import torch.multiprocessing as mp
......@@ -11,7 +10,7 @@ from torch import distributed as dist
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from mmcv.utils import TORCH_VERSION
from mmcv.utils import TORCH_VERSION, digit_version
def init_dist(launcher, backend='nccl', **kwargs):
......@@ -79,7 +78,8 @@ def _init_dist_slurm(backend, port=None):
def get_dist_info():
if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'):
if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) < digit_version('1.0')):
initialized = dist._initialized
else:
if dist.is_available():
......
import functools
import warnings
from collections import abc
from distutils.version import LooseVersion
from inspect import getfullargspec
import numpy as np
import torch
import torch.nn as nn
from mmcv.utils import TORCH_VERSION
from mmcv.utils import TORCH_VERSION, digit_version
from .dist_utils import allreduce_grads as _allreduce_grads
try:
......@@ -122,8 +121,8 @@ def auto_fp16(apply_to=None, out_fp32=False):
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
if (TORCH_VERSION != 'parrots' and
digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True):
output = old_func(*new_args, **new_kwargs)
else:
......@@ -208,8 +207,8 @@ def force_fp32(apply_to=None, out_fp16=False):
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
if (TORCH_VERSION != 'parrots' and
digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=False):
output = old_func(*new_args, **new_kwargs)
else:
......@@ -249,7 +248,7 @@ def wrap_fp16_model(model):
model (nn.Module): Model in FP32.
"""
if (TORCH_VERSION == 'parrots'
or LooseVersion(TORCH_VERSION) < LooseVersion('1.6.0')):
or digit_version(TORCH_VERSION) < digit_version('1.6.0')):
# convert model to fp16
model.half()
# patch the normalization layers to make it work in fp32 mode
......
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
from distutils.version import LooseVersion
from mmcv.utils import TORCH_VERSION
from mmcv.utils import TORCH_VERSION, digit_version
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
......@@ -24,8 +23,8 @@ class TensorboardLoggerHook(LoggerHook):
@master_only
def before_run(self, runner):
super(TensorboardLoggerHook, self).before_run(runner)
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.1')
or TORCH_VERSION == 'parrots'):
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.1')):
try:
from tensorboardX import SummaryWriter
except ImportError:
......
# Copyright (c) Open-MMLab. All rights reserved.
import copy
from collections import defaultdict
from distutils.version import LooseVersion
from itertools import chain
from torch.nn.utils import clip_grad
from mmcv.utils import TORCH_VERSION
from mmcv.utils import TORCH_VERSION, digit_version
from ..dist_utils import allreduce_grads
from ..fp16_utils import LossScaler, wrap_fp16_model
from .hook import HOOKS, Hook
......@@ -44,7 +43,7 @@ class OptimizerHook(Hook):
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
@HOOKS.register_module()
class Fp16OptimizerHook(OptimizerHook):
......
from functools import partial
from pkg_resources import parse_version
import torch
from mmcv.utils import digit_version
TORCH_VERSION = torch.__version__
is_rocm_pytorch = False
if parse_version(TORCH_VERSION) >= parse_version('1.5'):
if (TORCH_VERSION != 'parrots'
and digit_version(TORCH_VERSION) >= digit_version('1.5')):
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False
......
import warnings
from distutils.version import LooseVersion
import torch
from mmcv.utils import digit_version
def is_jit_tracing() -> bool:
if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'):
if (torch.__version__ != 'parrots'
and digit_version(torch.__version__) >= digit_version('1.6.0')):
on_trace = torch.jit.is_tracing()
# In PyTorch 1.6, torch.jit.is_tracing has a bug.
# Refers to https://github.com/pytorch/pytorch/issues/42448
......
import os
import subprocess
import warnings
from pkg_resources import parse_version
def digit_version(version_str):
def digit_version(version_str: str, length: int = 4):
"""Convert a version string into a tuple of integers.
This method is usually used for comparing two versions.
This method is usually used for comparing two versions. For pre-release
versions: alpha < beta < rc.
Args:
version_str (str): The version string.
length (int): The maximum number of version levels. Default: 4.
Returns:
tuple[int]: The version info in digits (integers).
"""
digit_version = []
for x in version_str.split('.'):
if x.isdigit():
digit_version.append(int(x))
elif x.find('rc') != -1:
patch_version = x.split('rc')
digit_version.append(int(patch_version[0]) - 1)
digit_version.append(int(patch_version[1]))
return tuple(digit_version)
assert 'parrots' not in version_str
version = parse_version(version_str)
assert version.release, f'failed to parse version {version_str}'
release = list(version.release)
release = release[:length]
if len(release) < length:
release = release + [0] * (length - len(release))
if version.is_prerelease:
mapping = {'a': -3, 'b': -2, 'rc': -1}
val = -4
# version.pre can be None
if version.pre:
if version.pre[0] not in mapping:
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
'version checking may go wrong')
else:
val = mapping[version.pre[0]]
release.extend([val, version.pre[-1]])
else:
release.extend([val, 0])
elif version.is_postrelease:
release.extend([1, version.post])
else:
release.extend([0, 0])
return tuple(release)
def _minimal_ext_cmd(cmd):
......
# Copyright (c) Open-MMLab. All rights reserved.
from pkg_resources import parse_version
__version__ = '1.3.9'
def parse_version_info(version_str: str) -> tuple:
def parse_version_info(version_str: str, length: int = 4) -> tuple:
"""Parse a version string into a tuple.
Args:
version_str (str): The version string.
length (int): The maximum number of version levels. Default: 4.
Returns:
tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
(1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1').
(1, 3, 0, 0, 0, 0), and "2.0.0rc1" is parsed into
(2, 0, 0, 0, 'rc', 1) (when length is set to 4).
"""
version_info = []
for x in version_str.split('.'):
if x.isdigit():
version_info.append(int(x))
elif x.find('rc') != -1:
patch_version = x.split('rc')
version_info.append(int(patch_version[0]))
version_info.append(f'rc{patch_version[1]}')
return tuple(version_info)
version = parse_version(version_str)
assert version.release, f'failed to parse version {version_str}'
release = list(version.release)
release = release[:length]
if len(release) < length:
release = release + [0] * (length - len(release))
if version.is_prerelease:
release.extend(list(version.pre))
elif version.is_postrelease:
release.extend(list(version.post))
else:
release.extend([0, 0])
return tuple(release)
version_info = parse_version_info(__version__)
......
......@@ -3,4 +3,5 @@ numpy
Pillow
pyyaml
regex;sys_platform=='win32'
setuptools>=52
yapf
......@@ -6,4 +6,5 @@ onnxruntime==1.4.0
pytest
PyTurboJPEG
scipy
setuptools>=52
tiffile
from distutils.version import LooseVersion
import numpy as np
import pytest
import torch
from mmcv.utils import TORCH_VERSION
from mmcv.utils import TORCH_VERSION, digit_version
try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
......@@ -144,7 +142,7 @@ class TestDeformconv(object):
# test amp when torch version >= '1.6.0', the type of
# input data for deformconv might be torch.float or torch.half
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True):
self._test_amp_deformconv(torch.float, 1e-1)
self._test_amp_deformconv(torch.half, 1e-1)
import os
from distutils.version import LooseVersion
import numpy
import torch
from mmcv.utils import TORCH_VERSION
from mmcv.utils import TORCH_VERSION, digit_version
try:
# If PyTorch version >= 1.6.0 and fp16 is enabled, torch.cuda.amp.autocast
......@@ -114,7 +113,7 @@ class TestMdconv(object):
# test amp when torch version >= '1.6.0', the type of
# input data for mdconv might be torch.float or torch.half
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
and digit_version(TORCH_VERSION) >= digit_version('1.6.0')):
with autocast(enabled=True):
self._test_amp_mdconv(torch.float)
self._test_amp_mdconv(torch.half)
import os
import random
from pkg_resources import parse_version
import numpy as np
import torch
from mmcv.runner import set_random_seed
from mmcv.utils import TORCH_VERSION
from mmcv.utils import TORCH_VERSION, digit_version
is_rocm_pytorch = False
if parse_version(TORCH_VERSION) >= parse_version('1.5'):
if digit_version(TORCH_VERSION) >= digit_version('1.5'):
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False
......
from distutils.version import LooseVersion
import pytest
import torch
from mmcv.utils import is_jit_tracing
from mmcv.utils import digit_version, is_jit_tracing
@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion('1.6.0'),
digit_version(torch.__version__) < digit_version('1.6.0'),
reason='torch.jit.is_tracing is not available before 1.6.0')
def test_is_jit_tracing():
......
from unittest.mock import patch
from mmcv import digit_version, get_git_hash, parse_version_info
import pytest
from mmcv import get_git_hash, parse_version_info
from mmcv.utils import digit_version
def test_digit_version():
assert digit_version('0.2.16') == (0, 2, 16)
assert digit_version('1.2.3') == (1, 2, 3)
assert digit_version('1.2.3rc0') == (1, 2, 2, 0)
assert digit_version('1.2.3rc1') == (1, 2, 2, 1)
assert digit_version('1.0rc0') == (1, -1, 0)
assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0)
assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0)
assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0)
assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1)
assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0)
assert digit_version('1.0') == digit_version('1.0.0')
assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5')
assert digit_version('1.0.0dev') < digit_version('1.0.0a')
assert digit_version('1.0.0a') < digit_version('1.0.0a1')
assert digit_version('1.0.0a') < digit_version('1.0.0b')
assert digit_version('1.0.0b') < digit_version('1.0.0rc')
assert digit_version('1.0.0rc1') < digit_version('1.0.0')
assert digit_version('1.0.0') < digit_version('1.0.0post')
assert digit_version('1.0.0post') < digit_version('1.0.0post1')
assert digit_version('v1') == (1, 0, 0, 0, 0, 0)
assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)
with pytest.raises(AssertionError):
digit_version('a')
with pytest.raises(AssertionError):
digit_version('1x')
with pytest.raises(AssertionError):
digit_version('1.x')
def test_parse_version_info():
assert parse_version_info('0.2.16') == (0, 2, 16)
assert parse_version_info('1.2.3') == (1, 2, 3)
assert parse_version_info('1.2.3rc0') == (1, 2, 3, 'rc0')
assert parse_version_info('1.2.3rc1') == (1, 2, 3, 'rc1')
assert parse_version_info('1.0rc0') == (1, 0, 'rc0')
assert parse_version_info('0.2.16') == (0, 2, 16, 0, 0, 0)
assert parse_version_info('1.2.3') == (1, 2, 3, 0, 0, 0)
assert parse_version_info('1.2.3rc0') == (1, 2, 3, 0, 'rc', 0)
assert parse_version_info('1.2.3rc1') == (1, 2, 3, 0, 'rc', 1)
assert parse_version_info('1.0rc0') == (1, 0, 0, 0, 'rc', 0)
def _mock_cmd_success(cmd):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment