Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
cd6d8d68
Commit
cd6d8d68
authored
Jan 09, 2020
by
rusty1s
Browse files
all cuda kernels done
parent
7e82bc0e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
180 additions
and
149 deletions
+180
-149
benchmark/scatter_segment.py
benchmark/scatter_segment.py
+59
-95
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+99
-47
test/test_segment.py
test/test_segment.py
+22
-7
No files found.
benchmark/scatter_segment.py
View file @
cd6d8d68
# flake8: noqa
import
time
import
os.path
as
osp
import
itertools
import
argparse
import
wget
import
torch
from
scipy.io
import
loadmat
import
torch_scatter
from
torch_scatter
import
scatter_add
,
scatter_mean
,
scatter_min
,
scatter_max
from
torch_scatter
import
segment_coo
,
segment_csr
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--reduce'
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda'
)
args
=
parser
.
parse_args
()
args
.
dense_reduce
=
'sum'
if
args
.
reduce
==
'add'
else
args
.
reduce
iters
=
20
device
=
'cuda'
sizes
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
short_rows
=
[
...
...
@@ -40,13 +49,13 @@ def bold(text, flag=True):
def
correctness
(
dataset
):
group
,
name
=
dataset
mat
=
loadmat
(
f
'
{
name
}
.mat'
)[
'Problem'
][
0
][
0
][
2
].
tocsr
()
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
device
,
torch
.
long
)
row
=
torch
.
from_numpy
(
mat
.
tocoo
().
row
).
to
(
device
,
torch
.
long
)
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
args
.
device
,
torch
.
long
)
row
=
torch
.
from_numpy
(
mat
.
tocoo
().
row
).
to
(
args
.
device
,
torch
.
long
)
dim_size
=
rowptr
.
size
(
0
)
-
1
for
size
in
sizes
:
try
:
x
=
torch
.
randn
((
row
.
size
(
0
),
size
),
device
=
device
)
x
=
torch
.
randn
((
row
.
size
(
0
),
size
),
device
=
args
.
device
)
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
out1
=
scatter_add
(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
)
...
...
@@ -63,92 +72,71 @@ def correctness(dataset):
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
# out1, arg_out1 = scatter_max(x, row, dim=0, dim_size=dim_size)
# out3, arg_out3 = segment_csr(x, rowptr, reduce='max')
x
=
x
.
abs_
().
mul_
(
-
1
)
# print(out1[:5])
# print(out3[:5])
out1
,
arg_out1
=
scatter_min
(
x
,
row
,
0
,
torch
.
zeros_like
(
out1
))
out2
,
arg_out2
=
segment_coo
(
x
,
row
,
reduce
=
'min'
)
out3
,
arg_out3
=
segment_csr
(
x
,
rowptr
,
reduce
=
'min'
)
# nnz = (out1 != out3).nonzero().flatten()
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
# nnz1 = nnz[0].item()
# print(rowptr[nnz1], rowptr[nnz1 + 1])
x
=
x
.
abs_
()
# print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
# print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
out1
,
arg_out1
=
scatter_max
(
x
,
row
,
0
,
torch
.
zeros_like
(
out1
))
out2
,
arg_out2
=
segment_coo
(
x
,
row
,
reduce
=
'max'
)
out3
,
arg_out3
=
segment_csr
(
x
,
rowptr
,
reduce
=
'max'
)
# print(out1[nnz1]
)
# print(out3[nnz1]
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
# assert torch.allclose(out1, out3, atol=1e-4)
# assert torch.all(arg_out1 == arg_out3)
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
def
time_func
(
func
,
x
):
try
:
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
func
(
x
)
torch
.
cuda
.
synchronize
()
return
time
.
perf_counter
()
-
t
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
return
float
(
'inf'
)
@
torch
.
no_grad
()
def
timing
(
dataset
):
group
,
name
=
dataset
mat
=
loadmat
(
f
'
{
name
}
.mat'
)[
'Problem'
][
0
][
0
][
2
].
tocsr
()
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
device
,
torch
.
long
)
row
=
torch
.
from_numpy
(
mat
.
tocoo
().
row
).
to
(
device
,
torch
.
long
)
rowptr
=
torch
.
from_numpy
(
mat
.
indptr
).
to
(
args
.
device
,
torch
.
long
)
row
=
torch
.
from_numpy
(
mat
.
tocoo
().
row
).
to
(
args
.
device
,
torch
.
long
)
row_perm
=
row
[
torch
.
randperm
(
row
.
size
(
0
))]
dim_size
=
rowptr
.
size
(
0
)
-
1
avg_row_len
=
row
.
size
(
0
)
/
dim_size
sca_row
=
lambda
x
:
getattr
(
torch_scatter
,
f
'scatter_
{
args
.
reduce
}
'
)(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
)
sca_col
=
lambda
x
:
getattr
(
torch_scatter
,
f
'scatter_
{
args
.
reduce
}
'
)(
x
,
row_perm
,
dim
=
0
,
dim_size
=
dim_size
)
seg_coo
=
lambda
x
:
segment_coo
(
x
,
row
,
reduce
=
args
.
reduce
)
seg_csr
=
lambda
x
:
segment_csr
(
x
,
rowptr
,
reduce
=
args
.
reduce
)
dense1
=
lambda
x
:
getattr
(
torch
,
args
.
dense_reduce
)(
x
,
dim
=-
2
)
dense2
=
lambda
x
:
getattr
(
torch
,
args
.
dense_reduce
)(
x
,
dim
=-
1
)
t1
,
t2
,
t3
,
t4
,
t5
,
t6
=
[],
[],
[],
[],
[],
[]
for
size
in
sizes
:
try
:
x
=
torch
.
randn
((
row
.
size
(
0
),
size
),
device
=
device
)
x
=
torch
.
randn
((
row
.
size
(
0
),
size
),
device
=
args
.
device
)
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
try
:
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
out
=
scatter_add
(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
)
del
out
torch
.
cuda
.
synchronize
()
t1
.
append
(
time
.
perf_counter
()
-
t
)
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
t1
.
append
(
float
(
'inf'
))
try
:
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
out
=
scatter_add
(
x
,
row_perm
,
dim
=
0
,
dim_size
=
dim_size
)
del
out
torch
.
cuda
.
synchronize
()
t2
.
append
(
time
.
perf_counter
()
-
t
)
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
t2
.
append
(
float
(
'inf'
))
try
:
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
out
=
segment_coo
(
x
,
row
,
dim_size
=
dim_size
,
reduce
=
'any'
)
del
out
torch
.
cuda
.
synchronize
()
t3
.
append
(
time
.
perf_counter
()
-
t
)
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
t3
.
append
(
float
(
'inf'
))
try
:
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
out
=
segment_csr
(
x
,
rowptr
,
reduce
=
'any'
)
del
out
torch
.
cuda
.
synchronize
()
t4
.
append
(
time
.
perf_counter
()
-
t
)
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
t4
.
append
(
float
(
'inf'
))
t1
+=
[
time_func
(
sca_row
,
x
)]
t2
+=
[
time_func
(
sca_col
,
x
)]
t3
+=
[
time_func
(
seg_coo
,
x
)]
t4
+=
[
time_func
(
seg_csr
,
x
)]
del
x
...
...
@@ -159,35 +147,11 @@ def timing(dataset):
try
:
x
=
torch
.
randn
((
dim_size
,
int
(
avg_row_len
+
1
),
size
),
device
=
device
)
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
try
:
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
out
=
x
.
sum
(
dim
=
1
)
del
out
torch
.
cuda
.
synchronize
()
t5
.
append
(
time
.
perf_counter
()
-
t
)
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
t5
.
append
(
float
(
'inf'
))
device
=
args
.
device
)
t5
+=
[
time_func
(
dense1
,
x
)]
x
=
x
.
view
(
dim_size
,
size
,
int
(
avg_row_len
+
1
))
x
=
x
.
squeeze
(
-
2
)
if
size
==
1
else
x
try
:
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
out
=
x
.
sum
(
dim
=-
1
)
del
out
torch
.
cuda
.
synchronize
()
t6
.
append
(
time
.
perf_counter
()
-
t
)
except
RuntimeError
:
torch
.
cuda
.
empty_cache
()
t6
.
append
(
float
(
'inf'
))
t6
+=
[
time_func
(
dense2
,
x
)]
del
x
...
...
@@ -221,7 +185,7 @@ def timing(dataset):
if
__name__
==
'__main__'
:
for
_
in
range
(
10
):
# Warmup.
torch
.
randn
(
100
,
100
,
device
=
device
).
sum
()
torch
.
randn
(
100
,
100
,
device
=
args
.
device
).
sum
()
for
dataset
in
itertools
.
chain
(
short_rows
,
long_rows
):
download
(
dataset
)
correctness
(
dataset
)
...
...
cuda/segment_kernel.cu
View file @
cd6d8d68
...
...
@@ -34,12 +34,22 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
if
(
REDUCE
==
MIN
)
{
return
std
::
numeric_limits
<
scalar_t
>::
max
();
}
else
if
(
REDUCE
==
MAX
)
{
return
std
::
numeric_limits
<
scalar_t
>::
min
();
return
std
::
numeric_limits
<
scalar_t
>::
lowest
();
}
else
{
return
(
scalar_t
)
0
;
}
}
static
inline
__host__
__device__
void
update
(
scalar_t
*
val
,
scalar_t
new_val
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
*
val
=
*
val
+
new_val
;
}
else
if
((
REDUCE
==
MIN
&&
new_val
<
*
val
)
||
(
REDUCE
==
MAX
&&
new_val
>
*
val
))
{
*
val
=
new_val
;
}
}
static
inline
__host__
__device__
void
update
(
scalar_t
*
val
,
scalar_t
new_val
,
int64_t
*
arg
,
int64_t
new_arg
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
...
...
@@ -68,9 +78,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
static
inline
__device__
void
atomic_write
(
scalar_t
*
address
,
scalar_t
val
,
int64_t
*
arg_address
,
int64_t
arg
)
{
static
inline
__device__
void
atomic_write
(
scalar_t
*
address
,
scalar_t
val
)
{
if
(
REDUCE
==
ADD
)
{
atomAdd
(
address
,
val
);
}
else
if
(
REDUCE
==
MEAN
)
{
...
...
@@ -80,14 +88,6 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
else
if
(
REDUCE
==
MAX
&&
val
>
*
address
)
{
atomMax
(
address
,
val
);
}
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
assert
(
false
);
// TODO
__syncthreads
();
if
(
*
address
==
val
)
{
*
arg_address
=
arg
;
}
}
}
};
...
...
@@ -111,7 +111,7 @@ segment_csr_kernel(const scalar_t *src_data,
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
scalar_t
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
()
,
tmp
;
scalar_t
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
int64_t
arg
,
arg_tmp
;
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
;
...
...
@@ -123,16 +123,10 @@ segment_csr_kernel(const scalar_t *src_data,
#pragma unroll
for
(
int
i
=
TB
/
2
;
i
>
0
;
i
/=
2
)
{
// Parallel reduction inside a single warp.
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
tmp
=
__shfl_down_sync
(
FULL_MASK
,
val
,
i
);
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
arg_tmp
=
__shfl_down_sync
(
FULL_MASK
,
arg
,
i
);
// Only update valid entries.
if
(
lane_idx
<
i
&&
row_start
+
lane_idx
+
i
<
row_end
)
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
tmp
,
&
arg
,
arg_tmp
);
}
else
{
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
__shfl_down_sync
(
FULL_MASK
,
val
,
i
),
&
arg
,
arg_tmp
);
}
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
__shfl_down_sync
(
FULL_MASK
,
val
,
i
),
&
arg
,
arg_tmp
);
}
if
(
lane_idx
==
0
)
{
...
...
@@ -256,7 +250,7 @@ template <typename scalar_t, ReductionType REDUCE, bool HAS_VAL>
__global__
void
segment_coo_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
)
{
scalar_t
*
out_data
,
size_t
E
,
size_t
N
)
{
// Each thread processes exactly one entry. Within a warp, we perform a
// parallel reduction across equal indices, and write the intermediate
...
...
@@ -269,40 +263,52 @@ segment_coo_kernel(const scalar_t *src_data,
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
int
idx
=
index_info
.
data
[
offset
],
next_idx
;
int
out_idx
=
(
row_idx
/
index_info
.
sizes
[
index_info
.
dims
-
1
])
*
N
+
idx
;
scalar_t
val
=
HAS_VAL
?
src_data
[
row_idx
]
:
(
scalar_t
)
1
,
tmp
;
int64_t
arg
,
arg_tmp
;
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
arg
=
row_idx
%
index_info
.
sizes
[
index_info
.
dims
-
1
];
}
#pragma unroll
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
// Parallel reduction inside a single warp.
tmp
=
__shfl_up_sync
(
FULL_MASK
,
val
,
i
);
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
arg_tmp
=
__shfl_up_sync
(
FULL_MASK
,
arg
,
i
);
}
next_idx
=
__shfl_up_sync
(
FULL_MASK
,
idx
,
i
);
assert
(
idx
>=
next_idx
);
if
(
lane_idx
>=
i
&&
idx
==
next_idx
)
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
tmp
,
&
arg
,
arg_
tmp
);
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
tmp
);
}
next_idx
=
__shfl_down_sync
(
FULL_MASK
,
idx
,
1
);
if
(
lane_idx
==
32
-
1
||
idx
!=
next_idx
)
{
Reducer
<
scalar_t
,
REDUCE
>::
atomic_write
(
out_data
+
idx
,
val
,
arg_out_data
+
idx
,
arg
);
Reducer
<
scalar_t
,
REDUCE
>::
atomic_write
(
out_data
+
out_idx
,
val
);
}
}
}
template
<
typename
scalar_t
>
__global__
void
segment_coo_arg_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
,
size_t
N
)
{
int
row_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
row_idx
<
E
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
int
idx
=
index_info
.
data
[
offset
];
int
out_idx
=
(
row_idx
/
index_info
.
sizes
[
index_info
.
dims
-
1
])
*
N
+
idx
;
scalar_t
val
=
__ldg
(
out_data
+
out_idx
);
if
(
src_data
[
row_idx
]
==
val
)
arg_out_data
[
out_idx
]
=
row_idx
%
index_info
.
sizes
[
index_info
.
dims
-
1
];
}
}
template
<
typename
scalar_t
,
ReductionType
REDUCE
,
int
TB
>
__global__
void
segment_coo_broadcast_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
,
size_t
K
)
{
scalar_t
*
out_data
,
size_t
E
,
size_t
K
,
size_t
N
)
{
// Each thread processes a single column and `TB` index entries. Coalesced
// read and write is performed in column-major order. The intermediate
...
...
@@ -314,6 +320,7 @@ __global__ void segment_coo_broadcast_kernel(
if
(
row_start
<
E
&&
col_idx
<
K
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_start
,
index_info
);
int
out_idx
=
(
row_start
/
index_info
.
sizes
[
index_info
.
dims
-
1
])
*
N
;
int
idx1
=
__ldg
(
index_info
.
data
+
offset
);
scalar_t
val
=
src_data
[
K
*
row_start
+
col_idx
];
...
...
@@ -327,15 +334,42 @@ __global__ void segment_coo_broadcast_kernel(
i
*
index_info
.
strides
[
index_info
.
dims
-
1
]);
assert
(
idx1
<=
idx2
);
if
(
idx1
==
idx2
)
{
val
+=
src_data
[
K
*
(
row_start
+
i
)
+
col_idx
];
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
src_data
[
K
*
(
row_start
+
i
)
+
col_idx
]);
}
else
{
atomAdd
(
out_data
+
K
*
idx1
+
col_idx
,
val
);
Reducer
<
scalar_t
,
REDUCE
>::
atomic_write
(
out_data
+
(
out_idx
+
idx1
)
*
K
+
col_idx
,
val
);
val
=
src_data
[
K
*
(
row_start
+
i
)
+
col_idx
];
}
idx1
=
idx2
;
}
atomAdd
(
out_data
+
K
*
idx1
+
col_idx
,
val
);
Reducer
<
scalar_t
,
REDUCE
>::
atomic_write
(
out_data
+
(
out_idx
+
idx1
)
*
K
+
col_idx
,
val
);
}
}
template
<
typename
scalar_t
>
__global__
void
segment_coo_arg_broadcast_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
,
size_t
K
,
size_t
N
)
{
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row_idx
=
thread_idx
/
K
;
int
col_idx
=
thread_idx
%
K
;
if
(
row_idx
<
E
&&
col_idx
<
K
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
int
idx
=
__ldg
(
index_info
.
data
+
offset
);
int
out_idx
=
((
row_idx
/
index_info
.
sizes
[
index_info
.
dims
-
1
])
*
N
+
idx
)
*
K
+
col_idx
;
scalar_t
val
=
__ldg
(
out_data
+
out_idx
);
if
(
src_data
[
thread_idx
]
==
val
)
arg_out_data
[
out_idx
]
=
row_idx
%
index_info
.
sizes
[
index_info
.
dims
-
1
];
}
}
...
...
@@ -371,6 +405,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
auto
E
=
index
.
numel
();
auto
K
=
src
.
numel
()
/
E
;
auto
N
=
out
.
size
(
reduce_dim
);
auto
avg_len
=
(
float
)
src
.
size
(
reduce_dim
)
/
(
float
)
out
.
size
(
reduce_dim
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int
>
(
index
);
...
...
@@ -383,25 +418,37 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
if
(
K
==
1
)
{
segment_coo_kernel
<
scalar_t
,
REDUCE
,
true
>
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
);
out_data
,
E
,
N
);
}
else
if
(
avg_len
<=
8
)
{
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
4
>
<<<
dim3
(((
E
+
(
8
*
4
)
-
1
)
/
(
8
*
4
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
,
K
);
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
}
else
if
(
avg_len
<=
16
)
{
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
8
>
<<<
dim3
(((
E
+
(
8
*
8
)
-
1
)
/
(
8
*
8
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
,
K
);
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
}
else
if
(
avg_len
<=
32
)
{
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
16
>
<<<
dim3
(((
E
+
(
8
*
16
)
-
1
)
/
(
8
*
16
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
,
K
);
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
}
else
{
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
32
>
<<<
dim3
(((
E
+
(
8
*
32
)
-
1
)
/
(
8
*
32
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
,
K
);
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
}
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
if
(
K
==
1
)
{
segment_coo_arg_kernel
<
scalar_t
>
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
,
N
);
}
else
{
segment_coo_arg_broadcast_kernel
<
scalar_t
>
<<<
BLOCKS
(
1
,
E
*
K
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
,
K
,
N
);
}
}
});
});
...
...
@@ -415,12 +462,17 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
auto
count_data
=
count
.
DATA_PTR
<
scalar_t
>
();
segment_coo_kernel
<
scalar_t
,
ADD
,
false
>
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
nullptr
,
index_info
,
count_data
,
nullptr
,
E
);
count_data
,
E
,
N
);
});
count
.
clamp_
(
1
);
out
.
div_
(
count
);
arg_out
=
count
;
for
(
int
i
=
reduce_dim
+
1
;
i
<
out
.
dim
();
i
++
)
{
count
=
count
.
unsqueeze
(
-
1
);
}
out
.
div_
(
count
);
}
return
std
::
make_tuple
(
out
,
arg_out
);
...
...
test/test_segment.py
View file @
cd6d8d68
...
...
@@ -3,7 +3,7 @@ from itertools import product
import
pytest
import
torch
from
torch_scatter
import
segment_coo
,
segment_csr
from
torch_scatter
import
scatter_
add
,
scatter_mean
,
scatter_min
# noqa
from
torch_scatter
import
scatter_
max
from
.utils
import
tensor
...
...
@@ -18,24 +18,39 @@ def test_forward(dtype, device):
device
)
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
# src = tensor([-1, -2, -3, -4, -5, -6], dtype, device)
src
.
requires_grad_
()
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
# out = scatter_min(src, index, dim=0)[0]
out
,
arg
=
scatter_max
(
src
,
index
,
dim
=
0
)
print
(
'SCA'
)
print
(
out
)
print
(
arg
)
# print('SCA', out)
# grad_out = torch.randn_like(out)
# print(grad_out)
# out.backward(grad_out)
# print(src.grad)
src
.
grad
=
None
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'mean'
)
print
(
'CSR'
,
out
)
# src.grad = None
out
,
arg
=
segment_coo
(
src
,
index
,
reduce
=
'max'
)
print
(
'COO'
)
print
(
out
)
print
(
arg
)
out
,
arg
=
segment_csr
(
src
,
indptr
,
reduce
=
'max'
)
print
(
'CSR'
)
print
(
out
)
print
(
arg
)
# out.backward(grad_out)
# print(src.grad)
# out = out[0] if isinstance(out, tuple) else out
# out.backward(torch.randn_like(out))
out
=
segment_coo
(
src
,
index
,
reduce
=
'm
ean'
)
print
(
'COO'
,
out
)
#
out = segment_coo(src, index, reduce='m
ax')[0]
#
print('COO', out)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment