rw.py 633 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
import warnings

rusty1s's avatar
rusty1s committed
3
4
5
6
7
8
import torch

if torch.cuda.is_available():
    import torch_cluster.rw_cuda


rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
def random_walk(row, col, start, walk_length, p=1, q=1, num_nodes=None):
    if p != 1 or q != 1:
        warnings.warn('Parameters `p` and `q` are not supported yet and will'
                      'be restored to their default values `p=1` and `q=1`.')
        p = q = 1

rusty1s's avatar
rusty1s committed
15
16
    num_nodes = row.max().item() + 1 if num_nodes is None else num_nodes
    if row.is_cuda:
rusty1s's avatar
rusty1s committed
17
        return torch_cluster.rw_cuda.rw(row, col, start, walk_length, p, q,
rusty1s's avatar
rusty1s committed
18
19
20
                                        num_nodes)
    else:
        raise NotImplementedError