Unverified Commit 861253ca authored by Zeqiang Lai's avatar Zeqiang Lai Committed by GitHub
Browse files

Fix DCNv3 version compatibility (#108)

parent 8f2d1583
...@@ -15,6 +15,9 @@ from torch.autograd.function import once_differentiable ...@@ -15,6 +15,9 @@ from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
import DCNv3 import DCNv3
import pkg_resources
dcn_version = float(pkg_resources.get_distribution('DCNv3').version)
class DCNv3Function(Function): class DCNv3Function(Function):
@staticmethod @staticmethod
...@@ -38,15 +41,16 @@ class DCNv3Function(Function): ...@@ -38,15 +41,16 @@ class DCNv3Function(Function):
ctx.im2col_step = im2col_step ctx.im2col_step = im2col_step
ctx.remove_center = remove_center ctx.remove_center = remove_center
kwargs = {} args = [
if remove_center:
kwargs['remove_center'] = remove_center
output = DCNv3.dcnv3_forward(
input, offset, mask, kernel_h, input, offset, mask, kernel_h,
kernel_w, stride_h, stride_w, pad_h, kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group, pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, ctx.im2col_step, **kwargs) group_channels, offset_scale, ctx.im2col_step
]
if remove_center or dcn_version > 1.0:
args.append(remove_center)
output = DCNv3.dcnv3_forward(*args)
ctx.save_for_backward(input, offset, mask) ctx.save_for_backward(input, offset, mask)
return output return output
...@@ -57,16 +61,17 @@ class DCNv3Function(Function): ...@@ -57,16 +61,17 @@ class DCNv3Function(Function):
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, offset, mask = ctx.saved_tensors input, offset, mask = ctx.saved_tensors
kwargs = {} args = [
if ctx.remove_center: input, offset, mask, ctx.kernel_h,
kwargs['remove_center'] = ctx.remove_center ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h,
ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group,
ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step
]
if ctx.remove_center or dcn_version > 1.0:
args.append(ctx.remove_center)
grad_input, grad_offset, grad_mask = \ grad_input, grad_offset, grad_mask = \
DCNv3.dcnv3_backward( DCNv3.dcnv3_backward(*args)
input, offset, mask, ctx.kernel_h,
ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h,
ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group,
ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step, **kwargs)
return grad_input, grad_offset, grad_mask, \ return grad_input, grad_offset, grad_mask, \
None, None, None, None, None, None, None, None, None, None, None, None, None None, None, None, None, None, None, None, None, None, None, None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment