"git@developer.sourcefind.cn:change/sglang.git" did not exist on "c3d2ad4ee6881fbcf0e68a2b159f9c91a7a1451a"
Unverified Commit 53271e3d authored by Ziyi Wu's avatar Ziyi Wu Committed by GitHub
Browse files

[Enhance] Faster implementation of KNN (#586)

* add knn_heap gpu op support

* add unit test

* remove old knn & rename knn_heap to knn

* interface consistency
parent b3e792bc
......@@ -5,7 +5,9 @@ from . import knn_ext
class KNN(Function):
"""KNN (CUDA).
r"""KNN (CUDA) based on heap data structure.
Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/
scene_seg/lib/pointops/src/knnquery_heap>`_.
Find k-nearest points.
"""
......@@ -14,9 +16,9 @@ class KNN(Function):
def forward(ctx,
k: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor,
center_xyz: torch.Tensor = None,
transposed: bool = False) -> torch.Tensor:
"""forward.
"""Forward.
Args:
k (int): number of nearest neighbors.
......@@ -34,15 +36,15 @@ class KNN(Function):
"""
assert k > 0
if not transposed:
if center_xyz is None:
center_xyz = xyz
if transposed:
xyz = xyz.transpose(2, 1).contiguous()
center_xyz = center_xyz.transpose(2, 1).contiguous()
B, _, npoint = center_xyz.shape
N = xyz.shape[2]
assert center_xyz.is_contiguous()
assert xyz.is_contiguous()
assert xyz.is_contiguous() # [B, N, 3]
assert center_xyz.is_contiguous() # [B, npoint, 3]
center_xyz_device = center_xyz.get_device()
assert center_xyz_device == xyz.get_device(), \
......@@ -50,20 +52,21 @@ class KNN(Function):
if torch.cuda.current_device() != center_xyz_device:
torch.cuda.set_device(center_xyz_device)
idx = center_xyz.new_zeros((B, k, npoint)).long()
B, npoint, _ = center_xyz.shape
N = xyz.shape[1]
for bi in range(B):
knn_ext.knn_wrapper(xyz[bi], N, center_xyz[bi], npoint, idx[bi], k)
idx = center_xyz.new_zeros((B, npoint, k)).int()
dist2 = center_xyz.new_zeros((B, npoint, k)).float()
knn_ext.knn_wrapper(B, N, npoint, k, xyz, center_xyz, idx, dist2)
# idx shape to [B, k, npoint]
idx = idx.transpose(2, 1).contiguous()
ctx.mark_non_differentiable(idx)
idx -= 1
return idx
@staticmethod
def backward(ctx, a=None):
return None, None
return None, None, None
knn = KNN.apply
// Modified from https://github.com/unlimblue/KNN_CUDA
// Modified from https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap
#include <vector>
#include <torch/serialize/tensor.h>
#include <torch/extension.h>
#include <vector>
#include <THC/THC.h>
#include <ATen/cuda/CUDAContext.h>
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_TYPE(x, t) AT_ASSERTM(x.dtype() == t, #x " must be " #t)
#define CHECK_CUDA(x) AT_ASSERTM(x.device().type() == at::Device::Type::CUDA, #x " must be on CUDA")
#define CHECK_INPUT(x, t) CHECK_CONTIGUOUS(x); CHECK_TYPE(x, t); CHECK_CUDA(x)
void knn_kernels_launcher(
const float* ref_dev,
int ref_nb,
const float* query_dev,
int query_nb,
int dim,
int k,
float* dist_dev,
long* ind_dev,
extern THCState *state;
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
void knn_kernel_launcher(
int b,
int n,
int m,
int nsample,
const float *xyz,
const float *new_xyz,
int *idx,
float *dist2,
cudaStream_t stream
);
// std::vector<at::Tensor> knn_wrapper(
void knn_wrapper(
at::Tensor & ref,
int ref_nb,
at::Tensor & query,
int query_nb,
at::Tensor & ind,
const int k
) {
CHECK_INPUT(ref, at::kFloat);
CHECK_INPUT(query, at::kFloat);
const float * ref_dev = ref.data_ptr<float>();
const float * query_dev = query.data_ptr<float>();
int dim = query.size(0);
auto dist = at::empty({ref_nb, query_nb}, query.options().dtype(at::kFloat));
float * dist_dev = dist.data_ptr<float>();
long * ind_dev = ind.data_ptr<long>();
void knn_wrapper(int b, int n, int m, int nsample, at::Tensor xyz_tensor, at::Tensor new_xyz_tensor, at::Tensor idx_tensor, at::Tensor dist2_tensor)
{
CHECK_INPUT(new_xyz_tensor);
CHECK_INPUT(xyz_tensor);
const float *new_xyz = new_xyz_tensor.data_ptr<float>();
const float *xyz = xyz_tensor.data_ptr<float>();
int *idx = idx_tensor.data_ptr<int>();
float *dist2 = dist2_tensor.data_ptr<float>();
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
knn_kernels_launcher(
ref_dev,
ref_nb,
query_dev,
query_nb,
dim,
k,
dist_dev,
ind_dev,
stream
);
knn_kernel_launcher(b, n, m, nsample, xyz, new_xyz, idx, dist2, stream);
}
......
/** Modified from https://github.com/unlimblue/KNN_CUDA
* which is the modified version of knn-CUDA
* from https://github.com/vincentfpgarcia/kNN-CUDA
* Last modified by Christopher B. Choy <chrischoy@ai.stanford.edu> 12/23/2016
* vincentfpgarcia wrote the original cuda code, Christopher modified it and
* set it up for pytorch 0.4, and unlimblue updated it to pytorch >= 1.0
*/
// Modified from https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap
// Includes
#include <cmath>
#include <cstdio>
#include "cuda.h"
// Constants used by the program
#define BLOCK_DIM 16
#define DEBUG 0
#define THREADS_PER_BLOCK 256
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
/**
* Computes the distance between two matrix A (reference points) and
* B (query points) containing respectively wA and wB points.
*
* @param A pointer on the matrix A
* @param wA width of the matrix A = number of points in A
* @param B pointer on the matrix B
* @param wB width of the matrix B = number of points in B
* @param dim dimension of points = height of matrices A and B
* @param AB pointer on the matrix containing the wA*wB distances computed
*/
__global__ void cuComputeDistanceGlobal(const float* A, int wA,
const float* B, int wB, int dim, float* AB){
// Declaration of the shared memory arrays As and Bs used to store the sub-matrix of A and B
__shared__ float shared_A[BLOCK_DIM][BLOCK_DIM];
__shared__ float shared_B[BLOCK_DIM][BLOCK_DIM];
// Sub-matrix of A (begin, step, end) and Sub-matrix of B (begin, step)
__shared__ int begin_A;
__shared__ int begin_B;
__shared__ int step_A;
__shared__ int step_B;
__shared__ int end_A;
// Thread index
int tx = threadIdx.x;
int ty = threadIdx.y;
__device__ void swap_float(float *x, float *y)
{
float tmp = *x;
*x = *y;
*y = tmp;
}
// Other variables
float tmp;
float ssd = 0;
// Loop parameters
begin_A = BLOCK_DIM * blockIdx.y;
begin_B = BLOCK_DIM * blockIdx.x;
step_A = BLOCK_DIM * wA;
step_B = BLOCK_DIM * wB;
end_A = begin_A + (dim-1) * wA;
__device__ void swap_int(int *x, int *y)
{
int tmp = *x;
*x = *y;
*y = tmp;
}
// Conditions
int cond0 = (begin_A + tx < wA); // used to write in shared memory
int cond1 = (begin_B + tx < wB); // used to write in shared memory & to computations and to write in output matrix
int cond2 = (begin_A + ty < wA); // used to computations and to write in output matrix
// Loop over all the sub-matrices of A and B required to compute the block sub-matrix
for (int a = begin_A, b = begin_B; a <= end_A; a += step_A, b += step_B) {
// Load the matrices from device memory to shared memory; each thread loads one element of each matrix
if (a/wA + ty < dim){
shared_A[ty][tx] = (cond0)? A[a + wA * ty + tx] : 0;
shared_B[ty][tx] = (cond1)? B[b + wB * ty + tx] : 0;
}
else{
shared_A[ty][tx] = 0;
shared_B[ty][tx] = 0;
__device__ void reheap(float *dist, int *idx, int k)
{
int root = 0;
int child = root * 2 + 1;
while (child < k)
{
if(child + 1 < k && dist[child+1] > dist[child])
child++;
if(dist[root] > dist[child])
return;
swap_float(&dist[root], &dist[child]);
swap_int(&idx[root], &idx[child]);
root = child;
child = root * 2 + 1;
}
}
// Synchronize to make sure the matrices are loaded
__syncthreads();
// Compute the difference between the two matrixes; each thread computes one element of the block sub-matrix
if (cond2 && cond1){
for (int k = 0; k < BLOCK_DIM; ++k){
tmp = shared_A[k][ty] - shared_B[k][tx];
ssd += tmp*tmp;
}
__device__ void heap_sort(float *dist, int *idx, int k)
{
int i;
for (i = k - 1; i > 0; i--)
{
swap_float(&dist[0], &dist[i]);
swap_int(&idx[0], &idx[i]);
reheap(dist, idx, i);
}
// Synchronize to make sure that the preceding computation is done before loading two new sub-matrices of A and B in the next iteration
__syncthreads();
}
// Write the block sub-matrix to device memory; each thread writes one element
if (cond2 && cond1)
AB[(begin_A + ty) * wB + begin_B + tx] = ssd;
}
/**
* Gathers k-th smallest distances for each column of the distance matrix in the top.
*
* @param dist distance matrix
* @param ind index matrix
* @param width width of the distance matrix and of the index matrix
* @param height height of the distance matrix and of the index matrix
* @param k number of neighbors to consider
*/
__global__ void cuInsertionSort(float *dist, long *ind, int width, int height, int k){
// input: xyz (b, n, 3) new_xyz (b, m, 3)
// output: idx (b, m, nsample) dist2 (b, m, nsample)
__global__ void knn_kernel(int b, int n, int m, int nsample, const float *__restrict__ xyz, const float *__restrict__ new_xyz, int *__restrict__ idx, float *__restrict__ dist2) {
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= m) return;
// Variables
int l, i, j;
float *p_dist;
long *p_ind;
float curr_dist, max_dist;
long curr_row, max_row;
unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x;
if (xIndex<width){
// Pointer shift, initialization, and max value
p_dist = dist + xIndex;
p_ind = ind + xIndex;
max_dist = p_dist[0];
p_ind[0] = 1;
new_xyz += bs_idx * m * 3 + pt_idx * 3;
xyz += bs_idx * n * 3;
idx += bs_idx * m * nsample + pt_idx * nsample;
dist2 += bs_idx * m * nsample + pt_idx * nsample;
// Part 1 : sort kth firt elementZ
for (l=1; l<k; l++){
curr_row = l * width;
curr_dist = p_dist[curr_row];
if (curr_dist<max_dist){
i=l-1;
for (int a=0; a<l-1; a++){
if (p_dist[a*width]>curr_dist){
i=a;
break;
}
}
for (j=l; j>i; j--){
p_dist[j*width] = p_dist[(j-1)*width];
p_ind[j*width] = p_ind[(j-1)*width];
}
p_dist[i*width] = curr_dist;
p_ind[i*width] = l + 1;
} else {
p_ind[l*width] = l + 1;
}
max_dist = p_dist[curr_row];
}
float new_x = new_xyz[0];
float new_y = new_xyz[1];
float new_z = new_xyz[2];
// Part 2 : insert element in the k-th first lines
max_row = (k-1)*width;
for (l=k; l<height; l++){
curr_dist = p_dist[l*width];
if (curr_dist<max_dist){
i=k-1;
for (int a=0; a<k-1; a++){
if (p_dist[a*width]>curr_dist){
i=a;
break;
}
}
for (j=k-1; j>i; j--){
p_dist[j*width] = p_dist[(j-1)*width];
p_ind[j*width] = p_ind[(j-1)*width];
float best_dist[100];
int best_idx[100];
for(int i = 0; i < nsample; i++){
best_dist[i] = 1e10;
best_idx[i] = 0;
}
for(int i = 0; i < n; i++){
float x = xyz[i * 3 + 0];
float y = xyz[i * 3 + 1];
float z = xyz[i * 3 + 2];
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
if (d2 < best_dist[0]){
best_dist[0] = d2;
best_idx[0] = i;
reheap(best_dist, best_idx, nsample);
}
p_dist[i*width] = curr_dist;
p_ind[i*width] = l + 1;
max_dist = p_dist[max_row];
}
}
}
}
/**
* Computes the square root of the first line (width-th first element)
* of the distance matrix.
*
* @param dist distance matrix
* @param width width of the distance matrix
* @param k number of neighbors to consider
*/
__global__ void cuParallelSqrt(float *dist, int width, int k){
unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x;
unsigned int yIndex = blockIdx.y * blockDim.y + threadIdx.y;
if (xIndex<width && yIndex<k)
dist[yIndex*width + xIndex] = sqrt(dist[yIndex*width + xIndex]);
}
void debug(float * dist_dev, long * ind_dev, const int query_nb, const int k){
float* dist_host = new float[query_nb * k];
long* idx_host = new long[query_nb * k];
// Memory copy of output from device to host
cudaMemcpy(dist_host, dist_dev,
query_nb * k * sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(idx_host, ind_dev,
query_nb * k * sizeof(long), cudaMemcpyDeviceToHost);
int i, j;
for(i = 0; i < k; i++){
for (j = 0; j < query_nb; j++) {
if (j % 8 == 0)
printf("/\n");
printf("%f ", sqrt(dist_host[i*query_nb + j]));
heap_sort(best_dist, best_idx, nsample);
for(int i = 0; i < nsample; i++){
idx[i] = best_idx[i];
dist2[i] = best_dist[i];
}
printf("\n");
}
}
void knn_kernel_launcher(int b, int n, int m, int nsample, const float *xyz, const float *new_xyz, int *idx, float *dist2, cudaStream_t stream) {
// param new_xyz: (B, m, 3)
// param xyz: (B, n, 3)
// param idx: (B, m, nsample)
//-----------------------------------------------------------------------------------------------//
// K-th NEAREST NEIGHBORS //
//-----------------------------------------------------------------------------------------------//
/**
* K nearest neighbor algorithm
* - Initialize CUDA
* - Allocate device memory
* - Copy point sets (reference and query points) from host to device memory
* - Compute the distances + indexes to the k nearest neighbors for each query point
* - Copy distances from device to host memory
*
* @param ref_host reference points ; pointer to linear matrix
* @param ref_nb number of reference points ; width of the matrix
* @param query_host query points ; pointer to linear matrix
* @param query_nb number of query points ; width of the matrix
* @param dim dimension of points ; height of the matrices
* @param k number of neighbor to consider
* @param dist_host distances to k nearest neighbors ; pointer to linear matrix
* @param dist_host indexes of the k nearest neighbors ; pointer to linear matrix
*
*/
void knn_kernels_launcher(const float* ref_dev, int ref_nb, const float* query_dev, int query_nb,
int dim, int k, float* dist_dev, long* ind_dev, cudaStream_t stream){
// Grids ans threads
dim3 g_16x16(query_nb / BLOCK_DIM, ref_nb / BLOCK_DIM, 1);
dim3 t_16x16(BLOCK_DIM, BLOCK_DIM, 1);
if (query_nb % BLOCK_DIM != 0) g_16x16.x += 1;
if (ref_nb % BLOCK_DIM != 0) g_16x16.y += 1;
//
dim3 g_256x1(query_nb / 256, 1, 1);
dim3 t_256x1(256, 1, 1);
if (query_nb%256 != 0) g_256x1.x += 1;
dim3 g_k_16x16(query_nb / BLOCK_DIM, k / BLOCK_DIM, 1);
dim3 t_k_16x16(BLOCK_DIM, BLOCK_DIM, 1);
if (query_nb % BLOCK_DIM != 0) g_k_16x16.x += 1;
if (k % BLOCK_DIM != 0) g_k_16x16.y += 1;
cudaError_t err;
// Kernel 1: Compute all the distances
cuComputeDistanceGlobal<<<g_16x16, t_16x16, 0, stream>>>(ref_dev, ref_nb,
query_dev, query_nb, dim, dist_dev);
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK);
#if DEBUG
printf("Pre insertionSort\n");
debug(dist_dev, ind_dev, query_nb, k);
#endif
knn_kernel<<<blocks, threads, 0, stream>>>(b, n, m, nsample, xyz, new_xyz, idx, dist2);
// cudaDeviceSynchronize(); // for using printf in kernel function
// Kernel 2: Sort each column
cuInsertionSort<<<g_256x1, t_256x1, 0, stream>>>(dist_dev, ind_dev, query_nb, ref_nb, k);
#if DEBUG
printf("Post insertionSort\n");
debug(dist_dev, ind_dev, query_nb, k);
#endif
// Kernel 3: Compute square root of k first elements
cuParallelSqrt<<<g_k_16x16,t_k_16x16, 0, stream>>>(dist_dev, query_nb, k);
err = cudaGetLastError();
if (cudaSuccess != err) {
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment