rw.py 2.13 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import warnings
rusty1s's avatar
rusty1s committed
2
from typing import Optional
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
5
6
import torch


rusty1s's avatar
rusty1s committed
7
8
9
10
@torch.jit.script
def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor,
                walk_length: int, p: float = 1, q: float = 1,
                coalesced: bool = False, num_nodes: Optional[int] = None):
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    """Samples random walks of length :obj:`walk_length` from all node indices
    in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
    `"node2vec: Scalable Feature Learning for Networks"
    <https://arxiv.org/abs/1607.00653>`_ paper.
    Edge indices :obj:`(row, col)` need to be coalesced/sorted according to
    :obj:`row` (use the :obj:`coalesced` attribute to force).

    Args:
        row (LongTensor): Source nodes.
        col (LongTensor): Target nodes.
        start (LongTensor): Nodes from where random walks start.
        walk_length (int): The walk length.
        p (float, optional): Likelihood of immediately revisiting a node in the
            walk. (default: :obj:`1`)
        q (float, optional): Control parameter to interpolate between
            breadth-first strategy and depth-first strategy (default: :obj:`1`)
        coalesced (bool, optional): If set to :obj:`True`, will coalesce/sort
            the graph given by :obj:`(row, col)` according to :obj:`row`.
            (default: :obj:`False`)
        num_nodes (int, optional): The number of nodes. (default: :obj:`None`)

    :rtype: :class:`LongTensor`
    """
    if num_nodes is None:
rusty1s's avatar
rusty1s committed
35
        num_nodes = max(int(row.max()), int(col.max())) + 1
rusty1s's avatar
rusty1s committed
36
37
38
39
40

    if coalesced:
        _, perm = torch.sort(row * num_nodes + col)
        row, col = row[perm], col[perm]

rusty1s's avatar
rusty1s committed
41
42
43
44
45
46
    deg = row.new_zeros(num_nodes)
    deg.scatter_add_(0, row, torch.ones_like(row))
    rowptr = row.new_zeros(num_nodes + 1)
    deg.cumsum(0, out=rowptr[1:])

    if p != 1. or q != 1.:  # pragma: no cover
rusty1s's avatar
rusty1s committed
47
48
        warnings.warn('Parameters `p` and `q` are not supported yet and will'
                      'be restored to their default values `p=1` and `q=1`.')
rusty1s's avatar
rusty1s committed
49
        p = q = 1.
rusty1s's avatar
rusty1s committed
50

rusty1s's avatar
rusty1s committed
51
52
    return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length,
                                               p, q)