Commit 78d9af48 authored by rusty1s's avatar rusty1s
Browse files

sample adj

parent d3ae9f10
......@@ -7,7 +7,7 @@ __version__ = '0.6.3'
for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw',
'_saint', '_padding'
'_saint', '_padding', '_sample'
]:
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
library, [osp.dirname(__file__)]).origin)
......@@ -50,7 +50,7 @@ from .metis import partition # noqa
from .bandwidth import reverse_cuthill_mckee # noqa
from .saint import saint_subgraph # noqa
from .padding import padded_index, padded_index_select # noqa
from .sample import sample # noqa
from .sample import sample, sample_adj # noqa
from .convert import to_torch_sparse, from_torch_sparse # noqa
from .convert import to_scipy, from_scipy # noqa
......
from typing import Optional
from typing import Optional, Tuple
import torch
from torch_sparse.tensor import SparseTensor
......@@ -22,4 +22,21 @@ def sample(src: SparseTensor, num_neighbors: int,
return col[rand]
def sample_adj(src: SparseTensor, subset: torch.Tensor, num_neighbors: int,
replace: bool = False) -> Tuple[SparseTensor, torch.Tensor]:
rowptr, col, _ = src.csr()
rowcount = src.storage.rowcount()
rowptr, col, n_id, e_id = torch.ops.torch_sparse.sample_adj(
rowptr, col, rowcount, subset, num_neighbors, replace)
out = SparseTensor(rowptr=rowptr, row=None, col=col, value=e_id,
sparse_sizes=(subset.size(0), n_id.size(0)),
is_sorted=True)
return out, n_id
SparseTensor.sample = sample
SparseTensor.sample_adj = sample_adj
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment