from typing import Tuple import torch import numpy as np from torch_scatter import scatter_add from torch_sparse.tensor import SparseTensor def subgraph(src: SparseTensor, node_idx: torch.Tensor) -> Tuple[SparseTensor, torch.Tensor]: row, col, value = src.coo() rowptr = src.storage.rowptr() data = torch.ops.torch_sparse.saint_subgraph(node_idx, rowptr, row, col) row, col, edge_index = data if value is not None: value = value[edge_index] out = SparseTensor( row=row, rowptr=None, col=col, value=value, sparse_sizes=(node_idx.size(0), node_idx.size(0)), is_sorted=True) return out, edge_index def sample_node(src: SparseTensor, num_nodes: int) -> Tuple[SparseTensor, torch.Tensor]: row, col, _ = src.coo() inv_in_deg = src.storage.colcount().to(torch.float).pow_(-1) inv_in_deg[inv_in_deg == float('inf')] = 0 prob = inv_in_deg[col] prob.mul_(prob) prob = scatter_add(prob, row, dim=0, dim_size=src.size(0)) prob.div_(prob.sum()) node_idx = prob.multinomial(num_nodes, replacement=True).unique() return src.permute(node_idx), node_idx def sample_edge(src: SparseTensor, num_edges: int) -> Tuple[SparseTensor, torch.Tensor]: row, col, _ = src.coo() inv_out_deg = src.storage.rowcount().to(torch.float).pow_(-1) inv_out_deg[inv_out_deg == float('inf')] = 0 inv_in_deg = src.storage.colcount().to(torch.float).pow_(-1) inv_in_deg[inv_in_deg == float('inf')] = 0 prob = inv_out_deg[row] + inv_in_deg[col] prob.div_(prob.sum()) edge_idx = prob.multinomial(num_edges, replacement=True) node_idx = col[edge_idx].unique() return src.permute(node_idx), node_idx def sample_rw(src: SparseTensor, num_root_nodes: int, walk_length: int) -> Tuple[SparseTensor, torch.Tensor]: rowptr, col, _ = src.csr() start = np.random.choice(src.size(0), size=num_root_nodes, replace=False) start = torch.from_numpy(start).to(src.device(), torch.long) out = torch.ops.torch_sparse.random_walk(rowptr, col, start, walk_length) node_idx = out.flatten().unique() return src.permute(node_idx), node_idx SparseTensor.sample_node = sample_node SparseTensor.sample_edge = sample_edge SparseTensor.sample_rw = sample_rw