Unverified Commit c036222b authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Move exclude edges helper to util file (#6332)

parent 8275bc29
...@@ -20,6 +20,7 @@ from .sampled_subgraph import * ...@@ -20,6 +20,7 @@ from .sampled_subgraph import *
from .subgraph_sampler import * from .subgraph_sampler import *
from .utils import ( from .utils import (
add_reverse_edges, add_reverse_edges,
exclude_seed_edges,
unique_and_compact, unique_and_compact,
unique_and_compact_node_pairs, unique_and_compact_node_pairs,
) )
......
...@@ -9,7 +9,6 @@ import dgl ...@@ -9,7 +9,6 @@ import dgl
from .base import etype_str_to_tuple from .base import etype_str_to_tuple
from .sampled_subgraph import SampledSubgraph from .sampled_subgraph import SampledSubgraph
from .utils import add_reverse_edges
__all__ = ["MiniBatch"] __all__ = ["MiniBatch"]
...@@ -226,50 +225,3 @@ class MiniBatch: ...@@ -226,50 +225,3 @@ class MiniBatch:
block.edata[dgl.EID] = subgraph.reverse_edge_ids block.edata[dgl.EID] = subgraph.reverse_edge_ids
return blocks return blocks
def exclude_edges(
minibatch: MiniBatch,
edges: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
],
):
"""
Exclude edges from the sampled subgraphs in the minibatch.
Parameters
----------
minibatch : MiniBatch
The minibatch.
edges : Dict[str, Tuple[torch.Tensor, torch.Tensor]] or Tuple[torch.Tensor, torch.Tensor]
The edges to be excluded.
"""
minibatch.sampled_subgraphs = [
subgraph.exclude_edges(edges)
for subgraph in minibatch.sampled_subgraphs
]
return minibatch
def exclude_seed_edges(minibatch: MiniBatch):
"""Exclude seed edges from the sampled subgraphs in the minibatch."""
return exclude_edges(minibatch, minibatch.node_pairs)
def exclude_seed_edges_and_reverse(
minibatch: MiniBatch, reverse_etypes: Dict[str, str] = None
):
"""
Exclude seed edges and their reverse edges from the sampled subgraphs in
the minibatch.
Parameters
----------
minibatch : MiniBatch
The minibatch.
reverse_etypes : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types.
"""
edges_to_exclude = add_reverse_edges(minibatch.node_pairs, reverse_etypes)
return exclude_edges(minibatch, edges_to_exclude)
...@@ -33,5 +33,5 @@ class MiniBatchTransformer(Mapper): ...@@ -33,5 +33,5 @@ class MiniBatchTransformer(Mapper):
minibatch = self.transformer(minibatch) minibatch = self.transformer(minibatch)
assert isinstance( assert isinstance(
minibatch, MiniBatch minibatch, MiniBatch
), "The transformer output should be a instance of MiniBatch" ), "The transformer output should be an instance of MiniBatch"
return minibatch return minibatch
...@@ -6,6 +6,7 @@ from typing import Dict, List, Tuple, Union ...@@ -6,6 +6,7 @@ from typing import Dict, List, Tuple, Union
import torch import torch
from ..base import etype_str_to_tuple from ..base import etype_str_to_tuple
from ..minibatch import MiniBatch
def add_reverse_edges( def add_reverse_edges(
...@@ -13,7 +14,7 @@ def add_reverse_edges( ...@@ -13,7 +14,7 @@ def add_reverse_edges(
Dict[str, Tuple[torch.Tensor, torch.Tensor]], Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
], ],
reverse_etypes: Dict[str, str] = None, reverse_etypes_mapping: Dict[str, str] = None,
): ):
r""" r"""
This function finds the reverse edges of the given `edges` and returns the This function finds the reverse edges of the given `edges` and returns the
...@@ -33,7 +34,7 @@ def add_reverse_edges( ...@@ -33,7 +34,7 @@ def add_reverse_edges(
of tensors. of tensors.
- If sampled subgraph is heterogeneous, then `edges` should be a - If sampled subgraph is heterogeneous, then `edges` should be a
dictionary of edge types and the corresponding edges to exclude. dictionary of edge types and the corresponding edges to exclude.
reverse_etypes : Dict[str, str], optional reverse_etypes_mapping : Dict[str, str], optional
The mapping from the original edge types to their reverse edge types. The mapping from the original edge types to their reverse edge types.
Returns Returns
...@@ -59,7 +60,7 @@ def add_reverse_edges( ...@@ -59,7 +60,7 @@ def add_reverse_edges(
return (torch.cat([u, v]), torch.cat([v, u])) return (torch.cat([u, v]), torch.cat([v, u]))
else: else:
combined_edges = edges.copy() combined_edges = edges.copy()
for etype, reverse_etype in reverse_etypes.items(): for etype, reverse_etype in reverse_etypes_mapping.items():
if etype in edges: if etype in edges:
if reverse_etype in combined_edges: if reverse_etype in combined_edges:
u, v = combined_edges[reverse_etype] u, v = combined_edges[reverse_etype]
...@@ -74,6 +75,34 @@ def add_reverse_edges( ...@@ -74,6 +75,34 @@ def add_reverse_edges(
return combined_edges return combined_edges
def exclude_seed_edges(
minibatch: MiniBatch,
include_reverse_edges: bool = False,
reverse_etypes_mapping: Dict[str, str] = None,
):
"""
Exclude seed edges with or without their reverse edges from the sampled
subgraphs in the minibatch.
Parameters
----------
minibatch : MiniBatch
The minibatch.
reverse_etypes_mapping : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types.
"""
edges_to_exclude = minibatch.node_pairs
if include_reverse_edges:
edges_to_exclude = add_reverse_edges(
minibatch.node_pairs, reverse_etypes_mapping
)
minibatch.sampled_subgraphs = [
subgraph.exclude_edges(edges_to_exclude)
for subgraph in minibatch.sampled_subgraphs
]
return minibatch
def unique_and_compact( def unique_and_compact(
nodes: Union[ nodes: Union[
List[torch.Tensor], List[torch.Tensor],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment