Unverified Commit 729924e3 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] update docstring of ItemSet (#6282)

parent feaeb1c2
......@@ -15,36 +15,49 @@ class ItemSet:
Parameters
----------
items: Iterable or Tuple[Iterable]
The items to be iterated over. If it is a tuple, each item in the tuple
is an iterable of items.
names: str or Tuple[str], optional
The names of the items. If it is a tuple, each name corresponds to an
item in the tuple.
Examples
--------
>>> import torch
>>> from dgl import graphbolt as gb
1. Single iterable.
1. Single iterable: seed nodes.
>>> node_ids = torch.arange(0, 5)
>>> item_set = gb.ItemSet(node_ids)
>>> item_set = gb.ItemSet(node_ids, names="seed_nodes")
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)]
>>> item_set.names
('seed_nodes',)
2. Tuple of iterables with same shape.
2. Tuple of iterables with same shape: seed nodes and labels.
>>> node_ids = torch.arange(0, 5)
>>> labels = torch.arange(5, 10)
>>> item_set = gb.ItemSet((node_ids, labels))
>>> item_set = gb.ItemSet(
... (node_ids, labels), names=("seed_nodes", "labels"))
>>> list(item_set)
[(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)),
(tensor(3), tensor(8)), (tensor(4), tensor(9))]
>>> item_set.names
('seed_nodes', 'labels')
3. Tuple of iterables with different shape.
3. Tuple of iterables with different shape: node pairs and negative dsts.
>>> node_pairs = torch.arange(0, 10).reshape(-1, 2)
>>> neg_dsts = torch.arange(10, 25).reshape(-1, 3)
>>> item_set = gb.ItemSet((node_pairs, neg_dsts))
>>> item_set = gb.ItemSet(
... (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts"))
>>> list(item_set)
[(tensor([0, 1]), tensor([10, 11, 12])),
(tensor([2, 3]), tensor([13, 14, 15])),
(tensor([4, 5]), tensor([16, 17, 18])),
(tensor([6, 7]), tensor([19, 20, 21])),
(tensor([8, 9]), tensor([22, 23, 24]))]
>>> item_set.names
('node_pairs', 'negative_dsts')
"""
def __init__(
......@@ -104,45 +117,59 @@ class ItemSetDict:
>>> import torch
>>> from dgl import graphbolt as gb
1. Single iterable.
1. Single iterable: seed nodes.
>>> node_ids_user = torch.arange(0, 5)
>>> node_ids_item = torch.arange(5, 10)
>>> item_set = gb.ItemSetDict({
... "user": gb.ItemSet(node_ids_user),
... "item": gb.ItemSet(node_ids_item)})
... "user": gb.ItemSet(node_ids_user, names="seed_nodes"),
... "item": gb.ItemSet(node_ids_item, names="seed_nodes")})
>>> list(item_set)
[{"user": tensor(0)}, {"user": tensor(1)}, {"user": tensor(2)},
{"user": tensor(3)}, {"user": tensor(4)}, {"item": tensor(5)},
{"item": tensor(6)}, {"item": tensor(7)}, {"item": tensor(8)},
{"item": tensor(9)}}]
>>> item_set.names
('seed_nodes',)
2. Tuple of iterables with same shape.
2. Tuple of iterables with same shape: seed nodes and labels.
>>> node_ids_user = torch.arange(0, 2)
>>> labels_user = torch.arange(0, 2)
>>> node_ids_item = torch.arange(2, 5)
>>> labels_item = torch.arange(2, 5)
>>> item_set = gb.ItemSetDict({
... "user": gb.ItemSet((node_ids_user, labels_user)),
... "item": gb.ItemSet((node_ids_item, labels_item))})
... "user": gb.ItemSet(
... (node_ids_user, labels_user),
... names=("seed_nodes", "labels")),
... "item": gb.ItemSet(
... (node_ids_item, labels_item),
... names=("seed_nodes", "labels"))})
>>> list(item_set)
[{"user": (tensor(0), tensor(0))}, {"user": (tensor(1), tensor(1))},
{"item": (tensor(2), tensor(2))}, {"item": (tensor(3), tensor(3))},
{"item": (tensor(4), tensor(4))}}]
>>> item_set.names
('seed_nodes', 'labels')
3. Tuple of iterables with different shape.
3. Tuple of iterables with different shape: node pairs and negative dsts.
>>> node_pairs_like = torch.arange(0, 4).reshape(-1, 2)
>>> neg_dsts_like = torch.arange(4, 10).reshape(-1, 3)
>>> node_pairs_follow = torch.arange(0, 6).reshape(-1, 2)
>>> neg_dsts_follow = torch.arange(6, 15).reshape(-1, 3)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet((node_pairs_like, neg_dsts_like)),
... "user:follow:user": gb.ItemSet((node_pairs_follow, neg_dsts_follow))})
... "user:like:item": gb.ItemSet(
... (node_pairs_like, neg_dsts_like),
... names=("node_pairs", "negative_dsts")),
... "user:follow:user": gb.ItemSet(
... (node_pairs_follow, neg_dsts_follow),
... names=("node_pairs", "negative_dsts"))})
>>> list(item_set)
[{"user:like:item": (tensor([0, 1]), tensor([4, 5, 6]))},
{"user:like:item": (tensor([2, 3]), tensor([7, 8, 9]))},
{"user:follow:user": (tensor([0, 1]), tensor([ 6, 7, 8, 9, 10, 11]))},
{"user:follow:user": (tensor([2, 3]), tensor([12, 13, 14, 15, 16, 17]))},
{"user:follow:user": (tensor([4, 5]), tensor([18, 19, 20, 21, 22, 23]))}]
>>> item_set.names
('node_pairs', 'negative_dsts')
"""
def __init__(self, itemsets: Dict[str, ItemSet]) -> None:
......
......@@ -101,9 +101,9 @@ def test_ItemSet_iteration_node_pairs_neg_dsts():
node_pairs = torch.arange(0, 10).reshape(-1, 2)
neg_dsts = torch.arange(10, 25).reshape(-1, 3)
item_set = gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "neg_dsts")
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
)
assert item_set.names == ("node_pairs", "neg_dsts")
assert item_set.names == ("node_pairs", "negative_dsts")
for i, (node_pair, neg_dst) in enumerate(item_set):
assert torch.equal(node_pairs[i], node_pair)
assert torch.equal(neg_dsts[i], neg_dst)
......@@ -319,17 +319,17 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts():
neg_dsts = torch.arange(10, 25).reshape(-1, 3)
node_pairs_neg_dsts = {
"user:like:item": gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "neg_dsts")
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
),
"user:follow:user": gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "neg_dsts")
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
),
}
expected_data = []
for key, value in node_pairs_neg_dsts.items():
expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_neg_dsts)
assert item_set.names == ("node_pairs", "neg_dsts")
assert item_set.names == ("node_pairs", "negative_dsts")
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment