test_batched_heterograph.py 13.7 KB
Newer Older
1
2
import dgl
import backend as F
Mufei Li's avatar
Mufei Li committed
3
import unittest
4
import pytest
5
6

from dgl.base import ALL
7
from utils import parametrize_dtype
8
from test_utils import check_graph_equal, get_cases
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

def check_equivalence_between_heterographs(g1, g2, node_attrs=None, edge_attrs=None):
    assert g1.ntypes == g2.ntypes
    assert g1.etypes == g2.etypes
    assert g1.canonical_etypes == g2.canonical_etypes

    for nty in g1.ntypes:
        assert g1.number_of_nodes(nty) == g2.number_of_nodes(nty)

    for ety in g1.etypes:
        if len(g1._etype2canonical[ety]) > 0:
            assert g1.number_of_edges(ety) == g2.number_of_edges(ety)

    for ety in g1.canonical_etypes:
        assert g1.number_of_edges(ety) == g2.number_of_edges(ety)
24
25
        src1, dst1, eid1 = g1.edges(etype=ety, form='all')
        src2, dst2, eid2 = g2.edges(etype=ety, form='all')
26
27
        assert F.allclose(src1, src2)
        assert F.allclose(dst1, dst2)
28
        assert F.allclose(eid1, eid2)
29
30
31

    if node_attrs is not None:
        for nty in node_attrs.keys():
32
33
            if g1.number_of_nodes(nty) == 0:
                continue
34
35
36
37
38
            for feat_name in node_attrs[nty]:
                assert F.allclose(g1.nodes[nty].data[feat_name], g2.nodes[nty].data[feat_name])

    if edge_attrs is not None:
        for ety in edge_attrs.keys():
39
40
            if g1.number_of_edges(ety) == 0:
                continue
41
42
43
            for feat_name in edge_attrs[ety]:
                assert F.allclose(g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name])

44
@pytest.mark.parametrize('gs', get_cases(['two_hetero_batch']))
45
@parametrize_dtype
46
def test_topology(gs, idtype):
47
    """Test batching two DGLHeteroGraphs where some nodes are isolated in some relations"""
48
49
50
    g1, g2 = gs
    g1 = g1.astype(idtype).to(F.ctx())
    g2 = g2.astype(idtype).to(F.ctx())
51
    bg = dgl.batch([g1, g2])
52

53
54
    assert bg.idtype == idtype
    assert bg.device == F.ctx()
55
56
57
58
59
60
61
    assert bg.ntypes == g2.ntypes
    assert bg.etypes == g2.etypes
    assert bg.canonical_etypes == g2.canonical_etypes
    assert bg.batch_size == 2

    # Test number of nodes
    for ntype in bg.ntypes:
62
63
        print(ntype)
        assert F.asnumpy(bg.batch_num_nodes(ntype)).tolist() == [
64
65
66
67
68
            g1.number_of_nodes(ntype), g2.number_of_nodes(ntype)]
        assert bg.number_of_nodes(ntype) == (
                g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype))

    # Test number of edges
69
    for etype in bg.canonical_etypes:
70
        assert F.asnumpy(bg.batch_num_edges(etype)).tolist() == [
71
72
73
74
75
76
77
78
79
            g1.number_of_edges(etype), g2.number_of_edges(etype)]
        assert bg.number_of_edges(etype) == (
            g1.number_of_edges(etype) + g2.number_of_edges(etype))

    # Test relabeled nodes
    for ntype in bg.ntypes:
        assert list(F.asnumpy(bg.nodes(ntype))) == list(range(bg.number_of_nodes(ntype)))

    # Test relabeled edges
80
    src, dst = bg.edges(etype=('user', 'follows', 'user'))
81
82
    assert list(F.asnumpy(src)) == [0, 1, 4, 5]
    assert list(F.asnumpy(dst)) == [1, 2, 5, 6]
83
    src, dst = bg.edges(etype=('user', 'follows', 'developer'))
84
85
    assert list(F.asnumpy(src)) == [0, 1, 4, 5]
    assert list(F.asnumpy(dst)) == [1, 2, 4, 5]
86
    src, dst, eid = bg.edges(etype='plays', form='all')
87
88
89
90
91
    assert list(F.asnumpy(src)) == [0, 1, 2, 3, 4, 5, 6]
    assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2, 2, 3]
    assert list(F.asnumpy(eid)) == [0, 1, 2, 3, 4, 5, 6]

    # Test unbatching graphs
92
    g3, g4 = dgl.unbatch(bg)
93
94
95
    check_equivalence_between_heterographs(g1, g3)
    check_equivalence_between_heterographs(g2, g4)

96
97
98
99
100
101
    # Test dtype cast
    if idtype == "int32":
        bg_cast = bg.long()
    else:
        bg_cast = bg.int()
    assert bg.batch_size == bg_cast.batch_size
102

103
104
105
    # Test local var
    bg_local = bg.local_var()
    assert bg.batch_size == bg_local.batch_size
106
107

@parametrize_dtype
108
def test_batching_batched(idtype):
109
110
    """Test batching a DGLHeteroGraph and a BatchedDGLHeteroGraph."""
    g1 = dgl.heterograph({
111
112
        ('user', 'follows', 'user'): ([0, 1], [1, 2]),
        ('user', 'plays', 'game'): ([0, 1], [0, 0])
113
    }, idtype=idtype, device=F.ctx())
114
    g2 = dgl.heterograph({
115
116
        ('user', 'follows', 'user'): ([0, 1], [1, 2]),
        ('user', 'plays', 'game'): ([0, 1], [0, 0])
117
118
    }, idtype=idtype, device=F.ctx())
    bg1 = dgl.batch([g1, g2])
119
    g3 = dgl.heterograph({
120
121
        ('user', 'follows', 'user'): ([0], [1]),
        ('user', 'plays', 'game'): ([1], [0])
122
123
124
125
    }, idtype=idtype, device=F.ctx())
    bg2 = dgl.batch([bg1, g3])
    assert bg2.idtype == idtype
    assert bg2.device == F.ctx()
126
127
128
129
130
131
132
    assert bg2.ntypes == g3.ntypes
    assert bg2.etypes == g3.etypes
    assert bg2.canonical_etypes == g3.canonical_etypes
    assert bg2.batch_size == 3

    # Test number of nodes
    for ntype in bg2.ntypes:
133
        assert F.asnumpy(bg2.batch_num_nodes(ntype)).tolist() == [
134
135
136
137
138
139
            g1.number_of_nodes(ntype), g2.number_of_nodes(ntype), g3.number_of_nodes(ntype)]
        assert bg2.number_of_nodes(ntype) == (
                g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype) + g3.number_of_nodes(ntype))

    # Test number of edges
    for etype in bg2.canonical_etypes:
140
        assert F.asnumpy(bg2.batch_num_edges(etype)).tolist() == [
141
142
143
144
145
146
147
148
149
            g1.number_of_edges(etype), g2.number_of_edges(etype), g3.number_of_edges(etype)]
        assert bg2.number_of_edges(etype) == (
                g1.number_of_edges(etype) + g2.number_of_edges(etype) + g3.number_of_edges(etype))

    # Test relabeled nodes
    for ntype in bg2.ntypes:
        assert list(F.asnumpy(bg2.nodes(ntype))) == list(range(bg2.number_of_nodes(ntype)))

    # Test relabeled edges
150
    src, dst = bg2.edges(etype='follows')
151
152
    assert list(F.asnumpy(src)) == [0, 1, 3, 4, 6]
    assert list(F.asnumpy(dst)) == [1, 2, 4, 5, 7]
153
    src, dst = bg2.edges(etype='plays')
154
155
156
157
    assert list(F.asnumpy(src)) == [0, 1, 3, 4, 7]
    assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2]

    # Test unbatching graphs
158
    g4, g5, g6 = dgl.unbatch(bg2)
159
160
161
162
    check_equivalence_between_heterographs(g1, g4)
    check_equivalence_between_heterographs(g2, g5)
    check_equivalence_between_heterographs(g3, g6)

163
@parametrize_dtype
164
def test_features(idtype):
165
166
    """Test the features of batched DGLHeteroGraphs"""
    g1 = dgl.heterograph({
167
168
        ('user', 'follows', 'user'): ([0, 1], [1, 2]),
        ('user', 'plays', 'game'): ([0, 1], [0, 0])
169
    }, idtype=idtype, device=F.ctx())
170
171
172
173
174
175
176
177
178
    g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
    g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
    g1.nodes['game'].data['h1'] = F.tensor([[0.]])
    g1.nodes['game'].data['h2'] = F.tensor([[1.]])
    g1.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
    g1.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])
    g1.edges['plays'].data['h1'] = F.tensor([[0.], [1.]])

    g2 = dgl.heterograph({
179
180
        ('user', 'follows', 'user'): ([0, 1], [1, 2]),
        ('user', 'plays', 'game'): ([0, 1], [0, 0])
181
    }, idtype=idtype, device=F.ctx())
182
183
184
185
186
187
188
189
    g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
    g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
    g2.nodes['game'].data['h1'] = F.tensor([[0.]])
    g2.nodes['game'].data['h2'] = F.tensor([[1.]])
    g2.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
    g2.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])
    g2.edges['plays'].data['h1'] = F.tensor([[0.], [1.]])

190
191
    # test default setting
    bg = dgl.batch([g1, g2])
192
193
194
195
196
197
198
199
200
201
    assert F.allclose(bg.nodes['user'].data['h1'],
                      F.cat([g1.nodes['user'].data['h1'], g2.nodes['user'].data['h1']], dim=0))
    assert F.allclose(bg.nodes['user'].data['h2'],
                      F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']], dim=0))
    assert F.allclose(bg.nodes['game'].data['h1'],
                      F.cat([g1.nodes['game'].data['h1'], g2.nodes['game'].data['h1']], dim=0))
    assert F.allclose(bg.nodes['game'].data['h2'],
                      F.cat([g1.nodes['game'].data['h2'], g2.nodes['game'].data['h2']], dim=0))
    assert F.allclose(bg.edges['follows'].data['h1'],
                      F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']], dim=0))
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    assert F.allclose(bg.edges['follows'].data['h2'],
                      F.cat([g1.edges['follows'].data['h2'], g2.edges['follows'].data['h2']], dim=0))
    assert F.allclose(bg.edges['plays'].data['h1'],
                      F.cat([g1.edges['plays'].data['h1'], g2.edges['plays'].data['h1']], dim=0))

    # test specifying ndata/edata
    bg = dgl.batch([g1, g2], ndata=['h2'], edata=['h1'])
    assert F.allclose(bg.nodes['user'].data['h2'],
                      F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']], dim=0))
    assert F.allclose(bg.nodes['game'].data['h2'],
                      F.cat([g1.nodes['game'].data['h2'], g2.nodes['game'].data['h2']], dim=0))
    assert F.allclose(bg.edges['follows'].data['h1'],
                      F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']], dim=0))
    assert F.allclose(bg.edges['plays'].data['h1'],
                      F.cat([g1.edges['plays'].data['h1'], g2.edges['plays'].data['h1']], dim=0))
    assert 'h1' not in bg.nodes['user'].data
    assert 'h1' not in bg.nodes['game'].data
    assert 'h2' not in bg.edges['follows'].data
220
221

    # Test unbatching graphs
222
    g3, g4 = dgl.unbatch(bg)
223
224
    check_equivalence_between_heterographs(
        g1, g3,
225
        node_attrs={'user': ['h2'], 'game': ['h2']},
226
227
228
        edge_attrs={('user', 'follows', 'user'): ['h1']})
    check_equivalence_between_heterographs(
        g2, g4,
229
        node_attrs={'user': ['h2'], 'game': ['h2']},
230
231
        edge_attrs={('user', 'follows', 'user'): ['h1']})

232
233
234
235
236
    # test legacy
    bg = dgl.batch([g1, g2], edge_attrs=['h1'])
    assert 'h2' not in bg.edges['follows'].data.keys()

@unittest.skipIf(F.backend_name == 'mxnet', reason="MXNet does not support split array with zero-length segment.")
237
@parametrize_dtype
238
def test_empty_relation(idtype):
239
240
    """Test the features of batched DGLHeteroGraphs"""
    g1 = dgl.heterograph({
241
242
        ('user', 'follows', 'user'): ([0, 1], [1, 2]),
        ('user', 'plays', 'game'): ([], [])
243
    }, idtype=idtype, device=F.ctx())
244
245
246
247
248
249
    g1.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
    g1.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
    g1.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
    g1.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])

    g2 = dgl.heterograph({
250
251
        ('user', 'follows', 'user'): ([0, 1], [1, 2]),
        ('user', 'plays', 'game'): ([0, 1], [0, 0])
252
    }, idtype=idtype, device=F.ctx())
253
254
255
256
257
258
259
260
    g2.nodes['user'].data['h1'] = F.tensor([[0.], [1.], [2.]])
    g2.nodes['user'].data['h2'] = F.tensor([[3.], [4.], [5.]])
    g2.nodes['game'].data['h1'] = F.tensor([[0.]])
    g2.nodes['game'].data['h2'] = F.tensor([[1.]])
    g2.edges['follows'].data['h1'] = F.tensor([[0.], [1.]])
    g2.edges['follows'].data['h2'] = F.tensor([[2.], [3.]])
    g2.edges['plays'].data['h1'] = F.tensor([[0.], [1.]])

261
262
263
264
265
266
267
268
269
270
271
    bg = dgl.batch([g1, g2])

    # Test number of nodes
    for ntype in bg.ntypes:
        assert F.asnumpy(bg.batch_num_nodes(ntype)).tolist() == [
            g1.number_of_nodes(ntype), g2.number_of_nodes(ntype)]

    # Test number of edges
    for etype in bg.canonical_etypes:
        assert F.asnumpy(bg.batch_num_edges(etype)).tolist() == [
            g1.number_of_edges(etype), g2.number_of_edges(etype)]
272

273
    # Test features
274
275
276
277
278
279
280
281
282
283
284
    assert F.allclose(bg.nodes['user'].data['h1'],
                      F.cat([g1.nodes['user'].data['h1'], g2.nodes['user'].data['h1']], dim=0))
    assert F.allclose(bg.nodes['user'].data['h2'],
                      F.cat([g1.nodes['user'].data['h2'], g2.nodes['user'].data['h2']], dim=0))
    assert F.allclose(bg.nodes['game'].data['h1'], g2.nodes['game'].data['h1'])
    assert F.allclose(bg.nodes['game'].data['h2'], g2.nodes['game'].data['h2'])
    assert F.allclose(bg.edges['follows'].data['h1'],
                      F.cat([g1.edges['follows'].data['h1'], g2.edges['follows'].data['h1']], dim=0))
    assert F.allclose(bg.edges['plays'].data['h1'], g2.edges['plays'].data['h1'])

    # Test unbatching graphs
285
    g3, g4 = dgl.unbatch(bg)
286
287
288
289
290
291
292
293
294
295
    check_equivalence_between_heterographs(
        g1, g3,
        node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']},
        edge_attrs={('user', 'follows', 'user'): ['h1']})
    check_equivalence_between_heterographs(
        g2, g4,
        node_attrs={'user': ['h1', 'h2'], 'game': ['h1', 'h2']},
        edge_attrs={('user', 'follows', 'user'): ['h1']})

    # Test graphs without edges
296
297
    g1 = dgl.heterograph({('u', 'r', 'v'): ([], [])}, {'u': 0, 'v': 4})
    g2 = dgl.heterograph({('u', 'r', 'v'): ([], [])}, {'u': 1, 'v': 5})
298
    dgl.batch([g1, g2])
299

Mufei Li's avatar
Mufei Li committed
300
@parametrize_dtype
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
def test_unbatch2(idtype):
    # batch 3 graphs but unbatch to 2
    g1 = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())
    g2 = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())
    g3 = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())
    bg = dgl.batch([g1, g2, g3])
    bnn = F.tensor([8, 4])
    bne = F.tensor([6, 3])
    f1, f2 = dgl.unbatch(bg, node_split=bnn, edge_split=bne)
    u, v = f1.edges(order='eid')
    assert F.allclose(u, F.tensor([0, 1, 2, 4, 5, 6]))
    assert F.allclose(v, F.tensor([1, 2, 3, 5, 6, 7]))
    u, v = f2.edges(order='eid')
    assert F.allclose(u, F.tensor([0, 1, 2]))
    assert F.allclose(v, F.tensor([1, 2, 3]))

    # batch 2 but unbatch to 3
    bg = dgl.batch([f1, f2])
    gg1, gg2, gg3 = dgl.unbatch(bg, F.tensor([4, 4, 4]), F.tensor([3, 3, 3]))
    check_graph_equal(g1, gg1)
    check_graph_equal(g2, gg2)
    check_graph_equal(g3, gg3)
323

324
if __name__ == '__main__':
325
326
327
    #test_topology('int32')
    #test_batching_batched('int32')
    #test_batched_features('int32')
328
    # test_empty_relation('int64')
329
330
    #test_to_device('int32')
    pass