"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "dd07b19e27b737d844f62a8107228591f8d7bca8"
test_batch-heterograph.py 18 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
import unittest
2

3
4
5
6
import backend as F

import dgl
import pytest
7
from dgl.base import ALL
8
from utils import check_graph_equal, get_cases, parametrize_idtype
9

Jinjing Zhou's avatar
Jinjing Zhou committed
10

11
12
13
def check_equivalence_between_heterographs(
    g1, g2, node_attrs=None, edge_attrs=None
):
14
15
16
17
18
    assert g1.ntypes == g2.ntypes
    assert g1.etypes == g2.etypes
    assert g1.canonical_etypes == g2.canonical_etypes

    for nty in g1.ntypes:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
19
        assert g1.num_nodes(nty) == g2.num_nodes(nty)
20
21
22

    for ety in g1.etypes:
        if len(g1._etype2canonical[ety]) > 0:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
23
            assert g1.num_edges(ety) == g2.num_edges(ety)
24
25

    for ety in g1.canonical_etypes:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
26
        assert g1.num_edges(ety) == g2.num_edges(ety)
27
28
        src1, dst1, eid1 = g1.edges(etype=ety, form="all")
        src2, dst2, eid2 = g2.edges(etype=ety, form="all")
29
30
        assert F.allclose(src1, src2)
        assert F.allclose(dst1, dst2)
31
        assert F.allclose(eid1, eid2)
32
33
34

    if node_attrs is not None:
        for nty in node_attrs.keys():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
35
            if g1.num_nodes(nty) == 0:
36
                continue
37
            for feat_name in node_attrs[nty]:
Jinjing Zhou's avatar
Jinjing Zhou committed
38
                assert F.allclose(
39
40
                    g1.nodes[nty].data[feat_name], g2.nodes[nty].data[feat_name]
                )
41
42
43

    if edge_attrs is not None:
        for ety in edge_attrs.keys():
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
44
            if g1.num_edges(ety) == 0:
45
                continue
46
            for feat_name in edge_attrs[ety]:
Jinjing Zhou's avatar
Jinjing Zhou committed
47
                assert F.allclose(
48
49
                    g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name]
                )
Jinjing Zhou's avatar
Jinjing Zhou committed
50

51

52
@pytest.mark.parametrize("gs", get_cases(["two_hetero_batch"]))
nv-dlasalle's avatar
nv-dlasalle committed
53
@parametrize_idtype
54
def test_topology(gs, idtype):
peizhou001's avatar
peizhou001 committed
55
    """Test batching two DGLGraphs where some nodes are isolated in some relations"""
56
57
58
    g1, g2 = gs
    g1 = g1.astype(idtype).to(F.ctx())
    g2 = g2.astype(idtype).to(F.ctx())
59
    bg = dgl.batch([g1, g2])
60

61
62
    assert bg.idtype == idtype
    assert bg.device == F.ctx()
63
64
65
66
67
68
69
    assert bg.ntypes == g2.ntypes
    assert bg.etypes == g2.etypes
    assert bg.canonical_etypes == g2.canonical_etypes
    assert bg.batch_size == 2

    # Test number of nodes
    for ntype in bg.ntypes:
70
71
        print(ntype)
        assert F.asnumpy(bg.batch_num_nodes(ntype)).tolist() == [
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
72
73
            g1.num_nodes(ntype),
            g2.num_nodes(ntype),
74
        ]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
75
76
        assert bg.num_nodes(ntype) == (
            g1.num_nodes(ntype) + g2.num_nodes(ntype)
77
        )
78
79

    # Test number of edges
80
    for etype in bg.canonical_etypes:
81
        assert F.asnumpy(bg.batch_num_edges(etype)).tolist() == [
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
82
83
            g1.num_edges(etype),
            g2.num_edges(etype),
84
        ]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
85
86
        assert bg.num_edges(etype) == (
            g1.num_edges(etype) + g2.num_edges(etype)
87
        )
88
89
90

    # Test relabeled nodes
    for ntype in bg.ntypes:
Jinjing Zhou's avatar
Jinjing Zhou committed
91
        assert list(F.asnumpy(bg.nodes(ntype))) == list(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
92
            range(bg.num_nodes(ntype))
93
        )
94
95

    # Test relabeled edges
96
    src, dst = bg.edges(etype=("user", "follows", "user"))
97
98
    assert list(F.asnumpy(src)) == [0, 1, 4, 5]
    assert list(F.asnumpy(dst)) == [1, 2, 5, 6]
99
    src, dst = bg.edges(etype=("user", "follows", "developer"))
100
101
    assert list(F.asnumpy(src)) == [0, 1, 4, 5]
    assert list(F.asnumpy(dst)) == [1, 2, 4, 5]
102
    src, dst, eid = bg.edges(etype="plays", form="all")
103
104
105
106
107
    assert list(F.asnumpy(src)) == [0, 1, 2, 3, 4, 5, 6]
    assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2, 2, 3]
    assert list(F.asnumpy(eid)) == [0, 1, 2, 3, 4, 5, 6]

    # Test unbatching graphs
108
    g3, g4 = dgl.unbatch(bg)
109
110
111
    check_equivalence_between_heterographs(g1, g3)
    check_equivalence_between_heterographs(g2, g4)

112
113
114
115
116
117
    # Test dtype cast
    if idtype == "int32":
        bg_cast = bg.long()
    else:
        bg_cast = bg.int()
    assert bg.batch_size == bg_cast.batch_size
118

119
120
121
    # Test local var
    bg_local = bg.local_var()
    assert bg.batch_size == bg_local.batch_size
122

Jinjing Zhou's avatar
Jinjing Zhou committed
123

nv-dlasalle's avatar
nv-dlasalle committed
124
@parametrize_idtype
125
def test_batching_batched(idtype):
peizhou001's avatar
peizhou001 committed
126
    """Test batching a DGLGraph and a batched DGLGraph."""
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    g1 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1], [0, 0]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    g2 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1], [0, 0]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
143
    bg1 = dgl.batch([g1, g2])
144
145
146
147
148
149
150
151
    g3 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0], [1]),
            ("user", "plays", "game"): ([1], [0]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
152
153
154
    bg2 = dgl.batch([bg1, g3])
    assert bg2.idtype == idtype
    assert bg2.device == F.ctx()
155
156
157
158
159
160
161
    assert bg2.ntypes == g3.ntypes
    assert bg2.etypes == g3.etypes
    assert bg2.canonical_etypes == g3.canonical_etypes
    assert bg2.batch_size == 3

    # Test number of nodes
    for ntype in bg2.ntypes:
162
        assert F.asnumpy(bg2.batch_num_nodes(ntype)).tolist() == [
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
163
164
165
            g1.num_nodes(ntype),
            g2.num_nodes(ntype),
            g3.num_nodes(ntype),
166
        ]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
167
168
        assert bg2.num_nodes(ntype) == (
            g1.num_nodes(ntype) + g2.num_nodes(ntype) + g3.num_nodes(ntype)
169
        )
170
171
172

    # Test number of edges
    for etype in bg2.canonical_etypes:
173
        assert F.asnumpy(bg2.batch_num_edges(etype)).tolist() == [
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
174
175
176
            g1.num_edges(etype),
            g2.num_edges(etype),
            g3.num_edges(etype),
177
        ]
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
178
179
        assert bg2.num_edges(etype) == (
            g1.num_edges(etype) + g2.num_edges(etype) + g3.num_edges(etype)
180
        )
181
182
183

    # Test relabeled nodes
    for ntype in bg2.ntypes:
Jinjing Zhou's avatar
Jinjing Zhou committed
184
        assert list(F.asnumpy(bg2.nodes(ntype))) == list(
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
185
            range(bg2.num_nodes(ntype))
186
        )
187
188

    # Test relabeled edges
189
    src, dst = bg2.edges(etype="follows")
190
191
    assert list(F.asnumpy(src)) == [0, 1, 3, 4, 6]
    assert list(F.asnumpy(dst)) == [1, 2, 4, 5, 7]
192
    src, dst = bg2.edges(etype="plays")
193
194
195
196
    assert list(F.asnumpy(src)) == [0, 1, 3, 4, 7]
    assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2]

    # Test unbatching graphs
197
    g4, g5, g6 = dgl.unbatch(bg2)
198
199
200
201
    check_equivalence_between_heterographs(g1, g4)
    check_equivalence_between_heterographs(g2, g5)
    check_equivalence_between_heterographs(g3, g6)

Jinjing Zhou's avatar
Jinjing Zhou committed
202

nv-dlasalle's avatar
nv-dlasalle committed
203
@parametrize_idtype
204
def test_features(idtype):
peizhou001's avatar
peizhou001 committed
205
    """Test the features of batched DGLGraphs"""
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    g1 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1], [0, 0]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    g1.nodes["user"].data["h1"] = F.tensor([[0.0], [1.0], [2.0]])
    g1.nodes["user"].data["h2"] = F.tensor([[3.0], [4.0], [5.0]])
    g1.nodes["game"].data["h1"] = F.tensor([[0.0]])
    g1.nodes["game"].data["h2"] = F.tensor([[1.0]])
    g1.edges["follows"].data["h1"] = F.tensor([[0.0], [1.0]])
    g1.edges["follows"].data["h2"] = F.tensor([[2.0], [3.0]])
    g1.edges["plays"].data["h1"] = F.tensor([[0.0], [1.0]])

    g2 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1], [0, 0]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    g2.nodes["user"].data["h1"] = F.tensor([[0.0], [1.0], [2.0]])
    g2.nodes["user"].data["h2"] = F.tensor([[3.0], [4.0], [5.0]])
    g2.nodes["game"].data["h1"] = F.tensor([[0.0]])
    g2.nodes["game"].data["h2"] = F.tensor([[1.0]])
    g2.edges["follows"].data["h1"] = F.tensor([[0.0], [1.0]])
    g2.edges["follows"].data["h2"] = F.tensor([[2.0], [3.0]])
    g2.edges["plays"].data["h1"] = F.tensor([[0.0], [1.0]])
237

238
239
    # test default setting
    bg = dgl.batch([g1, g2])
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    assert F.allclose(
        bg.nodes["user"].data["h1"],
        F.cat(
            [g1.nodes["user"].data["h1"], g2.nodes["user"].data["h1"]], dim=0
        ),
    )
    assert F.allclose(
        bg.nodes["user"].data["h2"],
        F.cat(
            [g1.nodes["user"].data["h2"], g2.nodes["user"].data["h2"]], dim=0
        ),
    )
    assert F.allclose(
        bg.nodes["game"].data["h1"],
        F.cat(
            [g1.nodes["game"].data["h1"], g2.nodes["game"].data["h1"]], dim=0
        ),
    )
    assert F.allclose(
        bg.nodes["game"].data["h2"],
        F.cat(
            [g1.nodes["game"].data["h2"], g2.nodes["game"].data["h2"]], dim=0
        ),
    )
    assert F.allclose(
        bg.edges["follows"].data["h1"],
        F.cat(
            [g1.edges["follows"].data["h1"], g2.edges["follows"].data["h1"]],
            dim=0,
        ),
    )
    assert F.allclose(
        bg.edges["follows"].data["h2"],
        F.cat(
            [g1.edges["follows"].data["h2"], g2.edges["follows"].data["h2"]],
            dim=0,
        ),
    )
    assert F.allclose(
        bg.edges["plays"].data["h1"],
        F.cat(
            [g1.edges["plays"].data["h1"], g2.edges["plays"].data["h1"]], dim=0
        ),
    )
284
285

    # test specifying ndata/edata
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    bg = dgl.batch([g1, g2], ndata=["h2"], edata=["h1"])
    assert F.allclose(
        bg.nodes["user"].data["h2"],
        F.cat(
            [g1.nodes["user"].data["h2"], g2.nodes["user"].data["h2"]], dim=0
        ),
    )
    assert F.allclose(
        bg.nodes["game"].data["h2"],
        F.cat(
            [g1.nodes["game"].data["h2"], g2.nodes["game"].data["h2"]], dim=0
        ),
    )
    assert F.allclose(
        bg.edges["follows"].data["h1"],
        F.cat(
            [g1.edges["follows"].data["h1"], g2.edges["follows"].data["h1"]],
            dim=0,
        ),
    )
    assert F.allclose(
        bg.edges["plays"].data["h1"],
        F.cat(
            [g1.edges["plays"].data["h1"], g2.edges["plays"].data["h1"]], dim=0
        ),
    )
    assert "h1" not in bg.nodes["user"].data
    assert "h1" not in bg.nodes["game"].data
    assert "h2" not in bg.edges["follows"].data
315
316

    # Test unbatching graphs
317
    g3, g4 = dgl.unbatch(bg)
318
    check_equivalence_between_heterographs(
319
320
321
322
323
        g1,
        g3,
        node_attrs={"user": ["h2"], "game": ["h2"]},
        edge_attrs={("user", "follows", "user"): ["h1"]},
    )
324
    check_equivalence_between_heterographs(
325
326
327
328
329
        g2,
        g4,
        node_attrs={"user": ["h2"], "game": ["h2"]},
        edge_attrs={("user", "follows", "user"): ["h1"]},
    )
330

Jinjing Zhou's avatar
Jinjing Zhou committed
331

332
333
334
335
@unittest.skipIf(
    F.backend_name == "mxnet",
    reason="MXNet does not support split array with zero-length segment.",
)
nv-dlasalle's avatar
nv-dlasalle committed
336
@parametrize_idtype
337
def test_empty_relation(idtype):
peizhou001's avatar
peizhou001 committed
338
    """Test the features of batched DGLGraphs"""
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
    g1 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([], []),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    g1.nodes["user"].data["h1"] = F.tensor([[0.0], [1.0], [2.0]])
    g1.nodes["user"].data["h2"] = F.tensor([[3.0], [4.0], [5.0]])
    g1.edges["follows"].data["h1"] = F.tensor([[0.0], [1.0]])
    g1.edges["follows"].data["h2"] = F.tensor([[2.0], [3.0]])

    g2 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1], [0, 0]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    g2.nodes["user"].data["h1"] = F.tensor([[0.0], [1.0], [2.0]])
    g2.nodes["user"].data["h2"] = F.tensor([[3.0], [4.0], [5.0]])
    g2.nodes["game"].data["h1"] = F.tensor([[0.0]])
    g2.nodes["game"].data["h2"] = F.tensor([[1.0]])
    g2.edges["follows"].data["h1"] = F.tensor([[0.0], [1.0]])
    g2.edges["follows"].data["h2"] = F.tensor([[2.0], [3.0]])
    g2.edges["plays"].data["h1"] = F.tensor([[0.0], [1.0]])
367

368
369
370
371
372
    bg = dgl.batch([g1, g2])

    # Test number of nodes
    for ntype in bg.ntypes:
        assert F.asnumpy(bg.batch_num_nodes(ntype)).tolist() == [
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
373
374
            g1.num_nodes(ntype),
            g2.num_nodes(ntype),
375
        ]
376
377
378
379

    # Test number of edges
    for etype in bg.canonical_etypes:
        assert F.asnumpy(bg.batch_num_edges(etype)).tolist() == [
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
380
381
            g1.num_edges(etype),
            g2.num_edges(etype),
382
        ]
383

384
    # Test features
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    assert F.allclose(
        bg.nodes["user"].data["h1"],
        F.cat(
            [g1.nodes["user"].data["h1"], g2.nodes["user"].data["h1"]], dim=0
        ),
    )
    assert F.allclose(
        bg.nodes["user"].data["h2"],
        F.cat(
            [g1.nodes["user"].data["h2"], g2.nodes["user"].data["h2"]], dim=0
        ),
    )
    assert F.allclose(bg.nodes["game"].data["h1"], g2.nodes["game"].data["h1"])
    assert F.allclose(bg.nodes["game"].data["h2"], g2.nodes["game"].data["h2"])
    assert F.allclose(
        bg.edges["follows"].data["h1"],
        F.cat(
            [g1.edges["follows"].data["h1"], g2.edges["follows"].data["h1"]],
            dim=0,
        ),
    )
    assert F.allclose(
        bg.edges["plays"].data["h1"], g2.edges["plays"].data["h1"]
    )
409
410

    # Test unbatching graphs
411
    g3, g4 = dgl.unbatch(bg)
412
    check_equivalence_between_heterographs(
413
414
415
416
417
        g1,
        g3,
        node_attrs={"user": ["h1", "h2"], "game": ["h1", "h2"]},
        edge_attrs={("user", "follows", "user"): ["h1"]},
    )
418
    check_equivalence_between_heterographs(
419
420
421
422
423
        g2,
        g4,
        node_attrs={"user": ["h1", "h2"], "game": ["h1", "h2"]},
        edge_attrs={("user", "follows", "user"): ["h1"]},
    )
424
425

    # Test graphs without edges
426
427
    g1 = dgl.heterograph({("u", "r", "v"): ([], [])}, {"u": 0, "v": 4})
    g2 = dgl.heterograph({("u", "r", "v"): ([], [])}, {"u": 1, "v": 5})
428
    dgl.batch([g1, g2])
429

Jinjing Zhou's avatar
Jinjing Zhou committed
430

nv-dlasalle's avatar
nv-dlasalle committed
431
@parametrize_idtype
432
433
434
435
436
437
438
439
440
def test_unbatch2(idtype):
    # batch 3 graphs but unbatch to 2
    g1 = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())
    g2 = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())
    g3 = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())
    bg = dgl.batch([g1, g2, g3])
    bnn = F.tensor([8, 4])
    bne = F.tensor([6, 3])
    f1, f2 = dgl.unbatch(bg, node_split=bnn, edge_split=bne)
441
    u, v = f1.edges(order="eid")
442
443
    assert F.allclose(u, F.tensor([0, 1, 2, 4, 5, 6]))
    assert F.allclose(v, F.tensor([1, 2, 3, 5, 6, 7]))
444
    u, v = f2.edges(order="eid")
445
446
447
448
449
450
451
452
453
    assert F.allclose(u, F.tensor([0, 1, 2]))
    assert F.allclose(v, F.tensor([1, 2, 3]))

    # batch 2 but unbatch to 3
    bg = dgl.batch([f1, f2])
    gg1, gg2, gg3 = dgl.unbatch(bg, F.tensor([4, 4, 4]), F.tensor([3, 3, 3]))
    check_graph_equal(g1, gg1)
    check_graph_equal(g2, gg2)
    check_graph_equal(g3, gg3)
454

Jinjing Zhou's avatar
Jinjing Zhou committed
455

nv-dlasalle's avatar
nv-dlasalle committed
456
@parametrize_idtype
457
def test_slice_batch(idtype):
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
    g1 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([], []),
            ("user", "follows", "game"): ([0, 0], [1, 4]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    g2 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1], [0, 0]),
            ("user", "follows", "game"): ([0, 1], [1, 4]),
        },
        num_nodes_dict={"user": 4, "game": 6},
        idtype=idtype,
        device=F.ctx(),
    )
    g3 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0], [2]),
            ("user", "plays", "game"): ([1, 2], [3, 4]),
            ("user", "follows", "game"): ([], []),
        },
        idtype=idtype,
        device=F.ctx(),
    )
486
487
    g_list = [g1, g2, g3]
    bg = dgl.batch(g_list)
488
489
490
491
492
493
    bg.nodes["user"].data["h1"] = F.randn((bg.num_nodes("user"), 2))
    bg.nodes["user"].data["h2"] = F.randn((bg.num_nodes("user"), 5))
    bg.edges[("user", "follows", "user")].data["h1"] = F.randn(
        (bg.num_edges(("user", "follows", "user")), 2)
    )
    for fmat in ["coo", "csr", "csc"]:
494
495
496
497
498
499
500
501
502
503
504
        bg = bg.formats(fmat)
        for i in range(len(g_list)):
            g_i = g_list[i]
            g_slice = dgl.slice_batch(bg, i)
            assert g_i.ntypes == g_slice.ntypes
            assert g_i.canonical_etypes == g_slice.canonical_etypes
            assert g_i.idtype == g_slice.idtype
            assert g_i.device == g_slice.device
            for nty in g_i.ntypes:
                assert g_i.num_nodes(nty) == g_slice.num_nodes(nty)
                for feat in g_i.nodes[nty].data:
505
506
507
                    assert F.allclose(
                        g_i.nodes[nty].data[feat], g_slice.nodes[nty].data[feat]
                    )
508
509
510
511

            for ety in g_i.canonical_etypes:
                assert g_i.num_edges(ety) == g_slice.num_edges(ety)
                for feat in g_i.edges[ety].data:
512
513
514
                    assert F.allclose(
                        g_i.edges[ety].data[feat], g_slice.edges[ety].data[feat]
                    )
515
516


nv-dlasalle's avatar
nv-dlasalle committed
517
@parametrize_idtype
518
def test_batch_keeps_empty_data(idtype):
519
520
521
    g1 = (
        dgl.heterograph({("a", "to", "a"): ([], [])}).astype(idtype).to(F.ctx())
    )
522
    g1.nodes["a"].data["nh"] = F.tensor([])
Jinjing Zhou's avatar
Jinjing Zhou committed
523
    g1.edges[("a", "to", "a")].data["eh"] = F.tensor([])
524
525
526
    g2 = (
        dgl.heterograph({("a", "to", "a"): ([], [])}).astype(idtype).to(F.ctx())
    )
527
    g2.nodes["a"].data["nh"] = F.tensor([])
Jinjing Zhou's avatar
Jinjing Zhou committed
528
    g2.edges[("a", "to", "a")].data["eh"] = F.tensor([])
529
530
531
532
    g = dgl.batch([g1, g2])
    assert "nh" in g.nodes["a"].data
    assert "eh" in g.edges[("a", "to", "a")].data

Jinjing Zhou's avatar
Jinjing Zhou committed
533

534
535
536
@unittest.skipIf(
    F._default_context_str == "gpu", reason="Issue is not related with GPU"
)
Jinjing Zhou's avatar
Jinjing Zhou committed
537
538
539
def test_batch_netypes():
    # Test for https://github.com/dmlc/dgl/issues/2808
    import networkx as nx
540

Jinjing Zhou's avatar
Jinjing Zhou committed
541
    B = nx.DiGraph()
542
543
544
545
546
    B.add_nodes_from(
        [1, 2, 3, 4],
        bipartite=0,
        some_attr=F.tensor([1, 2, 3, 4], dtype=F.float32),
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
547
    B.add_nodes_from(["a", "b", "c"], bipartite=1)
548
549
550
551
552
553
554
555
556
557
    B.add_edges_from(
        [(1, "a"), (1, "b"), (2, "b"), (2, "c"), (3, "c"), (4, "a")]
    )

    g_dict = {
        0: dgl.bipartite_from_networkx(B, "A", "e", "B"),
        1: dgl.bipartite_from_networkx(B, "B", "e", "A"),
        2: dgl.bipartite_from_networkx(B, "A", "e", "B", u_attrs=["some_attr"]),
        3: dgl.bipartite_from_networkx(B, "B", "e", "A", u_attrs=["some_attr"]),
    }
Jinjing Zhou's avatar
Jinjing Zhou committed
558
559
560
561
    for _, g in g_dict.items():
        dgl.batch((g, g, g))


562
if __name__ == "__main__":
Jinjing Zhou's avatar
Jinjing Zhou committed
563
564
565
    # test_topology('int32')
    # test_batching_batched('int32')
    # test_batched_features('int32')
566
    # test_empty_relation('int64')
Jinjing Zhou's avatar
Jinjing Zhou committed
567
    # test_to_device('int32')
568
    pass