Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
93990a90
"tests/python/pytorch/geometry/test_geometry.py" did not exist on "1425150459963514047ac3a7ce84574eaf463a2b"
Unverified
Commit
93990a90
authored
Mar 13, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Mar 13, 2024
Browse files
[GraphBolt] Refactor `NeighborSamplerImpl` (#7207)
parent
f0c7efa9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
73 additions
and
50 deletions
+73
-50
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+73
-50
No files found.
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
93990a90
...
@@ -230,8 +230,72 @@ class FetcherAndSampler(MiniBatchTransformer):
...
@@ -230,8 +230,72 @@ class FetcherAndSampler(MiniBatchTransformer):
super
().
__init__
(
datapipe
)
super
().
__init__
(
datapipe
)
class
NeighborSamplerImpl
(
SubgraphSampler
):
# pylint: disable=abstract-method
"""Base class for NeighborSamplers."""
# pylint: disable=useless-super-delegation
def
__init__
(
self
,
datapipe
,
graph
,
fanouts
,
replace
,
prob_name
,
deduplicate
,
sampler
,
):
super
().
__init__
(
datapipe
,
graph
,
fanouts
,
replace
,
prob_name
,
deduplicate
,
sampler
)
@
staticmethod
def
_prepare
(
node_type_to_id
,
minibatch
):
seeds
=
minibatch
.
_seed_nodes
# Enrich seeds with all node types.
if
isinstance
(
seeds
,
dict
):
ntypes
=
list
(
node_type_to_id
.
keys
())
# Loop over different seeds to extract the device they are on.
device
=
None
dtype
=
None
for
_
,
seed
in
seeds
.
items
():
device
=
seed
.
device
dtype
=
seed
.
dtype
break
default_tensor
=
torch
.
tensor
([],
dtype
=
dtype
,
device
=
device
)
seeds
=
{
ntype
:
seeds
.
get
(
ntype
,
default_tensor
)
for
ntype
in
ntypes
}
minibatch
.
_seed_nodes
=
seeds
minibatch
.
sampled_subgraphs
=
[]
return
minibatch
@
staticmethod
def
_set_input_nodes
(
minibatch
):
minibatch
.
input_nodes
=
minibatch
.
_seed_nodes
return
minibatch
# pylint: disable=arguments-differ
def
sampling_stages
(
self
,
datapipe
,
graph
,
fanouts
,
replace
,
prob_name
,
deduplicate
,
sampler
):
datapipe
=
datapipe
.
transform
(
partial
(
self
.
_prepare
,
graph
.
node_type_to_id
)
)
for
fanout
in
reversed
(
fanouts
):
# Convert fanout to tensor.
if
not
isinstance
(
fanout
,
torch
.
Tensor
):
fanout
=
torch
.
LongTensor
([
int
(
fanout
)])
datapipe
=
datapipe
.
sample_per_layer
(
sampler
,
fanout
,
replace
,
prob_name
)
datapipe
=
datapipe
.
compact_per_layer
(
deduplicate
)
return
datapipe
.
transform
(
self
.
_set_input_nodes
)
@
functional_datapipe
(
"sample_neighbor"
)
@
functional_datapipe
(
"sample_neighbor"
)
class
NeighborSampler
(
Subgraph
Sampler
):
class
NeighborSampler
(
Neighbor
Sampler
Impl
):
# pylint: disable=abstract-method
# pylint: disable=abstract-method
"""Sample neighbor edges from a graph and return a subgraph.
"""Sample neighbor edges from a graph and return a subgraph.
...
@@ -323,61 +387,20 @@ class NeighborSampler(SubgraphSampler):
...
@@ -323,61 +387,20 @@ class NeighborSampler(SubgraphSampler):
replace
=
False
,
replace
=
False
,
prob_name
=
None
,
prob_name
=
None
,
deduplicate
=
True
,
deduplicate
=
True
,
sampler
=
None
,
):
):
if
sampler
is
None
:
sampler
=
graph
.
sample_neighbors
super
().
__init__
(
super
().
__init__
(
datapipe
,
graph
,
fanouts
,
replace
,
prob_name
,
deduplicate
,
sampler
datapipe
,
)
graph
,
fanouts
,
@
staticmethod
replace
,
def
_prepare
(
node_type_to_id
,
minibatch
):
prob_name
,
seeds
=
minibatch
.
_seed_nodes
deduplicate
,
# Enrich seeds with all node types.
graph
.
sample_neighbors
,
if
isinstance
(
seeds
,
dict
):
ntypes
=
list
(
node_type_to_id
.
keys
())
# Loop over different seeds to extract the device they are on.
device
=
None
dtype
=
None
for
_
,
seed
in
seeds
.
items
():
device
=
seed
.
device
dtype
=
seed
.
dtype
break
default_tensor
=
torch
.
tensor
([],
dtype
=
dtype
,
device
=
device
)
seeds
=
{
ntype
:
seeds
.
get
(
ntype
,
default_tensor
)
for
ntype
in
ntypes
}
minibatch
.
_seed_nodes
=
seeds
minibatch
.
sampled_subgraphs
=
[]
return
minibatch
@
staticmethod
def
_set_input_nodes
(
minibatch
):
minibatch
.
input_nodes
=
minibatch
.
_seed_nodes
return
minibatch
# pylint: disable=arguments-differ
def
sampling_stages
(
self
,
datapipe
,
graph
,
fanouts
,
replace
,
prob_name
,
deduplicate
,
sampler
):
datapipe
=
datapipe
.
transform
(
partial
(
self
.
_prepare
,
graph
.
node_type_to_id
)
)
for
fanout
in
reversed
(
fanouts
):
# Convert fanout to tensor.
if
not
isinstance
(
fanout
,
torch
.
Tensor
):
fanout
=
torch
.
LongTensor
([
int
(
fanout
)])
datapipe
=
datapipe
.
sample_per_layer
(
sampler
,
fanout
,
replace
,
prob_name
)
)
datapipe
=
datapipe
.
compact_per_layer
(
deduplicate
)
return
datapipe
.
transform
(
self
.
_set_input_nodes
)
@
functional_datapipe
(
"sample_layer_neighbor"
)
@
functional_datapipe
(
"sample_layer_neighbor"
)
class
LayerNeighborSampler
(
NeighborSampler
):
class
LayerNeighborSampler
(
NeighborSampler
Impl
):
# pylint: disable=abstract-method
# pylint: disable=abstract-method
"""Sample layer neighbor edges from a graph and return a subgraph.
"""Sample layer neighbor edges from a graph and return a subgraph.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment