grid_kernel.cu 1.55 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>

rusty1s's avatar
rusty1s committed
5
6
#include "compat.cuh"

rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

template <typename scalar_t>
__global__ void grid_kernel(int64_t *cluster,
                            at::cuda::detail::TensorInfo<scalar_t, int64_t> pos,
                            scalar_t *__restrict__ size,
                            scalar_t *__restrict__ start,
                            scalar_t *__restrict__ end, size_t numel) {
  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = index; i < numel; i += stride) {
    int64_t c = 0, k = 1;
    for (ptrdiff_t d = 0; d < pos.sizes[1]; d++) {
      scalar_t p = pos.data[i * pos.strides[0] + d * pos.strides[1]] - start[d];
      c += (int64_t)(p / size[d]) * k;
rusty1s's avatar
rusty1s committed
23
      k *= (int64_t)((end[d] - start[d]) / size[d]) + 1;
rusty1s's avatar
rusty1s committed
24
25
26
27
28
29
30
    }
    cluster[i] = c;
  }
}

at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
                     at::Tensor end) {
rusty1s's avatar
rusty1s committed
31
  cudaSetDevice(pos.get_device());
rusty1s's avatar
rusty1s committed
32
33
  auto cluster = at::empty(pos.size(0), pos.options().dtype(at::kLong));

rusty1s's avatar
rusty1s committed
34
  AT_DISPATCH_ALL_TYPES(pos.scalar_type(), "grid_kernel", [&] {
rusty1s's avatar
rusty1s committed
35
    grid_kernel<scalar_t><<<BLOCKS(cluster.numel()), THREADS>>>(
rusty1s's avatar
rusty1s committed
36
        cluster.DATA_PTR<int64_t>(),
rusty1s's avatar
rusty1s committed
37
        at::cuda::detail::getTensorInfo<scalar_t, int64_t>(pos),
rusty1s's avatar
rusty1s committed
38
39
        size.DATA_PTR<scalar_t>(), start.DATA_PTR<scalar_t>(),
        end.DATA_PTR<scalar_t>(), cluster.numel());
rusty1s's avatar
rusty1s committed
40
41
42
43
  });

  return cluster;
}