Commit 0adaf7f9 authored by rusty1s's avatar rusty1s
Browse files

improve knn performance

parent 442e8d9c
......@@ -4,3 +4,4 @@
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
......@@ -4,7 +4,7 @@
#include "utils.cuh"
#define THREADS 1024
#define THREADS 256
template <typename scalar_t> struct Cosine {
static inline __device__ scalar_t dot(const scalar_t *a, const scalar_t *b,
......@@ -27,30 +27,36 @@ template <typename scalar_t> struct Cosine {
}
};
template <typename scalar_t>
__global__ void knn_kernel(const scalar_t *x, const scalar_t *y,
const int64_t *ptr_x, const int64_t *ptr_y,
scalar_t *dist, int64_t *row, int64_t *col,
int64_t K, int64_t dim, bool cosine) {
const int64_t batch_idx = blockIdx.x;
const int64_t x_start_idx = ptr_x[batch_idx];
const int64_t x_end_idx = ptr_x[batch_idx + 1];
__device__ int64_t get_example_idx(int64_t idx, const int64_t *ptr,
const int64_t num_examples) {
for (int64_t i = 0; i < num_examples; i++) {
if (ptr[i + 1] > idx)
return i;
}
return num_examples - 1;
}
const int64_t y_start_idx = ptr_y[batch_idx];
const int64_t y_end_idx = ptr_y[batch_idx + 1];
template <typename scalar_t>
__global__ void
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
const int64_t *__restrict__ ptr_x, const int64_t *__restrict__ ptr_y,
scalar_t *__restrict__ dist, int64_t *__restrict__ row,
int64_t *__restrict__ col, const int64_t k, const int64_t n,
const int64_t m, const int64_t dim, const int64_t num_examples,
const bool cosine) {
for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
n_y += THREADS) {
const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
if (n_y >= m)
return;
for (int64_t k = 0; k < K; k++) {
row[n_y * K + k] = n_y;
}
for (int64_t e = 0; e < k; e++)
row[n_y * k + e] = n_y;
for (int64_t n_x = x_start_idx; n_x < x_end_idx; n_x++) {
const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples);
for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
scalar_t tmp_dist = 0;
if (cosine) {
tmp_dist = Cosine<scalar_t>::dot(x, y, n_x, n_y, dim) /
(Cosine<scalar_t>::norm(x, n_x, dim) *
......@@ -63,59 +69,63 @@ __global__ void knn_kernel(const scalar_t *x, const scalar_t *y,
}
}
for (int64_t k_idx_1 = 0; k_idx_1 < K; k_idx_1++) {
if (dist[n_y * K + k_idx_1] > tmp_dist) {
for (ptrdiff_t k_idx_2 = K - 1; k_idx_2 > k_idx_1; k_idx_2--) {
dist[n_y * K + k_idx_2] = dist[n_y * K + k_idx_2 - 1];
col[n_y * K + k_idx_2] = col[n_y * K + k_idx_2 - 1];
for (int64_t e1 = 0; e1 < k; e1++) {
if (dist[n_y * k + e1] > tmp_dist) {
for (int64_t e2 = k - 1; e2 > e1; e2--) {
dist[n_y * k + e2] = dist[n_y * k + e2 - 1];
col[n_y * k + e2] = col[n_y * k + e2 - 1];
}
dist[n_y * K + k_idx_1] = tmp_dist;
col[n_y * K + k_idx_1] = n_x;
dist[n_y * k + e1] = tmp_dist;
col[n_y * k + e1] = n_x;
break;
}
}
}
}
}
torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y,
torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
torch::optional<torch::Tensor> ptr_x,
torch::optional<torch::Tensor> ptr_y, int64_t k,
bool cosine) {
torch::optional<torch::Tensor> ptr_y, const int64_t k,
const bool cosine) {
CHECK_CUDA(x);
CHECK_INPUT(x.dim() == 2);
CHECK_CUDA(y);
CHECK_INPUT(y.dim() == 2);
cudaSetDevice(x.get_device());
CHECK_INPUT(x.size(1) == y.size(1));
if (ptr_x.has_value()) {
CHECK_CUDA(ptr_x.value());
CHECK_INPUT(ptr_x.value().dim() == 1);
} else {
} else
ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
x.options().dtype(torch::kLong));
}
if (ptr_y.has_value()) {
CHECK_CUDA(ptr_y.value());
CHECK_INPUT(ptr_y.value().dim() == 1);
} else {
} else
ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
y.options().dtype(torch::kLong));
}
CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
auto dist = torch::full(y.size(0) * k, 1e38, y.options());
cudaSetDevice(x.get_device());
auto dist = torch::full(y.size(0) * k, 1e10, y.options());
auto row = torch::empty(y.size(0) * k, ptr_y.value().options());
auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
knn_kernel<scalar_t><<<ptr_x.value().size(0) - 1, THREADS, 0, stream>>>(
knn_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
dist.data_ptr<scalar_t>(), row.data_ptr<int64_t>(),
col.data_ptr<int64_t>(), k, x.size(1), cosine);
col.data_ptr<int64_t>(), k, x.size(0), y.size(0), x.size(1),
ptr_x.value().numel() - 1, cosine);
});
auto mask = col != -1;
......
......@@ -71,8 +71,7 @@ def test_knn_graph(dtype, device):
def test_knn_graph_large(dtype, device):
x = torch.randn(1000, 3, dtype=dtype, device=device)
edge_index = knn_graph(x, k=5, flow='target_to_source', loop=True,
num_workers=6)
edge_index = knn_graph(x, k=5, flow='target_to_source', loop=True)
tree = scipy.spatial.cKDTree(x.cpu().numpy())
_, col = tree.query(x.cpu(), k=5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment