Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
1adc8a71
Commit
1adc8a71
authored
Jan 08, 2020
by
rusty1s
Browse files
segment any
parent
3c89ebc2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
106 additions
and
22 deletions
+106
-22
benchmark/main.py
benchmark/main.py
+8
-8
cuda/gather.cpp
cuda/gather.cpp
+19
-0
cuda/gather_kernel.cu
cuda/gather_kernel.cu
+42
-0
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+28
-5
test/test_segment.py
test/test_segment.py
+6
-6
torch_scatter/segment.py
torch_scatter/segment.py
+3
-3
No files found.
benchmark/main.py
View file @
1adc8a71
...
@@ -48,11 +48,11 @@ def correctness(dataset):
...
@@ -48,11 +48,11 @@ def correctness(dataset):
for
size
in
sizes
:
for
size
in
sizes
:
try
:
try
:
x
=
torch
.
randn
((
row
.
size
(
0
),
size
),
device
=
device
)
x
=
torch
.
randn
((
row
.
size
(
0
),
size
),
device
=
device
)
x
=
x
.
un
squeeze
(
-
1
)
if
size
==
1
else
x
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
out1
=
scatter_add
(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
)
out1
=
scatter_add
(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
)
out2
=
segment_coo
(
x
,
row
,
dim_size
=
dim_size
)
out2
=
segment_coo
(
x
,
row
,
dim_size
=
dim_size
,
reduce
=
'add'
)
out3
=
segment_csr
(
x
,
rowptr
)
out3
=
segment_csr
(
x
,
rowptr
,
reduce
=
'add'
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
...
@@ -74,7 +74,7 @@ def timing(dataset):
...
@@ -74,7 +74,7 @@ def timing(dataset):
for
size
in
sizes
:
for
size
in
sizes
:
try
:
try
:
x
=
torch
.
randn
((
row
.
size
(
0
),
size
),
device
=
device
)
x
=
torch
.
randn
((
row
.
size
(
0
),
size
),
device
=
device
)
x
=
x
.
un
squeeze
(
-
1
)
if
size
==
1
else
x
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
try
:
try
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -104,7 +104,7 @@ def timing(dataset):
...
@@ -104,7 +104,7 @@ def timing(dataset):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
for
_
in
range
(
iters
):
out
=
segment_coo
(
x
,
row
,
dim_size
=
dim_size
)
out
=
segment_coo
(
x
,
row
,
dim_size
=
dim_size
,
reduce
=
'any'
)
del
out
del
out
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t3
.
append
(
time
.
perf_counter
()
-
t
)
t3
.
append
(
time
.
perf_counter
()
-
t
)
...
@@ -116,7 +116,7 @@ def timing(dataset):
...
@@ -116,7 +116,7 @@ def timing(dataset):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
for
_
in
range
(
iters
):
out
=
segment_csr
(
x
,
rowptr
)
out
=
segment_csr
(
x
,
rowptr
,
reduce
=
'any'
)
del
out
del
out
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t4
.
append
(
time
.
perf_counter
()
-
t
)
t4
.
append
(
time
.
perf_counter
()
-
t
)
...
@@ -134,7 +134,7 @@ def timing(dataset):
...
@@ -134,7 +134,7 @@ def timing(dataset):
try
:
try
:
x
=
torch
.
randn
((
dim_size
,
int
(
avg_row_len
+
1
),
size
),
x
=
torch
.
randn
((
dim_size
,
int
(
avg_row_len
+
1
),
size
),
device
=
device
)
device
=
device
)
x
=
x
.
un
squeeze
(
-
1
)
if
size
==
1
else
x
x
=
x
.
squeeze
(
-
1
)
if
size
==
1
else
x
try
:
try
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -149,7 +149,7 @@ def timing(dataset):
...
@@ -149,7 +149,7 @@ def timing(dataset):
t5
.
append
(
float
(
'inf'
))
t5
.
append
(
float
(
'inf'
))
x
=
x
.
view
(
dim_size
,
size
,
int
(
avg_row_len
+
1
))
x
=
x
.
view
(
dim_size
,
size
,
int
(
avg_row_len
+
1
))
x
=
x
.
un
squeeze
(
-
2
)
if
size
==
1
else
x
x
=
x
.
squeeze
(
-
2
)
if
size
==
1
else
x
try
:
try
:
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
...
cuda/gather.cpp
0 → 100644
View file @
1adc8a71
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at
::
Tensor
gather_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
);
at
::
Tensor
gather_csr
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
indptr
);
if
(
out_opt
.
has_value
())
CHECK_CUDA
(
out_opt
.
value
());
return
gather_csr_cuda
(
src
,
indptr
,
out_opt
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"gather_csr"
,
&
gather_csr
,
"Gather CSR (CUDA)"
);
}
cuda/gather_kernel.cu
0 → 100644
View file @
1adc8a71
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "atomics.cuh"
#include "compat.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
at
::
Tensor
gather_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
)
{
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
());
for
(
int
i
=
0
;
i
<
indptr
.
dim
()
-
1
;
i
++
)
AT_ASSERTM
(
src
.
size
(
i
)
==
indptr
.
size
(
i
));
src
=
src
.
contiguous
();
auto
gather_dim
=
indptr
.
dim
()
-
1
;
AT_ASSERTM
(
src
.
size
(
gather_dim
)
==
indptr
.
size
(
gather_dim
)
-
1
);
at
::
Tensor
out
;
if
(
out_opt
.
has_value
())
{
out
=
out_opt
.
value
().
contiguous
();
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
if
(
i
!=
gather_dim
)
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
));
}
else
{
int64_t
*
d_gather_size
=
indptr
.
flatten
()[
-
1
].
DATA_PTR
<
int64_t
>
();
int64_t
*
h_gather_size
;
cudaMemcpy
(
h_gather_size
,
d_gather_size
,
sizeof
(
int64_t
),
cudaMemcpyDeviceToHost
);
auto
sizes
=
src
.
sizes
().
vec
();
sizes
[
gather_dim
]
=
*
h_gather_size
;
out
=
at
::
empty
(
sizes
,
src
.
options
());
}
return
out
;
}
cuda/segment_kernel.cu
View file @
1adc8a71
...
@@ -11,7 +11,6 @@
...
@@ -11,7 +11,6 @@
#define FULL_MASK 0xffffffff
#define FULL_MASK 0xffffffff
enum
ReductionType
{
ADD
,
MEAN
,
MIN
,
MAX
};
enum
ReductionType
{
ADD
,
MEAN
,
MIN
,
MAX
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
[&] { \
if (reduce == "add") { \
if (reduce == "add") { \
...
@@ -42,12 +41,12 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
...
@@ -42,12 +41,12 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
static
inline
__host__
__device__
void
update
(
scalar_t
*
val
,
scalar_t
new_val
,
static
inline
__host__
__device__
void
update
(
scalar_t
*
val
,
scalar_t
new_val
,
int64_t
*
arg
,
int64_t
new_arg
)
{
int64_t
*
arg
,
int64_t
new_arg
)
{
if
((
REDUCE
==
MIN
&&
new_val
<
*
val
)
||
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
(
REDUCE
==
MAX
&&
new_val
>
*
val
))
{
*
val
=
*
val
+
new_val
;
}
else
if
((
REDUCE
==
MIN
&&
new_val
<
*
val
)
||
(
REDUCE
==
MAX
&&
new_val
>
*
val
))
{
*
val
=
new_val
;
*
val
=
new_val
;
*
arg
=
new_arg
;
*
arg
=
new_arg
;
}
else
{
*
val
=
*
val
+
new_val
;
}
}
}
}
...
@@ -220,6 +219,22 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
...
@@ -220,6 +219,22 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
}
if
(
reduce
==
"any"
)
{
auto
index
=
indptr
.
narrow
(
reduce_dim
,
0
,
indptr
.
size
(
reduce_dim
)
-
1
);
auto
index2
=
indptr
.
narrow
(
reduce_dim
,
1
,
indptr
.
size
(
reduce_dim
)
-
1
);
auto
mask
=
(
index2
-
index
)
==
0
;
for
(
int
i
=
reduce_dim
+
1
;
i
<
src
.
dim
();
i
++
)
{
index
=
index
.
unsqueeze
(
-
1
);
mask
=
mask
.
unsqueeze
(
-
1
);
}
at
::
gather_out
(
out
,
src
,
reduce_dim
,
index
.
expand
(
out
.
sizes
()));
out
.
masked_fill_
(
mask
.
expand
(
out
.
sizes
()),
0
);
return
std
::
make_tuple
(
out
,
arg_out
);
}
auto
N
=
out
.
size
(
reduce_dim
)
*
(
indptr
.
numel
()
/
indptr
.
size
(
-
1
));
auto
N
=
out
.
size
(
reduce_dim
)
*
(
indptr
.
numel
()
/
indptr
.
size
(
-
1
));
auto
K
=
out
.
numel
()
/
N
;
auto
K
=
out
.
numel
()
/
N
;
auto
E
=
src
.
size
(
reduce_dim
);
auto
E
=
src
.
size
(
reduce_dim
);
...
@@ -351,6 +366,14 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -351,6 +366,14 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
}
if
(
reduce
==
"any"
)
{
for
(
int
i
=
reduce_dim
+
1
;
i
<
src
.
dim
();
i
++
)
{
index
=
index
.
unsqueeze
(
-
1
);
}
out
.
scatter_
(
reduce_dim
,
index
.
expand
(
src
.
sizes
()),
src
);
return
std
::
make_tuple
(
out
,
arg_out
);
}
auto
E
=
index
.
numel
();
auto
E
=
index
.
numel
();
auto
K
=
src
.
numel
()
/
index
.
numel
();
auto
K
=
src
.
numel
()
/
index
.
numel
();
auto
avg_len
=
(
float
)
src
.
size
(
reduce_dim
)
/
(
float
)
out
.
size
(
reduce_dim
);
auto
avg_len
=
(
float
)
src
.
size
(
reduce_dim
)
/
(
float
)
out
.
size
(
reduce_dim
);
...
...
test/test_segment.py
View file @
1adc8a71
...
@@ -16,16 +16,16 @@ def test_forward(dtype, device):
...
@@ -16,16 +16,16 @@ def test_forward(dtype, device):
src
=
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
9
,
10
],
[
11
,
12
]],
dtype
,
src
=
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
9
,
10
],
[
11
,
12
]],
dtype
,
device
)
device
)
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
#
src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
src
.
requires_grad_
()
#
src.requires_grad_()
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'max'
)
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'any'
)
out
=
out
[
0
]
if
isinstance
(
out
,
tuple
)
else
out
print
(
'CSR'
,
out
)
print
(
'CSR'
,
out
)
# out = out[0] if isinstance(out, tuple) else out
out
.
backward
(
torch
.
randn_like
(
out
))
#
out.backward(torch.randn_like(out))
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
out
=
segment_coo
(
src
,
index
,
reduce
=
'a
dd
'
)
out
=
segment_coo
(
src
,
index
,
reduce
=
'a
ny
'
)
print
(
'COO'
,
out
)
print
(
'COO'
,
out
)
torch_scatter/segment.py
View file @
1adc8a71
...
@@ -7,7 +7,7 @@ if torch.cuda.is_available():
...
@@ -7,7 +7,7 @@ if torch.cuda.is_available():
class
SegmentCSR
(
torch
.
autograd
.
Function
):
class
SegmentCSR
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
src
,
indptr
,
out
,
reduce
):
def
forward
(
ctx
,
src
,
indptr
,
out
,
reduce
):
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
assert
reduce
in
[
'any'
,
'add'
,
'mean'
,
'min'
,
'max'
]
assert
indptr
.
dtype
==
torch
.
long
assert
indptr
.
dtype
==
torch
.
long
if
out
is
not
None
:
if
out
is
not
None
:
...
@@ -30,12 +30,12 @@ class SegmentCSR(torch.autograd.Function):
...
@@ -30,12 +30,12 @@ class SegmentCSR(torch.autograd.Function):
def
segment_coo
(
src
,
index
,
out
=
None
,
dim_size
=
None
,
reduce
=
'add'
):
def
segment_coo
(
src
,
index
,
out
=
None
,
dim_size
=
None
,
reduce
=
'add'
):
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
assert
reduce
in
[
'any'
,
'add'
,
'mean'
,
'min'
,
'max'
]
if
out
is
None
:
if
out
is
None
:
dim_size
=
index
.
max
().
item
()
+
1
if
dim_size
is
None
else
dim_size
dim_size
=
index
.
max
().
item
()
+
1
if
dim_size
is
None
else
dim_size
size
=
list
(
src
.
size
())
size
=
list
(
src
.
size
())
size
[
index
.
dim
()
-
1
]
=
dim_size
size
[
index
.
dim
()
-
1
]
=
dim_size
out
=
src
.
new_zeros
(
size
)
# TODO: DEPEND
ENT
ON REDUCE
out
=
src
.
new_zeros
(
size
)
# TODO: DEPEND
S
ON REDUCE
assert
index
.
dtype
==
torch
.
long
and
src
.
dtype
==
out
.
dtype
assert
index
.
dtype
==
torch
.
long
and
src
.
dtype
==
out
.
dtype
out
,
arg_out
=
segment_cuda
.
segment_coo
(
src
,
index
,
out
,
reduce
)
out
,
arg_out
=
segment_cuda
.
segment_coo
(
src
,
index
,
out
,
reduce
)
return
out
if
arg_out
is
None
else
(
out
,
arg_out
)
return
out
if
arg_out
is
None
else
(
out
,
arg_out
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment