Unverified Commit 34af2d68 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[GraphBolt] Enable ItemSet to initiate by integer (#6427)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-39-125.ap-northeast-1.compute.internal>
parent a2a3a913
"""GraphBolt Itemset."""
from typing import Dict, Iterable, Iterator, Sized, Tuple
from typing import Dict, Iterable, Iterator, Sized, Tuple, Union
__all__ = ["ItemSet", "ItemSetDict"]
......@@ -14,11 +14,13 @@ class ItemSet:
Parameters
----------
items: Iterable or Tuple[Iterable]
The items to be iterated over. If it's multi-dimensional iterable such
as `torch.Tensor`, it will be iterated over the first dimension. If it
is a tuple, each item in the tuple is an iterable of items.
names: str or Tuple[str], optional
items: Union[int, Iterable, Tuple[Iterable]]
The items to be iterated over. If it is a single integer, a `range()`
object will be created and iterated over. If it's multi-dimensional
iterable such as `torch.Tensor`, it will be iterated over the first
dimension. If it is a tuple, each item in the tuple is an iterable of
items.
names: Union[str, Tuple[str]], optional
The names of the items. If it is a tuple, each name corresponds to an
item in the tuple.
......@@ -27,8 +29,15 @@ class ItemSet:
>>> import torch
>>> from dgl import graphbolt as gb
1. Single iterable: seed nodes.
1. Integer: number of nodes.
>>> num = 10
>>> item_set = gb.ItemSet(num, names="seed_nodes")
>>> list(item_set)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
>>> item_set.names
('seed_nodes',)
2. Single iterable: seed nodes.
>>> node_ids = torch.arange(0, 5)
>>> item_set = gb.ItemSet(node_ids, names="seed_nodes")
>>> list(item_set)
......@@ -36,8 +45,7 @@ class ItemSet:
>>> item_set.names
('seed_nodes',)
2. Tuple of iterables with same shape: seed nodes and labels.
3. Tuple of iterables with same shape: seed nodes and labels.
>>> node_ids = torch.arange(0, 5)
>>> labels = torch.arange(5, 10)
>>> item_set = gb.ItemSet(
......@@ -48,8 +56,7 @@ class ItemSet:
>>> item_set.names
('seed_nodes', 'labels')
3. Tuple of iterables with different shape: node pairs and negative dsts.
4. Tuple of iterables with different shape: node pairs and negative dsts.
>>> node_pairs = torch.arange(0, 10).reshape(-1, 2)
>>> neg_dsts = torch.arange(10, 25).reshape(-1, 3)
>>> item_set = gb.ItemSet(
......@@ -66,9 +73,25 @@ class ItemSet:
def __init__(
self,
items: Iterable or Tuple[Iterable],
names: str or Tuple[str] = None,
items: Union[int, Iterable, Tuple[Iterable]],
names: Union[str, Tuple[str]] = None,
) -> None:
# Initiated by an integer.
if isinstance(items, int):
self._items = items
if names is not None:
if isinstance(names, tuple):
self._names = names
else:
self._names = (names,)
assert (
len(self._names) == 1
), "Number of names mustn't exceed 1 when item is an integer."
else:
self._names = None
return
# Otherwise.
if isinstance(items, tuple):
self._items = items
else:
......@@ -86,14 +109,18 @@ class ItemSet:
self._names = None
def __iter__(self) -> Iterator:
if isinstance(self._items, int):
yield from range(self._items)
return
if len(self._items) == 1:
yield from self._items[0]
return
if isinstance(self._items[0], Sized):
items_len = len(self._items[0])
# Use for-loop to iterate over the items. Can avoid a long
# wait time when the items are torch tensors. Since torch
# Use for-loop to iterate over the items. It can avoid a long
# waiting time when the items are torch tensors. Since torch
# tensors need to call self.unbind(0) to slice themselves.
# While for-loops are slower than zip, they prevent excessive
# wait times during the loading phase, and the impact on overall
......@@ -108,6 +135,8 @@ class ItemSet:
yield tuple(item)
def __len__(self) -> int:
if isinstance(self._items, int):
return self._items
if isinstance(self._items[0], Sized):
return len(self._items[0])
raise TypeError(
......
......@@ -23,6 +23,15 @@ def test_ItemSet_names():
item_set = gb.ItemSet(torch.arange(0, 5))
assert item_set.names is None
# Integer-initiated ItemSet with excessive names.
with pytest.raises(
AssertionError,
match=re.escape(
"Number of names mustn't exceed 1 when item is an integer."
),
):
_ = gb.ItemSet(5, names=("seed_nodes", "labels"))
# ItemSet with mismatched items and names.
with pytest.raises(
AssertionError,
......@@ -32,11 +41,18 @@ def test_ItemSet_names():
def test_ItemSet_length():
# Integer with valid length
num = 10
item_set = gb.ItemSet(num)
assert len(item_set) == 10
# Test __iter__() method. Same as below.
for i, item in enumerate(item_set):
assert i == item
# Single iterable with valid length.
ids = torch.arange(0, 5)
item_set = gb.ItemSet(ids)
assert len(item_set) == 5
# Test __iter__ method. Same as below.
for i, item in enumerate(item_set):
assert i == item.item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment