test_dataloader.py 22.5 KB
Newer Older
1
import os
2
3
4
5
6
7
import unittest
from collections.abc import Iterator, Mapping
from functools import partial

import backend as F

8
import dgl
9
import dgl.ops as OPS
10
11
import numpy as np
import pytest
12
import torch
13
import torch.distributed as dist
14
from pytests_utils import parametrize_idtype
15
16


17
@pytest.mark.parametrize("batch_size", [None, 16])
18
def test_graph_dataloader(batch_size):
19
    num_batches = 2
20
21
    num_samples = num_batches * (batch_size if batch_size is not None else 1)
    minigc_dataset = dgl.data.MiniGCDataset(num_samples, 10, 20)
22
23
24
    data_loader = dgl.dataloading.GraphDataLoader(
        minigc_dataset, batch_size=batch_size, shuffle=True
    )
25
    assert isinstance(iter(data_loader), Iterator)
26
27
    for graph, label in data_loader:
        assert isinstance(graph, dgl.DGLGraph)
28
29
30
31
32
33
        if batch_size is not None:
            assert F.asnumpy(label).shape[0] == batch_size
        else:
            # If batch size is None, the label element will be a single scalar following
            # PyTorch's practice.
            assert F.asnumpy(label).ndim == 0
34

35
36
37

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("num_workers", [0, 4])
38
39
40
def test_cluster_gcn(num_workers):
    dataset = dgl.data.CoraFullDataset()
    g = dataset[0]
41
42
    sampler = dgl.dataloading.ClusterGCNSampler(g, 100)
    dataloader = dgl.dataloading.DataLoader(
43
44
        g, torch.arange(100), sampler, batch_size=4, num_workers=num_workers
    )
45
46
47
    assert len(dataloader) == 25
    for i, sg in enumerate(dataloader):
        pass
48

49
50

@pytest.mark.parametrize("num_workers", [0, 4])
51
52
53
def test_shadow(num_workers):
    g = dgl.data.CoraFullDataset()[0]
    sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15])
54
    dataloader = dgl.dataloading.DataLoader(
55
56
57
58
59
60
61
62
        g,
        torch.arange(g.num_nodes()),
        sampler,
        batch_size=5,
        shuffle=True,
        drop_last=False,
        num_workers=num_workers,
    )
63
    for i, (input_nodes, output_nodes, subgraph) in enumerate(dataloader):
64
        assert torch.equal(input_nodes, subgraph.ndata[dgl.NID])
65
66
67
68
69
        assert torch.equal(input_nodes[: output_nodes.shape[0]], output_nodes)
        assert torch.equal(
            subgraph.ndata["label"], g.ndata["label"][input_nodes]
        )
        assert torch.equal(subgraph.ndata["feat"], g.ndata["feat"][input_nodes])
70
71
72
        if i == 5:
            break

73
74
75

@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("mode", ["node", "edge", "walk"])
76
77
78
def test_saint(num_workers, mode):
    g = dgl.data.CoraFullDataset()[0]

79
    if mode == "node":
80
        budget = 100
81
    elif mode == "edge":
82
        budget = 200
83
    elif mode == "walk":
84
85
86
87
        budget = (3, 2)

    sampler = dgl.dataloading.SAINTSampler(mode, budget)
    dataloader = dgl.dataloading.DataLoader(
88
89
        g, torch.arange(100), sampler, num_workers=num_workers
    )
90
91
92
    assert len(dataloader) == 100
    for sg in dataloader:
        pass
93

94

95
@parametrize_idtype
96
97
98
99
100
@pytest.mark.parametrize(
    "mode", ["cpu", "uva_cuda_indices", "uva_cpu_indices", "pure_gpu"]
)
@pytest.mark.parametrize("use_ddp", [False, True])
@pytest.mark.parametrize("use_mask", [False, True])
101
def test_neighbor_nonuniform(idtype, mode, use_ddp, use_mask):
102
103
104
105
    if mode != "cpu" and F.ctx() == F.cpu():
        pytest.skip("UVA and GPU sampling require a GPU.")
    if mode != "cpu" and use_mask:
        pytest.skip("Masked sampling only works on CPU.")
106
    if use_ddp:
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        if os.name == "nt":
            pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
        dist.init_process_group(
            "gloo" if F.ctx() == F.cpu() else "nccl",
            "tcp://127.0.0.1:12347",
            world_size=1,
            rank=0,
        )
    g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1])).astype(
        idtype
    )
    g.edata["p"] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
    g.edata["mask"] = g.edata["p"] != 0
    if mode in ("cpu", "uva_cpu_indices"):
121
122
123
        indices = F.copy_to(F.tensor([0, 1], idtype), F.cpu())
    else:
        indices = F.copy_to(F.tensor([0, 1], idtype), F.cuda())
124
    if mode == "pure_gpu":
125
        g = g.to(F.cuda())
126
    use_uva = mode.startswith("uva")
127

128
    if use_mask:
129
        prob, mask = None, "mask"
130
    else:
131
        prob, mask = "p", None
132

133
134
135
136
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [2], prob=prob, mask=mask
    )
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
137
        dataloader = dgl.dataloading.DataLoader(
138
139
140
141
142
            g,
            indices,
            sampler,
            batch_size=1,
            device=F.ctx(),
143
144
            num_workers=num_workers,
            use_uva=use_uva,
145
146
            use_ddp=use_ddp,
        )
147
148
149
150
151
152
153
        for input_nodes, output_nodes, blocks in dataloader:
            seed = output_nodes.item()
            neighbors = set(input_nodes[1:].cpu().numpy())
            if seed == 1:
                assert neighbors == {5, 6}
            elif seed == 0:
                assert neighbors == {1, 2}
154

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    g = dgl.heterograph(
        {
            ("B", "BA", "A"): (
                [1, 2, 3, 4, 5, 6, 7, 8],
                [0, 0, 0, 0, 1, 1, 1, 1],
            ),
            ("C", "CA", "A"): (
                [1, 2, 3, 4, 5, 6, 7, 8],
                [0, 0, 0, 0, 1, 1, 1, 1],
            ),
        }
    ).astype(idtype)
    g.edges["BA"].data["p"] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
    g.edges["BA"].data["mask"] = g.edges["BA"].data["p"] != 0
    g.edges["CA"].data["p"] = torch.FloatTensor([0, 0, 1, 1, 0, 0, 1, 1])
    g.edges["CA"].data["mask"] = g.edges["CA"].data["p"] != 0
    if mode == "pure_gpu":
172
        g = g.to(F.cuda())
173
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
174
        dataloader = dgl.dataloading.DataLoader(
175
176
177
178
179
            g,
            {"A": indices},
            sampler,
            batch_size=1,
            device=F.ctx(),
180
181
            num_workers=num_workers,
            use_uva=use_uva,
182
183
            use_ddp=use_ddp,
        )
184
        for input_nodes, output_nodes, blocks in dataloader:
185
            seed = output_nodes["A"].item()
186
            # Seed and neighbors are of different node types so slicing is not necessary here.
187
            neighbors = set(input_nodes["B"].cpu().numpy())
188
189
190
191
192
            if seed == 1:
                assert neighbors == {5, 6}
            elif seed == 0:
                assert neighbors == {1, 2}

193
            neighbors = set(input_nodes["C"].cpu().numpy())
194
195
196
197
            if seed == 1:
                assert neighbors == {7, 8}
            elif seed == 0:
                assert neighbors == {3, 4}
198

199
200
    if use_ddp:
        dist.destroy_process_group()
201

202

203
204
205
206
207
208
209
210
211
def _check_dtype(data, dtype, attr_name):
    if isinstance(data, dict):
        for k, v in data.items():
            assert getattr(v, attr_name) == dtype
    elif isinstance(data, list):
        for v in data:
            assert getattr(v, attr_name) == dtype
    else:
        assert getattr(data, attr_name) == dtype
212

213

214
215
216
217
218
219
220
221
222
223
def _check_device(data):
    if isinstance(data, dict):
        for k, v in data.items():
            assert v.device == F.ctx()
    elif isinstance(data, list):
        for v in data:
            assert v.device == F.ctx()
    else:
        assert data.device == F.ctx()

224

nv-dlasalle's avatar
nv-dlasalle committed
225
@parametrize_idtype
226
227
228
229
230
@pytest.mark.parametrize("sampler_name", ["full", "neighbor", "neighbor2"])
@pytest.mark.parametrize(
    "mode", ["cpu", "uva_cuda_indices", "uva_cpu_indices", "pure_gpu"]
)
@pytest.mark.parametrize("use_ddp", [False, True])
231
def test_node_dataloader(idtype, sampler_name, mode, use_ddp):
232
233
    if mode != "cpu" and F.ctx() == F.cpu():
        pytest.skip("UVA and GPU sampling require a GPU.")
234
    if use_ddp:
235
236
237
238
239
240
241
242
        if os.name == "nt":
            pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
        dist.init_process_group(
            "gloo" if F.ctx() == F.cpu() else "nccl",
            "tcp://127.0.0.1:12347",
            world_size=1,
            rank=0,
        )
243
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
244
245
246
    g1.ndata["feat"] = F.copy_to(F.randn((5, 8)), F.cpu())
    g1.ndata["label"] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu())
    if mode in ("cpu", "uva_cpu_indices"):
247
248
249
        indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cpu())
    else:
        indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cuda())
250
    if mode == "pure_gpu":
251
252
        g1 = g1.to(F.cuda())

253
    use_uva = mode.startswith("uva")
254
255

    sampler = {
256
257
258
259
260
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
        "neighbor2": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
    }[sampler_name]
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
261
        dataloader = dgl.dataloading.DataLoader(
262
263
264
265
            g1,
            indices,
            sampler,
            device=F.ctx(),
266
            batch_size=g1.num_nodes(),
267
268
            num_workers=num_workers,
            use_uva=use_uva,
269
270
            use_ddp=use_ddp,
        )
271
272
273
274
        for input_nodes, output_nodes, blocks in dataloader:
            _check_device(input_nodes)
            _check_device(output_nodes)
            _check_device(blocks)
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
            _check_dtype(input_nodes, idtype, "dtype")
            _check_dtype(output_nodes, idtype, "dtype")
            _check_dtype(blocks, idtype, "idtype")

    g2 = dgl.heterograph(
        {
            ("user", "follow", "user"): (
                [0, 0, 0, 1, 1, 1, 2],
                [1, 2, 3, 0, 2, 3, 0],
            ),
            ("user", "followed-by", "user"): (
                [1, 2, 3, 0, 2, 3, 0],
                [0, 0, 0, 1, 1, 1, 2],
            ),
            ("user", "play", "game"): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
            ("game", "played-by", "user"): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5]),
        }
    ).astype(idtype)
293
    for ntype in g2.ntypes:
294
295
296
297
        g2.nodes[ntype].data["feat"] = F.copy_to(
            F.randn((g2.num_nodes(ntype), 8)), F.cpu()
        )
    if mode in ("cpu", "uva_cpu_indices"):
298
299
300
        indices = {nty: F.copy_to(g2.nodes(nty), F.cpu()) for nty in g2.ntypes}
    else:
        indices = {nty: F.copy_to(g2.nodes(nty), F.cuda()) for nty in g2.ntypes}
301
    if mode == "pure_gpu":
302
        g2 = g2.to(F.cuda())
303

304
    batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
305
    sampler = {
306
307
308
309
310
311
312
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler(
            [{etype: 3 for etype in g2.etypes}] * 2
        ),
        "neighbor2": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
    }[sampler_name]
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
313
        dataloader = dgl.dataloading.DataLoader(
314
315
316
317
318
            g2,
            indices,
            sampler,
            device=F.ctx(),
            batch_size=batch_size,
319
320
            num_workers=num_workers,
            use_uva=use_uva,
321
322
            use_ddp=use_ddp,
        )
323
324
325
326
327
        assert isinstance(iter(dataloader), Iterator)
        for input_nodes, output_nodes, blocks in dataloader:
            _check_device(input_nodes)
            _check_device(output_nodes)
            _check_device(blocks)
328
329
330
            _check_dtype(input_nodes, idtype, "dtype")
            _check_dtype(output_nodes, idtype, "dtype")
            _check_dtype(blocks, idtype, "idtype")
331

332
333
    if use_ddp:
        dist.destroy_process_group()
334

335

336
@parametrize_idtype
337
338
339
340
341
342
343
344
345
346
347
@pytest.mark.parametrize("sampler_name", ["full", "neighbor"])
@pytest.mark.parametrize(
    "neg_sampler",
    [
        dgl.dataloading.negative_sampler.Uniform(2),
        dgl.dataloading.negative_sampler.GlobalUniform(15, False, 3),
        dgl.dataloading.negative_sampler.GlobalUniform(15, True, 3),
    ],
)
@pytest.mark.parametrize("mode", ["cpu", "uva", "pure_gpu"])
@pytest.mark.parametrize("use_ddp", [False, True])
348
def test_edge_dataloader(idtype, sampler_name, neg_sampler, mode, use_ddp):
349
350
351
352
353
    if mode != "cpu" and F.ctx() == F.cpu():
        pytest.skip("UVA and GPU sampling require a GPU.")
    if mode == "uva" and isinstance(
        neg_sampler, dgl.dataloading.negative_sampler.GlobalUniform
    ):
354
355
        pytest.skip("GlobalUniform don't support UVA yet.")
    if use_ddp:
356
357
358
359
360
361
362
363
        if os.name == "nt":
            pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
        dist.init_process_group(
            "gloo" if F.ctx() == F.cpu() else "nccl",
            "tcp://127.0.0.1:12347",
            world_size=1,
            rank=0,
        )
364
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
365
366
    g1.ndata["feat"] = F.copy_to(F.randn((5, 8)), F.cpu())
    if mode == "pure_gpu":
367
        g1 = g1.to(F.cuda())
368

369
    sampler = {
370
371
372
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
    }[sampler_name]
373

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
374
    # no negative sampler
375
376
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
    dataloader = dgl.dataloading.DataLoader(
377
378
379
380
381
382
383
384
        g1,
        g1.edges(form="eid"),
        edge_sampler,
        device=F.ctx(),
        batch_size=g1.num_edges(),
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
385
386
387
388
389
    for input_nodes, pos_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
390
    # negative sampler
391
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(
392
393
        sampler, negative_sampler=neg_sampler
    )
394
    dataloader = dgl.dataloading.DataLoader(
395
396
397
398
399
400
401
402
        g1,
        g1.edges(form="eid"),
        edge_sampler,
        device=F.ctx(),
        batch_size=g1.num_edges(),
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
403
404
405
406
407
408
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

409
410
411
412
413
414
415
416
417
418
419
420
421
422
    g2 = dgl.heterograph(
        {
            ("user", "follow", "user"): (
                [0, 0, 0, 1, 1, 1, 2],
                [1, 2, 3, 0, 2, 3, 0],
            ),
            ("user", "followed-by", "user"): (
                [1, 2, 3, 0, 2, 3, 0],
                [0, 0, 0, 1, 1, 1, 2],
            ),
            ("user", "play", "game"): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
            ("game", "played-by", "user"): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5]),
        }
    ).astype(idtype)
423
    for ntype in g2.ntypes:
424
425
426
427
        g2.nodes[ntype].data["feat"] = F.copy_to(
            F.randn((g2.num_nodes(ntype), 8)), F.cpu()
        )
    if mode == "pure_gpu":
428
429
        g2 = g2.to(F.cuda())

430
    batch_size = max(g2.num_edges(ety) for ety in g2.canonical_etypes)
431
    sampler = {
432
433
434
435
436
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler(
            [{etype: 3 for etype in g2.etypes}] * 2
        ),
    }[sampler_name]
437

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
438
    # no negative sampler
439
440
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
    dataloader = dgl.dataloading.DataLoader(
441
442
443
444
445
446
447
448
        g2,
        {ety: g2.edges(form="eid", etype=ety) for ety in g2.canonical_etypes},
        edge_sampler,
        device=F.ctx(),
        batch_size=batch_size,
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
449
    for input_nodes, pos_pair_graph, blocks in dataloader:
450
451
452
453
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
454
    # negative sampler
455
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(
456
457
        sampler, negative_sampler=neg_sampler
    )
458
    dataloader = dgl.dataloading.DataLoader(
459
460
461
462
463
464
465
466
        g2,
        {ety: g2.edges(form="eid", etype=ety) for ety in g2.canonical_etypes},
        edge_sampler,
        device=F.ctx(),
        batch_size=batch_size,
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
467

468
    assert isinstance(iter(dataloader), Iterator)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
469
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
470
471
472
473
474
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

475
476
    if use_ddp:
        dist.destroy_process_group()
477

478

479
480
481
482
483
484
def _create_homogeneous():
    s = torch.randint(0, 200, (1000,), device=F.ctx())
    d = torch.randint(0, 200, (1000,), device=F.ctx())
    src = torch.cat([s, d])
    dst = torch.cat([d, s])
    g = dgl.graph((s, d), num_nodes=200)
485
486
487
    reverse_eids = torch.cat(
        [torch.arange(1000, 2000), torch.arange(0, 1000)]
    ).to(F.ctx())
488
489
490
491
    always_exclude = torch.randint(0, 1000, (50,), device=F.ctx())
    seed_edges = torch.arange(0, 1000, device=F.ctx())
    return g, reverse_eids, always_exclude, seed_edges

492

493
494
def _create_heterogeneous():
    edges = {}
495
    for utype, etype, vtype in [("A", "AA", "A"), ("A", "AB", "B")]:
496
497
498
        s = torch.randint(0, 200, (1000,), device=F.ctx())
        d = torch.randint(0, 200, (1000,), device=F.ctx())
        edges[utype, etype, vtype] = (s, d)
499
500
501
502
503
504
505
506
        edges[vtype, "rev-" + etype, utype] = (d, s)
    g = dgl.heterograph(edges, num_nodes_dict={"A": 200, "B": 200})
    reverse_etypes = {
        "AA": "rev-AA",
        "AB": "rev-AB",
        "rev-AA": "AA",
        "rev-AB": "AB",
    }
507
    always_exclude = {
508
509
510
        "AA": torch.randint(0, 1000, (50,), device=F.ctx()),
        "AB": torch.randint(0, 1000, (50,), device=F.ctx()),
    }
511
    seed_edges = {
512
513
514
        "AA": torch.arange(0, 1000, device=F.ctx()),
        "AB": torch.arange(0, 1000, device=F.ctx()),
    }
515
516
    return g, reverse_etypes, always_exclude, seed_edges

517

518
519
520
def _find_edges_to_exclude(g, exclude, always_exclude, pair_eids):
    if exclude == None:
        return always_exclude
521
522
523
524
525
526
527
    elif exclude == "self":
        return (
            torch.cat([pair_eids, always_exclude])
            if always_exclude is not None
            else pair_eids
        )
    elif exclude == "reverse_id":
528
        pair_eids = torch.cat([pair_eids, pair_eids + 1000])
529
530
531
532
533
534
        return (
            torch.cat([pair_eids, always_exclude])
            if always_exclude is not None
            else pair_eids
        )
    elif exclude == "reverse_types":
535
        pair_eids = {g.to_canonical_etype(k): v for k, v in pair_eids.items()}
536
537
538
539
        if ("A", "AA", "A") in pair_eids:
            pair_eids[("A", "rev-AA", "A")] = pair_eids[("A", "AA", "A")]
        if ("A", "AB", "B") in pair_eids:
            pair_eids[("B", "rev-AB", "A")] = pair_eids[("A", "AB", "B")]
540
        if always_exclude is not None:
541
542
543
            always_exclude = {
                g.to_canonical_etype(k): v for k, v in always_exclude.items()
            }
544
545
546
547
548
549
550
            for k in always_exclude.keys():
                if k in pair_eids:
                    pair_eids[k] = torch.cat([pair_eids[k], always_exclude[k]])
                else:
                    pair_eids[k] = always_exclude[k]
        return pair_eids

551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567

@pytest.mark.parametrize("always_exclude_flag", [False, True])
@pytest.mark.parametrize(
    "exclude", [None, "self", "reverse_id", "reverse_types"]
)
@pytest.mark.parametrize(
    "sampler",
    [
        dgl.dataloading.MultiLayerFullNeighborSampler(1),
        dgl.dataloading.ShaDowKHopSampler([5]),
    ],
)
@pytest.mark.parametrize("batch_size", [1, 50])
def test_edge_dataloader_excludes(
    exclude, always_exclude_flag, batch_size, sampler
):
    if exclude == "reverse_types":
568
569
570
571
572
573
574
575
        g, reverse_etypes, always_exclude, seed_edges = _create_heterogeneous()
    else:
        g, reverse_eids, always_exclude, seed_edges = _create_homogeneous()
    g = g.to(F.ctx())
    if not always_exclude_flag:
        always_exclude = None

    kwargs = {}
576
577
578
579
580
581
582
583
584
    kwargs["exclude"] = (
        partial(_find_edges_to_exclude, g, exclude, always_exclude)
        if always_exclude_flag
        else exclude
    )
    kwargs["reverse_eids"] = reverse_eids if exclude == "reverse_id" else None
    kwargs["reverse_etypes"] = (
        reverse_etypes if exclude == "reverse_types" else None
    )
585
    sampler = dgl.dataloading.as_edge_prediction_sampler(sampler, **kwargs)
586

587
    dataloader = dgl.dataloading.DataLoader(
588
589
590
591
592
593
594
        g,
        seed_edges,
        sampler,
        batch_size=batch_size,
        device=F.ctx(),
        use_prefetch_thread=False,
    )
595
596
597
598
599
    for i, (input_nodes, pair_graph, blocks) in enumerate(dataloader):
        if isinstance(blocks, list):
            subg = blocks[0]
        else:
            subg = blocks
600
        pair_eids = pair_graph.edata[dgl.EID]
601
        block_eids = subg.edata[dgl.EID]
602

603
604
605
        edges_to_exclude = _find_edges_to_exclude(
            g, exclude, always_exclude, pair_eids
        )
606
607
        if edges_to_exclude is None:
            continue
608
609
610
611
612
613
        edges_to_exclude = dgl.utils.recursive_apply(
            edges_to_exclude, lambda x: x.cpu().numpy()
        )
        block_eids = dgl.utils.recursive_apply(
            block_eids, lambda x: x.cpu().numpy()
        )
614
615
616
617
618
619
620

        if isinstance(edges_to_exclude, Mapping):
            for k in edges_to_exclude.keys():
                assert not np.isin(edges_to_exclude[k], block_eids[k]).any()
        else:
            assert not np.isin(edges_to_exclude, block_eids).any()

621
622
623
        if i == 10:
            break

624

625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
def test_edge_dataloader_exclusion_without_all_reverses():
    data_dict = {
        ("A", "AB", "B"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
        ("B", "BA", "A"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
        ("B", "BC", "C"): (torch.tensor([0]), torch.tensor([0])),
        ("C", "CA", "A"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
    }
    g = dgl.heterograph(data_dict=data_dict)
    block_sampler = dgl.dataloading.MultiLayerNeighborSampler(
        fanouts=[1], replace=True
    )
    block_sampler = dgl.dataloading.as_edge_prediction_sampler(
        block_sampler,
        exclude="reverse_types",
        reverse_etypes={"AB": "BA"},
    )
    d = dgl.dataloading.DataLoader(
        graph=g,
        indices={
            "AB": torch.tensor([0]),
            "BC": torch.tensor([0]),
        },
        graph_sampler=block_sampler,
        batch_size=2,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        device=F.ctx(),
        use_ddp=False,
    )

    next(iter(d))


659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
def dummy_worker_init_fn(worker_id):
    pass


def test_dataloader_worker_init_fn():
    dataset = dgl.data.CoraFullDataset()
    g = dataset[0]
    sampler = dgl.dataloading.MultiLayerNeighborSampler([2])
    dataloader = dgl.dataloading.DataLoader(
        g,
        torch.arange(100),
        sampler,
        batch_size=4,
        num_workers=4,
        worker_init_fn=dummy_worker_init_fn,
    )
    for _ in dataloader:
        pass


679
680
681
682
683
if __name__ == "__main__":
    # test_node_dataloader(F.int32, 'neighbor', None)
    test_edge_dataloader_excludes(
        "reverse_types", False, 1, dgl.dataloading.ShaDowKHopSampler([5])
    )
684
    test_edge_dataloader_exclusion_without_all_reverses()