Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
badeaf19
"git@developer.sourcefind.cn:OpenDAS/torch-cluster.git" did not exist on "da65da5b95d733f24db94e17ce835ff25718c02c"
Unverified
Commit
badeaf19
authored
Feb 05, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Feb 05, 2024
Browse files
[GraphBolt][CUDA] Pipelined sampling optimization (#7039)
parent
4b265390
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
352 additions
and
75 deletions
+352
-75
python/dgl/graphbolt/base.py
python/dgl/graphbolt/base.py
+75
-0
python/dgl/graphbolt/dataloader.py
python/dgl/graphbolt/dataloader.py
+27
-60
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+150
-2
python/dgl/graphbolt/minibatch_transformer.py
python/dgl/graphbolt/minibatch_transformer.py
+6
-2
python/dgl/graphbolt/subgraph_sampler.py
python/dgl/graphbolt/subgraph_sampler.py
+1
-5
tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
+73
-0
tests/python/pytorch/graphbolt/test_dataloader.py
tests/python/pytorch/graphbolt/test_dataloader.py
+20
-6
No files found.
python/dgl/graphbolt/base.py
View file @
badeaf19
"""Base types and utilities for Graph Bolt."""
from
collections
import
deque
from
dataclasses
import
dataclass
import
torch
...
...
@@ -14,6 +15,10 @@ __all__ = [
"etype_str_to_tuple"
,
"etype_tuple_to_str"
,
"CopyTo"
,
"FutureWaiter"
,
"Waiter"
,
"Bufferer"
,
"EndMarker"
,
"isin"
,
"index_select"
,
"expand_indptr"
,
...
...
@@ -247,6 +252,76 @@ class CopyTo(IterDataPipe):
yield
data
@
functional_datapipe
(
"mark_end"
)
class
EndMarker
(
IterDataPipe
):
"""Used to mark the end of a datapipe and is a no-op."""
def
__init__
(
self
,
datapipe
):
self
.
datapipe
=
datapipe
def
__iter__
(
self
):
yield
from
self
.
datapipe
@
functional_datapipe
(
"buffer"
)
class
Bufferer
(
IterDataPipe
):
"""Buffers items before yielding them.
Parameters
----------
datapipe : DataPipe
The data pipeline.
buffer_size : int, optional
The size of the buffer which stores the fetched samples. If data coming
from datapipe has latency spikes, consider setting to a higher value.
Default is 1.
"""
def
__init__
(
self
,
datapipe
,
buffer_size
=
1
):
self
.
datapipe
=
datapipe
if
buffer_size
<=
0
:
raise
ValueError
(
"'buffer_size' is required to be a positive integer."
)
self
.
buffer
=
deque
(
maxlen
=
buffer_size
)
def
__iter__
(
self
):
for
data
in
self
.
datapipe
:
if
len
(
self
.
buffer
)
<
self
.
buffer
.
maxlen
:
self
.
buffer
.
append
(
data
)
else
:
return_data
=
self
.
buffer
.
popleft
()
self
.
buffer
.
append
(
data
)
yield
return_data
while
len
(
self
.
buffer
)
>
0
:
yield
self
.
buffer
.
popleft
()
@
functional_datapipe
(
"wait"
)
class
Waiter
(
IterDataPipe
):
"""Calls the wait function of all items."""
def
__init__
(
self
,
datapipe
):
self
.
datapipe
=
datapipe
def
__iter__
(
self
):
for
data
in
self
.
datapipe
:
data
.
wait
()
yield
data
@
functional_datapipe
(
"wait_future"
)
class
FutureWaiter
(
IterDataPipe
):
"""Calls the result function of all items and returns their results."""
def
__init__
(
self
,
datapipe
):
self
.
datapipe
=
datapipe
def
__iter__
(
self
):
for
data
in
self
.
datapipe
:
yield
data
.
result
()
@
dataclass
class
CSCFormatBase
:
r
"""Basic class representing data in Compressed Sparse Column (CSC) format.
...
...
python/dgl/graphbolt/dataloader.py
View file @
badeaf19
"""Graph Bolt DataLoaders"""
from
co
llections
import
deque
from
co
ncurrent.futures
import
ThreadPoolExecutor
import
torch
import
torch.utils.data
...
...
@@ -9,6 +9,7 @@ import torchdata.datapipes as dp
from
.base
import
CopyTo
from
.feature_fetcher
import
FeatureFetcher
from
.impl.neighbor_sampler
import
SamplePerLayer
from
.internal
import
datapipe_graph_to_adjlist
from
.item_sampler
import
ItemSampler
...
...
@@ -16,8 +17,6 @@ from .item_sampler import ItemSampler
__all__
=
[
"DataLoader"
,
"Awaiter"
,
"Bufferer"
,
]
...
...
@@ -40,61 +39,6 @@ def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
return
datapipe_graph
class
EndMarker
(
dp
.
iter
.
IterDataPipe
):
"""Used to mark the end of a datapipe and is a no-op."""
def
__init__
(
self
,
datapipe
):
self
.
datapipe
=
datapipe
def
__iter__
(
self
):
yield
from
self
.
datapipe
class
Bufferer
(
dp
.
iter
.
IterDataPipe
):
"""Buffers items before yielding them.
Parameters
----------
datapipe : DataPipe
The data pipeline.
buffer_size : int, optional
The size of the buffer which stores the fetched samples. If data coming
from datapipe has latency spikes, consider setting to a higher value.
Default is 1.
"""
def
__init__
(
self
,
datapipe
,
buffer_size
=
1
):
self
.
datapipe
=
datapipe
if
buffer_size
<=
0
:
raise
ValueError
(
"'buffer_size' is required to be a positive integer."
)
self
.
buffer
=
deque
(
maxlen
=
buffer_size
)
def
__iter__
(
self
):
for
data
in
self
.
datapipe
:
if
len
(
self
.
buffer
)
<
self
.
buffer
.
maxlen
:
self
.
buffer
.
append
(
data
)
else
:
return_data
=
self
.
buffer
.
popleft
()
self
.
buffer
.
append
(
data
)
yield
return_data
while
len
(
self
.
buffer
)
>
0
:
yield
self
.
buffer
.
popleft
()
class
Awaiter
(
dp
.
iter
.
IterDataPipe
):
"""Calls the wait function of all items."""
def
__init__
(
self
,
datapipe
):
self
.
datapipe
=
datapipe
def
__iter__
(
self
):
for
data
in
self
.
datapipe
:
data
.
wait
()
yield
data
class
MultiprocessingWrapper
(
dp
.
iter
.
IterDataPipe
):
"""Wraps a datapipe with multiprocessing.
...
...
@@ -156,6 +100,10 @@ class DataLoader(torch.utils.data.DataLoader):
If True, the data loader will overlap the UVA feature fetcher operations
with the rest of operations by using an alternative CUDA stream. Default
is True.
overlap_graph_fetch : bool, optional
If True, the data loader will overlap the UVA graph fetching operations
with the rest of operations by using an alternative CUDA stream. Default
is False.
max_uva_threads : int, optional
Limits the number of CUDA threads used for UVA copies so that the rest
of the computations can run simultaneously with it. Setting it to a too
...
...
@@ -170,6 +118,7 @@ class DataLoader(torch.utils.data.DataLoader):
num_workers
=
0
,
persistent_workers
=
True
,
overlap_feature_fetch
=
True
,
overlap_graph_fetch
=
False
,
max_uva_threads
=
6144
,
):
# Multiprocessing requires two modifications to the datapipe:
...
...
@@ -179,7 +128,7 @@ class DataLoader(torch.utils.data.DataLoader):
# 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe
# of the FeatureFetcher with a multiprocessing PyTorch DataLoader.
datapipe
=
EndMarker
(
datapipe
)
datapipe
=
datapipe
.
mark_end
(
)
datapipe_graph
=
dp_utils
.
traverse_dps
(
datapipe
)
# (1) Insert minibatch distribution.
...
...
@@ -223,7 +172,25 @@ class DataLoader(torch.utils.data.DataLoader):
datapipe_graph
=
dp_utils
.
replace_dp
(
datapipe_graph
,
feature_fetcher
,
Awaiter
(
Bufferer
(
feature_fetcher
,
buffer_size
=
1
)),
feature_fetcher
.
buffer
(
1
).
wait
(),
)
if
(
overlap_graph_fetch
and
num_workers
==
0
and
torch
.
cuda
.
is_available
()
):
torch
.
ops
.
graphbolt
.
set_max_uva_threads
(
max_uva_threads
)
samplers
=
dp_utils
.
find_dps
(
datapipe_graph
,
SamplePerLayer
,
)
executor
=
ThreadPoolExecutor
(
max_workers
=
1
)
for
sampler
in
samplers
:
datapipe_graph
=
dp_utils
.
replace_dp
(
datapipe_graph
,
sampler
,
sampler
.
fetch_and_sample
(
_get_uva_stream
(),
executor
,
1
),
)
# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
badeaf19
"""Neighbor subgraph samplers for GraphBolt."""
from
concurrent.futures
import
ThreadPoolExecutor
from
functools
import
partial
import
torch
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
from
..internal
import
compact_csc_format
,
unique_and_compact_csc_formats
from
..minibatch_transformer
import
MiniBatchTransformer
from
..subgraph_sampler
import
SubgraphSampler
from
.fused_csc_sampling_graph
import
fused_csc_sampling_graph
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
__all__
=
[
"NeighborSampler"
,
"LayerNeighborSampler"
]
__all__
=
[
"NeighborSampler"
,
"LayerNeighborSampler"
,
"SamplePerLayer"
,
"SamplePerLayerFromFetchedSubgraph"
,
"FetchInsubgraphData"
,
]
@
functional_datapipe
(
"fetch_insubgraph_data"
)
class
FetchInsubgraphData
(
Mapper
):
"""Fetches the insubgraph and wraps it in a FusedCSCSamplingGraph object. If
the provided sample_per_layer_obj has a valid prob_name, then it reads the
probabilies of all the fetched edges. Furthermore, if type_per_array tensor
exists in the underlying graph, then the types of all the fetched edges are
read as well."""
def
__init__
(
self
,
datapipe
,
sample_per_layer_obj
,
stream
=
None
,
executor
=
None
):
super
().
__init__
(
datapipe
,
self
.
_fetch_per_layer
)
self
.
graph
=
sample_per_layer_obj
.
sampler
.
__self__
self
.
prob_name
=
sample_per_layer_obj
.
prob_name
self
.
stream
=
stream
if
executor
is
None
:
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
1
)
else
:
self
.
executor
=
executor
def
_fetch_per_layer_impl
(
self
,
minibatch
,
stream
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
index
=
minibatch
.
_seed_nodes
if
isinstance
(
index
,
dict
):
index
=
self
.
graph
.
_convert_to_homogeneous_nodes
(
index
)
index
,
original_positions
=
index
.
sort
()
if
(
original_positions
.
diff
()
==
1
).
all
().
item
():
# is_sorted
minibatch
.
_subgraph_seed_nodes
=
None
else
:
minibatch
.
_subgraph_seed_nodes
=
original_positions
index
.
record_stream
(
torch
.
cuda
.
current_stream
())
index_select_csc_with_indptr
=
partial
(
torch
.
ops
.
graphbolt
.
index_select_csc
,
self
.
graph
.
csc_indptr
)
def
record_stream
(
tensor
):
if
stream
is
not
None
and
tensor
.
is_cuda
:
tensor
.
record_stream
(
stream
)
indptr
,
indices
=
index_select_csc_with_indptr
(
self
.
graph
.
indices
,
index
,
None
)
record_stream
(
indptr
)
record_stream
(
indices
)
output_size
=
len
(
indices
)
if
self
.
graph
.
type_per_edge
is
not
None
:
_
,
type_per_edge
=
index_select_csc_with_indptr
(
self
.
graph
.
type_per_edge
,
index
,
output_size
)
record_stream
(
type_per_edge
)
else
:
type_per_edge
=
None
if
self
.
graph
.
edge_attributes
is
not
None
:
probs_or_mask
=
self
.
graph
.
edge_attributes
.
get
(
self
.
prob_name
,
None
)
if
probs_or_mask
is
not
None
:
_
,
probs_or_mask
=
index_select_csc_with_indptr
(
probs_or_mask
,
index
,
output_size
)
record_stream
(
probs_or_mask
)
else
:
probs_or_mask
=
None
if
self
.
graph
.
node_type_offset
is
not
None
:
node_type_offset
=
torch
.
searchsorted
(
index
,
self
.
graph
.
node_type_offset
)
else
:
node_type_offset
=
None
subgraph
=
fused_csc_sampling_graph
(
indptr
,
indices
,
node_type_offset
=
node_type_offset
,
type_per_edge
=
type_per_edge
,
node_type_to_id
=
self
.
graph
.
node_type_to_id
,
edge_type_to_id
=
self
.
graph
.
edge_type_to_id
,
)
if
self
.
prob_name
is
not
None
and
probs_or_mask
is
not
None
:
subgraph
.
edge_attributes
=
{
self
.
prob_name
:
probs_or_mask
}
minibatch
.
sampled_subgraphs
.
insert
(
0
,
subgraph
)
if
self
.
stream
is
not
None
:
minibatch
.
wait
=
torch
.
cuda
.
current_stream
().
record_event
().
wait
return
minibatch
def
_fetch_per_layer
(
self
,
minibatch
):
current_stream
=
None
if
self
.
stream
is
not
None
:
current_stream
=
torch
.
cuda
.
current_stream
()
self
.
stream
.
wait_stream
(
current_stream
)
return
self
.
executor
.
submit
(
self
.
_fetch_per_layer_impl
,
minibatch
,
current_stream
)
@
functional_datapipe
(
"sample_per_layer_from_fetched_subgraph"
)
class
SamplePerLayerFromFetchedSubgraph
(
MiniBatchTransformer
):
"""Sample neighbor edges from a graph for a single layer."""
def
__init__
(
self
,
datapipe
,
sample_per_layer_obj
):
super
().
__init__
(
datapipe
,
self
.
_sample_per_layer_from_fetched_subgraph
)
self
.
sampler_name
=
sample_per_layer_obj
.
sampler
.
__name__
self
.
fanout
=
sample_per_layer_obj
.
fanout
self
.
replace
=
sample_per_layer_obj
.
replace
self
.
prob_name
=
sample_per_layer_obj
.
prob_name
def
_sample_per_layer_from_fetched_subgraph
(
self
,
minibatch
):
subgraph
=
minibatch
.
sampled_subgraphs
[
0
]
sampled_subgraph
=
getattr
(
subgraph
,
self
.
sampler_name
)(
minibatch
.
_subgraph_seed_nodes
,
self
.
fanout
,
self
.
replace
,
self
.
prob_name
,
)
delattr
(
minibatch
,
"_subgraph_seed_nodes"
)
sampled_subgraph
.
original_column_node_ids
=
minibatch
.
_seed_nodes
minibatch
.
sampled_subgraphs
[
0
]
=
sampled_subgraph
return
minibatch
@
functional_datapipe
(
"sample_per_layer"
)
...
...
@@ -72,6 +206,19 @@ class CompactPerLayer(MiniBatchTransformer):
return
minibatch
@
functional_datapipe
(
"fetch_and_sample"
)
class
FetcherAndSampler
(
MiniBatchTransformer
):
"""Overlapped graph sampling operation replacement."""
def
__init__
(
self
,
sampler
,
stream
,
executor
,
buffer_size
):
datapipe
=
sampler
.
datapipe
.
fetch_insubgraph_data
(
sampler
,
stream
,
executor
)
datapipe
=
datapipe
.
buffer
(
buffer_size
).
wait_future
().
wait
()
datapipe
=
datapipe
.
sample_per_layer_from_fetched_subgraph
(
sampler
)
super
().
__init__
(
datapipe
)
@
functional_datapipe
(
"sample_neighbor"
)
class
NeighborSampler
(
SubgraphSampler
):
# pylint: disable=abstract-method
...
...
@@ -173,7 +320,8 @@ class NeighborSampler(SubgraphSampler):
datapipe
,
graph
,
fanouts
,
replace
,
prob_name
,
deduplicate
,
sampler
)
def
_prepare
(
self
,
node_type_to_id
,
minibatch
):
@
staticmethod
def
_prepare
(
node_type_to_id
,
minibatch
):
seeds
=
minibatch
.
_seed_nodes
# Enrich seeds with all node types.
if
isinstance
(
seeds
,
dict
):
...
...
python/dgl/graphbolt/minibatch_transformer.py
View file @
badeaf19
...
...
@@ -29,10 +29,10 @@ class MiniBatchTransformer(Mapper):
def
__init__
(
self
,
datapipe
,
transformer
,
transformer
=
None
,
):
super
().
__init__
(
datapipe
,
self
.
_transformer
)
self
.
transformer
=
transformer
self
.
transformer
=
transformer
or
self
.
_identity
def
_transformer
(
self
,
minibatch
):
minibatch
=
self
.
transformer
(
minibatch
)
...
...
@@ -40,3 +40,7 @@ class MiniBatchTransformer(Mapper):
minibatch
,
(
MiniBatch
,)
),
"The transformer output should be an instance of MiniBatch"
return
minibatch
@
staticmethod
def
_identity
(
minibatch
):
return
minibatch
python/dgl/graphbolt/subgraph_sampler.py
View file @
badeaf19
...
...
@@ -46,11 +46,7 @@ class SubgraphSampler(MiniBatchTransformer):
datapipe
=
datapipe
.
transform
(
self
.
_preprocess
)
datapipe
=
self
.
sampling_stages
(
datapipe
,
*
args
,
**
kwargs
)
datapipe
=
datapipe
.
transform
(
self
.
_postprocess
)
super
().
__init__
(
datapipe
,
self
.
_identity
)
@
staticmethod
def
_identity
(
minibatch
):
return
minibatch
super
().
__init__
(
datapipe
)
@
staticmethod
def
_postprocess
(
minibatch
):
...
...
tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
0 → 100644
View file @
badeaf19
import
unittest
from
functools
import
partial
import
backend
as
F
import
dgl
import
dgl.graphbolt
as
gb
import
pytest
import
torch
def
get_hetero_graph
():
# COO graph:
# [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
# [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{
"n1:e1:n2"
:
0
,
"n2:e2:n1"
:
1
}
indptr
=
torch
.
LongTensor
([
0
,
2
,
4
,
6
,
8
,
10
])
indices
=
torch
.
LongTensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
0
,
1
])
type_per_edge
=
torch
.
LongTensor
([
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
])
edge_attributes
=
{
"weight"
:
torch
.
FloatTensor
(
[
2.5
,
0
,
8.4
,
0
,
0.4
,
1.2
,
2.5
,
0
,
8.4
,
0.5
]
),
"mask"
:
torch
.
BoolTensor
([
1
,
0
,
1
,
0
,
1
,
1
,
1
,
0
,
1
,
1
]),
}
node_type_offset
=
torch
.
LongTensor
([
0
,
2
,
5
])
return
gb
.
fused_csc_sampling_graph
(
indptr
,
indices
,
node_type_offset
=
node_type_offset
,
type_per_edge
=
type_per_edge
,
node_type_to_id
=
ntypes
,
edge_type_to_id
=
etypes
,
edge_attributes
=
edge_attributes
,
)
@
unittest
.
skipIf
(
F
.
_default_context_str
!=
"gpu"
,
reason
=
"Enabled only on GPU."
)
@
pytest
.
mark
.
parametrize
(
"hetero"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"prob_name"
,
[
None
,
"weight"
,
"mask"
])
def
test_NeighborSampler_GraphFetch
(
hetero
,
prob_name
):
items
=
torch
.
arange
(
3
)
names
=
"seed_nodes"
itemset
=
gb
.
ItemSet
(
items
,
names
=
names
)
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
if
hetero
:
itemset
=
gb
.
ItemSetDict
({
"n2"
:
itemset
})
else
:
graph
.
type_per_edge
=
None
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
fanout
=
torch
.
LongTensor
([
2
])
datapipe
=
item_sampler
.
map
(
gb
.
SubgraphSampler
.
_preprocess
)
datapipe
=
datapipe
.
map
(
partial
(
gb
.
NeighborSampler
.
_prepare
,
graph
.
node_type_to_id
)
)
sample_per_layer
=
gb
.
SamplePerLayer
(
datapipe
,
graph
.
sample_neighbors
,
fanout
,
False
,
prob_name
)
compact_per_layer
=
sample_per_layer
.
compact_per_layer
(
True
)
gb
.
seed
(
123
)
expected_results
=
list
(
compact_per_layer
)
datapipe
=
gb
.
FetchInsubgraphData
(
datapipe
,
sample_per_layer
)
datapipe
=
datapipe
.
wait_future
()
datapipe
=
gb
.
SamplePerLayerFromFetchedSubgraph
(
datapipe
,
sample_per_layer
)
datapipe
=
datapipe
.
compact_per_layer
(
True
)
gb
.
seed
(
123
)
new_results
=
list
(
datapipe
)
assert
len
(
expected_results
)
==
len
(
new_results
)
for
a
,
b
in
zip
(
expected_results
,
new_results
):
assert
repr
(
a
)
==
repr
(
b
)
tests/python/pytorch/graphbolt/test_dataloader.py
View file @
badeaf19
...
...
@@ -47,11 +47,21 @@ def test_DataLoader():
F
.
_default_context_str
!=
"gpu"
,
reason
=
"This test requires the GPU."
,
)
@
pytest
.
mark
.
parametrize
(
"overlap_feature_fetch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"sampler_name"
,
[
"NeighborSampler"
,
"LayerNeighborSampler"
]
)
@
pytest
.
mark
.
parametrize
(
"enable_feature_fetch"
,
[
True
,
False
])
def
test_gpu_sampling_DataLoader
(
overlap_feature_fetch
,
enable_feature_fetch
):
@
pytest
.
mark
.
parametrize
(
"overlap_feature_fetch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"overlap_graph_fetch"
,
[
True
,
False
])
def
test_gpu_sampling_DataLoader
(
sampler_name
,
enable_feature_fetch
,
overlap_feature_fetch
,
overlap_graph_fetch
,
):
N
=
40
B
=
4
num_layers
=
2
itemset
=
dgl
.
graphbolt
.
ItemSet
(
torch
.
arange
(
N
),
names
=
"seed_nodes"
)
graph
=
gb_test_utils
.
rand_csc_graph
(
200
,
0.15
,
bidirection_edge
=
True
).
to
(
F
.
ctx
()
...
...
@@ -68,10 +78,10 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch):
datapipe
=
dgl
.
graphbolt
.
ItemSampler
(
itemset
,
batch_size
=
B
)
datapipe
=
datapipe
.
copy_to
(
F
.
ctx
(),
extra_attrs
=
[
"seed_nodes"
])
datapipe
=
dgl
.
graphbolt
.
NeighborSampler
(
datapipe
=
getattr
(
dgl
.
graphbolt
,
sampler_name
)
(
datapipe
,
graph
,
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
2
)],
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layers
)],
)
if
enable_feature_fetch
:
datapipe
=
dgl
.
graphbolt
.
FeatureFetcher
(
...
...
@@ -81,14 +91,18 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch):
)
dataloader
=
dgl
.
graphbolt
.
DataLoader
(
datapipe
,
overlap_feature_fetch
=
overlap_feature_fetch
datapipe
,
overlap_feature_fetch
=
overlap_feature_fetch
,
overlap_graph_fetch
=
overlap_graph_fetch
,
)
bufferer_awaiter_cnt
=
int
(
enable_feature_fetch
and
overlap_feature_fetch
)
if
overlap_graph_fetch
:
bufferer_awaiter_cnt
+=
num_layers
datapipe
=
dataloader
.
dataset
datapipe_graph
=
dp_utils
.
traverse_dps
(
datapipe
)
awaiters
=
dp_utils
.
find_dps
(
datapipe_graph
,
dgl
.
graphbolt
.
Aw
aiter
,
dgl
.
graphbolt
.
W
aiter
,
)
assert
len
(
awaiters
)
==
bufferer_awaiter_cnt
bufferers
=
dp_utils
.
find_dps
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment