Unverified Commit 1d5ee6e0 authored by Haodong Duan's avatar Haodong Duan Committed by GitHub
Browse files

use LooseVersion for version checking (#1158)

parent 21845db4
from distutils.version import LooseVersion
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -70,7 +72,8 @@ class GELU(nn.Module): ...@@ -70,7 +72,8 @@ class GELU(nn.Module):
return F.gelu(input) return F.gelu(input)
if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.4': if (TORCH_VERSION == 'parrots'
or LooseVersion(TORCH_VERSION) < LooseVersion('1.4')):
ACTIVATION_LAYERS.register_module(module=GELU) ACTIVATION_LAYERS.register_module(module=GELU)
else: else:
ACTIVATION_LAYERS.register_module(module=nn.GELU) ACTIVATION_LAYERS.register_module(module=nn.GELU)
......
from distutils.version import LooseVersion
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -103,7 +105,8 @@ class SAConv2d(ConvAWS2d): ...@@ -103,7 +105,8 @@ class SAConv2d(ConvAWS2d):
out_s = deform_conv2d(x, offset, weight, self.stride, self.padding, out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1) self.dilation, self.groups, 1)
else: else:
if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots': if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
or TORCH_VERSION == 'parrots'):
out_s = super().conv2d_forward(x, weight) out_s = super().conv2d_forward(x, weight)
else: else:
out_s = super()._conv_forward(x, weight) out_s = super()._conv_forward(x, weight)
...@@ -117,7 +120,8 @@ class SAConv2d(ConvAWS2d): ...@@ -117,7 +120,8 @@ class SAConv2d(ConvAWS2d):
out_l = deform_conv2d(x, offset, weight, self.stride, self.padding, out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1) self.dilation, self.groups, 1)
else: else:
if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots': if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
or TORCH_VERSION == 'parrots'):
out_l = super().conv2d_forward(x, weight) out_l = super().conv2d_forward(x, weight)
else: else:
out_l = super()._conv_forward(x, weight) out_l = super()._conv_forward(x, weight)
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from distutils.version import LooseVersion
import torch import torch
from torch.nn.parallel.distributed import (DistributedDataParallel, from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors) _find_tensors)
...@@ -37,7 +39,7 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -37,7 +39,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward. # end of backward to the beginning of forward.
if (TORCH_VERSION >= '1.7' and 'parrots' if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets(): not in TORCH_VERSION) and self.reducer._rebuild_buckets():
print_log( print_log(
'Reducer buckets have been rebuilt in this iteration.', 'Reducer buckets have been rebuilt in this iteration.',
...@@ -63,7 +65,7 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -63,7 +65,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
else: else:
self.reducer.prepare_for_backward([]) self.reducer.prepare_for_backward([])
else: else:
if TORCH_VERSION > '1.2': if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'):
self.require_forward_param_sync = False self.require_forward_param_sync = False
return output return output
...@@ -77,7 +79,7 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -77,7 +79,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
""" """
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward. # end of backward to the beginning of forward.
if (TORCH_VERSION >= '1.7' and 'parrots' if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets(): not in TORCH_VERSION) and self.reducer._rebuild_buckets():
print_log( print_log(
'Reducer buckets have been rebuilt in this iteration.', 'Reducer buckets have been rebuilt in this iteration.',
...@@ -103,6 +105,6 @@ class MMDistributedDataParallel(DistributedDataParallel): ...@@ -103,6 +105,6 @@ class MMDistributedDataParallel(DistributedDataParallel):
else: else:
self.reducer.prepare_for_backward([]) self.reducer.prepare_for_backward([])
else: else:
if TORCH_VERSION > '1.2': if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'):
self.require_forward_param_sync = False self.require_forward_param_sync = False
return output return output
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from distutils.version import LooseVersion
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
...@@ -40,7 +42,7 @@ class MMDistributedDataParallel(nn.Module): ...@@ -40,7 +42,7 @@ class MMDistributedDataParallel(nn.Module):
self._dist_broadcast_coalesced(module_states, self._dist_broadcast_coalesced(module_states,
self.broadcast_bucket_size) self.broadcast_bucket_size)
if self.broadcast_buffers: if self.broadcast_buffers:
if TORCH_VERSION < '1.0': if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'):
buffers = [b.data for b in self.module._all_buffers()] buffers = [b.data for b in self.module._all_buffers()]
else: else:
buffers = [b.data for b in self.module.buffers()] buffers = [b.data for b in self.module.buffers()]
......
...@@ -3,6 +3,7 @@ import functools ...@@ -3,6 +3,7 @@ import functools
import os import os
import subprocess import subprocess
from collections import OrderedDict from collections import OrderedDict
from distutils.version import LooseVersion
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -78,7 +79,7 @@ def _init_dist_slurm(backend, port=None): ...@@ -78,7 +79,7 @@ def _init_dist_slurm(backend, port=None):
def get_dist_info(): def get_dist_info():
if TORCH_VERSION < '1.0': if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'):
initialized = dist._initialized initialized = dist._initialized
else: else:
if dist.is_available(): if dist.is_available():
......
import functools import functools
import warnings import warnings
from collections import abc from collections import abc
from distutils.version import LooseVersion
from inspect import getfullargspec from inspect import getfullargspec
import numpy as np import numpy as np
...@@ -121,7 +122,8 @@ def auto_fp16(apply_to=None, out_fp32=False): ...@@ -121,7 +122,8 @@ def auto_fp16(apply_to=None, out_fp32=False):
else: else:
new_kwargs[arg_name] = arg_value new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method # apply converted arguments to the decorated method
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=True): with autocast(enabled=True):
output = old_func(*new_args, **new_kwargs) output = old_func(*new_args, **new_kwargs)
else: else:
...@@ -206,7 +208,8 @@ def force_fp32(apply_to=None, out_fp16=False): ...@@ -206,7 +208,8 @@ def force_fp32(apply_to=None, out_fp16=False):
else: else:
new_kwargs[arg_name] = arg_value new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method # apply converted arguments to the decorated method
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=False): with autocast(enabled=False):
output = old_func(*new_args, **new_kwargs) output = old_func(*new_args, **new_kwargs)
else: else:
...@@ -245,7 +248,8 @@ def wrap_fp16_model(model): ...@@ -245,7 +248,8 @@ def wrap_fp16_model(model):
Args: Args:
model (nn.Module): Model in FP32. model (nn.Module): Model in FP32.
""" """
if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.6.0': if (TORCH_VERSION == 'parrots'
or LooseVersion(TORCH_VERSION) < LooseVersion('1.6.0')):
# convert model to fp16 # convert model to fp16
model.half() model.half()
# patch the normalization layers to make it work in fp32 mode # patch the normalization layers to make it work in fp32 mode
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp import os.path as osp
from distutils.version import LooseVersion
from mmcv.utils import TORCH_VERSION from mmcv.utils import TORCH_VERSION
from ...dist_utils import master_only from ...dist_utils import master_only
...@@ -23,7 +24,8 @@ class TensorboardLoggerHook(LoggerHook): ...@@ -23,7 +24,8 @@ class TensorboardLoggerHook(LoggerHook):
@master_only @master_only
def before_run(self, runner): def before_run(self, runner):
super(TensorboardLoggerHook, self).before_run(runner) super(TensorboardLoggerHook, self).before_run(runner)
if TORCH_VERSION < '1.1' or TORCH_VERSION == 'parrots': if (LooseVersion(TORCH_VERSION) < LooseVersion('1.1')
or TORCH_VERSION == 'parrots'):
try: try:
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
except ImportError: except ImportError:
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import copy import copy
from collections import defaultdict from collections import defaultdict
from distutils.version import LooseVersion
from itertools import chain from itertools import chain
from torch.nn.utils import clip_grad from torch.nn.utils import clip_grad
...@@ -42,7 +43,8 @@ class OptimizerHook(Hook): ...@@ -42,7 +43,8 @@ class OptimizerHook(Hook):
runner.optimizer.step() runner.optimizer.step()
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
@HOOKS.register_module() @HOOKS.register_module()
class Fp16OptimizerHook(OptimizerHook): class Fp16OptimizerHook(OptimizerHook):
......
from distutils.version import LooseVersion
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
...@@ -141,7 +143,8 @@ class TestDeformconv(object): ...@@ -141,7 +143,8 @@ class TestDeformconv(object):
# test amp when torch version >= '1.6.0', the type of # test amp when torch version >= '1.6.0', the type of
# input data for deformconv might be torch.float or torch.half # input data for deformconv might be torch.float or torch.half
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=True): with autocast(enabled=True):
self._test_amp_deformconv(torch.float, 1e-1) self._test_amp_deformconv(torch.float, 1e-1)
self._test_amp_deformconv(torch.half, 1e-1) self._test_amp_deformconv(torch.half, 1e-1)
import os import os
from distutils.version import LooseVersion
import numpy import numpy
import torch import torch
...@@ -112,7 +113,8 @@ class TestMdconv(object): ...@@ -112,7 +113,8 @@ class TestMdconv(object):
# test amp when torch version >= '1.6.0', the type of # test amp when torch version >= '1.6.0', the type of
# input data for mdconv might be torch.float or torch.half # input data for mdconv might be torch.float or torch.half
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=True): with autocast(enabled=True):
self._test_amp_mdconv(torch.float) self._test_amp_mdconv(torch.float)
self._test_amp_mdconv(torch.half) self._test_amp_mdconv(torch.half)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment