test_batch-heterograph.py 18.2 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
import unittest
2

3
4
5
6
import backend as F

import dgl
import pytest
7
from dgl.base import ALL
8
from test_utils import check_graph_equal, get_cases, parametrize_idtype
9

Jinjing Zhou's avatar
Jinjing Zhou committed
10

11
12
13
def check_equivalence_between_heterographs(
    g1, g2, node_attrs=None, edge_attrs=None
):
14
15
16
17
18
19
20
21
22
23
24
25
26
    assert g1.ntypes == g2.ntypes
    assert g1.etypes == g2.etypes
    assert g1.canonical_etypes == g2.canonical_etypes

    for nty in g1.ntypes:
        assert g1.number_of_nodes(nty) == g2.number_of_nodes(nty)

    for ety in g1.etypes:
        if len(g1._etype2canonical[ety]) > 0:
            assert g1.number_of_edges(ety) == g2.number_of_edges(ety)

    for ety in g1.canonical_etypes:
        assert g1.number_of_edges(ety) == g2.number_of_edges(ety)
27
28
        src1, dst1, eid1 = g1.edges(etype=ety, form="all")
        src2, dst2, eid2 = g2.edges(etype=ety, form="all")
29
30
        assert F.allclose(src1, src2)
        assert F.allclose(dst1, dst2)
31
        assert F.allclose(eid1, eid2)
32
33
34

    if node_attrs is not None:
        for nty in node_attrs.keys():
35
36
            if g1.number_of_nodes(nty) == 0:
                continue
37
            for feat_name in node_attrs[nty]:
Jinjing Zhou's avatar
Jinjing Zhou committed
38
                assert F.allclose(
39
40
                    g1.nodes[nty].data[feat_name], g2.nodes[nty].data[feat_name]
                )
41
42
43

    if edge_attrs is not None:
        for ety in edge_attrs.keys():
44
45
            if g1.number_of_edges(ety) == 0:
                continue
46
            for feat_name in edge_attrs[ety]:
Jinjing Zhou's avatar
Jinjing Zhou committed
47
                assert F.allclose(
48
49
                    g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name]
                )
Jinjing Zhou's avatar
Jinjing Zhou committed
50

51

52
@pytest.mark.parametrize("gs", get_cases(["two_hetero_batch"]))
nv-dlasalle's avatar
nv-dlasalle committed
53
@parametrize_idtype
54
def test_topology(gs, idtype):
peizhou001's avatar
peizhou001 committed
55
    """Test batching two DGLGraphs where some nodes are isolated in some relations"""
56
57
58
    g1, g2 = gs
    g1 = g1.astype(idtype).to(F.ctx())
    g2 = g2.astype(idtype).to(F.ctx())
59
    bg = dgl.batch([g1, g2])
60

61
62
    assert bg.idtype == idtype
    assert bg.device == F.ctx()
63
64
65
66
67
68
69
    assert bg.ntypes == g2.ntypes
    assert bg.etypes == g2.etypes
    assert bg.canonical_etypes == g2.canonical_etypes
    assert bg.batch_size == 2

    # Test number of nodes
    for ntype in bg.ntypes:
70
71
        print(ntype)
        assert F.asnumpy(bg.batch_num_nodes(ntype)).tolist() == [
72
73
74
            g1.number_of_nodes(ntype),
            g2.number_of_nodes(ntype),
        ]
75
        assert bg.number_of_nodes(ntype) == (
76
77
            g1.number_of_nodes(ntype) + g2.number_of_nodes(ntype)
        )
78
79

    # Test number of edges
80
    for etype in bg.canonical_etypes:
81
        assert F.asnumpy(bg.batch_num_edges(etype)).tolist() == [
82
83
84
            g1.number_of_edges(etype),
            g2.number_of_edges(etype),
        ]
85
        assert bg.number_of_edges(etype) == (
86
87
            g1.number_of_edges(etype) + g2.number_of_edges(etype)
        )
88
89
90

    # Test relabeled nodes
    for ntype in bg.ntypes:
Jinjing Zhou's avatar
Jinjing Zhou committed
91
        assert list(F.asnumpy(bg.nodes(ntype))) == list(
92
93
            range(bg.number_of_nodes(ntype))
        )
94
95

    # Test relabeled edges
96
    src, dst = bg.edges(etype=("user", "follows", "user"))
97
98
    assert list(F.asnumpy(src)) == [0, 1, 4, 5]
    assert list(F.asnumpy(dst)) == [1, 2, 5, 6]
99
    src, dst = bg.edges(etype=("user", "follows", "developer"))
100
101
    assert list(F.asnumpy(src)) == [0, 1, 4, 5]
    assert list(F.asnumpy(dst)) == [1, 2, 4, 5]
102
    src, dst, eid = bg.edges(etype="plays", form="all")
103
104
105
106
107
    assert list(F.asnumpy(src)) == [0, 1, 2, 3, 4, 5, 6]
    assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2, 2, 3]
    assert list(F.asnumpy(eid)) == [0, 1, 2, 3, 4, 5, 6]

    # Test unbatching graphs
108
    g3, g4 = dgl.unbatch(bg)
109
110
111
    check_equivalence_between_heterographs(g1, g3)
    check_equivalence_between_heterographs(g2, g4)

112
113
114
115
116
117
    # Test dtype cast
    if idtype == "int32":
        bg_cast = bg.long()
    else:
        bg_cast = bg.int()
    assert bg.batch_size == bg_cast.batch_size
118

119
120
121
    # Test local var
    bg_local = bg.local_var()
    assert bg.batch_size == bg_local.batch_size
122

Jinjing Zhou's avatar
Jinjing Zhou committed
123

nv-dlasalle's avatar
nv-dlasalle committed
124
@parametrize_idtype
125
def test_batching_batched(idtype):
peizhou001's avatar
peizhou001 committed
126
    """Test batching a DGLGraph and a batched DGLGraph."""
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
    g1 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1], [0, 0]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    g2 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1], [0, 0]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
143
    bg1 = dgl.batch([g1, g2])
144
145
146
147
148
149
150
151
    g3 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0], [1]),
            ("user", "plays", "game"): ([1], [0]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
152
153
154
    bg2 = dgl.batch([bg1, g3])
    assert bg2.idtype == idtype
    assert bg2.device == F.ctx()
155
156
157
158
159
160
161
    assert bg2.ntypes == g3.ntypes
    assert bg2.etypes == g3.etypes
    assert bg2.canonical_etypes == g3.canonical_etypes
    assert bg2.batch_size == 3

    # Test number of nodes
    for ntype in bg2.ntypes:
162
        assert F.asnumpy(bg2.batch_num_nodes(ntype)).tolist() == [
163
164
165
166
            g1.number_of_nodes(ntype),
            g2.number_of_nodes(ntype),
            g3.number_of_nodes(ntype),
        ]
167
        assert bg2.number_of_nodes(ntype) == (
168
169
170
171
            g1.number_of_nodes(ntype)
            + g2.number_of_nodes(ntype)
            + g3.number_of_nodes(ntype)
        )
172
173
174

    # Test number of edges
    for etype in bg2.canonical_etypes:
175
        assert F.asnumpy(bg2.batch_num_edges(etype)).tolist() == [
176
177
178
179
            g1.number_of_edges(etype),
            g2.number_of_edges(etype),
            g3.number_of_edges(etype),
        ]
180
        assert bg2.number_of_edges(etype) == (
181
182
183
184
            g1.number_of_edges(etype)
            + g2.number_of_edges(etype)
            + g3.number_of_edges(etype)
        )
185
186
187

    # Test relabeled nodes
    for ntype in bg2.ntypes:
Jinjing Zhou's avatar
Jinjing Zhou committed
188
        assert list(F.asnumpy(bg2.nodes(ntype))) == list(
189
190
            range(bg2.number_of_nodes(ntype))
        )
191
192

    # Test relabeled edges
193
    src, dst = bg2.edges(etype="follows")
194
195
    assert list(F.asnumpy(src)) == [0, 1, 3, 4, 6]
    assert list(F.asnumpy(dst)) == [1, 2, 4, 5, 7]
196
    src, dst = bg2.edges(etype="plays")
197
198
199
200
    assert list(F.asnumpy(src)) == [0, 1, 3, 4, 7]
    assert list(F.asnumpy(dst)) == [0, 0, 1, 1, 2]

    # Test unbatching graphs
201
    g4, g5, g6 = dgl.unbatch(bg2)
202
203
204
205
    check_equivalence_between_heterographs(g1, g4)
    check_equivalence_between_heterographs(g2, g5)
    check_equivalence_between_heterographs(g3, g6)

Jinjing Zhou's avatar
Jinjing Zhou committed
206

nv-dlasalle's avatar
nv-dlasalle committed
207
@parametrize_idtype
208
def test_features(idtype):
peizhou001's avatar
peizhou001 committed
209
    """Test the features of batched DGLGraphs"""
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    g1 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1], [0, 0]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    g1.nodes["user"].data["h1"] = F.tensor([[0.0], [1.0], [2.0]])
    g1.nodes["user"].data["h2"] = F.tensor([[3.0], [4.0], [5.0]])
    g1.nodes["game"].data["h1"] = F.tensor([[0.0]])
    g1.nodes["game"].data["h2"] = F.tensor([[1.0]])
    g1.edges["follows"].data["h1"] = F.tensor([[0.0], [1.0]])
    g1.edges["follows"].data["h2"] = F.tensor([[2.0], [3.0]])
    g1.edges["plays"].data["h1"] = F.tensor([[0.0], [1.0]])

    g2 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1], [0, 0]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    g2.nodes["user"].data["h1"] = F.tensor([[0.0], [1.0], [2.0]])
    g2.nodes["user"].data["h2"] = F.tensor([[3.0], [4.0], [5.0]])
    g2.nodes["game"].data["h1"] = F.tensor([[0.0]])
    g2.nodes["game"].data["h2"] = F.tensor([[1.0]])
    g2.edges["follows"].data["h1"] = F.tensor([[0.0], [1.0]])
    g2.edges["follows"].data["h2"] = F.tensor([[2.0], [3.0]])
    g2.edges["plays"].data["h1"] = F.tensor([[0.0], [1.0]])
241

242
243
    # test default setting
    bg = dgl.batch([g1, g2])
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    assert F.allclose(
        bg.nodes["user"].data["h1"],
        F.cat(
            [g1.nodes["user"].data["h1"], g2.nodes["user"].data["h1"]], dim=0
        ),
    )
    assert F.allclose(
        bg.nodes["user"].data["h2"],
        F.cat(
            [g1.nodes["user"].data["h2"], g2.nodes["user"].data["h2"]], dim=0
        ),
    )
    assert F.allclose(
        bg.nodes["game"].data["h1"],
        F.cat(
            [g1.nodes["game"].data["h1"], g2.nodes["game"].data["h1"]], dim=0
        ),
    )
    assert F.allclose(
        bg.nodes["game"].data["h2"],
        F.cat(
            [g1.nodes["game"].data["h2"], g2.nodes["game"].data["h2"]], dim=0
        ),
    )
    assert F.allclose(
        bg.edges["follows"].data["h1"],
        F.cat(
            [g1.edges["follows"].data["h1"], g2.edges["follows"].data["h1"]],
            dim=0,
        ),
    )
    assert F.allclose(
        bg.edges["follows"].data["h2"],
        F.cat(
            [g1.edges["follows"].data["h2"], g2.edges["follows"].data["h2"]],
            dim=0,
        ),
    )
    assert F.allclose(
        bg.edges["plays"].data["h1"],
        F.cat(
            [g1.edges["plays"].data["h1"], g2.edges["plays"].data["h1"]], dim=0
        ),
    )
288
289

    # test specifying ndata/edata
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
    bg = dgl.batch([g1, g2], ndata=["h2"], edata=["h1"])
    assert F.allclose(
        bg.nodes["user"].data["h2"],
        F.cat(
            [g1.nodes["user"].data["h2"], g2.nodes["user"].data["h2"]], dim=0
        ),
    )
    assert F.allclose(
        bg.nodes["game"].data["h2"],
        F.cat(
            [g1.nodes["game"].data["h2"], g2.nodes["game"].data["h2"]], dim=0
        ),
    )
    assert F.allclose(
        bg.edges["follows"].data["h1"],
        F.cat(
            [g1.edges["follows"].data["h1"], g2.edges["follows"].data["h1"]],
            dim=0,
        ),
    )
    assert F.allclose(
        bg.edges["plays"].data["h1"],
        F.cat(
            [g1.edges["plays"].data["h1"], g2.edges["plays"].data["h1"]], dim=0
        ),
    )
    assert "h1" not in bg.nodes["user"].data
    assert "h1" not in bg.nodes["game"].data
    assert "h2" not in bg.edges["follows"].data
319
320

    # Test unbatching graphs
321
    g3, g4 = dgl.unbatch(bg)
322
    check_equivalence_between_heterographs(
323
324
325
326
327
        g1,
        g3,
        node_attrs={"user": ["h2"], "game": ["h2"]},
        edge_attrs={("user", "follows", "user"): ["h1"]},
    )
328
    check_equivalence_between_heterographs(
329
330
331
332
333
        g2,
        g4,
        node_attrs={"user": ["h2"], "game": ["h2"]},
        edge_attrs={("user", "follows", "user"): ["h1"]},
    )
334

Jinjing Zhou's avatar
Jinjing Zhou committed
335

336
337
338
339
@unittest.skipIf(
    F.backend_name == "mxnet",
    reason="MXNet does not support split array with zero-length segment.",
)
nv-dlasalle's avatar
nv-dlasalle committed
340
@parametrize_idtype
341
def test_empty_relation(idtype):
peizhou001's avatar
peizhou001 committed
342
    """Test the features of batched DGLGraphs"""
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    g1 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([], []),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    g1.nodes["user"].data["h1"] = F.tensor([[0.0], [1.0], [2.0]])
    g1.nodes["user"].data["h2"] = F.tensor([[3.0], [4.0], [5.0]])
    g1.edges["follows"].data["h1"] = F.tensor([[0.0], [1.0]])
    g1.edges["follows"].data["h2"] = F.tensor([[2.0], [3.0]])

    g2 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1], [0, 0]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    g2.nodes["user"].data["h1"] = F.tensor([[0.0], [1.0], [2.0]])
    g2.nodes["user"].data["h2"] = F.tensor([[3.0], [4.0], [5.0]])
    g2.nodes["game"].data["h1"] = F.tensor([[0.0]])
    g2.nodes["game"].data["h2"] = F.tensor([[1.0]])
    g2.edges["follows"].data["h1"] = F.tensor([[0.0], [1.0]])
    g2.edges["follows"].data["h2"] = F.tensor([[2.0], [3.0]])
    g2.edges["plays"].data["h1"] = F.tensor([[0.0], [1.0]])
371

372
373
374
375
376
    bg = dgl.batch([g1, g2])

    # Test number of nodes
    for ntype in bg.ntypes:
        assert F.asnumpy(bg.batch_num_nodes(ntype)).tolist() == [
377
378
379
            g1.number_of_nodes(ntype),
            g2.number_of_nodes(ntype),
        ]
380
381
382
383

    # Test number of edges
    for etype in bg.canonical_etypes:
        assert F.asnumpy(bg.batch_num_edges(etype)).tolist() == [
384
385
386
            g1.number_of_edges(etype),
            g2.number_of_edges(etype),
        ]
387

388
    # Test features
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
    assert F.allclose(
        bg.nodes["user"].data["h1"],
        F.cat(
            [g1.nodes["user"].data["h1"], g2.nodes["user"].data["h1"]], dim=0
        ),
    )
    assert F.allclose(
        bg.nodes["user"].data["h2"],
        F.cat(
            [g1.nodes["user"].data["h2"], g2.nodes["user"].data["h2"]], dim=0
        ),
    )
    assert F.allclose(bg.nodes["game"].data["h1"], g2.nodes["game"].data["h1"])
    assert F.allclose(bg.nodes["game"].data["h2"], g2.nodes["game"].data["h2"])
    assert F.allclose(
        bg.edges["follows"].data["h1"],
        F.cat(
            [g1.edges["follows"].data["h1"], g2.edges["follows"].data["h1"]],
            dim=0,
        ),
    )
    assert F.allclose(
        bg.edges["plays"].data["h1"], g2.edges["plays"].data["h1"]
    )
413
414

    # Test unbatching graphs
415
    g3, g4 = dgl.unbatch(bg)
416
    check_equivalence_between_heterographs(
417
418
419
420
421
        g1,
        g3,
        node_attrs={"user": ["h1", "h2"], "game": ["h1", "h2"]},
        edge_attrs={("user", "follows", "user"): ["h1"]},
    )
422
    check_equivalence_between_heterographs(
423
424
425
426
427
        g2,
        g4,
        node_attrs={"user": ["h1", "h2"], "game": ["h1", "h2"]},
        edge_attrs={("user", "follows", "user"): ["h1"]},
    )
428
429

    # Test graphs without edges
430
431
    g1 = dgl.heterograph({("u", "r", "v"): ([], [])}, {"u": 0, "v": 4})
    g2 = dgl.heterograph({("u", "r", "v"): ([], [])}, {"u": 1, "v": 5})
432
    dgl.batch([g1, g2])
433

Jinjing Zhou's avatar
Jinjing Zhou committed
434

nv-dlasalle's avatar
nv-dlasalle committed
435
@parametrize_idtype
436
437
438
439
440
441
442
443
444
def test_unbatch2(idtype):
    # batch 3 graphs but unbatch to 2
    g1 = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())
    g2 = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())
    g3 = dgl.graph(([0, 1, 2], [1, 2, 3]), idtype=idtype, device=F.ctx())
    bg = dgl.batch([g1, g2, g3])
    bnn = F.tensor([8, 4])
    bne = F.tensor([6, 3])
    f1, f2 = dgl.unbatch(bg, node_split=bnn, edge_split=bne)
445
    u, v = f1.edges(order="eid")
446
447
    assert F.allclose(u, F.tensor([0, 1, 2, 4, 5, 6]))
    assert F.allclose(v, F.tensor([1, 2, 3, 5, 6, 7]))
448
    u, v = f2.edges(order="eid")
449
450
451
452
453
454
455
456
457
    assert F.allclose(u, F.tensor([0, 1, 2]))
    assert F.allclose(v, F.tensor([1, 2, 3]))

    # batch 2 but unbatch to 3
    bg = dgl.batch([f1, f2])
    gg1, gg2, gg3 = dgl.unbatch(bg, F.tensor([4, 4, 4]), F.tensor([3, 3, 3]))
    check_graph_equal(g1, gg1)
    check_graph_equal(g2, gg2)
    check_graph_equal(g3, gg3)
458

Jinjing Zhou's avatar
Jinjing Zhou committed
459

nv-dlasalle's avatar
nv-dlasalle committed
460
@parametrize_idtype
461
def test_slice_batch(idtype):
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
    g1 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([], []),
            ("user", "follows", "game"): ([0, 0], [1, 4]),
        },
        idtype=idtype,
        device=F.ctx(),
    )
    g2 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0, 1], [1, 2]),
            ("user", "plays", "game"): ([0, 1], [0, 0]),
            ("user", "follows", "game"): ([0, 1], [1, 4]),
        },
        num_nodes_dict={"user": 4, "game": 6},
        idtype=idtype,
        device=F.ctx(),
    )
    g3 = dgl.heterograph(
        {
            ("user", "follows", "user"): ([0], [2]),
            ("user", "plays", "game"): ([1, 2], [3, 4]),
            ("user", "follows", "game"): ([], []),
        },
        idtype=idtype,
        device=F.ctx(),
    )
490
491
    g_list = [g1, g2, g3]
    bg = dgl.batch(g_list)
492
493
494
495
496
497
    bg.nodes["user"].data["h1"] = F.randn((bg.num_nodes("user"), 2))
    bg.nodes["user"].data["h2"] = F.randn((bg.num_nodes("user"), 5))
    bg.edges[("user", "follows", "user")].data["h1"] = F.randn(
        (bg.num_edges(("user", "follows", "user")), 2)
    )
    for fmat in ["coo", "csr", "csc"]:
498
499
500
501
502
503
504
505
506
507
508
        bg = bg.formats(fmat)
        for i in range(len(g_list)):
            g_i = g_list[i]
            g_slice = dgl.slice_batch(bg, i)
            assert g_i.ntypes == g_slice.ntypes
            assert g_i.canonical_etypes == g_slice.canonical_etypes
            assert g_i.idtype == g_slice.idtype
            assert g_i.device == g_slice.device
            for nty in g_i.ntypes:
                assert g_i.num_nodes(nty) == g_slice.num_nodes(nty)
                for feat in g_i.nodes[nty].data:
509
510
511
                    assert F.allclose(
                        g_i.nodes[nty].data[feat], g_slice.nodes[nty].data[feat]
                    )
512
513
514
515

            for ety in g_i.canonical_etypes:
                assert g_i.num_edges(ety) == g_slice.num_edges(ety)
                for feat in g_i.edges[ety].data:
516
517
518
                    assert F.allclose(
                        g_i.edges[ety].data[feat], g_slice.edges[ety].data[feat]
                    )
519
520


nv-dlasalle's avatar
nv-dlasalle committed
521
@parametrize_idtype
522
def test_batch_keeps_empty_data(idtype):
523
524
525
    g1 = (
        dgl.heterograph({("a", "to", "a"): ([], [])}).astype(idtype).to(F.ctx())
    )
526
    g1.nodes["a"].data["nh"] = F.tensor([])
Jinjing Zhou's avatar
Jinjing Zhou committed
527
    g1.edges[("a", "to", "a")].data["eh"] = F.tensor([])
528
529
530
    g2 = (
        dgl.heterograph({("a", "to", "a"): ([], [])}).astype(idtype).to(F.ctx())
    )
531
    g2.nodes["a"].data["nh"] = F.tensor([])
Jinjing Zhou's avatar
Jinjing Zhou committed
532
    g2.edges[("a", "to", "a")].data["eh"] = F.tensor([])
533
534
535
536
    g = dgl.batch([g1, g2])
    assert "nh" in g.nodes["a"].data
    assert "eh" in g.edges[("a", "to", "a")].data

Jinjing Zhou's avatar
Jinjing Zhou committed
537

538
539
540
@unittest.skipIf(
    F._default_context_str == "gpu", reason="Issue is not related with GPU"
)
Jinjing Zhou's avatar
Jinjing Zhou committed
541
542
543
def test_batch_netypes():
    # Test for https://github.com/dmlc/dgl/issues/2808
    import networkx as nx
544

Jinjing Zhou's avatar
Jinjing Zhou committed
545
    B = nx.DiGraph()
546
547
548
549
550
    B.add_nodes_from(
        [1, 2, 3, 4],
        bipartite=0,
        some_attr=F.tensor([1, 2, 3, 4], dtype=F.float32),
    )
Jinjing Zhou's avatar
Jinjing Zhou committed
551
    B.add_nodes_from(["a", "b", "c"], bipartite=1)
552
553
554
555
556
557
558
559
560
561
    B.add_edges_from(
        [(1, "a"), (1, "b"), (2, "b"), (2, "c"), (3, "c"), (4, "a")]
    )

    g_dict = {
        0: dgl.bipartite_from_networkx(B, "A", "e", "B"),
        1: dgl.bipartite_from_networkx(B, "B", "e", "A"),
        2: dgl.bipartite_from_networkx(B, "A", "e", "B", u_attrs=["some_attr"]),
        3: dgl.bipartite_from_networkx(B, "B", "e", "A", u_attrs=["some_attr"]),
    }
Jinjing Zhou's avatar
Jinjing Zhou committed
562
563
564
565
    for _, g in g_dict.items():
        dgl.batch((g, g, g))


566
if __name__ == "__main__":
Jinjing Zhou's avatar
Jinjing Zhou committed
567
568
569
    # test_topology('int32')
    # test_batching_batched('int32')
    # test_batched_features('int32')
570
    # test_empty_relation('int64')
Jinjing Zhou's avatar
Jinjing Zhou committed
571
    # test_to_device('int32')
572
    pass