Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
cd6d8d68
"examples/git@developer.sourcefind.cn:OpenDAS/fairseq.git" did not exist on "b651b000033fd8ff51d1c3bea76f4fd1897bdf9c"
Commit
cd6d8d68
authored
Jan 09, 2020
by
rusty1s
Browse files
all cuda kernels done
parent
7e82bc0e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
180 additions
and
149 deletions
+180
-149
benchmark/scatter_segment.py
benchmark/scatter_segment.py
+59
-95
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+99
-47
test/test_segment.py
test/test_segment.py
+22
-7
No files found.
benchmark/scatter_segment.py
View file @
cd6d8d68
# flake8: noqa
import
time
import
time
import
os.path
as
osp
import
os.path
as
osp
import
itertools
import
itertools
import
argparse
import
wget
import
wget
import
torch
import
torch
from
scipy.io
import
loadmat
from
scipy.io
import
loadmat
import
torch_scatter
from
torch_scatter
import
scatter_add
,
scatter_mean
,
scatter_min
,
scatter_max
from
torch_scatter
import
scatter_add
,
scatter_mean
,
scatter_min
,
scatter_max
from
torch_scatter
import
segment_coo
,
segment_csr
from
torch_scatter
import
segment_coo
,
segment_csr
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--reduce'
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda'
)
args
=
parser
.
parse_args
()
args
.
dense_reduce
=
'sum'
if
args
.
reduce
==
'add'
else
args
.
reduce
iters
=
20
iters
=
20
device
=
'cuda'
sizes
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
sizes
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
short_rows
=
[
short_rows
=
[
...
@@ -40,13 +49,13 @@ def bold(text, flag=True):
...
@@ -40,13 +49,13 @@ def bold(text, flag=True):
def
correctness
(
dataset
):
def
correctness
(
dataset
):
group
,
name
=
dataset
group
,
name
=
dataset
mat
=
loadmat
(
f
'
{
name
}
.mat'
)[
'Problem'
][
0
][
0
][
2
].
tocsr
()
mat
=
loadmat
(
f
'
{
name
}
.mat'
)[
'Problem'
][
0
][
0
][
2
].
tocsr
()
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
device
,
torch
.
long
)
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
args
.
device
,
torch
.
long
)
row
=
torch
.
from_numpy
(
mat
.
tocoo
().
row
).
to
(
device
,
torch
.
long
)
row
=
torch
.
from_numpy
(
mat
.
tocoo
().
row
).
to
(
args
.
device
,
torch
.
long
)
dim_size
=
rowptr
.
size
(
0
)
-
1
dim_size
=
rowptr
.
size
(
0
)
-
1
for
size
in
sizes
:
for
size
in
sizes
:
try
:
try
:
x
=
torch
.
randn
((
row
.
size
(
0
),
size
),
device
=
device
)
x
=
torch
.
randn
((
row
.
size
(
0
),
size
),
device
=
args
.
device
)
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
out1
=
scatter_add
(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
)
out1
=
scatter_add
(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
)
...
@@ -63,92 +72,71 @@ def correctness(dataset):
...
@@ -63,92 +72,71 @@ def correctness(dataset):
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
# out1, arg_out1 = scatter_max(x, row, dim=0, dim_size=dim_size)
x
=
x
.
abs_
().
mul_
(
-
1
)
# out3, arg_out3 = segment_csr(x, rowptr, reduce='max')
# print(out1[:5])
out1
,
arg_out1
=
scatter_min
(
x
,
row
,
0
,
torch
.
zeros_like
(
out1
))
# print(out3[:5])
out2
,
arg_out2
=
segment_coo
(
x
,
row
,
reduce
=
'min'
)
out3
,
arg_out3
=
segment_csr
(
x
,
rowptr
,
reduce
=
'min'
)
# nnz = (out1 != out3).nonzero().flatten()
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
# nnz1 = nnz[0].item()
x
=
x
.
abs_
()
# print(rowptr[nnz1], rowptr[nnz1 + 1])
# print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
out1
,
arg_out1
=
scatter_max
(
x
,
row
,
0
,
torch
.
zeros_like
(
out1
))
# print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
out2
,
arg_out2
=
segment_coo
(
x
,
row
,
reduce
=
'max'
)
out3
,
arg_out3
=
segment_csr
(
x
,
rowptr
,
reduce
=
'max'
)
# print(out1[nnz1]
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
# print(out3[nnz1]
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
# assert torch.allclose(out1, out3, atol=1e-4)
# assert torch.all(arg_out1 == arg_out3)
except
RuntimeError
:
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
time_func
(
func
,
x
):
try
:
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
func
(
x
)
torch
.
cuda
.
synchronize
()
return
time
.
perf_counter
()
-
t
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
return
float
(
'inf'
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
timing
(
dataset
):
def
timing
(
dataset
):
group
,
name
=
dataset
group
,
name
=
dataset
mat
=
loadmat
(
f
'
{
name
}
.mat'
)[
'Problem'
][
0
][
0
][
2
].
tocsr
()
mat
=
loadmat
(
f
'
{
name
}
.mat'
)[
'Problem'
][
0
][
0
][
2
].
tocsr
()
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
device
,
torch
.
long
)
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
args
.
device
,
torch
.
long
)
row
=
torch
.
from_numpy
(
mat
.
tocoo
().
row
).
to
(
device
,
torch
.
long
)
row
=
torch
.
from_numpy
(
mat
.
tocoo
().
row
).
to
(
args
.
device
,
torch
.
long
)
row_perm
=
row
[
torch
.
randperm
(
row
.
size
(
0
))]
row_perm
=
row
[
torch
.
randperm
(
row
.
size
(
0
))]
dim_size
=
rowptr
.
size
(
0
)
-
1
dim_size
=
rowptr
.
size
(
0
)
-
1
avg_row_len
=
row
.
size
(
0
)
/
dim_size
avg_row_len
=
row
.
size
(
0
)
/
dim_size
sca_row
=
lambda
x
:
getattr
(
torch_scatter
,
f
'scatter_
{
args
.
reduce
}
'
)(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
)
sca_col
=
lambda
x
:
getattr
(
torch_scatter
,
f
'scatter_
{
args
.
reduce
}
'
)(
x
,
row_perm
,
dim
=
0
,
dim_size
=
dim_size
)
seg_coo
=
lambda
x
:
segment_coo
(
x
,
row
,
reduce
=
args
.
reduce
)
seg_csr
=
lambda
x
:
segment_csr
(
x
,
rowptr
,
reduce
=
args
.
reduce
)
dense1
=
lambda
x
:
getattr
(
torch
,
args
.
dense_reduce
)(
x
,
dim
=-
2
)
dense2
=
lambda
x
:
getattr
(
torch
,
args
.
dense_reduce
)(
x
,
dim
=-
1
)
t1
,
t2
,
t3
,
t4
,
t5
,
t6
=
[],
[],
[],
[],
[],
[]
t1
,
t2
,
t3
,
t4
,
t5
,
t6
=
[],
[],
[],
[],
[],
[]
for
size
in
sizes
:
for
size
in
sizes
:
try
:
try
:
x
=
torch
.
randn
((
row
.
size
(
0
),
size
),
device
=
device
)
x
=
torch
.
randn
((
row
.
size
(
0
),
size
),
device
=
args
.
device
)
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
try
:
t1
+=
[
time_func
(
sca_row
,
x
)]
torch
.
cuda
.
synchronize
()
t2
+=
[
time_func
(
sca_col
,
x
)]
t
=
time
.
perf_counter
()
t3
+=
[
time_func
(
seg_coo
,
x
)]
for
_
in
range
(
iters
):
t4
+=
[
time_func
(
seg_csr
,
x
)]
out
=
scatter_add
(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
)
del
out
torch
.
cuda
.
synchronize
()
t1
.
append
(
time
.
perf_counter
()
-
t
)
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
t1
.
append
(
float
(
'inf'
))
try
:
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
out
=
scatter_add
(
x
,
row_perm
,
dim
=
0
,
dim_size
=
dim_size
)
del
out
torch
.
cuda
.
synchronize
()
t2
.
append
(
time
.
perf_counter
()
-
t
)
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
t2
.
append
(
float
(
'inf'
))
try
:
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
out
=
segment_coo
(
x
,
row
,
dim_size
=
dim_size
,
reduce
=
'any'
)
del
out
torch
.
cuda
.
synchronize
()
t3
.
append
(
time
.
perf_counter
()
-
t
)
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
t3
.
append
(
float
(
'inf'
))
try
:
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
out
=
segment_csr
(
x
,
rowptr
,
reduce
=
'any'
)
del
out
torch
.
cuda
.
synchronize
()
t4
.
append
(
time
.
perf_counter
()
-
t
)
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
t4
.
append
(
float
(
'inf'
))
del
x
del
x
...
@@ -159,35 +147,11 @@ def timing(dataset):
...
@@ -159,35 +147,11 @@ def timing(dataset):
try
:
try
:
x
=
torch
.
randn
((
dim_size
,
int
(
avg_row_len
+
1
),
size
),
x
=
torch
.
randn
((
dim_size
,
int
(
avg_row_len
+
1
),
size
),
device
=
device
)
device
=
args
.
device
)
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
try
:
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
out
=
x
.
sum
(
dim
=
1
)
del
out
torch
.
cuda
.
synchronize
()
t5
.
append
(
time
.
perf_counter
()
-
t
)
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
t5
.
append
(
float
(
'inf'
))
t5
+=
[
time_func
(
dense1
,
x
)]
x
=
x
.
view
(
dim_size
,
size
,
int
(
avg_row_len
+
1
))
x
=
x
.
view
(
dim_size
,
size
,
int
(
avg_row_len
+
1
))
x
=
x
.
squeeze
(
-
2
)
if
size
==
1
else
x
t6
+=
[
time_func
(
dense2
,
x
)]
try
:
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
out
=
x
.
sum
(
dim
=-
1
)
del
out
torch
.
cuda
.
synchronize
()
t6
.
append
(
time
.
perf_counter
()
-
t
)
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
t6
.
append
(
float
(
'inf'
))
del
x
del
x
...
@@ -221,7 +185,7 @@ def timing(dataset):
...
@@ -221,7 +185,7 @@ def timing(dataset):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
for
_
in
range
(
10
):
# Warmup.
for
_
in
range
(
10
):
# Warmup.
torch
.
randn
(
100
,
100
,
device
=
device
).
sum
()
torch
.
randn
(
100
,
100
,
device
=
args
.
device
).
sum
()
for
dataset
in
itertools
.
chain
(
short_rows
,
long_rows
):
for
dataset
in
itertools
.
chain
(
short_rows
,
long_rows
):
download
(
dataset
)
download
(
dataset
)
correctness
(
dataset
)
correctness
(
dataset
)
...
...
cuda/segment_kernel.cu
View file @
cd6d8d68
...
@@ -34,12 +34,22 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
...
@@ -34,12 +34,22 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
if
(
REDUCE
==
MIN
)
{
if
(
REDUCE
==
MIN
)
{
return
std
::
numeric_limits
<
scalar_t
>::
max
();
return
std
::
numeric_limits
<
scalar_t
>::
max
();
}
else
if
(
REDUCE
==
MAX
)
{
}
else
if
(
REDUCE
==
MAX
)
{
return
std
::
numeric_limits
<
scalar_t
>::
min
();
return
std
::
numeric_limits
<
scalar_t
>::
lowest
();
}
else
{
}
else
{
return
(
scalar_t
)
0
;
return
(
scalar_t
)
0
;
}
}
}
}
static
inline
__host__
__device__
void
update
(
scalar_t
*
val
,
scalar_t
new_val
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
*
val
=
*
val
+
new_val
;
}
else
if
((
REDUCE
==
MIN
&&
new_val
<
*
val
)
||
(
REDUCE
==
MAX
&&
new_val
>
*
val
))
{
*
val
=
new_val
;
}
}
static
inline
__host__
__device__
void
update
(
scalar_t
*
val
,
scalar_t
new_val
,
static
inline
__host__
__device__
void
update
(
scalar_t
*
val
,
scalar_t
new_val
,
int64_t
*
arg
,
int64_t
new_arg
)
{
int64_t
*
arg
,
int64_t
new_arg
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
...
@@ -68,9 +78,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
...
@@ -68,9 +78,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
}
}
static
inline
__device__
void
atomic_write
(
scalar_t
*
address
,
scalar_t
val
,
static
inline
__device__
void
atomic_write
(
scalar_t
*
address
,
scalar_t
val
)
{
int64_t
*
arg_address
,
int64_t
arg
)
{
if
(
REDUCE
==
ADD
)
{
if
(
REDUCE
==
ADD
)
{
atomAdd
(
address
,
val
);
atomAdd
(
address
,
val
);
}
else
if
(
REDUCE
==
MEAN
)
{
}
else
if
(
REDUCE
==
MEAN
)
{
...
@@ -80,14 +88,6 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
...
@@ -80,14 +88,6 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
else
if
(
REDUCE
==
MAX
&&
val
>
*
address
)
{
}
else
if
(
REDUCE
==
MAX
&&
val
>
*
address
)
{
atomMax
(
address
,
val
);
atomMax
(
address
,
val
);
}
}
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
assert
(
false
);
// TODO
__syncthreads
();
if
(
*
address
==
val
)
{
*
arg_address
=
arg
;
}
}
}
}
};
};
...
@@ -111,7 +111,7 @@ segment_csr_kernel(const scalar_t *src_data,
...
@@ -111,7 +111,7 @@ segment_csr_kernel(const scalar_t *src_data,
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
scalar_t
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
()
,
tmp
;
scalar_t
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
int64_t
arg
,
arg_tmp
;
int64_t
arg
,
arg_tmp
;
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
;
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
;
...
@@ -123,16 +123,10 @@ segment_csr_kernel(const scalar_t *src_data,
...
@@ -123,16 +123,10 @@ segment_csr_kernel(const scalar_t *src_data,
#pragma unroll
#pragma unroll
for
(
int
i
=
TB
/
2
;
i
>
0
;
i
/=
2
)
{
for
(
int
i
=
TB
/
2
;
i
>
0
;
i
/=
2
)
{
// Parallel reduction inside a single warp.
// Parallel reduction inside a single warp.
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
tmp
=
__shfl_down_sync
(
FULL_MASK
,
val
,
i
);
arg_tmp
=
__shfl_down_sync
(
FULL_MASK
,
arg
,
i
);
arg_tmp
=
__shfl_down_sync
(
FULL_MASK
,
arg
,
i
);
// Only update valid entries.
Reducer
<
scalar_t
,
REDUCE
>::
update
(
if
(
lane_idx
<
i
&&
row_start
+
lane_idx
+
i
<
row_end
)
&
val
,
__shfl_down_sync
(
FULL_MASK
,
val
,
i
),
&
arg
,
arg_tmp
);
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
tmp
,
&
arg
,
arg_tmp
);
}
else
{
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
__shfl_down_sync
(
FULL_MASK
,
val
,
i
),
&
arg
,
arg_tmp
);
}
}
}
if
(
lane_idx
==
0
)
{
if
(
lane_idx
==
0
)
{
...
@@ -256,7 +250,7 @@ template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
...
@@ -256,7 +250,7 @@ template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
__global__
void
__global__
void
segment_coo_kernel
(
const
scalar_t
*
src_data
,
segment_coo_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
)
{
scalar_t
*
out_data
,
size_t
E
,
size_t
N
)
{
// Each thread processes exactly one entry. Within a warp, we perform a
// Each thread processes exactly one entry. Within a warp, we perform a
// parallel reduction across equal indices, and write the intermediate
// parallel reduction across equal indices, and write the intermediate
...
@@ -269,40 +263,52 @@ segment_coo_kernel(const scalar_t *src_data,
...
@@ -269,40 +263,52 @@ segment_coo_kernel(const scalar_t *src_data,
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
row_idx
,
index_info
);
int
idx
=
index_info
.
data
[
offset
],
next_idx
;
int
idx
=
index_info
.
data
[
offset
],
next_idx
;
int
out_idx
=
(
row_idx
/
index_info
.
sizes
[
index_info
.
dims
-
1
])
*
N
+
idx
;
scalar_t
val
=
HAS_VAL
?
src_data
[
row_idx
]
:
(
scalar_t
)
1
,
tmp
;
scalar_t
val
=
HAS_VAL
?
src_data
[
row_idx
]
:
(
scalar_t
)
1
,
tmp
;
int64_t
arg
,
arg_tmp
;
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
arg
=
row_idx
%
index_info
.
sizes
[
index_info
.
dims
-
1
];
}
#pragma unroll
#pragma unroll
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
// Parallel reduction inside a single warp.
// Parallel reduction inside a single warp.
tmp
=
__shfl_up_sync
(
FULL_MASK
,
val
,
i
);
tmp
=
__shfl_up_sync
(
FULL_MASK
,
val
,
i
);
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
arg_tmp
=
__shfl_up_sync
(
FULL_MASK
,
arg
,
i
);
}
next_idx
=
__shfl_up_sync
(
FULL_MASK
,
idx
,
i
);
next_idx
=
__shfl_up_sync
(
FULL_MASK
,
idx
,
i
);
assert
(
idx
>=
next_idx
);
assert
(
idx
>=
next_idx
);
if
(
lane_idx
>=
i
&&
idx
==
next_idx
)
if
(
lane_idx
>=
i
&&
idx
==
next_idx
)
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
tmp
,
&
arg
,
arg_
tmp
);
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
tmp
);
}
}
next_idx
=
__shfl_down_sync
(
FULL_MASK
,
idx
,
1
);
next_idx
=
__shfl_down_sync
(
FULL_MASK
,
idx
,
1
);
if
(
lane_idx
==
32
-
1
||
idx
!=
next_idx
)
{
if
(
lane_idx
==
32
-
1
||
idx
!=
next_idx
)
{
Reducer
<
scalar_t
,
REDUCE
>::
atomic_write
(
out_data
+
idx
,
val
,
Reducer
<
scalar_t
,
REDUCE
>::
atomic_write
(
out_data
+
out_idx
,
val
);
arg_out_data
+
idx
,
arg
);
}
}
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
segment_coo_arg_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
,
size_t
N
)
{
int
row_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
row_idx
<
E
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
int
idx
=
index_info
.
data
[
offset
];
int
out_idx
=
(
row_idx
/
index_info
.
sizes
[
index_info
.
dims
-
1
])
*
N
+
idx
;
scalar_t
val
=
__ldg
(
out_data
+
out_idx
);
if
(
src_data
[
row_idx
]
==
val
)
arg_out_data
[
out_idx
]
=
row_idx
%
index_info
.
sizes
[
index_info
.
dims
-
1
];
}
}
template
<
typename
scalar_t
,
ReductionType
REDUCE
,
int
TB
>
template
<
typename
scalar_t
,
ReductionType
REDUCE
,
int
TB
>
__global__
void
segment_coo_broadcast_kernel
(
__global__
void
segment_coo_broadcast_kernel
(
const
scalar_t
*
src_data
,
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
,
size_t
K
)
{
scalar_t
*
out_data
,
size_t
E
,
size_t
K
,
size_t
N
)
{
// Each thread processes a single column and `TB` index entries. Coalesced
// Each thread processes a single column and `TB` index entries. Coalesced
// read and write is performed in column-major order. The intermediate
// read and write is performed in column-major order. The intermediate
...
@@ -314,6 +320,7 @@ __global__ void segment_coo_broadcast_kernel(
...
@@ -314,6 +320,7 @@ __global__ void segment_coo_broadcast_kernel(
if
(
row_start
<
E
&&
col_idx
<
K
)
{
if
(
row_start
<
E
&&
col_idx
<
K
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_start
,
index_info
);
row_start
,
index_info
);
int
out_idx
=
(
row_start
/
index_info
.
sizes
[
index_info
.
dims
-
1
])
*
N
;
int
idx1
=
__ldg
(
index_info
.
data
+
offset
);
int
idx1
=
__ldg
(
index_info
.
data
+
offset
);
scalar_t
val
=
src_data
[
K
*
row_start
+
col_idx
];
scalar_t
val
=
src_data
[
K
*
row_start
+
col_idx
];
...
@@ -327,15 +334,42 @@ __global__ void segment_coo_broadcast_kernel(
...
@@ -327,15 +334,42 @@ __global__ void segment_coo_broadcast_kernel(
i
*
index_info
.
strides
[
index_info
.
dims
-
1
]);
i
*
index_info
.
strides
[
index_info
.
dims
-
1
]);
assert
(
idx1
<=
idx2
);
assert
(
idx1
<=
idx2
);
if
(
idx1
==
idx2
)
{
if
(
idx1
==
idx2
)
{
val
+=
src_data
[
K
*
(
row_start
+
i
)
+
col_idx
];
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
src_data
[
K
*
(
row_start
+
i
)
+
col_idx
]);
}
else
{
}
else
{
atomAdd
(
out_data
+
K
*
idx1
+
col_idx
,
val
);
Reducer
<
scalar_t
,
REDUCE
>::
atomic_write
(
out_data
+
(
out_idx
+
idx1
)
*
K
+
col_idx
,
val
);
val
=
src_data
[
K
*
(
row_start
+
i
)
+
col_idx
];
val
=
src_data
[
K
*
(
row_start
+
i
)
+
col_idx
];
}
}
idx1
=
idx2
;
idx1
=
idx2
;
}
}
atomAdd
(
out_data
+
K
*
idx1
+
col_idx
,
val
);
Reducer
<
scalar_t
,
REDUCE
>::
atomic_write
(
out_data
+
(
out_idx
+
idx1
)
*
K
+
col_idx
,
val
);
}
}
template
<
typename
scalar_t
>
__global__
void
segment_coo_arg_broadcast_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
,
size_t
K
,
size_t
N
)
{
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row_idx
=
thread_idx
/
K
;
int
col_idx
=
thread_idx
%
K
;
if
(
row_idx
<
E
&&
col_idx
<
K
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
int
idx
=
__ldg
(
index_info
.
data
+
offset
);
int
out_idx
=
((
row_idx
/
index_info
.
sizes
[
index_info
.
dims
-
1
])
*
N
+
idx
)
*
K
+
col_idx
;
scalar_t
val
=
__ldg
(
out_data
+
out_idx
);
if
(
src_data
[
thread_idx
]
==
val
)
arg_out_data
[
out_idx
]
=
row_idx
%
index_info
.
sizes
[
index_info
.
dims
-
1
];
}
}
}
}
...
@@ -371,6 +405,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -371,6 +405,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
auto
E
=
index
.
numel
();
auto
E
=
index
.
numel
();
auto
K
=
src
.
numel
()
/
E
;
auto
K
=
src
.
numel
()
/
E
;
auto
N
=
out
.
size
(
reduce_dim
);
auto
avg_len
=
(
float
)
src
.
size
(
reduce_dim
)
/
(
float
)
out
.
size
(
reduce_dim
);
auto
avg_len
=
(
float
)
src
.
size
(
reduce_dim
)
/
(
float
)
out
.
size
(
reduce_dim
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int
>
(
index
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int
>
(
index
);
...
@@ -383,25 +418,37 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -383,25 +418,37 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
if
(
K
==
1
)
{
if
(
K
==
1
)
{
segment_coo_kernel
<
scalar_t
,
REDUCE
,
true
>
segment_coo_kernel
<
scalar_t
,
REDUCE
,
true
>
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
);
out_data
,
E
,
N
);
}
else
if
(
avg_len
<=
8
)
{
}
else
if
(
avg_len
<=
8
)
{
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
4
>
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
4
>
<<<
dim3
(((
E
+
(
8
*
4
)
-
1
)
/
(
8
*
4
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
<<<
dim3
(((
E
+
(
8
*
4
)
-
1
)
/
(
8
*
4
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
,
K
);
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
}
else
if
(
avg_len
<=
16
)
{
}
else
if
(
avg_len
<=
16
)
{
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
8
>
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
8
>
<<<
dim3
(((
E
+
(
8
*
8
)
-
1
)
/
(
8
*
8
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
<<<
dim3
(((
E
+
(
8
*
8
)
-
1
)
/
(
8
*
8
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
,
K
);
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
}
else
if
(
avg_len
<=
32
)
{
}
else
if
(
avg_len
<=
32
)
{
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
16
>
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
16
>
<<<
dim3
(((
E
+
(
8
*
16
)
-
1
)
/
(
8
*
16
)),
(
K
+
31
)
/
32
),
<<<
dim3
(((
E
+
(
8
*
16
)
-
1
)
/
(
8
*
16
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
arg_out_data
,
E
,
K
);
N
);
}
else
{
}
else
{
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
32
>
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
32
>
<<<
dim3
(((
E
+
(
8
*
32
)
-
1
)
/
(
8
*
32
)),
(
K
+
31
)
/
32
),
<<<
dim3
(((
E
+
(
8
*
32
)
-
1
)
/
(
8
*
32
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
arg_out_data
,
E
,
K
);
N
);
}
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
if
(
K
==
1
)
{
segment_coo_arg_kernel
<
scalar_t
>
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
,
N
);
}
else
{
segment_coo_arg_broadcast_kernel
<
scalar_t
>
<<<
BLOCKS
(
1
,
E
*
K
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
,
K
,
N
);
}
}
}
});
});
});
});
...
@@ -415,12 +462,17 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -415,12 +462,17 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
auto
count_data
=
count
.
DATA_PTR
<
scalar_t
>
();
auto
count_data
=
count
.
DATA_PTR
<
scalar_t
>
();
segment_coo_kernel
<
scalar_t
,
ADD
,
false
>
segment_coo_kernel
<
scalar_t
,
ADD
,
false
>
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
nullptr
,
index_info
,
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
nullptr
,
index_info
,
count_data
,
nullptr
,
E
);
count_data
,
E
,
N
);
});
});
count
.
clamp_
(
1
);
count
.
clamp_
(
1
);
out
.
div_
(
count
);
arg_out
=
count
;
arg_out
=
count
;
for
(
int
i
=
reduce_dim
+
1
;
i
<
out
.
dim
();
i
++
)
{
count
=
count
.
unsqueeze
(
-
1
);
}
out
.
div_
(
count
);
}
}
return
std
::
make_tuple
(
out
,
arg_out
);
return
std
::
make_tuple
(
out
,
arg_out
);
...
...
test/test_segment.py
View file @
cd6d8d68
...
@@ -3,7 +3,7 @@ from itertools import product
...
@@ -3,7 +3,7 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch_scatter
import
segment_coo
,
segment_csr
from
torch_scatter
import
segment_coo
,
segment_csr
from
torch_scatter
import
scatter_
add
,
scatter_mean
,
scatter_min
# noqa
from
torch_scatter
import
scatter_
max
from
.utils
import
tensor
from
.utils
import
tensor
...
@@ -18,24 +18,39 @@ def test_forward(dtype, device):
...
@@ -18,24 +18,39 @@ def test_forward(dtype, device):
device
)
device
)
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
# src = tensor([-1, -2, -3, -4, -5, -6], dtype, device)
src
.
requires_grad_
()
src
.
requires_grad_
()
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
# out = scatter_min(src, index, dim=0)[0]
out
,
arg
=
scatter_max
(
src
,
index
,
dim
=
0
)
print
(
'SCA'
)
print
(
out
)
print
(
arg
)
# print('SCA', out)
# grad_out = torch.randn_like(out)
# grad_out = torch.randn_like(out)
# print(grad_out)
# print(grad_out)
# out.backward(grad_out)
# out.backward(grad_out)
# print(src.grad)
# print(src.grad)
src
.
grad
=
None
# src.grad = None
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'mean'
)
out
,
arg
=
segment_coo
(
src
,
index
,
reduce
=
'max'
)
print
(
'CSR'
,
out
)
print
(
'COO'
)
print
(
out
)
print
(
arg
)
out
,
arg
=
segment_csr
(
src
,
indptr
,
reduce
=
'max'
)
print
(
'CSR'
)
print
(
out
)
print
(
arg
)
# out.backward(grad_out)
# out.backward(grad_out)
# print(src.grad)
# print(src.grad)
# out = out[0] if isinstance(out, tuple) else out
# out = out[0] if isinstance(out, tuple) else out
# out.backward(torch.randn_like(out))
# out.backward(torch.randn_like(out))
out
=
segment_coo
(
src
,
index
,
reduce
=
'm
ean'
)
#
out = segment_coo(src, index, reduce='m
ax')[0]
print
(
'COO'
,
out
)
#
print('COO', out)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment