Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
58d0025d
Commit
58d0025d
authored
Dec 27, 2019
by
rusty1s
Browse files
coo segment impl
parent
fe67ccbd
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
82 additions
and
27 deletions
+82
-27
cuda/segment.cpp
cuda/segment.cpp
+4
-3
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+42
-4
test/test_segment.py
test/test_segment.py
+28
-10
torch_scatter/segment.py
torch_scatter/segment.py
+8
-10
No files found.
cuda/segment.cpp
View file @
58d0025d
...
...
@@ -3,7 +3,7 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at
::
Tensor
segment_add_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
);
at
::
Tensor
segment_add_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
);
void
segment_add_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
);
void
segment_add_thrust_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
);
...
...
@@ -13,10 +13,11 @@ at::Tensor segment_add_csr(at::Tensor src, at::Tensor indptr) {
return
segment_add_csr_cuda
(
src
,
indptr
);
}
at
::
Tensor
segment_add_coo
(
at
::
Tensor
src
,
at
::
Tensor
index
)
{
void
segment_add_coo
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
return
segment_add_coo_cuda
(
src
,
index
);
CHECK_CUDA
(
out
);
segment_add_coo_cuda
(
src
,
index
,
out
);
}
void
segment_add_thrust
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
...
...
cuda/segment_kernel.cu
View file @
58d0025d
...
...
@@ -6,6 +6,7 @@
#include <thrust/execution_policy.h>
#include "atomics.cuh"
#include "compat.cuh"
#define THREADS 256
...
...
@@ -41,14 +42,14 @@ __global__ void segment_add_csr_kernel(const scalar_t *src_data,
}
at
::
Tensor
segment_add_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
)
{
auto
numel
=
indptr
.
numel
()
-
1
;
auto
numel
=
indptr
.
numel
()
-
1
;
// TODO
auto
avg_length
=
(
float
)
src
.
numel
()
/
(
float
)
numel
;
auto
out
=
at
::
empty
({
numel
},
src
.
options
());
auto
indptr_data
=
indptr
.
DATA_PTR
<
int64_t
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_add_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_add_
csr_
kernel"
,
[
&
]
{
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
...
...
@@ -73,8 +74,45 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
return
out
;
}
at
::
Tensor
segment_add_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
)
{
return
src
;
template
<
typename
scalar_t
>
__global__
void
segment_add_coo_kernel
(
const
scalar_t
*
src_data
,
const
int64_t
*
index_data
,
scalar_t
*
out_data
,
size_t
numel
)
{
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
lane_idx
=
thread_idx
&
(
32
-
1
);
if
(
thread_idx
<
numel
)
{
auto
idx
=
__ldg
(
index_data
+
thread_idx
);
scalar_t
val
=
src_data
[
thread_idx
],
tmp
;
#pragma unroll
for
(
int
offset
=
1
;
offset
<
32
;
offset
*=
2
)
{
tmp
=
__shfl_up_sync
(
FULL_MASK
,
val
,
offset
);
if
(
lane_idx
>=
offset
&&
idx
==
__ldg
(
index_data
+
thread_idx
-
offset
))
{
val
+=
tmp
;
}
}
if
(
lane_idx
==
31
||
idx
!=
__ldg
(
index_data
+
thread_idx
+
1
))
{
atomAdd
(
out_data
+
idx
,
val
);
}
}
}
void
segment_add_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
auto
numel
=
src
.
numel
();
auto
index_data
=
index
.
DATA_PTR
<
int64_t
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_add_coo_kernel"
,
[
&
]
{
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
segment_add_coo_kernel
<
scalar_t
><<<
BLOCKS
(
1
,
numel
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_data
,
out_data
,
numel
);
});
}
void
segment_add_thrust_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
...
...
test/test_segment.py
View file @
58d0025d
...
...
@@ -4,7 +4,7 @@ from itertools import product
import
pytest
import
torch
from
torch_scatter
import
segment_add
,
scatter_add
from
torch_scatter.segment
import
segment_add
2
from
torch_scatter.segment
import
segment_add
_csr
,
segment_add_coo
from
.utils
import
tensor
...
...
@@ -23,20 +23,22 @@ def test_forward(dtype, device):
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_forward2
(
dtype
,
device
):
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
# indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
indptr
=
tensor
([[
0
,
2
,
5
,
5
,
6
]],
torch
.
long
,
device
)
out
=
segment_add_csr
(
src
,
indptr
)
print
(
'CSR'
,
out
)
out
=
segment_add2
(
src
,
indptr
)
print
(
'My'
,
out
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
out
=
segment_add_coo
(
src
,
index
)
print
(
'COO'
,
out
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_benchmark
(
dtype
,
device
):
from
torch_geometric.datasets
import
Planetoid
,
Reddit
# noqa
#
data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
data
=
Planetoid
(
'/tmp/Cora'
,
'Cora'
)[
0
].
to
(
device
)
# data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
data
=
Reddit
(
'/tmp/Reddit'
)[
0
].
to
(
device
)
#
data = Reddit('/tmp/Reddit')[0].to(device)
row
,
col
=
data
.
edge_index
x
=
torch
.
randn
(
data
.
num_edges
,
device
=
device
)
print
(
row
.
size
(
0
)
/
data
.
num_nodes
)
...
...
@@ -50,7 +52,14 @@ def test_benchmark(dtype, device):
for
_
in
range
(
100
):
out1
=
scatter_add
(
x
,
row
,
dim
=
0
,
dim_size
=
data
.
num_nodes
)
torch
.
cuda
.
synchronize
()
print
(
time
.
perf_counter
()
-
t
)
print
(
'Scatter Row'
,
time
.
perf_counter
()
-
t
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
scatter_add
(
x
,
col
,
dim
=
0
,
dim_size
=
data
.
num_nodes
)
torch
.
cuda
.
synchronize
()
print
(
'Scatter Col'
,
time
.
perf_counter
()
-
t
)
torch
.
cuda
.
synchronize
()
...
...
@@ -58,7 +67,7 @@ def test_benchmark(dtype, device):
for
_
in
range
(
100
):
out2
=
segment_add
(
x
,
row
,
dim
=
0
,
dim_size
=
data
.
num_nodes
)
torch
.
cuda
.
synchronize
()
print
(
time
.
perf_counter
()
-
t
)
print
(
'Thrust'
,
time
.
perf_counter
()
-
t
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-2
)
...
...
@@ -69,8 +78,17 @@ def test_benchmark(dtype, device):
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out3
=
segment_add
2
(
x
,
rowptr
)
out3
=
segment_add
_csr
(
x
,
rowptr
)
torch
.
cuda
.
synchronize
()
print
(
time
.
perf_counter
()
-
t
)
print
(
'CSR'
,
time
.
perf_counter
()
-
t
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-2
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out4
=
segment_add_coo
(
x
,
row
,
dim_size
=
data
.
num_nodes
)
torch
.
cuda
.
synchronize
()
print
(
'COO'
,
time
.
perf_counter
()
-
t
)
assert
torch
.
allclose
(
out1
,
out4
,
atol
=
1e-2
)
torch_scatter/segment.py
View file @
58d0025d
...
...
@@ -15,19 +15,17 @@ def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
if
not
src
.
is_cuda
:
return
scatter_add
(
src
,
index
,
dim
,
out
,
dim_size
,
fill_value
)
# index = index.transpose(dim, -1).contiguous()
# src = src.transpose(dim, -1).contiguous()
# out = out.transpose(dim, -1).contiguous()
# print(index)
# print(src)
torch_scatter
.
segment_cuda
.
segment_add_thrust
(
src
,
index
,
out
)
# out = out.transpose(dim, -1).contiguous()
# key = key.transpose(dim, -1).contiguous()
return
out
def
segment_add
2
(
src
,
indptr
):
def
segment_add
_csr
(
src
,
indptr
):
return
torch_scatter
.
segment_cuda
.
segment_add_csr
(
src
,
indptr
)
def
segment_add_coo
(
src
,
index
,
dim_size
=
None
):
dim_size
=
index
.
max
().
item
()
+
1
if
dim_size
is
None
else
dim_size
out
=
src
.
new_zeros
(
dim_size
)
torch_scatter
.
segment_cuda
.
segment_add_coo
(
src
,
index
,
out
)
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment