Commit b56c2359 authored by rusty1s's avatar rusty1s
Browse files

use bool mask

parent 1c4fdfe2
......@@ -6,13 +6,13 @@
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t>
__global__ void unique_cuda_kernel(scalar_t *__restrict__ src, uint8_t *mask,
__global__ void unique_cuda_kernel(scalar_t *__restrict__ src, bool *mask,
size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) {
if (i == 0 || src[i] != src[i - 1]) {
mask[i] = 1;
mask[i] = true;
}
}
}
......@@ -22,10 +22,10 @@ std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
at::Tensor perm;
std::tie(src, perm) = src.sort();
auto mask = at::zeros(src.numel(), src.options().dtype(at::kByte));
auto mask = at::zeros(src.numel(), src.options().dtype(at::kBool));
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "grid_cuda_kernel", [&] {
unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>(
src.DATA_PTR<scalar_t>(), mask.DATA_PTR<uint8_t>(), src.numel());
src.DATA_PTR<scalar_t>(), mask.DATA_PTR<bool>(), src.numel());
});
src = src.masked_select(mask);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment