__init__.py 36.2 KB
Newer Older
1
2
"""DGL PyTorch DataLoaders"""
import inspect
3
import math
4
from distutils.version import LooseVersion
5
import torch as th
6
from torch.utils.data import DataLoader
7
from torch.utils.data.distributed import DistributedSampler
8
import torch.distributed as dist
9
from ..dataloader import NodeCollator, EdgeCollator, GraphCollator
10
11
from ...distributed import DistGraph
from ...distributed import DistDataLoader
12
13
from ...ndarray import NDArray as DGLNDArray
from ... import backend as F
14
from ...base import DGLError
15

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
PYTORCH_VER = LooseVersion(th.__version__)
PYTORCH_16 = PYTORCH_VER >= LooseVersion("1.6.0")
PYTORCH_17 = PYTORCH_VER >= LooseVersion("1.7.0")

def _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed):
    # Note: will change the content of dataloader_kwargs
    dist_sampler_kwargs = {'shuffle': dataloader_kwargs['shuffle']}
    dataloader_kwargs['shuffle'] = False
    if PYTORCH_16:
        dist_sampler_kwargs['seed'] = ddp_seed
    if PYTORCH_17:
        dist_sampler_kwargs['drop_last'] = dataloader_kwargs['drop_last']
        dataloader_kwargs['drop_last'] = False

    return DistributedSampler(dataset, **dist_sampler_kwargs)

32
33
34
35
36
37
38
class _ScalarDataBatcherIter:
    def __init__(self, dataset, batch_size, drop_last):
        self.dataset = dataset
        self.batch_size = batch_size
        self.index = 0
        self.drop_last = drop_last

39
    # Make this an iterator for PyTorch Lightning compatibility
40
41
42
    def __iter__(self):
        return self

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    def __next__(self):
        num_items = self.dataset.shape[0]
        if self.index >= num_items:
            raise StopIteration
        end_idx = self.index + self.batch_size
        if end_idx > num_items:
            if self.drop_last:
                raise StopIteration
            end_idx = num_items
        batch = self.dataset[self.index:end_idx]
        self.index += self.batch_size

        return batch

class _ScalarDataBatcher(th.utils.data.IterableDataset):
    """Custom Dataset wrapper to return mini-batches as tensors, rather than as
    lists. When the dataset is on the GPU, this significantly reduces
    the overhead. For the case of a batch size of 1024, instead of giving a
    list of 1024 tensors to the collator, a single tensor of 1024 dimensions
    is passed in.
    """
    def __init__(self, dataset, shuffle=False, batch_size=1,
65
                 drop_last=False, use_ddp=False, ddp_seed=0):
66
67
68
69
70
        super(_ScalarDataBatcher).__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        self.use_ddp = use_ddp
        if use_ddp:
            self.rank = dist.get_rank()
            self.num_replicas = dist.get_world_size()
            self.seed = ddp_seed
            self.epoch = 0
            # The following code (and the idea of cross-process shuffling with the same seed)
            # comes from PyTorch.  See torch/utils/data/distributed.py for details.

            # If the dataset length is evenly divisible by # of replicas, then there
            # is no need to drop any sample, since the dataset will be split evenly.
            if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore
                # Split to nearest available length that is evenly divisible.
                # This is to ensure each rank receives the same amount of data when
                # using this Sampler.
                self.num_samples = math.ceil(
                    # `type:ignore` is required because Dataset cannot provide a default __len__
                    # see NOTE in pytorch/torch/utils/data/sampler.py
                    (len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore
                )
            else:
                self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore
            self.total_size = self.num_samples * self.num_replicas
94
95

    def __iter__(self):
96
97
98
99
100
101
        if self.use_ddp:
            return self._iter_ddp()
        else:
            return self._iter_non_ddp()

    def _divide_by_worker(self, dataset):
102
103
104
105
106
107
108
109
110
111
112
        worker_info = th.utils.data.get_worker_info()
        if worker_info:
            # worker gets only a fraction of the dataset
            chunk_size = dataset.shape[0] // worker_info.num_workers
            left_over = dataset.shape[0] % worker_info.num_workers
            start = (chunk_size*worker_info.id) + min(left_over, worker_info.id)
            end = start + chunk_size + (worker_info.id < left_over)
            assert worker_info.id < worker_info.num_workers-1 or \
                end == dataset.shape[0]
            dataset = dataset[start:end]

113
114
115
116
117
        return dataset

    def _iter_non_ddp(self):
        dataset = self._divide_by_worker(self.dataset)

118
119
120
121
122
123
        if self.shuffle:
            # permute the dataset
            perm = th.randperm(dataset.shape[0], device=dataset.device)
            dataset = dataset[perm]

        return _ScalarDataBatcherIter(dataset, self.batch_size, self.drop_last)
124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    def _iter_ddp(self):
        # The following code (and the idea of cross-process shuffling with the same seed)
        # comes from PyTorch.  See torch/utils/data/distributed.py for details.
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = th.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = th.randperm(len(self.dataset), generator=g)
        else:
            indices = th.arange(len(self.dataset))

        if not self.drop_last:
            # add extra samples to make it evenly divisible
            indices = th.cat([indices, indices[:(self.total_size - indices.shape[0])]])
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[:self.total_size]
        assert indices.shape[0] == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert indices.shape[0] == self.num_samples

        # Dividing by worker is our own stuff.
        dataset = self._divide_by_worker(self.dataset[indices])
        return _ScalarDataBatcherIter(dataset, self.batch_size, self.drop_last)

152
    def __len__(self):
153
154
155
156
157
158
        num_samples = self.num_samples if self.use_ddp else self.dataset.shape[0]
        return (num_samples + (0 if self.drop_last else self.batch_size - 1)) // self.batch_size

    def set_epoch(self, epoch):
        """Set epoch number for distributed training."""
        self.epoch = epoch
159

160
161
162
163
164
165
166
167
def _remove_kwargs_dist(kwargs):
    if 'num_workers' in kwargs:
        del kwargs['num_workers']
    if 'pin_memory' in kwargs:
        del kwargs['pin_memory']
        print('Distributed DataLoader does not support pin_memory')
    return kwargs

168
169
170
# The following code is a fix to the PyTorch-specific issue in
# https://github.com/dmlc/dgl/issues/2137
#
171
172
# Basically the sampled MFGs/subgraphs contain the features extracted from the
# parent graph.  In DGL, the MFGs/subgraphs will hold a reference to the parent
173
174
175
176
177
178
179
180
# graph feature tensor and an index tensor, so that the features could be extracted upon
# request.  However, in the context of multiprocessed sampling, we do not need to
# transmit the parent graph feature tensor from the subprocess to the main process,
# since they are exactly the same tensor, and transmitting a tensor from a subprocess
# to the main process is costly in PyTorch as it uses shared memory.  We work around
# it with the following trick:
#
# In the collator running in the sampler processes:
181
# For each frame in the MFG, we check each column and the column with the same name
182
183
184
185
186
# in the corresponding parent frame.  If the storage of the former column is the
# same object as the latter column, we are sure that the former column is a
# subcolumn of the latter, and set the storage of the former column as None.
#
# In the iterator of the main process:
187
# For each frame in the MFG, we check each column and the column with the same name
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
# in the corresponding parent frame.  If the storage of the former column is None,
# we replace it with the storage of the latter column.

def _pop_subframe_storage(subframe, frame):
    for key, col in subframe._columns.items():
        if key in frame._columns and col.storage is frame._columns[key].storage:
            col.storage = None

def _pop_subgraph_storage(subg, g):
    for ntype in subg.ntypes:
        if ntype not in g.ntypes:
            continue
        subframe = subg._node_frames[subg.get_ntype_id(ntype)]
        frame = g._node_frames[g.get_ntype_id(ntype)]
        _pop_subframe_storage(subframe, frame)
    for etype in subg.canonical_etypes:
        if etype not in g.canonical_etypes:
            continue
        subframe = subg._edge_frames[subg.get_etype_id(etype)]
        frame = g._edge_frames[g.get_etype_id(etype)]
        _pop_subframe_storage(subframe, frame)

def _pop_blocks_storage(blocks, g):
    for block in blocks:
        for ntype in block.srctypes:
            if ntype not in g.ntypes:
                continue
            subframe = block._node_frames[block.get_ntype_id_from_src(ntype)]
            frame = g._node_frames[g.get_ntype_id(ntype)]
            _pop_subframe_storage(subframe, frame)
        for ntype in block.dsttypes:
            if ntype not in g.ntypes:
                continue
            subframe = block._node_frames[block.get_ntype_id_from_dst(ntype)]
            frame = g._node_frames[g.get_ntype_id(ntype)]
            _pop_subframe_storage(subframe, frame)
        for etype in block.canonical_etypes:
            if etype not in g.canonical_etypes:
                continue
            subframe = block._edge_frames[block.get_etype_id(etype)]
            frame = g._edge_frames[g.get_etype_id(etype)]
            _pop_subframe_storage(subframe, frame)

def _restore_subframe_storage(subframe, frame):
    for key, col in subframe._columns.items():
        if col.storage is None:
            col.storage = frame._columns[key].storage

def _restore_subgraph_storage(subg, g):
    for ntype in subg.ntypes:
        if ntype not in g.ntypes:
            continue
        subframe = subg._node_frames[subg.get_ntype_id(ntype)]
        frame = g._node_frames[g.get_ntype_id(ntype)]
        _restore_subframe_storage(subframe, frame)
    for etype in subg.canonical_etypes:
        if etype not in g.canonical_etypes:
            continue
        subframe = subg._edge_frames[subg.get_etype_id(etype)]
        frame = g._edge_frames[g.get_etype_id(etype)]
        _restore_subframe_storage(subframe, frame)

def _restore_blocks_storage(blocks, g):
    for block in blocks:
        for ntype in block.srctypes:
            if ntype not in g.ntypes:
                continue
            subframe = block._node_frames[block.get_ntype_id_from_src(ntype)]
            frame = g._node_frames[g.get_ntype_id(ntype)]
            _restore_subframe_storage(subframe, frame)
        for ntype in block.dsttypes:
            if ntype not in g.ntypes:
                continue
            subframe = block._node_frames[block.get_ntype_id_from_dst(ntype)]
            frame = g._node_frames[g.get_ntype_id(ntype)]
            _restore_subframe_storage(subframe, frame)
        for etype in block.canonical_etypes:
            if etype not in g.canonical_etypes:
                continue
            subframe = block._edge_frames[block.get_etype_id(etype)]
            frame = g._edge_frames[g.get_etype_id(etype)]
            _restore_subframe_storage(subframe, frame)

class _NodeCollator(NodeCollator):
    def collate(self, items):
273
        # input_nodes, output_nodes, blocks
274
275
276
        result = super().collate(items)
        _pop_blocks_storage(result[-1], self.g)
        return result
277
278
279
280

class _EdgeCollator(EdgeCollator):
    def collate(self, items):
        if self.negative_sampler is None:
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
281
            # input_nodes, pair_graph, blocks
282
283
284
285
            result = super().collate(items)
            _pop_subgraph_storage(result[1], self.g)
            _pop_blocks_storage(result[-1], self.g_sampling)
            return result
286
        else:
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
287
            # input_nodes, pair_graph, neg_pair_graph, blocks
288
289
290
291
292
            result = super().collate(items)
            _pop_subgraph_storage(result[1], self.g)
            _pop_subgraph_storage(result[2], self.g)
            _pop_blocks_storage(result[-1], self.g_sampling)
            return result
293

294
295
296
297
298
299
300
301
302
303
def _to_device(data, device):
    if isinstance(data, dict):
        for k, v in data.items():
            data[k] = v.to(device)
    elif isinstance(data, list):
        data = [item.to(device) for item in data]
    else:
        data = data.to(device)
    return data

304
305
class _NodeDataLoaderIter:
    def __init__(self, node_dataloader):
306
        self.device = node_dataloader.device
307
308
309
        self.node_dataloader = node_dataloader
        self.iter_ = iter(node_dataloader.dataloader)

310
    # Make this an iterator for PyTorch Lightning compatibility
311
312
313
    def __iter__(self):
        return self

314
    def __next__(self):
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
315
        # input_nodes, output_nodes, blocks
316
317
318
        result_ = next(self.iter_)
        _restore_blocks_storage(result_[-1], self.node_dataloader.collator.g)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
319
        result = [_to_device(data, self.device) for data in result_]
320
        return result
321
322
323

class _EdgeDataLoaderIter:
    def __init__(self, edge_dataloader):
324
        self.device = edge_dataloader.device
325
326
327
        self.edge_dataloader = edge_dataloader
        self.iter_ = iter(edge_dataloader.dataloader)

328
    # Make this an iterator for PyTorch Lightning compatibility
329
330
331
    def __iter__(self):
        return self

332
    def __next__(self):
333
334
335
        result_ = next(self.iter_)

        if self.edge_dataloader.collator.negative_sampler is not None:
336
            # input_nodes, pair_graph, neg_pair_graph, blocks if None.
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
337
            # Otherwise, input_nodes, pair_graph, blocks
338
339
340
341
            _restore_subgraph_storage(result_[2], self.edge_dataloader.collator.g)
        _restore_subgraph_storage(result_[1], self.edge_dataloader.collator.g)
        _restore_blocks_storage(result_[-1], self.edge_dataloader.collator.g_sampling)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
342
        result = [_to_device(data, self.device) for data in result_]
343
        return result
344

345
class NodeDataLoader:
346
    """PyTorch dataloader for batch-iterating over a set of nodes, generating the list
347
    of message flow graphs (MFGs) as computation dependency of the said minibatch.
348
349
350

    Parameters
    ----------
351
    g : DGLGraph
352
353
354
        The graph.
    nids : Tensor or dict[ntype, Tensor]
        The node set to compute outputs.
355
    block_sampler : dgl.dataloading.BlockSampler
356
        The neighborhood sampler.
357
    device : device context, optional
358
        The device of the generated MFGs in each iteration, which should be a
359
        PyTorch device object (e.g., ``torch.device``).
360
361
362
    use_ddp : boolean, optional
        If True, tells the DataLoader to split the training set for each
        participating process appropriately using
363
        :class:`torch.utils.data.distributed.DistributedSampler`.
364
365

        Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
366
367
368
369
370
    ddp_seed : int, optional
        The seed for shuffling the dataset in
        :class:`torch.utils.data.distributed.DistributedSampler`.

        Only effective when :attr:`use_ddp` is True.
371
    kwargs : dict
372
        Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
373
374
375
376
377
378
379

    Examples
    --------
    To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on
    a homogeneous graph where each node takes messages from all neighbors (assume
    the backend is PyTorch):

380
    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
381
382
383
384
385
    >>> dataloader = dgl.dataloading.NodeDataLoader(
    ...     g, train_nid, sampler,
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for input_nodes, output_nodes, blocks in dataloader:
    ...     train_on(input_nodes, output_nodes, blocks)
386

387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    **Using with Distributed Data Parallel**

    If you are using PyTorch's distributed training (e.g. when using
    :mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by turning
    on the `use_ddp` option:

    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
    >>> dataloader = dgl.dataloading.NodeDataLoader(
    ...     g, train_nid, sampler, use_ddp=True,
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for epoch in range(start_epoch, n_epochs):
    ...     dataloader.set_epoch(epoch)
    ...     for input_nodes, output_nodes, blocks in dataloader:
    ...         train_on(input_nodes, output_nodes, blocks)

402
403
404
405
406
    Notes
    -----
    Please refer to
    :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`
    and :ref:`User Guide Section 6 <guide-minibatch>` for usage.
407
408
409
    """
    collator_arglist = inspect.getfullargspec(NodeCollator).args

410
    def __init__(self, g, nids, block_sampler, device='cpu', use_ddp=False, ddp_seed=0, **kwargs):
411
412
413
414
415
416
417
        collator_kwargs = {}
        dataloader_kwargs = {}
        for k, v in kwargs.items():
            if k in self.collator_arglist:
                collator_kwargs[k] = v
            else:
                dataloader_kwargs[k] = v
418

419
        if isinstance(g, DistGraph):
420
            assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.'
421
422
423
            # Distributed DataLoader currently does not support heterogeneous graphs
            # and does not copy features.  Fallback to normal solution
            self.collator = NodeCollator(g, nids, block_sampler, **collator_kwargs)
424
425
426
427
            _remove_kwargs_dist(dataloader_kwargs)
            self.dataloader = DistDataLoader(self.collator.dataset,
                                             collate_fn=self.collator.collate,
                                             **dataloader_kwargs)
428
            self.is_distributed = True
429
        else:
430
            self.collator = _NodeCollator(g, nids, block_sampler, **collator_kwargs)
431
            dataset = self.collator.dataset
432
            use_scalar_batcher = False
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455

            if th.device(device) != th.device('cpu'):
                # Only use the '_ScalarDataBatcher' when for the GPU, as it
                # doens't seem to have a performance benefit on the CPU.
                assert 'num_workers' not in dataloader_kwargs or \
                    dataloader_kwargs['num_workers'] == 0, \
                    'When performing dataloading from the GPU, num_workers ' \
                    'must be zero.'

                batch_size = dataloader_kwargs.get('batch_size', 0)

                if batch_size > 1:
                    if isinstance(dataset, DGLNDArray):
                        # the dataset needs to be a torch tensor for the
                        # _ScalarDataBatcher
                        dataset = F.zerocopy_from_dgl_ndarray(dataset)
                    if isinstance(dataset, th.Tensor):
                        shuffle = dataloader_kwargs.get('shuffle', False)
                        drop_last = dataloader_kwargs.get('drop_last', False)
                        # manually batch into tensors
                        dataset = _ScalarDataBatcher(dataset,
                                                     batch_size=batch_size,
                                                     shuffle=shuffle,
456
457
458
                                                     drop_last=drop_last,
                                                     use_ddp=use_ddp,
                                                     ddp_seed=ddp_seed)
459
460
461
462
                        # need to overwrite things that will be handled by the batcher
                        dataloader_kwargs['batch_size'] = None
                        dataloader_kwargs['shuffle'] = False
                        dataloader_kwargs['drop_last'] = False
463
464
                        use_scalar_batcher = True
                        self.scalar_batcher = dataset
465

466
            self.use_ddp = use_ddp
467
468
            self.use_scalar_batcher = use_scalar_batcher
            if use_ddp and not use_scalar_batcher:
469
                self.dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed)
470
471
                dataloader_kwargs['sampler'] = self.dist_sampler

472
473
474
475
            self.dataloader = DataLoader(
                dataset,
                collate_fn=self.collator.collate,
                **dataloader_kwargs)
476

477
            self.is_distributed = False
478
479
480
481
482

            # Precompute the CSR and CSC representations so each subprocess does not
            # duplicate.
            if dataloader_kwargs.get('num_workers', 0) > 0:
                g.create_formats_()
483
        self.device = device
484
485

    def __iter__(self):
486
        """Return the iterator of the data loader."""
487
488
489
490
491
        if self.is_distributed:
            # Directly use the iterator of DistDataLoader, which doesn't copy features anyway.
            return iter(self.dataloader)
        else:
            return _NodeDataLoaderIter(self)
492
493
494
495

    def __len__(self):
        """Return the number of batches of the data loader."""
        return len(self.dataloader)
496

497
498
499
500
501
502
503
504
505
506
507
508
509
510
    def set_epoch(self, epoch):
        """Sets the epoch number for the underlying sampler which ensures all replicas
        to use a different ordering for each epoch.

        Only available when :attr:`use_ddp` is True.

        Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.

        Parameters
        ----------
        epoch : int
            The epoch number.
        """
        if self.use_ddp:
511
512
513
514
            if self.use_scalar_batcher:
                self.scalar_batcher.set_epoch(epoch)
            else:
                self.dist_sampler.set_epoch(epoch)
515
516
517
        else:
            raise DGLError('set_epoch is only available when use_ddp is True.')

518
class EdgeDataLoader:
519
    """PyTorch dataloader for batch-iterating over a set of edges, generating the list
520
521
    of message flow graphs (MFGs) as computation dependency of the said minibatch for
    edge classification, edge regression, and link prediction.
522

523
524
525
526
527
528
529
530
531
532
533
    For each iteration, the object will yield

    * A tensor of input nodes necessary for computing the representation on edges, or
      a dictionary of node type names and such tensors.

    * A subgraph that contains only the edges in the minibatch and their incident nodes.
      Note that the graph has an identical metagraph with the original graph.

    * If a negative sampler is given, another graph that contains the "negative edges",
      connecting the source and destination nodes yielded from the given negative sampler.

534
    * A list of MFGs necessary for computing the representation of the incident nodes
535
536
537
538
539
      of the edges in the minibatch.

    For more details, please refer to :ref:`guide-minibatch-edge-classification-sampler`
    and :ref:`guide-minibatch-link-classification-sampler`.

540
541
    Parameters
    ----------
542
    g : DGLGraph
543
        The graph.
544
545
    eids : Tensor or dict[etype, Tensor]
        The edge set in graph :attr:`g` to compute outputs.
546
    block_sampler : dgl.dataloading.BlockSampler
547
        The neighborhood sampler.
548
    device : device context, optional
549
        The device of the generated MFGs and graphs in each iteration, which should be a
550
        PyTorch device object (e.g., ``torch.device``).
551
    g_sampling : DGLGraph, optional
552
553
554
555
556
557
558
559
560
561
562
563
564
        The graph where neighborhood sampling is performed.

        One may wish to iterate over the edges in one graph while perform sampling in
        another graph.  This may be the case for iterating over validation and test
        edge set while perform neighborhood sampling on the graph formed by only
        the training edge set.

        If None, assume to be the same as ``g``.
    exclude : str, optional
        Whether and how to exclude dependencies related to the sampled edges in the
        minibatch.  Possible values are

        * None,
565
        * ``reverse_id``,
566
567
        * ``reverse_types``

568
569
        See the description of the argument with the same name in the docstring of
        :class:`~dgl.dataloading.EdgeCollator` for more details.
570
571
572
573
574
575
    reverse_eids : Tensor or dict[etype, Tensor], optional
        A tensor of reverse edge ID mapping.  The i-th element indicates the ID of
        the i-th edge's reverse edge.

        If the graph is heterogeneous, this argument requires a dictionary of edge
        types and the reverse edge ID mapping tensors.
576
577
578

        See the description of the argument with the same name in the docstring of
        :class:`~dgl.dataloading.EdgeCollator` for more details.
579
    reverse_etypes : dict[etype, etype], optional
580
581
582
583
        The mapping from the original edge types to their reverse edge types.

        See the description of the argument with the same name in the docstring of
        :class:`~dgl.dataloading.EdgeCollator` for more details.
584
585
586
    negative_sampler : callable, optional
        The negative sampler.

587
588
        See the description of the argument with the same name in the docstring of
        :class:`~dgl.dataloading.EdgeCollator` for more details.
589
590
591
592
593
594
595
596
597
    use_ddp : boolean, optional
        If True, tells the DataLoader to split the training set for each
        participating process appropriately using
        :mod:`torch.utils.data.distributed.DistributedSampler`.

        The dataloader will have a :attr:`dist_sampler` attribute to set the
        epoch number, as recommended by PyTorch.

        Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
598
599
600
601
602
    ddp_seed : int, optional
        The seed for shuffling the dataset in
        :class:`torch.utils.data.distributed.DistributedSampler`.

        Only effective when :attr:`use_ddp` is True.
603
    kwargs : dict
604
        Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625

    Examples
    --------
    The following example shows how to train a 3-layer GNN for edge classification on a
    set of edges ``train_eid`` on a homogeneous undirected graph.  Each node takes
    messages from all neighbors.

    Say that you have an array of source node IDs ``src`` and another array of destination
    node IDs ``dst``.  One can make it bidirectional by adding another set of edges
    that connects from ``dst`` to ``src``:

    >>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src])))

    One can then know that the ID difference of an edge and its reverse edge is ``|E|``,
    where ``|E|`` is the length of your source/destination array.  The reverse edge
    mapping can be obtained by

    >>> E = len(src)
    >>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])

    Note that the sampled edges as well as their reverse edges are removed from
626
627
628
    computation dependencies of the incident nodes.  That is, the edge will not
    involve in neighbor sampling and message aggregation.  This is a common trick
    to avoid information leakage.
629

630
    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
631
    >>> dataloader = dgl.dataloading.EdgeDataLoader(
632
    ...     g, train_eid, sampler, exclude='reverse_id',
633
634
635
636
637
638
639
640
641
    ...     reverse_eids=reverse_eids,
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for input_nodes, pair_graph, blocks in dataloader:
    ...     train_on(input_nodes, pair_graph, blocks)

    To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` on a
    homogeneous graph where each node takes messages from all neighbors (assume the
    backend is PyTorch), with 5 uniformly chosen negative samples per edge:

642
    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
643
644
    >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
    >>> dataloader = dgl.dataloading.EdgeDataLoader(
645
    ...     g, train_eid, sampler, exclude='reverse_id',
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
    ...     reverse_eids=reverse_eids, negative_sampler=neg_sampler,
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
    ...     train_on(input_nodse, pair_graph, neg_pair_graph, blocks)

    For heterogeneous graphs, the reverse of an edge may have a different edge type
    from the original edge.  For instance, consider that you have an array of
    user-item clicks, representated by a user array ``user`` and an item array ``item``.
    You may want to build a heterogeneous graph with a user-click-item relation and an
    item-clicked-by-user relation.

    >>> g = dgl.heterograph({
    ...     ('user', 'click', 'item'): (user, item),
    ...     ('item', 'clicked-by', 'user'): (item, user)})

    To train a 3-layer GNN for edge classification on a set of edges ``train_eid`` with
    type ``click``, you can write

664
    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
665
666
667
668
669
670
671
672
673
674
    >>> dataloader = dgl.dataloading.EdgeDataLoader(
    ...     g, {'click': train_eid}, sampler, exclude='reverse_types',
    ...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for input_nodes, pair_graph, blocks in dataloader:
    ...     train_on(input_nodes, pair_graph, blocks)

    To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` with type
    ``click``, you can write

675
    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
676
677
678
679
680
681
682
    >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
    >>> dataloader = dgl.dataloading.EdgeDataLoader(
    ...     g, train_eid, sampler, exclude='reverse_types',
    ...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
    ...     negative_sampler=neg_sampler,
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
683
    ...     train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
684

685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
    **Using with Distributed Data Parallel**

    If you are using PyTorch's distributed training (e.g. when using
    :mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by
    turning on the :attr:`use_ddp` option:

    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
    >>> dataloader = dgl.dataloading.EdgeDataLoader(
    ...     g, train_eid, sampler, use_ddp=True, exclude='reverse_id',
    ...     reverse_eids=reverse_eids,
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for epoch in range(start_epoch, n_epochs):
    ...     dataloader.set_epoch(epoch)
    ...     for input_nodes, pair_graph, blocks in dataloader:
    ...         train_on(input_nodes, pair_graph, blocks)

701
702
    See also
    --------
703
704
705
706
707
708
709
    dgl.dataloading.dataloader.EdgeCollator

    Notes
    -----
    Please refer to
    :doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`
    and :ref:`User Guide Section 6 <guide-minibatch>` for usage.
710
711
712
713
714
715
716
717
718
719
720

    For end-to-end usages, please refer to the following tutorial/examples:

    * Edge classification on heterogeneous graph: GCMC

    * Link prediction on homogeneous graph: GraphSAGE for unsupervised learning

    * Link prediction on heterogeneous graph: RGCN for link prediction.
    """
    collator_arglist = inspect.getfullargspec(EdgeCollator).args

721
    def __init__(self, g, eids, block_sampler, device='cpu', use_ddp=False, ddp_seed=0, **kwargs):
722
723
724
725
726
727
728
        collator_kwargs = {}
        dataloader_kwargs = {}
        for k, v in kwargs.items():
            if k in self.collator_arglist:
                collator_kwargs[k] = v
            else:
                dataloader_kwargs[k] = v
729
        self.collator = _EdgeCollator(g, eids, block_sampler, **collator_kwargs)
730
        dataset = self.collator.dataset
731
732
733
734

        assert not isinstance(g, DistGraph), \
                'EdgeDataLoader does not support DistGraph for now. ' \
                + 'Please use DistDataLoader directly.'
735
736
737

        self.use_ddp = use_ddp
        if use_ddp:
738
            self.dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed)
739
740
            dataloader_kwargs['sampler'] = self.dist_sampler

741
        self.dataloader = DataLoader(
742
743
744
745
            dataset,
            collate_fn=self.collator.collate,
            **dataloader_kwargs)

746
        self.device = device
747

748
749
750
751
752
        # Precompute the CSR and CSC representations so each subprocess does not
        # duplicate.
        if dataloader_kwargs.get('num_workers', 0) > 0:
            g.create_formats_()

753
754
755
756
757
758
759
    def __iter__(self):
        """Return the iterator of the data loader."""
        return _EdgeDataLoaderIter(self)

    def __len__(self):
        """Return the number of batches of the data loader."""
        return len(self.dataloader)
760

761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
    def set_epoch(self, epoch):
        """Sets the epoch number for the underlying sampler which ensures all replicas
        to use a different ordering for each epoch.

        Only available when :attr:`use_ddp` is True.

        Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.

        Parameters
        ----------
        epoch : int
            The epoch number.
        """
        if self.use_ddp:
            self.dist_sampler.set_epoch(epoch)
        else:
            raise DGLError('set_epoch is only available when use_ddp is True.')

779
780
781
782
783
784
class GraphDataLoader:
    """PyTorch dataloader for batch-iterating over a set of graphs, generating the batched
    graph and corresponding label tensor (if provided) of the said minibatch.

    Parameters
    ----------
785
    collate_fn : Function, default is None
786
787
        The customized collate function. Will use the default collate
        function if not given.
788
789
790
791
792
793
794
795
796
797
798
    use_ddp : boolean, optional
        If True, tells the DataLoader to split the training set for each
        participating process appropriately using
        :class:`torch.utils.data.distributed.DistributedSampler`.

        Overrides the :attr:`sampler` argument of :class:`torch.utils.data.DataLoader`.
    ddp_seed : int, optional
        The seed for shuffling the dataset in
        :class:`torch.utils.data.distributed.DistributedSampler`.

        Only effective when :attr:`use_ddp` is True.
799
800
801
802
803
804
805
806
807
808
809
810
    kwargs : dict
        Arguments being passed to :py:class:`torch.utils.data.DataLoader`.

    Examples
    --------
    To train a GNN for graph classification on a set of graphs in ``dataset`` (assume
    the backend is PyTorch):

    >>> dataloader = dgl.dataloading.GraphDataLoader(
    ...     dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for batched_graph, labels in dataloader:
    ...     train_on(batched_graph, labels)
811
812
813
814
815
816
817
818
819
820
821
822
823

    **Using with Distributed Data Parallel**

    If you are using PyTorch's distributed training (e.g. when using
    :mod:`torch.nn.parallel.DistributedDataParallel`), you can train the model by
    turning on the :attr:`use_ddp` option:

    >>> dataloader = dgl.dataloading.GraphDataLoader(
    ...     dataset, use_ddp=True, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for epoch in range(start_epoch, n_epochs):
    ...     dataloader.set_epoch(epoch)
    ...     for batched_graph, labels in dataloader:
    ...         train_on(batched_graph, labels)
824
825
826
    """
    collator_arglist = inspect.getfullargspec(GraphCollator).args

827
    def __init__(self, dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs):
828
829
830
831
832
833
834
835
        collator_kwargs = {}
        dataloader_kwargs = {}
        for k, v in kwargs.items():
            if k in self.collator_arglist:
                collator_kwargs[k] = v
            else:
                dataloader_kwargs[k] = v

836
        if collate_fn is None:
837
838
            self.collate = GraphCollator(**collator_kwargs).collate
        else:
839
            self.collate = collate_fn
840

841
842
        self.use_ddp = use_ddp
        if use_ddp:
843
            self.dist_sampler = _create_dist_sampler(dataset, dataloader_kwargs, ddp_seed)
844
845
            dataloader_kwargs['sampler'] = self.dist_sampler

846
847
848
849
850
851
852
853
854
855
856
        self.dataloader = DataLoader(dataset=dataset,
                                     collate_fn=self.collate,
                                     **dataloader_kwargs)

    def __iter__(self):
        """Return the iterator of the data loader."""
        return iter(self.dataloader)

    def __len__(self):
        """Return the number of batches of the data loader."""
        return len(self.dataloader)
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874

    def set_epoch(self, epoch):
        """Sets the epoch number for the underlying sampler which ensures all replicas
        to use a different ordering for each epoch.

        Only available when :attr:`use_ddp` is True.

        Calls :meth:`torch.utils.data.distributed.DistributedSampler.set_epoch`.

        Parameters
        ----------
        epoch : int
            The epoch number.
        """
        if self.use_ddp:
            self.dist_sampler.set_epoch(epoch)
        else:
            raise DGLError('set_epoch is only available when use_ddp is True.')