test_basics.py 2.59 KB
Newer Older
1
2
from dgl.graph import DGLGraph

3
def message_func(src, edge):
4
5
    return src['h']

6
7
def update_func(node, accum):
    return {'h' : node['h'] + accum}
8

Minjie Wang's avatar
Minjie Wang committed
9
10
11
12
13
14
def message_dict_func(src, edge):
    return {'m' : src['h']}

def update_dict_func(node, accum):
    return {'h' : node['h'] + accum['m']}

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def generate_graph():
    g = DGLGraph()
    for i in range(10):
        g.add_node(i, h=i+1) # 10 nodes.
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
    return g

def check(g, h):
    nh = [str(g.nodes[i]['h']) for i in range(10)]
    h = [str(x) for x in h]
    assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))

Minjie Wang's avatar
Minjie Wang committed
32
def register1(g):
33
34
    g.register_message_func(message_func)
    g.register_update_func(update_func)
35
    g.register_reduce_func('sum')
Minjie Wang's avatar
Minjie Wang committed
36
37
38
39
40
41
42
43

def register2(g):
    g.register_message_func(message_dict_func)
    g.register_update_func(update_dict_func)
    g.register_reduce_func('sum')

def _test_sendrecv(g):
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
44
    g.sendto(0, 1)
45
    g.recv(1)
46
47
48
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
    g.sendto(5, 9)
    g.sendto(6, 9)
49
    g.recv(9)
50
51
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23])

Minjie Wang's avatar
Minjie Wang committed
52
def _test_multi_sendrecv(g):
53
54
55
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    # one-many
    g.sendto(0, [1, 2, 3])
56
    g.recv([1, 2, 3])
57
58
59
    check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 10])
    # many-one
    g.sendto([6, 7, 8], 9)
60
    g.recv(9)
61
62
63
    check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 34])
    # many-many
    g.sendto([0, 0, 4, 5], [4, 5, 9, 9])
64
    g.recv([4, 5, 9])
65
66
    check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45])

Minjie Wang's avatar
Minjie Wang committed
67
def _test_update_routines(g):
68
69
70
71
72
73
74
75
76
77
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    g.update_by_edge(0, 1)
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
    g.update_to(9)
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 55])
    g.update_from(0)
    check(g, [1, 4, 4, 5, 6, 7, 8, 9, 10, 55])
    g.update_all()
    check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108])

Minjie Wang's avatar
Minjie Wang committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
def test_sendrecv():
    g = generate_graph()
    register1(g)
    _test_sendrecv(g)
    g = generate_graph()
    register2(g)
    _test_sendrecv(g)

def test_multi_sendrecv():
    g = generate_graph()
    register1(g)
    _test_multi_sendrecv(g)
    g = generate_graph()
    register2(g)
    _test_multi_sendrecv(g)

def test_update_routines():
    g = generate_graph()
    register1(g)
    _test_update_routines(g)
    g = generate_graph()
    register2(g)
    _test_update_routines(g)

102
103
104
105
if __name__ == '__main__':
    test_sendrecv()
    test_multi_sendrecv()
    test_update_routines()