knn_cuda.cu 4.28 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
#include "radius_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "utils.cuh"

rusty1s's avatar
rusty1s committed
7
#define THREADS 256
rusty1s's avatar
rusty1s committed
8
9
10

template <typename scalar_t> struct Cosine {
  static inline __device__ scalar_t dot(const scalar_t *a, const scalar_t *b,
rusty1s's avatar
rusty1s committed
11
                                        int64_t n_a, int64_t n_b,
rusty1s's avatar
rusty1s committed
12
13
14
                                        int64_t size) {
    scalar_t result = 0;
    for (int64_t i = 0; i < size; i++) {
rusty1s's avatar
rusty1s committed
15
      result += a[n_a * size + i] * b[n_b * size + i];
rusty1s's avatar
rusty1s committed
16
17
18
19
    }
    return result;
  }

rusty1s's avatar
rusty1s committed
20
21
  static inline __device__ scalar_t norm(const scalar_t *a, int64_t n_a,
                                         int64_t size) {
rusty1s's avatar
rusty1s committed
22
23
    scalar_t result = 0;
    for (int64_t i = 0; i < size; i++) {
rusty1s's avatar
rusty1s committed
24
      result += a[n_a * size + i] * a[n_a * size + i];
rusty1s's avatar
rusty1s committed
25
26
27
28
29
    }
    return sqrt(result);
  }
};

rusty1s's avatar
rusty1s committed
30
31
32
33
34
35
36
37
__device__ int64_t get_example_idx(int64_t idx, const int64_t *ptr,
                                   const int64_t num_examples) {
  for (int64_t i = 0; i < num_examples; i++) {
    if (ptr[i + 1] > idx)
      return i;
  }
  return num_examples - 1;
}
rusty1s's avatar
rusty1s committed
38

rusty1s's avatar
rusty1s committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
template <typename scalar_t>
__global__ void
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
           const int64_t *__restrict__ ptr_x, const int64_t *__restrict__ ptr_y,
           scalar_t *__restrict__ dist, int64_t *__restrict__ row,
           int64_t *__restrict__ col, const int64_t k, const int64_t n,
           const int64_t m, const int64_t dim, const int64_t num_examples,
           const bool cosine) {

  const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
  if (n_y >= m)
    return;

  for (int64_t e = 0; e < k; e++)
    row[n_y * k + e] = n_y;

  const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples);

  for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
    scalar_t tmp_dist = 0;

    if (cosine) {
      tmp_dist = Cosine<scalar_t>::dot(x, y, n_x, n_y, dim) /
                 (Cosine<scalar_t>::norm(x, n_x, dim) *
                  Cosine<scalar_t>::norm(y, n_y, dim));
      tmp_dist = 1. - tmp_dist;
    } else {
      for (int64_t d = 0; d < dim; d++) {
        tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
                    (x[n_x * dim + d] - y[n_y * dim + d]);
rusty1s's avatar
rusty1s committed
69
      }
rusty1s's avatar
rusty1s committed
70
    }
rusty1s's avatar
rusty1s committed
71

rusty1s's avatar
rusty1s committed
72
73
74
75
76
    for (int64_t e1 = 0; e1 < k; e1++) {
      if (dist[n_y * k + e1] > tmp_dist) {
        for (int64_t e2 = k - 1; e2 > e1; e2--) {
          dist[n_y * k + e2] = dist[n_y * k + e2 - 1];
          col[n_y * k + e2] = col[n_y * k + e2 - 1];
rusty1s's avatar
rusty1s committed
77
        }
rusty1s's avatar
rusty1s committed
78
79
80
        dist[n_y * k + e1] = tmp_dist;
        col[n_y * k + e1] = n_x;
        break;
rusty1s's avatar
rusty1s committed
81
82
83
84
85
      }
    }
  }
}

rusty1s's avatar
rusty1s committed
86
torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
rusty1s's avatar
rusty1s committed
87
                       torch::optional<torch::Tensor> ptr_x,
rusty1s's avatar
rusty1s committed
88
89
                       torch::optional<torch::Tensor> ptr_y, const int64_t k,
                       const bool cosine) {
rusty1s's avatar
rusty1s committed
90

rusty1s's avatar
rusty1s committed
91
  CHECK_CUDA(x);
rusty1s's avatar
rusty1s committed
92
  CHECK_INPUT(x.dim() == 2);
rusty1s's avatar
rusty1s committed
93
  CHECK_CUDA(y);
rusty1s's avatar
rusty1s committed
94
  CHECK_INPUT(y.dim() == 2);
rusty1s's avatar
rusty1s committed
95
  CHECK_INPUT(x.size(1) == y.size(1));
rusty1s's avatar
rusty1s committed
96

rusty1s's avatar
rusty1s committed
97
98
99
  if (ptr_x.has_value()) {
    CHECK_CUDA(ptr_x.value());
    CHECK_INPUT(ptr_x.value().dim() == 1);
rusty1s's avatar
rusty1s committed
100
  } else
rusty1s's avatar
rusty1s committed
101
102
    ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
                          x.options().dtype(torch::kLong));
rusty1s's avatar
rusty1s committed
103

rusty1s's avatar
rusty1s committed
104
105
106
  if (ptr_y.has_value()) {
    CHECK_CUDA(ptr_y.value());
    CHECK_INPUT(ptr_y.value().dim() == 1);
rusty1s's avatar
rusty1s committed
107
  } else
rusty1s's avatar
rusty1s committed
108
109
    ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
                          y.options().dtype(torch::kLong));
rusty1s's avatar
rusty1s committed
110

rusty1s's avatar
rusty1s committed
111
  CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
rusty1s's avatar
rusty1s committed
112

rusty1s's avatar
rusty1s committed
113
114
115
  cudaSetDevice(x.get_device());

  auto dist = torch::full(y.size(0) * k, 1e10, y.options());
rusty1s's avatar
rusty1s committed
116
117
  auto row = torch::empty(y.size(0) * k, ptr_y.value().options());
  auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
rusty1s's avatar
rusty1s committed
118

rusty1s's avatar
rusty1s committed
119
120
  dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);

rusty1s's avatar
rusty1s committed
121
122
  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
rusty1s's avatar
rusty1s committed
123
    knn_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
rusty1s's avatar
rusty1s committed
124
        x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
rusty1s's avatar
rusty1s committed
125
        ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
rusty1s's avatar
rusty1s committed
126
        dist.data_ptr<scalar_t>(), row.data_ptr<int64_t>(),
rusty1s's avatar
rusty1s committed
127
128
        col.data_ptr<int64_t>(), k, x.size(0), y.size(0), x.size(1),
        ptr_x.value().numel() - 1, cosine);
rusty1s's avatar
rusty1s committed
129
130
131
  });

  auto mask = col != -1;
rusty1s's avatar
rusty1s committed
132
  return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
rusty1s's avatar
rusty1s committed
133
}