Unverified Commit 1cd01db9 authored by dingchang's avatar dingchang Committed by GitHub
Browse files

[Feature] Add gather points op from mmdet3d (#1338)

* add ops (gather points) in mmdet3d

* add ops (gather points) in mmdet3d

* refactor code

* refactor code

* fix typo
parent 70f902bb
......@@ -10,6 +10,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- CornerPool
- Deformable Convolution v1/v2
- Deformable RoIPool
- GatherPoints
- FurthestPointSample
- FurthestPointSampleWithDist
- GeneralizedAttention
......
......@@ -10,6 +10,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- CornerPool
- Deformable Convolution v1/v2
- Deformable RoIPool
- GatherPoints
- FurthestPointSample
- FurthestPointSampleWithDist
- GeneralizedAttention
......
......@@ -20,6 +20,7 @@ from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
from .furthest_point_sample import (furthest_point_sample,
furthest_point_sample_with_dist)
from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
from .gather_points import gather_points
from .info import (get_compiler_version, get_compiling_cuda_version,
get_onnxruntime_op_path)
from .knn import knn
......@@ -59,6 +60,6 @@ __all__ = [
'knn', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU',
'fused_bias_leakyrelu', 'RoIAlignRotated', 'roi_align_rotated',
'pixel_group', 'contour_expand', 'MultiScaleDeformableAttention',
'BorderAlign', 'border_align', 'furthest_point_sample',
'BorderAlign', 'border_align', 'gather_points', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation'
]
// Copyright (c) OpenMMLab. All rights reserved
#ifndef GATHER_POINTS_CUDA_KERNEL_CUH
#define GATHER_POINTS_CUDA_KERNEL_CUH
#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#define TOTAL_THREADS 1024
template <typename T>
__global__ void gather_points_forward_cuda_kernel(int b, int c, int n, int m,
const T *points,
const int *__restrict__ idx,
T *out) {
// points: (B, C, N)
// idx: (B, M)
// output:
// out: (B, C, M)
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
out += bs_idx * c * m + c_idx * m + pt_idx;
idx += bs_idx * m + pt_idx;
points += bs_idx * c * n + c_idx * n;
out[0] = points[idx[0]];
}
template <typename T>
__global__ void gather_points_backward_cuda_kernel(int b, int c, int n, int m,
const T *grad_out,
const int *__restrict__ idx,
T *grad_points) {
// grad_out: (B, C, M)
// idx: (B, M)
// output:
// grad_points: (B, C, N)
int bs_idx = blockIdx.z;
int c_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
grad_out += bs_idx * c * m + c_idx * m + pt_idx;
idx += bs_idx * m + pt_idx;
grad_points += bs_idx * c * n + c_idx * n;
atomicAdd(grad_points + idx[0], grad_out[0]);
}
#endif // GATHER_POINTS_CUDA_KERNEL_CUH
......@@ -21,8 +21,8 @@ void BallQueryForwardCUDAKernelLauncher(int b, int n, int m, float min_radius,
at::cuda::CUDAGuard device_guard(new_xyz.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK),
b); // blockIdx.x(col), blockIdx.y(row)
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
......
#include <stdio.h>
#include <stdlib.h>
#include "gather_points_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
void GatherPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints,
const Tensor points,
const Tensor idx, Tensor out) {
// points: (B, C, N)
// idx: (B, npoints)
// output:
// out: (B, C, npoints)
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
points.scalar_type(), "gather_points_forward_cuda_kernel", [&] {
gather_points_forward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
b, c, n, npoints, points.data_ptr<scalar_t>(),
idx.data_ptr<int>(), out.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
void GatherPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints,
const Tensor grad_out,
const Tensor idx,
Tensor grad_points) {
// grad_out: (B, C, npoints)
// idx: (B, npoints)
// output:
// grad_points: (B, C, N)
at::cuda::CUDAGuard device_guard(grad_out.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b);
dim3 threads(THREADS_PER_BLOCK);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_out.scalar_type(), "gather_points_backward_cuda_kernel", [&] {
gather_points_backward_cuda_kernel<scalar_t>
<<<blocks, threads, 0, stream>>>(
b, c, n, npoints, grad_out.data_ptr<scalar_t>(),
idx.data_ptr<int>(), grad_points.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
}
#include "pytorch_cpp_helper.hpp"
#ifdef MMCV_WITH_CUDA
void GatherPointsForwardCUDAKernelLauncher(int b, int c, int n, int npoints,
const Tensor points,
const Tensor idx, Tensor out);
void gather_points_forward_cuda(int b, int c, int n, int npoints,
const Tensor points, const Tensor idx,
Tensor out) {
GatherPointsForwardCUDAKernelLauncher(b, c, n, npoints, points, idx, out);
};
void GatherPointsBackwardCUDAKernelLauncher(int b, int c, int n, int npoints,
const Tensor grad_out,
const Tensor idx,
Tensor grad_points);
void gather_points_backward_cuda(int b, int c, int n, int npoints,
const Tensor grad_out, const Tensor idx,
Tensor grad_points) {
GatherPointsBackwardCUDAKernelLauncher(b, c, n, npoints, grad_out, idx,
grad_points);
};
#endif
void gather_points_forward(int b, int c, int n, int npoints,
Tensor points_tensor, Tensor idx_tensor,
Tensor out_tensor) {
if (points_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
gather_points_forward_cuda(b, c, n, npoints, points_tensor, idx_tensor,
out_tensor);
#else
AT_ERROR("gather_points is not compiled with GPU support");
#endif
} else {
AT_ERROR("gather_points is not implemented on CPU");
}
}
void gather_points_backward(int b, int c, int n, int npoints,
Tensor grad_out_tensor, Tensor idx_tensor,
Tensor grad_points_tensor) {
if (grad_out_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
gather_points_backward_cuda(b, c, n, npoints, grad_out_tensor, idx_tensor,
grad_points_tensor);
#else
AT_ERROR("gather_points is not compiled with GPU support");
#endif
} else {
AT_ERROR("gather_points is not implemented on CPU");
}
}
......@@ -53,6 +53,14 @@ void deform_roi_pool_backward(Tensor grad_output, Tensor input, Tensor rois,
int pooled_width, float spatial_scale,
int sampling_ratio, float gamma);
void gather_points_forward(int b, int c, int n, int npoints,
Tensor points_tensor, Tensor idx_tensor,
Tensor out_tensor);
void gather_points_backward(int b, int c, int n, int npoints,
Tensor grad_out_tensor, Tensor idx_tensor,
Tensor grad_points_tensor);
void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha);
......@@ -256,6 +264,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"fused_bias_leakyrelu (CUDA)", py::arg("input"), py::arg("bias"),
py::arg("empty"), py::arg("act"), py::arg("grad"), py::arg("alpha"),
py::arg("scale"));
m.def("gather_points_forward", &gather_points_forward,
"gather_points_forward", py::arg("b"), py::arg("c"), py::arg("n"),
py::arg("npoints"), py::arg("points_tensor"), py::arg("idx_tensor"),
py::arg("out_tensor"));
m.def("gather_points_backward", &gather_points_backward,
"gather_points_backward", py::arg("b"), py::arg("c"), py::arg("n"),
py::arg("npoints"), py::arg("grad_out_tensor"), py::arg("idx_tensor"),
py::arg("grad_points_tensor"));
m.def("get_compiler_version", &get_compiler_version, "get_compiler_version");
m.def("get_compiling_cuda_version", &get_compiling_cuda_version,
"get_compiling_cuda_version");
......
import torch
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext(
'_ext', ['gather_points_forward', 'gather_points_backward'])
class GatherPoints(Function):
"""Gather points with given index."""
@staticmethod
def forward(ctx, features: torch.Tensor,
indices: torch.Tensor) -> torch.Tensor:
"""
Args:
features (Tensor): (B, C, N) features to gather.
indices (Tensor): (B, M) where M is the number of points.
Returns:
Tensor: (B, C, M) where M is the number of points.
"""
assert features.is_contiguous()
assert indices.is_contiguous()
B, npoint = indices.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, npoint)
ext_module.gather_points_forward(B, C, N, npoint, features, indices,
output)
ctx.for_backwards = (indices, C, N)
ctx.mark_non_differentiable(indices)
return output
@staticmethod
def backward(ctx, grad_out):
idx, C, N = ctx.for_backwards
B, npoint = idx.size()
grad_features = torch.cuda.FloatTensor(B, C, N).zero_()
grad_out_data = grad_out.data.contiguous()
ext_module.gather_points_backward(B, C, N, npoint, grad_out_data, idx,
grad_features.data)
return grad_features, None
gather_points = GatherPoints.apply
import pytest
import torch
from mmcv.ops import gather_points
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_gather_points():
features = torch.tensor([[[
-1.6095, -0.1029, -0.8876, -1.2447, -2.4031, 0.3708, -1.1586, -1.4967,
-0.4800, 0.2252
],
[
1.9138, 3.4979, 1.6854, 1.5631, 3.6776,
3.1154, 2.1705, 2.5221, 2.0411, 3.1446
],
[
-1.4173, 0.3073, -1.4339, -1.4340, -1.2770,
-0.2867, -1.4162, -1.4044, -1.4245, -1.4074
]],
[[
0.2160, 0.0842, 0.3661, -0.2749, -0.4909,
-0.6066, -0.8773, -0.0745, -0.9496, 0.1434
],
[
1.3644, 1.8087, 1.6855, 1.9563, 1.2746,
1.9662, 0.9566, 1.8778, 1.1437, 1.3639
],
[
-0.7172, 0.1692, 0.2241, 0.0721, -0.7540,
0.0462, -0.6227, 0.3223, -0.6944, -0.5294
]]]).cuda()
idx = torch.tensor([[0, 1, 4, 0, 0, 0], [0, 5, 6, 0, 0, 0]]).int().cuda()
output = gather_points(features, idx)
expected_output = torch.tensor(
[[[-1.6095, -0.1029, -2.4031, -1.6095, -1.6095, -1.6095],
[1.9138, 3.4979, 3.6776, 1.9138, 1.9138, 1.9138],
[-1.4173, 0.3073, -1.2770, -1.4173, -1.4173, -1.4173]],
[[0.2160, -0.6066, -0.8773, 0.2160, 0.2160, 0.2160],
[1.3644, 1.9662, 0.9566, 1.3644, 1.3644, 1.3644],
[-0.7172, 0.0462, -0.6227, -0.7172, -0.7172, -0.7172]]]).cuda()
assert torch.allclose(output, expected_output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment