test_dataloader.py 19.5 KB
Newer Older
1
import os
2
import numpy as np
3
import dgl
4
import dgl.ops as OPS
5
6
import backend as F
import unittest
7
import torch
8
import torch.distributed as dist
9
10
from functools import partial
from collections.abc import Iterator, Mapping
nv-dlasalle's avatar
nv-dlasalle committed
11
from test_utils import parametrize_idtype
12
import pytest
13
14


15
16
@pytest.mark.parametrize('batch_size', [None, 16])
def test_graph_dataloader(batch_size):
17
    num_batches = 2
18
19
    num_samples = num_batches * (batch_size if batch_size is not None else 1)
    minigc_dataset = dgl.data.MiniGCDataset(num_samples, 10, 20)
20
    data_loader = dgl.dataloading.GraphDataLoader(minigc_dataset, batch_size=batch_size, shuffle=True)
21
    assert isinstance(iter(data_loader), Iterator)
22
23
    for graph, label in data_loader:
        assert isinstance(graph, dgl.DGLGraph)
24
25
26
27
28
29
        if batch_size is not None:
            assert F.asnumpy(label).shape[0] == batch_size
        else:
            # If batch size is None, the label element will be a single scalar following
            # PyTorch's practice.
            assert F.asnumpy(label).ndim == 0
30

31
32
33
34
35
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize('num_workers', [0, 4])
def test_cluster_gcn(num_workers):
    dataset = dgl.data.CoraFullDataset()
    g = dataset[0]
36
37
38
39
40
41
    sampler = dgl.dataloading.ClusterGCNSampler(g, 100)
    dataloader = dgl.dataloading.DataLoader(
        g, torch.arange(100), sampler, batch_size=4, num_workers=num_workers)
    assert len(dataloader) == 25
    for i, sg in enumerate(dataloader):
        pass
42
43
44
45
46

@pytest.mark.parametrize('num_workers', [0, 4])
def test_shadow(num_workers):
    g = dgl.data.CoraFullDataset()[0]
    sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15])
47
    dataloader = dgl.dataloading.DataLoader(
48
49
        g, torch.arange(g.num_nodes()), sampler,
        batch_size=5, shuffle=True, drop_last=False, num_workers=num_workers)
50
    for i, (input_nodes, output_nodes, subgraph) in enumerate(dataloader):
51
52
53
54
55
56
57
        assert torch.equal(input_nodes, subgraph.ndata[dgl.NID])
        assert torch.equal(input_nodes[:output_nodes.shape[0]], output_nodes)
        assert torch.equal(subgraph.ndata['label'], g.ndata['label'][input_nodes])
        assert torch.equal(subgraph.ndata['feat'], g.ndata['feat'][input_nodes])
        if i == 5:
            break

58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
@pytest.mark.parametrize('num_workers', [0, 4])
@pytest.mark.parametrize('mode', ['node', 'edge', 'walk'])
def test_saint(num_workers, mode):
    g = dgl.data.CoraFullDataset()[0]

    if mode == 'node':
        budget = 100
    elif mode == 'edge':
        budget = 200
    elif mode == 'walk':
        budget = (3, 2)

    sampler = dgl.dataloading.SAINTSampler(mode, budget)
    dataloader = dgl.dataloading.DataLoader(
        g, torch.arange(100), sampler, num_workers=num_workers)
    assert len(dataloader) == 100
    for sg in dataloader:
        pass
76

77
78
79
@parametrize_idtype
@pytest.mark.parametrize('mode', ['cpu', 'uva_cuda_indices', 'uva_cpu_indices', 'pure_gpu'])
@pytest.mark.parametrize('use_ddp', [False, True])
80
81
@pytest.mark.parametrize('use_mask', [False, True])
def test_neighbor_nonuniform(idtype, mode, use_ddp, use_mask):
82
83
    if mode != 'cpu' and F.ctx() == F.cpu():
        pytest.skip('UVA and GPU sampling require a GPU.')
84
85
    if mode != 'cpu' and use_mask:
        pytest.skip('Masked sampling only works on CPU.')
86
    if use_ddp:
87
88
        if os.name == 'nt':
            pytest.skip('PyTorch 1.13.0+ has problems in Windows DDP...')
89
90
91
        dist.init_process_group('gloo' if F.ctx() == F.cpu() else 'nccl',
            'tcp://127.0.0.1:12347', world_size=1, rank=0)
    g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1])).astype(idtype)
92
    g.edata['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
93
    g.edata['mask'] = (g.edata['p'] != 0)
94
95
96
97
98
99
100
101
    if mode in ('cpu', 'uva_cpu_indices'):
        indices = F.copy_to(F.tensor([0, 1], idtype), F.cpu())
    else:
        indices = F.copy_to(F.tensor([0, 1], idtype), F.cuda())
    if mode == 'pure_gpu':
        g = g.to(F.cuda())
    use_uva = mode.startswith('uva')

102
103
104
105
106
107
    if use_mask:
        prob, mask = None, 'mask'
    else:
        prob, mask = 'p', None

    sampler = dgl.dataloading.MultiLayerNeighborSampler([2], prob=prob, mask=mask)
108
    for num_workers in [0, 1, 2] if mode == 'cpu' else [0]:
109
        dataloader = dgl.dataloading.DataLoader(
110
111
112
113
114
115
116
117
118
119
120
121
            g, indices, sampler,
            batch_size=1, device=F.ctx(),
            num_workers=num_workers,
            use_uva=use_uva,
            use_ddp=use_ddp)
        for input_nodes, output_nodes, blocks in dataloader:
            seed = output_nodes.item()
            neighbors = set(input_nodes[1:].cpu().numpy())
            if seed == 1:
                assert neighbors == {5, 6}
            elif seed == 0:
                assert neighbors == {1, 2}
122
123
124
125

    g = dgl.heterograph({
        ('B', 'BA', 'A'): ([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]),
        ('C', 'CA', 'A'): ([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]),
126
        }).astype(idtype)
127
    g.edges['BA'].data['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
128
    g.edges['BA'].data['mask'] = (g.edges['BA'].data['p'] != 0)
129
    g.edges['CA'].data['p'] = torch.FloatTensor([0, 0, 1, 1, 0, 0, 1, 1])
130
    g.edges['CA'].data['mask'] = (g.edges['CA'].data['p'] != 0)
131
132
133
    if mode == 'pure_gpu':
        g = g.to(F.cuda())
    for num_workers in [0, 1, 2] if mode == 'cpu' else [0]:
134
        dataloader = dgl.dataloading.DataLoader(
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
            g, {'A': indices}, sampler,
            batch_size=1, device=F.ctx(),
            num_workers=num_workers,
            use_uva=use_uva,
            use_ddp=use_ddp)
        for input_nodes, output_nodes, blocks in dataloader:
            seed = output_nodes['A'].item()
            # Seed and neighbors are of different node types so slicing is not necessary here.
            neighbors = set(input_nodes['B'].cpu().numpy())
            if seed == 1:
                assert neighbors == {5, 6}
            elif seed == 0:
                assert neighbors == {1, 2}

            neighbors = set(input_nodes['C'].cpu().numpy())
            if seed == 1:
                assert neighbors == {7, 8}
            elif seed == 0:
                assert neighbors == {3, 4}
154

155
156
    if use_ddp:
        dist.destroy_process_group()
157

158
159
160
161
162
163
164
165
166
def _check_dtype(data, dtype, attr_name):
    if isinstance(data, dict):
        for k, v in data.items():
            assert getattr(v, attr_name) == dtype
    elif isinstance(data, list):
        for v in data:
            assert getattr(v, attr_name) == dtype
    else:
        assert getattr(data, attr_name) == dtype
167

168
169
170
171
172
173
174
175
176
177
def _check_device(data):
    if isinstance(data, dict):
        for k, v in data.items():
            assert v.device == F.ctx()
    elif isinstance(data, list):
        for v in data:
            assert v.device == F.ctx()
    else:
        assert data.device == F.ctx()

nv-dlasalle's avatar
nv-dlasalle committed
178
@parametrize_idtype
179
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2'])
180
181
182
183
184
185
@pytest.mark.parametrize('mode', ['cpu', 'uva_cuda_indices', 'uva_cpu_indices', 'pure_gpu'])
@pytest.mark.parametrize('use_ddp', [False, True])
def test_node_dataloader(idtype, sampler_name, mode, use_ddp):
    if mode != 'cpu' and F.ctx() == F.cpu():
        pytest.skip('UVA and GPU sampling require a GPU.')
    if use_ddp:
186
187
        if os.name == 'nt':
            pytest.skip('PyTorch 1.13.0+ has problems in Windows DDP...')
188
189
        dist.init_process_group('gloo' if F.ctx() == F.cpu() else 'nccl',
            'tcp://127.0.0.1:12347', world_size=1, rank=0)
190
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
Xin Yao's avatar
Xin Yao committed
191
    g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
192
    g1.ndata['label'] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu())
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    if mode in ('cpu', 'uva_cpu_indices'):
        indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cpu())
    else:
        indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cuda())
    if mode == 'pure_gpu':
        g1 = g1.to(F.cuda())

    use_uva = mode.startswith('uva')

    sampler = {
        'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
        'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
        'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name]
    for num_workers in [0, 1, 2] if mode == 'cpu' else [0]:
        dataloader = dgl.dataloading.DataLoader(
208
            g1, indices, sampler, device=F.ctx(),
209
            batch_size=g1.num_nodes(),
210
211
212
            num_workers=num_workers,
            use_uva=use_uva,
            use_ddp=use_ddp)
213
214
215
216
        for input_nodes, output_nodes, blocks in dataloader:
            _check_device(input_nodes)
            _check_device(output_nodes)
            _check_device(blocks)
217
218
219
            _check_dtype(input_nodes, idtype, 'dtype')
            _check_dtype(output_nodes, idtype, 'dtype')
            _check_dtype(blocks, idtype, 'idtype')
220
221
222
223
224
225

    g2 = dgl.heterograph({
         ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
         ('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
         ('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
         ('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
226
    }).astype(idtype)
227
    for ntype in g2.ntypes:
Xin Yao's avatar
Xin Yao committed
228
        g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
229
230
231
232
233
234
    if mode in ('cpu', 'uva_cpu_indices'):
        indices = {nty: F.copy_to(g2.nodes(nty), F.cpu()) for nty in g2.ntypes}
    else:
        indices = {nty: F.copy_to(g2.nodes(nty), F.cuda()) for nty in g2.ntypes}
    if mode == 'pure_gpu':
        g2 = g2.to(F.cuda())
235

236
    batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
237
238
239
    sampler = {
        'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
        'neighbor': dgl.dataloading.MultiLayerNeighborSampler([{etype: 3 for etype in g2.etypes}] * 2),
240
        'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name]
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    for num_workers in [0, 1, 2] if mode == 'cpu' else [0]:
        dataloader = dgl.dataloading.DataLoader(
            g2, indices, sampler,
            device=F.ctx(), batch_size=batch_size,
            num_workers=num_workers,
            use_uva=use_uva,
            use_ddp=use_ddp)
        assert isinstance(iter(dataloader), Iterator)
        for input_nodes, output_nodes, blocks in dataloader:
            _check_device(input_nodes)
            _check_device(output_nodes)
            _check_device(blocks)
            _check_dtype(input_nodes, idtype, 'dtype')
            _check_dtype(output_nodes, idtype, 'dtype')
            _check_dtype(blocks, idtype, 'idtype')
256

257
258
    if use_ddp:
        dist.destroy_process_group()
259

260
@parametrize_idtype
261
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor'])
262
263
264
265
@pytest.mark.parametrize('neg_sampler', [
    dgl.dataloading.negative_sampler.Uniform(2),
    dgl.dataloading.negative_sampler.GlobalUniform(15, False, 3),
    dgl.dataloading.negative_sampler.GlobalUniform(15, True, 3)])
266
267
268
269
270
271
272
273
@pytest.mark.parametrize('mode', ['cpu', 'uva', 'pure_gpu'])
@pytest.mark.parametrize('use_ddp', [False, True])
def test_edge_dataloader(idtype, sampler_name, neg_sampler, mode, use_ddp):
    if mode != 'cpu' and F.ctx() == F.cpu():
        pytest.skip('UVA and GPU sampling require a GPU.')
    if mode == 'uva' and isinstance(neg_sampler, dgl.dataloading.negative_sampler.GlobalUniform):
        pytest.skip("GlobalUniform don't support UVA yet.")
    if use_ddp:
274
275
        if os.name == 'nt':
            pytest.skip('PyTorch 1.13.0+ has problems in Windows DDP...')
276
277
278
        dist.init_process_group('gloo' if F.ctx() == F.cpu() else 'nccl',
            'tcp://127.0.0.1:12347', world_size=1, rank=0)
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
Xin Yao's avatar
Xin Yao committed
279
    g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
280
281
    if mode == 'pure_gpu':
        g1 = g1.to(F.cuda())
282

283
284
    sampler = {
        'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
285
        'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name]
286

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
287
    # no negative sampler
288
289
290
291
292
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
    dataloader = dgl.dataloading.DataLoader(
        g1, g1.edges(form='eid'), edge_sampler,
        device=F.ctx(), batch_size=g1.num_edges(),
        use_uva=(mode == 'uva'), use_ddp=use_ddp)
293
294
295
296
297
    for input_nodes, pos_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
298
    # negative sampler
299
300
301
302
303
304
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(
        sampler, negative_sampler=neg_sampler)
    dataloader = dgl.dataloading.DataLoader(
        g1, g1.edges(form='eid'), edge_sampler,
        device=F.ctx(), batch_size=g1.num_edges(),
        use_uva=(mode == 'uva'), use_ddp=use_ddp)
305
306
307
308
309
310
311
312
313
314
315
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

    g2 = dgl.heterograph({
         ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
         ('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
         ('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
         ('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
316
    }).astype(idtype)
317
    for ntype in g2.ntypes:
Xin Yao's avatar
Xin Yao committed
318
        g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
319
320
321
    if mode == 'pure_gpu':
        g2 = g2.to(F.cuda())

322
    batch_size = max(g2.num_edges(ety) for ety in g2.canonical_etypes)
323
324
325
    sampler = {
        'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
        'neighbor': dgl.dataloading.MultiLayerNeighborSampler([{etype: 3 for etype in g2.etypes}] * 2),
326
        }[sampler_name]
327

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
328
    # no negative sampler
329
330
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
    dataloader = dgl.dataloading.DataLoader(
331
        g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
332
333
        edge_sampler, device=F.ctx(), batch_size=batch_size,
        use_uva=(mode == 'uva'), use_ddp=use_ddp)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
334
    for input_nodes, pos_pair_graph, blocks in dataloader:
335
336
337
338
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
339
    # negative sampler
340
341
342
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(
        sampler, negative_sampler=neg_sampler)
    dataloader = dgl.dataloading.DataLoader(
343
        g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
344
345
        edge_sampler, device=F.ctx(),batch_size=batch_size,
        use_uva=(mode == 'uva'), use_ddp=use_ddp)
346

347
    assert isinstance(iter(dataloader), Iterator)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
348
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
349
350
351
352
353
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

354
355
    if use_ddp:
        dist.destroy_process_group()
356

357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def _create_homogeneous():
    s = torch.randint(0, 200, (1000,), device=F.ctx())
    d = torch.randint(0, 200, (1000,), device=F.ctx())
    src = torch.cat([s, d])
    dst = torch.cat([d, s])
    g = dgl.graph((s, d), num_nodes=200)
    reverse_eids = torch.cat([torch.arange(1000, 2000), torch.arange(0, 1000)]).to(F.ctx())
    always_exclude = torch.randint(0, 1000, (50,), device=F.ctx())
    seed_edges = torch.arange(0, 1000, device=F.ctx())
    return g, reverse_eids, always_exclude, seed_edges

def _create_heterogeneous():
    edges = {}
    for utype, etype, vtype in [('A', 'AA', 'A'), ('A', 'AB', 'B')]:
        s = torch.randint(0, 200, (1000,), device=F.ctx())
        d = torch.randint(0, 200, (1000,), device=F.ctx())
        edges[utype, etype, vtype] = (s, d)
        edges[vtype, 'rev-' + etype, utype] = (d, s)
    g = dgl.heterograph(edges, num_nodes_dict={'A': 200, 'B': 200})
    reverse_etypes = {'AA': 'rev-AA', 'AB': 'rev-AB', 'rev-AA': 'AA', 'rev-AB': 'AB'}
    always_exclude = {
        'AA': torch.randint(0, 1000, (50,), device=F.ctx()),
        'AB': torch.randint(0, 1000, (50,), device=F.ctx())}
    seed_edges = {
        'AA': torch.arange(0, 1000, device=F.ctx()),
        'AB': torch.arange(0, 1000, device=F.ctx())}
    return g, reverse_etypes, always_exclude, seed_edges

def _find_edges_to_exclude(g, exclude, always_exclude, pair_eids):
    if exclude == None:
        return always_exclude
    elif exclude == 'self':
        return torch.cat([pair_eids, always_exclude]) if always_exclude is not None else pair_eids
    elif exclude == 'reverse_id':
        pair_eids = torch.cat([pair_eids, pair_eids + 1000])
        return torch.cat([pair_eids, always_exclude]) if always_exclude is not None else pair_eids
    elif exclude == 'reverse_types':
        pair_eids = {g.to_canonical_etype(k): v for k, v in pair_eids.items()}
        if ('A', 'AA', 'A') in pair_eids:
            pair_eids[('A', 'rev-AA', 'A')] = pair_eids[('A', 'AA', 'A')]
        if ('A', 'AB', 'B') in pair_eids:
            pair_eids[('B', 'rev-AB', 'A')] = pair_eids[('A', 'AB', 'B')]
        if always_exclude is not None:
            always_exclude = {g.to_canonical_etype(k): v for k, v in always_exclude.items()}
            for k in always_exclude.keys():
                if k in pair_eids:
                    pair_eids[k] = torch.cat([pair_eids[k], always_exclude[k]])
                else:
                    pair_eids[k] = always_exclude[k]
        return pair_eids

@pytest.mark.parametrize('always_exclude_flag', [False, True])
@pytest.mark.parametrize('exclude', [None, 'self', 'reverse_id', 'reverse_types'])
410
411
412
413
@pytest.mark.parametrize('sampler', [dgl.dataloading.MultiLayerFullNeighborSampler(1),
                                     dgl.dataloading.ShaDowKHopSampler([5])])
@pytest.mark.parametrize('batch_size', [1, 50])
def test_edge_dataloader_excludes(exclude, always_exclude_flag, batch_size, sampler):
414
415
416
417
418
419
420
421
422
423
424
425
426
427
    if exclude == 'reverse_types':
        g, reverse_etypes, always_exclude, seed_edges = _create_heterogeneous()
    else:
        g, reverse_eids, always_exclude, seed_edges = _create_homogeneous()
    g = g.to(F.ctx())
    if not always_exclude_flag:
        always_exclude = None

    kwargs = {}
    kwargs['exclude'] = (
        partial(_find_edges_to_exclude, g, exclude, always_exclude) if always_exclude_flag
        else exclude)
    kwargs['reverse_eids'] = reverse_eids if exclude == 'reverse_id' else None
    kwargs['reverse_etypes'] = reverse_etypes if exclude == 'reverse_types' else None
428
    sampler = dgl.dataloading.as_edge_prediction_sampler(sampler, **kwargs)
429

430
431
432
433
434
435
436
    dataloader = dgl.dataloading.DataLoader(
        g, seed_edges, sampler, batch_size=batch_size, device=F.ctx(), use_prefetch_thread=False)
    for i, (input_nodes, pair_graph, blocks) in enumerate(dataloader):
        if isinstance(blocks, list):
            subg = blocks[0]
        else:
            subg = blocks
437
        pair_eids = pair_graph.edata[dgl.EID]
438
        block_eids = subg.edata[dgl.EID]
439
440
441
442
443
444
445
446
447
448
449
450
451

        edges_to_exclude = _find_edges_to_exclude(g, exclude, always_exclude, pair_eids)
        if edges_to_exclude is None:
            continue
        edges_to_exclude = dgl.utils.recursive_apply(edges_to_exclude, lambda x: x.cpu().numpy())
        block_eids = dgl.utils.recursive_apply(block_eids, lambda x: x.cpu().numpy())

        if isinstance(edges_to_exclude, Mapping):
            for k in edges_to_exclude.keys():
                assert not np.isin(edges_to_exclude[k], block_eids[k]).any()
        else:
            assert not np.isin(edges_to_exclude, block_eids).any()

452
453
454
        if i == 10:
            break

455
if __name__ == '__main__':
456
457
    #test_node_dataloader(F.int32, 'neighbor', None)
    test_edge_dataloader_excludes('reverse_types', False, 1, dgl.dataloading.ShaDowKHopSampler([5]))