"git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "7199699d78211ffc3ef754ec1d6d0d1de846ec87"
Unverified Commit 6b99f328 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] use str etype in graphs (#6235)

parent 911c4aba
...@@ -66,10 +66,7 @@ class FeatureFetcher(Mapper): ...@@ -66,10 +66,7 @@ class FeatureFetcher(Mapper):
edges = ( edges = (
subgraph.reverse_edge_ids subgraph.reverse_edge_ids
if not type_name if not type_name
# TODO(#6211): Clean up the edge type converter. else subgraph.reverse_edge_ids.get(type_name, None)
else subgraph.reverse_edge_ids.get(
tuple(type_name.split(":")), None
)
) )
if edges is not None: if edges is not None:
data.edge_feature[i][ data.edge_feature[i][
......
...@@ -4,14 +4,14 @@ import os ...@@ -4,14 +4,14 @@ import os
import tarfile import tarfile
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Union
import torch import torch
from ...base import ETYPE from ...base import ETYPE
from ...convert import to_homogeneous from ...convert import to_homogeneous
from ...heterograph import DGLGraph from ...heterograph import DGLGraph
from ..base import etype_str_to_tuple from ..base import etype_str_to_tuple, etype_tuple_to_str
from .sampled_subgraph_impl import SampledSubgraphImpl from .sampled_subgraph_impl import SampledSubgraphImpl
...@@ -21,7 +21,7 @@ class GraphMetadata: ...@@ -21,7 +21,7 @@ class GraphMetadata:
def __init__( def __init__(
self, self,
node_type_to_id: Dict[str, int], node_type_to_id: Dict[str, int],
edge_type_to_id: Dict[Tuple[str, str, str], int], edge_type_to_id: Dict[str, int],
): ):
"""Initialize the GraphMetadata object. """Initialize the GraphMetadata object.
...@@ -29,7 +29,7 @@ class GraphMetadata: ...@@ -29,7 +29,7 @@ class GraphMetadata:
---------- ----------
node_type_to_id : Dict[str, int] node_type_to_id : Dict[str, int]
Dictionary from node types to node type IDs. Dictionary from node types to node type IDs.
edge_type_to_id : Dict[Tuple[str, str, str], int] edge_type_to_id : Dict[str, int]
Dictionary from edge types to edge type IDs. Dictionary from edge types to edge type IDs.
Raises Raises
...@@ -55,7 +55,7 @@ class GraphMetadata: ...@@ -55,7 +55,7 @@ class GraphMetadata:
), "Multiple node types shoud not be mapped to a same id." ), "Multiple node types shoud not be mapped to a same id."
# Validate edge_type_to_id. # Validate edge_type_to_id.
for edge_type in edge_types: for edge_type in edge_types:
src, edge, dst = edge_type src, edge, dst = etype_str_to_tuple(edge_type)
assert isinstance(edge, str), "Edge type name should be string." assert isinstance(edge, str), "Edge type name should be string."
assert ( assert (
src in node_types src in node_types
...@@ -238,7 +238,7 @@ class CSCSamplingGraph: ...@@ -238,7 +238,7 @@ class CSCSamplingGraph:
# converted to heterogeneous graphs. # converted to heterogeneous graphs.
node_pairs = defaultdict(list) node_pairs = defaultdict(list)
for etype, etype_id in self.metadata.edge_type_to_id.items(): for etype, etype_id in self.metadata.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype] src_ntype_id = self.metadata.node_type_to_id[src_ntype]
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype] dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
mask = type_per_edge == etype_id mask = type_per_edge == etype_id
...@@ -719,7 +719,8 @@ def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph: ...@@ -719,7 +719,8 @@ def from_dglgraph(g: DGLGraph, is_homogeneous=False) -> CSCSamplingGraph:
# Initialize metadata. # Initialize metadata.
node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes} node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}
edge_type_to_id = { edge_type_to_id = {
etype: g.get_etype_id(etype) for etype in g.canonical_etypes etype_tuple_to_str(etype): g.get_etype_id(etype)
for etype in g.canonical_etypes
} }
metadata = GraphMetadata(node_type_to_id, edge_type_to_id) metadata = GraphMetadata(node_type_to_id, edge_type_to_id)
......
...@@ -5,6 +5,7 @@ from typing import Dict, Tuple, Union ...@@ -5,6 +5,7 @@ from typing import Dict, Tuple, Union
import torch import torch
from ..base import etype_str_to_tuple
from ..sampled_subgraph import SampledSubgraph from ..sampled_subgraph import SampledSubgraph
...@@ -14,11 +15,11 @@ class SampledSubgraphImpl(SampledSubgraph): ...@@ -14,11 +15,11 @@ class SampledSubgraphImpl(SampledSubgraph):
Examples Examples
-------- --------
>>> node_pairs = {('A', 'relation', 'B'): (torch.tensor([0, 1, 2]), >>> node_pairs = {"A:relation:B"): (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))} ... torch.tensor([0, 1, 2]))}
>>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])} >>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])} >>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> reverse_edge_ids = {('A', 'relation', 'B'): torch.tensor([19, 20, 21])} >>> reverse_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl( >>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs, ... node_pairs=node_pairs,
... reverse_column_node_ids=reverse_column_node_ids, ... reverse_column_node_ids=reverse_column_node_ids,
...@@ -26,33 +27,29 @@ class SampledSubgraphImpl(SampledSubgraph): ...@@ -26,33 +27,29 @@ class SampledSubgraphImpl(SampledSubgraph):
... reverse_edge_ids=reverse_edge_ids ... reverse_edge_ids=reverse_edge_ids
... ) ... )
>>> print(subgraph.node_pairs) >>> print(subgraph.node_pairs)
{('A', 'relation', 'B'): (tensor([0, 1, 2]), tensor([0, 1, 2]))} {"A:relation:B": (tensor([0, 1, 2]), tensor([0, 1, 2]))}
>>> print(subgraph.reverse_column_node_ids) >>> print(subgraph.reverse_column_node_ids)
{'B': tensor([10, 11, 12])} {'B': tensor([10, 11, 12])}
>>> print(subgraph.reverse_row_node_ids) >>> print(subgraph.reverse_row_node_ids)
{'A': tensor([13, 14, 15])} {'A': tensor([13, 14, 15])}
>>> print(subgraph.reverse_edge_ids) >>> print(subgraph.reverse_edge_ids)
{('A', 'relation', 'B'): tensor([19, 20, 21])} {"A:relation:B": tensor([19, 20, 21])}
""" """
node_pairs: Union[ node_pairs: Union[
Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]], Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
] = None ] = None
reverse_column_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None reverse_column_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
reverse_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None reverse_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
reverse_edge_ids: Union[ reverse_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
Dict[Tuple[str, str, str], torch.Tensor], torch.Tensor
] = None
def __post_init__(self): def __post_init__(self):
if isinstance(self.node_pairs, dict): if isinstance(self.node_pairs, dict):
for etype, pair in self.node_pairs.items(): for etype, pair in self.node_pairs.items():
assert ( assert (
isinstance(etype, tuple) and len(etype) == 3 isinstance(etype, str)
), "Edge type should be a triplet of strings (str, str, str)." and len(etype_str_to_tuple(etype)) == 3
assert all( ), "Edge type should be a string in format of str:str:str."
isinstance(item, str) for item in etype
), "Edge type should be a triplet of strings (str, str, str)."
assert ( assert (
isinstance(pair, tuple) and len(pair) == 2 isinstance(pair, tuple) and len(pair) == 2
), "Node pair should be a source-destination tuple (u, v)." ), "Node pair should be a source-destination tuple (u, v)."
...@@ -127,7 +124,7 @@ def _slice_subgraph(subgraph: SampledSubgraphImpl, index: torch.Tensor): ...@@ -127,7 +124,7 @@ def _slice_subgraph(subgraph: SampledSubgraphImpl, index: torch.Tensor):
def exclude_edges( def exclude_edges(
subgraph: SampledSubgraphImpl, subgraph: SampledSubgraphImpl,
edges: Union[ edges: Union[
Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]], Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
], ],
) -> SampledSubgraphImpl: ) -> SampledSubgraphImpl:
...@@ -142,8 +139,8 @@ def exclude_edges( ...@@ -142,8 +139,8 @@ def exclude_edges(
---------- ----------
subgraph : SampledSubgraphImpl subgraph : SampledSubgraphImpl
The sampled subgraph. The sampled subgraph.
edges : Union[Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]], edges : Union[Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor]] Tuple[torch.Tensor, torch.Tensor]]
Edges to exclude. If sampled subgraph is homogeneous, then `edges` Edges to exclude. If sampled subgraph is homogeneous, then `edges`
should be a pair of tensors representing the edges to exclude. If should be a pair of tensors representing the edges to exclude. If
sampled subgraph is heterogeneous, then `edges` should be a dictionary sampled subgraph is heterogeneous, then `edges` should be a dictionary
...@@ -156,11 +153,11 @@ def exclude_edges( ...@@ -156,11 +153,11 @@ def exclude_edges(
Examples Examples
-------- --------
>>> node_pairs = {('A', 'relation', 'B'): (torch.tensor([0, 1, 2]), >>> node_pairs = {"A:relation:B": (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))} ... torch.tensor([0, 1, 2]))}
>>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])} >>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])} >>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> reverse_edge_ids = {('A', 'relation', 'B'): torch.tensor([19, 20, 21])} >>> reverse_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl( >>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs, ... node_pairs=node_pairs,
... reverse_column_node_ids=reverse_column_node_ids, ... reverse_column_node_ids=reverse_column_node_ids,
...@@ -170,13 +167,13 @@ def exclude_edges( ...@@ -170,13 +167,13 @@ def exclude_edges(
>>> exclude_edges = (torch.tensor([14, 15]), torch.tensor([11, 12])) >>> exclude_edges = (torch.tensor([14, 15]), torch.tensor([11, 12]))
>>> result = gb.exclude_edges(subgraph, exclude_edges) >>> result = gb.exclude_edges(subgraph, exclude_edges)
>>> print(result.node_pairs) >>> print(result.node_pairs)
{('A', 'relation', 'B'): (tensor([0]), tensor([0]))} {"A:relation:B": (tensor([0]), tensor([0]))}
>>> print(result.reverse_column_node_ids) >>> print(result.reverse_column_node_ids)
{'B': tensor([10, 11, 12])} {'B': tensor([10, 11, 12])}
>>> print(result.reverse_row_node_ids) >>> print(result.reverse_row_node_ids)
{'A': tensor([13, 14, 15])} {'A': tensor([13, 14, 15])}
>>> print(result.reverse_edge_ids) >>> print(result.reverse_edge_ids)
{('A', 'relation', 'B'): tensor([19])} {"A:relation:B": tensor([19])}
""" """
assert isinstance(subgraph.node_pairs, tuple) == isinstance(edges, tuple), ( assert isinstance(subgraph.node_pairs, tuple) == isinstance(edges, tuple), (
"The sampled subgraph and the edges to exclude should be both " "The sampled subgraph and the edges to exclude should be both "
...@@ -197,15 +194,16 @@ def exclude_edges( ...@@ -197,15 +194,16 @@ def exclude_edges(
else: else:
index = {} index = {}
for etype, pair in subgraph.node_pairs.items(): for etype, pair in subgraph.node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
reverse_row_node_ids = ( reverse_row_node_ids = (
None None
if subgraph.reverse_row_node_ids is None if subgraph.reverse_row_node_ids is None
else subgraph.reverse_row_node_ids.get(etype[0]) else subgraph.reverse_row_node_ids.get(src_type)
) )
reverse_column_node_ids = ( reverse_column_node_ids = (
None None
if subgraph.reverse_column_node_ids is None if subgraph.reverse_column_node_ids is None
else subgraph.reverse_column_node_ids.get(etype[2]) else subgraph.reverse_column_node_ids.get(dst_type)
) )
reverse_edges = _to_reverse_ids( reverse_edges = _to_reverse_ids(
pair, pair,
......
"""Graphbolt sampled subgraph.""" """Graphbolt sampled subgraph."""
# pylint: disable= invalid-name # pylint: disable= invalid-name
from typing import Dict, Tuple from typing import Dict, Tuple, Union
import torch import torch
...@@ -14,7 +14,10 @@ class SampledSubgraph: ...@@ -14,7 +14,10 @@ class SampledSubgraph:
@property @property
def node_pairs( def node_pairs(
self, self,
) -> Tuple[torch.Tensor] or Dict[(str, str, str), Tuple[torch.Tensor]]: ) -> Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
]:
"""Returns the node pairs representing source-destination edges. """Returns the node pairs representing source-destination edges.
- If `node_pairs` is a tuple: It should be in the format ('u', 'v') - If `node_pairs` is a tuple: It should be in the format ('u', 'v')
representing source and destination pairs. representing source and destination pairs.
...@@ -26,7 +29,7 @@ class SampledSubgraph: ...@@ -26,7 +29,7 @@ class SampledSubgraph:
@property @property
def reverse_column_node_ids( def reverse_column_node_ids(
self, self,
) -> torch.Tensor or Dict[str, torch.Tensor]: ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Returns corresponding reverse column node ids the original graph. """Returns corresponding reverse column node ids the original graph.
Column's reverse node ids in the original graph. A graph structure Column's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is can be treated as a coordinated row and column pair, and this is
...@@ -42,7 +45,9 @@ class SampledSubgraph: ...@@ -42,7 +45,9 @@ class SampledSubgraph:
return None return None
@property @property
def reverse_row_node_ids(self) -> torch.Tensor or Dict[str, torch.Tensor]: def reverse_row_node_ids(
self,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Returns corresponding reverse row node ids the original graph. """Returns corresponding reverse row node ids the original graph.
Row's reverse node ids in the original graph. A graph structure Row's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is can be treated as a coordinated row and column pair, and this is
...@@ -57,7 +62,7 @@ class SampledSubgraph: ...@@ -57,7 +62,7 @@ class SampledSubgraph:
return None return None
@property @property
def reverse_edge_ids(self) -> torch.Tensor or Dict[str, torch.Tensor]: def reverse_edge_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Returns corresponding reverse edge ids the original graph. """Returns corresponding reverse edge ids the original graph.
Reverse edge ids in the original graph. This is useful when edge Reverse edge ids in the original graph. This is useful when edge
features are needed. features are needed.
......
...@@ -5,6 +5,8 @@ from typing import Dict, List, Tuple, Union ...@@ -5,6 +5,8 @@ from typing import Dict, List, Tuple, Union
import torch import torch
from ..base import etype_str_to_tuple
def unique_and_compact( def unique_and_compact(
nodes: Union[ nodes: Union[
...@@ -61,7 +63,7 @@ def unique_and_compact( ...@@ -61,7 +63,7 @@ def unique_and_compact(
def unique_and_compact_node_pairs( def unique_and_compact_node_pairs(
node_pairs: Union[ node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
Dict[Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]], Dict[str, Tuple[torch.Tensor, torch.Tensor]],
], ],
unique_dst_nodes: Union[ unique_dst_nodes: Union[
torch.Tensor, torch.Tensor,
...@@ -73,8 +75,8 @@ def unique_and_compact_node_pairs( ...@@ -73,8 +75,8 @@ def unique_and_compact_node_pairs(
Parameters Parameters
---------- ----------
node_pairs : Tuple[torch.Tensor, torch.Tensor] or \ node_pairs : Union[Tuple[torch.Tensor, torch.Tensor],
Dict(Tuple[str, str, str], Tuple[torch.Tensor, torch.Tensor]) Dict(str, Tuple[torch.Tensor, torch.Tensor])]
Node pairs representing source-destination edges. Node pairs representing source-destination edges.
- If `node_pairs` is a tuple: It means the graph is homogeneous. - If `node_pairs` is a tuple: It means the graph is homogeneous.
Also, it should be in the format ('u', 'v') representing source Also, it should be in the format ('u', 'v') representing source
...@@ -102,20 +104,20 @@ def unique_and_compact_node_pairs( ...@@ -102,20 +104,20 @@ def unique_and_compact_node_pairs(
>>> import dgl.graphbolt as gb >>> import dgl.graphbolt as gb
>>> N1 = torch.LongTensor([1, 2, 2]) >>> N1 = torch.LongTensor([1, 2, 2])
>>> N2 = torch.LongTensor([5, 6, 5]) >>> N2 = torch.LongTensor([5, 6, 5])
>>> node_pairs = {("n1", "e1", "n2"): (N1, N2), >>> node_pairs = {"n1:e1:n2": (N1, N2),
... ("n2", "e2", "n1"): (N2, N1)} ... "n2:e2:n1": (N2, N1)}
>>> unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs( >>> unique_nodes, compacted_node_pairs = gb.unique_and_compact_node_pairs(
... node_pairs ... node_pairs
... ) ... )
>>> print(unique_nodes) >>> print(unique_nodes)
{'n1': tensor([1, 2]), 'n2': tensor([5, 6])} {'n1': tensor([1, 2]), 'n2': tensor([5, 6])}
>>> print(compacted_node_pairs) >>> print(compacted_node_pairs)
{('n1', 'e1', 'n2'): (tensor([0, 1, 1]), tensor([0, 1, 0])), {"n1:e1:n2": (tensor([0, 1, 1]), tensor([0, 1, 0])),
('n2', 'e2', 'n1'): (tensor([0, 1, 0]), tensor([0, 1, 1]))} "n2:e2:n1": (tensor([0, 1, 0]), tensor([0, 1, 1]))}
""" """
is_homogeneous = not isinstance(node_pairs, dict) is_homogeneous = not isinstance(node_pairs, dict)
if is_homogeneous: if is_homogeneous:
node_pairs = {("_N", "_E", "_N"): node_pairs} node_pairs = {"_N:_E:_N": node_pairs}
if unique_dst_nodes is not None: if unique_dst_nodes is not None:
assert isinstance( assert isinstance(
unique_dst_nodes, torch.Tensor unique_dst_nodes, torch.Tensor
...@@ -126,8 +128,9 @@ def unique_and_compact_node_pairs( ...@@ -126,8 +128,9 @@ def unique_and_compact_node_pairs(
src_nodes = defaultdict(list) src_nodes = defaultdict(list)
dst_nodes = defaultdict(list) dst_nodes = defaultdict(list)
for etype, (src_node, dst_node) in node_pairs.items(): for etype, (src_node, dst_node) in node_pairs.items():
src_nodes[etype[0]].append(src_node) src_type, _, dst_type = etype_str_to_tuple(etype)
dst_nodes[etype[2]].append(dst_node) src_nodes[src_type].append(src_node)
dst_nodes[dst_type].append(dst_node)
src_nodes = {ntype: torch.cat(nodes) for ntype, nodes in src_nodes.items()} src_nodes = {ntype: torch.cat(nodes) for ntype, nodes in src_nodes.items()}
dst_nodes = {ntype: torch.cat(nodes) for ntype, nodes in dst_nodes.items()} dst_nodes = {ntype: torch.cat(nodes) for ntype, nodes in dst_nodes.items()}
# Compute unique destination nodes if not provided. # Compute unique destination nodes if not provided.
...@@ -156,7 +159,7 @@ def unique_and_compact_node_pairs( ...@@ -156,7 +159,7 @@ def unique_and_compact_node_pairs(
# Map back with the same order. # Map back with the same order.
for etype, pair in node_pairs.items(): for etype, pair in node_pairs.items():
num_elem = pair[0].size(0) num_elem = pair[0].size(0)
src_type, _, dst_type = etype src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted_src[src_type][:num_elem] src = compacted_src[src_type][:num_elem]
dst = compacted_dst[dst_type][:num_elem] dst = compacted_dst[dst_type][:num_elem]
compacted_node_pairs[etype] = (src, dst) compacted_node_pairs[etype] = (src, dst)
......
...@@ -43,7 +43,7 @@ def get_metadata(num_ntypes, num_etypes): ...@@ -43,7 +43,7 @@ def get_metadata(num_ntypes, num_etypes):
for n2 in range(n1, num_ntypes): for n2 in range(n1, num_ntypes):
if count >= num_etypes: if count >= num_etypes:
break break
etypes.update({(f"n{n1}", f"e{count}", f"n{n2}"): count}) etypes.update({f"n{n1}:e{count}:n{n2}": count})
count += 1 count += 1
return gb.GraphMetadata(ntypes, etypes) return gb.GraphMetadata(ntypes, etypes)
......
...@@ -73,7 +73,7 @@ def test_hetero_empty_graph(num_nodes): ...@@ -73,7 +73,7 @@ def test_hetero_empty_graph(num_nodes):
) )
def test_metadata_with_ntype_exception(ntypes): def test_metadata_with_ntype_exception(ntypes):
with pytest.raises(Exception): with pytest.raises(Exception):
gb.GraphMetadata(ntypes, {("n1", "e1", "n2"): 1}) gb.GraphMetadata(ntypes, {"n1:e1:n2": 1})
@unittest.skipIf( @unittest.skipIf(
...@@ -87,9 +87,9 @@ def test_metadata_with_ntype_exception(ntypes): ...@@ -87,9 +87,9 @@ def test_metadata_with_ntype_exception(ntypes):
{"e1": 1}, {"e1": 1},
{("n1", "e1"): 1}, {("n1", "e1"): 1},
{("n1", "e1", 10): 1}, {("n1", "e1", 10): 1},
{("n1", "e1", "n2"): 1, ("n1", "e2", "n3"): 1}, {"n1:e1:n2": 1, ("n1", "e2", "n3"): 1},
{("n1", "e1", "n10"): 1}, {("n1", "e1", "n10"): 1},
{("n1", "e1", "n2"): 1.5}, {"n1:e1:n2": 1.5},
], ],
) )
def test_metadata_with_etype_exception(etypes): def test_metadata_with_etype_exception(etypes):
...@@ -320,10 +320,10 @@ def test_in_subgraph_heterogeneous(): ...@@ -320,10 +320,10 @@ def test_in_subgraph_heterogeneous():
"N1": 1, "N1": 1,
} }
etypes = { etypes = {
("N0", "R0", "N0"): 0, "N0:R0:N0": 0,
("N0", "R1", "N1"): 1, "N0:R1:N1": 1,
("N1", "R2", "N0"): 2, "N1:R2:N0": 2,
("N1", "R3", "N1"): 3, "N1:R3:N1": 3,
} }
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12]) indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4]) indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
...@@ -403,8 +403,8 @@ def test_sample_neighbors_homo(): ...@@ -403,8 +403,8 @@ def test_sample_neighbors_homo():
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_hetero(labor): def test_sample_neighbors_hetero(labor):
"""Original graph in COO: """Original graph in COO:
("n1", "e1", "n2"):[0, 0, 1, 1, 1], [0, 2, 0, 1, 2] "n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
("n2", "e2", "n1"):[0, 0, 1, 2], [0, 1, 1 ,0] "n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1 0 0 1 0 1
0 0 1 1 1 0 0 1 1 1
1 1 0 0 0 1 1 0 0 0
...@@ -413,7 +413,7 @@ def test_sample_neighbors_hetero(labor): ...@@ -413,7 +413,7 @@ def test_sample_neighbors_hetero(labor):
""" """
# Initialize data. # Initialize data.
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {("n1", "e1", "n2"): 0, ("n2", "e2", "n1"): 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes) metadata = gb.GraphMetadata(ntypes, etypes)
num_nodes = 5 num_nodes = 5
num_edges = 9 num_edges = 9
...@@ -441,11 +441,11 @@ def test_sample_neighbors_hetero(labor): ...@@ -441,11 +441,11 @@ def test_sample_neighbors_hetero(labor):
# Verify in subgraph. # Verify in subgraph.
expected_node_pairs = { expected_node_pairs = {
("n1", "e1", "n2"): ( "n1:e1:n2": (
torch.LongTensor([0, 1]), torch.LongTensor([0, 1]),
torch.LongTensor([0, 0]), torch.LongTensor([0, 0]),
), ),
("n2", "e2", "n1"): ( "n2:e2:n1": (
torch.LongTensor([0, 2]), torch.LongTensor([0, 2]),
torch.LongTensor([0, 0]), torch.LongTensor([0, 0]),
), ),
...@@ -484,8 +484,8 @@ def test_sample_neighbors_fanouts( ...@@ -484,8 +484,8 @@ def test_sample_neighbors_fanouts(
fanouts, expected_sampled_num1, expected_sampled_num2, labor fanouts, expected_sampled_num1, expected_sampled_num2, labor
): ):
"""Original graph in COO: """Original graph in COO:
("n1", "e1", "n2"):[0, 0, 1, 1, 1], [0, 2, 0, 1, 2] "n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
("n2", "e2", "n1"):[0, 0, 1, 2], [0, 1, 1 ,0] "n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1 0 0 1 0 1
0 0 1 1 1 0 0 1 1 1
1 1 0 0 0 1 1 0 0 0
...@@ -494,7 +494,7 @@ def test_sample_neighbors_fanouts( ...@@ -494,7 +494,7 @@ def test_sample_neighbors_fanouts(
""" """
# Initialize data. # Initialize data.
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {("n1", "e1", "n2"): 0, ("n2", "e2", "n1"): 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes) metadata = gb.GraphMetadata(ntypes, etypes)
num_nodes = 5 num_nodes = 5
num_edges = 9 num_edges = 9
...@@ -520,14 +520,8 @@ def test_sample_neighbors_fanouts( ...@@ -520,14 +520,8 @@ def test_sample_neighbors_fanouts(
subgraph = sampler(nodes, fanouts) subgraph = sampler(nodes, fanouts)
# Verify in subgraph. # Verify in subgraph.
assert ( assert subgraph.node_pairs["n1:e1:n2"][0].numel() == expected_sampled_num1
subgraph.node_pairs[("n1", "e1", "n2")][0].numel() assert subgraph.node_pairs["n2:e2:n1"][0].numel() == expected_sampled_num2
== expected_sampled_num1
)
assert (
subgraph.node_pairs[("n2", "e2", "n1")][0].numel()
== expected_sampled_num2
)
@unittest.skipIf( @unittest.skipIf(
...@@ -542,8 +536,8 @@ def test_sample_neighbors_replace( ...@@ -542,8 +536,8 @@ def test_sample_neighbors_replace(
replace, expected_sampled_num1, expected_sampled_num2 replace, expected_sampled_num1, expected_sampled_num2
): ):
"""Original graph in COO: """Original graph in COO:
("n1", "e1", "n2"):[0, 0, 1, 1, 1], [0, 2, 0, 1, 2] "n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
("n2", "e2", "n1"):[0, 0, 1, 2], [0, 1, 1 ,0] "n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
0 0 1 0 1 0 0 1 0 1
0 0 1 1 1 0 0 1 1 1
1 1 0 0 0 1 1 0 0 0
...@@ -552,7 +546,7 @@ def test_sample_neighbors_replace( ...@@ -552,7 +546,7 @@ def test_sample_neighbors_replace(
""" """
# Initialize data. # Initialize data.
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {("n1", "e1", "n2"): 0, ("n2", "e2", "n1"): 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes) metadata = gb.GraphMetadata(ntypes, etypes)
num_nodes = 5 num_nodes = 5
num_edges = 9 num_edges = 9
...@@ -578,14 +572,8 @@ def test_sample_neighbors_replace( ...@@ -578,14 +572,8 @@ def test_sample_neighbors_replace(
) )
# Verify in subgraph. # Verify in subgraph.
assert ( assert subgraph.node_pairs["n1:e1:n2"][0].numel() == expected_sampled_num1
subgraph.node_pairs[("n1", "e1", "n2")][0].numel() assert subgraph.node_pairs["n2:e2:n1"][0].numel() == expected_sampled_num2
== expected_sampled_num1
)
assert (
subgraph.node_pairs[("n2", "e2", "n1")][0].numel()
== expected_sampled_num2
)
@unittest.skipIf( @unittest.skipIf(
...@@ -811,7 +799,7 @@ def test_from_dglgraph_homogeneous(): ...@@ -811,7 +799,7 @@ def test_from_dglgraph_homogeneous():
assert torch.equal(gb_g.node_type_offset, torch.tensor([0, 1000])) assert torch.equal(gb_g.node_type_offset, torch.tensor([0, 1000]))
assert torch.all(gb_g.type_per_edge == 0) assert torch.all(gb_g.type_per_edge == 0)
assert gb_g.metadata.node_type_to_id == {"_N": 0} assert gb_g.metadata.node_type_to_id == {"_N": 0}
assert gb_g.metadata.edge_type_to_id == {("_N", "_E", "_N"): 0} assert gb_g.metadata.edge_type_to_id == {"_N:_E:_N": 0}
@unittest.skipIf( @unittest.skipIf(
...@@ -855,10 +843,10 @@ def test_from_dglgraph_heterogeneous(): ...@@ -855,10 +843,10 @@ def test_from_dglgraph_heterogeneous():
"n3": 2, "n3": 2,
} }
assert gb_g.metadata.edge_type_to_id == { assert gb_g.metadata.edge_type_to_id == {
("n1", "r12", "n2"): 0, "n1:r12:n2": 0,
("n1", "r13", "n3"): 1, "n1:r13:n3": 1,
("n2", "r21", "n1"): 2, "n2:r21:n1": 2,
("n2", "r23", "n3"): 3, "n2:r23:n3": 3,
} }
...@@ -972,9 +960,9 @@ def test_sample_neighbors_hetero_pick_number( ...@@ -972,9 +960,9 @@ def test_sample_neighbors_hetero_pick_number(
num_edges = 9 num_edges = 9
ntypes = {"N0": 0, "N1": 1, "N2": 2, "N3": 3} ntypes = {"N0": 0, "N1": 1, "N2": 2, "N3": 3}
etypes = { etypes = {
("N0", "R0", "N1"): 0, "N0:R0:N1": 0,
("N0", "R1", "N2"): 1, "N0:R1:N2": 1,
("N0", "R2", "N3"): 2, "N0:R2:N3": 2,
} }
metadata = gb.GraphMetadata(ntypes, etypes) metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]) indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
......
...@@ -151,7 +151,7 @@ def get_hetero_graph(): ...@@ -151,7 +151,7 @@ def get_hetero_graph():
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type. # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3 # num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {("n1", "e1", "n2"): 0, ("n2", "e2", "n1"): 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes) metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10]) indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
......
...@@ -73,7 +73,7 @@ def test_exclude_edges_homo(reverse_row, reverse_column): ...@@ -73,7 +73,7 @@ def test_exclude_edges_homo(reverse_row, reverse_column):
@pytest.mark.parametrize("reverse_column", [True, False]) @pytest.mark.parametrize("reverse_column", [True, False])
def test_exclude_edges_hetero(reverse_row, reverse_column): def test_exclude_edges_hetero(reverse_row, reverse_column):
node_pairs = { node_pairs = {
("A", "relation", "B"): ( "A:relation:B": (
torch.tensor([0, 1, 2]), torch.tensor([0, 1, 2]),
torch.tensor([2, 1, 0]), torch.tensor([2, 1, 0]),
) )
...@@ -94,7 +94,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column): ...@@ -94,7 +94,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
else: else:
reverse_column_node_ids = None reverse_column_node_ids = None
dst_to_exclude = torch.tensor([0, 2]) dst_to_exclude = torch.tensor([0, 2])
reverse_edge_ids = {("A", "relation", "B"): torch.tensor([19, 20, 21])} reverse_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = SampledSubgraphImpl( subgraph = SampledSubgraphImpl(
node_pairs=node_pairs, node_pairs=node_pairs,
reverse_column_node_ids=reverse_column_node_ids, reverse_column_node_ids=reverse_column_node_ids,
...@@ -103,14 +103,14 @@ def test_exclude_edges_hetero(reverse_row, reverse_column): ...@@ -103,14 +103,14 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
) )
edges_to_exclude = { edges_to_exclude = {
("A", "relation", "B"): ( "A:relation:B": (
src_to_exclude, src_to_exclude,
dst_to_exclude, dst_to_exclude,
) )
} }
result = exclude_edges(subgraph, edges_to_exclude) result = exclude_edges(subgraph, edges_to_exclude)
expected_node_pairs = { expected_node_pairs = {
("A", "relation", "B"): ( "A:relation:B": (
torch.tensor([1]), torch.tensor([1]),
torch.tensor([1]), torch.tensor([1]),
) )
...@@ -127,7 +127,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column): ...@@ -127,7 +127,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
} }
else: else:
expected_column_node_ids = None expected_column_node_ids = None
expected_edge_ids = {("A", "relation", "B"): torch.tensor([20])} expected_edge_ids = {"A:relation:B": torch.tensor([20])}
_assert_container_equal(result.node_pairs, expected_node_pairs) _assert_container_equal(result.node_pairs, expected_node_pairs)
_assert_container_equal( _assert_container_equal(
......
...@@ -71,7 +71,7 @@ def get_hetero_graph(): ...@@ -71,7 +71,7 @@ def get_hetero_graph():
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type. # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3 # num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {("n1", "e1", "n2"): 0, ("n2", "e2", "n1"): 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes) metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10]) indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
...@@ -120,8 +120,8 @@ def test_FeatureFetcher_with_edges_hetero(): ...@@ -120,8 +120,8 @@ def test_FeatureFetcher_with_edges_hetero():
def add_node_and_edge_ids(seeds): def add_node_and_edge_ids(seeds):
subgraphs = [] subgraphs = []
reverse_edge_ids = { reverse_edge_ids = {
("n1", "e1", "n2"): torch.randint(0, 50, (10,)), "n1:e1:n2": torch.randint(0, 50, (10,)),
("n2", "e2", "n1"): torch.randint(0, 50, (10,)), "n2:e2:n1": torch.randint(0, 50, (10,)),
} }
for _ in range(3): for _ in range(3):
subgraphs.append( subgraphs.append(
......
...@@ -62,15 +62,15 @@ def test_unique_and_compact_node_pairs_hetero(): ...@@ -62,15 +62,15 @@ def test_unique_and_compact_node_pairs_hetero():
"n3": unique_N3, "n3": unique_N3,
} }
node_pairs = { node_pairs = {
("n1", "e1", "n2"): ( "n1:e1:n2": (
N1[:20], N1[:20],
N2, N2,
), ),
("n1", "e2", "n3"): ( "n1:e2:n3": (
N1[20:30], N1[20:30],
N3, N3,
), ),
("n2", "e3", "n3"): ( "n2:e3:n3": (
N2[10:], N2[10:],
N3, N3,
), ),
...@@ -84,7 +84,7 @@ def test_unique_and_compact_node_pairs_hetero(): ...@@ -84,7 +84,7 @@ def test_unique_and_compact_node_pairs_hetero():
assert torch.equal(torch.sort(nodes)[0], expected_nodes) assert torch.equal(torch.sort(nodes)[0], expected_nodes)
for etype, pair in compacted_node_pairs.items(): for etype, pair in compacted_node_pairs.items():
u, v = pair u, v = pair
u_type, _, v_type = etype u_type, _, v_type = gb.etype_str_to_tuple(etype)
u, v = unique_nodes[u_type][u], unique_nodes[v_type][v] u, v = unique_nodes[u_type][u], unique_nodes[v_type][v]
expected_u, expected_v = node_pairs[etype] expected_u, expected_v = node_pairs[etype]
assert torch.equal(u, expected_u) assert torch.equal(u, expected_u)
......
...@@ -84,7 +84,7 @@ def get_hetero_graph(): ...@@ -84,7 +84,7 @@ def get_hetero_graph():
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type. # [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3 # num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {("n1", "e1", "n2"): 0, ("n2", "e2", "n1"): 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes) metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10]) indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment