Unverified Commit bdb78758 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] define __len__ for ItemSet/ItemSetDict (#5844)

parent f5330cb6
"""GraphBolt Itemset.""" """GraphBolt Itemset."""
from typing import Dict, Iterable, Iterator, Tuple from typing import Dict, Iterable, Iterator, Sized, Tuple
__all__ = ["ItemSet", "ItemSetDict"] __all__ = ["ItemSet", "ItemSetDict"]
...@@ -49,9 +49,6 @@ class ItemSet: ...@@ -49,9 +49,6 @@ class ItemSet:
def __init__(self, items: Iterable or Tuple[Iterable]) -> None: def __init__(self, items: Iterable or Tuple[Iterable]) -> None:
if isinstance(items, tuple): if isinstance(items, tuple):
assert all(
items[0].size(0) == item.size(0) for item in items
), "Size mismatch between items in tuple."
self._items = items self._items = items
else: else:
self._items = (items,) self._items = (items,)
...@@ -64,6 +61,13 @@ class ItemSet: ...@@ -64,6 +61,13 @@ class ItemSet:
for item in zip_items: for item in zip_items:
yield tuple(item) yield tuple(item)
def __len__(self) -> int:
if isinstance(self._items[0], Sized):
return len(self._items[0])
raise TypeError(
f"{type(self).__name__} instance doesn't have valid length."
)
class ItemSetDict: class ItemSetDict:
r"""An iterable ItemsetDict. r"""An iterable ItemsetDict.
...@@ -128,3 +132,6 @@ class ItemSetDict: ...@@ -128,3 +132,6 @@ class ItemSetDict:
for key, itemset in self._itemsets.items(): for key, itemset in self._itemsets.items():
for item in itemset: for item in itemset:
yield {key: item} yield {key: item}
def __len__(self) -> int:
return sum(len(itemset) for itemset in self._itemsets.values())
...@@ -5,11 +5,86 @@ from dgl import graphbolt as gb ...@@ -5,11 +5,86 @@ from dgl import graphbolt as gb
from torch.testing import assert_close from torch.testing import assert_close
def test_mismatch_size_in_tuple(): def test_ItemSet_valid_length():
# Size mismatch. # Single iterable.
node_pairs = (torch.arange(0, 5), torch.arange(5, 11)) ids = torch.arange(0, 5)
with pytest.raises(AssertionError): item_set = gb.ItemSet(ids)
_ = gb.ItemSet(node_pairs) assert len(item_set) == 5
# Tuple of iterables.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
item_set = gb.ItemSet(node_pairs)
assert len(item_set) == 5
def test_ItemSet_invalid_length():
class InvalidLength:
def __iter__(self):
return iter([0, 1, 2])
# Single iterable.
item_set = gb.ItemSet(InvalidLength())
with pytest.raises(TypeError):
_ = len(item_set)
# Tuple of iterables.
item_set = gb.ItemSet((InvalidLength(), InvalidLength()))
with pytest.raises(TypeError):
_ = len(item_set)
def test_ItemSetDict_valid_length():
# Single iterable.
user_ids = torch.arange(0, 5)
item_ids = torch.arange(0, 5)
item_set = gb.ItemSetDict(
{
"user": gb.ItemSet(user_ids),
"item": gb.ItemSet(item_ids),
}
)
assert len(item_set) == len(user_ids) + len(item_ids)
# Tuple of iterables.
like = (torch.arange(0, 5), torch.arange(0, 5))
follow = (torch.arange(0, 5), torch.arange(5, 10))
item_set = gb.ItemSetDict(
{
("user", "like", "item"): gb.ItemSet(like),
("user", "follow", "user"): gb.ItemSet(follow),
}
)
assert len(item_set) == len(like[0]) + len(follow[0])
def test_ItemSetDict_invalid_length():
class InvalidLength:
def __iter__(self):
return iter([0, 1, 2])
# Single iterable.
item_set = gb.ItemSetDict(
{
"user": gb.ItemSet(InvalidLength()),
"item": gb.ItemSet(InvalidLength()),
}
)
with pytest.raises(TypeError):
_ = len(item_set)
# Tuple of iterables.
item_set = gb.ItemSetDict(
{
("user", "like", "item"): gb.ItemSet(
(InvalidLength(), InvalidLength())
),
("user", "follow", "user"): gb.ItemSet(
(InvalidLength(), InvalidLength())
),
}
)
with pytest.raises(TypeError):
_ = len(item_set)
def test_ItemSet_node_edge_ids(): def test_ItemSet_node_edge_ids():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment