test_batched_graph.py 5.83 KB
Newer Older
1
import dgl
2
import backend as F
3
4
5
6
7
8
9
10
11
12
13

def tree1():
    """Generate a tree
         0
        / \
       1   2
      / \
     3   4
    Edges are from leaves to root.
    """
    g = dgl.DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
14
    g.add_nodes(5)
15
16
17
18
    g.add_edge(3, 1)
    g.add_edge(4, 1)
    g.add_edge(1, 0)
    g.add_edge(2, 0)
19
20
    g.ndata['h'] = F.tensor([0, 1, 2, 3, 4])
    g.edata['h'] = F.randn((4, 10))
21
22
23
24
25
26
27
28
29
30
31
32
    return g

def tree2():
    """Generate a tree
         1
        / \
       4   3
      / \
     2   0
    Edges are from leaves to root.
    """
    g = dgl.DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
33
    g.add_nodes(5)
34
35
36
37
    g.add_edge(2, 4)
    g.add_edge(0, 4)
    g.add_edge(4, 1)
    g.add_edge(3, 1)
38
39
    g.ndata['h'] = F.tensor([0, 1, 2, 3, 4])
    g.edata['h'] = F.randn((4, 10))
40
41
42
43
44
45
46
    return g

def test_batch_unbatch():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
Minjie Wang's avatar
Minjie Wang committed
47
48
49
50
51
52
53
    assert bg.number_of_nodes() == 10
    assert bg.number_of_edges() == 8
    assert bg.batch_size == 2
    assert bg.batch_num_nodes == [5, 5]
    assert bg.batch_num_edges == [4, 4]

    tt1, tt2 = dgl.unbatch(bg)
54
55
56
57
    assert F.allclose(t1.ndata['h'], tt1.ndata['h'])
    assert F.allclose(t1.edata['h'], tt1.edata['h'])
    assert F.allclose(t2.ndata['h'], tt2.ndata['h'])
    assert F.allclose(t2.edata['h'], tt2.edata['h'])
Minjie Wang's avatar
Minjie Wang committed
58
59
60
61
62
63
64
65
66
67
68
69
70

def test_batch_unbatch1():
    t1 = tree1()
    t2 = tree2()
    b1 = dgl.batch([t1, t2])
    b2 = dgl.batch([t2, b1])
    assert b2.number_of_nodes() == 15
    assert b2.number_of_edges() == 12
    assert b2.batch_size == 3
    assert b2.batch_num_nodes == [5, 5, 5]
    assert b2.batch_num_edges == [4, 4, 4]

    s1, s2, s3 = dgl.unbatch(b2)
71
72
73
74
75
76
    assert F.allclose(t2.ndata['h'], s1.ndata['h'])
    assert F.allclose(t2.edata['h'], s1.edata['h'])
    assert F.allclose(t1.ndata['h'], s2.ndata['h'])
    assert F.allclose(t1.edata['h'], s2.edata['h'])
    assert F.allclose(t2.ndata['h'], s3.ndata['h'])
    assert F.allclose(t2.edata['h'], s3.edata['h'])
77

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    # Test batching readonly graphs
    t1.readonly()
    t2.readonly()
    t1_was_readonly = t1.is_readonly
    t2_was_readonly = t2.is_readonly
    bg = dgl.batch([t1, t2])

    assert t1.is_readonly == t1_was_readonly
    assert t2.is_readonly == t2_was_readonly
    assert bg.number_of_nodes() == 10
    assert bg.number_of_edges() == 8
    assert bg.batch_size == 2
    assert bg.batch_num_nodes == [5, 5]
    assert bg.batch_num_edges == [4, 4]

    rs1, rs2 = dgl.unbatch(bg)
    assert F.allclose(rs1.edges()[0], t1.edges()[0])
    assert F.allclose(rs1.edges()[1], t1.edges()[1])
    assert F.allclose(rs2.edges()[0], t2.edges()[0])
    assert F.allclose(rs2.edges()[1], t2.edges()[1])
    assert F.allclose(rs1.nodes(), t1.nodes())
    assert F.allclose(rs2.nodes(), t2.nodes())
    assert F.allclose(t1.ndata['h'], rs1.ndata['h'])
    assert F.allclose(t1.edata['h'], rs1.edata['h'])
    assert F.allclose(t2.ndata['h'], rs2.ndata['h'])
    assert F.allclose(t2.edata['h'], rs2.edata['h'])

105
106
107
108
109
110
111
112
113
def test_batch_unbatch2():
    # test setting/getting features after batch
    a = dgl.DGLGraph()
    a.add_nodes(4)
    a.add_edges(0, [1, 2, 3])
    b = dgl.DGLGraph()
    b.add_nodes(3)
    b.add_edges(0, [1, 2])
    c = dgl.batch([a, b])
114
115
116
117
    c.ndata['h'] = F.ones((7, 1))
    c.edata['w'] = F.ones((5, 1))
    assert F.allclose(c.ndata['h'], F.ones((7, 1)))
    assert F.allclose(c.edata['w'], F.ones((5, 1)))
118

119
def test_batch_send_then_recv():
120
121
122
123
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
124
    bg.register_message_func(lambda edges: {'m' : edges.src['h']})
125
    bg.register_reduce_func(lambda nodes: {'h' : F.sum(nodes.mailbox['m'], 1)})
Lingfan Yu's avatar
Lingfan Yu committed
126
127
    u = [3, 4, 2 + 5, 0 + 5]
    v = [1, 1, 4 + 5, 4 + 5]
128

129
    bg.send((u, v))
130
    bg.recv([1, 9]) # assuming recv takes in unique nodes
131

Lingfan Yu's avatar
Lingfan Yu committed
132
    t1, t2 = dgl.unbatch(bg)
133
134
    assert F.asnumpy(t1.ndata['h'][1]) == 7
    assert F.asnumpy(t2.ndata['h'][4]) == 2
135

136
137
138
139
140
141
def test_batch_send_and_recv():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
    bg.register_message_func(lambda edges: {'m' : edges.src['h']})
142
    bg.register_reduce_func(lambda nodes: {'h' : F.sum(nodes.mailbox['m'], 1)})
143
144
145
146
147
148
    u = [3, 4, 2 + 5, 0 + 5]
    v = [1, 1, 4 + 5, 4 + 5]

    bg.send_and_recv((u, v))

    t1, t2 = dgl.unbatch(bg)
149
150
    assert F.asnumpy(t1.ndata['h'][1]) == 7
    assert F.asnumpy(t2.ndata['h'][4]) == 2
151
152
153
154
155
156

def test_batch_propagate():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
157
    bg.register_message_func(lambda edges: {'m' : edges.src['h']})
158
    bg.register_reduce_func(lambda nodes: {'h' : F.sum(nodes.mailbox['m'], 1)})
159
160
161
162
163
    # get leaves.

    order = []

    # step 1
Lingfan Yu's avatar
Lingfan Yu committed
164
165
    u = [3, 4, 2 + 5, 0 + 5]
    v = [1, 1, 4 + 5, 4 + 5]
166
167
168
    order.append((u, v))

    # step 2
Lingfan Yu's avatar
Lingfan Yu committed
169
170
    u = [1, 2, 4 + 5, 3 + 5]
    v = [0, 0, 1 + 5, 1 + 5]
171
172
    order.append((u, v))

GaiYu0's avatar
GaiYu0 committed
173
    bg.prop_edges(order)
Lingfan Yu's avatar
Lingfan Yu committed
174
    t1, t2 = dgl.unbatch(bg)
175

176
177
    assert F.asnumpy(t1.ndata['h'][0]) == 9
    assert F.asnumpy(t2.ndata['h'][1]) == 5
178

179
180
def test_batched_edge_ordering():
    g1 = dgl.DGLGraph()
Lingfan Yu's avatar
Lingfan Yu committed
181
182
    g1.add_nodes(6)
    g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
183
    e1 = F.randn((5, 10))
184
    g1.edata['h'] = e1
185
    g2 = dgl.DGLGraph()
Lingfan Yu's avatar
Lingfan Yu committed
186
187
    g2.add_nodes(6)
    g2.add_edges([0, 1 ,2 ,5, 4 ,5], [1, 2, 3, 4, 3, 0])
188
    e2 = F.randn((6, 10))
189
    g2.edata['h'] = e2
190
    g = dgl.batch([g1, g2])
191
192
    r1 = g.edata['h'][g.edge_id(4, 5)]
    r2 = g1.edata['h'][g1.edge_id(4, 5)]
193
    assert F.array_equal(r1, r2)
194

Lingfan Yu's avatar
Lingfan Yu committed
195
196
197
198
199
200
201
202
203
204
205
def test_batch_no_edge():
    g1 = dgl.DGLGraph()
    g1.add_nodes(6)
    g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
    g2 = dgl.DGLGraph()
    g2.add_nodes(6)
    g2.add_edges([0, 1, 2, 5, 4, 5], [1 ,2 ,3, 4, 3, 0])
    g3 = dgl.DGLGraph()
    g3.add_nodes(1)  # no edges
    g = dgl.batch([g1, g3, g2]) # should not throw an error

206
207
if __name__ == '__main__':
    test_batch_unbatch()
Minjie Wang's avatar
Minjie Wang committed
208
    test_batch_unbatch1()
209
210
211
212
213
214
    #test_batch_unbatch2()
    #test_batched_edge_ordering()
    #test_batch_send_then_recv()
    #test_batch_send_and_recv()
    #test_batch_propagate()
    #test_batch_no_edge()