rw.py 1.92 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
from typing import Optional
rusty1s's avatar
rusty1s committed
2

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
from torch import Tensor
rusty1s's avatar
rusty1s committed
5
6


rusty1s's avatar
rusty1s committed
7
@torch.jit.script
rusty1s's avatar
rusty1s committed
8
9
10
def random_walk(row: Tensor, col: Tensor, start: Tensor, walk_length: int,
                p: float = 1, q: float = 1, coalesced: bool = True,
                num_nodes: Optional[int] = None) -> Tensor:
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    """Samples random walks of length :obj:`walk_length` from all node indices
    in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
    `"node2vec: Scalable Feature Learning for Networks"
    <https://arxiv.org/abs/1607.00653>`_ paper.
    Edge indices :obj:`(row, col)` need to be coalesced/sorted according to
    :obj:`row` (use the :obj:`coalesced` attribute to force).

    Args:
        row (LongTensor): Source nodes.
        col (LongTensor): Target nodes.
        start (LongTensor): Nodes from where random walks start.
        walk_length (int): The walk length.
        p (float, optional): Likelihood of immediately revisiting a node in the
            walk. (default: :obj:`1`)
        q (float, optional): Control parameter to interpolate between
            breadth-first strategy and depth-first strategy (default: :obj:`1`)
        coalesced (bool, optional): If set to :obj:`True`, will coalesce/sort
            the graph given by :obj:`(row, col)` according to :obj:`row`.
rusty1s's avatar
rusty1s committed
29
            (default: :obj:`True`)
rusty1s's avatar
rusty1s committed
30
31
32
33
34
        num_nodes (int, optional): The number of nodes. (default: :obj:`None`)

    :rtype: :class:`LongTensor`
    """
    if num_nodes is None:
rusty1s's avatar
rusty1s committed
35
        num_nodes = max(int(row.max()), int(col.max())) + 1
rusty1s's avatar
rusty1s committed
36
37

    if coalesced:
rusty1s's avatar
rusty1s committed
38
        perm = torch.argsort(row * num_nodes + col)
rusty1s's avatar
rusty1s committed
39
40
        row, col = row[perm], col[perm]

rusty1s's avatar
rusty1s committed
41
42
43
    deg = row.new_zeros(num_nodes)
    deg.scatter_add_(0, row, torch.ones_like(row))
    rowptr = row.new_zeros(num_nodes + 1)
rusty1s's avatar
fix  
rusty1s committed
44
    torch.cumsum(deg, 0, out=rowptr[1:])
rusty1s's avatar
rusty1s committed
45
46

    return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length,
rusty1s's avatar
rusty1s committed
47
                                               p, q)[0]