Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
498188dd
Unverified
Commit
498188dd
authored
Dec 22, 2023
by
Rhett Ying
Committed by
GitHub
Dec 22, 2023
Browse files
[GraphBolt] exclude nothing if edge is not found in node pairs (#6807)
parent
b04c9797
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
17 deletions
+36
-17
python/dgl/graphbolt/sampled_subgraph.py
python/dgl/graphbolt/sampled_subgraph.py
+16
-3
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
+20
-14
No files found.
python/dgl/graphbolt/sampled_subgraph.py
View file @
498188dd
"""Graphbolt sampled subgraph."""
"""Graphbolt sampled subgraph."""
# pylint: disable= invalid-name
# pylint: disable= invalid-name
from
typing
import
Dict
,
Tuple
,
Union
from
typing
import
Dict
,
Tuple
,
Union
...
@@ -181,6 +182,10 @@ class SampledSubgraph:
...
@@ -181,6 +182,10 @@ class SampledSubgraph:
index
=
{}
index
=
{}
is_cscformat
=
0
is_cscformat
=
0
for
etype
,
pair
in
self
.
node_pairs
.
items
():
for
etype
,
pair
in
self
.
node_pairs
.
items
():
if
etype
not
in
edges
:
# No edges need to be excluded.
index
[
etype
]
=
None
continue
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
original_row_node_ids
=
(
original_row_node_ids
=
(
None
None
...
@@ -207,7 +212,7 @@ class SampledSubgraph:
...
@@ -207,7 +212,7 @@ class SampledSubgraph:
)
)
index
[
etype
]
=
_exclude_homo_edges
(
index
[
etype
]
=
_exclude_homo_edges
(
reverse_edges
,
reverse_edges
,
edges
.
get
(
etype
)
,
edges
[
etype
]
,
assume_num_node_within_int32
,
assume_num_node_within_int32
,
)
)
if
is_cscformat
:
if
is_cscformat
:
...
@@ -266,8 +271,12 @@ def _relabel_two_arrays(lhs_array, rhs_array):
...
@@ -266,8 +271,12 @@ def _relabel_two_arrays(lhs_array, rhs_array):
return
mapping
[:
lhs_array
.
numel
()],
mapping
[
lhs_array
.
numel
()
:]
return
mapping
[:
lhs_array
.
numel
()],
mapping
[
lhs_array
.
numel
()
:]
def
_exclude_homo_edges
(
edges
,
edges_to_exclude
,
assume_num_node_within_int32
):
def
_exclude_homo_edges
(
"""Return the indices of edges that are not in edges_to_exclude."""
edges
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
edges_to_exclude
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
assume_num_node_within_int32
:
bool
,
):
"""Return the indices of edges to be included."""
if
assume_num_node_within_int32
:
if
assume_num_node_within_int32
:
val
=
edges
[
0
]
<<
32
|
edges
[
1
]
val
=
edges
[
0
]
<<
32
|
edges
[
1
]
val_to_exclude
=
edges_to_exclude
[
0
]
<<
32
|
edges_to_exclude
[
1
]
val_to_exclude
=
edges_to_exclude
[
0
]
<<
32
|
edges_to_exclude
[
1
]
...
@@ -286,6 +295,8 @@ def _slice_subgraph_node_pairs(subgraph: SampledSubgraph, index: torch.Tensor):
...
@@ -286,6 +295,8 @@ def _slice_subgraph_node_pairs(subgraph: SampledSubgraph, index: torch.Tensor):
def
_index_select
(
obj
,
index
):
def
_index_select
(
obj
,
index
):
if
obj
is
None
:
if
obj
is
None
:
return
None
return
None
if
index
is
None
:
return
obj
if
isinstance
(
obj
,
torch
.
Tensor
):
if
isinstance
(
obj
,
torch
.
Tensor
):
return
obj
[
index
]
return
obj
[
index
]
if
isinstance
(
obj
,
tuple
):
if
isinstance
(
obj
,
tuple
):
...
@@ -312,6 +323,8 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
...
@@ -312,6 +323,8 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
def
_index_select
(
obj
,
index
):
def
_index_select
(
obj
,
index
):
if
obj
is
None
:
if
obj
is
None
:
return
None
return
None
if
index
is
None
:
return
obj
if
isinstance
(
obj
,
CSCFormatBase
):
if
isinstance
(
obj
,
CSCFormatBase
):
new_indices
=
obj
.
indices
[
index
]
new_indices
=
obj
.
indices
[
index
]
new_indptr
=
torch
.
searchsorted
(
index
,
obj
.
indptr
)
new_indptr
=
torch
.
searchsorted
(
index
,
obj
.
indptr
)
...
...
tests/python/pytorch/graphbolt/test_subgraph_sampler.py
View file @
498188dd
from
functools
import
partial
import
dgl
import
dgl
import
dgl.graphbolt
as
gb
import
dgl.graphbolt
as
gb
import
pytest
import
pytest
...
@@ -88,25 +90,27 @@ def to_link_batch(data):
...
@@ -88,25 +90,27 @@ def to_link_batch(data):
def
test_SubgraphSampler_Link
(
labor
):
def
test_SubgraphSampler_Link
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
20
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
20
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
item_sampler
,
graph
,
fanouts
)
datapipe
=
Sampler
(
datapipe
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
def
test_SubgraphSampler_Link_With_Negative
(
labor
):
def
test_SubgraphSampler_Link_With_Negative
(
labor
):
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
graph
=
gb_test_utils
.
rand_csc_graph
(
20
,
0.15
,
bidirection_edge
=
True
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
20
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
itemset
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
20
).
reshape
(
-
1
,
2
),
names
=
"node_pairs"
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
negative_dp
=
gb
.
UniformNegativeSampler
(
item_sampler
,
graph
,
1
)
datapipe
=
gb
.
UniformNegativeSampler
(
datapipe
,
graph
,
1
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
negative_dp
,
graph
,
fanouts
)
datapipe
=
Sampler
(
datapipe
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
def
get_hetero_graph
():
def
get_hetero_graph
():
...
@@ -163,12 +167,13 @@ def test_SubgraphSampler_Link_Hetero(labor):
...
@@ -163,12 +167,13 @@ def test_SubgraphSampler_Link_Hetero(labor):
}
}
)
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
item_sampler
,
graph
,
fanouts
)
datapipe
=
Sampler
(
datapipe
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
...
@@ -187,13 +192,14 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
...
@@ -187,13 +192,14 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(labor):
}
}
)
)
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
datapipe
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
)
num_layer
=
2
num_layer
=
2
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
fanouts
=
[
torch
.
LongTensor
([
2
])
for
_
in
range
(
num_layer
)]
negative_dp
=
gb
.
UniformNegativeSampler
(
item_sampler
,
graph
,
1
)
datapipe
=
gb
.
UniformNegativeSampler
(
datapipe
,
graph
,
1
)
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
Sampler
=
gb
.
LayerNeighborSampler
if
labor
else
gb
.
NeighborSampler
neighbor_dp
=
Sampler
(
negative_dp
,
graph
,
fanouts
)
datapipe
=
Sampler
(
datapipe
,
graph
,
fanouts
)
assert
len
(
list
(
neighbor_dp
))
==
5
datapipe
=
datapipe
.
transform
(
partial
(
gb
.
exclude_seed_edges
))
assert
len
(
list
(
datapipe
))
==
5
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"labor"
,
[
False
,
True
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment