Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
729924e3
"tests/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "f051a5cc693c7d02c10d335c24970730b72277be"
Unverified
Commit
729924e3
authored
Sep 05, 2023
by
Rhett Ying
Committed by
GitHub
Sep 05, 2023
Browse files
[GraphBolt] update docstring of ItemSet (#6282)
parent
feaeb1c2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
20 deletions
+47
-20
python/dgl/graphbolt/itemset.py
python/dgl/graphbolt/itemset.py
+42
-15
tests/python/pytorch/graphbolt/test_itemset.py
tests/python/pytorch/graphbolt/test_itemset.py
+5
-5
No files found.
python/dgl/graphbolt/itemset.py
View file @
729924e3
...
@@ -15,36 +15,49 @@ class ItemSet:
...
@@ -15,36 +15,49 @@ class ItemSet:
Parameters
Parameters
----------
----------
items: Iterable or Tuple[Iterable]
items: Iterable or Tuple[Iterable]
The items to be iterated over. If it is a tuple, each item in the tuple
is an iterable of items.
names: str or Tuple[str], optional
The names of the items. If it is a tuple, each name corresponds to an
item in the tuple.
Examples
Examples
--------
--------
>>> import torch
>>> import torch
>>> from dgl import graphbolt as gb
>>> from dgl import graphbolt as gb
1. Single iterable.
1. Single iterable
: seed nodes
.
>>> node_ids = torch.arange(0, 5)
>>> node_ids = torch.arange(0, 5)
>>> item_set = gb.ItemSet(node_ids)
>>> item_set = gb.ItemSet(node_ids
, names="seed_nodes"
)
>>> list(item_set)
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)]
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)]
>>> item_set.names
('seed_nodes',)
2. Tuple of iterables with same shape.
2. Tuple of iterables with same shape
: seed nodes and labels
.
>>> node_ids = torch.arange(0, 5)
>>> node_ids = torch.arange(0, 5)
>>> labels = torch.arange(5, 10)
>>> labels = torch.arange(5, 10)
>>> item_set = gb.ItemSet((node_ids, labels))
>>> item_set = gb.ItemSet(
... (node_ids, labels), names=("seed_nodes", "labels"))
>>> list(item_set)
>>> list(item_set)
[(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)),
[(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)),
(tensor(3), tensor(8)), (tensor(4), tensor(9))]
(tensor(3), tensor(8)), (tensor(4), tensor(9))]
>>> item_set.names
('seed_nodes', 'labels')
3. Tuple of iterables with different shape.
3. Tuple of iterables with different shape
: node pairs and negative dsts
.
>>> node_pairs = torch.arange(0, 10).reshape(-1, 2)
>>> node_pairs = torch.arange(0, 10).reshape(-1, 2)
>>> neg_dsts = torch.arange(10, 25).reshape(-1, 3)
>>> neg_dsts = torch.arange(10, 25).reshape(-1, 3)
>>> item_set = gb.ItemSet((node_pairs, neg_dsts))
>>> item_set = gb.ItemSet(
... (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts"))
>>> list(item_set)
>>> list(item_set)
[(tensor([0, 1]), tensor([10, 11, 12])),
[(tensor([0, 1]), tensor([10, 11, 12])),
(tensor([2, 3]), tensor([13, 14, 15])),
(tensor([2, 3]), tensor([13, 14, 15])),
(tensor([4, 5]), tensor([16, 17, 18])),
(tensor([4, 5]), tensor([16, 17, 18])),
(tensor([6, 7]), tensor([19, 20, 21])),
(tensor([6, 7]), tensor([19, 20, 21])),
(tensor([8, 9]), tensor([22, 23, 24]))]
(tensor([8, 9]), tensor([22, 23, 24]))]
>>> item_set.names
('node_pairs', 'negative_dsts')
"""
"""
def
__init__
(
def
__init__
(
...
@@ -104,45 +117,59 @@ class ItemSetDict:
...
@@ -104,45 +117,59 @@ class ItemSetDict:
>>> import torch
>>> import torch
>>> from dgl import graphbolt as gb
>>> from dgl import graphbolt as gb
1. Single iterable.
1. Single iterable
: seed nodes
.
>>> node_ids_user = torch.arange(0, 5)
>>> node_ids_user = torch.arange(0, 5)
>>> node_ids_item = torch.arange(5, 10)
>>> node_ids_item = torch.arange(5, 10)
>>> item_set = gb.ItemSetDict({
>>> item_set = gb.ItemSetDict({
... "user": gb.ItemSet(node_ids_user),
... "user": gb.ItemSet(node_ids_user
, names="seed_nodes"
),
... "item": gb.ItemSet(node_ids_item)})
... "item": gb.ItemSet(node_ids_item
, names="seed_nodes"
)})
>>> list(item_set)
>>> list(item_set)
[{"user": tensor(0)}, {"user": tensor(1)}, {"user": tensor(2)},
[{"user": tensor(0)}, {"user": tensor(1)}, {"user": tensor(2)},
{"user": tensor(3)}, {"user": tensor(4)}, {"item": tensor(5)},
{"user": tensor(3)}, {"user": tensor(4)}, {"item": tensor(5)},
{"item": tensor(6)}, {"item": tensor(7)}, {"item": tensor(8)},
{"item": tensor(6)}, {"item": tensor(7)}, {"item": tensor(8)},
{"item": tensor(9)}}]
{"item": tensor(9)}}]
>>> item_set.names
('seed_nodes',)
2. Tuple of iterables with same shape.
2. Tuple of iterables with same shape
: seed nodes and labels
.
>>> node_ids_user = torch.arange(0, 2)
>>> node_ids_user = torch.arange(0, 2)
>>> labels_user = torch.arange(0, 2)
>>> labels_user = torch.arange(0, 2)
>>> node_ids_item = torch.arange(2, 5)
>>> node_ids_item = torch.arange(2, 5)
>>> labels_item = torch.arange(2, 5)
>>> labels_item = torch.arange(2, 5)
>>> item_set = gb.ItemSetDict({
>>> item_set = gb.ItemSetDict({
... "user": gb.ItemSet((node_ids_user, labels_user)),
... "user": gb.ItemSet(
... "item": gb.ItemSet((node_ids_item, labels_item))})
... (node_ids_user, labels_user),
... names=("seed_nodes", "labels")),
... "item": gb.ItemSet(
... (node_ids_item, labels_item),
... names=("seed_nodes", "labels"))})
>>> list(item_set)
>>> list(item_set)
[{"user": (tensor(0), tensor(0))}, {"user": (tensor(1), tensor(1))},
[{"user": (tensor(0), tensor(0))}, {"user": (tensor(1), tensor(1))},
{"item": (tensor(2), tensor(2))}, {"item": (tensor(3), tensor(3))},
{"item": (tensor(2), tensor(2))}, {"item": (tensor(3), tensor(3))},
{"item": (tensor(4), tensor(4))}}]
{"item": (tensor(4), tensor(4))}}]
>>> item_set.names
('seed_nodes', 'labels')
3. Tuple of iterables with different shape.
3. Tuple of iterables with different shape
: node pairs and negative dsts
.
>>> node_pairs_like = torch.arange(0, 4).reshape(-1, 2)
>>> node_pairs_like = torch.arange(0, 4).reshape(-1, 2)
>>> neg_dsts_like = torch.arange(4, 10).reshape(-1, 3)
>>> neg_dsts_like = torch.arange(4, 10).reshape(-1, 3)
>>> node_pairs_follow = torch.arange(0, 6).reshape(-1, 2)
>>> node_pairs_follow = torch.arange(0, 6).reshape(-1, 2)
>>> neg_dsts_follow = torch.arange(6, 15).reshape(-1, 3)
>>> neg_dsts_follow = torch.arange(6, 15).reshape(-1, 3)
>>> item_set = gb.ItemSetDict({
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet((node_pairs_like, neg_dsts_like)),
... "user:like:item": gb.ItemSet(
... "user:follow:user": gb.ItemSet((node_pairs_follow, neg_dsts_follow))})
... (node_pairs_like, neg_dsts_like),
... names=("node_pairs", "negative_dsts")),
... "user:follow:user": gb.ItemSet(
... (node_pairs_follow, neg_dsts_follow),
... names=("node_pairs", "negative_dsts"))})
>>> list(item_set)
>>> list(item_set)
[{"user:like:item": (tensor([0, 1]), tensor([4, 5, 6]))},
[{"user:like:item": (tensor([0, 1]), tensor([4, 5, 6]))},
{"user:like:item": (tensor([2, 3]), tensor([7, 8, 9]))},
{"user:like:item": (tensor([2, 3]), tensor([7, 8, 9]))},
{"user:follow:user": (tensor([0, 1]), tensor([ 6, 7, 8, 9, 10, 11]))},
{"user:follow:user": (tensor([0, 1]), tensor([ 6, 7, 8, 9, 10, 11]))},
{"user:follow:user": (tensor([2, 3]), tensor([12, 13, 14, 15, 16, 17]))},
{"user:follow:user": (tensor([2, 3]), tensor([12, 13, 14, 15, 16, 17]))},
{"user:follow:user": (tensor([4, 5]), tensor([18, 19, 20, 21, 22, 23]))}]
{"user:follow:user": (tensor([4, 5]), tensor([18, 19, 20, 21, 22, 23]))}]
>>> item_set.names
('node_pairs', 'negative_dsts')
"""
"""
def
__init__
(
self
,
itemsets
:
Dict
[
str
,
ItemSet
])
->
None
:
def
__init__
(
self
,
itemsets
:
Dict
[
str
,
ItemSet
])
->
None
:
...
...
tests/python/pytorch/graphbolt/test_itemset.py
View file @
729924e3
...
@@ -101,9 +101,9 @@ def test_ItemSet_iteration_node_pairs_neg_dsts():
...
@@ -101,9 +101,9 @@ def test_ItemSet_iteration_node_pairs_neg_dsts():
node_pairs
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
node_pairs
=
torch
.
arange
(
0
,
10
).
reshape
(
-
1
,
2
)
neg_dsts
=
torch
.
arange
(
10
,
25
).
reshape
(
-
1
,
3
)
neg_dsts
=
torch
.
arange
(
10
,
25
).
reshape
(
-
1
,
3
)
item_set
=
gb
.
ItemSet
(
item_set
=
gb
.
ItemSet
(
(
node_pairs
,
neg_dsts
),
names
=
(
"node_pairs"
,
"neg_dsts"
)
(
node_pairs
,
neg_dsts
),
names
=
(
"node_pairs"
,
"neg
ative
_dsts"
)
)
)
assert
item_set
.
names
==
(
"node_pairs"
,
"neg_dsts"
)
assert
item_set
.
names
==
(
"node_pairs"
,
"neg
ative
_dsts"
)
for
i
,
(
node_pair
,
neg_dst
)
in
enumerate
(
item_set
):
for
i
,
(
node_pair
,
neg_dst
)
in
enumerate
(
item_set
):
assert
torch
.
equal
(
node_pairs
[
i
],
node_pair
)
assert
torch
.
equal
(
node_pairs
[
i
],
node_pair
)
assert
torch
.
equal
(
neg_dsts
[
i
],
neg_dst
)
assert
torch
.
equal
(
neg_dsts
[
i
],
neg_dst
)
...
@@ -319,17 +319,17 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts():
...
@@ -319,17 +319,17 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts():
neg_dsts
=
torch
.
arange
(
10
,
25
).
reshape
(
-
1
,
3
)
neg_dsts
=
torch
.
arange
(
10
,
25
).
reshape
(
-
1
,
3
)
node_pairs_neg_dsts
=
{
node_pairs_neg_dsts
=
{
"user:like:item"
:
gb
.
ItemSet
(
"user:like:item"
:
gb
.
ItemSet
(
(
node_pairs
,
neg_dsts
),
names
=
(
"node_pairs"
,
"neg_dsts"
)
(
node_pairs
,
neg_dsts
),
names
=
(
"node_pairs"
,
"neg
ative
_dsts"
)
),
),
"user:follow:user"
:
gb
.
ItemSet
(
"user:follow:user"
:
gb
.
ItemSet
(
(
node_pairs
,
neg_dsts
),
names
=
(
"node_pairs"
,
"neg_dsts"
)
(
node_pairs
,
neg_dsts
),
names
=
(
"node_pairs"
,
"neg
ative
_dsts"
)
),
),
}
}
expected_data
=
[]
expected_data
=
[]
for
key
,
value
in
node_pairs_neg_dsts
.
items
():
for
key
,
value
in
node_pairs_neg_dsts
.
items
():
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
expected_data
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
ItemSetDict
(
node_pairs_neg_dsts
)
item_set
=
gb
.
ItemSetDict
(
node_pairs_neg_dsts
)
assert
item_set
.
names
==
(
"node_pairs"
,
"neg_dsts"
)
assert
item_set
.
names
==
(
"node_pairs"
,
"neg
ative
_dsts"
)
for
i
,
item
in
enumerate
(
item_set
):
for
i
,
item
in
enumerate
(
item_set
):
assert
len
(
item
)
==
1
assert
len
(
item
)
==
1
assert
isinstance
(
item
,
dict
)
assert
isinstance
(
item
,
dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment