Unverified Commit ced802d0 authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolot] Add the exclude edges interface (#5988)

parent 7c2ea23a
...@@ -37,13 +37,13 @@ class SampledSubgraphImpl(SampledSubgraph): ...@@ -37,13 +37,13 @@ class SampledSubgraphImpl(SampledSubgraph):
{('A', 'B', 'relation'): tensor([19, 20, 21])} {('A', 'B', 'relation'): tensor([19, 20, 21])}
""" """
node_pairs: Union[ node_pairs: Union[
Dict[Tuple[str, str, str], Tuple[torch.tensor, torch.tensor]], Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.tensor, torch.tensor], Tuple[torch.Tensor, torch.Tensor],
] = None ] = None
reverse_column_node_ids: Union[Dict[str, torch.tensor], torch.tensor] = None reverse_column_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
reverse_row_node_ids: Union[Dict[str, torch.tensor], torch.tensor] = None reverse_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
reverse_edge_ids: Union[ reverse_edge_ids: Union[
Dict[Tuple[str, str, str], torch.tensor], torch.tensor Dict[Tuple[str, str, str], torch.Tensor], torch.Tensor
] = None ] = None
def __post_init__(self): def __post_init__(self):
...@@ -68,3 +68,33 @@ class SampledSubgraphImpl(SampledSubgraph): ...@@ -68,3 +68,33 @@ class SampledSubgraphImpl(SampledSubgraph):
assert all( assert all(
isinstance(item, torch.Tensor) for item in self.node_pairs isinstance(item, torch.Tensor) for item in self.node_pairs
), "Nodes in pairs should be of type torch.Tensor." ), "Nodes in pairs should be of type torch.Tensor."
def exclude_edges(
subgraph: SampledSubgraphImpl,
edges: Union[
Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
],
) -> SampledSubgraphImpl:
r"""Exclude edges from the sampled subgraph.
Parameters
----------
subgraph : SampledSubgraphImpl
The sampled subgraph.
edges : Union[Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor]]
Edges to exclude. If sampled subgraph is homogeneous, then `edges`
should be a pair of tensors representing the edges to exclude. If
sampled subgraph is heterogeneous, then `edges` should be a dictionary
of edge types and the corresponding edges to exclude.
Returns
-------
SampledSubgraphImpl
The sampled subgraph with the excluded edges.
"""
# TODO(zhenkun): Implement this.
raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment