radius_cuda.cu 3.03 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
#include "radius_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "utils.cuh"

rusty1s's avatar
rusty1s committed
7
#define THREADS 256
rusty1s's avatar
rusty1s committed
8
9

template <typename scalar_t>
rusty1s's avatar
rusty1s committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
__global__ void
radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
              const int64_t *__restrict__ ptr_x,
              const int64_t *__restrict__ ptr_y, int64_t *__restrict__ row,
              int64_t *__restrict__ col, const scalar_t r, const int64_t n,
              const int64_t m, const int64_t dim, const int64_t num_examples,
              const int64_t max_num_neighbors) {

  const int64_t n_y = blockIdx.x * blockDim.x + threadIdx.x;
  if (n_y >= m)
    return;

  int64_t count = 0;
  const int64_t example_idx = get_example_idx(n_y, ptr_y, num_examples);

  for (int64_t n_x = ptr_x[example_idx]; n_x < ptr_x[example_idx + 1]; n_x++) {
    scalar_t dist = 0;
    for (int64_t d = 0; d < dim; d++) {
      dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
              (x[n_x * dim + d] - y[n_y * dim + d]);
rusty1s's avatar
rusty1s committed
30
    }
rusty1s's avatar
rusty1s committed
31
32
33
34
35
36
37
38
39

    if (dist < r) {
      row[n_y * max_num_neighbors + count] = n_y;
      col[n_y * max_num_neighbors + count] = n_x;
      count++;
    }

    if (count >= max_num_neighbors)
      break;
rusty1s's avatar
rusty1s committed
40
41
42
  }
}

rusty1s's avatar
rusty1s committed
43
torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y,
rusty1s's avatar
rusty1s committed
44
                          torch::optional<torch::Tensor> ptr_x,
rusty1s's avatar
rusty1s committed
45
46
                          torch::optional<torch::Tensor> ptr_y, const double r,
                          const int64_t max_num_neighbors) {
rusty1s's avatar
rusty1s committed
47
  CHECK_CUDA(x);
rusty1s's avatar
rusty1s committed
48
  CHECK_CONTIGUOUS(x);
rusty1s's avatar
rusty1s committed
49
  CHECK_INPUT(x.dim() == 2);
rusty1s's avatar
rusty1s committed
50
  CHECK_CUDA(y);
rusty1s's avatar
rusty1s committed
51
  CHECK_CONTIGUOUS(y);
rusty1s's avatar
rusty1s committed
52
  CHECK_INPUT(y.dim() == 2);
rusty1s's avatar
rusty1s committed
53
54
  CHECK_INPUT(x.size(1) == y.size(1));

rusty1s's avatar
rusty1s committed
55
56
  cudaSetDevice(x.get_device());

rusty1s's avatar
rusty1s committed
57
58
59
  if (ptr_x.has_value()) {
    CHECK_CUDA(ptr_x.value());
    CHECK_INPUT(ptr_x.value().dim() == 1);
rusty1s's avatar
rusty1s committed
60
  } else
rusty1s's avatar
rusty1s committed
61
62
    ptr_x = torch::arange(0, x.size(0) + 1, x.size(0),
                          x.options().dtype(torch::kLong));
rusty1s's avatar
rusty1s committed
63

rusty1s's avatar
rusty1s committed
64
65
66
  if (ptr_y.has_value()) {
    CHECK_CUDA(ptr_y.value());
    CHECK_INPUT(ptr_y.value().dim() == 1);
rusty1s's avatar
rusty1s committed
67
  } else
rusty1s's avatar
rusty1s committed
68
69
    ptr_y = torch::arange(0, y.size(0) + 1, y.size(0),
                          y.options().dtype(torch::kLong));
rusty1s's avatar
rusty1s committed
70

rusty1s's avatar
rusty1s committed
71
  CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
rusty1s's avatar
rusty1s committed
72

rusty1s's avatar
rusty1s committed
73
74
  cudaSetDevice(x.get_device());

rusty1s's avatar
rusty1s committed
75
76
77
78
  auto row =
      torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options());
  auto col =
      torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options());
rusty1s's avatar
rusty1s committed
79

rusty1s's avatar
rusty1s committed
80
81
  dim3 BLOCKS((y.size(0) + THREADS - 1) / THREADS);

rusty1s's avatar
rusty1s committed
82
  auto stream = at::cuda::getCurrentCUDAStream();
Matthias Fey's avatar
Matthias Fey committed
83
84
  auto scalar_type = x.scalar_type();
  AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] {
rusty1s's avatar
rusty1s committed
85
    radius_kernel<scalar_t><<<BLOCKS, THREADS, 0, stream>>>(
rusty1s's avatar
rusty1s committed
86
        x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
rusty1s's avatar
rusty1s committed
87
        ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
rusty1s's avatar
rusty1s committed
88
89
        row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r * r, x.size(0),
        y.size(0), x.size(1), ptr_x.value().numel() - 1, max_num_neighbors);
rusty1s's avatar
rusty1s committed
90
91
92
93
94
  });

  auto mask = row != -1;
  return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
}