"vscode:/vscode.git/clone" did not exist on "7842c3ed0d39ea8cc912e4e61ed4f1e9f654a58c"
Unverified Commit c036222b authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Move exclude edges helper to util file (#6332)

parent 8275bc29
......@@ -20,6 +20,7 @@ from .sampled_subgraph import *
from .subgraph_sampler import *
from .utils import (
add_reverse_edges,
exclude_seed_edges,
unique_and_compact,
unique_and_compact_node_pairs,
)
......
......@@ -9,7 +9,6 @@ import dgl
from .base import etype_str_to_tuple
from .sampled_subgraph import SampledSubgraph
from .utils import add_reverse_edges
__all__ = ["MiniBatch"]
......@@ -226,50 +225,3 @@ class MiniBatch:
block.edata[dgl.EID] = subgraph.reverse_edge_ids
return blocks
def exclude_edges(
minibatch: MiniBatch,
edges: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
],
):
"""
Exclude edges from the sampled subgraphs in the minibatch.
Parameters
----------
minibatch : MiniBatch
The minibatch.
edges : Dict[str, Tuple[torch.Tensor, torch.Tensor]] or Tuple[torch.Tensor, torch.Tensor]
The edges to be excluded.
"""
minibatch.sampled_subgraphs = [
subgraph.exclude_edges(edges)
for subgraph in minibatch.sampled_subgraphs
]
return minibatch
def exclude_seed_edges(minibatch: MiniBatch):
"""Exclude seed edges from the sampled subgraphs in the minibatch."""
return exclude_edges(minibatch, minibatch.node_pairs)
def exclude_seed_edges_and_reverse(
minibatch: MiniBatch, reverse_etypes: Dict[str, str] = None
):
"""
Exclude seed edges and their reverse edges from the sampled subgraphs in
the minibatch.
Parameters
----------
minibatch : MiniBatch
The minibatch.
reverse_etypes : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types.
"""
edges_to_exclude = add_reverse_edges(minibatch.node_pairs, reverse_etypes)
return exclude_edges(minibatch, edges_to_exclude)
......@@ -33,5 +33,5 @@ class MiniBatchTransformer(Mapper):
minibatch = self.transformer(minibatch)
assert isinstance(
minibatch, MiniBatch
), "The transformer output should be a instance of MiniBatch"
), "The transformer output should be an instance of MiniBatch"
return minibatch
......@@ -6,6 +6,7 @@ from typing import Dict, List, Tuple, Union
import torch
from ..base import etype_str_to_tuple
from ..minibatch import MiniBatch
def add_reverse_edges(
......@@ -13,7 +14,7 @@ def add_reverse_edges(
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
],
reverse_etypes: Dict[str, str] = None,
reverse_etypes_mapping: Dict[str, str] = None,
):
r"""
This function finds the reverse edges of the given `edges` and returns the
......@@ -33,7 +34,7 @@ def add_reverse_edges(
of tensors.
- If sampled subgraph is heterogeneous, then `edges` should be a
dictionary of edge types and the corresponding edges to exclude.
reverse_etypes : Dict[str, str], optional
reverse_etypes_mapping : Dict[str, str], optional
The mapping from the original edge types to their reverse edge types.
Returns
......@@ -59,7 +60,7 @@ def add_reverse_edges(
return (torch.cat([u, v]), torch.cat([v, u]))
else:
combined_edges = edges.copy()
for etype, reverse_etype in reverse_etypes.items():
for etype, reverse_etype in reverse_etypes_mapping.items():
if etype in edges:
if reverse_etype in combined_edges:
u, v = combined_edges[reverse_etype]
......@@ -74,6 +75,34 @@ def add_reverse_edges(
return combined_edges
def exclude_seed_edges(
minibatch: MiniBatch,
include_reverse_edges: bool = False,
reverse_etypes_mapping: Dict[str, str] = None,
):
"""
Exclude seed edges with or without their reverse edges from the sampled
subgraphs in the minibatch.
Parameters
----------
minibatch : MiniBatch
The minibatch.
reverse_etypes_mapping : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types.
"""
edges_to_exclude = minibatch.node_pairs
if include_reverse_edges:
edges_to_exclude = add_reverse_edges(
minibatch.node_pairs, reverse_etypes_mapping
)
minibatch.sampled_subgraphs = [
subgraph.exclude_edges(edges_to_exclude)
for subgraph in minibatch.sampled_subgraphs
]
return minibatch
def unique_and_compact(
nodes: Union[
List[torch.Tensor],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment