Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
3c89ebc2
"sgl-kernel/vscode:/vscode.git/clone" did not exist on "31dfff7da7ade6703303a67bfe6ef52ead97640a"
Commit
3c89ebc2
authored
Jan 07, 2020
by
rusty1s
Browse files
autograd function
parent
6e561c88
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
80 additions
and
133 deletions
+80
-133
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+40
-38
test/test_segment.py
test/test_segment.py
+8
-91
torch_scatter/segment.py
torch_scatter/segment.py
+32
-4
No files found.
cuda/segment_kernel.cu
View file @
3c89ebc2
...
@@ -68,8 +68,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
...
@@ -68,8 +68,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
}
}
static
inline
__device__
void
atom_write
(
scalar_t
*
address
,
scalar_t
val
,
static
inline
__device__
void
atomic_write
(
scalar_t
*
address
,
scalar_t
val
,
int64_t
*
arg_address
,
int64_t
arg
)
{
int64_t
*
arg_address
,
int64_t
arg
)
{
if
(
REDUCE
==
ADD
)
{
if
(
REDUCE
==
ADD
)
{
atomAdd
(
address
,
val
);
atomAdd
(
address
,
val
);
}
else
if
(
REDUCE
==
MEAN
)
{
}
else
if
(
REDUCE
==
MEAN
)
{
...
@@ -81,6 +82,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
...
@@ -81,6 +82,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
}
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
assert
(
false
);
// TODO
__syncthreads
();
__syncthreads
();
if
(
*
address
==
val
)
{
if
(
*
address
==
val
)
{
*
arg_address
=
arg
;
*
arg_address
=
arg
;
...
@@ -280,7 +282,7 @@ segment_coo_kernel(const scalar_t *src_data,
...
@@ -280,7 +282,7 @@ segment_coo_kernel(const scalar_t *src_data,
next_idx
=
__shfl_down_sync
(
FULL_MASK
,
idx
,
1
);
next_idx
=
__shfl_down_sync
(
FULL_MASK
,
idx
,
1
);
if
(
lane_idx
==
32
-
1
||
idx
!=
next_idx
)
{
if
(
lane_idx
==
32
-
1
||
idx
!=
next_idx
)
{
Reducer
<
scalar_t
,
REDUCE
>::
atom_write
(
out_data
+
idx
,
val
,
Reducer
<
scalar_t
,
REDUCE
>::
atom
ic
_write
(
out_data
+
idx
,
val
,
arg_out_data
+
idx
,
arg
);
arg_out_data
+
idx
,
arg
);
}
}
}
}
...
@@ -343,8 +345,10 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -343,8 +345,10 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
));
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
));
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce
==
"min"
||
reduce
==
"max"
)
{
if
(
reduce
==
"min"
||
reduce
==
"max"
)
{
arg_out
=
at
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
index
.
options
());
arg_out
=
at
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
index
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
}
auto
E
=
index
.
numel
();
auto
E
=
index
.
numel
();
...
@@ -357,43 +361,41 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -357,43 +361,41 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
// Select the right kernel based on average row length (purely heuristic)
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
// and whether we need broadcasting capabilties (K > 1):
if
(
K
==
1
)
{
segment_coo_kernel
<
scalar_t
,
REDUCE
>
if
(
K
==
1
&&
reduce
==
"add"
)
{
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
segment_coo_kernel
<
scalar_t
,
ADD
><<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
out_data
,
arg_out_data
,
E
);
src_data
,
index_info
,
out_data
,
nullptr
,
E
);
}
else
if
(
avg_len
<=
8
)
{
}
else
if
(
K
==
1
&&
reduce
==
"mean"
)
{
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
4
>
segment_coo_kernel
<
scalar_t
,
MEAN
><<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
<<<
dim3
(((
E
+
(
8
*
4
)
-
1
)
/
(
8
*
4
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
src_data
,
index_info
,
out_data
,
nullptr
,
E
);
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
,
K
);
}
else
if
(
K
==
1
&&
reduce
==
"min"
)
{
}
else
if
(
avg_len
<=
16
)
{
auto
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
8
>
segment_coo_kernel
<
scalar_t
,
MIN
><<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
<<<
dim3
(((
E
+
(
8
*
8
)
-
1
)
/
(
8
*
8
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
);
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
,
K
);
}
else
if
(
K
==
1
&&
reduce
==
"max"
)
{
}
else
if
(
avg_len
<=
32
)
{
auto
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
16
>
segment_coo_kernel
<
scalar_t
,
MAX
><<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
<<<
dim3
(((
E
+
(
8
*
16
)
-
1
)
/
(
8
*
16
)),
(
K
+
31
)
/
32
),
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
);
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
}
else
if
(
avg_len
<=
8
)
arg_out_data
,
E
,
K
);
segment_coo_broadcast_kernel
<
scalar_t
,
ADD
,
4
>
}
else
{
<<<
dim3
(((
E
+
(
8
*
4
)
-
1
)
/
(
8
*
4
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
32
>
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
,
K
);
<<<
dim3
(((
E
+
(
8
*
32
)
-
1
)
/
(
8
*
32
)),
(
K
+
31
)
/
32
),
else
if
(
avg_len
<=
16
)
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
segment_coo_broadcast_kernel
<
scalar_t
,
ADD
,
8
>
arg_out_data
,
E
,
K
);
<<<
dim3
(((
E
+
(
8
*
8
)
-
1
)
/
(
8
*
8
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
}
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
,
K
);
});
else
if
(
avg_len
<=
32
)
segment_coo_broadcast_kernel
<
scalar_t
,
ADD
,
16
>
<<<
dim3
(((
E
+
(
8
*
16
)
-
1
)
/
(
8
*
16
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
,
K
);
else
segment_coo_broadcast_kernel
<
scalar_t
,
ADD
,
32
>
<<<
dim3
(((
E
+
(
8
*
32
)
-
1
)
/
(
8
*
32
)),
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
nullptr
,
E
,
K
);
});
});
if
(
reduce
==
"mean"
)
{
if
(
reduce
==
"mean"
)
{
AT_ASSERTM
(
false
);
// TODO: DIVIDE ENTRIES.
auto
count
=
at
::
empty_like
(
index
,
out
.
options
());
AT_DISPATCH_ALL_TYPES
(
out
.
scalar_type
(),
"count_kernel"
,
[
&
]
{
auto
count_data
=
count
.
DATA_PTR
<
scalar_t
>
();
AT_ASSERTM
(
false
);
// TODO
});
out
=
out
/
count
;
arg_out
=
count
;
}
}
return
std
::
make_tuple
(
out
,
arg_out
);
return
std
::
make_tuple
(
out
,
arg_out
);
...
...
test/test_segment.py
View file @
3c89ebc2
...
@@ -10,105 +10,22 @@ dtypes = [torch.float]
...
@@ -10,105 +10,22 @@ dtypes = [torch.float]
devices
=
[
torch
.
device
(
'cuda'
)]
devices
=
[
torch
.
device
(
'cuda'
)]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'dtype,device'
,
product
(
dtypes
,
devices
))
def
test_forward
(
dtype
,
device
):
def
test_forward
(
dtype
,
device
):
#
src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
src
=
tensor
([[
1
,
2
],
[
3
,
4
],
[
5
,
6
],
[
7
,
8
],
[
9
,
10
],
[
11
,
12
]],
dtype
,
#
device)
device
)
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
src
=
tensor
([
1
,
2
,
3
,
4
,
5
,
6
],
dtype
,
device
)
src
.
requires_grad_
()
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
indptr
=
tensor
([
0
,
2
,
5
,
5
,
6
],
torch
.
long
,
device
)
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'add'
)
out
=
segment_csr
(
src
,
indptr
,
reduce
=
'max'
)
out
=
out
[
0
]
if
isinstance
(
out
,
tuple
)
else
out
print
(
'CSR'
,
out
)
print
(
'CSR'
,
out
)
out
.
backward
(
torch
.
randn_like
(
out
))
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
index
=
tensor
([
0
,
0
,
1
,
1
,
1
,
3
],
torch
.
long
,
device
)
out
=
segment_coo
(
src
,
index
,
reduce
=
'add'
)
out
=
segment_coo
(
src
,
index
,
reduce
=
'add'
)
print
(
'COO'
,
out
)
print
(
'COO'
,
out
)
# @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
# def test_benchmark(dtype, device):
# from torch_geometric.datasets import Planetoid, Reddit # noqa
# # data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
# data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
# row, col = data.edge_index
# print(data.num_edges)
# print(row.size(0) / data.num_nodes)
# num_repeats = 1
# row = row.view(-1, 1).repeat(1, num_repeats).view(-1).contiguous()
# col = col.view(-1, 1).repeat(1, num_repeats).view(-1).contiguous()
# # Warmup
# for _ in range(10):
# torch.randn(100, 100, device=device).sum()
# x = torch.randn(row.size(0), device=device)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out1 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Row', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# scatter_add(x, col, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Col', time.perf_counter() - t)
# rowcount = segment_add(torch.ones_like(row), row)
# rowptr = torch.cat([rowcount.new_zeros(1), rowcount.cumsum(0)], dim=0)
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out3 = segment_add_csr(x, rowptr)
# torch.cuda.synchronize()
# print('CSR', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out4 = segment_add_coo(x, row, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('COO', time.perf_counter() - t)
# assert torch.allclose(out1, out3, atol=1e-2)
# assert torch.allclose(out1, out4, atol=1e-2)
# x = torch.randn((row.size(0), 64), device=device)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out5 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Row + Dim', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# scatter_add(x, col, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Col + Dim', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out6 = segment_add_csr(x, rowptr)
# torch.cuda.synchronize()
# print('CSR + Dim', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out7 = segment_add_coo(x, row, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('COO + Dim', time.perf_counter() - t)
# assert torch.allclose(out5, out6, atol=1e-2)
# assert torch.allclose(out5, out7, atol=1e-2)
torch_scatter/segment.py
View file @
3c89ebc2
...
@@ -4,6 +4,31 @@ if torch.cuda.is_available():
...
@@ -4,6 +4,31 @@ if torch.cuda.is_available():
from
torch_scatter
import
segment_cuda
from
torch_scatter
import
segment_cuda
class
SegmentCSR
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
src
,
indptr
,
out
,
reduce
):
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
assert
indptr
.
dtype
==
torch
.
long
if
out
is
not
None
:
ctx
.
mark_dirty
(
out
)
ctx
.
reduce
=
reduce
ctx
.
save_for_backward
(
src
,
indptr
)
out
,
arg_out
=
segment_cuda
.
segment_csr
(
src
,
indptr
,
out
,
reduce
)
return
out
if
arg_out
is
None
else
(
out
,
arg_out
)
@
staticmethod
def
backward
(
ctx
,
grad_out
,
*
args
):
src
,
indptr
=
ctx
.
saved_tensors
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
grad_src
=
src
return
grad_src
,
None
,
None
,
None
def
segment_coo
(
src
,
index
,
out
=
None
,
dim_size
=
None
,
reduce
=
'add'
):
def
segment_coo
(
src
,
index
,
out
=
None
,
dim_size
=
None
,
reduce
=
'add'
):
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
if
out
is
None
:
if
out
is
None
:
...
@@ -17,7 +42,10 @@ def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
...
@@ -17,7 +42,10 @@ def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
def
segment_csr
(
src
,
indptr
,
out
=
None
,
reduce
=
'add'
):
def
segment_csr
(
src
,
indptr
,
out
=
None
,
reduce
=
'add'
):
assert
reduce
in
[
'add'
,
'mean'
,
'min'
,
'max'
]
return
SegmentCSR
.
apply
(
src
,
indptr
,
out
,
reduce
)
assert
indptr
.
dtype
==
torch
.
long
out
,
arg_out
=
segment_cuda
.
segment_csr
(
src
,
indptr
,
out
,
reduce
)
return
out
if
arg_out
is
None
else
(
out
,
arg_out
)
# assert reduce in ['add', 'mean', 'min', 'max']
# assert indptr.dtype == torch.long
# out, arg_out = segment_cuda.segment_csr(src, indptr, out, reduce)
# return out if arg_out is None else (out, arg_out)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment