test_graph_batch.py 3.39 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import networkx as nx
import dgl
import torch
import numpy as np

def tree1():
    """Generate a tree
         0
        / \
       1   2
      / \
     3   4
    Edges are from leaves to root.
    """
    g = dgl.DGLGraph()
    g.add_node(0)
    g.add_node(1)
    g.add_node(2)
    g.add_node(3)
    g.add_node(4)
    g.add_edge(3, 1)
    g.add_edge(4, 1)
    g.add_edge(1, 0)
    g.add_edge(2, 0)
    g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4]))
26
    g.set_e_repr(torch.randn(4, 10))
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    return g

def tree2():
    """Generate a tree
         1
        / \
       4   3
      / \
     2   0
    Edges are from leaves to root.
    """
    g = dgl.DGLGraph()
    g.add_node(0)
    g.add_node(1)
    g.add_node(2)
    g.add_node(3)
    g.add_node(4)
    g.add_edge(2, 4)
    g.add_edge(0, 4)
    g.add_edge(4, 1)
    g.add_edge(3, 1)
    g.set_n_repr(torch.Tensor([0, 1, 2, 3, 4]))
49
    g.set_e_repr(torch.randn(4, 10))
50
51
52
53
54
    return g

def test_batch_unbatch():
    t1 = tree1()
    t2 = tree2()
55
56
57
58
    n1 = t1.get_n_repr()
    n2 = t2.get_n_repr()
    e1 = t1.get_e_repr()
    e2 = t2.get_e_repr()
59
60
61
62

    bg = dgl.batch([t1, t2])
    dgl.unbatch(bg)

63
64
65
66
    assert(n1.equal(t1.get_n_repr()))
    assert(n2.equal(t2.get_n_repr()))
    assert(e1.equal(t1.get_e_repr()))
    assert(e2.equal(t2.get_e_repr()))
67
68
69
70
71
72
73


def test_batch_sendrecv():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
Minjie Wang's avatar
Minjie Wang committed
74
75
    bg.register_message_func(lambda src, edge: src)
    bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1))
76
77
78
79
80
81
82
83
    e1 = [(3, 1), (4, 1)]
    e2 = [(2, 4), (0, 4)]

    u1, v1 = bg.query_new_edge(t1, *zip(*e1))
    u2, v2 = bg.query_new_edge(t2, *zip(*e2))
    u = np.concatenate((u1, u2)).tolist()
    v = np.concatenate((v1, v2)).tolist()

84
    bg.send(u, v)
85
86
87
88
89
90
91
92
93
94
95
96
    bg.recv(v)

    dgl.unbatch(bg)
    assert t1.get_n_repr()[1] == 7
    assert t2.get_n_repr()[4] == 2


def test_batch_propagate():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
Minjie Wang's avatar
Minjie Wang committed
97
98
    bg.register_message_func(lambda src, edge: src)
    bg.register_reduce_func(lambda node, msgs: torch.sum(msgs, 1))
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    # get leaves.

    order = []

    # step 1
    e1 = [(3, 1), (4, 1)]
    e2 = [(2, 4), (0, 4)]
    u1, v1 = bg.query_new_edge(t1, *zip(*e1))
    u2, v2 = bg.query_new_edge(t2, *zip(*e2))
    u = np.concatenate((u1, u2)).tolist()
    v = np.concatenate((v1, v2)).tolist()
    order.append((u, v))

    # step 2
    e1 = [(1, 0), (2, 0)]
    e2 = [(4, 1), (3, 1)]
    u1, v1 = bg.query_new_edge(t1, *zip(*e1))
    u2, v2 = bg.query_new_edge(t2, *zip(*e2))
    u = np.concatenate((u1, u2)).tolist()
    v = np.concatenate((v1, v2)).tolist()
    order.append((u, v))

    bg.propagate(iterator=order)
    dgl.unbatch(bg)

    assert t1.get_n_repr()[0] == 9
    assert t2.get_n_repr()[1] == 5

127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def test_batched_edge_ordering():
    g1 = dgl.DGLGraph()
    g1.add_nodes_from([0,1,2, 3, 4, 5])
    g1.add_edges_from([(4, 5), (4, 3), (2, 3), (2, 1), (0, 1)])
    g1.edge_list
    e1 = torch.randn(5, 10)
    g1.set_e_repr(e1)
    g2 = dgl.DGLGraph()
    g2.add_nodes_from([0, 1, 2, 3, 4, 5])
    g2.add_edges_from([(0, 1), (1, 2), (2, 3), (5, 4), (4, 3), (5, 0)])
    e2 = torch.randn(6, 10)
    g2.set_e_repr(e2)
    g = dgl.batch([g1, g2])
    r1 = g.get_e_repr()[g.get_edge_id(4, 5)]
    r2 = g1.get_e_repr()[g1.get_edge_id(4, 5)]
    assert torch.equal(r1, r2)
143
144
145

if __name__ == '__main__':
    test_batch_unbatch()
146
    test_batched_edge_ordering()
147
148
    test_batch_sendrecv()
    test_batch_propagate()