Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
6fca568d
Commit
6fca568d
authored
Jul 28, 2021
by
rusty1s
Browse files
override shfl methods for torch.half
parent
66bcc36e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
12 deletions
+14
-12
csrc/cuda/segment_coo_cuda.cu
csrc/cuda/segment_coo_cuda.cu
+1
-6
csrc/cuda/segment_csr_cuda.cu
csrc/cuda/segment_csr_cuda.cu
+1
-6
csrc/cuda/utils.cuh
csrc/cuda/utils.cuh
+12
-0
No files found.
csrc/cuda/segment_coo_cuda.cu
View file @
6fca568d
...
...
@@ -3,7 +3,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <type_traits>
#include "reducer.cuh"
#include "utils.cuh"
...
...
@@ -26,10 +25,6 @@ segment_coo_kernel(const scalar_t *src_data,
int
lane_idx
=
row_idx
&
(
32
-
1
);
int
D
=
index_info
.
sizes
[
index_info
.
dims
-
1
];
using
cuda_scalar_t
=
typename
std
::
conditional
<
std
::
is_same
<
scalar_t
,
at
::
Half
>::
value
,
__half
,
scalar_t
>::
type
;
if
(
row_idx
<
E
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
...
...
@@ -41,7 +36,7 @@ segment_coo_kernel(const scalar_t *src_data,
#pragma unroll
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
// Parallel reduction inside a single warp.
tmp
=
__shfl_up_sync
(
FULL_MASK
,
(
cuda_scalar_t
)
val
,
i
);
tmp
=
__shfl_up_sync
(
FULL_MASK
,
val
,
i
);
next_idx
=
__shfl_up_sync
(
FULL_MASK
,
idx
,
i
);
if
(
lane_idx
>=
i
&&
row_idx
/
D
==
(
row_idx
-
i
)
/
D
)
{
assert
(
idx
>=
next_idx
);
...
...
csrc/cuda/segment_csr_cuda.cu
View file @
6fca568d
...
...
@@ -26,10 +26,6 @@ segment_csr_kernel(const scalar_t *src_data,
int
row_idx
=
thread_idx
/
TB
;
int
lane_idx
=
thread_idx
&
(
TB
-
1
);
using
cuda_scalar_t
=
typename
std
::
conditional
<
std
::
is_same
<
scalar_t
,
at
::
Half
>::
value
,
__half
,
scalar_t
>::
type
;
if
(
row_idx
<
N
)
{
int
offset
=
IndexPtrToOffset
<
int64_t
>::
get
(
row_idx
,
indptr_info
);
int64_t
row_start
=
__ldg
(
indptr_info
.
data
+
offset
);
...
...
@@ -52,8 +48,7 @@ segment_csr_kernel(const scalar_t *src_data,
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
arg_tmp
=
__shfl_down_sync
(
FULL_MASK
,
arg
,
i
);
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
__shfl_down_sync
(
FULL_MASK
,
(
cuda_scalar_t
)
val
,
i
),
&
arg
,
arg_tmp
);
&
val
,
__shfl_down_sync
(
FULL_MASK
,
val
,
i
),
&
arg
,
arg_tmp
);
}
if
(
lane_idx
==
0
)
{
...
...
csrc/cuda/utils.cuh
View file @
6fca568d
...
...
@@ -5,3 +5,15 @@
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
__device__
__inline__
at
::
Half
__shfl_up_sync
(
const
unsigned
mask
,
const
at
::
Half
var
,
const
unsigned
int
delta
)
{
return
__shfl_up_sync
(
mask
,
(
__half
)
var
,
delta
);
}
__device__
__inline__
at
::
Half
__shfl_down_sync
(
const
unsigned
mask
,
const
at
::
Half
var
,
const
unsigned
int
delta
)
{
return
__shfl_down_sync
(
mask
,
(
__half
)
var
,
delta
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment