nearest_kernel.cu 2.42 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>

rusty1s's avatar
rusty1s committed
3
#include "compat.cuh"
rusty1s's avatar
rusty1s committed
4
5
#include "utils.cuh"

rusty1s's avatar
rusty1s committed
6
7
#define THREADS 1024

rusty1s's avatar
rusty1s committed
8
template <typename scalar_t>
rusty1s's avatar
typos  
rusty1s committed
9
10
11
12
13
__global__ void nearest_kernel(const scalar_t *__restrict__ x,
                               const scalar_t *__restrict__ y,
                               const int64_t *__restrict__ batch_x,
                               const int64_t *__restrict__ batch_y,
                               int64_t *__restrict__ out, const size_t dim) {
rusty1s's avatar
rusty1s committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

  const ptrdiff_t n_x = blockIdx.x;
  const ptrdiff_t batch_idx = batch_x[n_x];
  const ptrdiff_t idx = threadIdx.x;

  const ptrdiff_t start_idx = batch_y[batch_idx];
  const ptrdiff_t end_idx = batch_y[batch_idx + 1];

  __shared__ scalar_t best_dist[THREADS];
  __shared__ int64_t best_dist_idx[THREADS];

  scalar_t best = 1e38;
  ptrdiff_t best_idx = 0;
  for (ptrdiff_t n_y = start_idx + idx; n_y < end_idx; n_y += THREADS) {

    scalar_t dist = 0;
    for (ptrdiff_t d = 0; d < dim; d++) {
      dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
              (x[n_x * dim + d] - y[n_y * dim + d]);
    }

    if (dist < best) {
      best = dist;
      best_idx = n_y;
    }
  }

  best_dist[idx] = best;
  best_dist_idx[idx] = best_idx;

  for (int64_t u = 0; (1 << u) < THREADS; u++) {
    __syncthreads();
    if (idx < (THREADS >> (u + 1))) {
      int64_t idx_1 = (idx * 2) << u;
      int64_t idx_2 = (idx * 2 + 1) << u;
      if (best_dist[idx_1] > best_dist[idx_2]) {
        best_dist[idx_1] = best_dist[idx_2];
        best_dist_idx[idx_1] = best_dist_idx[idx_2];
      }
    }
  }

  __syncthreads();
  if (idx == 0) {
rusty1s's avatar
typos  
rusty1s committed
58
    out[n_x] = best_dist_idx[0];
rusty1s's avatar
rusty1s committed
59
60
61
  }
}

rusty1s's avatar
typos  
rusty1s committed
62
63
at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x,
                        at::Tensor batch_y) {
rusty1s's avatar
rusty1s committed
64
  cudaSetDevice(x.get_device());
rusty1s's avatar
rusty1s committed
65
  auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
rusty1s's avatar
rusty1s committed
66
  cudaMemcpy(batch_sizes, batch_x[-1].DATA_PTR<int64_t>(), sizeof(int64_t),
rusty1s's avatar
rusty1s committed
67
68
69
70
71
72
             cudaMemcpyDeviceToHost);
  auto batch_size = batch_sizes[0] + 1;

  batch_y = degree(batch_y, batch_size);
  batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0);

rusty1s's avatar
typos  
rusty1s committed
73
  auto out = at::empty_like(batch_x);
rusty1s's avatar
rusty1s committed
74

rusty1s's avatar
rusty1s committed
75
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "nearest_kernel", [&] {
rusty1s's avatar
rusty1s committed
76
    nearest_kernel<scalar_t><<<x.size(0), THREADS>>>(
rusty1s's avatar
rusty1s committed
77
78
79
        x.DATA_PTR<scalar_t>(), y.DATA_PTR<scalar_t>(),
        batch_x.DATA_PTR<int64_t>(), batch_y.DATA_PTR<int64_t>(),
        out.DATA_PTR<int64_t>(), x.size(1));
rusty1s's avatar
rusty1s committed
80
81
  });

rusty1s's avatar
typos  
rusty1s committed
82
  return out;
rusty1s's avatar
rusty1s committed
83
}