nearest_kernel.cu 2.5 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>

rusty1s's avatar
rusty1s committed
3
4
#include "utils.cuh"

rusty1s's avatar
rusty1s committed
5
6
#define THREADS 1024

rusty1s's avatar
rusty1s committed
7
template <typename scalar_t>
rusty1s's avatar
rusty1s committed
8
9
10
__global__ void
nearest_kernel(scalar_t *__restrict__ x, scalar_t *__restrict__ y,
               int64_t *__restrict__ batch_x, int64_t *__restrict__ batch_y,
rusty1s's avatar
rusty1s committed
11
12
               scalar_t *__restrict__ out, int64_t *__restrict__ out_idx,
               size_t dim) {
rusty1s's avatar
rusty1s committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

  const ptrdiff_t n_x = blockIdx.x;
  const ptrdiff_t batch_idx = batch_x[n_x];
  const ptrdiff_t idx = threadIdx.x;

  const ptrdiff_t start_idx = batch_y[batch_idx];
  const ptrdiff_t end_idx = batch_y[batch_idx + 1];

  __shared__ scalar_t best_dist[THREADS];
  __shared__ int64_t best_dist_idx[THREADS];

  scalar_t best = 1e38;
  ptrdiff_t best_idx = 0;
  for (ptrdiff_t n_y = start_idx + idx; n_y < end_idx; n_y += THREADS) {

    scalar_t dist = 0;
    for (ptrdiff_t d = 0; d < dim; d++) {
      dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
              (x[n_x * dim + d] - y[n_y * dim + d]);
    }
rusty1s's avatar
rusty1s committed
33
    dist = sqrt(dist);
rusty1s's avatar
rusty1s committed
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

    if (dist < best) {
      best = dist;
      best_idx = n_y;
    }
  }

  best_dist[idx] = best;
  best_dist_idx[idx] = best_idx;

  for (int64_t u = 0; (1 << u) < THREADS; u++) {
    __syncthreads();
    if (idx < (THREADS >> (u + 1))) {
      int64_t idx_1 = (idx * 2) << u;
      int64_t idx_2 = (idx * 2 + 1) << u;
      if (best_dist[idx_1] > best_dist[idx_2]) {
        best_dist[idx_1] = best_dist[idx_2];
        best_dist_idx[idx_1] = best_dist_idx[idx_2];
      }
    }
  }

  __syncthreads();
  if (idx == 0) {
rusty1s's avatar
rusty1s committed
58
59
    out[n_x] = best_dist[0];
    out_idx[n_x] = best_dist_idx[0];
rusty1s's avatar
rusty1s committed
60
61
62
  }
}

rusty1s's avatar
rusty1s committed
63
64
65
std::tuple<at::Tensor, at::Tensor> nearest_cuda(at::Tensor x, at::Tensor y,
                                                at::Tensor batch_x,
                                                at::Tensor batch_y) {
rusty1s's avatar
rusty1s committed
66
67
68
69
70
71
72
73
  auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
  cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
             cudaMemcpyDeviceToHost);
  auto batch_size = batch_sizes[0] + 1;

  batch_y = degree(batch_y, batch_size);
  batch_y = at::cat({at::zeros(1, batch_y.options()), batch_y.cumsum(0)}, 0);

rusty1s's avatar
rusty1s committed
74
75
  auto out = at::empty(x.size(0), x.options());
  auto out_idx = at::empty_like(batch_x);
rusty1s's avatar
rusty1s committed
76
77

  AT_DISPATCH_FLOATING_TYPES(x.type(), "fps_kernel", [&] {
rusty1s's avatar
rusty1s committed
78
    nearest_kernel<scalar_t><<<x.size(0), THREADS>>>(
rusty1s's avatar
rusty1s committed
79
        x.data<scalar_t>(), y.data<scalar_t>(), batch_x.data<int64_t>(),
rusty1s's avatar
rusty1s committed
80
81
        batch_y.data<int64_t>(), out.data<scalar_t>(), out_idx.data<int64_t>(),
        x.size(1));
rusty1s's avatar
rusty1s committed
82
83
  });

rusty1s's avatar
rusty1s committed
84
  return std::make_tuple(out, out_idx);
rusty1s's avatar
rusty1s committed
85
}