grid_cuda.cu 2.22 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
#include "grid_cuda.h"
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
#include <ATen/cuda/CUDAContext.h>
rusty1s's avatar
rusty1s committed
4

rusty1s's avatar
rusty1s committed
5
#include "utils.cuh"
rusty1s's avatar
rusty1s committed
6
7
8
9
10
11
12

#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

template <typename scalar_t>
__global__ void grid_kernel(const scalar_t *pos, const scalar_t *size,
                            const scalar_t *start, const scalar_t *end,
rusty1s's avatar
rusty1s committed
13
                            int64_t *out, int64_t D, int64_t numel) {
rusty1s's avatar
rusty1s committed
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
  const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;

  if (thread_idx < numel) {
    int64_t c = 0, k = 1;
    for (int64_t d = 0; d < D; d++) {
      scalar_t p = pos.data[thread_idx * D + d] - start[d];
      c += (int64_t)(p / size[d]) * k;
      k *= (int64_t)((end[d] - start[d]) / size[d]) + 1;
    }
    out[thread_idx] = c;
  }
}

torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
                       torch::optional<torch::Tensor> optional_start,
                       torch::optional<torch::Tensor> optional_end) {
  CHECK_CUDA(pos);
  CHECK_CUDA(size);
  cudaSetDevice(pos.get_device());

  if (optional_start.has_value())
rusty1s's avatar
rusty1s committed
35
    CHECK_CUDA(optional_start.value());
rusty1s's avatar
rusty1s committed
36
  if (optional_start.has_value())
rusty1s's avatar
rusty1s committed
37
    CHECK_CUDA(optional_start.value());
rusty1s's avatar
rusty1s committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

  pos = pos.view({pos.size(0), -1}).contiguous();
  size = size.contiguous();

  CHECK_INPUT(size.numel() == pos.size(1));

  if (!optional_start.has_value())
    optional_start = std::get<0>(pos.min(0));
  else {
    optional_start = optional_start.value().contiguous();
    CHECK_INPUT(optional_start.value().numel() == pos.size(1));
  }

  if (!optional_end.has_value())
    optional_end = std::get<0>(pos.max(0));
  else {
    optional_start = optional_start.value().contiguous();
    CHECK_INPUT(optional_start.value().numel() == pos.size(1));
  }

  auto start = optional_start.value();
  auto end = optional_end.value();

  auto out = torch::empty(pos.size(0), pos.options().dtype(torch::kLong));

rusty1s's avatar
rusty1s committed
63
  auto stream = at::cuda::getCurrentCUDAStream();
rusty1s's avatar
rusty1s committed
64
  AT_DISPATCH_ALL_TYPES(pos.scalar_type(), "grid_kernel", [&] {
rusty1s's avatar
rusty1s committed
65
    grid_kernel<scalar_t><<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
rusty1s's avatar
rusty1s committed
66
67
        pos.data_ptr<scalar_t>(), size.data_ptr<scalar_t>(),
        start.data_ptr<scalar_t>(), end.data_ptr<scalar_t>(),
rusty1s's avatar
rusty1s committed
68
        out.data_ptr<int64_t>(), pos.size(1), out.numel());
rusty1s's avatar
rusty1s committed
69
70
71
72
  });

  return out;
}