Unverified Commit 8d87bf44 authored by Yinlei Sun's avatar Yinlei Sun Committed by GitHub
Browse files

[Fix] Fix deform_conv ops on Ascend NPU (#2832)

parent e322848e
...@@ -57,7 +57,8 @@ class DeformConv2dFunction(Function): ...@@ -57,7 +57,8 @@ class DeformConv2dFunction(Function):
input_tensor, grad_output, offset_out, weight, offset_all, input_tensor, grad_output, offset_out, weight, offset_all,
kernel_size=[weight.shape[3], weight.shape[2]], kernel_size=[weight.shape[3], weight.shape[2]],
stride=[1, 1, ctx.stride[0], ctx.stride[1]], stride=[1, 1, ctx.stride[0], ctx.stride[1]],
padding=[1, 1, ctx.padding[0], ctx.padding[1]], padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1],
ctx.padding[1]],
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
groups=ctx.groups, deformable_groups=ctx.deform_groups, groups=ctx.groups, deformable_groups=ctx.deform_groups,
modulated=True) modulated=True)
......
...@@ -64,7 +64,9 @@ class ModulatedDeformConv2dFunction(Function): ...@@ -64,7 +64,9 @@ class ModulatedDeformConv2dFunction(Function):
conv2d_bias, conv2d_bias,
kernel_size=[kernel_w, kernel_h], kernel_size=[kernel_w, kernel_h],
stride=[1, 1, ctx.stride[0], ctx.stride[1]], stride=[1, 1, ctx.stride[0], ctx.stride[1]],
padding=[1, 1, ctx.padding[0], ctx.padding[1]], padding=[
ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1]
],
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
groups=ctx.groups, groups=ctx.groups,
deformable_groups=ctx.deform_groups, deformable_groups=ctx.deform_groups,
...@@ -84,7 +86,8 @@ class ModulatedDeformConv2dFunction(Function): ...@@ -84,7 +86,8 @@ class ModulatedDeformConv2dFunction(Function):
input_tensor, grad_output, offset_out, weight, offset_all, input_tensor, grad_output, offset_out, weight, offset_all,
kernel_size=[weight.shape[3], weight.shape[2]], kernel_size=[weight.shape[3], weight.shape[2]],
stride=[1, 1, ctx.stride[0], ctx.stride[1]], stride=[1, 1, ctx.stride[0], ctx.stride[1]],
padding=[1, 1, ctx.padding[0], ctx.padding[1]], padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1],
ctx.padding[1]],
dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]],
groups=ctx.groups, deformable_groups=ctx.deform_groups, groups=ctx.groups, deformable_groups=ctx.deform_groups,
modulated=True) modulated=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment