Unverified Commit f2439cd7 authored by 6V's avatar 6V Committed by GitHub
Browse files

Fix modulated_deform_conv for torch_npu v2.1 (#2941)

parent 8523eeef
...@@ -58,7 +58,8 @@ class ModulatedDeformConv2dFunction(Function): ...@@ -58,7 +58,8 @@ class ModulatedDeformConv2dFunction(Function):
kernel_w, kernel_h, ctx.deform_groups) kernel_w, kernel_h, ctx.deform_groups)
select_offset = offset.index_select(1, sort_index_fp) select_offset = offset.index_select(1, sort_index_fp)
offset_all = torch.cat([select_offset, mask], dim=1) offset_all = torch.cat([select_offset, mask], dim=1)
output, offset_out = torch.npu_deformable_conv2d( import torch_npu
output, offset_out = torch_npu.npu_deformable_conv2d(
input_tensor, input_tensor,
weight, weight,
offset_all, offset_all,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment