knn_cuda.cu 4.2 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
#include "radius_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "utils.cuh"

rusty1s's avatar
rusty1s committed
7
#define THREADS 256
rusty1s's avatar
rusty1s committed
8
9
10

template <typename scalar_t> struct Cosine {
  static inline __device__ scalar_t dot(const scalar_t *a, const scalar_t *b,
rusty1s's avatar
rusty1s committed
11
                                        int64_t n_a, int64_t n_b,
rusty1s's avatar
rusty1s committed
12
13
14
                                        int64_t size) {
    scalar_t result = 0;
    for (int64_t i = 0; i < size; i++) {
rusty1s's avatar
rusty1s committed
15
      result += a[n_a * size + i] * b[n_b * size + i];
rusty1s's avatar
rusty1s committed
16
17
18
19
    }
    return result;
  }

rusty1s's avatar
rusty1s committed
20
21
  static inline __device__ scalar_t norm(const scalar_t *a, int64_t n_a,
                                         int64_t size) {
rusty1s's avatar
rusty1s committed
22
23
    scalar_t result = 0;
    for (int64_t i = 0; i < size; i++) {
rusty1s's avatar
rusty1s committed
24
      result += a[n_a * size + i] * a[n_a * size + i];
rusty1s's avatar
rusty1s committed
25
26
27
28
29
    }
    return sqrt(result);
  }
};

rusty1s's avatar
rusty1s committed
30
31
32
33
template <typename scalar_t>
__global__ void
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
           const int64_t *__restrict__ ptr_x, const int64_t *__restrict__ ptr_y,
rusty1s's avatar
rusty1s committed
34
35
36
           int64_t *__restrict__ row, int64_t *__restrict__ col,
           const int64_t k, const int64_t n, const int64_t m, const int64_t dim,
           const int64_t num_examples, const bool cosine) {
rusty1s's avatar
rusty1s committed
37
38
39
40
41
42
43

  const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
  if (n_y >= m)
    return;

  const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples);

rusty1s's avatar
rusty1s committed
44
45
46
47
  scalar_t best_dist[100];
  int64_t best_idx[100];

  for (int e = 0; e < k; e++) {
Matthias Fey's avatar
Matthias Fey committed
48
    best_dist[e] = 5e4;
rusty1s's avatar
rusty1s committed
49
50
51
    best_idx[e] = -1;
  }

rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
59
60
61
62
63
  for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
    scalar_t tmp_dist = 0;

    if (cosine) {
      tmp_dist = Cosine<scalar_t>::dot(x, y, n_x, n_y, dim) /
                 (Cosine<scalar_t>::norm(x, n_x, dim) *
                  Cosine<scalar_t>::norm(y, n_y, dim));
      tmp_dist = 1. - tmp_dist;
    } else {
      for (int64_t d = 0; d < dim; d++) {
        tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
                    (x[n_x * dim + d] - y[n_y * dim + d]);
rusty1s's avatar
rusty1s committed
64
      }
rusty1s's avatar
rusty1s committed
65
    }
rusty1s's avatar
rusty1s committed
66

rusty1s's avatar
rusty1s committed
67
    for (int64_t e1 = 0; e1 < k; e1++) {
rusty1s's avatar
rusty1s committed
68
      if (best_dist[e1] > tmp_dist) {
rusty1s's avatar
rusty1s committed
69
        for (int64_t e2 = k - 1; e2 > e1; e2--) {
rusty1s's avatar
rusty1s committed
70
71
          best_dist[e2] = best_dist[e2 - 1];
          best_idx[e2] = best_idx[e2 - 1];
rusty1s's avatar
rusty1s committed
72
        }
rusty1s's avatar
rusty1s committed
73
74
        best_dist[e1] = tmp_dist;
        best_idx[e1] = n_x;
rusty1s's avatar
rusty1s committed
75
        break;
rusty1s's avatar
rusty1s committed
76
77
78
      }
    }
  }
rusty1s's avatar
rusty1s committed
79
80
81
82
83

  for (int64_t e = 0; e < k; e++) {
    row[n_y * k + e] = n_y;
    col[n_y * k + e] = best_idx[e];
  }
rusty1s's avatar
rusty1s committed
84
85
}

rusty1s's avatar
rusty1s committed
86
torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y,
rusty1s's avatar
rusty1s committed
87
                       torch::optional<torch::Tensor> ptr_x,
rusty1s's avatar
rusty1s committed
88
89
                       torch::optional<torch::Tensor> ptr_y, const int64_t k,
                       const bool cosine) {
rusty1s's avatar
rusty1s committed
90

rusty1s's avatar
rusty1s committed
91
  CHECK_CUDA(x);
rusty1s's avatar
rusty1s committed
92
  CHECK_CONTIGUOUS(x);
rusty1s's avatar
rusty1s committed
93
  CHECK_INPUT(x.dim() == 2);
rusty1s's avatar
rusty1s committed
94
  CHECK_CUDA(y);
rusty1s's avatar
rusty1s committed
95
  CHECK_CONTIGUOUS(y);
rusty1s's avatar
rusty1s committed
96
  CHECK_INPUT(y.dim() == 2);
rusty1s's avatar
rusty1s committed
97
  CHECK_INPUT(x.size(1) == y.size(1));
rusty1s's avatar
rusty1s committed
98
  AT_ASSERTM(k <= 100, "`k` needs to smaller than or equal to 100");
rusty1s's avatar
rusty1s committed
99

rusty1s's avatar
rusty1s committed
100
101
102
  if (ptr_x.has_value()) {
    CHECK_CUDA(ptr_x.value());
    CHECK_INPUT(ptr_x.value().dim() == 1);
rusty1s's avatar
rusty1s committed
103
  } else
rusty1s's avatar
rusty1s committed
104
105
    ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
                          x.options().dtype(torch::kLong));
rusty1s's avatar
rusty1s committed
106

rusty1s's avatar
rusty1s committed
107
108
109
  if (ptr_y.has_value()) {
    CHECK_CUDA(ptr_y.value());
    CHECK_INPUT(ptr_y.value().dim() == 1);
rusty1s's avatar
rusty1s committed
110
  } else
rusty1s's avatar
rusty1s committed
111
112
    ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
                          y.options().dtype(torch::kLong));
rusty1s's avatar
rusty1s committed
113

rusty1s's avatar
rusty1s committed
114
  CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
rusty1s's avatar
rusty1s committed
115

rusty1s's avatar
rusty1s committed
116
117
  cudaSetDevice(x.get_device());

rusty1s's avatar
rusty1s committed
118
119
  auto row = torch::empty(y.size(0) * k, ptr_y.value().options());
  auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
rusty1s's avatar
rusty1s committed
120

rusty1s's avatar
rusty1s committed
121
122
  dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);

rusty1s's avatar
rusty1s committed
123
  auto stream = at::cuda::getCurrentCUDAStream();
Matthias Fey's avatar
Matthias Fey committed
124
125
  auto scalar_type = x.scalar_type();
  AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
rusty1s's avatar
rusty1s committed
126
    knn_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
rusty1s's avatar
rusty1s committed
127
        x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
rusty1s's avatar
rusty1s committed
128
        ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
rusty1s's avatar
rusty1s committed
129
130
        row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), k, x.size(0),
        y.size(0), x.size(1), ptr_x.value().numel() - 1, cosine);
rusty1s's avatar
rusty1s committed
131
132
133
  });

  auto mask = col != -1;
rusty1s's avatar
rusty1s committed
134
  return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
rusty1s's avatar
rusty1s committed
135
}