Unverified Commit e2e524df authored by Tong He's avatar Tong He Committed by GitHub
Browse files

[Feature] Add GraphDataLoader implementation (#2496)

* add graph dataloader

* add to doc

* fix

* fix

* fix docstring

* update according to torch default_collate

* add unittest

* fix

* fix lint

* fix
parent fa5f4d6a
...@@ -16,6 +16,7 @@ and an ``EdgeDataLoader`` for edge/link prediction task. ...@@ -16,6 +16,7 @@ and an ``EdgeDataLoader`` for edge/link prediction task.
.. autoclass:: NodeDataLoader .. autoclass:: NodeDataLoader
.. autoclass:: EdgeDataLoader .. autoclass:: EdgeDataLoader
.. autoclass:: GraphDataLoader
.. _api-dataloading-neighbor-sampling: .. _api-dataloading-neighbor-sampling:
Neighbor Sampler Neighbor Sampler
......
"""Data loaders""" """Data loaders"""
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from abc import ABC, abstractproperty, abstractmethod from abc import ABC, abstractproperty, abstractmethod
import re
import numpy as np import numpy as np
from .. import transform from .. import transform
from ..base import NID, EID from ..base import NID, EID
from .. import backend as F from .. import backend as F
from .. import utils from .. import utils
from ..batch import batch
from ..convert import heterograph from ..convert import heterograph
from ..heterograph import DGLHeteroGraph as DGLGraph
from ..distributed.dist_graph import DistGraph from ..distributed.dist_graph import DistGraph
# pylint: disable=unused-argument # pylint: disable=unused-argument
...@@ -678,3 +681,82 @@ class EdgeCollator(Collator): ...@@ -678,3 +681,82 @@ class EdgeCollator(Collator):
return self._collate(items) return self._collate(items)
else: else:
return self._collate_with_negative_sampling(items) return self._collate_with_negative_sampling(items)
class GraphCollator(object):
"""Given a set of graphs as well as their graph-level data, the collate function will batch the
graphs into a batched graph, and stack the tensors into a single bigger tensor. If the
example is a container (such as sequences or mapping), the collate function preserves
the structure and collates each of the elements recursively.
If the set of graphs has no graph-level data, the collate function will yield a batched graph.
Examples
--------
To train a GNN for graph classification on a set of graphs in ``dataset`` (assume
the backend is PyTorch):
>>> dataloader = dgl.dataloading.GraphDataLoader(
... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for batched_graph, labels in dataloader:
... train_on(batched_graph, labels)
"""
def __init__(self):
self.graph_collate_err_msg_format = (
"graph_collate: batch must contain DGLGraph, tensors, numpy arrays, "
"numbers, dicts or lists; found {}")
self.np_str_obj_array_pattern = re.compile(r'[SaUO]')
#This implementation is based on torch.utils.data._utils.collate.default_collate
def collate(self, items):
"""This function is similar to ``torch.utils.data._utils.collate.default_collate``.
It combines the sampled graphs and corresponding graph-level data
into a batched graph and tensors.
Parameters
----------
items : list of data points or tuples
Elements in the list are expected to have the same length.
Each sub-element will be batched as a batched graph, or a
batched tensor correspondingly.
Returns
-------
A tuple of the batching results.
"""
elem = items[0]
elem_type = type(elem)
if isinstance(elem, DGLGraph):
batched_graphs = batch(items)
return batched_graphs
elif F.is_tensor(elem):
return F.stack(items, 0)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(self.graph_collate_err_msg_format.format(elem.dtype))
return self.collate([F.tensor(b) for b in items])
elif elem.shape == (): # scalars
return F.tensor(items)
elif isinstance(elem, float):
return F.tensor(items, dtype=F.float64)
elif isinstance(elem, int):
return F.tensor(items)
elif isinstance(elem, (str, bytes)):
return items
elif isinstance(elem, Mapping):
return {key: self.collate([d[key] for d in items]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(self.collate(samples) for samples in zip(*items)))
elif isinstance(elem, Sequence):
# check to make sure that the elements in batch have consistent size
item_iter = iter(items)
elem_size = len(next(item_iter))
if not all(len(elem) == elem_size for elem in item_iter):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*items)
return [self.collate(samples) for samples in transposed]
raise TypeError(self.graph_collate_err_msg_format.format(elem_type))
"""DGL PyTorch DataLoaders""" """DGL PyTorch DataLoaders"""
import inspect import inspect
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from ..dataloader import NodeCollator, EdgeCollator from ..dataloader import NodeCollator, EdgeCollator, GraphCollator
from ...distributed import DistGraph from ...distributed import DistGraph
from ...distributed import DistDataLoader from ...distributed import DistDataLoader
...@@ -414,3 +414,53 @@ class EdgeDataLoader: ...@@ -414,3 +414,53 @@ class EdgeDataLoader:
def __len__(self): def __len__(self):
"""Return the number of batches of the data loader.""" """Return the number of batches of the data loader."""
return len(self.dataloader) return len(self.dataloader)
class GraphDataLoader:
"""PyTorch dataloader for batch-iterating over a set of graphs, generating the batched
graph and corresponding label tensor (if provided) of the said minibatch.
Parameters
----------
collate : Function, default is None
The customized collate function. Will use the default collate
function if not given.
kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
Examples
--------
To train a GNN for graph classification on a set of graphs in ``dataset`` (assume
the backend is PyTorch):
>>> dataloader = dgl.dataloading.GraphDataLoader(
... dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for batched_graph, labels in dataloader:
... train_on(batched_graph, labels)
"""
collator_arglist = inspect.getfullargspec(GraphCollator).args
def __init__(self, dataset, collate=None, **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
if k in self.collator_arglist:
collator_kwargs[k] = v
else:
dataloader_kwargs[k] = v
if collate is None:
self.collate = GraphCollator(**collator_kwargs).collate
else:
self.collate = collate
self.dataloader = DataLoader(dataset=dataset,
collate_fn=self.collate,
**dataloader_kwargs)
def __iter__(self):
"""Return the iterator of the data loader."""
return iter(self.dataloader)
def __len__(self):
"""Return the number of batches of the data loader."""
return len(self.dataloader)
...@@ -181,6 +181,15 @@ def test_neighbor_sampler_dataloader(): ...@@ -181,6 +181,15 @@ def test_neighbor_sampler_dataloader():
collator.dataset, collate_fn=collator.collate, batch_size=2, shuffle=True, drop_last=False) collator.dataset, collate_fn=collator.collate, batch_size=2, shuffle=True, drop_last=False)
_check_neighbor_sampling_dataloader(_g, nid, dl, mode) _check_neighbor_sampling_dataloader(_g, nid, dl, mode)
def test_graph_dataloader():
batch_size = 16
num_batches = 2
minigc_dataset = dgl.data.MiniGCDataset(batch_size * num_batches, 10, 20)
data_loader = dgl.dataloading.GraphDataLoader(minigc_dataset, batch_size=batch_size, shuffle=True)
for graph, label in data_loader:
assert isinstance(graph, dgl.DGLGraph)
assert F.asnumpy(label).shape[0] == batch_size
if __name__ == '__main__': if __name__ == '__main__':
test_neighbor_sampler_dataloader() test_neighbor_sampler_dataloader()
test_graph_dataloader()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment