Unverified Commit 1d5ee6e0 authored by Haodong Duan's avatar Haodong Duan Committed by GitHub
Browse files

use LooseVersion for version checking (#1158)

parent 21845db4
from distutils.version import LooseVersion
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -70,7 +72,8 @@ class GELU(nn.Module):
return F.gelu(input)
if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.4':
if (TORCH_VERSION == 'parrots'
or LooseVersion(TORCH_VERSION) < LooseVersion('1.4')):
ACTIVATION_LAYERS.register_module(module=GELU)
else:
ACTIVATION_LAYERS.register_module(module=nn.GELU)
......
from distutils.version import LooseVersion
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -103,7 +105,8 @@ class SAConv2d(ConvAWS2d):
out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1)
else:
if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots':
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
or TORCH_VERSION == 'parrots'):
out_s = super().conv2d_forward(x, weight)
else:
out_s = super()._conv_forward(x, weight)
......@@ -117,7 +120,8 @@ class SAConv2d(ConvAWS2d):
out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1)
else:
if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots':
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
or TORCH_VERSION == 'parrots'):
out_l = super().conv2d_forward(x, weight)
else:
out_l = super()._conv_forward(x, weight)
......
# Copyright (c) Open-MMLab. All rights reserved.
from distutils.version import LooseVersion
import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors)
......@@ -37,7 +39,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if (TORCH_VERSION >= '1.7' and 'parrots'
if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets():
print_log(
'Reducer buckets have been rebuilt in this iteration.',
......@@ -63,7 +65,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
else:
self.reducer.prepare_for_backward([])
else:
if TORCH_VERSION > '1.2':
if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'):
self.require_forward_param_sync = False
return output
......@@ -77,7 +79,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
"""
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if (TORCH_VERSION >= '1.7' and 'parrots'
if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets():
print_log(
'Reducer buckets have been rebuilt in this iteration.',
......@@ -103,6 +105,6 @@ class MMDistributedDataParallel(DistributedDataParallel):
else:
self.reducer.prepare_for_backward([])
else:
if TORCH_VERSION > '1.2':
if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'):
self.require_forward_param_sync = False
return output
# Copyright (c) Open-MMLab. All rights reserved.
from distutils.version import LooseVersion
import torch
import torch.distributed as dist
import torch.nn as nn
......@@ -40,7 +42,7 @@ class MMDistributedDataParallel(nn.Module):
self._dist_broadcast_coalesced(module_states,
self.broadcast_bucket_size)
if self.broadcast_buffers:
if TORCH_VERSION < '1.0':
if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'):
buffers = [b.data for b in self.module._all_buffers()]
else:
buffers = [b.data for b in self.module.buffers()]
......
......@@ -3,6 +3,7 @@ import functools
import os
import subprocess
from collections import OrderedDict
from distutils.version import LooseVersion
import torch
import torch.multiprocessing as mp
......@@ -78,7 +79,7 @@ def _init_dist_slurm(backend, port=None):
def get_dist_info():
if TORCH_VERSION < '1.0':
if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'):
initialized = dist._initialized
else:
if dist.is_available():
......
import functools
import warnings
from collections import abc
from distutils.version import LooseVersion
from inspect import getfullargspec
import numpy as np
......@@ -121,7 +122,8 @@ def auto_fp16(apply_to=None, out_fp32=False):
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=True):
output = old_func(*new_args, **new_kwargs)
else:
......@@ -206,7 +208,8 @@ def force_fp32(apply_to=None, out_fp16=False):
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=False):
output = old_func(*new_args, **new_kwargs)
else:
......@@ -245,7 +248,8 @@ def wrap_fp16_model(model):
Args:
model (nn.Module): Model in FP32.
"""
if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.6.0':
if (TORCH_VERSION == 'parrots'
or LooseVersion(TORCH_VERSION) < LooseVersion('1.6.0')):
# convert model to fp16
model.half()
# patch the normalization layers to make it work in fp32 mode
......
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
from distutils.version import LooseVersion
from mmcv.utils import TORCH_VERSION
from ...dist_utils import master_only
......@@ -23,7 +24,8 @@ class TensorboardLoggerHook(LoggerHook):
@master_only
def before_run(self, runner):
super(TensorboardLoggerHook, self).before_run(runner)
if TORCH_VERSION < '1.1' or TORCH_VERSION == 'parrots':
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.1')
or TORCH_VERSION == 'parrots'):
try:
from tensorboardX import SummaryWriter
except ImportError:
......
# Copyright (c) Open-MMLab. All rights reserved.
import copy
from collections import defaultdict
from distutils.version import LooseVersion
from itertools import chain
from torch.nn.utils import clip_grad
......@@ -42,7 +43,8 @@ class OptimizerHook(Hook):
runner.optimizer.step()
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
@HOOKS.register_module()
class Fp16OptimizerHook(OptimizerHook):
......
from distutils.version import LooseVersion
import numpy as np
import pytest
import torch
......@@ -141,7 +143,8 @@ class TestDeformconv(object):
# test amp when torch version >= '1.6.0', the type of
# input data for deformconv might be torch.float or torch.half
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=True):
self._test_amp_deformconv(torch.float, 1e-1)
self._test_amp_deformconv(torch.half, 1e-1)
import os
from distutils.version import LooseVersion
import numpy
import torch
......@@ -112,7 +113,8 @@ class TestMdconv(object):
# test amp when torch version >= '1.6.0', the type of
# input data for mdconv might be torch.float or torch.half
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=True):
self._test_amp_mdconv(torch.float)
self._test_amp_mdconv(torch.half)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment