test_dataloader.py 27.4 KB
Newer Older
1
import os
2
3
4
5
6
7
import unittest
from collections.abc import Iterator, Mapping
from functools import partial

import backend as F

8
import dgl
9
import dgl.ops as OPS
10
11
import numpy as np
import pytest
12
import torch
13
import torch.distributed as dist
14
import torch.multiprocessing as mp
15
from utils import parametrize_idtype
16
17


18
@pytest.mark.parametrize("batch_size", [None, 16])
19
def test_graph_dataloader(batch_size):
20
    num_batches = 2
21
22
    num_samples = num_batches * (batch_size if batch_size is not None else 1)
    minigc_dataset = dgl.data.MiniGCDataset(num_samples, 10, 20)
23
24
25
    data_loader = dgl.dataloading.GraphDataLoader(
        minigc_dataset, batch_size=batch_size, shuffle=True
    )
26
    assert isinstance(iter(data_loader), Iterator)
27
28
    for graph, label in data_loader:
        assert isinstance(graph, dgl.DGLGraph)
29
30
31
32
33
34
        if batch_size is not None:
            assert F.asnumpy(label).shape[0] == batch_size
        else:
            # If batch size is None, the label element will be a single scalar following
            # PyTorch's practice.
            assert F.asnumpy(label).ndim == 0
35

36
37
38

@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("num_workers", [0, 4])
39
40
41
def test_cluster_gcn(num_workers):
    dataset = dgl.data.CoraFullDataset()
    g = dataset[0]
42
43
    sampler = dgl.dataloading.ClusterGCNSampler(g, 100)
    dataloader = dgl.dataloading.DataLoader(
44
45
        g, torch.arange(100), sampler, batch_size=4, num_workers=num_workers
    )
46
47
48
    assert len(dataloader) == 25
    for i, sg in enumerate(dataloader):
        pass
49

50
51

@pytest.mark.parametrize("num_workers", [0, 4])
52
53
54
def test_shadow(num_workers):
    g = dgl.data.CoraFullDataset()[0]
    sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15])
55
    dataloader = dgl.dataloading.DataLoader(
56
57
58
59
60
61
62
63
        g,
        torch.arange(g.num_nodes()),
        sampler,
        batch_size=5,
        shuffle=True,
        drop_last=False,
        num_workers=num_workers,
    )
64
    for i, (input_nodes, output_nodes, subgraph) in enumerate(dataloader):
65
        assert torch.equal(input_nodes, subgraph.ndata[dgl.NID])
66
67
68
69
70
        assert torch.equal(input_nodes[: output_nodes.shape[0]], output_nodes)
        assert torch.equal(
            subgraph.ndata["label"], g.ndata["label"][input_nodes]
        )
        assert torch.equal(subgraph.ndata["feat"], g.ndata["feat"][input_nodes])
71
72
73
        if i == 5:
            break

74
75
76

@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("mode", ["node", "edge", "walk"])
77
78
79
def test_saint(num_workers, mode):
    g = dgl.data.CoraFullDataset()[0]

80
    if mode == "node":
81
        budget = 100
82
    elif mode == "edge":
83
        budget = 200
84
    elif mode == "walk":
85
86
87
88
        budget = (3, 2)

    sampler = dgl.dataloading.SAINTSampler(mode, budget)
    dataloader = dgl.dataloading.DataLoader(
89
90
        g, torch.arange(100), sampler, num_workers=num_workers
    )
91
92
93
    assert len(dataloader) == 100
    for sg in dataloader:
        pass
94

95

96
@parametrize_idtype
97
98
99
100
101
@pytest.mark.parametrize(
    "mode", ["cpu", "uva_cuda_indices", "uva_cpu_indices", "pure_gpu"]
)
@pytest.mark.parametrize("use_ddp", [False, True])
@pytest.mark.parametrize("use_mask", [False, True])
102
def test_neighbor_nonuniform(idtype, mode, use_ddp, use_mask):
103
104
105
106
    if mode != "cpu" and F.ctx() == F.cpu():
        pytest.skip("UVA and GPU sampling require a GPU.")
    if mode != "cpu" and use_mask:
        pytest.skip("Masked sampling only works on CPU.")
107
    if use_ddp:
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        if os.name == "nt":
            pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
        dist.init_process_group(
            "gloo" if F.ctx() == F.cpu() else "nccl",
            "tcp://127.0.0.1:12347",
            world_size=1,
            rank=0,
        )
    g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1])).astype(
        idtype
    )
    g.edata["p"] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
    g.edata["mask"] = g.edata["p"] != 0
    if mode in ("cpu", "uva_cpu_indices"):
122
123
124
        indices = F.copy_to(F.tensor([0, 1], idtype), F.cpu())
    else:
        indices = F.copy_to(F.tensor([0, 1], idtype), F.cuda())
125
    if mode == "pure_gpu":
126
        g = g.to(F.cuda())
127
    use_uva = mode.startswith("uva")
128

129
    if use_mask:
130
        prob, mask = None, "mask"
131
    else:
132
        prob, mask = "p", None
133

134
135
136
137
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [2], prob=prob, mask=mask
    )
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
138
        dataloader = dgl.dataloading.DataLoader(
139
140
141
142
143
            g,
            indices,
            sampler,
            batch_size=1,
            device=F.ctx(),
144
145
            num_workers=num_workers,
            use_uva=use_uva,
146
147
            use_ddp=use_ddp,
        )
148
149
150
151
152
153
154
        for input_nodes, output_nodes, blocks in dataloader:
            seed = output_nodes.item()
            neighbors = set(input_nodes[1:].cpu().numpy())
            if seed == 1:
                assert neighbors == {5, 6}
            elif seed == 0:
                assert neighbors == {1, 2}
155

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    g = dgl.heterograph(
        {
            ("B", "BA", "A"): (
                [1, 2, 3, 4, 5, 6, 7, 8],
                [0, 0, 0, 0, 1, 1, 1, 1],
            ),
            ("C", "CA", "A"): (
                [1, 2, 3, 4, 5, 6, 7, 8],
                [0, 0, 0, 0, 1, 1, 1, 1],
            ),
        }
    ).astype(idtype)
    g.edges["BA"].data["p"] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
    g.edges["BA"].data["mask"] = g.edges["BA"].data["p"] != 0
    g.edges["CA"].data["p"] = torch.FloatTensor([0, 0, 1, 1, 0, 0, 1, 1])
    g.edges["CA"].data["mask"] = g.edges["CA"].data["p"] != 0
    if mode == "pure_gpu":
173
        g = g.to(F.cuda())
174
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
175
        dataloader = dgl.dataloading.DataLoader(
176
177
178
179
180
            g,
            {"A": indices},
            sampler,
            batch_size=1,
            device=F.ctx(),
181
182
            num_workers=num_workers,
            use_uva=use_uva,
183
184
            use_ddp=use_ddp,
        )
185
        for input_nodes, output_nodes, blocks in dataloader:
186
            seed = output_nodes["A"].item()
187
            # Seed and neighbors are of different node types so slicing is not necessary here.
188
            neighbors = set(input_nodes["B"].cpu().numpy())
189
190
191
192
193
            if seed == 1:
                assert neighbors == {5, 6}
            elif seed == 0:
                assert neighbors == {1, 2}

194
            neighbors = set(input_nodes["C"].cpu().numpy())
195
196
197
198
            if seed == 1:
                assert neighbors == {7, 8}
            elif seed == 0:
                assert neighbors == {3, 4}
199

200
201
    if use_ddp:
        dist.destroy_process_group()
202

203

204
205
206
207
208
209
210
211
212
def _check_dtype(data, dtype, attr_name):
    if isinstance(data, dict):
        for k, v in data.items():
            assert getattr(v, attr_name) == dtype
    elif isinstance(data, list):
        for v in data:
            assert getattr(v, attr_name) == dtype
    else:
        assert getattr(data, attr_name) == dtype
213

214

215
216
217
218
219
220
221
222
223
224
def _check_device(data):
    if isinstance(data, dict):
        for k, v in data.items():
            assert v.device == F.ctx()
    elif isinstance(data, list):
        for v in data:
            assert v.device == F.ctx()
    else:
        assert data.device == F.ctx()

225

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
@pytest.mark.parametrize("sampler_name", ["full", "neighbor"])
@pytest.mark.parametrize(
    "mode", ["cpu", "uva_cuda_indices", "uva_cpu_indices", "pure_gpu"]
)
@pytest.mark.parametrize("nprocs", [1, 4])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ddp_dataloader_decompose_dataset(
    sampler_name, mode, nprocs, drop_last
):
    if torch.cuda.device_count() < nprocs and mode != "cpu":
        pytest.skip(
            "DDP dataloader needs sufficient GPUs for UVA and GPU sampling."
        )
    if mode != "cpu" and F.ctx() == F.cpu():
        pytest.skip("UVA and GPU sampling require a GPU.")

    if os.name == "nt":
        pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
    g, _, _, _ = _create_homogeneous()
    g = g.to(F.cpu())

    sampler = {
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
    }[sampler_name]
    indices = F.copy_to(F.arange(0, g.num_nodes()), F.cpu())
    data = indices, sampler
    arguments = mode, drop_last
    g.create_formats_()
    os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // nprocs)
    mp.spawn(_ddp_runner, args=(nprocs, g, data, arguments), nprocs=nprocs)


def _ddp_runner(proc_id, nprocs, g, data, args):
    mode, drop_last = args
    indices, sampler = data
    if mode == "cpu":
        device = torch.device("cpu")
    else:
        device = torch.device(proc_id)
        torch.cuda.set_device(device)
    if mode == "pure_gpu":
        g = g.to(F.cuda())
    if mode in ("cpu", "uva_cpu_indices"):
        indices = indices.cpu()
    else:
        indices = indices.cuda()

    dist.init_process_group(
        "nccl" if mode != "cpu" else "gloo",
        "tcp://127.0.0.1:12347",
        world_size=nprocs,
        rank=proc_id,
    )
    use_uva = mode.startswith("uva")
    batch_size = g.num_nodes()
    shuffle = False
    for num_workers in [1, 4] if mode == "cpu" else [0]:
        dataloader = dgl.dataloading.DataLoader(
            g,
            indices,
            sampler,
            device=device,
            batch_size=batch_size,  # g1.num_nodes(),
            num_workers=num_workers,
            use_uva=use_uva,
            use_ddp=True,
            drop_last=drop_last,
            shuffle=shuffle,
        )
        max_nid = [0]
        for i, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
            block = blocks[-1]
            o_src, o_dst = block.edges()
            src_nodes_id = block.srcdata[dgl.NID][o_src]
            dst_nodes_id = block.dstdata[dgl.NID][o_dst]
            max_nid.append(np.max(dst_nodes_id.cpu().numpy()))

        local_max = torch.tensor(np.max(max_nid))
        if torch.distributed.get_backend() == "nccl":
            local_max = local_max.cuda()
        dist.reduce(local_max, 0, op=dist.ReduceOp.MAX)
        if proc_id == 0:
            if drop_last and not shuffle and local_max > 0:
                assert (
                    local_max.item()
                    == len(indices)
                    - len(indices) % nprocs
                    - 1
                    - (len(indices) // nprocs) % batch_size
                )
            elif not drop_last:
                assert local_max == len(indices) - 1
    dist.destroy_process_group()


nv-dlasalle's avatar
nv-dlasalle committed
322
@parametrize_idtype
323
324
325
@pytest.mark.parametrize(
    "sampler_name", ["full", "neighbor", "neighbor2", "labor"]
)
326
327
328
329
@pytest.mark.parametrize(
    "mode", ["cpu", "uva_cuda_indices", "uva_cpu_indices", "pure_gpu"]
)
@pytest.mark.parametrize("use_ddp", [False, True])
330
def test_node_dataloader(idtype, sampler_name, mode, use_ddp):
331
332
    if mode != "cpu" and F.ctx() == F.cpu():
        pytest.skip("UVA and GPU sampling require a GPU.")
333
    if use_ddp:
334
335
336
337
338
339
340
341
        if os.name == "nt":
            pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
        dist.init_process_group(
            "gloo" if F.ctx() == F.cpu() else "nccl",
            "tcp://127.0.0.1:12347",
            world_size=1,
            rank=0,
        )
342
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
343
344
345
    g1.ndata["feat"] = F.copy_to(F.randn((5, 8)), F.cpu())
    g1.ndata["label"] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu())
    if mode in ("cpu", "uva_cpu_indices"):
346
347
348
        indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cpu())
    else:
        indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cuda())
349
    if mode == "pure_gpu":
350
351
        g1 = g1.to(F.cuda())

352
    use_uva = mode.startswith("uva")
353
354

    sampler = {
355
356
357
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
        "neighbor2": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
358
        "labor": dgl.dataloading.LaborSampler([3, 3]),
359
360
    }[sampler_name]
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
361
        dataloader = dgl.dataloading.DataLoader(
362
363
364
365
            g1,
            indices,
            sampler,
            device=F.ctx(),
366
            batch_size=g1.num_nodes(),
367
368
            num_workers=num_workers,
            use_uva=use_uva,
369
370
            use_ddp=use_ddp,
        )
371
372
373
374
        for input_nodes, output_nodes, blocks in dataloader:
            _check_device(input_nodes)
            _check_device(output_nodes)
            _check_device(blocks)
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
            _check_dtype(input_nodes, idtype, "dtype")
            _check_dtype(output_nodes, idtype, "dtype")
            _check_dtype(blocks, idtype, "idtype")

    g2 = dgl.heterograph(
        {
            ("user", "follow", "user"): (
                [0, 0, 0, 1, 1, 1, 2],
                [1, 2, 3, 0, 2, 3, 0],
            ),
            ("user", "followed-by", "user"): (
                [1, 2, 3, 0, 2, 3, 0],
                [0, 0, 0, 1, 1, 1, 2],
            ),
            ("user", "play", "game"): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
            ("game", "played-by", "user"): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5]),
        }
    ).astype(idtype)
393
    for ntype in g2.ntypes:
394
395
396
397
        g2.nodes[ntype].data["feat"] = F.copy_to(
            F.randn((g2.num_nodes(ntype), 8)), F.cpu()
        )
    if mode in ("cpu", "uva_cpu_indices"):
398
399
400
        indices = {nty: F.copy_to(g2.nodes(nty), F.cpu()) for nty in g2.ntypes}
    else:
        indices = {nty: F.copy_to(g2.nodes(nty), F.cuda()) for nty in g2.ntypes}
401
    if mode == "pure_gpu":
402
        g2 = g2.to(F.cuda())
403

404
    batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
405
    sampler = {
406
407
408
409
410
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler(
            [{etype: 3 for etype in g2.etypes}] * 2
        ),
        "neighbor2": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
411
        "labor": dgl.dataloading.LaborSampler([3, 3]),
412
413
    }[sampler_name]
    for num_workers in [0, 1, 2] if mode == "cpu" else [0]:
414
        dataloader = dgl.dataloading.DataLoader(
415
416
417
418
419
            g2,
            indices,
            sampler,
            device=F.ctx(),
            batch_size=batch_size,
420
421
            num_workers=num_workers,
            use_uva=use_uva,
422
423
            use_ddp=use_ddp,
        )
424
425
426
427
428
        assert isinstance(iter(dataloader), Iterator)
        for input_nodes, output_nodes, blocks in dataloader:
            _check_device(input_nodes)
            _check_device(output_nodes)
            _check_device(blocks)
429
430
431
            _check_dtype(input_nodes, idtype, "dtype")
            _check_dtype(output_nodes, idtype, "dtype")
            _check_dtype(blocks, idtype, "idtype")
432

433
434
    if use_ddp:
        dist.destroy_process_group()
435

436

437
@parametrize_idtype
438
439
440
441
442
443
444
445
446
447
448
@pytest.mark.parametrize("sampler_name", ["full", "neighbor"])
@pytest.mark.parametrize(
    "neg_sampler",
    [
        dgl.dataloading.negative_sampler.Uniform(2),
        dgl.dataloading.negative_sampler.GlobalUniform(15, False, 3),
        dgl.dataloading.negative_sampler.GlobalUniform(15, True, 3),
    ],
)
@pytest.mark.parametrize("mode", ["cpu", "uva", "pure_gpu"])
@pytest.mark.parametrize("use_ddp", [False, True])
449
def test_edge_dataloader(idtype, sampler_name, neg_sampler, mode, use_ddp):
450
451
452
453
454
    if mode != "cpu" and F.ctx() == F.cpu():
        pytest.skip("UVA and GPU sampling require a GPU.")
    if mode == "uva" and isinstance(
        neg_sampler, dgl.dataloading.negative_sampler.GlobalUniform
    ):
455
456
        pytest.skip("GlobalUniform don't support UVA yet.")
    if use_ddp:
457
458
459
460
461
462
463
464
        if os.name == "nt":
            pytest.skip("PyTorch 1.13.0+ has problems in Windows DDP...")
        dist.init_process_group(
            "gloo" if F.ctx() == F.cpu() else "nccl",
            "tcp://127.0.0.1:12347",
            world_size=1,
            rank=0,
        )
465
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
466
467
    g1.ndata["feat"] = F.copy_to(F.randn((5, 8)), F.cpu())
    if mode == "pure_gpu":
468
        g1 = g1.to(F.cuda())
469

470
    sampler = {
471
472
473
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
    }[sampler_name]
474

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
475
    # no negative sampler
476
477
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
    dataloader = dgl.dataloading.DataLoader(
478
479
480
481
482
483
484
485
        g1,
        g1.edges(form="eid"),
        edge_sampler,
        device=F.ctx(),
        batch_size=g1.num_edges(),
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
486
487
488
489
490
    for input_nodes, pos_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
491
    # negative sampler
492
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(
493
494
        sampler, negative_sampler=neg_sampler
    )
495
    dataloader = dgl.dataloading.DataLoader(
496
497
498
499
500
501
502
503
        g1,
        g1.edges(form="eid"),
        edge_sampler,
        device=F.ctx(),
        batch_size=g1.num_edges(),
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
504
505
506
507
508
509
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

510
511
512
513
514
515
516
517
518
519
520
521
522
523
    g2 = dgl.heterograph(
        {
            ("user", "follow", "user"): (
                [0, 0, 0, 1, 1, 1, 2],
                [1, 2, 3, 0, 2, 3, 0],
            ),
            ("user", "followed-by", "user"): (
                [1, 2, 3, 0, 2, 3, 0],
                [0, 0, 0, 1, 1, 1, 2],
            ),
            ("user", "play", "game"): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
            ("game", "played-by", "user"): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5]),
        }
    ).astype(idtype)
524
    for ntype in g2.ntypes:
525
526
527
528
        g2.nodes[ntype].data["feat"] = F.copy_to(
            F.randn((g2.num_nodes(ntype), 8)), F.cpu()
        )
    if mode == "pure_gpu":
529
530
        g2 = g2.to(F.cuda())

531
    batch_size = max(g2.num_edges(ety) for ety in g2.canonical_etypes)
532
    sampler = {
533
534
535
536
537
        "full": dgl.dataloading.MultiLayerFullNeighborSampler(2),
        "neighbor": dgl.dataloading.MultiLayerNeighborSampler(
            [{etype: 3 for etype in g2.etypes}] * 2
        ),
    }[sampler_name]
538

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
539
    # no negative sampler
540
541
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
    dataloader = dgl.dataloading.DataLoader(
542
543
544
545
546
547
548
549
        g2,
        {ety: g2.edges(form="eid", etype=ety) for ety in g2.canonical_etypes},
        edge_sampler,
        device=F.ctx(),
        batch_size=batch_size,
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
550
    for input_nodes, pos_pair_graph, blocks in dataloader:
551
552
553
554
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
555
    # negative sampler
556
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(
557
558
        sampler, negative_sampler=neg_sampler
    )
559
    dataloader = dgl.dataloading.DataLoader(
560
561
562
563
564
565
566
567
        g2,
        {ety: g2.edges(form="eid", etype=ety) for ety in g2.canonical_etypes},
        edge_sampler,
        device=F.ctx(),
        batch_size=batch_size,
        use_uva=(mode == "uva"),
        use_ddp=use_ddp,
    )
568

569
    assert isinstance(iter(dataloader), Iterator)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
570
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
571
572
573
574
575
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

576
577
    if use_ddp:
        dist.destroy_process_group()
578

579

580
581
582
583
584
585
def _create_homogeneous():
    s = torch.randint(0, 200, (1000,), device=F.ctx())
    d = torch.randint(0, 200, (1000,), device=F.ctx())
    src = torch.cat([s, d])
    dst = torch.cat([d, s])
    g = dgl.graph((s, d), num_nodes=200)
586
587
588
    reverse_eids = torch.cat(
        [torch.arange(1000, 2000), torch.arange(0, 1000)]
    ).to(F.ctx())
589
590
591
592
    always_exclude = torch.randint(0, 1000, (50,), device=F.ctx())
    seed_edges = torch.arange(0, 1000, device=F.ctx())
    return g, reverse_eids, always_exclude, seed_edges

593

594
595
def _create_heterogeneous():
    edges = {}
596
    for utype, etype, vtype in [("A", "AA", "A"), ("A", "AB", "B")]:
597
598
599
        s = torch.randint(0, 200, (1000,), device=F.ctx())
        d = torch.randint(0, 200, (1000,), device=F.ctx())
        edges[utype, etype, vtype] = (s, d)
600
601
602
603
604
605
606
607
        edges[vtype, "rev-" + etype, utype] = (d, s)
    g = dgl.heterograph(edges, num_nodes_dict={"A": 200, "B": 200})
    reverse_etypes = {
        "AA": "rev-AA",
        "AB": "rev-AB",
        "rev-AA": "AA",
        "rev-AB": "AB",
    }
608
    always_exclude = {
609
610
611
        "AA": torch.randint(0, 1000, (50,), device=F.ctx()),
        "AB": torch.randint(0, 1000, (50,), device=F.ctx()),
    }
612
    seed_edges = {
613
614
615
        "AA": torch.arange(0, 1000, device=F.ctx()),
        "AB": torch.arange(0, 1000, device=F.ctx()),
    }
616
617
    return g, reverse_etypes, always_exclude, seed_edges

618

619
620
621
622
623
def _remove_duplicates(s, d):
    s, d = list(zip(*list(set(zip(s.tolist(), d.tolist())))))
    return torch.tensor(s, device=F.ctx()), torch.tensor(d, device=F.ctx())


624
625
626
def _find_edges_to_exclude(g, exclude, always_exclude, pair_eids):
    if exclude == None:
        return always_exclude
627
628
629
630
631
632
633
    elif exclude == "self":
        return (
            torch.cat([pair_eids, always_exclude])
            if always_exclude is not None
            else pair_eids
        )
    elif exclude == "reverse_id":
634
        pair_eids = torch.cat([pair_eids, pair_eids + 1000])
635
636
637
638
639
640
        return (
            torch.cat([pair_eids, always_exclude])
            if always_exclude is not None
            else pair_eids
        )
    elif exclude == "reverse_types":
641
        pair_eids = {g.to_canonical_etype(k): v for k, v in pair_eids.items()}
642
643
644
645
        if ("A", "AA", "A") in pair_eids:
            pair_eids[("A", "rev-AA", "A")] = pair_eids[("A", "AA", "A")]
        if ("A", "AB", "B") in pair_eids:
            pair_eids[("B", "rev-AB", "A")] = pair_eids[("A", "AB", "B")]
646
        if always_exclude is not None:
647
648
649
            always_exclude = {
                g.to_canonical_etype(k): v for k, v in always_exclude.items()
            }
650
651
652
653
654
655
656
            for k in always_exclude.keys():
                if k in pair_eids:
                    pair_eids[k] = torch.cat([pair_eids[k], always_exclude[k]])
                else:
                    pair_eids[k] = always_exclude[k]
        return pair_eids

657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673

@pytest.mark.parametrize("always_exclude_flag", [False, True])
@pytest.mark.parametrize(
    "exclude", [None, "self", "reverse_id", "reverse_types"]
)
@pytest.mark.parametrize(
    "sampler",
    [
        dgl.dataloading.MultiLayerFullNeighborSampler(1),
        dgl.dataloading.ShaDowKHopSampler([5]),
    ],
)
@pytest.mark.parametrize("batch_size", [1, 50])
def test_edge_dataloader_excludes(
    exclude, always_exclude_flag, batch_size, sampler
):
    if exclude == "reverse_types":
674
675
676
677
678
679
680
681
        g, reverse_etypes, always_exclude, seed_edges = _create_heterogeneous()
    else:
        g, reverse_eids, always_exclude, seed_edges = _create_homogeneous()
    g = g.to(F.ctx())
    if not always_exclude_flag:
        always_exclude = None

    kwargs = {}
682
683
684
685
686
687
688
689
690
    kwargs["exclude"] = (
        partial(_find_edges_to_exclude, g, exclude, always_exclude)
        if always_exclude_flag
        else exclude
    )
    kwargs["reverse_eids"] = reverse_eids if exclude == "reverse_id" else None
    kwargs["reverse_etypes"] = (
        reverse_etypes if exclude == "reverse_types" else None
    )
691
    sampler = dgl.dataloading.as_edge_prediction_sampler(sampler, **kwargs)
692

693
    dataloader = dgl.dataloading.DataLoader(
694
695
696
697
698
699
700
        g,
        seed_edges,
        sampler,
        batch_size=batch_size,
        device=F.ctx(),
        use_prefetch_thread=False,
    )
701
702
703
704
705
    for i, (input_nodes, pair_graph, blocks) in enumerate(dataloader):
        if isinstance(blocks, list):
            subg = blocks[0]
        else:
            subg = blocks
706
        pair_eids = pair_graph.edata[dgl.EID]
707
        block_eids = subg.edata[dgl.EID]
708

709
710
711
        edges_to_exclude = _find_edges_to_exclude(
            g, exclude, always_exclude, pair_eids
        )
712
713
        if edges_to_exclude is None:
            continue
714
715
716
717
718
719
        edges_to_exclude = dgl.utils.recursive_apply(
            edges_to_exclude, lambda x: x.cpu().numpy()
        )
        block_eids = dgl.utils.recursive_apply(
            block_eids, lambda x: x.cpu().numpy()
        )
720
721
722
723
724
725
726

        if isinstance(edges_to_exclude, Mapping):
            for k in edges_to_exclude.keys():
                assert not np.isin(edges_to_exclude[k], block_eids[k]).any()
        else:
            assert not np.isin(edges_to_exclude, block_eids).any()

727
728
729
        if i == 10:
            break

730

731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
def test_edge_dataloader_exclusion_with_reverse_seed_nodes():
    utype, etype, vtype = ("A", "AB", "B")
    s = torch.randint(0, 20, (500,), device=F.ctx())
    d = torch.randint(0, 20, (500,), device=F.ctx())
    s, d = _remove_duplicates(s, d)
    g = dgl.heterograph({("A", "AB", "B"): (s, d), ("B", "BA", "A"): (d, s)})
    sampler = dgl.dataloading.as_edge_prediction_sampler(
        dgl.dataloading.NeighborSampler(fanouts=[2, 2, 2]),
        exclude="reverse_types",
        reverse_etypes={"AB": "BA", "BA": "AB"},
    )
    seed_edges = {
        "AB": torch.arange(g.number_of_edges("AB"), device=F.ctx()),
        "BA": torch.arange(g.number_of_edges("BA"), device=F.ctx()),
    }
    dataloader = dgl.dataloading.DataLoader(
        g,
        seed_edges,
        sampler,
        batch_size=2,
        device=F.ctx(),
        shuffle=True,
        drop_last=False,
    )
    for _, pos_graph, mfgs in dataloader:
        s, d = pos_graph["AB"].edges()
        AB_pos = list(zip(s.tolist(), d.tolist()))
        s, d = pos_graph["BA"].edges()
        BA_pos = list(zip(s.tolist(), d.tolist()))

        s, d = mfgs[-1]["AB"].edges()
        AB_mfg = list(zip(s.tolist(), d.tolist()))
        s, d = mfgs[-1]["BA"].edges()
        BA_mfg = list(zip(s.tolist(), d.tolist()))

        assert all(edge not in AB_mfg for edge in AB_pos)
        assert all(edge not in BA_mfg for edge in BA_pos)


770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
def test_edge_dataloader_exclusion_without_all_reverses():
    data_dict = {
        ("A", "AB", "B"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
        ("B", "BA", "A"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
        ("B", "BC", "C"): (torch.tensor([0]), torch.tensor([0])),
        ("C", "CA", "A"): (torch.tensor([0, 1]), torch.tensor([0, 1])),
    }
    g = dgl.heterograph(data_dict=data_dict)
    block_sampler = dgl.dataloading.MultiLayerNeighborSampler(
        fanouts=[1], replace=True
    )
    block_sampler = dgl.dataloading.as_edge_prediction_sampler(
        block_sampler,
        exclude="reverse_types",
        reverse_etypes={"AB": "BA"},
    )
    d = dgl.dataloading.DataLoader(
        graph=g,
        indices={
            "AB": torch.tensor([0]),
            "BC": torch.tensor([0]),
        },
        graph_sampler=block_sampler,
        batch_size=2,
        shuffle=True,
        drop_last=False,
        num_workers=0,
        device=F.ctx(),
        use_ddp=False,
    )

    next(iter(d))


804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
def dummy_worker_init_fn(worker_id):
    pass


def test_dataloader_worker_init_fn():
    dataset = dgl.data.CoraFullDataset()
    g = dataset[0]
    sampler = dgl.dataloading.MultiLayerNeighborSampler([2])
    dataloader = dgl.dataloading.DataLoader(
        g,
        torch.arange(100),
        sampler,
        batch_size=4,
        num_workers=4,
        worker_init_fn=dummy_worker_init_fn,
    )
    for _ in dataloader:
        pass


824
825
826
827
828
if __name__ == "__main__":
    # test_node_dataloader(F.int32, 'neighbor', None)
    test_edge_dataloader_excludes(
        "reverse_types", False, 1, dgl.dataloading.ShaDowKHopSampler([5])
    )
829
    test_edge_dataloader_exclusion_without_all_reverses()