grid_kernel.cu 1.36 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>

rusty1s's avatar
rusty1s committed
4
#include "common.cuh"
rusty1s's avatar
cleaner  
rusty1s committed
5

rusty1s's avatar
rusty1s committed
6
template <typename scalar_t>
rusty1s's avatar
cleanup  
rusty1s committed
7
__global__ void
rusty1s's avatar
rusty1s committed
8
9
10
grid_kernel(int64_t *cluster, at::cuda::detail::TensorInfo<scalar_t, int> pos,
            scalar_t *__restrict__ size, scalar_t *__restrict__ start,
            scalar_t *__restrict__ end, size_t num_nodes) {
rusty1s's avatar
rusty1s committed
11
12
  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
rusty1s's avatar
typo  
rusty1s committed
13
  for (ptrdiff_t i = index; i < num_nodes; i += stride) {
rusty1s's avatar
rusty1s committed
14
15
16
    int64_t c = 0, k = 1;
    scalar_t tmp;
    for (ptrdiff_t d = 0; d < pos.sizes[1]; d++) {
rusty1s's avatar
rusty1s committed
17
      tmp = pos.data[i * pos.strides[0] + d * pos.strides[1]] - start[d];
rusty1s's avatar
rusty1s committed
18
19
20
21
22
23
24
      c += (int64_t)(tmp / size[d]) * k;
      k += (int64_t)((end[d] - start[d]) / size[d]);
    }
    cluster[i] = c;
  }
}

rusty1s's avatar
rusty1s committed
25
26
27
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
                at::Tensor end) {
  auto cluster = at::empty(pos.type().toScalarType(at::kLong), {pos.size(0)});
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
rusty1s committed
29
30
  AT_DISPATCH_ALL_TYPES(pos.type(), "grid_kernel", [&] {
    grid_kernel<scalar_t><<<BLOCKS(pos.size(0)), THREADS>>>(
rusty1s's avatar
cleaner  
rusty1s committed
31
32
33
        cluster.data<int64_t>(),
        at::cuda::detail::getTensorInfo<scalar_t, int>(pos),
        size.toType(pos.type()).data<scalar_t>(),
rusty1s's avatar
rusty1s committed
34
        start.toType(pos.type()).data<scalar_t>(),
rusty1s's avatar
rusty1s committed
35
        end.toType(pos.type()).data<scalar_t>(), pos.size(0));
rusty1s's avatar
rusty1s committed
36
37
38
39
  });

  return cluster;
}