test_basics.py 2.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from dgl.graph import DGLGraph

def message_func(src, dst, edge):
    return src['h']

def update_func(node, msgs):
    m = sum(msgs)
    return {'h' : node['h'] + m}

def generate_graph():
    g = DGLGraph()
    for i in range(10):
        g.add_node(i, h=i+1) # 10 nodes.
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
    return g

def check(g, h):
    nh = [str(g.nodes[i]['h']) for i in range(10)]
    h = [str(x) for x in h]
    assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))

def test_sendrecv():
    g = generate_graph()
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    g.register_message_func(message_func)
    g.register_update_func(update_func)
    g.sendto(0, 1)
    g.recvfrom(1, [0])
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
    g.sendto(5, 9)
    g.sendto(6, 9)
    g.recvfrom(9, [5, 6])
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23])

def test_multi_sendrecv():
    g = generate_graph()
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    g.register_message_func(message_func)
    g.register_update_func(update_func)
    # one-many
    g.sendto(0, [1, 2, 3])
    g.recvfrom([1, 2, 3], [[0], [0], [0]])
    check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 10])
    # many-one
    g.sendto([6, 7, 8], 9)
    g.recvfrom(9, [6, 7, 8])
    check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 34])
    # many-many
    g.sendto([0, 0, 4, 5], [4, 5, 9, 9])
    g.recvfrom([4, 5, 9], [[0], [0], [4, 5]])
    check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45])

def test_update_routines():
    g = generate_graph()
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    g.register_message_func(message_func)
    g.register_update_func(update_func)
    g.update_by_edge(0, 1)
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
    g.update_to(9)
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 55])
    g.update_from(0)
    check(g, [1, 4, 4, 5, 6, 7, 8, 9, 10, 55])
    g.update_all()
    check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108])

if __name__ == '__main__':
    test_sendrecv()
    test_multi_sendrecv()
    test_update_routines()