Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
fa17fd09
Unverified
Commit
fa17fd09
authored
Sep 04, 2023
by
Rhett Ying
Committed by
GitHub
Sep 04, 2023
Browse files
[GraphBolt] update tests of ItemSet to cover canonical cases (#6272)
parent
ac49220c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
191 additions
and
134 deletions
+191
-134
tests/python/pytorch/graphbolt/test_itemset.py
tests/python/pytorch/graphbolt/test_itemset.py
+191
-134
No files found.
tests/python/pytorch/graphbolt/test_itemset.py
View file @
fa17fd09
...
...
@@ -9,16 +9,17 @@ from torch.testing import assert_close
def
test_ItemSet_names
():
# ItemSet with single name.
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
),
names
=
"seed_node"
)
assert
item_set
.
names
==
(
"seed_node"
,)
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
),
names
=
"seed_node
s
"
)
assert
item_set
.
names
==
(
"seed_node
s
"
,)
# ItemSet with multiple names.
item_set
=
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
)),
names
=
(
"seed_node"
,
"label"
)
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
)),
names
=
(
"seed_nodes"
,
"labels"
),
)
assert
item_set
.
names
==
(
"seed_node"
,
"label"
)
assert
item_set
.
names
==
(
"seed_node
s
"
,
"label
s
"
)
# ItemSet with
no
name.
# ItemSet with
out
name.
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
))
assert
item_set
.
names
is
None
...
...
@@ -27,33 +28,120 @@ def test_ItemSet_names():
AssertionError
,
match
=
re
.
escape
(
"Number of items (1) and names (2) must match."
),
):
_
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
),
names
=
(
"seed_node"
,
"label"
))
_
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
),
names
=
(
"seed_nodes"
,
"labels"
))
def
test_ItemSet_length
():
# Single iterable with valid length.
ids
=
torch
.
arange
(
0
,
5
)
item_set
=
gb
.
ItemSet
(
ids
)
assert
len
(
item_set
)
==
5
# Tuple of iterables with valid length.
item_set
=
gb
.
ItemSet
((
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
)))
assert
len
(
item_set
)
==
5
class
InvalidLength
:
def
__iter__
(
self
):
return
iter
([
0
,
1
,
2
])
# Single iterable with invalid length.
item_set
=
gb
.
ItemSet
(
InvalidLength
())
with
pytest
.
raises
(
TypeError
):
_
=
len
(
item_set
)
# Tuple of iterables with invalid length.
item_set
=
gb
.
ItemSet
((
InvalidLength
(),
InvalidLength
()))
with
pytest
.
raises
(
TypeError
):
_
=
len
(
item_set
)
def
test_ItemSet_iteration_seed_nodes
():
# Node IDs.
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
),
names
=
"seed_nodes"
)
assert
item_set
.
names
==
(
"seed_nodes"
,)
for
i
,
item
in
enumerate
(
item_set
):
assert
i
==
item
.
item
()
def
test_ItemSet_iteration_seed_nodes_labels
():
# Node IDs and labels.
seed_nodes
=
torch
.
arange
(
0
,
5
)
labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
item_set
=
gb
.
ItemSet
((
seed_nodes
,
labels
),
names
=
(
"seed_nodes"
,
"labels"
))
assert
item_set
.
names
==
(
"seed_nodes"
,
"labels"
)
for
i
,
(
seed_node
,
label
)
in
enumerate
(
item_set
):
assert
seed_node
==
seed_nodes
[
i
]
assert
label
==
labels
[
i
]
def
test_ItemSet_iteration_node_pairs
():
# Node pairs.
node_pairs
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
item_set
=
gb
.
ItemSet
(
node_pairs
,
names
=
"node_pairs"
)
assert
item_set
.
names
==
(
"node_pairs"
,)
for
i
,
(
src
,
dst
)
in
enumerate
(
item_set
):
assert
node_pairs
[
i
][
0
]
==
src
assert
node_pairs
[
i
][
1
]
==
dst
def
test_ItemSet_iteration_node_pairs_labels
():
# Node pairs and labels
node_pairs
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
item_set
=
gb
.
ItemSet
((
node_pairs
,
labels
),
names
=
(
"node_pairs"
,
"labels"
))
assert
item_set
.
names
==
(
"node_pairs"
,
"labels"
)
for
i
,
(
node_pair
,
label
)
in
enumerate
(
item_set
):
assert
torch
.
equal
(
node_pairs
[
i
],
node_pair
)
assert
labels
[
i
]
==
label
def
test_ItemSet_iteration_node_pairs_neg_dsts
():
# Node pairs and negative destinations.
node_pairs
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
neg_dsts
=
torch
.
arange
(
10
,
25
).
reshape
(
-
1
,
3
)
item_set
=
gb
.
ItemSet
(
(
node_pairs
,
neg_dsts
),
names
=
(
"node_pairs"
,
"neg_dsts"
)
)
assert
item_set
.
names
==
(
"node_pairs"
,
"neg_dsts"
)
for
i
,
(
node_pair
,
neg_dst
)
in
enumerate
(
item_set
):
assert
torch
.
equal
(
node_pairs
[
i
],
node_pair
)
assert
torch
.
equal
(
neg_dsts
[
i
],
neg_dst
)
def
test_ItemSet_iteration_graphs
():
# Graphs.
graphs
=
[
dgl
.
rand_graph
(
10
,
20
)
for
_
in
range
(
5
)]
item_set
=
gb
.
ItemSet
(
graphs
)
assert
item_set
.
names
is
None
for
i
,
item
in
enumerate
(
item_set
):
assert
graphs
[
i
]
==
item
def
test_ItemSetDict_names
():
# ItemSetDict with single name.
item_set
=
gb
.
ItemSetDict
(
{
"user"
:
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
),
names
=
"seed_node"
),
"item"
:
gb
.
ItemSet
(
torch
.
arange
(
5
,
10
),
names
=
"seed_node"
),
"user"
:
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
),
names
=
"seed_node
s
"
),
"item"
:
gb
.
ItemSet
(
torch
.
arange
(
5
,
10
),
names
=
"seed_node
s
"
),
}
)
assert
item_set
.
names
==
(
"seed_node"
,)
assert
item_set
.
names
==
(
"seed_node
s
"
,)
# ItemSetDict with multiple names.
item_set
=
gb
.
ItemSetDict
(
{
"user"
:
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
)),
names
=
(
"seed_node"
,
"label"
),
names
=
(
"seed_node
s
"
,
"label
s
"
),
),
"item"
:
gb
.
ItemSet
(
(
torch
.
arange
(
5
,
10
),
torch
.
arange
(
10
,
15
)),
names
=
(
"seed_node"
,
"label"
),
names
=
(
"seed_node
s
"
,
"label
s
"
),
),
}
)
assert
item_set
.
names
==
(
"seed_node"
,
"label"
)
assert
item_set
.
names
==
(
"seed_node
s
"
,
"label
s
"
)
# ItemSetDict with no name.
item_set
=
gb
.
ItemSetDict
(
...
...
@@ -73,45 +161,17 @@ def test_ItemSetDict_names():
{
"user"
:
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
)),
names
=
(
"seed_node"
,
"label"
),
names
=
(
"seed_node
s
"
,
"label
s
"
),
),
"item"
:
gb
.
ItemSet
(
(
torch
.
arange
(
5
,
10
),),
names
=
(
"seed_node"
,)
(
torch
.
arange
(
5
,
10
),),
names
=
(
"seed_node
s
"
,)
),
}
)
def
test_ItemSet_valid_length
():
# Single iterable.
ids
=
torch
.
arange
(
0
,
5
)
item_set
=
gb
.
ItemSet
(
ids
)
assert
len
(
item_set
)
==
5
# Tuple of iterables.
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
item_set
=
gb
.
ItemSet
(
node_pairs
)
assert
len
(
item_set
)
==
5
def
test_ItemSet_invalid_length
():
class
InvalidLength
:
def
__iter__
(
self
):
return
iter
([
0
,
1
,
2
])
# Single iterable.
item_set
=
gb
.
ItemSet
(
InvalidLength
())
with
pytest
.
raises
(
TypeError
):
_
=
len
(
item_set
)
# Tuple of iterables.
item_set
=
gb
.
ItemSet
((
InvalidLength
(),
InvalidLength
()))
with
pytest
.
raises
(
TypeError
):
_
=
len
(
item_set
)
def
test_ItemSetDict_valid_length
():
# Single iterable.
def
test_ItemSetDict_length
():
# Single iterable with valid length.
user_ids
=
torch
.
arange
(
0
,
5
)
item_ids
=
torch
.
arange
(
0
,
5
)
item_set
=
gb
.
ItemSetDict
(
...
...
@@ -122,24 +182,26 @@ def test_ItemSetDict_valid_length():
)
assert
len
(
item_set
)
==
len
(
user_ids
)
+
len
(
item_ids
)
# Tuple of iterables.
like
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
0
,
5
))
follow
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
# Tuple of iterables with valid length.
node_pairs_like
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
neg_dsts_like
=
torch
.
arange
(
10
,
20
).
reshape
(
-
1
,
2
)
node_pairs_follow
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
neg_dsts_follow
=
torch
.
arange
(
10
,
20
).
reshape
(
-
1
,
2
)
item_set
=
gb
.
ItemSetDict
(
{
"user:like:item"
:
gb
.
ItemSet
(
like
),
"user:follow:user"
:
gb
.
ItemSet
(
follow
),
"user:like:item"
:
gb
.
ItemSet
((
node_pairs_like
,
neg_dsts_like
)),
"user:follow:user"
:
gb
.
ItemSet
(
(
node_pairs_follow
,
neg_dsts_follow
)
),
}
)
assert
len
(
item_set
)
==
len
(
like
[
0
])
+
len
(
follow
[
0
])
assert
len
(
item_set
)
==
node_pairs_like
.
size
(
0
)
+
node_pairs_follow
.
size
(
0
)
def
test_ItemSetDict_invalid_length
():
class
InvalidLength
:
def
__iter__
(
self
):
return
iter
([
0
,
1
,
2
])
# Single iterable.
# Single iterable
with invalid length
.
item_set
=
gb
.
ItemSetDict
(
{
"user"
:
gb
.
ItemSet
(
InvalidLength
()),
...
...
@@ -149,7 +211,7 @@ def test_ItemSetDict_invalid_length():
with
pytest
.
raises
(
TypeError
):
_
=
len
(
item_set
)
# Tuple of iterables.
# Tuple of iterables
with invalid length
.
item_set
=
gb
.
ItemSetDict
(
{
"user:like:item"
:
gb
.
ItemSet
((
InvalidLength
(),
InvalidLength
())),
...
...
@@ -160,63 +222,19 @@ def test_ItemSetDict_invalid_length():
_
=
len
(
item_set
)
def
test_ItemSet_node_edge_ids
():
# Node or edge IDs.
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
))
for
i
,
item
in
enumerate
(
item_set
):
assert
i
==
item
.
item
()
def
test_ItemSet_graphs
():
# Graphs.
graphs
=
[
dgl
.
rand_graph
(
10
,
20
)
for
_
in
range
(
5
)]
item_set
=
gb
.
ItemSet
(
graphs
)
for
i
,
item
in
enumerate
(
item_set
):
assert
graphs
[
i
]
==
item
def
test_ItemSet_node_pairs
():
# Node pairs.
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
item_set
=
gb
.
ItemSet
(
node_pairs
)
for
i
,
(
src
,
dst
)
in
enumerate
(
item_set
):
assert
node_pairs
[
0
][
i
]
==
src
assert
node_pairs
[
1
][
i
]
==
dst
def
test_ItemSet_node_pairs_labels
():
# Node pairs and labels
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
item_set
=
gb
.
ItemSet
((
node_pairs
[
0
],
node_pairs
[
1
],
labels
))
for
i
,
(
src
,
dst
,
label
)
in
enumerate
(
item_set
):
assert
node_pairs
[
0
][
i
]
==
src
assert
node_pairs
[
1
][
i
]
==
dst
assert
labels
[
i
]
==
label
def
test_ItemSet_head_tail_neg_tails
():
# Head, tail and negative tails.
heads
=
torch
.
arange
(
0
,
5
)
tails
=
torch
.
arange
(
5
,
10
)
neg_tails
=
torch
.
arange
(
10
,
20
).
reshape
(
5
,
2
)
item_set
=
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
))
for
i
,
(
head
,
tail
,
negs
)
in
enumerate
(
item_set
):
assert
heads
[
i
]
==
head
assert
tails
[
i
]
==
tail
assert_close
(
neg_tails
[
i
],
negs
)
def
test_ItemSetDict_node_edge_ids
():
# Node or edge IDs
def
test_ItemSetDict_iteration_seed_nodes
():
# Node IDs.
user_ids
=
torch
.
arange
(
0
,
5
)
item_ids
=
torch
.
arange
(
5
,
10
)
ids
=
{
"user
:like:item
"
:
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
)
),
"
user:follow:user"
:
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
)
),
"user"
:
gb
.
ItemSet
(
user_ids
,
names
=
"seed_nodes"
),
"
item"
:
gb
.
ItemSet
(
item_ids
,
names
=
"seed_nodes"
),
}
chained_ids
=
[]
for
key
,
value
in
ids
.
items
():
chained_ids
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
ItemSetDict
(
ids
)
assert
item_set
.
names
==
(
"seed_nodes"
,)
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
...
...
@@ -224,59 +242,98 @@ def test_ItemSetDict_node_edge_ids():
assert
item
[
chained_ids
[
i
][
0
]]
==
chained_ids
[
i
][
1
]
def
test_ItemSetDict_node_pairs
():
def
test_ItemSetDict_iteration_seed_nodes_labels
():
# Node IDs and labels.
user_ids
=
torch
.
arange
(
0
,
5
)
user_labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
item_ids
=
torch
.
arange
(
5
,
10
)
item_labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
ids_labels
=
{
"user"
:
gb
.
ItemSet
(
(
user_ids
,
user_labels
),
names
=
(
"seed_nodes"
,
"labels"
)
),
"item"
:
gb
.
ItemSet
(
(
item_ids
,
item_labels
),
names
=
(
"seed_nodes"
,
"labels"
)
),
}
chained_ids
=
[]
for
key
,
value
in
ids_labels
.
items
():
chained_ids
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
ItemSetDict
(
ids_labels
)
assert
item_set
.
names
==
(
"seed_nodes"
,
"labels"
)
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
chained_ids
[
i
][
0
]
in
item
assert
item
[
chained_ids
[
i
][
0
]]
==
chained_ids
[
i
][
1
]
def
test_ItemSetDict_iteration_node_pairs
():
# Node pairs.
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
)
)
node_pairs
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
node_pairs_dict
=
{
"user:like:item"
:
gb
.
ItemSet
(
node_pairs
),
"user:follow:user"
:
gb
.
ItemSet
(
node_pairs
),
"user:like:item"
:
gb
.
ItemSet
(
node_pairs
,
names
=
"node_pairs"
),
"user:follow:user"
:
gb
.
ItemSet
(
node_pairs
,
names
=
"node_pairs"
),
}
expected_data
=
[]
for
key
,
value
in
node_pairs_dict
.
items
():
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
ItemSetDict
(
node_pairs_dict
)
assert
item_set
.
names
==
(
"node_pairs"
,)
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
expected_data
[
i
][
0
]
in
item
assert
item
[
expected_data
[
i
][
0
]]
==
expected_data
[
i
][
1
]
assert
torch
.
equal
(
item
[
expected_data
[
i
][
0
]]
,
expected_data
[
i
][
1
]
)
def
test_ItemSetDict_node_pairs_labels
():
def
test_ItemSetDict_
iteration_
node_pairs_labels
():
# Node pairs and labels
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
)
)
node_pairs
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
node_pairs_dict
=
{
"user:like:item"
:
gb
.
ItemSet
((
node_pairs
[
0
],
node_pairs
[
1
],
labels
)),
"user:follow:user"
:
gb
.
ItemSet
((
node_pairs
[
0
],
node_pairs
[
1
],
labels
)),
node_pairs_labels
=
{
"user:like:item"
:
gb
.
ItemSet
(
(
node_pairs
,
labels
),
names
=
(
"node_pairs"
,
"labels"
)
),
"user:follow:user"
:
gb
.
ItemSet
(
(
node_pairs
,
labels
),
names
=
(
"node_pairs"
,
"labels"
)
),
}
expected_data
=
[]
for
key
,
value
in
node_pairs_
dict
.
items
():
for
key
,
value
in
node_pairs_
labels
.
items
():
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
ItemSetDict
(
node_pairs_dict
)
item_set
=
gb
.
ItemSetDict
(
node_pairs_labels
)
assert
item_set
.
names
==
(
"node_pairs"
,
"labels"
)
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
expected_data
[
i
][
0
]
in
item
assert
item
[
expected_data
[
i
][
0
]]
==
expected_data
[
i
][
1
]
def
test_ItemSetDict_head_tail_neg_tails
():
# Head, tail and negative tails.
heads
=
torch
.
arange
(
0
,
5
)
tails
=
torch
.
arange
(
5
,
10
)
neg_tails
=
torch
.
arange
(
10
,
20
).
reshape
(
5
,
2
)
item_set
=
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
))
data_dict
=
{
"user:like:item"
:
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
)),
"user:follow:user"
:
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
)),
key
,
value
=
expected_data
[
i
]
assert
key
in
item
assert
torch
.
equal
(
item
[
key
][
0
],
value
[
0
])
assert
item
[
key
][
1
]
==
value
[
1
]
def
test_ItemSetDict_iteration_node_pairs_neg_dsts
():
# Node pairs and negative destinations.
node_pairs
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
neg_dsts
=
torch
.
arange
(
10
,
25
).
reshape
(
-
1
,
3
)
node_pairs_neg_dsts
=
{
"user:like:item"
:
gb
.
ItemSet
(
(
node_pairs
,
neg_dsts
),
names
=
(
"node_pairs"
,
"neg_dsts"
)
),
"user:follow:user"
:
gb
.
ItemSet
(
(
node_pairs
,
neg_dsts
),
names
=
(
"node_pairs"
,
"neg_dsts"
)
),
}
expected_data
=
[]
for
key
,
value
in
data_dict
.
items
():
for
key
,
value
in
node_pairs_neg_dsts
.
items
():
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
ItemSetDict
(
data_dict
)
item_set
=
gb
.
ItemSetDict
(
node_pairs_neg_dsts
)
assert
item_set
.
names
==
(
"node_pairs"
,
"neg_dsts"
)
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
expected_data
[
i
][
0
]
in
item
assert_close
(
item
[
expected_data
[
i
][
0
]],
expected_data
[
i
][
1
])
key
,
value
=
expected_data
[
i
]
assert
key
in
item
assert
torch
.
equal
(
item
[
key
][
0
],
value
[
0
])
assert
torch
.
equal
(
item
[
key
][
1
],
value
[
1
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment