rw.py 2.1 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import warnings

rusty1s's avatar
rusty1s committed
3
import torch
rusty1s's avatar
rusty1s committed
4
import torch_cluster.rw_cpu
rusty1s's avatar
rusty1s committed
5
6
7
8
9

if torch.cuda.is_available():
    import torch_cluster.rw_cuda


rusty1s's avatar
rusty1s committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def random_walk(row, col, start, walk_length, p=1, q=1, coalesced=False,
                num_nodes=None):
    """Samples random walks of length :obj:`walk_length` from all node indices
    in :obj:`start` in the graph given by :obj:`(row, col)` as described in the
    `"node2vec: Scalable Feature Learning for Networks"
    <https://arxiv.org/abs/1607.00653>`_ paper.
    Edge indices :obj:`(row, col)` need to be coalesced/sorted according to
    :obj:`row` (use the :obj:`coalesced` attribute to force).

    Args:
        row (LongTensor): Source nodes.
        col (LongTensor): Target nodes.
        start (LongTensor): Nodes from where random walks start.
        walk_length (int): The walk length.
        p (float, optional): Likelihood of immediately revisiting a node in the
            walk. (default: :obj:`1`)
        q (float, optional): Control parameter to interpolate between
            breadth-first strategy and depth-first strategy (default: :obj:`1`)
        coalesced (bool, optional): If set to :obj:`True`, will coalesce/sort
            the graph given by :obj:`(row, col)` according to :obj:`row`.
            (default: :obj:`False`)
        num_nodes (int, optional): The number of nodes. (default: :obj:`None`)

    :rtype: :class:`LongTensor`
    """
    if num_nodes is None:
        num_nodes = max(row.max(), col.max()).item() + 1

    if coalesced:
        _, perm = torch.sort(row * num_nodes + col)
        row, col = row[perm], col[perm]

    if p != 1 or q != 1:  # pragma: no cover
rusty1s's avatar
rusty1s committed
43
44
45
46
        warnings.warn('Parameters `p` and `q` are not supported yet and will'
                      'be restored to their default values `p=1` and `q=1`.')
        p = q = 1

rusty1s's avatar
rusty1s committed
47
48
    start = start.flatten()

rusty1s's avatar
pragma  
rusty1s committed
49
    if row.is_cuda:  # pragma: no cover
rusty1s's avatar
rusty1s committed
50
        return torch_cluster.rw_cuda.rw(row, col, start, walk_length, p, q,
rusty1s's avatar
rusty1s committed
51
52
                                        num_nodes)
    else:
rusty1s's avatar
rusty1s committed
53
54
        return torch_cluster.rw_cpu.rw(row, col, start, walk_length, p, q,
                                       num_nodes)