Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
bdb78758
You need to sign in or sign up before continuing.
Unverified
Commit
bdb78758
authored
Jun 12, 2023
by
Rhett Ying
Committed by
GitHub
Jun 12, 2023
Browse files
[GraphBolt] define __len__ for ItemSet/ItemSetDict (#5844)
parent
f5330cb6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
91 additions
and
9 deletions
+91
-9
python/dgl/graphbolt/itemset.py
python/dgl/graphbolt/itemset.py
+11
-4
tests/python/pytorch/graphbolt/test_itemset.py
tests/python/pytorch/graphbolt/test_itemset.py
+80
-5
No files found.
python/dgl/graphbolt/itemset.py
View file @
bdb78758
"""GraphBolt Itemset."""
"""GraphBolt Itemset."""
from
typing
import
Dict
,
Iterable
,
Iterator
,
Tuple
from
typing
import
Dict
,
Iterable
,
Iterator
,
Sized
,
Tuple
__all__
=
[
"ItemSet"
,
"ItemSetDict"
]
__all__
=
[
"ItemSet"
,
"ItemSetDict"
]
...
@@ -49,9 +49,6 @@ class ItemSet:
...
@@ -49,9 +49,6 @@ class ItemSet:
def
__init__
(
self
,
items
:
Iterable
or
Tuple
[
Iterable
])
->
None
:
def
__init__
(
self
,
items
:
Iterable
or
Tuple
[
Iterable
])
->
None
:
if
isinstance
(
items
,
tuple
):
if
isinstance
(
items
,
tuple
):
assert
all
(
items
[
0
].
size
(
0
)
==
item
.
size
(
0
)
for
item
in
items
),
"Size mismatch between items in tuple."
self
.
_items
=
items
self
.
_items
=
items
else
:
else
:
self
.
_items
=
(
items
,)
self
.
_items
=
(
items
,)
...
@@ -64,6 +61,13 @@ class ItemSet:
...
@@ -64,6 +61,13 @@ class ItemSet:
for
item
in
zip_items
:
for
item
in
zip_items
:
yield
tuple
(
item
)
yield
tuple
(
item
)
def
__len__
(
self
)
->
int
:
if
isinstance
(
self
.
_items
[
0
],
Sized
):
return
len
(
self
.
_items
[
0
])
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
instance doesn't have valid length."
)
class
ItemSetDict
:
class
ItemSetDict
:
r
"""An iterable ItemsetDict.
r
"""An iterable ItemsetDict.
...
@@ -128,3 +132,6 @@ class ItemSetDict:
...
@@ -128,3 +132,6 @@ class ItemSetDict:
for
key
,
itemset
in
self
.
_itemsets
.
items
():
for
key
,
itemset
in
self
.
_itemsets
.
items
():
for
item
in
itemset
:
for
item
in
itemset
:
yield
{
key
:
item
}
yield
{
key
:
item
}
def
__len__
(
self
)
->
int
:
return
sum
(
len
(
itemset
)
for
itemset
in
self
.
_itemsets
.
values
())
tests/python/pytorch/graphbolt/test_itemset.py
View file @
bdb78758
...
@@ -5,11 +5,86 @@ from dgl import graphbolt as gb
...
@@ -5,11 +5,86 @@ from dgl import graphbolt as gb
from
torch.testing
import
assert_close
from
torch.testing
import
assert_close
def
test_mismatch_size_in_tuple
():
def
test_ItemSet_valid_length
():
# Size mismatch.
# Single iterable.
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
11
))
ids
=
torch
.
arange
(
0
,
5
)
with
pytest
.
raises
(
AssertionError
):
item_set
=
gb
.
ItemSet
(
ids
)
_
=
gb
.
ItemSet
(
node_pairs
)
assert
len
(
item_set
)
==
5
# Tuple of iterables.
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
item_set
=
gb
.
ItemSet
(
node_pairs
)
assert
len
(
item_set
)
==
5
def
test_ItemSet_invalid_length
():
class
InvalidLength
:
def
__iter__
(
self
):
return
iter
([
0
,
1
,
2
])
# Single iterable.
item_set
=
gb
.
ItemSet
(
InvalidLength
())
with
pytest
.
raises
(
TypeError
):
_
=
len
(
item_set
)
# Tuple of iterables.
item_set
=
gb
.
ItemSet
((
InvalidLength
(),
InvalidLength
()))
with
pytest
.
raises
(
TypeError
):
_
=
len
(
item_set
)
def
test_ItemSetDict_valid_length
():
# Single iterable.
user_ids
=
torch
.
arange
(
0
,
5
)
item_ids
=
torch
.
arange
(
0
,
5
)
item_set
=
gb
.
ItemSetDict
(
{
"user"
:
gb
.
ItemSet
(
user_ids
),
"item"
:
gb
.
ItemSet
(
item_ids
),
}
)
assert
len
(
item_set
)
==
len
(
user_ids
)
+
len
(
item_ids
)
# Tuple of iterables.
like
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
0
,
5
))
follow
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
item_set
=
gb
.
ItemSetDict
(
{
(
"user"
,
"like"
,
"item"
):
gb
.
ItemSet
(
like
),
(
"user"
,
"follow"
,
"user"
):
gb
.
ItemSet
(
follow
),
}
)
assert
len
(
item_set
)
==
len
(
like
[
0
])
+
len
(
follow
[
0
])
def
test_ItemSetDict_invalid_length
():
class
InvalidLength
:
def
__iter__
(
self
):
return
iter
([
0
,
1
,
2
])
# Single iterable.
item_set
=
gb
.
ItemSetDict
(
{
"user"
:
gb
.
ItemSet
(
InvalidLength
()),
"item"
:
gb
.
ItemSet
(
InvalidLength
()),
}
)
with
pytest
.
raises
(
TypeError
):
_
=
len
(
item_set
)
# Tuple of iterables.
item_set
=
gb
.
ItemSetDict
(
{
(
"user"
,
"like"
,
"item"
):
gb
.
ItemSet
(
(
InvalidLength
(),
InvalidLength
())
),
(
"user"
,
"follow"
,
"user"
):
gb
.
ItemSet
(
(
InvalidLength
(),
InvalidLength
())
),
}
)
with
pytest
.
raises
(
TypeError
):
_
=
len
(
item_set
)
def
test_ItemSet_node_edge_ids
():
def
test_ItemSet_node_edge_ids
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment