knn_kernel.cu 3.53 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>

rusty1s's avatar
rusty1s committed
3
#include "compat.cuh"
rusty1s's avatar
rusty1s committed
4
5
6
7
#include "utils.cuh"

#define THREADS 1024

rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
template <typename scalar_t> struct Cosine {
  static inline __device__ scalar_t dot(const scalar_t *a, const scalar_t *b,
                                        size_t size) {
    scalar_t result = 0;
    for (ptrdiff_t i = 0; i < size; i++) {
      result += a[i] * b[i];
14
15
    }
    return result;
rusty1s's avatar
rusty1s committed
16
  }
17

rusty1s's avatar
rusty1s committed
18
19
20
21
22
23
24
25
  static inline __device__ scalar_t norm(const scalar_t *a, size_t size) {
    scalar_t result = 0;
    for (ptrdiff_t i = 0; i < size; i++) {
      result += a[i] * a[i];
    }
    return sqrt(result);
  }
};
26

rusty1s's avatar
rusty1s committed
27
28
29
30
31
32
template <typename scalar_t>
__global__ void
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
           const int64_t *__restrict__ batch_x,
           const int64_t *__restrict__ batch_y, scalar_t *__restrict__ dist,
           int64_t *__restrict__ row, int64_t *__restrict__ col, size_t k,
33
           size_t dim, bool cosine) {
rusty1s's avatar
rusty1s committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52

  const ptrdiff_t batch_idx = blockIdx.x;
  const ptrdiff_t idx = threadIdx.x;

  const ptrdiff_t start_idx_x = batch_x[batch_idx];
  const ptrdiff_t end_idx_x = batch_x[batch_idx + 1];

  const ptrdiff_t start_idx_y = batch_y[batch_idx];
  const ptrdiff_t end_idx_y = batch_y[batch_idx + 1];

  for (ptrdiff_t n_y = start_idx_y + idx; n_y < end_idx_y; n_y += THREADS) {

    for (ptrdiff_t k_idx = 0; k_idx < k; k_idx++) {
      row[n_y * k + k_idx] = n_y;
    }

    for (ptrdiff_t n_x = start_idx_x; n_x < end_idx_x; n_x++) {

      scalar_t tmp_dist = 0;
53
      if (cosine) {
rusty1s's avatar
rusty1s committed
54
55
56
57
        tmp_dist =
            Cosine<scalar_t>::norm(x, dim) * Cosine<scalar_t>::norm(y, dim) -
            Cosine<scalar_t>::dot(x, y, dim);
      } else {
58
59
60
61
62
63
        for (ptrdiff_t d = 0; d < dim; d++) {
          tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
                      (x[n_x * dim + d] - y[n_y * dim + d]);
        }
      }

rusty1s's avatar
rusty1s committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
      for (ptrdiff_t k_idx_1 = 0; k_idx_1 < k; k_idx_1++) {
        if (dist[n_y * k + k_idx_1] > tmp_dist) {
          for (ptrdiff_t k_idx_2 = k - 1; k_idx_2 > k_idx_1; k_idx_2--) {
            dist[n_y * k + k_idx_2] = dist[n_y * k + k_idx_2 - 1];
            col[n_y * k + k_idx_2] = col[n_y * k + k_idx_2 - 1];
          }
          dist[n_y * k + k_idx_1] = tmp_dist;
          col[n_y * k + k_idx_1] = n_x;
          break;
        }
      }
    }
  }
}

at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
80
                    at::Tensor batch_y, bool cosine) {
rusty1s's avatar
rusty1s committed
81
  cudaSetDevice(x.get_device());
rusty1s's avatar
rusty1s committed
82
  auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
rusty1s's avatar
rusty1s committed
83
  cudaMemcpy(batch_sizes, batch_x[-1].DATA_PTR<int64_t>(), sizeof(int64_t),
rusty1s's avatar
rusty1s committed
84
85
86
87
88
89
90
91
92
93
             cudaMemcpyDeviceToHost);
  auto batch_size = batch_sizes[0] + 1;

  batch_x = degree(batch_x, batch_size);
  batch_x = at::cat({at::zeros(1, batch_x.options()), batch_x.cumsum(0)}, 0);
  batch_y = degree(batch_y, batch_size);
  batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0);

  auto dist = at::full(y.size(0) * k, 1e38, y.options());
  auto row = at::empty(y.size(0) * k, batch_y.options());
rusty1s's avatar
rusty1s committed
94
  auto col = at::full(y.size(0) * k, -1, batch_y.options());
rusty1s's avatar
rusty1s committed
95

rusty1s's avatar
rusty1s committed
96
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
rusty1s's avatar
rusty1s committed
97
    knn_kernel<scalar_t><<<batch_size, THREADS>>>(
rusty1s's avatar
rusty1s committed
98
99
100
101
        x.DATA_PTR<scalar_t>(), y.DATA_PTR<scalar_t>(),
        batch_x.DATA_PTR<int64_t>(), batch_y.DATA_PTR<int64_t>(),
        dist.DATA_PTR<scalar_t>(), row.DATA_PTR<int64_t>(),
        col.DATA_PTR<int64_t>(), k, x.size(1), cosine);
rusty1s's avatar
rusty1s committed
102
103
  });

rusty1s's avatar
rusty1s committed
104
105
  auto mask = col != -1;
  return at::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
rusty1s's avatar
rusty1s committed
106
}