cluster_kernel.cu 1.64 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>

template <typename scalar_t>
__global__ void grid_cuda_kernel(
    int64_t *cluster, const at::cuda::detail::TensorInfo<scalar_t, int> pos,
    const scalar_t *__restrict__ size, const scalar_t *__restrict__ start,
    const scalar_t *__restrict__ end, const size_t n) {
  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = index; i < n; i += stride) {
    int64_t c = 0, k = 1;
    scalar_t tmp;
    for (ptrdiff_t d = 0; d < pos.sizes[1]; d++) {
      tmp = (pos.data[i * pos.strides[0] + d * pos.strides[1]]) - start[d];
      c += (int64_t)(tmp / size[d]) * k;
      k += (int64_t)((end[d] - start[d]) / size[d]);
    }
    cluster[i] = c;
  }
}

at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
                     at::Tensor end) {
  size = size.toType(pos.type());
  start = start.toType(pos.type());
  end = end.toType(pos.type());

  const auto num_nodes = pos.size(0);
  auto cluster = at::empty(pos.type().toScalarType(at::kLong), {num_nodes});

  const int threads = 1024;
  const dim3 blocks((num_nodes + threads - 1) / threads);

  AT_DISPATCH_ALL_TYPES(pos.type(), "unique", [&] {
    auto cluster_data = cluster.data<int64_t>();
    auto pos_info = at::cuda::detail::getTensorInfo<scalar_t, int>(pos);
    auto size_data = size.data<scalar_t>();
    auto start_data = start.data<scalar_t>();
    auto end_data = end.data<scalar_t>();
    grid_cuda_kernel<scalar_t><<<blocks, threads>>>(
        cluster_data, pos_info, size_data, start_data, end_data, num_nodes);
  });

  return cluster;
}