Unverified Commit 86c9ac2c authored by BigBigDream's avatar BigBigDream Committed by GitHub
Browse files

add 'iof' mode for box_iou_rotated (#753)

* add iof mode for box_iou_rotated

* update doc

* fix lint

* fix lint

* fix lint

* fix lint

* fix lint
parent de767fc1
...@@ -3,7 +3,7 @@ from ..utils import ext_loader ...@@ -3,7 +3,7 @@ from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated']) ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated'])
def box_iou_rotated(bboxes1, bboxes2, aligned=False): def box_iou_rotated(bboxes1, bboxes2, mode='iou', aligned=False):
"""Return intersection-over-union (Jaccard index) of boxes. """Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in Both sets of boxes are expected to be in
...@@ -18,10 +18,15 @@ def box_iou_rotated(bboxes1, bboxes2, aligned=False): ...@@ -18,10 +18,15 @@ def box_iou_rotated(bboxes1, bboxes2, aligned=False):
It has shape (N, 5), indicating (x, y, w, h, theta) for each row. It has shape (N, 5), indicating (x, y, w, h, theta) for each row.
boxes2 (Tensor): rotated bboxes 2. \ boxes2 (Tensor): rotated bboxes 2. \
It has shape (M, 5), indicating (x, y, w, h, theta) for each row. It has shape (M, 5), indicating (x, y, w, h, theta) for each row.
mode (str): "iou" (intersection over union) or iof (intersection over
foreground).
Returns: Returns:
ious(Tensor): shape (N, M) if aligned == False else shape (N,) ious(Tensor): shape (N, M) if aligned == False else shape (N,)
""" """
assert mode in ['iou', 'iof']
mode_dict = {'iou': 0, 'iof': 1}
mode_flag = mode_dict[mode]
rows = bboxes1.size(0) rows = bboxes1.size(0)
cols = bboxes2.size(0) cols = bboxes2.size(0)
if aligned: if aligned:
...@@ -30,7 +35,8 @@ def box_iou_rotated(bboxes1, bboxes2, aligned=False): ...@@ -30,7 +35,8 @@ def box_iou_rotated(bboxes1, bboxes2, aligned=False):
ious = bboxes1.new_zeros((rows * cols)) ious = bboxes1.new_zeros((rows * cols))
bboxes1 = bboxes1.contiguous() bboxes1 = bboxes1.contiguous()
bboxes2 = bboxes2.contiguous() bboxes2 = bboxes2.contiguous()
ext_module.box_iou_rotated(bboxes1, bboxes2, ious, aligned=aligned) ext_module.box_iou_rotated(
bboxes1, bboxes2, ious, mode_flag=mode_flag, aligned=aligned)
if not aligned: if not aligned:
ious = ious.view(rows, cols) ious = ious.view(rows, cols)
return ious return ious
...@@ -18,11 +18,9 @@ const int BLOCK_DIM_Y = 16; ...@@ -18,11 +18,9 @@ const int BLOCK_DIM_Y = 16;
inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); } inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); }
template <typename T> template <typename T>
__global__ void box_iou_rotated_cuda_kernel(const int n_boxes1, __global__ void box_iou_rotated_cuda_kernel(
const int n_boxes2, const int n_boxes1, const int n_boxes2, const T* dev_boxes1,
const T* dev_boxes1, const T* dev_boxes2, T* dev_ious, const int mode_flag, const bool aligned) {
const T* dev_boxes2, T* dev_ious,
const bool aligned) {
if (aligned) { if (aligned) {
CUDA_1D_KERNEL_LOOP(index, n_boxes1) { CUDA_1D_KERNEL_LOOP(index, n_boxes1) {
int b1 = index; int b1 = index;
...@@ -47,7 +45,8 @@ __global__ void box_iou_rotated_cuda_kernel(const int n_boxes1, ...@@ -47,7 +45,8 @@ __global__ void box_iou_rotated_cuda_kernel(const int n_boxes1,
block_boxes2[3] = dev_boxes2[base2 + 3]; block_boxes2[3] = dev_boxes2[base2 + 3];
block_boxes2[4] = dev_boxes2[base2 + 4]; block_boxes2[4] = dev_boxes2[base2 + 4];
dev_ious[index] = single_box_iou_rotated<T>(block_boxes1, block_boxes2); dev_ious[index] =
single_box_iou_rotated<T>(block_boxes1, block_boxes2, mode_flag);
} }
} else { } else {
CUDA_1D_KERNEL_LOOP(index, n_boxes1 * n_boxes2) { CUDA_1D_KERNEL_LOOP(index, n_boxes1 * n_boxes2) {
...@@ -73,7 +72,8 @@ __global__ void box_iou_rotated_cuda_kernel(const int n_boxes1, ...@@ -73,7 +72,8 @@ __global__ void box_iou_rotated_cuda_kernel(const int n_boxes1,
block_boxes2[3] = dev_boxes2[base2 + 3]; block_boxes2[3] = dev_boxes2[base2 + 3];
block_boxes2[4] = dev_boxes2[base2 + 4]; block_boxes2[4] = dev_boxes2[base2 + 4];
dev_ious[index] = single_box_iou_rotated<T>(block_boxes1, block_boxes2); dev_ious[index] =
single_box_iou_rotated<T>(block_boxes1, block_boxes2, mode_flag);
} }
} }
} }
......
...@@ -308,7 +308,8 @@ HOST_DEVICE_INLINE T rotated_boxes_intersection(const RotatedBox<T>& box1, ...@@ -308,7 +308,8 @@ HOST_DEVICE_INLINE T rotated_boxes_intersection(const RotatedBox<T>& box1,
template <typename T> template <typename T>
HOST_DEVICE_INLINE T single_box_iou_rotated(T const* const box1_raw, HOST_DEVICE_INLINE T single_box_iou_rotated(T const* const box1_raw,
T const* const box2_raw) { T const* const box2_raw,
const int mode_flag) {
// shift center to the middle point to achieve higher precision in result // shift center to the middle point to achieve higher precision in result
RotatedBox<T> box1, box2; RotatedBox<T> box1, box2;
auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0; auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
...@@ -331,6 +332,12 @@ HOST_DEVICE_INLINE T single_box_iou_rotated(T const* const box1_raw, ...@@ -331,6 +332,12 @@ HOST_DEVICE_INLINE T single_box_iou_rotated(T const* const box1_raw,
} }
const T intersection = rotated_boxes_intersection<T>(box1, box2); const T intersection = rotated_boxes_intersection<T>(box1, box2);
const T iou = intersection / (area1 + area2 - intersection); T baseS = 1.0;
if (mode_flag == 0) {
baseS = (area1 + area2 - intersection);
} else if (mode_flag == 1) {
baseS = area1;
}
const T iou = intersection / baseS;
return iou; return iou;
} }
...@@ -71,7 +71,7 @@ __global__ void nms_rotated_cuda_kernel(const int n_boxes, ...@@ -71,7 +71,7 @@ __global__ void nms_rotated_cuda_kernel(const int n_boxes,
// Instead of devIoU used by original horizontal nms, here // Instead of devIoU used by original horizontal nms, here
// we use the single_box_iou_rotated function from // we use the single_box_iou_rotated function from
// box_iou_rotated_utils.h // box_iou_rotated_utils.h
if (single_box_iou_rotated<T>(cur_box, block_boxes + i * 6) > if (single_box_iou_rotated<T>(cur_box, block_boxes + i * 6, 0) >
iou_threshold) { iou_threshold) {
t |= 1ULL << i; t |= 1ULL << i;
} }
...@@ -121,7 +121,7 @@ __global__ void nms_rotated_cuda_kernel(const int n_boxes, ...@@ -121,7 +121,7 @@ __global__ void nms_rotated_cuda_kernel(const int n_boxes,
// Instead of devIoU used by original horizontal nms, here // Instead of devIoU used by original horizontal nms, here
// we use the single_box_iou_rotated function from // we use the single_box_iou_rotated function from
// box_iou_rotated_utils.h // box_iou_rotated_utils.h
if (single_box_iou_rotated<T>(cur_box, block_boxes + i * 5) > if (single_box_iou_rotated<T>(cur_box, block_boxes + i * 5, 0) >
iou_threshold) { iou_threshold) {
t |= 1ULL << i; t |= 1ULL << i;
} }
......
...@@ -5,11 +5,12 @@ ...@@ -5,11 +5,12 @@
void box_iou_rotated_cpu_launcher(const DArrayLite boxes1, void box_iou_rotated_cpu_launcher(const DArrayLite boxes1,
const DArrayLite boxes2, DArrayLite ious, const DArrayLite boxes2, DArrayLite ious,
const bool aligned); const int mode_flag, const bool aligned);
void box_iou_rotated_cuda_launcher(const DArrayLite boxes1, void box_iou_rotated_cuda_launcher(const DArrayLite boxes1,
const DArrayLite boxes2, DArrayLite ious, const DArrayLite boxes2, DArrayLite ious,
const bool aligned, cudaStream_t stream); const int mode_flag, const bool aligned,
cudaStream_t stream);
void box_iou_rotated_cpu(HostContext& ctx, const SSElement& attr, void box_iou_rotated_cpu(HostContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins, const OperatorBase::in_list_t& ins,
...@@ -18,9 +19,13 @@ void box_iou_rotated_cpu(HostContext& ctx, const SSElement& attr, ...@@ -18,9 +19,13 @@ void box_iou_rotated_cpu(HostContext& ctx, const SSElement& attr,
const auto& boxes2 = ins[1]; const auto& boxes2 = ins[1];
bool aligned; bool aligned;
SSAttrs(attr).get<bool>("aligned", aligned).done(); int mode_flag;
SSAttrs(attr)
.get<bool>("aligned", aligned)
.get<int>("mode_flag", mode_flag)
.done();
auto& ious = outs[0]; auto& ious = outs[0];
box_iou_rotated_cpu_launcher(boxes1, boxes2, ious, aligned); box_iou_rotated_cpu_launcher(boxes1, boxes2, ious, mode_flag, aligned);
} }
void box_iou_rotated_cuda(CudaContext& ctx, const SSElement& attr, void box_iou_rotated_cuda(CudaContext& ctx, const SSElement& attr,
...@@ -30,15 +35,21 @@ void box_iou_rotated_cuda(CudaContext& ctx, const SSElement& attr, ...@@ -30,15 +35,21 @@ void box_iou_rotated_cuda(CudaContext& ctx, const SSElement& attr,
const auto& boxes2 = ins[1]; const auto& boxes2 = ins[1];
bool aligned; bool aligned;
SSAttrs(attr).get<bool>("aligned", aligned).done(); int mode_flag;
SSAttrs(attr)
.get<bool>("aligned", aligned)
.get<int>("mode_flag", mode_flag)
.done();
cudaStream_t stream = getStreamNative<CudaDevice>(ctx.getStream()); cudaStream_t stream = getStreamNative<CudaDevice>(ctx.getStream());
auto& ious = outs[0]; auto& ious = outs[0];
box_iou_rotated_cuda_launcher(boxes1, boxes2, ious, aligned, stream); box_iou_rotated_cuda_launcher(boxes1, boxes2, ious, mode_flag, aligned,
stream);
} }
PARROTS_EXTENSION_REGISTER(box_iou_rotated) PARROTS_EXTENSION_REGISTER(box_iou_rotated)
.attr("aligned") .attr("aligned")
.attr("mode_flag")
.input(2) .input(2)
.output(1) .output(1)
.apply(box_iou_rotated_cpu) .apply(box_iou_rotated_cpu)
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
template <typename T> template <typename T>
void box_iou_rotated_cpu_kernel(const DArrayLite boxes1, void box_iou_rotated_cpu_kernel(const DArrayLite boxes1,
const DArrayLite boxes2, DArrayLite ious, const DArrayLite boxes2, DArrayLite ious,
const bool aligned) { const int mode_flag, const bool aligned) {
int output_size = ious.size(); int output_size = ious.size();
int num_boxes1 = boxes1.dim(0); int num_boxes1 = boxes1.dim(0);
int num_boxes2 = boxes2.dim(0); int num_boxes2 = boxes2.dim(0);
...@@ -16,14 +16,14 @@ void box_iou_rotated_cpu_kernel(const DArrayLite boxes1, ...@@ -16,14 +16,14 @@ void box_iou_rotated_cpu_kernel(const DArrayLite boxes1,
if (aligned) { if (aligned) {
for (int i = 0; i < output_size; i++) { for (int i = 0; i < output_size; i++) {
ious_ptr[i] = ious_ptr[i] = single_box_iou_rotated<T>(boxes1[i].ptr<T>(),
single_box_iou_rotated<T>(boxes1[i].ptr<T>(), boxes2[i].ptr<T>()); boxes2[i].ptr<T>(), mode_flag);
} }
} else { } else {
for (int i = 0; i < num_boxes1; i++) { for (int i = 0; i < num_boxes1; i++) {
for (int j = 0; j < num_boxes2; j++) { for (int j = 0; j < num_boxes2; j++) {
ious_ptr[i * num_boxes2 + j] = ious_ptr[i * num_boxes2 + j] = single_box_iou_rotated<T>(
single_box_iou_rotated<T>(boxes1[i].ptr<T>(), boxes2[j].ptr<T>()); boxes1[i].ptr<T>(), boxes2[j].ptr<T>(), mode_flag);
} }
} }
} }
...@@ -31,6 +31,6 @@ void box_iou_rotated_cpu_kernel(const DArrayLite boxes1, ...@@ -31,6 +31,6 @@ void box_iou_rotated_cpu_kernel(const DArrayLite boxes1,
void box_iou_rotated_cpu_launcher(const DArrayLite boxes1, void box_iou_rotated_cpu_launcher(const DArrayLite boxes1,
const DArrayLite boxes2, DArrayLite ious, const DArrayLite boxes2, DArrayLite ious,
const bool aligned) { const int mode_flag, const bool aligned) {
box_iou_rotated_cpu_kernel<float>(boxes1, boxes2, ious, aligned); box_iou_rotated_cpu_kernel<float>(boxes1, boxes2, ious, mode_flag, aligned);
} }
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
void box_iou_rotated_cuda_launcher(const DArrayLite boxes1, void box_iou_rotated_cuda_launcher(const DArrayLite boxes1,
const DArrayLite boxes2, DArrayLite ious, const DArrayLite boxes2, DArrayLite ious,
const bool aligned, cudaStream_t stream) { const int mode_flag, const bool aligned,
cudaStream_t stream) {
using scalar_t = float; using scalar_t = float;
int output_size = ious.size(); int output_size = ious.size();
...@@ -16,7 +17,8 @@ void box_iou_rotated_cuda_launcher(const DArrayLite boxes1, ...@@ -16,7 +17,8 @@ void box_iou_rotated_cuda_launcher(const DArrayLite boxes1,
box_iou_rotated_cuda_kernel<scalar_t> box_iou_rotated_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>( <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
num_boxes1, num_boxes2, boxes1.ptr<scalar_t>(), num_boxes1, num_boxes2, boxes1.ptr<scalar_t>(),
boxes2.ptr<scalar_t>(), (scalar_t*)ious.ptr<scalar_t>(), aligned); boxes2.ptr<scalar_t>(), (scalar_t*)ious.ptr<scalar_t>(), mode_flag,
aligned);
PARROTS_CUDA_CHECK(cudaGetLastError()); PARROTS_CUDA_CHECK(cudaGetLastError());
} }
...@@ -4,26 +4,26 @@ ...@@ -4,26 +4,26 @@
#include "pytorch_cpp_helper.hpp" #include "pytorch_cpp_helper.hpp"
void box_iou_rotated_cpu(const Tensor boxes1, const Tensor boxes2, Tensor ious, void box_iou_rotated_cpu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const bool aligned); const int mode_flag, const bool aligned);
#ifdef MMCV_WITH_CUDA #ifdef MMCV_WITH_CUDA
void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious, void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const bool aligned); const int mode_flag, const bool aligned);
#endif #endif
// Interface for Python // Interface for Python
// inline is needed to prevent multiple function definitions when this header is // inline is needed to prevent multiple function definitions when this header is
// included by different cpps // included by different cpps
void box_iou_rotated(const Tensor boxes1, const Tensor boxes2, Tensor ious, void box_iou_rotated(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const bool aligned) { const int mode_flag, const bool aligned) {
assert(boxes1.device().is_cuda() == boxes2.device().is_cuda()); assert(boxes1.device().is_cuda() == boxes2.device().is_cuda());
if (boxes1.device().is_cuda()) { if (boxes1.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA #ifdef MMCV_WITH_CUDA
box_iou_rotated_cuda(boxes1, boxes2, ious, aligned); box_iou_rotated_cuda(boxes1, boxes2, ious, mode_flag, aligned);
#else #else
AT_ERROR("Not compiled with GPU support"); AT_ERROR("Not compiled with GPU support");
#endif #endif
} else { } else {
box_iou_rotated_cpu(boxes1, boxes2, ious, aligned); box_iou_rotated_cpu(boxes1, boxes2, ious, mode_flag, aligned);
} }
} }
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
template <typename T> template <typename T>
void box_iou_rotated_cpu_kernel(const Tensor boxes1, const Tensor boxes2, void box_iou_rotated_cpu_kernel(const Tensor boxes1, const Tensor boxes2,
Tensor ious, const bool aligned) { Tensor ious, const int mode_flag,
const bool aligned) {
int output_size = ious.numel(); int output_size = ious.numel();
auto num_boxes1 = boxes1.size(0); auto num_boxes1 = boxes1.size(0);
auto num_boxes2 = boxes2.size(0); auto num_boxes2 = boxes2.size(0);
...@@ -14,19 +15,19 @@ void box_iou_rotated_cpu_kernel(const Tensor boxes1, const Tensor boxes2, ...@@ -14,19 +15,19 @@ void box_iou_rotated_cpu_kernel(const Tensor boxes1, const Tensor boxes2,
if (aligned) { if (aligned) {
for (int i = 0; i < output_size; i++) { for (int i = 0; i < output_size; i++) {
ious[i] = single_box_iou_rotated<T>(boxes1[i].data_ptr<T>(), ious[i] = single_box_iou_rotated<T>(boxes1[i].data_ptr<T>(),
boxes2[i].data_ptr<T>()); boxes2[i].data_ptr<T>(), mode_flag);
} }
} else { } else {
for (int i = 0; i < num_boxes1; i++) { for (int i = 0; i < num_boxes1; i++) {
for (int j = 0; j < num_boxes2; j++) { for (int j = 0; j < num_boxes2; j++) {
ious[i * num_boxes2 + j] = single_box_iou_rotated<T>( ious[i * num_boxes2 + j] = single_box_iou_rotated<T>(
boxes1[i].data_ptr<T>(), boxes2[j].data_ptr<T>()); boxes1[i].data_ptr<T>(), boxes2[j].data_ptr<T>(), mode_flag);
} }
} }
} }
} }
void box_iou_rotated_cpu(const Tensor boxes1, const Tensor boxes2, Tensor ious, void box_iou_rotated_cpu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const bool aligned) { const int mode_flag, const bool aligned) {
box_iou_rotated_cpu_kernel<float>(boxes1, boxes2, ious, aligned); box_iou_rotated_cpu_kernel<float>(boxes1, boxes2, ious, mode_flag, aligned);
} }
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "pytorch_cuda_helper.hpp" #include "pytorch_cuda_helper.hpp"
void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious, void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const bool aligned) { const int mode_flag, const bool aligned) {
using scalar_t = float; using scalar_t = float;
AT_ASSERTM(boxes1.type().is_cuda(), "boxes1 must be a CUDA tensor"); AT_ASSERTM(boxes1.type().is_cuda(), "boxes1 must be a CUDA tensor");
AT_ASSERTM(boxes2.type().is_cuda(), "boxes2 must be a CUDA tensor"); AT_ASSERTM(boxes2.type().is_cuda(), "boxes2 must be a CUDA tensor");
...@@ -20,6 +20,6 @@ void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious, ...@@ -20,6 +20,6 @@ void box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious,
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>( <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
num_boxes1, num_boxes2, boxes1.data_ptr<scalar_t>(), num_boxes1, num_boxes2, boxes1.data_ptr<scalar_t>(),
boxes2.data_ptr<scalar_t>(), (scalar_t*)ious.data_ptr<scalar_t>(), boxes2.data_ptr<scalar_t>(), (scalar_t*)ious.data_ptr<scalar_t>(),
aligned); mode_flag, aligned);
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
} }
...@@ -46,8 +46,8 @@ Tensor nms_rotated_cpu_kernel(const Tensor dets, const Tensor scores, ...@@ -46,8 +46,8 @@ Tensor nms_rotated_cpu_kernel(const Tensor dets, const Tensor scores,
continue; continue;
} }
auto ovr = single_box_iou_rotated<scalar_t>(dets[i].data_ptr<scalar_t>(), auto ovr = single_box_iou_rotated<scalar_t>(
dets[j].data_ptr<scalar_t>()); dets[i].data_ptr<scalar_t>(), dets[j].data_ptr<scalar_t>(), 0);
if (ovr >= iou_threshold) { if (ovr >= iou_threshold) {
suppressed[j] = 1; suppressed[j] = 1;
} }
......
...@@ -176,7 +176,7 @@ Tensor top_pool_forward(Tensor input); ...@@ -176,7 +176,7 @@ Tensor top_pool_forward(Tensor input);
Tensor top_pool_backward(Tensor input, Tensor grad_output); Tensor top_pool_backward(Tensor input, Tensor grad_output);
void box_iou_rotated(const Tensor boxes1, const Tensor boxes2, Tensor ious, void box_iou_rotated(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const bool aligned); const int mode_flag, const bool aligned);
Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
const Tensor dets_sorted, const float iou_threshold, const Tensor dets_sorted, const float iou_threshold,
...@@ -366,7 +366,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -366,7 +366,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("box_iou_rotated", &box_iou_rotated, "IoU for rotated boxes", m.def("box_iou_rotated", &box_iou_rotated, "IoU for rotated boxes",
py::arg("boxes1"), py::arg("boxes2"), py::arg("ious"), py::arg("boxes1"), py::arg("boxes2"), py::arg("ious"),
py::arg("aligned")); py::arg("mode_flag"), py::arg("aligned"));
m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes", py::arg("dets"), m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes", py::arg("dets"),
py::arg("scores"), py::arg("order"), py::arg("dets_sorted"), py::arg("scores"), py::arg("order"), py::arg("dets_sorted"),
py::arg("iou_threshold"), py::arg("multi_label")); py::arg("iou_threshold"), py::arg("multi_label"));
......
...@@ -60,3 +60,58 @@ class TestBoxIoURotated(object): ...@@ -60,3 +60,58 @@ class TestBoxIoURotated(object):
ious = box_iou_rotated(boxes1, boxes2, aligned=True) ious = box_iou_rotated(boxes1, boxes2, aligned=True)
assert np.allclose( assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4) ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
def test_box_iou_rotated_iof_cpu(self):
from mmcv.ops import box_iou_rotated
np_boxes1 = np.asarray(
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
[7.0, 7.0, 8.0, 8.0, 0.4]],
dtype=np.float32)
np_boxes2 = np.asarray(
[[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5],
[5.0, 5.0, 6.0, 7.0, 0.4]],
dtype=np.float32)
np_expect_ious = np.asarray(
[[0.4959, 0.5306, 0.0000], [0.1823, 0.5420, 0.1832],
[0.0000, 0.0000, 0.4404]],
dtype=np.float32)
np_expect_ious_aligned = np.asarray([0.4959, 0.5420, 0.4404],
dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1)
boxes2 = torch.from_numpy(np_boxes2)
ious = box_iou_rotated(boxes1, boxes2, mode='iof')
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
ious = box_iou_rotated(boxes1, boxes2, mode='iof', aligned=True)
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_box_iou_rotated_iof_cuda(self):
from mmcv.ops import box_iou_rotated
np_boxes1 = np.asarray(
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
[7.0, 7.0, 8.0, 8.0, 0.4]],
dtype=np.float32)
np_boxes2 = np.asarray(
[[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5],
[5.0, 5.0, 6.0, 7.0, 0.4]],
dtype=np.float32)
np_expect_ious = np.asarray(
[[0.4959, 0.5306, 0.0000], [0.1823, 0.5420, 0.1832],
[0.0000, 0.0000, 0.4404]],
dtype=np.float32)
np_expect_ious_aligned = np.asarray([0.4959, 0.5420, 0.4404],
dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1).cuda()
boxes2 = torch.from_numpy(np_boxes2).cuda()
ious = box_iou_rotated(boxes1, boxes2, mode='iof')
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
ious = box_iou_rotated(boxes1, boxes2, mode='iof', aligned=True)
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment