__init__.py 26 KB
Newer Older
1
2
"""DGL PyTorch DataLoaders"""
import inspect
3
import torch as th
4
from torch.utils.data import DataLoader
5
from ..dataloader import NodeCollator, EdgeCollator, GraphCollator
6
7
from ...distributed import DistGraph
from ...distributed import DistDataLoader
8
9
10
11
12
13
14
15
16
17
from ...ndarray import NDArray as DGLNDArray
from ... import backend as F

class _ScalarDataBatcherIter:
    def __init__(self, dataset, batch_size, drop_last):
        self.dataset = dataset
        self.batch_size = batch_size
        self.index = 0
        self.drop_last = drop_last

18
    # Make this an iterator for PyTorch Lightning compatibility
19
20
21
    def __iter__(self):
        return self

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    def __next__(self):
        num_items = self.dataset.shape[0]
        if self.index >= num_items:
            raise StopIteration
        end_idx = self.index + self.batch_size
        if end_idx > num_items:
            if self.drop_last:
                raise StopIteration
            end_idx = num_items
        batch = self.dataset[self.index:end_idx]
        self.index += self.batch_size

        return batch

class _ScalarDataBatcher(th.utils.data.IterableDataset):
    """Custom Dataset wrapper to return mini-batches as tensors, rather than as
    lists. When the dataset is on the GPU, this significantly reduces
    the overhead. For the case of a batch size of 1024, instead of giving a
    list of 1024 tensors to the collator, a single tensor of 1024 dimensions
    is passed in.
    """
    def __init__(self, dataset, shuffle=False, batch_size=1,
                 drop_last=False):
        super(_ScalarDataBatcher).__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

    def __iter__(self):
        worker_info = th.utils.data.get_worker_info()
        dataset = self.dataset
        if worker_info:
            # worker gets only a fraction of the dataset
            chunk_size = dataset.shape[0] // worker_info.num_workers
            left_over = dataset.shape[0] % worker_info.num_workers
            start = (chunk_size*worker_info.id) + min(left_over, worker_info.id)
            end = start + chunk_size + (worker_info.id < left_over)
            assert worker_info.id < worker_info.num_workers-1 or \
                end == dataset.shape[0]
            dataset = dataset[start:end]

        if self.shuffle:
            # permute the dataset
            perm = th.randperm(dataset.shape[0], device=dataset.device)
            dataset = dataset[perm]

        return _ScalarDataBatcherIter(dataset, self.batch_size, self.drop_last)
70

71
72
73
74
    def __len__(self):
        return (self.dataset.shape[0] + (0 if self.drop_last else self.batch_size - 1)) // \
            self.batch_size

75
76
77
78
79
80
81
82
def _remove_kwargs_dist(kwargs):
    if 'num_workers' in kwargs:
        del kwargs['num_workers']
    if 'pin_memory' in kwargs:
        del kwargs['pin_memory']
        print('Distributed DataLoader does not support pin_memory')
    return kwargs

83
84
85
# The following code is a fix to the PyTorch-specific issue in
# https://github.com/dmlc/dgl/issues/2137
#
86
87
# Basically the sampled MFGs/subgraphs contain the features extracted from the
# parent graph.  In DGL, the MFGs/subgraphs will hold a reference to the parent
88
89
90
91
92
93
94
95
# graph feature tensor and an index tensor, so that the features could be extracted upon
# request.  However, in the context of multiprocessed sampling, we do not need to
# transmit the parent graph feature tensor from the subprocess to the main process,
# since they are exactly the same tensor, and transmitting a tensor from a subprocess
# to the main process is costly in PyTorch as it uses shared memory.  We work around
# it with the following trick:
#
# In the collator running in the sampler processes:
96
# For each frame in the MFG, we check each column and the column with the same name
97
98
99
100
101
# in the corresponding parent frame.  If the storage of the former column is the
# same object as the latter column, we are sure that the former column is a
# subcolumn of the latter, and set the storage of the former column as None.
#
# In the iterator of the main process:
102
# For each frame in the MFG, we check each column and the column with the same name
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# in the corresponding parent frame.  If the storage of the former column is None,
# we replace it with the storage of the latter column.

def _pop_subframe_storage(subframe, frame):
    for key, col in subframe._columns.items():
        if key in frame._columns and col.storage is frame._columns[key].storage:
            col.storage = None

def _pop_subgraph_storage(subg, g):
    for ntype in subg.ntypes:
        if ntype not in g.ntypes:
            continue
        subframe = subg._node_frames[subg.get_ntype_id(ntype)]
        frame = g._node_frames[g.get_ntype_id(ntype)]
        _pop_subframe_storage(subframe, frame)
    for etype in subg.canonical_etypes:
        if etype not in g.canonical_etypes:
            continue
        subframe = subg._edge_frames[subg.get_etype_id(etype)]
        frame = g._edge_frames[g.get_etype_id(etype)]
        _pop_subframe_storage(subframe, frame)

def _pop_blocks_storage(blocks, g):
    for block in blocks:
        for ntype in block.srctypes:
            if ntype not in g.ntypes:
                continue
            subframe = block._node_frames[block.get_ntype_id_from_src(ntype)]
            frame = g._node_frames[g.get_ntype_id(ntype)]
            _pop_subframe_storage(subframe, frame)
        for ntype in block.dsttypes:
            if ntype not in g.ntypes:
                continue
            subframe = block._node_frames[block.get_ntype_id_from_dst(ntype)]
            frame = g._node_frames[g.get_ntype_id(ntype)]
            _pop_subframe_storage(subframe, frame)
        for etype in block.canonical_etypes:
            if etype not in g.canonical_etypes:
                continue
            subframe = block._edge_frames[block.get_etype_id(etype)]
            frame = g._edge_frames[g.get_etype_id(etype)]
            _pop_subframe_storage(subframe, frame)

def _restore_subframe_storage(subframe, frame):
    for key, col in subframe._columns.items():
        if col.storage is None:
            col.storage = frame._columns[key].storage

def _restore_subgraph_storage(subg, g):
    for ntype in subg.ntypes:
        if ntype not in g.ntypes:
            continue
        subframe = subg._node_frames[subg.get_ntype_id(ntype)]
        frame = g._node_frames[g.get_ntype_id(ntype)]
        _restore_subframe_storage(subframe, frame)
    for etype in subg.canonical_etypes:
        if etype not in g.canonical_etypes:
            continue
        subframe = subg._edge_frames[subg.get_etype_id(etype)]
        frame = g._edge_frames[g.get_etype_id(etype)]
        _restore_subframe_storage(subframe, frame)

def _restore_blocks_storage(blocks, g):
    for block in blocks:
        for ntype in block.srctypes:
            if ntype not in g.ntypes:
                continue
            subframe = block._node_frames[block.get_ntype_id_from_src(ntype)]
            frame = g._node_frames[g.get_ntype_id(ntype)]
            _restore_subframe_storage(subframe, frame)
        for ntype in block.dsttypes:
            if ntype not in g.ntypes:
                continue
            subframe = block._node_frames[block.get_ntype_id_from_dst(ntype)]
            frame = g._node_frames[g.get_ntype_id(ntype)]
            _restore_subframe_storage(subframe, frame)
        for etype in block.canonical_etypes:
            if etype not in g.canonical_etypes:
                continue
            subframe = block._edge_frames[block.get_etype_id(etype)]
            frame = g._edge_frames[g.get_etype_id(etype)]
            _restore_subframe_storage(subframe, frame)

class _NodeCollator(NodeCollator):
    def collate(self, items):
188
        # input_nodes, output_nodes, blocks
189
190
191
        result = super().collate(items)
        _pop_blocks_storage(result[-1], self.g)
        return result
192
193
194
195

class _EdgeCollator(EdgeCollator):
    def collate(self, items):
        if self.negative_sampler is None:
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
196
            # input_nodes, pair_graph, blocks
197
198
199
200
            result = super().collate(items)
            _pop_subgraph_storage(result[1], self.g)
            _pop_blocks_storage(result[-1], self.g_sampling)
            return result
201
        else:
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
202
            # input_nodes, pair_graph, neg_pair_graph, blocks
203
204
205
206
207
            result = super().collate(items)
            _pop_subgraph_storage(result[1], self.g)
            _pop_subgraph_storage(result[2], self.g)
            _pop_blocks_storage(result[-1], self.g_sampling)
            return result
208

209
210
211
212
213
214
215
216
217
218
def _to_device(data, device):
    if isinstance(data, dict):
        for k, v in data.items():
            data[k] = v.to(device)
    elif isinstance(data, list):
        data = [item.to(device) for item in data]
    else:
        data = data.to(device)
    return data

219
220
class _NodeDataLoaderIter:
    def __init__(self, node_dataloader):
221
        self.device = node_dataloader.device
222
223
224
        self.node_dataloader = node_dataloader
        self.iter_ = iter(node_dataloader.dataloader)

225
    # Make this an iterator for PyTorch Lightning compatibility
226
227
228
    def __iter__(self):
        return self

229
    def __next__(self):
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
230
        # input_nodes, output_nodes, blocks
231
232
233
        result_ = next(self.iter_)
        _restore_blocks_storage(result_[-1], self.node_dataloader.collator.g)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
234
        result = [_to_device(data, self.device) for data in result_]
235
        return result
236
237
238

class _EdgeDataLoaderIter:
    def __init__(self, edge_dataloader):
239
        self.device = edge_dataloader.device
240
241
242
        self.edge_dataloader = edge_dataloader
        self.iter_ = iter(edge_dataloader.dataloader)

243
    # Make this an iterator for PyTorch Lightning compatibility
244
245
246
    def __iter__(self):
        return self

247
    def __next__(self):
248
249
250
        result_ = next(self.iter_)

        if self.edge_dataloader.collator.negative_sampler is not None:
251
            # input_nodes, pair_graph, neg_pair_graph, blocks if None.
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
252
            # Otherwise, input_nodes, pair_graph, blocks
253
254
255
256
            _restore_subgraph_storage(result_[2], self.edge_dataloader.collator.g)
        _restore_subgraph_storage(result_[1], self.edge_dataloader.collator.g)
        _restore_blocks_storage(result_[-1], self.edge_dataloader.collator.g_sampling)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
257
        result = [_to_device(data, self.device) for data in result_]
258
        return result
259

260
class NodeDataLoader:
261
    """PyTorch dataloader for batch-iterating over a set of nodes, generating the list
262
    of message flow graphs (MFGs) as computation dependency of the said minibatch.
263
264
265

    Parameters
    ----------
266
    g : DGLGraph
267
268
269
        The graph.
    nids : Tensor or dict[ntype, Tensor]
        The node set to compute outputs.
270
    block_sampler : dgl.dataloading.BlockSampler
271
        The neighborhood sampler.
272
    device : device context, optional
273
        The device of the generated MFGs in each iteration, which should be a
274
        PyTorch device object (e.g., ``torch.device``).
275
    kwargs : dict
276
        Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
277
278
279
280
281
282
283

    Examples
    --------
    To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on
    a homogeneous graph where each node takes messages from all neighbors (assume
    the backend is PyTorch):

284
    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
285
286
287
288
289
    >>> dataloader = dgl.dataloading.NodeDataLoader(
    ...     g, train_nid, sampler,
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for input_nodes, output_nodes, blocks in dataloader:
    ...     train_on(input_nodes, output_nodes, blocks)
290
291
292
293
294
295

    Notes
    -----
    Please refer to
    :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`
    and :ref:`User Guide Section 6 <guide-minibatch>` for usage.
296
297
298
    """
    collator_arglist = inspect.getfullargspec(NodeCollator).args

299
    def __init__(self, g, nids, block_sampler, device='cpu', **kwargs):
300
301
302
303
304
305
306
        collator_kwargs = {}
        dataloader_kwargs = {}
        for k, v in kwargs.items():
            if k in self.collator_arglist:
                collator_kwargs[k] = v
            else:
                dataloader_kwargs[k] = v
307

308
        if isinstance(g, DistGraph):
309
            assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.'
310
311
312
            # Distributed DataLoader currently does not support heterogeneous graphs
            # and does not copy features.  Fallback to normal solution
            self.collator = NodeCollator(g, nids, block_sampler, **collator_kwargs)
313
314
315
316
            _remove_kwargs_dist(dataloader_kwargs)
            self.dataloader = DistDataLoader(self.collator.dataset,
                                             collate_fn=self.collator.collate,
                                             **dataloader_kwargs)
317
            self.is_distributed = True
318
        else:
319
            self.collator = _NodeCollator(g, nids, block_sampler, **collator_kwargs)
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
            dataset = self.collator.dataset

            if th.device(device) != th.device('cpu'):
                # Only use the '_ScalarDataBatcher' when for the GPU, as it
                # doens't seem to have a performance benefit on the CPU.
                assert 'num_workers' not in dataloader_kwargs or \
                    dataloader_kwargs['num_workers'] == 0, \
                    'When performing dataloading from the GPU, num_workers ' \
                    'must be zero.'

                batch_size = dataloader_kwargs.get('batch_size', 0)

                if batch_size > 1:
                    if isinstance(dataset, DGLNDArray):
                        # the dataset needs to be a torch tensor for the
                        # _ScalarDataBatcher
                        dataset = F.zerocopy_from_dgl_ndarray(dataset)
                    if isinstance(dataset, th.Tensor):
                        shuffle = dataloader_kwargs.get('shuffle', False)
                        drop_last = dataloader_kwargs.get('drop_last', False)
                        # manually batch into tensors
                        dataset = _ScalarDataBatcher(dataset,
                                                     batch_size=batch_size,
                                                     shuffle=shuffle,
                                                     drop_last=drop_last)
                        # need to overwrite things that will be handled by the batcher
                        dataloader_kwargs['batch_size'] = None
                        dataloader_kwargs['shuffle'] = False
                        dataloader_kwargs['drop_last'] = False

            self.dataloader = DataLoader(
                dataset,
                collate_fn=self.collator.collate,
                **dataloader_kwargs)
354
            self.is_distributed = False
355
356
357
358
359

            # Precompute the CSR and CSC representations so each subprocess does not
            # duplicate.
            if dataloader_kwargs.get('num_workers', 0) > 0:
                g.create_formats_()
360
        self.device = device
361
362

    def __iter__(self):
363
        """Return the iterator of the data loader."""
364
365
366
367
368
        if self.is_distributed:
            # Directly use the iterator of DistDataLoader, which doesn't copy features anyway.
            return iter(self.dataloader)
        else:
            return _NodeDataLoaderIter(self)
369
370
371
372

    def __len__(self):
        """Return the number of batches of the data loader."""
        return len(self.dataloader)
373

374
class EdgeDataLoader:
375
    """PyTorch dataloader for batch-iterating over a set of edges, generating the list
376
377
    of message flow graphs (MFGs) as computation dependency of the said minibatch for
    edge classification, edge regression, and link prediction.
378

379
380
381
382
383
384
385
386
387
388
389
    For each iteration, the object will yield

    * A tensor of input nodes necessary for computing the representation on edges, or
      a dictionary of node type names and such tensors.

    * A subgraph that contains only the edges in the minibatch and their incident nodes.
      Note that the graph has an identical metagraph with the original graph.

    * If a negative sampler is given, another graph that contains the "negative edges",
      connecting the source and destination nodes yielded from the given negative sampler.

390
    * A list of MFGs necessary for computing the representation of the incident nodes
391
392
393
394
395
      of the edges in the minibatch.

    For more details, please refer to :ref:`guide-minibatch-edge-classification-sampler`
    and :ref:`guide-minibatch-link-classification-sampler`.

396
397
    Parameters
    ----------
398
    g : DGLGraph
399
        The graph.
400
401
    eids : Tensor or dict[etype, Tensor]
        The edge set in graph :attr:`g` to compute outputs.
402
    block_sampler : dgl.dataloading.BlockSampler
403
        The neighborhood sampler.
404
    device : device context, optional
405
        The device of the generated MFGs and graphs in each iteration, which should be a
406
        PyTorch device object (e.g., ``torch.device``).
407
    g_sampling : DGLGraph, optional
408
409
410
411
412
413
414
415
416
417
418
419
420
        The graph where neighborhood sampling is performed.

        One may wish to iterate over the edges in one graph while perform sampling in
        another graph.  This may be the case for iterating over validation and test
        edge set while perform neighborhood sampling on the graph formed by only
        the training edge set.

        If None, assume to be the same as ``g``.
    exclude : str, optional
        Whether and how to exclude dependencies related to the sampled edges in the
        minibatch.  Possible values are

        * None,
421
        * ``reverse_id``,
422
423
        * ``reverse_types``

424
425
        See the description of the argument with the same name in the docstring of
        :class:`~dgl.dataloading.EdgeCollator` for more details.
426
427
428
429
430
431
    reverse_eids : Tensor or dict[etype, Tensor], optional
        A tensor of reverse edge ID mapping.  The i-th element indicates the ID of
        the i-th edge's reverse edge.

        If the graph is heterogeneous, this argument requires a dictionary of edge
        types and the reverse edge ID mapping tensors.
432
433
434

        See the description of the argument with the same name in the docstring of
        :class:`~dgl.dataloading.EdgeCollator` for more details.
435
    reverse_etypes : dict[etype, etype], optional
436
437
438
439
        The mapping from the original edge types to their reverse edge types.

        See the description of the argument with the same name in the docstring of
        :class:`~dgl.dataloading.EdgeCollator` for more details.
440
441
442
    negative_sampler : callable, optional
        The negative sampler.

443
444
        See the description of the argument with the same name in the docstring of
        :class:`~dgl.dataloading.EdgeCollator` for more details.
445
    kwargs : dict
446
        Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467

    Examples
    --------
    The following example shows how to train a 3-layer GNN for edge classification on a
    set of edges ``train_eid`` on a homogeneous undirected graph.  Each node takes
    messages from all neighbors.

    Say that you have an array of source node IDs ``src`` and another array of destination
    node IDs ``dst``.  One can make it bidirectional by adding another set of edges
    that connects from ``dst`` to ``src``:

    >>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src])))

    One can then know that the ID difference of an edge and its reverse edge is ``|E|``,
    where ``|E|`` is the length of your source/destination array.  The reverse edge
    mapping can be obtained by

    >>> E = len(src)
    >>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])

    Note that the sampled edges as well as their reverse edges are removed from
468
469
470
    computation dependencies of the incident nodes.  That is, the edge will not
    involve in neighbor sampling and message aggregation.  This is a common trick
    to avoid information leakage.
471

472
    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
473
    >>> dataloader = dgl.dataloading.EdgeDataLoader(
474
    ...     g, train_eid, sampler, exclude='reverse_id',
475
476
477
478
479
480
481
482
483
    ...     reverse_eids=reverse_eids,
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for input_nodes, pair_graph, blocks in dataloader:
    ...     train_on(input_nodes, pair_graph, blocks)

    To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` on a
    homogeneous graph where each node takes messages from all neighbors (assume the
    backend is PyTorch), with 5 uniformly chosen negative samples per edge:

484
    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
485
486
    >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
    >>> dataloader = dgl.dataloading.EdgeDataLoader(
487
    ...     g, train_eid, sampler, exclude='reverse_id',
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
    ...     reverse_eids=reverse_eids, negative_sampler=neg_sampler,
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
    ...     train_on(input_nodse, pair_graph, neg_pair_graph, blocks)

    For heterogeneous graphs, the reverse of an edge may have a different edge type
    from the original edge.  For instance, consider that you have an array of
    user-item clicks, representated by a user array ``user`` and an item array ``item``.
    You may want to build a heterogeneous graph with a user-click-item relation and an
    item-clicked-by-user relation.

    >>> g = dgl.heterograph({
    ...     ('user', 'click', 'item'): (user, item),
    ...     ('item', 'clicked-by', 'user'): (item, user)})

    To train a 3-layer GNN for edge classification on a set of edges ``train_eid`` with
    type ``click``, you can write

506
    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
507
508
509
510
511
512
513
514
515
516
    >>> dataloader = dgl.dataloading.EdgeDataLoader(
    ...     g, {'click': train_eid}, sampler, exclude='reverse_types',
    ...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for input_nodes, pair_graph, blocks in dataloader:
    ...     train_on(input_nodes, pair_graph, blocks)

    To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` with type
    ``click``, you can write

517
    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
518
519
520
521
522
523
524
    >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
    >>> dataloader = dgl.dataloading.EdgeDataLoader(
    ...     g, train_eid, sampler, exclude='reverse_types',
    ...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
    ...     negative_sampler=neg_sampler,
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
525
    ...     train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
526
527
528

    See also
    --------
529
530
531
532
533
534
535
    dgl.dataloading.dataloader.EdgeCollator

    Notes
    -----
    Please refer to
    :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`
    and :ref:`User Guide Section 6 <guide-minibatch>` for usage.
536
537
538
539
540
541
542
543
544
545
546

    For end-to-end usages, please refer to the following tutorial/examples:

    * Edge classification on heterogeneous graph: GCMC

    * Link prediction on homogeneous graph: GraphSAGE for unsupervised learning

    * Link prediction on heterogeneous graph: RGCN for link prediction.
    """
    collator_arglist = inspect.getfullargspec(EdgeCollator).args

547
    def __init__(self, g, eids, block_sampler, device='cpu', **kwargs):
548
549
550
551
552
553
554
        collator_kwargs = {}
        dataloader_kwargs = {}
        for k, v in kwargs.items():
            if k in self.collator_arglist:
                collator_kwargs[k] = v
            else:
                dataloader_kwargs[k] = v
555
        self.collator = _EdgeCollator(g, eids, block_sampler, **collator_kwargs)
556
557
558
559

        assert not isinstance(g, DistGraph), \
                'EdgeDataLoader does not support DistGraph for now. ' \
                + 'Please use DistDataLoader directly.'
560
        self.dataloader = DataLoader(
561
            self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs)
562
        self.device = device
563

564
565
566
567
568
        # Precompute the CSR and CSC representations so each subprocess does not
        # duplicate.
        if dataloader_kwargs.get('num_workers', 0) > 0:
            g.create_formats_()

569
570
571
572
573
574
575
    def __iter__(self):
        """Return the iterator of the data loader."""
        return _EdgeDataLoaderIter(self)

    def __len__(self):
        """Return the number of batches of the data loader."""
        return len(self.dataloader)
576
577
578
579
580
581
582

class GraphDataLoader:
    """PyTorch dataloader for batch-iterating over a set of graphs, generating the batched
    graph and corresponding label tensor (if provided) of the said minibatch.

    Parameters
    ----------
583
    collate_fn : Function, default is None
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
        The customized collate function. Will use the default collate
        function if not given.
    kwargs : dict
        Arguments being passed to :py:class:`torch.utils.data.DataLoader`.

    Examples
    --------
    To train a GNN for graph classification on a set of graphs in ``dataset`` (assume
    the backend is PyTorch):

    >>> dataloader = dgl.dataloading.GraphDataLoader(
    ...     dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for batched_graph, labels in dataloader:
    ...     train_on(batched_graph, labels)
    """
    collator_arglist = inspect.getfullargspec(GraphCollator).args

601
    def __init__(self, dataset, collate_fn=None, **kwargs):
602
603
604
605
606
607
608
609
        collator_kwargs = {}
        dataloader_kwargs = {}
        for k, v in kwargs.items():
            if k in self.collator_arglist:
                collator_kwargs[k] = v
            else:
                dataloader_kwargs[k] = v

610
        if collate_fn is None:
611
612
            self.collate = GraphCollator(**collator_kwargs).collate
        else:
613
            self.collate = collate_fn
614
615
616
617
618
619
620
621
622
623
624
625

        self.dataloader = DataLoader(dataset=dataset,
                                     collate_fn=self.collate,
                                     **dataloader_kwargs)

    def __iter__(self):
        """Return the iterator of the data loader."""
        return iter(self.dataloader)

    def __len__(self):
        """Return the number of batches of the data loader."""
        return len(self.dataloader)