Unverified Commit 20e5e266 authored by Xinyu Yao's avatar Xinyu Yao Committed by GitHub
Browse files

[GraphBolt] Remove old version exclude edges. (#7299)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 279ebb15
...@@ -54,4 +54,4 @@ from .internal import ( ...@@ -54,4 +54,4 @@ from .internal import (
unique_and_compact, unique_and_compact,
unique_and_compact_csc_formats, unique_and_compact_csc_formats,
) )
from .utils import add_reverse_edges, add_reverse_edges_2, exclude_seed_edges from .utils import add_reverse_edges, exclude_seed_edges
...@@ -115,9 +115,7 @@ class SampledSubgraph: ...@@ -115,9 +115,7 @@ class SampledSubgraph:
def exclude_edges( def exclude_edges(
self, self,
edges: Union[ edges: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Dict[str, torch.Tensor], Dict[str, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor],
torch.Tensor, torch.Tensor,
], ],
assume_num_node_within_int32: bool = True, assume_num_node_within_int32: bool = True,
...@@ -133,10 +131,9 @@ class SampledSubgraph: ...@@ -133,10 +131,9 @@ class SampledSubgraph:
---------- ----------
self : SampledSubgraph self : SampledSubgraph
The sampled subgraph. The sampled subgraph.
edges : Union[Tuple[torch.Tensor, torch.Tensor], edges : Union[torch.Tensor, Dict[str, torch.Tensor]]
Dict[str, Tuple[torch.Tensor, torch.Tensor]]]
Edges to exclude. If sampled subgraph is homogeneous, then `edges` Edges to exclude. If sampled subgraph is homogeneous, then `edges`
should be a pair of tensors representing the edges to exclude. If should be a N*2 tensors representing the edges to exclude. If
sampled subgraph is heterogeneous, then `edges` should be a sampled subgraph is heterogeneous, then `edges` should be a
dictionary of edge types and the corresponding edges to exclude. dictionary of edge types and the corresponding edges to exclude.
assume_num_node_within_int32: bool assume_num_node_within_int32: bool
...@@ -165,8 +162,7 @@ class SampledSubgraph: ...@@ -165,8 +162,7 @@ class SampledSubgraph:
... original_row_node_ids=original_row_node_ids, ... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids ... original_edge_ids=original_edge_ids
... ) ... )
>>> edges_to_exclude = {"A:relation:B": (torch.tensor([14, 15]), >>> edges_to_exclude = {"A:relation:B": torch.tensor([[14, 11], [15, 12]])}
... torch.tensor([11, 12]))}
>>> result = subgraph.exclude_edges(edges_to_exclude) >>> result = subgraph.exclude_edges(edges_to_exclude)
>>> print(result.sampled_csc) >>> print(result.sampled_csc)
{'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 1, 1]), {'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 1, 1]),
...@@ -183,9 +179,9 @@ class SampledSubgraph: ...@@ -183,9 +179,9 @@ class SampledSubgraph:
assert ( assert (
assume_num_node_within_int32 assume_num_node_within_int32
), "Values > int32 are not supported yet." ), "Values > int32 are not supported yet."
assert ( assert (isinstance(self.sampled_csc, CSCFormatBase)) == isinstance(
isinstance(self.sampled_csc, (CSCFormatBase, tuple)) edges, torch.Tensor
) == isinstance(edges, (tuple, torch.Tensor)), ( ), (
"The sampled subgraph and the edges to exclude should be both " "The sampled subgraph and the edges to exclude should be both "
"homogeneous or both heterogeneous." "homogeneous or both heterogeneous."
) )
...@@ -202,11 +198,6 @@ class SampledSubgraph: ...@@ -202,11 +198,6 @@ class SampledSubgraph:
self.original_row_node_ids, self.original_row_node_ids,
self.original_column_node_ids, self.original_column_node_ids,
) )
if isinstance(edges, torch.Tensor):
index = _exclude_homo_edges_2(
reverse_edges, edges, assume_num_node_within_int32
)
else:
index = _exclude_homo_edges( index = _exclude_homo_edges(
reverse_edges, edges, assume_num_node_within_int32 reverse_edges, edges, assume_num_node_within_int32
) )
...@@ -234,13 +225,6 @@ class SampledSubgraph: ...@@ -234,13 +225,6 @@ class SampledSubgraph:
original_row_node_ids, original_row_node_ids,
original_column_node_ids, original_column_node_ids,
) )
if isinstance(edges[etype], torch.Tensor):
index[etype] = _exclude_homo_edges_2(
reverse_edges,
edges[etype],
assume_num_node_within_int32,
)
else:
index[etype] = _exclude_homo_edges( index[etype] = _exclude_homo_edges(
reverse_edges, reverse_edges,
edges[etype], edges[etype],
...@@ -286,26 +270,6 @@ def _relabel_two_arrays(lhs_array, rhs_array): ...@@ -286,26 +270,6 @@ def _relabel_two_arrays(lhs_array, rhs_array):
def _exclude_homo_edges( def _exclude_homo_edges(
edges: Tuple[torch.Tensor, torch.Tensor],
edges_to_exclude: Tuple[torch.Tensor, torch.Tensor],
assume_num_node_within_int32: bool,
):
"""Return the indices of edges to be included."""
if assume_num_node_within_int32:
val = edges[0].long() << 32 | edges[1].long()
val_to_exclude = (
edges_to_exclude[0].long() << 32 | edges_to_exclude[1].long()
)
else:
# TODO: Add support for value > int32.
raise NotImplementedError(
"Values out of range int32 are not supported yet"
)
mask = ~isin(val, val_to_exclude)
return torch.nonzero(mask, as_tuple=True)[0]
def _exclude_homo_edges_2(
edges: Tuple[torch.Tensor, torch.Tensor], edges: Tuple[torch.Tensor, torch.Tensor],
edges_to_exclude: torch.Tensor, edges_to_exclude: torch.Tensor,
assume_num_node_within_int32: bool, assume_num_node_within_int32: bool,
......
"""Utility functions for external use.""" """Utility functions for external use."""
from typing import Dict, Tuple, Union from typing import Dict, Union
import torch import torch
...@@ -8,72 +8,6 @@ from .minibatch import MiniBatch ...@@ -8,72 +8,6 @@ from .minibatch import MiniBatch
def add_reverse_edges( def add_reverse_edges(
edges: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
],
reverse_etypes_mapping: Dict[str, str] = None,
):
r"""
This function finds the reverse edges of the given `edges` and returns the
composition of them. In a homogeneous graph, reverse edges have inverted
source and destination node IDs. While in a heterogeneous graph, reversing
also involves swapping node IDs and their types. This function could be
used before `exclude_edges` function to help find targeting edges.
Note: The found reverse edges may not really exists in the original graph.
And repeat edges could be added becasue reverse edges may already exists in
the `edges`.
Parameters
----------
edges : Union[Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor]]
- If sampled subgraph is homogeneous, then `edges` should be a pair of
of tensors.
- If sampled subgraph is heterogeneous, then `edges` should be a
dictionary of edge types and the corresponding edges to exclude.
reverse_etypes_mapping : Dict[str, str], optional
The mapping from the original edge types to their reverse edge types.
Returns
-------
Union[Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor]]
The node pairs contain both the original edges and their reverse
counterparts.
Examples
--------
>>> edges = {"A:r:B": (torch.tensor([0, 1]), torch.tensor([1, 2]))}
>>> print(gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A"}))
{'A:r:B': (tensor([0, 1]), tensor([1, 2])),
'B:rr:A': (tensor([1, 2]), tensor([0, 1]))}
>>> edges = (torch.tensor([0, 1]), torch.tensor([2, 1]))
>>> print(gb.add_reverse_edges(edges))
(tensor([0, 1, 2, 1]), tensor([2, 1, 0, 1]))
"""
if isinstance(edges, tuple):
u, v = edges
return (torch.cat([u, v]), torch.cat([v, u]))
else:
combined_edges = edges.copy()
for etype, reverse_etype in reverse_etypes_mapping.items():
if etype in edges:
if reverse_etype in combined_edges:
u, v = combined_edges[reverse_etype]
u = torch.cat([u, edges[etype][1]])
v = torch.cat([v, edges[etype][0]])
combined_edges[reverse_etype] = (u, v)
else:
combined_edges[reverse_etype] = (
edges[etype][1],
edges[etype][0],
)
return combined_edges
def add_reverse_edges_2(
edges: Union[Dict[str, torch.Tensor], torch.Tensor], edges: Union[Dict[str, torch.Tensor], torch.Tensor],
reverse_etypes_mapping: Dict[str, str] = None, reverse_etypes_mapping: Dict[str, str] = None,
): ):
...@@ -157,16 +91,9 @@ def exclude_seed_edges( ...@@ -157,16 +91,9 @@ def exclude_seed_edges(
reverse_etypes_mapping : Dict[str, str] = None reverse_etypes_mapping : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types. The mapping from the original edge types to their reverse edge types.
""" """
if minibatch.node_pairs is not None:
edges_to_exclude = minibatch.node_pairs
if include_reverse_edges:
edges_to_exclude = add_reverse_edges(
minibatch.node_pairs, reverse_etypes_mapping
)
else:
edges_to_exclude = minibatch.seeds edges_to_exclude = minibatch.seeds
if include_reverse_edges: if include_reverse_edges:
edges_to_exclude = add_reverse_edges_2( edges_to_exclude = add_reverse_edges(
edges_to_exclude, reverse_etypes_mapping edges_to_exclude, reverse_etypes_mapping
) )
minibatch.sampled_subgraphs = [ minibatch.sampled_subgraphs = [
......
...@@ -57,7 +57,7 @@ def test_exclude_edges_homo_deduplicated(reverse_row, reverse_column): ...@@ -57,7 +57,7 @@ def test_exclude_edges_homo_deduplicated(reverse_row, reverse_column):
original_row_node_ids, original_row_node_ids,
original_edge_ids, original_edge_ids,
) )
edges_to_exclude = (src_to_exclude, dst_to_exclude) edges_to_exclude = torch.cat((src_to_exclude, dst_to_exclude)).view(2, -1).T
result = subgraph.exclude_edges(edges_to_exclude) result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = gb.CSCFormatBase( expected_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 2, 2, 2]), indices=torch.tensor([0, 3]) indptr=torch.tensor([0, 0, 1, 2, 2, 2]), indices=torch.tensor([0, 3])
...@@ -107,7 +107,7 @@ def test_exclude_edges_homo_duplicated(reverse_row, reverse_column): ...@@ -107,7 +107,7 @@ def test_exclude_edges_homo_duplicated(reverse_row, reverse_column):
original_row_node_ids, original_row_node_ids,
original_edge_ids, original_edge_ids,
) )
edges_to_exclude = (src_to_exclude, dst_to_exclude) edges_to_exclude = torch.cat((src_to_exclude, dst_to_exclude)).view(2, -1).T
result = subgraph.exclude_edges(edges_to_exclude) result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = gb.CSCFormatBase( expected_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 1, 1, 3]), indices=torch.tensor([0, 2, 2]) indptr=torch.tensor([0, 0, 1, 1, 1, 3]), indices=torch.tensor([0, 2, 2])
...@@ -163,10 +163,14 @@ def test_exclude_edges_hetero_deduplicated(reverse_row, reverse_column): ...@@ -163,10 +163,14 @@ def test_exclude_edges_hetero_deduplicated(reverse_row, reverse_column):
) )
edges_to_exclude = { edges_to_exclude = {
"A:relation:B": ( "A:relation:B": torch.cat(
(
src_to_exclude, src_to_exclude,
dst_to_exclude, dst_to_exclude,
) )
)
.view(2, -1)
.T
} }
result = subgraph.exclude_edges(edges_to_exclude) result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = { expected_csc_formats = {
...@@ -231,10 +235,14 @@ def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column): ...@@ -231,10 +235,14 @@ def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column):
) )
edges_to_exclude = { edges_to_exclude = {
"A:relation:B": ( "A:relation:B": torch.cat(
(
src_to_exclude, src_to_exclude,
dst_to_exclude, dst_to_exclude,
) )
)
.view(2, -1)
.T
} }
result = subgraph.exclude_edges(edges_to_exclude) result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = { expected_csc_formats = {
...@@ -525,10 +533,14 @@ def test_sampled_subgraph_to_device(): ...@@ -525,10 +533,14 @@ def test_sampled_subgraph_to_device():
original_edge_ids=original_edge_ids, original_edge_ids=original_edge_ids,
) )
edges_to_exclude = { edges_to_exclude = {
"A:relation:B": ( "A:relation:B": torch.cat(
(
src_to_exclude, src_to_exclude,
dst_to_exclude, dst_to_exclude,
) )
)
.view(2, -1)
.T
} }
graph = subgraph.exclude_edges(edges_to_exclude) graph = subgraph.exclude_edges(edges_to_exclude)
......
...@@ -5,67 +5,56 @@ import torch ...@@ -5,67 +5,56 @@ import torch
def test_find_reverse_edges_homo(): def test_find_reverse_edges_homo():
edges = (torch.tensor([1, 3, 5]), torch.tensor([2, 4, 5])) edges = torch.tensor([[1, 3, 5], [2, 4, 5]]).T
edges = gb.add_reverse_edges(edges) edges = gb.add_reverse_edges(edges)
expected_edges = ( expected_edges = torch.tensor([[1, 3, 5, 2, 4, 5], [2, 4, 5, 1, 3, 5]]).T
torch.tensor([1, 3, 5, 2, 4, 5]), assert torch.equal(edges, expected_edges)
torch.tensor([2, 4, 5, 1, 3, 5]),
)
assert torch.equal(edges[0], expected_edges[0])
assert torch.equal(edges[1], expected_edges[1]) assert torch.equal(edges[1], expected_edges[1])
def test_find_reverse_edges_hetero(): def test_find_reverse_edges_hetero():
edges = { edges = {
"A:r:B": (torch.tensor([1, 5]), torch.tensor([2, 5])), "A:r:B": torch.tensor([[1, 5], [2, 5]]).T,
"B:rr:A": (torch.tensor([3]), torch.tensor([3])), "B:rr:A": torch.tensor([[3], [3]]).T,
} }
edges = gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A"}) edges = gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A"})
expected_edges = { expected_edges = {
"A:r:B": (torch.tensor([1, 5]), torch.tensor([2, 5])), "A:r:B": torch.tensor([[1, 5], [2, 5]]).T,
"B:rr:A": (torch.tensor([3, 2, 5]), torch.tensor([3, 1, 5])), "B:rr:A": torch.tensor([[3, 2, 5], [3, 1, 5]]).T,
} }
assert torch.equal(edges["A:r:B"][0], expected_edges["A:r:B"][0]) assert torch.equal(edges["A:r:B"], expected_edges["A:r:B"])
assert torch.equal(edges["A:r:B"][1], expected_edges["A:r:B"][1]) assert torch.equal(edges["B:rr:A"], expected_edges["B:rr:A"])
assert torch.equal(edges["B:rr:A"][0], expected_edges["B:rr:A"][0])
assert torch.equal(edges["B:rr:A"][1], expected_edges["B:rr:A"][1])
def test_find_reverse_edges_bi_reverse_types(): def test_find_reverse_edges_bi_reverse_types():
edges = { edges = {
"A:r:B": (torch.tensor([1, 5]), torch.tensor([2, 5])), "A:r:B": torch.tensor([[1, 5], [2, 5]]).T,
"B:rr:A": (torch.tensor([3]), torch.tensor([3])), "B:rr:A": torch.tensor([[3], [3]]).T,
} }
edges = gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A", "B:rr:A": "A:r:B"}) edges = gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A", "B:rr:A": "A:r:B"})
expected_edges = { expected_edges = {
"A:r:B": (torch.tensor([1, 5, 3]), torch.tensor([2, 5, 3])), "A:r:B": torch.tensor([[1, 5, 3], [2, 5, 3]]).T,
"B:rr:A": (torch.tensor([3, 2, 5]), torch.tensor([3, 1, 5])), "B:rr:A": torch.tensor([[3, 2, 5], [3, 1, 5]]).T,
} }
assert torch.equal(edges["A:r:B"][0], expected_edges["A:r:B"][0]) assert torch.equal(edges["A:r:B"], expected_edges["A:r:B"])
assert torch.equal(edges["A:r:B"][1], expected_edges["A:r:B"][1]) assert torch.equal(edges["B:rr:A"], expected_edges["B:rr:A"])
assert torch.equal(edges["B:rr:A"][0], expected_edges["B:rr:A"][0])
assert torch.equal(edges["B:rr:A"][1], expected_edges["B:rr:A"][1])
def test_find_reverse_edges_circual_reverse_types(): def test_find_reverse_edges_circual_reverse_types():
edges = { edges = {
"A:r1:B": (torch.tensor([1]), torch.tensor([1])), "A:r1:B": torch.tensor([[1, 1]]),
"B:r2:C": (torch.tensor([2]), torch.tensor([2])), "B:r2:C": torch.tensor([[2, 2]]),
"C:r3:A": (torch.tensor([3]), torch.tensor([3])), "C:r3:A": torch.tensor([[3, 3]]),
} }
edges = gb.add_reverse_edges( edges = gb.add_reverse_edges(
edges, {"A:r1:B": "B:r2:C", "B:r2:C": "C:r3:A", "C:r3:A": "A:r1:B"} edges, {"A:r1:B": "B:r2:C", "B:r2:C": "C:r3:A", "C:r3:A": "A:r1:B"}
) )
expected_edges = { expected_edges = {
"A:r1:B": (torch.tensor([1, 3]), torch.tensor([1, 3])), "A:r1:B": torch.tensor([[1, 3], [1, 3]]).T,
"B:r2:C": (torch.tensor([2, 1]), torch.tensor([2, 1])), "B:r2:C": torch.tensor([[2, 1], [2, 1]]).T,
"C:r3:A": (torch.tensor([3, 2]), torch.tensor([3, 2])), "C:r3:A": torch.tensor([[3, 2], [3, 2]]).T,
} }
assert torch.equal(edges["A:r1:B"][0], expected_edges["A:r1:B"][0]) assert torch.equal(edges["A:r1:B"], expected_edges["A:r1:B"])
assert torch.equal(edges["A:r1:B"][1], expected_edges["A:r1:B"][1]) assert torch.equal(edges["B:r2:C"], expected_edges["B:r2:C"])
assert torch.equal(edges["B:r2:C"][0], expected_edges["B:r2:C"][0]) assert torch.equal(edges["A:r1:B"], expected_edges["A:r1:B"])
assert torch.equal(edges["B:r2:C"][1], expected_edges["B:r2:C"][1]) assert torch.equal(edges["C:r3:A"], expected_edges["C:r3:A"])
assert torch.equal(edges["A:r1:B"][0], expected_edges["A:r1:B"][0])
assert torch.equal(edges["A:r1:B"][1], expected_edges["A:r1:B"][1])
assert torch.equal(edges["C:r3:A"][0], expected_edges["C:r3:A"][0])
assert torch.equal(edges["C:r3:A"][1], expected_edges["C:r3:A"][1])
...@@ -12,45 +12,8 @@ import torch ...@@ -12,45 +12,8 @@ import torch
def test_add_reverse_edges_homo(): def test_add_reverse_edges_homo():
edges = (torch.tensor([0, 1, 2, 3]), torch.tensor([4, 5, 6, 7]))
combined_edges = gb.add_reverse_edges(edges)
assert torch.equal(
combined_edges[0], torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
)
assert torch.equal(
combined_edges[1], torch.tensor([4, 5, 6, 7, 0, 1, 2, 3])
)
def test_add_reverse_edges_hetero():
# reverse_etype doesn't exist in original etypes.
edges = {"n1:e1:n2": (torch.tensor([0, 1, 2]), torch.tensor([4, 5, 6]))}
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"}
combined_edges = gb.add_reverse_edges(edges, reverse_etype_mapping)
assert torch.equal(combined_edges["n1:e1:n2"][0], torch.tensor([0, 1, 2]))
assert torch.equal(combined_edges["n1:e1:n2"][1], torch.tensor([4, 5, 6]))
assert torch.equal(combined_edges["n2:e2:n1"][0], torch.tensor([4, 5, 6]))
assert torch.equal(combined_edges["n2:e2:n1"][1], torch.tensor([0, 1, 2]))
# reverse_etype exists in original etypes.
edges = {
"n1:e1:n2": (torch.tensor([0, 1, 2]), torch.tensor([4, 5, 6])),
"n2:e2:n1": (torch.tensor([7, 8, 9]), torch.tensor([10, 11, 12])),
}
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"}
combined_edges = gb.add_reverse_edges(edges, reverse_etype_mapping)
assert torch.equal(combined_edges["n1:e1:n2"][0], torch.tensor([0, 1, 2]))
assert torch.equal(combined_edges["n1:e1:n2"][1], torch.tensor([4, 5, 6]))
assert torch.equal(
combined_edges["n2:e2:n1"][0], torch.tensor([7, 8, 9, 4, 5, 6])
)
assert torch.equal(
combined_edges["n2:e2:n1"][1], torch.tensor([10, 11, 12, 0, 1, 2])
)
def test_add_reverse_edges_2_homo():
edges = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]).T edges = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]).T
combined_edges = gb.add_reverse_edges_2(edges) combined_edges = gb.add_reverse_edges(edges)
assert torch.equal( assert torch.equal(
combined_edges, combined_edges,
torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7], [4, 5, 6, 7, 0, 1, 2, 3]]).T, torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7], [4, 5, 6, 7, 0, 1, 2, 3]]).T,
...@@ -63,14 +26,14 @@ def test_add_reverse_edges_2_homo(): ...@@ -63,14 +26,14 @@ def test_add_reverse_edges_2_homo():
"Only tensor with shape N*2 is supported now, but got torch.Size([4])." "Only tensor with shape N*2 is supported now, but got torch.Size([4])."
), ),
): ):
gb.add_reverse_edges_2(edges) gb.add_reverse_edges(edges)
def test_add_reverse_edges_2_hetero(): def test_add_reverse_edges_hetero():
# reverse_etype doesn't exist in original etypes. # reverse_etype doesn't exist in original etypes.
edges = {"n1:e1:n2": torch.tensor([[0, 1, 2], [4, 5, 6]]).T} edges = {"n1:e1:n2": torch.tensor([[0, 1, 2], [4, 5, 6]]).T}
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"} reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"}
combined_edges = gb.add_reverse_edges_2(edges, reverse_etype_mapping) combined_edges = gb.add_reverse_edges(edges, reverse_etype_mapping)
assert torch.equal( assert torch.equal(
combined_edges["n1:e1:n2"], torch.tensor([[0, 1, 2], [4, 5, 6]]).T combined_edges["n1:e1:n2"], torch.tensor([[0, 1, 2], [4, 5, 6]]).T
) )
...@@ -83,7 +46,7 @@ def test_add_reverse_edges_2_hetero(): ...@@ -83,7 +46,7 @@ def test_add_reverse_edges_2_hetero():
"n2:e2:n1": torch.tensor([[7, 8, 9], [10, 11, 12]]).T, "n2:e2:n1": torch.tensor([[7, 8, 9], [10, 11, 12]]).T,
} }
reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"} reverse_etype_mapping = {"n1:e1:n2": "n2:e2:n1"}
combined_edges = gb.add_reverse_edges_2(edges, reverse_etype_mapping) combined_edges = gb.add_reverse_edges(edges, reverse_etype_mapping)
assert torch.equal( assert torch.equal(
combined_edges["n1:e1:n2"], torch.tensor([[0, 1, 2], [4, 5, 6]]).T combined_edges["n1:e1:n2"], torch.tensor([[0, 1, 2], [4, 5, 6]]).T
) )
...@@ -102,7 +65,7 @@ def test_add_reverse_edges_2_hetero(): ...@@ -102,7 +65,7 @@ def test_add_reverse_edges_2_hetero():
"Only tensor with shape N*2 is supported now, but got torch.Size([3])." "Only tensor with shape N*2 is supported now, but got torch.Size([3])."
), ),
): ):
gb.add_reverse_edges_2(edges, reverse_etype_mapping) gb.add_reverse_edges(edges, reverse_etype_mapping)
@unittest.skipIf( @unittest.skipIf(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment