rw.py 2.14 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import warnings
rusty1s's avatar
rusty1s committed
2
from typing import Optional
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
5
6
import torch


rusty1s's avatar
rusty1s committed
7
8
9
10
@torch.jit.script
def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor,
                walk_length: int, p: float = 1, q: float = 1,
                coalesced: bool = False, num_nodes: Optional[int] = None):
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    """Samples random walks of length :obj:`walk_length` from all node indices
    in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
    `"node2vec: Scalable Feature Learning for Networks"
    <https://arxiv.org/abs/1607.00653>`_ paper.
    Edge indices :obj:`(row, col)` need to be coalesced/sorted according to
    :obj:`row` (use the :obj:`coalesced` attribute to force).

    Args:
        row (LongTensor): Source nodes.
        col (LongTensor): Target nodes.
        start (LongTensor): Nodes from where random walks start.
        walk_length (int): The walk length.
        p (float, optional): Likelihood of immediately revisiting a node in the
            walk. (default: :obj:`1`)
        q (float, optional): Control parameter to interpolate between
            breadth-first strategy and depth-first strategy (default: :obj:`1`)
        coalesced (bool, optional): If set to :obj:`True`, will coalesce/sort
            the graph given by :obj:`(row, col)` according to :obj:`row`.
            (default: :obj:`False`)
        num_nodes (int, optional): The number of nodes. (default: :obj:`None`)

    :rtype: :class:`LongTensor`
    """
    if num_nodes is None:
rusty1s's avatar
rusty1s committed
35
        num_nodes = max(int(row.max()), int(col.max())) + 1
rusty1s's avatar
rusty1s committed
36
37

    if coalesced:
rusty1s's avatar
rusty1s committed
38
        perm = torch.argsort(row * num_nodes + col)
rusty1s's avatar
rusty1s committed
39
40
        row, col = row[perm], col[perm]

rusty1s's avatar
rusty1s committed
41
42
43
    deg = row.new_zeros(num_nodes)
    deg.scatter_add_(0, row, torch.ones_like(row))
    rowptr = row.new_zeros(num_nodes + 1)
rusty1s's avatar
fix  
rusty1s committed
44
    torch.cumsum(deg, 0, out=rowptr[1:])
rusty1s's avatar
rusty1s committed
45
46

    if p != 1. or q != 1.:  # pragma: no cover
rusty1s's avatar
rusty1s committed
47
48
        warnings.warn('Parameters `p` and `q` are not supported yet and will'
                      'be restored to their default values `p=1` and `q=1`.')
rusty1s's avatar
rusty1s committed
49
        p = q = 1.
rusty1s's avatar
rusty1s committed
50

rusty1s's avatar
rusty1s committed
51
52
    return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length,
                                               p, q)