test_randomwalk.py 1.97 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import dgl
from dgl import utils
import backend as F
import numpy as np

def test_random_walk():
    edge_list = [(0, 1), (1, 2), (2, 3), (3, 4),
                 (4, 3), (3, 2), (2, 1), (1, 0)]
    seeds = [0, 1]
    n_traces = 3
    n_hops = 4

    g = dgl.DGLGraph(edge_list, readonly=True)
    traces = dgl.contrib.sampling.random_walk(g, seeds, n_traces, n_hops)
    traces = F.zerocopy_to_numpy(traces)

    assert traces.shape == (len(seeds), n_traces, n_hops + 1)

    for i, seed in enumerate(seeds):
        assert (traces[i, :, 0] == seeds[i]).all()

    trace_diff = np.diff(traces, axis=-1)
    # only nodes with adjacent IDs are connected
    assert (np.abs(trace_diff) == 1).all()

def test_random_walk_with_restart():
    edge_list = [(0, 1), (1, 2), (2, 3), (3, 4),
                 (4, 3), (3, 2), (2, 1), (1, 0)]
    seeds = [0, 1]
    max_nodes = 10

    g = dgl.DGLGraph(edge_list)

    # test normal RWR
    traces = dgl.contrib.sampling.random_walk_with_restart(g, seeds, 0.2, max_nodes)
    assert len(traces) == len(seeds)
    for traces_per_seed in traces:
        total_nodes = 0
        for t in traces_per_seed:
            total_nodes += len(t)
            trace_diff = np.diff(F.zerocopy_to_numpy(t), axis=-1)
            assert (np.abs(trace_diff) == 1).all()
        assert total_nodes >= max_nodes

    # test RWR with early stopping
    traces = dgl.contrib.sampling.random_walk_with_restart(
            g, seeds, 1, 100, max_nodes, 1)
    assert len(traces) == len(seeds)
    for traces_per_seed in traces:
        assert sum(len(t) for t in traces_per_seed) < 100

    # test bipartite RWR
    traces = dgl.contrib.sampling.bipartite_single_sided_random_walk_with_restart(
            g, seeds, 0.2, max_nodes)
    assert len(traces) == len(seeds)
    for traces_per_seed in traces:
        for t in traces_per_seed:
            trace_diff = np.diff(F.zerocopy_to_numpy(t), axis=-1)
            assert (trace_diff % 2 == 0).all()

if __name__ == '__main__':
    test_random_walk()