Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
b56c2359
Commit
b56c2359
authored
Oct 14, 2019
by
rusty1s
Browse files
use bool mask
parent
1c4fdfe2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
cuda/unique_kernel.cu
cuda/unique_kernel.cu
+4
-4
No files found.
cuda/unique_kernel.cu
View file @
b56c2359
...
...
@@ -6,13 +6,13 @@
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template
<
typename
scalar_t
>
__global__
void
unique_cuda_kernel
(
scalar_t
*
__restrict__
src
,
uint8_t
*
mask
,
__global__
void
unique_cuda_kernel
(
scalar_t
*
__restrict__
src
,
bool
*
mask
,
size_t
numel
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
index
;
i
<
numel
;
i
+=
stride
)
{
if
(
i
==
0
||
src
[
i
]
!=
src
[
i
-
1
])
{
mask
[
i
]
=
1
;
mask
[
i
]
=
true
;
}
}
}
...
...
@@ -22,10 +22,10 @@ std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
at
::
Tensor
perm
;
std
::
tie
(
src
,
perm
)
=
src
.
sort
();
auto
mask
=
at
::
zeros
(
src
.
numel
(),
src
.
options
().
dtype
(
at
::
kB
yte
));
auto
mask
=
at
::
zeros
(
src
.
numel
(),
src
.
options
().
dtype
(
at
::
kB
ool
));
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"grid_cuda_kernel"
,
[
&
]
{
unique_cuda_kernel
<
scalar_t
><<<
BLOCKS
(
src
.
numel
()),
THREADS
>>>
(
src
.
DATA_PTR
<
scalar_t
>
(),
mask
.
DATA_PTR
<
uint8_t
>
(),
src
.
numel
());
src
.
DATA_PTR
<
scalar_t
>
(),
mask
.
DATA_PTR
<
bool
>
(),
src
.
numel
());
});
src
=
src
.
masked_select
(
mask
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment