Unverified Commit a4c82617 authored by Yue Zhou's avatar Yue Zhou Committed by GitHub
Browse files

[Feature] Add box_iou_quadri & nms_quadri (#2277)

* update

* update

* fix link

* fix bug

* update nms_quadri

* fix lint

* Update test_nms_quadri.py

* Update box_iou_quadri.py

* fix bug

* Update test_nms_quadri.py

* Update box_iou_rotated_utils.hpp

* Update box_iou_quadri.py

* Update mmcv/ops/nms.py
parent 75ea2f89
...@@ -10,6 +10,7 @@ We implement common ops used in detection, segmentation, etc. ...@@ -10,6 +10,7 @@ We implement common ops used in detection, segmentation, etc.
| BBoxOverlaps | | √ | √ | √ | | BBoxOverlaps | | √ | √ | √ |
| BorderAlign | | √ | | | | BorderAlign | | √ | | |
| BoxIouRotated | √ | √ | | | | BoxIouRotated | √ | √ | | |
| BoxIouQuadri | √ | √ | | |
| CARAFE | | √ | √ | | | CARAFE | | √ | √ | |
| ChamferDistance | | √ | | | | ChamferDistance | | √ | | |
| CrissCrossAttention | | √ | | | | CrissCrossAttention | | √ | | |
...@@ -35,6 +36,7 @@ We implement common ops used in detection, segmentation, etc. ...@@ -35,6 +36,7 @@ We implement common ops used in detection, segmentation, etc.
| MultiScaleDeformableAttn | | √ | | | | MultiScaleDeformableAttn | | √ | | |
| NMS | √ | √ | √ | | | NMS | √ | √ | √ | |
| NMSRotated | √ | √ | | | | NMSRotated | √ | √ | | |
| NMSQuadri | √ | √ | | |
| PixelGroup | √ | | | | | PixelGroup | √ | | | |
| PointsInBoxes | √ | √ | | | | PointsInBoxes | √ | √ | | |
| PointsInPolygons | | √ | | | | PointsInPolygons | | √ | | |
......
...@@ -10,6 +10,7 @@ MMCV 提供了检测、分割等任务中常用的算子 ...@@ -10,6 +10,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| BBoxOverlaps | | √ | √ | √ | | BBoxOverlaps | | √ | √ | √ |
| BorderAlign | | √ | | | | BorderAlign | | √ | | |
| BoxIouRotated | √ | √ | | | | BoxIouRotated | √ | √ | | |
| BoxIouQuadri | √ | √ | | |
| CARAFE | | √ | √ | | | CARAFE | | √ | √ | |
| ChamferDistance | | √ | | | | ChamferDistance | | √ | | |
| CrissCrossAttention | | √ | | | | CrissCrossAttention | | √ | | |
...@@ -35,6 +36,7 @@ MMCV 提供了检测、分割等任务中常用的算子 ...@@ -35,6 +36,7 @@ MMCV 提供了检测、分割等任务中常用的算子
| MultiScaleDeformableAttn | | √ | | | | MultiScaleDeformableAttn | | √ | | |
| NMS | √ | √ | √ | | | NMS | √ | √ | √ | |
| NMSRotated | √ | √ | | | | NMSRotated | √ | √ | | |
| NMSQuadri | √ | √ | | |
| PixelGroup | √ | | | | | PixelGroup | √ | | | |
| PointsInBoxes | √ | √ | | | | PointsInBoxes | √ | √ | | |
| PointsInPolygons | | √ | | | | PointsInPolygons | | √ | | |
......
...@@ -4,6 +4,7 @@ from .assign_score_withk import assign_score_withk ...@@ -4,6 +4,7 @@ from .assign_score_withk import assign_score_withk
from .ball_query import ball_query from .ball_query import ball_query
from .bbox import bbox_overlaps from .bbox import bbox_overlaps
from .border_align import BorderAlign, border_align from .border_align import BorderAlign, border_align
from .box_iou_quadri import box_iou_quadri
from .box_iou_rotated import box_iou_rotated from .box_iou_rotated import box_iou_rotated
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention from .cc_attention import CrissCrossAttention
...@@ -37,7 +38,7 @@ from .modulated_deform_conv import (ModulatedDeformConv2d, ...@@ -37,7 +38,7 @@ from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack, ModulatedDeformConv2dPack,
modulated_deform_conv2d) modulated_deform_conv2d)
from .multi_scale_deform_attn import MultiScaleDeformableAttention from .multi_scale_deform_attn import MultiScaleDeformableAttention
from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms from .nms import batched_nms, nms, nms_match, nms_quadri, nms_rotated, soft_nms
from .pixel_group import pixel_group from .pixel_group import pixel_group
from .point_sample import (SimpleRoIAlign, point_sample, from .point_sample import (SimpleRoIAlign, point_sample,
rel_roi_point_to_rel_img_point) rel_roi_point_to_rel_img_point)
...@@ -82,13 +83,14 @@ __all__ = [ ...@@ -82,13 +83,14 @@ __all__ = [
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask', 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign', 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk', 'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query', 'box_iou_rotated', 'box_iou_quadri', 'RoIPointPool3d', 'nms_rotated',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu', 'knn', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU',
'rotated_feature_align', 'RiRoIAlignRotated', 'riroi_align_rotated', 'fused_bias_leakyrelu', 'rotated_feature_align', 'RiRoIAlignRotated',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'QueryAndGroup', 'riroi_align_rotated', 'RoIAlignRotated', 'roi_align_rotated',
'GroupAll', 'grouping_operation', 'contour_expand', 'three_nn', 'pixel_group', 'QueryAndGroup', 'GroupAll', 'grouping_operation',
'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign', 'contour_expand', 'three_nn', 'three_interpolate',
'border_align', 'gather_points', 'furthest_point_sample', 'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
'gather_points', 'furthest_point_sample', 'nms_quadri',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation', 'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
'boxes_iou3d', 'boxes_iou_bev', 'boxes_overlap_bev', 'nms_bev', 'boxes_iou3d', 'boxes_iou_bev', 'boxes_overlap_bev', 'nms_bev',
'nms_normal_bev', 'nms3d', 'nms3d_normal', 'Voxelization', 'voxelization', 'nms_normal_bev', 'nms3d', 'nms3d_normal', 'Voxelization', 'voxelization',
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['box_iou_quadri'])
def box_iou_quadri(bboxes1: torch.Tensor,
bboxes2: torch.Tensor,
mode: str = 'iou',
aligned: bool = False) -> torch.Tensor:
"""Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in
(x1, y1, ..., x4, y4) format.
If ``aligned`` is ``False``, then calculate the ious between each bbox
of bboxes1 and bboxes2, otherwise the ious between each aligned pair of
bboxes1 and bboxes2.
Args:
bboxes1 (torch.Tensor): quadrilateral bboxes 1. It has shape (N, 8),
indicating (x1, y1, ..., x4, y4) for each row.
bboxes2 (torch.Tensor): quadrilateral bboxes 2. It has shape (M, 8),
indicating (x1, y1, ..., x4, y4) for each row.
mode (str): "iou" (intersection over union) or iof (intersection over
foreground).
Returns:
torch.Tensor: Return the ious betweens boxes. If ``aligned`` is
``False``, the shape of ious is (N, M) else (N,).
"""
assert mode in ['iou', 'iof']
mode_dict = {'iou': 0, 'iof': 1}
mode_flag = mode_dict[mode]
rows = bboxes1.size(0)
cols = bboxes2.size(0)
if aligned:
ious = bboxes1.new_zeros(rows)
else:
ious = bboxes1.new_zeros(rows * cols)
bboxes1 = bboxes1.contiguous()
bboxes2 = bboxes2.contiguous()
ext_module.box_iou_quadri(
bboxes1, bboxes2, ious, mode_flag=mode_flag, aligned=aligned)
if not aligned:
ious = ious.view(rows, cols)
return ious
...@@ -270,6 +270,17 @@ HOST_DEVICE_INLINE int convex_hull_graham(const Point<T> (&p)[24], ...@@ -270,6 +270,17 @@ HOST_DEVICE_INLINE int convex_hull_graham(const Point<T> (&p)[24],
return m; return m;
} }
template <typename T>
HOST_DEVICE_INLINE T quadri_box_area(const Point<T> (&q)[4]) {
T area = 0;
#pragma unroll
for (int i = 1; i < 3; i++) {
area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));
}
return area / 2.0;
}
template <typename T> template <typename T>
HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) { HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) {
if (m <= 2) { if (m <= 2) {
...@@ -308,6 +319,25 @@ HOST_DEVICE_INLINE T rotated_boxes_intersection(const RotatedBox<T>& box1, ...@@ -308,6 +319,25 @@ HOST_DEVICE_INLINE T rotated_boxes_intersection(const RotatedBox<T>& box1,
return polygon_area<T>(orderedPts, num_convex); return polygon_area<T>(orderedPts, num_convex);
} }
template <typename T>
HOST_DEVICE_INLINE T quadri_boxes_intersection(const Point<T> (&pts1)[4],
const Point<T> (&pts2)[4]) {
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
// from rotated_rect_intersection_pts
Point<T> intersectPts[24], orderedPts[24];
int num = get_intersection_points<T>(pts1, pts2, intersectPts);
if (num <= 2) {
return 0.0;
}
// Convex Hull to order the intersection points in clockwise order and find
// the contour area.
int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);
return polygon_area<T>(orderedPts, num_convex);
}
} // namespace } // namespace
template <typename T> template <typename T>
...@@ -345,3 +375,52 @@ HOST_DEVICE_INLINE T single_box_iou_rotated(T const* const box1_raw, ...@@ -345,3 +375,52 @@ HOST_DEVICE_INLINE T single_box_iou_rotated(T const* const box1_raw,
const T iou = intersection / baseS; const T iou = intersection / baseS;
return iou; return iou;
} }
template <typename T>
HOST_DEVICE_INLINE T single_box_iou_quadri(T const* const pts1_raw,
T const* const pts2_raw,
const int mode_flag) {
// shift center to the middle point to achieve higher precision in result
Point<T> pts1[4], pts2[4];
auto center_shift_x =
(pts1_raw[0] + pts2_raw[0] + pts1_raw[2] + pts2_raw[2] + pts1_raw[4] +
pts2_raw[4] + pts1_raw[6] + pts2_raw[6]) /
8.0;
auto center_shift_y =
(pts1_raw[1] + pts2_raw[1] + pts1_raw[3] + pts2_raw[3] + pts1_raw[5] +
pts2_raw[5] + pts1_raw[7] + pts2_raw[7]) /
8.0;
pts1[0].x = pts1_raw[0] - center_shift_x;
pts1[0].y = pts1_raw[1] - center_shift_y;
pts1[1].x = pts1_raw[2] - center_shift_x;
pts1[1].y = pts1_raw[3] - center_shift_y;
pts1[2].x = pts1_raw[4] - center_shift_x;
pts1[2].y = pts1_raw[5] - center_shift_y;
pts1[3].x = pts1_raw[6] - center_shift_x;
pts1[3].y = pts1_raw[7] - center_shift_y;
pts2[0].x = pts2_raw[0] - center_shift_x;
pts2[0].y = pts2_raw[1] - center_shift_y;
pts2[1].x = pts2_raw[2] - center_shift_x;
pts2[1].y = pts2_raw[3] - center_shift_y;
pts2[2].x = pts2_raw[4] - center_shift_x;
pts2[2].y = pts2_raw[5] - center_shift_y;
pts2[3].x = pts2_raw[6] - center_shift_x;
pts2[3].y = pts2_raw[7] - center_shift_y;
const T area1 = quadri_box_area<T>(pts1);
const T area2 = quadri_box_area<T>(pts2);
if (area1 < 1e-14 || area2 < 1e-14) {
return 0.f;
}
const T intersection = quadri_boxes_intersection<T>(pts1, pts2);
T baseS = 1.0;
if (mode_flag == 0) {
baseS = (area1 + area2 - intersection);
} else if (mode_flag == 1) {
baseS = area1;
}
const T iou = intersection / baseS;
return iou;
}
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#ifndef BOX_IOU_QUADRI_CUDA_CUH
#define BOX_IOU_QUADRI_CUDA_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#include "box_iou_rotated_utils.hpp"
// 2D block with 32 * 16 = 512 threads per block
const int BLOCK_DIM_X = 32;
const int BLOCK_DIM_Y = 16;
inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); }
template <typename T>
__global__ void box_iou_quadri_cuda_kernel(
const int n_boxes1, const int n_boxes2, const T* dev_boxes1,
const T* dev_boxes2, T* dev_ious, const int mode_flag, const bool aligned) {
if (aligned) {
CUDA_1D_KERNEL_LOOP(index, n_boxes1) {
int b1 = index;
int b2 = index;
int base1 = b1 * 8;
float block_boxes1[8];
float block_boxes2[8];
block_boxes1[0] = dev_boxes1[base1 + 0];
block_boxes1[1] = dev_boxes1[base1 + 1];
block_boxes1[2] = dev_boxes1[base1 + 2];
block_boxes1[3] = dev_boxes1[base1 + 3];
block_boxes1[4] = dev_boxes1[base1 + 4];
block_boxes1[5] = dev_boxes1[base1 + 5];
block_boxes1[6] = dev_boxes1[base1 + 6];
block_boxes1[7] = dev_boxes1[base1 + 7];
int base2 = b2 * 8;
block_boxes2[0] = dev_boxes2[base2 + 0];
block_boxes2[1] = dev_boxes2[base2 + 1];
block_boxes2[2] = dev_boxes2[base2 + 2];
block_boxes2[3] = dev_boxes2[base2 + 3];
block_boxes2[4] = dev_boxes2[base2 + 4];
block_boxes2[5] = dev_boxes2[base2 + 5];
block_boxes2[6] = dev_boxes2[base2 + 6];
block_boxes2[7] = dev_boxes2[base2 + 7];
dev_ious[index] =
single_box_iou_quadri<T>(block_boxes1, block_boxes2, mode_flag);
}
} else {
CUDA_1D_KERNEL_LOOP(index, n_boxes1 * n_boxes2) {
int b1 = index / n_boxes2;
int b2 = index % n_boxes2;
int base1 = b1 * 8;
float block_boxes1[8];
float block_boxes2[8];
block_boxes1[0] = dev_boxes1[base1 + 0];
block_boxes1[1] = dev_boxes1[base1 + 1];
block_boxes1[2] = dev_boxes1[base1 + 2];
block_boxes1[3] = dev_boxes1[base1 + 3];
block_boxes1[4] = dev_boxes1[base1 + 4];
block_boxes1[5] = dev_boxes1[base1 + 5];
block_boxes1[6] = dev_boxes1[base1 + 6];
block_boxes1[7] = dev_boxes1[base1 + 7];
int base2 = b2 * 8;
block_boxes2[0] = dev_boxes2[base2 + 0];
block_boxes2[1] = dev_boxes2[base2 + 1];
block_boxes2[2] = dev_boxes2[base2 + 2];
block_boxes2[3] = dev_boxes2[base2 + 3];
block_boxes2[4] = dev_boxes2[base2 + 4];
block_boxes2[5] = dev_boxes2[base2 + 5];
block_boxes2[6] = dev_boxes2[base2 + 6];
block_boxes2[7] = dev_boxes2[base2 + 7];
dev_ious[index] =
single_box_iou_quadri<T>(block_boxes1, block_boxes2, mode_flag);
}
}
}
#endif
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#ifndef NMS_QUADRI_CUDA_CUH
#define NMS_QUADRI_CUDA_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#include "box_iou_rotated_utils.hpp"
__host__ __device__ inline int divideUP(const int x, const int y) {
return (((x) + (y)-1) / (y));
}
namespace {
int const threadsPerBlock = sizeof(unsigned long long) * 8;
}
template <typename T>
__global__ void nms_quadri_cuda_kernel(const int n_boxes,
const float iou_threshold,
const T* dev_boxes,
unsigned long long* dev_mask,
const int multi_label) {
if (multi_label == 1) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
// Compared to nms_cuda_kernel, where each box is represented with 4 values
// (x1, y1, x2, y2), each rotated box is represented with 8 values
// (x1, y1, ..., x4, y4) here.
__shared__ T block_boxes[threadsPerBlock * 8];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 8 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 0];
block_boxes[threadIdx.x * 8 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 1];
block_boxes[threadIdx.x * 8 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 2];
block_boxes[threadIdx.x * 8 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 3];
block_boxes[threadIdx.x * 8 + 4] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 4];
block_boxes[threadIdx.x * 8 + 5] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 5];
block_boxes[threadIdx.x * 8 + 6] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 6];
block_boxes[threadIdx.x * 8 + 7] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 9 + 7];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const T* cur_box = dev_boxes + cur_box_idx * 9;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
// Instead of devIoU used by original horizontal nms, here
// we use the single_box_iou_quadri function from
// box_iou_rotated_utils.h
if (single_box_iou_quadri<T>(cur_box, block_boxes + i * 8, 0) >
iou_threshold) {
t |= 1ULL << i;
}
}
const int col_blocks = divideUP(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
} else {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
// Compared to nms_cuda_kernel, where each box is represented with 4 values
// (x1, y1, x2, y2), each rotated box is represented with 8 values
// (x1, y1, , ..., x4, y4) here.
__shared__ T block_boxes[threadsPerBlock * 8];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 8 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 0];
block_boxes[threadIdx.x * 8 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 1];
block_boxes[threadIdx.x * 8 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 2];
block_boxes[threadIdx.x * 8 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 3];
block_boxes[threadIdx.x * 8 + 4] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 4];
block_boxes[threadIdx.x * 8 + 5] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 5];
block_boxes[threadIdx.x * 8 + 6] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 6];
block_boxes[threadIdx.x * 8 + 7] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 8 + 7];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const T* cur_box = dev_boxes + cur_box_idx * 8;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
// Instead of devIoU used by original horizontal nms, here
// we use the single_box_iou_quadri function from
// box_iou_rotated_utils.h
if (single_box_iou_quadri<T>(cur_box, block_boxes + i * 8, 0) >
iou_threshold) {
t |= 1ULL << i;
}
}
const int col_blocks = divideUP(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
}
#endif
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
void box_iou_quadri_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
DISPATCH_DEVICE_IMPL(box_iou_quadri_impl, boxes1, boxes2, ious, mode_flag,
aligned);
}
// Interface for Python
// inline is needed to prevent multiple function definitions when this header is
// included by different cpps
void box_iou_quadri(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
box_iou_quadri_impl(boxes1, boxes2, ious, mode_flag, aligned);
}
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "box_iou_rotated_utils.hpp"
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"
template <typename T>
void box_iou_quadri_cpu_kernel(const Tensor boxes1, const Tensor boxes2,
Tensor ious, const int mode_flag,
const bool aligned) {
int output_size = ious.numel();
auto num_boxes1 = boxes1.size(0);
auto num_boxes2 = boxes2.size(0);
if (aligned) {
for (int i = 0; i < output_size; i++) {
ious[i] = single_box_iou_quadri<T>(boxes1[i].data_ptr<T>(),
boxes2[i].data_ptr<T>(), mode_flag);
}
} else {
for (int i = 0; i < num_boxes1; i++) {
for (int j = 0; j < num_boxes2; j++) {
ious[i * num_boxes2 + j] = single_box_iou_quadri<T>(
boxes1[i].data_ptr<T>(), boxes2[j].data_ptr<T>(), mode_flag);
}
}
}
}
void box_iou_quadri_cpu(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
box_iou_quadri_cpu_kernel<float>(boxes1, boxes2, ious, mode_flag, aligned);
}
void box_iou_quadri_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
REGISTER_DEVICE_IMPL(box_iou_quadri_impl, CPU, box_iou_quadri_cpu);
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "box_iou_rotated_utils.hpp"
#include "pytorch_cpp_helper.hpp"
template <typename scalar_t>
Tensor nms_quadri_cpu_kernel(const Tensor dets, const Tensor scores,
const float iou_threshold) {
// nms_quadri_cpu_kernel is modified from torchvision's nms_cpu_kernel,
// however, the code in this function is much shorter because
// we delegate the IoU computation for quadri boxes to
// the single_box_iou_quadri function in box_iou_rotated_utils.h
AT_ASSERTM(!dets.is_cuda(), "dets must be a CPU tensor");
AT_ASSERTM(!scores.is_cuda(), "scores must be a CPU tensor");
AT_ASSERTM(dets.scalar_type() == scores.scalar_type(),
"dets should have the same type as scores");
if (dets.numel() == 0) {
return at::empty({0}, dets.options().dtype(at::kLong));
}
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto ndets = dets.size(0);
Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
auto suppressed = suppressed_t.data_ptr<uint8_t>();
auto keep = keep_t.data_ptr<int64_t>();
auto order = order_t.data_ptr<int64_t>();
int64_t num_to_keep = 0;
for (int64_t _i = 0; _i < ndets; _i++) {
auto i = order[_i];
if (suppressed[i] == 1) {
continue;
}
keep[num_to_keep++] = i;
for (int64_t _j = _i + 1; _j < ndets; _j++) {
auto j = order[_j];
if (suppressed[j] == 1) {
continue;
}
auto ovr = single_box_iou_quadri<scalar_t>(
dets[i].data_ptr<scalar_t>(), dets[j].data_ptr<scalar_t>(), 0);
if (ovr >= iou_threshold) {
suppressed[j] = 1;
}
}
}
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
}
Tensor nms_quadri_cpu(const Tensor dets, const Tensor scores,
const float iou_threshold) {
auto result = at::empty({0}, dets.options());
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_quadri", [&] {
result = nms_quadri_cpu_kernel<scalar_t>(dets, scores, iou_threshold);
});
return result;
}
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "box_iou_quadri_cuda.cuh"
#include "pytorch_cuda_helper.hpp"
void box_iou_quadri_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned) {
using scalar_t = float;
AT_ASSERTM(boxes1.is_cuda(), "boxes1 must be a CUDA tensor");
AT_ASSERTM(boxes2.is_cuda(), "boxes2 must be a CUDA tensor");
int output_size = ious.numel();
int num_boxes1 = boxes1.size(0);
int num_boxes2 = boxes2.size(0);
at::cuda::CUDAGuard device_guard(boxes1.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
box_iou_quadri_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
num_boxes1, num_boxes2, boxes1.data_ptr<scalar_t>(),
boxes2.data_ptr<scalar_t>(), (scalar_t*)ious.data_ptr<scalar_t>(),
mode_flag, aligned);
AT_CUDA_CHECK(cudaGetLastError());
}
...@@ -125,6 +125,13 @@ void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious, ...@@ -125,6 +125,13 @@ void box_iou_rotated_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned); const int mode_flag, const bool aligned);
REGISTER_DEVICE_IMPL(box_iou_rotated_impl, CUDA, box_iou_rotated_cuda); REGISTER_DEVICE_IMPL(box_iou_rotated_impl, CUDA, box_iou_rotated_cuda);
void box_iou_quadri_cuda(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
void box_iou_quadri_impl(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
REGISTER_DEVICE_IMPL(box_iou_quadri_impl, CUDA, box_iou_quadri_cuda);
void CARAFEForwardCUDAKernelLauncher(const Tensor features, const Tensor masks, void CARAFEForwardCUDAKernelLauncher(const Tensor features, const Tensor masks,
Tensor rfeatures, Tensor routput, Tensor rfeatures, Tensor routput,
Tensor rmasks, Tensor output, Tensor rmasks, Tensor output,
......
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "nms_quadri_cuda.cuh"
#include "pytorch_cuda_helper.hpp"
Tensor nms_quadri_cuda(const Tensor dets, const Tensor scores,
const Tensor order_t, const Tensor dets_sorted,
float iou_threshold, const int multi_label) {
// using scalar_t = float;
AT_ASSERTM(dets.is_cuda(), "dets must be a CUDA tensor");
AT_ASSERTM(scores.is_cuda(), "scores must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(dets.device());
int dets_num = dets.size(0);
const int col_blocks = at::cuda::ATenCeilDiv(dets_num, threadsPerBlock);
Tensor mask =
at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
dim3 blocks(col_blocks, col_blocks);
dim3 threads(threadsPerBlock);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
dets_sorted.scalar_type(), "nms_quadri_kernel_cuda", [&] {
nms_quadri_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
dets_num, iou_threshold, dets_sorted.data_ptr<scalar_t>(),
(unsigned long long*)mask.data_ptr<int64_t>(), multi_label);
});
Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host =
(unsigned long long*)mask_cpu.data_ptr<int64_t>();
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
Tensor keep =
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data_ptr<int64_t>();
int num_to_keep = 0;
for (int i = 0; i < dets_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
AT_CUDA_CHECK(cudaGetLastError());
return order_t.index(
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
.to(order_t.device(), keep.scalar_type())});
}
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#include "pytorch_cpp_helper.hpp"
Tensor nms_quadri_cpu(const Tensor dets, const Tensor scores,
const float iou_threshold);
#ifdef MMCV_WITH_CUDA
Tensor nms_quadri_cuda(const Tensor dets, const Tensor scores,
const Tensor order, const Tensor dets_sorted,
const float iou_threshold, const int multi_label);
#endif
// Interface for Python
// inline is needed to prevent multiple function definitions when this header is
// included by different cpps
Tensor nms_quadri(const Tensor dets, const Tensor scores, const Tensor order,
const Tensor dets_sorted, const float iou_threshold,
const int multi_label) {
assert(dets.device().is_cuda() == scores.device().is_cuda());
if (dets.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
return nms_quadri_cuda(dets, scores, order, dets_sorted, iou_threshold,
multi_label);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return nms_quadri_cpu(dets, scores, iou_threshold);
}
...@@ -423,6 +423,13 @@ void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2, ...@@ -423,6 +423,13 @@ void chamfer_distance_backward(const Tensor xyz1, const Tensor xyz2,
Tensor graddist2, Tensor gradxyz1, Tensor graddist2, Tensor gradxyz1,
Tensor gradxyz2); Tensor gradxyz2);
void box_iou_quadri(const Tensor boxes1, const Tensor boxes2, Tensor ious,
const int mode_flag, const bool aligned);
Tensor nms_quadri(const Tensor dets, const Tensor scores, const Tensor order,
const Tensor dets_sorted, const float iou_threshold,
const int multi_label);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"), py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
...@@ -853,4 +860,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -853,4 +860,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("input"), py::arg("rois"), py::arg("grad_rois"), py::arg("input"), py::arg("rois"), py::arg("grad_rois"),
py::arg("pooled_height"), py::arg("pooled_width"), py::arg("pooled_height"), py::arg("pooled_width"),
py::arg("spatial_scale")); py::arg("spatial_scale"));
m.def("box_iou_quadri", &box_iou_quadri, "IoU for quadrilateral boxes",
py::arg("boxes1"), py::arg("boxes2"), py::arg("ious"),
py::arg("mode_flag"), py::arg("aligned"));
m.def("nms_quadri", &nms_quadri, "NMS for quadrilateral boxes",
py::arg("dets"), py::arg("scores"), py::arg("order"),
py::arg("dets_sorted"), py::arg("iou_threshold"),
py::arg("multi_label"));
} }
...@@ -8,7 +8,7 @@ from torch import Tensor ...@@ -8,7 +8,7 @@ from torch import Tensor
from ..utils import ext_loader from ..utils import ext_loader
ext_module = ext_loader.load_ext( ext_module = ext_loader.load_ext(
'_ext', ['nms', 'softnms', 'nms_match', 'nms_rotated']) '_ext', ['nms', 'softnms', 'nms_match', 'nms_rotated', 'nms_quadri'])
# This function is modified from: https://github.com/pytorch/vision/ # This function is modified from: https://github.com/pytorch/vision/
...@@ -459,3 +459,45 @@ def nms_rotated(dets: Tensor, ...@@ -459,3 +459,45 @@ def nms_rotated(dets: Tensor,
dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)), dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
dim=1) dim=1)
return dets, keep_inds return dets, keep_inds
def nms_quadri(dets: Tensor,
scores: Tensor,
iou_threshold: float,
labels: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""Performs non-maximum suppression (NMS) on the quadrilateral boxes
according to their intersection-over-union (IoU).
Quadri NMS iteratively removes lower scoring quadrilateral boxes
which have an IoU greater than iou_threshold with another (higher
scoring) quadrilateral box.
Args:
dets (torch.Tensor): Quadri boxes in shape (N, 8).
They are expected to be in
(x1, y1, ..., x4, y4) format.
scores (torch.Tensor): scores in shape (N, ).
iou_threshold (float): IoU thresh for NMS.
labels (torch.Tensor, optional): boxes' label in shape (N,).
Returns:
tuple: kept dets(boxes and scores) and indice, which is always the
same data type as the input.
"""
if dets.shape[0] == 0:
return dets, None
multi_label = labels is not None
if multi_label:
dets_with_lables = \
torch.cat((dets, labels.unsqueeze(1)), 1) # type: ignore
else:
dets_with_lables = dets
_, order = scores.sort(0, descending=True)
dets_sorted = dets_with_lables.index_select(0, order)
keep_inds = ext_module.nms_quadri(dets_with_lables, scores, order,
dets_sorted, iou_threshold, multi_label)
dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
dim=1)
return dets, keep_inds
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE
class TestBoxIoUQuadri:
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
])
def test_box_iou_quadri_cuda(self, device):
from mmcv.ops import box_iou_quadri
np_boxes1 = np.asarray([[1.0, 1.0, 3.0, 4.0, 4.0, 4.0, 4.0, 1.0],
[2.0, 2.0, 3.0, 4.0, 4.0, 2.0, 3.0, 1.0],
[7.0, 7.0, 8.0, 8.0, 9.0, 7.0, 8.0, 6.0]],
dtype=np.float32)
np_boxes2 = np.asarray([[0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 0.0],
[2.0, 1.0, 2.0, 4.0, 4.0, 4.0, 4.0, 1.0],
[7.0, 6.0, 7.0, 8.0, 9.0, 8.0, 9.0, 6.0]],
dtype=np.float32)
np_expect_ious = np.asarray(
[[0.0714, 1.0000, 0.0000], [0.0000, 0.5000, 0.0000],
[0.0000, 0.0000, 0.5000]],
dtype=np.float32)
np_expect_ious_aligned = np.asarray([0.0714, 0.5000, 0.5000],
dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1).to(device)
boxes2 = torch.from_numpy(np_boxes2).to(device)
ious = box_iou_quadri(boxes1, boxes2)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
ious = box_iou_quadri(boxes1, boxes2, aligned=True)
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
])
def test_box_iou_quadri_iof_cuda(self, device):
from mmcv.ops import box_iou_quadri
np_boxes1 = np.asarray([[1.0, 1.0, 3.0, 4.0, 4.0, 4.0, 4.0, 1.0],
[2.0, 2.0, 3.0, 4.0, 4.0, 2.0, 3.0, 1.0],
[7.0, 7.0, 8.0, 8.0, 9.0, 7.0, 8.0, 6.0]],
dtype=np.float32)
np_boxes2 = np.asarray([[0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 0.0],
[2.0, 1.0, 2.0, 4.0, 4.0, 4.0, 4.0, 1.0],
[7.0, 6.0, 7.0, 8.0, 9.0, 8.0, 9.0, 6.0]],
dtype=np.float32)
np_expect_ious = np.asarray(
[[0.1111, 1.0000, 0.0000], [0.0000, 1.0000, 0.0000],
[0.0000, 0.0000, 1.0000]],
dtype=np.float32)
np_expect_ious_aligned = np.asarray([0.1111, 1.0000, 1.0000],
dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1).to(device)
boxes2 = torch.from_numpy(np_boxes2).to(device)
ious = box_iou_quadri(boxes1, boxes2, mode='iof')
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
ious = box_iou_quadri(boxes1, boxes2, mode='iof', aligned=True)
assert np.allclose(
ious.cpu().numpy(), np_expect_ious_aligned, atol=1e-4)
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmcv.utils import IS_CUDA_AVAILABLE
class TestNMSQuadri:
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
])
def test_ml_nms_quadri(self, device):
from mmcv.ops import nms_quadri
np_boxes = np.array([[1.0, 1.0, 3.0, 4.0, 4.0, 4.0, 4.0, 1.0, 0.7],
[2.0, 2.0, 3.0, 4.0, 4.0, 2.0, 3.0, 1.0, 0.8],
[7.0, 7.0, 8.0, 8.0, 9.0, 7.0, 8.0, 6.0, 0.5],
[0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.9]],
dtype=np.float32)
np_labels = np.array([1, 0, 1, 0], dtype=np.float32)
np_expect_dets = np.array([[0., 0., 0., 2., 2., 2., 2., 0.],
[2., 2., 3., 4., 4., 2., 3., 1.],
[7., 7., 8., 8., 9., 7., 8., 6.]],
dtype=np.float32)
np_expect_keep_inds = np.array([3, 1, 2], dtype=np.int64)
boxes = torch.from_numpy(np_boxes).to(device)
labels = torch.from_numpy(np_labels).to(device)
dets, keep_inds = nms_quadri(boxes[:, :8], boxes[:, -1], 0.3, labels)
assert np.allclose(dets.cpu().numpy()[:, :8], np_expect_dets)
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds)
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
])
def test_nms_quadri(self, device):
from mmcv.ops import nms_quadri
np_boxes = np.array([[1.0, 1.0, 3.0, 4.0, 4.0, 4.0, 4.0, 1.0, 0.7],
[2.0, 2.0, 3.0, 4.0, 4.0, 2.0, 3.0, 1.0, 0.8],
[7.0, 7.0, 8.0, 8.0, 9.0, 7.0, 8.0, 6.0, 0.5],
[0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.9]],
dtype=np.float32)
np_expect_dets = np.array([[0., 0., 0., 2., 2., 2., 2., 0.],
[2., 2., 3., 4., 4., 2., 3., 1.],
[7., 7., 8., 8., 9., 7., 8., 6.]],
dtype=np.float32)
np_expect_keep_inds = np.array([3, 1, 2], dtype=np.int64)
boxes = torch.from_numpy(np_boxes).to(device)
dets, keep_inds = nms_quadri(boxes[:, :8], boxes[:, -1], 0.3)
assert np.allclose(dets.cpu().numpy()[:, :8], np_expect_dets)
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds)
@pytest.mark.parametrize('device', [
'cpu',
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
])
def test_batched_nms(self, device):
# test batched_nms with nms_quadri
from mmcv.ops import batched_nms
np_boxes = np.array([[1.0, 1.0, 3.0, 4.0, 4.0, 4.0, 4.0, 1.0, 0.7],
[2.0, 2.0, 3.0, 4.0, 4.0, 2.0, 3.0, 1.0, 0.8],
[7.0, 7.0, 8.0, 8.0, 9.0, 7.0, 8.0, 6.0, 0.5],
[0.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0, 0.0, 0.9]],
dtype=np.float32)
np_labels = np.array([1, 0, 1, 0], dtype=np.float32)
np_expect_agnostic_dets = np.array([[0., 0., 0., 2., 2., 2., 2., 0.],
[2., 2., 3., 4., 4., 2., 3., 1.],
[7., 7., 8., 8., 9., 7., 8., 6.]],
dtype=np.float32)
np_expect_agnostic_keep_inds = np.array([3, 1, 2], dtype=np.int64)
np_expect_dets = np.array([[0., 0., 0., 2., 2., 2., 2., 0.],
[2., 2., 3., 4., 4., 2., 3., 1.],
[1., 1., 3., 4., 4., 4., 4., 1.],
[7., 7., 8., 8., 9., 7., 8., 6.]],
dtype=np.float32)
np_expect_keep_inds = np.array([3, 1, 0, 2], dtype=np.int64)
nms_cfg = dict(type='nms_quadri', iou_threshold=0.3)
# test class_agnostic is True
boxes, keep = batched_nms(
torch.from_numpy(np_boxes[:, :8]).to(device),
torch.from_numpy(np_boxes[:, -1]).to(device),
torch.from_numpy(np_labels).to(device),
nms_cfg,
class_agnostic=True)
assert np.allclose(boxes.cpu().numpy()[:, :8], np_expect_agnostic_dets)
assert np.allclose(keep.cpu().numpy(), np_expect_agnostic_keep_inds)
# test class_agnostic is False
boxes, keep = batched_nms(
torch.from_numpy(np_boxes[:, :8]).to(device),
torch.from_numpy(np_boxes[:, -1]).to(device),
torch.from_numpy(np_labels).to(device),
nms_cfg,
class_agnostic=False)
assert np.allclose(boxes.cpu().numpy()[:, :8], np_expect_dets)
assert np.allclose(keep.cpu().numpy(), np_expect_keep_inds)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment