Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
1328baf7
Unverified
Commit
1328baf7
authored
Sep 11, 2023
by
LastWhisper
Committed by
GitHub
Sep 11, 2023
Browse files
[Graphbolt] Use `for` to implement `ItemSet.__iter__` (#6293)
Optimize the ItemSet.__iter__ function implementation
parent
ce8a7dd3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
28 additions
and
3 deletions
+28
-3
python/dgl/graphbolt/itemset.py
python/dgl/graphbolt/itemset.py
+17
-3
tests/python/pytorch/graphbolt/test_itemset.py
tests/python/pytorch/graphbolt/test_itemset.py
+11
-0
No files found.
python/dgl/graphbolt/itemset.py
View file @
1328baf7
...
...
@@ -86,9 +86,23 @@ class ItemSet:
if
len
(
self
.
_items
)
==
1
:
yield
from
self
.
_items
[
0
]
return
zip_items
=
zip
(
*
self
.
_items
)
for
item
in
zip_items
:
yield
tuple
(
item
)
if
isinstance
(
self
.
_items
[
0
],
Sized
):
items_len
=
len
(
self
.
_items
[
0
])
# Use for-loop to iterate over the items. Can avoid a long
# wait time when the items are torch tensors. Since torch
# tensors need to call self.unbind(0) to slice themselves.
# While for-loops are slower than zip, they prevent excessive
# wait times during the loading phase, and the impact on overall
# performance during the training/testing stage is minimal.
# For more details, see https://github.com/dmlc/dgl/pull/6293.
for
i
in
range
(
items_len
):
yield
tuple
(
item
[
i
]
for
item
in
self
.
_items
)
else
:
# If the items are not Sized, we use zip to iterate over them.
zip_items
=
zip
(
*
self
.
_items
)
for
item
in
zip_items
:
yield
tuple
(
item
)
def
__len__
(
self
)
->
int
:
if
isinstance
(
self
.
_items
[
0
],
Sized
):
...
...
tests/python/pytorch/graphbolt/test_itemset.py
View file @
1328baf7
...
...
@@ -36,10 +36,16 @@ def test_ItemSet_length():
ids
=
torch
.
arange
(
0
,
5
)
item_set
=
gb
.
ItemSet
(
ids
)
assert
len
(
item_set
)
==
5
# Test __iter__ method. Same as below.
for
i
,
item
in
enumerate
(
item_set
):
assert
i
==
item
.
item
()
# Tuple of iterables with valid length.
item_set
=
gb
.
ItemSet
((
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
)))
assert
len
(
item_set
)
==
5
for
i
,
(
item1
,
item2
)
in
enumerate
(
item_set
):
assert
i
==
item1
.
item
()
assert
i
+
5
==
item2
.
item
()
class
InvalidLength
:
def
__iter__
(
self
):
...
...
@@ -49,11 +55,16 @@ def test_ItemSet_length():
item_set
=
gb
.
ItemSet
(
InvalidLength
())
with
pytest
.
raises
(
TypeError
):
_
=
len
(
item_set
)
for
i
,
item
in
enumerate
(
item_set
):
assert
i
==
item
# Tuple of iterables with invalid length.
item_set
=
gb
.
ItemSet
((
InvalidLength
(),
InvalidLength
()))
with
pytest
.
raises
(
TypeError
):
_
=
len
(
item_set
)
for
i
,
(
item1
,
item2
)
in
enumerate
(
item_set
):
assert
i
==
item1
assert
i
==
item2
def
test_ItemSet_iteration_seed_nodes
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment