Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
badeaf19
"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "9bb3507014059bddea1fa6a53cebaedd96824cf1"
Unverified
Commit
badeaf19
authored
Feb 05, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Feb 05, 2024
Browse files
[GraphBolt][CUDA] Pipelined sampling optimization (#7039)
parent
4b265390
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
352 additions
and
75 deletions
+352
-75
python/dgl/graphbolt/base.py
python/dgl/graphbolt/base.py
+75
-0
python/dgl/graphbolt/dataloader.py
python/dgl/graphbolt/dataloader.py
+27
-60
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+150
-2
python/dgl/graphbolt/minibatch_transformer.py
python/dgl/graphbolt/minibatch_transformer.py
+6
-2
python/dgl/graphbolt/subgraph_sampler.py
python/dgl/graphbolt/subgraph_sampler.py
+1
-5
tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
+73
-0
tests/python/pytorch/graphbolt/test_dataloader.py
tests/python/pytorch/graphbolt/test_dataloader.py
+20
-6
No files found.
python/dgl/graphbolt/base.py
View file @
badeaf19
"""Base types and utilities for Graph Bolt."""
"""Base types and utilities for Graph Bolt."""
from
collections
import
deque
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
import
torch
import
torch
...
@@ -14,6 +15,10 @@ __all__ = [
...
@@ -14,6 +15,10 @@ __all__ = [
"etype_str_to_tuple"
,
"etype_str_to_tuple"
,
"etype_tuple_to_str"
,
"etype_tuple_to_str"
,
"CopyTo"
,
"CopyTo"
,
"FutureWaiter"
,
"Waiter"
,
"Bufferer"
,
"EndMarker"
,
"isin"
,
"isin"
,
"index_select"
,
"index_select"
,
"expand_indptr"
,
"expand_indptr"
,
...
@@ -247,6 +252,76 @@ class CopyTo(IterDataPipe):
...
@@ -247,6 +252,76 @@ class CopyTo(IterDataPipe):
yield
data
yield
data
@
functional_datapipe
(
"mark_end"
)
class
EndMarker
(
IterDataPipe
):
"""Used to mark the end of a datapipe and is a no-op."""
def
__init__
(
self
,
datapipe
):
self
.
datapipe
=
datapipe
def
__iter__
(
self
):
yield
from
self
.
datapipe
@
functional_datapipe
(
"buffer"
)
class
Bufferer
(
IterDataPipe
):
"""Buffers items before yielding them.
Parameters
----------
datapipe : DataPipe
The data pipeline.
buffer_size : int, optional
The size of the buffer which stores the fetched samples. If data coming
from datapipe has latency spikes, consider setting to a higher value.
Default is 1.
"""
def
__init__
(
self
,
datapipe
,
buffer_size
=
1
):
self
.
datapipe
=
datapipe
if
buffer_size
<=
0
:
raise
ValueError
(
"'buffer_size' is required to be a positive integer."
)
self
.
buffer
=
deque
(
maxlen
=
buffer_size
)
def
__iter__
(
self
):
for
data
in
self
.
datapipe
:
if
len
(
self
.
buffer
)
<
self
.
buffer
.
maxlen
:
self
.
buffer
.
append
(
data
)
else
:
return_data
=
self
.
buffer
.
popleft
()
self
.
buffer
.
append
(
data
)
yield
return_data
while
len
(
self
.
buffer
)
>
0
:
yield
self
.
buffer
.
popleft
()
@
functional_datapipe
(
"wait"
)
class
Waiter
(
IterDataPipe
):
"""Calls the wait function of all items."""
def
__init__
(
self
,
datapipe
):
self
.
datapipe
=
datapipe
def
__iter__
(
self
):
for
data
in
self
.
datapipe
:
data
.
wait
()
yield
data
@
functional_datapipe
(
"wait_future"
)
class
FutureWaiter
(
IterDataPipe
):
"""Calls the result function of all items and returns their results."""
def
__init__
(
self
,
datapipe
):
self
.
datapipe
=
datapipe
def
__iter__
(
self
):
for
data
in
self
.
datapipe
:
yield
data
.
result
()
@
dataclass
@
dataclass
class
CSCFormatBase
:
class
CSCFormatBase
:
r
"""Basic class representing data in Compressed Sparse Column (CSC) format.
r
"""Basic class representing data in Compressed Sparse Column (CSC) format.
...
...
python/dgl/graphbolt/dataloader.py
View file @
badeaf19
"""Graph Bolt DataLoaders"""
"""Graph Bolt DataLoaders"""
from
co
llections
import
deque
from
co
ncurrent.futures
import
ThreadPoolExecutor
import
torch
import
torch
import
torch.utils.data
import
torch.utils.data
...
@@ -9,6 +9,7 @@ import torchdata.datapipes as dp
...
@@ -9,6 +9,7 @@ import torchdata.datapipes as dp
from
.base
import
CopyTo
from
.base
import
CopyTo
from
.feature_fetcher
import
FeatureFetcher
from
.feature_fetcher
import
FeatureFetcher
from
.impl.neighbor_sampler
import
SamplePerLayer
from
.internal
import
datapipe_graph_to_adjlist
from
.internal
import
datapipe_graph_to_adjlist
from
.item_sampler
import
ItemSampler
from
.item_sampler
import
ItemSampler
...
@@ -16,8 +17,6 @@ from .item_sampler import ItemSampler
...
@@ -16,8 +17,6 @@ from .item_sampler import ItemSampler
__all__
=
[
__all__
=
[
"DataLoader"
,
"DataLoader"
,
"Awaiter"
,
"Bufferer"
,
]
]
...
@@ -40,61 +39,6 @@ def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
...
@@ -40,61 +39,6 @@ def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
return
datapipe_graph
return
datapipe_graph
class
EndMarker
(
dp
.
iter
.
IterDataPipe
):
"""Used to mark the end of a datapipe and is a no-op."""
def
__init__
(
self
,
datapipe
):
self
.
datapipe
=
datapipe
def
__iter__
(
self
):
yield
from
self
.
datapipe
class
Bufferer
(
dp
.
iter
.
IterDataPipe
):
"""Buffers items before yielding them.
Parameters
----------
datapipe : DataPipe
The data pipeline.
buffer_size : int, optional
The size of the buffer which stores the fetched samples. If data coming
from datapipe has latency spikes, consider setting to a higher value.
Default is 1.
"""
def
__init__
(
self
,
datapipe
,
buffer_size
=
1
):
self
.
datapipe
=
datapipe
if
buffer_size
<=
0
:
raise
ValueError
(
"'buffer_size' is required to be a positive integer."
)
self
.
buffer
=
deque
(
maxlen
=
buffer_size
)
def
__iter__
(
self
):
for
data
in
self
.
datapipe
:
if
len
(
self
.
buffer
)
<
self
.
buffer
.
maxlen
:
self
.
buffer
.
append
(
data
)
else
:
return_data
=
self
.
buffer
.
popleft
()
self
.
buffer
.
append
(
data
)
yield
return_data
while
len
(
self
.
buffer
)
>
0
:
yield
self
.
buffer
.
popleft
()
class
Awaiter
(
dp
.
iter
.
IterDataPipe
):
"""Calls the wait function of all items."""
def
__init__
(
self
,
datapipe
):
self
.
datapipe
=
datapipe
def
__iter__
(
self
):
for
data
in
self
.
datapipe
:
data
.
wait
()
yield
data
class
MultiprocessingWrapper
(
dp
.
iter
.
IterDataPipe
):
class
MultiprocessingWrapper
(
dp
.
iter
.
IterDataPipe
):
"""Wraps a datapipe with multiprocessing.
"""Wraps a datapipe with multiprocessing.
...
@@ -156,6 +100,10 @@ class DataLoader(torch.utils.data.DataLoader):
...
@@ -156,6 +100,10 @@ class DataLoader(torch.utils.data.DataLoader):
If True, the data loader will overlap the UVA feature fetcher operations
If True, the data loader will overlap the UVA feature fetcher operations
with the rest of operations by using an alternative CUDA stream. Default
with the rest of operations by using an alternative CUDA stream. Default
is True.
is True.
overlap_graph_fetch : bool, optional
If True, the data loader will overlap the UVA graph fetching operations
with the rest of operations by using an alternative CUDA stream. Default
is False.
max_uva_threads : int, optional
max_uva_threads : int, optional
Limits the number of CUDA threads used for UVA copies so that the rest
Limits the number of CUDA threads used for UVA copies so that the rest
of the computations can run simultaneously with it. Setting it to a too
of the computations can run simultaneously with it. Setting it to a too
...
@@ -170,6 +118,7 @@ class DataLoader(torch.utils.data.DataLoader):
...
@@ -170,6 +118,7 @@ class DataLoader(torch.utils.data.DataLoader):
num_workers
=
0
,
num_workers
=
0
,
persistent_workers
=
True
,
persistent_workers
=
True
,
overlap_feature_fetch
=
True
,
overlap_feature_fetch
=
True
,
overlap_graph_fetch
=
False
,
max_uva_threads
=
6144
,
max_uva_threads
=
6144
,
):
):
# Multiprocessing requires two modifications to the datapipe:
# Multiprocessing requires two modifications to the datapipe:
...
@@ -179,7 +128,7 @@ class DataLoader(torch.utils.data.DataLoader):
...
@@ -179,7 +128,7 @@ class DataLoader(torch.utils.data.DataLoader):
# 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe
# 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe
# of the FeatureFetcher with a multiprocessing PyTorch DataLoader.
# of the FeatureFetcher with a multiprocessing PyTorch DataLoader.
datapipe
=
EndMarker
(
datapipe
)
datapipe
=
datapipe
.
mark_end
(
)
datapipe_graph
=
dp_utils
.
traverse_dps
(
datapipe
)
datapipe_graph
=
dp_utils
.
traverse_dps
(
datapipe
)
# (1) Insert minibatch distribution.
# (1) Insert minibatch distribution.
...
@@ -223,7 +172,25 @@ class DataLoader(torch.utils.data.DataLoader):
...
@@ -223,7 +172,25 @@ class DataLoader(torch.utils.data.DataLoader):
datapipe_graph
=
dp_utils
.
replace_dp
(
datapipe_graph
=
dp_utils
.
replace_dp
(
datapipe_graph
,
datapipe_graph
,
feature_fetcher
,
feature_fetcher
,
Awaiter
(
Bufferer
(
feature_fetcher
,
buffer_size
=
1
)),
feature_fetcher
.
buffer
(
1
).
wait
(),
)
if
(
overlap_graph_fetch
and
num_workers
==
0
and
torch
.
cuda
.
is_available
()
):
torch
.
ops
.
graphbolt
.
set_max_uva_threads
(
max_uva_threads
)
samplers
=
dp_utils
.
find_dps
(
datapipe_graph
,
SamplePerLayer
,
)
executor
=
ThreadPoolExecutor
(
max_workers
=
1
)
for
sampler
in
samplers
:
datapipe_graph
=
dp_utils
.
replace_dp
(
datapipe_graph
,
sampler
,
sampler
.
fetch_and_sample
(
_get_uva_stream
(),
executor
,
1
),
)
)
# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
badeaf19
"""Neighbor subgraph samplers for GraphBolt."""
"""Neighbor subgraph samplers for GraphBolt."""
from
concurrent.futures
import
ThreadPoolExecutor
from
functools
import
partial
from
functools
import
partial
import
torch
import
torch
from
torch.utils.data
import
functional_datapipe
from
torch.utils.data
import
functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
from
..internal
import
compact_csc_format
,
unique_and_compact_csc_formats
from
..internal
import
compact_csc_format
,
unique_and_compact_csc_formats
from
..minibatch_transformer
import
MiniBatchTransformer
from
..minibatch_transformer
import
MiniBatchTransformer
from
..subgraph_sampler
import
SubgraphSampler
from
..subgraph_sampler
import
SubgraphSampler
from
.fused_csc_sampling_graph
import
fused_csc_sampling_graph
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
__all__
=
[
"NeighborSampler"
,
"LayerNeighborSampler"
]
__all__
=
[
"NeighborSampler"
,
"LayerNeighborSampler"
,
"SamplePerLayer"
,
"SamplePerLayerFromFetchedSubgraph"
,
"FetchInsubgraphData"
,
]
@
functional_datapipe
(
"fetch_insubgraph_data"
)
class
FetchInsubgraphData
(
Mapper
):
"""Fetches the insubgraph and wraps it in a FusedCSCSamplingGraph object. If
the provided sample_per_layer_obj has a valid prob_name, then it reads the
probabilies of all the fetched edges. Furthermore, if type_per_array tensor
exists in the underlying graph, then the types of all the fetched edges are
read as well."""
def
__init__
(
self
,
datapipe
,
sample_per_layer_obj
,
stream
=
None
,
executor
=
None
):
super
().
__init__
(
datapipe
,
self
.
_fetch_per_layer
)
self
.
graph
=
sample_per_layer_obj
.
sampler
.
__self__
self
.
prob_name
=
sample_per_layer_obj
.
prob_name
self
.
stream
=
stream
if
executor
is
None
:
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
1
)
else
:
self
.
executor
=
executor
def
_fetch_per_layer_impl
(
self
,
minibatch
,
stream
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
index
=
minibatch
.
_seed_nodes
if
isinstance
(
index
,
dict
):
index
=
self
.
graph
.
_convert_to_homogeneous_nodes
(
index
)
index
,
original_positions
=
index
.
sort
()
if
(
original_positions
.
diff
()
==
1
).
all
().
item
():
# is_sorted
minibatch
.
_subgraph_seed_nodes
=
None
else
:
minibatch
.
_subgraph_seed_nodes
=
original_positions
index
.
record_stream
(
torch
.
cuda
.
current_stream
())
index_select_csc_with_indptr
=
partial
(
torch
.
ops
.
graphbolt
.
index_select_csc
,
self
.
graph
.
csc_indptr
)
def
record_stream
(
tensor
):
if
stream
is
not
None
and
tensor
.
is_cuda
:
tensor
.
record_stream
(
stream
)
indptr
,
indices
=
index_select_csc_with_indptr
(
self
.
graph
.
indices
,
index
,
None
)
record_stream
(
indptr
)
record_stream
(
indices
)
output_size
=
len
(
indices
)
if
self
.
graph
.
type_per_edge
is
not
None
:
_
,
type_per_edge
=
index_select_csc_with_indptr
(
self
.
graph
.
type_per_edge
,
index
,
output_size
)
record_stream
(
type_per_edge
)
else
:
type_per_edge
=
None
if
self
.
graph
.
edge_attributes
is
not
None
:
probs_or_mask
=
self
.
graph
.
edge_attributes
.
get
(
self
.
prob_name
,
None
)
if
probs_or_mask
is
not
None
:
_
,
probs_or_mask
=
index_select_csc_with_indptr
(
probs_or_mask
,
index
,
output_size
)
record_stream
(
probs_or_mask
)
else
:
probs_or_mask
=
None
if
self
.
graph
.
node_type_offset
is
not
None
:
node_type_offset
=
torch
.
searchsorted
(
index
,
self
.
graph
.
node_type_offset
)
else
:
node_type_offset
=
None
subgraph
=
fused_csc_sampling_graph
(
indptr
,
indices
,
node_type_offset
=
node_type_offset
,
type_per_edge
=
type_per_edge
,
node_type_to_id
=
self
.
graph
.
node_type_to_id
,
edge_type_to_id
=
self
.
graph
.
edge_type_to_id
,
)
if
self
.
prob_name
is
not
None
and
probs_or_mask
is
not
None
:
subgraph
.
edge_attributes
=
{
self
.
prob_name
:
probs_or_mask
}
minibatch
.
sampled_subgraphs
.
insert
(
0
,
subgraph
)
if
self
.
stream
is
not
None
:
minibatch
.
wait
=
torch
.
cuda
.
current_stream
().
record_event
().
wait
return
minibatch
def
_fetch_per_layer
(
self
,
minibatch
):
current_stream
=
None
if
self
.
stream
is
not
None
:
current_stream
=
torch
.
cuda
.
current_stream
()
self
.
stream
.
wait_stream
(
current_stream
)
return
self
.
executor
.
submit
(
self
.
_fetch_per_layer_impl
,
minibatch
,
current_stream
)
@
functional_datapipe
(
"sample_per_layer_from_fetched_subgraph"
)
class
SamplePerLayerFromFetchedSubgraph
(
MiniBatchTransformer
):
"""Sample neighbor edges from a graph for a single layer."""
def
__init__
(
self
,
datapipe
,
sample_per_layer_obj
):
super
().
__init__
(
datapipe
,
self
.
_sample_per_layer_from_fetched_subgraph
)
self
.
sampler_name
=
sample_per_layer_obj
.
sampler
.
__name__
self
.
fanout
=
sample_per_layer_obj
.
fanout
self
.
replace
=
sample_per_layer_obj
.
replace
self
.
prob_name
=
sample_per_layer_obj
.
prob_name
def
_sample_per_layer_from_fetched_subgraph
(
self
,
minibatch
):
subgraph
=
minibatch
.
sampled_subgraphs
[
0
]
sampled_subgraph
=
getattr
(
subgraph
,
self
.
sampler_name
)(
minibatch
.
_subgraph_seed_nodes
,
self
.
fanout
,
self
.
replace
,
self
.
prob_name
,
)
delattr
(
minibatch
,
"_subgraph_seed_nodes"
)
sampled_subgraph
.
original_column_node_ids
=
minibatch
.
_seed_nodes
minibatch
.
sampled_subgraphs
[
0
]
=
sampled_subgraph
return
minibatch
@
functional_datapipe
(
"sample_per_layer"
)
@
functional_datapipe
(
"sample_per_layer"
)
...
@@ -72,6 +206,19 @@ class CompactPerLayer(MiniBatchTransformer):
...
@@ -72,6 +206,19 @@ class CompactPerLayer(MiniBatchTransformer):
return
minibatch
return
minibatch
@
functional_datapipe
(
"fetch_and_sample"
)
class
FetcherAndSampler
(
MiniBatchTransformer
):
"""Overlapped graph sampling operation replacement."""
def
__init__
(
self
,
sampler
,
stream
,
executor
,
buffer_size
):
datapipe
=
sampler
.
datapipe
.
fetch_insubgraph_data
(
sampler
,
stream
,
executor
)
datapipe
=
datapipe
.
buffer
(
buffer_size
).
wait_future
().
wait
()
datapipe
=
datapipe
.
sample_per_layer_from_fetched_subgraph
(
sampler
)
super
().
__init__
(
datapipe
)
@
functional_datapipe
(
"sample_neighbor"
)
@
functional_datapipe
(
"sample_neighbor"
)
class
NeighborSampler
(
SubgraphSampler
):
class
NeighborSampler
(
SubgraphSampler
):
# pylint: disable=abstract-method
# pylint: disable=abstract-method
...
@@ -173,7 +320,8 @@ class NeighborSampler(SubgraphSampler):
...
@@ -173,7 +320,8 @@ class NeighborSampler(SubgraphSampler):
datapipe
,
graph
,
fanouts
,
replace
,
prob_name
,
deduplicate
,
sampler
datapipe
,
graph
,
fanouts
,
replace
,
prob_name
,
deduplicate
,
sampler
)
)
def
_prepare
(
self
,
node_type_to_id
,
minibatch
):
@
staticmethod
def
_prepare
(
node_type_to_id
,
minibatch
):
seeds
=
minibatch
.
_seed_nodes
seeds
=
minibatch
.
_seed_nodes
# Enrich seeds with all node types.
# Enrich seeds with all node types.
if
isinstance
(
seeds
,
dict
):
if
isinstance
(
seeds
,
dict
):
...
...
python/dgl/graphbolt/minibatch_transformer.py
View file @
badeaf19
...
@@ -29,10 +29,10 @@ class MiniBatchTransformer(Mapper):
...
@@ -29,10 +29,10 @@ class MiniBatchTransformer(Mapper):
def
__init__
(
def
__init__
(
self
,
self
,
datapipe
,
datapipe
,
transformer
,
transformer
=
None
,
):
):
super
().
__init__
(
datapipe
,
self
.
_transformer
)
super
().
__init__
(
datapipe
,
self
.
_transformer
)
self
.
transformer
=
transformer
self
.
transformer
=
transformer
or
self
.
_identity
def
_transformer
(
self
,
minibatch
):
def
_transformer
(
self
,
minibatch
):
minibatch
=
self
.
transformer
(
minibatch
)
minibatch
=
self
.
transformer
(
minibatch
)
...
@@ -40,3 +40,7 @@ class MiniBatchTransformer(Mapper):
...
@@ -40,3 +40,7 @@ class MiniBatchTransformer(Mapper):
minibatch
,
(
MiniBatch
,)
minibatch
,
(
MiniBatch
,)
),
"The transformer output should be an instance of MiniBatch"
),
"The transformer output should be an instance of MiniBatch"
return
minibatch
return
minibatch
@
staticmethod
def
_identity
(
minibatch
):
return
minibatch
python/dgl/graphbolt/subgraph_sampler.py
View file @
badeaf19
...
@@ -46,11 +46,7 @@ class SubgraphSampler(MiniBatchTransformer):
...
@@ -46,11 +46,7 @@ class SubgraphSampler(MiniBatchTransformer):
datapipe
=
datapipe
.
transform
(
self
.
_preprocess
)
datapipe
=
datapipe
.
transform
(
self
.
_preprocess
)
datapipe
=
self
.
sampling_stages
(
datapipe
,
*
args
,
**
kwargs
)
datapipe
=
self
.
sampling_stages
(
datapipe
,
*
args
,
**
kwargs
)
datapipe
=
datapipe
.
transform
(
self
.
_postprocess
)
datapipe
=
datapipe
.
transform
(
self
.
_postprocess
)
super
().
__init__
(
datapipe
,
self
.
_identity
)
super
().
__init__
(
datapipe
)
@
staticmethod
def
_identity
(
minibatch
):
return
minibatch
@
staticmethod
@
staticmethod
def
_postprocess
(
minibatch
):
def
_postprocess
(
minibatch
):
...
...
tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
0 → 100644
View file @
badeaf19
import
unittest
from
functools
import
partial
import
backend
as
F
import
dgl
import
dgl.graphbolt
as
gb
import
pytest
import
torch
def
get_hetero_graph
():
# COO graph:
# [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
# [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
etypes
=
{
"n1:e1:n2"
:
0
,
"n2:e2:n1"
:
1
}
indptr
=
torch
.
LongTensor
([
0
,
2
,
4
,
6
,
8
,
10
])
indices
=
torch
.
LongTensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
0
,
1
])
type_per_edge
=
torch
.
LongTensor
([
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
])
edge_attributes
=
{
"weight"
:
torch
.
FloatTensor
(
[
2.5
,
0
,
8.4
,
0
,
0.4
,
1.2
,
2.5
,
0
,
8.4
,
0.5
]
),
"mask"
:
torch
.
BoolTensor
([
1
,
0
,
1
,
0
,
1
,
1
,
1
,
0
,
1
,
1
]),
}
node_type_offset
=
torch
.
LongTensor
([
0
,
2
,
5
])
return
gb
.
fused_csc_sampling_graph
(
indptr
,
indices
,
node_type_offset
=
node_type_offset
,
type_per_edge
=
type_per_edge
,
node_type_to_id
=
ntypes
,
edge_type_to_id
=
etypes
,
edge_attributes
=
edge_attributes
,
)
@
unittest
.
skipIf
(
F
.
_default_context_str
!=
"gpu"
,
reason
=
"Enabled only on GPU."
)
@
pytest
.
mark
.
parametrize
(
"hetero"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"prob_name"
,
[
None
,
"weight"
,
"mask"
])
def
test_NeighborSampler_GraphFetch
(
hetero
,
prob_name
):
items
=
torch
.
arange
(
3
)
names
=
"seed_nodes"
itemset
=
gb
.
ItemSet
(
items
,
names
=
names
)
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
if
hetero
:
itemset
=
gb
.
ItemSetDict
({
"n2"
:
itemset
})
else
:
graph
.
type_per_edge
=
None
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
fanout
=
torch
.
LongTensor
([
2
])
datapipe
=
item_sampler
.
map
(
gb
.
SubgraphSampler
.
_preprocess
)
datapipe
=
datapipe
.
map
(
partial
(
gb
.
NeighborSampler
.
_prepare
,
graph
.
node_type_to_id
)
)
sample_per_layer
=
gb
.
SamplePerLayer
(
datapipe
,
graph
.
sample_neighbors
,
fanout
,
False
,
prob_name
)
compact_per_layer
=
sample_per_layer
.
compact_per_layer
(
True
)
gb
.
seed
(
123
)
expected_results
=
list
(
compact_per_layer
)
datapipe
=
gb
.
FetchInsubgraphData
(
datapipe
,
sample_per_layer
)
datapipe
=
datapipe
.
wait_future
()
datapipe
=
gb
.
SamplePerLayerFromFetchedSubgraph
(
datapipe
,
sample_per_layer
)
datapipe
=
datapipe
.
compact_per_layer
(
True
)
gb
.
seed
(
123
)
new_results
=
list
(
datapipe
)
assert
len
(
expected_results
)
==
len
(
new_results
)
for
a
,
b
in
zip
(
expected_results
,
new_results
):
assert
repr
(
a
)
==
repr
(
b
)
tests/python/pytorch/graphbolt/test_dataloader.py
View file @
badeaf19
...
@@ -47,11 +47,21 @@ def test_DataLoader():
...
@@ -47,11 +47,21 @@ def test_DataLoader():
F
.
_default_context_str
!=
"gpu"
,
F
.
_default_context_str
!=
"gpu"
,
reason
=
"This test requires the GPU."
,
reason
=
"This test requires the GPU."
,
)
)
@
pytest
.
mark
.
parametrize
(
"overlap_feature_fetch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"sampler_name"
,
[
"NeighborSampler"
,
"LayerNeighborSampler"
]
)
@
pytest
.
mark
.
parametrize
(
"enable_feature_fetch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"enable_feature_fetch"
,
[
True
,
False
])
def
test_gpu_sampling_DataLoader
(
overlap_feature_fetch
,
enable_feature_fetch
):
@
pytest
.
mark
.
parametrize
(
"overlap_feature_fetch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"overlap_graph_fetch"
,
[
True
,
False
])
def
test_gpu_sampling_DataLoader
(
sampler_name
,
enable_feature_fetch
,
overlap_feature_fetch
,
overlap_graph_fetch
,
):
N
=
40
N
=
40
B
=
4
B
=
4
num_layers
=
2
itemset
=
dgl
.
graphbolt
.
ItemSet
(
torch
.
arange
(
N
),
names
=
"seed_nodes"
)
itemset
=
dgl
.
graphbolt
.
ItemSet
(
torch
.
arange
(
N
),
names
=
"seed_nodes"
)
graph
=
gb_test_utils
.
rand_csc_graph
(
200
,
0.15
,
bidirection_edge
=
True
).
to
(
graph
=
gb_test_utils
.
rand_csc_graph
(
200
,
0.15
,
bidirection_edge
=
True
).
to
(
F
.
ctx
()
F
.
ctx
()
...
@@ -68,10 +78,10 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch):
...
@@ -68,10 +78,10 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch):
datapipe
=
dgl
.
graphbolt
.
ItemSampler
(
itemset
,
batch_size
=
B
)
datapipe
=
dgl
.
graphbolt
.
ItemSampler
(
itemset
,
batch_size
=
B
)
datapipe
=
datapipe
.
copy_to
(
F
.
ctx
(),
extra_attrs
=
[
"seed_nodes"
])
datapipe
=
datapipe
.
copy_to
(
F
.
ctx
(),
extra_attrs
=
[
"seed_nodes"
])
datapipe
=
dgl
.
graphbolt
.
NeighborSampler
(
datapipe
=
getattr
(
dgl
.
graphbolt
,
sampler_name
)
(
datapipe
,
datapipe
,
graph
,
graph
,
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
2
)],
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layers
)],
)
)
if
enable_feature_fetch
:
if
enable_feature_fetch
:
datapipe
=
dgl
.
graphbolt
.
FeatureFetcher
(
datapipe
=
dgl
.
graphbolt
.
FeatureFetcher
(
...
@@ -81,14 +91,18 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch):
...
@@ -81,14 +91,18 @@ def test_gpu_sampling_DataLoader(overlap_feature_fetch, enable_feature_fetch):
)
)
dataloader
=
dgl
.
graphbolt
.
DataLoader
(
dataloader
=
dgl
.
graphbolt
.
DataLoader
(
datapipe
,
overlap_feature_fetch
=
overlap_feature_fetch
datapipe
,
overlap_feature_fetch
=
overlap_feature_fetch
,
overlap_graph_fetch
=
overlap_graph_fetch
,
)
)
bufferer_awaiter_cnt
=
int
(
enable_feature_fetch
and
overlap_feature_fetch
)
bufferer_awaiter_cnt
=
int
(
enable_feature_fetch
and
overlap_feature_fetch
)
if
overlap_graph_fetch
:
bufferer_awaiter_cnt
+=
num_layers
datapipe
=
dataloader
.
dataset
datapipe
=
dataloader
.
dataset
datapipe_graph
=
dp_utils
.
traverse_dps
(
datapipe
)
datapipe_graph
=
dp_utils
.
traverse_dps
(
datapipe
)
awaiters
=
dp_utils
.
find_dps
(
awaiters
=
dp_utils
.
find_dps
(
datapipe_graph
,
datapipe_graph
,
dgl
.
graphbolt
.
Aw
aiter
,
dgl
.
graphbolt
.
W
aiter
,
)
)
assert
len
(
awaiters
)
==
bufferer_awaiter_cnt
assert
len
(
awaiters
)
==
bufferer_awaiter_cnt
bufferers
=
dp_utils
.
find_dps
(
bufferers
=
dp_utils
.
find_dps
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment