Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
50eb1014
Unverified
Commit
50eb1014
authored
Feb 01, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Feb 01, 2024
Browse files
[GraphBolt] Refactor NeighborSampler and expose fine-grained datapipes. (#6983)
parent
e602ab1b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
146 additions
and
57 deletions
+146
-57
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+98
-48
python/dgl/graphbolt/subgraph_sampler.py
python/dgl/graphbolt/subgraph_sampler.py
+48
-9
No files found.
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
50eb1014
"""Neighbor subgraph samplers for GraphBolt."""
"""Neighbor subgraph samplers for GraphBolt."""
from
functools
import
partial
import
torch
import
torch
from
torch.utils.data
import
functional_datapipe
from
torch.utils.data
import
functional_datapipe
from
..internal
import
compact_csc_format
,
unique_and_compact_csc_formats
from
..internal
import
compact_csc_format
,
unique_and_compact_csc_formats
from
..minibatch_transformer
import
MiniBatchTransformer
from
..subgraph_sampler
import
SubgraphSampler
from
..subgraph_sampler
import
SubgraphSampler
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
...
@@ -12,8 +15,66 @@ from .sampled_subgraph_impl import SampledSubgraphImpl
...
@@ -12,8 +15,66 @@ from .sampled_subgraph_impl import SampledSubgraphImpl
__all__
=
[
"NeighborSampler"
,
"LayerNeighborSampler"
]
__all__
=
[
"NeighborSampler"
,
"LayerNeighborSampler"
]
@
functional_datapipe
(
"sample_per_layer"
)
class
SamplePerLayer
(
MiniBatchTransformer
):
"""Sample neighbor edges from a graph for a single layer."""
def
__init__
(
self
,
datapipe
,
sampler
,
fanout
,
replace
,
prob_name
):
super
().
__init__
(
datapipe
,
self
.
_sample_per_layer
)
self
.
sampler
=
sampler
self
.
fanout
=
fanout
self
.
replace
=
replace
self
.
prob_name
=
prob_name
def
_sample_per_layer
(
self
,
minibatch
):
subgraph
=
self
.
sampler
(
minibatch
.
_seed_nodes
,
self
.
fanout
,
self
.
replace
,
self
.
prob_name
)
minibatch
.
sampled_subgraphs
.
insert
(
0
,
subgraph
)
return
minibatch
@
functional_datapipe
(
"compact_per_layer"
)
class
CompactPerLayer
(
MiniBatchTransformer
):
"""Compact the sampled edges for a single layer."""
def
__init__
(
self
,
datapipe
,
deduplicate
):
super
().
__init__
(
datapipe
,
self
.
_compact_per_layer
)
self
.
deduplicate
=
deduplicate
def
_compact_per_layer
(
self
,
minibatch
):
subgraph
=
minibatch
.
sampled_subgraphs
[
0
]
seeds
=
minibatch
.
_seed_nodes
if
self
.
deduplicate
:
(
original_row_node_ids
,
compacted_csc_format
,
)
=
unique_and_compact_csc_formats
(
subgraph
.
sampled_csc
,
seeds
)
subgraph
=
SampledSubgraphImpl
(
sampled_csc
=
compacted_csc_format
,
original_column_node_ids
=
seeds
,
original_row_node_ids
=
original_row_node_ids
,
original_edge_ids
=
subgraph
.
original_edge_ids
,
)
else
:
(
original_row_node_ids
,
compacted_csc_format
,
)
=
compact_csc_format
(
subgraph
.
sampled_csc
,
seeds
)
subgraph
=
SampledSubgraphImpl
(
sampled_csc
=
compacted_csc_format
,
original_column_node_ids
=
seeds
,
original_row_node_ids
=
original_row_node_ids
,
original_edge_ids
=
subgraph
.
original_edge_ids
,
)
minibatch
.
_seed_nodes
=
original_row_node_ids
minibatch
.
sampled_subgraphs
[
0
]
=
subgraph
return
minibatch
@
functional_datapipe
(
"sample_neighbor"
)
@
functional_datapipe
(
"sample_neighbor"
)
class
NeighborSampler
(
SubgraphSampler
):
class
NeighborSampler
(
SubgraphSampler
):
# pylint: disable=abstract-method
"""Sample neighbor edges from a graph and return a subgraph.
"""Sample neighbor edges from a graph and return a subgraph.
Functional name: :obj:`sample_neighbor`.
Functional name: :obj:`sample_neighbor`.
...
@@ -95,6 +156,7 @@ class NeighborSampler(SubgraphSampler):
...
@@ -95,6 +156,7 @@ class NeighborSampler(SubgraphSampler):
)]
)]
"""
"""
# pylint: disable=useless-super-delegation
def
__init__
(
def
__init__
(
self
,
self
,
datapipe
,
datapipe
,
...
@@ -103,26 +165,19 @@ class NeighborSampler(SubgraphSampler):
...
@@ -103,26 +165,19 @@ class NeighborSampler(SubgraphSampler):
replace
=
False
,
replace
=
False
,
prob_name
=
None
,
prob_name
=
None
,
deduplicate
=
True
,
deduplicate
=
True
,
sampler
=
None
,
):
):
super
().
__init__
(
datapipe
)
if
sampler
is
None
:
self
.
graph
=
graph
sampler
=
graph
.
sample_neighbors
# Convert fanouts to a list of tensors.
super
().
__init__
(
self
.
fanouts
=
[]
datapipe
,
graph
,
fanouts
,
replace
,
prob_name
,
deduplicate
,
sampler
for
fanout
in
fanouts
:
)
if
not
isinstance
(
fanout
,
torch
.
Tensor
):
fanout
=
torch
.
LongTensor
([
int
(
fanout
)])
self
.
fanouts
.
insert
(
0
,
fanout
)
self
.
replace
=
replace
self
.
prob_name
=
prob_name
self
.
deduplicate
=
deduplicate
self
.
sampler
=
graph
.
sample_neighbors
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
):
def
_prepare
(
self
,
node_type_to_id
,
minibatch
):
subgraphs
=
[]
seeds
=
minibatch
.
_seed_nodes
num_layers
=
len
(
self
.
fanouts
)
# Enrich seeds with all node types.
# Enrich seeds with all node types.
if
isinstance
(
seeds
,
dict
):
if
isinstance
(
seeds
,
dict
):
ntypes
=
list
(
self
.
graph
.
node_type_to_id
.
keys
())
ntypes
=
list
(
node_type_to_id
.
keys
())
# Loop over different seeds to extract the device they are on.
# Loop over different seeds to extract the device they are on.
device
=
None
device
=
None
dtype
=
None
dtype
=
None
...
@@ -134,42 +189,37 @@ class NeighborSampler(SubgraphSampler):
...
@@ -134,42 +189,37 @@ class NeighborSampler(SubgraphSampler):
seeds
=
{
seeds
=
{
ntype
:
seeds
.
get
(
ntype
,
default_tensor
)
for
ntype
in
ntypes
ntype
:
seeds
.
get
(
ntype
,
default_tensor
)
for
ntype
in
ntypes
}
}
for
hop
in
range
(
num_layers
):
minibatch
.
_seed_nodes
=
seeds
subgraph
=
self
.
sampler
(
minibatch
.
sampled_subgraphs
=
[]
seeds
,
return
minibatch
self
.
fanouts
[
hop
],
self
.
replace
,
@
staticmethod
self
.
prob_name
,
def
_set_input_nodes
(
minibatch
):
minibatch
.
input_nodes
=
minibatch
.
_seed_nodes
return
minibatch
# pylint: disable=arguments-differ
def
sampling_stages
(
self
,
datapipe
,
graph
,
fanouts
,
replace
,
prob_name
,
deduplicate
,
sampler
):
datapipe
=
datapipe
.
transform
(
partial
(
self
.
_prepare
,
graph
.
node_type_to_id
)
)
for
fanout
in
reversed
(
fanouts
):
# Convert fanout to tensor.
if
not
isinstance
(
fanout
,
torch
.
Tensor
):
fanout
=
torch
.
LongTensor
([
int
(
fanout
)])
datapipe
=
datapipe
.
sample_per_layer
(
sampler
,
fanout
,
replace
,
prob_name
)
)
if
self
.
deduplicate
:
datapipe
=
datapipe
.
compact_per_layer
(
deduplicate
)
(
original_row_node_ids
,
return
datapipe
.
transform
(
self
.
_set_input_nodes
)
compacted_csc_format
,
)
=
unique_and_compact_csc_formats
(
subgraph
.
sampled_csc
,
seeds
)
subgraph
=
SampledSubgraphImpl
(
sampled_csc
=
compacted_csc_format
,
original_column_node_ids
=
seeds
,
original_row_node_ids
=
original_row_node_ids
,
original_edge_ids
=
subgraph
.
original_edge_ids
,
)
else
:
(
original_row_node_ids
,
compacted_csc_format
,
)
=
compact_csc_format
(
subgraph
.
sampled_csc
,
seeds
)
subgraph
=
SampledSubgraphImpl
(
sampled_csc
=
compacted_csc_format
,
original_column_node_ids
=
seeds
,
original_row_node_ids
=
original_row_node_ids
,
original_edge_ids
=
subgraph
.
original_edge_ids
,
)
subgraphs
.
insert
(
0
,
subgraph
)
seeds
=
original_row_node_ids
return
seeds
,
subgraphs
@
functional_datapipe
(
"sample_layer_neighbor"
)
@
functional_datapipe
(
"sample_layer_neighbor"
)
class
LayerNeighborSampler
(
NeighborSampler
):
class
LayerNeighborSampler
(
NeighborSampler
):
# pylint: disable=abstract-method
"""Sample layer neighbor edges from a graph and return a subgraph.
"""Sample layer neighbor edges from a graph and return a subgraph.
Functional name: :obj:`sample_layer_neighbor`.
Functional name: :obj:`sample_layer_neighbor`.
...
@@ -280,5 +330,5 @@ class LayerNeighborSampler(NeighborSampler):
...
@@ -280,5 +330,5 @@ class LayerNeighborSampler(NeighborSampler):
replace
,
replace
,
prob_name
,
prob_name
,
deduplicate
,
deduplicate
,
graph
.
sample_layer_neighbors
,
)
)
self
.
sampler
=
graph
.
sample_layer_neighbors
python/dgl/graphbolt/subgraph_sampler.py
View file @
50eb1014
...
@@ -22,21 +22,44 @@ class SubgraphSampler(MiniBatchTransformer):
...
@@ -22,21 +22,44 @@ class SubgraphSampler(MiniBatchTransformer):
Functional name: :obj:`sample_subgraph`.
Functional name: :obj:`sample_subgraph`.
This class is the base class of all subgraph samplers. Any subclass of
This class is the base class of all subgraph samplers. Any subclass of
SubgraphSampler should implement the :meth:`sample_subgraphs` method.
SubgraphSampler should implement either the :meth:`sample_subgraphs` method
or the :meth:`sampling_stages` method to define the fine-grained sampling
stages to take advantage of optimizations provided by the GraphBolt
DataLoader.
Parameters
Parameters
----------
----------
datapipe : DataPipe
datapipe : DataPipe
The datapipe.
The datapipe.
args : Non-Keyword Arguments
Arguments to be passed into sampling_stages.
kwargs : Keyword Arguments
Arguments to be passed into sampling_stages.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
datapipe
,
datapipe
,
*
args
,
**
kwargs
,
):
):
super
().
__init__
(
datapipe
,
self
.
_sample
)
datapipe
=
datapipe
.
transform
(
self
.
_preprocess
)
datapipe
=
self
.
sampling_stages
(
datapipe
,
*
args
,
**
kwargs
)
datapipe
=
datapipe
.
transform
(
self
.
_postprocess
)
super
().
__init__
(
datapipe
,
self
.
_identity
)
def
_sample
(
self
,
minibatch
):
@
staticmethod
def
_identity
(
minibatch
):
return
minibatch
@
staticmethod
def
_postprocess
(
minibatch
):
delattr
(
minibatch
,
"_seed_nodes"
)
delattr
(
minibatch
,
"_seeds_timestamp"
)
return
minibatch
@
staticmethod
def
_preprocess
(
minibatch
):
if
minibatch
.
node_pairs
is
not
None
:
if
minibatch
.
node_pairs
is
not
None
:
(
(
seeds
,
seeds
,
...
@@ -44,7 +67,7 @@ class SubgraphSampler(MiniBatchTransformer):
...
@@ -44,7 +67,7 @@ class SubgraphSampler(MiniBatchTransformer):
minibatch
.
compacted_node_pairs
,
minibatch
.
compacted_node_pairs
,
minibatch
.
compacted_negative_srcs
,
minibatch
.
compacted_negative_srcs
,
minibatch
.
compacted_negative_dsts
,
minibatch
.
compacted_negative_dsts
,
)
=
self
.
_node_pairs_preprocess
(
minibatch
)
)
=
SubgraphSampler
.
_node_pairs_preprocess
(
minibatch
)
elif
minibatch
.
seed_nodes
is
not
None
:
elif
minibatch
.
seed_nodes
is
not
None
:
seeds
=
minibatch
.
seed_nodes
seeds
=
minibatch
.
seed_nodes
seeds_timestamp
=
(
seeds_timestamp
=
(
...
@@ -55,13 +78,12 @@ class SubgraphSampler(MiniBatchTransformer):
...
@@ -55,13 +78,12 @@ class SubgraphSampler(MiniBatchTransformer):
f
"Invalid minibatch
{
minibatch
}
: Either `node_pairs` or "
f
"Invalid minibatch
{
minibatch
}
: Either `node_pairs` or "
"`seed_nodes` should have a value."
"`seed_nodes` should have a value."
)
)
(
minibatch
.
_seed_nodes
=
seeds
minibatch
.
input_nodes
,
minibatch
.
_seeds_timestamp
=
seeds_timestamp
minibatch
.
sampled_subgraphs
,
)
=
self
.
sample_subgraphs
(
seeds
,
seeds_timestamp
)
return
minibatch
return
minibatch
def
_node_pairs_preprocess
(
self
,
minibatch
):
@
staticmethod
def
_node_pairs_preprocess
(
minibatch
):
use_timestamp
=
hasattr
(
minibatch
,
"timestamp"
)
use_timestamp
=
hasattr
(
minibatch
,
"timestamp"
)
node_pairs
=
minibatch
.
node_pairs
node_pairs
=
minibatch
.
node_pairs
neg_src
,
neg_dst
=
minibatch
.
negative_srcs
,
minibatch
.
negative_dsts
neg_src
,
neg_dst
=
minibatch
.
negative_srcs
,
minibatch
.
negative_dsts
...
@@ -191,6 +213,23 @@ class SubgraphSampler(MiniBatchTransformer):
...
@@ -191,6 +213,23 @@ class SubgraphSampler(MiniBatchTransformer):
compacted_negative_dsts
if
has_neg_dst
else
None
,
compacted_negative_dsts
if
has_neg_dst
else
None
,
)
)
def
_sample
(
self
,
minibatch
):
(
minibatch
.
input_nodes
,
minibatch
.
sampled_subgraphs
,
)
=
self
.
sample_subgraphs
(
minibatch
.
_seed_nodes
,
minibatch
.
_seeds_timestamp
)
return
minibatch
def
sampling_stages
(
self
,
datapipe
):
"""The sampling stages are defined here by chaining to the datapipe. The
default implementation expects :meth:`sample_subgraphs` to be
implemented. To define fine-grained stages, this method should be
overridden.
"""
return
datapipe
.
transform
(
self
.
_sample
)
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
):
def
sample_subgraphs
(
self
,
seeds
,
seeds_timestamp
):
"""Sample subgraphs from the given seeds, possibly with temporal constraints.
"""Sample subgraphs from the given seeds, possibly with temporal constraints.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment