Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
fe67ccbd
Commit
fe67ccbd
authored
Dec 24, 2019
by
rusty1s
Browse files
update with variable TB
parent
0ad76a83
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
19 deletions
+46
-19
cuda/segment.cpp
cuda/segment.cpp
+13
-4
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+28
-10
test/test_segment.py
test/test_segment.py
+3
-3
torch_scatter/segment.py
torch_scatter/segment.py
+2
-2
No files found.
cuda/segment.cpp
View file @
fe67ccbd
...
...
@@ -2,13 +2,21 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at
::
Tensor
segment_add_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
int64_t
dim
);
at
::
Tensor
segment_add_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
);
at
::
Tensor
segment_add_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
);
void
segment_add_thrust_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
);
at
::
Tensor
segment_add
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
int64_t
dim
)
{
at
::
Tensor
segment_add
_csr
(
at
::
Tensor
src
,
at
::
Tensor
indptr
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
indptr
);
return
segment_add_cuda
(
src
,
indptr
,
dim
);
return
segment_add_csr_cuda
(
src
,
indptr
);
}
at
::
Tensor
segment_add_coo
(
at
::
Tensor
src
,
at
::
Tensor
index
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
return
segment_add_coo_cuda
(
src
,
index
);
}
void
segment_add_thrust
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
...
...
@@ -19,6 +27,7 @@ void segment_add_thrust(at::Tensor src, at::Tensor index, at::Tensor out) {
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"segment_add"
,
&
segment_add
,
"Segment Add (CUDA)"
);
m
.
def
(
"segment_add_csr"
,
&
segment_add_csr
,
"Segment Add CSR (CUDA)"
);
m
.
def
(
"segment_add_coo"
,
&
segment_add_coo
,
"Segment Add COO (CUDA)"
);
m
.
def
(
"segment_add_thrust"
,
&
segment_add_thrust
,
"Segment Add Thrust (CUDA)"
);
}
cuda/segment_kernel.cu
View file @
fe67ccbd
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
...
...
@@ -11,12 +9,13 @@
#include "compat.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
template
<
typename
scalar_t
,
int
TB
>
__global__
void
segment_add_kernel
(
const
scalar_t
*
src_data
,
const
int64_t
*
indptr_data
,
scalar_t
*
out_data
,
size_t
numel
)
{
__global__
void
segment_add_
csr_
kernel
(
const
scalar_t
*
src_data
,
const
int64_t
*
indptr_data
,
scalar_t
*
out_data
,
size_t
numel
)
{
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
warp_idx
=
thread_idx
/
TB
;
...
...
@@ -41,24 +40,43 @@ __global__ void segment_add_kernel(const scalar_t *src_data,
}
}
at
::
Tensor
segment_add_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
int64_t
dim
)
{
at
::
Tensor
segment_add_
csr_
cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
)
{
auto
numel
=
indptr
.
numel
()
-
1
;
auto
avg_length
=
(
float
)
src
.
numel
()
/
(
float
)
numel
;
auto
out
=
at
::
empty
({
numel
},
src
.
options
());
auto
indptr_data
=
indptr
.
DATA_PTR
<
int64_t
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_add_kernel"
,
[
&
]
{
auto
indptr_data
=
indptr
.
DATA_PTR
<
int64_t
>
();
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
segment_add_kernel
<
scalar_t
,
32
>
<<<
(
32
*
numel
+
THREADS
-
1
)
/
THREADS
,
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_data
,
out_data
,
numel
);
if
(
avg_length
<=
4
)
segment_add_csr_kernel
<
scalar_t
,
4
>
<<<
BLOCKS
(
4
,
numel
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_data
,
out_data
,
numel
);
else
if
(
avg_length
<=
8
)
segment_add_csr_kernel
<
scalar_t
,
8
>
<<<
BLOCKS
(
8
,
numel
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_data
,
out_data
,
numel
);
else
if
(
avg_length
<=
16
)
segment_add_csr_kernel
<
scalar_t
,
16
>
<<<
BLOCKS
(
16
,
numel
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_data
,
out_data
,
numel
);
else
segment_add_csr_kernel
<
scalar_t
,
32
>
<<<
BLOCKS
(
32
,
numel
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_data
,
out_data
,
numel
);
});
return
out
;
}
at
::
Tensor
segment_add_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
)
{
return
src
;
}
void
segment_add_thrust_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
allocator
=
THCThrustAllocator
(
at
::
globalContext
().
lazyInitCUDA
());
...
...
test/test_segment.py
View file @
fe67ccbd
...
...
@@ -27,14 +27,14 @@ def test_forward2(dtype, device):
indptr
=
tensor
([[
0
,
2
,
5
,
5
,
6
]],
torch
.
long
,
device
)
out
=
segment_add2
(
src
,
indptr
,
dim
=
0
)
out
=
segment_add2
(
src
,
indptr
)
print
(
'My'
,
out
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_benchmark
(
dtype
,
device
):
from
torch_geometric.datasets
import
Planetoid
,
Reddit
# noqa
data
=
Planetoid
(
'/tmp/Cora'
,
'Cora'
)[
0
].
to
(
device
)
#
data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
# data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
data
=
Reddit
(
'/tmp/Reddit'
)[
0
].
to
(
device
)
row
,
col
=
data
.
edge_index
...
...
@@ -69,7 +69,7 @@ def test_benchmark(dtype, device):
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out3
=
segment_add2
(
x
,
rowptr
,
dim
=
0
)
out3
=
segment_add2
(
x
,
rowptr
)
torch
.
cuda
.
synchronize
()
print
(
time
.
perf_counter
()
-
t
)
...
...
torch_scatter/segment.py
View file @
fe67ccbd
...
...
@@ -29,5 +29,5 @@ def segment_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
return
out
def
segment_add2
(
src
,
indptr
,
dim
=-
1
):
return
torch_scatter
.
segment_cuda
.
segment_add
(
src
,
indptr
,
dim
)
def
segment_add2
(
src
,
indptr
):
return
torch_scatter
.
segment_cuda
.
segment_add
_csr
(
src
,
indptr
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment