"vscode:/vscode.git/clone" did not exist on "ab4920a4cb97edcd989b8002aeae8dc04a3dde4e"
Unverified Commit 5f99b50e authored by Yezhen Cong's avatar Yezhen Cong Committed by GitHub
Browse files

[Feature] Support knn gpu op (#360)

* support knn gpu op

* made it more robust and fixed comments
parent 573946a8
...@@ -9,6 +9,7 @@ from .gather_points import gather_points ...@@ -9,6 +9,7 @@ from .gather_points import gather_points
from .group_points import (GroupAll, QueryAndGroup, group_points, from .group_points import (GroupAll, QueryAndGroup, group_points,
grouping_operation) grouping_operation)
from .interpolate import three_interpolate, three_nn from .interpolate import three_interpolate, three_nn
from .knn import knn
from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d
from .pointnet_modules import (PointFPModule, PointSAModule, PointSAModuleMSG, from .pointnet_modules import (PointFPModule, PointSAModule, PointSAModuleMSG,
build_sa_module) build_sa_module)
...@@ -25,7 +26,7 @@ __all__ = [ ...@@ -25,7 +26,7 @@ __all__ = [
'dynamic_scatter', 'DynamicScatter', 'sigmoid_focal_loss', 'dynamic_scatter', 'DynamicScatter', 'sigmoid_focal_loss',
'SigmoidFocalLoss', 'SparseBasicBlock', 'SparseBottleneck', 'SigmoidFocalLoss', 'SparseBasicBlock', 'SparseBottleneck',
'RoIAwarePool3d', 'points_in_boxes_gpu', 'points_in_boxes_cpu', 'RoIAwarePool3d', 'points_in_boxes_gpu', 'points_in_boxes_cpu',
'make_sparse_convmodule', 'ball_query', 'furthest_point_sample', 'make_sparse_convmodule', 'ball_query', 'knn', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'three_interpolate', 'three_nn', 'furthest_point_sample_with_dist', 'three_interpolate', 'three_nn',
'gather_points', 'grouping_operation', 'group_points', 'GroupAll', 'gather_points', 'grouping_operation', 'group_points', 'GroupAll',
'QueryAndGroup', 'PointSAModule', 'PointSAModuleMSG', 'PointFPModule', 'QueryAndGroup', 'PointSAModule', 'PointSAModuleMSG', 'PointFPModule',
......
...@@ -12,28 +12,28 @@ class GatherPoints(Function): ...@@ -12,28 +12,28 @@ class GatherPoints(Function):
@staticmethod @staticmethod
def forward(ctx, features: torch.Tensor, def forward(ctx, features: torch.Tensor,
indicies: torch.Tensor) -> torch.Tensor: indices: torch.Tensor) -> torch.Tensor:
"""forward. """forward.
Args: Args:
features (Tensor): (B, C, N) features to gather. features (Tensor): (B, C, N) features to gather.
indicies (Tensor): (B, M) where M is the number of points. indices (Tensor): (B, M) where M is the number of points.
Returns: Returns:
Tensor: (B, C, M) where M is the number of points. Tensor: (B, C, M) where M is the number of points.
""" """
assert features.is_contiguous() assert features.is_contiguous()
assert indicies.is_contiguous() assert indices.is_contiguous()
B, npoint = indicies.size() B, npoint = indices.size()
_, C, N = features.size() _, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, npoint) output = torch.cuda.FloatTensor(B, C, npoint)
gather_points_ext.gather_points_wrapper(B, C, N, npoint, features, gather_points_ext.gather_points_wrapper(B, C, N, npoint, features,
indicies, output) indices, output)
ctx.for_backwards = (indicies, C, N) ctx.for_backwards = (indices, C, N)
ctx.mark_non_differentiable(indicies) ctx.mark_non_differentiable(indices)
return output return output
@staticmethod @staticmethod
......
from .knn import knn
__all__ = ['knn']
import torch
from torch.autograd import Function
from . import knn_ext
class KNN(Function):
"""KNN (CUDA).
Find k-nearest points.
"""
@staticmethod
def forward(ctx,
k: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor,
transposed: bool = False) -> torch.Tensor:
"""forward.
Args:
k (int): number of nearest neighbors.
xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N).
xyz coordinates of the features.
center_xyz (Tensor): (B, npoint, 3) if transposed == False,
else (B, 3, npoint). centers of the knn query.
transposed (bool): whether the input tensors are transposed.
defaults to False.
Returns:
Tensor: (B, k, npoint) tensor with the indicies of
the features that form k-nearest neighbours.
"""
assert k > 0
B, npoint = center_xyz.shape[:2]
N = xyz.shape[1]
if not transposed:
xyz = xyz.transpose(2, 1).contiguous()
center_xyz = center_xyz.transpose(2, 1).contiguous()
assert center_xyz.is_contiguous()
assert xyz.is_contiguous()
center_xyz_device = center_xyz.get_device()
assert center_xyz_device == xyz.get_device(), \
'center_xyz and xyz should be put on the same device'
if torch.cuda.current_device() != center_xyz_device:
torch.cuda.set_device(center_xyz_device)
idx = center_xyz.new_zeros((B, k, npoint)).long()
for bi in range(B):
knn_ext.knn_wrapper(xyz[bi], N, center_xyz[bi], npoint, idx[bi], k)
ctx.mark_non_differentiable(idx)
idx -= 1
return idx
@staticmethod
def backward(ctx, a=None):
return None, None
knn = KNN.apply
// Modified from https://github.com/unlimblue/KNN_CUDA
#include <vector>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_TYPE(x, t) AT_ASSERTM(x.dtype() == t, #x " must be " #t)
#define CHECK_CUDA(x) AT_ASSERTM(x.device().type() == at::Device::Type::CUDA, #x " must be on CUDA")
#define CHECK_INPUT(x, t) CHECK_CONTIGUOUS(x); CHECK_TYPE(x, t); CHECK_CUDA(x)
void knn_kernels_launcher(
const float* ref_dev,
int ref_nb,
const float* query_dev,
int query_nb,
int dim,
int k,
float* dist_dev,
long* ind_dev,
cudaStream_t stream
);
// std::vector<at::Tensor> knn_wrapper(
void knn_wrapper(
at::Tensor & ref,
int ref_nb,
at::Tensor & query,
int query_nb,
at::Tensor & ind,
const int k
) {
CHECK_INPUT(ref, at::kFloat);
CHECK_INPUT(query, at::kFloat);
const float * ref_dev = ref.data_ptr<float>();
const float * query_dev = query.data_ptr<float>();
int dim = query.size(0);
auto dist = at::empty({ref_nb, query_nb}, query.options().dtype(at::kFloat));
float * dist_dev = dist.data_ptr<float>();
long * ind_dev = ind.data_ptr<long>();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
knn_kernels_launcher(
ref_dev,
ref_nb,
query_dev,
query_nb,
dim,
k,
dist_dev,
ind_dev,
stream
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("knn_wrapper", &knn_wrapper, "knn_wrapper");
}
/** Modified from https://github.com/unlimblue/KNN_CUDA
* which is the modified version of knn-CUDA
* from https://github.com/vincentfpgarcia/kNN-CUDA
* Last modified by Christopher B. Choy <chrischoy@ai.stanford.edu> 12/23/2016
* vincentfpgarcia wrote the original cuda code, Christopher modified it and
* set it up for pytorch 0.4, and unlimblue updated it to pytorch >= 1.0
*/
// Includes
#include <cstdio>
#include "cuda.h"
// Constants used by the program
#define BLOCK_DIM 16
#define DEBUG 0
/**
* Computes the distance between two matrix A (reference points) and
* B (query points) containing respectively wA and wB points.
*
* @param A pointer on the matrix A
* @param wA width of the matrix A = number of points in A
* @param B pointer on the matrix B
* @param wB width of the matrix B = number of points in B
* @param dim dimension of points = height of matrices A and B
* @param AB pointer on the matrix containing the wA*wB distances computed
*/
__global__ void cuComputeDistanceGlobal(const float* A, int wA,
const float* B, int wB, int dim, float* AB){
// Declaration of the shared memory arrays As and Bs used to store the sub-matrix of A and B
__shared__ float shared_A[BLOCK_DIM][BLOCK_DIM];
__shared__ float shared_B[BLOCK_DIM][BLOCK_DIM];
// Sub-matrix of A (begin, step, end) and Sub-matrix of B (begin, step)
__shared__ int begin_A;
__shared__ int begin_B;
__shared__ int step_A;
__shared__ int step_B;
__shared__ int end_A;
// Thread index
int tx = threadIdx.x;
int ty = threadIdx.y;
// Other variables
float tmp;
float ssd = 0;
// Loop parameters
begin_A = BLOCK_DIM * blockIdx.y;
begin_B = BLOCK_DIM * blockIdx.x;
step_A = BLOCK_DIM * wA;
step_B = BLOCK_DIM * wB;
end_A = begin_A + (dim-1) * wA;
// Conditions
int cond0 = (begin_A + tx < wA); // used to write in shared memory
int cond1 = (begin_B + tx < wB); // used to write in shared memory & to computations and to write in output matrix
int cond2 = (begin_A + ty < wA); // used to computations and to write in output matrix
// Loop over all the sub-matrices of A and B required to compute the block sub-matrix
for (int a = begin_A, b = begin_B; a <= end_A; a += step_A, b += step_B) {
// Load the matrices from device memory to shared memory; each thread loads one element of each matrix
if (a/wA + ty < dim){
shared_A[ty][tx] = (cond0)? A[a + wA * ty + tx] : 0;
shared_B[ty][tx] = (cond1)? B[b + wB * ty + tx] : 0;
}
else{
shared_A[ty][tx] = 0;
shared_B[ty][tx] = 0;
}
// Synchronize to make sure the matrices are loaded
__syncthreads();
// Compute the difference between the two matrixes; each thread computes one element of the block sub-matrix
if (cond2 && cond1){
for (int k = 0; k < BLOCK_DIM; ++k){
tmp = shared_A[k][ty] - shared_B[k][tx];
ssd += tmp*tmp;
}
}
// Synchronize to make sure that the preceding computation is done before loading two new sub-matrices of A and B in the next iteration
__syncthreads();
}
// Write the block sub-matrix to device memory; each thread writes one element
if (cond2 && cond1)
AB[(begin_A + ty) * wB + begin_B + tx] = ssd;
}
/**
* Gathers k-th smallest distances for each column of the distance matrix in the top.
*
* @param dist distance matrix
* @param ind index matrix
* @param width width of the distance matrix and of the index matrix
* @param height height of the distance matrix and of the index matrix
* @param k number of neighbors to consider
*/
__global__ void cuInsertionSort(float *dist, long *ind, int width, int height, int k){
// Variables
int l, i, j;
float *p_dist;
long *p_ind;
float curr_dist, max_dist;
long curr_row, max_row;
unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x;
if (xIndex<width){
// Pointer shift, initialization, and max value
p_dist = dist + xIndex;
p_ind = ind + xIndex;
max_dist = p_dist[0];
p_ind[0] = 1;
// Part 1 : sort kth firt elementZ
for (l=1; l<k; l++){
curr_row = l * width;
curr_dist = p_dist[curr_row];
if (curr_dist<max_dist){
i=l-1;
for (int a=0; a<l-1; a++){
if (p_dist[a*width]>curr_dist){
i=a;
break;
}
}
for (j=l; j>i; j--){
p_dist[j*width] = p_dist[(j-1)*width];
p_ind[j*width] = p_ind[(j-1)*width];
}
p_dist[i*width] = curr_dist;
p_ind[i*width] = l + 1;
} else {
p_ind[l*width] = l + 1;
}
max_dist = p_dist[curr_row];
}
// Part 2 : insert element in the k-th first lines
max_row = (k-1)*width;
for (l=k; l<height; l++){
curr_dist = p_dist[l*width];
if (curr_dist<max_dist){
i=k-1;
for (int a=0; a<k-1; a++){
if (p_dist[a*width]>curr_dist){
i=a;
break;
}
}
for (j=k-1; j>i; j--){
p_dist[j*width] = p_dist[(j-1)*width];
p_ind[j*width] = p_ind[(j-1)*width];
}
p_dist[i*width] = curr_dist;
p_ind[i*width] = l + 1;
max_dist = p_dist[max_row];
}
}
}
}
/**
* Computes the square root of the first line (width-th first element)
* of the distance matrix.
*
* @param dist distance matrix
* @param width width of the distance matrix
* @param k number of neighbors to consider
*/
__global__ void cuParallelSqrt(float *dist, int width, int k){
unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int yIndex = blockIdx.y * blockDim.y + threadIdx.y;
if (xIndex<width && yIndex<k)
dist[yIndex*width + xIndex] = sqrt(dist[yIndex*width + xIndex]);
}
void debug(float * dist_dev, long * ind_dev, const int query_nb, const int k){
float* dist_host = new float[query_nb * k];
long* idx_host = new long[query_nb * k];
// Memory copy of output from device to host
cudaMemcpy(dist_host, dist_dev,
query_nb * k * sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(idx_host, ind_dev,
query_nb * k * sizeof(long), cudaMemcpyDeviceToHost);
int i, j;
for(i = 0; i < k; i++){
for (j = 0; j < query_nb; j++) {
if (j % 8 == 0)
printf("/\n");
printf("%f ", sqrt(dist_host[i*query_nb + j]));
}
printf("\n");
}
}
//-----------------------------------------------------------------------------------------------//
// K-th NEAREST NEIGHBORS //
//-----------------------------------------------------------------------------------------------//
/**
* K nearest neighbor algorithm
* - Initialize CUDA
* - Allocate device memory
* - Copy point sets (reference and query points) from host to device memory
* - Compute the distances + indexes to the k nearest neighbors for each query point
* - Copy distances from device to host memory
*
* @param ref_host reference points ; pointer to linear matrix
* @param ref_nb number of reference points ; width of the matrix
* @param query_host query points ; pointer to linear matrix
* @param query_nb number of query points ; width of the matrix
* @param dim dimension of points ; height of the matrices
* @param k number of neighbor to consider
* @param dist_host distances to k nearest neighbors ; pointer to linear matrix
* @param dist_host indexes of the k nearest neighbors ; pointer to linear matrix
*
*/
void knn_kernels_launcher(const float* ref_dev, int ref_nb, const float* query_dev, int query_nb,
int dim, int k, float* dist_dev, long* ind_dev, cudaStream_t stream){
// Grids ans threads
dim3 g_16x16(query_nb / BLOCK_DIM, ref_nb / BLOCK_DIM, 1);
dim3 t_16x16(BLOCK_DIM, BLOCK_DIM, 1);
if (query_nb % BLOCK_DIM != 0) g_16x16.x += 1;
if (ref_nb % BLOCK_DIM != 0) g_16x16.y += 1;
//
dim3 g_256x1(query_nb / 256, 1, 1);
dim3 t_256x1(256, 1, 1);
if (query_nb%256 != 0) g_256x1.x += 1;
dim3 g_k_16x16(query_nb / BLOCK_DIM, k / BLOCK_DIM, 1);
dim3 t_k_16x16(BLOCK_DIM, BLOCK_DIM, 1);
if (query_nb % BLOCK_DIM != 0) g_k_16x16.x += 1;
if (k % BLOCK_DIM != 0) g_k_16x16.y += 1;
// Kernel 1: Compute all the distances
cuComputeDistanceGlobal<<<g_16x16, t_16x16, 0, stream>>>(ref_dev, ref_nb,
query_dev, query_nb, dim, dist_dev);
#if DEBUG
printf("Pre insertionSort\n");
debug(dist_dev, ind_dev, query_nb, k);
#endif
// Kernel 2: Sort each column
cuInsertionSort<<<g_256x1, t_256x1, 0, stream>>>(dist_dev, ind_dev, query_nb, ref_nb, k);
#if DEBUG
printf("Post insertionSort\n");
debug(dist_dev, ind_dev, query_nb, k);
#endif
// Kernel 3: Compute square root of k first elements
cuParallelSqrt<<<g_k_16x16,t_k_16x16, 0, stream>>>(dist_dev, query_nb, k);
}
...@@ -221,6 +221,11 @@ if __name__ == '__main__': ...@@ -221,6 +221,11 @@ if __name__ == '__main__':
module='mmdet3d.ops.ball_query', module='mmdet3d.ops.ball_query',
sources=['src/ball_query.cpp'], sources=['src/ball_query.cpp'],
sources_cuda=['src/ball_query_cuda.cu']), sources_cuda=['src/ball_query_cuda.cu']),
make_cuda_ext(
name='knn_ext',
module='mmdet3d.ops.knn',
sources=['src/knn.cpp'],
sources_cuda=['src/knn_cuda.cu']),
make_cuda_ext( make_cuda_ext(
name='group_points_ext', name='group_points_ext',
module='mmdet3d.ops.group_points', module='mmdet3d.ops.group_points',
......
...@@ -3,7 +3,7 @@ import torch ...@@ -3,7 +3,7 @@ import torch
from mmdet3d.ops import (ball_query, furthest_point_sample, from mmdet3d.ops import (ball_query, furthest_point_sample,
furthest_point_sample_with_dist, gather_points, furthest_point_sample_with_dist, gather_points,
grouping_operation, three_interpolate, three_nn) grouping_operation, knn, three_interpolate, three_nn)
def test_fps(): def test_fps():
...@@ -73,6 +73,51 @@ def test_ball_query(): ...@@ -73,6 +73,51 @@ def test_ball_query():
assert torch.all(idx == expected_idx) assert torch.all(idx == expected_idx)
def test_knn():
if not torch.cuda.is_available():
pytest.skip()
new_xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625],
[-2.2769, 2.7817, -0.2334],
[-0.4003, 2.4666, -0.5116],
[-0.0740, 1.3147, -1.3625],
[-0.0740, 1.3147, -1.3625]],
[[-2.0289, 2.4952, -0.1708],
[-2.0668, 6.0278, -0.4875],
[0.4066, 1.4211, -0.2947],
[-2.0289, 2.4952, -0.1708],
[-2.0289, 2.4952, -0.1708]]]).cuda()
xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634],
[-0.4003, 2.4666,
-0.5116], [-0.5251, 2.4379, -0.8466],
[-0.9691, 1.1418,
-1.3733], [-0.2232, 0.9561, -1.3626],
[-2.2769, 2.7817, -0.2334],
[-0.2822, 1.3192, -1.3645], [0.1533, 1.5024, -1.0432],
[0.4917, 1.1529, -1.3496]],
[[-2.0289, 2.4952,
-0.1708], [-0.7188, 0.9956, -0.5096],
[-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610],
[0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791],
[-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947],
[0.3220, 1.4447, 0.3548], [-0.9744, 2.3856,
-1.2000]]]).cuda()
idx = knn(5, xyz, new_xyz)
new_xyz_ = new_xyz.unsqueeze(2).repeat(1, 1, xyz.shape[1], 1)
xyz_ = xyz.unsqueeze(1).repeat(1, new_xyz.shape[1], 1, 1)
dist = ((new_xyz_ - xyz_) * (new_xyz_ - xyz_)).sum(-1)
expected_idx = dist.topk(k=5, dim=2, largest=False)[1].transpose(2, 1)
assert torch.all(idx == expected_idx)
idx = knn(5, xyz, xyz)
xyz_ = xyz.unsqueeze(2).repeat(1, 1, xyz.shape[1], 1)
xyz__ = xyz.unsqueeze(1).repeat(1, xyz.shape[1], 1, 1)
dist = ((xyz_ - xyz__) * (xyz_ - xyz__)).sum(-1)
expected_idx = dist.topk(k=5, dim=2, largest=False)[1].transpose(2, 1)
assert torch.all(idx == expected_idx)
def test_grouping_points(): def test_grouping_points():
if not torch.cuda.is_available(): if not torch.cuda.is_available():
pytest.skip() pytest.skip()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment