Unverified Commit 65efb4f5 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Add default prefetcher (#6541)

parent f5c8c0b5
......@@ -4,6 +4,7 @@ import torch.utils.data
import torchdata.dataloader2.graph as dp_utils
import torchdata.datapipes as dp
from .base import CopyTo
from .feature_fetcher import FeatureFetcher
from .item_sampler import ItemSampler
......@@ -16,6 +17,25 @@ __all__ = [
]
def _find_and_wrap_parent(
datapipe_graph, datapipe_adjlist, target_datapipe, wrapper, **kwargs
):
"""Find parent of target_datapipe and wrap it with ."""
datapipes = dp_utils.find_dps(
datapipe_graph,
target_datapipe,
)
for datapipe in datapipes:
datapipe_id = id(datapipe)
for parent_datapipe_id in datapipe_adjlist[datapipe_id][1]:
parent_datapipe, _ = datapipe_adjlist[parent_datapipe_id]
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
parent_datapipe,
wrapper(parent_datapipe, **kwargs),
)
class SingleProcessDataLoader(torch.utils.data.DataLoader):
"""Single process DataLoader.
......@@ -34,6 +54,17 @@ class SingleProcessDataLoader(torch.utils.data.DataLoader):
# The exception is that batch_size should be None, since we already
# have minibatch sampling and collating in ItemSampler.
def __init__(self, datapipe):
datapipe_graph = dp_utils.traverse_dps(datapipe)
datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)
# Cut datapipe at CopyTo and wrap with prefetcher. This enables the
# data pipeline up to the CopyTo operation to run in a separate thread.
_find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
CopyTo,
dp.iter.Prefetcher,
buffer_size=2,
)
super().__init__(datapipe, batch_size=None, num_workers=0)
......@@ -107,19 +138,23 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader):
)
# (2) Cut datapipe at FeatureFetcher and wrap.
feature_fetchers = dp_utils.find_dps(
_find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
FeatureFetcher,
MultiprocessingWrapper,
num_workers=num_workers,
)
# (3) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
# data pipeline up to the CopyTo operation to run in a separate thread.
_find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
CopyTo,
dp.iter.Prefetcher,
buffer_size=2,
)
for feature_fetcher in feature_fetchers:
feature_fetcher_id = id(feature_fetcher)
for parent_datapipe_id in datapipe_adjlist[feature_fetcher_id][1]:
parent_datapipe, _ = datapipe_adjlist[parent_datapipe_id]
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
parent_datapipe,
MultiprocessingWrapper(parent_datapipe, num_workers),
)
# The stages after feature fetching is still done in the main process.
# So we set num_workers to 0 here.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment