test_dataloader.py 27.3 KB
Newer Older
1
import os
2
3
4
5
6
7
import unittest
from collections.abc import Iterator, Mapping
from functools import partial

import backend as F

8
import dgl
9
import dgl.ops as OPS
10
11
import numpy as np
import pytest
12
import torch
13
import torch.distributed as dist
14
import torch.multiprocessing as mp
15
from utils import parametrize_idtype
16
17


18
@pytest.mark.parametrize("batch_size", [None, 16])
19
def test_graph_dataloader(batch_size):
20
    num_batches = 2
21
22
    num_samples = num_batches * (batch_size if batch_size is not None else 1)
    minigc_dataset = dgl.data.MiniGCDataset(num_samples, 10, 20)
23
24
25
    data_loader = dgl.dataloading.GraphDataLoader(
        minigc_dataset, batch_size=batch_size, shuffle=True
    )
26
    assert isinstance(iter(data_loader), Iterator)
27
28
    for graph, label in data_loader:
        assert isinstance(graph, dgl.DGLGraph)
29
30
31
32
33
34
        if batch_size is not None:
            assert F.asnumpy(label).shape[0] == batch_size
        else:
            # If batch size is None, the label element will be a single scalar following
            # PyTorch's practice.
            assert F.asnumpy(label).ndim == 0
35

36
37
38

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("num_workers", [0, 4])
39
40
41
def test_cluster_gcn(num_workers):
    dataset = dgl.data.CoraFullDataset()
    g = dataset[0]
42
43
    sampler = dgl.dataloading.ClusterGCNSampler(g, 100)
    dataloader = dgl.dataloading.DataLoader(
44
45
        g, torch.arange(100), sampler, batch_size=4, num_workers=num_workers
    )
46
47
48
    assert len(dataloader) == 25
    for i, sg in enumerate(dataloader):
        pass
49

50
51

@pytest.mark.parametrize("num_workers", [0, 4])
52
53
54
def test_shadow(num_workers):
    g = dgl.data.CoraFullDataset()[0]
    sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15])
55
    dataloader = dgl.dataloading.DataLoader(
56
57
58
59
60
61
62
63
        g,
        torch.arange(g.num_nodes()),
        sampler,
        batch_size=5,
        shuffle=True,
        drop_last=False,
        num_workers=num_workers,
    )
64
    for i, (input_nodes, output_nodes, subgraph) in enumerate(dataloader):
65
        assert torch.equal(input_nodes, subgraph.ndata[dgl.NID])
66
67
68
69
70
        assert torch.equal(input_nodes[: output_nodes.shape[0]], output_nodes)
        assert torch.equal(
            subgraph.ndata["label"], g.ndata["label"][input_nodes]
        )
        assert torch.equal(subgraph.ndata["feat"], g.ndata["feat"][input_nodes])
71
72
73
        if i == 5:
            break

74
75
76

@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("mode", ["node", "edge", "walk"])
77
78
79
def test_saint(num_workers, mode):
    g = dgl.data.CoraFullDataset()[0]

80
    if mode == "node":
81
        budget = 100
82
    elif mode == "edge":
83
        budget = 200
84
    elif mode == "walk":
85
86
87
88
        budget = (3, 2)

    sampler = dgl.dataloading.SAINTSampler(mode, budget)
    dataloader = dgl.dataloading.DataLoader(
89
90
        g, torch.arange(100), sampler, num_workers=num_workers
    )
91
92
93
    assert len(dataloader) == 100
    for sg in dataloader:
        pass
94

95

96
@parametrize_idtype
97
98
99
100
101
@pytest.mark.parametrize(
    "mode", ["cpu", "uva_cuda_indices", "uva_cpu_indices", "pure_gpu"]
)
@pytest.mark.parametrize("use_ddp", [False, True])
@pytest.mark.parametrize("use_mask", [False, True])
102
def test_neighbor_nonuniform(idtype, mode, use_ddp, use_mask):
103
104
105
106
    if mode != "cpu" and F.ctx() == F.cpu():
        pytest.skip("UVA and GPU sampling require a GPU.")
    if mode != "cpu" and use_mask:
        pytest.skip("Masked sampling only works on CPU.")
107
    if use_ddp:
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        if os.name == "nt":
            pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
        dist.init_process_group(
            "gloo" if F.ctx() == F.cpu() else "nccl",
            "tcp://127.0.0.1:12347",
            world_size=1,
            rank=0,
        )
    g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1])).astype(
        idtype
    )
    g.edata["p"] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
    g.edata["mask"] = g.edata["p"] != 0
    if mode in ("cpu", "uva_cpu_indices"):
122
123
124
        indices = F.copy_to(F.tensor([0, 1], idtype), F.cpu())
    else:
        indices = F.copy_to(F.tensor([0, 1], idtype), F.cuda())
125
    if mode == "pure_gpu":
126
        g = g.to(F.cuda())
127
    use_uva = mode.startswith("uva")
128

129
    if use_mask:
130
        prob, mask = None, "mask"
131
    else:
132
        prob, mask = "p", None
133

134
135
136
137
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [2], prob=prob, mask=mask
    )
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
138
        dataloader = dgl.dataloading.DataLoader(
139
140
141
142
143
            g,
            indices,
            sampler,
            batch_size=1,
            device=F.ctx(),
144
145
            num_workers=num_workers,
            use_uva=use_uva,
146
147
            use_ddp=use_ddp,
        )
148
149
150
151
152
153
154
        for input_nodes, output_nodes, blocks in dataloader:
            seed = output_nodes.item()
            neighbors = set(input_nodes[1:].cpu().numpy())
            if seed == 1:
                assert neighbors == {5, 6}
            elif seed == 0:
                assert neighbors == {1, 2}
155

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    g = dgl.heterograph(
        {
            ("B", "BA", "A"): (
                [1, 2, 3, 4, 5, 6, 7, 8],
                [0, 0, 0, 0, 1, 1, 1, 1],
            ),
            ("C", "CA", "A"): (
                [1, 2, 3, 4, 5, 6, 7, 8],
                [0, 0, 0, 0, 1, 1, 1, 1],
            ),
        }
    ).astype(idtype)
    g.edges["BA"].data["p"] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
    g.edges["BA"].data["mask"] = g.edges["BA"].data["p"] != 0
    g.edges["CA"].data["p"] = torch.FloatTensor([0, 0, 1, 1, 0, 0, 1, 1])
    g.edges["CA"].data["mask"] = g.edges["CA"].data["p"] != 0
    if mode == "pure_gpu":
173
        g = g.to(F.cuda())
174
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
175
        dataloader = dgl.dataloading.DataLoader(
176
177
178
179
180
            g,
            {"A": indices},
            sampler,
            batch_size=1,
            device=F.ctx(),
181
182
            num_workers=num_workers,
            use_uva=use_uva,
183
184
            use_ddp=use_ddp,
        )
185
        for input_nodes, output_nodes, blocks in dataloader:
186
            seed = output_nodes["A"].item()
187
            # Seed and neighbors are of different node types so slicing is not necessary here.
188
            neighbors = set(input_nodes["B"].cpu().numpy())
189
190
191
192
193
            if seed == 1:
                assert neighbors == {5, 6}
            elif seed == 0:
                assert neighbors == {1, 2}

194
            neighbors = set(input_nodes["C"].cpu().numpy())
195
196
197
198
            if seed == 1:
                assert neighbors == {7, 8}
            elif seed == 0:
                assert neighbors == {3, 4}
199

200
201
    if use_ddp:
        dist.destroy_process_group()
202

203

204
205
206
207
208
209
210
211
212
def _check_dtype(data, dtype, attr_name):
    if isinstance(data, dict):
        for k, v in data.items():
            assert getattr(v, attr_name) == dtype
    elif isinstance(data, list):
        for v in data:
            assert getattr(v, attr_name) == dtype
    else:
        assert getattr(data, attr_name) == dtype
213

214

215
216
217
218
219
220
221
222
223
224
def _check_device(data):
    if isinstance(data, dict):
        for k, v in data.items():
            assert v.device == F.ctx()
    elif isinstance(data, list):
        for v in data:
            assert v.device == F.ctx()
    else:
        assert data.device == F.ctx()

225

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
@pytest.mark.parametrize("sampler_name", ["full", "neighbor"])
@pytest.mark.parametrize(
    "mode", ["cpu", "uva_cuda_indices", "uva_cpu_indices", "pure_gpu"]
)
@pytest.mark.parametrize("nprocs", [1, 4])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ddp_dataloader_decompose_dataset(
    sampler_name, mode, nprocs, drop_last
):
    if torch.cuda.device_count() < nprocs and mode != "cpu":
        pytest.skip(
            "DDP dataloader needs sufficient GPUs for UVA and GPU sampling."
        )
    if mode != "cpu" and F.ctx() == F.cpu():
        pytest.skip("UVA and GPU sampling require a GPU.")

    if os.name == "nt":
        pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
    g, _, _, _ = _create_homogeneous()
    g = g.to(F.cpu())

    sampler = {
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
    }[sampler_name]
    indices = F.copy_to(F.arange(0, g.num_nodes()), F.cpu())
    data = indices, sampler
    arguments = mode, drop_last
    g.create_formats_()
    os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // nprocs)
    mp.spawn(_ddp_runner, args=(nprocs, g, data, arguments), nprocs=nprocs)


def _ddp_runner(proc_id, nprocs, g, data, args):
    mode, drop_last = args
    indices, sampler = data
    if mode == "cpu":
        device = torch.device("cpu")
    else:
        device = torch.device(proc_id)
        torch.cuda.set_device(device)
    if mode == "pure_gpu":
        g = g.to(F.cuda())
    if mode in ("cpu", "uva_cpu_indices"):
        indices = indices.cpu()
    else:
        indices = indices.cuda()

    dist.init_process_group(
        "nccl" if mode != "cpu" else "gloo",
        "tcp://127.0.0.1:12347",
        world_size=nprocs,
        rank=proc_id,
    )
    use_uva = mode.startswith("uva")
    batch_size = g.num_nodes()
    shuffle = False
    for num_workers in [1, 4] if mode == "cpu" else [0]:
        dataloader = dgl.dataloading.DataLoader(
            g,
            indices,
            sampler,
            device=device,
            batch_size=batch_size,  # g1.num_nodes(),
            num_workers=num_workers,
            use_uva=use_uva,
            use_ddp=True,
            drop_last=drop_last,
            shuffle=shuffle,
        )
        max_nid = [0]
        for i, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
            block = blocks[-1]
            o_src, o_dst = block.edges()
            src_nodes_id = block.srcdata[dgl.NID][o_src]
            dst_nodes_id = block.dstdata[dgl.NID][o_dst]
            max_nid.append(np.max(dst_nodes_id.cpu().numpy()))

        local_max = torch.tensor(np.max(max_nid))
        if torch.distributed.get_backend() == "nccl":
            local_max = local_max.cuda()
        dist.reduce(local_max, 0, op=dist.ReduceOp.MAX)
        if proc_id == 0:
            if drop_last and not shuffle and local_max > 0:
                assert (
                    local_max.item()
                    == len(indices)
                    - len(indices) % nprocs
                    - 1
                    - (len(indices) // nprocs) % batch_size
                )
            elif not drop_last:
                assert local_max == len(indices) - 1
    dist.destroy_process_group()


nv-dlasalle's avatar
nv-dlasalle committed
322
@parametrize_idtype
323
324
325
326
327
@pytest.mark.parametrize("sampler_name", ["full", "neighbor", "neighbor2"])
@pytest.mark.parametrize(
    "mode", ["cpu", "uva_cuda_indices", "uva_cpu_indices", "pure_gpu"]
)
@pytest.mark.parametrize("use_ddp", [False, True])
328
def test_node_dataloader(idtype, sampler_name, mode, use_ddp):
329
330
    if mode != "cpu" and F.ctx() == F.cpu():
        pytest.skip("UVA and GPU sampling require a GPU.")
331
    if use_ddp:
332
333
334
335
336
337
338
339
        if os.name == "nt":
            pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
        dist.init_process_group(
            "gloo" if F.ctx() == F.cpu() else "nccl",
            "tcp://127.0.0.1:12347",
            world_size=1,
            rank=0,
        )
340
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
341
342
343
    g1.ndata["feat"] = F.copy_to(F.randn((5, 8)), F.cpu())
    g1.ndata["label"] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu())
    if mode in ("cpu", "uva_cpu_indices"):
344
345
346
        indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cpu())
    else:
        indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cuda())
347
    if mode == "pure_gpu":
348
349
        g1 = g1.to(F.cuda())

350
    use_uva = mode.startswith("uva")
351
352

    sampler = {
353
354
355
356
357
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
        "neighbor2": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
    }[sampler_name]
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
358
        dataloader = dgl.dataloading.DataLoader(
359
360
361
362
            g1,
            indices,
            sampler,
            device=F.ctx(),
363
            batch_size=g1.num_nodes(),
364
365
            num_workers=num_workers,
            use_uva=use_uva,
366
367
            use_ddp=use_ddp,
        )
368
369
370
371
        for input_nodes, output_nodes, blocks in dataloader:
            _check_device(input_nodes)
            _check_device(output_nodes)
            _check_device(blocks)
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
            _check_dtype(input_nodes, idtype, "dtype")
            _check_dtype(output_nodes, idtype, "dtype")
            _check_dtype(blocks, idtype, "idtype")

    g2 = dgl.heterograph(
        {
            ("user", "follow", "user"): (
                [0, 0, 0, 1, 1, 1, 2],
                [1, 2, 3, 0, 2, 3, 0],
            ),
            ("user", "followed-by", "user"): (
                [1, 2, 3, 0, 2, 3, 0],
                [0, 0, 0, 1, 1, 1, 2],
            ),
            ("user", "play", "game"): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
            ("game", "played-by", "user"): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5]),
        }
    ).astype(idtype)
390
    for ntype in g2.ntypes:
391
392
393
394
        g2.nodes[ntype].data["feat"] = F.copy_to(
            F.randn((g2.num_nodes(ntype), 8)), F.cpu()
        )
    if mode in ("cpu", "uva_cpu_indices"):
395
396
397
        indices = {nty: F.copy_to(g2.nodes(nty), F.cpu()) for nty in g2.ntypes}
    else:
        indices = {nty: F.copy_to(g2.nodes(nty), F.cuda()) for nty in g2.ntypes}
398
    if mode == "pure_gpu":
399
        g2 = g2.to(F.cuda())
400

401
    batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
402
    sampler = {
403
404
405
406
407
408
409
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler(
            [{etype: 3 for etype in g2.etypes}] * 2
        ),
        "neighbor2": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
    }[sampler_name]
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
410
        dataloader = dgl.dataloading.DataLoader(
411
412
413
414
415
            g2,
            indices,
            sampler,
            device=F.ctx(),
            batch_size=batch_size,
416
417
            num_workers=num_workers,
            use_uva=use_uva,
418
419
            use_ddp=use_ddp,
        )
420
421
422
423
424
        assert isinstance(iter(dataloader), Iterator)
        for input_nodes, output_nodes, blocks in dataloader:
            _check_device(input_nodes)
            _check_device(output_nodes)
            _check_device(blocks)
425
426
427
            _check_dtype(input_nodes, idtype, "dtype")
            _check_dtype(output_nodes, idtype, "dtype")
            _check_dtype(blocks, idtype, "idtype")
428

429
430
    if use_ddp:
        dist.destroy_process_group()
431

432

433
@parametrize_idtype
434
435
436
437
438
439
440
441
442
443
444
@pytest.mark.parametrize("sampler_name", ["full", "neighbor"])
@pytest.mark.parametrize(
    "neg_sampler",
    [
        dgl.dataloading.negative_sampler.Uniform(2),
        dgl.dataloading.negative_sampler.GlobalUniform(15, False, 3),
        dgl.dataloading.negative_sampler.GlobalUniform(15, True, 3),
    ],
)
@pytest.mark.parametrize("mode", ["cpu", "uva", "pure_gpu"])
@pytest.mark.parametrize("use_ddp", [False, True])
445
def test_edge_dataloader(idtype, sampler_name, neg_sampler, mode, use_ddp):
446
447
448
449
450
    if mode != "cpu" and F.ctx() == F.cpu():
        pytest.skip("UVA and GPU sampling require a GPU.")
    if mode == "uva" and isinstance(
        neg_sampler, dgl.dataloading.negative_sampler.GlobalUniform
    ):
451
452
        pytest.skip("GlobalUniform don't support UVA yet.")
    if use_ddp:
453
454
455
456
457
458
459
460
        if os.name == "nt":
            pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
        dist.init_process_group(
            "gloo" if F.ctx() == F.cpu() else "nccl",
            "tcp://127.0.0.1:12347",
            world_size=1,
            rank=0,
        )
461
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
462
463
    g1.ndata["feat"] = F.copy_to(F.randn((5, 8)), F.cpu())
    if mode == "pure_gpu":
464
        g1 = g1.to(F.cuda())
465

466
    sampler = {
467
468
469
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
    }[sampler_name]
470

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
471
    # no negative sampler
472
473
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
    dataloader = dgl.dataloading.DataLoader(
474
475
476
477
478
479
480
481
        g1,
        g1.edges(form="eid"),
        edge_sampler,
        device=F.ctx(),
        batch_size=g1.num_edges(),
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
482
483
484
485
486
    for input_nodes, pos_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
487
    # negative sampler
488
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(
489
490
        sampler, negative_sampler=neg_sampler
    )
491
    dataloader = dgl.dataloading.DataLoader(
492
493
494
495
496
497
498
499
        g1,
        g1.edges(form="eid"),
        edge_sampler,
        device=F.ctx(),
        batch_size=g1.num_edges(),
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
500
501
502
503
504
505
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

506
507
508
509
510
511
512
513
514
515
516
517
518
519
    g2 = dgl.heterograph(
        {
            ("user", "follow", "user"): (
                [0, 0, 0, 1, 1, 1, 2],
                [1, 2, 3, 0, 2, 3, 0],
            ),
            ("user", "followed-by", "user"): (
                [1, 2, 3, 0, 2, 3, 0],
                [0, 0, 0, 1, 1, 1, 2],
            ),
            ("user", "play", "game"): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
            ("game", "played-by", "user"): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5]),
        }
    ).astype(idtype)
520
    for ntype in g2.ntypes:
521
522
523
524
        g2.nodes[ntype].data["feat"] = F.copy_to(
            F.randn((g2.num_nodes(ntype), 8)), F.cpu()
        )
    if mode == "pure_gpu":
525
526
        g2 = g2.to(F.cuda())

527
    batch_size = max(g2.num_edges(ety) for ety in g2.canonical_etypes)
528
    sampler = {
529
530
531
532
533
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler(
            [{etype: 3 for etype in g2.etypes}] * 2
        ),
    }[sampler_name]
534

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
535
    # no negative sampler
536
537
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
    dataloader = dgl.dataloading.DataLoader(
538
539
540
541
542
543
544
545
        g2,
        {ety: g2.edges(form="eid", etype=ety) for ety in g2.canonical_etypes},
        edge_sampler,
        device=F.ctx(),
        batch_size=batch_size,
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
546
    for input_nodes, pos_pair_graph, blocks in dataloader:
547
548
549
550
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
551
    # negative sampler
552
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(
553
554
        sampler, negative_sampler=neg_sampler
    )
555
    dataloader = dgl.dataloading.DataLoader(
556
557
558
559
560
561
562
563
        g2,
        {ety: g2.edges(form="eid", etype=ety) for ety in g2.canonical_etypes},
        edge_sampler,
        device=F.ctx(),
        batch_size=batch_size,
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
564

565
    assert isinstance(iter(dataloader), Iterator)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
566
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
567
568
569
570
571
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

572
573
    if use_ddp:
        dist.destroy_process_group()
574

575

576
577
578
579
580
581
def _create_homogeneous():
    s = torch.randint(0, 200, (1000,), device=F.ctx())
    d = torch.randint(0, 200, (1000,), device=F.ctx())
    src = torch.cat([s, d])
    dst = torch.cat([d, s])
    g = dgl.graph((s, d), num_nodes=200)
582
583
584
    reverse_eids = torch.cat(
        [torch.arange(1000, 2000), torch.arange(0, 1000)]
    ).to(F.ctx())
585
586
587
588
    always_exclude = torch.randint(0, 1000, (50,), device=F.ctx())
    seed_edges = torch.arange(0, 1000, device=F.ctx())
    return g, reverse_eids, always_exclude, seed_edges

589

590
591
def _create_heterogeneous():
    edges = {}
592
    for utype, etype, vtype in [("A", "AA", "A"), ("A", "AB", "B")]:
593
594
595
        s = torch.randint(0, 200, (1000,), device=F.ctx())
        d = torch.randint(0, 200, (1000,), device=F.ctx())
        edges[utype, etype, vtype] = (s, d)
596
597
598
599
600
601
602
603
        edges[vtype, "rev-" + etype, utype] = (d, s)
    g = dgl.heterograph(edges, num_nodes_dict={"A": 200, "B": 200})
    reverse_etypes = {
        "AA": "rev-AA",
        "AB": "rev-AB",
        "rev-AA": "AA",
        "rev-AB": "AB",
    }
604
    always_exclude = {
605
606
607
        "AA": torch.randint(0, 1000, (50,), device=F.ctx()),
        "AB": torch.randint(0, 1000, (50,), device=F.ctx()),
    }
608
    seed_edges = {
609
610
611
        "AA": torch.arange(0, 1000, device=F.ctx()),
        "AB": torch.arange(0, 1000, device=F.ctx()),
    }
612
613
    return g, reverse_etypes, always_exclude, seed_edges

614

615
616
617
618
619
def _remove_duplicates(s, d):
    s, d = list(zip(*list(set(zip(s.tolist(), d.tolist())))))
    return torch.tensor(s, device=F.ctx()), torch.tensor(d, device=F.ctx())


620
621
622
def _find_edges_to_exclude(g, exclude, always_exclude, pair_eids):
    if exclude == None:
        return always_exclude
623
624
625
626
627
628
629
    elif exclude == "self":
        return (
            torch.cat([pair_eids, always_exclude])
            if always_exclude is not None
            else pair_eids
        )
    elif exclude == "reverse_id":
630
        pair_eids = torch.cat([pair_eids, pair_eids + 1000])
631
632
633
634
635
636
        return (
            torch.cat([pair_eids, always_exclude])
            if always_exclude is not None
            else pair_eids
        )
    elif exclude == "reverse_types":
637
        pair_eids = {g.to_canonical_etype(k): v for k, v in pair_eids.items()}
638
639
640
641
        if ("A", "AA", "A") in pair_eids:
            pair_eids[("A", "rev-AA", "A")] = pair_eids[("A", "AA", "A")]
        if ("A", "AB", "B") in pair_eids:
            pair_eids[("B", "rev-AB", "A")] = pair_eids[("A", "AB", "B")]
642
        if always_exclude is not None:
643
644
645
            always_exclude = {
                g.to_canonical_etype(k): v for k, v in always_exclude.items()
            }
646
647
648
649
650
651
652
            for k in always_exclude.keys():
                if k in pair_eids:
                    pair_eids[k] = torch.cat([pair_eids[k], always_exclude[k]])
                else:
                    pair_eids[k] = always_exclude[k]
        return pair_eids

653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669

@pytest.mark.parametrize("always_exclude_flag", [False, True])
@pytest.mark.parametrize(
    "exclude", [None, "self", "reverse_id", "reverse_types"]
)
@pytest.mark.parametrize(
    "sampler",
    [
        dgl.dataloading.MultiLayerFullNeighborSampler(1),
        dgl.dataloading.ShaDowKHopSampler([5]),
    ],
)
@pytest.mark.parametrize("batch_size", [1, 50])
def test_edge_dataloader_excludes(
    exclude, always_exclude_flag, batch_size, sampler
):
    if exclude == "reverse_types":
670
671
672
673
674
675
676
677
        g, reverse_etypes, always_exclude, seed_edges = _create_heterogeneous()
    else:
        g, reverse_eids, always_exclude, seed_edges = _create_homogeneous()
    g = g.to(F.ctx())
    if not always_exclude_flag:
        always_exclude = None

    kwargs = {}
678
679
680
681
682
683
684
685
686
    kwargs["exclude"] = (
        partial(_find_edges_to_exclude, g, exclude, always_exclude)
        if always_exclude_flag
        else exclude
    )
    kwargs["reverse_eids"] = reverse_eids if exclude == "reverse_id" else None
    kwargs["reverse_etypes"] = (
        reverse_etypes if exclude == "reverse_types" else None
    )
687
    sampler = dgl.dataloading.as_edge_prediction_sampler(sampler, **kwargs)
688

689
    dataloader = dgl.dataloading.DataLoader(
690
691
692
693
694
695
696
        g,
        seed_edges,
        sampler,
        batch_size=batch_size,
        device=F.ctx(),
        use_prefetch_thread=False,
    )
697
698
699
700
701
    for i, (input_nodes, pair_graph, blocks) in enumerate(dataloader):
        if isinstance(blocks, list):
            subg = blocks[0]
        else:
            subg = blocks
702
        pair_eids = pair_graph.edata[dgl.EID]
703
        block_eids = subg.edata[dgl.EID]
704

705
706
707
        edges_to_exclude = _find_edges_to_exclude(
            g, exclude, always_exclude, pair_eids
        )
708
709
        if edges_to_exclude is None:
            continue
710
711
712
713
714
715
        edges_to_exclude = dgl.utils.recursive_apply(
            edges_to_exclude, lambda x: x.cpu().numpy()
        )
        block_eids = dgl.utils.recursive_apply(
            block_eids, lambda x: x.cpu().numpy()
        )
716
717
718
719
720
721
722

        if isinstance(edges_to_exclude, Mapping):
            for k in edges_to_exclude.keys():
                assert not np.isin(edges_to_exclude[k], block_eids[k]).any()
        else:
            assert not np.isin(edges_to_exclude, block_eids).any()

723
724
725
        if i == 10:
            break

726

727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
def test_edge_dataloader_exclusion_with_reverse_seed_nodes():
    utype, etype, vtype = ("A", "AB", "B")
    s = torch.randint(0, 20, (500,), device=F.ctx())
    d = torch.randint(0, 20, (500,), device=F.ctx())
    s, d = _remove_duplicates(s, d)
    g = dgl.heterograph({("A", "AB", "B"): (s, d), ("B", "BA", "A"): (d, s)})
    sampler = dgl.dataloading.as_edge_prediction_sampler(
        dgl.dataloading.NeighborSampler(fanouts=[2, 2, 2]),
        exclude="reverse_types",
        reverse_etypes={"AB": "BA", "BA": "AB"},
    )
    seed_edges = {
        "AB": torch.arange(g.number_of_edges("AB"), device=F.ctx()),
        "BA": torch.arange(g.number_of_edges("BA"), device=F.ctx()),
    }
    dataloader = dgl.dataloading.DataLoader(
        g,
        seed_edges,
        sampler,
        batch_size=2,
        device=F.ctx(),
        shuffle=True,
        drop_last=False,
    )
    for _, pos_graph, mfgs in dataloader:
        s, d = pos_graph["AB"].edges()
        AB_pos = list(zip(s.tolist(), d.tolist()))
        s, d = pos_graph["BA"].edges()
        BA_pos = list(zip(s.tolist(), d.tolist()))

        s, d = mfgs[-1]["AB"].edges()
        AB_mfg = list(zip(s.tolist(), d.tolist()))
        s, d = mfgs[-1]["BA"].edges()
        BA_mfg = list(zip(s.tolist(), d.tolist()))

        assert all(edge not in AB_mfg for edge in AB_pos)
        assert all(edge not in BA_mfg for edge in BA_pos)


766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
def test_edge_dataloader_exclusion_without_all_reverses():
    data_dict = {
        ("A", "AB", "B"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
        ("B", "BA", "A"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
        ("B", "BC", "C"): (torch.tensor([0]), torch.tensor([0])),
        ("C", "CA", "A"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
    }
    g = dgl.heterograph(data_dict=data_dict)
    block_sampler = dgl.dataloading.MultiLayerNeighborSampler(
        fanouts=[1], replace=True
    )
    block_sampler = dgl.dataloading.as_edge_prediction_sampler(
        block_sampler,
        exclude="reverse_types",
        reverse_etypes={"AB": "BA"},
    )
    d = dgl.dataloading.DataLoader(
        graph=g,
        indices={
            "AB": torch.tensor([0]),
            "BC": torch.tensor([0]),
        },
        graph_sampler=block_sampler,
        batch_size=2,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        device=F.ctx(),
        use_ddp=False,
    )

    next(iter(d))


800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
def dummy_worker_init_fn(worker_id):
    pass


def test_dataloader_worker_init_fn():
    dataset = dgl.data.CoraFullDataset()
    g = dataset[0]
    sampler = dgl.dataloading.MultiLayerNeighborSampler([2])
    dataloader = dgl.dataloading.DataLoader(
        g,
        torch.arange(100),
        sampler,
        batch_size=4,
        num_workers=4,
        worker_init_fn=dummy_worker_init_fn,
    )
    for _ in dataloader:
        pass


820
821
822
823
824
if __name__ == "__main__":
    # test_node_dataloader(F.int32, 'neighbor', None)
    test_edge_dataloader_excludes(
        "reverse_types", False, 1, dgl.dataloading.ShaDowKHopSampler([5])
    )
825
    test_edge_dataloader_exclusion_without_all_reverses()