Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c036222b
Unverified
Commit
c036222b
authored
Sep 15, 2023
by
peizhou001
Committed by
GitHub
Sep 15, 2023
Browse files
[Graphbolt] Move exclude edges helper to util file (#6332)
parent
8275bc29
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
34 additions
and
52 deletions
+34
-52
python/dgl/graphbolt/__init__.py
python/dgl/graphbolt/__init__.py
+1
-0
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+0
-48
python/dgl/graphbolt/minibatch_transformer.py
python/dgl/graphbolt/minibatch_transformer.py
+1
-1
python/dgl/graphbolt/utils/sample_utils.py
python/dgl/graphbolt/utils/sample_utils.py
+32
-3
No files found.
python/dgl/graphbolt/__init__.py
View file @
c036222b
...
...
@@ -20,6 +20,7 @@ from .sampled_subgraph import *
from
.subgraph_sampler
import
*
from
.utils
import
(
add_reverse_edges
,
exclude_seed_edges
,
unique_and_compact
,
unique_and_compact_node_pairs
,
)
...
...
python/dgl/graphbolt/minibatch.py
View file @
c036222b
...
...
@@ -9,7 +9,6 @@ import dgl
from
.base
import
etype_str_to_tuple
from
.sampled_subgraph
import
SampledSubgraph
from
.utils
import
add_reverse_edges
__all__
=
[
"MiniBatch"
]
...
...
@@ -226,50 +225,3 @@ class MiniBatch:
block
.
edata
[
dgl
.
EID
]
=
subgraph
.
reverse_edge_ids
return
blocks
def
exclude_edges
(
minibatch
:
MiniBatch
,
edges
:
Union
[
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
],
):
"""
Exclude edges from the sampled subgraphs in the minibatch.
Parameters
----------
minibatch : MiniBatch
The minibatch.
edges : Dict[str, Tuple[torch.Tensor, torch.Tensor]] or Tuple[torch.Tensor, torch.Tensor]
The edges to be excluded.
"""
minibatch
.
sampled_subgraphs
=
[
subgraph
.
exclude_edges
(
edges
)
for
subgraph
in
minibatch
.
sampled_subgraphs
]
return
minibatch
def
exclude_seed_edges
(
minibatch
:
MiniBatch
):
"""Exclude seed edges from the sampled subgraphs in the minibatch."""
return
exclude_edges
(
minibatch
,
minibatch
.
node_pairs
)
def
exclude_seed_edges_and_reverse
(
minibatch
:
MiniBatch
,
reverse_etypes
:
Dict
[
str
,
str
]
=
None
):
"""
Exclude seed edges and their reverse edges from the sampled subgraphs in
the minibatch.
Parameters
----------
minibatch : MiniBatch
The minibatch.
reverse_etypes : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types.
"""
edges_to_exclude
=
add_reverse_edges
(
minibatch
.
node_pairs
,
reverse_etypes
)
return
exclude_edges
(
minibatch
,
edges_to_exclude
)
python/dgl/graphbolt/minibatch_transformer.py
View file @
c036222b
...
...
@@ -33,5 +33,5 @@ class MiniBatchTransformer(Mapper):
minibatch
=
self
.
transformer
(
minibatch
)
assert
isinstance
(
minibatch
,
MiniBatch
),
"The transformer output should be a instance of MiniBatch"
),
"The transformer output should be a
n
instance of MiniBatch"
return
minibatch
python/dgl/graphbolt/utils/sample_utils.py
View file @
c036222b
...
...
@@ -6,6 +6,7 @@ from typing import Dict, List, Tuple, Union
import
torch
from
..base
import
etype_str_to_tuple
from
..minibatch
import
MiniBatch
def
add_reverse_edges
(
...
...
@@ -13,7 +14,7 @@ def add_reverse_edges(
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
],
reverse_etypes
:
Dict
[
str
,
str
]
=
None
,
reverse_etypes
_mapping
:
Dict
[
str
,
str
]
=
None
,
):
r
"""
This function finds the reverse edges of the given `edges` and returns the
...
...
@@ -33,7 +34,7 @@ def add_reverse_edges(
of tensors.
- If sampled subgraph is heterogeneous, then `edges` should be a
dictionary of edge types and the corresponding edges to exclude.
reverse_etypes : Dict[str, str], optional
reverse_etypes
_mapping
: Dict[str, str], optional
The mapping from the original edge types to their reverse edge types.
Returns
...
...
@@ -59,7 +60,7 @@ def add_reverse_edges(
return
(
torch
.
cat
([
u
,
v
]),
torch
.
cat
([
v
,
u
]))
else
:
combined_edges
=
edges
.
copy
()
for
etype
,
reverse_etype
in
reverse_etypes
.
items
():
for
etype
,
reverse_etype
in
reverse_etypes
_mapping
.
items
():
if
etype
in
edges
:
if
reverse_etype
in
combined_edges
:
u
,
v
=
combined_edges
[
reverse_etype
]
...
...
@@ -74,6 +75,34 @@ def add_reverse_edges(
return
combined_edges
def
exclude_seed_edges
(
minibatch
:
MiniBatch
,
include_reverse_edges
:
bool
=
False
,
reverse_etypes_mapping
:
Dict
[
str
,
str
]
=
None
,
):
"""
Exclude seed edges with or without their reverse edges from the sampled
subgraphs in the minibatch.
Parameters
----------
minibatch : MiniBatch
The minibatch.
reverse_etypes_mapping : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types.
"""
edges_to_exclude
=
minibatch
.
node_pairs
if
include_reverse_edges
:
edges_to_exclude
=
add_reverse_edges
(
minibatch
.
node_pairs
,
reverse_etypes_mapping
)
minibatch
.
sampled_subgraphs
=
[
subgraph
.
exclude_edges
(
edges_to_exclude
)
for
subgraph
in
minibatch
.
sampled_subgraphs
]
return
minibatch
def
unique_and_compact
(
nodes
:
Union
[
List
[
torch
.
Tensor
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment