Unverified Commit 34af2d68 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Enable ItemSet to initiate by integer (#6427)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-39-125.ap-northeast-1.compute.internal>
parent a2a3a913
"""GraphBolt Itemset.""" """GraphBolt Itemset."""
from typing import Dict, Iterable, Iterator, Sized, Tuple from typing import Dict, Iterable, Iterator, Sized, Tuple, Union
__all__ = ["ItemSet", "ItemSetDict"] __all__ = ["ItemSet", "ItemSetDict"]
...@@ -14,11 +14,13 @@ class ItemSet: ...@@ -14,11 +14,13 @@ class ItemSet:
Parameters Parameters
---------- ----------
items: Iterable or Tuple[Iterable] items: Union[int, Iterable, Tuple[Iterable]]
The items to be iterated over. If it's multi-dimensional iterable such The items to be iterated over. If it is a single integer, a `range()`
as `torch.Tensor`, it will be iterated over the first dimension. If it object will be created and iterated over. If it's multi-dimensional
is a tuple, each item in the tuple is an iterable of items. iterable such as `torch.Tensor`, it will be iterated over the first
names: str or Tuple[str], optional dimension. If it is a tuple, each item in the tuple is an iterable of
items.
names: Union[str, Tuple[str]], optional
The names of the items. If it is a tuple, each name corresponds to an The names of the items. If it is a tuple, each name corresponds to an
item in the tuple. item in the tuple.
...@@ -27,8 +29,15 @@ class ItemSet: ...@@ -27,8 +29,15 @@ class ItemSet:
>>> import torch >>> import torch
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
1. Single iterable: seed nodes. 1. Integer: number of nodes.
>>> num = 10
>>> item_set = gb.ItemSet(num, names="seed_nodes")
>>> list(item_set)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> item_set.names
('seed_nodes',)
2. Single iterable: seed nodes.
>>> node_ids = torch.arange(0, 5) >>> node_ids = torch.arange(0, 5)
>>> item_set = gb.ItemSet(node_ids, names="seed_nodes") >>> item_set = gb.ItemSet(node_ids, names="seed_nodes")
>>> list(item_set) >>> list(item_set)
...@@ -36,8 +45,7 @@ class ItemSet: ...@@ -36,8 +45,7 @@ class ItemSet:
>>> item_set.names >>> item_set.names
('seed_nodes',) ('seed_nodes',)
2. Tuple of iterables with same shape: seed nodes and labels. 3. Tuple of iterables with same shape: seed nodes and labels.
>>> node_ids = torch.arange(0, 5) >>> node_ids = torch.arange(0, 5)
>>> labels = torch.arange(5, 10) >>> labels = torch.arange(5, 10)
>>> item_set = gb.ItemSet( >>> item_set = gb.ItemSet(
...@@ -48,8 +56,7 @@ class ItemSet: ...@@ -48,8 +56,7 @@ class ItemSet:
>>> item_set.names >>> item_set.names
('seed_nodes', 'labels') ('seed_nodes', 'labels')
3. Tuple of iterables with different shape: node pairs and negative dsts. 4. Tuple of iterables with different shape: node pairs and negative dsts.
>>> node_pairs = torch.arange(0, 10).reshape(-1, 2) >>> node_pairs = torch.arange(0, 10).reshape(-1, 2)
>>> neg_dsts = torch.arange(10, 25).reshape(-1, 3) >>> neg_dsts = torch.arange(10, 25).reshape(-1, 3)
>>> item_set = gb.ItemSet( >>> item_set = gb.ItemSet(
...@@ -66,9 +73,25 @@ class ItemSet: ...@@ -66,9 +73,25 @@ class ItemSet:
def __init__( def __init__(
self, self,
items: Iterable or Tuple[Iterable], items: Union[int, Iterable, Tuple[Iterable]],
names: str or Tuple[str] = None, names: Union[str, Tuple[str]] = None,
) -> None: ) -> None:
# Initiated by an integer.
if isinstance(items, int):
self._items = items
if names is not None:
if isinstance(names, tuple):
self._names = names
else:
self._names = (names,)
assert (
len(self._names) == 1
), "Number of names mustn't exceed 1 when item is an integer."
else:
self._names = None
return
# Otherwise.
if isinstance(items, tuple): if isinstance(items, tuple):
self._items = items self._items = items
else: else:
...@@ -86,14 +109,18 @@ class ItemSet: ...@@ -86,14 +109,18 @@ class ItemSet:
self._names = None self._names = None
def __iter__(self) -> Iterator: def __iter__(self) -> Iterator:
if isinstance(self._items, int):
yield from range(self._items)
return
if len(self._items) == 1: if len(self._items) == 1:
yield from self._items[0] yield from self._items[0]
return return
if isinstance(self._items[0], Sized): if isinstance(self._items[0], Sized):
items_len = len(self._items[0]) items_len = len(self._items[0])
# Use for-loop to iterate over the items. Can avoid a long # Use for-loop to iterate over the items. It can avoid a long
# wait time when the items are torch tensors. Since torch # waiting time when the items are torch tensors. Since torch
# tensors need to call self.unbind(0) to slice themselves. # tensors need to call self.unbind(0) to slice themselves.
# While for-loops are slower than zip, they prevent excessive # While for-loops are slower than zip, they prevent excessive
# wait times during the loading phase, and the impact on overall # wait times during the loading phase, and the impact on overall
...@@ -108,6 +135,8 @@ class ItemSet: ...@@ -108,6 +135,8 @@ class ItemSet:
yield tuple(item) yield tuple(item)
def __len__(self) -> int: def __len__(self) -> int:
if isinstance(self._items, int):
return self._items
if isinstance(self._items[0], Sized): if isinstance(self._items[0], Sized):
return len(self._items[0]) return len(self._items[0])
raise TypeError( raise TypeError(
......
...@@ -23,6 +23,15 @@ def test_ItemSet_names(): ...@@ -23,6 +23,15 @@ def test_ItemSet_names():
item_set = gb.ItemSet(torch.arange(0, 5)) item_set = gb.ItemSet(torch.arange(0, 5))
assert item_set.names is None assert item_set.names is None
# Integer-initiated ItemSet with excessive names.
with pytest.raises(
AssertionError,
match=re.escape(
"Number of names mustn't exceed 1 when item is an integer."
),
):
_ = gb.ItemSet(5, names=("seed_nodes", "labels"))
# ItemSet with mismatched items and names. # ItemSet with mismatched items and names.
with pytest.raises( with pytest.raises(
AssertionError, AssertionError,
...@@ -32,11 +41,18 @@ def test_ItemSet_names(): ...@@ -32,11 +41,18 @@ def test_ItemSet_names():
def test_ItemSet_length(): def test_ItemSet_length():
# Integer with valid length
num = 10
item_set = gb.ItemSet(num)
assert len(item_set) == 10
# Test __iter__() method. Same as below.
for i, item in enumerate(item_set):
assert i == item
# Single iterable with valid length. # Single iterable with valid length.
ids = torch.arange(0, 5) ids = torch.arange(0, 5)
item_set = gb.ItemSet(ids) item_set = gb.ItemSet(ids)
assert len(item_set) == 5 assert len(item_set) == 5
# Test __iter__ method. Same as below.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert i == item.item() assert i == item.item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment