Unverified Commit f946a933 authored by sherie's avatar sherie Committed by GitHub
Browse files

[Enhancement] Add the dtype limit of nms_npu to maintain consistency with the GPU (#2724)



* Increase the dtype limit to maintain consistency with the gpu.

* update error message
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

---------

Co-authored-by: momo609 <963372609.qq.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 8ceac934
...@@ -4,6 +4,8 @@ using namespace NPU_NAME_SPACE; ...@@ -4,6 +4,8 @@ using namespace NPU_NAME_SPACE;
using namespace std; using namespace std;
Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) { Tensor nms_npu(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
TORCH_CHECK((boxes.scalar_type == at::ScalarType::Float),
"The type of boxes tensor passed in nms_npu should be float");
int64_t offset_64 = offset; int64_t offset_64 = offset;
at::Tensor iou_threshold_y = at_npu::native::OpPreparation::ApplyTensor( at::Tensor iou_threshold_y = at_npu::native::OpPreparation::ApplyTensor(
{}, boxes.options().dtype(at::kFloat), boxes) {}, boxes.options().dtype(at::kFloat), boxes)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment