"vscode:/vscode.git/clone" did not exist on "16b4b823e3b638de7ee56a251de4f76ceab97d13"
test_basics.py 2.71 KB
Newer Older
1
2
from dgl.graph import DGLGraph

3
def message_func(src, edge):
4
5
    return src['h']

6
7
8
9
10
def reduce_func(node, msgs):
    return {'m' : sum(msgs)}

def apply_func(node):
    return {'h' : node['h'] + node['m']}
11

Minjie Wang's avatar
Minjie Wang committed
12
13
14
def message_dict_func(src, edge):
    return {'m' : src['h']}

15
16
17
18
19
def reduce_dict_func(node, msgs):
    return {'m' : sum([msg['m'] for msg in msgs])}

def apply_dict_func(node):
    return {'h' : node['h'] + node['m']}
Minjie Wang's avatar
Minjie Wang committed
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def generate_graph():
    g = DGLGraph()
    for i in range(10):
        g.add_node(i, h=i+1) # 10 nodes.
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
    return g

def check(g, h):
    nh = [str(g.nodes[i]['h']) for i in range(10)]
    h = [str(x) for x in h]
    assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))

Minjie Wang's avatar
Minjie Wang committed
38
def register1(g):
39
    g.register_message_func(message_func)
40
41
    g.register_reduce_func(reduce_func)
    g.register_apply_node_func(apply_func)
Minjie Wang's avatar
Minjie Wang committed
42
43
44

def register2(g):
    g.register_message_func(message_dict_func)
45
46
    g.register_reduce_func(reduce_dict_func)
    g.register_apply_node_func(apply_dict_func)
Minjie Wang's avatar
Minjie Wang committed
47
48
49

def _test_sendrecv(g):
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
50
    g.send(0, 1)
51
    g.recv(1)
52
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
53
54
    g.send(5, 9)
    g.send(6, 9)
55
    g.recv(9)
56
57
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23])

Minjie Wang's avatar
Minjie Wang committed
58
def _test_multi_sendrecv(g):
59
60
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    # one-many
61
    g.send(0, [1, 2, 3])
62
    g.recv([1, 2, 3])
63
64
    check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 10])
    # many-one
65
    g.send([6, 7, 8], 9)
66
    g.recv(9)
67
68
    check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 34])
    # many-many
69
    g.send([0, 0, 4, 5], [4, 5, 9, 9])
70
    g.recv([4, 5, 9])
71
72
    check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45])

Minjie Wang's avatar
Minjie Wang committed
73
def _test_update_routines(g):
74
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
75
    g.send_and_recv(0, 1)
76
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
77
    g.pull(9)
78
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 55])
79
    g.push(0)
80
81
82
83
    check(g, [1, 4, 4, 5, 6, 7, 8, 9, 10, 55])
    g.update_all()
    check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108])

Minjie Wang's avatar
Minjie Wang committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
def test_sendrecv():
    g = generate_graph()
    register1(g)
    _test_sendrecv(g)
    g = generate_graph()
    register2(g)
    _test_sendrecv(g)

def test_multi_sendrecv():
    g = generate_graph()
    register1(g)
    _test_multi_sendrecv(g)
    g = generate_graph()
    register2(g)
    _test_multi_sendrecv(g)

def test_update_routines():
    g = generate_graph()
    register1(g)
    _test_update_routines(g)
    g = generate_graph()
    register2(g)
    _test_update_routines(g)

108
109
110
111
if __name__ == '__main__':
    test_sendrecv()
    test_multi_sendrecv()
    test_update_routines()