Unverified Commit 1a8e0c00 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] remove GraphMeta and update API accordingly (#6709)

parent 50cbb0ec
...@@ -1255,14 +1255,12 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config): ...@@ -1255,14 +1255,12 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
graph, _, _, gpb, _, _, _ = load_partition( graph, _, _, gpb, _, _, _ = load_partition(
part_config, part_id, load_feats=False part_config, part_id, load_feats=False
) )
# Construct GraphMetadata.
_, _, ntypes, etypes = load_partition_book(part_config, part_id) _, _, ntypes, etypes = load_partition_book(part_config, part_id)
node_type_to_id = {ntype: ntid for ntid, ntype in enumerate(ntypes)} node_type_to_id = {ntype: ntid for ntid, ntype in enumerate(ntypes)}
edge_type_to_id = { edge_type_to_id = {
_etype_tuple_to_str(etype): etid _etype_tuple_to_str(etype): etid
for etid, etype in enumerate(etypes) for etid, etype in enumerate(etypes)
} }
metadata = graphbolt.GraphMetadata(node_type_to_id, edge_type_to_id)
# Obtain CSC indtpr and indices. # Obtain CSC indtpr and indices.
indptr, indices, _ = graph.adj().csc() indptr, indices, _ = graph.adj().csc()
# Initalize type per edge. # Initalize type per edge.
...@@ -1275,7 +1273,8 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config): ...@@ -1275,7 +1273,8 @@ def convert_dgl_partition_to_csc_sampling_graph(part_config):
indices, indices,
node_type_offset=None, node_type_offset=None,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
) )
orig_graph_path = os.path.join( orig_graph_path = os.path.join(
os.path.dirname(part_config), os.path.dirname(part_config),
......
...@@ -20,7 +20,6 @@ from .sampled_subgraph_impl import ( ...@@ -20,7 +20,6 @@ from .sampled_subgraph_impl import (
__all__ = [ __all__ = [
"GraphMetadata",
"FusedCSCSamplingGraph", "FusedCSCSamplingGraph",
"from_fused_csc", "from_fused_csc",
"load_from_shared_memory", "load_from_shared_memory",
...@@ -28,65 +27,6 @@ __all__ = [ ...@@ -28,65 +27,6 @@ __all__ = [
] ]
class GraphMetadata:
r"""Class for metadata of csc sampling graph."""
def __init__(
self,
node_type_to_id: Dict[str, int],
edge_type_to_id: Dict[str, int],
):
"""Initialize the GraphMetadata object.
Parameters
----------
node_type_to_id : Dict[str, int]
Dictionary from node types to node type IDs.
edge_type_to_id : Dict[str, int]
Dictionary from edge types to edge type IDs.
Raises
------
AssertionError
If any of the assertions fail.
"""
node_types = list(node_type_to_id.keys())
edge_types = list(edge_type_to_id.keys())
node_type_ids = list(node_type_to_id.values())
edge_type_ids = list(edge_type_to_id.values())
# Validate node_type_to_id.
assert all(
isinstance(x, str) for x in node_types
), "Node type name should be string."
assert all(
isinstance(x, int) for x in node_type_ids
), "Node type id should be int."
assert len(node_type_ids) == len(
set(node_type_ids)
), "Multiple node types shoud not be mapped to a same id."
# Validate edge_type_to_id.
for edge_type in edge_types:
src, edge, dst = etype_str_to_tuple(edge_type)
assert isinstance(edge, str), "Edge type name should be string."
assert (
src in node_types
), f"Unrecognized node type {src} in edge type {edge_type}"
assert (
dst in node_types
), f"Unrecognized node type {dst} in edge type {edge_type}"
assert all(
isinstance(x, int) for x in edge_type_ids
), "Edge type id should be int."
assert len(edge_type_ids) == len(
set(edge_type_ids)
), "Multiple edge types shoud not be mapped to a same id."
self.node_type_to_id = node_type_to_id
self.edge_type_to_id = edge_type_to_id
class FusedCSCSamplingGraph(SamplingGraph): class FusedCSCSamplingGraph(SamplingGraph):
r"""A sampling graph in CSC format.""" r"""A sampling graph in CSC format."""
...@@ -148,9 +88,11 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -148,9 +88,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> node_type_offset = torch.LongTensor([0, 2, 5]) >>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor( >>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3]) ... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> metadata = gb.GraphMetadata(ntypes, etypes) >>> graph = gb.from_fused_csc(indptr, indices,
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset, ... node_type_offset=node_type_offset,
... type_per_edge, None, metadata) ... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> print(graph.num_nodes) >>> print(graph.num_nodes)
{'N0': 2, 'N1': 3} {'N0': 2, 'N1': 3}
""" """
...@@ -158,14 +100,14 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -158,14 +100,14 @@ class FusedCSCSamplingGraph(SamplingGraph):
offset = self.node_type_offset offset = self.node_type_offset
# Homogenous. # Homogenous.
if offset is None or self.metadata is None: if offset is None or self.node_type_to_id is None:
return self._c_csc_graph.num_nodes() return self._c_csc_graph.num_nodes()
# Heterogenous # Heterogenous
else: else:
num_nodes_per_type = { num_nodes_per_type = {
_type: (offset[_idx + 1] - offset[_idx]).item() _type: (offset[_idx + 1] - offset[_idx]).item()
for _type, _idx in self.metadata.node_type_to_id.items() for _type, _idx in self.node_type_to_id.items()
} }
return num_nodes_per_type return num_nodes_per_type
...@@ -358,21 +300,6 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -358,21 +300,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
"""Sets the edge attributes dictionary.""" """Sets the edge attributes dictionary."""
self._c_csc_graph.set_edge_attributes(edge_attributes) self._c_csc_graph.set_edge_attributes(edge_attributes)
@property
def metadata(self) -> Optional[GraphMetadata]:
"""Returns the metadata of the graph.
[TODO][Rui] This API needs to be updated.
Returns
-------
GraphMetadata or None
If present, returns the metadata of the graph.
"""
if self.node_type_to_id is None or self.edge_type_to_id is None:
return None
return GraphMetadata(self.node_type_to_id, self.edge_type_to_id)
def in_subgraph( def in_subgraph(
self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]] self, nodes: Union[torch.Tensor, Dict[str, torch.Tensor]]
) -> FusedSampledSubgraphImpl: ) -> FusedSampledSubgraphImpl:
...@@ -405,14 +332,16 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -405,14 +332,16 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> ntypes = {"N0": 0, "N1": 1} >>> ntypes = {"N0": 0, "N1": 1}
>>> etypes = { >>> etypes = {
... "N0:R0:N0": 0, "N0:R1:N1": 1, "N1:R2:N0": 2, "N1:R3:N1": 3} ... "N0:R0:N0": 0, "N0:R1:N1": 1, "N1:R2:N0": 2, "N1:R3:N1": 3}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12]) >>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4]) >>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
>>> node_type_offset = torch.LongTensor([0, 2, 5]) >>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor( >>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3]) ... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset, >>> graph = gb.from_fused_csc(indptr, indices,
... type_per_edge, None, metadata) ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])} >>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
>>> in_subgraph = graph.in_subgraph(nodes) >>> in_subgraph = graph.in_subgraph(nodes)
>>> print(in_subgraph.node_pairs) >>> print(in_subgraph.node_pairs)
...@@ -465,10 +394,10 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -465,10 +394,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
# converted to heterogeneous graphs. # converted to heterogeneous graphs.
node_pairs = defaultdict(list) node_pairs = defaultdict(list)
original_hetero_edge_ids = {} original_hetero_edge_ids = {}
for etype, etype_id in self.metadata.edge_type_to_id.items(): for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype) src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype] src_ntype_id = self.node_type_to_id[src_ntype]
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype] dst_ntype_id = self.node_type_to_id[dst_ntype]
mask = type_per_edge == etype_id mask = type_per_edge == etype_id
hetero_row = row[mask] - self.node_type_offset[src_ntype_id] hetero_row = row[mask] - self.node_type_offset[src_ntype_id]
hetero_column = ( hetero_column = (
...@@ -486,7 +415,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -486,7 +415,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
def _convert_to_homogeneous_nodes(self, nodes): def _convert_to_homogeneous_nodes(self, nodes):
homogeneous_nodes = [] homogeneous_nodes = []
for ntype, ids in nodes.items(): for ntype, ids in nodes.items():
ntype_id = self.metadata.node_type_to_id[ntype] ntype_id = self.node_type_to_id[ntype]
homogeneous_nodes.append(ids + self.node_type_offset[ntype_id]) homogeneous_nodes.append(ids + self.node_type_offset[ntype_id])
return torch.cat(homogeneous_nodes) return torch.cat(homogeneous_nodes)
...@@ -526,7 +455,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -526,7 +455,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
subgraph_indptr = {} subgraph_indptr = {}
node_edge_type = defaultdict(list) node_edge_type = defaultdict(list)
original_hetero_edge_ids = {} original_hetero_edge_ids = {}
for etype, etype_id in self.metadata.edge_type_to_id.items(): for etype, etype_id in self.edge_type_to_id.items():
subgraph_indice[etype] = torch.empty( subgraph_indice[etype] = torch.empty(
(num.get(etype_id, 0),), dtype=indices.dtype (num.get(etype_id, 0),), dtype=indices.dtype
) )
...@@ -539,7 +468,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -539,7 +468,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
# Preprocessing saves the type of seed_nodes as the edge type # Preprocessing saves the type of seed_nodes as the edge type
# of dst_ntype. # of dst_ntype.
_, _, dst_ntype = etype_str_to_tuple(etype) _, _, dst_ntype = etype_str_to_tuple(etype)
dst_ntype_id = self.metadata.node_type_to_id[dst_ntype] dst_ntype_id = self.node_type_to_id[dst_ntype]
node_edge_type[dst_ntype_id].append((etype, etype_id)) node_edge_type[dst_ntype_id].append((etype, etype_id))
# construct subgraphs # construct subgraphs
for i, seed in enumerate(column): for i, seed in enumerate(column):
...@@ -553,7 +482,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -553,7 +482,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
) )
for etype, etype_id in node_edge_type[node_type]: for etype, etype_id in node_edge_type[node_type]:
src_ntype, _, _ = etype_str_to_tuple(etype) src_ntype, _, _ = etype_str_to_tuple(etype)
src_ntype_id = self.metadata.node_type_to_id[src_ntype] src_ntype_id = self.node_type_to_id[src_ntype]
num_edges = torch.searchsorted( num_edges = torch.searchsorted(
type_per_edge[l:r], etype_id, right=True type_per_edge[l:r], etype_id, right=True
).item() ).item()
...@@ -578,7 +507,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -578,7 +507,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
indptr=torch.tensor(subgraph_indptr[etype]), indptr=torch.tensor(subgraph_indptr[etype]),
indices=subgraph_indice[etype], indices=subgraph_indice[etype],
) )
for etype in self.metadata.edge_type_to_id.keys() for etype in self.edge_type_to_id.keys()
} }
return SampledSubgraphImpl( return SampledSubgraphImpl(
node_pairs=node_pairs, node_pairs=node_pairs,
...@@ -644,13 +573,15 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -644,13 +573,15 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> import torch >>> import torch
>>> ntypes = {"n1": 0, "n2": 1} >>> ntypes = {"n1": 0, "n2": 1}
>>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9]) >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
>>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
>>> node_type_offset = torch.LongTensor([0, 2, 5]) >>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0]) >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
>>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge, >>> graph = gb.from_fused_csc(indptr, indices,
... node_type_offset=node_type_offset, metadata=metadata) ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])} >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1]) >>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_neighbors(nodes, fanouts) >>> subgraph = graph.sample_neighbors(nodes, fanouts)
...@@ -673,8 +604,8 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -673,8 +604,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
assert nodes.dim() == 1, "Nodes should be 1-D tensor." assert nodes.dim() == 1, "Nodes should be 1-D tensor."
assert fanouts.dim() == 1, "Fanouts should be 1-D tensor." assert fanouts.dim() == 1, "Fanouts should be 1-D tensor."
expected_fanout_len = 1 expected_fanout_len = 1
if self.metadata and self.metadata.edge_type_to_id: if self.edge_type_to_id:
expected_fanout_len = len(self.metadata.edge_type_to_id) expected_fanout_len = len(self.edge_type_to_id)
assert len(fanouts) in [ assert len(fanouts) in [
expected_fanout_len, expected_fanout_len,
1, 1,
...@@ -830,13 +761,15 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -830,13 +761,15 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> import torch >>> import torch
>>> ntypes = {"n1": 0, "n2": 1} >>> ntypes = {"n1": 0, "n2": 1}
>>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} >>> etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9]) >>> indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
>>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) >>> indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
>>> node_type_offset = torch.LongTensor([0, 2, 5]) >>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0]) >>> type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
>>> graph = gb.from_fused_csc(indptr, indices, type_per_edge=type_per_edge, >>> graph = gb.from_fused_csc(indptr, indices,
... node_type_offset=node_type_offset, metadata=metadata) ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge,
... node_type_to_id=ntypes,
... edge_type_to_id=etypes)
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])} >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1]) >>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_layer_neighbors(nodes, fanouts) >>> subgraph = graph.sample_layer_neighbors(nodes, fanouts)
...@@ -907,7 +840,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -907,7 +840,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
), "The 'node_type_offset' array is necessary for performing \ ), "The 'node_type_offset' array is necessary for performing \
negative sampling by edge type." negative sampling by edge type."
_, _, dst_node_type = etype_str_to_tuple(edge_type) _, _, dst_node_type = etype_str_to_tuple(edge_type)
dst_node_type_id = self.metadata.node_type_to_id[dst_node_type] dst_node_type_id = self.node_type_to_id[dst_node_type]
max_node_id = ( max_node_id = (
self.node_type_offset[dst_node_type_id + 1] self.node_type_offset[dst_node_type_id + 1]
- self.node_type_offset[dst_node_type_id] - self.node_type_offset[dst_node_type_id]
...@@ -965,8 +898,9 @@ def from_fused_csc( ...@@ -965,8 +898,9 @@ def from_fused_csc(
indices: torch.Tensor, indices: torch.Tensor,
node_type_offset: Optional[torch.tensor] = None, node_type_offset: Optional[torch.tensor] = None,
type_per_edge: Optional[torch.tensor] = None, type_per_edge: Optional[torch.tensor] = None,
node_type_to_id: Optional[Dict[str, int]] = None,
edge_type_to_id: Optional[Dict[str, int]] = None,
edge_attributes: Optional[Dict[str, torch.tensor]] = None, edge_attributes: Optional[Dict[str, torch.tensor]] = None,
metadata: Optional[GraphMetadata] = None,
) -> FusedCSCSamplingGraph: ) -> FusedCSCSamplingGraph:
"""Create a FusedCSCSamplingGraph object from a CSC representation. """Create a FusedCSCSamplingGraph object from a CSC representation.
...@@ -982,10 +916,12 @@ def from_fused_csc( ...@@ -982,10 +916,12 @@ def from_fused_csc(
Offset of node types in the graph, by default None. Offset of node types in the graph, by default None.
type_per_edge : Optional[torch.tensor], optional type_per_edge : Optional[torch.tensor], optional
Type ids of each edge in the graph, by default None. Type ids of each edge in the graph, by default None.
node_type_to_id : Optional[Dict[str, int]], optional
Map node types to ids, by default None.
edge_type_to_id : Optional[Dict[str, int]], optional
Map edge types to ids, by default None.
edge_attributes: Optional[Dict[str, torch.tensor]], optional edge_attributes: Optional[Dict[str, torch.tensor]], optional
Edge attributes of the graph, by default None. Edge attributes of the graph, by default None.
metadata: Optional[GraphMetadata], optional
Metadata of the graph, by default None.
Returns Returns
------- -------
...@@ -995,8 +931,7 @@ def from_fused_csc( ...@@ -995,8 +931,7 @@ def from_fused_csc(
Examples Examples
-------- --------
>>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2} >>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2}
>>> etypes = {('n1', 'e1', 'n2'): 0, ('n1', 'e2', 'n3'): 1} >>> etypes = {'n1:e1:n2': 0, 'n1:e2:n3': 1}
>>> metadata = graphbolt.GraphMetadata(ntypes, etypes)
>>> csc_indptr = torch.tensor([0, 2, 5, 7]) >>> csc_indptr = torch.tensor([0, 2, 5, 7])
>>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3]) >>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3])
>>> node_type_offset = torch.tensor([0, 1, 2, 3]) >>> node_type_offset = torch.tensor([0, 1, 2, 3])
...@@ -1004,18 +939,50 @@ def from_fused_csc( ...@@ -1004,18 +939,50 @@ def from_fused_csc(
>>> graph = graphbolt.from_fused_csc(csc_indptr, indices, >>> graph = graphbolt.from_fused_csc(csc_indptr, indices,
... node_type_offset=node_type_offset, ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge, ... type_per_edge=type_per_edge,
... edge_attributes=None, metadata=metadata) ... node_type_to_id=ntypes, edge_type_to_id=etypes,
... edge_attributes=None,)
>>> print(graph) >>> print(graph)
FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]), FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]), indices=tensor([1, 3, 0, 1, 2, 0, 3]),
total_num_nodes=3, total_num_edges=7) total_num_nodes=3, total_num_edges=7)
""" """
if metadata and metadata.node_type_to_id and node_type_offset is not None: if node_type_to_id is not None and edge_type_to_id is not None:
assert len(metadata.node_type_to_id) + 1 == node_type_offset.size( node_types = list(node_type_to_id.keys())
edge_types = list(edge_type_to_id.keys())
node_type_ids = list(node_type_to_id.values())
edge_type_ids = list(edge_type_to_id.values())
# Validate node_type_to_id.
assert all(
isinstance(x, str) for x in node_types
), "Node type name should be string."
assert all(
isinstance(x, int) for x in node_type_ids
), "Node type id should be int."
assert len(node_type_ids) == len(
set(node_type_ids)
), "Multiple node types shoud not be mapped to a same id."
# Validate edge_type_to_id.
for edge_type in edge_types:
src, edge, dst = etype_str_to_tuple(edge_type)
assert isinstance(edge, str), "Edge type name should be string."
assert (
src in node_types
), f"Unrecognized node type {src} in edge type {edge_type}"
assert (
dst in node_types
), f"Unrecognized node type {dst} in edge type {edge_type}"
assert all(
isinstance(x, int) for x in edge_type_ids
), "Edge type id should be int."
assert len(edge_type_ids) == len(
set(edge_type_ids)
), "Multiple edge types shoud not be mapped to a same id."
if node_type_offset is not None:
assert len(node_type_to_id) + 1 == node_type_offset.size(
0 0
), "node_type_offset length should be |ntypes| + 1." ), "node_type_offset length should be |ntypes| + 1."
node_type_to_id = metadata.node_type_to_id if metadata else None
edge_type_to_id = metadata.edge_type_to_id if metadata else None
return FusedCSCSamplingGraph( return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc( torch.ops.graphbolt.from_fused_csc(
csc_indptr, csc_indptr,
...@@ -1091,7 +1058,8 @@ def from_dglgraph( ...@@ -1091,7 +1058,8 @@ def from_dglgraph(
homo_g, ntype_count, _ = to_homogeneous(g, return_count=True) homo_g, ntype_count, _ = to_homogeneous(g, return_count=True)
if is_homogeneous: if is_homogeneous:
metadata = None node_type_to_id = None
edge_type_to_id = None
else: else:
# Initialize metadata. # Initialize metadata.
node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes} node_type_to_id = {ntype: g.get_ntype_id(ntype) for ntype in g.ntypes}
...@@ -1099,7 +1067,6 @@ def from_dglgraph( ...@@ -1099,7 +1067,6 @@ def from_dglgraph(
etype_tuple_to_str(etype): g.get_etype_id(etype) etype_tuple_to_str(etype): g.get_etype_id(etype)
for etype in g.canonical_etypes for etype in g.canonical_etypes
} }
metadata = GraphMetadata(node_type_to_id, edge_type_to_id)
# Obtain CSC matrix. # Obtain CSC matrix.
indptr, indices, edge_ids = homo_g.adj_tensors("csc") indptr, indices, edge_ids = homo_g.adj_tensors("csc")
...@@ -1118,8 +1085,6 @@ def from_dglgraph( ...@@ -1118,8 +1085,6 @@ def from_dglgraph(
# Assign edge attributes according to the original eids mapping. # Assign edge attributes according to the original eids mapping.
edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids] edge_attributes[ORIGINAL_EDGE_ID] = homo_g.edata[EID][edge_ids]
node_type_to_id = metadata.node_type_to_id if metadata else None
edge_type_to_id = metadata.edge_type_to_id if metadata else None
return FusedCSCSamplingGraph( return FusedCSCSamplingGraph(
torch.ops.graphbolt.from_fused_csc( torch.ops.graphbolt.from_fused_csc(
indptr, indptr,
......
...@@ -112,7 +112,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -112,7 +112,7 @@ class NeighborSampler(SubgraphSampler):
num_layers = len(self.fanouts) num_layers = len(self.fanouts)
# Enrich seeds with all node types. # Enrich seeds with all node types.
if isinstance(seeds, dict): if isinstance(seeds, dict):
ntypes = list(self.graph.metadata.node_type_to_id.keys()) ntypes = list(self.graph.node_type_to_id.keys())
seeds = { seeds = {
ntype: seeds.get(ntype, torch.LongTensor([])) ntype: seeds.get(ntype, torch.LongTensor([]))
for ntype in ntypes for ntype in ntypes
......
...@@ -705,9 +705,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo( ...@@ -705,9 +705,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_homo(
assert th.equal(orig_indices, new_g.indices) assert th.equal(orig_indices, new_g.indices)
assert new_g.node_type_offset is None assert new_g.node_type_offset is None
assert all(new_g.type_per_edge == 0) assert all(new_g.type_per_edge == 0)
for node_type, type_id in new_g.metadata.node_type_to_id.items(): for node_type, type_id in new_g.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.metadata.edge_type_to_id.items(): for edge_type, type_id in new_g.edge_type_to_id.items():
assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
...@@ -736,9 +736,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero( ...@@ -736,9 +736,9 @@ def test_convert_dgl_partition_to_csc_sampling_graph_hetero(
orig_indptr, orig_indices, _ = orig_g.adj().csc() orig_indptr, orig_indices, _ = orig_g.adj().csc()
assert th.equal(orig_indptr, new_g.csc_indptr) assert th.equal(orig_indptr, new_g.csc_indptr)
assert th.equal(orig_indices, new_g.indices) assert th.equal(orig_indices, new_g.indices)
for node_type, type_id in new_g.metadata.node_type_to_id.items(): for node_type, type_id in new_g.node_type_to_id.items():
assert g.get_ntype_id(node_type) == type_id assert g.get_ntype_id(node_type) == type_id
for edge_type, type_id in new_g.metadata.edge_type_to_id.items(): for edge_type, type_id in new_g.edge_type_to_id.items():
assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id assert g.get_etype_id(_etype_str_to_tuple(edge_type)) == type_id
assert new_g.node_type_offset is None assert new_g.node_type_offset is None
assert th.equal(orig_g.edata[dgl.ETYPE], new_g.type_per_edge) assert th.equal(orig_g.edata[dgl.ETYPE], new_g.type_per_edge)
......
...@@ -32,7 +32,7 @@ def random_homo_graph(num_nodes, num_edges): ...@@ -32,7 +32,7 @@ def random_homo_graph(num_nodes, num_edges):
return csc_indptr, indices return csc_indptr, indices
def get_metadata(num_ntypes, num_etypes): def get_type_to_id(num_ntypes, num_etypes):
ntypes = {f"n{i}": i for i in range(num_ntypes)} ntypes = {f"n{i}": i for i in range(num_ntypes)}
etypes = {} etypes = {}
count = 0 count = 0
...@@ -42,7 +42,7 @@ def get_metadata(num_ntypes, num_etypes): ...@@ -42,7 +42,7 @@ def get_metadata(num_ntypes, num_etypes):
break break
etypes.update({f"n{n1}:e{count}:n{n2}": count}) etypes.update({f"n{n1}:e{count}:n{n2}": count})
count += 1 count += 1
return gb.GraphMetadata(ntypes, etypes) return ntypes, etypes
def get_ntypes_and_etypes(num_nodes, num_ntypes, num_etypes): def get_ntypes_and_etypes(num_nodes, num_ntypes, num_etypes):
...@@ -82,7 +82,8 @@ def random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes): ...@@ -82,7 +82,8 @@ def random_hetero_graph(num_nodes, num_edges, num_ntypes, num_etypes):
gb_g.indices, gb_g.indices,
gb_g.node_type_offset, gb_g.node_type_offset,
gb_g.type_per_edge, gb_g.type_per_edge,
gb_g.metadata, gb_g.node_type_to_id,
gb_g.edge_type_to_id,
) )
......
...@@ -42,7 +42,9 @@ def test_empty_graph(total_num_nodes): ...@@ -42,7 +42,9 @@ def test_empty_graph(total_num_nodes):
def test_hetero_empty_graph(total_num_nodes): def test_hetero_empty_graph(total_num_nodes):
csc_indptr = torch.zeros((total_num_nodes + 1,), dtype=int) csc_indptr = torch.zeros((total_num_nodes + 1,), dtype=int)
indices = torch.tensor([]) indices = torch.tensor([])
metadata = gbt.get_metadata(num_ntypes=3, num_etypes=5) node_type_to_id, edge_type_to_id = gbt.get_type_to_id(
num_ntypes=3, num_etypes=5
)
# Some node types have no nodes. # Some node types have no nodes.
if total_num_nodes == 0: if total_num_nodes == 0:
node_type_offset = torch.zeros((4,), dtype=int) node_type_offset = torch.zeros((4,), dtype=int)
...@@ -56,17 +58,18 @@ def test_hetero_empty_graph(total_num_nodes): ...@@ -56,17 +58,18 @@ def test_hetero_empty_graph(total_num_nodes):
graph = gb.from_fused_csc( graph = gb.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset, node_type_offset=node_type_offset,
type_per_edge, type_per_edge=type_per_edge,
None, node_type_to_id=node_type_to_id,
metadata, edge_type_to_id=edge_type_to_id,
edge_attributes=None,
) )
assert graph.total_num_edges == 0 assert graph.total_num_edges == 0
assert graph.total_num_nodes == total_num_nodes assert graph.total_num_nodes == total_num_nodes
assert torch.equal(graph.csc_indptr, csc_indptr) assert torch.equal(graph.csc_indptr, csc_indptr)
assert torch.equal(graph.indices, indices) assert torch.equal(graph.indices, indices)
assert graph.metadata.node_type_to_id == metadata.node_type_to_id assert graph.node_type_to_id == node_type_to_id
assert graph.metadata.edge_type_to_id == metadata.edge_type_to_id assert graph.edge_type_to_id == edge_type_to_id
assert torch.equal(graph.node_type_offset, node_type_offset) assert torch.equal(graph.node_type_offset, node_type_offset)
assert torch.equal(graph.type_per_edge, type_per_edge) assert torch.equal(graph.type_per_edge, type_per_edge)
...@@ -78,9 +81,11 @@ def test_hetero_empty_graph(total_num_nodes): ...@@ -78,9 +81,11 @@ def test_hetero_empty_graph(total_num_nodes):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"ntypes", [{"n1": 1, "n2": 1}, {5: 1, "n2": 2}, {"n1": 1.5, "n2": 2.0}] "ntypes", [{"n1": 1, "n2": 1}, {5: 1, "n2": 2}, {"n1": 1.5, "n2": 2.0}]
) )
def test_metadata_with_ntype_exception(ntypes): def test_type_to_id_with_ntype_exception(ntypes):
with pytest.raises(Exception): with pytest.raises(AssertionError):
gb.GraphMetadata(ntypes, {"n1:e1:n2": 1}) gb.from_fused_csc(
None, None, node_type_to_id=ntypes, edge_type_to_id={"e1": 1}
)
@unittest.skipIf( @unittest.skipIf(
...@@ -99,9 +104,14 @@ def test_metadata_with_ntype_exception(ntypes): ...@@ -99,9 +104,14 @@ def test_metadata_with_ntype_exception(ntypes):
{"n1:e1:n2": 1.5}, {"n1:e1:n2": 1.5},
], ],
) )
def test_metadata_with_etype_exception(etypes): def test_type_to_id_with_etype_exception(etypes):
with pytest.raises(Exception): with pytest.raises(Exception):
gb.GraphMetadata({"n1": 0, "n2": 1, "n3": 2}, etypes) gb.from_fused_csc(
None,
None,
node_type_to_id={"n1": 0, "n2": 1, "n3": 2},
edge_type_to_id=etypes,
)
@unittest.skipIf( @unittest.skipIf(
...@@ -131,9 +141,10 @@ def test_homo_graph(total_num_nodes, total_num_edges): ...@@ -131,9 +141,10 @@ def test_homo_graph(total_num_nodes, total_num_edges):
assert torch.equal(indices, graph.indices) assert torch.equal(indices, graph.indices)
assert graph.edge_attributes == edge_attributes assert graph.edge_attributes == edge_attributes
assert graph.metadata is None
assert graph.node_type_offset is None assert graph.node_type_offset is None
assert graph.type_per_edge is None assert graph.type_per_edge is None
assert graph.node_type_to_id is None
assert graph.edge_type_to_id is None
@unittest.skipIf( @unittest.skipIf(
...@@ -151,7 +162,8 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes): ...@@ -151,7 +162,8 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
indices, indices,
node_type_offset, node_type_offset,
type_per_edge, type_per_edge,
metadata, node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph( ) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes total_num_nodes, total_num_edges, num_ntypes, num_etypes
) )
...@@ -162,10 +174,11 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes): ...@@ -162,10 +174,11 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
graph = gb.from_fused_csc( graph = gb.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset, node_type_offset=node_type_offset,
type_per_edge, type_per_edge=type_per_edge,
edge_attributes, node_type_to_id=node_type_to_id,
metadata, edge_type_to_id=edge_type_to_id,
edge_attributes=edge_attributes,
) )
assert graph.total_num_nodes == total_num_nodes assert graph.total_num_nodes == total_num_nodes
...@@ -176,8 +189,8 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes): ...@@ -176,8 +189,8 @@ def test_hetero_graph(total_num_nodes, total_num_edges, num_ntypes, num_etypes):
assert torch.equal(node_type_offset, graph.node_type_offset) assert torch.equal(node_type_offset, graph.node_type_offset)
assert torch.equal(type_per_edge, graph.type_per_edge) assert torch.equal(type_per_edge, graph.type_per_edge)
assert graph.edge_attributes == edge_attributes assert graph.edge_attributes == edge_attributes
assert metadata.node_type_to_id == graph.metadata.node_type_to_id assert node_type_to_id == graph.node_type_to_id
assert metadata.edge_type_to_id == graph.metadata.edge_type_to_id assert edge_type_to_id == graph.edge_type_to_id
@unittest.skipIf( @unittest.skipIf(
...@@ -247,9 +260,13 @@ def test_num_nodes_hetero(): ...@@ -247,9 +260,13 @@ def test_num_nodes_hetero():
assert all(type_per_edge < len(etypes)) assert all(type_per_edge < len(etypes))
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
metadata = gb.GraphMetadata(ntypes, etypes)
graph = gb.from_fused_csc( graph = gb.from_fused_csc(
indptr, indices, node_type_offset, type_per_edge, None, metadata indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
# Verify nodes number per node types. # Verify nodes number per node types.
...@@ -283,12 +300,22 @@ def test_num_nodes_hetero(): ...@@ -283,12 +300,22 @@ def test_num_nodes_hetero():
) )
def test_node_type_offset_wrong_legnth(node_type_offset): def test_node_type_offset_wrong_legnth(node_type_offset):
num_ntypes = 3 num_ntypes = 3
csc_indptr, indices, _, type_per_edge, metadata = gbt.random_hetero_graph( (
10, 50, num_ntypes, 5 csc_indptr,
) indices,
_,
type_per_edge,
node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(10, 50, num_ntypes, 5)
with pytest.raises(Exception): with pytest.raises(Exception):
gb.from_fused_csc( gb.from_fused_csc(
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata csc_indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
) )
...@@ -317,9 +344,11 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges): ...@@ -317,9 +344,11 @@ def test_load_save_homo_graph(total_num_nodes, total_num_edges):
assert torch.equal(graph.csc_indptr, graph2.csc_indptr) assert torch.equal(graph.csc_indptr, graph2.csc_indptr)
assert torch.equal(graph.indices, graph2.indices) assert torch.equal(graph.indices, graph2.indices)
assert graph.metadata is None and graph2.metadata is None
assert graph.node_type_offset is None and graph2.node_type_offset is None assert graph.node_type_offset is None and graph2.node_type_offset is None
assert graph.type_per_edge is None and graph2.type_per_edge is None assert graph.type_per_edge is None and graph2.type_per_edge is None
assert graph.node_type_to_id is None and graph2.node_type_to_id is None
assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None
assert graph.edge_attributes is None and graph2.edge_attributes is None
@unittest.skipIf( @unittest.skipIf(
...@@ -339,12 +368,18 @@ def test_load_save_hetero_graph( ...@@ -339,12 +368,18 @@ def test_load_save_hetero_graph(
indices, indices,
node_type_offset, node_type_offset,
type_per_edge, type_per_edge,
metadata, node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph( ) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes total_num_nodes, total_num_edges, num_ntypes, num_etypes
) )
graph = gb.from_fused_csc( graph = gb.from_fused_csc(
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata csc_indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
) )
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
...@@ -359,8 +394,8 @@ def test_load_save_hetero_graph( ...@@ -359,8 +394,8 @@ def test_load_save_hetero_graph(
assert torch.equal(graph.indices, graph2.indices) assert torch.equal(graph.indices, graph2.indices)
assert torch.equal(graph.node_type_offset, graph2.node_type_offset) assert torch.equal(graph.node_type_offset, graph2.node_type_offset)
assert torch.equal(graph.type_per_edge, graph2.type_per_edge) assert torch.equal(graph.type_per_edge, graph2.type_per_edge)
assert graph.metadata.node_type_to_id == graph2.metadata.node_type_to_id assert graph.node_type_to_id == graph2.node_type_to_id
assert graph.metadata.edge_type_to_id == graph2.metadata.edge_type_to_id assert graph.edge_type_to_id == graph2.edge_type_to_id
@unittest.skipIf( @unittest.skipIf(
...@@ -410,7 +445,8 @@ def test_pickle_hetero_graph( ...@@ -410,7 +445,8 @@ def test_pickle_hetero_graph(
indices, indices,
node_type_offset, node_type_offset,
type_per_edge, type_per_edge,
metadata, node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph( ) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes total_num_nodes, total_num_edges, num_ntypes, num_etypes
) )
...@@ -421,10 +457,11 @@ def test_pickle_hetero_graph( ...@@ -421,10 +457,11 @@ def test_pickle_hetero_graph(
graph = gb.from_fused_csc( graph = gb.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset, node_type_offset=node_type_offset,
type_per_edge, type_per_edge=type_per_edge,
edge_attributes, node_type_to_id=node_type_to_id,
metadata, edge_type_to_id=edge_type_to_id,
edge_attributes=edge_attributes,
) )
serialized = pickle.dumps(graph) serialized = pickle.dumps(graph)
...@@ -466,7 +503,8 @@ def test_multiprocessing(): ...@@ -466,7 +503,8 @@ def test_multiprocessing():
indices, indices,
node_type_offset, node_type_offset,
type_per_edge, type_per_edge,
metadata, node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph( ) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes total_num_nodes, total_num_edges, num_ntypes, num_etypes
) )
...@@ -476,10 +514,11 @@ def test_multiprocessing(): ...@@ -476,10 +514,11 @@ def test_multiprocessing():
graph = gb.from_fused_csc( graph = gb.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset, node_type_offset=node_type_offset,
type_per_edge, type_per_edge=type_per_edge,
edge_attributes, node_type_to_id=node_type_to_id,
metadata, edge_type_to_id=edge_type_to_id,
edge_attributes=edge_attributes,
) )
p = mp.Process( p = mp.Process(
...@@ -572,9 +611,13 @@ def test_in_subgraph_heterogeneous(): ...@@ -572,9 +611,13 @@ def test_in_subgraph_heterogeneous():
assert all(type_per_edge < len(etypes)) assert all(type_per_edge < len(etypes))
# Construct FusedCSCSamplingGraph. # Construct FusedCSCSamplingGraph.
metadata = gb.GraphMetadata(ntypes, etypes)
graph = gb.from_fused_csc( graph = gb.from_fused_csc(
indptr, indices, node_type_offset, type_per_edge, None, metadata indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
# Extract in subgraph. # Extract in subgraph.
...@@ -678,8 +721,6 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype): ...@@ -678,8 +721,6 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
# Initialize data. # Initialize data.
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9 total_num_edges = 9
indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype) indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype)
indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype) indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype)
...@@ -694,7 +735,8 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype): ...@@ -694,7 +735,8 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
# Sample on both node types. # Sample on both node types.
...@@ -787,8 +829,6 @@ def test_sample_neighbors_fanouts( ...@@ -787,8 +829,6 @@ def test_sample_neighbors_fanouts(
# Initialize data. # Initialize data.
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9 total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9]) indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
...@@ -803,7 +843,8 @@ def test_sample_neighbors_fanouts( ...@@ -803,7 +843,8 @@ def test_sample_neighbors_fanouts(
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
...@@ -845,8 +886,6 @@ def test_sample_neighbors_replace( ...@@ -845,8 +886,6 @@ def test_sample_neighbors_replace(
# Initialize data. # Initialize data.
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9 total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9]) indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
...@@ -861,7 +900,8 @@ def test_sample_neighbors_replace( ...@@ -861,7 +900,8 @@ def test_sample_neighbors_replace(
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
...@@ -888,7 +928,6 @@ def test_sample_neighbors_return_eids_homo(labor): ...@@ -888,7 +928,6 @@ def test_sample_neighbors_return_eids_homo(labor):
1 0 0 0 1 1 0 0 0 1
""" """
# Initialize data. # Initialize data.
total_num_nodes = 5
total_num_edges = 12 total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12]) indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4]) indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
...@@ -933,8 +972,6 @@ def test_sample_neighbors_return_eids_hetero(labor): ...@@ -933,8 +972,6 @@ def test_sample_neighbors_return_eids_hetero(labor):
# Initialize data. # Initialize data.
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9 total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9]) indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
...@@ -952,8 +989,9 @@ def test_sample_neighbors_return_eids_hetero(labor): ...@@ -952,8 +989,9 @@ def test_sample_neighbors_return_eids_hetero(labor):
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
node_type_to_id=ntypes,
edge_type_to_id=etypes,
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
metadata=metadata,
) )
# Sample on both node types. # Sample on both node types.
...@@ -991,7 +1029,6 @@ def test_sample_neighbors_probs(replace, labor, probs_name): ...@@ -991,7 +1029,6 @@ def test_sample_neighbors_probs(replace, labor, probs_name):
1 0 0 0 1 1 0 0 0 1
""" """
# Initialize data. # Initialize data.
total_num_nodes = 5
total_num_edges = 12 total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12]) indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4]) indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
...@@ -1042,7 +1079,6 @@ def test_sample_neighbors_probs(replace, labor, probs_name): ...@@ -1042,7 +1079,6 @@ def test_sample_neighbors_probs(replace, labor, probs_name):
) )
def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask): def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
# Initialize data. # Initialize data.
total_num_nodes = 5
total_num_edges = 12 total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12]) indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4]) indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
...@@ -1143,9 +1179,10 @@ def test_homo_graph_on_shared_memory( ...@@ -1143,9 +1179,10 @@ def test_homo_graph_on_shared_memory(
graph1.edge_attributes[name], graph2.edge_attributes[name] graph1.edge_attributes[name], graph2.edge_attributes[name]
) )
assert graph1.metadata is None and graph2.metadata is None
assert graph1.node_type_offset is None and graph2.node_type_offset is None assert graph1.node_type_offset is None and graph2.node_type_offset is None
assert graph1.type_per_edge is None and graph2.type_per_edge is None assert graph1.type_per_edge is None and graph2.type_per_edge is None
assert graph1.node_type_to_id is None and graph2.node_type_to_id is None
assert graph1.edge_type_to_id is None and graph2.edge_type_to_id is None
@unittest.skipIf( @unittest.skipIf(
...@@ -1168,7 +1205,8 @@ def test_hetero_graph_on_shared_memory( ...@@ -1168,7 +1205,8 @@ def test_hetero_graph_on_shared_memory(
indices, indices,
node_type_offset, node_type_offset,
type_per_edge, type_per_edge,
metadata, node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph( ) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes total_num_nodes, total_num_edges, num_ntypes, num_etypes
) )
...@@ -1185,8 +1223,9 @@ def test_hetero_graph_on_shared_memory( ...@@ -1185,8 +1223,9 @@ def test_hetero_graph_on_shared_memory(
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
metadata=metadata,
) )
shm_name = "test_hetero_g" shm_name = "test_hetero_g"
...@@ -1231,10 +1270,10 @@ def test_hetero_graph_on_shared_memory( ...@@ -1231,10 +1270,10 @@ def test_hetero_graph_on_shared_memory(
graph1.edge_attributes[name], graph2.edge_attributes[name] graph1.edge_attributes[name], graph2.edge_attributes[name]
) )
assert metadata.node_type_to_id == graph1.metadata.node_type_to_id assert node_type_to_id == graph1.node_type_to_id
assert metadata.edge_type_to_id == graph1.metadata.edge_type_to_id assert edge_type_to_id == graph1.edge_type_to_id
assert metadata.node_type_to_id == graph2.metadata.node_type_to_id assert node_type_to_id == graph2.node_type_to_id
assert metadata.edge_type_to_id == graph2.metadata.edge_type_to_id assert edge_type_to_id == graph2.edge_type_to_id
def process_csc_sampling_graph_on_shared_memory(graph, data_queue, flag_queue): def process_csc_sampling_graph_on_shared_memory(graph, data_queue, flag_queue):
...@@ -1293,7 +1332,8 @@ def test_multiprocessing_with_shared_memory(): ...@@ -1293,7 +1332,8 @@ def test_multiprocessing_with_shared_memory():
indices, indices,
node_type_offset, node_type_offset,
type_per_edge, type_per_edge,
metadata, node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph( ) = gbt.random_hetero_graph(
total_num_nodes, total_num_edges, num_ntypes, num_etypes total_num_nodes, total_num_edges, num_ntypes, num_etypes
) )
...@@ -1308,8 +1348,9 @@ def test_multiprocessing_with_shared_memory(): ...@@ -1308,8 +1348,9 @@ def test_multiprocessing_with_shared_memory():
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
edge_attributes=None, edge_attributes=None,
metadata=metadata,
) )
ctx = mp.get_context("spawn") # Use spawn method. ctx = mp.get_context("spawn") # Use spawn method.
...@@ -1374,7 +1415,8 @@ def test_from_dglgraph_homogeneous(): ...@@ -1374,7 +1415,8 @@ def test_from_dglgraph_homogeneous():
assert gb_g.total_num_edges == dgl_g.num_edges() assert gb_g.total_num_edges == dgl_g.num_edges()
assert gb_g.node_type_offset is None assert gb_g.node_type_offset is None
assert gb_g.type_per_edge is None assert gb_g.type_per_edge is None
assert gb_g.metadata is None assert gb_g.node_type_to_id is None
assert gb_g.edge_type_to_id is None
@unittest.skipIf( @unittest.skipIf(
...@@ -1432,7 +1474,7 @@ def test_from_dglgraph_heterogeneous(): ...@@ -1432,7 +1474,7 @@ def test_from_dglgraph_heterogeneous():
list( list(
map( map(
lambda ss: ss.split(":")[1], lambda ss: ss.split(":")[1],
gb_g.metadata.edge_type_to_id.keys(), gb_g.edge_type_to_id.keys(),
) )
) )
== dgl_g.etypes == dgl_g.etypes
...@@ -1456,13 +1498,13 @@ def test_from_dglgraph_heterogeneous(): ...@@ -1456,13 +1498,13 @@ def test_from_dglgraph_heterogeneous():
[3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 2] [3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 1, 2]
), ),
) )
assert gb_g.metadata.node_type_to_id == { assert gb_g.node_type_to_id == {
"author": 0, "author": 0,
"field": 1, "field": 1,
"institution": 2, "institution": 2,
"paper": 3, "paper": 3,
} }
assert gb_g.metadata.edge_type_to_id == { assert gb_g.edge_type_to_id == {
"author:affiliated_with:institution": 0, "author:affiliated_with:institution": 0,
"author:writes:paper": 1, "author:writes:paper": 1,
"paper:cites:paper": 2, "paper:cites:paper": 2,
...@@ -1584,7 +1626,6 @@ def test_sample_neighbors_hetero_pick_number( ...@@ -1584,7 +1626,6 @@ def test_sample_neighbors_hetero_pick_number(
"N0:R1:N2": 1, "N0:R1:N2": 1,
"N0:R2:N3": 2, "N0:R2:N3": 2,
} }
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]) indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9]) indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
node_type_offset = torch.LongTensor([0, 1, 4, 7, 10]) node_type_offset = torch.LongTensor([0, 1, 4, 7, 10])
...@@ -1607,7 +1648,8 @@ def test_sample_neighbors_hetero_pick_number( ...@@ -1607,7 +1648,8 @@ def test_sample_neighbors_hetero_pick_number(
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
...@@ -1673,7 +1715,6 @@ def test_csc_sampling_graph_to_device(): ...@@ -1673,7 +1715,6 @@ def test_csc_sampling_graph_to_device():
"N0:R1:N2": 1, "N0:R1:N2": 1,
"N0:R2:N3": 2, "N0:R2:N3": 2,
} }
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]) indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9]) indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
node_type_offset = torch.LongTensor([0, 1, 4, 7, 10]) node_type_offset = torch.LongTensor([0, 1, 4, 7, 10])
...@@ -1696,7 +1737,8 @@ def test_csc_sampling_graph_to_device(): ...@@ -1696,7 +1737,8 @@ def test_csc_sampling_graph_to_device():
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
# Copy to device. # Copy to device.
...@@ -1725,7 +1767,6 @@ def test_sample_neighbors_homo_csc_format(): ...@@ -1725,7 +1767,6 @@ def test_sample_neighbors_homo_csc_format():
1 0 0 0 1 1 0 0 0 1
""" """
# Initialize data. # Initialize data.
total_num_nodes = 5
total_num_edges = 12 total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12]) indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4]) indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
...@@ -1769,8 +1810,6 @@ def test_sample_neighbors_hetero_csc_format(labor): ...@@ -1769,8 +1810,6 @@ def test_sample_neighbors_hetero_csc_format(labor):
# Initialize data. # Initialize data.
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9 total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9]) indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
...@@ -1785,7 +1824,8 @@ def test_sample_neighbors_hetero_csc_format(labor): ...@@ -1785,7 +1824,8 @@ def test_sample_neighbors_hetero_csc_format(labor):
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
# Sample on both node types. # Sample on both node types.
...@@ -1875,8 +1915,6 @@ def test_sample_neighbors_fanouts_csc_format( ...@@ -1875,8 +1915,6 @@ def test_sample_neighbors_fanouts_csc_format(
# Initialize data. # Initialize data.
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9 total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9]) indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
...@@ -1891,7 +1929,8 @@ def test_sample_neighbors_fanouts_csc_format( ...@@ -1891,7 +1929,8 @@ def test_sample_neighbors_fanouts_csc_format(
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
...@@ -1937,8 +1976,6 @@ def test_sample_neighbors_replace_csc_format( ...@@ -1937,8 +1976,6 @@ def test_sample_neighbors_replace_csc_format(
# Initialize data. # Initialize data.
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9 total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9]) indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
...@@ -1953,7 +1990,8 @@ def test_sample_neighbors_replace_csc_format( ...@@ -1953,7 +1990,8 @@ def test_sample_neighbors_replace_csc_format(
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0]), "n2": torch.LongTensor([0])}
...@@ -1986,7 +2024,6 @@ def test_sample_neighbors_return_eids_homo_csc_format(labor): ...@@ -1986,7 +2024,6 @@ def test_sample_neighbors_return_eids_homo_csc_format(labor):
1 0 0 0 1 1 0 0 0 1
""" """
# Initialize data. # Initialize data.
total_num_nodes = 5
total_num_edges = 12 total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12]) indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4]) indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
...@@ -2034,8 +2071,6 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor): ...@@ -2034,8 +2071,6 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor):
# Initialize data. # Initialize data.
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
total_num_nodes = 5
total_num_edges = 9 total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9]) indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
...@@ -2054,7 +2089,8 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor): ...@@ -2054,7 +2089,8 @@ def test_sample_neighbors_return_eids_hetero_csc_format(labor):
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
# Sample on both node types. # Sample on both node types.
...@@ -2092,7 +2128,6 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name): ...@@ -2092,7 +2128,6 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name):
1 0 0 0 1 1 0 0 0 1
""" """
# Initialize data. # Initialize data.
total_num_nodes = 5
total_num_edges = 12 total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12]) indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4]) indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
...@@ -2209,7 +2244,6 @@ def test_sample_neighbors_homo_pick_number_csc_format( ...@@ -2209,7 +2244,6 @@ def test_sample_neighbors_homo_pick_number_csc_format(
0 0 0 0 0 0 0 0 0 0 0 0
""" """
# Initialize data. # Initialize data.
total_num_nodes = 6
total_num_edges = 6 total_num_edges = 6
indptr = torch.LongTensor([0, 6, 6, 6, 6, 6, 6]) indptr = torch.LongTensor([0, 6, 6, 6, 6, 6, 6])
indices = torch.LongTensor([0, 1, 2, 3, 4, 5]) indices = torch.LongTensor([0, 1, 2, 3, 4, 5])
...@@ -2291,7 +2325,6 @@ def test_sample_neighbors_hetero_pick_number_csc_format( ...@@ -2291,7 +2325,6 @@ def test_sample_neighbors_hetero_pick_number_csc_format(
"N2:R1:N0": 1, "N2:R1:N0": 1,
"N3:R2:N0": 2, "N3:R2:N0": 2,
} }
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9]) indptr = torch.LongTensor([0, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9])
indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9]) indices = torch.LongTensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
node_type_offset = torch.LongTensor([0, 1, 4, 7, 10]) node_type_offset = torch.LongTensor([0, 1, 4, 7, 10])
...@@ -2314,7 +2347,8 @@ def test_sample_neighbors_hetero_pick_number_csc_format( ...@@ -2314,7 +2347,8 @@ def test_sample_neighbors_hetero_pick_number_csc_format(
edge_attributes=edge_attributes, edge_attributes=edge_attributes,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
# Generate subgraph via sample neighbors. # Generate subgraph via sample neighbors.
......
...@@ -76,7 +76,6 @@ def test_InSubgraphSampler_hetero(): ...@@ -76,7 +76,6 @@ def test_InSubgraphSampler_hetero():
"N1:R2:N0": 2, "N1:R2:N0": 2,
"N1:R3:N1": 3, "N1:R3:N1": 3,
} }
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14]) indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4]) indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
node_type_offset = torch.LongTensor([0, 3, 6]) node_type_offset = torch.LongTensor([0, 3, 6])
...@@ -86,7 +85,8 @@ def test_InSubgraphSampler_hetero(): ...@@ -86,7 +85,8 @@ def test_InSubgraphSampler_hetero():
indices=indices, indices=indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
......
...@@ -106,7 +106,6 @@ def get_hetero_graph(): ...@@ -106,7 +106,6 @@ def get_hetero_graph():
# num_nodes = 5, num_n1 = 2, num_n2 = 3 # num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10]) indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
...@@ -116,7 +115,8 @@ def get_hetero_graph(): ...@@ -116,7 +115,8 @@ def get_hetero_graph():
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
......
...@@ -1026,11 +1026,12 @@ def test_OnDiskDataset_Graph_homogeneous(): ...@@ -1026,11 +1026,12 @@ def test_OnDiskDataset_Graph_homogeneous():
assert torch.equal(graph.csc_indptr, graph2.csc_indptr) assert torch.equal(graph.csc_indptr, graph2.csc_indptr)
assert torch.equal(graph.indices, graph2.indices) assert torch.equal(graph.indices, graph2.indices)
assert graph.metadata is None and graph2.metadata is None
assert ( assert (
graph.node_type_offset is None and graph2.node_type_offset is None graph.node_type_offset is None and graph2.node_type_offset is None
) )
assert graph.type_per_edge is None and graph2.type_per_edge is None assert graph.type_per_edge is None and graph2.type_per_edge is None
assert graph.node_type_to_id is None and graph2.node_type_to_id is None
assert graph.edge_type_to_id is None and graph2.edge_type_to_id is None
def test_OnDiskDataset_Graph_heterogeneous(): def test_OnDiskDataset_Graph_heterogeneous():
...@@ -1040,10 +1041,16 @@ def test_OnDiskDataset_Graph_heterogeneous(): ...@@ -1040,10 +1041,16 @@ def test_OnDiskDataset_Graph_heterogeneous():
indices, indices,
node_type_offset, node_type_offset,
type_per_edge, type_per_edge,
metadata, node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4) ) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)
graph = gb.from_fused_csc( graph = gb.from_fused_csc(
csc_indptr, indices, node_type_offset, type_per_edge, None, metadata csc_indptr,
indices,
node_type_offset=node_type_offset,
type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
) )
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
...@@ -1065,8 +1072,8 @@ def test_OnDiskDataset_Graph_heterogeneous(): ...@@ -1065,8 +1072,8 @@ def test_OnDiskDataset_Graph_heterogeneous():
assert torch.equal(graph.indices, graph2.indices) assert torch.equal(graph.indices, graph2.indices)
assert torch.equal(graph.node_type_offset, graph2.node_type_offset) assert torch.equal(graph.node_type_offset, graph2.node_type_offset)
assert torch.equal(graph.type_per_edge, graph2.type_per_edge) assert torch.equal(graph.type_per_edge, graph2.type_per_edge)
assert graph.metadata.node_type_to_id == graph2.metadata.node_type_to_id assert graph.node_type_to_id == graph2.node_type_to_id
assert graph.metadata.edge_type_to_id == graph2.metadata.edge_type_to_id assert graph.edge_type_to_id == graph2.edge_type_to_id
def test_OnDiskDataset_Metadata(): def test_OnDiskDataset_Metadata():
...@@ -1854,15 +1861,17 @@ def test_OnDiskDataset_all_nodes_set_hetero(): ...@@ -1854,15 +1861,17 @@ def test_OnDiskDataset_all_nodes_set_hetero():
indices, indices,
node_type_offset, node_type_offset,
type_per_edge, type_per_edge,
metadata, node_type_to_id,
edge_type_to_id,
) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4) ) = gbt.random_hetero_graph(1000, 10 * 1000, 3, 4)
graph = gb.from_fused_csc( graph = gb.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
node_type_to_id=node_type_to_id,
edge_type_to_id=edge_type_to_id,
edge_attributes=None, edge_attributes=None,
metadata=metadata,
) )
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
......
...@@ -141,7 +141,6 @@ def get_hetero_graph(): ...@@ -141,7 +141,6 @@ def get_hetero_graph():
# num_nodes = 5, num_n1 = 2, num_n2 = 3 # num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10]) indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
...@@ -151,7 +150,8 @@ def get_hetero_graph(): ...@@ -151,7 +150,8 @@ def get_hetero_graph():
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
......
...@@ -117,7 +117,6 @@ def get_hetero_graph(): ...@@ -117,7 +117,6 @@ def get_hetero_graph():
# num_nodes = 5, num_n1 = 2, num_n2 = 3 # num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes = {"n1": 0, "n2": 1} ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1} etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
metadata = gb.GraphMetadata(ntypes, etypes)
indptr = torch.LongTensor([0, 2, 4, 6, 8, 10]) indptr = torch.LongTensor([0, 2, 4, 6, 8, 10])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1]) indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0])
...@@ -127,7 +126,8 @@ def get_hetero_graph(): ...@@ -127,7 +126,8 @@ def get_hetero_graph():
indices, indices,
node_type_offset=node_type_offset, node_type_offset=node_type_offset,
type_per_edge=type_per_edge, type_per_edge=type_per_edge,
metadata=metadata, node_type_to_id=ntypes,
edge_type_to_id=etypes,
) )
...@@ -207,7 +207,8 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor): ...@@ -207,7 +207,8 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
indices, indices,
node_type_offset, node_type_offset,
type_per_edge, type_per_edge,
metadata, node_type_to_id,
edge_type_to_id,
) = gb_test_utils.random_hetero_graph( ) = gb_test_utils.random_hetero_graph(
num_nodes, num_edges, num_ntypes, num_etypes num_nodes, num_edges, num_ntypes, num_etypes
) )
...@@ -218,10 +219,11 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor): ...@@ -218,10 +219,11 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
graph = gb.from_fused_csc( graph = gb.from_fused_csc(
csc_indptr, csc_indptr,
indices, indices,
node_type_offset, node_type_offset=node_type_offset,
type_per_edge, type_per_edge=type_per_edge,
edge_attributes, node_type_to_id=node_type_to_id,
metadata, edge_type_to_id=edge_type_to_id,
edge_attributes=edge_attributes,
) )
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment