"vscode:/vscode.git/clone" did not exist on "595b37b665db6b419969c59624ec15be15ff716c"
Unverified Commit b085224f authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[GraphBolt][CUDA] Dataloader feature overlap fix (#7036)

parent 68377251
......@@ -139,10 +139,7 @@ def create_dataloader(
if args.storage_device == "cpu":
datapipe = datapipe.copy_to(device)
# Until https://github.com/dmlc/dgl/issues/7008, overlap should be False.
dataloader = gb.DataLoader(
datapipe, args.num_workers, overlap_feature_fetch=False
)
dataloader = gb.DataLoader(datapipe, args.num_workers)
# Return the fully-initialized DataLoader object.
return dataloader
......
......@@ -16,17 +16,18 @@ from .item_sampler import ItemSampler
__all__ = [
"DataLoader",
"Awaiter",
"Bufferer",
]
def _find_and_wrap_parent(
datapipe_graph, datapipe_adjlist, target_datapipe, wrapper, **kwargs
):
def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
"""Find parent of target_datapipe and wrap it with ."""
datapipes = dp_utils.find_dps(
datapipe_graph,
target_datapipe,
)
datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)
for datapipe in datapipes:
datapipe_id = id(datapipe)
for parent_datapipe_id in datapipe_adjlist[datapipe_id][1]:
......@@ -36,6 +37,7 @@ def _find_and_wrap_parent(
parent_datapipe,
wrapper(parent_datapipe, **kwargs),
)
return datapipe_graph
class EndMarker(dp.iter.IterDataPipe):
......@@ -45,8 +47,7 @@ class EndMarker(dp.iter.IterDataPipe):
self.datapipe = datapipe
def __iter__(self):
for data in self.datapipe:
yield data
yield from self.datapipe
class Bufferer(dp.iter.IterDataPipe):
......@@ -58,11 +59,11 @@ class Bufferer(dp.iter.IterDataPipe):
The data pipeline.
buffer_size : int, optional
The size of the buffer which stores the fetched samples. If data coming
from datapipe has latency spikes, consider increasing passing a high
value. Default is 2.
from datapipe has latency spikes, consider setting to a higher value.
Default is 1.
"""
def __init__(self, datapipe, buffer_size=2):
def __init__(self, datapipe, buffer_size=1):
self.datapipe = datapipe
if buffer_size <= 0:
raise ValueError(
......@@ -180,7 +181,6 @@ class DataLoader(torch.utils.data.DataLoader):
datapipe = EndMarker(datapipe)
datapipe_graph = dp_utils.traverse_dps(datapipe)
datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)
# (1) Insert minibatch distribution.
# TODO(BarclayII): Currently I'm using sharding_filter() as a
......@@ -198,9 +198,8 @@ class DataLoader(torch.utils.data.DataLoader):
)
# (2) Cut datapipe at FeatureFetcher and wrap.
_find_and_wrap_parent(
datapipe_graph = _find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
FeatureFetcher,
MultiprocessingWrapper,
num_workers=num_workers,
......@@ -221,25 +220,16 @@ class DataLoader(torch.utils.data.DataLoader):
)
for feature_fetcher in feature_fetchers:
feature_fetcher.stream = _get_uva_stream()
_find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
EndMarker,
Bufferer,
buffer_size=2,
)
_find_and_wrap_parent(
datapipe_graph = dp_utils.replace_dp(
datapipe_graph,
datapipe_adjlist,
EndMarker,
Awaiter,
feature_fetcher,
Awaiter(Bufferer(feature_fetcher, buffer_size=1)),
)
# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
# data pipeline up to the CopyTo operation to run in a separate thread.
_find_and_wrap_parent(
datapipe_graph = _find_and_wrap_parent(
datapipe_graph,
datapipe_adjlist,
CopyTo,
dp.iter.Prefetcher,
buffer_size=2,
......
......@@ -7,6 +7,8 @@ import dgl.graphbolt
import pytest
import torch
import torchdata.dataloader2.graph as dp_utils
from . import gb_test_utils
......@@ -46,7 +48,8 @@ def test_DataLoader():
reason="This test requires the GPU.",
)
@pytest.mark.parametrize("overlap_feature_fetch", [True, False])
def test_gpu_sampling_DataLoader(overlap_feature_fetch):
@pytest.mark.parametrize("enable_feature_fetch", [True, False])
def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch):
N = 40
B = 4
itemset = dgl.graphbolt.ItemSet(torch.arange(N), names="seed_nodes")
......@@ -70,6 +73,7 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch):
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
)
if enable_feature_fetch:
datapipe = dgl.graphbolt.FeatureFetcher(
datapipe,
feature_store,
......@@ -79,4 +83,17 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch):
dataloader = dgl.graphbolt.DataLoader(
datapipe, overlap_feature_fetch=overlap_feature_fetch
)
bufferer_awaiter_cnt = int(enable_feature_fetch and overlap_feature_fetch)
datapipe = dataloader.dataset
datapipe_graph = dp_utils.traverse_dps(datapipe)
awaiters = dp_utils.find_dps(
datapipe_graph,
dgl.graphbolt.Awaiter,
)
assert len(awaiters) == bufferer_awaiter_cnt
bufferers = dp_utils.find_dps(
datapipe_graph,
dgl.graphbolt.Bufferer,
)
assert len(bufferers) == bufferer_awaiter_cnt
assert len(list(dataloader)) == N // B
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment