test_basics.py 3.03 KB
Newer Older
1
2
from dgl.graph import DGLGraph

3
def message_func(src, edge):
4
5
    return src['h']

6
7
def update_func(node, accum):
    return {'h' : node['h'] + accum}
8

Minjie Wang's avatar
Minjie Wang committed
9
10
11
12
13
14
def message_dict_func(src, edge):
    return {'m' : src['h']}

def update_dict_func(node, accum):
    return {'h' : node['h'] + accum['m']}

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def generate_graph():
    g = DGLGraph()
    for i in range(10):
        g.add_node(i, h=i+1) # 10 nodes.
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    # add a back flow from 9 to 0
    g.add_edge(9, 0)
    return g

def check(g, h):
    nh = [str(g.nodes[i]['h']) for i in range(10)]
    h = [str(x) for x in h]
    assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))

Minjie Wang's avatar
Minjie Wang committed
32
def register1(g):
33
34
    g.register_message_func(message_func)
    g.register_update_func(update_func)
35
    g.register_reduce_func('sum')
Minjie Wang's avatar
Minjie Wang committed
36
37
38
39
40
41
42
43

def register2(g):
    g.register_message_func(message_dict_func)
    g.register_update_func(update_dict_func)
    g.register_reduce_func('sum')

def _test_sendrecv(g):
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
44
    g.sendto(0, 1)
45
    g.recv(1)
46
47
48
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
    g.sendto(5, 9)
    g.sendto(6, 9)
49
    g.recv(9)
50
51
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 23])

Minjie Wang's avatar
Minjie Wang committed
52
def _test_multi_sendrecv(g):
53
54
55
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    # one-many
    g.sendto(0, [1, 2, 3])
56
    g.recv([1, 2, 3])
57
58
59
    check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 10])
    # many-one
    g.sendto([6, 7, 8], 9)
60
    g.recv(9)
61
62
63
    check(g, [1, 3, 4, 5, 5, 6, 7, 8, 9, 34])
    # many-many
    g.sendto([0, 0, 4, 5], [4, 5, 9, 9])
64
    g.recv([4, 5, 9])
65
66
    check(g, [1, 3, 4, 5, 6, 7, 7, 8, 9, 45])

Minjie Wang's avatar
Minjie Wang committed
67
def _test_update_routines(g):
68
69
70
71
72
73
74
75
76
77
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    g.update_by_edge(0, 1)
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 10])
    g.update_to(9)
    check(g, [1, 3, 3, 4, 5, 6, 7, 8, 9, 55])
    g.update_from(0)
    check(g, [1, 4, 4, 5, 6, 7, 8, 9, 10, 55])
    g.update_all()
    check(g, [56, 5, 5, 6, 7, 8, 9, 10, 11, 108])

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def _test_update_to_0deg():
    g = DGLGraph()
    g.add_node(0, h=2)
    g.add_node(1, h=1)
    g.add_edge(0, 1)
    def _message(src, edge):
        return src
    def _reduce(node, msgs):
        assert msgs is not None
        return msgs.sum(1)
    def _update(node, accum):
        assert accum is None
        return {'h': node['h'] * 2}
    g.update_to(0, _message, _reduce, _update)
    assert g.nodes[0]['h'] == 4

Minjie Wang's avatar
Minjie Wang committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def test_sendrecv():
    g = generate_graph()
    register1(g)
    _test_sendrecv(g)
    g = generate_graph()
    register2(g)
    _test_sendrecv(g)

def test_multi_sendrecv():
    g = generate_graph()
    register1(g)
    _test_multi_sendrecv(g)
    g = generate_graph()
    register2(g)
    _test_multi_sendrecv(g)

def test_update_routines():
    g = generate_graph()
    register1(g)
    _test_update_routines(g)
    g = generate_graph()
    register2(g)
    _test_update_routines(g)

118
119
    _test_update_to_0deg()

120
121
122
123
if __name__ == '__main__':
    test_sendrecv()
    test_multi_sendrecv()
    test_update_routines()