Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9d5b897a
Unverified
Commit
9d5b897a
authored
Oct 18, 2023
by
Rhett Ying
Committed by
GitHub
Oct 18, 2023
Browse files
[GraphBolt] enable indexing for ItemSetDict (#6459)
parent
8b37564b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
143 additions
and
2 deletions
+143
-2
python/dgl/graphbolt/itemset.py
python/dgl/graphbolt/itemset.py
+58
-0
tests/python/pytorch/graphbolt/test_itemset.py
tests/python/pytorch/graphbolt/test_itemset.py
+85
-2
No files found.
python/dgl/graphbolt/itemset.py
View file @
9d5b897a
...
...
@@ -202,6 +202,8 @@ class ItemSetDict:
{"user": tensor(3)}, {"user": tensor(4)}, {"item": tensor(5)},
{"item": tensor(6)}, {"item": tensor(7)}, {"item": tensor(8)},
{"item": tensor(9)}}]
>>> item_set[:]
{"user": tensor([0, 1, 2, 3, 4]), "item": tensor([5, 6, 7, 8, 9])}
>>> item_set.names
('seed_nodes',)
...
...
@@ -222,6 +224,9 @@ class ItemSetDict:
[{"user": (tensor(0), tensor(0))}, {"user": (tensor(1), tensor(1))},
{"item": (tensor(2), tensor(2))}, {"item": (tensor(3), tensor(3))},
{"item": (tensor(4), tensor(4))}}]
>>> item_set[:]
{"user": (tensor([0, 1]), tensor([0, 1])),
"item": (tensor([2, 3, 4]), tensor([2, 3, 4]))}
>>> item_set.names
('seed_nodes', 'labels')
...
...
@@ -244,6 +249,13 @@ class ItemSetDict:
{"user:follow:user": (tensor([0, 1]), tensor([ 6, 7, 8, 9, 10, 11]))},
{"user:follow:user": (tensor([2, 3]), tensor([12, 13, 14, 15, 16, 17]))},
{"user:follow:user": (tensor([4, 5]), tensor([18, 19, 20, 21, 22, 23]))}]
>>> item_set[:]
{"user:like:item": (tensor([[0, 1], [2, 3]]),
tensor([[4, 5, 6], [7, 8, 9]])),
"user:follow:user": (tensor([[0, 1], [2, 3], [4, 5]]),
tensor([[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17],
[18, 19, 20, 21, 22, 23]]))}
>>> item_set.names
('node_pairs', 'negative_dsts')
"""
...
...
@@ -254,6 +266,15 @@ class ItemSetDict:
assert
all
(
self
.
_names
==
itemset
.
names
for
itemset
in
itemsets
.
values
()
),
"All itemsets must have the same names."
try
:
# For indexable itemsets, we compute the offsets for each itemset
# in advance to speed up indexing.
offsets
=
[
0
]
+
[
len
(
itemset
)
for
itemset
in
self
.
_itemsets
.
values
()
]
self
.
_offsets
=
torch
.
tensor
(
offsets
).
cumsum
(
0
)
except
TypeError
:
self
.
_offsets
=
None
def
__iter__
(
self
)
->
Iterator
:
for
key
,
itemset
in
self
.
_itemsets
.
items
():
...
...
@@ -263,6 +284,43 @@ class ItemSetDict:
def
__len__
(
self
)
->
int
:
return
sum
(
len
(
itemset
)
for
itemset
in
self
.
_itemsets
.
values
())
def
__getitem__
(
self
,
idx
:
Union
[
int
,
slice
])
->
Dict
[
str
,
Tuple
]:
if
self
.
_offsets
is
None
:
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
instance doesn't support indexing."
)
total_num
=
self
.
_offsets
[
-
1
]
if
isinstance
(
idx
,
int
):
if
idx
<
0
:
idx
+=
total_num
if
idx
<
0
or
idx
>=
total_num
:
raise
IndexError
(
f
"
{
type
(
self
).
__name__
}
index out of range."
)
offset_idx
=
torch
.
searchsorted
(
self
.
_offsets
,
idx
,
right
=
True
)
offset_idx
-=
1
idx
-=
self
.
_offsets
[
offset_idx
]
key
=
list
(
self
.
_itemsets
.
keys
())[
offset_idx
]
return
{
key
:
self
.
_itemsets
[
key
][
idx
]}
elif
isinstance
(
idx
,
slice
):
start
,
stop
,
step
=
idx
.
indices
(
total_num
)
assert
step
==
1
,
"Step must be 1."
assert
start
<
stop
,
"Start must be smaller than stop."
data
=
{}
offset_idx_start
=
max
(
1
,
torch
.
searchsorted
(
self
.
_offsets
,
start
,
right
=
False
)
)
keys
=
list
(
self
.
_itemsets
.
keys
())
for
offset_idx
in
range
(
offset_idx_start
,
len
(
self
.
_offsets
)):
key
=
keys
[
offset_idx
-
1
]
data
[
key
]
=
self
.
_itemsets
[
key
][
max
(
0
,
start
-
self
.
_offsets
[
offset_idx
-
1
])
:
stop
-
self
.
_offsets
[
offset_idx
-
1
]
]
if
stop
<=
self
.
_offsets
[
offset_idx
]:
break
return
data
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
indices must be int or slice."
)
@
property
def
names
(
self
)
->
Tuple
[
str
]:
"""Return the names of the items."""
...
...
tests/python/pytorch/graphbolt/test_itemset.py
View file @
9d5b897a
...
...
@@ -312,8 +312,14 @@ def test_ItemSetDict_length():
"item"
:
gb
.
ItemSet
(
InvalidLength
()),
}
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
,
match
=
"ItemSet instance doesn't have valid length."
):
_
=
len
(
item_set
)
with
pytest
.
raises
(
TypeError
,
match
=
"ItemSetDict instance doesn't support indexing."
):
_
=
item_set
[
0
]
# Tuple of iterables with invalid length.
item_set
=
gb
.
ItemSetDict
(
...
...
@@ -322,8 +328,14 @@ def test_ItemSetDict_length():
"user:follow:user"
:
gb
.
ItemSet
((
InvalidLength
(),
InvalidLength
())),
}
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
,
match
=
"ItemSet instance doesn't have valid length."
):
_
=
len
(
item_set
)
with
pytest
.
raises
(
TypeError
,
match
=
"ItemSetDict instance doesn't support indexing."
):
_
=
item_set
[
0
]
def
test_ItemSetDict_iteration_seed_nodes
():
...
...
@@ -339,11 +351,48 @@ def test_ItemSetDict_iteration_seed_nodes():
chained_ids
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
ItemSetDict
(
ids
)
assert
item_set
.
names
==
(
"seed_nodes"
,)
# Iterating over ItemSetDict and indexing one by one.
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
chained_ids
[
i
][
0
]
in
item
assert
item
[
chained_ids
[
i
][
0
]]
==
chained_ids
[
i
][
1
]
assert
item_set
[
i
]
==
item
assert
item_set
[
i
-
len
(
item_set
)]
==
item
# Indexing all with a slice.
assert
torch
.
equal
(
item_set
[:][
"user"
],
user_ids
)
assert
torch
.
equal
(
item_set
[:][
"item"
],
item_ids
)
# Indexing partial with a slice.
partial_data
=
item_set
[:
3
]
assert
len
(
list
(
partial_data
.
keys
()))
==
1
assert
torch
.
equal
(
partial_data
[
"user"
],
user_ids
[:
3
])
partial_data
=
item_set
[
7
:]
assert
len
(
list
(
partial_data
.
keys
()))
==
1
assert
torch
.
equal
(
partial_data
[
"item"
],
item_ids
[
2
:])
partial_data
=
item_set
[
3
:
7
]
assert
len
(
list
(
partial_data
.
keys
()))
==
2
assert
torch
.
equal
(
partial_data
[
"user"
],
user_ids
[
3
:
5
])
assert
torch
.
equal
(
partial_data
[
"item"
],
item_ids
[:
2
])
# Exception cases.
with
pytest
.
raises
(
AssertionError
,
match
=
"Step must be 1."
):
_
=
item_set
[::
2
]
with
pytest
.
raises
(
AssertionError
,
match
=
"Start must be smaller than stop."
):
_
=
item_set
[
5
:
3
]
with
pytest
.
raises
(
AssertionError
,
match
=
"Start must be smaller than stop."
):
_
=
item_set
[
-
1
:
3
]
with
pytest
.
raises
(
IndexError
,
match
=
"ItemSetDict index out of range."
):
_
=
item_set
[
20
]
with
pytest
.
raises
(
IndexError
,
match
=
"ItemSetDict index out of range."
):
_
=
item_set
[
-
20
]
with
pytest
.
raises
(
TypeError
,
match
=
"ItemSetDict indices must be int or slice."
):
_
=
item_set
[
torch
.
arange
(
3
)]
def
test_ItemSetDict_iteration_seed_nodes_labels
():
...
...
@@ -365,11 +414,18 @@ def test_ItemSetDict_iteration_seed_nodes_labels():
chained_ids
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
ItemSetDict
(
ids_labels
)
assert
item_set
.
names
==
(
"seed_nodes"
,
"labels"
)
# Iterating over ItemSetDict and indexing one by one.
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
chained_ids
[
i
][
0
]
in
item
assert
item
[
chained_ids
[
i
][
0
]]
==
chained_ids
[
i
][
1
]
assert
item_set
[
i
]
==
item
# Indexing with a slice.
assert
torch
.
equal
(
item_set
[:][
"user"
][
0
],
user_ids
)
assert
torch
.
equal
(
item_set
[:][
"user"
][
1
],
user_labels
)
assert
torch
.
equal
(
item_set
[:][
"item"
][
0
],
item_ids
)
assert
torch
.
equal
(
item_set
[:][
"item"
][
1
],
item_labels
)
def
test_ItemSetDict_iteration_node_pairs
():
...
...
@@ -384,11 +440,18 @@ def test_ItemSetDict_iteration_node_pairs():
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
ItemSetDict
(
node_pairs_dict
)
assert
item_set
.
names
==
(
"node_pairs"
,)
# Iterating over ItemSetDict and indexing one by one.
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
expected_data
[
i
][
0
]
in
item
assert
torch
.
equal
(
item
[
expected_data
[
i
][
0
]],
expected_data
[
i
][
1
])
assert
item_set
[
i
].
keys
()
==
item
.
keys
()
key
=
list
(
item
.
keys
())[
0
]
assert
torch
.
equal
(
item_set
[
i
][
key
],
item
[
key
])
# Indexing with a slice.
assert
torch
.
equal
(
item_set
[:][
"user:like:item"
],
node_pairs
)
assert
torch
.
equal
(
item_set
[:][
"user:follow:user"
],
node_pairs
)
def
test_ItemSetDict_iteration_node_pairs_labels
():
...
...
@@ -408,6 +471,7 @@ def test_ItemSetDict_iteration_node_pairs_labels():
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
ItemSetDict
(
node_pairs_labels
)
assert
item_set
.
names
==
(
"node_pairs"
,
"labels"
)
# Iterating over ItemSetDict and indexing one by one.
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
...
...
@@ -415,6 +479,15 @@ def test_ItemSetDict_iteration_node_pairs_labels():
assert
key
in
item
assert
torch
.
equal
(
item
[
key
][
0
],
value
[
0
])
assert
item
[
key
][
1
]
==
value
[
1
]
assert
item_set
[
i
].
keys
()
==
item
.
keys
()
key
=
list
(
item
.
keys
())[
0
]
assert
torch
.
equal
(
item_set
[
i
][
key
][
0
],
item
[
key
][
0
])
assert
torch
.
equal
(
item_set
[
i
][
key
][
1
],
item
[
key
][
1
])
# Indexing with a slice.
assert
torch
.
equal
(
item_set
[:][
"user:like:item"
][
0
],
node_pairs
)
assert
torch
.
equal
(
item_set
[:][
"user:like:item"
][
1
],
labels
)
assert
torch
.
equal
(
item_set
[:][
"user:follow:user"
][
0
],
node_pairs
)
assert
torch
.
equal
(
item_set
[:][
"user:follow:user"
][
1
],
labels
)
def
test_ItemSetDict_iteration_node_pairs_neg_dsts
():
...
...
@@ -434,6 +507,7 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts():
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
ItemSetDict
(
node_pairs_neg_dsts
)
assert
item_set
.
names
==
(
"node_pairs"
,
"negative_dsts"
)
# Iterating over ItemSetDict and indexing one by one.
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
...
...
@@ -441,3 +515,12 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts():
assert
key
in
item
assert
torch
.
equal
(
item
[
key
][
0
],
value
[
0
])
assert
torch
.
equal
(
item
[
key
][
1
],
value
[
1
])
assert
item_set
[
i
].
keys
()
==
item
.
keys
()
key
=
list
(
item
.
keys
())[
0
]
assert
torch
.
equal
(
item_set
[
i
][
key
][
0
],
item
[
key
][
0
])
assert
torch
.
equal
(
item_set
[
i
][
key
][
1
],
item
[
key
][
1
])
# Indexing with a slice.
assert
torch
.
equal
(
item_set
[:][
"user:like:item"
][
0
],
node_pairs
)
assert
torch
.
equal
(
item_set
[:][
"user:like:item"
][
1
],
neg_dsts
)
assert
torch
.
equal
(
item_set
[:][
"user:follow:user"
][
0
],
node_pairs
)
assert
torch
.
equal
(
item_set
[:][
"user:follow:user"
][
1
],
neg_dsts
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment