unique_kernel.cu 1.04 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
#include <ATen/ATen.h>

rusty1s's avatar
rusty1s committed
3
4
#include "compat.cuh"

rusty1s's avatar
rusty1s committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

template <typename scalar_t>
__global__ void unique_cuda_kernel(scalar_t *__restrict__ src, uint8_t *mask,
                                   size_t numel) {
  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = index; i < numel; i += stride) {
    if (i == 0 || src[i] != src[i - 1]) {
      mask[i] = 1;
    }
  }
}

std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
rusty1s's avatar
rusty1s committed
21
  cudaSetDevice(src.get_device());
rusty1s's avatar
rusty1s committed
22
23
24
  at::Tensor perm;
  std::tie(src, perm) = src.sort();

rusty1s's avatar
rusty1s committed
25
26
  auto mask = at::zeros(src.numel(), src.options().dtype(at::kByte));
  AT_DISPATCH_ALL_TYPES(src.scalar_type(), "grid_cuda_kernel", [&] {
rusty1s's avatar
rusty1s committed
27
    unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>(
rusty1s's avatar
rusty1s committed
28
        src.DATA_PTR<scalar_t>(), mask.DATA_PTR<uint8_t>(), src.numel());
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
  });

  src = src.masked_select(mask);
  perm = perm.masked_select(mask);

  return std::make_tuple(src, perm);
}