Unverified Commit 7b18b977 authored by Cao Yuhang's avatar Cao Yuhang Committed by GitHub
Browse files

fix saconv (#489)

* fix saconv

* add parrots condition

* add unittest

* fix torch version
parent eacaf475
...@@ -4,6 +4,7 @@ import torch.nn.functional as F ...@@ -4,6 +4,7 @@ import torch.nn.functional as F
from mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init from mmcv.cnn import CONV_LAYERS, ConvAWS2d, constant_init
from mmcv.ops.deform_conv import deform_conv2d from mmcv.ops.deform_conv import deform_conv2d
from mmcv.utils import TORCH_VERSION
@CONV_LAYERS.register_module(name='SAC') @CONV_LAYERS.register_module(name='SAC')
...@@ -102,7 +103,10 @@ class SAConv2d(ConvAWS2d): ...@@ -102,7 +103,10 @@ class SAConv2d(ConvAWS2d):
out_s = deform_conv2d(x, offset, weight, self.stride, self.padding, out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1) self.dilation, self.groups, 1)
else: else:
out_s = super().conv2d_forward(x, weight) if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots':
out_s = super().conv2d_forward(x, weight)
else:
out_s = super()._conv_forward(x, weight)
ori_p = self.padding ori_p = self.padding
ori_d = self.dilation ori_d = self.dilation
self.padding = tuple(3 * p for p in self.padding) self.padding = tuple(3 * p for p in self.padding)
...@@ -113,7 +117,10 @@ class SAConv2d(ConvAWS2d): ...@@ -113,7 +117,10 @@ class SAConv2d(ConvAWS2d):
out_l = deform_conv2d(x, offset, weight, self.stride, self.padding, out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1) self.dilation, self.groups, 1)
else: else:
out_l = super().conv2d_forward(x, weight) if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots':
out_l = super().conv2d_forward(x, weight)
else:
out_l = super()._conv_forward(x, weight)
out = switch * out_s + (1 - switch) * out_l out = switch * out_s + (1 - switch) * out_l
self.padding = ori_p self.padding = ori_p
self.dilation = ori_d self.dilation = ori_d
......
import pytest
import torch
import torch.nn as nn
from mmcv.ops import SAConv2d
def test_sacconv():
# test with normal cast
x = torch.rand(1, 3, 256, 256)
saconv = SAConv2d(3, 5, kernel_size=3, padding=1)
sac_out = saconv(x)
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1)
refer_out = refer_conv(x)
assert sac_out.shape == refer_out.shape
# test with dilation >= 2
dalited_saconv = SAConv2d(3, 5, kernel_size=3, padding=2, dilation=2)
dalited_sac_out = dalited_saconv(x)
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=2, dilation=2)
refer_out = refer_conv(x)
assert dalited_sac_out.shape == refer_out.shape
# test with deform
deform_saconv = SAConv2d(3, 5, kernel_size=3, padding=1, use_deform=True)
if torch.cuda.is_available():
x = torch.rand(1, 3, 256, 256).cuda()
deform_saconv = SAConv2d(
3, 5, kernel_size=3, padding=1, use_deform=True).cuda()
deform_sac_out = deform_saconv(x).cuda()
refer_conv = nn.Conv2d(3, 5, kernel_size=3, padding=1).cuda()
refer_out = refer_conv(x)
assert deform_sac_out.shape == refer_out.shape
else:
with pytest.raises(RuntimeError):
# deform conv is not implemented on cpu
deform_saconv(x)
# test with groups >= 2
x = torch.rand(1, 4, 256, 256)
group_saconv = SAConv2d(4, 4, kernel_size=3, padding=1, groups=2)
group_sac_out = group_saconv(x)
refer_conv = nn.Conv2d(4, 4, kernel_size=3, padding=1, groups=2)
refer_out = refer_conv(x)
assert group_sac_out.shape == refer_out.shape
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment