test_basics2.py 1.87 KB
Newer Older
1
2
3
from dgl import DGLGraph
from dgl.graph import __REPR__

4
def message_func(hu, e_uv):
5
6
    return hu

7
def message_not_called(hu, e_uv):
8
9
10
    assert False
    return hu

Minjie Wang's avatar
Minjie Wang committed
11
def reduce_not_called(h, msgs):
12
13
14
15
16
17
18
19
20
21
22
23
    assert False
    return 0

def update_no_msg(h, accum):
    assert accum is None
    return h + 1

def update_func(h, accum):
    assert accum is not None
    return h + accum

def check(g, h):
24
    nh = [str(g.nodes[i][__REPR__]) for i in range(10)]
25
26
27
28
29
30
    h = [str(x) for x in h]
    assert nh == h, "nh=[%s], h=[%s]" % (' '.join(nh), ' '.join(h))

def generate_graph():
    g = DGLGraph()
    for i in range(10):
31
        g.add_node(i, __REPR__=i+1) # 10 nodes.
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    # create a graph where 0 is the source and 9 is the sink
    for i in range(1, 9):
        g.add_edge(0, i)
        g.add_edge(i, 9)
    return g

def test_no_msg_update():
    g = generate_graph()
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    g.register_message_func(message_not_called)
    g.register_reduce_func(reduce_not_called)
    g.register_update_func(update_no_msg)
    for i in range(10):
        g.recv(i)
    check(g, [2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

def test_double_recv():
    g = generate_graph()
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    g.register_message_func(message_func)
    g.register_reduce_func('sum')
    g.register_update_func(update_func)
    g.sendto(1, 9)
    g.sendto(2, 9)
    g.recv(9)
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 15])
    try:
        # The second recv should have a None message
        g.recv(9)
    except:
        return
    assert False

def test_recv_no_pred():
    g = generate_graph()
    check(g, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    g.register_message_func(message_not_called)
    g.register_reduce_func(reduce_not_called)
    g.register_update_func(update_no_msg)
    g.recv(0)

if __name__ == '__main__':
    test_no_msg_update()
    test_double_recv()
    test_recv_no_pred()