Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
0ad76a83
Commit
0ad76a83
authored
Dec 24, 2019
by
rusty1s
Browse files
warp parallel segment implementation
parent
1b316a63
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
147 additions
and
20 deletions
+147
-20
cuda/segment.cpp
cuda/segment.cpp
+11
-5
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+55
-8
test/test_segment.py
test/test_segment.py
+60
-4
torch_scatter/segment.py
torch_scatter/segment.py
+21
-3
No files found.
cuda/segment.cpp
View file @
0ad76a83
...
...
@@ -2,17 +2,23 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
segment_add_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
);
at
::
Tensor
segment_add_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
int64_t
dim
);
void
segment_add_
thrust_
cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
segment_add
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
at
::
Tensor
segment_add
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
int64_t
dim
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
indptr
);
return
segment_add_cuda
(
src
,
indptr
,
dim
);
}
void
segment_add_thrust
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
CHECK_CUDA
(
out
);
return
segment_add_cuda
(
src
,
index
,
out
);
return
segment_add_
thrust_
cuda
(
src
,
index
,
out
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"segment_add"
,
&
segment_add
,
"Segment Add (CUDA)"
);
m
.
def
(
"segment_add_thrust"
,
&
segment_add_thrust
,
"Segment Add Thrust (CUDA)"
);
}
cuda/segment_kernel.cu
View file @
0ad76a83
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
...
...
@@ -8,10 +10,57 @@
#include "compat.cuh"
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
segment_add_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
cudaSetDevice
(
src
.
get_device
());
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
#define THREADS 256
#define FULL_MASK 0xffffffff
template
<
typename
scalar_t
,
int
TB
>
__global__
void
segment_add_kernel
(
const
scalar_t
*
src_data
,
const
int64_t
*
indptr_data
,
scalar_t
*
out_data
,
size_t
numel
)
{
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
warp_idx
=
thread_idx
/
TB
;
int
lane_idx
=
thread_idx
&
(
TB
-
1
);
if
(
warp_idx
<
numel
)
{
int
row_start
=
__ldg
(
indptr_data
+
warp_idx
);
int
row_end
=
__ldg
(
indptr_data
+
warp_idx
+
1
);
scalar_t
val
=
(
scalar_t
)
0
;
for
(
int
src_idx
=
row_start
+
lane_idx
;
src_idx
<
row_end
;
src_idx
+=
TB
)
{
val
+=
__ldg
(
src_data
+
src_idx
);
}
#pragma unroll
for
(
int
offset
=
TB
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
__shfl_down_sync
(
FULL_MASK
,
val
,
offset
);
// Parallel reduction.
if
(
lane_idx
==
0
)
{
out_data
[
warp_idx
]
=
val
;
}
}
}
at
::
Tensor
segment_add_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
int64_t
dim
)
{
auto
numel
=
indptr
.
numel
()
-
1
;
auto
out
=
at
::
empty
({
numel
},
src
.
options
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_add_kernel"
,
[
&
]
{
auto
indptr_data
=
indptr
.
DATA_PTR
<
int64_t
>
();
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
segment_add_kernel
<
scalar_t
,
32
>
<<<
(
32
*
numel
+
THREADS
-
1
)
/
THREADS
,
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_data
,
out_data
,
numel
);
});
return
out
;
}
void
segment_add_thrust_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
allocator
=
THCThrustAllocator
(
at
::
globalContext
().
lazyInitCUDA
());
auto
policy
=
thrust
::
cuda
::
par
(
allocator
).
on
(
stream
);
...
...
@@ -20,13 +69,11 @@ segment_add_cuda(at::Tensor src, at::Tensor index, at::Tensor out) {
auto
index_data
=
thrust
::
device_ptr
<
int64_t
>
(
index
.
DATA_PTR
<
int64_t
>
());
auto
key_data
=
thrust
::
device_ptr
<
int64_t
>
(
key
.
DATA_PTR
<
int64_t
>
());
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_add_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_add_
thrust_
kernel"
,
[
&
]
{
auto
src_data
=
thrust
::
device_ptr
<
scalar_t
>
(
src
.
DATA_PTR
<
scalar_t
>
());
auto
out_data
=
thrust
::
device_ptr
<
scalar_t
>
(
out
.
DATA_PTR
<
scalar_t
>
());
thrust
::
reduce_by_key
(
policy
,
index_data
,
index_data
+
index
.
size
(
0
),
thrust
::
reduce_by_key
(
policy
,
index_data
,
index_data
+
index
.
numel
(
),
src_data
,
key_data
,
out_data
);
});
return
std
::
make_tuple
(
out
,
key
);
}
test/test_segment.py
View file @
0ad76a83
import
time
from
itertools
import
product
import
pytest
import
torch
from
torch_scatter
import
segment_add
from
torch_scatter
import
segment_add
,
scatter_add
from
torch_scatter.segment
import
segment_add2
from
.utils
import
tensor
...
...
@@ -14,7 +16,61 @@ devices = [torch.device('cuda')]
def
test_forward
(
dtype
,
device
):
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
out
=
segment_add
(
src
,
index
,
dim
=
0
)
print
(
'Thrust'
,
out
)
out
,
key
=
segment_add
(
src
,
index
,
dim
=
0
)
print
(
out
)
print
(
key
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_forward2
(
dtype
,
device
):
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
# indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
indptr
=
tensor
([[
0
,
2
,
5
,
5
,
6
]],
torch
.
long
,
device
)
out
=
segment_add2
(
src
,
indptr
,
dim
=
0
)
print
(
'My'
,
out
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_benchmark
(
dtype
,
device
):
from
torch_geometric.datasets
import
Planetoid
,
Reddit
# noqa
data
=
Planetoid
(
'/tmp/Cora'
,
'Cora'
)[
0
].
to
(
device
)
# data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
data
=
Reddit
(
'/tmp/Reddit'
)[
0
].
to
(
device
)
row
,
col
=
data
.
edge_index
x
=
torch
.
randn
(
data
.
num_edges
,
device
=
device
)
print
(
row
.
size
(
0
)
/
data
.
num_nodes
)
# Warmup
for
_
in
range
(
10
):
torch
.
randn
(
100
,
100
,
device
=
device
).
sum
()
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out1
=
scatter_add
(
x
,
row
,
dim
=
0
,
dim_size
=
data
.
num_nodes
)
torch
.
cuda
.
synchronize
()
print
(
time
.
perf_counter
()
-
t
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out2
=
segment_add
(
x
,
row
,
dim
=
0
,
dim_size
=
data
.
num_nodes
)
torch
.
cuda
.
synchronize
()
print
(
time
.
perf_counter
()
-
t
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-2
)
rowcount
=
segment_add
(
torch
.
ones_like
(
row
),
row
)
rowptr
=
torch
.
cat
([
rowcount
.
new_zeros
(
1
),
rowcount
.
cumsum
(
0
)],
dim
=
0
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out3
=
segment_add2
(
x
,
rowptr
,
dim
=
0
)
torch
.
cuda
.
synchronize
()
print
(
time
.
perf_counter
()
-
t
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-2
)
torch_scatter/segment.py
View file @
0ad76a83
import
torch
from
torch_scatter.utils.gen
import
gen
from
torch_scatter.add
import
scatter_add
if
torch
.
cuda
.
is_available
():
import
torch_scatter.segment_cuda
...
...
@@ -10,6 +11,23 @@ def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
src
,
out
,
index
,
dim
=
gen
(
src
,
index
,
dim
,
out
,
dim_size
,
fill_value
)
if
src
.
size
(
dim
)
==
0
:
# pragma: no cover
return
out
assert
src
.
is_cuda
out
,
key
=
torch_scatter
.
segment_cuda
.
segment_add
(
src
,
index
,
out
)
return
out
,
key
if
not
src
.
is_cuda
:
return
scatter_add
(
src
,
index
,
dim
,
out
,
dim_size
,
fill_value
)
# index = index.transpose(dim, -1).contiguous()
# src = src.transpose(dim, -1).contiguous()
# out = out.transpose(dim, -1).contiguous()
# print(index)
# print(src)
torch_scatter
.
segment_cuda
.
segment_add_thrust
(
src
,
index
,
out
)
# out = out.transpose(dim, -1).contiguous()
# key = key.transpose(dim, -1).contiguous()
return
out
def
segment_add2
(
src
,
indptr
,
dim
=-
1
):
return
torch_scatter
.
segment_cuda
.
segment_add
(
src
,
indptr
,
dim
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment