Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
d762df6a
Commit
d762df6a
authored
Dec 28, 2019
by
rusty1s
Browse files
broadcasting capabilities
parent
cca0044c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
80 additions
and
31 deletions
+80
-31
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+74
-27
test/test_segment.py
test/test_segment.py
+6
-4
No files found.
cuda/segment_kernel.cu
View file @
d762df6a
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
...
...
@@ -8,32 +10,70 @@
#include "atomics.cuh"
#include "compat.cuh"
#include "index.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
// template <typename scalar_t, int TB>
// __global__ void segment_add_csr_broadcast_kernel(const scalar_t *src_data,
// const int64_t *indptr_data,
// scalar_t *out_data,
// size_t numel) {}
// template <typename T, typename I> struct IndexPtrToOffset<T, I> {
// static inline __host__ __device__ I
// get(I idx, const at::cuda::detail::TensorInfo<T, I> &info) {
// return idx;
// I offset = idx % (info.sizes[info.dims - 1] - 1);
// idx /= info.sizes[info.dims - 1] - 1;
// for (int i = info.dims - 2; i >= 0; --i) {
// offset += (idx % info.sizes[i]) * info.strides[i];
// idx /= info.sizes[i];
// }
// return offset;
// }
// };
template
<
typename
T
,
typename
I
>
struct
IndexPtrToOffset
{
static
__host__
__device__
I
get
(
I
idx
,
const
at
::
cuda
::
detail
::
TensorInfo
<
T
,
I
>
&
info
)
{
I
offset
=
idx
%
(
info
.
sizes
[
info
.
dims
-
1
]
-
1
);
idx
/=
info
.
sizes
[
info
.
dims
-
1
]
-
1
;
for
(
int
i
=
info
.
dims
-
2
;
i
>=
0
;
--
i
)
{
offset
+=
(
idx
%
info
.
sizes
[
i
])
*
info
.
strides
[
i
];
idx
/=
info
.
sizes
[
i
];
}
return
offset
;
}
};
template
<
typename
scalar_t
,
int
TB
>
__global__
void
segment_add_csr_kernel
(
const
scalar_t
*
src_data
,
const
int64_t
*
indptr_data
,
scalar_t
*
out_data
,
size_t
numel
)
{
__global__
void
segment_add_csr_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
indptr_info
,
scalar_t
*
out_data
,
size_t
N
,
size_t
E
)
{
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
warp_idx
=
thread_idx
/
TB
;
int
lane_idx
=
thread_idx
&
(
TB
-
1
);
if
(
warp_idx
<
numel
)
{
int
row_start
=
__ldg
(
indptr_data
+
warp_idx
);
int
row_end
=
__ldg
(
indptr_data
+
warp_idx
+
1
);
if
(
warp_idx
<
N
)
{
auto
offset
=
IndexPtrToOffset
<
int64_t
,
int
>::
get
(
warp_idx
,
indptr_info
);
int
row_start
=
__ldg
(
indptr_info
.
data
+
offset
);
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
1
);
scalar_t
val
=
(
scalar_t
)
0
;
offset
=
(
warp_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
;
for
(
int
src_idx
=
row_start
+
lane_idx
;
src_idx
<
row_end
;
src_idx
+=
TB
)
{
val
+=
__ldg
(
src_data
+
src_idx
)
;
val
+=
src_data
[
offset
+
src_idx
]
;
}
#pragma unroll
for
(
int
offset
=
TB
/
2
;
offset
>
0
;
offset
/=
2
)
val
+=
__shfl_down_sync
(
FULL_MASK
,
val
,
offset
);
// Parallel reduction.
for
(
int
i
=
TB
/
2
;
i
>
0
;
i
/=
2
)
val
+=
__shfl_down_sync
(
FULL_MASK
,
val
,
i
);
// Parallel reduction.
if
(
lane_idx
==
0
)
{
out_data
[
warp_idx
]
=
val
;
...
...
@@ -42,33 +82,40 @@ __global__ void segment_add_csr_kernel(const scalar_t *src_data,
}
at
::
Tensor
segment_add_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
)
{
auto
numel
=
indptr
.
numel
()
-
1
;
// TODO
auto
avg_length
=
(
float
)
src
.
numel
()
/
(
float
)
numel
;
src
=
src
.
contiguous
();
AT_ASSERTM
(
indptr
.
stride
(
-
1
)
==
1
);
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
());
auto
reduce_dim
=
indptr
.
dim
()
-
1
;
auto
sizes
=
src
.
sizes
().
vec
();
sizes
[
reduce_dim
]
=
indptr
.
size
(
reduce_dim
)
-
1
;
auto
out
=
at
::
empty
(
sizes
,
src
.
options
());
auto
out
=
at
::
empty
({
numel
},
src
.
options
());
auto
N
=
(
indptr
.
size
(
-
1
)
-
1
)
*
(
indptr
.
numel
()
/
indptr
.
size
(
-
1
));
auto
E
=
src
.
size
(
reduce_dim
);
auto
avg_length
=
(
float
)
src
.
size
(
reduce_dim
)
/
(
float
)
out
.
size
(
reduce_dim
);
auto
indptr_
data
=
indptr
.
DATA_PTR
<
int64_t
>
(
);
auto
indptr_
info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int
>
(
indptr
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_add_csr_kernel"
,
[
&
]
{
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
if
(
avg_length
<=
4
)
segment_add_csr_kernel
<
scalar_t
,
4
>
<<<
BLOCKS
(
4
,
numel
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_data
,
out_data
,
numel
);
else
if
(
avg_length
<=
8
)
segment_add_csr_kernel
<
scalar_t
,
8
>
<<<
BLOCKS
(
8
,
numel
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_data
,
out_data
,
numel
);
else
if
(
avg_length
<=
16
)
if
(
avg_length
<=
4
)
{
segment_add_csr_kernel
<
scalar_t
,
4
><<<
BLOCKS
(
4
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
N
,
E
);
}
else
if
(
avg_length
<=
8
)
{
segment_add_csr_kernel
<
scalar_t
,
8
><<<
BLOCKS
(
8
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
N
,
E
);
}
else
if
(
avg_length
<=
16
)
{
segment_add_csr_kernel
<
scalar_t
,
16
>
<<<
BLOCKS
(
16
,
numel
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_
data
,
out_data
,
numel
);
else
<<<
BLOCKS
(
16
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_
info
,
out_data
,
N
,
E
);
}
else
{
segment_add_csr_kernel
<
scalar_t
,
32
>
<<<
BLOCKS
(
32
,
numel
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_data
,
out_data
,
numel
);
<<<
BLOCKS
(
32
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
N
,
E
);
}
});
return
out
;
...
...
test/test_segment.py
View file @
d762df6a
...
...
@@ -22,9 +22,11 @@ def test_forward(dtype, device):
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_forward2
(
dtype
,
device
):
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
src
=
tensor
([
[
1
,
2
,
3
,
4
,
5
,
6
],
[
1
,
3
,
5
,
7
,
9
,
11
]],
dtype
,
device
)
indptr
=
tensor
([[
0
,
2
,
5
,
5
,
6
]],
torch
.
long
,
device
)
indptr
=
indptr
.
view
(
1
,
-
1
).
expand
(
2
,
-
1
)
assert
indptr
.
stride
(
-
1
)
==
1
out
=
segment_add_csr
(
src
,
indptr
)
print
(
'CSR'
,
out
)
...
...
@@ -36,9 +38,9 @@ def test_forward2(dtype, device):
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_benchmark
(
dtype
,
device
):
from
torch_geometric.datasets
import
Planetoid
,
Reddit
# noqa
data
=
Planetoid
(
'/tmp/Cora'
,
'Cora'
)[
0
].
to
(
device
)
data
=
Planetoid
(
'/tmp/PubMed'
,
'PubMed'
)[
0
].
to
(
device
)
#
data = Reddit('/tmp/Reddit')[0].to(device)
#
data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
#
data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
data
=
Reddit
(
'/tmp/Reddit'
)[
0
].
to
(
device
)
row
,
col
=
data
.
edge_index
x
=
torch
.
randn
(
data
.
num_edges
,
device
=
device
)
print
(
row
.
size
(
0
)
/
data
.
num_nodes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment