test_sampler.py 4.67 KB
Newer Older
1
import backend as F
Da Zheng's avatar
Da Zheng committed
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np
import scipy as sp
import dgl
from dgl import utils

def generate_rand_graph(n):
    arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
    return dgl.DGLGraph(arr, readonly=True)

def test_1neighbor_sampler_all():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
14
15
16
    for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 100, neighbor_type='in',
                                                          num_workers=4, return_seed_id=True):
        seed_ids = aux['seeds']
Da Zheng's avatar
Da Zheng committed
17
        assert len(seed_ids) == 1
18
        src, dst, eid = g.in_edges(seed_ids, form='all')
Da Zheng's avatar
Da Zheng committed
19
        # Test if there is a self loop
20
        self_loop = F.asnumpy(F.sum(src == dst, 0)) == 1
Da Zheng's avatar
Da Zheng committed
21
22
23
24
25
26
27
        if self_loop:
            assert subg.number_of_nodes() == len(src)
        else:
            assert subg.number_of_nodes() == len(src) + 1
        assert subg.number_of_edges() >= len(src)

        child_ids = subg.map_to_subgraph_nid(seed_ids)
28
        child_src, child_dst, child_eid = subg.in_edges(child_ids, form='all')
Da Zheng's avatar
Da Zheng committed
29
30

        child_src1 = subg.map_to_subgraph_nid(src)
31
        assert F.asnumpy(F.sum(child_src1 == child_src, 0)) == len(src)
Da Zheng's avatar
Da Zheng committed
32
33

def is_sorted(arr):
34
    return np.sum(np.sort(arr) == arr, 0) == len(arr)
Da Zheng's avatar
Da Zheng committed
35
36

def verify_subgraph(g, subg, seed_id):
37
    src, dst, eid = g.in_edges(seed_id, form='all')
Da Zheng's avatar
Da Zheng committed
38
    child_id = subg.map_to_subgraph_nid(seed_id)
39
    child_src, child_dst, child_eid = subg.in_edges(child_id, form='all')
40
    child_src = F.asnumpy(child_src)
Da Zheng's avatar
Da Zheng committed
41
42
43
44
45
    # We don't allow duplicate elements in the neighbor list.
    assert(len(np.unique(child_src)) == len(child_src))
    # The neighbor list also needs to be sorted.
    assert(is_sorted(child_src))

46
    child_src1 = F.asnumpy(subg.map_to_subgraph_nid(src))
Da Zheng's avatar
Da Zheng committed
47
48
49
50
51
52
53
    child_src1 = child_src1[child_src1 >= 0]
    for i in child_src:
        assert i in child_src1

def test_1neighbor_sampler():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
54
55
56
    for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
                                                          num_workers=4, return_seed_id=True):
        seed_ids = aux['seeds']
Da Zheng's avatar
Da Zheng committed
57
58
59
60
61
        assert len(seed_ids) == 1
        assert subg.number_of_nodes() <= 6
        assert subg.number_of_edges() <= 5
        verify_subgraph(g, subg, seed_ids)

62
63
64
65
66
67
68
69
70
71
72
def test_prefetch_neighbor_sampler():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
    for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
                                                          num_workers=4, return_seed_id=True, prefetch=True):
        seed_ids = aux['seeds']
        assert len(seed_ids) == 1
        assert subg.number_of_nodes() <= 6
        assert subg.number_of_edges() <= 5
        verify_subgraph(g, subg, seed_ids)

Da Zheng's avatar
Da Zheng committed
73
74
75
def test_10neighbor_sampler_all():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
76
77
78
    for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 10, 100, neighbor_type='in',
                                                          num_workers=4, return_seed_id=True):
        seed_ids = aux['seeds']
79
        src, dst, eid = g.in_edges(seed_ids, form='all')
Da Zheng's avatar
Da Zheng committed
80
81

        child_ids = subg.map_to_subgraph_nid(seed_ids)
82
        child_src, child_dst, child_eid = subg.in_edges(child_ids, form='all')
Da Zheng's avatar
Da Zheng committed
83
84

        child_src1 = subg.map_to_subgraph_nid(src)
85
        assert F.asnumpy(F.sum(child_src1 == child_src, 0)) == len(src)
Da Zheng's avatar
Da Zheng committed
86
87
88

def check_10neighbor_sampler(g, seeds):
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
89
90
91
92
    for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 10, 5, neighbor_type='in',
                                                          num_workers=4, seed_nodes=seeds,
                                                          return_seed_id=True):
        seed_ids = aux['seeds']
Da Zheng's avatar
Da Zheng committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        assert subg.number_of_nodes() <= 6 * len(seed_ids)
        assert subg.number_of_edges() <= 5 * len(seed_ids)
        for seed_id in seed_ids:
            verify_subgraph(g, subg, seed_id)

def test_10neighbor_sampler():
    g = generate_rand_graph(100)
    check_10neighbor_sampler(g, None)
    check_10neighbor_sampler(g, seeds=np.unique(np.random.randint(0, g.number_of_nodes(),
                                                                  size=int(g.number_of_nodes() / 10))))

if __name__ == '__main__':
    test_1neighbor_sampler_all()
    test_10neighbor_sampler_all()
    test_1neighbor_sampler()
    test_10neighbor_sampler()