test_sampler.py 4.11 KB
Newer Older
Da Zheng's avatar
Da Zheng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
os.environ['DGLBACKEND'] = 'mxnet'
import mxnet as mx
import numpy as np
import scipy as sp
import dgl
from dgl import utils

def generate_rand_graph(n):
    arr = (sp.sparse.random(n, n, density=0.1, format='coo') != 0).astype(np.int64)
    return dgl.DGLGraph(arr, readonly=True)

def test_1neighbor_sampler_all():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
    for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 1, 100, neighbor_type='in',
                                                               num_workers=4):
        assert len(seed_ids) == 1
        src, dst, eid = g._graph.in_edges(utils.toindex(seed_ids))
        # Test if there is a self loop
        self_loop = mx.nd.sum(src.tousertensor() == dst.tousertensor()).asnumpy() == 1
        if self_loop:
            assert subg.number_of_nodes() == len(src)
        else:
            assert subg.number_of_nodes() == len(src) + 1
        assert subg.number_of_edges() >= len(src)

        child_ids = subg.map_to_subgraph_nid(seed_ids)
        child_src, child_dst, child_eid = subg._graph.in_edges(child_ids)

        child_src1 = subg.map_to_subgraph_nid(src)
        assert mx.nd.sum(child_src1.tousertensor() == child_src.tousertensor()).asnumpy() == len(src)

def is_sorted(arr):
    return np.sum(np.sort(arr) == arr) == len(arr)

def verify_subgraph(g, subg, seed_id):
    seed_id = utils.toindex(seed_id)
    src, dst, eid = g._graph.in_edges(utils.toindex(seed_id))
    child_id = subg.map_to_subgraph_nid(seed_id)
    child_src, child_dst, child_eid = subg._graph.in_edges(child_id)
    child_src = child_src.tousertensor().asnumpy()
    # We don't allow duplicate elements in the neighbor list.
    assert(len(np.unique(child_src)) == len(child_src))
    # The neighbor list also needs to be sorted.
    assert(is_sorted(child_src))

    child_src1 = subg.map_to_subgraph_nid(src).tousertensor().asnumpy()
    child_src1 = child_src1[child_src1 >= 0]
    for i in child_src:
        assert i in child_src1

def test_1neighbor_sampler():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
    for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 1, 5, neighbor_type='in',
                                                               num_workers=4):
        assert len(seed_ids) == 1
        assert subg.number_of_nodes() <= 6
        assert subg.number_of_edges() <= 5
        verify_subgraph(g, subg, seed_ids)

def test_10neighbor_sampler_all():
    g = generate_rand_graph(100)
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
    for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 10, 100, neighbor_type='in',
                                                               num_workers=4):
        src, dst, eid = g._graph.in_edges(utils.toindex(seed_ids))

        child_ids = subg.map_to_subgraph_nid(seed_ids)
        child_src, child_dst, child_eid = subg._graph.in_edges(child_ids)

        child_src1 = subg.map_to_subgraph_nid(src)
        assert mx.nd.sum(child_src1.tousertensor() == child_src.tousertensor()).asnumpy() == len(src)

def check_10neighbor_sampler(g, seeds):
    # In this case, NeighborSampling simply gets the neighborhood of a single vertex.
    for subg, seed_ids in dgl.contrib.sampling.NeighborSampler(g, 10, 5, neighbor_type='in',
                                                               num_workers=4, seed_nodes=seeds):
        assert subg.number_of_nodes() <= 6 * len(seed_ids)
        assert subg.number_of_edges() <= 5 * len(seed_ids)
        for seed_id in seed_ids:
            verify_subgraph(g, subg, seed_id)

def test_10neighbor_sampler():
    g = generate_rand_graph(100)
    check_10neighbor_sampler(g, None)
    check_10neighbor_sampler(g, seeds=np.unique(np.random.randint(0, g.number_of_nodes(),
                                                                  size=int(g.number_of_nodes() / 10))))

if __name__ == '__main__':
    test_1neighbor_sampler_all()
    test_10neighbor_sampler_all()
    test_1neighbor_sampler()
    test_10neighbor_sampler()