test_sampler.py 4.64 KB
Newer Older
1
import backend as F
Da Zheng's avatar
Da Zheng committed
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np
import scipy as sp
import dgl
from dgl import utils

def generate_rand_graph(n):
    arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
    return dgl.DGLGraph(arr, readonly=True)

def test_1neighbor_sampler_all():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
14
15
16
    for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 100, neighbor_type='in',
                                                          num_workers=4, return_seed_id=True):
        seed_ids = aux['seeds']
Da Zheng's avatar
Da Zheng committed
17
        assert len(seed_ids) == 1
18
        src, dst, eid = g.in_edges(seed_ids, form='all')
Da Zheng's avatar
Da Zheng committed
19
20
        assert subg.number_of_nodes() == len(src) + 1
        assert subg.number_of_edges() == len(src)
Da Zheng's avatar
Da Zheng committed
21

Da Zheng's avatar
Da Zheng committed
22
23
24
        assert seed_ids == subg.layer_parent_nid(-1)
        child_src, child_dst, child_eid = subg.in_edges(subg.layer_nid(-1), form='all')
        assert F.array_equal(child_src, subg.layer_nid(0))
Da Zheng's avatar
Da Zheng committed
25

Da Zheng's avatar
Da Zheng committed
26
27
        src1 = subg.map_to_parent_nid(child_src)
        assert F.array_equal(src1, src)
Da Zheng's avatar
Da Zheng committed
28
29

def is_sorted(arr):
30
    return np.sum(np.sort(arr) == arr, 0) == len(arr)
Da Zheng's avatar
Da Zheng committed
31
32

def verify_subgraph(g, subg, seed_id):
Da Zheng's avatar
Da Zheng committed
33
34
35
36
    seed_id = F.asnumpy(seed_id)
    seeds = F.asnumpy(subg.map_to_parent_nid(subg.layer_nid(-1)))
    assert seed_id in seeds
    child_seed = F.asnumpy(subg.layer_nid(-1))[seeds == seed_id]
37
    src, dst, eid = g.in_edges(seed_id, form='all')
Da Zheng's avatar
Da Zheng committed
38
39
    child_src, child_dst, child_eid = subg.in_edges(child_seed, form='all')

40
    child_src = F.asnumpy(child_src)
Da Zheng's avatar
Da Zheng committed
41
42
43
44
45
    # We don't allow duplicate elements in the neighbor list.
    assert(len(np.unique(child_src)) == len(child_src))
    # The neighbor list also needs to be sorted.
    assert(is_sorted(child_src))

Da Zheng's avatar
Da Zheng committed
46
47
48
    # a neighbor in the subgraph must also exist in parent graph.
    for i in subg.map_to_parent_nid(child_src):
        assert i in src
Da Zheng's avatar
Da Zheng committed
49
50
51
52

def test_1neighbor_sampler():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
53
54
55
    for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
                                                          num_workers=4, return_seed_id=True):
        seed_ids = aux['seeds']
Da Zheng's avatar
Da Zheng committed
56
57
58
59
60
        assert len(seed_ids) == 1
        assert subg.number_of_nodes() <= 6
        assert subg.number_of_edges() <= 5
        verify_subgraph(g, subg, seed_ids)

61
62
63
64
65
66
67
68
69
70
71
def test_prefetch_neighbor_sampler():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
    for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
                                                          num_workers=4, return_seed_id=True, prefetch=True):
        seed_ids = aux['seeds']
        assert len(seed_ids) == 1
        assert subg.number_of_nodes() <= 6
        assert subg.number_of_edges() <= 5
        verify_subgraph(g, subg, seed_ids)

Da Zheng's avatar
Da Zheng committed
72
73
74
def test_10neighbor_sampler_all():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
75
76
77
    for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 10, 100, neighbor_type='in',
                                                          num_workers=4, return_seed_id=True):
        seed_ids = aux['seeds']
Da Zheng's avatar
Da Zheng committed
78
        assert F.array_equal(seed_ids, subg.map_to_parent_nid(subg.layer_nid(-1)))
Da Zheng's avatar
Da Zheng committed
79

Da Zheng's avatar
Da Zheng committed
80
81
82
83
        src, dst, eid = g.in_edges(seed_ids, form='all')
        child_src, child_dst, child_eid = subg.in_edges(subg.layer_nid(-1), form='all')
        src1 = subg.map_to_parent_nid(child_src)
        assert F.array_equal(src1, src)
Da Zheng's avatar
Da Zheng committed
84
85
86

def check_10neighbor_sampler(g, seeds):
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
87
88
89
90
    for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 10, 5, neighbor_type='in',
                                                          num_workers=4, seed_nodes=seeds,
                                                          return_seed_id=True):
        seed_ids = aux['seeds']
Da Zheng's avatar
Da Zheng committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        assert subg.number_of_nodes() <= 6 * len(seed_ids)
        assert subg.number_of_edges() <= 5 * len(seed_ids)
        for seed_id in seed_ids:
            verify_subgraph(g, subg, seed_id)

def test_10neighbor_sampler():
    g = generate_rand_graph(100)
    check_10neighbor_sampler(g, None)
    check_10neighbor_sampler(g, seeds=np.unique(np.random.randint(0, g.number_of_nodes(),
                                                                  size=int(g.number_of_nodes() / 10))))

if __name__ == '__main__':
    test_1neighbor_sampler_all()
    test_10neighbor_sampler_all()
    test_1neighbor_sampler()
    test_10neighbor_sampler()