Unverified Commit 2157b708 authored by dingchang's avatar dingchang Committed by GitHub
Browse files

[Feature] Add iou3d op from mmdet3d (#1356)



* add ops (iou3d) in mmdet3d

* add unit test

* refactor code

* refactor code

* refactor code

* refactor code

* refactor code
Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
parent 4c8bfb48
......@@ -24,6 +24,7 @@ from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
from .gather_points import gather_points
from .info import (get_compiler_version, get_compiling_cuda_version,
get_onnxruntime_op_path)
from .iou3d import boxes_iou_bev, nms_bev, nms_normal_bev
from .knn import knn
from .masked_conv import MaskedConv2d, masked_conv2d
from .modulated_deform_conv import (ModulatedDeformConv2d,
......@@ -53,82 +54,27 @@ from .upfirdn2d import upfirdn2d
from .voxelize import Voxelization, voxelization
__all__ = [
'bbox_overlaps',
'CARAFE',
'CARAFENaive',
'CARAFEPack',
'carafe',
'carafe_naive',
'CornerPool',
'DeformConv2d',
'DeformConv2dPack',
'deform_conv2d',
'DeformRoIPool',
'DeformRoIPoolPack',
'ModulatedDeformRoIPoolPack',
'deform_roi_pool',
'SigmoidFocalLoss',
'SoftmaxFocalLoss',
'sigmoid_focal_loss',
'softmax_focal_loss',
'get_compiler_version',
'get_compiling_cuda_version',
'get_onnxruntime_op_path',
'MaskedConv2d',
'masked_conv2d',
'ModulatedDeformConv2d',
'ModulatedDeformConv2dPack',
'modulated_deform_conv2d',
'batched_nms',
'nms',
'soft_nms',
'nms_match',
'RoIAlign',
'roi_align',
'RoIPool',
'roi_pool',
'SyncBatchNorm',
'Conv2d',
'ConvTranspose2d',
'Linear',
'MaxPool2d',
'CrissCrossAttention',
'PSAMask',
'point_sample',
'rel_roi_point_to_rel_img_point',
'SimpleRoIAlign',
'SAConv2d',
'TINShift',
'tin_shift',
'assign_score_withk',
'box_iou_rotated',
'RoIPointPool3d',
'nms_rotated',
'knn',
'ball_query',
'upfirdn2d',
'FusedBiasLeakyReLU',
'fused_bias_leakyrelu',
'RoIAlignRotated',
'roi_align_rotated',
'pixel_group',
'contour_expand',
'three_nn',
'three_interpolate',
'MultiScaleDeformableAttention',
'Voxelization',
'voxelization',
'dynamic_scatter',
'DynamicScatter',
'BorderAlign',
'border_align',
'gather_points',
'furthest_point_sample',
'furthest_point_sample_with_dist',
'PointsSampler',
'Correlation',
'RoIAwarePool3d',
'points_in_boxes_part',
'points_in_boxes_cpu',
'points_in_boxes_all',
'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
'carafe_naive', 'CornerPool', 'DeformConv2d', 'DeformConv2dPack',
'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss',
'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss',
'get_compiler_version', 'get_compiling_cuda_version',
'get_onnxruntime_op_path', 'MaskedConv2d', 'masked_conv2d',
'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'assign_score_withk',
'box_iou_rotated', 'RoIPointPool3d', 'nms_rotated', 'knn', 'ball_query',
'upfirdn2d', 'FusedBiasLeakyReLU', 'boxes_iou_bev', 'nms_bev',
'nms_normal_bev', 'fused_bias_leakyrelu', 'RoIAlignRotated',
'roi_align_rotated', 'pixel_group', 'contour_expand', 'three_nn',
'three_interpolate', 'MultiScaleDeformableAttention', 'BorderAlign',
'border_align', 'gather_points', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation',
'Voxelization', 'voxelization', 'dynamic_scatter', 'DynamicScatter',
'RoIAwarePool3d', 'points_in_boxes_part', 'points_in_boxes_cpu',
'points_in_boxes_all'
]
// Copyright (c) OpenMMLab. All rights reserved
#ifndef IOU3D_CUDA_KERNEL_CUH
#define IOU3D_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
const int THREADS_PER_BLOCK_IOU3D = 16;
const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
__device__ const float EPS = 1e-8;
struct Point {
float x, y;
__device__ Point() {}
__device__ Point(double _x, double _y) { x = _x, y = _y; }
__device__ void set(float _x, float _y) {
x = _x;
y = _y;
}
__device__ Point operator+(const Point &b) const {
return Point(x + b.x, y + b.y);
}
__device__ Point operator-(const Point &b) const {
return Point(x - b.x, y - b.y);
}
};
__device__ inline float cross(const Point &a, const Point &b) {
return a.x * b.y - a.y * b.x;
}
__device__ inline float cross(const Point &p1, const Point &p2,
const Point &p0) {
return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y);
}
__device__ int check_rect_cross(const Point &p1, const Point &p2,
const Point &q1, const Point &q2) {
int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) &&
min(q1.x, q2.x) <= max(p1.x, p2.x) &&
min(p1.y, p2.y) <= max(q1.y, q2.y) &&
min(q1.y, q2.y) <= max(p1.y, p2.y);
return ret;
}
__device__ inline int check_in_box2d(const float *box, const Point &p) {
// params: box (5) [x1, y1, x2, y2, angle]
const float MARGIN = 1e-5;
float center_x = (box[0] + box[2]) / 2;
float center_y = (box[1] + box[3]) / 2;
float angle_cos = cos(-box[4]),
angle_sin =
sin(-box[4]); // rotate the point in the opposite direction of box
float rot_x =
(p.x - center_x) * angle_cos - (p.y - center_y) * angle_sin + center_x;
float rot_y =
(p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos + center_y;
return (rot_x > box[0] - MARGIN && rot_x < box[2] + MARGIN &&
rot_y > box[1] - MARGIN && rot_y < box[3] + MARGIN);
}
__device__ inline int intersection(const Point &p1, const Point &p0,
const Point &q1, const Point &q0,
Point &ans_point) {
// fast exclusion
if (check_rect_cross(p0, p1, q0, q1) == 0) return 0;
// check cross standing
float s1 = cross(q0, p1, p0);
float s2 = cross(p1, q1, p0);
float s3 = cross(p0, q1, q0);
float s4 = cross(q1, p1, q0);
if (!(s1 * s2 > 0 && s3 * s4 > 0)) return 0;
// calculate intersection of two lines
float s5 = cross(q1, p1, p0);
if (fabs(s5 - s1) > EPS) {
ans_point.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1);
ans_point.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1);
} else {
float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y;
float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y;
float D = a0 * b1 - a1 * b0;
ans_point.x = (b0 * c1 - b1 * c0) / D;
ans_point.y = (a1 * c0 - a0 * c1) / D;
}
return 1;
}
__device__ inline void rotate_around_center(const Point &center,
const float angle_cos,
const float angle_sin, Point &p) {
float new_x =
(p.x - center.x) * angle_cos - (p.y - center.y) * angle_sin + center.x;
float new_y =
(p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y;
p.set(new_x, new_y);
}
__device__ inline int point_cmp(const Point &a, const Point &b,
const Point &center) {
return atan2(a.y - center.y, a.x - center.x) >
atan2(b.y - center.y, b.x - center.x);
}
__device__ inline float box_overlap(const float *box_a, const float *box_b) {
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 = box_a[3],
a_angle = box_a[4];
float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 = box_b[3],
b_angle = box_b[4];
Point center_a((a_x1 + a_x2) / 2, (a_y1 + a_y2) / 2);
Point center_b((b_x1 + b_x2) / 2, (b_y1 + b_y2) / 2);
Point box_a_corners[5];
box_a_corners[0].set(a_x1, a_y1);
box_a_corners[1].set(a_x2, a_y1);
box_a_corners[2].set(a_x2, a_y2);
box_a_corners[3].set(a_x1, a_y2);
Point box_b_corners[5];
box_b_corners[0].set(b_x1, b_y1);
box_b_corners[1].set(b_x2, b_y1);
box_b_corners[2].set(b_x2, b_y2);
box_b_corners[3].set(b_x1, b_y2);
// get oriented corners
float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle);
float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle);
for (int k = 0; k < 4; k++) {
rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]);
rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]);
}
box_a_corners[4] = box_a_corners[0];
box_b_corners[4] = box_b_corners[0];
// get intersection of lines
Point cross_points[16];
Point poly_center;
int cnt = 0, flag = 0;
poly_center.set(0, 0);
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
flag = intersection(box_a_corners[i + 1], box_a_corners[i],
box_b_corners[j + 1], box_b_corners[j],
cross_points[cnt]);
if (flag) {
poly_center = poly_center + cross_points[cnt];
cnt++;
}
}
}
// check corners
for (int k = 0; k < 4; k++) {
if (check_in_box2d(box_a, box_b_corners[k])) {
poly_center = poly_center + box_b_corners[k];
cross_points[cnt] = box_b_corners[k];
cnt++;
}
if (check_in_box2d(box_b, box_a_corners[k])) {
poly_center = poly_center + box_a_corners[k];
cross_points[cnt] = box_a_corners[k];
cnt++;
}
}
poly_center.x /= cnt;
poly_center.y /= cnt;
// sort the points of polygon
Point temp;
for (int j = 0; j < cnt - 1; j++) {
for (int i = 0; i < cnt - j - 1; i++) {
if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) {
temp = cross_points[i];
cross_points[i] = cross_points[i + 1];
cross_points[i + 1] = temp;
}
}
}
// get the overlap areas
float area = 0;
for (int k = 0; k < cnt - 1; k++) {
area += cross(cross_points[k] - cross_points[0],
cross_points[k + 1] - cross_points[0]);
}
return fabs(area) / 2.0;
}
__device__ inline float iou_bev(const float *box_a, const float *box_b) {
// params: box_a (5) [x1, y1, x2, y2, angle]
// params: box_b (5) [x1, y1, x2, y2, angle]
float sa = (box_a[2] - box_a[0]) * (box_a[3] - box_a[1]);
float sb = (box_b[2] - box_b[0]) * (box_b[3] - box_b[1]);
float s_overlap = box_overlap(box_a, box_b);
return s_overlap / fmaxf(sa + sb - s_overlap, EPS);
}
__global__ void iou3d_boxes_overlap_bev_forward_cuda_kernel(
const int num_a, const float *boxes_a, const int num_b,
const float *boxes_b, float *ans_overlap) {
const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
if (a_idx >= num_a || b_idx >= num_b) {
return;
}
const float *cur_box_a = boxes_a + a_idx * 5;
const float *cur_box_b = boxes_b + b_idx * 5;
float s_overlap = box_overlap(cur_box_a, cur_box_b);
ans_overlap[a_idx * num_b + b_idx] = s_overlap;
}
__global__ void iou3d_boxes_iou_bev_forward_cuda_kernel(const int num_a,
const float *boxes_a,
const int num_b,
const float *boxes_b,
float *ans_iou) {
const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
if (a_idx >= num_a || b_idx >= num_b) {
return;
}
const float *cur_box_a = boxes_a + a_idx * 5;
const float *cur_box_b = boxes_b + b_idx * 5;
float cur_iou_bev = iou_bev(cur_box_a, cur_box_b);
ans_iou[a_idx * num_b + b_idx] = cur_iou_bev;
}
__global__ void nms_forward_cuda_kernel(const int boxes_num,
const float nms_overlap_thresh,
const float *boxes,
unsigned long long *mask) {
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS,
THREADS_PER_BLOCK_NMS);
const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS,
THREADS_PER_BLOCK_NMS);
__shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 5 + 0] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0];
block_boxes[threadIdx.x * 5 + 1] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1];
block_boxes[threadIdx.x * 5 + 2] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2];
block_boxes[threadIdx.x * 5 + 3] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3];
block_boxes[threadIdx.x * 5 + 4] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
const float *cur_box = boxes + cur_box_idx * 5;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (iou_bev(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
t |= 1ULL << i;
}
}
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
mask[cur_box_idx * col_blocks + col_start] = t;
}
}
__device__ inline float iou_normal(float const *const a, float const *const b) {
float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]);
float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]);
float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f);
float interS = width * height;
float Sa = (a[2] - a[0]) * (a[3] - a[1]);
float Sb = (b[2] - b[0]) * (b[3] - b[1]);
return interS / fmaxf(Sa + Sb - interS, EPS);
}
__global__ void nms_normal_forward_cuda_kernel(const int boxes_num,
const float nms_overlap_thresh,
const float *boxes,
unsigned long long *mask) {
// params: boxes (N, 5) [x1, y1, x2, y2, ry]
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS,
THREADS_PER_BLOCK_NMS);
const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS,
THREADS_PER_BLOCK_NMS);
__shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 5];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 5 + 0] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 0];
block_boxes[threadIdx.x * 5 + 1] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 1];
block_boxes[threadIdx.x * 5 + 2] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 2];
block_boxes[threadIdx.x * 5 + 3] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 3];
block_boxes[threadIdx.x * 5 + 4] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 5 + 4];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
const float *cur_box = boxes + cur_box_idx * 5;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (iou_normal(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
t |= 1ULL << i;
}
}
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
mask[cur_box_idx * col_blocks + col_start] = t;
}
}
#endif // IOU3D_CUDA_KERNEL_CUH
......@@ -6,6 +6,8 @@
using namespace at;
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CPU(x) \
......
// Modified from
// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms_kernel.cu
/*
3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others)
Written by Shaoshuai Shi
All Rights Reserved 2019-2020.
*/
#include <stdio.h>
#include "iou3d_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
const Tensor boxes_a,
const int num_b,
const Tensor boxes_b,
Tensor ans_overlap) {
at::cuda::CUDAGuard device_guard(boxes_a.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(num_b, THREADS_PER_BLOCK_IOU3D),
DIVUP(num_a, THREADS_PER_BLOCK_IOU3D));
dim3 threads(THREADS_PER_BLOCK_IOU3D, THREADS_PER_BLOCK_IOU3D);
iou3d_boxes_overlap_bev_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(
num_a, boxes_a.data_ptr<float>(), num_b, boxes_b.data_ptr<float>(),
ans_overlap.data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
}
void IoU3DBoxesIoUBevForwardCUDAKernelLauncher(const int num_a,
const Tensor boxes_a,
const int num_b,
const Tensor boxes_b,
Tensor ans_iou) {
at::cuda::CUDAGuard device_guard(boxes_a.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(num_b, THREADS_PER_BLOCK_IOU3D),
DIVUP(num_a, THREADS_PER_BLOCK_IOU3D));
dim3 threads(THREADS_PER_BLOCK_IOU3D, THREADS_PER_BLOCK_IOU3D);
iou3d_boxes_iou_bev_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(
num_a, boxes_a.data_ptr<float>(), num_b, boxes_b.data_ptr<float>(),
ans_iou.data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
}
void IoU3DNMSForwardCUDAKernelLauncher(const Tensor boxes,
unsigned long long *mask, int boxes_num,
float nms_overlap_thresh) {
at::cuda::CUDAGuard device_guard(boxes.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
dim3 threads(THREADS_PER_BLOCK_NMS);
nms_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(
boxes_num, nms_overlap_thresh, boxes.data_ptr<float>(), mask);
AT_CUDA_CHECK(cudaGetLastError());
}
void IoU3DNMSNormalForwardCUDAKernelLauncher(const Tensor boxes,
unsigned long long *mask,
int boxes_num,
float nms_overlap_thresh) {
at::cuda::CUDAGuard device_guard(boxes.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
dim3 threads(THREADS_PER_BLOCK_NMS);
nms_normal_forward_cuda_kernel<<<blocks, threads, 0, stream>>>(
boxes_num, nms_overlap_thresh, boxes.data_ptr<float>(), mask);
AT_CUDA_CHECK(cudaGetLastError());
}
// Modified from
// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms.cpp
/*
3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others)
Written by Shaoshuai Shi
All Rights Reserved 2019-2020.
*/
#include "pytorch_cpp_helper.hpp"
const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
#ifdef MMCV_WITH_CUDA
#include <cuda.h>
#include <cuda_runtime_api.h>
#define CHECK_ERROR(state) \
{ gpuAssert((state), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line,
bool abort = true) {
if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
line);
if (abort) exit(code);
}
}
void IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(const int num_a,
const Tensor boxes_a,
const int num_b,
const Tensor boxes_b,
Tensor ans_overlap);
void iou3d_boxes_overlap_bev_forward_cuda(const int num_a, const Tensor boxes_a,
const int num_b, const Tensor boxes_b,
Tensor ans_overlap) {
IoU3DBoxesOverlapBevForwardCUDAKernelLauncher(num_a, boxes_a, num_b, boxes_b,
ans_overlap);
};
void IoU3DBoxesIoUBevForwardCUDAKernelLauncher(const int num_a,
const Tensor boxes_a,
const int num_b,
const Tensor boxes_b,
Tensor ans_iou);
void iou3d_boxes_iou_bev_forward_cuda(const int num_a, const Tensor boxes_a,
const int num_b, const Tensor boxes_b,
Tensor ans_iou) {
IoU3DBoxesIoUBevForwardCUDAKernelLauncher(num_a, boxes_a, num_b, boxes_b,
ans_iou);
};
void IoU3DNMSForwardCUDAKernelLauncher(const Tensor boxes,
unsigned long long *mask, int boxes_num,
float nms_overlap_thresh);
void iou3d_nms_forward_cuda(const Tensor boxes, unsigned long long *mask,
int boxes_num, float nms_overlap_thresh) {
IoU3DNMSForwardCUDAKernelLauncher(boxes, mask, boxes_num, nms_overlap_thresh);
};
void IoU3DNMSNormalForwardCUDAKernelLauncher(const Tensor boxes,
unsigned long long *mask,
int boxes_num,
float nms_overlap_thresh);
void iou3d_nms_normal_forward_cuda(const Tensor boxes, unsigned long long *mask,
int boxes_num, float nms_overlap_thresh) {
IoU3DNMSNormalForwardCUDAKernelLauncher(boxes, mask, boxes_num,
nms_overlap_thresh);
};
#endif
void iou3d_boxes_overlap_bev_forward(Tensor boxes_a, Tensor boxes_b,
Tensor ans_overlap) {
// params boxes_a: (N, 5) [x1, y1, x2, y2, ry]
// params boxes_b: (M, 5)
// params ans_overlap: (N, M)
if (boxes_a.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(boxes_a);
CHECK_CUDA_INPUT(boxes_b);
CHECK_CUDA_INPUT(ans_overlap);
int num_a = boxes_a.size(0);
int num_b = boxes_b.size(0);
iou3d_boxes_overlap_bev_forward_cuda(num_a, boxes_a, num_b, boxes_b,
ans_overlap);
#else
AT_ERROR("iou3d_boxes_overlap_bev is not compiled with GPU support");
#endif
} else {
AT_ERROR("iou3d_boxes_overlap_bev is not implemented on CPU");
}
}
void iou3d_boxes_iou_bev_forward(Tensor boxes_a, Tensor boxes_b,
Tensor ans_iou) {
// params boxes_a: (N, 5) [x1, y1, x2, y2, ry]
// params boxes_b: (M, 5)
// params ans_overlap: (N, M)
if (boxes_a.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(boxes_a);
CHECK_CUDA_INPUT(boxes_b);
CHECK_CUDA_INPUT(ans_iou);
int num_a = boxes_a.size(0);
int num_b = boxes_b.size(0);
iou3d_boxes_iou_bev_forward_cuda(num_a, boxes_a, num_b, boxes_b, ans_iou);
#else
AT_ERROR("iou3d_boxes_iou_bev is not compiled with GPU support");
#endif
} else {
AT_ERROR("iou3d_boxes_iou_bev is not implemented on CPU");
}
}
int iou3d_nms_forward(Tensor boxes, Tensor keep, float nms_overlap_thresh) {
// params boxes: (N, 5) [x1, y1, x2, y2, ry]
// params keep: (N)
if (boxes.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(boxes);
CHECK_CONTIGUOUS(keep);
int boxes_num = boxes.size(0);
int64_t *keep_data = keep.data_ptr<int64_t>();
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
unsigned long long *mask_data = NULL;
CHECK_ERROR(
cudaMalloc((void **)&mask_data,
boxes_num * col_blocks * sizeof(unsigned long long)));
iou3d_nms_forward_cuda(boxes, mask_data, boxes_num, nms_overlap_thresh);
// unsigned long long mask_cpu[boxes_num * col_blocks];
// unsigned long long *mask_cpu = new unsigned long long [boxes_num *
// col_blocks];
std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);
// printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data,
boxes_num * col_blocks * sizeof(unsigned long long),
cudaMemcpyDeviceToHost));
cudaFree(mask_data);
unsigned long long *remv_cpu = new unsigned long long[col_blocks]();
int num_to_keep = 0;
for (int i = 0; i < boxes_num; i++) {
int nblock = i / THREADS_PER_BLOCK_NMS;
int inblock = i % THREADS_PER_BLOCK_NMS;
if (!(remv_cpu[nblock] & (1ULL << inblock))) {
keep_data[num_to_keep++] = i;
unsigned long long *p = &mask_cpu[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv_cpu[j] |= p[j];
}
}
}
delete[] remv_cpu;
if (cudaSuccess != cudaGetLastError()) printf("Error!\n");
return num_to_keep;
#else
AT_ERROR("iou3d_nms is not compiled with GPU support");
#endif
} else {
AT_ERROR("iou3d_nms is not implemented on CPU");
}
}
int iou3d_nms_normal_forward(Tensor boxes, Tensor keep,
float nms_overlap_thresh) {
// params boxes: (N, 5) [x1, y1, x2, y2, ry]
// params keep: (N)
if (boxes.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(boxes);
CHECK_CONTIGUOUS(keep);
int boxes_num = boxes.size(0);
int64_t *keep_data = keep.data_ptr<int64_t>();
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
unsigned long long *mask_data = NULL;
CHECK_ERROR(
cudaMalloc((void **)&mask_data,
boxes_num * col_blocks * sizeof(unsigned long long)));
iou3d_nms_normal_forward_cuda(boxes, mask_data, boxes_num,
nms_overlap_thresh);
// unsigned long long mask_cpu[boxes_num * col_blocks];
// unsigned long long *mask_cpu = new unsigned long long [boxes_num *
// col_blocks];
std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);
CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data,
boxes_num * col_blocks * sizeof(unsigned long long),
cudaMemcpyDeviceToHost));
cudaFree(mask_data);
unsigned long long *remv_cpu = new unsigned long long[col_blocks]();
int num_to_keep = 0;
for (int i = 0; i < boxes_num; i++) {
int nblock = i / THREADS_PER_BLOCK_NMS;
int inblock = i % THREADS_PER_BLOCK_NMS;
if (!(remv_cpu[nblock] & (1ULL << inblock))) {
keep_data[num_to_keep++] = i;
unsigned long long *p = &mask_cpu[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv_cpu[j] |= p[j];
}
}
}
delete[] remv_cpu;
if (cudaSuccess != cudaGetLastError()) printf("Error!\n");
return num_to_keep;
#else
AT_ERROR("iou3d_nms_normal is not compiled with GPU support");
#endif
} else {
AT_ERROR("iou3d_nms_normal is not implemented on CPU");
}
}
......@@ -105,6 +105,17 @@ void three_nn_forward(int b, int n, int m, Tensor unknown_tensor,
void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
const int mode, const bool aligned, const int offset);
void iou3d_boxes_overlap_bev_forward(Tensor boxes_a, Tensor boxes_b,
Tensor ans_overlap);
void iou3d_boxes_iou_bev_forward(Tensor boxes_a, Tensor boxes_b,
Tensor ans_iou);
int iou3d_nms_forward(Tensor boxes, Tensor keep, float nms_overlap_thresh);
int iou3d_nms_normal_forward(Tensor boxes, Tensor keep,
float nms_overlap_thresh);
void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor,
Tensor new_xyz_tensor, Tensor idx_tensor, Tensor dist2_tensor);
......@@ -442,6 +453,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("bbox_overlaps", &bbox_overlaps, "bbox_overlaps", py::arg("bboxes1"),
py::arg("bboxes2"), py::arg("ious"), py::arg("mode"),
py::arg("aligned"), py::arg("offset"));
m.def("iou3d_boxes_overlap_bev_forward", &iou3d_boxes_overlap_bev_forward,
"iou3d_boxes_overlap_bev_forward", py::arg("boxes_a"),
py::arg("boxes_b"), py::arg("ans_overlap"));
m.def("iou3d_boxes_iou_bev_forward", &iou3d_boxes_iou_bev_forward,
"iou3d_boxes_iou_bev_forward", py::arg("boxes_a"), py::arg("boxes_b"),
py::arg("ans_iou"));
m.def("iou3d_nms_forward", &iou3d_nms_forward, "iou3d_nms_forward",
py::arg("boxes"), py::arg("keep"), py::arg("nms_overlap_thresh"));
m.def("iou3d_nms_normal_forward", &iou3d_nms_normal_forward,
"iou3d_nms_normal_forward", py::arg("boxes"), py::arg("keep"),
py::arg("nms_overlap_thresh"));
m.def("furthest_point_sampling_forward", &furthest_point_sampling_forward,
"furthest_point_sampling_forward", py::arg("b"), py::arg("n"),
py::arg("m"), py::arg("points_tensor"), py::arg("temp_tensor"),
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', [
'iou3d_boxes_iou_bev_forward', 'iou3d_nms_forward',
'iou3d_nms_normal_forward'
])
def boxes_iou_bev(boxes_a, boxes_b):
"""Calculate boxes IoU in the Bird's Eye View.
Args:
boxes_a (torch.Tensor): Input boxes a with shape (M, 5).
boxes_b (torch.Tensor): Input boxes b with shape (N, 5).
Returns:
ans_iou (torch.Tensor): IoU result with shape (M, N).
"""
ans_iou = boxes_a.new_zeros(
torch.Size((boxes_a.shape[0], boxes_b.shape[0])))
ext_module.iou3d_boxes_iou_bev_forward(boxes_a.contiguous(),
boxes_b.contiguous(), ans_iou)
return ans_iou
def nms_bev(boxes, scores, thresh, pre_max_size=None, post_max_size=None):
"""NMS function GPU implementation (for BEV boxes). The overlap of two
boxes for IoU calculation is defined as the exact overlapping area of the
two boxes. In this function, one can also set ``pre_max_size`` and
``post_max_size``.
Args:
boxes (torch.Tensor): Input boxes with the shape of [N, 5]
([x1, y1, x2, y2, ry]).
scores (torch.Tensor): Scores of boxes with the shape of [N].
thresh (float): Overlap threshold of NMS.
pre_max_size (int, optional): Max size of boxes before NMS.
Default: None.
post_max_size (int, optional): Max size of boxes after NMS.
Default: None.
Returns:
torch.Tensor: Indexes after NMS.
"""
order = scores.sort(0, descending=True)[1]
if pre_max_size is not None:
order = order[:pre_max_size]
boxes = boxes[order].contiguous()
keep = torch.zeros(boxes.size(0), dtype=torch.long)
num_out = ext_module.iou3d_nms_forward(boxes, keep, thresh)
keep = order[keep[:num_out].cuda(boxes.device)].contiguous()
if post_max_size is not None:
keep = keep[:post_max_size]
return keep
def nms_normal_bev(boxes, scores, thresh):
"""Normal NMS function GPU implementation (for BEV boxes). The overlap of
two boxes for IoU calculation is defined as the exact overlapping area of
the two boxes WITH their yaw angle set to 0.
Args:
boxes (torch.Tensor): Input boxes with shape (N, 5).
scores (torch.Tensor): Scores of predicted boxes with shape (N).
thresh (float): Overlap threshold of NMS.
Returns:
torch.Tensor: Remaining indices with scores in descending order.
"""
order = scores.sort(0, descending=True)[1]
boxes = boxes[order].contiguous()
keep = torch.zeros(boxes.size(0), dtype=torch.long)
num_out = ext_module.iou3d_nms_normal_forward(boxes, keep, thresh)
return order[keep[:num_out].cuda(boxes.device)].contiguous()
import numpy as np
import pytest
import torch
from mmcv.ops import boxes_iou_bev, nms_bev, nms_normal_bev
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_boxes_iou_bev():
np_boxes1 = np.asarray(
[[1.0, 1.0, 3.0, 4.0, 0.5], [2.0, 2.0, 3.0, 4.0, 0.6],
[7.0, 7.0, 8.0, 8.0, 0.4]],
dtype=np.float32)
np_boxes2 = np.asarray(
[[0.0, 2.0, 2.0, 5.0, 0.3], [2.0, 1.0, 3.0, 3.0, 0.5],
[5.0, 5.0, 6.0, 7.0, 0.4]],
dtype=np.float32)
np_expect_ious = np.asarray(
[[0.2621, 0.2948, 0.0000], [0.0549, 0.1587, 0.0000],
[0.0000, 0.0000, 0.0000]],
dtype=np.float32)
boxes1 = torch.from_numpy(np_boxes1).cuda()
boxes2 = torch.from_numpy(np_boxes2).cuda()
ious = boxes_iou_bev(boxes1, boxes2)
assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_nms_gpu():
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
dtype=np.float32)
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
np_inds = np.array([1, 0, 3])
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
inds = nms_bev(boxes.cuda(), scores.cuda(), thresh=0.3)
assert np.allclose(inds.cpu().numpy(), np_inds) # test gpu
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_nms_normal_gpu():
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
dtype=np.float32)
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
np_inds = np.array([1, 2, 0, 3])
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
inds = nms_normal_bev(boxes.cuda(), scores.cuda(), thresh=0.3)
assert np.allclose(inds.cpu().numpy(), np_inds) # test gpu
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment