"src/graph/sampling/vscode:/vscode.git/clone" did not exist on "7359481497b1ba30d029bdfabe6a4bb6333f27ca"
__init__.py 21.1 KB
Newer Older
1
2
3
"""DGL PyTorch DataLoaders"""
import inspect
from torch.utils.data import DataLoader
4
from ..dataloader import NodeCollator, EdgeCollator, GraphCollator
5
6
7
8
9
10
11
12
13
14
15
from ...distributed import DistGraph
from ...distributed import DistDataLoader

def _remove_kwargs_dist(kwargs):
    if 'num_workers' in kwargs:
        del kwargs['num_workers']
    if 'pin_memory' in kwargs:
        del kwargs['pin_memory']
        print('Distributed DataLoader does not support pin_memory')
    return kwargs

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# The following code is a fix to the PyTorch-specific issue in
# https://github.com/dmlc/dgl/issues/2137
#
# Basically the sampled blocks/subgraphs contain the features extracted from the
# parent graph.  In DGL, the blocks/subgraphs will hold a reference to the parent
# graph feature tensor and an index tensor, so that the features could be extracted upon
# request.  However, in the context of multiprocessed sampling, we do not need to
# transmit the parent graph feature tensor from the subprocess to the main process,
# since they are exactly the same tensor, and transmitting a tensor from a subprocess
# to the main process is costly in PyTorch as it uses shared memory.  We work around
# it with the following trick:
#
# In the collator running in the sampler processes:
# For each frame in the block, we check each column and the column with the same name
# in the corresponding parent frame.  If the storage of the former column is the
# same object as the latter column, we are sure that the former column is a
# subcolumn of the latter, and set the storage of the former column as None.
#
# In the iterator of the main process:
# For each frame in the block, we check each column and the column with the same name
# in the corresponding parent frame.  If the storage of the former column is None,
# we replace it with the storage of the latter column.

def _pop_subframe_storage(subframe, frame):
    for key, col in subframe._columns.items():
        if key in frame._columns and col.storage is frame._columns[key].storage:
            col.storage = None

def _pop_subgraph_storage(subg, g):
    for ntype in subg.ntypes:
        if ntype not in g.ntypes:
            continue
        subframe = subg._node_frames[subg.get_ntype_id(ntype)]
        frame = g._node_frames[g.get_ntype_id(ntype)]
        _pop_subframe_storage(subframe, frame)
    for etype in subg.canonical_etypes:
        if etype not in g.canonical_etypes:
            continue
        subframe = subg._edge_frames[subg.get_etype_id(etype)]
        frame = g._edge_frames[g.get_etype_id(etype)]
        _pop_subframe_storage(subframe, frame)

def _pop_blocks_storage(blocks, g):
    for block in blocks:
        for ntype in block.srctypes:
            if ntype not in g.ntypes:
                continue
            subframe = block._node_frames[block.get_ntype_id_from_src(ntype)]
            frame = g._node_frames[g.get_ntype_id(ntype)]
            _pop_subframe_storage(subframe, frame)
        for ntype in block.dsttypes:
            if ntype not in g.ntypes:
                continue
            subframe = block._node_frames[block.get_ntype_id_from_dst(ntype)]
            frame = g._node_frames[g.get_ntype_id(ntype)]
            _pop_subframe_storage(subframe, frame)
        for etype in block.canonical_etypes:
            if etype not in g.canonical_etypes:
                continue
            subframe = block._edge_frames[block.get_etype_id(etype)]
            frame = g._edge_frames[g.get_etype_id(etype)]
            _pop_subframe_storage(subframe, frame)

def _restore_subframe_storage(subframe, frame):
    for key, col in subframe._columns.items():
        if col.storage is None:
            col.storage = frame._columns[key].storage

def _restore_subgraph_storage(subg, g):
    for ntype in subg.ntypes:
        if ntype not in g.ntypes:
            continue
        subframe = subg._node_frames[subg.get_ntype_id(ntype)]
        frame = g._node_frames[g.get_ntype_id(ntype)]
        _restore_subframe_storage(subframe, frame)
    for etype in subg.canonical_etypes:
        if etype not in g.canonical_etypes:
            continue
        subframe = subg._edge_frames[subg.get_etype_id(etype)]
        frame = g._edge_frames[g.get_etype_id(etype)]
        _restore_subframe_storage(subframe, frame)

def _restore_blocks_storage(blocks, g):
    for block in blocks:
        for ntype in block.srctypes:
            if ntype not in g.ntypes:
                continue
            subframe = block._node_frames[block.get_ntype_id_from_src(ntype)]
            frame = g._node_frames[g.get_ntype_id(ntype)]
            _restore_subframe_storage(subframe, frame)
        for ntype in block.dsttypes:
            if ntype not in g.ntypes:
                continue
            subframe = block._node_frames[block.get_ntype_id_from_dst(ntype)]
            frame = g._node_frames[g.get_ntype_id(ntype)]
            _restore_subframe_storage(subframe, frame)
        for etype in block.canonical_etypes:
            if etype not in g.canonical_etypes:
                continue
            subframe = block._edge_frames[block.get_etype_id(etype)]
            frame = g._edge_frames[g.get_etype_id(etype)]
            _restore_subframe_storage(subframe, frame)

class _NodeCollator(NodeCollator):
    def collate(self, items):
121
122
123
124
        # input_nodes, output_nodes, [items], blocks
        result = super().collate(items)
        _pop_blocks_storage(result[-1], self.g)
        return result
125
126
127
128

class _EdgeCollator(EdgeCollator):
    def collate(self, items):
        if self.negative_sampler is None:
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
129
            # input_nodes, pair_graph, blocks
130
131
132
133
            result = super().collate(items)
            _pop_subgraph_storage(result[1], self.g)
            _pop_blocks_storage(result[-1], self.g_sampling)
            return result
134
        else:
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
135
            # input_nodes, pair_graph, neg_pair_graph, blocks
136
137
138
139
140
            result = super().collate(items)
            _pop_subgraph_storage(result[1], self.g)
            _pop_subgraph_storage(result[2], self.g)
            _pop_blocks_storage(result[-1], self.g_sampling)
            return result
141

142
143
144
145
146
147
148
149
150
151
def _to_device(data, device):
    if isinstance(data, dict):
        for k, v in data.items():
            data[k] = v.to(device)
    elif isinstance(data, list):
        data = [item.to(device) for item in data]
    else:
        data = data.to(device)
    return data

152
153
class _NodeDataLoaderIter:
    def __init__(self, node_dataloader):
154
        self.device = node_dataloader.device
155
156
157
158
        self.node_dataloader = node_dataloader
        self.iter_ = iter(node_dataloader.dataloader)

    def __next__(self):
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
159
        # input_nodes, output_nodes, blocks
160
161
162
        result_ = next(self.iter_)
        _restore_blocks_storage(result_[-1], self.node_dataloader.collator.g)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
163
        result = [_to_device(data, self.device) for data in result_]
164
        return result
165
166
167

class _EdgeDataLoaderIter:
    def __init__(self, edge_dataloader):
168
        self.device = edge_dataloader.device
169
170
171
172
        self.edge_dataloader = edge_dataloader
        self.iter_ = iter(edge_dataloader.dataloader)

    def __next__(self):
173
174
175
        result_ = next(self.iter_)

        if self.edge_dataloader.collator.negative_sampler is not None:
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
176
177
            # input_nodes, pair_graph, neg_pair_graph, blocks
            # Otherwise, input_nodes, pair_graph, blocks
178
179
180
181
            _restore_subgraph_storage(result_[2], self.edge_dataloader.collator.g)
        _restore_subgraph_storage(result_[1], self.edge_dataloader.collator.g)
        _restore_blocks_storage(result_[-1], self.edge_dataloader.collator.g_sampling)

Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
182
        result = [_to_device(data, self.device) for data in result_]
183
        return result
184

185
class NodeDataLoader:
186
187
188
189
190
    """PyTorch dataloader for batch-iterating over a set of nodes, generating the list
    of blocks as computation dependency of the said minibatch.

    Parameters
    ----------
191
    g : DGLGraph
192
193
194
        The graph.
    nids : Tensor or dict[ntype, Tensor]
        The node set to compute outputs.
195
    block_sampler : dgl.dataloading.BlockSampler
196
        The neighborhood sampler.
197
198
199
    device : device context, optional
        The device of the generated blocks in each iteration, which should be a
        PyTorch device object (e.g., ``torch.device``).
200
    kwargs : dict
201
        Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
202
203
204
205
206
207
208

    Examples
    --------
    To train a 3-layer GNN for node classification on a set of nodes ``train_nid`` on
    a homogeneous graph where each node takes messages from all neighbors (assume
    the backend is PyTorch):

209
    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
210
211
212
213
214
215
216
217
    >>> dataloader = dgl.dataloading.NodeDataLoader(
    ...     g, train_nid, sampler,
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for input_nodes, output_nodes, blocks in dataloader:
    ...     train_on(input_nodes, output_nodes, blocks)
    """
    collator_arglist = inspect.getfullargspec(NodeCollator).args

218
    def __init__(self, g, nids, block_sampler, device='cpu', **kwargs):
219
220
221
222
223
224
225
        collator_kwargs = {}
        dataloader_kwargs = {}
        for k, v in kwargs.items():
            if k in self.collator_arglist:
                collator_kwargs[k] = v
            else:
                dataloader_kwargs[k] = v
226

227
        if isinstance(g, DistGraph):
228
            assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.'
229
230
231
            # Distributed DataLoader currently does not support heterogeneous graphs
            # and does not copy features.  Fallback to normal solution
            self.collator = NodeCollator(g, nids, block_sampler, **collator_kwargs)
232
233
234
235
            _remove_kwargs_dist(dataloader_kwargs)
            self.dataloader = DistDataLoader(self.collator.dataset,
                                             collate_fn=self.collator.collate,
                                             **dataloader_kwargs)
236
            self.is_distributed = True
237
        else:
238
            self.collator = _NodeCollator(g, nids, block_sampler, **collator_kwargs)
239
240
241
            self.dataloader = DataLoader(self.collator.dataset,
                                         collate_fn=self.collator.collate,
                                         **dataloader_kwargs)
242
            self.is_distributed = False
243
244
245
246
247

            # Precompute the CSR and CSC representations so each subprocess does not
            # duplicate.
            if dataloader_kwargs.get('num_workers', 0) > 0:
                g.create_formats_()
248
        self.device = device
249
250

    def __iter__(self):
251
        """Return the iterator of the data loader."""
252
253
254
255
256
        if self.is_distributed:
            # Directly use the iterator of DistDataLoader, which doesn't copy features anyway.
            return iter(self.dataloader)
        else:
            return _NodeDataLoaderIter(self)
257
258
259
260

    def __len__(self):
        """Return the number of batches of the data loader."""
        return len(self.dataloader)
261

262
class EdgeDataLoader:
263
264
265
266
    """PyTorch dataloader for batch-iterating over a set of edges, generating the list
    of blocks as computation dependency of the said minibatch for edge classification,
    edge regression, and link prediction.

267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    For each iteration, the object will yield

    * A tensor of input nodes necessary for computing the representation on edges, or
      a dictionary of node type names and such tensors.

    * A subgraph that contains only the edges in the minibatch and their incident nodes.
      Note that the graph has an identical metagraph with the original graph.

    * If a negative sampler is given, another graph that contains the "negative edges",
      connecting the source and destination nodes yielded from the given negative sampler.

    * A list of blocks necessary for computing the representation of the incident nodes
      of the edges in the minibatch.

    For more details, please refer to :ref:`guide-minibatch-edge-classification-sampler`
    and :ref:`guide-minibatch-link-classification-sampler`.

284
285
    Parameters
    ----------
286
    g : DGLGraph
287
        The graph.
288
289
    eids : Tensor or dict[etype, Tensor]
        The edge set in graph :attr:`g` to compute outputs.
290
    block_sampler : dgl.dataloading.BlockSampler
291
        The neighborhood sampler.
292
293
294
    device : device context, optional
        The device of the generated blocks and graphs in each iteration, which should be a
        PyTorch device object (e.g., ``torch.device``).
295
    g_sampling : DGLGraph, optional
296
297
298
299
300
301
302
303
304
305
306
307
308
        The graph where neighborhood sampling is performed.

        One may wish to iterate over the edges in one graph while perform sampling in
        another graph.  This may be the case for iterating over validation and test
        edge set while perform neighborhood sampling on the graph formed by only
        the training edge set.

        If None, assume to be the same as ``g``.
    exclude : str, optional
        Whether and how to exclude dependencies related to the sampled edges in the
        minibatch.  Possible values are

        * None,
309
        * ``reverse_id``,
310
311
        * ``reverse_types``

312
313
        See the description of the argument with the same name in the docstring of
        :class:`~dgl.dataloading.EdgeCollator` for more details.
314
    reverse_edge_ids : Tensor or dict[etype, Tensor], optional
315
316
317
318
        The mapping from the original edge IDs to the ID of their reverse edges.

        See the description of the argument with the same name in the docstring of
        :class:`~dgl.dataloading.EdgeCollator` for more details.
319
    reverse_etypes : dict[etype, etype], optional
320
321
322
323
        The mapping from the original edge types to their reverse edge types.

        See the description of the argument with the same name in the docstring of
        :class:`~dgl.dataloading.EdgeCollator` for more details.
324
325
326
    negative_sampler : callable, optional
        The negative sampler.

327
328
        See the description of the argument with the same name in the docstring of
        :class:`~dgl.dataloading.EdgeCollator` for more details.
329
    kwargs : dict
330
        Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351

    Examples
    --------
    The following example shows how to train a 3-layer GNN for edge classification on a
    set of edges ``train_eid`` on a homogeneous undirected graph.  Each node takes
    messages from all neighbors.

    Say that you have an array of source node IDs ``src`` and another array of destination
    node IDs ``dst``.  One can make it bidirectional by adding another set of edges
    that connects from ``dst`` to ``src``:

    >>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src])))

    One can then know that the ID difference of an edge and its reverse edge is ``|E|``,
    where ``|E|`` is the length of your source/destination array.  The reverse edge
    mapping can be obtained by

    >>> E = len(src)
    >>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])

    Note that the sampled edges as well as their reverse edges are removed from
352
353
354
    computation dependencies of the incident nodes.  That is, the edge will not
    involve in neighbor sampling and message aggregation.  This is a common trick
    to avoid information leakage.
355

356
    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
357
    >>> dataloader = dgl.dataloading.EdgeDataLoader(
358
    ...     g, train_eid, sampler, exclude='reverse_id',
359
360
361
362
363
364
365
366
367
    ...     reverse_eids=reverse_eids,
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for input_nodes, pair_graph, blocks in dataloader:
    ...     train_on(input_nodes, pair_graph, blocks)

    To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` on a
    homogeneous graph where each node takes messages from all neighbors (assume the
    backend is PyTorch), with 5 uniformly chosen negative samples per edge:

368
    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
369
370
    >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
    >>> dataloader = dgl.dataloading.EdgeDataLoader(
371
    ...     g, train_eid, sampler, exclude='reverse_id',
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
    ...     reverse_eids=reverse_eids, negative_sampler=neg_sampler,
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
    ...     train_on(input_nodse, pair_graph, neg_pair_graph, blocks)

    For heterogeneous graphs, the reverse of an edge may have a different edge type
    from the original edge.  For instance, consider that you have an array of
    user-item clicks, representated by a user array ``user`` and an item array ``item``.
    You may want to build a heterogeneous graph with a user-click-item relation and an
    item-clicked-by-user relation.

    >>> g = dgl.heterograph({
    ...     ('user', 'click', 'item'): (user, item),
    ...     ('item', 'clicked-by', 'user'): (item, user)})

    To train a 3-layer GNN for edge classification on a set of edges ``train_eid`` with
    type ``click``, you can write

390
    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
391
392
393
394
395
396
397
398
399
400
    >>> dataloader = dgl.dataloading.EdgeDataLoader(
    ...     g, {'click': train_eid}, sampler, exclude='reverse_types',
    ...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for input_nodes, pair_graph, blocks in dataloader:
    ...     train_on(input_nodes, pair_graph, blocks)

    To train a 3-layer GNN for link prediction on a set of edges ``train_eid`` with type
    ``click``, you can write

401
    >>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
402
403
404
405
406
407
408
409
410
411
412
    >>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
    >>> dataloader = dgl.dataloading.EdgeDataLoader(
    ...     g, train_eid, sampler, exclude='reverse_types',
    ...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
    ...     negative_sampler=neg_sampler,
    ...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
    ...     train_on(input_nodse, pair_graph, neg_pair_graph, blocks)

    See also
    --------
413
    :class:`~dgl.dataloading.dataloader.EdgeCollator`
414
415
416
417
418
419
420
421
422
423
424

    For end-to-end usages, please refer to the following tutorial/examples:

    * Edge classification on heterogeneous graph: GCMC

    * Link prediction on homogeneous graph: GraphSAGE for unsupervised learning

    * Link prediction on heterogeneous graph: RGCN for link prediction.
    """
    collator_arglist = inspect.getfullargspec(EdgeCollator).args

425
    def __init__(self, g, eids, block_sampler, device='cpu', **kwargs):
426
427
428
429
430
431
432
        collator_kwargs = {}
        dataloader_kwargs = {}
        for k, v in kwargs.items():
            if k in self.collator_arglist:
                collator_kwargs[k] = v
            else:
                dataloader_kwargs[k] = v
433
        self.collator = _EdgeCollator(g, eids, block_sampler, **collator_kwargs)
434
435
436
437

        assert not isinstance(g, DistGraph), \
                'EdgeDataLoader does not support DistGraph for now. ' \
                + 'Please use DistDataLoader directly.'
438
        self.dataloader = DataLoader(
439
            self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs)
440
        self.device = device
441

442
443
444
445
446
        # Precompute the CSR and CSC representations so each subprocess does not
        # duplicate.
        if dataloader_kwargs.get('num_workers', 0) > 0:
            g.create_formats_()

447
448
449
450
451
452
453
    def __iter__(self):
        """Return the iterator of the data loader."""
        return _EdgeDataLoaderIter(self)

    def __len__(self):
        """Return the number of batches of the data loader."""
        return len(self.dataloader)
454
455
456
457
458
459
460

class GraphDataLoader:
    """PyTorch dataloader for batch-iterating over a set of graphs, generating the batched
    graph and corresponding label tensor (if provided) of the said minibatch.

    Parameters
    ----------
461
    collate_fn : Function, default is None
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
        The customized collate function. Will use the default collate
        function if not given.
    kwargs : dict
        Arguments being passed to :py:class:`torch.utils.data.DataLoader`.

    Examples
    --------
    To train a GNN for graph classification on a set of graphs in ``dataset`` (assume
    the backend is PyTorch):

    >>> dataloader = dgl.dataloading.GraphDataLoader(
    ...     dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
    >>> for batched_graph, labels in dataloader:
    ...     train_on(batched_graph, labels)
    """
    collator_arglist = inspect.getfullargspec(GraphCollator).args

479
    def __init__(self, dataset, collate_fn=None, **kwargs):
480
481
482
483
484
485
486
487
        collator_kwargs = {}
        dataloader_kwargs = {}
        for k, v in kwargs.items():
            if k in self.collator_arglist:
                collator_kwargs[k] = v
            else:
                dataloader_kwargs[k] = v

488
        if collate_fn is None:
489
490
            self.collate = GraphCollator(**collator_kwargs).collate
        else:
491
            self.collate = collate_fn
492
493
494
495
496
497
498
499
500
501
502
503

        self.dataloader = DataLoader(dataset=dataset,
                                     collate_fn=self.collate,
                                     **dataloader_kwargs)

    def __iter__(self):
        """Return the iterator of the data loader."""
        return iter(self.dataloader)

    def __len__(self):
        """Return the number of batches of the data loader."""
        return len(self.dataloader)