"vscode:/vscode.git/clone" did not exist on "111fe85d0daf880e154c1b5aa477cba4ebf5e915"
Commit b56c2359 authored by rusty1s's avatar rusty1s
Browse files

use bool mask

parent 1c4fdfe2
...@@ -6,13 +6,13 @@ ...@@ -6,13 +6,13 @@
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t> template <typename scalar_t>
__global__ void unique_cuda_kernel(scalar_t *__restrict__ src, uint8_t *mask, __global__ void unique_cuda_kernel(scalar_t *__restrict__ src, bool *mask,
size_t numel) { size_t numel) {
const size_t index = blockIdx.x * blockDim.x + threadIdx.x; const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x; const size_t stride = blockDim.x * gridDim.x;
for (ptrdiff_t i = index; i < numel; i += stride) { for (ptrdiff_t i = index; i < numel; i += stride) {
if (i == 0 || src[i] != src[i - 1]) { if (i == 0 || src[i] != src[i - 1]) {
mask[i] = 1; mask[i] = true;
} }
} }
} }
...@@ -22,10 +22,10 @@ std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) { ...@@ -22,10 +22,10 @@ std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
at::Tensor perm; at::Tensor perm;
std::tie(src, perm) = src.sort(); std::tie(src, perm) = src.sort();
auto mask = at::zeros(src.numel(), src.options().dtype(at::kByte)); auto mask = at::zeros(src.numel(), src.options().dtype(at::kBool));
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "grid_cuda_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "grid_cuda_kernel", [&] {
unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>( unique_cuda_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS>>>(
src.DATA_PTR<scalar_t>(), mask.DATA_PTR<uint8_t>(), src.numel()); src.DATA_PTR<scalar_t>(), mask.DATA_PTR<bool>(), src.numel());
}); });
src = src.masked_select(mask); src = src.masked_select(mask);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment