Unverified Commit a9787648 authored by BigBigDream's avatar BigBigDream Committed by GitHub
Browse files

Add box_iou_rotated, ml_nms_rotated and nms_rotated (#625)



* add box_iou_rotated, ml_nms_rotated and nms_rotated

* fix lint

* fix lint

* fix .py lint

* fix cpp lint

* add newline at the end

* add new line

* fix unittest

* config google style

* fix lint

* lint

* lint

* yapf

* update

* fix lint

* fix lint

* fix lint

* fix

* fix format

* fix format

* add modified from

* add docstring and update others

* update docstring

* update docstring

* update

* fix bug

* fix bug

* fix bug
Co-authored-by: default avatarCao Yuhang <yhcao6@gmail.com>
parent f61bb642
from .bbox import bbox_overlaps
from .box_iou_rotated import box_iou_rotated
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
from .corner_pool import CornerPool
......@@ -16,7 +17,7 @@ from .masked_conv import MaskedConv2d, masked_conv2d
from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack,
modulated_deform_conv2d)
from .nms import batched_nms, nms, nms_match, soft_nms
from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
from .point_sample import (SimpleRoIAlign, point_sample,
rel_roi_point_to_rel_img_point)
from .psa_mask import PSAMask
......@@ -38,5 +39,5 @@ __all__ = [
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift'
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated'
]
import torch
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated'])
def box_iou_rotated(bboxes1, bboxes2):
"""Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in
(x_center, y_center, width, height, angle) format.
Arguments:
boxes1 (Tensor): rotated bboxes 1. \
It has shape (N, 5), indicating (x, y, w, h, theta) for each row.
boxes2 (Tensor): rotated bboxes 2. \
It has shape (N, 5), indicating (x, y, w, h, theta) for each row.
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
if torch.__version__ == 'parrots':
out = torch.zeros((bboxes1.shape[0], bboxes2.shape[0]),
dtype=torch.float32).to(bboxes1.device)
ext_module.box_iou_rotated(bboxes1, bboxes2, out)
else:
out = ext_module.box_iou_rotated(bboxes1, bboxes2)
return out
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
#ifndef BOX_IOU_ROTATED_CUDA_CUH
#define BOX_IOU_ROTATED_CUDA_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#include "box_iou_rotated_utils.hpp"
// 2D block with 32 * 16 = 512 threads per block
const int BLOCK_DIM_X = 32;
const int BLOCK_DIM_Y = 16;
inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); }
template <typename T>
__global__ void box_iou_rotated_cuda_kernel(const int n_boxes1,
const int n_boxes2,
const T* dev_boxes1,
const T* dev_boxes2, T* dev_ious) {
const int row_start = blockIdx.x * blockDim.x;
const int col_start = blockIdx.y * blockDim.y;
const int row_size = min(n_boxes1 - row_start, blockDim.x);
const int col_size = min(n_boxes2 - col_start, blockDim.y);
__shared__ float block_boxes1[BLOCK_DIM_X * 5];
__shared__ float block_boxes2[BLOCK_DIM_Y * 5];
// It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y
if (threadIdx.x < row_size && threadIdx.y == 0) {
block_boxes1[threadIdx.x * 5 + 0] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 0];
block_boxes1[threadIdx.x * 5 + 1] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 1];
block_boxes1[threadIdx.x * 5 + 2] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 2];
block_boxes1[threadIdx.x * 5 + 3] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 3];
block_boxes1[threadIdx.x * 5 + 4] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 4];
}
if (threadIdx.x < col_size && threadIdx.y == 0) {
block_boxes2[threadIdx.x * 5 + 0] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 0];
block_boxes2[threadIdx.x * 5 + 1] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 1];
block_boxes2[threadIdx.x * 5 + 2] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 2];
block_boxes2[threadIdx.x * 5 + 3] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 3];
block_boxes2[threadIdx.x * 5 + 4] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 4];
}
__syncthreads();
if (threadIdx.x < row_size && threadIdx.y < col_size) {
int offset = (row_start + threadIdx.x) * n_boxes2 + col_start + threadIdx.y;
dev_ious[offset] = single_box_iou_rotated<T>(
block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5);
}
}
#endif
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h
#pragma once
#include <cassert>
#include <cmath>
#ifdef __CUDACC__
// Designates functions callable from the host (CPU) and the device (GPU)
#define HOST_DEVICE __host__ __device__
#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
#else
#include <algorithm>
#define HOST_DEVICE
#define HOST_DEVICE_INLINE HOST_DEVICE inline
#endif
namespace {
template <typename T>
struct RotatedBox {
T x_ctr, y_ctr, w, h, a;
};
template <typename T>
struct Point {
T x, y;
HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {}
HOST_DEVICE_INLINE Point operator+(const Point& p) const {
return Point(x + p.x, y + p.y);
}
HOST_DEVICE_INLINE Point& operator+=(const Point& p) {
x += p.x;
y += p.y;
return *this;
}
HOST_DEVICE_INLINE Point operator-(const Point& p) const {
return Point(x - p.x, y - p.y);
}
HOST_DEVICE_INLINE Point operator*(const T coeff) const {
return Point(x * coeff, y * coeff);
}
};
template <typename T>
HOST_DEVICE_INLINE T dot_2d(const Point<T>& A, const Point<T>& B) {
return A.x * B.x + A.y * B.y;
}
template <typename T>
HOST_DEVICE_INLINE T cross_2d(const Point<T>& A, const Point<T>& B) {
return A.x * B.y - B.x * A.y;
}
template <typename T>
HOST_DEVICE_INLINE void get_rotated_vertices(const RotatedBox<T>& box,
Point<T> (&pts)[4]) {
// M_PI / 180. == 0.01745329251
// double theta = box.a * 0.01745329251;
// MODIFIED
double theta = box.a;
T cosTheta2 = (T)cos(theta) * 0.5f;
T sinTheta2 = (T)sin(theta) * 0.5f;
// y: top --> down; x: left --> right
pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w;
pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w;
pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
pts[2].x = 2 * box.x_ctr - pts[0].x;
pts[2].y = 2 * box.y_ctr - pts[0].y;
pts[3].x = 2 * box.x_ctr - pts[1].x;
pts[3].y = 2 * box.y_ctr - pts[1].y;
}
template <typename T>
HOST_DEVICE_INLINE int get_intersection_points(const Point<T> (&pts1)[4],
const Point<T> (&pts2)[4],
Point<T> (&intersections)[24]) {
// Line vector
// A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
Point<T> vec1[4], vec2[4];
for (int i = 0; i < 4; i++) {
vec1[i] = pts1[(i + 1) % 4] - pts1[i];
vec2[i] = pts2[(i + 1) % 4] - pts2[i];
}
// Line test - test all line combos for intersection
int num = 0; // number of intersections
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
// Solve for 2x2 Ax=b
T det = cross_2d<T>(vec2[j], vec1[i]);
// This takes care of parallel lines
if (fabs(det) <= 1e-14) {
continue;
}
auto vec12 = pts2[j] - pts1[i];
T t1 = cross_2d<T>(vec2[j], vec12) / det;
T t2 = cross_2d<T>(vec1[i], vec12) / det;
if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
intersections[num++] = pts1[i] + vec1[i] * t1;
}
}
}
// Check for vertices of rect1 inside rect2
{
const auto& AB = vec2[0];
const auto& DA = vec2[3];
auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) {
// assume ABCD is the rectangle, and P is the point to be judged
// P is inside ABCD iff. P's projection on AB lies within AB
// and P's projection on AD lies within AD
auto AP = pts1[i] - pts2[0];
auto APdotAB = dot_2d<T>(AP, AB);
auto APdotAD = -dot_2d<T>(AP, DA);
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
(APdotAD <= ADdotAD)) {
intersections[num++] = pts1[i];
}
}
}
// Reverse the check - check for vertices of rect2 inside rect1
{
const auto& AB = vec1[0];
const auto& DA = vec1[3];
auto ABdotAB = dot_2d<T>(AB, AB);
auto ADdotAD = dot_2d<T>(DA, DA);
for (int i = 0; i < 4; i++) {
auto AP = pts2[i] - pts1[0];
auto APdotAB = dot_2d<T>(AP, AB);
auto APdotAD = -dot_2d<T>(AP, DA);
if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
(APdotAD <= ADdotAD)) {
intersections[num++] = pts2[i];
}
}
}
return num;
}
template <typename T>
HOST_DEVICE_INLINE int convex_hull_graham(const Point<T> (&p)[24],
const int& num_in, Point<T> (&q)[24],
bool shift_to_zero = false) {
assert(num_in >= 2);
// Step 1:
// Find point with minimum y
// if more than 1 points have the same minimum y,
// pick the one with the minimum x.
int t = 0;
for (int i = 1; i < num_in; i++) {
if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
t = i;
}
}
auto& start = p[t]; // starting point
// Step 2:
// Subtract starting point from every points (for sorting in the next step)
for (int i = 0; i < num_in; i++) {
q[i] = p[i] - start;
}
// Swap the starting point to position 0
auto tmp = q[0];
q[0] = q[t];
q[t] = tmp;
// Step 3:
// Sort point 1 ~ num_in according to their relative cross-product values
// (essentially sorting according to angles)
// If the angles are the same, sort according to their distance to origin
T dist[24];
for (int i = 0; i < num_in; i++) {
dist[i] = dot_2d<T>(q[i], q[i]);
}
#ifdef __CUDACC__
// CUDA version
// In the future, we can potentially use thrust
// for sorting here to improve speed (though not guaranteed)
for (int i = 1; i < num_in - 1; i++) {
for (int j = i + 1; j < num_in; j++) {
T crossProduct = cross_2d<T>(q[i], q[j]);
if ((crossProduct < -1e-6) ||
(fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
auto q_tmp = q[i];
q[i] = q[j];
q[j] = q_tmp;
auto dist_tmp = dist[i];
dist[i] = dist[j];
dist[j] = dist_tmp;
}
}
}
#else
// CPU version
std::sort(q + 1, q + num_in,
[](const Point<T>& A, const Point<T>& B) -> bool {
T temp = cross_2d<T>(A, B);
if (fabs(temp) < 1e-6) {
return dot_2d<T>(A, A) < dot_2d<T>(B, B);
} else {
return temp > 0;
}
});
#endif
// Step 4:
// Make sure there are at least 2 points (that don't overlap with each other)
// in the stack
int k; // index of the non-overlapped second point
for (k = 1; k < num_in; k++) {
if (dist[k] > 1e-8) {
break;
}
}
if (k == num_in) {
// We reach the end, which means the convex hull is just one point
q[0] = p[t];
return 1;
}
q[1] = q[k];
int m = 2; // 2 points in the stack
// Step 5:
// Finally we can start the scanning process.
// When a non-convex relationship between the 3 points is found
// (either concave shape or duplicated points),
// we pop the previous point from the stack
// until the 3-point relationship is convex again, or
// until the stack only contains two points
for (int i = k + 1; i < num_in; i++) {
while (m > 1 && cross_2d<T>(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
m--;
}
q[m++] = q[i];
}
// Step 6 (Optional):
// In general sense we need the original coordinates, so we
// need to shift the points back (reverting Step 2)
// But if we're only interested in getting the area/perimeter of the shape
// We can simply return.
if (!shift_to_zero) {
for (int i = 0; i < m; i++) {
q[i] += start;
}
}
return m;
}
template <typename T>
HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) {
if (m <= 2) {
return 0;
}
T area = 0;
for (int i = 1; i < m - 1; i++) {
area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));
}
return area / 2.0;
}
template <typename T>
HOST_DEVICE_INLINE T rotated_boxes_intersection(const RotatedBox<T>& box1,
const RotatedBox<T>& box2) {
// There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
// from rotated_rect_intersection_pts
Point<T> intersectPts[24], orderedPts[24];
Point<T> pts1[4];
Point<T> pts2[4];
get_rotated_vertices<T>(box1, pts1);
get_rotated_vertices<T>(box2, pts2);
int num = get_intersection_points<T>(pts1, pts2, intersectPts);
if (num <= 2) {
return 0.0;
}
// Convex Hull to order the intersection points in clockwise order and find
// the contour area.
int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);
return polygon_area<T>(orderedPts, num_convex);
}
} // namespace
template <typename T>
HOST_DEVICE_INLINE T single_box_iou_rotated(T const* const box1_raw,
T const* const box2_raw) {
// shift center to the middle point to achieve higher precision in result
RotatedBox<T> box1, box2;
auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0;
box1.x_ctr = box1_raw[0] - center_shift_x;
box1.y_ctr = box1_raw[1] - center_shift_y;
box1.w = box1_raw[2];
box1.h = box1_raw[3];
box1.a = box1_raw[4];
box2.x_ctr = box2_raw[0] - center_shift_x;
box2.y_ctr = box2_raw[1] - center_shift_y;
box2.w = box2_raw[2];
box2.h = box2_raw[3];
box2.a = box2_raw[4];
const T area1 = box1.w * box1.h;
const T area2 = box2.w * box2.h;
if (area1 < 1e-14 || area2 < 1e-14) {
return 0.f;
}
const T intersection = rotated_boxes_intersection<T>(box1, box2);
const T iou = intersection / (area1 + area2 - intersection);
return iou;
}
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu
#ifndef NMS_ROTATED_CUDA_CUH
#define NMS_ROTATED_CUDA_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#include "box_iou_rotated_utils.hpp"
__host__ __device__ inline int divideUP(const int x, const int y) {
return (((x) + (y)-1) / (y));
}
namespace {
int const threadsPerBlock = sizeof(unsigned long long) * 8;
}
template <typename T>
__global__ void nms_rotated_cuda_kernel(const int n_boxes,
const float iou_threshold,
const T* dev_boxes,
unsigned long long* dev_mask,
const int multi_label) {
// nms_rotated_cuda_kernel is modified from torchvision's nms_cuda_kernel
if (multi_label == 1) {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
// Compared to nms_cuda_kernel, where each box is represented with 4 values
// (x1, y1, x2, y2), each rotated box is represented with 5 values
// (x_center, y_center, width, height, angle_degrees) here.
__shared__ T block_boxes[threadsPerBlock * 5];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 6 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 0];
block_boxes[threadIdx.x * 6 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 1];
block_boxes[threadIdx.x * 6 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 2];
block_boxes[threadIdx.x * 6 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 3];
block_boxes[threadIdx.x * 6 + 4] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 4];
block_boxes[threadIdx.x * 6 + 5] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 5];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const T* cur_box = dev_boxes + cur_box_idx * 6;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
// Instead of devIoU used by original horizontal nms, here
// we use the single_box_iou_rotated function from
// box_iou_rotated_utils.h
if (single_box_iou_rotated<T>(cur_box, block_boxes + i * 6) >
iou_threshold) {
t |= 1ULL << i;
}
}
const int col_blocks = divideUP(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
} else {
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size =
min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
const int col_size =
min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
// Compared to nms_cuda_kernel, where each box is represented with 4 values
// (x1, y1, x2, y2), each rotated box is represented with 5 values
// (x_center, y_center, width, height, angle_degrees) here.
__shared__ T block_boxes[threadsPerBlock * 5];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 5 + 0] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
block_boxes[threadIdx.x * 5 + 1] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
block_boxes[threadIdx.x * 5 + 2] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
block_boxes[threadIdx.x * 5 + 3] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
block_boxes[threadIdx.x * 5 + 4] =
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
const T* cur_box = dev_boxes + cur_box_idx * 5;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
// Instead of devIoU used by original horizontal nms, here
// we use the single_box_iou_rotated function from
// box_iou_rotated_utils.h
if (single_box_iou_rotated<T>(cur_box, block_boxes + i * 5) >
iou_threshold) {
t |= 1ULL << i;
}
}
const int col_blocks = divideUP(n_boxes, threadsPerBlock);
dev_mask[cur_box_idx * col_blocks + col_start] = t;
}
}
}
#endif
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h
#include "parrots_cpp_helper.hpp"
DArrayLite box_iou_rotated_cuda(const DArrayLite boxes1,
const DArrayLite boxes2, cudaStream_t stream,
CudaContext& ctx);
void box_iou_rotated(CudaContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
const auto& boxes1 = ins[0];
const auto& boxes2 = ins[1];
cudaStream_t stream = getStreamNative<CudaDevice>(ctx.getStream());
outs[0] = box_iou_rotated_cuda(boxes1, boxes2, stream, ctx);
}
PARROTS_EXTENSION_REGISTER(box_iou_rotated)
.input(2)
.output(1)
.apply(box_iou_rotated)
.done();
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
#include "box_iou_rotated_cuda.cuh"
#include "parrots_cuda_helper.hpp"
DArrayLite box_iou_rotated_cuda(const DArrayLite boxes1,
const DArrayLite boxes2, cudaStream_t stream,
CudaContext& ctx) {
using scalar_t = float;
int num_boxes1 = boxes1.dim(0);
int num_boxes2 = boxes2.dim(0);
auto ious = ctx.createDArrayLite(
DArraySpec::array(Prim::Float32, DArrayShape(num_boxes1 * num_boxes2)));
if (num_boxes1 > 0 && num_boxes2 > 0) {
const int blocks_x = divideUP(num_boxes1, BLOCK_DIM_X);
const int blocks_y = divideUP(num_boxes2, BLOCK_DIM_Y);
dim3 blocks(blocks_x, blocks_y);
dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y);
box_iou_rotated_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
num_boxes1, num_boxes2, boxes1.ptr<scalar_t>(), boxes2.ptr<scalar_t>(),
(scalar_t*)ious.ptr<scalar_t>());
PARROTS_CUDA_CHECK(cudaGetLastError());
}
return ious.view({num_boxes1, num_boxes2});
}
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/nms_rotated/nms_rotated.h
#include "parrots_cpp_helper.hpp"
DArrayLite nms_rotated_cuda(const DArrayLite dets, const DArrayLite scores,
const DArrayLite dets_sorted, float iou_threshold,
const int multi_label, cudaStream_t stream,
CudaContext& ctx);
// Interface for Python
// inline is needed to prevent multiple function definitions when this header is
// included by different cpps
void nms_rotated(CudaContext& ctx, const SSElement& attr,
const OperatorBase::in_list_t& ins,
OperatorBase::out_list_t& outs) {
float iou_threshold;
int multi_label;
SSAttrs(attr)
.get<float>("iou_threshold", iou_threshold)
.get<int>("multi_label", multi_label)
.done();
const auto& dets = ins[0];
const auto& scores = ins[1];
const auto& dets_sorted = ins[2];
cudaStream_t stream = getStreamNative<CudaDevice>(ctx.getStream());
outs[0] = nms_rotated_cuda(dets, scores, dets_sorted, iou_threshold,
multi_label, stream, ctx);
}
PARROTS_EXTENSION_REGISTER(nms_rotated)
.attr("multi_label")
.attr("iou_threshold")
.input(3)
.output(1)
.apply(nms_rotated)
.done();
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu
#include "nms_rotated_cuda.cuh"
#include "parrots_cuda_helper.hpp"
DArrayLite nms_rotated_cuda(const DArrayLite dets, const DArrayLite scores,
const DArrayLite dets_sorted, float iou_threshold,
const int multi_label, cudaStream_t stream,
CudaContext& ctx) {
int dets_num = dets.dim(0);
const int col_blocks = divideUP(dets_num, threadsPerBlock);
auto mask = ctx.createDArrayLite(
DArraySpec::array(Prim::Int64, DArrayShape(dets_num * col_blocks)));
dim3 blocks(col_blocks, col_blocks);
dim3 threads(threadsPerBlock);
PARROTS_DISPATCH_FLOATING_TYPES_AND_HALF(dets_sorted.elemType().prim(), [&] {
nms_rotated_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
dets_num, iou_threshold, dets_sorted.ptr<scalar_t>(),
(unsigned long long*)mask.ptr<int64_t>(), multi_label);
});
DArrayLite mask_cpu = ctx.createDArrayLite(mask, getHostProxy());
unsigned long long* mask_host = (unsigned long long*)mask_cpu.ptr<int64_t>();
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
auto keep = ctx.createDArrayLite(
DArraySpec::array(Prim::Int64, DArrayShape(dets_num)), getHostProxy());
int64_t* keep_out = keep.ptr<int64_t>();
for (int i = 0; i < dets_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[i] = 1;
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
auto keep_cuda = ctx.createDArrayLite(keep, ctx.getProxy());
PARROTS_CUDA_CHECK(cudaGetLastError());
return keep_cuda;
}
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h
#include "pytorch_cpp_helper.hpp"
Tensor box_iou_rotated_cpu(const Tensor boxes1, const Tensor boxes2);
#ifdef MMCV_WITH_CUDA
Tensor box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2);
#endif
// Interface for Python
// inline is needed to prevent multiple function definitions when this header is
// included by different cpps
Tensor box_iou_rotated(const Tensor boxes1, const Tensor boxes2) {
assert(boxes1.device().is_cuda() == boxes2.device().is_cuda());
if (boxes1.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
return box_iou_rotated_cuda(boxes1, boxes2);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return box_iou_rotated_cpu(boxes1, boxes2);
}
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp
#include "box_iou_rotated_utils.hpp"
#include "pytorch_cpp_helper.hpp"
template <typename T>
void box_iou_rotated_cpu_kernel(const Tensor boxes1, const Tensor boxes2,
Tensor ious) {
auto widths1 = boxes1.select(1, 2).contiguous();
auto heights1 = boxes1.select(1, 3).contiguous();
auto widths2 = boxes2.select(1, 2).contiguous();
auto heights2 = boxes2.select(1, 3).contiguous();
Tensor areas1 = widths1 * heights1;
Tensor areas2 = widths2 * heights2;
auto num_boxes1 = boxes1.size(0);
auto num_boxes2 = boxes2.size(0);
for (int i = 0; i < num_boxes1; i++) {
for (int j = 0; j < num_boxes2; j++) {
ious[i * num_boxes2 + j] = single_box_iou_rotated<T>(
boxes1[i].data_ptr<T>(), boxes2[j].data_ptr<T>());
}
}
}
Tensor box_iou_rotated_cpu(const Tensor boxes1, const Tensor boxes2) {
auto num_boxes1 = boxes1.size(0);
auto num_boxes2 = boxes2.size(0);
Tensor ious =
at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat));
box_iou_rotated_cpu_kernel<float>(boxes1, boxes2, ious);
// reshape from 1d array to 2d array
auto shape = std::vector<int64_t>{num_boxes1, num_boxes2};
return ious.reshape(shape);
}
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
#include "box_iou_rotated_cuda.cuh"
#include "pytorch_cuda_helper.hpp"
Tensor box_iou_rotated_cuda(const Tensor boxes1, const Tensor boxes2) {
using scalar_t = float;
AT_ASSERTM(boxes1.type().is_cuda(), "boxes1 must be a CUDA tensor");
AT_ASSERTM(boxes2.type().is_cuda(), "boxes2 must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(boxes1.device());
int num_boxes1 = boxes1.size(0);
int num_boxes2 = boxes2.size(0);
Tensor ious =
at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat));
if (num_boxes1 > 0 && num_boxes2 > 0) {
const int blocks_x = at::cuda::ATenCeilDiv(num_boxes1, BLOCK_DIM_X);
const int blocks_y = at::cuda::ATenCeilDiv(num_boxes2, BLOCK_DIM_Y);
dim3 blocks(blocks_x, blocks_y);
dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
box_iou_rotated_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
num_boxes1, num_boxes2, boxes1.data_ptr<scalar_t>(),
boxes2.data_ptr<scalar_t>(), (scalar_t*)ious.data_ptr<scalar_t>());
AT_CUDA_CHECK(cudaGetLastError());
}
// reshape from 1d array to 2d array
auto shape = std::vector<int64_t>{num_boxes1, num_boxes2};
return ious.reshape(shape);
}
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/nms_rotated/nms_rotated.h
#include "pytorch_cpp_helper.hpp"
Tensor nms_rotated_cpu(const Tensor dets, const Tensor scores,
const float iou_threshold);
#ifdef MMCV_WITH_CUDA
Tensor nms_rotated_cuda(const Tensor dets, const Tensor scores,
const Tensor order, const Tensor dets_sorted,
const float iou_threshold, const int multi_label);
#endif
// Interface for Python
// inline is needed to prevent multiple function definitions when this header is
// included by different cpps
Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order,
const Tensor dets_sorted, const float iou_threshold,
const int multi_label) {
assert(dets.device().is_cuda() == scores.device().is_cuda());
if (dets.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
return nms_rotated_cuda(dets, scores, order, dets_sorted, iou_threshold,
multi_label);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return nms_rotated_cpu(dets, scores, iou_threshold);
}
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp
#include "box_iou_rotated_utils.hpp"
#include "pytorch_cpp_helper.hpp"
template <typename scalar_t>
Tensor nms_rotated_cpu_kernel(const Tensor dets, const Tensor scores,
const float iou_threshold) {
// nms_rotated_cpu_kernel is modified from torchvision's nms_cpu_kernel,
// however, the code in this function is much shorter because
// we delegate the IoU computation for rotated boxes to
// the single_box_iou_rotated function in box_iou_rotated_utils.h
AT_ASSERTM(!dets.type().is_cuda(), "dets must be a CPU tensor");
AT_ASSERTM(!scores.type().is_cuda(), "scores must be a CPU tensor");
AT_ASSERTM(dets.type() == scores.type(),
"dets should have the same type as scores");
if (dets.numel() == 0) {
return at::empty({0}, dets.options().dtype(at::kLong));
}
auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto ndets = dets.size(0);
Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte));
Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong));
auto suppressed = suppressed_t.data_ptr<uint8_t>();
auto keep = keep_t.data_ptr<int64_t>();
auto order = order_t.data_ptr<int64_t>();
int64_t num_to_keep = 0;
for (int64_t _i = 0; _i < ndets; _i++) {
auto i = order[_i];
if (suppressed[i] == 1) {
continue;
}
keep[num_to_keep++] = i;
for (int64_t _j = _i + 1; _j < ndets; _j++) {
auto j = order[_j];
if (suppressed[j] == 1) {
continue;
}
auto ovr = single_box_iou_rotated<scalar_t>(dets[i].data_ptr<scalar_t>(),
dets[j].data_ptr<scalar_t>());
if (ovr >= iou_threshold) {
suppressed[j] = 1;
}
}
}
return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep);
}
Tensor nms_rotated_cpu(const Tensor dets, const Tensor scores,
const float iou_threshold) {
auto result = at::empty({0}, dets.options());
AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms_rotated", [&] {
result = nms_rotated_cpu_kernel<scalar_t>(dets, scores, iou_threshold);
});
return result;
}
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu
#include "nms_rotated_cuda.cuh"
#include "pytorch_cuda_helper.hpp"
Tensor nms_rotated_cuda(const Tensor dets, const Tensor scores,
const Tensor order_t, const Tensor dets_sorted,
float iou_threshold, const int multi_label) {
// using scalar_t = float;
AT_ASSERTM(dets.type().is_cuda(), "dets must be a CUDA tensor");
AT_ASSERTM(scores.type().is_cuda(), "scores must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(dets.device());
int dets_num = dets.size(0);
const int col_blocks = at::cuda::ATenCeilDiv(dets_num, threadsPerBlock);
Tensor mask =
at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong));
dim3 blocks(col_blocks, col_blocks);
dim3 threads(threadsPerBlock);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
dets_sorted.type(), "nms_rotated_kernel_cuda", [&] {
nms_rotated_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
dets_num, iou_threshold, dets_sorted.data<scalar_t>(),
(unsigned long long*)mask.data<int64_t>(), multi_label);
});
Tensor mask_cpu = mask.to(at::kCPU);
unsigned long long* mask_host = (unsigned long long*)mask_cpu.data<int64_t>();
std::vector<unsigned long long> remv(col_blocks);
memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks);
Tensor keep =
at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU));
int64_t* keep_out = keep.data<int64_t>();
int num_to_keep = 0;
for (int i = 0; i < dets_num; i++) {
int nblock = i / threadsPerBlock;
int inblock = i % threadsPerBlock;
if (!(remv[nblock] & (1ULL << inblock))) {
keep_out[num_to_keep++] = i;
unsigned long long* p = mask_host + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv[j] |= p[j];
}
}
}
AT_CUDA_CHECK(cudaGetLastError());
return order_t.index(
{keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep)
.to(order_t.device(), keep.scalar_type())});
}
......@@ -175,6 +175,12 @@ Tensor top_pool_forward(Tensor input);
Tensor top_pool_backward(Tensor input, Tensor grad_output);
Tensor box_iou_rotated(const Tensor boxes1, const Tensor boxes2);
Tensor nms_rotated(const Tensor dets, Tensor scores, Tensor order,
Tensor dets_sorted, const float iou_threshold,
const int multi_label);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_compiler_version", &get_compiler_version, "get_compiler_version");
m.def("get_compiling_cuda_version", &get_compiling_cuda_version,
......@@ -357,4 +363,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("top_pool_backward", &top_pool_backward, "Top Pool Backward",
py::arg("input"), py::arg("grad_output"),
py::call_guard<py::gil_scoped_release>());
m.def("box_iou_rotated", &box_iou_rotated, "IoU for rotated boxes",
py::arg("boxes1"), py::arg("boxes2"));
m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes", py::arg("dets"),
py::arg("scores"), py::arg("order"), py::arg("dets_sorted"),
py::arg("iou_threshold"), py::arg("multi_label"));
}
......@@ -5,6 +5,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <THC/THCAtomics.cuh>
#include "common_cuda_helper.hpp"
......
......@@ -6,7 +6,8 @@ import torch
from mmcv.utils import deprecated_api_warning
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['nms', 'softnms', 'nms_match'])
ext_module = ext_loader.load_ext(
'_ext', ['nms', 'softnms', 'nms_match', 'nms_rotated'])
# This function is modified from: https://github.com/pytorch/vision/
......@@ -304,3 +305,52 @@ def nms_match(dets, iou_threshold):
return [dets.new_tensor(m, dtype=torch.long) for m in matched]
else:
return [np.array(m, dtype=np.int) for m in matched]
def nms_rotated(dets, scores, iou_threshold, labels=None):
"""Performs non-maximum suppression (NMS) on the rotated boxes according to
their intersection-over-union (IoU).
Rotated NMS iteratively removes lower scoring rotated boxes which have an
IoU greater than iou_threshold with another (higher scoring) rotated box.
Args:
boxes (Tensor): Rotated boxes in shape (N, 5). They are expected to \
be in (x_ctr, y_ctr, width, height, angle_radian) format.
scores (Tensor): scores in shape (N, ).
iou_threshold (float): IoU thresh for NMS.
labels (Tensor): boxes's label in shape (N,).
Returns:
tuple: kept dets(boxes and scores) and indice, which is always the \
same data type as the input.
"""
if dets.shape[0] == 0:
return dets, None
multi_label = labels is not None
if multi_label:
dets_wl = torch.cat((dets, labels.unsqueeze(1)), 1)
else:
dets_wl = dets
_, order = scores.sort(0, descending=True)
dets_sorted = dets_wl.index_select(0, order)
if torch.__version__ == 'parrots':
select = torch.zeros((dets.shape[0]),
dtype=torch.int64).to(dets.device)
ext_module.nms_rotated(
dets_wl,
scores,
dets_sorted,
select,
iou_threshold=iou_threshold,
multi_label=multi_label)
keep_inds = order.masked_select(select == 1)
dets = dets[keep_inds, :]
else:
keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
iou_threshold, multi_label)
dets = dets[keep_inds, :]
dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
dim=1)
return dets, keep_inds
import numpy as np
import torch
class TestBoxIoURotated(object):
def test_box_iou_rotated(self):
if not torch.cuda.is_available():
return
from mmcv.ops import box_iou_rotated
b1 = torch.tensor(
[[1.0, 1.0, 3.0, 4.0], [2.0, 2.0, 3.0, 4.0], [7.0, 7.0, 8.0, 8.0]],
dtype=torch.float32).cuda()
b2 = torch.tensor([[0.0, 2.0, 2.0, 5.0], [2.0, 1.0, 3.0, 3.0]],
dtype=torch.float32).cuda()
expect_output = torch.tensor(
[[0.2715, 0.0000], [0.1396, 0.0000], [0.0000, 0.0000]],
dtype=torch.float32).cuda()
output = box_iou_rotated(b1, b2)
assert np.allclose(
output.cpu().numpy(), expect_output.cpu().numpy(), atol=1e-4)
import numpy as np
import torch
class TestNmsRotated(object):
def test_ml_nms_rotated(self):
if not torch.cuda.is_available():
return
from mmcv.ops import nms_rotated
np_boxes = np.array(
[[6.0, 3.0, 8.0, 7.0, 0.5, 0.7], [3.0, 6.0, 9.0, 11.0, 0.6, 0.8],
[3.0, 7.0, 10.0, 12.0, 0.3, 0.5], [1.0, 4.0, 13.0, 7.0, 0.6, 0.9]
],
dtype=np.float32)
np_labels = np.array([1, 0, 1, 0], dtype=np.float32)
np_expect_dets = np.array(
[[1.0, 4.0, 13.0, 7.0, 0.6], [3.0, 6.0, 9.0, 11.0, 0.6],
[6.0, 3.0, 8.0, 7.0, 0.5]],
dtype=np.float32)
np_expect_keep_inds = np.array([3, 1, 0], dtype=np.int64)
boxes = torch.from_numpy(np_boxes).cuda()
labels = torch.from_numpy(np_labels).cuda()
dets, keep_inds = nms_rotated(boxes, 0.5, labels, True)
assert np.allclose(dets.cpu().numpy(), np_expect_dets)
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds)
def test_nms_rotated(self):
if not torch.cuda.is_available():
return
from mmcv.ops import nms_rotated
np_boxes = np.array(
[[6.0, 3.0, 8.0, 7.0, 0.5, 0.7], [3.0, 6.0, 9.0, 11.0, 0.6, 0.8],
[3.0, 7.0, 10.0, 12.0, 0.3, 0.5], [1.0, 4.0, 13.0, 7.0, 0.6, 0.9]
],
dtype=np.float32)
np_expect_dets = np.array(
[[1.0, 4.0, 13.0, 7.0, 0.6], [3.0, 6.0, 9.0, 11.0, 0.6],
[6.0, 3.0, 8.0, 7.0, 0.5]],
dtype=np.float32)
np_expect_keep_inds = np.array([3, 1, 0], dtype=np.int64)
boxes = torch.from_numpy(np_boxes).cuda()
dets, keep_inds = nms_rotated(boxes, 0.5)
assert np.allclose(dets.cpu().numpy(), np_expect_dets)
assert np.allclose(keep_inds.cpu().numpy(), np_expect_keep_inds)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment