Unverified Commit 1a8e0c00 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] remove GraphMeta and update API accordingly (#6709)

parent 50cbb0ec
...@@ -1255,14 +1255,12 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config): ...@@ -1255,14 +1255,12 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
graph, _, _, gpb, _, _, _ = load_partition( graph, _, _, gpb, _, _, _ = load_partition(
part_config, part_id, load_feats=False part_config, part_id, load_feats=False
) )
# Construct GraphMetadata.
_, _, ntypes, etypes = load_partition_book(part_config, part_id) _, _, ntypes, etypes = load_partition_book(part_config, part_id)
node_type_to_id = {ntype: ntid for ntid, ntype in enumerate(ntypes)} node_type_to_id = {ntype: ntid for ntid, ntype in enumerate(ntypes)}
edge_type_to_id = { edge_type_to_id = {
_etype_tuple_to_str(etype): etid _etype_tuple_to_str(etype): etid
for etid, etype in enumerate(etypes) for etid, etype in enumerate(etypes)
} }
metadata = graphbolt.GraphMetadata(node_type_to_id, edge_type_to_id)
# Obtain CSC indtpr and indices. # Obtain CSC indtpr and indices.
indptr, indices, _ = graph.adj().csc() indptr, indices, _ = graph.adj().csc()
# Initalize type per edge. # Initalize type per edge.
...@@ -1275,7 +1273,8 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config): ...@@ -1275,7 +1273,8 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
indices, indices,
node_type_offset=None, node_type_offset=None,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
) )
orig_graph_path = os.path.join( orig_graph_path = os.path.join(
os.path.dirname(part_config), os.path.dirname(part_config),
......
...@@ -20,7 +20,6 @@ from .sampled_subgraph_impl import ( ...@@ -20,7 +20,6 @@ from .sampled_subgraph_impl import (
__all__ = [ __all__ = [
"GraphMetadata",
"FusedCSCSamplingGraph", "FusedCSCSamplingGraph",
"from_fused_csc", "from_fused_csc",
"load_from_shared_memory", "load_from_shared_memory",
...@@ -28,65 +27,6 @@ __all__ = [ ...@@ -28,65 +27,6 @@ __all__ = [
] ]
class GraphMetadata:
r"""Class for metadata of csc sampling graph."""
def __init__(
self,
node_type_to_id: Dict[str, int],
edge_type_to_id: Dict[str, int],
):
"""Initialize the GraphMetadata object.
Parameters
----------
node_type_to_id : Dict[str, int]
Dictionary from node types to node type IDs.
edge_type_to_id : Dict[str, int]
Dictionary from edge types to edge type IDs.
Raises
------
AssertionError
If any of the assertions fail.
"""
node_types = list(node_type_to_id.keys())
edge_types = list(edge_type_to_id.keys())
node_type_ids = list(node_type_to_id.values())
edge_type_ids = list(edge_type_to_id.values())
# Validate node_type_to_id.
assert all(
isinstance(x, str) for x in node_types
), "Node type name should be string."
assert all(
isinstance(x, int) for x in node_type_ids
), "Node type id should be int."
assert len(node_type_ids) == len(
set(node_type_ids)
), "Multiple node types shoud not be mapped to a same id."
# Validate edge_type_to_id.
for edge_type in edge_types:
src, edge, dst = etype_str_to_tuple(edge_type)
assert isinstance(edge, str), "Edge type name should be string."
assert (
src in node_types
), f"Unrecognized node type {src} in edge type {edge_type}"
assert (
dst in node_types
), f"Unrecognized node type {dst} in edge type {edge_type}"
assert all(
isinstance(x, int) for x in edge_type_ids
), "Edge type id should be int."
assert len(edge_type_ids) == len(
set(edge_type_ids)
), "Multiple edge types shoud not be mapped to a same id."
self.node_type_to_id = node_type_to_id
self.edge_type_to_id = edge_type_to_id
class FusedCSCSamplingGraph(SamplingGraph): class FusedCSCSamplingGraph(SamplingGraph):
r"""A sampling graph in CSC format.""" r"""A sampling graph in CSC format."""
...@@ -148,9 +88,11 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -148,9 +88,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> node_type_offset = torch.LongTensor([0, 2, 5]) >>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor( >>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3]) ... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> metadata = gb.GraphMetadata(ntypes, etypes) >>> graph = gb.from_fused_csc(indptr, indices,
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset, ... node_type_offset=node_type_offset,
... type_per_edge, None, metadata) ... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> print(graph.num_nodes) >>> print(graph.num_nodes)
{'N0': 2, 'N1': 3} {'N0': 2, 'N1': 3}
""" """
...@@ -158,14 +100,14 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -158,14 +100,14 @@ class FusedCSCSamplingGraph(SamplingGraph):
offset = self.node_type_offset offset = self.node_type_offset
# Homogenous. # Homogenous.
if offset is None or self.metadata is None: if offset is None or self.node_type_to_id is None:
return self._c_csc_graph.num_nodes() return self._c_csc_graph.num_nodes()
# Heterogenous # Heterogenous
else: else:
num_nodes_per_type = { num_nodes_per_type = {
_type: (offset[_idx + 1] - offset[_idx]).item() _type: (offset[_idx + 1] - offset[_idx]).item()
for _type, _idx in self.metadata.node_type_to_id.items() for _type, _idx in self.node_type_to_id.items()
} }
return num_nodes_per_type return num_nodes_per_type
...@@ -358,21 +300,6 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -358,21 +300,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""Sets the edge attributes dictionary.""" """Sets the edge attributes dictionary."""
self._c_csc_graph.set_edge_attributes(edge_attributes) self._c_csc_graph.set_edge_attributes(edge_attributes)
@property
def metadata(self) -> Optional[GraphMetadata]:
"""Returns the metadata of the graph.
[TODO][Rui] This API needs to be updated.
Returns
-------
GraphMetadata or None
If present, returns the metadata of the graph.
"""
if self.node_type_to_id is None or self.edge_type_to_id is None:
return None
return GraphMetadata(self.node_type_to_id, self.edge_type_to_id)
def in_subgraph( def in_subgraph(
self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]] self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]]
) -> FusedSampledSubgraphImpl: ) -> FusedSampledSubgraphImpl:
...@@ -405,14 +332,16 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -405,14 +332,16 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> ntypes = {"N0": 0, "N1": 1} >>> ntypes = {"N0": 0, "N1": 1}
>>> etypes = { >>> etypes = {
... "N0:R0:N0": 0, "N0:R1:N1": 1, "N1:R2:N0": 2, "N1:R3:N1": 3} ... "N0:R0:N0": 0, "N0:R1:N1": 1, "N1:R2:N0": 2, "N1:R3:N1": 3}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12]) >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4]) >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
>>> node_type_offset = torch.LongTensor([0, 2, 5]) >>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor( >>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3]) ... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset, >>> graph = gb.from_fused_csc(indptr, indices,
... type_per_edge, None, metadata) ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])} >>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
>>> in_subgraph = graph.in_subgraph(nodes) >>> in_subgraph = graph.in_subgraph(nodes)
>>> print(in_subgraph.node_pairs) >>> print(in_subgraph.node_pairs)
...@@ -465,10 +394,10 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -465,10 +394,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
# converted to heterogeneous graphs. # converted to heterogeneous graphs.
node_pairs = defaultdict(list) node_pairs = defaultdict(list)
original_hetero_edge_ids = {} original_hetero_edge_ids = {}
for etype, etype_id in self.metadata.edge_type_to_id.items(): for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype) src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype] src_ntype_id = self.node_type_to_id[src_ntype]
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype] dst_ntype_id = self.node_type_to_id[dst_ntype]
mask = type_per_edge == etype_id mask = type_per_edge == etype_id
hetero_row = row[mask] - self.node_type_offset[src_ntype_id] hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
hetero_column = ( hetero_column = (
...@@ -486,7 +415,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -486,7 +415,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
def _convert_to_homogeneous_nodes(self, nodes): def _convert_to_homogeneous_nodes(self, nodes):
homogeneous_nodes = [] homogeneous_nodes = []
for ntype, ids in nodes.items(): for ntype, ids in nodes.items():
ntype_id = self.metadata.node_type_to_id[ntype] ntype_id = self.node_type_to_id[ntype]
homogeneous_nodes.append(ids + self.node_type_offset[ntype_id]) homogeneous_nodes.append(ids + self.node_type_offset[ntype_id])
return torch.cat(homogeneous_nodes) return torch.cat(homogeneous_nodes)
...@@ -526,7 +455,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -526,7 +455,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
subgraph_indptr = {} subgraph_indptr = {}
node_edge_type = defaultdict(list) node_edge_type = defaultdict(list)
original_hetero_edge_ids = {} original_hetero_edge_ids = {}
for etype, etype_id in self.metadata.edge_type_to_id.items(): for etype, etype_id in self.edge_type_to_id.items():
subgraph_indice[etype] = torch.empty( subgraph_indice[etype] = torch.empty(
(num.get(etype_id, 0),), dtype=indices.dtype (num.get(etype_id, 0),), dtype=indices.dtype
) )
...@@ -539,7 +468,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -539,7 +468,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
# Preprocessing saves the type of seed_nodes as the edge type # Preprocessing saves the type of seed_nodes as the edge type
# of dst_ntype. # of dst_ntype.
_, _, dst_ntype = etype_str_to_tuple(etype) _, _, dst_ntype = etype_str_to_tuple(etype)
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype] dst_ntype_id = self.node_type_to_id[dst_ntype]
node_edge_type[dst_ntype_id].append((etype, etype_id)) node_edge_type[dst_ntype_id].append((etype, etype_id))
# construct subgraphs # construct subgraphs
for i, seed in enumerate(column): for i, seed in enumerate(column):
...@@ -553,7 +482,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -553,7 +482,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
) )
for etype, etype_id in node_edge_type[node_type]: for etype, etype_id in node_edge_type[node_type]:
src_ntype, _, _ = etype_str_to_tuple(etype) src_ntype, _, _ = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype] src_ntype_id = self.node_type_to_id[src_ntype]
num_edges = torch.searchsorted( num_edges = torch.searchsorted(
type_per_edge[l:r], etype_id, right=True type_per_edge[l:r], etype_id, right=True
).item() ).item()
...@@ -578,7 +507,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -578,7 +507,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
indptr=torch.tensor(subgraph_indptr[etype]), indptr=torch.tensor(subgraph_indptr[etype]),
indices=subgraph_indice[etype], indices=subgraph_indice[etype],
) )
for etype in self.metadata.edge_type_to_id.keys() for etype in self.edge_type_to_id.keys()
} }
return SampledSubgraphImpl( return SampledSubgraphImpl(
node_pairs=node_pairs, node_pairs=node_pairs,
...@@ -644,13 +573,15 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -644,13 +573,15 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> import torch >>> import torch
>>> ntypes = {"n1": 0, "n2": 1} >>> ntypes = {"n1": 0, "n2": 1}
>>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9]) >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
>>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
>>> node_type_offset = torch.LongTensor([0, 2, 5]) >>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0]) >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
>>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge, >>> graph = gb.from_fused_csc(indptr, indices,
... node_type_offset=node_type_offset, metadata=metadata) ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])} >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1]) >>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_neighbors(nodes, fanouts) >>> subgraph = graph.sample_neighbors(nodes, fanouts)
...@@ -673,8 +604,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -673,8 +604,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
assert nodes.dim() == 1, "Nodes should be 1-D tensor." assert nodes.dim() == 1, "Nodes should be 1-D tensor."
assert fanouts.dim() == 1, "Fanouts should be 1-D tensor." assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
expected_fanout_len = 1 expected_fanout_len = 1
if self.metadata and self.metadata.edge_type_to_id: if self.edge_type_to_id:
expected_fanout_len = len(self.metadata.edge_type_to_id) expected_fanout_len = len(self.edge_type_to_id)
assert len(fanouts) in [ assert len(fanouts) in [
expected_fanout_len, expected_fanout_len,
1, 1,
...@@ -830,13 +761,15 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -830,13 +761,15 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> import torch >>> import torch
>>> ntypes = {"n1": 0, "n2": 1} >>> ntypes = {"n1": 0, "n2": 1}
>>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9]) >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
>>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
>>> node_type_offset = torch.LongTensor([0, 2, 5]) >>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0]) >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
>>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge, >>> graph = gb.from_fused_csc(indptr, indices,
... node_type_offset=node_type_offset, metadata=metadata) ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])} >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1]) >>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_layer_neighbors(nodes, fanouts) >>> subgraph = graph.sample_layer_neighbors(nodes, fanouts)
...@@ -907,7 +840,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -907,7 +840,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
), "The 'node_type_offset' array is necessary for performing \ ), "The 'node_type_offset' array is necessary for performing \
negative sampling by edge type." negative sampling by edge type."
_, _, dst_node_type = etype_str_to_tuple(edge_type) _, _, dst_node_type = etype_str_to_tuple(edge_type)
dst_node_type_id = self.metadata.node_type_to_id[dst_node_type] dst_node_type_id = self.node_type_to_id[dst_node_type]
max_node_id = ( max_node_id = (
self.node_type_offset[dst_node_type_id + 1] self.node_type_offset[dst_node_type_id + 1]
- self.node_type_offset[dst_node_type_id] - self.node_type_offset[dst_node_type_id]
...@@ -965,8 +898,9 @@ def from_fused_csc( ...@@ -965,8 +898,9 @@ def from_fused_csc(
indices: torch.Tensor, indices: torch.Tensor,
node_type_offset: Optional[torch.tensor] = None, node_type_offset: Optional[torch.tensor] = None,
type_per_edge: Optional[torch.tensor] = None, type_per_edge: Optional[torch.tensor] = None,
node_type_to_id: Optional[Dict[str, int]] = None,
edge_type_to_id: Optional[Dict[str, int]] = None,
edge_attributes: Optional[Dict[str, torch.tensor]] = None, edge_attributes: Optional[Dict[str, torch.tensor]] = None,
metadata: Optional[GraphMetadata] = None,
) -> FusedCSCSamplingGraph: ) -> FusedCSCSamplingGraph:
"""Create a FusedCSCSamplingGraph object from a CSC representation. """Create a FusedCSCSamplingGraph object from a CSC representation.
...@@ -982,10 +916,12 @@ def from_fused_csc( ...@@ -982,10 +916,12 @@ def from_fused_csc(
Offset of node types in the graph, by default None. Offset of node types in the graph, by default None.
type_per_edge : Optional[torch.tensor], optional type_per_edge : Optional[torch.tensor], optional
Type ids of each edge in the graph, by default None. Type ids of each edge in the graph, by default None.
node_type_to_id : Optional[Dict[str, int]], optional
Map node types to ids, by default None.
edge_type_to_id : Optional[Dict[str, int]], optional
Map edge types to ids, by default None.
edge_attributes: Optional[Dict[str, torch.tensor]], optional edge_attributes: Optional[Dict[str, torch.tensor]], optional
Edge attributes of the graph, by default None. Edge attributes of the graph, by default None.
metadata: Optional[GraphMetadata], optional
Metadata of the graph, by default None.
Returns Returns
------- -------
...@@ -995,8 +931,7 @@ def from_fused_csc( ...@@ -995,8 +931,7 @@ def from_fused_csc(
Examples Examples
-------- --------
>>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2} >>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2}
>>> etypes = {('n1', 'e1', 'n2'): 0, ('n1', 'e2', 'n3'): 1} >>> etypes = {'n1:e1:n2': 0, 'n1:e2:n3': 1}
>>> metadata = graphbolt.GraphMetadata(ntypes, etypes)
>>> csc_indptr = torch.tensor([0, 2, 5, 7]) >>> csc_indptr = torch.tensor([0, 2, 5, 7])
>>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3]) >>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3])
>>> node_type_offset = torch.tensor([0, 1, 2, 3]) >>> node_type_offset = torch.tensor([0, 1, 2, 3])
...@@ -1004,18 +939,50 @@ def from_fused_csc( ...@@ -1004,18 +939,50 @@ def from_fused_csc(
>>> graph = graphbolt.from_fused_csc(csc_indptr, indices, >>> graph = graphbolt.from_fused_csc(csc_indptr, indices,
... node_type_offset=node_type_offset, ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge, ... type_per_edge=type_per_edge,
... edge_attributes=None, metadata=metadata) ... node_type_to_id=ntypes, edge_type_to_id=etypes,
... edge_attributes=None,)
>>> print(graph) >>> print(graph)
FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]), FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]), indices=tensor([1, 3, 0, 1, 2, 0, 3]),
total_num_nodes=3, total_num_edges=7) total_num_nodes=3, total_num_edges=7)
""" """
if metadata and metadata.node_type_to_id and node_type_offset is not None: if node_type_to_id is not None and edge_type_to_id is not None:
assert len(metadata.node_type_to_id) + 1 == node_type_offset.size( node_types = list(node_type_to_id.keys())
0 edge_types = list(edge_type_to_id.keys())
), "node_type_offset length should be |ntypes| + 1." node_type_ids = list(node_type_to_id.values())
node_type_to_id = metadata.node_type_to_id if metadata else None edge_type_ids = list(edge_type_to_id.values())
edge_type_to_id = metadata.edge_type_to_id if metadata else None
# Validate node_type_to_id.
assert all(
isinstance(x, str) for x in node_types
), "Node type name should be string."
assert all(
isinstance(x, int) for x in node_type_ids
), "Node type id should be int."
assert len(node_type_ids) == len(
set(node_type_ids)
), "Multiple node types shoud not be mapped to a same id."
# Validate edge_type_to_id.
for edge_type in edge_types:
src, edge, dst = etype_str_to_tuple(edge_type)
assert isinstance(edge, str), "Edge type name should be string."
assert (
src in node_types
), f"Unrecognized node type {src} in edge type {edge_type}"
assert (
dst in node_types
), f"Unrecognized node type {dst} in edge type {edge_type}"
assert all(
isinstance(x, int) for x in edge_type_ids
), "Edge type id should be int."
assert len(edge_type_ids) == len(
set(edge_type_ids)
), "Multiple edge types shoud not be mapped to a same id."
if node_type_offset is not None:
assert len(node_type_to_id) + 1 == node_type_offset.size(
0
), "node_type_offset length should be |ntypes| + 1."
return FusedCSCSamplingGraph( return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc( torch.ops.graphbolt.from_fused_csc(
csc_indptr, csc_indptr,
...@@ -1091,7 +1058,8 @@ def from_dglgraph( ...@@ -1091,7 +1058,8 @@ def from_dglgraph(
homo_g, ntype_count, _ = to_homogeneous(g, return_count=True) homo_g, ntype_count, _ = to_homogeneous(g, return_count=True)
if is_homogeneous: if is_homogeneous:
metadata = None node_type_to_id = None
edge_type_to_id = None
else: else:
# Initialize metadata. # Initialize metadata.
node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes} node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}
...@@ -1099,7 +1067,6 @@ def from_dglgraph( ...@@ -1099,7 +1067,6 @@ def from_dglgraph(
etype_tuple_to_str(etype): g.get_etype_id(etype) etype_tuple_to_str(etype): g.get_etype_id(etype)
for etype in g.canonical_etypes for etype in g.canonical_etypes
} }
metadata = GraphMetadata(node_type_to_id, edge_type_to_id)
# Obtain CSC matrix. # Obtain CSC matrix.
indptr, indices, edge_ids = homo_g.adj_tensors("csc") indptr, indices, edge_ids = homo_g.adj_tensors("csc")
...@@ -1118,8 +1085,6 @@ def from_dglgraph( ...@@ -1118,8 +1085,6 @@ def from_dglgraph(
# Assign edge attributes according to the original eids mapping. # Assign edge attributes according to the original eids mapping.
edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids] edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
node_type_to_id = metadata.node_type_to_id if metadata else None
edge_type_to_id = metadata.edge_type_to_id if metadata else None
return FusedCSCSamplingGraph( return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc( torch.ops.graphbolt.from_fused_csc(
indptr, indptr,
......
...@@ -112,7 +112,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -112,7 +112,7 @@ class NeighborSampler(SubgraphSampler):
num_layers = len(self.fanouts) num_layers = len(self.fanouts)
# Enrich seeds with all node types. # Enrich seeds with all node types.
if isinstance(seeds, dict): if isinstance(seeds, dict):
ntypes = list(self.graph.metadata.node_type_to_id.keys()) ntypes = list(self.graph.node_type_to_id.keys())
seeds = { seeds = {
ntype: seeds.get(ntype, torch.LongTensor([])) ntype: seeds.get(ntype, torch.LongTensor([]))
for ntype in ntypes for ntype in ntypes
......
...@@ -705,9 +705,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo( ...@@ -705,9 +705,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
assert th.equal(orig_indices, new_g.indices) assert th.equal(orig_indices, new_g.indices)
assert new_g.node_type_offset is None assert new_g.node_type_offset is None
assert all(new_g.type_per_edge == 0) assert all(new_g.type_per_edge == 0)
for node_type, type_id in new_g.metadata.node_type_to_id.items(): for node_type, type_id in new_g.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.metadata.edge_type_to_id.items(): for edge_type, type_id in new_g.edge_type_to_id.items():
assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
...@@ -736,9 +736,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero( ...@@ -736,9 +736,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
orig_indptr, orig_indices, _ = orig_g.adj().csc() orig_indptr, orig_indices, _ = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr) assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices) assert th.equal(orig_indices, new_g.indices)
for node_type, type_id in new_g.metadata.node_type_to_id.items(): for node_type, type_id in new_g.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.metadata.edge_type_to_id.items(): for edge_type, type_id in new_g.edge_type_to_id.items():
assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
assert new_g.node_type_offset is None assert new_g.node_type_offset is None
assert th.equal(orig_g.edata[dgl.ETYPE], new_g.type_per_edge) assert th.equal(orig_g.edata[dgl.ETYPE], new_g.type_per_edge)
......
...@@ -32,7 +32,7 @@ def random_homo_graph(num_nodes, num_edges): ...@@ -32,7 +32,7 @@ def random_homo_graph(num_nodes, num_edges):
return csc_indptr, indices return csc_indptr, indices
def get_metadata(num_ntypes, num_etypes): def get_type_to_id(num_ntypes, num_etypes):
ntypes = {f"n{i}": i for i in range(num_ntypes)} ntypes = {f"n{i}": i for i in range(num_ntypes)}
etypes = {} etypes = {}
count = 0 count = 0
...@@ -42,7 +42,7 @@ def get_metadata(num_ntypes, num_etypes): ...@@ -42,7 +42,7 @@ def get_metadata(num_ntypes, num_etypes):
break break
etypes.update({f"n{n1}:e{count}:n{n2}": count}) etypes.update({f"n{n1}:e{count}:n{n2}": count})
count += 1 count += 1
return gb.GraphMetadata(ntypes, etypes) return ntypes, etypes
def get_ntypes_and_etypes(num_nodes, num_ntypes, num_etypes): def get_ntypes_and_etypes(num_nodes, num_ntypes, num_etypes):
...@@ -82,7 +82,8 @@ def random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes): ...@@ -82,7 +82,8 @@ def random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
gb_g.indices, gb_g.indices,
gb_g.node_type_offset, gb_g.node_type_offset,
gb_g.type_per_edge, gb_g.type_per_edge,
gb_g.metadata, gb_g.node_type_to_id,
gb_g.edge_type_to_id,
) )
......
...@@ -76,7 +76,6 @@ def test_InSubgraphSampler_hetero(): ...@@ -76,7 +76,6 @@ def test_InSubgraphSampler_hetero():
"N1:R2:N0": 2, "N1:R2:N0": 2,
"N1:R3:N1": 3, "N1:R3:N1": 3,
} }
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14]) indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4]) indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
node_type_offset = torch.LongTensor([0, 3, 6]) node_type_offset = torch.LongTensor([0, 3, 6])
...@@ -86,7 +85,8 @@ def test_InSubgraphSampler_hetero(): ...@@ -86,7 +85,8 @@ def test_InSubgraphSampler_hetero():
indices=indices, indices=indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
......
...@@ -106,7 +106,6 @@ def get_hetero_graph(): ...@@ -106,7 +106,6 @@ def get_hetero_graph():
# num_nodes = 5, num_n1 = 2, num_n2 = 3 # num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10]) indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
...@@ -116,7 +115,8 @@ def get_hetero_graph(): ...@@ -116,7 +115,8 @@ def get_hetero_graph():
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
......
...@@ -1026,11 +1026,12 @@ def test_OnDiskDataset_Graph_homogeneous(): ...@@ -1026,11 +1026,12 @@ def test_OnDiskDataset_Graph_homogeneous():
assert torch.equal(graph.csc_indptr, graph2.csc_indptr) assert torch.equal(graph.csc_indptr, graph2.csc_indptr)
assert torch.equal(graph.indices, graph2.indices) assert torch.equal(graph.indices, graph2.indices)
assert graph.metadata is None and graph2.metadata is None
assert ( assert (
graph.node_type_offset is None and graph2.node_type_offset is None graph.node_type_offset is None and graph2.node_type_offset is None
) )
assert graph.type_per_edge is None and graph2.type_per_edge is None assert graph.type_per_edge is None and graph2.type_per_edge is None
assert graph.node_type_to_id is None and graph2.node_type_to_id is None
assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None
def test_OnDiskDataset_Graph_heterogeneous(): def test_OnDiskDataset_Graph_heterogeneous():
...@@ -1040,10 +1041,16 @@ def test_OnDiskDataset_Graph_heterogeneous(): ...@@ -1040,10 +1041,16 @@ def test_OnDiskDataset_Graph_heterogeneous():
indices, indices,
node_type_offset, node_type_offset,
type_per_edge, type_per_edge,
metadata, node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4) ) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)
graph = gb.from_fused_csc( graph = gb.from_fused_csc(
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata csc_indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
) )
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
...@@ -1065,8 +1072,8 @@ def test_OnDiskDataset_Graph_heterogeneous(): ...@@ -1065,8 +1072,8 @@ def test_OnDiskDataset_Graph_heterogeneous():
assert torch.equal(graph.indices, graph2.indices) assert torch.equal(graph.indices, graph2.indices)
assert torch.equal(graph.node_type_offset, graph2.node_type_offset) assert torch.equal(graph.node_type_offset, graph2.node_type_offset)
assert torch.equal(graph.type_per_edge, graph2.type_per_edge) assert torch.equal(graph.type_per_edge, graph2.type_per_edge)
assert graph.metadata.node_type_to_id == graph2.metadata.node_type_to_id assert graph.node_type_to_id == graph2.node_type_to_id
assert graph.metadata.edge_type_to_id == graph2.metadata.edge_type_to_id assert graph.edge_type_to_id == graph2.edge_type_to_id
def test_OnDiskDataset_Metadata(): def test_OnDiskDataset_Metadata():
...@@ -1854,15 +1861,17 @@ def test_OnDiskDataset_all_nodes_set_hetero(): ...@@ -1854,15 +1861,17 @@ def test_OnDiskDataset_all_nodes_set_hetero():
indices, indices,
node_type_offset, node_type_offset,
type_per_edge, type_per_edge,
metadata, node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4) ) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)
graph = gb.from_fused_csc( graph = gb.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
edge_attributes=None, edge_attributes=None,
metadata=metadata,
) )
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
......
...@@ -141,7 +141,6 @@ def get_hetero_graph(): ...@@ -141,7 +141,6 @@ def get_hetero_graph():
# num_nodes = 5, num_n1 = 2, num_n2 = 3 # num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10]) indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
...@@ -151,7 +150,8 @@ def get_hetero_graph(): ...@@ -151,7 +150,8 @@ def get_hetero_graph():
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
......
...@@ -117,7 +117,6 @@ def get_hetero_graph(): ...@@ -117,7 +117,6 @@ def get_hetero_graph():
# num_nodes = 5, num_n1 = 2, num_n2 = 3 # num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10]) indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
...@@ -127,7 +126,8 @@ def get_hetero_graph(): ...@@ -127,7 +126,8 @@ def get_hetero_graph():
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
...@@ -207,7 +207,8 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor): ...@@ -207,7 +207,8 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
indices, indices,
node_type_offset, node_type_offset,
type_per_edge, type_per_edge,
metadata, node_type_to_id,
edge_type_to_id,
) = gb_test_utils.random_hetero_graph( ) = gb_test_utils.random_hetero_graph(
num_nodes, num_edges, num_ntypes, num_etypes num_nodes, num_edges, num_ntypes, num_etypes
) )
...@@ -218,10 +219,11 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor): ...@@ -218,10 +219,11 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
graph = gb.from_fused_csc( graph = gb.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset, node_type_offset=node_type_offset,
type_per_edge, type_per_edge=type_per_edge,
edge_attributes, node_type_to_id=node_type_to_id,
metadata, edge_type_to_id=edge_type_to_id,
edge_attributes=edge_attributes,
) )
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment