Commit 7587ce48 authored by rusty1s's avatar rusty1s
Browse files

knn implementation

parent 787eaef6
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random);
at::Tensor fps(at::Tensor x, at::Tensor batch, float ratio, bool random) {
CHECK_CUDA(x);
IS_CONTIGUOUS(x);
CHECK_CUDA(batch);
return fps_cuda(x, batch, ratio, random);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fps", &fps, "Farthest Point Sampling (CUDA)");
}
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
at::Tensor batch_y, bool cosine);
at::Tensor knn(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
at::Tensor batch_y, bool cosine) {
CHECK_CUDA(x);
IS_CONTIGUOUS(x);
CHECK_CUDA(y);
IS_CONTIGUOUS(y);
CHECK_CUDA(batch_x);
CHECK_CUDA(batch_y);
return knn_cuda(x, y, k, batch_x, batch_y, cosine);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("knn", &knn, "k-Nearest Neighbor (CUDA)");
}
#include "radius_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
template <typename scalar_t> struct Cosine {
static inline __device__ scalar_t dot(const scalar_t *a, const scalar_t *b,
int64_t size) {
scalar_t result = 0;
for (int64_t i = 0; i < size; i++) {
result += a[i] * b[i];
}
return result;
}
static inline __device__ scalar_t norm(const scalar_t *a, int64_t size) {
scalar_t result = 0;
for (int64_t i = 0; i < size; i++) {
result += a[i] * a[i];
}
return sqrt(result);
}
};
template <typename scalar_t>
__global__ void knn_kernel(const scalar_t *x, const scalar_t *y,
const int64_t *ptr_x, const int64_t *ptr_y,
scalar_t *dist, int64_t *row, int64_t *col,
int64_t K, int64_t dim, bool cosine) {
const int64_t batch_idx = blockIdx.x;
const int64_t x_start_idx = ptr_x[batch_idx];
const int64_t x_end_idx = ptr_x[batch_idx + 1];
const int64_t y_start_idx = ptr_y[batch_idx];
const int64_t y_end_idx = ptr_y[batch_idx + 1];
for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
n_y += THREADS) {
for (int64_t k = 0; k < K; k++) {
row[n_y * K + k] = n_y;
}
for (int64_t n_x = x_start_idx; n_x < x_end_idx; n_x++) {
scalar_t tmp_dist = 0;
if (cosine) {
tmp_dist =
Cosine<scalar_t>::norm(x, dim) * Cosine<scalar_t>::norm(y, dim) -
Cosine<scalar_t>::dot(x, y, dim);
} else {
for (int64_t d = 0; d < dim; d++) {
tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
}
for (int64_t k_idx_1 = 0; k_idx_1 < K; k_idx_1++) {
if (dist[n_y * K + k_idx_1] > tmp_dist) {
for (ptrdiff_t k_idx_2 = K - 1; k_idx_2 > k_idx_1; k_idx_2--) {
dist[n_y * K + k_idx_2] = dist[n_y * K + k_idx_2 - 1];
col[n_y * K + k_idx_2] = col[n_y * K + k_idx_2 - 1];
}
dist[n_y * K + k_idx_1] = tmp_dist;
col[n_y * K + k_idx_1] = n_x;
break;
}
}
}
}
}
torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, int64_t k, bool cosine) {
CHECK_CUDA(x);
CHECK_CUDA(y);
CHECK_CUDA(ptr_x);
CHECK_CUDA(ptr_y);
cudaSetDevice(x.get_device());
x = x.view({x.size(0), -1}).contiguous();
y = y.view({y.size(0), -1}).contiguous();
auto dist = torch::full(y.size(0) * k, 1e38, y.options());
auto row = torch::empty(y.size(0) * k, ptr_y.options());
auto col = torch::full(y.size(0) * k, -1, ptr_y.options());
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
knn_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
dist.data_ptr<scalar_t>(), row.data_ptr<int64_t>(),
col.data_ptr<int64_t>(), k, x.size(1), cosine);
});
auto mask = col != -1;
return at::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
}
#pragma once
#include <torch/extension.h>
torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
torch::Tensor ptr_y, int64_t k, bool cosine);
#include <ATen/ATen.h>
#include "compat.cuh"
#include "utils.cuh"
#define THREADS 1024
template <typename scalar_t> struct Cosine {
static inline __device__ scalar_t dot(const scalar_t *a, const scalar_t *b,
size_t size) {
scalar_t result = 0;
for (ptrdiff_t i = 0; i < size; i++) {
result += a[i] * b[i];
}
return result;
}
static inline __device__ scalar_t norm(const scalar_t *a, size_t size) {
scalar_t result = 0;
for (ptrdiff_t i = 0; i < size; i++) {
result += a[i] * a[i];
}
return sqrt(result);
}
};
template <typename scalar_t>
__global__ void
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
const int64_t *__restrict__ batch_x,
const int64_t *__restrict__ batch_y, scalar_t *__restrict__ dist,
int64_t *__restrict__ row, int64_t *__restrict__ col, size_t k,
size_t dim, bool cosine) {
const ptrdiff_t batch_idx = blockIdx.x;
const ptrdiff_t idx = threadIdx.x;
const ptrdiff_t start_idx_x = batch_x[batch_idx];
const ptrdiff_t end_idx_x = batch_x[batch_idx + 1];
const ptrdiff_t start_idx_y = batch_y[batch_idx];
const ptrdiff_t end_idx_y = batch_y[batch_idx + 1];
for (ptrdiff_t n_y = start_idx_y + idx; n_y < end_idx_y; n_y += THREADS) {
for (ptrdiff_t k_idx = 0; k_idx < k; k_idx++) {
row[n_y * k + k_idx] = n_y;
}
for (ptrdiff_t n_x = start_idx_x; n_x < end_idx_x; n_x++) {
scalar_t tmp_dist = 0;
if (cosine) {
tmp_dist =
Cosine<scalar_t>::norm(x, dim) * Cosine<scalar_t>::norm(y, dim) -
Cosine<scalar_t>::dot(x, y, dim);
} else {
for (ptrdiff_t d = 0; d < dim; d++) {
tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
(x[n_x * dim + d] - y[n_y * dim + d]);
}
}
for (ptrdiff_t k_idx_1 = 0; k_idx_1 < k; k_idx_1++) {
if (dist[n_y * k + k_idx_1] > tmp_dist) {
for (ptrdiff_t k_idx_2 = k - 1; k_idx_2 > k_idx_1; k_idx_2--) {
dist[n_y * k + k_idx_2] = dist[n_y * k + k_idx_2 - 1];
col[n_y * k + k_idx_2] = col[n_y * k + k_idx_2 - 1];
}
dist[n_y * k + k_idx_1] = tmp_dist;
col[n_y * k + k_idx_1] = n_x;
break;
}
}
}
}
}
at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
at::Tensor batch_y, bool cosine) {
cudaSetDevice(x.get_device());
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(batch_sizes, batch_x[-1].DATA_PTR<int64_t>(), sizeof(int64_t),
cudaMemcpyDeviceToHost);
auto batch_size = batch_sizes[0] + 1;
batch_x = degree(batch_x, batch_size);
batch_x = at::cat({at::zeros(1, batch_x.options()), batch_x.cumsum(0)}, 0);
batch_y = degree(batch_y, batch_size);
batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0);
auto dist = at::full(y.size(0) * k, 1e38, y.options());
auto row = at::empty(y.size(0) * k, batch_y.options());
auto col = at::full(y.size(0) * k, -1, batch_y.options());
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
knn_kernel<scalar_t><<<batch_size, THREADS>>>(
x.DATA_PTR<scalar_t>(), y.DATA_PTR<scalar_t>(),
batch_x.DATA_PTR<int64_t>(), batch_y.DATA_PTR<int64_t>(),
dist.DATA_PTR<scalar_t>(), row.DATA_PTR<int64_t>(),
col.DATA_PTR<int64_t>(), k, x.size(1), cosine);
});
auto mask = col != -1;
return at::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
}
...@@ -13,13 +13,12 @@ __global__ void radius_kernel(const scalar_t *x, const scalar_t *y, ...@@ -13,13 +13,12 @@ __global__ void radius_kernel(const scalar_t *x, const scalar_t *y,
int64_t max_num_neighbors, int64_t dim) { int64_t max_num_neighbors, int64_t dim) {
const int64_t batch_idx = blockIdx.x; const int64_t batch_idx = blockIdx.x;
// const ptrdiff_t idx = threadIdx.x;
const ptrdiff_t x_start_idx = ptr_x[batch_idx]; const int64_t x_start_idx = ptr_x[batch_idx];
const ptrdiff_t x_end_idx = ptr_x[batch_idx + 1]; const int64_t x_end_idx = ptr_x[batch_idx + 1];
const ptrdiff_t y_start_idx = ptr_y[batch_idx]; const int64_t y_start_idx = ptr_y[batch_idx];
const ptrdiff_t y_end_idx = ptr_y[batch_idx + 1]; const int64_t y_end_idx = ptr_y[batch_idx + 1];
for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx; for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
n_y += THREADS) { n_y += THREADS) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment