test_sampler.py 4.17 KB
Newer Older
Da Zheng's avatar
Da Zheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import os
os.environ['DGLBACKEND'] = 'mxnet'
import mxnet as mx
import numpy as np
import scipy as sp
import dgl
from dgl import utils

def generate_rand_graph(n):
    arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
    return dgl.DGLGraph(arr, readonly=True)

def test_1neighbor_sampler_all():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
16
17
18
    for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 100, neighbor_type='in',
                                                          num_workers=4, return_seed_id=True):
        seed_ids = aux['seeds']
Da Zheng's avatar
Da Zheng committed
19
        assert len(seed_ids) == 1
20
        src, dst, eid = g.in_edges(seed_ids, form='all')
Da Zheng's avatar
Da Zheng committed
21
        # Test if there is a self loop
22
        self_loop = mx.nd.sum(src == dst).asnumpy() == 1
Da Zheng's avatar
Da Zheng committed
23
24
25
26
27
28
29
        if self_loop:
            assert subg.number_of_nodes() == len(src)
        else:
            assert subg.number_of_nodes() == len(src) + 1
        assert subg.number_of_edges() >= len(src)

        child_ids = subg.map_to_subgraph_nid(seed_ids)
30
        child_src, child_dst, child_eid = subg.in_edges(child_ids, form='all')
Da Zheng's avatar
Da Zheng committed
31
32

        child_src1 = subg.map_to_subgraph_nid(src)
33
        assert mx.nd.sum(child_src1 == child_src).asnumpy() == len(src)
Da Zheng's avatar
Da Zheng committed
34
35
36
37
38

def is_sorted(arr):
    return np.sum(np.sort(arr) == arr) == len(arr)

def verify_subgraph(g, subg, seed_id):
39
    src, dst, eid = g.in_edges(seed_id, form='all')
Da Zheng's avatar
Da Zheng committed
40
    child_id = subg.map_to_subgraph_nid(seed_id)
41
42
    child_src, child_dst, child_eid = subg.in_edges(child_id, form='all')
    child_src = child_src.asnumpy()
Da Zheng's avatar
Da Zheng committed
43
44
45
46
47
    # We don't allow duplicate elements in the neighbor list.
    assert(len(np.unique(child_src)) == len(child_src))
    # The neighbor list also needs to be sorted.
    assert(is_sorted(child_src))

48
    child_src1 = subg.map_to_subgraph_nid(src).asnumpy()
Da Zheng's avatar
Da Zheng committed
49
50
51
52
53
54
55
    child_src1 = child_src1[child_src1 >= 0]
    for i in child_src:
        assert i in child_src1

def test_1neighbor_sampler():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
56
57
58
    for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
                                                          num_workers=4, return_seed_id=True):
        seed_ids = aux['seeds']
Da Zheng's avatar
Da Zheng committed
59
60
61
62
63
64
65
66
        assert len(seed_ids) == 1
        assert subg.number_of_nodes() <= 6
        assert subg.number_of_edges() <= 5
        verify_subgraph(g, subg, seed_ids)

def test_10neighbor_sampler_all():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
67
68
69
    for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 10, 100, neighbor_type='in',
                                                          num_workers=4, return_seed_id=True):
        seed_ids = aux['seeds']
70
        src, dst, eid = g.in_edges(seed_ids, form='all')
Da Zheng's avatar
Da Zheng committed
71
72

        child_ids = subg.map_to_subgraph_nid(seed_ids)
73
        child_src, child_dst, child_eid = subg.in_edges(child_ids, form='all')
Da Zheng's avatar
Da Zheng committed
74
75

        child_src1 = subg.map_to_subgraph_nid(src)
76
        assert mx.nd.sum(child_src1 == child_src).asnumpy() == len(src)
Da Zheng's avatar
Da Zheng committed
77
78
79

def check_10neighbor_sampler(g, seeds):
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
80
81
82
83
    for subg, aux in dgl.contrib.sampling.NeighborSampler(g, 10, 5, neighbor_type='in',
                                                          num_workers=4, seed_nodes=seeds,
                                                          return_seed_id=True):
        seed_ids = aux['seeds']
Da Zheng's avatar
Da Zheng committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        assert subg.number_of_nodes() <= 6 * len(seed_ids)
        assert subg.number_of_edges() <= 5 * len(seed_ids)
        for seed_id in seed_ids:
            verify_subgraph(g, subg, seed_id)

def test_10neighbor_sampler():
    g = generate_rand_graph(100)
    check_10neighbor_sampler(g, None)
    check_10neighbor_sampler(g, seeds=np.unique(np.random.randint(0, g.number_of_nodes(),
                                                                  size=int(g.number_of_nodes() / 10))))

if __name__ == '__main__':
    test_1neighbor_sampler_all()
    test_10neighbor_sampler_all()
    test_1neighbor_sampler()
    test_10neighbor_sampler()