Unverified Commit 861253ca authored by Zeqiang Lai's avatar Zeqiang Lai Committed by GitHub
Browse files

Fix DCNv3 version compatibility (#108)

parent 8f2d1583
......@@ -15,6 +15,9 @@ from torch.autograd.function import once_differentiable
from torch.cuda.amp import custom_bwd, custom_fwd
import DCNv3
import pkg_resources
dcn_version = float(pkg_resources.get_distribution('DCNv3').version)
class DCNv3Function(Function):
@staticmethod
......@@ -38,15 +41,16 @@ class DCNv3Function(Function):
ctx.im2col_step = im2col_step
ctx.remove_center = remove_center
kwargs = {}
if remove_center:
kwargs['remove_center'] = remove_center
output = DCNv3.dcnv3_forward(
args = [
input, offset, mask, kernel_h,
kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group,
group_channels, offset_scale, ctx.im2col_step, **kwargs)
group_channels, offset_scale, ctx.im2col_step
]
if remove_center or dcn_version > 1.0:
args.append(remove_center)
output = DCNv3.dcnv3_forward(*args)
ctx.save_for_backward(input, offset, mask)
return output
......@@ -57,16 +61,17 @@ class DCNv3Function(Function):
def backward(ctx, grad_output):
input, offset, mask = ctx.saved_tensors
kwargs = {}
if ctx.remove_center:
kwargs['remove_center'] = ctx.remove_center
grad_input, grad_offset, grad_mask = \
DCNv3.dcnv3_backward(
args = [
input, offset, mask, ctx.kernel_h,
ctx.kernel_w, ctx.stride_h, ctx.stride_w, ctx.pad_h,
ctx.pad_w, ctx.dilation_h, ctx.dilation_w, ctx.group,
ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step, **kwargs)
ctx.group_channels, ctx.offset_scale, grad_output.contiguous(), ctx.im2col_step
]
if ctx.remove_center or dcn_version > 1.0:
args.append(ctx.remove_center)
grad_input, grad_offset, grad_mask = \
DCNv3.dcnv3_backward(*args)
return grad_input, grad_offset, grad_mask, \
None, None, None, None, None, None, None, None, None, None, None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment