Unverified Commit 905321f8 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Modify `__repr__` (#6953)

parent 80f36134
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import json import json
import os import os
import shutil import shutil
import textwrap
from copy import deepcopy from copy import deepcopy
from typing import Dict, List, Union from typing import Dict, List, Union
...@@ -339,7 +340,24 @@ class OnDiskTask: ...@@ -339,7 +340,24 @@ class OnDiskTask:
return self._test_set return self._test_set
def __repr__(self) -> str: def __repr__(self) -> str:
return _ondisk_task_str(self) ret = "{Classname}({attributes})"
attributes_str = ""
attributes = get_attributes(self)
attributes.reverse()
for attribute in attributes:
if attribute[0] == "_":
continue
value = getattr(self, attribute)
attributes_str += f"{attribute}={value},\n"
attributes_str = textwrap.indent(
attributes_str, " " * len("OnDiskTask(")
).strip()
return ret.format(
Classname=self.__class__.__name__, attributes=attributes_str
)
class OnDiskDataset(Dataset): class OnDiskDataset(Dataset):
...@@ -752,25 +770,3 @@ class BuiltinDataset(OnDiskDataset): ...@@ -752,25 +770,3 @@ class BuiltinDataset(OnDiskDataset):
extract_archive(zip_file_path, root, overwrite=True) extract_archive(zip_file_path, root, overwrite=True)
os.remove(zip_file_path) os.remove(zip_file_path)
super().__init__(dataset_dir, force_preprocess=False) super().__init__(dataset_dir, force_preprocess=False)
def _ondisk_task_str(task: OnDiskTask) -> str:
final_str = "OnDiskTask("
indent_len = len(final_str)
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
attributes = get_attributes(task)
attributes.reverse()
for name in attributes:
if name[0] == "_":
continue
val = getattr(task, name)
final_str += (
f"{name}={_add_indent(str(val), indent_len + len(name) + 1)},\n"
+ " " * indent_len
)
return final_str[:-indent_len] + ")"
...@@ -172,36 +172,24 @@ class TorchBasedFeature(Feature): ...@@ -172,36 +172,24 @@ class TorchBasedFeature(Feature):
def __repr__(self) -> str: def __repr__(self) -> str:
ret = ( ret = (
"TorchBasedFeature(\n" "{Classname}(\n"
" feature={feature},\n" " feature={feature},\n"
" metadata={metadata},\n" " metadata={metadata},\n"
")" ")"
) )
feature_str = str(self._tensor) feature_str = textwrap.indent(
feature_str_lines = feature_str.splitlines() str(self._tensor), " " * len(" feature=")
if len(feature_str_lines) > 1: ).strip()
feature_str = ( metadata_str = textwrap.indent(
feature_str_lines[0] str(self.metadata()), " " * len(" metadata=")
+ "\n" ).strip()
+ textwrap.indent(
"\n".join(feature_str_lines[1:]), " " * len(" feature=") return ret.format(
) Classname=self.__class__.__name__,
) feature=feature_str,
metadata=metadata_str,
metadata_str = str(self.metadata()) )
metadata_str_lines = metadata_str.splitlines()
if len(metadata_str_lines) > 1:
metadata_str = (
metadata_str_lines[0]
+ "\n"
+ textwrap.indent(
"\n".join(metadata_str_lines[1:]),
" " * len(" metadata="),
)
)
return ret.format(feature=feature_str, metadata=metadata_str)
class TorchBasedFeatureStore(BasicFeatureStore): class TorchBasedFeatureStore(BasicFeatureStore):
...@@ -268,17 +256,8 @@ class TorchBasedFeatureStore(BasicFeatureStore): ...@@ -268,17 +256,8 @@ class TorchBasedFeatureStore(BasicFeatureStore):
feature.pin_memory_() feature.pin_memory_()
def __repr__(self) -> str: def __repr__(self) -> str:
ret = "TorchBasedFeatureStore(\n" + " {features}\n" + ")" ret = "{Classname}(\n" + " {features}\n" + ")"
features_str = textwrap.indent(str(self._features), " ").strip()
features_str = str(self._features) return ret.format(
features_str_lines = features_str.splitlines() Classname=self.__class__.__name__, features=features_str
if len(features_str_lines) > 1: )
features_str = (
features_str_lines[0]
+ "\n"
+ textwrap.indent(
"\n".join(features_str_lines[1:]), " " * len(" ")
)
)
return ret.format(features=features_str)
...@@ -180,7 +180,7 @@ class ItemSet: ...@@ -180,7 +180,7 @@ class ItemSet:
def __repr__(self) -> str: def __repr__(self) -> str:
ret = ( ret = (
f"ItemSet(\n" f"{self.__class__.__name__}(\n"
f" items={self._items},\n" f" items={self._items},\n"
f" names={self._names},\n" f" names={self._names},\n"
f")" f")"
...@@ -342,18 +342,18 @@ class ItemSetDict: ...@@ -342,18 +342,18 @@ class ItemSetDict:
def __repr__(self) -> str: def __repr__(self) -> str:
ret = ( ret = (
"ItemSetDict(\n" "{Classname}(\n"
" itemsets={itemsets},\n" " itemsets={itemsets},\n"
" names={names},\n" " names={names},\n"
")" ")"
) )
itemsets_str = repr(self._itemsets) itemsets_str = textwrap.indent(
lines = itemsets_str.splitlines() repr(self._itemsets), " " * len(" itemsets=")
itemsets_str = ( ).strip()
lines[0]
+ "\n"
+ textwrap.indent("\n".join(lines[1:]), " " * len(" itemsets="))
)
return ret.format(itemsets=itemsets_str, names=self._names) return ret.format(
Classname=self.__class__.__name__,
itemsets=itemsets_str,
names=self._names,
)
...@@ -2570,21 +2570,20 @@ def test_OnDiskTask_repr_homogeneous(): ...@@ -2570,21 +2570,20 @@ def test_OnDiskTask_repr_homogeneous():
task = gb.OnDiskTask(metadata, item_set, item_set, item_set) task = gb.OnDiskTask(metadata, item_set, item_set, item_set)
expected_str = ( expected_str = (
"OnDiskTask(validation_set=ItemSet(\n" "OnDiskTask(validation_set=ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n" " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n" " names=('seed_nodes', 'labels'),\n"
" ),\n" " ),\n"
" train_set=ItemSet(\n" " train_set=ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n" " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n" " names=('seed_nodes', 'labels'),\n"
" ),\n" " ),\n"
" test_set=ItemSet(\n" " test_set=ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n" " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n" " names=('seed_nodes', 'labels'),\n"
" ),\n" " ),\n"
" metadata={'name': 'node_classification'},\n" " metadata={'name': 'node_classification'},)"
")"
) )
assert str(task) == expected_str, print(task) assert repr(task) == expected_str, task
def test_OnDiskTask_repr_heterogeneous(): def test_OnDiskTask_repr_heterogeneous():
...@@ -2598,39 +2597,38 @@ def test_OnDiskTask_repr_heterogeneous(): ...@@ -2598,39 +2597,38 @@ def test_OnDiskTask_repr_heterogeneous():
task = gb.OnDiskTask(metadata, item_set, item_set, item_set) task = gb.OnDiskTask(metadata, item_set, item_set, item_set)
expected_str = ( expected_str = (
"OnDiskTask(validation_set=ItemSetDict(\n" "OnDiskTask(validation_set=ItemSetDict(\n"
" itemsets={'user': ItemSet(\n" " itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n" " items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n" " names=('seed_nodes',),\n"
" ), 'item': ItemSet(\n" " ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]),),\n" " items=(tensor([5, 6, 7, 8, 9]),),\n"
" names=('seed_nodes',),\n" " names=('seed_nodes',),\n"
" )},\n" " )},\n"
" names=('seed_nodes',),\n" " names=('seed_nodes',),\n"
" ),\n" " ),\n"
" train_set=ItemSetDict(\n" " train_set=ItemSetDict(\n"
" itemsets={'user': ItemSet(\n" " itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n" " items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n" " names=('seed_nodes',),\n"
" ), 'item': ItemSet(\n" " ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]),),\n" " items=(tensor([5, 6, 7, 8, 9]),),\n"
" names=('seed_nodes',),\n" " names=('seed_nodes',),\n"
" )},\n" " )},\n"
" names=('seed_nodes',),\n" " names=('seed_nodes',),\n"
" ),\n" " ),\n"
" test_set=ItemSetDict(\n" " test_set=ItemSetDict(\n"
" itemsets={'user': ItemSet(\n" " itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n" " items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n" " names=('seed_nodes',),\n"
" ), 'item': ItemSet(\n" " ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]),),\n" " items=(tensor([5, 6, 7, 8, 9]),),\n"
" names=('seed_nodes',),\n" " names=('seed_nodes',),\n"
" )},\n" " )},\n"
" names=('seed_nodes',),\n" " names=('seed_nodes',),\n"
" ),\n" " ),\n"
" metadata={'name': 'node_classification'},\n" " metadata={'name': 'node_classification'},)"
")"
) )
assert str(task) == expected_str, print(task) assert repr(task) == expected_str, task
def test_OnDiskDataset_load_tasks_selectively(): def test_OnDiskDataset_load_tasks_selectively():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment