knn_cuda.cu 4.03 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
#include "radius_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "utils.cuh"

#define THREADS 1024

template <typename scalar_t> struct Cosine {
  static inline __device__ scalar_t dot(const scalar_t *a, const scalar_t *b,
rusty1s's avatar
rusty1s committed
11
                                        int64_t n_a, int64_t n_b,
rusty1s's avatar
rusty1s committed
12
13
14
                                        int64_t size) {
    scalar_t result = 0;
    for (int64_t i = 0; i < size; i++) {
rusty1s's avatar
rusty1s committed
15
      result += a[n_a * size + i] * b[n_b * size + i];
rusty1s's avatar
rusty1s committed
16
17
18
19
    }
    return result;
  }

rusty1s's avatar
rusty1s committed
20
21
  static inline __device__ scalar_t norm(const scalar_t *a, int64_t n_a,
                                         int64_t size) {
rusty1s's avatar
rusty1s committed
22
23
    scalar_t result = 0;
    for (int64_t i = 0; i < size; i++) {
rusty1s's avatar
rusty1s committed
24
      result += a[n_a * size + i] * a[n_a * size + i];
rusty1s's avatar
rusty1s committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    }
    return sqrt(result);
  }
};

template <typename scalar_t>
__global__ void knn_kernel(const scalar_t *x, const scalar_t *y,
                           const int64_t *ptr_x, const int64_t *ptr_y,
                           scalar_t *dist, int64_t *row, int64_t *col,
                           int64_t K, int64_t dim, bool cosine) {

  const int64_t batch_idx = blockIdx.x;

  const int64_t x_start_idx = ptr_x[batch_idx];
  const int64_t x_end_idx = ptr_x[batch_idx + 1];

  const int64_t y_start_idx = ptr_y[batch_idx];
  const int64_t y_end_idx = ptr_y[batch_idx + 1];

  for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
       n_y += THREADS) {

    for (int64_t k = 0; k < K; k++) {
      row[n_y * K + k] = n_y;
    }

    for (int64_t n_x = x_start_idx; n_x < x_end_idx; n_x++) {

      scalar_t tmp_dist = 0;
      if (cosine) {
rusty1s's avatar
rusty1s committed
55
56
57
58
        tmp_dist = Cosine<scalar_t>::dot(x, y, n_x, n_y, dim) /
                   (Cosine<scalar_t>::norm(x, n_x, dim) *
                    Cosine<scalar_t>::norm(y, n_y, dim));
        tmp_dist = 1. - tmp_dist;
rusty1s's avatar
rusty1s committed
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
      } else {
        for (int64_t d = 0; d < dim; d++) {
          tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
                      (x[n_x * dim + d] - y[n_y * dim + d]);
        }
      }

      for (int64_t k_idx_1 = 0; k_idx_1 < K; k_idx_1++) {
        if (dist[n_y * K + k_idx_1] > tmp_dist) {
          for (ptrdiff_t k_idx_2 = K - 1; k_idx_2 > k_idx_1; k_idx_2--) {
            dist[n_y * K + k_idx_2] = dist[n_y * K + k_idx_2 - 1];
            col[n_y * K + k_idx_2] = col[n_y * K + k_idx_2 - 1];
          }
          dist[n_y * K + k_idx_1] = tmp_dist;
          col[n_y * K + k_idx_1] = n_x;
          break;
        }
      }
    }
  }
}

rusty1s's avatar
rusty1s committed
81
82
83
84
85
torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y,
                       torch::optional<torch::Tensor> ptr_x,
                       torch::optional<torch::Tensor> ptr_y, int64_t k,
                       bool cosine) {

rusty1s's avatar
rusty1s committed
86
  CHECK_CUDA(x);
rusty1s's avatar
rusty1s committed
87
  CHECK_INPUT(x.dim() == 2);
rusty1s's avatar
rusty1s committed
88
  CHECK_CUDA(y);
rusty1s's avatar
rusty1s committed
89
  CHECK_INPUT(y.dim() == 2);
rusty1s's avatar
rusty1s committed
90
91
  cudaSetDevice(x.get_device());

rusty1s's avatar
rusty1s committed
92
93
94
95
  if (ptr_x.has_value()) {
    CHECK_CUDA(ptr_x.value());
    CHECK_INPUT(ptr_x.value().dim() == 1);
  } else {
rusty1s's avatar
rusty1s committed
96
97
    ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
                          x.options().dtype(torch::kLong));
rusty1s's avatar
rusty1s committed
98
99
100
101
102
  }
  if (ptr_y.has_value()) {
    CHECK_CUDA(ptr_y.value());
    CHECK_INPUT(ptr_y.value().dim() == 1);
  } else {
rusty1s's avatar
rusty1s committed
103
104
    ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
                          y.options().dtype(torch::kLong));
rusty1s's avatar
rusty1s committed
105
106
  }
  CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
rusty1s's avatar
rusty1s committed
107
108

  auto dist = torch::full(y.size(0) * k, 1e38, y.options());
rusty1s's avatar
rusty1s committed
109
110
  auto row = torch::empty(y.size(0) * k, ptr_y.value().options());
  auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options());
rusty1s's avatar
rusty1s committed
111
112
113

  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
rusty1s's avatar
rusty1s committed
114
    knn_kernel<scalar_t><<<ptr_x.value().size(0) - 1, THREADS, 0, stream>>>(
rusty1s's avatar
rusty1s committed
115
        x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
rusty1s's avatar
rusty1s committed
116
        ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
rusty1s's avatar
rusty1s committed
117
118
119
120
121
        dist.data_ptr<scalar_t>(), row.data_ptr<int64_t>(),
        col.data_ptr<int64_t>(), k, x.size(1), cosine);
  });

  auto mask = col != -1;
rusty1s's avatar
rusty1s committed
122
  return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
rusty1s's avatar
rusty1s committed
123
}