Unverified Commit f3dfc413 authored by dingchang's avatar dingchang Committed by GitHub
Browse files

[Feature] Add ballquery op from mmdet3d (#1332)

parent 4e101e0b
......@@ -2,6 +2,7 @@
We implement common CUDA ops used in detection, segmentation, etc.
- BallQuery
- BBoxOverlaps
- CARAFE
- CrissCrossAttention
......
......@@ -2,6 +2,7 @@
MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- BallQuery
- BBoxOverlaps
- CARAFE
- CrissCrossAttention
......
# Copyright (c) OpenMMLab. All rights reserved.
from .ball_query import ball_query
from .bbox import bbox_overlaps
from .border_align import BorderAlign, border_align
from .box_iou_rotated import box_iou_rotated
......@@ -50,7 +51,7 @@ __all__ = [
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align'
]
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['ball_query_forward'])
class BallQuery(Function):
"""Find nearby points in spherical space."""
@staticmethod
def forward(ctx, min_radius: float, max_radius: float, sample_num: int,
xyz: torch.Tensor, center_xyz: torch.Tensor) -> torch.Tensor:
"""
Args:
min_radius (float): minimum radius of the balls.
max_radius (float): maximum radius of the balls.
sample_num (int): maximum number of features in the balls.
xyz (Tensor): (B, N, 3) xyz coordinates of the features.
center_xyz (Tensor): (B, npoint, 3) centers of the ball query.
Returns:
Tensor: (B, npoint, nsample) tensor with the indicies of
the features that form the query balls.
"""
assert center_xyz.is_contiguous()
assert xyz.is_contiguous()
assert min_radius < max_radius
B, N, _ = xyz.size()
npoint = center_xyz.size(1)
idx = xyz.new_zeros(B, npoint, sample_num, dtype=torch.int)
ext_module.ball_query_forward(B, N, npoint, min_radius, max_radius,
sample_num, center_xyz, xyz, idx)
ctx.mark_non_differentiable(idx)
return idx
@staticmethod
def backward(ctx, a=None):
return None, None, None, None
ball_query = BallQuery.apply
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu
#ifndef BALL_QUERY_CUDA_KERNEL_CUH
#define BALL_QUERY_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
template <typename T>
__global__ void ball_query_forward_cuda_kernel(int b, int n, int m,
float min_radius,
float max_radius, int nsample,
const T* new_xyz, const T* xyz,
int* idx) {
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= m) return;
new_xyz += bs_idx * m * 3 + pt_idx * 3;
xyz += bs_idx * n * 3;
idx += bs_idx * m * nsample + pt_idx * nsample;
float max_radius2 = max_radius * max_radius;
float min_radius2 = min_radius * min_radius;
T new_x = new_xyz[0];
T new_y = new_xyz[1];
T new_z = new_xyz[2];
int cnt = 0;
for (int k = 0; k < n; ++k) {
T x = xyz[k * 3 + 0];
T y = xyz[k * 3 + 1];
T z = xyz[k * 3 + 2];
T d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
(new_z - z) * (new_z - z);
if (d2 == 0 || (d2 >= min_radius2 && d2 < max_radius2)) {
if (cnt == 0) {
for (int l = 0; l < nsample; ++l) {
idx[l] = k;
}
}
idx[cnt] = k;
++cnt;
if (cnt >= nsample) break;
}
}
}
#endif // BALL_QUERY_CUDA_KERNEL_CUH
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query.cpp
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void BallQueryForwardCUDAKernelLauncher(int b, int n, int m, float min_radius,
float max_radius, int nsample,
const Tensor new_xyz, const Tensor xyz,
Tensor idx);
void ball_query_forward_cuda(int b, int n, int m, float min_radius,
float max_radius, int nsample,
const Tensor new_xyz, const Tensor xyz,
Tensor idx) {
BallQueryForwardCUDAKernelLauncher(b, n, m, min_radius, max_radius, nsample,
new_xyz, xyz, idx);
};
#endif
void ball_query_forward(int b, int n, int m, float min_radius, float max_radius,
int nsample, Tensor new_xyz_tensor, Tensor xyz_tensor,
Tensor idx_tensor) {
if (new_xyz_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(new_xyz_tensor);
CHECK_CUDA_INPUT(xyz_tensor);
ball_query_forward_cuda(b, n, m, min_radius, max_radius, nsample,
new_xyz_tensor, xyz_tensor, idx_tensor);
#else
AT_ERROR("ball_query is not compiled with GPU support");
#endif
} else {
AT_ERROR("ball_query is not implemented on CPU");
}
}
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/ball_query_gpu.cu
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "ball_query_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void BallQueryForwardCUDAKernelLauncher(int b, int n, int m, float min_radius,
float max_radius, int nsample,
const Tensor new_xyz, const Tensor xyz,
Tensor idx) {
// new_xyz: (B, M, 3)
// xyz: (B, N, 3)
// output:
// idx: (B, M, nsample)
at::cuda::CUDAGuard device_guard(new_xyz.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK),
b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
new_xyz.scalar_type(), "ball_query_forward_cuda_kernel", [&] {
ball_query_forward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
b, n, m, min_radius, max_radius, nsample,
new_xyz.data_ptr<scalar_t>(), xyz.data_ptr<scalar_t>(),
idx.data_ptr<int>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
......@@ -111,16 +111,16 @@ Tensor nms(Tensor boxes, Tensor scores, float iou_threshold, int offset);
Tensor softnms(Tensor boxes, Tensor scores, Tensor dets, float iou_threshold,
float sigma, float min_score, int method, int offset);
std::vector<std::vector<int> > nms_match(Tensor dets, float iou_threshold);
std::vector<std::vector<int>> nms_match(Tensor dets, float iou_threshold);
std::vector<std::vector<float> > pixel_group(
std::vector<std::vector<float>> pixel_group(
Tensor score, Tensor mask, Tensor embedding, Tensor kernel_label,
Tensor kernel_contour, int kernel_region_num, float distance_threshold);
std::vector<std::vector<int> > contour_expand(Tensor kernel_mask,
Tensor internal_kernel_label,
int min_kernel_area,
int kernel_num);
std::vector<std::vector<int>> contour_expand(Tensor kernel_mask,
Tensor internal_kernel_label,
int min_kernel_area,
int kernel_num);
void roi_align_forward(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x, int aligned_height,
......@@ -172,6 +172,10 @@ void tin_shift_forward(Tensor input, Tensor shift, Tensor output);
void tin_shift_backward(Tensor grad_output, Tensor shift, Tensor grad_input);
void ball_query_forward(int b, int n, int m, float min_radius, float max_radius,
int nsample, Tensor new_xyz_tensor, Tensor xyz_tensor,
Tensor idx_tensor);
Tensor bottom_pool_forward(Tensor input);
Tensor bottom_pool_backward(Tensor input, Tensor grad_output);
......@@ -415,6 +419,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("nms_rotated", &nms_rotated, "NMS for rotated boxes", py::arg("dets"),
py::arg("scores"), py::arg("order"), py::arg("dets_sorted"),
py::arg("iou_threshold"), py::arg("multi_label"));
m.def("ball_query_forward", &ball_query_forward, "ball_query_forward",
py::arg("b"), py::arg("n"), py::arg("m"), py::arg("min_radius"),
py::arg("max_radius"), py::arg("nsample"), py::arg("new_xyz_tensor"),
py::arg("xyz_tensor"), py::arg("idx_tensor"));
m.def("roi_align_rotated_forward", &roi_align_rotated_forward,
"roi_align_rotated forward", py::arg("input"), py::arg("rois"),
py::arg("output"), py::arg("pooled_height"), py::arg("pooled_width"),
......
import pytest
import torch
from mmcv.ops import ball_query
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_ball_query():
new_xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625],
[-2.2769, 2.7817, -0.2334],
[-0.4003, 2.4666, -0.5116],
[-0.0740, 1.3147, -1.3625],
[-0.0740, 1.3147, -1.3625]],
[[-2.0289, 2.4952, -0.1708],
[-2.0668, 6.0278, -0.4875],
[0.4066, 1.4211, -0.2947],
[-2.0289, 2.4952, -0.1708],
[-2.0289, 2.4952, -0.1708]]]).cuda()
xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634],
[-0.4003, 2.4666,
-0.5116], [-0.5251, 2.4379, -0.8466],
[-0.9691, 1.1418,
-1.3733], [-0.2232, 0.9561, -1.3626],
[-2.2769, 2.7817, -0.2334],
[-0.2822, 1.3192, -1.3645], [0.1533, 1.5024, -1.0432],
[0.4917, 1.1529, -1.3496]],
[[-2.0289, 2.4952,
-0.1708], [-0.7188, 0.9956, -0.5096],
[-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610],
[0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791],
[-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947],
[0.3220, 1.4447, 0.3548], [-0.9744, 2.3856,
-1.2000]]]).cuda()
idx = ball_query(0, 0.2, 5, xyz, new_xyz)
expected_idx = torch.tensor([[[0, 0, 0, 0, 0], [6, 6, 6, 6, 6],
[2, 2, 2, 2, 2], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]],
[[0, 0, 0, 0, 0], [2, 2, 2, 2, 2],
[7, 7, 7, 7, 7], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]]).cuda()
assert torch.all(idx == expected_idx)
# test dilated ball query
idx = ball_query(0.2, 0.4, 5, xyz, new_xyz)
expected_idx = torch.tensor([[[0, 5, 7, 0, 0], [6, 6, 6, 6, 6],
[2, 3, 2, 2, 2], [0, 5, 7, 0, 0],
[0, 5, 7, 0, 0]],
[[0, 0, 0, 0, 0], [2, 2, 2, 2, 2],
[7, 7, 7, 7, 7], [0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]]]).cuda()
assert torch.all(idx == expected_idx)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment