Unverified Commit 09c8e8d9 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Add adding reverse edges helper (#6288)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-16-19.ap-northeast-1.compute.internal>
parent 8d50b34b
......@@ -18,7 +18,11 @@ from .item_sampler import *
from .negative_sampler import *
from .sampled_subgraph import *
from .subgraph_sampler import *
from .utils import unique_and_compact, unique_and_compact_node_pairs
from .utils import (
add_reverse_edges,
unique_and_compact,
unique_and_compact_node_pairs,
)
def load_graphbolt():
......
......@@ -8,6 +8,72 @@ import torch
from ..base import etype_str_to_tuple
def add_reverse_edges(
edges: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
],
reverse_etypes: Dict[str, str] = None,
):
r"""
This function finds the reverse edges of the given `edges` and returns the
composition of them. In a homogeneous graph, reverse edges have inverted
source and destination node IDs. While in a heterogeneous graph, reversing
also involves swapping node IDs and their types. This function could be
used before `exclude_edges` function to help find targeting edges.
Note: The found reverse edges may not really exists in the original graph.
And repeat edges could be added becasue reverse edges may already exists in
the `edges`.
Parameters
----------
edges : Union[Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor]]
- If sampled subgraph is homogeneous, then `edges` should be a pair of
of tensors.
- If sampled subgraph is heterogeneous, then `edges` should be a
dictionary of edge types and the corresponding edges to exclude.
reverse_etypes : Dict[str, str], optional
The mapping from the original edge types to their reverse edge types.
Returns
-------
Union[Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor]]
The node pairs contain both the original edges and their reverse
counterparts.
Examples
--------
>>> edges = {"A:r:B": (torch.tensor([0, 1]), torch.tensor([1, 2]))}
>>> print(gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A"}))
{'A:r:B': (tensor([0, 1]), tensor([1, 2])),
'B:rr:A': (tensor([1, 2]), tensor([0, 1]))}
>>> edges = (torch.tensor([0, 1]), torch.tensor([2, 1]))
>>> print(gb.add_reverse_edges(edges))
(tensor([0, 1, 2, 1]), tensor([2, 1, 0, 1]))
"""
if isinstance(edges, tuple):
u, v = edges
return (torch.cat([u, v]), torch.cat([v, u]))
else:
combined_edges = edges.copy()
for etype, reverse_etype in reverse_etypes.items():
if etype in edges:
if reverse_etype in combined_edges:
u, v = combined_edges[reverse_etype]
u = torch.cat([u, edges[etype][1]])
v = torch.cat([v, edges[etype][0]])
combined_edges[reverse_etype] = (u, v)
else:
combined_edges[reverse_etype] = (
edges[etype][1],
edges[etype][0],
)
return combined_edges
def unique_and_compact(
nodes: Union[
List[torch.Tensor],
......
......@@ -3,6 +3,73 @@ import pytest
import torch
def test_find_reverse_edges_homo():
edges = (torch.tensor([1, 3, 5]), torch.tensor([2, 4, 5]))
edges = gb.add_reverse_edges(edges)
expected_edges = (
torch.tensor([1, 3, 5, 2, 4, 5]),
torch.tensor([2, 4, 5, 1, 3, 5]),
)
assert torch.equal(edges[0], expected_edges[0])
assert torch.equal(edges[1], expected_edges[1])
def test_find_reverse_edges_hetero():
edges = {
"A:r:B": (torch.tensor([1, 5]), torch.tensor([2, 5])),
"B:rr:A": (torch.tensor([3]), torch.tensor([3])),
}
edges = gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A"})
expected_edges = {
"A:r:B": (torch.tensor([1, 5]), torch.tensor([2, 5])),
"B:rr:A": (torch.tensor([3, 2, 5]), torch.tensor([3, 1, 5])),
}
assert torch.equal(edges["A:r:B"][0], expected_edges["A:r:B"][0])
assert torch.equal(edges["A:r:B"][1], expected_edges["A:r:B"][1])
assert torch.equal(edges["B:rr:A"][0], expected_edges["B:rr:A"][0])
assert torch.equal(edges["B:rr:A"][1], expected_edges["B:rr:A"][1])
def test_find_reverse_edges_bi_reverse_types():
edges = {
"A:r:B": (torch.tensor([1, 5]), torch.tensor([2, 5])),
"B:rr:A": (torch.tensor([3]), torch.tensor([3])),
}
edges = gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A", "B:rr:A": "A:r:B"})
expected_edges = {
"A:r:B": (torch.tensor([1, 5, 3]), torch.tensor([2, 5, 3])),
"B:rr:A": (torch.tensor([3, 2, 5]), torch.tensor([3, 1, 5])),
}
assert torch.equal(edges["A:r:B"][0], expected_edges["A:r:B"][0])
assert torch.equal(edges["A:r:B"][1], expected_edges["A:r:B"][1])
assert torch.equal(edges["B:rr:A"][0], expected_edges["B:rr:A"][0])
assert torch.equal(edges["B:rr:A"][1], expected_edges["B:rr:A"][1])
def test_find_reverse_edges_circual_reverse_types():
edges = {
"A:r1:B": (torch.tensor([1]), torch.tensor([1])),
"B:r2:C": (torch.tensor([2]), torch.tensor([2])),
"C:r3:A": (torch.tensor([3]), torch.tensor([3])),
}
edges = gb.add_reverse_edges(
edges, {"A:r1:B": "B:r2:C", "B:r2:C": "C:r3:A", "C:r3:A": "A:r1:B"}
)
expected_edges = {
"A:r1:B": (torch.tensor([1, 3]), torch.tensor([1, 3])),
"B:r2:C": (torch.tensor([2, 1]), torch.tensor([2, 1])),
"C:r3:A": (torch.tensor([3, 2]), torch.tensor([3, 2])),
}
assert torch.equal(edges["A:r1:B"][0], expected_edges["A:r1:B"][0])
assert torch.equal(edges["A:r1:B"][1], expected_edges["A:r1:B"][1])
assert torch.equal(edges["B:r2:C"][0], expected_edges["B:r2:C"][0])
assert torch.equal(edges["B:r2:C"][1], expected_edges["B:r2:C"][1])
assert torch.equal(edges["A:r1:B"][0], expected_edges["A:r1:B"][0])
assert torch.equal(edges["A:r1:B"][1], expected_edges["A:r1:B"][1])
assert torch.equal(edges["C:r3:A"][0], expected_edges["C:r3:A"][0])
assert torch.equal(edges["C:r3:A"][1], expected_edges["C:r3:A"][1])
def test_unique_and_compact_hetero():
N1 = torch.randint(0, 50, (30,))
N2 = torch.randint(0, 50, (20,))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment