Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b085224f
Unverified
Commit
b085224f
authored
Jan 30, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Jan 30, 2024
Browse files
[GraphBolt][CUDA] Dataloader feature overlap fix (#7036)
parent
68377251
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
36 deletions
+40
-36
examples/multigpu/graphbolt/node_classification.py
examples/multigpu/graphbolt/node_classification.py
+1
-4
python/dgl/graphbolt/dataloader.py
python/dgl/graphbolt/dataloader.py
+16
-26
tests/python/pytorch/graphbolt/test_dataloader.py
tests/python/pytorch/graphbolt/test_dataloader.py
+23
-6
No files found.
examples/multigpu/graphbolt/node_classification.py
View file @
b085224f
...
@@ -139,10 +139,7 @@ def create_dataloader(
...
@@ -139,10 +139,7 @@ def create_dataloader(
if
args
.
storage_device
==
"cpu"
:
if
args
.
storage_device
==
"cpu"
:
datapipe
=
datapipe
.
copy_to
(
device
)
datapipe
=
datapipe
.
copy_to
(
device
)
# Until https://github.com/dmlc/dgl/issues/7008, overlap should be False.
dataloader
=
gb
.
DataLoader
(
datapipe
,
args
.
num_workers
)
dataloader
=
gb
.
DataLoader
(
datapipe
,
args
.
num_workers
,
overlap_feature_fetch
=
False
)
# Return the fully-initialized DataLoader object.
# Return the fully-initialized DataLoader object.
return
dataloader
return
dataloader
...
...
python/dgl/graphbolt/dataloader.py
View file @
b085224f
...
@@ -16,17 +16,18 @@ from .item_sampler import ItemSampler
...
@@ -16,17 +16,18 @@ from .item_sampler import ItemSampler
__all__
=
[
__all__
=
[
"DataLoader"
,
"DataLoader"
,
"Awaiter"
,
"Bufferer"
,
]
]
def
_find_and_wrap_parent
(
def
_find_and_wrap_parent
(
datapipe_graph
,
target_datapipe
,
wrapper
,
**
kwargs
):
datapipe_graph
,
datapipe_adjlist
,
target_datapipe
,
wrapper
,
**
kwargs
):
"""Find parent of target_datapipe and wrap it with ."""
"""Find parent of target_datapipe and wrap it with ."""
datapipes
=
dp_utils
.
find_dps
(
datapipes
=
dp_utils
.
find_dps
(
datapipe_graph
,
datapipe_graph
,
target_datapipe
,
target_datapipe
,
)
)
datapipe_adjlist
=
datapipe_graph_to_adjlist
(
datapipe_graph
)
for
datapipe
in
datapipes
:
for
datapipe
in
datapipes
:
datapipe_id
=
id
(
datapipe
)
datapipe_id
=
id
(
datapipe
)
for
parent_datapipe_id
in
datapipe_adjlist
[
datapipe_id
][
1
]:
for
parent_datapipe_id
in
datapipe_adjlist
[
datapipe_id
][
1
]:
...
@@ -36,6 +37,7 @@ def _find_and_wrap_parent(
...
@@ -36,6 +37,7 @@ def _find_and_wrap_parent(
parent_datapipe
,
parent_datapipe
,
wrapper
(
parent_datapipe
,
**
kwargs
),
wrapper
(
parent_datapipe
,
**
kwargs
),
)
)
return
datapipe_graph
class
EndMarker
(
dp
.
iter
.
IterDataPipe
):
class
EndMarker
(
dp
.
iter
.
IterDataPipe
):
...
@@ -45,8 +47,7 @@ class EndMarker(dp.iter.IterDataPipe):
...
@@ -45,8 +47,7 @@ class EndMarker(dp.iter.IterDataPipe):
self
.
datapipe
=
datapipe
self
.
datapipe
=
datapipe
def
__iter__
(
self
):
def
__iter__
(
self
):
for
data
in
self
.
datapipe
:
yield
from
self
.
datapipe
yield
data
class
Bufferer
(
dp
.
iter
.
IterDataPipe
):
class
Bufferer
(
dp
.
iter
.
IterDataPipe
):
...
@@ -58,11 +59,11 @@ class Bufferer(dp.iter.IterDataPipe):
...
@@ -58,11 +59,11 @@ class Bufferer(dp.iter.IterDataPipe):
The data pipeline.
The data pipeline.
buffer_size : int, optional
buffer_size : int, optional
The size of the buffer which stores the fetched samples. If data coming
The size of the buffer which stores the fetched samples. If data coming
from datapipe has latency spikes, consider
increasing passing a high
from datapipe has latency spikes, consider
setting to a higher value.
value.
Default is
2
.
Default is
1
.
"""
"""
def
__init__
(
self
,
datapipe
,
buffer_size
=
2
):
def
__init__
(
self
,
datapipe
,
buffer_size
=
1
):
self
.
datapipe
=
datapipe
self
.
datapipe
=
datapipe
if
buffer_size
<=
0
:
if
buffer_size
<=
0
:
raise
ValueError
(
raise
ValueError
(
...
@@ -180,7 +181,6 @@ class DataLoader(torch.utils.data.DataLoader):
...
@@ -180,7 +181,6 @@ class DataLoader(torch.utils.data.DataLoader):
datapipe
=
EndMarker
(
datapipe
)
datapipe
=
EndMarker
(
datapipe
)
datapipe_graph
=
dp_utils
.
traverse_dps
(
datapipe
)
datapipe_graph
=
dp_utils
.
traverse_dps
(
datapipe
)
datapipe_adjlist
=
datapipe_graph_to_adjlist
(
datapipe_graph
)
# (1) Insert minibatch distribution.
# (1) Insert minibatch distribution.
# TODO(BarclayII): Currently I'm using sharding_filter() as a
# TODO(BarclayII): Currently I'm using sharding_filter() as a
...
@@ -198,9 +198,8 @@ class DataLoader(torch.utils.data.DataLoader):
...
@@ -198,9 +198,8 @@ class DataLoader(torch.utils.data.DataLoader):
)
)
# (2) Cut datapipe at FeatureFetcher and wrap.
# (2) Cut datapipe at FeatureFetcher and wrap.
_find_and_wrap_parent
(
datapipe_graph
=
_find_and_wrap_parent
(
datapipe_graph
,
datapipe_graph
,
datapipe_adjlist
,
FeatureFetcher
,
FeatureFetcher
,
MultiprocessingWrapper
,
MultiprocessingWrapper
,
num_workers
=
num_workers
,
num_workers
=
num_workers
,
...
@@ -221,25 +220,16 @@ class DataLoader(torch.utils.data.DataLoader):
...
@@ -221,25 +220,16 @@ class DataLoader(torch.utils.data.DataLoader):
)
)
for
feature_fetcher
in
feature_fetchers
:
for
feature_fetcher
in
feature_fetchers
:
feature_fetcher
.
stream
=
_get_uva_stream
()
feature_fetcher
.
stream
=
_get_uva_stream
()
_find_and_wrap_parent
(
datapipe_graph
=
dp_utils
.
replace_dp
(
datapipe_graph
,
datapipe_adjlist
,
EndMarker
,
Bufferer
,
buffer_size
=
2
,
)
_find_and_wrap_parent
(
datapipe_graph
,
datapipe_graph
,
datapipe_adjlist
,
feature_fetcher
,
EndMarker
,
Awaiter
(
Bufferer
(
feature_fetcher
,
buffer_size
=
1
)),
Awaiter
,
)
)
# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
# data pipeline up to the CopyTo operation to run in a separate thread.
# data pipeline up to the CopyTo operation to run in a separate thread.
_find_and_wrap_parent
(
datapipe_graph
=
_find_and_wrap_parent
(
datapipe_graph
,
datapipe_graph
,
datapipe_adjlist
,
CopyTo
,
CopyTo
,
dp
.
iter
.
Prefetcher
,
dp
.
iter
.
Prefetcher
,
buffer_size
=
2
,
buffer_size
=
2
,
...
...
tests/python/pytorch/graphbolt/test_dataloader.py
View file @
b085224f
...
@@ -7,6 +7,8 @@ import dgl.graphbolt
...
@@ -7,6 +7,8 @@ import dgl.graphbolt
import
pytest
import
pytest
import
torch
import
torch
import
torchdata.dataloader2.graph
as
dp_utils
from
.
import
gb_test_utils
from
.
import
gb_test_utils
...
@@ -46,7 +48,8 @@ def test_DataLoader():
...
@@ -46,7 +48,8 @@ def test_DataLoader():
reason
=
"This test requires the GPU."
,
reason
=
"This test requires the GPU."
,
)
)
@
pytest
.
mark
.
parametrize
(
"overlap_feature_fetch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"overlap_feature_fetch"
,
[
True
,
False
])
def
test_gpu_sampling_DataLoader
(
overlap_feature_fetch
):
@
pytest
.
mark
.
parametrize
(
"enable_feature_fetch"
,
[
True
,
False
])
def
test_gpu_sampling_DataLoader
(
overlap_feature_fetch
,
enable_feature_fetch
):
N
=
40
N
=
40
B
=
4
B
=
4
itemset
=
dgl
.
graphbolt
.
ItemSet
(
torch
.
arange
(
N
),
names
=
"seed_nodes"
)
itemset
=
dgl
.
graphbolt
.
ItemSet
(
torch
.
arange
(
N
),
names
=
"seed_nodes"
)
...
@@ -70,6 +73,7 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch):
...
@@ -70,6 +73,7 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch):
graph
,
graph
,
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
2
)],
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
2
)],
)
)
if
enable_feature_fetch
:
datapipe
=
dgl
.
graphbolt
.
FeatureFetcher
(
datapipe
=
dgl
.
graphbolt
.
FeatureFetcher
(
datapipe
,
datapipe
,
feature_store
,
feature_store
,
...
@@ -79,4 +83,17 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch):
...
@@ -79,4 +83,17 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch):
dataloader
=
dgl
.
graphbolt
.
DataLoader
(
dataloader
=
dgl
.
graphbolt
.
DataLoader
(
datapipe
,
overlap_feature_fetch
=
overlap_feature_fetch
datapipe
,
overlap_feature_fetch
=
overlap_feature_fetch
)
)
bufferer_awaiter_cnt
=
int
(
enable_feature_fetch
and
overlap_feature_fetch
)
datapipe
=
dataloader
.
dataset
datapipe_graph
=
dp_utils
.
traverse_dps
(
datapipe
)
awaiters
=
dp_utils
.
find_dps
(
datapipe_graph
,
dgl
.
graphbolt
.
Awaiter
,
)
assert
len
(
awaiters
)
==
bufferer_awaiter_cnt
bufferers
=
dp_utils
.
find_dps
(
datapipe_graph
,
dgl
.
graphbolt
.
Bufferer
,
)
assert
len
(
bufferers
)
==
bufferer_awaiter_cnt
assert
len
(
list
(
dataloader
))
==
N
//
B
assert
len
(
list
(
dataloader
))
==
N
//
B
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment