Unverified Commit 20e5e266 authored by Xinyu Yao's avatar Xinyu Yao Committed by GitHub
Browse files

[GraphBolt] Remove old version exclude edges. (#7299)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 279ebb15
......@@ -54,4 +54,4 @@ from .internal import (
unique_and_compact,
unique_and_compact_csc_formats,
)
from .utils import add_reverse_edges, add_reverse_edges_2, exclude_seed_edges
from .utils import add_reverse_edges, exclude_seed_edges
......@@ -115,9 +115,7 @@ class SampledSubgraph:
def exclude_edges(
self,
edges: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Dict[str, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor],
torch.Tensor,
],
assume_num_node_within_int32: bool = True,
......@@ -133,10 +131,9 @@ class SampledSubgraph:
----------
self : SampledSubgraph
The sampled subgraph.
edges : Union[Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]]]
edges : Union[torch.Tensor, Dict[str, torch.Tensor]]
Edges to exclude. If sampled subgraph is homogeneous, then `edges`
should be a pair of tensors representing the edges to exclude. If
should be a N*2 tensors representing the edges to exclude. If
sampled subgraph is heterogeneous, then `edges` should be a
dictionary of edge types and the corresponding edges to exclude.
assume_num_node_within_int32: bool
......@@ -165,8 +162,7 @@ class SampledSubgraph:
... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids
... )
>>> edges_to_exclude = {"A:relation:B": (torch.tensor([14, 15]),
... torch.tensor([11, 12]))}
>>> edges_to_exclude = {"A:relation:B": torch.tensor([[14, 11], [15, 12]])}
>>> result = subgraph.exclude_edges(edges_to_exclude)
>>> print(result.sampled_csc)
{'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 1, 1]),
......@@ -183,9 +179,9 @@ class SampledSubgraph:
assert (
assume_num_node_within_int32
), "Values > int32 are not supported yet."
assert (
isinstance(self.sampled_csc, (CSCFormatBase, tuple))
) == isinstance(edges, (tuple, torch.Tensor)), (
assert (isinstance(self.sampled_csc, CSCFormatBase)) == isinstance(
edges, torch.Tensor
), (
"The sampled subgraph and the edges to exclude should be both "
"homogeneous or both heterogeneous."
)
......@@ -202,14 +198,9 @@ class SampledSubgraph:
self.original_row_node_ids,
self.original_column_node_ids,
)
if isinstance(edges, torch.Tensor):
index = _exclude_homo_edges_2(
reverse_edges, edges, assume_num_node_within_int32
)
else:
index = _exclude_homo_edges(
reverse_edges, edges, assume_num_node_within_int32
)
index = _exclude_homo_edges(
reverse_edges, edges, assume_num_node_within_int32
)
return calling_class(*_slice_subgraph(self, index))
else:
index = {}
......@@ -234,18 +225,11 @@ class SampledSubgraph:
original_row_node_ids,
original_column_node_ids,
)
if isinstance(edges[etype], torch.Tensor):
index[etype] = _exclude_homo_edges_2(
reverse_edges,
edges[etype],
assume_num_node_within_int32,
)
else:
index[etype] = _exclude_homo_edges(
reverse_edges,
edges[etype],
assume_num_node_within_int32,
)
index[etype] = _exclude_homo_edges(
reverse_edges,
edges[etype],
assume_num_node_within_int32,
)
return calling_class(*_slice_subgraph(self, index))
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
......@@ -286,26 +270,6 @@ def _relabel_two_arrays(lhs_array, rhs_array):
def _exclude_homo_edges(
edges: Tuple[torch.Tensor, torch.Tensor],
edges_to_exclude: Tuple[torch.Tensor, torch.Tensor],
assume_num_node_within_int32: bool,
):
"""Return the indices of edges to be included."""
if assume_num_node_within_int32:
val = edges[0].long() << 32 | edges[1].long()
val_to_exclude = (
edges_to_exclude[0].long() << 32 | edges_to_exclude[1].long()
)
else:
# TODO: Add support for value > int32.
raise NotImplementedError(
"Values out of range int32 are not supported yet"
)
mask = ~isin(val, val_to_exclude)
return torch.nonzero(mask, as_tuple=True)[0]
def _exclude_homo_edges_2(
edges: Tuple[torch.Tensor, torch.Tensor],
edges_to_exclude: torch.Tensor,
assume_num_node_within_int32: bool,
......
"""Utility functions for external use."""
from typing import Dict, Tuple, Union
from typing import Dict, Union
import torch
......@@ -8,72 +8,6 @@ from .minibatch import MiniBatch
def add_reverse_edges(
edges: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
],
reverse_etypes_mapping: Dict[str, str] = None,
):
r"""
This function finds the reverse edges of the given `edges` and returns the
composition of them. In a homogeneous graph, reverse edges have inverted
source and destination node IDs. While in a heterogeneous graph, reversing
also involves swapping node IDs and their types. This function could be
used before `exclude_edges` function to help find targeting edges.
Note: The found reverse edges may not really exists in the original graph.
And repeat edges could be added becasue reverse edges may already exists in
the `edges`.
Parameters
----------
edges : Union[Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor]]
- If sampled subgraph is homogeneous, then `edges` should be a pair of
of tensors.
- If sampled subgraph is heterogeneous, then `edges` should be a
dictionary of edge types and the corresponding edges to exclude.
reverse_etypes_mapping : Dict[str, str], optional
The mapping from the original edge types to their reverse edge types.
Returns
-------
Union[Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor]]
The node pairs contain both the original edges and their reverse
counterparts.
Examples
--------
>>> edges = {"A:r:B": (torch.tensor([0, 1]), torch.tensor([1, 2]))}
>>> print(gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A"}))
{'A:r:B': (tensor([0, 1]), tensor([1, 2])),
'B:rr:A': (tensor([1, 2]), tensor([0, 1]))}
>>> edges = (torch.tensor([0, 1]), torch.tensor([2, 1]))
>>> print(gb.add_reverse_edges(edges))
(tensor([0, 1, 2, 1]), tensor([2, 1, 0, 1]))
"""
if isinstance(edges, tuple):
u, v = edges
return (torch.cat([u, v]), torch.cat([v, u]))
else:
combined_edges = edges.copy()
for etype, reverse_etype in reverse_etypes_mapping.items():
if etype in edges:
if reverse_etype in combined_edges:
u, v = combined_edges[reverse_etype]
u = torch.cat([u, edges[etype][1]])
v = torch.cat([v, edges[etype][0]])
combined_edges[reverse_etype] = (u, v)
else:
combined_edges[reverse_etype] = (
edges[etype][1],
edges[etype][0],
)
return combined_edges
def add_reverse_edges_2(
edges: Union[Dict[str, torch.Tensor], torch.Tensor],
reverse_etypes_mapping: Dict[str, str] = None,
):
......@@ -157,18 +91,11 @@ def exclude_seed_edges(
reverse_etypes_mapping : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types.
"""
if minibatch.node_pairs is not None:
edges_to_exclude = minibatch.node_pairs
if include_reverse_edges:
edges_to_exclude = add_reverse_edges(
minibatch.node_pairs, reverse_etypes_mapping
)
else:
edges_to_exclude = minibatch.seeds
if include_reverse_edges:
edges_to_exclude = add_reverse_edges_2(
edges_to_exclude, reverse_etypes_mapping
)
edges_to_exclude = minibatch.seeds
if include_reverse_edges:
edges_to_exclude = add_reverse_edges(
edges_to_exclude, reverse_etypes_mapping
)
minibatch.sampled_subgraphs = [
subgraph.exclude_edges(edges_to_exclude)
for subgraph in minibatch.sampled_subgraphs
......
......@@ -57,7 +57,7 @@ def test_exclude_edges_homo_deduplicated(reverse_row, reverse_column):
original_row_node_ids,
original_edge_ids,
)
edges_to_exclude = (src_to_exclude, dst_to_exclude)
edges_to_exclude = torch.cat((src_to_exclude, dst_to_exclude)).view(2, -1).T
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 2, 2, 2]), indices=torch.tensor([0, 3])
......@@ -107,7 +107,7 @@ def test_exclude_edges_homo_duplicated(reverse_row, reverse_column):
original_row_node_ids,
original_edge_ids,
)
edges_to_exclude = (src_to_exclude, dst_to_exclude)
edges_to_exclude = torch.cat((src_to_exclude, dst_to_exclude)).view(2, -1).T
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 1, 1, 3]), indices=torch.tensor([0, 2, 2])
......@@ -163,10 +163,14 @@ def test_exclude_edges_hetero_deduplicated(reverse_row, reverse_column):
)
edges_to_exclude = {
"A:relation:B": (
src_to_exclude,
dst_to_exclude,
"A:relation:B": torch.cat(
(
src_to_exclude,
dst_to_exclude,
)
)
.view(2, -1)
.T
}
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = {
......@@ -231,10 +235,14 @@ def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column):
)
edges_to_exclude = {
"A:relation:B": (
src_to_exclude,
dst_to_exclude,
"A:relation:B": torch.cat(
(
src_to_exclude,
dst_to_exclude,
)
)
.view(2, -1)
.T
}
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = {
......@@ -525,10 +533,14 @@ def test_sampled_subgraph_to_device():
original_edge_ids=original_edge_ids,
)
edges_to_exclude = {
"A:relation:B": (
src_to_exclude,
dst_to_exclude,
"A:relation:B": torch.cat(
(
src_to_exclude,
dst_to_exclude,
)
)
.view(2, -1)
.T
}
graph = subgraph.exclude_edges(edges_to_exclude)
......
......@@ -5,67 +5,56 @@ import torch
def test_find_reverse_edges_homo():
edges = (torch.tensor([1, 3, 5]), torch.tensor([2, 4, 5]))
edges = torch.tensor([[1, 3, 5], [2, 4, 5]]).T
edges = gb.add_reverse_edges(edges)
expected_edges = (
torch.tensor([1, 3, 5, 2, 4, 5]),
torch.tensor([2, 4, 5, 1, 3, 5]),
)
assert torch.equal(edges[0], expected_edges[0])
expected_edges = torch.tensor([[1, 3, 5, 2, 4, 5], [2, 4, 5, 1, 3, 5]]).T
assert torch.equal(edges, expected_edges)
assert torch.equal(edges[1], expected_edges[1])
def test_find_reverse_edges_hetero():
edges = {
"A:r:B": (torch.tensor([1, 5]), torch.tensor([2, 5])),
"B:rr:A": (torch.tensor([3]), torch.tensor([3])),
"A:r:B": torch.tensor([[1, 5], [2, 5]]).T,
"B:rr:A": torch.tensor([[3], [3]]).T,
}
edges = gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A"})
expected_edges = {
"A:r:B": (torch.tensor([1, 5]), torch.tensor([2, 5])),
"B:rr:A": (torch.tensor([3, 2, 5]), torch.tensor([3, 1, 5])),
"A:r:B": torch.tensor([[1, 5], [2, 5]]).T,
"B:rr:A": torch.tensor([[3, 2, 5], [3, 1, 5]]).T,
}
assert torch.equal(edges["A:r:B"][0], expected_edges["A:r:B"][0])
assert torch.equal(edges["A:r:B"][1], expected_edges["A:r:B"][1])
assert torch.equal(edges["B:rr:A"][0], expected_edges["B:rr:A"][0])
assert torch.equal(edges["B:rr:A"][1], expected_edges["B:rr:A"][1])
assert torch.equal(edges["A:r:B"], expected_edges["A:r:B"])
assert torch.equal(edges["B:rr:A"], expected_edges["B:rr:A"])
def test_find_reverse_edges_bi_reverse_types():
edges = {
"A:r:B": (torch.tensor([1, 5]), torch.tensor([2, 5])),
"B:rr:A": (torch.tensor([3]), torch.tensor([3])),
"A:r:B": torch.tensor([[1, 5], [2, 5]]).T,
"B:rr:A": torch.tensor([[3], [3]]).T,
}
edges = gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A", "B:rr:A": "A:r:B"})
expected_edges = {
"A:r:B": (torch.tensor([1, 5, 3]), torch.tensor([2, 5, 3])),
"B:rr:A": (torch.tensor([3, 2, 5]), torch.tensor([3, 1, 5])),
"A:r:B": torch.tensor([[1, 5, 3], [2, 5, 3]]).T,
"B:rr:A": torch.tensor([[3, 2, 5], [3, 1, 5]]).T,
}
assert torch.equal(edges["A:r:B"][0], expected_edges["A:r:B"][0])
assert torch.equal(edges["A:r:B"][1], expected_edges["A:r:B"][1])
assert torch.equal(edges["B:rr:A"][0], expected_edges["B:rr:A"][0])
assert torch.equal(edges["B:rr:A"][1], expected_edges["B:rr:A"][1])
assert torch.equal(edges["A:r:B"], expected_edges["A:r:B"])
assert torch.equal(edges["B:rr:A"], expected_edges["B:rr:A"])
def test_find_reverse_edges_circual_reverse_types():
edges = {
"A:r1:B": (torch.tensor([1]), torch.tensor([1])),
"B:r2:C": (torch.tensor([2]), torch.tensor([2])),
"C:r3:A": (torch.tensor([3]), torch.tensor([3])),
"A:r1:B": torch.tensor([[1, 1]]),
"B:r2:C": torch.tensor([[2, 2]]),
"C:r3:A": torch.tensor([[3, 3]]),
}
edges = gb.add_reverse_edges(
edges, {"A:r1:B": "B:r2:C", "B:r2:C": "C:r3:A", "C:r3:A": "A:r1:B"}
)
expected_edges = {
"A:r1:B": (torch.tensor([1, 3]), torch.tensor([1, 3])),
"B:r2:C": (torch.tensor([2, 1]), torch.tensor([2, 1])),
"C:r3:A": (torch.tensor([3, 2]), torch.tensor([3, 2])),
"A:r1:B": torch.tensor([[1, 3], [1, 3]]).T,
"B:r2:C": torch.tensor([[2, 1], [2, 1]]).T,
"C:r3:A": torch.tensor([[3, 2], [3, 2]]).T,
}
assert torch.equal(edges["A:r1:B"][0], expected_edges["A:r1:B"][0])
assert torch.equal(edges["A:r1:B"][1], expected_edges["A:r1:B"][1])
assert torch.equal(edges["B:r2:C"][0], expected_edges["B:r2:C"][0])
assert torch.equal(edges["B:r2:C"][1], expected_edges["B:r2:C"][1])
assert torch.equal(edges["A:r1:B"][0], expected_edges["A:r1:B"][0])
assert torch.equal(edges["A:r1:B"][1], expected_edges["A:r1:B"][1])
assert torch.equal(edges["C:r3:A"][0], expected_edges["C:r3:A"][0])
assert torch.equal(edges["C:r3:A"][1], expected_edges["C:r3:A"][1])
assert torch.equal(edges["A:r1:B"], expected_edges["A:r1:B"])
assert torch.equal(edges["B:r2:C"], expected_edges["B:r2:C"])
assert torch.equal(edges["A:r1:B"], expected_edges["A:r1:B"])
assert torch.equal(edges["C:r3:A"], expected_edges["C:r3:A"])
......@@ -12,45 +12,8 @@ import torch
def test_add_reverse_edges_homo():
edges = (torch.tensor([0, 1, 2, 3]), torch.tensor([4, 5, 6, 7]))
combined_edges = gb.add_reverse_edges(edges)
assert torch.equal(
combined_edges[0], torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
)
assert torch.equal(
combined_edges[1], torch.tensor([4, 5, 6, 7, 0, 1, 2, 3])
)
def test_add_reverse_edges_hetero():
# reverse_etype doesn't exist in original etypes.
edges = {"n1:e1:n2": (torch.tensor([0, 1, 2]), torch.tensor([4, 5, 6]))}
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"}
combined_edges = gb.add_reverse_edges(edges, reverse_etype_mapping)
assert torch.equal(combined_edges["n1:e1:n2"][0], torch.tensor([0, 1, 2]))
assert torch.equal(combined_edges["n1:e1:n2"][1], torch.tensor([4, 5, 6]))
assert torch.equal(combined_edges["n2:e2:n1"][0], torch.tensor([4, 5, 6]))
assert torch.equal(combined_edges["n2:e2:n1"][1], torch.tensor([0, 1, 2]))
# reverse_etype exists in original etypes.
edges = {
"n1:e1:n2": (torch.tensor([0, 1, 2]), torch.tensor([4, 5, 6])),
"n2:e2:n1": (torch.tensor([7, 8, 9]), torch.tensor([10, 11, 12])),
}
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"}
combined_edges = gb.add_reverse_edges(edges, reverse_etype_mapping)
assert torch.equal(combined_edges["n1:e1:n2"][0], torch.tensor([0, 1, 2]))
assert torch.equal(combined_edges["n1:e1:n2"][1], torch.tensor([4, 5, 6]))
assert torch.equal(
combined_edges["n2:e2:n1"][0], torch.tensor([7, 8, 9, 4, 5, 6])
)
assert torch.equal(
combined_edges["n2:e2:n1"][1], torch.tensor([10, 11, 12, 0, 1, 2])
)
def test_add_reverse_edges_2_homo():
edges = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]).T
combined_edges = gb.add_reverse_edges_2(edges)
combined_edges = gb.add_reverse_edges(edges)
assert torch.equal(
combined_edges,
torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7], [4, 5, 6, 7, 0, 1, 2, 3]]).T,
......@@ -63,14 +26,14 @@ def test_add_reverse_edges_2_homo():
"Only tensor with shape N*2 is supported now, but got torch.Size([4])."
),
):
gb.add_reverse_edges_2(edges)
gb.add_reverse_edges(edges)
def test_add_reverse_edges_2_hetero():
def test_add_reverse_edges_hetero():
# reverse_etype doesn't exist in original etypes.
edges = {"n1:e1:n2": torch.tensor([[0, 1, 2], [4, 5, 6]]).T}
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"}
combined_edges = gb.add_reverse_edges_2(edges, reverse_etype_mapping)
combined_edges = gb.add_reverse_edges(edges, reverse_etype_mapping)
assert torch.equal(
combined_edges["n1:e1:n2"], torch.tensor([[0, 1, 2], [4, 5, 6]]).T
)
......@@ -83,7 +46,7 @@ def test_add_reverse_edges_2_hetero():
"n2:e2:n1": torch.tensor([[7, 8, 9], [10, 11, 12]]).T,
}
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"}
combined_edges = gb.add_reverse_edges_2(edges, reverse_etype_mapping)
combined_edges = gb.add_reverse_edges(edges, reverse_etype_mapping)
assert torch.equal(
combined_edges["n1:e1:n2"], torch.tensor([[0, 1, 2], [4, 5, 6]]).T
)
......@@ -102,7 +65,7 @@ def test_add_reverse_edges_2_hetero():
"Only tensor with shape N*2 is supported now, but got torch.Size([3])."
),
):
gb.add_reverse_edges_2(edges, reverse_etype_mapping)
gb.add_reverse_edges(edges, reverse_etype_mapping)
@unittest.skipIf(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment