rw.py 2.14 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
import warnings
rusty1s's avatar
rusty1s committed
2
from typing import Optional
rusty1s's avatar
rusty1s committed
3

rusty1s's avatar
rusty1s committed
4
5
6
import torch


rusty1s's avatar
rusty1s committed
7
8
9
@torch.jit.script
def random_walk(row: torch.Tensor, col: torch.Tensor, start: torch.Tensor,
                walk_length: int, p: float = 1, q: float = 1,
rusty1s's avatar
rusty1s committed
10
                coalesced: bool = True, num_nodes: Optional[int] = None):
rusty1s's avatar
rusty1s committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    """Samples random walks of length :obj:`walk_length` from all node indices
    in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
    `"node2vec: Scalable Feature Learning for Networks"
    <https://arxiv.org/abs/1607.00653>`_ paper.
    Edge indices :obj:`(row, col)` need to be coalesced/sorted according to
    :obj:`row` (use the :obj:`coalesced` attribute to force).

    Args:
        row (LongTensor): Source nodes.
        col (LongTensor): Target nodes.
        start (LongTensor): Nodes from where random walks start.
        walk_length (int): The walk length.
        p (float, optional): Likelihood of immediately revisiting a node in the
            walk. (default: :obj:`1`)
        q (float, optional): Control parameter to interpolate between
            breadth-first strategy and depth-first strategy (default: :obj:`1`)
        coalesced (bool, optional): If set to :obj:`True`, will coalesce/sort
            the graph given by :obj:`(row, col)` according to :obj:`row`.
rusty1s's avatar
rusty1s committed
29
            (default: :obj:`True`)
rusty1s's avatar
rusty1s committed
30
31
32
33
34
        num_nodes (int, optional): The number of nodes. (default: :obj:`None`)

    :rtype: :class:`LongTensor`
    """
    if num_nodes is None:
rusty1s's avatar
rusty1s committed
35
        num_nodes = max(int(row.max()), int(col.max())) + 1
rusty1s's avatar
rusty1s committed
36
37

    if coalesced:
rusty1s's avatar
rusty1s committed
38
        perm = torch.argsort(row * num_nodes + col)
rusty1s's avatar
rusty1s committed
39
40
        row, col = row[perm], col[perm]

rusty1s's avatar
rusty1s committed
41
42
43
    deg = row.new_zeros(num_nodes)
    deg.scatter_add_(0, row, torch.ones_like(row))
    rowptr = row.new_zeros(num_nodes + 1)
rusty1s's avatar
fix  
rusty1s committed
44
    torch.cumsum(deg, 0, out=rowptr[1:])
rusty1s's avatar
rusty1s committed
45
46

    if p != 1. or q != 1.:  # pragma: no cover
rusty1s's avatar
rusty1s committed
47
48
        warnings.warn('Parameters `p` and `q` are not supported yet and will'
                      'be restored to their default values `p=1` and `q=1`.')
rusty1s's avatar
rusty1s committed
49
        p = q = 1.
rusty1s's avatar
rusty1s committed
50

rusty1s's avatar
rusty1s committed
51
52
    return torch.ops.torch_cluster.random_walk(rowptr, col, start, walk_length,
                                               p, q)