Unverified Commit 53271e3d authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Enhance] Faster implementation of KNN (#586)

* add knn_heap gpu op support

* add unit test

* remove old knn & rename knn_heap to knn

* interface consistency
parent b3e792bc
...@@ -5,7 +5,9 @@ from . import knn_ext ...@@ -5,7 +5,9 @@ from . import knn_ext
class KNN(Function): class KNN(Function):
"""KNN (CUDA). r"""KNN (CUDA) based on heap data structure.
Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/
scene_seg/lib/pointops/src/knnquery_heap>`_.
Find k-nearest points. Find k-nearest points.
""" """
...@@ -14,9 +16,9 @@ class KNN(Function): ...@@ -14,9 +16,9 @@ class KNN(Function):
def forward(ctx, def forward(ctx,
k: int, k: int,
xyz: torch.Tensor, xyz: torch.Tensor,
center_xyz: torch.Tensor, center_xyz: torch.Tensor = None,
transposed: bool = False) -> torch.Tensor: transposed: bool = False) -> torch.Tensor:
"""forward. """Forward.
Args: Args:
k (int): number of nearest neighbors. k (int): number of nearest neighbors.
...@@ -34,15 +36,15 @@ class KNN(Function): ...@@ -34,15 +36,15 @@ class KNN(Function):
""" """
assert k > 0 assert k > 0
if not transposed: if center_xyz is None:
center_xyz = xyz
if transposed:
xyz = xyz.transpose(2, 1).contiguous() xyz = xyz.transpose(2, 1).contiguous()
center_xyz = center_xyz.transpose(2, 1).contiguous() center_xyz = center_xyz.transpose(2, 1).contiguous()
B, _, npoint = center_xyz.shape assert xyz.is_contiguous() # [B, N, 3]
N = xyz.shape[2] assert center_xyz.is_contiguous() # [B, npoint, 3]
assert center_xyz.is_contiguous()
assert xyz.is_contiguous()
center_xyz_device = center_xyz.get_device() center_xyz_device = center_xyz.get_device()
assert center_xyz_device == xyz.get_device(), \ assert center_xyz_device == xyz.get_device(), \
...@@ -50,20 +52,21 @@ class KNN(Function): ...@@ -50,20 +52,21 @@ class KNN(Function):
if torch.cuda.current_device() != center_xyz_device: if torch.cuda.current_device() != center_xyz_device:
torch.cuda.set_device(center_xyz_device) torch.cuda.set_device(center_xyz_device)
idx = center_xyz.new_zeros((B, k, npoint)).long() B, npoint, _ = center_xyz.shape
N = xyz.shape[1]
for bi in range(B): idx = center_xyz.new_zeros((B, npoint, k)).int()
knn_ext.knn_wrapper(xyz[bi], N, center_xyz[bi], npoint, idx[bi], k) dist2 = center_xyz.new_zeros((B, npoint, k)).float()
knn_ext.knn_wrapper(B, N, npoint, k, xyz, center_xyz, idx, dist2)
# idx shape to [B, k, npoint]
idx = idx.transpose(2, 1).contiguous()
ctx.mark_non_differentiable(idx) ctx.mark_non_differentiable(idx)
idx -= 1
return idx return idx
@staticmethod @staticmethod
def backward(ctx, a=None): def backward(ctx, a=None):
return None, None return None, None, None
knn = KNN.apply knn = KNN.apply
// Modified from https://github.com/unlimblue/KNN_CUDA // Modified from https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap
#include <vector> #include <torch/serialize/tensor.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <vector>
#include <THC/THC.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") extern THCState *state;
#define CHECK_TYPE(x, t) AT_ASSERTM(x.dtype() == t, #x " must be " #t)
#define CHECK_CUDA(x) AT_ASSERTM(x.device().type() == at::Device::Type::CUDA, #x " must be on CUDA") #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_INPUT(x, t) CHECK_CONTIGUOUS(x); CHECK_TYPE(x, t); CHECK_CUDA(x) #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
void knn_kernels_launcher(
const float* ref_dev, void knn_kernel_launcher(
int ref_nb, int b,
const float* query_dev, int n,
int query_nb, int m,
int dim, int nsample,
int k, const float *xyz,
float* dist_dev, const float *new_xyz,
long* ind_dev, int *idx,
float *dist2,
cudaStream_t stream cudaStream_t stream
); );
// std::vector<at::Tensor> knn_wrapper( void knn_wrapper(int b, int n, int m, int nsample, at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, at::Tensor idx_tensor, at::Tensor dist2_tensor)
void knn_wrapper( {
at::Tensor & ref, CHECK_INPUT(new_xyz_tensor);
int ref_nb, CHECK_INPUT(xyz_tensor);
at::Tensor & query,
int query_nb, const float *new_xyz = new_xyz_tensor.data_ptr<float>();
at::Tensor & ind, const float *xyz = xyz_tensor.data_ptr<float>();
const int k int *idx = idx_tensor.data_ptr<int>();
) { float *dist2 = dist2_tensor.data_ptr<float>();
CHECK_INPUT(ref, at::kFloat);
CHECK_INPUT(query, at::kFloat);
const float * ref_dev = ref.data_ptr<float>();
const float * query_dev = query.data_ptr<float>();
int dim = query.size(0);
auto dist = at::empty({ref_nb, query_nb}, query.options().dtype(at::kFloat));
float * dist_dev = dist.data_ptr<float>();
long * ind_dev = ind.data_ptr<long>();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
knn_kernels_launcher( knn_kernel_launcher(b, n, m, nsample, xyz, new_xyz, idx, dist2, stream);
ref_dev,
ref_nb,
query_dev,
query_nb,
dim,
k,
dist_dev,
ind_dev,
stream
);
} }
......
/** Modified from https://github.com/unlimblue/KNN_CUDA // Modified from https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap
* which is the modified version of knn-CUDA
* from https://github.com/vincentfpgarcia/kNN-CUDA
* Last modified by Christopher B. Choy <chrischoy@ai.stanford.edu> 12/23/2016
* vincentfpgarcia wrote the original cuda code, Christopher modified it and
* set it up for pytorch 0.4, and unlimblue updated it to pytorch >= 1.0
*/
// Includes #include <cmath>
#include <cstdio> #include <cstdio>
#include "cuda.h"
// Constants used by the program #define THREADS_PER_BLOCK 256
#define BLOCK_DIM 16 #define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
#define DEBUG 0
/**
* Computes the distance between two matrix A (reference points) and
* B (query points) containing respectively wA and wB points.
*
* @param A pointer on the matrix A
* @param wA width of the matrix A = number of points in A
* @param B pointer on the matrix B
* @param wB width of the matrix B = number of points in B
* @param dim dimension of points = height of matrices A and B
* @param AB pointer on the matrix containing the wA*wB distances computed
*/
__global__ void cuComputeDistanceGlobal(const float* A, int wA,
const float* B, int wB, int dim, float* AB){
// Declaration of the shared memory arrays As and Bs used to store the sub-matrix of A and B __device__ void swap_float(float *x, float *y)
__shared__ float shared_A[BLOCK_DIM][BLOCK_DIM]; {
__shared__ float shared_B[BLOCK_DIM][BLOCK_DIM]; float tmp = *x;
*x = *y;
// Sub-matrix of A (begin, step, end) and Sub-matrix of B (begin, step) *y = tmp;
__shared__ int begin_A; }
__shared__ int begin_B;
__shared__ int step_A;
__shared__ int step_B;
__shared__ int end_A;
// Thread index
int tx = threadIdx.x;
int ty = threadIdx.y;
// Other variables
float tmp;
float ssd = 0;
// Loop parameters __device__ void swap_int(int *x, int *y)
begin_A = BLOCK_DIM * blockIdx.y; {
begin_B = BLOCK_DIM * blockIdx.x; int tmp = *x;
step_A = BLOCK_DIM * wA; *x = *y;
step_B = BLOCK_DIM * wB; *y = tmp;
end_A = begin_A + (dim-1) * wA; }
// Conditions
int cond0 = (begin_A + tx < wA); // used to write in shared memory
int cond1 = (begin_B + tx < wB); // used to write in shared memory & to computations and to write in output matrix
int cond2 = (begin_A + ty < wA); // used to computations and to write in output matrix
// Loop over all the sub-matrices of A and B required to compute the block sub-matrix __device__ void reheap(float *dist, int *idx, int k)
for (int a = begin_A, b = begin_B; a <= end_A; a += step_A, b += step_B) { {
// Load the matrices from device memory to shared memory; each thread loads one element of each matrix int root = 0;
if (a/wA + ty < dim){ int child = root * 2 + 1;
shared_A[ty][tx] = (cond0)? A[a + wA * ty + tx] : 0; while (child < k)
shared_B[ty][tx] = (cond1)? B[b + wB * ty + tx] : 0; {
} if(child + 1 < k && dist[child+1] > dist[child])
else{ child++;
shared_A[ty][tx] = 0; if(dist[root] > dist[child])
shared_B[ty][tx] = 0; return;
swap_float(&dist[root], &dist[child]);
swap_int(&idx[root], &idx[child]);
root = child;
child = root * 2 + 1;
} }
}
// Synchronize to make sure the matrices are loaded
__syncthreads();
// Compute the difference between the two matrixes; each thread computes one element of the block sub-matrix __device__ void heap_sort(float *dist, int *idx, int k)
if (cond2 && cond1){ {
for (int k = 0; k < BLOCK_DIM; ++k){ int i;
tmp = shared_A[k][ty] - shared_B[k][tx]; for (i = k - 1; i > 0; i--)
ssd += tmp*tmp; {
} swap_float(&dist[0], &dist[i]);
swap_int(&idx[0], &idx[i]);
reheap(dist, idx, i);
} }
// Synchronize to make sure that the preceding computation is done before loading two new sub-matrices of A and B in the next iteration
__syncthreads();
}
// Write the block sub-matrix to device memory; each thread writes one element
if (cond2 && cond1)
AB[(begin_A + ty) * wB + begin_B + tx] = ssd;
} }
/** // input: xyz (b, n, 3) new_xyz (b, m, 3)
* Gathers k-th smallest distances for each column of the distance matrix in the top. // output: idx (b, m, nsample) dist2 (b, m, nsample)
* __global__ void knn_kernel(int b, int n, int m, int nsample, const float *__restrict__ xyz, const float *__restrict__ new_xyz, int *__restrict__ idx, float *__restrict__ dist2) {
* @param dist distance matrix int bs_idx = blockIdx.y;
* @param ind index matrix int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
* @param width width of the distance matrix and of the index matrix if (bs_idx >= b || pt_idx >= m) return;
* @param height height of the distance matrix and of the index matrix
* @param k number of neighbors to consider
*/
__global__ void cuInsertionSort(float *dist, long *ind, int width, int height, int k){
// Variables new_xyz += bs_idx * m * 3 + pt_idx * 3;
int l, i, j; xyz += bs_idx * n * 3;
float *p_dist; idx += bs_idx * m * nsample + pt_idx * nsample;
long *p_ind; dist2 += bs_idx * m * nsample + pt_idx * nsample;
float curr_dist, max_dist;
long curr_row, max_row;
unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x;
if (xIndex<width){
// Pointer shift, initialization, and max value
p_dist = dist + xIndex;
p_ind = ind + xIndex;
max_dist = p_dist[0];
p_ind[0] = 1;
// Part 1 : sort kth firt elementZ float new_x = new_xyz[0];
for (l=1; l<k; l++){ float new_y = new_xyz[1];
curr_row = l * width; float new_z = new_xyz[2];
curr_dist = p_dist[curr_row];
if (curr_dist<max_dist){
i=l-1;
for (int a=0; a<l-1; a++){
if (p_dist[a*width]>curr_dist){
i=a;
break;
}
}
for (j=l; j>i; j--){
p_dist[j*width] = p_dist[(j-1)*width];
p_ind[j*width] = p_ind[(j-1)*width];
}
p_dist[i*width] = curr_dist;
p_ind[i*width] = l + 1;
} else {
p_ind[l*width] = l + 1;
}
max_dist = p_dist[curr_row];
}
// Part 2 : insert element in the k-th first lines float best_dist[100];
max_row = (k-1)*width; int best_idx[100];
for (l=k; l<height; l++){ for(int i = 0; i < nsample; i++){
curr_dist = p_dist[l*width]; best_dist[i] = 1e10;
if (curr_dist<max_dist){ best_idx[i] = 0;
i=k-1; }
for (int a=0; a<k-1; a++){ for(int i = 0; i < n; i++){
if (p_dist[a*width]>curr_dist){ float x = xyz[i * 3 + 0];
i=a; float y = xyz[i * 3 + 1];
break; float z = xyz[i * 3 + 2];
} float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
} if (d2 < best_dist[0]){
for (j=k-1; j>i; j--){ best_dist[0] = d2;
p_dist[j*width] = p_dist[(j-1)*width]; best_idx[0] = i;
p_ind[j*width] = p_ind[(j-1)*width]; reheap(best_dist, best_idx, nsample);
} }
p_dist[i*width] = curr_dist;
p_ind[i*width] = l + 1;
max_dist = p_dist[max_row];
}
} }
} heap_sort(best_dist, best_idx, nsample);
} for(int i = 0; i < nsample; i++){
idx[i] = best_idx[i];
dist2[i] = best_dist[i];
/**
* Computes the square root of the first line (width-th first element)
* of the distance matrix.
*
* @param dist distance matrix
* @param width width of the distance matrix
* @param k number of neighbors to consider
*/
__global__ void cuParallelSqrt(float *dist, int width, int k){
unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int yIndex = blockIdx.y * blockDim.y + threadIdx.y;
if (xIndex<width && yIndex<k)
dist[yIndex*width + xIndex] = sqrt(dist[yIndex*width + xIndex]);
}
void debug(float * dist_dev, long * ind_dev, const int query_nb, const int k){
float* dist_host = new float[query_nb * k];
long* idx_host = new long[query_nb * k];
// Memory copy of output from device to host
cudaMemcpy(dist_host, dist_dev,
query_nb * k * sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(idx_host, ind_dev,
query_nb * k * sizeof(long), cudaMemcpyDeviceToHost);
int i, j;
for(i = 0; i < k; i++){
for (j = 0; j < query_nb; j++) {
if (j % 8 == 0)
printf("/\n");
printf("%f ", sqrt(dist_host[i*query_nb + j]));
} }
printf("\n");
}
} }
void knn_kernel_launcher(int b, int n, int m, int nsample, const float *xyz, const float *new_xyz, int *idx, float *dist2, cudaStream_t stream) {
// param new_xyz: (B, m, 3)
// param xyz: (B, n, 3)
// param idx: (B, m, nsample)
//-----------------------------------------------------------------------------------------------// cudaError_t err;
// K-th NEAREST NEIGHBORS //
//-----------------------------------------------------------------------------------------------//
/**
* K nearest neighbor algorithm
* - Initialize CUDA
* - Allocate device memory
* - Copy point sets (reference and query points) from host to device memory
* - Compute the distances + indexes to the k nearest neighbors for each query point
* - Copy distances from device to host memory
*
* @param ref_host reference points ; pointer to linear matrix
* @param ref_nb number of reference points ; width of the matrix
* @param query_host query points ; pointer to linear matrix
* @param query_nb number of query points ; width of the matrix
* @param dim dimension of points ; height of the matrices
* @param k number of neighbor to consider
* @param dist_host distances to k nearest neighbors ; pointer to linear matrix
* @param dist_host indexes of the k nearest neighbors ; pointer to linear matrix
*
*/
void knn_kernels_launcher(const float* ref_dev, int ref_nb, const float* query_dev, int query_nb,
int dim, int k, float* dist_dev, long* ind_dev, cudaStream_t stream){
// Grids ans threads
dim3 g_16x16(query_nb / BLOCK_DIM, ref_nb / BLOCK_DIM, 1);
dim3 t_16x16(BLOCK_DIM, BLOCK_DIM, 1);
if (query_nb % BLOCK_DIM != 0) g_16x16.x += 1;
if (ref_nb % BLOCK_DIM != 0) g_16x16.y += 1;
//
dim3 g_256x1(query_nb / 256, 1, 1);
dim3 t_256x1(256, 1, 1);
if (query_nb%256 != 0) g_256x1.x += 1;
dim3 g_k_16x16(query_nb / BLOCK_DIM, k / BLOCK_DIM, 1);
dim3 t_k_16x16(BLOCK_DIM, BLOCK_DIM, 1);
if (query_nb % BLOCK_DIM != 0) g_k_16x16.x += 1;
if (k % BLOCK_DIM != 0) g_k_16x16.y += 1;
// Kernel 1: Compute all the distances dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
cuComputeDistanceGlobal<<<g_16x16, t_16x16, 0, stream>>>(ref_dev, ref_nb, dim3 threads(THREADS_PER_BLOCK);
query_dev, query_nb, dim, dist_dev);
#if DEBUG knn_kernel<<<blocks, threads, 0, stream>>>(b, n, m, nsample, xyz, new_xyz, idx, dist2);
printf("Pre insertionSort\n"); // cudaDeviceSynchronize(); // for using printf in kernel function
debug(dist_dev, ind_dev, query_nb, k);
#endif
// Kernel 2: Sort each column err = cudaGetLastError();
cuInsertionSort<<<g_256x1, t_256x1, 0, stream>>>(dist_dev, ind_dev, query_nb, ref_nb, k); if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
#if DEBUG exit(-1);
printf("Post insertionSort\n"); }
debug(dist_dev, ind_dev, query_nb, k);
#endif
// Kernel 3: Compute square root of k first elements
cuParallelSqrt<<<g_k_16x16,t_k_16x16, 0, stream>>>(dist_dev, query_nb, k);
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment