"tests/python/vscode:/vscode.git/clone" did not exist on "308bd6f5b245929211f365396ca2007ac151b8e7"
Unverified Commit 905321f8 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Modify `__repr__` (#6953)

parent 80f36134
......@@ -3,6 +3,7 @@
import json
import os
import shutil
import textwrap
from copy import deepcopy
from typing import Dict, List, Union
......@@ -339,7 +340,24 @@ class OnDiskTask:
return self._test_set
def __repr__(self) -> str:
return _ondisk_task_str(self)
ret = "{Classname}({attributes})"
attributes_str = ""
attributes = get_attributes(self)
attributes.reverse()
for attribute in attributes:
if attribute[0] == "_":
continue
value = getattr(self, attribute)
attributes_str += f"{attribute}={value},\n"
attributes_str = textwrap.indent(
attributes_str, " " * len("OnDiskTask(")
).strip()
return ret.format(
Classname=self.__class__.__name__, attributes=attributes_str
)
class OnDiskDataset(Dataset):
......@@ -752,25 +770,3 @@ class BuiltinDataset(OnDiskDataset):
extract_archive(zip_file_path, root, overwrite=True)
os.remove(zip_file_path)
super().__init__(dataset_dir, force_preprocess=False)
def _ondisk_task_str(task: OnDiskTask) -> str:
final_str = "OnDiskTask("
indent_len = len(final_str)
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
attributes = get_attributes(task)
attributes.reverse()
for name in attributes:
if name[0] == "_":
continue
val = getattr(task, name)
final_str += (
f"{name}={_add_indent(str(val), indent_len + len(name) + 1)},\n"
+ " " * indent_len
)
return final_str[:-indent_len] + ")"
......@@ -172,37 +172,25 @@ class TorchBasedFeature(Feature):
def __repr__(self) -> str:
ret = (
"TorchBasedFeature(\n"
"{Classname}(\n"
" feature={feature},\n"
" metadata={metadata},\n"
")"
)
feature_str = str(self._tensor)
feature_str_lines = feature_str.splitlines()
if len(feature_str_lines) > 1:
feature_str = (
feature_str_lines[0]
+ "\n"
+ textwrap.indent(
"\n".join(feature_str_lines[1:]), " " * len(" feature=")
)
)
metadata_str = str(self.metadata())
metadata_str_lines = metadata_str.splitlines()
if len(metadata_str_lines) > 1:
metadata_str = (
metadata_str_lines[0]
+ "\n"
+ textwrap.indent(
"\n".join(metadata_str_lines[1:]),
" " * len(" metadata="),
)
feature_str = textwrap.indent(
str(self._tensor), " " * len(" feature=")
).strip()
metadata_str = textwrap.indent(
str(self.metadata()), " " * len(" metadata=")
).strip()
return ret.format(
Classname=self.__class__.__name__,
feature=feature_str,
metadata=metadata_str,
)
return ret.format(feature=feature_str, metadata=metadata_str)
class TorchBasedFeatureStore(BasicFeatureStore):
r"""A store to manage multiple pytorch based feature for access.
......@@ -268,17 +256,8 @@ class TorchBasedFeatureStore(BasicFeatureStore):
feature.pin_memory_()
def __repr__(self) -> str:
ret = "TorchBasedFeatureStore(\n" + " {features}\n" + ")"
features_str = str(self._features)
features_str_lines = features_str.splitlines()
if len(features_str_lines) > 1:
features_str = (
features_str_lines[0]
+ "\n"
+ textwrap.indent(
"\n".join(features_str_lines[1:]), " " * len(" ")
)
ret = "{Classname}(\n" + " {features}\n" + ")"
features_str = textwrap.indent(str(self._features), " ").strip()
return ret.format(
Classname=self.__class__.__name__, features=features_str
)
return ret.format(features=features_str)
......@@ -180,7 +180,7 @@ class ItemSet:
def __repr__(self) -> str:
ret = (
f"ItemSet(\n"
f"{self.__class__.__name__}(\n"
f" items={self._items},\n"
f" names={self._names},\n"
f")"
......@@ -342,18 +342,18 @@ class ItemSetDict:
def __repr__(self) -> str:
ret = (
"ItemSetDict(\n"
"{Classname}(\n"
" itemsets={itemsets},\n"
" names={names},\n"
")"
)
itemsets_str = repr(self._itemsets)
lines = itemsets_str.splitlines()
itemsets_str = (
lines[0]
+ "\n"
+ textwrap.indent("\n".join(lines[1:]), " " * len(" itemsets="))
)
itemsets_str = textwrap.indent(
repr(self._itemsets), " " * len(" itemsets=")
).strip()
return ret.format(itemsets=itemsets_str, names=self._names)
return ret.format(
Classname=self.__class__.__name__,
itemsets=itemsets_str,
names=self._names,
)
......@@ -2581,10 +2581,9 @@ def test_OnDiskTask_repr_homogeneous():
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n"
" ),\n"
" metadata={'name': 'node_classification'},\n"
")"
" metadata={'name': 'node_classification'},)"
)
assert str(task) == expected_str, print(task)
assert repr(task) == expected_str, task
def test_OnDiskTask_repr_heterogeneous():
......@@ -2627,10 +2626,9 @@ def test_OnDiskTask_repr_heterogeneous():
" )},\n"
" names=('seed_nodes',),\n"
" ),\n"
" metadata={'name': 'node_classification'},\n"
")"
" metadata={'name': 'node_classification'},)"
)
assert str(task) == expected_str, print(task)
assert repr(task) == expected_str, task
def test_OnDiskDataset_load_tasks_selectively():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment