"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "f02caa53307c4c157c210c9fbe3fffb97ac2e635"
Unverified Commit 2b966da2 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add `__repr__` to ItemSet. (#6790)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 67d93458
...@@ -174,6 +174,9 @@ class ItemSet: ...@@ -174,6 +174,9 @@ class ItemSet:
"""Return the names of the items.""" """Return the names of the items."""
return self._names return self._names
def __repr__(self) -> str:
return _itemset_str(self, "ItemSet")
class ItemSetDict: class ItemSetDict:
r"""Dictionary wrapper of **ItemSet**. r"""Dictionary wrapper of **ItemSet**.
...@@ -325,3 +328,33 @@ class ItemSetDict: ...@@ -325,3 +328,33 @@ class ItemSetDict:
def names(self) -> Tuple[str]: def names(self) -> Tuple[str]:
"""Return the names of the items.""" """Return the names of the items."""
return self._names return self._names
def __repr__(self) -> str:
return _itemset_str(self, "ItemSetDict")
def _itemset_str(itemset: Union[ItemSet, ItemSetDict], name) -> str:
final_str = f"{name}("
indent_len = len(final_str)
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
items = (
itemset._items if isinstance(itemset, ItemSet) else itemset._itemsets
)
item_str = (
"items="
+ _add_indent(str(items), indent_len + len("items="))
+ ",\n"
+ " " * indent_len
)
name_str = (
"names="
+ _add_indent(str(itemset._names), indent_len + len("items="))
+ ",\n)"
)
final_str += item_str + name_str
return final_str
...@@ -524,3 +524,70 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts(): ...@@ -524,3 +524,70 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts():
assert torch.equal(item_set[:]["user:like:item"][1], neg_dsts) assert torch.equal(item_set[:]["user:like:item"][1], neg_dsts)
assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs) assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs)
assert torch.equal(item_set[:]["user:follow:user"][1], neg_dsts) assert torch.equal(item_set[:]["user:follow:user"][1], neg_dsts)
def test_ItemSet_repr():
# ItemSet with single name.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes")
expected_str = str(
"""ItemSet(items=(tensor([0, 1, 2, 3, 4]),),
names=('seed_nodes',),
)"""
)
assert str(item_set) == expected_str, print(item_set)
# ItemSet with multiple names.
item_set = gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"),
)
expected_str = str(
"""ItemSet(items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),
names=('seed_nodes', 'labels'),
)"""
)
assert str(item_set) == expected_str, print(item_set)
def test_ItemSetDict_repr():
# ItemSetDict with single name.
item_set = gb.ItemSetDict(
{
"user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
"item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"),
}
)
expected_str = str(
"""ItemSetDict(items={'user': ItemSet(items=(tensor([0, 1, 2, 3, 4]),),
names=('seed_nodes',),
), 'item': ItemSet(items=(tensor([5, 6, 7, 8, 9]),),
names=('seed_nodes',),
)},
names=('seed_nodes',),
)"""
)
assert str(item_set) == expected_str, print(item_set)
# ItemSetDict with multiple names.
item_set = gb.ItemSetDict(
{
"user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"),
),
"item": gb.ItemSet(
(torch.arange(5, 10), torch.arange(10, 15)),
names=("seed_nodes", "labels"),
),
}
)
expected_str = str(
"""ItemSetDict(items={'user': ItemSet(items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),
names=('seed_nodes', 'labels'),
), 'item': ItemSet(items=(tensor([5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14])),
names=('seed_nodes', 'labels'),
)},
names=('seed_nodes', 'labels'),
)"""
)
assert str(item_set) == expected_str, print(item_set)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment