nearest_cuda.cu 2.43 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
#include "nearest_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "utils.cuh"

#define THREADS 1024

template <typename scalar_t>
__global__ void nearest_kernel(const scalar_t *x, const scalar_t *y,
rusty1s's avatar
rusty1s committed
11
                               const int64_t *ptr_x, const int64_t *ptr_y,
rusty1s's avatar
rusty1s committed
12
13
                               int64_t *out, int64_t batch_size, int64_t dim) {

rusty1s's avatar
rusty1s committed
14
  const int64_t thread_idx = threadIdx.x;
rusty1s's avatar
rusty1s committed
15
  const int64_t n_x = blockIdx.x;
rusty1s's avatar
rusty1s committed
16

rusty1s's avatar
rusty1s committed
17
  int64_t batch_idx;
rusty1s's avatar
rusty1s committed
18
19
  for (int64_t b = 0; b < batch_size; b++) {
    if (n_x >= ptr_x[b] && n_x < ptr_x[b + 1]) {
rusty1s's avatar
rusty1s committed
20
      batch_idx = b;
rusty1s's avatar
rusty1s committed
21
      break;
rusty1s's avatar
rusty1s committed
22
23
    }
  }
rusty1s's avatar
rusty1s committed
24
25
26
27
28
29
30
31
32

  const int64_t y_start_idx = ptr_y[batch_idx];
  const int64_t y_end_idx = ptr_y[batch_idx + 1];

  __shared__ scalar_t best_dist[THREADS];
  __shared__ int64_t best_dist_idx[THREADS];

  scalar_t best = 1e38;
  int64_t best_idx = 0;
rusty1s's avatar
rusty1s committed
33
  for (int64_t n_y = y_start_idx + thread_idx; n_y < y_end_idx;
rusty1s's avatar
rusty1s committed
34
       n_y += THREADS) {
rusty1s's avatar
rusty1s committed
35
36
37
38
39
40
41
42
43
44
45
46
    scalar_t dist = 0;
    for (int64_t d = 0; d < dim; d++) {
      dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
              (x[n_x * dim + d] - y[n_y * dim + d]);
    }

    if (dist < best) {
      best = dist;
      best_idx = n_y;
    }
  }

rusty1s's avatar
rusty1s committed
47
48
  best_dist[thread_idx] = best;
  best_dist_idx[thread_idx] = best_idx;
rusty1s's avatar
rusty1s committed
49

rusty1s's avatar
rusty1s committed
50
  for (int64_t u = 0; (1 << u) < THREADS; u++) {
rusty1s's avatar
rusty1s committed
51
    __syncthreads();
rusty1s's avatar
rusty1s committed
52
53
54
55
56
57
58
    if (thread_idx < (THREADS >> (u + 1))) {
      int64_t idx_1 = (thread_idx * 2) << u;
      int64_t idx_2 = (thread_idx * 2 + 1) << u;
      if (best_dist[idx_1] > best_dist[idx_2]) {
        best_dist[idx_1] = best_dist[idx_2];
        best_dist_idx[idx_1] = best_dist_idx[idx_2];
      }
rusty1s's avatar
rusty1s committed
59
60
61
62
    }
  }

  __syncthreads();
rusty1s's avatar
rusty1s committed
63
  if (thread_idx == 0) {
rusty1s's avatar
rusty1s committed
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    out[n_x] = best_dist_idx[0];
  }
}

torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y,
                           torch::Tensor ptr_x, torch::Tensor ptr_y) {
  CHECK_CUDA(x);
  CHECK_CUDA(y);
  CHECK_CUDA(ptr_x);
  CHECK_CUDA(ptr_y);
  cudaSetDevice(x.get_device());

  x = x.view({x.size(0), -1}).contiguous();
  y = y.view({y.size(0), -1}).contiguous();

rusty1s's avatar
typo  
rusty1s committed
79
  auto out = torch::empty({x.size(0)}, ptr_x.options());
rusty1s's avatar
rusty1s committed
80
81
82
83
84
85
86
87
88
89
90

  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "nearest_kernel", [&] {
    nearest_kernel<scalar_t><<<x.size(0), THREADS, 0, stream>>>(
        x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
        ptr_x.data_ptr<int64_t>(), ptr_y.data_ptr<int64_t>(),
        out.data_ptr<int64_t>(), ptr_x.size(0) - 1, x.size(1));
  });

  return out;
}