Unverified Commit 6b99f328 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] use str etype in graphs (#6235)

parent 911c4aba
......@@ -66,10 +66,7 @@ class FeatureFetcher(Mapper):
edges = (
subgraph.reverse_edge_ids
if not type_name
# TODO(#6211): Clean up the edge type converter.
else subgraph.reverse_edge_ids.get(
tuple(type_name.split(":")), None
)
else subgraph.reverse_edge_ids.get(type_name, None)
)
if edges is not None:
data.edge_feature[i][
......
......@@ -4,14 +4,14 @@ import os
import tarfile
import tempfile
from collections import defaultdict
from typing import Dict, Optional, Tuple, Union
from typing import Dict, Optional, Union
import torch
from ...base import ETYPE
from ...convert import to_homogeneous
from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple
from ..base import etype_str_to_tuple, etype_tuple_to_str
from .sampled_subgraph_impl import SampledSubgraphImpl
......@@ -21,7 +21,7 @@ class GraphMetadata:
def __init__(
self,
node_type_to_id: Dict[str, int],
edge_type_to_id: Dict[Tuple[str, str, str], int],
edge_type_to_id: Dict[str, int],
):
"""Initialize the GraphMetadata object.
......@@ -29,7 +29,7 @@ class GraphMetadata:
----------
node_type_to_id : Dict[str, int]
Dictionary from node types to node type IDs.
edge_type_to_id : Dict[Tuple[str, str, str], int]
edge_type_to_id : Dict[str, int]
Dictionary from edge types to edge type IDs.
Raises
......@@ -55,7 +55,7 @@ class GraphMetadata:
), "Multiple node types shoud not be mapped to a same id."
# Validate edge_type_to_id.
for edge_type in edge_types:
src, edge, dst = edge_type
src, edge, dst = etype_str_to_tuple(edge_type)
assert isinstance(edge, str), "Edge type name should be string."
assert (
src in node_types
......@@ -238,7 +238,7 @@ class CSCSamplingGraph:
# converted to heterogeneous graphs.
node_pairs = defaultdict(list)
for etype, etype_id in self.metadata.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype]
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
mask = type_per_edge == etype_id
......@@ -719,7 +719,8 @@ def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph:
# Initialize metadata.
node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}
edge_type_to_id = {
etype: g.get_etype_id(etype) for etype in g.canonical_etypes
etype_tuple_to_str(etype): g.get_etype_id(etype)
for etype in g.canonical_etypes
}
metadata = GraphMetadata(node_type_to_id, edge_type_to_id)
......
......@@ -5,6 +5,7 @@ from typing import Dict, Tuple, Union
import torch
from ..base import etype_str_to_tuple
from ..sampled_subgraph import SampledSubgraph
......@@ -14,11 +15,11 @@ class SampledSubgraphImpl(SampledSubgraph):
Examples
--------
>>> node_pairs = {('A', 'relation', 'B'): (torch.tensor([0, 1, 2]),
>>> node_pairs = {"A:relation:B"): (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))}
>>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> reverse_edge_ids = {('A', 'relation', 'B'): torch.tensor([19, 20, 21])}
>>> reverse_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... reverse_column_node_ids=reverse_column_node_ids,
......@@ -26,33 +27,29 @@ class SampledSubgraphImpl(SampledSubgraph):
... reverse_edge_ids=reverse_edge_ids
... )
>>> print(subgraph.node_pairs)
{('A', 'relation', 'B'): (tensor([0, 1, 2]), tensor([0, 1, 2]))}
{"A:relation:B": (tensor([0, 1, 2]), tensor([0, 1, 2]))}
>>> print(subgraph.reverse_column_node_ids)
{'B': tensor([10, 11, 12])}
>>> print(subgraph.reverse_row_node_ids)
{'A': tensor([13, 14, 15])}
>>> print(subgraph.reverse_edge_ids)
{('A', 'relation', 'B'): tensor([19, 20, 21])}
{"A:relation:B": tensor([19, 20, 21])}
"""
node_pairs: Union[
Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
] = None
reverse_column_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
reverse_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
reverse_edge_ids: Union[
Dict[Tuple[str, str, str], torch.Tensor], torch.Tensor
] = None
reverse_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
def __post_init__(self):
if isinstance(self.node_pairs, dict):
for etype, pair in self.node_pairs.items():
assert (
isinstance(etype, tuple) and len(etype) == 3
), "Edge type should be a triplet of strings (str, str, str)."
assert all(
isinstance(item, str) for item in etype
), "Edge type should be a triplet of strings (str, str, str)."
isinstance(etype, str)
and len(etype_str_to_tuple(etype)) == 3
), "Edge type should be a string in format of str:str:str."
assert (
isinstance(pair, tuple) and len(pair) == 2
), "Node pair should be a source-destination tuple (u, v)."
......@@ -127,7 +124,7 @@ def _slice_subgraph(subgraph: SampledSubgraphImpl, index: torch.Tensor):
def exclude_edges(
subgraph: SampledSubgraphImpl,
edges: Union[
Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
],
) -> SampledSubgraphImpl:
......@@ -142,7 +139,7 @@ def exclude_edges(
----------
subgraph : SampledSubgraphImpl
The sampled subgraph.
edges : Union[Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]],
edges : Union[Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor]]
Edges to exclude. If sampled subgraph is homogeneous, then `edges`
should be a pair of tensors representing the edges to exclude. If
......@@ -156,11 +153,11 @@ def exclude_edges(
Examples
--------
>>> node_pairs = {('A', 'relation', 'B'): (torch.tensor([0, 1, 2]),
>>> node_pairs = {"A:relation:B": (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))}
>>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> reverse_edge_ids = {('A', 'relation', 'B'): torch.tensor([19, 20, 21])}
>>> reverse_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... reverse_column_node_ids=reverse_column_node_ids,
......@@ -170,13 +167,13 @@ def exclude_edges(
>>> exclude_edges = (torch.tensor([14, 15]), torch.tensor([11, 12]))
>>> result = gb.exclude_edges(subgraph, exclude_edges)
>>> print(result.node_pairs)
{('A', 'relation', 'B'): (tensor([0]), tensor([0]))}
{"A:relation:B": (tensor([0]), tensor([0]))}
>>> print(result.reverse_column_node_ids)
{'B': tensor([10, 11, 12])}
>>> print(result.reverse_row_node_ids)
{'A': tensor([13, 14, 15])}
>>> print(result.reverse_edge_ids)
{('A', 'relation', 'B'): tensor([19])}
{"A:relation:B": tensor([19])}
"""
assert isinstance(subgraph.node_pairs, tuple) == isinstance(edges, tuple), (
"The sampled subgraph and the edges to exclude should be both "
......@@ -197,15 +194,16 @@ def exclude_edges(
else:
index = {}
for etype, pair in subgraph.node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
reverse_row_node_ids = (
None
if subgraph.reverse_row_node_ids is None
else subgraph.reverse_row_node_ids.get(etype[0])
else subgraph.reverse_row_node_ids.get(src_type)
)
reverse_column_node_ids = (
None
if subgraph.reverse_column_node_ids is None
else subgraph.reverse_column_node_ids.get(etype[2])
else subgraph.reverse_column_node_ids.get(dst_type)
)
reverse_edges = _to_reverse_ids(
pair,
......
"""Graphbolt sampled subgraph."""
# pylint: disable= invalid-name
from typing import Dict, Tuple
from typing import Dict, Tuple, Union
import torch
......@@ -14,7 +14,10 @@ class SampledSubgraph:
@property
def node_pairs(
self,
) -> Tuple[torch.Tensor] or Dict[(str, str, str), Tuple[torch.Tensor]]:
) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
]:
"""Returns the node pairs representing source-destination edges.
- If `node_pairs` is a tuple: It should be in the format ('u', 'v')
representing source and destination pairs.
......@@ -26,7 +29,7 @@ class SampledSubgraph:
@property
def reverse_column_node_ids(
self,
) -> torch.Tensor or Dict[str, torch.Tensor]:
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Returns corresponding reverse column node ids the original graph.
Column's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
......@@ -42,7 +45,9 @@ class SampledSubgraph:
return None
@property
def reverse_row_node_ids(self) -> torch.Tensor or Dict[str, torch.Tensor]:
def reverse_row_node_ids(
self,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Returns corresponding reverse row node ids the original graph.
Row's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
......@@ -57,7 +62,7 @@ class SampledSubgraph:
return None
@property
def reverse_edge_ids(self) -> torch.Tensor or Dict[str, torch.Tensor]:
def reverse_edge_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Returns corresponding reverse edge ids the original graph.
Reverse edge ids in the original graph. This is useful when edge
features are needed.
......
......@@ -5,6 +5,8 @@ from typing import Dict, List, Tuple, Union
import torch
from ..base import etype_str_to_tuple
def unique_and_compact(
nodes: Union[
......@@ -61,7 +63,7 @@ def unique_and_compact(
def unique_and_compact_node_pairs(
node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
],
unique_dst_nodes: Union[
torch.Tensor,
......@@ -73,8 +75,8 @@ def unique_and_compact_node_pairs(
Parameters
----------
node_pairs : Tuple[torch.Tensor, torch.Tensor] or \
Dict(Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor])
node_pairs : Union[Tuple[torch.Tensor, torch.Tensor],
Dict(str, Tuple[torch.Tensor, torch.Tensor])]
Node pairs representing source-destination edges.
- If `node_pairs` is a tuple: It means the graph is homogeneous.
Also, it should be in the format ('u', 'v') representing source
......@@ -102,20 +104,20 @@ def unique_and_compact_node_pairs(
>>> import dgl.graphbolt as gb
>>> N1 = torch.LongTensor([1, 2, 2])
>>> N2 = torch.LongTensor([5, 6, 5])
>>> node_pairs = {("n1", "e1", "n2"): (N1, N2),
... ("n2", "e2", "n1"): (N2, N1)}
>>> node_pairs = {"n1:e1:n2": (N1, N2),
... "n2:e2:n1": (N2, N1)}
>>> unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
... node_pairs
... )
>>> print(unique_nodes)
{'n1': tensor([1, 2]), 'n2': tensor([5, 6])}
>>> print(compacted_node_pairs)
{('n1', 'e1', 'n2'): (tensor([0, 1, 1]), tensor([0, 1, 0])),
('n2', 'e2', 'n1'): (tensor([0, 1, 0]), tensor([0, 1, 1]))}
{"n1:e1:n2": (tensor([0, 1, 1]), tensor([0, 1, 0])),
"n2:e2:n1": (tensor([0, 1, 0]), tensor([0, 1, 1]))}
"""
is_homogeneous = not isinstance(node_pairs, dict)
if is_homogeneous:
node_pairs = {("_N", "_E", "_N"): node_pairs}
node_pairs = {"_N:_E:_N": node_pairs}
if unique_dst_nodes is not None:
assert isinstance(
unique_dst_nodes, torch.Tensor
......@@ -126,8 +128,9 @@ def unique_and_compact_node_pairs(
src_nodes = defaultdict(list)
dst_nodes = defaultdict(list)
for etype, (src_node, dst_node) in node_pairs.items():
src_nodes[etype[0]].append(src_node)
dst_nodes[etype[2]].append(dst_node)
src_type, _, dst_type = etype_str_to_tuple(etype)
src_nodes[src_type].append(src_node)
dst_nodes[dst_type].append(dst_node)
src_nodes = {ntype: torch.cat(nodes) for ntype, nodes in src_nodes.items()}
dst_nodes = {ntype: torch.cat(nodes) for ntype, nodes in dst_nodes.items()}
# Compute unique destination nodes if not provided.
......@@ -156,7 +159,7 @@ def unique_and_compact_node_pairs(
# Map back with the same order.
for etype, pair in node_pairs.items():
num_elem = pair[0].size(0)
src_type, _, dst_type = etype
src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted_src[src_type][:num_elem]
dst = compacted_dst[dst_type][:num_elem]
compacted_node_pairs[etype] = (src, dst)
......
......@@ -43,7 +43,7 @@ def get_metadata(num_ntypes, num_etypes):
for n2 in range(n1, num_ntypes):
if count >= num_etypes:
break
etypes.update({(f"n{n1}", f"e{count}", f"n{n2}"): count})
etypes.update({f"n{n1}:e{count}:n{n2}": count})
count += 1
return gb.GraphMetadata(ntypes, etypes)
......
......@@ -73,7 +73,7 @@ def test_hetero_empty_graph(num_nodes):
)
def test_metadata_with_ntype_exception(ntypes):
with pytest.raises(Exception):
gb.GraphMetadata(ntypes, {("n1", "e1", "n2"): 1})
gb.GraphMetadata(ntypes, {"n1:e1:n2": 1})
@unittest.skipIf(
......@@ -87,9 +87,9 @@ def test_metadata_with_ntype_exception(ntypes):
{"e1": 1},
{("n1", "e1"): 1},
{("n1", "e1", 10): 1},
{("n1", "e1", "n2"): 1, ("n1", "e2", "n3"): 1},
{"n1:e1:n2": 1, ("n1", "e2", "n3"): 1},
{("n1", "e1", "n10"): 1},
{("n1", "e1", "n2"): 1.5},
{"n1:e1:n2": 1.5},
],
)
def test_metadata_with_etype_exception(etypes):
......@@ -320,10 +320,10 @@ def test_in_subgraph_heterogeneous():
"N1": 1,
}
etypes = {
("N0", "R0", "N0"): 0,
("N0", "R1", "N1"): 1,
("N1", "R2", "N0"): 2,
("N1", "R3", "N1"): 3,
"N0:R0:N0": 0,
"N0:R1:N1": 1,
"N1:R2:N0": 2,
"N1:R3:N1": 3,
}
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
......@@ -403,8 +403,8 @@ def test_sample_neighbors_homo():
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_hetero(labor):
"""Original graph in COO:
("n1", "e1", "n2"):[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
("n2", "e2", "n1"):[0, 0, 1, 2], [0, 1, 1 ,0]
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
......@@ -413,7 +413,7 @@ def test_sample_neighbors_hetero(labor):
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {("n1", "e1", "n2"): 0, ("n2", "e2", "n1"): 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
num_nodes = 5
num_edges = 9
......@@ -441,11 +441,11 @@ def test_sample_neighbors_hetero(labor):
# Verify in subgraph.
expected_node_pairs = {
("n1", "e1", "n2"): (
"n1:e1:n2": (
torch.LongTensor([0, 1]),
torch.LongTensor([0, 0]),
),
("n2", "e2", "n1"): (
"n2:e2:n1": (
torch.LongTensor([0, 2]),
torch.LongTensor([0, 0]),
),
......@@ -484,8 +484,8 @@ def test_sample_neighbors_fanouts(
fanouts, expected_sampled_num1, expected_sampled_num2, labor
):
"""Original graph in COO:
("n1", "e1", "n2"):[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
("n2", "e2", "n1"):[0, 0, 1, 2], [0, 1, 1 ,0]
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
......@@ -494,7 +494,7 @@ def test_sample_neighbors_fanouts(
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {("n1", "e1", "n2"): 0, ("n2", "e2", "n1"): 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
num_nodes = 5
num_edges = 9
......@@ -520,14 +520,8 @@ def test_sample_neighbors_fanouts(
subgraph = sampler(nodes, fanouts)
# Verify in subgraph.
assert (
subgraph.node_pairs[("n1", "e1", "n2")][0].numel()
== expected_sampled_num1
)
assert (
subgraph.node_pairs[("n2", "e2", "n1")][0].numel()
== expected_sampled_num2
)
assert subgraph.node_pairs["n1:e1:n2"][0].numel() == expected_sampled_num1
assert subgraph.node_pairs["n2:e2:n1"][0].numel() == expected_sampled_num2
@unittest.skipIf(
......@@ -542,8 +536,8 @@ def test_sample_neighbors_replace(
replace, expected_sampled_num1, expected_sampled_num2
):
"""Original graph in COO:
("n1", "e1", "n2"):[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
("n2", "e2", "n1"):[0, 0, 1, 2], [0, 1, 1 ,0]
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1
0 0 1 1 1
1 1 0 0 0
......@@ -552,7 +546,7 @@ def test_sample_neighbors_replace(
"""
# Initialize data.
ntypes = {"n1": 0, "n2": 1}
etypes = {("n1", "e1", "n2"): 0, ("n2", "e2", "n1"): 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
num_nodes = 5
num_edges = 9
......@@ -578,14 +572,8 @@ def test_sample_neighbors_replace(
)
# Verify in subgraph.
assert (
subgraph.node_pairs[("n1", "e1", "n2")][0].numel()
== expected_sampled_num1
)
assert (
subgraph.node_pairs[("n2", "e2", "n1")][0].numel()
== expected_sampled_num2
)
assert subgraph.node_pairs["n1:e1:n2"][0].numel() == expected_sampled_num1
assert subgraph.node_pairs["n2:e2:n1"][0].numel() == expected_sampled_num2
@unittest.skipIf(
......@@ -811,7 +799,7 @@ def test_from_dglgraph_homogeneous():
assert torch.equal(gb_g.node_type_offset, torch.tensor([0, 1000]))
assert torch.all(gb_g.type_per_edge == 0)
assert gb_g.metadata.node_type_to_id == {"_N": 0}
assert gb_g.metadata.edge_type_to_id == {("_N", "_E", "_N"): 0}
assert gb_g.metadata.edge_type_to_id == {"_N:_E:_N": 0}
@unittest.skipIf(
......@@ -855,10 +843,10 @@ def test_from_dglgraph_heterogeneous():
"n3": 2,
}
assert gb_g.metadata.edge_type_to_id == {
("n1", "r12", "n2"): 0,
("n1", "r13", "n3"): 1,
("n2", "r21", "n1"): 2,
("n2", "r23", "n3"): 3,
"n1:r12:n2": 0,
"n1:r13:n3": 1,
"n2:r21:n1": 2,
"n2:r23:n3": 3,
}
......@@ -972,9 +960,9 @@ def test_sample_neighbors_hetero_pick_number(
num_edges = 9
ntypes = {"N0": 0, "N1": 1, "N2": 2, "N3": 3}
etypes = {
("N0", "R0", "N1"): 0,
("N0", "R1", "N2"): 1,
("N0", "R2", "N3"): 2,
"N0:R0:N1": 0,
"N0:R1:N2": 1,
"N0:R2:N3": 2,
}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
......
......@@ -151,7 +151,7 @@ def get_hetero_graph():
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {("n1", "e1", "n2"): 0, ("n2", "e2", "n1"): 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
......
......@@ -73,7 +73,7 @@ def test_exclude_edges_homo(reverse_row, reverse_column):
@pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_hetero(reverse_row, reverse_column):
node_pairs = {
("A", "relation", "B"): (
"A:relation:B": (
torch.tensor([0, 1, 2]),
torch.tensor([2, 1, 0]),
)
......@@ -94,7 +94,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
else:
reverse_column_node_ids = None
dst_to_exclude = torch.tensor([0, 2])
reverse_edge_ids = {("A", "relation", "B"): torch.tensor([19, 20, 21])}
reverse_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = SampledSubgraphImpl(
node_pairs=node_pairs,
reverse_column_node_ids=reverse_column_node_ids,
......@@ -103,14 +103,14 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
)
edges_to_exclude = {
("A", "relation", "B"): (
"A:relation:B": (
src_to_exclude,
dst_to_exclude,
)
}
result = exclude_edges(subgraph, edges_to_exclude)
expected_node_pairs = {
("A", "relation", "B"): (
"A:relation:B": (
torch.tensor([1]),
torch.tensor([1]),
)
......@@ -127,7 +127,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
}
else:
expected_column_node_ids = None
expected_edge_ids = {("A", "relation", "B"): torch.tensor([20])}
expected_edge_ids = {"A:relation:B": torch.tensor([20])}
_assert_container_equal(result.node_pairs, expected_node_pairs)
_assert_container_equal(
......
......@@ -71,7 +71,7 @@ def get_hetero_graph():
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {("n1", "e1", "n2"): 0, ("n2", "e2", "n1"): 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
......@@ -120,8 +120,8 @@ def test_FeatureFetcher_with_edges_hetero():
def add_node_and_edge_ids(seeds):
subgraphs = []
reverse_edge_ids = {
("n1", "e1", "n2"): torch.randint(0, 50, (10,)),
("n2", "e2", "n1"): torch.randint(0, 50, (10,)),
"n1:e1:n2": torch.randint(0, 50, (10,)),
"n2:e2:n1": torch.randint(0, 50, (10,)),
}
for _ in range(3):
subgraphs.append(
......
......@@ -62,15 +62,15 @@ def test_unique_and_compact_node_pairs_hetero():
"n3": unique_N3,
}
node_pairs = {
("n1", "e1", "n2"): (
"n1:e1:n2": (
N1[:20],
N2,
),
("n1", "e2", "n3"): (
"n1:e2:n3": (
N1[20:30],
N3,
),
("n2", "e3", "n3"): (
"n2:e3:n3": (
N2[10:],
N3,
),
......@@ -84,7 +84,7 @@ def test_unique_and_compact_node_pairs_hetero():
assert torch.equal(torch.sort(nodes)[0], expected_nodes)
for etype, pair in compacted_node_pairs.items():
u, v = pair
u_type, _, v_type = etype
u_type, _, v_type = gb.etype_str_to_tuple(etype)
u, v = unique_nodes[u_type][u], unique_nodes[v_type][v]
expected_u, expected_v = node_pairs[etype]
assert torch.equal(u, expected_u)
......
......@@ -84,7 +84,7 @@ def get_hetero_graph():
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {("n1", "e1", "n2"): 0, ("n2", "e2", "n1"): 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment