Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
78d9af48
Commit
78d9af48
authored
May 21, 2020
by
rusty1s
Browse files
sample adj
parent
d3ae9f10
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
3 deletions
+20
-3
torch_sparse/__init__.py
torch_sparse/__init__.py
+2
-2
torch_sparse/sample.py
torch_sparse/sample.py
+18
-1
No files found.
torch_sparse/__init__.py
View file @
78d9af48
...
...
@@ -7,7 +7,7 @@ __version__ = '0.6.3'
for
library
in
[
'_version'
,
'_convert'
,
'_diag'
,
'_spmm'
,
'_spspmm'
,
'_metis'
,
'_rw'
,
'_saint'
,
'_padding'
'_saint'
,
'_padding'
,
'_sample'
]:
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
library
,
[
osp
.
dirname
(
__file__
)]).
origin
)
...
...
@@ -50,7 +50,7 @@ from .metis import partition # noqa
from
.bandwidth
import
reverse_cuthill_mckee
# noqa
from
.saint
import
saint_subgraph
# noqa
from
.padding
import
padded_index
,
padded_index_select
# noqa
from
.sample
import
sample
# noqa
from
.sample
import
sample
,
sample_adj
# noqa
from
.convert
import
to_torch_sparse
,
from_torch_sparse
# noqa
from
.convert
import
to_scipy
,
from_scipy
# noqa
...
...
torch_sparse/sample.py
View file @
78d9af48
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
import
torch
from
torch_sparse.tensor
import
SparseTensor
...
...
@@ -22,4 +22,21 @@ def sample(src: SparseTensor, num_neighbors: int,
return
col
[
rand
]
def
sample_adj
(
src
:
SparseTensor
,
subset
:
torch
.
Tensor
,
num_neighbors
:
int
,
replace
:
bool
=
False
)
->
Tuple
[
SparseTensor
,
torch
.
Tensor
]:
rowptr
,
col
,
_
=
src
.
csr
()
rowcount
=
src
.
storage
.
rowcount
()
rowptr
,
col
,
n_id
,
e_id
=
torch
.
ops
.
torch_sparse
.
sample_adj
(
rowptr
,
col
,
rowcount
,
subset
,
num_neighbors
,
replace
)
out
=
SparseTensor
(
rowptr
=
rowptr
,
row
=
None
,
col
=
col
,
value
=
e_id
,
sparse_sizes
=
(
subset
.
size
(
0
),
n_id
.
size
(
0
)),
is_sorted
=
True
)
return
out
,
n_id
SparseTensor
.
sample
=
sample
SparseTensor
.
sample_adj
=
sample_adj
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment