Unverified Commit b003732d authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Update `__repr__` of `TorchBasedFeature` and `TorchBasedFeatureStore` (#6945)

parent 982f2028
"""Torch-based feature store for GraphBolt.""" """Torch-based feature store for GraphBolt."""
import textwrap
from typing import Dict, List from typing import Dict, List
import numpy as np import numpy as np
...@@ -169,7 +171,37 @@ class TorchBasedFeature(Feature): ...@@ -169,7 +171,37 @@ class TorchBasedFeature(Feature):
self._tensor = self._tensor.pin_memory() self._tensor = self._tensor.pin_memory()
def __repr__(self) -> str: def __repr__(self) -> str:
return _torch_based_feature_str(self) ret = (
"TorchBasedFeature(\n"
" feature={feature},\n"
" metadata={metadata},\n"
")"
)
feature_str = str(self._tensor)
feature_str_lines = feature_str.splitlines()
if len(feature_str_lines) > 1:
feature_str = (
feature_str_lines[0]
+ "\n"
+ textwrap.indent(
"\n".join(feature_str_lines[1:]), " " * len(" feature=")
)
)
metadata_str = str(self.metadata())
metadata_str_lines = metadata_str.splitlines()
if len(metadata_str_lines) > 1:
metadata_str = (
metadata_str_lines[0]
+ "\n"
+ textwrap.indent(
"\n".join(metadata_str_lines[1:]),
" " * len(" metadata="),
)
)
return ret.format(feature=feature_str, metadata=metadata_str)
class TorchBasedFeatureStore(BasicFeatureStore): class TorchBasedFeatureStore(BasicFeatureStore):
...@@ -236,40 +268,17 @@ class TorchBasedFeatureStore(BasicFeatureStore): ...@@ -236,40 +268,17 @@ class TorchBasedFeatureStore(BasicFeatureStore):
feature.pin_memory_() feature.pin_memory_()
def __repr__(self) -> str: def __repr__(self) -> str:
return _torch_based_feature_store_str(self._features) ret = "TorchBasedFeatureStore(\n" + " {features}\n" + ")"
features_str = str(self._features)
def _torch_based_feature_str(feature: TorchBasedFeature) -> str: features_str_lines = features_str.splitlines()
final_str = "TorchBasedFeature(" if len(features_str_lines) > 1:
indent_len = len(final_str) features_str = (
features_str_lines[0]
def _add_indent(_str, indent): + "\n"
lines = _str.split("\n") + textwrap.indent(
lines = [lines[0]] + [" " * indent + line for line in lines[1:]] "\n".join(features_str_lines[1:]), " " * len(" ")
return "\n".join(lines)
feature_str = "feature=" + _add_indent(
str(feature._tensor), indent_len + len("feature=")
) )
final_str += feature_str + ",\n" + " " * indent_len
metadata_str = "metadata=" + _add_indent(
str(feature.metadata()), indent_len + len("metadata=")
) )
final_str += metadata_str + ",\n)"
return final_str
def _torch_based_feature_store_str(
features: Dict[str, TorchBasedFeature]
) -> str:
final_str = "TorchBasedFeatureStore"
indent_len = len(final_str)
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
features_str = _add_indent(str(features), indent_len) return ret.format(features=features_str)
final_str += features_str
return final_str
...@@ -296,23 +296,27 @@ def test_torch_based_feature_repr(in_memory): ...@@ -296,23 +296,27 @@ def test_torch_based_feature_repr(in_memory):
feature_a = gb.TorchBasedFeature(a, metadata=metadata) feature_a = gb.TorchBasedFeature(a, metadata=metadata)
feature_b = gb.TorchBasedFeature(b) feature_b = gb.TorchBasedFeature(b)
expected_str_feature_a = str( expected_str_feature_a = (
"""TorchBasedFeature(feature=tensor([[1, 2, 3], "TorchBasedFeature(\n"
[4, 5, 6]]), " feature=tensor([[1, 2, 3],\n"
metadata={'max_value': 3}, " [4, 5, 6]]),\n"
)""" " metadata={'max_value': 3},\n"
")"
) )
expected_str_feature_b = str( expected_str_feature_b = (
"""TorchBasedFeature(feature=tensor([[[1, 2], "TorchBasedFeature(\n"
[3, 4]], " feature=tensor([[[1, 2],\n"
" [3, 4]],\n"
[[4, 5], "\n"
[6, 7]]]), " [[4, 5],\n"
metadata={}, " [6, 7]]]),\n"
)""" " metadata={},\n"
")"
) )
assert str(feature_a) == expected_str_feature_a
assert str(feature_b) == expected_str_feature_b assert repr(feature_a) == expected_str_feature_a, feature_a
assert repr(feature_b) == expected_str_feature_b, feature_b
a = b = metadata = None a = b = metadata = None
feature_a = feature_b = None feature_a = feature_b = None
expected_str_feature_a = expected_str_feature_b = None expected_str_feature_a = expected_str_feature_b = None
...@@ -345,21 +349,24 @@ def test_torch_based_feature_store_repr(in_memory): ...@@ -345,21 +349,24 @@ def test_torch_based_feature_store_repr(in_memory):
] ]
feature_store = gb.TorchBasedFeatureStore(feature_data) feature_store = gb.TorchBasedFeatureStore(feature_data)
expected_feature_store_str = str( expected_feature_store_str = (
"""TorchBasedFeatureStore{(<OnDiskFeatureDataDomain.NODE: 'node'>, 'paper', 'a'): TorchBasedFeature(feature=tensor([[1, 2, 4], "TorchBasedFeatureStore(\n"
[2, 5, 3]]), " {(<OnDiskFeatureDataDomain.NODE: 'node'>, 'paper', 'a'): TorchBasedFeature(\n"
metadata={}, " feature=tensor([[1, 2, 4],\n"
), (<OnDiskFeatureDataDomain.EDGE: 'edge'>, 'paper:cites:paper', 'b'): TorchBasedFeature(feature=tensor([[[1, 2], " [2, 5, 3]]),\n"
[3, 4]], " metadata={},\n"
" ), (<OnDiskFeatureDataDomain.EDGE: 'edge'>, 'paper:cites:paper', 'b'): TorchBasedFeature(\n"
[[2, 5], " feature=tensor([[[1, 2],\n"
[3, 4]]]), " [3, 4]],\n"
metadata={}, "\n"
)}""" " [[2, 5],\n"
) " [3, 4]]]),\n"
assert str(feature_store) == expected_feature_store_str, print( " metadata={},\n"
feature_store " )}\n"
")"
) )
assert repr(feature_store) == expected_feature_store_str, feature_store
a = b = feature_data = None a = b = feature_data = None
feature_store = expected_feature_store_str = None feature_store = expected_feature_store_str = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment