Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
6e561c88
"apps/kg/vscode:/vscode.git/clone" did not exist on "7b3a7b14381acf7d5d8213e3e36a94fdf69c827b"
Commit
6e561c88
authored
Jan 07, 2020
by
rusty1s
Browse files
atomics
parent
9725b043
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
19 deletions
+44
-19
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+40
-9
test/test_segment.py
test/test_segment.py
+4
-10
No files found.
cuda/segment_kernel.cu
View file @
6e561c88
...
@@ -67,10 +67,30 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
...
@@ -67,10 +67,30 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
}
}
}
}
static
inline
__device__
void
atom_write
(
scalar_t
*
address
,
scalar_t
val
,
int64_t
*
arg_address
,
int64_t
arg
)
{
if
(
REDUCE
==
ADD
)
{
atomAdd
(
address
,
val
);
}
else
if
(
REDUCE
==
MEAN
)
{
atomAdd
(
address
,
val
);
}
else
if
(
REDUCE
==
MIN
&&
val
<
*
address
)
{
atomMin
(
address
,
val
);
}
else
if
(
REDUCE
==
MAX
&&
val
>
*
address
)
{
atomMax
(
address
,
val
);
}
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
__syncthreads
();
if
(
*
address
==
val
)
{
*
arg_address
=
arg
;
}
}
}
};
};
// We need our own `IndexToOffset` implementation since we do not want to
access
// We need our own `IndexToOffset` implementation since we do not want to
// the last element of the `indexptr`.
//
access
the last element of the `indexptr`.
template
<
typename
scalar_t
>
struct
IndexPtrToOffset
{
template
<
typename
scalar_t
>
struct
IndexPtrToOffset
{
static
inline
__host__
__device__
int
static
inline
__host__
__device__
int
get
(
int
idx
,
const
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int
>
&
info
)
{
get
(
int
idx
,
const
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int
>
&
info
)
{
...
@@ -92,8 +112,8 @@ segment_csr_kernel(const scalar_t *src_data,
...
@@ -92,8 +112,8 @@ segment_csr_kernel(const scalar_t *src_data,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
N
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
N
,
size_t
E
)
{
size_t
E
)
{
// Each warp processes exactly `32/TB` rows and aggregates all row values
via
// Each warp processes exactly `32/TB` rows and aggregates all row values
// a parallel reduction.
//
via
a parallel reduction.
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row_idx
=
thread_idx
/
TB
;
int
row_idx
=
thread_idx
/
TB
;
...
@@ -106,7 +126,7 @@ segment_csr_kernel(const scalar_t *src_data,
...
@@ -106,7 +126,7 @@ segment_csr_kernel(const scalar_t *src_data,
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
scalar_t
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
scalar_t
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
int64_t
arg
,
tmp
;
int64_t
arg
,
arg_
tmp
;
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
;
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
;
for
(
int
src_idx
=
row_start
+
lane_idx
;
src_idx
<
row_end
;
src_idx
+=
TB
)
{
for
(
int
src_idx
=
row_start
+
lane_idx
;
src_idx
<
row_end
;
src_idx
+=
TB
)
{
...
@@ -118,10 +138,10 @@ segment_csr_kernel(const scalar_t *src_data,
...
@@ -118,10 +138,10 @@ segment_csr_kernel(const scalar_t *src_data,
for
(
int
i
=
TB
/
2
;
i
>
0
;
i
/=
2
)
{
for
(
int
i
=
TB
/
2
;
i
>
0
;
i
/=
2
)
{
// Parallel reduction inside a single warp.
// Parallel reduction inside a single warp.
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
tmp
=
__shfl_down_sync
(
FULL_MASK
,
arg
,
i
);
arg_
tmp
=
__shfl_down_sync
(
FULL_MASK
,
arg
,
i
);
}
}
Reducer
<
scalar_t
,
REDUCE
>::
update
(
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
__shfl_down_sync
(
FULL_MASK
,
val
,
i
),
&
arg
,
tmp
);
&
val
,
__shfl_down_sync
(
FULL_MASK
,
val
,
i
),
&
arg
,
arg_
tmp
);
}
}
if
(
lane_idx
==
0
)
{
if
(
lane_idx
==
0
)
{
...
@@ -241,20 +261,27 @@ segment_coo_kernel(const scalar_t *src_data,
...
@@ -241,20 +261,27 @@ segment_coo_kernel(const scalar_t *src_data,
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
row_idx
,
index_info
);
int
idx
=
index_info
.
data
[
offset
],
next_idx
;
int
idx
=
index_info
.
data
[
offset
],
next_idx
;
scalar_t
val
=
src_data
[
row_idx
],
tmp
;
scalar_t
val
=
src_data
[
row_idx
],
tmp
;
int64_t
arg
=
row_idx
%
index_info
.
sizes
[
index_info
.
dims
-
1
],
arg_tmp
;
#pragma unroll
#pragma unroll
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
// Parallel reduction inside a single warp.
tmp
=
__shfl_up_sync
(
FULL_MASK
,
val
,
i
);
tmp
=
__shfl_up_sync
(
FULL_MASK
,
val
,
i
);
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
arg_tmp
=
__shfl_up_sync
(
FULL_MASK
,
arg
,
i
);
}
next_idx
=
__shfl_up_sync
(
FULL_MASK
,
idx
,
i
);
next_idx
=
__shfl_up_sync
(
FULL_MASK
,
idx
,
i
);
assert
(
idx
>=
next_idx
);
assert
(
idx
>=
next_idx
);
if
(
lane_idx
>=
i
&&
idx
==
next_idx
)
if
(
lane_idx
>=
i
&&
idx
==
next_idx
)
val
+=
tmp
;
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
tmp
,
&
arg
,
arg_
tmp
)
;
}
}
next_idx
=
__shfl_down_sync
(
FULL_MASK
,
idx
,
1
);
next_idx
=
__shfl_down_sync
(
FULL_MASK
,
idx
,
1
);
if
(
lane_idx
==
32
-
1
||
idx
!=
next_idx
)
{
if
(
lane_idx
==
32
-
1
||
idx
!=
next_idx
)
{
atomAdd
(
out_data
+
idx
,
val
);
Reducer
<
scalar_t
,
REDUCE
>::
atom_write
(
out_data
+
idx
,
val
,
arg_out_data
+
idx
,
arg
);
}
}
}
}
}
}
...
@@ -365,5 +392,9 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -365,5 +392,9 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
,
K
);
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
,
K
);
});
});
if
(
reduce
==
"mean"
)
{
AT_ASSERTM
(
false
);
// TODO: DIVIDE ENTRIES.
}
return
std
::
make_tuple
(
out
,
arg_out
);
return
std
::
make_tuple
(
out
,
arg_out
);
}
}
test/test_segment.py
View file @
6e561c88
...
@@ -18,18 +18,12 @@ def test_forward(dtype, device):
...
@@ -18,18 +18,12 @@ def test_forward(dtype, device):
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
# out = segment_coo(src, index)
# print('COO', out)
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'add'
)
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'add'
)
print
(
'CSR'
,
out
)
print
(
'CSR'
,
out
)
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'mean'
)
print
(
'CSR'
,
out
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'min'
)
out
=
segment_coo
(
src
,
index
,
reduce
=
'add'
)
print
(
'CSR'
,
out
)
print
(
'COO'
,
out
)
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'max'
)
print
(
'CSR'
,
out
)
# @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
# @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment