Unverified Commit 65efb4f5 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Add default prefetcher (#6541)

parent f5c8c0b5
...@@ -4,6 +4,7 @@ import torch.utils.data ...@@ -4,6 +4,7 @@ import torch.utils.data
import torchdata.dataloader2.graph as dp_utils import torchdata.dataloader2.graph as dp_utils
import torchdata.datapipes as dp import torchdata.datapipes as dp
from .base import CopyTo
from .feature_fetcher import FeatureFetcher from .feature_fetcher import FeatureFetcher
from .item_sampler import ItemSampler from .item_sampler import ItemSampler
...@@ -16,6 +17,25 @@ __all__ = [ ...@@ -16,6 +17,25 @@ __all__ = [
] ]
def _find_and_wrap_parent(
datapipe_graph, datapipe_adjlist, target_datapipe, wrapper, **kwargs
):
"""Find parent of target_datapipe and wrap it with ."""
datapipes = dp_utils.find_dps(
datapipe_graph,
target_datapipe,
)
for datapipe in datapipes:
datapipe_id = id(datapipe)
for parent_datapipe_id in datapipe_adjlist[datapipe_id][1]:
parent_datapipe, _ = datapipe_adjlist[parent_datapipe_id]
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
parent_datapipe,
wrapper(parent_datapipe, **kwargs),
)
class SingleProcessDataLoader(torch.utils.data.DataLoader): class SingleProcessDataLoader(torch.utils.data.DataLoader):
"""Single process DataLoader. """Single process DataLoader.
...@@ -34,6 +54,17 @@ class SingleProcessDataLoader(torch.utils.data.DataLoader): ...@@ -34,6 +54,17 @@ class SingleProcessDataLoader(torch.utils.data.DataLoader):
# The exception is that batch_size should be None, since we already # The exception is that batch_size should be None, since we already
# have minibatch sampling and collating in ItemSampler. # have minibatch sampling and collating in ItemSampler.
def __init__(self, datapipe): def __init__(self, datapipe):
datapipe_graph = dp_utils.traverse_dps(datapipe)
datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)
# Cut datapipe at CopyTo and wrap with prefetcher. This enables the
# data pipeline up to the CopyTo operation to run in a separate thread.
_find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
CopyTo,
dp.iter.Prefetcher,
buffer_size=2,
)
super().__init__(datapipe, batch_size=None, num_workers=0) super().__init__(datapipe, batch_size=None, num_workers=0)
...@@ -107,18 +138,22 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader): ...@@ -107,18 +138,22 @@ class MultiProcessDataLoader(torch.utils.data.DataLoader):
) )
# (2) Cut datapipe at FeatureFetcher and wrap. # (2) Cut datapipe at FeatureFetcher and wrap.
feature_fetchers = dp_utils.find_dps( _find_and_wrap_parent(
datapipe_graph, datapipe_graph,
datapipe_adjlist,
FeatureFetcher, FeatureFetcher,
MultiprocessingWrapper,
num_workers=num_workers,
) )
for feature_fetcher in feature_fetchers:
feature_fetcher_id = id(feature_fetcher) # (3) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
for parent_datapipe_id in datapipe_adjlist[feature_fetcher_id][1]: # data pipeline up to the CopyTo operation to run in a separate thread.
parent_datapipe, _ = datapipe_adjlist[parent_datapipe_id] _find_and_wrap_parent(
datapipe_graph = dp_utils.replace_dp(
datapipe_graph, datapipe_graph,
parent_datapipe, datapipe_adjlist,
MultiprocessingWrapper(parent_datapipe, num_workers), CopyTo,
dp.iter.Prefetcher,
buffer_size=2,
) )
# The stages after feature fetching is still done in the main process. # The stages after feature fetching is still done in the main process.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment