"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "0c2373d0bba3499e95776e7936e207d8a1676e65"
Unverified Commit 7e285d35 authored by Haodong Duan's avatar Haodong Duan Committed by GitHub
Browse files

[Fix] Fix saconv (#1147)

* fix saconv

* update

* update

* fix

* use LooseVersion
parent b035fe91
...@@ -100,6 +100,9 @@ class SAConv2d(ConvAWS2d): ...@@ -100,6 +100,9 @@ class SAConv2d(ConvAWS2d):
switch = self.switch(avg_x) switch = self.switch(avg_x)
# sac # sac
weight = self._get_weight(self.weight) weight = self._get_weight(self.weight)
zero_bias = torch.zeros(
self.out_channels, device=weight.device, dtype=weight.dtype)
if self.use_deform: if self.use_deform:
offset = self.offset_s(avg_x) offset = self.offset_s(avg_x)
out_s = deform_conv2d(x, offset, weight, self.stride, self.padding, out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
...@@ -108,6 +111,9 @@ class SAConv2d(ConvAWS2d): ...@@ -108,6 +111,9 @@ class SAConv2d(ConvAWS2d):
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0') if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
or TORCH_VERSION == 'parrots'): or TORCH_VERSION == 'parrots'):
out_s = super().conv2d_forward(x, weight) out_s = super().conv2d_forward(x, weight)
elif LooseVersion(TORCH_VERSION) >= LooseVersion('1.8.0'):
# bias is a required argument of _conv_forward in torch 1.8.0
out_s = super()._conv_forward(x, weight, zero_bias)
else: else:
out_s = super()._conv_forward(x, weight) out_s = super()._conv_forward(x, weight)
ori_p = self.padding ori_p = self.padding
...@@ -123,8 +129,12 @@ class SAConv2d(ConvAWS2d): ...@@ -123,8 +129,12 @@ class SAConv2d(ConvAWS2d):
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0') if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
or TORCH_VERSION == 'parrots'): or TORCH_VERSION == 'parrots'):
out_l = super().conv2d_forward(x, weight) out_l = super().conv2d_forward(x, weight)
elif LooseVersion(TORCH_VERSION) >= LooseVersion('1.8.0'):
# bias is a required argument of _conv_forward in torch 1.8.0
out_l = super()._conv_forward(x, weight, zero_bias)
else: else:
out_l = super()._conv_forward(x, weight) out_l = super()._conv_forward(x, weight)
out = switch * out_s + (1 - switch) * out_l out = switch * out_s + (1 - switch) * out_l
self.padding = ori_p self.padding = ori_p
self.dilation = ori_d self.dilation = ori_d
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment