Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
06074d73
"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "4ab94cb7aa01828960f48f55fa477110f4d191fb"
Unverified
Commit
06074d73
authored
Sep 19, 2023
by
Rhett Ying
Committed by
GitHub
Sep 19, 2023
Browse files
[GraphBolt] enrich node types for input/output nodes of sampled subgraph (#6348)
parent
adf49937
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
3 deletions
+30
-3
python/dgl/graphbolt/impl/csc_sampling_graph.py
python/dgl/graphbolt/impl/csc_sampling_graph.py
+0
-2
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+8
-0
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
.../python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
+5
-1
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+17
-0
No files found.
python/dgl/graphbolt/impl/csc_sampling_graph.py
View file @
06074d73
...
@@ -242,8 +242,6 @@ class CSCSamplingGraph:
...
@@ -242,8 +242,6 @@ class CSCSamplingGraph:
src_ntype_id
=
self
.
metadata
.
node_type_to_id
[
src_ntype
]
src_ntype_id
=
self
.
metadata
.
node_type_to_id
[
src_ntype
]
dst_ntype_id
=
self
.
metadata
.
node_type_to_id
[
dst_ntype
]
dst_ntype_id
=
self
.
metadata
.
node_type_to_id
[
dst_ntype
]
mask
=
type_per_edge
==
etype_id
mask
=
type_per_edge
==
etype_id
if
mask
.
count_nonzero
()
==
0
:
continue
hetero_row
=
row
[
mask
]
-
self
.
node_type_offset
[
src_ntype_id
]
hetero_row
=
row
[
mask
]
-
self
.
node_type_offset
[
src_ntype_id
]
hetero_column
=
(
hetero_column
=
(
column
[
mask
]
-
self
.
node_type_offset
[
dst_ntype_id
]
column
[
mask
]
-
self
.
node_type_offset
[
dst_ntype_id
]
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
06074d73
...
@@ -78,6 +78,7 @@ class NeighborSampler(SubgraphSampler):
...
@@ -78,6 +78,7 @@ class NeighborSampler(SubgraphSampler):
3
3
"""
"""
super
().
__init__
(
datapipe
)
super
().
__init__
(
datapipe
)
self
.
graph
=
graph
# Convert fanouts to a list of tensors.
# Convert fanouts to a list of tensors.
self
.
fanouts
=
[]
self
.
fanouts
=
[]
for
fanout
in
fanouts
:
for
fanout
in
fanouts
:
...
@@ -91,6 +92,13 @@ class NeighborSampler(SubgraphSampler):
...
@@ -91,6 +92,13 @@ class NeighborSampler(SubgraphSampler):
def
_sample_subgraphs
(
self
,
seeds
):
def
_sample_subgraphs
(
self
,
seeds
):
subgraphs
=
[]
subgraphs
=
[]
num_layers
=
len
(
self
.
fanouts
)
num_layers
=
len
(
self
.
fanouts
)
# Enrich seeds with all node types.
if
isinstance
(
seeds
,
dict
):
ntypes
=
list
(
self
.
graph
.
metadata
.
node_type_to_id
.
keys
())
seeds
=
{
ntype
:
seeds
.
get
(
ntype
,
torch
.
LongTensor
([]))
for
ntype
in
ntypes
}
for
hop
in
range
(
num_layers
):
for
hop
in
range
(
num_layers
):
subgraph
=
self
.
sampler
(
subgraph
=
self
.
sampler
(
seeds
,
seeds
,
...
...
tests/python/pytorch/graphbolt/impl/test_csc_sampling_graph.py
View file @
06074d73
...
@@ -584,8 +584,12 @@ def test_sample_neighbors_hetero(labor):
...
@@ -584,8 +584,12 @@ def test_sample_neighbors_hetero(labor):
torch
.
LongTensor
([
0
,
2
]),
torch
.
LongTensor
([
0
,
2
]),
torch
.
LongTensor
([
0
,
0
]),
torch
.
LongTensor
([
0
,
0
]),
),
),
"n1:e1:n2"
:
(
torch
.
LongTensor
([]),
torch
.
LongTensor
([]),
),
}
}
assert
len
(
subgraph
.
node_pairs
)
==
1
assert
len
(
subgraph
.
node_pairs
)
==
2
for
etype
,
pairs
in
expected_node_pairs
.
items
():
for
etype
,
pairs
in
expected_node_pairs
.
items
():
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
0
],
pairs
[
0
])
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
0
],
pairs
[
0
])
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
1
],
pairs
[
1
])
assert
torch
.
equal
(
subgraph
.
node_pairs
[
etype
][
1
],
pairs
[
1
])
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
06074d73
...
@@ -129,6 +129,23 @@ def get_hetero_graph():
...
@@ -129,6 +129,23 @@ def get_hetero_graph():
)
)
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Node_Hetero
(
labor
):
graph
=
get_hetero_graph
()
itemset
=
gb
.
ItemSetDict
(
{
"n2"
:
gb
.
ItemSet
(
torch
.
arange
(
3
),
names
=
"seed_nodes"
)}
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
sampler_dp
=
Sampler
(
item_sampler
,
graph
,
fanouts
)
assert
len
(
list
(
sampler_dp
))
==
2
for
minibatch
in
sampler_dp
:
blocks
=
minibatch
.
to_dgl_blocks
()
assert
len
(
blocks
)
==
num_layer
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Link_Hetero
(
labor
):
def
test_SubgraphSampler_Link_Hetero
(
labor
):
graph
=
get_hetero_graph
()
graph
=
get_hetero_graph
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment