Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
9c756a5e
Unverified
Commit
9c756a5e
authored
Jun 09, 2023
by
Rhett Ying
Committed by
GitHub
Jun 09, 2023
Browse files
[GraphBolt] refine ItemSet/ItemSetDict and add examples (#5834)
parent
c9778b55
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
105 additions
and
31 deletions
+105
-31
python/dgl/graphbolt/itemset.py
python/dgl/graphbolt/itemset.py
+78
-12
tests/python/pytorch/graphbolt/test_itemset.py
tests/python/pytorch/graphbolt/test_itemset.py
+27
-19
No files found.
python/dgl/graphbolt/itemset.py
View file @
9c756a5e
...
@@ -13,10 +13,43 @@ class ItemSet:
...
@@ -13,10 +13,43 @@ class ItemSet:
Parameters
Parameters
----------
----------
items: Iterable or Tuple[Iterable]
items: Iterable or Tuple[Iterable]
Examples
--------
>>> import torch
>>> from dgl import graphbolt as gb
1. Single iterable.
>>> node_ids = torch.arange(0, 5)
>>> item_set = gb.ItemSet(node_ids)
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)]
2. Tuple of iterables with same shape.
>>> node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
>>> item_set = gb.ItemSet(node_pairs)
>>> list(item_set)
[(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)),
(tensor(3), tensor(8)), (tensor(4), tensor(9))]
3. Tuple of iterables with different shape.
>>> heads = torch.arange(0, 5)
>>> tails = torch.arange(5, 10)
>>> neg_tails = torch.arange(10, 20).reshape(5, 2)
>>> item_set = gb.ItemSet((heads, tails, neg_tails))
>>> list(item_set)
[(tensor(0), tensor(5), tensor([10, 11])),
(tensor(1), tensor(6), tensor([12, 13])),
(tensor(2), tensor(7), tensor([14, 15])),
(tensor(3), tensor(8), tensor([16, 17])),
(tensor(4), tensor(9), tensor([18, 19]))]
"""
"""
def
__init__
(
self
,
items
):
def
__init__
(
self
,
items
):
if
isinstance
(
items
,
tuple
):
if
isinstance
(
items
,
tuple
):
assert
all
(
items
[
0
].
size
(
0
)
==
item
.
size
(
0
)
for
item
in
items
),
"Size mismatch between items in tuple."
self
.
_items
=
items
self
.
_items
=
items
else
:
else
:
self
.
_items
=
(
items
,)
self
.
_items
=
(
items
,)
...
@@ -29,12 +62,6 @@ class ItemSet:
...
@@ -29,12 +62,6 @@ class ItemSet:
for
item
in
zip_items
:
for
item
in
zip_items
:
yield
tuple
(
item
)
yield
tuple
(
item
)
def
__getitem__
(
self
,
_
):
raise
NotImplementedError
def
__len__
(
self
):
raise
NotImplementedError
class
ItemSetDict
:
class
ItemSetDict
:
r
"""An iterable ItemsetDict.
r
"""An iterable ItemsetDict.
...
@@ -45,6 +72,51 @@ class ItemSetDict:
...
@@ -45,6 +72,51 @@ class ItemSetDict:
Parameters
Parameters
----------
----------
itemsets: Dict[str, ItemSet]
itemsets: Dict[str, ItemSet]
Examples
--------
>>> import torch
>>> from dgl import graphbolt as gb
1. Single iterable.
>>> node_ids_user = torch.arange(0, 5)
>>> node_ids_item = torch.arange(5, 10)
>>> item_set = gb.ItemSetDict({
... 'user': gb.ItemSet(node_ids_user),
... 'item': gb.ItemSet(node_ids_item)})
>>> list(item_set)
[{'user': tensor(0)}, {'user': tensor(1)}, {'user': tensor(2)},
{'user': tensor(3)}, {'user': tensor(4)}, {'item': tensor(5)},
{'item': tensor(6)}, {'item': tensor(7)}, {'item': tensor(8)},
{'item': tensor(9)}]
2. Tuple of iterables with same shape.
>>> node_pairs_like = (torch.arange(0, 2), torch.arange(0, 2))
>>> node_pairs_follow = (torch.arange(0, 3), torch.arange(3, 6))
>>> item_set = gb.ItemSetDict({
... ('user', 'like', 'item'): gb.ItemSet(node_pairs_like),
... ('user', 'follow', 'user'): gb.ItemSet(node_pairs_follow)})
>>> list(item_set)
[{('user', 'like', 'item'): (tensor(0), tensor(0))},
{('user', 'like', 'item'): (tensor(1), tensor(1))},
{('user', 'follow', 'user'): (tensor(0), tensor(3))},
{('user', 'follow', 'user'): (tensor(1), tensor(4))},
{('user', 'follow', 'user'): (tensor(2), tensor(5))}]
3. Tuple of iterables with different shape.
>>> like = (torch.arange(0, 2), torch.arange(0, 2),
... torch.arange(0, 4).reshape(-1, 2))
>>> follow = (torch.arange(0, 3), torch.arange(3, 6),
... torch.arange(0, 6).reshape(-1, 2))
>>> item_set = gb.ItemSetDict({
... ('user', 'like', 'item'): gb.ItemSet(like),
... ('user', 'follow', 'user'): gb.ItemSet(follow)})
>>> list(item_set)
[{('user', 'like', 'item'): (tensor(0), tensor(0), tensor([0, 1]))},
{('user', 'like', 'item'): (tensor(1), tensor(1), tensor([2, 3]))},
{('user', 'follow', 'user'): (tensor(0), tensor(3), tensor([0, 1]))},
{('user', 'follow', 'user'): (tensor(1), tensor(4), tensor([2, 3]))},
{('user', 'follow', 'user'): (tensor(2), tensor(5), tensor([4, 5]))}]
"""
"""
def
__init__
(
self
,
itemsets
):
def
__init__
(
self
,
itemsets
):
...
@@ -54,9 +126,3 @@ class ItemSetDict:
...
@@ -54,9 +126,3 @@ class ItemSetDict:
for
key
,
itemset
in
self
.
_itemsets
.
items
():
for
key
,
itemset
in
self
.
_itemsets
.
items
():
for
item
in
itemset
:
for
item
in
itemset
:
yield
{
key
:
item
}
yield
{
key
:
item
}
def
__getitem__
(
self
,
_
):
raise
NotImplementedError
def
__len__
(
self
):
raise
NotImplementedError
tests/python/pytorch/graphbolt/test_itemset.py
View file @
9c756a5e
import
dgl
import
dgl
import
pytest
import
torch
import
torch
from
dgl
import
graphbolt
as
gb
from
torch.testing
import
assert_close
from
torch.testing
import
assert_close
from
dgl.graphbolt
import
*
def
test_mismatch_size_in_tuple
():
# Size mismatch.
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
11
))
with
pytest
.
raises
(
AssertionError
):
_
=
gb
.
ItemSet
(
node_pairs
)
def
test_ItemSet_node_edge_ids
():
def
test_ItemSet_node_edge_ids
():
# Node or edge IDs.
# Node or edge IDs.
item_set
=
ItemSet
(
torch
.
arange
(
0
,
5
))
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
))
for
i
,
item
in
enumerate
(
item_set
):
for
i
,
item
in
enumerate
(
item_set
):
assert
i
==
item
.
item
()
assert
i
==
item
.
item
()
...
@@ -14,7 +22,7 @@ def test_ItemSet_node_edge_ids():
...
@@ -14,7 +22,7 @@ def test_ItemSet_node_edge_ids():
def
test_ItemSet_graphs
():
def
test_ItemSet_graphs
():
# Graphs.
# Graphs.
graphs
=
[
dgl
.
rand_graph
(
10
,
20
)
for
_
in
range
(
5
)]
graphs
=
[
dgl
.
rand_graph
(
10
,
20
)
for
_
in
range
(
5
)]
item_set
=
ItemSet
(
graphs
)
item_set
=
gb
.
ItemSet
(
graphs
)
for
i
,
item
in
enumerate
(
item_set
):
for
i
,
item
in
enumerate
(
item_set
):
assert
graphs
[
i
]
==
item
assert
graphs
[
i
]
==
item
...
@@ -22,7 +30,7 @@ def test_ItemSet_graphs():
...
@@ -22,7 +30,7 @@ def test_ItemSet_graphs():
def
test_ItemSet_node_pairs
():
def
test_ItemSet_node_pairs
():
# Node pairs.
# Node pairs.
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
item_set
=
ItemSet
(
node_pairs
)
item_set
=
gb
.
ItemSet
(
node_pairs
)
for
i
,
(
src
,
dst
)
in
enumerate
(
item_set
):
for
i
,
(
src
,
dst
)
in
enumerate
(
item_set
):
assert
node_pairs
[
0
][
i
]
==
src
assert
node_pairs
[
0
][
i
]
==
src
assert
node_pairs
[
1
][
i
]
==
dst
assert
node_pairs
[
1
][
i
]
==
dst
...
@@ -32,7 +40,7 @@ def test_ItemSet_node_pairs_labels():
...
@@ -32,7 +40,7 @@ def test_ItemSet_node_pairs_labels():
# Node pairs and labels
# Node pairs and labels
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
item_set
=
ItemSet
((
node_pairs
[
0
],
node_pairs
[
1
],
labels
))
item_set
=
gb
.
ItemSet
((
node_pairs
[
0
],
node_pairs
[
1
],
labels
))
for
i
,
(
src
,
dst
,
label
)
in
enumerate
(
item_set
):
for
i
,
(
src
,
dst
,
label
)
in
enumerate
(
item_set
):
assert
node_pairs
[
0
][
i
]
==
src
assert
node_pairs
[
0
][
i
]
==
src
assert
node_pairs
[
1
][
i
]
==
dst
assert
node_pairs
[
1
][
i
]
==
dst
...
@@ -44,7 +52,7 @@ def test_ItemSet_head_tail_neg_tails():
...
@@ -44,7 +52,7 @@ def test_ItemSet_head_tail_neg_tails():
heads
=
torch
.
arange
(
0
,
5
)
heads
=
torch
.
arange
(
0
,
5
)
tails
=
torch
.
arange
(
5
,
10
)
tails
=
torch
.
arange
(
5
,
10
)
neg_tails
=
torch
.
arange
(
10
,
20
).
reshape
(
5
,
2
)
neg_tails
=
torch
.
arange
(
10
,
20
).
reshape
(
5
,
2
)
item_set
=
ItemSet
((
heads
,
tails
,
neg_tails
))
item_set
=
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
))
for
i
,
(
head
,
tail
,
negs
)
in
enumerate
(
item_set
):
for
i
,
(
head
,
tail
,
negs
)
in
enumerate
(
item_set
):
assert
heads
[
i
]
==
head
assert
heads
[
i
]
==
head
assert
tails
[
i
]
==
tail
assert
tails
[
i
]
==
tail
...
@@ -54,13 +62,13 @@ def test_ItemSet_head_tail_neg_tails():
...
@@ -54,13 +62,13 @@ def test_ItemSet_head_tail_neg_tails():
def
test_ItemSetDict_node_edge_ids
():
def
test_ItemSetDict_node_edge_ids
():
# Node or edge IDs
# Node or edge IDs
ids
=
{
ids
=
{
(
"user"
,
"like"
,
"item"
):
ItemSet
(
torch
.
arange
(
0
,
5
)),
(
"user"
,
"like"
,
"item"
):
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
)),
(
"user"
,
"follow"
,
"user"
):
ItemSet
(
torch
.
arange
(
0
,
5
)),
(
"user"
,
"follow"
,
"user"
):
gb
.
ItemSet
(
torch
.
arange
(
0
,
5
)),
}
}
chained_ids
=
[]
chained_ids
=
[]
for
key
,
value
in
ids
.
items
():
for
key
,
value
in
ids
.
items
():
chained_ids
+=
[(
key
,
v
)
for
v
in
value
]
chained_ids
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
ItemSetDict
(
ids
)
item_set
=
gb
.
ItemSetDict
(
ids
)
for
i
,
item
in
enumerate
(
item_set
):
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
isinstance
(
item
,
dict
)
...
@@ -72,13 +80,13 @@ def test_ItemSetDict_node_pairs():
...
@@ -72,13 +80,13 @@ def test_ItemSetDict_node_pairs():
# Node pairs.
# Node pairs.
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
node_pairs_dict
=
{
node_pairs_dict
=
{
(
"user"
,
"like"
,
"item"
):
ItemSet
(
node_pairs
),
(
"user"
,
"like"
,
"item"
):
gb
.
ItemSet
(
node_pairs
),
(
"user"
,
"follow"
,
"user"
):
ItemSet
(
node_pairs
),
(
"user"
,
"follow"
,
"user"
):
gb
.
ItemSet
(
node_pairs
),
}
}
expected_data
=
[]
expected_data
=
[]
for
key
,
value
in
node_pairs_dict
.
items
():
for
key
,
value
in
node_pairs_dict
.
items
():
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
ItemSetDict
(
node_pairs_dict
)
item_set
=
gb
.
ItemSetDict
(
node_pairs_dict
)
for
i
,
item
in
enumerate
(
item_set
):
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
isinstance
(
item
,
dict
)
...
@@ -91,17 +99,17 @@ def test_ItemSetDict_node_pairs_labels():
...
@@ -91,17 +99,17 @@ def test_ItemSetDict_node_pairs_labels():
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
node_pairs
=
(
torch
.
arange
(
0
,
5
),
torch
.
arange
(
5
,
10
))
labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
labels
=
torch
.
randint
(
0
,
3
,
(
5
,))
node_pairs_dict
=
{
node_pairs_dict
=
{
(
"user"
,
"like"
,
"item"
):
ItemSet
(
(
"user"
,
"like"
,
"item"
):
gb
.
ItemSet
(
(
node_pairs
[
0
],
node_pairs
[
1
],
labels
)
(
node_pairs
[
0
],
node_pairs
[
1
],
labels
)
),
),
(
"user"
,
"follow"
,
"user"
):
ItemSet
(
(
"user"
,
"follow"
,
"user"
):
gb
.
ItemSet
(
(
node_pairs
[
0
],
node_pairs
[
1
],
labels
)
(
node_pairs
[
0
],
node_pairs
[
1
],
labels
)
),
),
}
}
expected_data
=
[]
expected_data
=
[]
for
key
,
value
in
node_pairs_dict
.
items
():
for
key
,
value
in
node_pairs_dict
.
items
():
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
ItemSetDict
(
node_pairs_dict
)
item_set
=
gb
.
ItemSetDict
(
node_pairs_dict
)
for
i
,
item
in
enumerate
(
item_set
):
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
isinstance
(
item
,
dict
)
...
@@ -114,15 +122,15 @@ def test_ItemSetDict_head_tail_neg_tails():
...
@@ -114,15 +122,15 @@ def test_ItemSetDict_head_tail_neg_tails():
heads
=
torch
.
arange
(
0
,
5
)
heads
=
torch
.
arange
(
0
,
5
)
tails
=
torch
.
arange
(
5
,
10
)
tails
=
torch
.
arange
(
5
,
10
)
neg_tails
=
torch
.
arange
(
10
,
20
).
reshape
(
5
,
2
)
neg_tails
=
torch
.
arange
(
10
,
20
).
reshape
(
5
,
2
)
item_set
=
ItemSet
((
heads
,
tails
,
neg_tails
))
item_set
=
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
))
data_dict
=
{
data_dict
=
{
(
"user"
,
"like"
,
"item"
):
ItemSet
((
heads
,
tails
,
neg_tails
)),
(
"user"
,
"like"
,
"item"
):
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
)),
(
"user"
,
"follow"
,
"user"
):
ItemSet
((
heads
,
tails
,
neg_tails
)),
(
"user"
,
"follow"
,
"user"
):
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
)),
}
}
expected_data
=
[]
expected_data
=
[]
for
key
,
value
in
data_dict
.
items
():
for
key
,
value
in
data_dict
.
items
():
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
ItemSetDict
(
data_dict
)
item_set
=
gb
.
ItemSetDict
(
data_dict
)
for
i
,
item
in
enumerate
(
item_set
):
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
isinstance
(
item
,
dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment