test_batched_graph.py 7.31 KB
Newer Older
1
import dgl
2
import backend as F
3
4
5
6
7
8
9
10
11
12
13

def tree1():
    """Generate a tree
         0
        / \
       1   2
      / \
     3   4
    Edges are from leaves to root.
    """
    g = dgl.DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
14
    g.add_nodes(5)
15
16
17
18
    g.add_edge(3, 1)
    g.add_edge(4, 1)
    g.add_edge(1, 0)
    g.add_edge(2, 0)
19
20
    g.ndata['h'] = F.tensor([0, 1, 2, 3, 4])
    g.edata['h'] = F.randn((4, 10))
21
22
23
24
25
26
27
28
29
30
31
32
    return g

def tree2():
    """Generate a tree
         1
        / \
       4   3
      / \
     2   0
    Edges are from leaves to root.
    """
    g = dgl.DGLGraph()
Minjie Wang's avatar
Minjie Wang committed
33
    g.add_nodes(5)
34
35
36
37
    g.add_edge(2, 4)
    g.add_edge(0, 4)
    g.add_edge(4, 1)
    g.add_edge(3, 1)
38
39
    g.ndata['h'] = F.tensor([0, 1, 2, 3, 4])
    g.edata['h'] = F.randn((4, 10))
40
41
42
43
44
45
46
    return g

def test_batch_unbatch():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
Minjie Wang's avatar
Minjie Wang committed
47
48
49
50
51
52
53
    assert bg.number_of_nodes() == 10
    assert bg.number_of_edges() == 8
    assert bg.batch_size == 2
    assert bg.batch_num_nodes == [5, 5]
    assert bg.batch_num_edges == [4, 4]

    tt1, tt2 = dgl.unbatch(bg)
54
55
56
57
    assert F.allclose(t1.ndata['h'], tt1.ndata['h'])
    assert F.allclose(t1.edata['h'], tt1.edata['h'])
    assert F.allclose(t2.ndata['h'], tt2.ndata['h'])
    assert F.allclose(t2.edata['h'], tt2.edata['h'])
Minjie Wang's avatar
Minjie Wang committed
58
59
60
61
62
63
64
65
66
67
68
69
70

def test_batch_unbatch1():
    t1 = tree1()
    t2 = tree2()
    b1 = dgl.batch([t1, t2])
    b2 = dgl.batch([t2, b1])
    assert b2.number_of_nodes() == 15
    assert b2.number_of_edges() == 12
    assert b2.batch_size == 3
    assert b2.batch_num_nodes == [5, 5, 5]
    assert b2.batch_num_edges == [4, 4, 4]

    s1, s2, s3 = dgl.unbatch(b2)
71
72
73
74
75
76
    assert F.allclose(t2.ndata['h'], s1.ndata['h'])
    assert F.allclose(t2.edata['h'], s1.edata['h'])
    assert F.allclose(t1.ndata['h'], s2.ndata['h'])
    assert F.allclose(t1.edata['h'], s2.edata['h'])
    assert F.allclose(t2.ndata['h'], s3.ndata['h'])
    assert F.allclose(t2.edata['h'], s3.edata['h'])
77

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    # Test batching readonly graphs
    t1.readonly()
    t2.readonly()
    t1_was_readonly = t1.is_readonly
    t2_was_readonly = t2.is_readonly
    bg = dgl.batch([t1, t2])

    assert t1.is_readonly == t1_was_readonly
    assert t2.is_readonly == t2_was_readonly
    assert bg.number_of_nodes() == 10
    assert bg.number_of_edges() == 8
    assert bg.batch_size == 2
    assert bg.batch_num_nodes == [5, 5]
    assert bg.batch_num_edges == [4, 4]

    rs1, rs2 = dgl.unbatch(bg)
    assert F.allclose(rs1.edges()[0], t1.edges()[0])
    assert F.allclose(rs1.edges()[1], t1.edges()[1])
    assert F.allclose(rs2.edges()[0], t2.edges()[0])
    assert F.allclose(rs2.edges()[1], t2.edges()[1])
    assert F.allclose(rs1.nodes(), t1.nodes())
    assert F.allclose(rs2.nodes(), t2.nodes())
    assert F.allclose(t1.ndata['h'], rs1.ndata['h'])
    assert F.allclose(t1.edata['h'], rs1.edata['h'])
    assert F.allclose(t2.ndata['h'], rs2.ndata['h'])
    assert F.allclose(t2.edata['h'], rs2.edata['h'])

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def test_batch_unbatch_frame():
    """Test module of node/edge frames of batched/unbatched DGLGraphs.
    Also address the bug mentioned in https://github.com/dmlc/dgl/issues/1475.
    """
    t1 = tree1()
    t2 = tree2()
    N1 = t1.number_of_nodes()
    E1 = t1.number_of_edges()
    N2 = t2.number_of_nodes()
    E2 = t2.number_of_edges()
    D = 10
    t1.ndata['h'] = F.randn((N1, D))
    t1.edata['h'] = F.randn((E1, D))
    t2.ndata['h'] = F.randn((N2, D))
    t2.edata['h'] = F.randn((E2, D))
    
    if F.backend_name != 'tensorflow':  # tf's tensor is immutable
        b1 = dgl.batch([t1, t2])
        b2 = dgl.batch([t2])
        b1.ndata['h'][:N1] = F.zeros((N1, D))
        b1.edata['h'][:E1] = F.zeros((E1, D))
        b2.ndata['h'][:N2] = F.zeros((N2, D))
        b2.edata['h'][:E2] = F.zeros((E2, D))
        assert not F.allclose(t1.ndata['h'], F.zeros((N1, D)))
        assert not F.allclose(t1.edata['h'], F.zeros((E1, D)))
        assert not F.allclose(t2.ndata['h'], F.zeros((N2, D)))
        assert not F.allclose(t2.edata['h'], F.zeros((E2, D)))

        g1, g2 = dgl.unbatch(b1)
        _g2, = dgl.unbatch(b2)
        assert F.allclose(g1.ndata['h'], F.zeros((N1, D)))
        assert F.allclose(g1.edata['h'], F.zeros((E1, D)))
        assert F.allclose(g2.ndata['h'], t2.ndata['h'])
        assert F.allclose(g2.edata['h'], t2.edata['h'])
        assert F.allclose(_g2.ndata['h'], F.zeros((N2, D)))
        assert F.allclose(_g2.edata['h'], F.zeros((E2, D)))

142
143
144
145
146
147
148
149
150
def test_batch_unbatch2():
    # test setting/getting features after batch
    a = dgl.DGLGraph()
    a.add_nodes(4)
    a.add_edges(0, [1, 2, 3])
    b = dgl.DGLGraph()
    b.add_nodes(3)
    b.add_edges(0, [1, 2])
    c = dgl.batch([a, b])
151
152
153
154
    c.ndata['h'] = F.ones((7, 1))
    c.edata['w'] = F.ones((5, 1))
    assert F.allclose(c.ndata['h'], F.ones((7, 1)))
    assert F.allclose(c.edata['w'], F.ones((5, 1)))
155

156
def test_batch_send_then_recv():
157
158
159
160
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
161
    bg.register_message_func(lambda edges: {'m' : edges.src['h']})
162
    bg.register_reduce_func(lambda nodes: {'h' : F.sum(nodes.mailbox['m'], 1)})
Lingfan Yu's avatar
Lingfan Yu committed
163
164
    u = [3, 4, 2 + 5, 0 + 5]
    v = [1, 1, 4 + 5, 4 + 5]
165

166
    bg.send((u, v))
167
    bg.recv([1, 9]) # assuming recv takes in unique nodes
168

Lingfan Yu's avatar
Lingfan Yu committed
169
    t1, t2 = dgl.unbatch(bg)
170
171
    assert F.asnumpy(t1.ndata['h'][1]) == 7
    assert F.asnumpy(t2.ndata['h'][4]) == 2
172

173
174
175
176
177
178
def test_batch_send_and_recv():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
    bg.register_message_func(lambda edges: {'m' : edges.src['h']})
179
    bg.register_reduce_func(lambda nodes: {'h' : F.sum(nodes.mailbox['m'], 1)})
180
181
182
183
184
185
    u = [3, 4, 2 + 5, 0 + 5]
    v = [1, 1, 4 + 5, 4 + 5]

    bg.send_and_recv((u, v))

    t1, t2 = dgl.unbatch(bg)
186
187
    assert F.asnumpy(t1.ndata['h'][1]) == 7
    assert F.asnumpy(t2.ndata['h'][4]) == 2
188
189
190
191
192
193

def test_batch_propagate():
    t1 = tree1()
    t2 = tree2()

    bg = dgl.batch([t1, t2])
194
    bg.register_message_func(lambda edges: {'m' : edges.src['h']})
195
    bg.register_reduce_func(lambda nodes: {'h' : F.sum(nodes.mailbox['m'], 1)})
196
197
198
199
200
    # get leaves.

    order = []

    # step 1
Lingfan Yu's avatar
Lingfan Yu committed
201
202
    u = [3, 4, 2 + 5, 0 + 5]
    v = [1, 1, 4 + 5, 4 + 5]
203
204
205
    order.append((u, v))

    # step 2
Lingfan Yu's avatar
Lingfan Yu committed
206
207
    u = [1, 2, 4 + 5, 3 + 5]
    v = [0, 0, 1 + 5, 1 + 5]
208
209
    order.append((u, v))

GaiYu0's avatar
GaiYu0 committed
210
    bg.prop_edges(order)
Lingfan Yu's avatar
Lingfan Yu committed
211
    t1, t2 = dgl.unbatch(bg)
212

213
214
    assert F.asnumpy(t1.ndata['h'][0]) == 9
    assert F.asnumpy(t2.ndata['h'][1]) == 5
215

216
217
def test_batched_edge_ordering():
    g1 = dgl.DGLGraph()
Lingfan Yu's avatar
Lingfan Yu committed
218
219
    g1.add_nodes(6)
    g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
220
    e1 = F.randn((5, 10))
221
    g1.edata['h'] = e1
222
    g2 = dgl.DGLGraph()
Lingfan Yu's avatar
Lingfan Yu committed
223
224
    g2.add_nodes(6)
    g2.add_edges([0, 1 ,2 ,5, 4 ,5], [1, 2, 3, 4, 3, 0])
225
    e2 = F.randn((6, 10))
226
    g2.edata['h'] = e2
227
    g = dgl.batch([g1, g2])
228
229
    r1 = g.edata['h'][g.edge_id(4, 5)]
    r2 = g1.edata['h'][g1.edge_id(4, 5)]
230
    assert F.array_equal(r1, r2)
231

Lingfan Yu's avatar
Lingfan Yu committed
232
233
234
235
236
237
238
239
240
241
242
def test_batch_no_edge():
    g1 = dgl.DGLGraph()
    g1.add_nodes(6)
    g1.add_edges([4, 4, 2, 2, 0], [5, 3, 3, 1, 1])
    g2 = dgl.DGLGraph()
    g2.add_nodes(6)
    g2.add_edges([0, 1, 2, 5, 4, 5], [1 ,2 ,3, 4, 3, 0])
    g3 = dgl.DGLGraph()
    g3.add_nodes(1)  # no edges
    g = dgl.batch([g1, g3, g2]) # should not throw an error

243
244
if __name__ == '__main__':
    test_batch_unbatch()
Minjie Wang's avatar
Minjie Wang committed
245
    test_batch_unbatch1()
246
    test_batch_unbatch_frame()
247
248
249
250
251
252
    #test_batch_unbatch2()
    #test_batched_edge_ordering()
    #test_batch_send_then_recv()
    #test_batch_send_and_recv()
    #test_batch_propagate()
    #test_batch_no_edge()