test_basics2.py 1.81 KB
Newer Older
1
2
3
from dgl import DGLGraph
from dgl.graph import __REPR__

4
def message_func(hu, e_uv):
5
6
    return hu

7
def message_not_called(hu, e_uv):
8
9
10
    assert False
    return hu

Minjie Wang's avatar
Minjie Wang committed
11
def reduce_not_called(h, msgs):
12
13
14
    assert False
    return 0

15
16
def reduce_func(h, msgs):
    return h + sum(msgs)
17
18

def check(g, h):
19
    nh = [str(g.nodes[i][__REPR__]) for i in range(10)]
20
21
22
23
24
25
    h = [str(x) for x in h]
    assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))

def generate_graph():
    g = DGLGraph()
    for i in range(10):
26
        g.add_node(i, __REPR__=i+1) # 10 nodes.
27
28
29
30
31
32
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    return g

33
def test_no_msg_recv():
34
35
36
37
    g = generate_graph()
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    g.register_message_func(message_not_called)
    g.register_reduce_func(reduce_not_called)
38
    g.register_apply_node_func(lambda h : h + 1)
39
40
41
42
43
44
45
46
    for i in range(10):
        g.recv(i)
    check(g, [2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

def test_double_recv():
    g = generate_graph()
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    g.register_message_func(message_func)
47
48
49
    g.register_reduce_func(reduce_func)
    g.send(1, 9)
    g.send(2, 9)
50
51
52
    g.recv(9)
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 15])
    g.register_reduce_func(reduce_not_called)
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    g.recv(9)

def test_pull_0deg():
    g = DGLGraph()
    g.add_node(0, h=2)
    g.add_node(1, h=1)
    g.add_edge(0, 1)
    def _message(src, edge):
        assert False
        return src
    def _reduce(node, msgs):
        assert False
        return node
    def _update(node):
        return {'h': node['h'] * 2}
    g.pull(0, _message, _reduce, _update)
    assert g.nodes[0]['h'] == 4
70
71

if __name__ == '__main__':
72
    test_no_msg_recv()
73
    test_double_recv()
74
    test_pull_0deg()