test_sampler.py 7.32 KB
Newer Older
1
import backend as F
Da Zheng's avatar
Da Zheng committed
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np
import scipy as sp
import dgl
from dgl import utils

def generate_rand_graph(n):
    arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
    return dgl.DGLGraph(arr, readonly=True)

def test_1neighbor_sampler_all():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
14
15
16
    for subg in dgl.contrib.sampling.NeighborSampler(g, 1, 100, neighbor_type='in',
                                                     num_workers=4):
        seed_ids = subg.layer_parent_nid(-1)
Da Zheng's avatar
Da Zheng committed
17
        assert len(seed_ids) == 1
18
        src, dst, eid = g.in_edges(seed_ids, form='all')
Da Zheng's avatar
Da Zheng committed
19
20
        assert subg.number_of_nodes() == len(src) + 1
        assert subg.number_of_edges() == len(src)
Da Zheng's avatar
Da Zheng committed
21

Da Zheng's avatar
Da Zheng committed
22
23
24
        assert seed_ids == subg.layer_parent_nid(-1)
        child_src, child_dst, child_eid = subg.in_edges(subg.layer_nid(-1), form='all')
        assert F.array_equal(child_src, subg.layer_nid(0))
Da Zheng's avatar
Da Zheng committed
25

Da Zheng's avatar
Da Zheng committed
26
27
        src1 = subg.map_to_parent_nid(child_src)
        assert F.array_equal(src1, src)
Da Zheng's avatar
Da Zheng committed
28
29

def is_sorted(arr):
30
    return np.sum(np.sort(arr) == arr, 0) == len(arr)
Da Zheng's avatar
Da Zheng committed
31
32

def verify_subgraph(g, subg, seed_id):
Da Zheng's avatar
Da Zheng committed
33
34
35
36
    seed_id = F.asnumpy(seed_id)
    seeds = F.asnumpy(subg.map_to_parent_nid(subg.layer_nid(-1)))
    assert seed_id in seeds
    child_seed = F.asnumpy(subg.layer_nid(-1))[seeds == seed_id]
37
    src, dst, eid = g.in_edges(seed_id, form='all')
Da Zheng's avatar
Da Zheng committed
38
39
    child_src, child_dst, child_eid = subg.in_edges(child_seed, form='all')

40
    child_src = F.asnumpy(child_src)
Da Zheng's avatar
Da Zheng committed
41
42
43
44
45
    # We don't allow duplicate elements in the neighbor list.
    assert(len(np.unique(child_src)) == len(child_src))
    # The neighbor list also needs to be sorted.
    assert(is_sorted(child_src))

Da Zheng's avatar
Da Zheng committed
46
47
48
    # a neighbor in the subgraph must also exist in parent graph.
    for i in subg.map_to_parent_nid(child_src):
        assert i in src
Da Zheng's avatar
Da Zheng committed
49
50
51
52

def test_1neighbor_sampler():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
53
54
55
    for subg in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
                                                     num_workers=4):
        seed_ids = subg.layer_parent_nid(-1)
Da Zheng's avatar
Da Zheng committed
56
57
58
59
60
        assert len(seed_ids) == 1
        assert subg.number_of_nodes() <= 6
        assert subg.number_of_edges() <= 5
        verify_subgraph(g, subg, seed_ids)

61
62
63
def test_prefetch_neighbor_sampler():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
64
65
66
    for subg in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
                                                     num_workers=4, prefetch=True):
        seed_ids = subg.layer_parent_nid(-1)
67
68
69
70
71
        assert len(seed_ids) == 1
        assert subg.number_of_nodes() <= 6
        assert subg.number_of_edges() <= 5
        verify_subgraph(g, subg, seed_ids)

Da Zheng's avatar
Da Zheng committed
72
73
74
def test_10neighbor_sampler_all():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
75
76
77
    for subg in dgl.contrib.sampling.NeighborSampler(g, 10, 100, neighbor_type='in',
                                                     num_workers=4):
        seed_ids = subg.layer_parent_nid(-1)
Da Zheng's avatar
Da Zheng committed
78
        assert F.array_equal(seed_ids, subg.map_to_parent_nid(subg.layer_nid(-1)))
Da Zheng's avatar
Da Zheng committed
79

Da Zheng's avatar
Da Zheng committed
80
81
82
83
        src, dst, eid = g.in_edges(seed_ids, form='all')
        child_src, child_dst, child_eid = subg.in_edges(subg.layer_nid(-1), form='all')
        src1 = subg.map_to_parent_nid(child_src)
        assert F.array_equal(src1, src)
Da Zheng's avatar
Da Zheng committed
84
85
86

def check_10neighbor_sampler(g, seeds):
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
87
88
89
    for subg in dgl.contrib.sampling.NeighborSampler(g, 10, 5, neighbor_type='in',
                                                     num_workers=4, seed_nodes=seeds):
        seed_ids = subg.layer_parent_nid(-1)
Da Zheng's avatar
Da Zheng committed
90
91
92
93
94
95
96
97
98
99
100
        assert subg.number_of_nodes() <= 6 * len(seed_ids)
        assert subg.number_of_edges() <= 5 * len(seed_ids)
        for seed_id in seed_ids:
            verify_subgraph(g, subg, seed_id)

def test_10neighbor_sampler():
    g = generate_rand_graph(100)
    check_10neighbor_sampler(g, None)
    check_10neighbor_sampler(g, seeds=np.unique(np.random.randint(0, g.number_of_nodes(),
                                                                  size=int(g.number_of_nodes() / 10))))

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def test_layer_sampler(prefetch=False):
    g = generate_rand_graph(100)
    nid = g.nodes()
    src, dst, eid = g.all_edges(form='all', order='eid')
    n_batches = 5
    batch_size = 50
    seed_batches = [np.sort(np.random.choice(F.asnumpy(nid), batch_size, replace=False))
                    for i in range(n_batches)]
    seed_nodes = np.hstack(seed_batches)
    layer_sizes = [50] * 3
    LayerSampler = getattr(dgl.contrib.sampling, 'LayerSampler')
    sampler = LayerSampler(g, batch_size, layer_sizes, 'in',
                           seed_nodes=seed_nodes, num_workers=4, prefetch=prefetch)
    for sub_g in sampler:
        assert all(sub_g.layer_size(i) < size for i, size in enumerate(layer_sizes))
        sub_nid = F.arange(0, sub_g.number_of_nodes())
        assert all(np.all(np.isin(F.asnumpy(sub_g.layer_nid(i)), F.asnumpy(sub_nid)))
                   for i in range(sub_g.num_layers))
        assert np.all(np.isin(F.asnumpy(sub_g.map_to_parent_nid(sub_nid)),
                              F.asnumpy(nid)))
        sub_eid = F.arange(0, sub_g.number_of_edges())
        assert np.all(np.isin(F.asnumpy(sub_g.map_to_parent_eid(sub_eid)),
                              F.asnumpy(eid)))
        assert any(np.all(np.sort(F.asnumpy(sub_g.layer_parent_nid(-1))) == seed_batch)
                   for seed_batch in seed_batches)

        sub_src, sub_dst = sub_g.all_edges(order='eid')
        for i in range(sub_g.num_blocks):
            block_eid = sub_g.block_eid(i)
            block_src = sub_g.map_to_parent_nid(sub_src[block_eid])
            block_dst = sub_g.map_to_parent_nid(sub_dst[block_eid])

            block_parent_eid = sub_g.block_parent_eid(i)
            block_parent_src = src[block_parent_eid]
            block_parent_dst = dst[block_parent_eid]

            assert np.all(F.asnumpy(block_src == block_parent_src))

        n_layers = sub_g.num_layers
        sub_n = sub_g.number_of_nodes()
        assert sum(F.shape(sub_g.layer_nid(i))[0] for i in range(n_layers)) == sub_n
        n_blocks = sub_g.num_blocks
        sub_m = sub_g.number_of_edges()
        assert sum(F.shape(sub_g.block_eid(i))[0] for i in range(n_blocks)) == sub_m

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
def test_random_walk():
    edge_list = [(0, 1), (1, 2), (2, 3), (3, 4),
                 (4, 3), (3, 2), (2, 1), (1, 0)]
    seeds = [0, 1]
    n_traces = 3
    n_hops = 4

    g = dgl.DGLGraph(edge_list, readonly=True)
    traces = dgl.contrib.sampling.random_walk(g, seeds, n_traces, n_hops)
    traces = F.zerocopy_to_numpy(traces)

    assert traces.shape == (len(seeds), n_traces, n_hops + 1)

    for i, seed in enumerate(seeds):
        assert (traces[i, :, 0] == seeds[i]).all()

    trace_diff = np.diff(traces, axis=-1)
    # only nodes with adjacent IDs are connected
    assert (np.abs(trace_diff) == 1).all()

Da Zheng's avatar
Da Zheng committed
166
167
168
169
170
if __name__ == '__main__':
    test_1neighbor_sampler_all()
    test_10neighbor_sampler_all()
    test_1neighbor_sampler()
    test_10neighbor_sampler()
171
172
    test_layer_sampler()
    test_layer_sampler(prefetch=True)
173
    test_random_walk()