rw.py 2.26 KB
Newer Older
1
from typing import Optional, Tuple, Union
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
from torch import Tensor
rusty1s's avatar
rusty1s committed
5
6


rusty1s's avatar
rusty1s committed
7
@torch.jit.script
8
9
10
11
12
13
14
15
16
17
18
def random_walk(
    row: Tensor,
    col: Tensor,
    start: Tensor,
    walk_length: int,
    p: float = 1,
    q: float = 1,
    coalesced: bool = True,
    num_nodes: Optional[int] = None,
    return_edge_indices: bool = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
rusty1s's avatar
rusty1s committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    """Samples random walks of length :obj:`walk_length` from all node indices
    in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
    `"node2vec: Scalable Feature Learning for Networks"
    <https://arxiv.org/abs/1607.00653>`_ paper.
    Edge indices :obj:`(row, col)` need to be coalesced/sorted according to
    :obj:`row` (use the :obj:`coalesced` attribute to force).

    Args:
        row (LongTensor): Source nodes.
        col (LongTensor): Target nodes.
        start (LongTensor): Nodes from where random walks start.
        walk_length (int): The walk length.
        p (float, optional): Likelihood of immediately revisiting a node in the
            walk. (default: :obj:`1`)
        q (float, optional): Control parameter to interpolate between
            breadth-first strategy and depth-first strategy (default: :obj:`1`)
        coalesced (bool, optional): If set to :obj:`True`, will coalesce/sort
            the graph given by :obj:`(row, col)` according to :obj:`row`.
rusty1s's avatar
rusty1s committed
37
            (default: :obj:`True`)
rusty1s's avatar
rusty1s committed
38
        num_nodes (int, optional): The number of nodes. (default: :obj:`None`)
39
40
41
        return_edge_indices (bool, optional): Whether to additionally return
            the indices of edges traversed during the random walk.
            (default: :obj:`False`)
rusty1s's avatar
rusty1s committed
42
43
44
45

    :rtype: :class:`LongTensor`
    """
    if num_nodes is None:
46
        num_nodes = max(int(row.max()), int(col.max()), int(start.max())) + 1
rusty1s's avatar
rusty1s committed
47
48

    if coalesced:
rusty1s's avatar
rusty1s committed
49
        perm = torch.argsort(row * num_nodes + col)
rusty1s's avatar
rusty1s committed
50
51
        row, col = row[perm], col[perm]

rusty1s's avatar
rusty1s committed
52
53
54
    deg = row.new_zeros(num_nodes)
    deg.scatter_add_(0, row, torch.ones_like(row))
    rowptr = row.new_zeros(num_nodes + 1)
rusty1s's avatar
fix  
rusty1s committed
55
    torch.cumsum(deg, 0, out=rowptr[1:])
rusty1s's avatar
rusty1s committed
56

57
58
59
60
61
62
63
64
    node_seq, edge_seq = torch.ops.torch_cluster.random_walk(
        rowptr, col, start, walk_length, p, q,
    )

    if return_edge_indices:
        return node_seq, edge_seq

    return node_seq