Unverified Commit 9a3f1249 authored by Ming-Hsuan-Tu's avatar Ming-Hsuan-Tu Committed by GitHub
Browse files

[fix] Missing arguments when converting dcn to onnx (#624)



* fix issues when converting deformable convolution to onnx

* keep  and  for interface consistency
Co-authored-by: default avatarmaningsheng <maningsheng@sensetime.com>
parent bcf85026
...@@ -20,20 +20,29 @@ ext_module = ext_loader.load_ext('_ext', [ ...@@ -20,20 +20,29 @@ ext_module = ext_loader.load_ext('_ext', [
class DeformConv2dFunction(Function): class DeformConv2dFunction(Function):
@staticmethod @staticmethod
def symbolic(g, input, offset, weight, stride, padding, dilation, groups, def symbolic(g,
deform_groups, bias, im2col_step): input,
offset,
weight,
stride,
padding,
dilation,
groups,
deform_groups,
bias=False,
im2col_step=32):
return g.op( return g.op(
'MMCVDeformConv2d', 'MMCVDeformConv2d',
input, input,
offset, offset,
weight, weight,
stride=stride, stride_i=stride,
padding=padding, padding_i=padding,
dilation=dilation, dilation_i=dilation,
groups=groups, groups_i=groups,
deform_groups=deform_groups, deform_groups_i=deform_groups,
bias=bias, bias_i=bias,
im2col_step=im2col_step) im2col_step_i=im2col_step)
@staticmethod @staticmethod
def forward(ctx, def forward(ctx,
...@@ -52,7 +61,6 @@ class DeformConv2dFunction(Function): ...@@ -52,7 +61,6 @@ class DeformConv2dFunction(Function):
f'Expected 4D tensor as input, got {input.dim()}D tensor \ f'Expected 4D tensor as input, got {input.dim()}D tensor \
instead.') instead.')
assert bias is False, 'Only support bias is False.' assert bias is False, 'Only support bias is False.'
ctx.stride = _pair(stride) ctx.stride = _pair(stride)
ctx.padding = _pair(padding) ctx.padding = _pair(padding)
ctx.dilation = _pair(dilation) ctx.dilation = _pair(dilation)
......
...@@ -27,11 +27,11 @@ class ModulatedDeformConv2dFunction(Function): ...@@ -27,11 +27,11 @@ class ModulatedDeformConv2dFunction(Function):
mask, mask,
weight, weight,
bias, bias,
stride=stride, stride_i=stride,
padding=padding, padding_i=padding,
dilation=dilation, dilation_i=dilation,
groups=groups, groups_i=groups,
deform_groups=deform_groups) deform_groups_i=deform_groups)
@staticmethod @staticmethod
def forward(ctx, def forward(ctx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment