from typing import Optional, Tuple, Union import torch from torch import Tensor @torch.jit.script def random_walk( row: Tensor, col: Tensor, start: Tensor, walk_length: int, p: float = 1, q: float = 1, coalesced: bool = True, num_nodes: Optional[int] = None, return_edge_indices: bool = False, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Samples random walks of length :obj:`walk_length` from all node indices in :obj:`start` in the graph given by :obj:`(row, col)` as described in the `"node2vec: Scalable Feature Learning for Networks" `_ paper. Edge indices :obj:`(row, col)` need to be coalesced/sorted according to :obj:`row` (use the :obj:`coalesced` attribute to force). Args: row (LongTensor): Source nodes. col (LongTensor): Target nodes. start (LongTensor): Nodes from where random walks start. walk_length (int): The walk length. p (float, optional): Likelihood of immediately revisiting a node in the walk. (default: :obj:`1`) q (float, optional): Control parameter to interpolate between breadth-first strategy and depth-first strategy (default: :obj:`1`) coalesced (bool, optional): If set to :obj:`True`, will coalesce/sort the graph given by :obj:`(row, col)` according to :obj:`row`. (default: :obj:`True`) num_nodes (int, optional): The number of nodes. (default: :obj:`None`) return_edge_indices (bool, optional): Whether to additionally return the indices of edges traversed during the random walk. (default: :obj:`False`) :rtype: :class:`LongTensor` """ if num_nodes is None: num_nodes = max(int(row.max()), int(col.max()), int(start.max())) + 1 if coalesced: perm = torch.argsort(row * num_nodes + col) row, col = row[perm], col[perm] deg = row.new_zeros(num_nodes) deg.scatter_add_(0, row, torch.ones_like(row)) rowptr = row.new_zeros(num_nodes + 1) torch.cumsum(deg, 0, out=rowptr[1:]) node_seq, edge_seq = torch.ops.torch_cluster.random_walk( rowptr, col, start, walk_length, p, q, ) if return_edge_indices: return node_seq, edge_seq return node_seq