Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
9a91c42d
Commit
9a91c42d
authored
Jan 07, 2020
by
rusty1s
Browse files
reduce op in segment_csr
parent
611b2994
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
346 additions
and
192 deletions
+346
-192
benchmark/main.py
benchmark/main.py
+5
-5
cuda/segment.cpp
cuda/segment.cpp
+16
-11
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+200
-58
test/test_segment.py
test/test_segment.py
+102
-96
torch_scatter/__init__.py
torch_scatter/__init__.py
+3
-2
torch_scatter/segment.py
torch_scatter/segment.py
+20
-20
No files found.
benchmark/main.py
View file @
9a91c42d
...
...
@@ -7,7 +7,7 @@ from scipy.io import loadmat
import
torch
from
torch_scatter
import
scatter_add
from
torch_scatter
.segment
import
segment_
add_
csr
,
segment_
add_
coo
from
torch_scatter
import
segment_csr
,
segment_coo
iters
=
20
device
=
'cuda'
...
...
@@ -51,8 +51,8 @@ def correctness(dataset):
x
=
x
.
unsqueeze
(
-
1
)
if
size
==
1
else
x
out1
=
scatter_add
(
x
,
row
,
dim
=
0
,
dim_size
=
dim_size
)
out2
=
segment_
add_
coo
(
x
,
row
,
dim_size
=
dim_size
)
out3
=
segment_
add_
csr
(
x
,
rowptr
)
out2
=
segment_coo
(
x
,
row
,
dim_size
=
dim_size
)
out3
=
segment_csr
(
x
,
rowptr
)
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1e-4
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-4
)
...
...
@@ -104,7 +104,7 @@ def timing(dataset):
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
out
=
segment_
add_
coo
(
x
,
row
,
dim_size
=
dim_size
)
out
=
segment_coo
(
x
,
row
,
dim_size
=
dim_size
)
del
out
torch
.
cuda
.
synchronize
()
t3
.
append
(
time
.
perf_counter
()
-
t
)
...
...
@@ -116,7 +116,7 @@ def timing(dataset):
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
iters
):
out
=
segment_
add_
csr
(
x
,
rowptr
)
out
=
segment_csr
(
x
,
rowptr
)
del
out
torch
.
cuda
.
synchronize
()
t4
.
append
(
time
.
perf_counter
()
-
t
)
...
...
cuda/segment.cpp
View file @
9a91c42d
...
...
@@ -2,28 +2,33 @@
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
at
::
Tensor
segment_add_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
);
at
::
Tensor
segment_add_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
);
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
segment_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
,
std
::
string
reduce
);
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
segment_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
std
::
string
reduce
);
at
::
Tensor
segment_add_csr
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
)
{
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
segment_csr
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
,
std
::
string
reduce
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
indptr
);
if
(
out_opt
.
has_value
())
CHECK_CUDA
(
out_opt
.
value
());
return
segment_
add_
csr_cuda
(
src
,
indptr
,
out_opt
);
return
segment_csr_cuda
(
src
,
indptr
,
out_opt
,
reduce
);
}
at
::
Tensor
segment_add_coo
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
segment_coo
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
std
::
string
reduce
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
CHECK_CUDA
(
out
);
return
segment_
add_
coo_cuda
(
src
,
index
,
out
);
return
segment_coo_cuda
(
src
,
index
,
out
,
reduce
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"segment_
add_
csr"
,
&
segment_
add_
csr
,
"Segment
Add
CSR (CUDA)"
);
m
.
def
(
"segment_
add_
coo"
,
&
segment_
add_
coo
,
"Segment
Add
COO (CUDA)"
);
m
.
def
(
"segment_csr"
,
&
segment_csr
,
"Segment CSR (CUDA)"
);
m
.
def
(
"segment_coo"
,
&
segment_coo
,
"Segment COO (CUDA)"
);
}
cuda/segment_kernel.cu
View file @
9a91c42d
...
...
@@ -10,6 +10,11 @@
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
#define ADD 0
#define MEAN 1
#define MIN 2
#define MAX 3
// We need our own `IndexToOffset` implementation since we do not want to access
// the last element of the `indexptr`.
template
<
typename
T
,
typename
I
>
struct
IndexPtrToOffset
{
...
...
@@ -26,14 +31,13 @@ template <typename T, typename I> struct IndexPtrToOffset {
}
};
template
<
typename
scalar_t
,
int
TB
>
template
<
typename
scalar_t
,
int
REDUCE
,
int
TB
>
__global__
void
segment_add_csr_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
indptr_info
,
scalar_t
*
out_data
,
size_t
N
,
size_t
E
)
{
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
N
,
size_t
E
)
{
// Each warp processes exactly `32/TB` rows. We usually set `TB=32` and only
// make use of it in case the average row length is less than 32.
// Each warp processes exactly `32/TB` rows.
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row_idx
=
thread_idx
/
TB
;
...
...
@@ -44,30 +48,90 @@ __global__ void segment_add_csr_kernel(
int
row_start
=
__ldg
(
indptr_info
.
data
+
offset
);
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
scalar_t
val
=
(
scalar_t
)
0
;
scalar_t
val
,
tmp
;
int64_t
arg_val
,
arg_tmp
;
if
(
REDUCE
==
ADD
)
{
val
=
(
scalar_t
)
0
;
}
else
if
(
REDUCE
==
MEAN
)
{
val
=
(
scalar_t
)
0
;
}
else
if
(
REDUCE
==
MIN
)
{
val
=
std
::
numeric_limits
<
scalar_t
>::
max
();
}
else
if
(
REDUCE
==
MAX
)
{
val
=
std
::
numeric_limits
<
scalar_t
>::
min
();
}
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
;
for
(
int
src_idx
=
row_start
+
lane_idx
;
src_idx
<
row_end
;
src_idx
+=
TB
)
{
val
+=
src_data
[
offset
+
src_idx
];
// "Mostly" coalesced read.
tmp
=
src_data
[
offset
+
src_idx
];
// "Mostly" coalesced read.
if
(
REDUCE
==
ADD
)
{
val
+=
tmp
;
}
else
if
(
REDUCE
==
MEAN
)
{
val
+=
tmp
;
}
else
if
(
REDUCE
==
MIN
&&
tmp
<
val
)
{
val
=
tmp
;
arg_val
=
src_idx
;
}
else
if
(
REDUCE
==
MAX
&&
tmp
>
val
)
{
val
=
tmp
;
arg_val
=
src_idx
;
}
}
#pragma unroll
for
(
int
i
=
TB
/
2
;
i
>
0
;
i
/=
2
)
{
// Parallel reduction inside a single warp.
val
+=
__shfl_down_sync
(
FULL_MASK
,
val
,
i
);
tmp
=
__shfl_down_sync
(
FULL_MASK
,
val
,
i
);
if
(
REDUCE
==
ADD
)
{
val
+=
tmp
;
}
else
if
(
REDUCE
==
MEAN
)
{
val
+=
tmp
;
}
else
if
(
REDUCE
==
MIN
)
{
arg_tmp
=
__shfl_down_sync
(
FULL_MASK
,
arg_val
,
i
);
if
(
tmp
<
val
)
{
val
=
tmp
;
arg_val
=
arg_tmp
;
}
}
else
if
(
REDUCE
==
MAX
)
{
arg_tmp
=
__shfl_down_sync
(
FULL_MASK
,
arg_val
,
i
);
if
(
tmp
>
val
)
{
val
=
tmp
;
arg_val
=
arg_tmp
;
}
}
}
if
(
lane_idx
==
0
)
{
out_data
[
row_idx
]
=
val
;
// "Mostly" coalesced write.
// "Mostly" coalesced write.
if
(
REDUCE
==
ADD
)
{
out_data
[
row_idx
]
=
val
;
}
else
if
(
REDUCE
==
MEAN
)
{
out_data
[
row_idx
]
=
val
/
(
scalar_t
)
max
(
row_end
-
row_start
,
1
);
}
else
if
(
REDUCE
==
MIN
)
{
if
(
row_end
-
row_start
>
0
)
{
out_data
[
row_idx
]
=
val
;
arg_out_data
[
row_idx
]
=
arg_val
;
}
else
{
out_data
[
row_idx
]
=
0
;
}
}
else
if
(
REDUCE
==
MAX
)
{
if
(
row_end
-
row_start
>
0
)
{
out_data
[
row_idx
]
=
val
;
arg_out_data
[
row_idx
]
=
arg_val
;
}
else
{
out_data
[
row_idx
]
=
0
;
}
}
}
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
int
REDUCE
>
__global__
void
segment_add_csr_broadcast_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
indptr_info
,
scalar_t
*
out_data
,
size_t
N
,
size_t
K
,
size_t
E
)
{
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
N
,
size_t
K
,
size_t
E
)
{
// Each thread processes exactly one row. It turned out that is more efficient
// than using shared memory due to avoiding synchronization barriers.
...
...
@@ -81,19 +145,62 @@ __global__ void segment_add_csr_broadcast_kernel(
int
row_start
=
__ldg
(
indptr_info
.
data
+
offset
);
int
row_end
=
__ldg
(
indptr_info
.
data
+
offset
+
indptr_info
.
strides
[
indptr_info
.
dims
-
1
]);
scalar_t
val
=
(
scalar_t
)
0
;
scalar_t
val
,
tmp
;
int64_t
arg_val
;
if
(
REDUCE
==
ADD
)
{
val
=
(
scalar_t
)
0
;
}
else
if
(
REDUCE
==
MEAN
)
{
val
=
(
scalar_t
)
0
;
}
else
if
(
REDUCE
==
MIN
)
{
val
=
std
::
numeric_limits
<
scalar_t
>::
max
();
}
else
if
(
REDUCE
==
MAX
)
{
val
=
std
::
numeric_limits
<
scalar_t
>::
min
();
}
offset
=
(
row_idx
/
(
indptr_info
.
sizes
[
indptr_info
.
dims
-
1
]
-
1
))
*
E
*
K
;
for
(
int
src_idx
=
row_start
;
src_idx
<
row_end
;
src_idx
++
)
{
val
+=
src_data
[
offset
+
K
*
src_idx
+
lane_idx
];
// Coalesced read.
tmp
=
src_data
[
offset
+
K
*
src_idx
+
lane_idx
];
// Coalesced read.
if
(
REDUCE
==
ADD
)
{
val
+=
tmp
;
}
else
if
(
REDUCE
==
MEAN
)
{
val
+=
tmp
;
}
else
if
(
REDUCE
==
MIN
&&
tmp
<
val
)
{
val
=
tmp
;
arg_val
=
src_idx
;
}
else
if
(
REDUCE
==
MAX
&&
tmp
>
val
)
{
val
=
tmp
;
arg_val
=
src_idx
;
}
}
out_data
[
thread_idx
]
=
val
;
// Coalesced write.
// Coalesced write.
if
(
REDUCE
==
ADD
)
{
out_data
[
thread_idx
]
=
val
;
}
else
if
(
REDUCE
==
MEAN
)
{
out_data
[
thread_idx
]
=
val
/
(
scalar_t
)
max
(
row_end
-
row_start
,
1
);
}
else
if
(
REDUCE
==
MIN
)
{
if
(
row_end
-
row_start
>
0
)
{
out_data
[
thread_idx
]
=
val
;
arg_out_data
[
thread_idx
]
=
arg_val
;
}
else
{
out_data
[
thread_idx
]
=
0
;
}
}
else
if
(
REDUCE
==
MAX
)
{
if
(
row_end
-
row_start
>
0
)
{
out_data
[
thread_idx
]
=
val
;
arg_out_data
[
thread_idx
]
=
arg_val
;
}
else
{
out_data
[
thread_idx
]
=
0
;
}
}
}
}
at
::
Tensor
segment_add_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
)
{
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
segment_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
,
std
::
string
reduce
)
{
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
());
for
(
int
i
=
0
;
i
<
indptr
.
dim
()
-
1
;
i
++
)
...
...
@@ -104,7 +211,7 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr,
at
::
Tensor
out
;
if
(
out_opt
.
has_value
())
{
out
=
out_opt
.
value
();
out
=
out_opt
.
value
()
.
contiguous
()
;
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
if
(
i
!=
reduce_dim
)
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
));
...
...
@@ -115,10 +222,15 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr,
out
=
at
::
empty
(
sizes
,
src
.
options
());
}
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
if
(
reduce
==
"min"
||
reduce
==
"max"
)
{
arg_out
=
at
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
indptr
.
options
());
}
auto
N
=
out
.
size
(
reduce_dim
)
*
(
indptr
.
numel
()
/
indptr
.
size
(
-
1
));
auto
K
=
out
.
numel
()
/
N
;
auto
E
=
src
.
size
(
reduce_dim
);
auto
avg_len
gth
=
(
float
)
src
.
size
(
reduce_dim
)
/
(
float
)
out
.
size
(
reduce_dim
);
//
auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
auto
indptr_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int
>
(
indptr
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
@@ -126,37 +238,56 @@ at::Tensor segment_add_csr_cuda(at::Tensor src, at::Tensor indptr,
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
// Select the right kernel based on
average row length
and whether we need
// Select the right kernel based on
the reduce operation
and whether we need
// broadcasting capabilties (K > 1):
if
(
K
==
1
&&
avg_length
<=
4
)
{
segment_add_csr_kernel
<
scalar_t
,
4
><<<
BLOCKS
(
4
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
N
,
E
);
}
else
if
(
K
==
1
&&
avg_length
<=
8
)
{
segment_add_csr_kernel
<
scalar_t
,
8
><<<
BLOCKS
(
8
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
N
,
E
);
}
else
if
(
K
==
1
&&
avg_length
<=
16
)
{
segment_add_csr_kernel
<
scalar_t
,
16
>
<<<
BLOCKS
(
16
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
N
,
E
);
}
else
if
(
K
==
1
)
{
segment_add_csr_kernel
<
scalar_t
,
32
>
if
(
K
==
1
&&
reduce
==
"add"
)
{
segment_add_csr_kernel
<
scalar_t
,
ADD
,
1
>
<<<
BLOCKS
(
32
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
N
,
E
);
}
else
{
segment_add_csr_broadcast_kernel
<
scalar_t
>
<<<
BLOCKS
(
1
,
N
*
K
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
N
,
K
,
E
);
out_data
,
nullptr
,
N
,
E
);
}
else
if
(
K
==
1
&&
reduce
==
"mean"
)
{
segment_add_csr_kernel
<
scalar_t
,
MEAN
,
1
>
<<<
BLOCKS
(
32
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
nullptr
,
N
,
E
);
}
else
if
(
K
==
1
&&
reduce
==
"min"
)
{
auto
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
segment_add_csr_kernel
<
scalar_t
,
MIN
,
1
>
<<<
BLOCKS
(
32
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
arg_out_data
,
N
,
E
);
}
else
if
(
K
==
1
&&
reduce
==
"max"
)
{
auto
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
segment_add_csr_kernel
<
scalar_t
,
MAX
,
1
>
<<<
BLOCKS
(
32
,
N
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
arg_out_data
,
N
,
E
);
}
else
if
(
reduce
==
"add"
)
{
segment_add_csr_broadcast_kernel
<
scalar_t
,
ADD
>
<<<
BLOCKS
(
1
,
N
*
K
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
nullptr
,
N
,
K
,
E
);
}
else
if
(
reduce
==
"mean"
)
{
segment_add_csr_broadcast_kernel
<
scalar_t
,
MEAN
>
<<<
BLOCKS
(
1
,
N
*
K
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
nullptr
,
N
,
K
,
E
);
}
else
if
(
reduce
==
"min"
)
{
auto
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
segment_add_csr_broadcast_kernel
<
scalar_t
,
MIN
>
<<<
BLOCKS
(
1
,
N
*
K
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
arg_out_data
,
N
,
K
,
E
);
}
else
if
(
reduce
==
"max"
)
{
auto
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
segment_add_csr_broadcast_kernel
<
scalar_t
,
MAX
>
<<<
BLOCKS
(
1
,
N
*
K
),
THREADS
,
0
,
stream
>>>
(
src_data
,
indptr_info
,
out_data
,
arg_out_data
,
N
,
K
,
E
);
}
});
return
out
;
return
std
::
make_tuple
(
out
,
arg_out
)
;
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
int
REDUCE
>
__global__
void
segment_add_coo_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
size_t
E
)
{
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
)
{
// Each thread processes exactly one entry. Within a warp, we perform a
// parallel reduction across equal indices, and write the intermediate
...
...
@@ -187,11 +318,11 @@ __global__ void segment_add_coo_kernel(
}
}
template
<
typename
scalar_t
,
int
TB
>
template
<
typename
scalar_t
,
int
REDUCE
,
int
TB
>
__global__
void
segment_add_coo_broadcast_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
size_t
E
,
size_t
K
)
{
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
,
size_t
K
)
{
// Each thread processes a single column and `TB` rows. Coalesced read and
// write is performed in column-major order. The intermediate results are
...
...
@@ -228,49 +359,60 @@ __global__ void segment_add_coo_broadcast_kernel(
}
}
at
::
Tensor
segment_add_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
)
{
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
segment_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
std
::
string
reduce
)
{
AT_ASSERTM
(
src
.
dim
()
>=
index
.
dim
());
for
(
int
i
=
0
;
i
<
index
.
dim
();
i
++
)
AT_ASSERTM
(
src
.
size
(
i
)
==
index
.
size
(
i
));
src
=
src
.
contiguous
();
out
=
out
.
contiguous
();
auto
reduce_dim
=
index
.
dim
()
-
1
;
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
if
(
i
!=
reduce_dim
)
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
));
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
if
(
reduce
==
"min"
||
reduce
==
"max"
)
{
arg_out
=
at
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
index
.
options
());
}
auto
E
=
index
.
numel
();
auto
K
=
src
.
numel
()
/
index
.
numel
();
auto
avg_len
gth
=
(
float
)
src
.
size
(
reduce_dim
)
/
(
float
)
out
.
size
(
reduce_dim
);
auto
avg_len
=
(
float
)
src
.
size
(
reduce_dim
)
/
(
float
)
out
.
size
(
reduce_dim
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int
>
(
index
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_
add_
coo_kernel"
,
[
&
]
{
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_coo_kernel"
,
[
&
]
{
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
// Select the right kernel based on average row length (purely heuristic)
// and whether we need broadcasting capabilties (K > 1):
if
(
K
==
1
)
segment_add_coo_kernel
<
scalar_t
><<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
);
else
if
(
avg_length
<=
8
)
segment_add_coo_broadcast_kernel
<
scalar_t
,
4
>
segment_add_coo_kernel
<
scalar_t
,
ADD
>
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
);
else
if
(
avg_len
<=
8
)
segment_add_coo_broadcast_kernel
<
scalar_t
,
ADD
,
4
>
<<<
dim3
(((
E
+
(
8
*
4
)
-
1
)
/
(
8
*
4
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
);
else
if
(
avg_len
gth
<=
16
)
segment_add_coo_broadcast_kernel
<
scalar_t
,
8
>
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
,
K
);
else
if
(
avg_len
<=
16
)
segment_add_coo_broadcast_kernel
<
scalar_t
,
ADD
,
8
>
<<<
dim3
(((
E
+
(
8
*
8
)
-
1
)
/
(
8
*
8
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
);
else
if
(
avg_len
gth
<=
32
)
segment_add_coo_broadcast_kernel
<
scalar_t
,
16
>
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
,
K
);
else
if
(
avg_len
<=
32
)
segment_add_coo_broadcast_kernel
<
scalar_t
,
ADD
,
16
>
<<<
dim3
(((
E
+
(
8
*
16
)
-
1
)
/
(
8
*
16
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
);
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
,
K
);
else
segment_add_coo_broadcast_kernel
<
scalar_t
,
32
>
segment_add_coo_broadcast_kernel
<
scalar_t
,
ADD
,
32
>
<<<
dim3
(((
E
+
(
8
*
32
)
-
1
)
/
(
8
*
32
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
);
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
,
K
);
});
return
out
;
return
std
::
make_tuple
(
out
,
arg_out
)
;
}
test/test_segment.py
View file @
9a91c42d
import
time
from
itertools
import
product
import
pytest
import
torch
from
torch_scatter
import
segment_add
,
scatter_add
from
torch_scatter.segment
import
segment_add_csr
,
segment_add_coo
from
torch_scatter
import
segment_coo
,
segment_csr
from
.utils
import
tensor
...
...
@@ -14,101 +12,109 @@ devices = [torch.device('cuda')]
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_forward
(
dtype
,
device
):
src
=
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
9
,
10
],
[
11
,
12
]],
dtype
,
device
)
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
# src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
# device)
out
=
segment_add_csr
(
src
,
indptr
)
print
(
'CSR'
,
out
)
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
out
=
segment_
add_
coo
(
src
,
index
)
print
(
'COO'
,
out
)
#
out = segment_coo(src, index)
#
print('COO', out)
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'add'
)
print
(
'CSR'
,
out
)
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'mean'
)
print
(
'CSR'
,
out
)
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'min'
)
print
(
'CSR'
,
out
)
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'max'
)
print
(
'CSR'
,
out
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_benchmark
(
dtype
,
device
):
from
torch_geometric.datasets
import
Planetoid
,
Reddit
# noqa
# data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
data
=
Planetoid
(
'/tmp/PubMed'
,
'PubMed'
)[
0
].
to
(
device
)
row
,
col
=
data
.
edge_index
print
(
data
.
num_edges
)
print
(
row
.
size
(
0
)
/
data
.
num_nodes
)
num_repeats
=
1
row
=
row
.
view
(
-
1
,
1
).
repeat
(
1
,
num_repeats
).
view
(
-
1
).
contiguous
()
col
=
col
.
view
(
-
1
,
1
).
repeat
(
1
,
num_repeats
).
view
(
-
1
).
contiguous
()
# Warmup
for
_
in
range
(
10
):
torch
.
randn
(
100
,
100
,
device
=
device
).
sum
()
x
=
torch
.
randn
(
row
.
size
(
0
),
device
=
device
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out1
=
scatter_add
(
x
,
row
,
dim
=
0
,
dim_size
=
data
.
num_nodes
)
torch
.
cuda
.
synchronize
()
print
(
'Scatter Row'
,
time
.
perf_counter
()
-
t
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
scatter_add
(
x
,
col
,
dim
=
0
,
dim_size
=
data
.
num_nodes
)
torch
.
cuda
.
synchronize
()
print
(
'Scatter Col'
,
time
.
perf_counter
()
-
t
)
rowcount
=
segment_add
(
torch
.
ones_like
(
row
),
row
)
rowptr
=
torch
.
cat
([
rowcount
.
new_zeros
(
1
),
rowcount
.
cumsum
(
0
)],
dim
=
0
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out3
=
segment_add_csr
(
x
,
rowptr
)
torch
.
cuda
.
synchronize
()
print
(
'CSR'
,
time
.
perf_counter
()
-
t
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out4
=
segment_add_coo
(
x
,
row
,
dim_size
=
data
.
num_nodes
)
torch
.
cuda
.
synchronize
()
print
(
'COO'
,
time
.
perf_counter
()
-
t
)
assert
torch
.
allclose
(
out1
,
out3
,
atol
=
1e-2
)
assert
torch
.
allclose
(
out1
,
out4
,
atol
=
1e-2
)
x
=
torch
.
randn
((
row
.
size
(
0
),
64
),
device
=
device
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out5
=
scatter_add
(
x
,
row
,
dim
=
0
,
dim_size
=
data
.
num_nodes
)
torch
.
cuda
.
synchronize
()
print
(
'Scatter Row + Dim'
,
time
.
perf_counter
()
-
t
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
scatter_add
(
x
,
col
,
dim
=
0
,
dim_size
=
data
.
num_nodes
)
torch
.
cuda
.
synchronize
()
print
(
'Scatter Col + Dim'
,
time
.
perf_counter
()
-
t
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out6
=
segment_add_csr
(
x
,
rowptr
)
torch
.
cuda
.
synchronize
()
print
(
'CSR + Dim'
,
time
.
perf_counter
()
-
t
)
torch
.
cuda
.
synchronize
()
t
=
time
.
perf_counter
()
for
_
in
range
(
100
):
out7
=
segment_add_coo
(
x
,
row
,
dim_size
=
data
.
num_nodes
)
torch
.
cuda
.
synchronize
()
print
(
'COO + Dim'
,
time
.
perf_counter
()
-
t
)
assert
torch
.
allclose
(
out5
,
out6
,
atol
=
1e-2
)
assert
torch
.
allclose
(
out5
,
out7
,
atol
=
1e-2
)
# @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
# def test_benchmark(dtype, device):
# from torch_geometric.datasets import Planetoid, Reddit # noqa
# # data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
# data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
# row, col = data.edge_index
# print(data.num_edges)
# print(row.size(0) / data.num_nodes)
# num_repeats = 1
# row = row.view(-1, 1).repeat(1, num_repeats).view(-1).contiguous()
# col = col.view(-1, 1).repeat(1, num_repeats).view(-1).contiguous()
# # Warmup
# for _ in range(10):
# torch.randn(100, 100, device=device).sum()
# x = torch.randn(row.size(0), device=device)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out1 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Row', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# scatter_add(x, col, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Col', time.perf_counter() - t)
# rowcount = segment_add(torch.ones_like(row), row)
# rowptr = torch.cat([rowcount.new_zeros(1), rowcount.cumsum(0)], dim=0)
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out3 = segment_add_csr(x, rowptr)
# torch.cuda.synchronize()
# print('CSR', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out4 = segment_add_coo(x, row, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('COO', time.perf_counter() - t)
# assert torch.allclose(out1, out3, atol=1e-2)
# assert torch.allclose(out1, out4, atol=1e-2)
# x = torch.randn((row.size(0), 64), device=device)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out5 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Row + Dim', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# scatter_add(x, col, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Col + Dim', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out6 = segment_add_csr(x, rowptr)
# torch.cuda.synchronize()
# print('CSR + Dim', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out7 = segment_add_coo(x, row, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('COO + Dim', time.perf_counter() - t)
# assert torch.allclose(out5, out6, atol=1e-2)
# assert torch.allclose(out5, out7, atol=1e-2)
torch_scatter/__init__.py
View file @
9a91c42d
...
...
@@ -8,7 +8,7 @@ from .max import scatter_max
from
.min
import
scatter_min
from
.logsumexp
import
scatter_logsumexp
from
.segment
import
segment_
add
from
.segment
import
segment_
coo
,
segment_csr
import
torch_scatter.composite
...
...
@@ -24,7 +24,8 @@ __all__ = [
'scatter_max'
,
'scatter_min'
,
'scatter_logsumexp'
,
'segment_add'
,
'segment_coo'
,
'segment_csr'
,
'torch_scatter'
,
'__version__'
,
]
torch_scatter/segment.py
View file @
9a91c42d
import
torch
from
torch_scatter.add
import
scatter_add
if
torch
.
cuda
.
is_available
():
import
torch_scatter.segment_cuda
def
segment_add
(
src
,
index
,
dim
=-
1
,
out
=
None
,
dim_size
=
None
,
fill_value
=
0
):
return
scatter_add
(
src
,
index
,
dim
,
out
,
dim_size
,
fill_value
)
from
torch_scatter
import
segment_cuda
def
segment_add_csr
(
src
,
indptr
,
out
=
None
):
return
torch_scatter
.
segment_cuda
.
segment_add_csr
(
src
,
indptr
,
out
)
def
segment_add_coo
(
src
,
index
,
dim_size
=
None
):
def
segment_coo
(
src
,
index
,
out
=
None
,
dim_size
=
None
,
reduce
=
'add'
):
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
if
out
is
None
:
dim_size
=
index
.
max
().
item
()
+
1
if
dim_size
is
None
else
dim_size
size
=
list
(
src
.
size
())
size
[
index
.
dim
()
-
1
]
=
dim_size
out
=
src
.
new_zeros
(
size
)
torch_scatter
.
segment_cuda
.
segment_add_coo
(
src
,
index
,
out
)
return
out
out
=
src
.
new_zeros
(
size
)
# TODO: DEPENDENT ON REDUCE
assert
index
.
dtype
==
torch
.
long
and
src
.
dtype
==
out
.
dtype
out
,
arg_out
=
segment_cuda
.
segment_coo
(
src
,
index
,
out
,
reduce
)
return
out
if
arg_out
is
None
else
(
out
,
arg_out
)
def
segment_csr
(
src
,
indptr
,
out
=
None
,
reduce
=
'add'
):
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
assert
indptr
.
dtype
==
torch
.
long
out
,
arg_out
=
segment_cuda
.
segment_csr
(
src
,
indptr
,
out
,
reduce
)
return
out
if
arg_out
is
None
else
(
out
,
arg_out
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment