Unverified Commit 1a8e0c00 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] remove GraphMeta and update API accordingly (#6709)

parent 50cbb0ec
......@@ -1255,14 +1255,12 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
graph, _, _, gpb, _, _, _ = load_partition(
part_config, part_id, load_feats=False
)
# Construct GraphMetadata.
_, _, ntypes, etypes = load_partition_book(part_config, part_id)
node_type_to_id = {ntype: ntid for ntid, ntype in enumerate(ntypes)}
edge_type_to_id = {
_etype_tuple_to_str(etype): etid
for etid, etype in enumerate(etypes)
}
metadata = graphbolt.GraphMetadata(node_type_to_id, edge_type_to_id)
# Obtain CSC indtpr and indices.
indptr, indices, _ = graph.adj().csc()
# Initalize type per edge.
......@@ -1275,7 +1273,8 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
indices,
node_type_offset=None,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
)
orig_graph_path = os.path.join(
os.path.dirname(part_config),
......
......@@ -20,7 +20,6 @@ from .sampled_subgraph_impl import (
__all__ = [
"GraphMetadata",
"FusedCSCSamplingGraph",
"from_fused_csc",
"load_from_shared_memory",
......@@ -28,65 +27,6 @@ __all__ = [
]
class GraphMetadata:
r"""Class for metadata of csc sampling graph."""
def __init__(
self,
node_type_to_id: Dict[str, int],
edge_type_to_id: Dict[str, int],
):
"""Initialize the GraphMetadata object.
Parameters
----------
node_type_to_id : Dict[str, int]
Dictionary from node types to node type IDs.
edge_type_to_id : Dict[str, int]
Dictionary from edge types to edge type IDs.
Raises
------
AssertionError
If any of the assertions fail.
"""
node_types = list(node_type_to_id.keys())
edge_types = list(edge_type_to_id.keys())
node_type_ids = list(node_type_to_id.values())
edge_type_ids = list(edge_type_to_id.values())
# Validate node_type_to_id.
assert all(
isinstance(x, str) for x in node_types
), "Node type name should be string."
assert all(
isinstance(x, int) for x in node_type_ids
), "Node type id should be int."
assert len(node_type_ids) == len(
set(node_type_ids)
), "Multiple node types shoud not be mapped to a same id."
# Validate edge_type_to_id.
for edge_type in edge_types:
src, edge, dst = etype_str_to_tuple(edge_type)
assert isinstance(edge, str), "Edge type name should be string."
assert (
src in node_types
), f"Unrecognized node type {src} in edge type {edge_type}"
assert (
dst in node_types
), f"Unrecognized node type {dst} in edge type {edge_type}"
assert all(
isinstance(x, int) for x in edge_type_ids
), "Edge type id should be int."
assert len(edge_type_ids) == len(
set(edge_type_ids)
), "Multiple edge types shoud not be mapped to a same id."
self.node_type_to_id = node_type_to_id
self.edge_type_to_id = edge_type_to_id
class FusedCSCSamplingGraph(SamplingGraph):
r"""A sampling graph in CSC format."""
......@@ -148,9 +88,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
... type_per_edge, None, metadata)
>>> graph = gb.from_fused_csc(indptr, indices,
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> print(graph.num_nodes)
{'N0': 2, 'N1': 3}
"""
......@@ -158,14 +100,14 @@ class FusedCSCSamplingGraph(SamplingGraph):
offset = self.node_type_offset
# Homogenous.
if offset is None or self.metadata is None:
if offset is None or self.node_type_to_id is None:
return self._c_csc_graph.num_nodes()
# Heterogenous
else:
num_nodes_per_type = {
_type: (offset[_idx + 1] - offset[_idx]).item()
for _type, _idx in self.metadata.node_type_to_id.items()
for _type, _idx in self.node_type_to_id.items()
}
return num_nodes_per_type
......@@ -358,21 +300,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""Sets the edge attributes dictionary."""
self._c_csc_graph.set_edge_attributes(edge_attributes)
@property
def metadata(self) -> Optional[GraphMetadata]:
"""Returns the metadata of the graph.
[TODO][Rui] This API needs to be updated.
Returns
-------
GraphMetadata or None
If present, returns the metadata of the graph.
"""
if self.node_type_to_id is None or self.edge_type_to_id is None:
return None
return GraphMetadata(self.node_type_to_id, self.edge_type_to_id)
def in_subgraph(
self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]]
) -> FusedSampledSubgraphImpl:
......@@ -405,14 +332,16 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> ntypes = {"N0": 0, "N1": 1}
>>> etypes = {
... "N0:R0:N0": 0, "N0:R1:N1": 1, "N1:R2:N0": 2, "N1:R3:N1": 3}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
... type_per_edge, None, metadata)
>>> graph = gb.from_fused_csc(indptr, indices,
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
>>> in_subgraph = graph.in_subgraph(nodes)
>>> print(in_subgraph.node_pairs)
......@@ -465,10 +394,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
# converted to heterogeneous graphs.
node_pairs = defaultdict(list)
original_hetero_edge_ids = {}
for etype, etype_id in self.metadata.edge_type_to_id.items():
for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype]
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
src_ntype_id = self.node_type_to_id[src_ntype]
dst_ntype_id = self.node_type_to_id[dst_ntype]
mask = type_per_edge == etype_id
hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
hetero_column = (
......@@ -486,7 +415,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
def _convert_to_homogeneous_nodes(self, nodes):
homogeneous_nodes = []
for ntype, ids in nodes.items():
ntype_id = self.metadata.node_type_to_id[ntype]
ntype_id = self.node_type_to_id[ntype]
homogeneous_nodes.append(ids + self.node_type_offset[ntype_id])
return torch.cat(homogeneous_nodes)
......@@ -526,7 +455,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
subgraph_indptr = {}
node_edge_type = defaultdict(list)
original_hetero_edge_ids = {}
for etype, etype_id in self.metadata.edge_type_to_id.items():
for etype, etype_id in self.edge_type_to_id.items():
subgraph_indice[etype] = torch.empty(
(num.get(etype_id, 0),), dtype=indices.dtype
)
......@@ -539,7 +468,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
# Preprocessing saves the type of seed_nodes as the edge type
# of dst_ntype.
_, _, dst_ntype = etype_str_to_tuple(etype)
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype]
dst_ntype_id = self.node_type_to_id[dst_ntype]
node_edge_type[dst_ntype_id].append((etype, etype_id))
# construct subgraphs
for i, seed in enumerate(column):
......@@ -553,7 +482,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
)
for etype, etype_id in node_edge_type[node_type]:
src_ntype, _, _ = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype]
src_ntype_id = self.node_type_to_id[src_ntype]
num_edges = torch.searchsorted(
type_per_edge[l:r], etype_id, right=True
).item()
......@@ -578,7 +507,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
indptr=torch.tensor(subgraph_indptr[etype]),
indices=subgraph_indice[etype],
)
for etype in self.metadata.edge_type_to_id.keys()
for etype in self.edge_type_to_id.keys()
}
return SampledSubgraphImpl(
node_pairs=node_pairs,
......@@ -644,13 +573,15 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> import torch
>>> ntypes = {"n1": 0, "n2": 1}
>>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
>>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
>>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge,
... node_type_offset=node_type_offset, metadata=metadata)
>>> graph = gb.from_fused_csc(indptr, indices,
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_neighbors(nodes, fanouts)
......@@ -673,8 +604,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
assert nodes.dim() == 1, "Nodes should be 1-D tensor."
assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
expected_fanout_len = 1
if self.metadata and self.metadata.edge_type_to_id:
expected_fanout_len = len(self.metadata.edge_type_to_id)
if self.edge_type_to_id:
expected_fanout_len = len(self.edge_type_to_id)
assert len(fanouts) in [
expected_fanout_len,
1,
......@@ -830,13 +761,15 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> import torch
>>> ntypes = {"n1": 0, "n2": 1}
>>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
>>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
>>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge,
... node_type_offset=node_type_offset, metadata=metadata)
>>> graph = gb.from_fused_csc(indptr, indices,
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_layer_neighbors(nodes, fanouts)
......@@ -907,7 +840,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
), "The 'node_type_offset' array is necessary for performing \
negative sampling by edge type."
_, _, dst_node_type = etype_str_to_tuple(edge_type)
dst_node_type_id = self.metadata.node_type_to_id[dst_node_type]
dst_node_type_id = self.node_type_to_id[dst_node_type]
max_node_id = (
self.node_type_offset[dst_node_type_id + 1]
- self.node_type_offset[dst_node_type_id]
......@@ -965,8 +898,9 @@ def from_fused_csc(
indices: torch.Tensor,
node_type_offset: Optional[torch.tensor] = None,
type_per_edge: Optional[torch.tensor] = None,
node_type_to_id: Optional[Dict[str, int]] = None,
edge_type_to_id: Optional[Dict[str, int]] = None,
edge_attributes: Optional[Dict[str, torch.tensor]] = None,
metadata: Optional[GraphMetadata] = None,
) -> FusedCSCSamplingGraph:
"""Create a FusedCSCSamplingGraph object from a CSC representation.
......@@ -982,10 +916,12 @@ def from_fused_csc(
Offset of node types in the graph, by default None.
type_per_edge : Optional[torch.tensor], optional
Type ids of each edge in the graph, by default None.
node_type_to_id : Optional[Dict[str, int]], optional
Map node types to ids, by default None.
edge_type_to_id : Optional[Dict[str, int]], optional
Map edge types to ids, by default None.
edge_attributes: Optional[Dict[str, torch.tensor]], optional
Edge attributes of the graph, by default None.
metadata: Optional[GraphMetadata], optional
Metadata of the graph, by default None.
Returns
-------
......@@ -995,8 +931,7 @@ def from_fused_csc(
Examples
--------
>>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2}
>>> etypes = {('n1', 'e1', 'n2'): 0, ('n1', 'e2', 'n3'): 1}
>>> metadata = graphbolt.GraphMetadata(ntypes, etypes)
>>> etypes = {'n1:e1:n2': 0, 'n1:e2:n3': 1}
>>> csc_indptr = torch.tensor([0, 2, 5, 7])
>>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3])
>>> node_type_offset = torch.tensor([0, 1, 2, 3])
......@@ -1004,18 +939,50 @@ def from_fused_csc(
>>> graph = graphbolt.from_fused_csc(csc_indptr, indices,
... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... edge_attributes=None, metadata=metadata)
... node_type_to_id=ntypes, edge_type_to_id=etypes,
... edge_attributes=None,)
>>> print(graph)
FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]),
total_num_nodes=3, total_num_edges=7)
"""
if metadata and metadata.node_type_to_id and node_type_offset is not None:
assert len(metadata.node_type_to_id) + 1 == node_type_offset.size(
if node_type_to_id is not None and edge_type_to_id is not None:
node_types = list(node_type_to_id.keys())
edge_types = list(edge_type_to_id.keys())
node_type_ids = list(node_type_to_id.values())
edge_type_ids = list(edge_type_to_id.values())
# Validate node_type_to_id.
assert all(
isinstance(x, str) for x in node_types
), "Node type name should be string."
assert all(
isinstance(x, int) for x in node_type_ids
), "Node type id should be int."
assert len(node_type_ids) == len(
set(node_type_ids)
), "Multiple node types shoud not be mapped to a same id."
# Validate edge_type_to_id.
for edge_type in edge_types:
src, edge, dst = etype_str_to_tuple(edge_type)
assert isinstance(edge, str), "Edge type name should be string."
assert (
src in node_types
), f"Unrecognized node type {src} in edge type {edge_type}"
assert (
dst in node_types
), f"Unrecognized node type {dst} in edge type {edge_type}"
assert all(
isinstance(x, int) for x in edge_type_ids
), "Edge type id should be int."
assert len(edge_type_ids) == len(
set(edge_type_ids)
), "Multiple edge types shoud not be mapped to a same id."
if node_type_offset is not None:
assert len(node_type_to_id) + 1 == node_type_offset.size(
0
), "node_type_offset length should be |ntypes| + 1."
node_type_to_id = metadata.node_type_to_id if metadata else None
edge_type_to_id = metadata.edge_type_to_id if metadata else None
return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc(
csc_indptr,
......@@ -1091,7 +1058,8 @@ def from_dglgraph(
homo_g, ntype_count, _ = to_homogeneous(g, return_count=True)
if is_homogeneous:
metadata = None
node_type_to_id = None
edge_type_to_id = None
else:
# Initialize metadata.
node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}
......@@ -1099,7 +1067,6 @@ def from_dglgraph(
etype_tuple_to_str(etype): g.get_etype_id(etype)
for etype in g.canonical_etypes
}
metadata = GraphMetadata(node_type_to_id, edge_type_to_id)
# Obtain CSC matrix.
indptr, indices, edge_ids = homo_g.adj_tensors("csc")
......@@ -1118,8 +1085,6 @@ def from_dglgraph(
# Assign edge attributes according to the original eids mapping.
edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
node_type_to_id = metadata.node_type_to_id if metadata else None
edge_type_to_id = metadata.edge_type_to_id if metadata else None
return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc(
indptr,
......
......@@ -112,7 +112,7 @@ class NeighborSampler(SubgraphSampler):
num_layers = len(self.fanouts)
# Enrich seeds with all node types.
if isinstance(seeds, dict):
ntypes = list(self.graph.metadata.node_type_to_id.keys())
ntypes = list(self.graph.node_type_to_id.keys())
seeds = {
ntype: seeds.get(ntype, torch.LongTensor([]))
for ntype in ntypes
......
......@@ -705,9 +705,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
assert th.equal(orig_indices, new_g.indices)
assert new_g.node_type_offset is None
assert all(new_g.type_per_edge == 0)
for node_type, type_id in new_g.metadata.node_type_to_id.items():
for node_type, type_id in new_g.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.metadata.edge_type_to_id.items():
for edge_type, type_id in new_g.edge_type_to_id.items():
assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
......@@ -736,9 +736,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
orig_indptr, orig_indices, _ = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices)
for node_type, type_id in new_g.metadata.node_type_to_id.items():
for node_type, type_id in new_g.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.metadata.edge_type_to_id.items():
for edge_type, type_id in new_g.edge_type_to_id.items():
assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
assert new_g.node_type_offset is None
assert th.equal(orig_g.edata[dgl.ETYPE], new_g.type_per_edge)
......
......@@ -32,7 +32,7 @@ def random_homo_graph(num_nodes, num_edges):
return csc_indptr, indices
def get_metadata(num_ntypes, num_etypes):
def get_type_to_id(num_ntypes, num_etypes):
ntypes = {f"n{i}": i for i in range(num_ntypes)}
etypes = {}
count = 0
......@@ -42,7 +42,7 @@ def get_metadata(num_ntypes, num_etypes):
break
etypes.update({f"n{n1}:e{count}:n{n2}": count})
count += 1
return gb.GraphMetadata(ntypes, etypes)
return ntypes, etypes
def get_ntypes_and_etypes(num_nodes, num_ntypes, num_etypes):
......@@ -82,7 +82,8 @@ def random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
gb_g.indices,
gb_g.node_type_offset,
gb_g.type_per_edge,
gb_g.metadata,
gb_g.node_type_to_id,
gb_g.edge_type_to_id,
)
......
......@@ -76,7 +76,6 @@ def test_InSubgraphSampler_hetero():
"N1:R2:N0": 2,
"N1:R3:N1": 3,
}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
node_type_offset = torch.LongTensor([0, 3, 6])
......@@ -86,7 +85,8 @@ def test_InSubgraphSampler_hetero():
indices=indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
item_set = gb.ItemSetDict(
......
......@@ -106,7 +106,6 @@ def get_hetero_graph():
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
......@@ -116,7 +115,8 @@ def get_hetero_graph():
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
......
......@@ -1026,11 +1026,12 @@ def test_OnDiskDataset_Graph_homogeneous():
assert torch.equal(graph.csc_indptr, graph2.csc_indptr)
assert torch.equal(graph.indices, graph2.indices)
assert graph.metadata is None and graph2.metadata is None
assert (
graph.node_type_offset is None and graph2.node_type_offset is None
)
assert graph.type_per_edge is None and graph2.type_per_edge is None
assert graph.node_type_to_id is None and graph2.node_type_to_id is None
assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None
def test_OnDiskDataset_Graph_heterogeneous():
......@@ -1040,10 +1041,16 @@ def test_OnDiskDataset_Graph_heterogeneous():
indices,
node_type_offset,
type_per_edge,
metadata,
node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)
graph = gb.from_fused_csc(
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata
csc_indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
)
with tempfile.TemporaryDirectory() as test_dir:
......@@ -1065,8 +1072,8 @@ def test_OnDiskDataset_Graph_heterogeneous():
assert torch.equal(graph.indices, graph2.indices)
assert torch.equal(graph.node_type_offset, graph2.node_type_offset)
assert torch.equal(graph.type_per_edge, graph2.type_per_edge)
assert graph.metadata.node_type_to_id == graph2.metadata.node_type_to_id
assert graph.metadata.edge_type_to_id == graph2.metadata.edge_type_to_id
assert graph.node_type_to_id == graph2.node_type_to_id
assert graph.edge_type_to_id == graph2.edge_type_to_id
def test_OnDiskDataset_Metadata():
......@@ -1854,15 +1861,17 @@ def test_OnDiskDataset_all_nodes_set_hetero():
indices,
node_type_offset,
type_per_edge,
metadata,
node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
edge_attributes=None,
metadata=metadata,
)
with tempfile.TemporaryDirectory() as test_dir:
......
......@@ -141,7 +141,6 @@ def get_hetero_graph():
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
......@@ -151,7 +150,8 @@ def get_hetero_graph():
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
......
......@@ -117,7 +117,6 @@ def get_hetero_graph():
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
......@@ -127,7 +126,8 @@ def get_hetero_graph():
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
metadata=metadata,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
)
......@@ -207,7 +207,8 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
indices,
node_type_offset,
type_per_edge,
metadata,
node_type_to_id,
edge_type_to_id,
) = gb_test_utils.random_hetero_graph(
num_nodes, num_edges, num_ntypes, num_etypes
)
......@@ -218,10 +219,11 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
graph = gb.from_fused_csc(
csc_indptr,
indices,
node_type_offset,
type_per_edge,
edge_attributes,
metadata,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
edge_attributes=edge_attributes,
)
itemset = gb.ItemSetDict(
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment