"app/vscode:/vscode.git/clone" did not exist on "a9cc270b4da9bac2b2ffbc05c809b49e56649816"
unique_kernel.cu 1 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include <ATen/ATen.h>

#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS

template <typename scalar_t>
__global__ void unique_cuda_kernel(scalar_t *__restrict__ src, uint8_t *mask,
                                   size_t numel) {
  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
  const size_t stride = blockDim.x * gridDim.x;
  for (ptrdiff_t i = index; i < numel; i += stride) {
    if (i == 0 || src[i] != src[i - 1]) {
      mask[i] = 1;
    }
  }
}

std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
rusty1s's avatar
rusty1s committed
19
  cudaSetDevice(src.get_device());
rusty1s's avatar
rusty1s committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
  at::Tensor perm;
  std::tie(src, perm) = src.sort();

  auto mask = at::zeros(src.numel(), src.type().toScalarType(at::kByte));
  AT_DISPATCH_ALL_TYPES(src.type(), "grid_cuda_kernel", [&] {
    unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>(
        src.data<scalar_t>(), mask.data<uint8_t>(), src.numel());
  });

  src = src.masked_select(mask);
  perm = perm.masked_select(mask);

  return std::make_tuple(src, perm);
}