Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
337b5ea7
"src/array/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "ed2e5409113ea65a6c395a2d2d70c06b59b7e80a"
Unverified
Commit
337b5ea7
authored
Sep 12, 2023
by
Rhett Ying
Committed by
GitHub
Sep 12, 2023
Browse files
[GraphBolt] enable fanouts to be a list of int (#6309)
parent
89bed21a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
2 deletions
+32
-2
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+8
-2
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+24
-0
No files found.
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
337b5ea7
"""Neighbor subgraph samplers for GraphBolt."""
"""Neighbor subgraph samplers for GraphBolt."""
import
torch
from
torch.utils.data
import
functional_datapipe
from
torch.utils.data
import
functional_datapipe
from
..subgraph_sampler
import
SubgraphSampler
from
..subgraph_sampler
import
SubgraphSampler
...
@@ -37,7 +38,7 @@ class NeighborSampler(SubgraphSampler):
...
@@ -37,7 +38,7 @@ class NeighborSampler(SubgraphSampler):
The datapipe.
The datapipe.
graph : CSCSamplingGraph
graph : CSCSamplingGraph
The graph on which to perform subgraph sampling.
The graph on which to perform subgraph sampling.
fanouts: list[torch.Tensor]
fanouts: list[torch.Tensor]
or list[int]
The number of edges to be sampled for each node with or without
The number of edges to be sampled for each node with or without
considering edge types. The length of this parameter implicitly
considering edge types. The length of this parameter implicitly
signifies the layer of sampling being conducted.
signifies the layer of sampling being conducted.
...
@@ -81,7 +82,12 @@ class NeighborSampler(SubgraphSampler):
...
@@ -81,7 +82,12 @@ class NeighborSampler(SubgraphSampler):
3
3
"""
"""
super
().
__init__
(
datapipe
)
super
().
__init__
(
datapipe
)
self
.
fanouts
=
fanouts
# Convert fanouts to a list of tensors.
self
.
fanouts
=
[]
for
fanout
in
fanouts
:
if
not
isinstance
(
fanout
,
torch
.
Tensor
):
fanout
=
torch
.
LongTensor
([
int
(
fanout
)])
self
.
fanouts
.
append
(
fanout
)
self
.
replace
=
replace
self
.
replace
=
replace
self
.
prob_name
=
prob_name
self
.
prob_name
=
prob_name
self
.
sampler
=
graph
.
sample_neighbors
self
.
sampler
=
graph
.
sample_neighbors
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
337b5ea7
...
@@ -41,6 +41,30 @@ def test_NeighborSampler_invoke(labor):
...
@@ -41,6 +41,30 @@ def test_NeighborSampler_invoke(labor):
assert
len
(
list
(
datapipe
))
==
5
assert
len
(
list
(
datapipe
))
==
5
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_NeighborSampler_fanouts
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
10
),
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
# `fanouts` is a list of tensors.
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
if
labor
:
datapipe
=
item_sampler
.
sample_layer_neighbor
(
graph
,
fanouts
)
else
:
datapipe
=
item_sampler
.
sample_neighbor
(
graph
,
fanouts
)
assert
len
(
list
(
datapipe
))
==
5
# `fanouts` is a list of integers.
fanouts
=
[
2
for
_
in
range
(
num_layer
)]
if
labor
:
datapipe
=
item_sampler
.
sample_layer_neighbor
(
graph
,
fanouts
)
else
:
datapipe
=
item_sampler
.
sample_neighbor
(
graph
,
fanouts
)
assert
len
(
list
(
datapipe
))
==
5
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Node
(
labor
):
def
test_SubgraphSampler_Node
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment