Unverified Commit b003732d authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Update `__repr__` of `TorchBasedFeature` and `TorchBasedFeatureStore` (#6945)

parent 982f2028
"""Torch-based feature store for GraphBolt."""
import textwrap
from typing import Dict, List
import numpy as np
......@@ -169,7 +171,37 @@ class TorchBasedFeature(Feature):
self._tensor = self._tensor.pin_memory()
def __repr__(self) -> str:
return _torch_based_feature_str(self)
ret = (
"TorchBasedFeature(\n"
" feature={feature},\n"
" metadata={metadata},\n"
")"
)
feature_str = str(self._tensor)
feature_str_lines = feature_str.splitlines()
if len(feature_str_lines) > 1:
feature_str = (
feature_str_lines[0]
+ "\n"
+ textwrap.indent(
"\n".join(feature_str_lines[1:]), " " * len(" feature=")
)
)
metadata_str = str(self.metadata())
metadata_str_lines = metadata_str.splitlines()
if len(metadata_str_lines) > 1:
metadata_str = (
metadata_str_lines[0]
+ "\n"
+ textwrap.indent(
"\n".join(metadata_str_lines[1:]),
" " * len(" metadata="),
)
)
return ret.format(feature=feature_str, metadata=metadata_str)
class TorchBasedFeatureStore(BasicFeatureStore):
......@@ -236,40 +268,17 @@ class TorchBasedFeatureStore(BasicFeatureStore):
feature.pin_memory_()
def __repr__(self) -> str:
return _torch_based_feature_store_str(self._features)
def _torch_based_feature_str(feature: TorchBasedFeature) -> str:
final_str = "TorchBasedFeature("
indent_len = len(final_str)
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
feature_str = "feature=" + _add_indent(
str(feature._tensor), indent_len + len("feature=")
ret = "TorchBasedFeatureStore(\n" + " {features}\n" + ")"
features_str = str(self._features)
features_str_lines = features_str.splitlines()
if len(features_str_lines) > 1:
features_str = (
features_str_lines[0]
+ "\n"
+ textwrap.indent(
"\n".join(features_str_lines[1:]), " " * len(" ")
)
final_str += feature_str + ",\n" + " " * indent_len
metadata_str = "metadata=" + _add_indent(
str(feature.metadata()), indent_len + len("metadata=")
)
final_str += metadata_str + ",\n)"
return final_str
def _torch_based_feature_store_str(
features: Dict[str, TorchBasedFeature]
) -> str:
final_str = "TorchBasedFeatureStore"
indent_len = len(final_str)
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
features_str = _add_indent(str(features), indent_len)
final_str += features_str
return final_str
return ret.format(features=features_str)
......@@ -296,23 +296,27 @@ def test_torch_based_feature_repr(in_memory):
feature_a = gb.TorchBasedFeature(a, metadata=metadata)
feature_b = gb.TorchBasedFeature(b)
expected_str_feature_a = str(
"""TorchBasedFeature(feature=tensor([[1, 2, 3],
[4, 5, 6]]),
metadata={'max_value': 3},
)"""
expected_str_feature_a = (
"TorchBasedFeature(\n"
" feature=tensor([[1, 2, 3],\n"
" [4, 5, 6]]),\n"
" metadata={'max_value': 3},\n"
")"
)
expected_str_feature_b = str(
"""TorchBasedFeature(feature=tensor([[[1, 2],
[3, 4]],
[[4, 5],
[6, 7]]]),
metadata={},
)"""
expected_str_feature_b = (
"TorchBasedFeature(\n"
" feature=tensor([[[1, 2],\n"
" [3, 4]],\n"
"\n"
" [[4, 5],\n"
" [6, 7]]]),\n"
" metadata={},\n"
")"
)
assert str(feature_a) == expected_str_feature_a
assert str(feature_b) == expected_str_feature_b
assert repr(feature_a) == expected_str_feature_a, feature_a
assert repr(feature_b) == expected_str_feature_b, feature_b
a = b = metadata = None
feature_a = feature_b = None
expected_str_feature_a = expected_str_feature_b = None
......@@ -345,21 +349,24 @@ def test_torch_based_feature_store_repr(in_memory):
]
feature_store = gb.TorchBasedFeatureStore(feature_data)
expected_feature_store_str = str(
"""TorchBasedFeatureStore{(<OnDiskFeatureDataDomain.NODE: 'node'>, 'paper', 'a'): TorchBasedFeature(feature=tensor([[1, 2, 4],
[2, 5, 3]]),
metadata={},
), (<OnDiskFeatureDataDomain.EDGE: 'edge'>, 'paper:cites:paper', 'b'): TorchBasedFeature(feature=tensor([[[1, 2],
[3, 4]],
[[2, 5],
[3, 4]]]),
metadata={},
)}"""
)
assert str(feature_store) == expected_feature_store_str, print(
feature_store
expected_feature_store_str = (
"TorchBasedFeatureStore(\n"
" {(<OnDiskFeatureDataDomain.NODE: 'node'>, 'paper', 'a'): TorchBasedFeature(\n"
" feature=tensor([[1, 2, 4],\n"
" [2, 5, 3]]),\n"
" metadata={},\n"
" ), (<OnDiskFeatureDataDomain.EDGE: 'edge'>, 'paper:cites:paper', 'b'): TorchBasedFeature(\n"
" feature=tensor([[[1, 2],\n"
" [3, 4]],\n"
"\n"
" [[2, 5],\n"
" [3, 4]]]),\n"
" metadata={},\n"
" )}\n"
")"
)
assert repr(feature_store) == expected_feature_store_str, feature_store
a = b = feature_data = None
feature_store = expected_feature_store_str = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment