"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "09e777a3e13cf811e35da57abfe6ce239d9b0f15"
Unverified Commit fe78093f authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Update `__repr__` of `FusedCSCSamplingGraph` (#6956)

parent 2da6acef
"""CSC format sampling graph.""" """CSC format sampling graph."""
import textwrap
# pylint: disable= invalid-name # pylint: disable= invalid-name
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
...@@ -26,7 +29,38 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -26,7 +29,38 @@ class FusedCSCSamplingGraph(SamplingGraph):
r"""A sampling graph in CSC format.""" r"""A sampling graph in CSC format."""
def __repr__(self): def __repr__(self):
return _csc_sampling_graph_str(self) final_str = (
"{classname}(csc_indptr={csc_indptr},\n"
"indices={indices},\n"
"{metadata})"
)
classname_str = self.__class__.__name__
csc_indptr_str = str(self.csc_indptr)
indices_str = str(self.indices)
meta_str = f"total_num_nodes={self.total_num_nodes}, num_edges={self.num_edges},"
if self.node_type_offset is not None:
meta_str += f"\nnode_type_offset={self.node_type_offset},"
if self.type_per_edge is not None:
meta_str += f"\ntype_per_edge={self.type_per_edge},"
if self.node_type_to_id is not None:
meta_str += f"\nnode_type_to_id={self.node_type_to_id},"
if self.edge_type_to_id is not None:
meta_str += f"\nedge_type_to_id={self.edge_type_to_id},"
if self.node_attributes is not None:
meta_str += f"\nnode_attributes={self.node_attributes},"
if self.edge_attributes is not None:
meta_str += f"\nedge_attributes={self.edge_attributes},"
final_str = final_str.format(
classname=classname_str,
csc_indptr=csc_indptr_str,
indices=indices_str,
metadata=meta_str,
)
return textwrap.indent(
final_str, " " * (len(classname_str) + 1)
).strip()
def __init__( def __init__(
self, self,
...@@ -1120,19 +1154,23 @@ def fused_csc_sampling_graph( ...@@ -1120,19 +1154,23 @@ def fused_csc_sampling_graph(
-------- --------
>>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2} >>> ntypes = {'n1': 0, 'n2': 1, 'n3': 2}
>>> etypes = {'n1:e1:n2': 0, 'n1:e2:n3': 1} >>> etypes = {'n1:e1:n2': 0, 'n1:e2:n3': 1}
>>> csc_indptr = torch.tensor([0, 2, 5, 7]) >>> csc_indptr = torch.tensor([0, 2, 5, 7, 8])
>>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3]) >>> indices = torch.tensor([1, 3, 0, 1, 2, 0, 3, 2])
>>> node_type_offset = torch.tensor([0, 1, 2, 3]) >>> node_type_offset = torch.tensor([0, 1, 2, 4])
>>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0]) >>> type_per_edge = torch.tensor([0, 1, 0, 1, 1, 0, 0, 0])
>>> graph = graphbolt.fused_csc_sampling_graph(csc_indptr, indices, >>> graph = graphbolt.fused_csc_sampling_graph(csc_indptr, indices,
... node_type_offset=node_type_offset, ... node_type_offset=node_type_offset,
... type_per_edge=type_per_edge, ... type_per_edge=type_per_edge,
... node_type_to_id=ntypes, edge_type_to_id=etypes, ... node_type_to_id=ntypes, edge_type_to_id=etypes,
... node_attributes=None, edge_attributes=None,) ... node_attributes=None, edge_attributes=None,)
>>> print(graph) >>> print(graph)
FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7]), FusedCSCSamplingGraph(csc_indptr=tensor([0, 2, 5, 7, 8]),
indices=tensor([1, 3, 0, 1, 2, 0, 3]), indices=tensor([1, 3, 0, 1, 2, 0, 3, 2]),
total_num_nodes=3, total_num_edges=7) total_num_nodes=4, num_edges={'n1:e1:n2': 5, 'n1:e2:n3': 3},
node_type_offset=tensor([0, 1, 2, 4]),
type_per_edge=tensor([0, 1, 0, 1, 1, 0, 0, 0]),
node_type_to_id={'n1': 0, 'n2': 1, 'n3': 2},
edge_type_to_id={'n1:e1:n2': 0, 'n1:e2:n3': 1},)
""" """
if node_type_to_id is not None and edge_type_to_id is not None: if node_type_to_id is not None and edge_type_to_id is not None:
node_types = list(node_type_to_id.keys()) node_types = list(node_type_to_id.keys())
...@@ -1205,48 +1243,6 @@ def load_from_shared_memory( ...@@ -1205,48 +1243,6 @@ def load_from_shared_memory(
) )
def _csc_sampling_graph_str(graph: FusedCSCSamplingGraph) -> str:
"""Internal function for converting a csc sampling graph to string
representation.
"""
csc_indptr_str = str(graph.csc_indptr)
indices_str = str(graph.indices)
meta_str = f"num_nodes={graph.total_num_nodes}, num_edges={graph.num_edges}"
if graph.node_type_offset is not None:
meta_str += f", node_type_offset={graph.node_type_offset}"
if graph.type_per_edge is not None:
meta_str += f", type_per_edge={graph.type_per_edge}"
if graph.node_type_to_id is not None:
meta_str += f", node_type_to_id={graph.node_type_to_id}"
if graph.edge_type_to_id is not None:
meta_str += f", edge_type_to_id={graph.edge_type_to_id}"
if graph.node_attributes is not None:
meta_str += f", node_attributes={graph.node_attributes}"
if graph.edge_attributes is not None:
meta_str += f", edge_attributes={graph.edge_attributes}"
prefix = f"{type(graph).__name__}("
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
final_str = (
"csc_indptr="
+ _add_indent(csc_indptr_str, len("csc_indptr="))
+ ",\n"
+ "indices="
+ _add_indent(indices_str, len("indices="))
+ ",\n"
+ meta_str
+ ")"
)
final_str = prefix + _add_indent(final_str, len(prefix))
return final_str
def from_dglgraph( def from_dglgraph(
g: DGLGraph, g: DGLGraph,
is_homogeneous: bool = False, is_homogeneous: bool = False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment