Unverified Commit 1328baf7 authored by LastWhisper's avatar LastWhisper Committed by GitHub
Browse files

[Graphbolt] Use `for` to implement `ItemSet.__iter__` (#6293)

Optimize the ItemSet.__iter__ function implementation
parent ce8a7dd3
......@@ -86,9 +86,23 @@ class ItemSet:
if len(self._items) == 1:
yield from self._items[0]
return
zip_items = zip(*self._items)
for item in zip_items:
yield tuple(item)
if isinstance(self._items[0], Sized):
items_len = len(self._items[0])
# Use for-loop to iterate over the items. Can avoid a long
# wait time when the items are torch tensors. Since torch
# tensors need to call self.unbind(0) to slice themselves.
# While for-loops are slower than zip, they prevent excessive
# wait times during the loading phase, and the impact on overall
# performance during the training/testing stage is minimal.
# For more details, see https://github.com/dmlc/dgl/pull/6293.
for i in range(items_len):
yield tuple(item[i] for item in self._items)
else:
# If the items are not Sized, we use zip to iterate over them.
zip_items = zip(*self._items)
for item in zip_items:
yield tuple(item)
def __len__(self) -> int:
if isinstance(self._items[0], Sized):
......
......@@ -36,10 +36,16 @@ def test_ItemSet_length():
ids = torch.arange(0, 5)
item_set = gb.ItemSet(ids)
assert len(item_set) == 5
# Test __iter__ method. Same as below.
for i, item in enumerate(item_set):
assert i == item.item()
# Tuple of iterables with valid length.
item_set = gb.ItemSet((torch.arange(0, 5), torch.arange(5, 10)))
assert len(item_set) == 5
for i, (item1, item2) in enumerate(item_set):
assert i == item1.item()
assert i + 5 == item2.item()
class InvalidLength:
def __iter__(self):
......@@ -49,11 +55,16 @@ def test_ItemSet_length():
item_set = gb.ItemSet(InvalidLength())
with pytest.raises(TypeError):
_ = len(item_set)
for i, item in enumerate(item_set):
assert i == item
# Tuple of iterables with invalid length.
item_set = gb.ItemSet((InvalidLength(), InvalidLength()))
with pytest.raises(TypeError):
_ = len(item_set)
for i, (item1, item2) in enumerate(item_set):
assert i == item1
assert i == item2
def test_ItemSet_iteration_seed_nodes():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment