Unverified Commit ca99624f authored by sherie's avatar sherie Committed by GitHub
Browse files

[Fix] Fix the support for nms_rotated in Ascend (#2931)

parent b361a81a
...@@ -27,16 +27,21 @@ ...@@ -27,16 +27,21 @@
#define NPU_NAME_SPACE at_npu::native #define NPU_NAME_SPACE at_npu::native
#if MMCV_WITH_XLA #ifdef MMCV_WITH_XLA
#define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value) #define REGISTER_NPU_IMPL(key, value) REGISTER_DEVICE_IMPL(key, XLA, value)
#else #else
#define REGISTER_NPU_IMPL(key, value) \ #define REGISTER_NPU_IMPL(key, value) \
REGISTER_DEVICE_IMPL(key, PrivateUse1, value) REGISTER_DEVICE_IMPL(key, PrivateUse1, value)
#endif #endif
#define CHECK_NPU(x) \ #ifdef MMCV_WITH_XLA
TORCH_CHECK( \ #define CHECK_NPU(x) \
x.device().type() == at::kXLA || x.device().type() == at::kPrivateUse1, \ TORCH_CHECK(x.device().type() == at::kXLA, #x " must be a NPU tensor")
#x " must be a NPU tensor") #else
#define CHECK_NPU(x) \
TORCH_CHECK(x.device().type() == at::kPrivateUse1, #x \
" must be a NPU " \
"tensor")
#endif
#endif // PYTORCH_NPU_HELPER_HPP_ #endif // PYTORCH_NPU_HELPER_HPP_
...@@ -36,11 +36,13 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, ...@@ -36,11 +36,13 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
#else #else
AT_ERROR("Not compiled with GPU support"); AT_ERROR("Not compiled with GPU support");
#endif #endif
#ifdef MMCV_WITH_XLA
} else if (dets.device().type() == at::kXLA) { } else if (dets.device().type() == at::kXLA) {
#ifdef MMCV_WITH_NPU
return nms_rotated_npu(dets, scores, labels, iou_threshold); return nms_rotated_npu(dets, scores, labels, iou_threshold);
#else #endif
AT_ERROR("Not compiled with NPU support"); #ifdef MMCV_WITH_KPRIVATE
} else if (dets.device().type() == at::kPrivateUse1) {
return nms_rotated_npu(dets, scores, labels, iou_threshold);
#endif #endif
#ifdef MMCV_WITH_MLU #ifdef MMCV_WITH_MLU
} else if (dets.device().type() == at::kMLU) { } else if (dets.device().type() == at::kMLU) {
......
...@@ -21,9 +21,53 @@ void gather_points_forward_npu(int b, int c, int n, int npoints, ...@@ -21,9 +21,53 @@ void gather_points_forward_npu(int b, int c, int n, int npoints,
.Attr("batch_dims", batch_dims) .Attr("batch_dims", batch_dims)
.Run(); .Run();
} }
void gather_points_backward_npu(int b, int c, int n, int npoints,
const Tensor grad_out, const Tensor idx,
Tensor grad_points) {
at::Tensor indices = idx;
if (idx.scalar_type() != at::ScalarType::Int) {
indices = idx.to(at::kInt);
}
if (idx.dim() == 0) {
indices.unsqueeze_(0);
}
int64_t dim = 0;
at::SmallVector<int64_t, N> pad_size = array_to_small_vector(idx.sizes());
at::Tensor trans_grad_points = grad_points.transpose(1, 2).contiguous();
at::Tensor grad_points_view = trans_grad_points.view(
{trans_grad_points.sizes()[0] * trans_grad_points.sizes()[1],
trans_grad_points.sizes()[2]});
at::Tensor trans_grad_out = grad_out.transpose(1, 2).contiguous();
trans_grad_out = trans_grad_out.view(
{trans_grad_out.sizes()[0] * trans_grad_out.sizes()[1],
trans_grad_out.sizes()[2]});
auto index = at::arange(0, b);
index = index.to(grad_out.device());
index = at::mul(index, n);
index = index.view({b, 1});
index = at::broadcast_to(index, pad_size);
indices = at::add(index, indices);
indices = indices.view({-1});
OpCommand cmd;
cmd.Name("InplaceIndexAdd")
.Input(grad_points_view)
.Input(indices)
.Input(trans_grad_out)
.Output(grad_points_view)
.Attr("axis", dim)
.Run();
at::Tensor grad_points_result =
grad_points_view.view(trans_grad_points.sizes());
grad_points_result = grad_points_result.transpose(1, 2);
grad_points.copy_(grad_points_result);
}
void gather_points_forward_impl(int b, int c, int n, int npoints, void gather_points_forward_impl(int b, int c, int n, int npoints,
const Tensor points, const Tensor idx, const Tensor points, const Tensor idx,
Tensor out); Tensor out);
void gather_points_backward_impl(int b, int c, int n, int npoints,
const Tensor grad_out, const Tensor idx,
Tensor grad_points);
REGISTER_NPU_IMPL(gather_points_forward_impl, gather_points_forward_npu); REGISTER_NPU_IMPL(gather_points_forward_impl, gather_points_forward_npu);
REGISTER_NPU_IMPL(gather_points_backward_impl, gather_points_backward_npu);
...@@ -396,8 +396,10 @@ def get_extensions(): ...@@ -396,8 +396,10 @@ def get_extensions():
from torch_npu.utils.cpp_extension import NpuExtension from torch_npu.utils.cpp_extension import NpuExtension
define_macros += [('MMCV_WITH_NPU', None)] define_macros += [('MMCV_WITH_NPU', None)]
extension = NpuExtension extension = NpuExtension
if parse_version(torch.__version__) >= parse_version('2.0.0'): if parse_version(torch.__version__) <= parse_version('2.0.0'):
define_macros += [('MMCV_WITH_XLA', None)] define_macros += [('MMCV_WITH_XLA', None)]
if parse_version(torch.__version__) > parse_version('2.0.0'):
define_macros += [('MMCV_WITH_KPRIVATE', None)]
except Exception: except Exception:
raise ImportError('can not find any torch_npu') raise ImportError('can not find any torch_npu')
# src # src
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment