radius_cuda.cu 2.77 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#include "radius_cuda.h"

#include <ATen/cuda/CUDAContext.h>

#include "utils.cuh"

#define THREADS 1024

template <typename scalar_t>
__global__ void radius_kernel(const scalar_t *x, const scalar_t *y,
                              const int64_t *ptr_x, const int64_t *ptr_y,
                              int64_t *row, int64_t *col, scalar_t radius,
                              int64_t max_num_neighbors, int64_t dim) {

  const int64_t batch_idx = blockIdx.x;

rusty1s's avatar
rusty1s committed
17
18
  const int64_t x_start_idx = ptr_x[batch_idx];
  const int64_t x_end_idx = ptr_x[batch_idx + 1];
rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
rusty1s committed
20
21
  const int64_t y_start_idx = ptr_y[batch_idx];
  const int64_t y_end_idx = ptr_y[batch_idx + 1];
rusty1s's avatar
rusty1s committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

  for (int64_t n_y = y_start_idx + threadIdx.x; n_y < y_end_idx;
       n_y += THREADS) {
    int64_t count = 0;
    for (int64_t n_x = x_start_idx; n_x < x_end_idx; n_x++) {
      scalar_t dist = 0;
      for (int64_t d = 0; d < dim; d++) {
        dist += (x[n_x * dim + d] - y[n_y * dim + d]) *
                (x[n_x * dim + d] - y[n_y * dim + d]);
      }
      dist = sqrt(dist);

      if (dist <= radius) {
        row[n_y * max_num_neighbors + count] = n_y;
        col[n_y * max_num_neighbors + count] = n_x;
        count++;
      }

      if (count >= max_num_neighbors) {
        break;
      }
    }
  }
}

rusty1s's avatar
rusty1s committed
47
48
49
torch::Tensor radius_cuda(torch::Tensor x, torch::Tensor y,
                          torch::optional<torch::Tensor> ptr_x,
                          torch::optional<torch::Tensor> ptr_y, double r,
rusty1s's avatar
rusty1s committed
50
51
                          int64_t max_num_neighbors) {
  CHECK_CUDA(x);
rusty1s's avatar
rusty1s committed
52
  CHECK_INPUT(x.dim() == 2);
rusty1s's avatar
rusty1s committed
53
  CHECK_CUDA(y);
rusty1s's avatar
rusty1s committed
54
  CHECK_INPUT(y.dim() == 2);
rusty1s's avatar
rusty1s committed
55
56
  cudaSetDevice(x.get_device());

rusty1s's avatar
rusty1s committed
57
58
59
60
61
62
63
64
65
66
67
68
69
  if (ptr_x.has_value()) {
    CHECK_CUDA(ptr_x.value());
    CHECK_INPUT(ptr_x.value().dim() == 1);
  } else {
    ptr_x = torch::tensor({0, x.size(0)}, x.options().dtype(torch::kLong));
  }
  if (ptr_y.has_value()) {
    CHECK_CUDA(ptr_y.value());
    CHECK_INPUT(ptr_y.value().dim() == 1);
  } else {
    ptr_y = torch::tensor({0, y.size(0)}, y.options().dtype(torch::kLong));
  }
  CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel());
rusty1s's avatar
rusty1s committed
70

rusty1s's avatar
rusty1s committed
71
72
73
74
  auto row =
      torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options());
  auto col =
      torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options());
rusty1s's avatar
rusty1s committed
75
76
77
78
79

  auto stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "radius_kernel", [&] {
    radius_kernel<scalar_t><<<ptr_x.size(0) - 1, THREADS, 0, stream>>>(
        x.data_ptr<scalar_t>(), y.data_ptr<scalar_t>(),
rusty1s's avatar
rusty1s committed
80
        ptr_x.value().data_ptr<int64_t>(), ptr_y.value().data_ptr<int64_t>(),
rusty1s's avatar
compile  
rusty1s committed
81
82
        row.data_ptr<int64_t>(), col.data_ptr<int64_t>(), r, max_num_neighbors,
        x.size(1));
rusty1s's avatar
rusty1s committed
83
84
85
86
87
  });

  auto mask = row != -1;
  return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0);
}