Unverified Commit 50d1fffb authored by liuhw's avatar liuhw Committed by GitHub
Browse files

[Fix] Keep the shape of iou op's parameter 2 smaller than parameter 1 (#2821)

parent f7382417
...@@ -12,11 +12,17 @@ void bbox_overlaps_npu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, ...@@ -12,11 +12,17 @@ void bbox_overlaps_npu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
if (mode == 1) { if (mode == 1) {
modeStr = "iof"; modeStr = "iof";
} }
bool swap_flag = false;
at::Tensor bboxesFP32 = bboxes2; at::Tensor bboxesFP32 = bboxes2;
at::Tensor gtboxesFP32 = bboxes1; at::Tensor gtboxesFP32 = bboxes1;
if (bboxes2.size(0) < bboxes1.size(0)) {
swap_flag = true;
bboxesFP32 = bboxes1;
gtboxesFP32 = bboxes2;
}
if (bboxes2.scalar_type() != at::ScalarType::Float) { if (bboxes2.scalar_type() != at::ScalarType::Float) {
bboxesFP32 = NPUNativeFunctions::npu_dtype_cast(bboxes2, at::kFloat); bboxesFP32 = NPUNativeFunctions::npu_dtype_cast(bboxesFP32, at::kFloat);
gtboxesFP32 = NPUNativeFunctions::npu_dtype_cast(bboxes1, at::kFloat); gtboxesFP32 = NPUNativeFunctions::npu_dtype_cast(gtboxesFP32, at::kFloat);
} }
c10::SmallVector<int64_t, SIZE> iousSize = {gtboxesFP32.size(0), c10::SmallVector<int64_t, SIZE> iousSize = {gtboxesFP32.size(0),
bboxesFP32.size(0)}; bboxesFP32.size(0)};
...@@ -38,6 +44,7 @@ void bbox_overlaps_npu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious, ...@@ -38,6 +44,7 @@ void bbox_overlaps_npu(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
if (bboxes2.scalar_type() != at::ScalarType::Float) { if (bboxes2.scalar_type() != at::ScalarType::Float) {
iousFP32 = NPUNativeFunctions::npu_dtype_cast(iousFP32, at::kHalf); iousFP32 = NPUNativeFunctions::npu_dtype_cast(iousFP32, at::kHalf);
} }
iousFP32 = swap_flag ? iousFP32.transpose(0, 1) : iousFP32;
ious.copy_(iousFP32); ious.copy_(iousFP32);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment