test_dataloader.py 21 KB
Newer Older
1
import os
2
3
4
5
6
7
import unittest
from collections.abc import Iterator, Mapping
from functools import partial

import backend as F

8
import dgl
9
import dgl.ops as OPS
10
11
import numpy as np
import pytest
12
import torch
13
import torch.distributed as dist
nv-dlasalle's avatar
nv-dlasalle committed
14
from test_utils import parametrize_idtype
15
16


17
@pytest.mark.parametrize("batch_size", [None, 16])
18
def test_graph_dataloader(batch_size):
19
    num_batches = 2
20
21
    num_samples = num_batches * (batch_size if batch_size is not None else 1)
    minigc_dataset = dgl.data.MiniGCDataset(num_samples, 10, 20)
22
23
24
    data_loader = dgl.dataloading.GraphDataLoader(
        minigc_dataset, batch_size=batch_size, shuffle=True
    )
25
    assert isinstance(iter(data_loader), Iterator)
26
27
    for graph, label in data_loader:
        assert isinstance(graph, dgl.DGLGraph)
28
29
30
31
32
33
        if batch_size is not None:
            assert F.asnumpy(label).shape[0] == batch_size
        else:
            # If batch size is None, the label element will be a single scalar following
            # PyTorch's practice.
            assert F.asnumpy(label).ndim == 0
34

35
36
37

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("num_workers", [0, 4])
38
39
40
def test_cluster_gcn(num_workers):
    dataset = dgl.data.CoraFullDataset()
    g = dataset[0]
41
42
    sampler = dgl.dataloading.ClusterGCNSampler(g, 100)
    dataloader = dgl.dataloading.DataLoader(
43
44
        g, torch.arange(100), sampler, batch_size=4, num_workers=num_workers
    )
45
46
47
    assert len(dataloader) == 25
    for i, sg in enumerate(dataloader):
        pass
48

49
50

@pytest.mark.parametrize("num_workers", [0, 4])
51
52
53
def test_shadow(num_workers):
    g = dgl.data.CoraFullDataset()[0]
    sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15])
54
    dataloader = dgl.dataloading.DataLoader(
55
56
57
58
59
60
61
62
        g,
        torch.arange(g.num_nodes()),
        sampler,
        batch_size=5,
        shuffle=True,
        drop_last=False,
        num_workers=num_workers,
    )
63
    for i, (input_nodes, output_nodes, subgraph) in enumerate(dataloader):
64
        assert torch.equal(input_nodes, subgraph.ndata[dgl.NID])
65
66
67
68
69
        assert torch.equal(input_nodes[: output_nodes.shape[0]], output_nodes)
        assert torch.equal(
            subgraph.ndata["label"], g.ndata["label"][input_nodes]
        )
        assert torch.equal(subgraph.ndata["feat"], g.ndata["feat"][input_nodes])
70
71
72
        if i == 5:
            break

73
74
75

@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("mode", ["node", "edge", "walk"])
76
77
78
def test_saint(num_workers, mode):
    g = dgl.data.CoraFullDataset()[0]

79
    if mode == "node":
80
        budget = 100
81
    elif mode == "edge":
82
        budget = 200
83
    elif mode == "walk":
84
85
86
87
        budget = (3, 2)

    sampler = dgl.dataloading.SAINTSampler(mode, budget)
    dataloader = dgl.dataloading.DataLoader(
88
89
        g, torch.arange(100), sampler, num_workers=num_workers
    )
90
91
92
    assert len(dataloader) == 100
    for sg in dataloader:
        pass
93

94

95
@parametrize_idtype
96
97
98
99
100
@pytest.mark.parametrize(
    "mode", ["cpu", "uva_cuda_indices", "uva_cpu_indices", "pure_gpu"]
)
@pytest.mark.parametrize("use_ddp", [False, True])
@pytest.mark.parametrize("use_mask", [False, True])
101
def test_neighbor_nonuniform(idtype, mode, use_ddp, use_mask):
102
103
104
105
    if mode != "cpu" and F.ctx() == F.cpu():
        pytest.skip("UVA and GPU sampling require a GPU.")
    if mode != "cpu" and use_mask:
        pytest.skip("Masked sampling only works on CPU.")
106
    if use_ddp:
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        if os.name == "nt":
            pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
        dist.init_process_group(
            "gloo" if F.ctx() == F.cpu() else "nccl",
            "tcp://127.0.0.1:12347",
            world_size=1,
            rank=0,
        )
    g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1])).astype(
        idtype
    )
    g.edata["p"] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
    g.edata["mask"] = g.edata["p"] != 0
    if mode in ("cpu", "uva_cpu_indices"):
121
122
123
        indices = F.copy_to(F.tensor([0, 1], idtype), F.cpu())
    else:
        indices = F.copy_to(F.tensor([0, 1], idtype), F.cuda())
124
    if mode == "pure_gpu":
125
        g = g.to(F.cuda())
126
    use_uva = mode.startswith("uva")
127

128
    if use_mask:
129
        prob, mask = None, "mask"
130
    else:
131
        prob, mask = "p", None
132

133
134
135
136
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [2], prob=prob, mask=mask
    )
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
137
        dataloader = dgl.dataloading.DataLoader(
138
139
140
141
142
            g,
            indices,
            sampler,
            batch_size=1,
            device=F.ctx(),
143
144
            num_workers=num_workers,
            use_uva=use_uva,
145
146
            use_ddp=use_ddp,
        )
147
148
149
150
151
152
153
        for input_nodes, output_nodes, blocks in dataloader:
            seed = output_nodes.item()
            neighbors = set(input_nodes[1:].cpu().numpy())
            if seed == 1:
                assert neighbors == {5, 6}
            elif seed == 0:
                assert neighbors == {1, 2}
154

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    g = dgl.heterograph(
        {
            ("B", "BA", "A"): (
                [1, 2, 3, 4, 5, 6, 7, 8],
                [0, 0, 0, 0, 1, 1, 1, 1],
            ),
            ("C", "CA", "A"): (
                [1, 2, 3, 4, 5, 6, 7, 8],
                [0, 0, 0, 0, 1, 1, 1, 1],
            ),
        }
    ).astype(idtype)
    g.edges["BA"].data["p"] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
    g.edges["BA"].data["mask"] = g.edges["BA"].data["p"] != 0
    g.edges["CA"].data["p"] = torch.FloatTensor([0, 0, 1, 1, 0, 0, 1, 1])
    g.edges["CA"].data["mask"] = g.edges["CA"].data["p"] != 0
    if mode == "pure_gpu":
172
        g = g.to(F.cuda())
173
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
174
        dataloader = dgl.dataloading.DataLoader(
175
176
177
178
179
            g,
            {"A": indices},
            sampler,
            batch_size=1,
            device=F.ctx(),
180
181
            num_workers=num_workers,
            use_uva=use_uva,
182
183
            use_ddp=use_ddp,
        )
184
        for input_nodes, output_nodes, blocks in dataloader:
185
            seed = output_nodes["A"].item()
186
            # Seed and neighbors are of different node types so slicing is not necessary here.
187
            neighbors = set(input_nodes["B"].cpu().numpy())
188
189
190
191
192
            if seed == 1:
                assert neighbors == {5, 6}
            elif seed == 0:
                assert neighbors == {1, 2}

193
            neighbors = set(input_nodes["C"].cpu().numpy())
194
195
196
197
            if seed == 1:
                assert neighbors == {7, 8}
            elif seed == 0:
                assert neighbors == {3, 4}
198

199
200
    if use_ddp:
        dist.destroy_process_group()
201

202

203
204
205
206
207
208
209
210
211
def _check_dtype(data, dtype, attr_name):
    if isinstance(data, dict):
        for k, v in data.items():
            assert getattr(v, attr_name) == dtype
    elif isinstance(data, list):
        for v in data:
            assert getattr(v, attr_name) == dtype
    else:
        assert getattr(data, attr_name) == dtype
212

213

214
215
216
217
218
219
220
221
222
223
def _check_device(data):
    if isinstance(data, dict):
        for k, v in data.items():
            assert v.device == F.ctx()
    elif isinstance(data, list):
        for v in data:
            assert v.device == F.ctx()
    else:
        assert data.device == F.ctx()

224

nv-dlasalle's avatar
nv-dlasalle committed
225
@parametrize_idtype
226
227
228
229
230
@pytest.mark.parametrize("sampler_name", ["full", "neighbor", "neighbor2"])
@pytest.mark.parametrize(
    "mode", ["cpu", "uva_cuda_indices", "uva_cpu_indices", "pure_gpu"]
)
@pytest.mark.parametrize("use_ddp", [False, True])
231
def test_node_dataloader(idtype, sampler_name, mode, use_ddp):
232
233
    if mode != "cpu" and F.ctx() == F.cpu():
        pytest.skip("UVA and GPU sampling require a GPU.")
234
    if use_ddp:
235
236
237
238
239
240
241
242
        if os.name == "nt":
            pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
        dist.init_process_group(
            "gloo" if F.ctx() == F.cpu() else "nccl",
            "tcp://127.0.0.1:12347",
            world_size=1,
            rank=0,
        )
243
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
244
245
246
    g1.ndata["feat"] = F.copy_to(F.randn((5, 8)), F.cpu())
    g1.ndata["label"] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu())
    if mode in ("cpu", "uva_cpu_indices"):
247
248
249
        indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cpu())
    else:
        indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cuda())
250
    if mode == "pure_gpu":
251
252
        g1 = g1.to(F.cuda())

253
    use_uva = mode.startswith("uva")
254
255

    sampler = {
256
257
258
259
260
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
        "neighbor2": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
    }[sampler_name]
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
261
        dataloader = dgl.dataloading.DataLoader(
262
263
264
265
            g1,
            indices,
            sampler,
            device=F.ctx(),
266
            batch_size=g1.num_nodes(),
267
268
            num_workers=num_workers,
            use_uva=use_uva,
269
270
            use_ddp=use_ddp,
        )
271
272
273
274
        for input_nodes, output_nodes, blocks in dataloader:
            _check_device(input_nodes)
            _check_device(output_nodes)
            _check_device(blocks)
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
            _check_dtype(input_nodes, idtype, "dtype")
            _check_dtype(output_nodes, idtype, "dtype")
            _check_dtype(blocks, idtype, "idtype")

    g2 = dgl.heterograph(
        {
            ("user", "follow", "user"): (
                [0, 0, 0, 1, 1, 1, 2],
                [1, 2, 3, 0, 2, 3, 0],
            ),
            ("user", "followed-by", "user"): (
                [1, 2, 3, 0, 2, 3, 0],
                [0, 0, 0, 1, 1, 1, 2],
            ),
            ("user", "play", "game"): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
            ("game", "played-by", "user"): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5]),
        }
    ).astype(idtype)
293
    for ntype in g2.ntypes:
294
295
296
297
        g2.nodes[ntype].data["feat"] = F.copy_to(
            F.randn((g2.num_nodes(ntype), 8)), F.cpu()
        )
    if mode in ("cpu", "uva_cpu_indices"):
298
299
300
        indices = {nty: F.copy_to(g2.nodes(nty), F.cpu()) for nty in g2.ntypes}
    else:
        indices = {nty: F.copy_to(g2.nodes(nty), F.cuda()) for nty in g2.ntypes}
301
    if mode == "pure_gpu":
302
        g2 = g2.to(F.cuda())
303

304
    batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
305
    sampler = {
306
307
308
309
310
311
312
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler(
            [{etype: 3 for etype in g2.etypes}] * 2
        ),
        "neighbor2": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
    }[sampler_name]
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
313
        dataloader = dgl.dataloading.DataLoader(
314
315
316
317
318
            g2,
            indices,
            sampler,
            device=F.ctx(),
            batch_size=batch_size,
319
320
            num_workers=num_workers,
            use_uva=use_uva,
321
322
            use_ddp=use_ddp,
        )
323
324
325
326
327
        assert isinstance(iter(dataloader), Iterator)
        for input_nodes, output_nodes, blocks in dataloader:
            _check_device(input_nodes)
            _check_device(output_nodes)
            _check_device(blocks)
328
329
330
            _check_dtype(input_nodes, idtype, "dtype")
            _check_dtype(output_nodes, idtype, "dtype")
            _check_dtype(blocks, idtype, "idtype")
331

332
333
    if use_ddp:
        dist.destroy_process_group()
334

335

336
@parametrize_idtype
337
338
339
340
341
342
343
344
345
346
347
@pytest.mark.parametrize("sampler_name", ["full", "neighbor"])
@pytest.mark.parametrize(
    "neg_sampler",
    [
        dgl.dataloading.negative_sampler.Uniform(2),
        dgl.dataloading.negative_sampler.GlobalUniform(15, False, 3),
        dgl.dataloading.negative_sampler.GlobalUniform(15, True, 3),
    ],
)
@pytest.mark.parametrize("mode", ["cpu", "uva", "pure_gpu"])
@pytest.mark.parametrize("use_ddp", [False, True])
348
def test_edge_dataloader(idtype, sampler_name, neg_sampler, mode, use_ddp):
349
350
351
352
353
    if mode != "cpu" and F.ctx() == F.cpu():
        pytest.skip("UVA and GPU sampling require a GPU.")
    if mode == "uva" and isinstance(
        neg_sampler, dgl.dataloading.negative_sampler.GlobalUniform
    ):
354
355
        pytest.skip("GlobalUniform don't support UVA yet.")
    if use_ddp:
356
357
358
359
360
361
362
363
        if os.name == "nt":
            pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
        dist.init_process_group(
            "gloo" if F.ctx() == F.cpu() else "nccl",
            "tcp://127.0.0.1:12347",
            world_size=1,
            rank=0,
        )
364
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
365
366
    g1.ndata["feat"] = F.copy_to(F.randn((5, 8)), F.cpu())
    if mode == "pure_gpu":
367
        g1 = g1.to(F.cuda())
368

369
    sampler = {
370
371
372
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
    }[sampler_name]
373

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
374
    # no negative sampler
375
376
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
    dataloader = dgl.dataloading.DataLoader(
377
378
379
380
381
382
383
384
        g1,
        g1.edges(form="eid"),
        edge_sampler,
        device=F.ctx(),
        batch_size=g1.num_edges(),
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
385
386
387
388
389
    for input_nodes, pos_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
390
    # negative sampler
391
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(
392
393
        sampler, negative_sampler=neg_sampler
    )
394
    dataloader = dgl.dataloading.DataLoader(
395
396
397
398
399
400
401
402
        g1,
        g1.edges(form="eid"),
        edge_sampler,
        device=F.ctx(),
        batch_size=g1.num_edges(),
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
403
404
405
406
407
408
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

409
410
411
412
413
414
415
416
417
418
419
420
421
422
    g2 = dgl.heterograph(
        {
            ("user", "follow", "user"): (
                [0, 0, 0, 1, 1, 1, 2],
                [1, 2, 3, 0, 2, 3, 0],
            ),
            ("user", "followed-by", "user"): (
                [1, 2, 3, 0, 2, 3, 0],
                [0, 0, 0, 1, 1, 1, 2],
            ),
            ("user", "play", "game"): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
            ("game", "played-by", "user"): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5]),
        }
    ).astype(idtype)
423
    for ntype in g2.ntypes:
424
425
426
427
        g2.nodes[ntype].data["feat"] = F.copy_to(
            F.randn((g2.num_nodes(ntype), 8)), F.cpu()
        )
    if mode == "pure_gpu":
428
429
        g2 = g2.to(F.cuda())

430
    batch_size = max(g2.num_edges(ety) for ety in g2.canonical_etypes)
431
    sampler = {
432
433
434
435
436
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler(
            [{etype: 3 for etype in g2.etypes}] * 2
        ),
    }[sampler_name]
437

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
438
    # no negative sampler
439
440
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
    dataloader = dgl.dataloading.DataLoader(
441
442
443
444
445
446
447
448
        g2,
        {ety: g2.edges(form="eid", etype=ety) for ety in g2.canonical_etypes},
        edge_sampler,
        device=F.ctx(),
        batch_size=batch_size,
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
449
    for input_nodes, pos_pair_graph, blocks in dataloader:
450
451
452
453
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
454
    # negative sampler
455
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(
456
457
        sampler, negative_sampler=neg_sampler
    )
458
    dataloader = dgl.dataloading.DataLoader(
459
460
461
462
463
464
465
466
        g2,
        {ety: g2.edges(form="eid", etype=ety) for ety in g2.canonical_etypes},
        edge_sampler,
        device=F.ctx(),
        batch_size=batch_size,
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
467

468
    assert isinstance(iter(dataloader), Iterator)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
469
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
470
471
472
473
474
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

475
476
    if use_ddp:
        dist.destroy_process_group()
477

478

479
480
481
482
483
484
def _create_homogeneous():
    s = torch.randint(0, 200, (1000,), device=F.ctx())
    d = torch.randint(0, 200, (1000,), device=F.ctx())
    src = torch.cat([s, d])
    dst = torch.cat([d, s])
    g = dgl.graph((s, d), num_nodes=200)
485
486
487
    reverse_eids = torch.cat(
        [torch.arange(1000, 2000), torch.arange(0, 1000)]
    ).to(F.ctx())
488
489
490
491
    always_exclude = torch.randint(0, 1000, (50,), device=F.ctx())
    seed_edges = torch.arange(0, 1000, device=F.ctx())
    return g, reverse_eids, always_exclude, seed_edges

492

493
494
def _create_heterogeneous():
    edges = {}
495
    for utype, etype, vtype in [("A", "AA", "A"), ("A", "AB", "B")]:
496
497
498
        s = torch.randint(0, 200, (1000,), device=F.ctx())
        d = torch.randint(0, 200, (1000,), device=F.ctx())
        edges[utype, etype, vtype] = (s, d)
499
500
501
502
503
504
505
506
        edges[vtype, "rev-" + etype, utype] = (d, s)
    g = dgl.heterograph(edges, num_nodes_dict={"A": 200, "B": 200})
    reverse_etypes = {
        "AA": "rev-AA",
        "AB": "rev-AB",
        "rev-AA": "AA",
        "rev-AB": "AB",
    }
507
    always_exclude = {
508
509
510
        "AA": torch.randint(0, 1000, (50,), device=F.ctx()),
        "AB": torch.randint(0, 1000, (50,), device=F.ctx()),
    }
511
    seed_edges = {
512
513
514
        "AA": torch.arange(0, 1000, device=F.ctx()),
        "AB": torch.arange(0, 1000, device=F.ctx()),
    }
515
516
    return g, reverse_etypes, always_exclude, seed_edges

517

518
519
520
def _find_edges_to_exclude(g, exclude, always_exclude, pair_eids):
    if exclude == None:
        return always_exclude
521
522
523
524
525
526
527
    elif exclude == "self":
        return (
            torch.cat([pair_eids, always_exclude])
            if always_exclude is not None
            else pair_eids
        )
    elif exclude == "reverse_id":
528
        pair_eids = torch.cat([pair_eids, pair_eids + 1000])
529
530
531
532
533
534
        return (
            torch.cat([pair_eids, always_exclude])
            if always_exclude is not None
            else pair_eids
        )
    elif exclude == "reverse_types":
535
        pair_eids = {g.to_canonical_etype(k): v for k, v in pair_eids.items()}
536
537
538
539
        if ("A", "AA", "A") in pair_eids:
            pair_eids[("A", "rev-AA", "A")] = pair_eids[("A", "AA", "A")]
        if ("A", "AB", "B") in pair_eids:
            pair_eids[("B", "rev-AB", "A")] = pair_eids[("A", "AB", "B")]
540
        if always_exclude is not None:
541
542
543
            always_exclude = {
                g.to_canonical_etype(k): v for k, v in always_exclude.items()
            }
544
545
546
547
548
549
550
            for k in always_exclude.keys():
                if k in pair_eids:
                    pair_eids[k] = torch.cat([pair_eids[k], always_exclude[k]])
                else:
                    pair_eids[k] = always_exclude[k]
        return pair_eids

551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567

@pytest.mark.parametrize("always_exclude_flag", [False, True])
@pytest.mark.parametrize(
    "exclude", [None, "self", "reverse_id", "reverse_types"]
)
@pytest.mark.parametrize(
    "sampler",
    [
        dgl.dataloading.MultiLayerFullNeighborSampler(1),
        dgl.dataloading.ShaDowKHopSampler([5]),
    ],
)
@pytest.mark.parametrize("batch_size", [1, 50])
def test_edge_dataloader_excludes(
    exclude, always_exclude_flag, batch_size, sampler
):
    if exclude == "reverse_types":
568
569
570
571
572
573
574
575
        g, reverse_etypes, always_exclude, seed_edges = _create_heterogeneous()
    else:
        g, reverse_eids, always_exclude, seed_edges = _create_homogeneous()
    g = g.to(F.ctx())
    if not always_exclude_flag:
        always_exclude = None

    kwargs = {}
576
577
578
579
580
581
582
583
584
    kwargs["exclude"] = (
        partial(_find_edges_to_exclude, g, exclude, always_exclude)
        if always_exclude_flag
        else exclude
    )
    kwargs["reverse_eids"] = reverse_eids if exclude == "reverse_id" else None
    kwargs["reverse_etypes"] = (
        reverse_etypes if exclude == "reverse_types" else None
    )
585
    sampler = dgl.dataloading.as_edge_prediction_sampler(sampler, **kwargs)
586

587
    dataloader = dgl.dataloading.DataLoader(
588
589
590
591
592
593
594
        g,
        seed_edges,
        sampler,
        batch_size=batch_size,
        device=F.ctx(),
        use_prefetch_thread=False,
    )
595
596
597
598
599
    for i, (input_nodes, pair_graph, blocks) in enumerate(dataloader):
        if isinstance(blocks, list):
            subg = blocks[0]
        else:
            subg = blocks
600
        pair_eids = pair_graph.edata[dgl.EID]
601
        block_eids = subg.edata[dgl.EID]
602

603
604
605
        edges_to_exclude = _find_edges_to_exclude(
            g, exclude, always_exclude, pair_eids
        )
606
607
        if edges_to_exclude is None:
            continue
608
609
610
611
612
613
        edges_to_exclude = dgl.utils.recursive_apply(
            edges_to_exclude, lambda x: x.cpu().numpy()
        )
        block_eids = dgl.utils.recursive_apply(
            block_eids, lambda x: x.cpu().numpy()
        )
614
615
616
617
618
619
620

        if isinstance(edges_to_exclude, Mapping):
            for k in edges_to_exclude.keys():
                assert not np.isin(edges_to_exclude[k], block_eids[k]).any()
        else:
            assert not np.isin(edges_to_exclude, block_eids).any()

621
622
623
        if i == 10:
            break

624
625
626
627
628
629

if __name__ == "__main__":
    # test_node_dataloader(F.int32, 'neighbor', None)
    test_edge_dataloader_excludes(
        "reverse_types", False, 1, dgl.dataloading.ShaDowKHopSampler([5])
    )