Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
d4325fd1
Commit
d4325fd1
authored
Dec 31, 2019
by
rusty1s
Browse files
segment coo kernels
parent
0b3069fe
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
140 additions
and
144 deletions
+140
-144
cuda/segment.cpp
cuda/segment.cpp
+0
-8
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+116
-98
test/test_segment.py
test/test_segment.py
+20
-27
torch_scatter/segment.py
torch_scatter/segment.py
+4
-11
No files found.
cuda/segment.cpp
View file @
d4325fd1
...
...
@@ -20,15 +20,7 @@ void segment_add_coo(at::Tensor src, at::Tensor index, at::Tensor out) {
segment_add_coo_cuda
(
src
,
index
,
out
);
}
void
segment_add_thrust
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
CHECK_CUDA
(
out
);
return
segment_add_thrust_cuda
(
src
,
index
,
out
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"segment_add_csr"
,
&
segment_add_csr
,
"Segment Add CSR (CUDA)"
);
m
.
def
(
"segment_add_coo"
,
&
segment_add_coo
,
"Segment Add COO (CUDA)"
);
m
.
def
(
"segment_add_thrust"
,
&
segment_add_thrust
,
"Segment Add Thrust (CUDA)"
);
}
cuda/segment_kernel.cu
View file @
d4325fd1
...
...
@@ -3,19 +3,15 @@
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
#include <thrust/execution_policy.h>
#include "atomics.cuh"
#include "compat.cuh"
#include "index.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
// We need our own `IndexToOffset` implementation since we do not want to access
// the last element of the `indexptr`.
template
<
typename
T
,
typename
I
>
struct
IndexPtrToOffset
{
static
__host__
__device__
I
get
(
I
idx
,
const
at
::
cuda
::
detail
::
TensorInfo
<
T
,
I
>
&
info
)
{
...
...
@@ -36,12 +32,15 @@ __global__ void segment_add_csr_kernel(
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
indptr_info
,
scalar_t
*
out_data
,
size_t
N
,
size_t
E
)
{
// Each warp processes exactly `32/TB` rows. We usually set `TB=32` and only
// make use of it in case the average row length is less than 32.
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row_idx
=
thread_idx
/
TB
;
int
lane_idx
=
thread_idx
&
(
TB
-
1
);
if
(
row_idx
<
N
)
{
auto
offset
=
IndexPtrToOffset
<
int64_t
,
int
>::
get
(
row_idx
,
indptr_info
);
int
offset
=
IndexPtrToOffset
<
int64_t
,
int
>::
get
(
row_idx
,
indptr_info
);
int
row_start
=
__ldg
(
indptr_info
.
data
+
offset
);
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
...
...
@@ -49,15 +48,17 @@ __global__ void segment_add_csr_kernel(
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
;
for
(
int
src_idx
=
row_start
+
lane_idx
;
src_idx
<
row_end
;
src_idx
+=
TB
)
{
val
+=
src_data
[
offset
+
src_idx
];
val
+=
src_data
[
offset
+
src_idx
];
// "Mostly" coalesced read.
}
#pragma unroll
for
(
int
i
=
TB
/
2
;
i
>
0
;
i
/=
2
)
val
+=
__shfl_down_sync
(
FULL_MASK
,
val
,
i
);
// Parallel reduction
for
(
int
i
=
TB
/
2
;
i
>
0
;
i
/=
2
)
{
// Parallel reduction inside a single warp.
val
+=
__shfl_down_sync
(
FULL_MASK
,
val
,
i
);
}
if
(
lane_idx
==
0
)
{
out_data
[
row_idx
]
=
val
;
out_data
[
row_idx
]
=
val
;
// "Mostly" coalesced write.
}
}
}
...
...
@@ -68,12 +69,15 @@ __global__ void segment_add_csr_broadcast_kernel(
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
indptr_info
,
scalar_t
*
out_data
,
size_t
N
,
size_t
K
,
size_t
E
)
{
// Each thread processes exactly one row. It turned out that is more efficient
// than using shared memory due to avoiding synchronization barriers.
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row_idx
=
thread_idx
/
K
;
int
lane_idx
=
thread_idx
%
K
;
if
(
thread_idx
<
N
*
K
)
{
auto
offset
=
IndexPtrToOffset
<
int64_t
,
int
>::
get
(
row_idx
,
indptr_info
);
int
offset
=
IndexPtrToOffset
<
int64_t
,
int
>::
get
(
row_idx
,
indptr_info
);
int
row_start
=
__ldg
(
indptr_info
.
data
+
offset
);
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
...
...
@@ -81,53 +85,10 @@ __global__ void segment_add_csr_broadcast_kernel(
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
*
K
;
for
(
int
src_idx
=
row_start
;
src_idx
<
row_end
;
src_idx
++
)
{
// Coalesced read into `src_data`.
val
+=
src_data
[
offset
+
K
*
src_idx
+
lane_idx
];
}
out_data
[
thread_idx
]
=
val
;
// Coalesced write into `out_data`
}
}
template
<
typename
scalar_t
,
int
TB
>
__global__
void
segment_add_csr_broadcast_kernel2
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
indptr_info
,
scalar_t
*
out_data
,
size_t
N
,
size_t
K
,
size_t
E
)
{
int
thread_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
row_idx
=
thread_idx
/
TB
;
int
lane_idx
=
thread_idx
&
(
TB
-
1
);
int
col_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
__shared__
scalar_t
vals
[
32
][
32
];
if
(
row_idx
<
N
)
{
auto
offset
=
IndexPtrToOffset
<
int64_t
,
int
>::
get
(
row_idx
,
indptr_info
);
int
row_start
=
__ldg
(
indptr_info
.
data
+
offset
);
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
scalar_t
val
=
(
scalar_t
)
0
;
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
*
K
;
if
(
col_idx
<
K
)
{
for
(
int
i
=
row_start
+
lane_idx
;
i
<
row_end
;
i
+=
TB
)
{
val
+=
src_data
[
offset
+
K
*
i
+
col_idx
];
}
}
vals
[
threadIdx
.
x
][
threadIdx
.
y
]
=
val
;
__syncthreads
();
#pragma unroll
for
(
int
i
=
1
;
i
<
TB
;
i
*=
2
)
{
vals
[
threadIdx
.
x
][
threadIdx
.
y
]
+=
vals
[
threadIdx
.
x
][
threadIdx
.
y
+
i
];
__syncthreads
();
val
+=
src_data
[
offset
+
K
*
src_idx
+
lane_idx
];
// Coalesced read.
}
if
(
col_idx
<
K
&&
lane_idx
==
0
)
{
out_data
[
row_idx
*
K
+
col_idx
]
=
vals
[
threadIdx
.
x
][
threadIdx
.
y
];
}
out_data
[
thread_idx
]
=
val
;
// Coalesced write.
}
}
...
...
@@ -150,10 +111,12 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
auto
indptr_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int
>
(
indptr
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_
add_
csr_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_csr_kernel"
,
[
&
]
{
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
// Select the right kernel based on average row length and whether we need
// broadcasting capabilties (K > 1):
if
(
K
==
1
&&
avg_length
<=
4
)
{
segment_add_csr_kernel
<
scalar_t
,
4
><<<
BLOCKS
(
4
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
N
,
E
);
...
...
@@ -178,65 +141,120 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr) {
return
out
;
}
template
<
typename
scalar_t
,
int
TB
>
__global__
void
segment_add_coo_kernel
(
const
scalar_t
*
src_data
,
const
int64_t
*
index_data
,
scalar_t
*
out_data
,
size_t
numel
)
{
template
<
typename
scalar_t
>
__global__
void
segment_add_coo_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
size_t
E
)
{
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
lane_idx
=
thread_idx
&
(
TB
-
1
);
// Each thread processes exactly one entry. Within a warp, we perform a
// parallel reduction across equal indices, and write the intermediate
// result via atomics.
int
row_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
lane_idx
=
row_idx
&
(
32
-
1
);
if
(
thread_idx
<
numel
)
{
auto
idx
=
__ldg
(
index_data
+
thread_idx
);
scalar_t
val
=
src_data
[
thread_idx
],
tmp
;
if
(
row_idx
<
E
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
int
idx
=
index_info
.
data
[
offset
],
next_idx
;
scalar_t
val
=
src_data
[
row_idx
],
tmp
;
#pragma unroll
for
(
int
offset
=
1
;
offset
<
TB
;
offset
*=
2
)
{
tmp
=
__shfl_up_sync
(
FULL_MASK
,
val
,
offset
);
int
idx_next
=
__ldg
(
index_data
+
thread_idx
-
offset
);
// AT_ASSERTM(lane_idx < offset || idx <= idx_next);
if
(
lane_idx
>=
offset
&&
idx
==
idx_next
)
{
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
tmp
=
__shfl_up_sync
(
FULL_MASK
,
val
,
i
);
next_idx
=
__shfl_up_sync
(
FULL_MASK
,
idx
,
i
);
if
(
lane_idx
>=
i
&&
idx
==
next_idx
)
val
+=
tmp
;
}
}
if
(
lane_idx
==
TB
-
1
||
idx
!=
__ldg
(
index_data
+
thread_idx
+
1
))
{
next_idx
=
__shfl_down_sync
(
FULL_MASK
,
idx
,
1
);
if
(
lane_idx
==
32
-
1
||
idx
!=
next_idx
)
{
atomAdd
(
out_data
+
idx
,
val
);
}
}
}
void
segment_add_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
auto
numel
=
src
.
numel
();
auto
avg_length
=
(
float
)
numel
/
(
float
)
out
.
numel
();
template
<
typename
scalar_t
,
int
TB
>
__global__
void
segment_add_coo_broadcast_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
size_t
E
,
size_t
K
)
{
auto
index_data
=
index
.
DATA_PTR
<
int64_t
>
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_add_coo_kernel"
,
[
&
]
{
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
// Each thread processes a single column and `TB` rows. Coalesced read and
// write is performed in column-major order. The intermediate results are
// written via atomics.
segment_add_coo_kernel
<
scalar_t
,
32
>
<<<
BLOCKS
(
1
,
numel
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_data
,
out_data
,
numel
);
});
int
row_start
=
(
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
)
*
TB
;
int
col_idx
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
row_start
<
E
&&
col_idx
<
K
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_start
,
index_info
);
int
idx1
=
__ldg
(
index_info
.
data
+
offset
);
scalar_t
val
=
src_data
[
K
*
row_start
+
col_idx
];
#pragma unroll
for
(
int
i
=
1
;
i
<
TB
;
i
++
)
{
if
(
row_start
+
i
>=
E
)
break
;
int
idx2
=
__ldg
(
index_info
.
data
+
offset
+
i
*
index_info
.
strides
[
index_info
.
dims
-
1
]);
if
(
idx1
==
idx2
)
{
val
+=
src_data
[
K
*
(
row_start
+
i
)
+
col_idx
];
}
else
{
atomAdd
(
out_data
+
K
*
idx1
+
col_idx
,
val
);
val
=
src_data
[
K
*
(
row_start
+
i
)
+
col_idx
];
}
idx1
=
idx2
;
}
atomAdd
(
out_data
+
K
*
idx1
+
col_idx
,
val
);
}
}
void
segment_add_thrust_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
allocator
=
THCThrustAllocator
(
at
::
globalContext
().
lazyInitCUDA
());
auto
policy
=
thrust
::
cuda
::
par
(
allocator
).
on
(
stream
);
void
segment_add_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
AT_ASSERTM
(
src
.
dim
()
>=
index
.
dim
());
for
(
int
i
=
0
;
i
<
index
.
dim
();
i
++
)
AT_ASSERTM
(
src
.
size
(
i
)
==
index
.
size
(
i
));
src
=
src
.
contiguous
();
auto
reduce_dim
=
index
.
dim
()
-
1
;
auto
key
=
at
::
full_like
(
out
,
-
1
,
out
.
options
().
dtype
(
at
::
kLong
));
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
if
(
i
!=
reduce_dim
)
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
));
auto
index_data
=
thrust
::
device_ptr
<
int64_t
>
(
index
.
DATA_PTR
<
int64_t
>
());
auto
key_data
=
thrust
::
device_ptr
<
int64_t
>
(
key
.
DATA_PTR
<
int64_t
>
());
auto
E
=
index
.
numel
();
auto
K
=
src
.
numel
()
/
index
.
numel
();
auto
avg_length
=
(
float
)
src
.
size
(
reduce_dim
)
/
(
float
)
out
.
size
(
reduce_dim
);
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_add_thrust_kernel"
,
[
&
]
{
auto
src_data
=
thrust
::
device_ptr
<
scalar_t
>
(
src
.
DATA_PTR
<
scalar_t
>
());
auto
out_data
=
thrust
::
device_ptr
<
scalar_t
>
(
out
.
DATA_PTR
<
scalar_t
>
());
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int
>
(
index
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_add_coo_kernel"
,
[
&
]
{
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
thrust
::
reduce_by_key
(
policy
,
index_data
,
index_data
+
index
.
numel
(),
src_data
,
key_data
,
out_data
);
if
(
K
==
1
)
segment_add_coo_kernel
<
scalar_t
><<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
);
else
if
(
avg_length
<=
8
)
segment_add_coo_broadcast_kernel
<
scalar_t
,
4
>
<<<
dim3
(((
E
+
(
8
*
4
)
-
1
)
/
(
8
*
4
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
);
else
if
(
avg_length
<=
16
)
segment_add_coo_broadcast_kernel
<
scalar_t
,
8
>
<<<
dim3
(((
E
+
(
8
*
8
)
-
1
)
/
(
8
*
8
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
);
else
if
(
avg_length
<=
32
)
segment_add_coo_broadcast_kernel
<
scalar_t
,
16
>
<<<
dim3
(((
E
+
(
8
*
16
)
-
1
)
/
(
8
*
16
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
);
else
segment_add_coo_broadcast_kernel
<
scalar_t
,
32
>
<<<
dim3
(((
E
+
(
8
*
32
)
-
1
)
/
(
8
*
32
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
);
});
}
test/test_segment.py
View file @
d4325fd1
...
...
@@ -14,25 +14,16 @@ devices = [torch.device('cuda')]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_forward
(
dtype
,
device
):
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
out
=
segment_add
(
src
,
index
,
dim
=
0
)
# print('Thrust', out)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_forward2
(
dtype
,
device
):
src
=
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
9
,
10
],
[
11
,
12
]],
dtype
,
device
)
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
# indptr = indptr.view(1, -1).expand(2, -1).t().contiguous().t()
out
=
segment_add_csr
(
src
,
indptr
)
print
(
'CSR'
,
out
)
#
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
#
out = segment_add_coo(src, index)
#
print('COO', out)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
out
=
segment_add_coo
(
src
,
index
)
print
(
'COO'
,
out
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
...
...
@@ -40,15 +31,20 @@ def test_benchmark(dtype, device):
from
torch_geometric.datasets
import
Planetoid
,
Reddit
# noqa
# data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
data
=
Planetoid
(
'/tmp/PubMed'
,
'PubMed'
)[
0
].
to
(
device
)
# data = Reddit('/tmp/Reddit')[0].to(device)
row
,
col
=
data
.
edge_index
x
=
torch
.
randn
(
data
.
num_edges
,
device
=
device
)
print
(
data
.
num_edges
)
print
(
row
.
size
(
0
)
/
data
.
num_nodes
)
num_repeats
=
1
row
=
row
.
view
(
-
1
,
1
).
repeat
(
1
,
num_repeats
).
view
(
-
1
).
contiguous
()
col
=
col
.
view
(
-
1
,
1
).
repeat
(
1
,
num_repeats
).
view
(
-
1
).
contiguous
()
# Warmup
for
_
in
range
(
10
):
torch
.
randn
(
100
,
100
,
device
=
device
).
sum
()
x
=
torch
.
randn
(
row
.
size
(
0
),
device
=
device
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
...
...
@@ -63,16 +59,6 @@ def test_benchmark(dtype, device):
torch
.
cuda
.
synchronize
()
print
(
'Scatter Col'
,
time
.
perf_counter
()
-
t
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out2
=
segment_add
(
x
,
row
,
dim
=
0
,
dim_size
=
data
.
num_nodes
)
torch
.
cuda
.
synchronize
()
print
(
'Thrust'
,
time
.
perf_counter
()
-
t
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-2
)
rowcount
=
segment_add
(
torch
.
ones_like
(
row
),
row
)
rowptr
=
torch
.
cat
([
rowcount
.
new_zeros
(
1
),
rowcount
.
cumsum
(
0
)],
dim
=
0
)
torch
.
cuda
.
synchronize
()
...
...
@@ -84,8 +70,6 @@ def test_benchmark(dtype, device):
torch
.
cuda
.
synchronize
()
print
(
'CSR'
,
time
.
perf_counter
()
-
t
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-2
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
...
...
@@ -93,9 +77,10 @@ def test_benchmark(dtype, device):
torch
.
cuda
.
synchronize
()
print
(
'COO'
,
time
.
perf_counter
()
-
t
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-2
)
assert
torch
.
allclose
(
out1
,
out4
,
atol
=
1e-2
)
x
=
torch
.
randn
((
data
.
num_edges
,
32
),
device
=
device
)
x
=
torch
.
randn
((
row
.
size
(
0
),
64
),
device
=
device
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
...
...
@@ -118,4 +103,12 @@ def test_benchmark(dtype, device):
torch
.
cuda
.
synchronize
()
print
(
'CSR + Dim'
,
time
.
perf_counter
()
-
t
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out7
=
segment_add_coo
(
x
,
row
,
dim_size
=
data
.
num_nodes
)
torch
.
cuda
.
synchronize
()
print
(
'COO + Dim'
,
time
.
perf_counter
()
-
t
)
assert
torch
.
allclose
(
out5
,
out6
,
atol
=
1e-2
)
assert
torch
.
allclose
(
out5
,
out7
,
atol
=
1e-2
)
torch_scatter/segment.py
View file @
d4325fd1
...
...
@@ -8,16 +8,7 @@ if torch.cuda.is_available():
def
segment_add
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
0
):
src
,
out
,
index
,
dim
=
gen
(
src
,
index
,
dim
,
out
,
dim_size
,
fill_value
)
if
src
.
size
(
dim
)
==
0
:
# pragma: no cover
return
out
if
not
src
.
is_cuda
:
return
scatter_add
(
src
,
index
,
dim
,
out
,
dim_size
,
fill_value
)
torch_scatter
.
segment_cuda
.
segment_add_thrust
(
src
,
index
,
out
)
return
out
return
scatter_add
(
src
,
index
,
dim
,
out
,
dim_size
,
fill_value
)
def
segment_add_csr
(
src
,
indptr
):
...
...
@@ -26,6 +17,8 @@ def segment_add_csr(src, indptr):
def
segment_add_coo
(
src
,
index
,
dim_size
=
None
):
dim_size
=
index
.
max
().
item
()
+
1
if
dim_size
is
None
else
dim_size
out
=
src
.
new_zeros
(
dim_size
)
size
=
list
(
src
.
size
())
size
[
index
.
dim
()
-
1
]
=
dim_size
out
=
src
.
new_zeros
(
size
)
torch_scatter
.
segment_cuda
.
segment_add_coo
(
src
,
index
,
out
)
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment