Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
eda4b3d7
Commit
eda4b3d7
authored
Mar 21, 2020
by
rusty1s
Browse files
random walk
parent
fff381c5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
93 deletions
+20
-93
test/test_saint.py
test/test_saint.py
+2
-30
torch_sparse/__init__.py
torch_sparse/__init__.py
+4
-4
torch_sparse/rw.py
torch_sparse/rw.py
+11
-0
torch_sparse/saint.py
torch_sparse/saint.py
+3
-59
No files found.
test/test_saint.py
View file @
eda4b3d7
import
pytest
import
pytest
import
torch
import
torch
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.saint
import
subgraph
from
.utils
import
devices
from
.utils
import
devices
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_subgraph
(
device
):
def
test_
saint_
subgraph
(
device
):
row
=
torch
.
tensor
([
0
,
0
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
4
])
row
=
torch
.
tensor
([
0
,
0
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
4
])
col
=
torch
.
tensor
([
1
,
2
,
0
,
2
,
0
,
1
,
3
,
2
,
4
,
3
])
col
=
torch
.
tensor
([
1
,
2
,
0
,
2
,
0
,
1
,
3
,
2
,
4
,
3
])
adj
=
SparseTensor
(
row
=
row
,
col
=
col
).
to
(
device
)
adj
=
SparseTensor
(
row
=
row
,
col
=
col
).
to
(
device
)
node_idx
=
torch
.
tensor
([
0
,
1
,
2
])
node_idx
=
torch
.
tensor
([
0
,
1
,
2
])
adj
,
edge_index
=
subgraph
(
adj
,
node_idx
)
adj
,
edge_index
=
adj
.
saint_subgraph
(
node_idx
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_sample_node
(
device
):
row
=
torch
.
tensor
([
0
,
0
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
4
])
col
=
torch
.
tensor
([
1
,
2
,
0
,
2
,
0
,
1
,
3
,
2
,
4
,
3
])
adj
=
SparseTensor
(
row
=
row
,
col
=
col
).
to
(
device
)
adj
,
perm
=
adj
.
sample_node
(
num_nodes
=
3
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_sample_edge
(
device
):
row
=
torch
.
tensor
([
0
,
0
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
4
])
col
=
torch
.
tensor
([
1
,
2
,
0
,
2
,
0
,
1
,
3
,
2
,
4
,
3
])
adj
=
SparseTensor
(
row
=
row
,
col
=
col
).
to
(
device
)
adj
,
perm
=
adj
.
sample_edge
(
num_edges
=
3
)
@
pytest
.
mark
.
parametrize
(
'device'
,
devices
)
def
test_sample_rw
(
device
):
row
=
torch
.
tensor
([
0
,
0
,
1
,
1
,
2
,
2
,
2
,
3
,
3
,
4
])
col
=
torch
.
tensor
([
1
,
2
,
0
,
2
,
0
,
1
,
3
,
2
,
4
,
3
])
adj
=
SparseTensor
(
row
=
row
,
col
=
col
).
to
(
device
)
adj
,
perm
=
adj
.
sample_rw
(
num_root_nodes
=
3
,
walk_length
=
2
)
torch_sparse/__init__.py
View file @
eda4b3d7
...
@@ -55,8 +55,9 @@ from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
...
@@ -55,8 +55,9 @@ from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
from
.reduce
import
sum
,
mean
,
min
,
max
# noqa
from
.reduce
import
sum
,
mean
,
min
,
max
# noqa
from
.matmul
import
matmul
# noqa
from
.matmul
import
matmul
# noqa
from
.cat
import
cat
,
cat_diag
# noqa
from
.cat
import
cat
,
cat_diag
# noqa
from
.rw
import
random_walk
# noqa
from
.metis
import
partition
# noqa
from
.metis
import
partition
# noqa
from
.saint
import
sa
mple_node
,
sample_edge
,
sample_rw
# noqa
from
.saint
import
sa
int_subgraph
# noqa
from
.convert
import
to_torch_sparse
,
from_torch_sparse
# noqa
from
.convert
import
to_torch_sparse
,
from_torch_sparse
# noqa
from
.convert
import
to_scipy
,
from_scipy
# noqa
from
.convert
import
to_scipy
,
from_scipy
# noqa
...
@@ -96,10 +97,9 @@ __all__ = [
...
@@ -96,10 +97,9 @@ __all__ = [
'matmul'
,
'matmul'
,
'cat'
,
'cat'
,
'cat_diag'
,
'cat_diag'
,
'random_walk'
,
'partition'
,
'partition'
,
'sample_node'
,
'saint_subgraph'
,
'sample_edge'
,
'sample_rw'
,
'to_torch_sparse'
,
'to_torch_sparse'
,
'from_torch_sparse'
,
'from_torch_sparse'
,
'to_scipy'
,
'to_scipy'
,
...
...
torch_sparse/rw.py
0 → 100644
View file @
eda4b3d7
import
torch
from
torch_sparse.tensor
import
SparseTensor
def
random_walk
(
src
:
SparseTensor
,
start
:
torch
.
Tensor
,
walk_length
:
int
)
->
torch
.
Tensor
:
rowptr
,
col
,
_
=
src
.
csr
()
return
torch
.
ops
.
torch_sparse
.
random_walk
(
rowptr
,
col
,
start
,
walk_length
)
SparseTensor
.
random_walk
=
random_walk
torch_sparse/saint.py
View file @
eda4b3d7
from
typing
import
Tuple
from
typing
import
Tuple
import
torch
import
torch
import
numpy
as
np
from
torch_scatter
import
scatter_add
from
torch_sparse.tensor
import
SparseTensor
from
torch_sparse.tensor
import
SparseTensor
def
subgraph
(
src
:
SparseTensor
,
def
saint_
subgraph
(
src
:
SparseTensor
,
node_idx
:
torch
.
Tensor
node_idx
:
torch
.
Tensor
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
]:
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
]:
row
,
col
,
value
=
src
.
coo
()
row
,
col
,
value
=
src
.
coo
()
rowptr
=
src
.
storage
.
rowptr
()
rowptr
=
src
.
storage
.
rowptr
()
...
@@ -28,58 +26,4 @@ def subgraph(src: SparseTensor,
...
@@ -28,58 +26,4 @@ def subgraph(src: SparseTensor,
return
out
,
edge_index
return
out
,
edge_index
def
sample_node
(
src
:
SparseTensor
,
SparseTensor
.
saint_subgraph
=
saint_subgraph
num_nodes
:
int
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
]:
row
,
col
,
_
=
src
.
coo
()
inv_in_deg
=
src
.
storage
.
colcount
().
to
(
torch
.
float
).
pow_
(
-
1
)
inv_in_deg
[
inv_in_deg
==
float
(
'inf'
)]
=
0
prob
=
inv_in_deg
[
col
]
prob
.
mul_
(
prob
)
prob
=
scatter_add
(
prob
,
row
,
dim
=
0
,
dim_size
=
src
.
size
(
0
))
prob
.
div_
(
prob
.
sum
())
node_idx
=
prob
.
multinomial
(
num_nodes
,
replacement
=
True
).
unique
()
return
src
.
permute
(
node_idx
),
node_idx
def
sample_edge
(
src
:
SparseTensor
,
num_edges
:
int
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
]:
row
,
col
,
_
=
src
.
coo
()
inv_out_deg
=
src
.
storage
.
rowcount
().
to
(
torch
.
float
).
pow_
(
-
1
)
inv_out_deg
[
inv_out_deg
==
float
(
'inf'
)]
=
0
inv_in_deg
=
src
.
storage
.
colcount
().
to
(
torch
.
float
).
pow_
(
-
1
)
inv_in_deg
[
inv_in_deg
==
float
(
'inf'
)]
=
0
prob
=
inv_out_deg
[
row
]
+
inv_in_deg
[
col
]
prob
.
div_
(
prob
.
sum
())
edge_idx
=
prob
.
multinomial
(
num_edges
,
replacement
=
True
)
node_idx
=
col
[
edge_idx
].
unique
()
return
src
.
permute
(
node_idx
),
node_idx
def
sample_rw
(
src
:
SparseTensor
,
num_root_nodes
:
int
,
walk_length
:
int
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
]:
rowptr
,
col
,
_
=
src
.
csr
()
start
=
np
.
random
.
choice
(
src
.
size
(
0
),
size
=
num_root_nodes
,
replace
=
False
)
start
=
torch
.
from_numpy
(
start
).
to
(
src
.
device
(),
torch
.
long
)
out
=
torch
.
ops
.
torch_sparse
.
random_walk
(
rowptr
,
col
,
start
,
walk_length
)
node_idx
=
out
.
flatten
().
unique
()
return
src
.
permute
(
node_idx
),
node_idx
SparseTensor
.
sample_node
=
sample_node
SparseTensor
.
sample_edge
=
sample_edge
SparseTensor
.
sample_rw
=
sample_rw
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment