nearest_cuda.cu 2.32 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include "nearest_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "utils.cuh"

#define THREADS 1024

template <typename scalar_t>
__global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
                               const int64_t *_ptr_x, const int64_t *_ptr_y,
                               int64_t *out, int64_t batch_size, int64_t dim) {

  const int64_t n_x = blockIdx.x;
  int64_t batch_idx;
  for (int64_t b = 0; b < batch_idx; b++)
    if (ptr_x[b] >= n_x and ptr_x[b + 1] < n_x)
      batch_idx = b;

  const int64_t y_start_idx = ptr_y[batch_idx];
  const int64_t y_end_idx = ptr_y[batch_idx + 1];

  __shared__ scalar_t best_dist[THREADS];
  __shared__ int64_t best_dist_idx[THREADS];

  scalar_t best = 1e38;
  int64_t best_idx = 0;
rusty1s's avatar
rusty1s committed
28
29
  for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
       n_y += THREADS) {
rusty1s's avatar
rusty1s committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    scalar_t dist = 0;
    for (int64_t d = 0; d < dim; d++) {
      dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
              (x[n_x * dim + d] - y[n_y * dim + d]);
    }

    if (dist < best) {
      best = dist;
      best_idx = n_y;
    }
  }

  best_dist[idx] = best;
  best_dist_idx[idx] = best_idx;

  for (int64_t u = 0; (1 << u) < THREADS; u++) {
    __syncthreads();
    if (idx < (THREADS >> (u + 1))) {
      int64_t idx_1 = (idx * 2) << u;
      int64_t idx_2 = (idx * 2 + 1) << u;
      if (best_dist[idx_1] > best_dist[idx_2]) {
        best_dist[idx_1] = best_dist[idx_2];
        best_dist_idx[idx_1] = best_dist_idx[idx_2];
      }
    }
  }

  __syncthreads();
  if (idx == 0) {
    out[n_x] = best_dist_idx[0];
  }
}

torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y,
                           torch::Tensor ptr_x, torch::Tensor ptr_y) {
  CHECK_CUDA(x);
  CHECK_CUDA(y);
  CHECK_CUDA(ptr_x);
  CHECK_CUDA(ptr_y);
  cudaSetDevice(x.get_device());

  x = x.view({x.size(0), -1}).contiguous();
  y = y.view({y.size(0), -1}).contiguous();

  auto out = torch::empty({x.size(0), ptr_x.options()});

  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "nearest_kernel", [&] {
    nearest_kernel<scalar_t><<<x.size(0), THREADS, 0, stream>>>(
        x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
        ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
        out.data_ptr<int64_t>(), ptr_x.size(0) - 1, x.size(1));
  });

  return out;
}