unique_kernel.cu 1.04 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>

rusty1s's avatar
rusty1s committed
3
4
#include "compat.cuh"

rusty1s's avatar
rusty1s committed
5
6
7
8
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

template <typename scalar_t>
rusty1s's avatar
rusty1s committed
9
__global__ void unique_cuda_kernel(scalar_t *__restrict__ src, bool *mask,
rusty1s's avatar
rusty1s committed
10
11
12
13
14
                                   size_t numel) {
  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = index; i < numel; i += stride) {
    if (i == 0 || src[i] != src[i - 1]) {
rusty1s's avatar
rusty1s committed
15
      mask[i] = true;
rusty1s's avatar
rusty1s committed
16
17
18
19
20
    }
  }
}

std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
rusty1s's avatar
rusty1s committed
21
  cudaSetDevice(src.get_device());
rusty1s's avatar
rusty1s committed
22

rusty1s's avatar
rusty1s committed
23
24
25
  at::Tensor perm;
  std::tie(src, perm) = src.sort();

rusty1s's avatar
rusty1s committed
26
  auto mask = at::zeros(src.numel(), src.options().dtype(at::kBool));
rusty1s's avatar
rusty1s committed
27
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "grid_cuda_kernel", [&] {
rusty1s's avatar
rusty1s committed
28
    unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>(
rusty1s's avatar
rusty1s committed
29
        src.DATA_PTR<scalar_t>(), mask.DATA_PTR<bool>(), src.numel());
rusty1s's avatar
rusty1s committed
30
31
32
33
34
35
36
  });

  src = src.masked_select(mask);
  perm = perm.masked_select(mask);

  return std::make_tuple(src, perm);
}