Commit eda4b3d7 authored by rusty1s's avatar rusty1s
Browse files

random walk

parent fff381c5
import pytest import pytest
import torch import torch
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
from torch_sparse.saint import subgraph
from .utils import devices from .utils import devices
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_subgraph(device): def test_saint_subgraph(device):
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4]) row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3]) col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
adj = SparseTensor(row=row, col=col).to(device) adj = SparseTensor(row=row, col=col).to(device)
node_idx = torch.tensor([0, 1, 2]) node_idx = torch.tensor([0, 1, 2])
adj, edge_index = subgraph(adj, node_idx) adj, edge_index = adj.saint_subgraph(node_idx)
@pytest.mark.parametrize('device', devices)
def test_sample_node(device):
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
adj = SparseTensor(row=row, col=col).to(device)
adj, perm = adj.sample_node(num_nodes=3)
@pytest.mark.parametrize('device', devices)
def test_sample_edge(device):
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
adj = SparseTensor(row=row, col=col).to(device)
adj, perm = adj.sample_edge(num_edges=3)
@pytest.mark.parametrize('device', devices)
def test_sample_rw(device):
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
adj = SparseTensor(row=row, col=col).to(device)
adj, perm = adj.sample_rw(num_root_nodes=3, walk_length=2)
...@@ -55,8 +55,9 @@ from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa ...@@ -55,8 +55,9 @@ from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
from .reduce import sum, mean, min, max # noqa from .reduce import sum, mean, min, max # noqa
from .matmul import matmul # noqa from .matmul import matmul # noqa
from .cat import cat, cat_diag # noqa from .cat import cat, cat_diag # noqa
from .rw import random_walk # noqa
from .metis import partition # noqa from .metis import partition # noqa
from .saint import sample_node, sample_edge, sample_rw # noqa from .saint import saint_subgraph # noqa
from .convert import to_torch_sparse, from_torch_sparse # noqa from .convert import to_torch_sparse, from_torch_sparse # noqa
from .convert import to_scipy, from_scipy # noqa from .convert import to_scipy, from_scipy # noqa
...@@ -96,10 +97,9 @@ __all__ = [ ...@@ -96,10 +97,9 @@ __all__ = [
'matmul', 'matmul',
'cat', 'cat',
'cat_diag', 'cat_diag',
'random_walk',
'partition', 'partition',
'sample_node', 'saint_subgraph',
'sample_edge',
'sample_rw',
'to_torch_sparse', 'to_torch_sparse',
'from_torch_sparse', 'from_torch_sparse',
'to_scipy', 'to_scipy',
......
import torch
from torch_sparse.tensor import SparseTensor
def random_walk(src: SparseTensor, start: torch.Tensor,
walk_length: int) -> torch.Tensor:
rowptr, col, _ = src.csr()
return torch.ops.torch_sparse.random_walk(rowptr, col, start, walk_length)
SparseTensor.random_walk = random_walk
from typing import Tuple from typing import Tuple
import torch import torch
import numpy as np
from torch_scatter import scatter_add
from torch_sparse.tensor import SparseTensor from torch_sparse.tensor import SparseTensor
def subgraph(src: SparseTensor, def saint_subgraph(src: SparseTensor, node_idx: torch.Tensor
node_idx: torch.Tensor) -> Tuple[SparseTensor, torch.Tensor]: ) -> Tuple[SparseTensor, torch.Tensor]:
row, col, value = src.coo() row, col, value = src.coo()
rowptr = src.storage.rowptr() rowptr = src.storage.rowptr()
...@@ -28,58 +26,4 @@ def subgraph(src: SparseTensor, ...@@ -28,58 +26,4 @@ def subgraph(src: SparseTensor,
return out, edge_index return out, edge_index
def sample_node(src: SparseTensor, SparseTensor.saint_subgraph = saint_subgraph
num_nodes: int) -> Tuple[SparseTensor, torch.Tensor]:
row, col, _ = src.coo()
inv_in_deg = src.storage.colcount().to(torch.float).pow_(-1)
inv_in_deg[inv_in_deg == float('inf')] = 0
prob = inv_in_deg[col]
prob.mul_(prob)
prob = scatter_add(prob, row, dim=0, dim_size=src.size(0))
prob.div_(prob.sum())
node_idx = prob.multinomial(num_nodes, replacement=True).unique()
return src.permute(node_idx), node_idx
def sample_edge(src: SparseTensor,
num_edges: int) -> Tuple[SparseTensor, torch.Tensor]:
row, col, _ = src.coo()
inv_out_deg = src.storage.rowcount().to(torch.float).pow_(-1)
inv_out_deg[inv_out_deg == float('inf')] = 0
inv_in_deg = src.storage.colcount().to(torch.float).pow_(-1)
inv_in_deg[inv_in_deg == float('inf')] = 0
prob = inv_out_deg[row] + inv_in_deg[col]
prob.div_(prob.sum())
edge_idx = prob.multinomial(num_edges, replacement=True)
node_idx = col[edge_idx].unique()
return src.permute(node_idx), node_idx
def sample_rw(src: SparseTensor, num_root_nodes: int,
walk_length: int) -> Tuple[SparseTensor, torch.Tensor]:
rowptr, col, _ = src.csr()
start = np.random.choice(src.size(0), size=num_root_nodes, replace=False)
start = torch.from_numpy(start).to(src.device(), torch.long)
out = torch.ops.torch_sparse.random_walk(rowptr, col, start, walk_length)
node_idx = out.flatten().unique()
return src.permute(node_idx), node_idx
SparseTensor.sample_node = sample_node
SparseTensor.sample_edge = sample_edge
SparseTensor.sample_rw = sample_rw
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment