Unverified Commit 90e57e74 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Update `__repr__` of `ItemSet` and `ItemSetDict` (#6944)

parent b003732d
"""GraphBolt Itemset."""
import textwrap
from typing import Dict, Iterable, Iterator, Sized, Tuple, Union
import torch
......@@ -175,7 +176,14 @@ class ItemSet:
return self._names
def __repr__(self) -> str:
return _itemset_str(self, "ItemSet")
ret = (
f"ItemSet(\n"
f" items={self._items},\n"
f" names={self._names},\n"
f")"
)
return ret
class ItemSetDict:
......@@ -330,31 +338,19 @@ class ItemSetDict:
return self._names
def __repr__(self) -> str:
return _itemset_str(self, "ItemSetDict")
def _itemset_str(itemset: Union[ItemSet, ItemSetDict], name) -> str:
final_str = f"{name}("
indent_len = len(final_str)
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
items = (
itemset._items if isinstance(itemset, ItemSet) else itemset._itemsets
)
item_str = (
"items="
+ _add_indent(str(items), indent_len + len("items="))
+ ",\n"
+ " " * indent_len
ret = (
"ItemSetDict(\n"
" itemsets={itemsets},\n"
" names={names},\n"
")"
)
name_str = (
"names="
+ _add_indent(str(itemset._names), indent_len + len("items="))
+ ",\n)"
itemsets_str = repr(self._itemsets)
lines = itemsets_str.splitlines()
itemsets_str = (
lines[0]
+ "\n"
+ textwrap.indent("\n".join(lines[1:]), " " * len(" itemsets="))
)
final_str += item_str + name_str
return final_str
return ret.format(itemsets=itemsets_str, names=self._names)
......@@ -2348,18 +2348,21 @@ def test_OnDiskTask_repr_homogeneous():
)
metadata = {"name": "node_classification"}
task = gb.OnDiskTask(metadata, item_set, item_set, item_set)
expected_str = str(
"""OnDiskTask(validation_set=ItemSet(items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),
names=('seed_nodes', 'labels'),
),
train_set=ItemSet(items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),
names=('seed_nodes', 'labels'),
),
test_set=ItemSet(items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),
names=('seed_nodes', 'labels'),
),
metadata={'name': 'node_classification'},
)"""
expected_str = (
"OnDiskTask(validation_set=ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n"
" ),\n"
" train_set=ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n"
" ),\n"
" test_set=ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n"
" ),\n"
" metadata={'name': 'node_classification'},\n"
")"
)
assert str(task) == expected_str, print(task)
......@@ -2373,30 +2376,39 @@ def test_OnDiskTask_repr_heterogeneous():
)
metadata = {"name": "node_classification"}
task = gb.OnDiskTask(metadata, item_set, item_set, item_set)
expected_str = str(
"""OnDiskTask(validation_set=ItemSetDict(items={'user': ItemSet(items=(tensor([0, 1, 2, 3, 4]),),
names=('seed_nodes',),
), 'item': ItemSet(items=(tensor([5, 6, 7, 8, 9]),),
names=('seed_nodes',),
)},
names=('seed_nodes',),
),
train_set=ItemSetDict(items={'user': ItemSet(items=(tensor([0, 1, 2, 3, 4]),),
names=('seed_nodes',),
), 'item': ItemSet(items=(tensor([5, 6, 7, 8, 9]),),
names=('seed_nodes',),
)},
names=('seed_nodes',),
),
test_set=ItemSetDict(items={'user': ItemSet(items=(tensor([0, 1, 2, 3, 4]),),
names=('seed_nodes',),
), 'item': ItemSet(items=(tensor([5, 6, 7, 8, 9]),),
names=('seed_nodes',),
)},
names=('seed_nodes',),
),
metadata={'name': 'node_classification'},
)"""
expected_str = (
"OnDiskTask(validation_set=ItemSetDict(\n"
" itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n"
" ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]),),\n"
" names=('seed_nodes',),\n"
" )},\n"
" names=('seed_nodes',),\n"
" ),\n"
" train_set=ItemSetDict(\n"
" itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n"
" ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]),),\n"
" names=('seed_nodes',),\n"
" )},\n"
" names=('seed_nodes',),\n"
" ),\n"
" test_set=ItemSetDict(\n"
" itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n"
" ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]),),\n"
" names=('seed_nodes',),\n"
" )},\n"
" names=('seed_nodes',),\n"
" ),\n"
" metadata={'name': 'node_classification'},\n"
")"
)
assert str(task) == expected_str, print(task)
......
......@@ -529,24 +529,27 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts():
def test_ItemSet_repr():
# ItemSet with single name.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
expected_str = str(
"""ItemSet(items=(tensor([0, 1, 2, 3, 4]),),
names=('seed_nodes',),
)"""
expected_str = (
"ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n"
")"
)
assert str(item_set) == expected_str, print(item_set)
assert str(item_set) == expected_str, item_set
# ItemSet with multiple names.
item_set = gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"),
)
expected_str = str(
"""ItemSet(items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),
names=('seed_nodes', 'labels'),
)"""
expected_str = (
"ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n"
")"
)
assert str(item_set) == expected_str, print(item_set)
assert str(item_set) == expected_str, item_set
def test_ItemSetDict_repr():
......@@ -557,16 +560,19 @@ def test_ItemSetDict_repr():
"item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"),
}
)
expected_str = str(
"""ItemSetDict(items={'user': ItemSet(items=(tensor([0, 1, 2, 3, 4]),),
names=('seed_nodes',),
), 'item': ItemSet(items=(tensor([5, 6, 7, 8, 9]),),
names=('seed_nodes',),
)},
names=('seed_nodes',),
)"""
expected_str = (
"ItemSetDict(\n"
" itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n"
" ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]),),\n"
" names=('seed_nodes',),\n"
" )},\n"
" names=('seed_nodes',),\n"
")"
)
assert str(item_set) == expected_str, print(item_set)
assert str(item_set) == expected_str, item_set
# ItemSetDict with multiple names.
item_set = gb.ItemSetDict(
......@@ -581,13 +587,16 @@ def test_ItemSetDict_repr():
),
}
)
expected_str = str(
"""ItemSetDict(items={'user': ItemSet(items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),
names=('seed_nodes', 'labels'),
), 'item': ItemSet(items=(tensor([5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14])),
names=('seed_nodes', 'labels'),
)},
names=('seed_nodes', 'labels'),
)"""
expected_str = (
"ItemSetDict(\n"
" itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n"
" ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14])),\n"
" names=('seed_nodes', 'labels'),\n"
" )},\n"
" names=('seed_nodes', 'labels'),\n"
")"
)
assert str(item_set) == expected_str, print(item_set)
assert str(item_set) == expected_str, item_set
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment