Unverified Commit e0323b4e authored by mengpenghui's avatar mengpenghui Committed by GitHub
Browse files

[Fix] Check the version of torchvision in __init__ of DCN (#2556)

parent 71ee2a61
...@@ -438,12 +438,9 @@ class DeformConv2dPack(DeformConv2d): ...@@ -438,12 +438,9 @@ class DeformConv2dPack(DeformConv2d):
if IS_MLU_AVAILABLE: if IS_MLU_AVAILABLE:
import torchvision import torchvision
from torchvision.ops import deform_conv2d as tv_deform_conv2d
from mmcv.utils import digit_version from mmcv.utils import digit_version
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
from torchvision.ops import deform_conv2d as tv_deform_conv2d
@CONV_LAYERS.register_module('DCN', force=True) @CONV_LAYERS.register_module('DCN', force=True)
class DeformConv2dPack_MLU(DeformConv2d): class DeformConv2dPack_MLU(DeformConv2d):
...@@ -471,6 +468,8 @@ if IS_MLU_AVAILABLE: ...@@ -471,6 +468,8 @@ if IS_MLU_AVAILABLE:
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d( self.conv_offset = nn.Conv2d(
...@@ -494,7 +493,6 @@ if IS_MLU_AVAILABLE: ...@@ -494,7 +493,6 @@ if IS_MLU_AVAILABLE:
) == 0, 'batch size must be divisible by im2col_step' ) == 0, 'batch size must be divisible by im2col_step'
offset = self.conv_offset(x) offset = self.conv_offset(x)
x = x.type_as(offset) x = x.type_as(offset)
weight = self.weight weight = self.weight.type_as(x)
weight = weight.type_as(x)
return tv_deform_conv2d(x, offset, weight, None, self.stride, return tv_deform_conv2d(x, offset, weight, None, self.stride,
self.padding, self.dilation) self.padding, self.dilation)
...@@ -356,11 +356,9 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d): ...@@ -356,11 +356,9 @@ class ModulatedDeformConv2dPack(ModulatedDeformConv2d):
if IS_MLU_AVAILABLE: if IS_MLU_AVAILABLE:
import torchvision import torchvision
from torchvision.ops import deform_conv2d as tv_deform_conv2d
from mmcv.utils import digit_version from mmcv.utils import digit_version
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
from torchvision.ops import deform_conv2d as tv_deform_conv2d
@CONV_LAYERS.register_module('DCNv2', force=True) @CONV_LAYERS.register_module('DCNv2', force=True)
class ModulatedDeformConv2dPack_MLU(ModulatedDeformConv2d): class ModulatedDeformConv2dPack_MLU(ModulatedDeformConv2d):
...@@ -383,6 +381,8 @@ if IS_MLU_AVAILABLE: ...@@ -383,6 +381,8 @@ if IS_MLU_AVAILABLE:
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
assert digit_version(torchvision.__version__) >= digit_version(
'0.10.0a0'), 'the version of torchvision should be >= 0.10.0'
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.conv_offset = nn.Conv2d( self.conv_offset = nn.Conv2d(
self.in_channels, self.in_channels,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment