Unverified Commit 987d34b0 authored by MrShadowY's avatar MrShadowY Committed by GitHub
Browse files

[Feature] Add the support of BoxIouRotated op for ascend device (#2842)


Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 0a2f60ba
...@@ -9,7 +9,7 @@ We implement common ops used in detection, segmentation, etc. ...@@ -9,7 +9,7 @@ We implement common ops used in detection, segmentation, etc.
| BallQuery | | √ | √ | | | | BallQuery | | √ | √ | | |
| BBoxOverlaps | | √ | √ | √ | √ | | BBoxOverlaps | | √ | √ | √ | √ |
| BorderAlign | | √ | | | | | BorderAlign | | √ | | | |
| BoxIouRotated | √ | √ | √ | | | | BoxIouRotated | √ | √ | √ | | |
| BoxIouQuadri | √ | √ | | | | | BoxIouQuadri | √ | √ | | | |
| CARAFE | | √ | √ | | | | CARAFE | | √ | √ | | |
| ChamferDistance | | √ | | | | | ChamferDistance | | √ | | | |
......
...@@ -9,7 +9,7 @@ MMCV 提供了检测、分割等任务中常用的算子 ...@@ -9,7 +9,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| BallQuery | | √ | √ | | | | BallQuery | | √ | √ | | |
| BBoxOverlaps | | √ | √ | √ | √ | | BBoxOverlaps | | √ | √ | √ | √ |
| BorderAlign | | √ | | | | | BorderAlign | | √ | | | |
| BoxIouRotated | √ | √ | √ | | | | BoxIouRotated | √ | √ | √ | | |
| BoxIouQuadri | √ | √ | | | | | BoxIouQuadri | √ | √ | | | |
| CARAFE | | √ | √ | | | | CARAFE | | √ | √ | | |
| ChamferDistance | | √ | | | | | ChamferDistance | | √ | | | |
......
...@@ -142,6 +142,11 @@ def box_iou_rotated(bboxes1: torch.Tensor, ...@@ -142,6 +142,11 @@ def box_iou_rotated(bboxes1: torch.Tensor,
flip_mat[-1] = -1 flip_mat[-1] = -1
bboxes1 = bboxes1 * flip_mat bboxes1 = bboxes1 * flip_mat
bboxes2 = bboxes2 * flip_mat bboxes2 = bboxes2 * flip_mat
if bboxes1.device.type == 'npu':
scale_mat = bboxes1.new_ones(bboxes1.shape[-1])
scale_mat[-1] = 1.0 / 0.01745329252
bboxes1 = bboxes1 * scale_mat
bboxes2 = bboxes2 * scale_mat
bboxes1 = bboxes1.contiguous() bboxes1 = bboxes1.contiguous()
bboxes2 = bboxes2.contiguous() bboxes2 = bboxes2.contiguous()
ext_module.box_iou_rotated( ext_module.box_iou_rotated(
......
#include "pytorch_npu_helper.hpp"
using namespace NPU_NAME_SPACE;
using namespace std;
void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
void box_iou_rotated_npu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
at::Tensor boxes = at::ones_like(boxes1);
at::Tensor query_boxes = at::ones_like(boxes2);
boxes = boxes1.transpose(0, 1).unsqueeze(0);
query_boxes = boxes2.transpose(0, 1).unsqueeze(0);
bool is_trans = false;
string modeStr = "iou";
if (mode_flag == 1) {
modeStr = "iof";
}
bool is_cross = true;
if (aligned) {
is_cross = false;
}
float v_threshold = 0;
float e_threshold = 0;
OpCommand cmd;
cmd.Name("RotatedIou")
.Input(boxes)
.Input(query_boxes)
.Output(ious)
.Attr("trans", is_trans)
.Attr("mode", modeStr)
.Attr("is_cross", is_cross)
.Attr("v_threshold", v_threshold)
.Attr("e_threshold", e_threshold)
.Run();
if (is_cross) {
ious = ious.view({boxes1.size(0), boxes2.size(0)});
} else {
ious = ious.view({boxes1.size(0), 1});
}
}
REGISTER_NPU_IMPL(box_iou_rotated_impl, box_iou_rotated_npu);
...@@ -4,7 +4,7 @@ import pytest ...@@ -4,7 +4,7 @@ import pytest
import torch import torch
from mmcv.ops import box_iou_rotated from mmcv.ops import box_iou_rotated
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE
class TestBoxIoURotated: class TestBoxIoURotated:
...@@ -54,7 +54,11 @@ class TestBoxIoURotated: ...@@ -54,7 +54,11 @@ class TestBoxIoURotated:
pytest.param( pytest.param(
'mlu', 'mlu',
marks=pytest.mark.skipif( marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support')) not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
]) ])
def test_box_iou_rotated(self, device): def test_box_iou_rotated(self, device):
np_boxes1 = np.asarray( np_boxes1 = np.asarray(
...@@ -137,7 +141,11 @@ class TestBoxIoURotated: ...@@ -137,7 +141,11 @@ class TestBoxIoURotated:
pytest.param( pytest.param(
'mlu', 'mlu',
marks=pytest.mark.skipif( marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support')) not IS_MLU_AVAILABLE, reason='requires MLU support')),
pytest.param(
'npu',
marks=pytest.mark.skipif(
not IS_NPU_AVAILABLE, reason='requires NPU support'))
]) ])
def test_box_iou_rotated_iof(self, device): def test_box_iou_rotated_iof(self, device):
np_boxes1 = np.asarray( np_boxes1 = np.asarray(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment