"...pytorch/mvgrl/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "bcffdb82c9260de4449d70be526447adc6cb14d1"
Unverified Commit f7f4e73a authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[GraphBolt] multiprocess dataloader (#5959)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent a33fafb7
"""Graph Bolt DataLoaders"""
import torch.utils.data
import torchdata.dataloader2.graph as dp_utils
import torchdata.datapipes as dp
from .datapipe_utils import datapipe_graph_to_adjlist
from .feature_fetcher import FeatureFetcher
from .minibatch_sampler import MinibatchSampler
class SingleProcessDataLoader(torch.utils.data.DataLoader):
......@@ -22,3 +29,92 @@ class SingleProcessDataLoader(torch.utils.data.DataLoader):
# have minibatch sampling and collating in MinibatchSampler.
def __init__(self, datapipe):
super().__init__(datapipe, batch_size=None, num_workers=0)
class MultiprocessingWrapper(dp.iter.IterDataPipe):
"""Wraps a datapipe with multiprocessing.
Parameters
----------
datapipe : DataPipe
The data pipeline.
num_workers : int, optional
The number of worker processes. Default is 0, meaning that there
will be no multiprocessing.
"""
def __init__(self, datapipe, num_workers=0):
self.datapipe = datapipe
self.dataloader = torch.utils.data.DataLoader(
datapipe,
batch_size=None,
num_workers=num_workers,
)
def __iter__(self):
yield from self.dataloader
class MultiProcessDataLoader(torch.utils.data.DataLoader):
"""Multiprocessing DataLoader.
Iterates over the data pipeline with everything before feature fetching
(i.e. :class:`dgl.graphbolt.FeatureFetcher`) in subprocesses, and
everything after feature fetching in the main process. The datapipe
is modified in-place as a result.
Only works on single GPU.
Parameters
----------
datapipe : DataPipe
The data pipeline.
num_workers : int, optional
Number of worker processes. Default is 0, which is identical to
:class:`SingleProcessDataLoader`.
"""
def __init__(self, datapipe, num_workers=0):
# Multiprocessing requires two modifications to the datapipe:
#
# 1. Insert a stage after MinibatchSampler to distribute the
# minibatches evenly across processes.
# 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe
# of the FeatureFetcher with a multiprocessing PyTorch DataLoader.
datapipe_graph = dp_utils.traverse_dps(datapipe)
datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)
# (1) Insert minibatch distribution.
# TODO(BarclayII): Currently I'm using sharding_filter() as a
# concept demonstration. Later on minibatch distribution should be
# merged into MinibatchSampler to maximize efficiency.
minibatch_samplers = dp_utils.find_dps(
datapipe_graph,
MinibatchSampler,
)
for minibatch_sampler in minibatch_samplers:
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
minibatch_sampler,
minibatch_sampler.sharding_filter(),
)
# (2) Cut datapipe at FeatureFetcher and wrap.
feature_fetchers = dp_utils.find_dps(
datapipe_graph,
FeatureFetcher,
)
for feature_fetcher in feature_fetchers:
feature_fetcher_id = id(feature_fetcher)
for parent_datapipe_id in datapipe_adjlist[feature_fetcher_id][1]:
parent_datapipe, _ = datapipe_adjlist[parent_datapipe_id]
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
parent_datapipe,
MultiprocessingWrapper(parent_datapipe, num_workers),
)
# The stages after feature fetching is still done in the main process.
# So we set num_workers to 0 here.
super().__init__(datapipe, batch_size=None, num_workers=0)
"""DataPipe utilities"""
def _get_parents(result_dict, datapipe_graph):
for k, (v, parents) in datapipe_graph.items():
if k not in result_dict:
result_dict[k] = (v, list(parents.keys()))
_get_parents(result_dict, parents)
def datapipe_graph_to_adjlist(datapipe_graph):
"""Given a DataPipe graph returned by
:func:`torch.utils.data.graph.traverse_dps` in DAG form, convert it into
adjacency list form.
Namely, :func:`torch.utils.data.graph.traverse_dps` returns the following
data structure:
.. code::
{
id(datapipe): (
datapipe,
{
id(parent1_of_datapipe): (parent1_of_datapipe, {...}),
id(parent2_of_datapipe): (parent2_of_datapipe, {...}),
...
}
)
}
We convert it into the following for easier access:
.. code::
{
id(datapipe1): (
datapipe1,
[id(parent1_of_datapipe1), id(parent2_of_datapipe1), ...]
),
id(datapipe2): (
datapipe2,
[id(parent1_of_datapipe2), id(parent2_of_datapipe2), ...]
),
...
}
"""
result_dict = {}
_get_parents(result_dict, datapipe_graph)
return result_dict
from functools import partial
import backend as F
import dgl
import dgl.graphbolt
import gb_test_utils
import torch
def sampler_func(graph, data):
seeds = data
sampler = dgl.dataloading.NeighborSampler([2, 2])
return sampler.sample(graph, seeds)
def fetch_func(features, labels, data):
input_nodes, output_nodes, adjs = data
input_features = features.read(input_nodes)
output_labels = labels.read(output_nodes)
return input_features, output_labels, adjs
def test_DataLoader():
N = 40
B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N))
# TODO(BarclayII): temporarily using DGLGraph. Should test using
# GraphBolt's storage as well once issue #5953 is resolved.
graph = dgl.add_reverse_edges(dgl.rand_graph(200, 6000))
features = dgl.graphbolt.TorchBasedFeatureStore(torch.randn(200, 4))
labels = dgl.graphbolt.TorchBasedFeatureStore(torch.randint(0, 10, (200,)))
minibatch_sampler = dgl.graphbolt.MinibatchSampler(itemset, batch_size=B)
subgraph_sampler = dgl.graphbolt.SubgraphSampler(
minibatch_sampler,
partial(sampler_func, graph),
)
feature_fetcher = dgl.graphbolt.FeatureFetcher(
subgraph_sampler,
partial(fetch_func, features, labels),
)
device_transferrer = dgl.graphbolt.CopyTo(feature_fetcher, F.ctx())
dataloader = dgl.graphbolt.MultiProcessDataLoader(
device_transferrer,
num_workers=4,
)
assert len(list(dataloader)) == N // B
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment