Unverified Commit b085224f authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Dataloader feature overlap fix (#7036)

parent 68377251
...@@ -139,10 +139,7 @@ def create_dataloader( ...@@ -139,10 +139,7 @@ def create_dataloader(
if args.storage_device == "cpu": if args.storage_device == "cpu":
datapipe = datapipe.copy_to(device) datapipe = datapipe.copy_to(device)
# Until https://github.com/dmlc/dgl/issues/7008, overlap should be False. dataloader = gb.DataLoader(datapipe, args.num_workers)
dataloader = gb.DataLoader(
datapipe, args.num_workers, overlap_feature_fetch=False
)
# Return the fully-initialized DataLoader object. # Return the fully-initialized DataLoader object.
return dataloader return dataloader
......
...@@ -16,17 +16,18 @@ from .item_sampler import ItemSampler ...@@ -16,17 +16,18 @@ from .item_sampler import ItemSampler
__all__ = [ __all__ = [
"DataLoader", "DataLoader",
"Awaiter",
"Bufferer",
] ]
def _find_and_wrap_parent( def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
datapipe_graph, datapipe_adjlist, target_datapipe, wrapper, **kwargs
):
"""Find parent of target_datapipe and wrap it with .""" """Find parent of target_datapipe and wrap it with ."""
datapipes = dp_utils.find_dps( datapipes = dp_utils.find_dps(
datapipe_graph, datapipe_graph,
target_datapipe, target_datapipe,
) )
datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)
for datapipe in datapipes: for datapipe in datapipes:
datapipe_id = id(datapipe) datapipe_id = id(datapipe)
for parent_datapipe_id in datapipe_adjlist[datapipe_id][1]: for parent_datapipe_id in datapipe_adjlist[datapipe_id][1]:
...@@ -36,6 +37,7 @@ def _find_and_wrap_parent( ...@@ -36,6 +37,7 @@ def _find_and_wrap_parent(
parent_datapipe, parent_datapipe,
wrapper(parent_datapipe, **kwargs), wrapper(parent_datapipe, **kwargs),
) )
return datapipe_graph
class EndMarker(dp.iter.IterDataPipe): class EndMarker(dp.iter.IterDataPipe):
...@@ -45,8 +47,7 @@ class EndMarker(dp.iter.IterDataPipe): ...@@ -45,8 +47,7 @@ class EndMarker(dp.iter.IterDataPipe):
self.datapipe = datapipe self.datapipe = datapipe
def __iter__(self): def __iter__(self):
for data in self.datapipe: yield from self.datapipe
yield data
class Bufferer(dp.iter.IterDataPipe): class Bufferer(dp.iter.IterDataPipe):
...@@ -58,11 +59,11 @@ class Bufferer(dp.iter.IterDataPipe): ...@@ -58,11 +59,11 @@ class Bufferer(dp.iter.IterDataPipe):
The data pipeline. The data pipeline.
buffer_size : int, optional buffer_size : int, optional
The size of the buffer which stores the fetched samples. If data coming The size of the buffer which stores the fetched samples. If data coming
from datapipe has latency spikes, consider increasing passing a high from datapipe has latency spikes, consider setting to a higher value.
value. Default is 2. Default is 1.
""" """
def __init__(self, datapipe, buffer_size=2): def __init__(self, datapipe, buffer_size=1):
self.datapipe = datapipe self.datapipe = datapipe
if buffer_size <= 0: if buffer_size <= 0:
raise ValueError( raise ValueError(
...@@ -180,7 +181,6 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -180,7 +181,6 @@ class DataLoader(torch.utils.data.DataLoader):
datapipe = EndMarker(datapipe) datapipe = EndMarker(datapipe)
datapipe_graph = dp_utils.traverse_dps(datapipe) datapipe_graph = dp_utils.traverse_dps(datapipe)
datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)
# (1) Insert minibatch distribution. # (1) Insert minibatch distribution.
# TODO(BarclayII): Currently I'm using sharding_filter() as a # TODO(BarclayII): Currently I'm using sharding_filter() as a
...@@ -198,9 +198,8 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -198,9 +198,8 @@ class DataLoader(torch.utils.data.DataLoader):
) )
# (2) Cut datapipe at FeatureFetcher and wrap. # (2) Cut datapipe at FeatureFetcher and wrap.
_find_and_wrap_parent( datapipe_graph = _find_and_wrap_parent(
datapipe_graph, datapipe_graph,
datapipe_adjlist,
FeatureFetcher, FeatureFetcher,
MultiprocessingWrapper, MultiprocessingWrapper,
num_workers=num_workers, num_workers=num_workers,
...@@ -221,25 +220,16 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -221,25 +220,16 @@ class DataLoader(torch.utils.data.DataLoader):
) )
for feature_fetcher in feature_fetchers: for feature_fetcher in feature_fetchers:
feature_fetcher.stream = _get_uva_stream() feature_fetcher.stream = _get_uva_stream()
_find_and_wrap_parent( datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
datapipe_adjlist,
EndMarker,
Bufferer,
buffer_size=2,
)
_find_and_wrap_parent(
datapipe_graph, datapipe_graph,
datapipe_adjlist, feature_fetcher,
EndMarker, Awaiter(Bufferer(feature_fetcher, buffer_size=1)),
Awaiter,
) )
# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the # (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
# data pipeline up to the CopyTo operation to run in a separate thread. # data pipeline up to the CopyTo operation to run in a separate thread.
_find_and_wrap_parent( datapipe_graph = _find_and_wrap_parent(
datapipe_graph, datapipe_graph,
datapipe_adjlist,
CopyTo, CopyTo,
dp.iter.Prefetcher, dp.iter.Prefetcher,
buffer_size=2, buffer_size=2,
......
...@@ -7,6 +7,8 @@ import dgl.graphbolt ...@@ -7,6 +7,8 @@ import dgl.graphbolt
import pytest import pytest
import torch import torch
import torchdata.dataloader2.graph as dp_utils
from . import gb_test_utils from . import gb_test_utils
...@@ -46,7 +48,8 @@ def test_DataLoader(): ...@@ -46,7 +48,8 @@ def test_DataLoader():
reason="This test requires the GPU.", reason="This test requires the GPU.",
) )
@pytest.mark.parametrize("overlap_feature_fetch", [True, False]) @pytest.mark.parametrize("overlap_feature_fetch", [True, False])
def test_gpu_sampling_DataLoader(overlap_feature_fetch): @pytest.mark.parametrize("enable_feature_fetch", [True, False])
def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch):
N = 40 N = 40
B = 4 B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes") itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
...@@ -70,6 +73,7 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch): ...@@ -70,6 +73,7 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch):
graph, graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)], fanouts=[torch.LongTensor([2]) for _ in range(2)],
) )
if enable_feature_fetch:
datapipe = dgl.graphbolt.FeatureFetcher( datapipe = dgl.graphbolt.FeatureFetcher(
datapipe, datapipe,
feature_store, feature_store,
...@@ -79,4 +83,17 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch): ...@@ -79,4 +83,17 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch):
dataloader = dgl.graphbolt.DataLoader( dataloader = dgl.graphbolt.DataLoader(
datapipe, overlap_feature_fetch=overlap_feature_fetch datapipe, overlap_feature_fetch=overlap_feature_fetch
) )
bufferer_awaiter_cnt = int(enable_feature_fetch and overlap_feature_fetch)
datapipe = dataloader.dataset
datapipe_graph = dp_utils.traverse_dps(datapipe)
awaiters = dp_utils.find_dps(
datapipe_graph,
dgl.graphbolt.Awaiter,
)
assert len(awaiters) == bufferer_awaiter_cnt
bufferers = dp_utils.find_dps(
datapipe_graph,
dgl.graphbolt.Bufferer,
)
assert len(bufferers) == bufferer_awaiter_cnt
assert len(list(dataloader)) == N // B assert len(list(dataloader)) == N // B
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment