Unverified Commit 90e57e74 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Update `__repr__` of `ItemSet` and `ItemSetDict` (#6944)

parent b003732d
"""GraphBolt Itemset.""" """GraphBolt Itemset."""
import textwrap
from typing import Dict, Iterable, Iterator, Sized, Tuple, Union from typing import Dict, Iterable, Iterator, Sized, Tuple, Union
import torch import torch
...@@ -175,7 +176,14 @@ class ItemSet: ...@@ -175,7 +176,14 @@ class ItemSet:
return self._names return self._names
def __repr__(self) -> str: def __repr__(self) -> str:
return _itemset_str(self, "ItemSet") ret = (
f"ItemSet(\n"
f" items={self._items},\n"
f" names={self._names},\n"
f")"
)
return ret
class ItemSetDict: class ItemSetDict:
...@@ -330,31 +338,19 @@ class ItemSetDict: ...@@ -330,31 +338,19 @@ class ItemSetDict:
return self._names return self._names
def __repr__(self) -> str: def __repr__(self) -> str:
return _itemset_str(self, "ItemSetDict") ret = (
"ItemSetDict(\n"
" itemsets={itemsets},\n"
def _itemset_str(itemset: Union[ItemSet, ItemSetDict], name) -> str: " names={names},\n"
final_str = f"{name}(" ")"
indent_len = len(final_str)
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
items = (
itemset._items if isinstance(itemset, ItemSet) else itemset._itemsets
)
item_str = (
"items="
+ _add_indent(str(items), indent_len + len("items="))
+ ",\n"
+ " " * indent_len
) )
name_str = (
"names=" itemsets_str = repr(self._itemsets)
+ _add_indent(str(itemset._names), indent_len + len("items=")) lines = itemsets_str.splitlines()
+ ",\n)" itemsets_str = (
lines[0]
+ "\n"
+ textwrap.indent("\n".join(lines[1:]), " " * len(" itemsets="))
) )
final_str += item_str + name_str
return final_str return ret.format(itemsets=itemsets_str, names=self._names)
...@@ -2348,18 +2348,21 @@ def test_OnDiskTask_repr_homogeneous(): ...@@ -2348,18 +2348,21 @@ def test_OnDiskTask_repr_homogeneous():
) )
metadata = {"name": "node_classification"} metadata = {"name": "node_classification"}
task = gb.OnDiskTask(metadata, item_set, item_set, item_set) task = gb.OnDiskTask(metadata, item_set, item_set, item_set)
expected_str = str( expected_str = (
"""OnDiskTask(validation_set=ItemSet(items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])), "OnDiskTask(validation_set=ItemSet(\n"
names=('seed_nodes', 'labels'), " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
), " names=('seed_nodes', 'labels'),\n"
train_set=ItemSet(items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])), " ),\n"
names=('seed_nodes', 'labels'), " train_set=ItemSet(\n"
), " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
test_set=ItemSet(items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])), " names=('seed_nodes', 'labels'),\n"
names=('seed_nodes', 'labels'), " ),\n"
), " test_set=ItemSet(\n"
metadata={'name': 'node_classification'}, " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
)""" " names=('seed_nodes', 'labels'),\n"
" ),\n"
" metadata={'name': 'node_classification'},\n"
")"
) )
assert str(task) == expected_str, print(task) assert str(task) == expected_str, print(task)
...@@ -2373,30 +2376,39 @@ def test_OnDiskTask_repr_heterogeneous(): ...@@ -2373,30 +2376,39 @@ def test_OnDiskTask_repr_heterogeneous():
) )
metadata = {"name": "node_classification"} metadata = {"name": "node_classification"}
task = gb.OnDiskTask(metadata, item_set, item_set, item_set) task = gb.OnDiskTask(metadata, item_set, item_set, item_set)
expected_str = str( expected_str = (
"""OnDiskTask(validation_set=ItemSetDict(items={'user': ItemSet(items=(tensor([0, 1, 2, 3, 4]),), "OnDiskTask(validation_set=ItemSetDict(\n"
names=('seed_nodes',), " itemsets={'user': ItemSet(\n"
), 'item': ItemSet(items=(tensor([5, 6, 7, 8, 9]),), " items=(tensor([0, 1, 2, 3, 4]),),\n"
names=('seed_nodes',), " names=('seed_nodes',),\n"
)}, " ), 'item': ItemSet(\n"
names=('seed_nodes',), " items=(tensor([5, 6, 7, 8, 9]),),\n"
), " names=('seed_nodes',),\n"
train_set=ItemSetDict(items={'user': ItemSet(items=(tensor([0, 1, 2, 3, 4]),), " )},\n"
names=('seed_nodes',), " names=('seed_nodes',),\n"
), 'item': ItemSet(items=(tensor([5, 6, 7, 8, 9]),), " ),\n"
names=('seed_nodes',), " train_set=ItemSetDict(\n"
)}, " itemsets={'user': ItemSet(\n"
names=('seed_nodes',), " items=(tensor([0, 1, 2, 3, 4]),),\n"
), " names=('seed_nodes',),\n"
test_set=ItemSetDict(items={'user': ItemSet(items=(tensor([0, 1, 2, 3, 4]),), " ), 'item': ItemSet(\n"
names=('seed_nodes',), " items=(tensor([5, 6, 7, 8, 9]),),\n"
), 'item': ItemSet(items=(tensor([5, 6, 7, 8, 9]),), " names=('seed_nodes',),\n"
names=('seed_nodes',), " )},\n"
)}, " names=('seed_nodes',),\n"
names=('seed_nodes',), " ),\n"
), " test_set=ItemSetDict(\n"
metadata={'name': 'node_classification'}, " itemsets={'user': ItemSet(\n"
)""" " items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n"
" ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]),),\n"
" names=('seed_nodes',),\n"
" )},\n"
" names=('seed_nodes',),\n"
" ),\n"
" metadata={'name': 'node_classification'},\n"
")"
) )
assert str(task) == expected_str, print(task) assert str(task) == expected_str, print(task)
......
...@@ -529,24 +529,27 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts(): ...@@ -529,24 +529,27 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts():
def test_ItemSet_repr(): def test_ItemSet_repr():
# ItemSet with single name. # ItemSet with single name.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes") item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
expected_str = str( expected_str = (
"""ItemSet(items=(tensor([0, 1, 2, 3, 4]),), "ItemSet(\n"
names=('seed_nodes',), " items=(tensor([0, 1, 2, 3, 4]),),\n"
)""" " names=('seed_nodes',),\n"
")"
) )
assert str(item_set) == expected_str, print(item_set)
assert str(item_set) == expected_str, item_set
# ItemSet with multiple names. # ItemSet with multiple names.
item_set = gb.ItemSet( item_set = gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"), names=("seed_nodes", "labels"),
) )
expected_str = str( expected_str = (
"""ItemSet(items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])), "ItemSet(\n"
names=('seed_nodes', 'labels'), " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
)""" " names=('seed_nodes', 'labels'),\n"
")"
) )
assert str(item_set) == expected_str, print(item_set) assert str(item_set) == expected_str, item_set
def test_ItemSetDict_repr(): def test_ItemSetDict_repr():
...@@ -557,16 +560,19 @@ def test_ItemSetDict_repr(): ...@@ -557,16 +560,19 @@ def test_ItemSetDict_repr():
"item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"), "item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"),
} }
) )
expected_str = str( expected_str = (
"""ItemSetDict(items={'user': ItemSet(items=(tensor([0, 1, 2, 3, 4]),), "ItemSetDict(\n"
names=('seed_nodes',), " itemsets={'user': ItemSet(\n"
), 'item': ItemSet(items=(tensor([5, 6, 7, 8, 9]),), " items=(tensor([0, 1, 2, 3, 4]),),\n"
names=('seed_nodes',), " names=('seed_nodes',),\n"
)}, " ), 'item': ItemSet(\n"
names=('seed_nodes',), " items=(tensor([5, 6, 7, 8, 9]),),\n"
)""" " names=('seed_nodes',),\n"
" )},\n"
" names=('seed_nodes',),\n"
")"
) )
assert str(item_set) == expected_str, print(item_set) assert str(item_set) == expected_str, item_set
# ItemSetDict with multiple names. # ItemSetDict with multiple names.
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
...@@ -581,13 +587,16 @@ def test_ItemSetDict_repr(): ...@@ -581,13 +587,16 @@ def test_ItemSetDict_repr():
), ),
} }
) )
expected_str = str( expected_str = (
"""ItemSetDict(items={'user': ItemSet(items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])), "ItemSetDict(\n"
names=('seed_nodes', 'labels'), " itemsets={'user': ItemSet(\n"
), 'item': ItemSet(items=(tensor([5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14])), " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
names=('seed_nodes', 'labels'), " names=('seed_nodes', 'labels'),\n"
)}, " ), 'item': ItemSet(\n"
names=('seed_nodes', 'labels'), " items=(tensor([5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14])),\n"
)""" " names=('seed_nodes', 'labels'),\n"
" )},\n"
" names=('seed_nodes', 'labels'),\n"
")"
) )
assert str(item_set) == expected_str, print(item_set) assert str(item_set) == expected_str, item_set
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment