knn_kernel.cu 3.35 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
#include <ATen/ATen.h>

#include "utils.cuh"

#define THREADS 1024

7
// Code from https://github.com/adamantmc/CudaCosineSimilarity/blob/master/src/CudaCosineSimilarity.cu
8
9
10
template <typename scalar_t>
__global__ void
dot(double *a, double *b, size_t size) {
11
12
13
14
15
16
17
18
19
    double result = 0;

    for(int i = 0; i < size; i++) {
        result += a[i] * b[i];
    }

    return result;
}

20
21
22
23
template <typename scalar_t>
__global__ void
norm(double *a, size_t size) {
      double result = dot(a,a,size);
24
25
26
27
      result = sqrt(result);
      return result;
}

rusty1s's avatar
rusty1s committed
28
29
30
31
32
33
template <typename scalar_t>
__global__ void
knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
           const int64_t *__restrict__ batch_x,
           const int64_t *__restrict__ batch_y, scalar_t *__restrict__ dist,
           int64_t *__restrict__ row, int64_t *__restrict__ col, size_t k,
34
           size_t dim, bool cosine) {
rusty1s's avatar
rusty1s committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

  const ptrdiff_t batch_idx = blockIdx.x;
  const ptrdiff_t idx = threadIdx.x;

  const ptrdiff_t start_idx_x = batch_x[batch_idx];
  const ptrdiff_t end_idx_x = batch_x[batch_idx + 1];

  const ptrdiff_t start_idx_y = batch_y[batch_idx];
  const ptrdiff_t end_idx_y = batch_y[batch_idx + 1];

  for (ptrdiff_t n_y = start_idx_y + idx; n_y < end_idx_y; n_y += THREADS) {

    for (ptrdiff_t k_idx = 0; k_idx < k; k_idx++) {
      row[n_y * k + k_idx] = n_y;
    }

    for (ptrdiff_t n_x = start_idx_x; n_x < end_idx_x; n_x++) {

      scalar_t tmp_dist = 0;
54
      if (cosine) {
55
        tmp_dist = norm(x,dim)*norm(y,dim)-dot(x,y,dim)
rusty1s's avatar
rusty1s committed
56
      }
57
58
59
60
61
62
63
      else {
        for (ptrdiff_t d = 0; d < dim; d++) {
          tmp_dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
                      (x[n_x * dim + d] - y[n_y * dim + d]);
        }
      }

rusty1s's avatar
rusty1s committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

      for (ptrdiff_t k_idx_1 = 0; k_idx_1 < k; k_idx_1++) {
        if (dist[n_y * k + k_idx_1] > tmp_dist) {
          for (ptrdiff_t k_idx_2 = k - 1; k_idx_2 > k_idx_1; k_idx_2--) {
            dist[n_y * k + k_idx_2] = dist[n_y * k + k_idx_2 - 1];
            col[n_y * k + k_idx_2] = col[n_y * k + k_idx_2 - 1];
          }
          dist[n_y * k + k_idx_1] = tmp_dist;
          col[n_y * k + k_idx_1] = n_x;
          break;
        }
      }
    }
  }
}

at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
81
                    at::Tensor batch_y, bool cosine) {
rusty1s's avatar
rusty1s committed
82
  cudaSetDevice(x.get_device());
rusty1s's avatar
rusty1s committed
83
84
85
86
87
88
89
90
91
92
93
94
  auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
  cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
             cudaMemcpyDeviceToHost);
  auto batch_size = batch_sizes[0] + 1;

  batch_x = degree(batch_x, batch_size);
  batch_x = at::cat({at::zeros(1, batch_x.options()), batch_x.cumsum(0)}, 0);
  batch_y = degree(batch_y, batch_size);
  batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0);

  auto dist = at::full(y.size(0) * k, 1e38, y.options());
  auto row = at::empty(y.size(0) * k, batch_y.options());
rusty1s's avatar
rusty1s committed
95
  auto col = at::full(y.size(0) * k, -1, batch_y.options());
rusty1s's avatar
rusty1s committed
96

rusty1s's avatar
rusty1s committed
97
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
rusty1s's avatar
rusty1s committed
98
99
100
    knn_kernel<scalar_t><<<batch_size, THREADS>>>(
        x.data<scalar_t>(), y.data<scalar_t>(), batch_x.data<int64_t>(),
        batch_y.data<int64_t>(), dist.data<scalar_t>(), row.data<int64_t>(),
101
        col.data<int64_t>(), k, x.size(1), cosine);
rusty1s's avatar
rusty1s committed
102
103
  });

rusty1s's avatar
rusty1s committed
104
105
  auto mask = col != -1;
  return at::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
rusty1s's avatar
rusty1s committed
106
}