knn_cuda.cu 3.23 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include "radius_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "utils.cuh"

#define THREADS 1024

template <typename scalar_t> struct Cosine {
  static inline __device__ scalar_t dot(const scalar_t *a, const scalar_t *b,
                                        int64_t size) {
    scalar_t result = 0;
    for (int64_t i = 0; i < size; i++) {
      result += a[i] * b[i];
    }
    return result;
  }

  static inline __device__ scalar_t norm(const scalar_t *a, int64_t size) {
    scalar_t result = 0;
    for (int64_t i = 0; i < size; i++) {
      result += a[i] * a[i];
    }
    return sqrt(result);
  }
};

template <typename scalar_t>
__global__ void knn_kernel(const scalar_t *x, const scalar_t *y,
                           const int64_t *ptr_x, const int64_t *ptr_y,
                           scalar_t *dist, int64_t *row, int64_t *col,
                           int64_t K, int64_t dim, bool cosine) {

  const int64_t batch_idx = blockIdx.x;

  const int64_t x_start_idx = ptr_x[batch_idx];
  const int64_t x_end_idx = ptr_x[batch_idx + 1];

  const int64_t y_start_idx = ptr_y[batch_idx];
  const int64_t y_end_idx = ptr_y[batch_idx + 1];

  for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
       n_y += THREADS) {

    for (int64_t k = 0; k < K; k++) {
      row[n_y * K + k] = n_y;
    }

    for (int64_t n_x = x_start_idx; n_x < x_end_idx; n_x++) {

      scalar_t tmp_dist = 0;
      if (cosine) {
        tmp_dist =
            Cosine<scalar_t>::norm(x, dim) * Cosine<scalar_t>::norm(y, dim) -
            Cosine<scalar_t>::dot(x, y, dim);
      } else {
        for (int64_t d = 0; d < dim; d++) {
          tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
                      (x[n_x * dim + d] - y[n_y * dim + d]);
        }
      }

      for (int64_t k_idx_1 = 0; k_idx_1 < K; k_idx_1++) {
        if (dist[n_y * K + k_idx_1] > tmp_dist) {
          for (ptrdiff_t k_idx_2 = K - 1; k_idx_2 > k_idx_1; k_idx_2--) {
            dist[n_y * K + k_idx_2] = dist[n_y * K + k_idx_2 - 1];
            col[n_y * K + k_idx_2] = col[n_y * K + k_idx_2 - 1];
          }
          dist[n_y * K + k_idx_1] = tmp_dist;
          col[n_y * K + k_idx_1] = n_x;
          break;
        }
      }
    }
  }
}

torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
                       torch::Tensor ptr_y, int64_t k, bool cosine) {
  CHECK_CUDA(x);
  CHECK_CUDA(y);
  CHECK_CUDA(ptr_x);
  CHECK_CUDA(ptr_y);
  cudaSetDevice(x.get_device());

  x = x.view({x.size(0), -1}).contiguous();
  y = y.view({y.size(0), -1}).contiguous();

  auto dist = torch::full(y.size(0) * k, 1e38, y.options());
  auto row = torch::empty(y.size(0) * k, ptr_y.options());
  auto col = torch::full(y.size(0) * k, -1, ptr_y.options());

  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
    knn_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>(
        x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
        ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
        dist.data_ptr<scalar_t>(), row.data_ptr<int64_t>(),
        col.data_ptr<int64_t>(), k, x.size(1), cosine);
  });

  auto mask = col != -1;
rusty1s's avatar
rusty1s committed
103
  return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
rusty1s's avatar
rusty1s committed
104
}