Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
8b77e547
Commit
8b77e547
authored
Apr 06, 2020
by
rusty1s
Browse files
added python wrapper
parent
631df924
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
38 additions
and
22 deletions
+38
-22
test/test_padding.py
test/test_padding.py
+8
-10
torch_sparse/__init__.py
torch_sparse/__init__.py
+3
-0
torch_sparse/metis.py
torch_sparse/metis.py
+3
-5
torch_sparse/padding.py
torch_sparse/padding.py
+21
-0
torch_sparse/saint.py
torch_sparse/saint.py
+3
-7
No files found.
test/test_padding.py
View file @
8b77e547
...
@@ -2,7 +2,7 @@ from itertools import product
...
@@ -2,7 +2,7 @@ from itertools import product
import
pytest
import
pytest
import
torch
import
torch
from
torch_sparse
import
SparseTensor
from
torch_sparse
import
SparseTensor
,
padded_index_select
from
.utils
import
grad_dtypes
,
tensor
from
.utils
import
grad_dtypes
,
tensor
...
@@ -14,11 +14,9 @@ def test_padded_index_select(dtype, device):
...
@@ -14,11 +14,9 @@ def test_padded_index_select(dtype, device):
row
=
torch
.
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
3
])
row
=
torch
.
tensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
2
,
2
,
3
])
col
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
0
,
2
,
3
,
1
,
3
,
2
])
col
=
torch
.
tensor
([
0
,
1
,
2
,
3
,
0
,
2
,
3
,
1
,
3
,
2
])
adj
=
SparseTensor
(
row
=
row
,
col
=
col
).
to
(
device
)
adj
=
SparseTensor
(
row
=
row
,
col
=
col
).
to
(
device
)
rowptr
,
col
,
_
=
adj
.
csr
()
rowcount
=
adj
.
storage
.
rowcount
()
binptr
=
torch
.
tensor
([
0
,
3
,
5
],
device
=
device
)
binptr
=
torch
.
tensor
([
0
,
3
,
5
],
device
=
device
)
data
=
torch
.
ops
.
torch_sparse
.
padded_index
(
rowptr
,
col
,
rowcount
,
binptr
)
data
=
adj
.
padded_index
(
binptr
)
node_perm
,
row_perm
,
col_perm
,
mask
,
node_size
,
edge_size
=
data
node_perm
,
row_perm
,
col_perm
,
mask
,
node_size
,
edge_size
=
data
assert
node_perm
.
tolist
()
==
[
2
,
3
,
0
,
1
]
assert
node_perm
.
tolist
()
==
[
2
,
3
,
0
,
1
]
...
@@ -29,21 +27,21 @@ def test_padded_index_select(dtype, device):
...
@@ -29,21 +27,21 @@ def test_padded_index_select(dtype, device):
assert
edge_size
==
[
4
,
8
]
assert
edge_size
==
[
4
,
8
]
x
=
tensor
([
0
,
1
,
2
,
3
],
dtype
,
device
).
view
(
-
1
,
1
).
requires_grad_
()
x
=
tensor
([
0
,
1
,
2
,
3
],
dtype
,
device
).
view
(
-
1
,
1
).
requires_grad_
()
fill_value
=
torch
.
tensor
(
0.
,
dtype
=
dtype
)
x_j
=
padded_index_select
(
x
,
col_perm
)
out
=
torch
.
ops
.
torch_sparse
.
padded_index_select
(
x
,
col_perm
,
fill_value
)
assert
out
.
flatten
().
tolist
()
==
[
1
,
3
,
2
,
0
,
0
,
1
,
2
,
3
,
0
,
2
,
3
,
0
]
assert
x_j
.
flatten
().
tolist
()
==
[
1
,
3
,
2
,
0
,
0
,
1
,
2
,
3
,
0
,
2
,
3
,
0
]
grad_out
=
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
dtype
,
device
)
grad_out
=
tensor
([
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
],
dtype
,
device
)
out
.
backward
(
grad_out
.
view
(
-
1
,
1
))
x_j
.
backward
(
grad_out
.
view
(
-
1
,
1
))
assert
x
.
grad
.
flatten
().
tolist
()
==
[
12
,
5
,
17
,
18
]
assert
x
.
grad
.
flatten
().
tolist
()
==
[
12
,
5
,
17
,
18
]
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_padded_index_select_runtime
():
def
test_padded_index_select_runtime
(
device
):
return
return
from
torch_geometric.datasets
import
Planetoid
from
torch_geometric.datasets
import
Planetoid
device
=
torch
.
device
(
'cuda'
)
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
...
...
torch_sparse/__init__.py
View file @
8b77e547
...
@@ -58,6 +58,7 @@ from .cat import cat, cat_diag # noqa
...
@@ -58,6 +58,7 @@ from .cat import cat, cat_diag # noqa
from
.rw
import
random_walk
# noqa
from
.rw
import
random_walk
# noqa
from
.metis
import
partition
# noqa
from
.metis
import
partition
# noqa
from
.saint
import
saint_subgraph
# noqa
from
.saint
import
saint_subgraph
# noqa
from
.padding
import
padded_index
,
padded_index_select
# noqa
from
.convert
import
to_torch_sparse
,
from_torch_sparse
# noqa
from
.convert
import
to_torch_sparse
,
from_torch_sparse
# noqa
from
.convert
import
to_scipy
,
from_scipy
# noqa
from
.convert
import
to_scipy
,
from_scipy
# noqa
...
@@ -100,6 +101,8 @@ __all__ = [
...
@@ -100,6 +101,8 @@ __all__ = [
'random_walk'
,
'random_walk'
,
'partition'
,
'partition'
,
'saint_subgraph'
,
'saint_subgraph'
,
'padded_index'
,
'padded_index_select'
,
'to_torch_sparse'
,
'to_torch_sparse'
,
'from_torch_sparse'
,
'from_torch_sparse'
,
'to_scipy'
,
'to_scipy'
,
...
...
torch_sparse/metis.py
View file @
8b77e547
...
@@ -5,9 +5,8 @@ from torch_sparse.tensor import SparseTensor
...
@@ -5,9 +5,8 @@ from torch_sparse.tensor import SparseTensor
from
torch_sparse.permute
import
permute
from
torch_sparse.permute
import
permute
def
partition
(
def
partition
(
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
src
:
SparseTensor
,
num_parts
:
int
,
recursive
:
bool
=
False
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
rowptr
,
col
=
src
.
storage
.
rowptr
().
cpu
(),
src
.
storage
.
col
().
cpu
()
rowptr
,
col
=
src
.
storage
.
rowptr
().
cpu
(),
src
.
storage
.
col
().
cpu
()
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
num_parts
,
cluster
=
torch
.
ops
.
torch_sparse
.
partition
(
rowptr
,
col
,
num_parts
,
...
@@ -21,5 +20,4 @@ def partition(
...
@@ -21,5 +20,4 @@ def partition(
return
out
,
partptr
,
perm
return
out
,
partptr
,
perm
SparseTensor
.
partition
=
lambda
self
,
num_parts
,
recursive
=
False
:
partition
(
SparseTensor
.
partition
=
partition
self
,
num_parts
,
recursive
)
torch_sparse/padding.py
0 → 100644
View file @
8b77e547
from
typing
import
Tuple
,
List
import
torch
from
torch_sparse.tensor
import
SparseTensor
def
padded_index
(
src
:
SparseTensor
,
binptr
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
List
[
int
],
List
[
int
]]:
return
torch
.
ops
.
torch_sparse
.
padded_index
(
src
.
storage
.
rowptr
(),
src
.
storage
.
col
(),
src
.
storage
.
rowcount
(),
binptr
)
def
padded_index_select
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
fill_value
:
float
=
0.
)
->
torch
.
Tensor
:
fill_value
=
torch
.
tensor
(
fill_value
,
dtype
=
src
.
dtype
)
return
torch
.
ops
.
torch_sparse
.
padded_index_select
(
src
,
index
,
fill_value
)
SparseTensor
.
padded_index
=
padded_index
torch_sparse/saint.py
View file @
8b77e547
...
@@ -15,13 +15,9 @@ def saint_subgraph(src: SparseTensor, node_idx: torch.Tensor
...
@@ -15,13 +15,9 @@ def saint_subgraph(src: SparseTensor, node_idx: torch.Tensor
if
value
is
not
None
:
if
value
is
not
None
:
value
=
value
[
edge_index
]
value
=
value
[
edge_index
]
out
=
SparseTensor
(
out
=
SparseTensor
(
row
=
row
,
rowptr
=
None
,
col
=
col
,
value
=
value
,
row
=
row
,
sparse_sizes
=
(
node_idx
.
size
(
0
),
node_idx
.
size
(
0
)),
rowptr
=
None
,
is_sorted
=
True
)
col
=
col
,
value
=
value
,
sparse_sizes
=
(
node_idx
.
size
(
0
),
node_idx
.
size
(
0
)),
is_sorted
=
True
)
return
out
,
edge_index
return
out
,
edge_index
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment