test_dataloader.py 18.5 KB
Newer Older
1
import os
2
import numpy as np
3
import dgl
4
import dgl.ops as OPS
5
6
import backend as F
import unittest
7
import torch
8
import torch.distributed as dist
9
10
from functools import partial
from collections.abc import Iterator, Mapping
nv-dlasalle's avatar
nv-dlasalle committed
11
from test_utils import parametrize_idtype
12
import pytest
13
14


15
16
17
18
19
def test_graph_dataloader():
    batch_size = 16
    num_batches = 2
    minigc_dataset = dgl.data.MiniGCDataset(batch_size * num_batches, 10, 20)
    data_loader = dgl.dataloading.GraphDataLoader(minigc_dataset, batch_size=batch_size, shuffle=True)
20
    assert isinstance(iter(data_loader), Iterator)
21
22
23
    for graph, label in data_loader:
        assert isinstance(graph, dgl.DGLGraph)
        assert F.asnumpy(label).shape[0] == batch_size
24

25
26
27
28
29
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@pytest.mark.parametrize('num_workers', [0, 4])
def test_cluster_gcn(num_workers):
    dataset = dgl.data.CoraFullDataset()
    g = dataset[0]
30
31
32
33
34
35
    sampler = dgl.dataloading.ClusterGCNSampler(g, 100)
    dataloader = dgl.dataloading.DataLoader(
        g, torch.arange(100), sampler, batch_size=4, num_workers=num_workers)
    assert len(dataloader) == 25
    for i, sg in enumerate(dataloader):
        pass
36
37
38
39
40

@pytest.mark.parametrize('num_workers', [0, 4])
def test_shadow(num_workers):
    g = dgl.data.CoraFullDataset()[0]
    sampler = dgl.dataloading.ShaDowKHopSampler([5, 10, 15])
41
    dataloader = dgl.dataloading.DataLoader(
42
43
        g, torch.arange(g.num_nodes()), sampler,
        batch_size=5, shuffle=True, drop_last=False, num_workers=num_workers)
44
    for i, (input_nodes, output_nodes, subgraph) in enumerate(dataloader):
45
46
47
48
49
50
51
        assert torch.equal(input_nodes, subgraph.ndata[dgl.NID])
        assert torch.equal(input_nodes[:output_nodes.shape[0]], output_nodes)
        assert torch.equal(subgraph.ndata['label'], g.ndata['label'][input_nodes])
        assert torch.equal(subgraph.ndata['feat'], g.ndata['feat'][input_nodes])
        if i == 5:
            break

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@pytest.mark.parametrize('num_workers', [0, 4])
@pytest.mark.parametrize('mode', ['node', 'edge', 'walk'])
def test_saint(num_workers, mode):
    g = dgl.data.CoraFullDataset()[0]

    if mode == 'node':
        budget = 100
    elif mode == 'edge':
        budget = 200
    elif mode == 'walk':
        budget = (3, 2)

    sampler = dgl.dataloading.SAINTSampler(mode, budget)
    dataloader = dgl.dataloading.DataLoader(
        g, torch.arange(100), sampler, num_workers=num_workers)
    assert len(dataloader) == 100
    for sg in dataloader:
        pass
70

71
72
73
74
75
76
77
78
79
80
@parametrize_idtype
@pytest.mark.parametrize('mode', ['cpu', 'uva_cuda_indices', 'uva_cpu_indices', 'pure_gpu'])
@pytest.mark.parametrize('use_ddp', [False, True])
def test_neighbor_nonuniform(idtype, mode, use_ddp):
    if mode != 'cpu' and F.ctx() == F.cpu():
        pytest.skip('UVA and GPU sampling require a GPU.')
    if use_ddp:
        dist.init_process_group('gloo' if F.ctx() == F.cpu() else 'nccl',
            'tcp://127.0.0.1:12347', world_size=1, rank=0)
    g = dgl.graph(([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1])).astype(idtype)
81
    g.edata['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
82
83
84
85
86
87
88
89
    if mode in ('cpu', 'uva_cpu_indices'):
        indices = F.copy_to(F.tensor([0, 1], idtype), F.cpu())
    else:
        indices = F.copy_to(F.tensor([0, 1], idtype), F.cuda())
    if mode == 'pure_gpu':
        g = g.to(F.cuda())
    use_uva = mode.startswith('uva')

90
    sampler = dgl.dataloading.MultiLayerNeighborSampler([2], prob='p')
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    for num_workers in [0, 1, 2] if mode == 'cpu' else [0]:
        dataloader = dgl.dataloading.NodeDataLoader(
            g, indices, sampler,
            batch_size=1, device=F.ctx(),
            num_workers=num_workers,
            use_uva=use_uva,
            use_ddp=use_ddp)
        for input_nodes, output_nodes, blocks in dataloader:
            seed = output_nodes.item()
            neighbors = set(input_nodes[1:].cpu().numpy())
            if seed == 1:
                assert neighbors == {5, 6}
            elif seed == 0:
                assert neighbors == {1, 2}
105
106
107
108

    g = dgl.heterograph({
        ('B', 'BA', 'A'): ([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]),
        ('C', 'CA', 'A'): ([1, 2, 3, 4, 5, 6, 7, 8], [0, 0, 0, 0, 1, 1, 1, 1]),
109
        }).astype(idtype)
110
111
    g.edges['BA'].data['p'] = torch.FloatTensor([1, 1, 0, 0, 1, 1, 0, 0])
    g.edges['CA'].data['p'] = torch.FloatTensor([0, 0, 1, 1, 0, 0, 1, 1])
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    if mode == 'pure_gpu':
        g = g.to(F.cuda())
    for num_workers in [0, 1, 2] if mode == 'cpu' else [0]:
        dataloader = dgl.dataloading.NodeDataLoader(
            g, {'A': indices}, sampler,
            batch_size=1, device=F.ctx(),
            num_workers=num_workers,
            use_uva=use_uva,
            use_ddp=use_ddp)
        for input_nodes, output_nodes, blocks in dataloader:
            seed = output_nodes['A'].item()
            # Seed and neighbors are of different node types so slicing is not necessary here.
            neighbors = set(input_nodes['B'].cpu().numpy())
            if seed == 1:
                assert neighbors == {5, 6}
            elif seed == 0:
                assert neighbors == {1, 2}

            neighbors = set(input_nodes['C'].cpu().numpy())
            if seed == 1:
                assert neighbors == {7, 8}
            elif seed == 0:
                assert neighbors == {3, 4}
135

136
137
    if use_ddp:
        dist.destroy_process_group()
138

139
140
141
142
143
144
145
146
147
def _check_dtype(data, dtype, attr_name):
    if isinstance(data, dict):
        for k, v in data.items():
            assert getattr(v, attr_name) == dtype
    elif isinstance(data, list):
        for v in data:
            assert getattr(v, attr_name) == dtype
    else:
        assert getattr(data, attr_name) == dtype
148

149
150
151
152
153
154
155
156
157
158
def _check_device(data):
    if isinstance(data, dict):
        for k, v in data.items():
            assert v.device == F.ctx()
    elif isinstance(data, list):
        for v in data:
            assert v.device == F.ctx()
    else:
        assert data.device == F.ctx()

nv-dlasalle's avatar
nv-dlasalle committed
159
@parametrize_idtype
160
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor', 'neighbor2'])
161
162
163
164
165
166
167
168
@pytest.mark.parametrize('mode', ['cpu', 'uva_cuda_indices', 'uva_cpu_indices', 'pure_gpu'])
@pytest.mark.parametrize('use_ddp', [False, True])
def test_node_dataloader(idtype, sampler_name, mode, use_ddp):
    if mode != 'cpu' and F.ctx() == F.cpu():
        pytest.skip('UVA and GPU sampling require a GPU.')
    if use_ddp:
        dist.init_process_group('gloo' if F.ctx() == F.cpu() else 'nccl',
            'tcp://127.0.0.1:12347', world_size=1, rank=0)
169
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
Xin Yao's avatar
Xin Yao committed
170
    g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
171
    g1.ndata['label'] = F.copy_to(F.randn((g1.num_nodes(),)), F.cpu())
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    if mode in ('cpu', 'uva_cpu_indices'):
        indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cpu())
    else:
        indices = F.copy_to(F.arange(0, g1.num_nodes(), idtype), F.cuda())
    if mode == 'pure_gpu':
        g1 = g1.to(F.cuda())

    use_uva = mode.startswith('uva')

    sampler = {
        'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
        'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3]),
        'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name]
    for num_workers in [0, 1, 2] if mode == 'cpu' else [0]:
        dataloader = dgl.dataloading.DataLoader(
187
            g1, indices, sampler, device=F.ctx(),
188
            batch_size=g1.num_nodes(),
189
190
191
            num_workers=num_workers,
            use_uva=use_uva,
            use_ddp=use_ddp)
192
193
194
195
        for input_nodes, output_nodes, blocks in dataloader:
            _check_device(input_nodes)
            _check_device(output_nodes)
            _check_device(blocks)
196
197
198
            _check_dtype(input_nodes, idtype, 'dtype')
            _check_dtype(output_nodes, idtype, 'dtype')
            _check_dtype(blocks, idtype, 'idtype')
199
200
201
202
203
204

    g2 = dgl.heterograph({
         ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
         ('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
         ('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
         ('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
205
    }).astype(idtype)
206
    for ntype in g2.ntypes:
Xin Yao's avatar
Xin Yao committed
207
        g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
208
209
210
211
212
213
    if mode in ('cpu', 'uva_cpu_indices'):
        indices = {nty: F.copy_to(g2.nodes(nty), F.cpu()) for nty in g2.ntypes}
    else:
        indices = {nty: F.copy_to(g2.nodes(nty), F.cuda()) for nty in g2.ntypes}
    if mode == 'pure_gpu':
        g2 = g2.to(F.cuda())
214

215
    batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)
216
217
218
    sampler = {
        'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
        'neighbor': dgl.dataloading.MultiLayerNeighborSampler([{etype: 3 for etype in g2.etypes}] * 2),
219
        'neighbor2': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name]
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    for num_workers in [0, 1, 2] if mode == 'cpu' else [0]:
        dataloader = dgl.dataloading.DataLoader(
            g2, indices, sampler,
            device=F.ctx(), batch_size=batch_size,
            num_workers=num_workers,
            use_uva=use_uva,
            use_ddp=use_ddp)
        assert isinstance(iter(dataloader), Iterator)
        for input_nodes, output_nodes, blocks in dataloader:
            _check_device(input_nodes)
            _check_device(output_nodes)
            _check_device(blocks)
            _check_dtype(input_nodes, idtype, 'dtype')
            _check_dtype(output_nodes, idtype, 'dtype')
            _check_dtype(blocks, idtype, 'idtype')
235

236
237
    if use_ddp:
        dist.destroy_process_group()
238

239
@parametrize_idtype
240
@pytest.mark.parametrize('sampler_name', ['full', 'neighbor'])
241
242
243
244
@pytest.mark.parametrize('neg_sampler', [
    dgl.dataloading.negative_sampler.Uniform(2),
    dgl.dataloading.negative_sampler.GlobalUniform(15, False, 3),
    dgl.dataloading.negative_sampler.GlobalUniform(15, True, 3)])
245
246
247
248
249
250
251
252
253
254
255
@pytest.mark.parametrize('mode', ['cpu', 'uva', 'pure_gpu'])
@pytest.mark.parametrize('use_ddp', [False, True])
def test_edge_dataloader(idtype, sampler_name, neg_sampler, mode, use_ddp):
    if mode != 'cpu' and F.ctx() == F.cpu():
        pytest.skip('UVA and GPU sampling require a GPU.')
    if mode == 'uva' and isinstance(neg_sampler, dgl.dataloading.negative_sampler.GlobalUniform):
        pytest.skip("GlobalUniform don't support UVA yet.")
    if use_ddp:
        dist.init_process_group('gloo' if F.ctx() == F.cpu() else 'nccl',
            'tcp://127.0.0.1:12347', world_size=1, rank=0)
    g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4])).astype(idtype)
Xin Yao's avatar
Xin Yao committed
256
    g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())
257
258
    if mode == 'pure_gpu':
        g1 = g1.to(F.cuda())
259

260
261
    sampler = {
        'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
262
        'neighbor': dgl.dataloading.MultiLayerNeighborSampler([3, 3])}[sampler_name]
263

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
264
    # no negative sampler
265
266
267
268
269
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
    dataloader = dgl.dataloading.DataLoader(
        g1, g1.edges(form='eid'), edge_sampler,
        device=F.ctx(), batch_size=g1.num_edges(),
        use_uva=(mode == 'uva'), use_ddp=use_ddp)
270
271
272
273
274
    for input_nodes, pos_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
275
    # negative sampler
276
277
278
279
280
281
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(
        sampler, negative_sampler=neg_sampler)
    dataloader = dgl.dataloading.DataLoader(
        g1, g1.edges(form='eid'), edge_sampler,
        device=F.ctx(), batch_size=g1.num_edges(),
        use_uva=(mode == 'uva'), use_ddp=use_ddp)
282
283
284
285
286
287
288
289
290
291
292
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

    g2 = dgl.heterograph({
         ('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
         ('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
         ('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
         ('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
293
    }).astype(idtype)
294
    for ntype in g2.ntypes:
Xin Yao's avatar
Xin Yao committed
295
        g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
296
297
298
    if mode == 'pure_gpu':
        g2 = g2.to(F.cuda())

299
    batch_size = max(g2.num_edges(ety) for ety in g2.canonical_etypes)
300
301
302
    sampler = {
        'full': dgl.dataloading.MultiLayerFullNeighborSampler(2),
        'neighbor': dgl.dataloading.MultiLayerNeighborSampler([{etype: 3 for etype in g2.etypes}] * 2),
303
        }[sampler_name]
304

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
305
    # no negative sampler
306
307
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
    dataloader = dgl.dataloading.DataLoader(
308
        g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
309
310
        edge_sampler, device=F.ctx(), batch_size=batch_size,
        use_uva=(mode == 'uva'), use_ddp=use_ddp)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
311
    for input_nodes, pos_pair_graph, blocks in dataloader:
312
313
314
315
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(blocks)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
316
    # negative sampler
317
318
319
    edge_sampler = dgl.dataloading.as_edge_prediction_sampler(
        sampler, negative_sampler=neg_sampler)
    dataloader = dgl.dataloading.DataLoader(
320
        g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
321
322
        edge_sampler, device=F.ctx(),batch_size=batch_size,
        use_uva=(mode == 'uva'), use_ddp=use_ddp)
323

324
    assert isinstance(iter(dataloader), Iterator)
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
325
    for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
326
327
328
329
330
        _check_device(input_nodes)
        _check_device(pos_pair_graph)
        _check_device(neg_pair_graph)
        _check_device(blocks)

331
332
    if use_ddp:
        dist.destroy_process_group()
333

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
def _create_homogeneous():
    s = torch.randint(0, 200, (1000,), device=F.ctx())
    d = torch.randint(0, 200, (1000,), device=F.ctx())
    src = torch.cat([s, d])
    dst = torch.cat([d, s])
    g = dgl.graph((s, d), num_nodes=200)
    reverse_eids = torch.cat([torch.arange(1000, 2000), torch.arange(0, 1000)]).to(F.ctx())
    always_exclude = torch.randint(0, 1000, (50,), device=F.ctx())
    seed_edges = torch.arange(0, 1000, device=F.ctx())
    return g, reverse_eids, always_exclude, seed_edges

def _create_heterogeneous():
    edges = {}
    for utype, etype, vtype in [('A', 'AA', 'A'), ('A', 'AB', 'B')]:
        s = torch.randint(0, 200, (1000,), device=F.ctx())
        d = torch.randint(0, 200, (1000,), device=F.ctx())
        edges[utype, etype, vtype] = (s, d)
        edges[vtype, 'rev-' + etype, utype] = (d, s)
    g = dgl.heterograph(edges, num_nodes_dict={'A': 200, 'B': 200})
    reverse_etypes = {'AA': 'rev-AA', 'AB': 'rev-AB', 'rev-AA': 'AA', 'rev-AB': 'AB'}
    always_exclude = {
        'AA': torch.randint(0, 1000, (50,), device=F.ctx()),
        'AB': torch.randint(0, 1000, (50,), device=F.ctx())}
    seed_edges = {
        'AA': torch.arange(0, 1000, device=F.ctx()),
        'AB': torch.arange(0, 1000, device=F.ctx())}
    return g, reverse_etypes, always_exclude, seed_edges

def _find_edges_to_exclude(g, exclude, always_exclude, pair_eids):
    if exclude == None:
        return always_exclude
    elif exclude == 'self':
        return torch.cat([pair_eids, always_exclude]) if always_exclude is not None else pair_eids
    elif exclude == 'reverse_id':
        pair_eids = torch.cat([pair_eids, pair_eids + 1000])
        return torch.cat([pair_eids, always_exclude]) if always_exclude is not None else pair_eids
    elif exclude == 'reverse_types':
        pair_eids = {g.to_canonical_etype(k): v for k, v in pair_eids.items()}
        if ('A', 'AA', 'A') in pair_eids:
            pair_eids[('A', 'rev-AA', 'A')] = pair_eids[('A', 'AA', 'A')]
        if ('A', 'AB', 'B') in pair_eids:
            pair_eids[('B', 'rev-AB', 'A')] = pair_eids[('A', 'AB', 'B')]
        if always_exclude is not None:
            always_exclude = {g.to_canonical_etype(k): v for k, v in always_exclude.items()}
            for k in always_exclude.keys():
                if k in pair_eids:
                    pair_eids[k] = torch.cat([pair_eids[k], always_exclude[k]])
                else:
                    pair_eids[k] = always_exclude[k]
        return pair_eids

@pytest.mark.parametrize('always_exclude_flag', [False, True])
@pytest.mark.parametrize('exclude', [None, 'self', 'reverse_id', 'reverse_types'])
387
388
389
390
@pytest.mark.parametrize('sampler', [dgl.dataloading.MultiLayerFullNeighborSampler(1),
                                     dgl.dataloading.ShaDowKHopSampler([5])])
@pytest.mark.parametrize('batch_size', [1, 50])
def test_edge_dataloader_excludes(exclude, always_exclude_flag, batch_size, sampler):
391
392
393
394
395
396
397
398
399
400
401
402
403
404
    if exclude == 'reverse_types':
        g, reverse_etypes, always_exclude, seed_edges = _create_heterogeneous()
    else:
        g, reverse_eids, always_exclude, seed_edges = _create_homogeneous()
    g = g.to(F.ctx())
    if not always_exclude_flag:
        always_exclude = None

    kwargs = {}
    kwargs['exclude'] = (
        partial(_find_edges_to_exclude, g, exclude, always_exclude) if always_exclude_flag
        else exclude)
    kwargs['reverse_eids'] = reverse_eids if exclude == 'reverse_id' else None
    kwargs['reverse_etypes'] = reverse_etypes if exclude == 'reverse_types' else None
405
    sampler = dgl.dataloading.as_edge_prediction_sampler(sampler, **kwargs)
406

407
408
409
410
411
412
413
    dataloader = dgl.dataloading.DataLoader(
        g, seed_edges, sampler, batch_size=batch_size, device=F.ctx(), use_prefetch_thread=False)
    for i, (input_nodes, pair_graph, blocks) in enumerate(dataloader):
        if isinstance(blocks, list):
            subg = blocks[0]
        else:
            subg = blocks
414
        pair_eids = pair_graph.edata[dgl.EID]
415
        block_eids = subg.edata[dgl.EID]
416
417
418
419
420
421
422
423
424
425
426
427
428

        edges_to_exclude = _find_edges_to_exclude(g, exclude, always_exclude, pair_eids)
        if edges_to_exclude is None:
            continue
        edges_to_exclude = dgl.utils.recursive_apply(edges_to_exclude, lambda x: x.cpu().numpy())
        block_eids = dgl.utils.recursive_apply(block_eids, lambda x: x.cpu().numpy())

        if isinstance(edges_to_exclude, Mapping):
            for k in edges_to_exclude.keys():
                assert not np.isin(edges_to_exclude[k], block_eids[k]).any()
        else:
            assert not np.isin(edges_to_exclude, block_eids).any()

429
430
431
        if i == 10:
            break

432
if __name__ == '__main__':
433
434
    #test_node_dataloader(F.int32, 'neighbor', None)
    test_edge_dataloader_excludes('reverse_types', False, 1, dgl.dataloading.ShaDowKHopSampler([5]))