cluster_kernel.cu 1.65 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>

template <typename scalar_t>
__global__ void grid_cuda_kernel(
    int64_t *cluster, const at::cuda::detail::TensorInfo<scalar_t, int> pos,
    const scalar_t *__restrict__ size, const scalar_t *__restrict__ start,
rusty1s's avatar
typo  
rusty1s committed
8
    const scalar_t *__restrict__ end, const size_t num_nodes) {
rusty1s's avatar
rusty1s committed
9
10
  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
rusty1s's avatar
typo  
rusty1s committed
11
  for (ptrdiff_t i = index; i < num_nodes; i += stride) {
rusty1s's avatar
rusty1s committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    int64_t c = 0, k = 1;
    scalar_t tmp;
    for (ptrdiff_t d = 0; d < pos.sizes[1]; d++) {
      tmp = (pos.data[i * pos.strides[0] + d * pos.strides[1]]) - start[d];
      c += (int64_t)(tmp / size[d]) * k;
      k += (int64_t)((end[d] - start[d]) / size[d]);
    }
    cluster[i] = c;
  }
}

at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
                     at::Tensor end) {
  size = size.toType(pos.type());
  start = start.toType(pos.type());
  end = end.toType(pos.type());

  const auto num_nodes = pos.size(0);
  auto cluster = at::empty(pos.type().toScalarType(at::kLong), {num_nodes});

  const int threads = 1024;
  const dim3 blocks((num_nodes + threads - 1) / threads);

  AT_DISPATCH_ALL_TYPES(pos.type(), "unique", [&] {
    auto cluster_data = cluster.data<int64_t>();
    auto pos_info = at::cuda::detail::getTensorInfo<scalar_t, int>(pos);
    auto size_data = size.data<scalar_t>();
    auto start_data = start.data<scalar_t>();
    auto end_data = end.data<scalar_t>();
    grid_cuda_kernel<scalar_t><<<blocks, threads>>>(
        cluster_data, pos_info, size_data, start_data, end_data, num_nodes);
  });

  return cluster;
}