Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
c8ec9ce3
Unverified
Commit
c8ec9ce3
authored
Oct 16, 2023
by
Rhett Ying
Committed by
GitHub
Oct 16, 2023
Browse files
[GraphBolt] enable indexing on ItemSet instance (#6439)
parent
a548805d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
124 additions
and
33 deletions
+124
-33
python/dgl/graphbolt/itemset.py
python/dgl/graphbolt/itemset.py
+39
-21
tests/python/pytorch/graphbolt/test_itemset.py
tests/python/pytorch/graphbolt/test_itemset.py
+85
-12
No files found.
python/dgl/graphbolt/itemset.py
View file @
c8ec9ce3
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
from
typing
import
Dict
,
Iterable
,
Iterator
,
Sized
,
Tuple
,
Union
from
typing
import
Dict
,
Iterable
,
Iterator
,
Sized
,
Tuple
,
Union
import
torch
__all__
=
[
"ItemSet"
,
"ItemSetDict"
]
__all__
=
[
"ItemSet"
,
"ItemSetDict"
]
...
@@ -33,7 +35,10 @@ class ItemSet:
...
@@ -33,7 +35,10 @@ class ItemSet:
>>> num = 10
>>> num = 10
>>> item_set = gb.ItemSet(num, names="seed_nodes")
>>> item_set = gb.ItemSet(num, names="seed_nodes")
>>> list(item_set)
>>> list(item_set)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5),
tensor(6), tensor(7), tensor(8), tensor(9)]
>>> item_set[torch.arange(0, num)]
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> item_set.names
>>> item_set.names
('seed_nodes',)
('seed_nodes',)
...
@@ -42,6 +47,8 @@ class ItemSet:
...
@@ -42,6 +47,8 @@ class ItemSet:
>>> item_set = gb.ItemSet(node_ids, names="seed_nodes")
>>> item_set = gb.ItemSet(node_ids, names="seed_nodes")
>>> list(item_set)
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)]
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4])
>>> item_set.names
>>> item_set.names
('seed_nodes',)
('seed_nodes',)
...
@@ -53,6 +60,8 @@ class ItemSet:
...
@@ -53,6 +60,8 @@ class ItemSet:
>>> list(item_set)
>>> list(item_set)
[(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)),
[(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)),
(tensor(3), tensor(8)), (tensor(4), tensor(9))]
(tensor(3), tensor(8)), (tensor(4), tensor(9))]
>>> item_set[:]
(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))
>>> item_set.names
>>> item_set.names
('seed_nodes', 'labels')
('seed_nodes', 'labels')
...
@@ -67,6 +76,10 @@ class ItemSet:
...
@@ -67,6 +76,10 @@ class ItemSet:
(tensor([4, 5]), tensor([16, 17, 18])),
(tensor([4, 5]), tensor([16, 17, 18])),
(tensor([6, 7]), tensor([19, 20, 21])),
(tensor([6, 7]), tensor([19, 20, 21])),
(tensor([8, 9]), tensor([22, 23, 24]))]
(tensor([8, 9]), tensor([22, 23, 24]))]
>>> item_set[:]
(tensor([[0, 1], [2, 3], [4, 5], [6, 7],[8, 9]]),
tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18], [19, 20, 21],
[22, 23, 24]]))
>>> item_set.names
>>> item_set.names
('node_pairs', 'negative_dsts')
('node_pairs', 'negative_dsts')
"""
"""
...
@@ -76,33 +89,20 @@ class ItemSet:
...
@@ -76,33 +89,20 @@ class ItemSet:
items
:
Union
[
int
,
Iterable
,
Tuple
[
Iterable
]],
items
:
Union
[
int
,
Iterable
,
Tuple
[
Iterable
]],
names
:
Union
[
str
,
Tuple
[
str
]]
=
None
,
names
:
Union
[
str
,
Tuple
[
str
]]
=
None
,
)
->
None
:
)
->
None
:
# Initiated by an integer.
if
isinstance
(
items
,
(
int
,
tuple
)):
if
isinstance
(
items
,
int
):
self
.
_items
=
items
if
names
is
not
None
:
if
isinstance
(
names
,
tuple
):
self
.
_names
=
names
else
:
self
.
_names
=
(
names
,)
assert
(
len
(
self
.
_names
)
==
1
),
"Number of names mustn't exceed 1 when item is an integer."
else
:
self
.
_names
=
None
return
# Otherwise.
if
isinstance
(
items
,
tuple
):
self
.
_items
=
items
self
.
_items
=
items
else
:
else
:
self
.
_items
=
(
items
,)
self
.
_items
=
(
items
,)
if
names
is
not
None
:
if
names
is
not
None
:
num_items
=
(
len
(
self
.
_items
)
if
isinstance
(
self
.
_items
,
tuple
)
else
1
)
if
isinstance
(
names
,
tuple
):
if
isinstance
(
names
,
tuple
):
self
.
_names
=
names
self
.
_names
=
names
else
:
else
:
self
.
_names
=
(
names
,)
self
.
_names
=
(
names
,)
assert
len
(
self
.
_items
)
==
len
(
self
.
_names
),
(
assert
num
_items
==
len
(
self
.
_names
),
(
f
"Number of items (
{
len
(
self
.
_items
)
}
) and "
f
"Number of items (
{
num
_items
}
) and "
f
"names (
{
len
(
self
.
_names
)
}
) must match."
f
"names (
{
len
(
self
.
_names
)
}
) must match."
)
)
else
:
else
:
...
@@ -110,7 +110,7 @@ class ItemSet:
...
@@ -110,7 +110,7 @@ class ItemSet:
def
__iter__
(
self
)
->
Iterator
:
def
__iter__
(
self
)
->
Iterator
:
if
isinstance
(
self
.
_items
,
int
):
if
isinstance
(
self
.
_items
,
int
):
yield
from
range
(
self
.
_items
)
yield
from
torch
.
a
range
(
self
.
_items
)
return
return
if
len
(
self
.
_items
)
==
1
:
if
len
(
self
.
_items
)
==
1
:
...
@@ -143,6 +143,24 @@ class ItemSet:
...
@@ -143,6 +143,24 @@ class ItemSet:
f
"
{
type
(
self
).
__name__
}
instance doesn't have valid length."
f
"
{
type
(
self
).
__name__
}
instance doesn't have valid length."
)
)
def
__getitem__
(
self
,
idx
:
Union
[
int
,
slice
,
Iterable
])
->
Tuple
:
try
:
len
(
self
)
except
TypeError
:
raise
TypeError
(
f
"
{
type
(
self
).
__name__
}
instance doesn't support indexing."
)
if
isinstance
(
self
.
_items
,
int
):
assert
isinstance
(
idx
,
(
int
,
torch
.
Tensor
)),
(
f
"Indexing of integer-initialized
{
type
(
self
).
__name__
}
"
f
"instance must be int or torch.Tensor."
)
# [Warning] Index range is not checked.
return
idx
if
len
(
self
.
_items
)
==
1
:
return
self
.
_items
[
0
][
idx
]
return
tuple
(
item
[
idx
]
for
item
in
self
.
_items
)
@
property
@
property
def
names
(
self
)
->
Tuple
[
str
]:
def
names
(
self
)
->
Tuple
[
str
]:
"""Return the names of the items."""
"""Return the names of the items."""
...
...
tests/python/pytorch/graphbolt/test_itemset.py
View file @
c8ec9ce3
...
@@ -26,9 +26,7 @@ def test_ItemSet_names():
...
@@ -26,9 +26,7 @@ def test_ItemSet_names():
# Integer-initiated ItemSet with excessive names.
# Integer-initiated ItemSet with excessive names.
with
pytest
.
raises
(
with
pytest
.
raises
(
AssertionError
,
AssertionError
,
match
=
re
.
escape
(
match
=
re
.
escape
(
"Number of items (1) and names (2) must match."
),
"Number of names mustn't exceed 1 when item is an integer."
),
):
):
_
=
gb
.
ItemSet
(
5
,
names
=
(
"seed_nodes"
,
"labels"
))
_
=
gb
.
ItemSet
(
5
,
names
=
(
"seed_nodes"
,
"labels"
))
...
@@ -69,61 +67,123 @@ def test_ItemSet_length():
...
@@ -69,61 +67,123 @@ def test_ItemSet_length():
# Single iterable with invalid length.
# Single iterable with invalid length.
item_set
=
gb
.
ItemSet
(
InvalidLength
())
item_set
=
gb
.
ItemSet
(
InvalidLength
())
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
,
match
=
"ItemSet instance doesn't have valid length."
):
_
=
len
(
item_set
)
_
=
len
(
item_set
)
with
pytest
.
raises
(
TypeError
,
match
=
"ItemSet instance doesn't support indexing."
):
_
=
item_set
[
0
]
for
i
,
item
in
enumerate
(
item_set
):
for
i
,
item
in
enumerate
(
item_set
):
assert
i
==
item
assert
i
==
item
# Tuple of iterables with invalid length.
# Tuple of iterables with invalid length.
item_set
=
gb
.
ItemSet
((
InvalidLength
(),
InvalidLength
()))
item_set
=
gb
.
ItemSet
((
InvalidLength
(),
InvalidLength
()))
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
,
match
=
"ItemSet instance doesn't have valid length."
):
_
=
len
(
item_set
)
_
=
len
(
item_set
)
with
pytest
.
raises
(
TypeError
,
match
=
"ItemSet instance doesn't support indexing."
):
_
=
item_set
[
0
]
for
i
,
(
item1
,
item2
)
in
enumerate
(
item_set
):
for
i
,
(
item1
,
item2
)
in
enumerate
(
item_set
):
assert
i
==
item1
assert
i
==
item1
assert
i
==
item2
assert
i
==
item2
def
test_ItemSet_
iteration_
seed_nodes
():
def
test_ItemSet_seed_nodes
():
# Node IDs.
# Node IDs
with tensor
.
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
),
names
=
"seed_nodes"
)
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
),
names
=
"seed_nodes"
)
assert
item_set
.
names
==
(
"seed_nodes"
,)
assert
item_set
.
names
==
(
"seed_nodes"
,)
# Iterating over ItemSet and indexing one by one.
for
i
,
item
in
enumerate
(
item_set
):
assert
i
==
item
.
item
()
assert
i
==
item_set
[
i
]
# Indexing with a slice.
assert
torch
.
equal
(
item_set
[:],
torch
.
arange
(
0
,
5
))
# Indexing with an Iterable.
assert
torch
.
equal
(
item_set
[
torch
.
arange
(
0
,
5
)],
torch
.
arange
(
0
,
5
))
# Node IDs with single integer.
item_set
=
gb
.
ItemSet
(
5
,
names
=
"seed_nodes"
)
assert
item_set
.
names
==
(
"seed_nodes"
,)
# Iterating over ItemSet and indexing one by one.
for
i
,
item
in
enumerate
(
item_set
):
for
i
,
item
in
enumerate
(
item_set
):
assert
i
==
item
.
item
()
assert
i
==
item
.
item
()
assert
i
==
item_set
[
i
]
# Indexing with a slice.
with
pytest
.
raises
(
AssertionError
,
match
=
(
"Indexing of integer-initialized ItemSet instance must be int or "
"torch.Tensor."
),
):
_
=
item_set
[:]
# Indexing with an Tensor.
assert
torch
.
equal
(
item_set
[
torch
.
arange
(
0
,
5
)],
torch
.
arange
(
0
,
5
))
def
test_ItemSet_
iteration_
seed_nodes_labels
():
def
test_ItemSet_seed_nodes_labels
():
# Node IDs and labels.
# Node IDs and labels.
seed_nodes
=
torch
.
arange
(
0
,
5
)
seed_nodes
=
torch
.
arange
(
0
,
5
)
labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
item_set
=
gb
.
ItemSet
((
seed_nodes
,
labels
),
names
=
(
"seed_nodes"
,
"labels"
))
item_set
=
gb
.
ItemSet
((
seed_nodes
,
labels
),
names
=
(
"seed_nodes"
,
"labels"
))
assert
item_set
.
names
==
(
"seed_nodes"
,
"labels"
)
assert
item_set
.
names
==
(
"seed_nodes"
,
"labels"
)
# Iterating over ItemSet and indexing one by one.
for
i
,
(
seed_node
,
label
)
in
enumerate
(
item_set
):
for
i
,
(
seed_node
,
label
)
in
enumerate
(
item_set
):
assert
seed_node
==
seed_nodes
[
i
]
assert
seed_node
==
seed_nodes
[
i
]
assert
label
==
labels
[
i
]
assert
label
==
labels
[
i
]
assert
seed_node
==
item_set
[
i
][
0
]
assert
label
==
item_set
[
i
][
1
]
# Indexing with a slice.
assert
torch
.
equal
(
item_set
[:][
0
],
seed_nodes
)
assert
torch
.
equal
(
item_set
[:][
1
],
labels
)
# Indexing with an Iterable.
assert
torch
.
equal
(
item_set
[
torch
.
arange
(
0
,
5
)][
0
],
seed_nodes
)
assert
torch
.
equal
(
item_set
[
torch
.
arange
(
0
,
5
)][
1
],
labels
)
def
test_ItemSet_
iteration_
node_pairs
():
def
test_ItemSet_node_pairs
():
# Node pairs.
# Node pairs.
node_pairs
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
node_pairs
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
item_set
=
gb
.
ItemSet
(
node_pairs
,
names
=
"node_pairs"
)
item_set
=
gb
.
ItemSet
(
node_pairs
,
names
=
"node_pairs"
)
assert
item_set
.
names
==
(
"node_pairs"
,)
assert
item_set
.
names
==
(
"node_pairs"
,)
# Iterating over ItemSet and indexing one by one.
for
i
,
(
src
,
dst
)
in
enumerate
(
item_set
):
for
i
,
(
src
,
dst
)
in
enumerate
(
item_set
):
assert
node_pairs
[
i
][
0
]
==
src
assert
node_pairs
[
i
][
0
]
==
src
assert
node_pairs
[
i
][
1
]
==
dst
assert
node_pairs
[
i
][
1
]
==
dst
assert
node_pairs
[
i
][
0
]
==
item_set
[
i
][
0
]
assert
node_pairs
[
i
][
1
]
==
item_set
[
i
][
1
]
# Indexing with a slice.
assert
torch
.
equal
(
item_set
[:],
node_pairs
)
# Indexing with an Iterable.
assert
torch
.
equal
(
item_set
[
torch
.
arange
(
0
,
5
)],
node_pairs
)
def
test_ItemSet_
iteration_
node_pairs_labels
():
def
test_ItemSet_node_pairs_labels
():
# Node pairs and labels
# Node pairs and labels
node_pairs
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
node_pairs
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
item_set
=
gb
.
ItemSet
((
node_pairs
,
labels
),
names
=
(
"node_pairs"
,
"labels"
))
item_set
=
gb
.
ItemSet
((
node_pairs
,
labels
),
names
=
(
"node_pairs"
,
"labels"
))
assert
item_set
.
names
==
(
"node_pairs"
,
"labels"
)
assert
item_set
.
names
==
(
"node_pairs"
,
"labels"
)
# Iterating over ItemSet and indexing one by one.
for
i
,
(
node_pair
,
label
)
in
enumerate
(
item_set
):
for
i
,
(
node_pair
,
label
)
in
enumerate
(
item_set
):
assert
torch
.
equal
(
node_pairs
[
i
],
node_pair
)
assert
torch
.
equal
(
node_pairs
[
i
],
node_pair
)
assert
labels
[
i
]
==
label
assert
labels
[
i
]
==
label
assert
torch
.
equal
(
node_pairs
[
i
],
item_set
[
i
][
0
])
assert
labels
[
i
]
==
item_set
[
i
][
1
]
# Indexing with a slice.
assert
torch
.
equal
(
item_set
[:][
0
],
node_pairs
)
assert
torch
.
equal
(
item_set
[:][
1
],
labels
)
# Indexing with an Iterable.
assert
torch
.
equal
(
item_set
[
torch
.
arange
(
0
,
5
)][
0
],
node_pairs
)
assert
torch
.
equal
(
item_set
[
torch
.
arange
(
0
,
5
)][
1
],
labels
)
def
test_ItemSet_
iteration_
node_pairs_neg_dsts
():
def
test_ItemSet_node_pairs_neg_dsts
():
# Node pairs and negative destinations.
# Node pairs and negative destinations.
node_pairs
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
node_pairs
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
neg_dsts
=
torch
.
arange
(
10
,
25
).
reshape
(
-
1
,
3
)
neg_dsts
=
torch
.
arange
(
10
,
25
).
reshape
(
-
1
,
3
)
...
@@ -131,18 +191,31 @@ def test_ItemSet_iteration_node_pairs_neg_dsts():
...
@@ -131,18 +191,31 @@ def test_ItemSet_iteration_node_pairs_neg_dsts():
(
node_pairs
,
neg_dsts
),
names
=
(
"node_pairs"
,
"negative_dsts"
)
(
node_pairs
,
neg_dsts
),
names
=
(
"node_pairs"
,
"negative_dsts"
)
)
)
assert
item_set
.
names
==
(
"node_pairs"
,
"negative_dsts"
)
assert
item_set
.
names
==
(
"node_pairs"
,
"negative_dsts"
)
# Iterating over ItemSet and indexing one by one.
for
i
,
(
node_pair
,
neg_dst
)
in
enumerate
(
item_set
):
for
i
,
(
node_pair
,
neg_dst
)
in
enumerate
(
item_set
):
assert
torch
.
equal
(
node_pairs
[
i
],
node_pair
)
assert
torch
.
equal
(
node_pairs
[
i
],
node_pair
)
assert
torch
.
equal
(
neg_dsts
[
i
],
neg_dst
)
assert
torch
.
equal
(
neg_dsts
[
i
],
neg_dst
)
assert
torch
.
equal
(
node_pairs
[
i
],
item_set
[
i
][
0
])
assert
torch
.
equal
(
neg_dsts
[
i
],
item_set
[
i
][
1
])
# Indexing with a slice.
assert
torch
.
equal
(
item_set
[:][
0
],
node_pairs
)
assert
torch
.
equal
(
item_set
[:][
1
],
neg_dsts
)
# Indexing with an Iterable.
assert
torch
.
equal
(
item_set
[
torch
.
arange
(
0
,
5
)][
0
],
node_pairs
)
assert
torch
.
equal
(
item_set
[
torch
.
arange
(
0
,
5
)][
1
],
neg_dsts
)
def
test_ItemSet_
iteration_
graphs
():
def
test_ItemSet_graphs
():
# Graphs.
# Graphs.
graphs
=
[
dgl
.
rand_graph
(
10
,
20
)
for
_
in
range
(
5
)]
graphs
=
[
dgl
.
rand_graph
(
10
,
20
)
for
_
in
range
(
5
)]
item_set
=
gb
.
ItemSet
(
graphs
)
item_set
=
gb
.
ItemSet
(
graphs
)
assert
item_set
.
names
is
None
assert
item_set
.
names
is
None
# Iterating over ItemSet and indexing one by one.
for
i
,
item
in
enumerate
(
item_set
):
for
i
,
item
in
enumerate
(
item_set
):
assert
graphs
[
i
]
==
item
assert
graphs
[
i
]
==
item
assert
graphs
[
i
]
==
item_set
[
i
]
# Indexing with a slice.
assert
item_set
[:]
==
graphs
def
test_ItemSetDict_names
():
def
test_ItemSetDict_names
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment