Unverified Commit bdb78758 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] define __len__ for ItemSet/ItemSetDict (#5844)

parent f5330cb6
"""GraphBolt Itemset."""
from typing import Dict, Iterable, Iterator, Tuple
from typing import Dict, Iterable, Iterator, Sized, Tuple
__all__ = ["ItemSet", "ItemSetDict"]
......@@ -49,9 +49,6 @@ class ItemSet:
def __init__(self, items: Iterable or Tuple[Iterable]) -> None:
if isinstance(items, tuple):
assert all(
items[0].size(0) == item.size(0) for item in items
), "Size mismatch between items in tuple."
self._items = items
else:
self._items = (items,)
......@@ -64,6 +61,13 @@ class ItemSet:
for item in zip_items:
yield tuple(item)
def __len__(self) -> int:
if isinstance(self._items[0], Sized):
return len(self._items[0])
raise TypeError(
f"{type(self).__name__} instance doesn't have valid length."
)
class ItemSetDict:
r"""An iterable ItemsetDict.
......@@ -128,3 +132,6 @@ class ItemSetDict:
for key, itemset in self._itemsets.items():
for item in itemset:
yield {key: item}
def __len__(self) -> int:
return sum(len(itemset) for itemset in self._itemsets.values())
......@@ -5,11 +5,86 @@ from dgl import graphbolt as gb
from torch.testing import assert_close
def test_mismatch_size_in_tuple():
# Size mismatch.
node_pairs = (torch.arange(0, 5), torch.arange(5, 11))
with pytest.raises(AssertionError):
_ = gb.ItemSet(node_pairs)
def test_ItemSet_valid_length():
# Single iterable.
ids = torch.arange(0, 5)
item_set = gb.ItemSet(ids)
assert len(item_set) == 5
# Tuple of iterables.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
item_set = gb.ItemSet(node_pairs)
assert len(item_set) == 5
def test_ItemSet_invalid_length():
class InvalidLength:
def __iter__(self):
return iter([0, 1, 2])
# Single iterable.
item_set = gb.ItemSet(InvalidLength())
with pytest.raises(TypeError):
_ = len(item_set)
# Tuple of iterables.
item_set = gb.ItemSet((InvalidLength(), InvalidLength()))
with pytest.raises(TypeError):
_ = len(item_set)
def test_ItemSetDict_valid_length():
# Single iterable.
user_ids = torch.arange(0, 5)
item_ids = torch.arange(0, 5)
item_set = gb.ItemSetDict(
{
"user": gb.ItemSet(user_ids),
"item": gb.ItemSet(item_ids),
}
)
assert len(item_set) == len(user_ids) + len(item_ids)
# Tuple of iterables.
like = (torch.arange(0, 5), torch.arange(0, 5))
follow = (torch.arange(0, 5), torch.arange(5, 10))
item_set = gb.ItemSetDict(
{
("user", "like", "item"): gb.ItemSet(like),
("user", "follow", "user"): gb.ItemSet(follow),
}
)
assert len(item_set) == len(like[0]) + len(follow[0])
def test_ItemSetDict_invalid_length():
class InvalidLength:
def __iter__(self):
return iter([0, 1, 2])
# Single iterable.
item_set = gb.ItemSetDict(
{
"user": gb.ItemSet(InvalidLength()),
"item": gb.ItemSet(InvalidLength()),
}
)
with pytest.raises(TypeError):
_ = len(item_set)
# Tuple of iterables.
item_set = gb.ItemSetDict(
{
("user", "like", "item"): gb.ItemSet(
(InvalidLength(), InvalidLength())
),
("user", "follow", "user"): gb.ItemSet(
(InvalidLength(), InvalidLength())
),
}
)
with pytest.raises(TypeError):
_ = len(item_set)
def test_ItemSet_node_edge_ids():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment