Unverified Commit 729924e3 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] update docstring of ItemSet (#6282)

parent feaeb1c2
...@@ -15,36 +15,49 @@ class ItemSet: ...@@ -15,36 +15,49 @@ class ItemSet:
Parameters Parameters
---------- ----------
items: Iterable or Tuple[Iterable] items: Iterable or Tuple[Iterable]
The items to be iterated over. If it is a tuple, each item in the tuple
is an iterable of items.
names: str or Tuple[str], optional
The names of the items. If it is a tuple, each name corresponds to an
item in the tuple.
Examples Examples
-------- --------
>>> import torch >>> import torch
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
1. Single iterable. 1. Single iterable: seed nodes.
>>> node_ids = torch.arange(0, 5) >>> node_ids = torch.arange(0, 5)
>>> item_set = gb.ItemSet(node_ids) >>> item_set = gb.ItemSet(node_ids, names="seed_nodes")
>>> list(item_set) >>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)] [tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)]
>>> item_set.names
('seed_nodes',)
2. Tuple of iterables with same shape. 2. Tuple of iterables with same shape: seed nodes and labels.
>>> node_ids = torch.arange(0, 5) >>> node_ids = torch.arange(0, 5)
>>> labels = torch.arange(5, 10) >>> labels = torch.arange(5, 10)
>>> item_set = gb.ItemSet((node_ids, labels)) >>> item_set = gb.ItemSet(
... (node_ids, labels), names=("seed_nodes", "labels"))
>>> list(item_set) >>> list(item_set)
[(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)), [(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)),
(tensor(3), tensor(8)), (tensor(4), tensor(9))] (tensor(3), tensor(8)), (tensor(4), tensor(9))]
>>> item_set.names
('seed_nodes', 'labels')
3. Tuple of iterables with different shape. 3. Tuple of iterables with different shape: node pairs and negative dsts.
>>> node_pairs = torch.arange(0, 10).reshape(-1, 2) >>> node_pairs = torch.arange(0, 10).reshape(-1, 2)
>>> neg_dsts = torch.arange(10, 25).reshape(-1, 3) >>> neg_dsts = torch.arange(10, 25).reshape(-1, 3)
>>> item_set = gb.ItemSet((node_pairs, neg_dsts)) >>> item_set = gb.ItemSet(
... (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts"))
>>> list(item_set) >>> list(item_set)
[(tensor([0, 1]), tensor([10, 11, 12])), [(tensor([0, 1]), tensor([10, 11, 12])),
(tensor([2, 3]), tensor([13, 14, 15])), (tensor([2, 3]), tensor([13, 14, 15])),
(tensor([4, 5]), tensor([16, 17, 18])), (tensor([4, 5]), tensor([16, 17, 18])),
(tensor([6, 7]), tensor([19, 20, 21])), (tensor([6, 7]), tensor([19, 20, 21])),
(tensor([8, 9]), tensor([22, 23, 24]))] (tensor([8, 9]), tensor([22, 23, 24]))]
>>> item_set.names
('node_pairs', 'negative_dsts')
""" """
def __init__( def __init__(
...@@ -104,45 +117,59 @@ class ItemSetDict: ...@@ -104,45 +117,59 @@ class ItemSetDict:
>>> import torch >>> import torch
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
1. Single iterable. 1. Single iterable: seed nodes.
>>> node_ids_user = torch.arange(0, 5) >>> node_ids_user = torch.arange(0, 5)
>>> node_ids_item = torch.arange(5, 10) >>> node_ids_item = torch.arange(5, 10)
>>> item_set = gb.ItemSetDict({ >>> item_set = gb.ItemSetDict({
... "user": gb.ItemSet(node_ids_user), ... "user": gb.ItemSet(node_ids_user, names="seed_nodes"),
... "item": gb.ItemSet(node_ids_item)}) ... "item": gb.ItemSet(node_ids_item, names="seed_nodes")})
>>> list(item_set) >>> list(item_set)
[{"user": tensor(0)}, {"user": tensor(1)}, {"user": tensor(2)}, [{"user": tensor(0)}, {"user": tensor(1)}, {"user": tensor(2)},
{"user": tensor(3)}, {"user": tensor(4)}, {"item": tensor(5)}, {"user": tensor(3)}, {"user": tensor(4)}, {"item": tensor(5)},
{"item": tensor(6)}, {"item": tensor(7)}, {"item": tensor(8)}, {"item": tensor(6)}, {"item": tensor(7)}, {"item": tensor(8)},
{"item": tensor(9)}}] {"item": tensor(9)}}]
>>> item_set.names
('seed_nodes',)
2. Tuple of iterables with same shape. 2. Tuple of iterables with same shape: seed nodes and labels.
>>> node_ids_user = torch.arange(0, 2) >>> node_ids_user = torch.arange(0, 2)
>>> labels_user = torch.arange(0, 2) >>> labels_user = torch.arange(0, 2)
>>> node_ids_item = torch.arange(2, 5) >>> node_ids_item = torch.arange(2, 5)
>>> labels_item = torch.arange(2, 5) >>> labels_item = torch.arange(2, 5)
>>> item_set = gb.ItemSetDict({ >>> item_set = gb.ItemSetDict({
... "user": gb.ItemSet((node_ids_user, labels_user)), ... "user": gb.ItemSet(
... "item": gb.ItemSet((node_ids_item, labels_item))}) ... (node_ids_user, labels_user),
... names=("seed_nodes", "labels")),
... "item": gb.ItemSet(
... (node_ids_item, labels_item),
... names=("seed_nodes", "labels"))})
>>> list(item_set) >>> list(item_set)
[{"user": (tensor(0), tensor(0))}, {"user": (tensor(1), tensor(1))}, [{"user": (tensor(0), tensor(0))}, {"user": (tensor(1), tensor(1))},
{"item": (tensor(2), tensor(2))}, {"item": (tensor(3), tensor(3))}, {"item": (tensor(2), tensor(2))}, {"item": (tensor(3), tensor(3))},
{"item": (tensor(4), tensor(4))}}] {"item": (tensor(4), tensor(4))}}]
>>> item_set.names
('seed_nodes', 'labels')
3. Tuple of iterables with different shape. 3. Tuple of iterables with different shape: node pairs and negative dsts.
>>> node_pairs_like = torch.arange(0, 4).reshape(-1, 2) >>> node_pairs_like = torch.arange(0, 4).reshape(-1, 2)
>>> neg_dsts_like = torch.arange(4, 10).reshape(-1, 3) >>> neg_dsts_like = torch.arange(4, 10).reshape(-1, 3)
>>> node_pairs_follow = torch.arange(0, 6).reshape(-1, 2) >>> node_pairs_follow = torch.arange(0, 6).reshape(-1, 2)
>>> neg_dsts_follow = torch.arange(6, 15).reshape(-1, 3) >>> neg_dsts_follow = torch.arange(6, 15).reshape(-1, 3)
>>> item_set = gb.ItemSetDict({ >>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet((node_pairs_like, neg_dsts_like)), ... "user:like:item": gb.ItemSet(
... "user:follow:user": gb.ItemSet((node_pairs_follow, neg_dsts_follow))}) ... (node_pairs_like, neg_dsts_like),
... names=("node_pairs", "negative_dsts")),
... "user:follow:user": gb.ItemSet(
... (node_pairs_follow, neg_dsts_follow),
... names=("node_pairs", "negative_dsts"))})
>>> list(item_set) >>> list(item_set)
[{"user:like:item": (tensor([0, 1]), tensor([4, 5, 6]))}, [{"user:like:item": (tensor([0, 1]), tensor([4, 5, 6]))},
{"user:like:item": (tensor([2, 3]), tensor([7, 8, 9]))}, {"user:like:item": (tensor([2, 3]), tensor([7, 8, 9]))},
{"user:follow:user": (tensor([0, 1]), tensor([ 6, 7, 8, 9, 10, 11]))}, {"user:follow:user": (tensor([0, 1]), tensor([ 6, 7, 8, 9, 10, 11]))},
{"user:follow:user": (tensor([2, 3]), tensor([12, 13, 14, 15, 16, 17]))}, {"user:follow:user": (tensor([2, 3]), tensor([12, 13, 14, 15, 16, 17]))},
{"user:follow:user": (tensor([4, 5]), tensor([18, 19, 20, 21, 22, 23]))}] {"user:follow:user": (tensor([4, 5]), tensor([18, 19, 20, 21, 22, 23]))}]
>>> item_set.names
('node_pairs', 'negative_dsts')
""" """
def __init__(self, itemsets: Dict[str, ItemSet]) -> None: def __init__(self, itemsets: Dict[str, ItemSet]) -> None:
......
...@@ -101,9 +101,9 @@ def test_ItemSet_iteration_node_pairs_neg_dsts(): ...@@ -101,9 +101,9 @@ def test_ItemSet_iteration_node_pairs_neg_dsts():
node_pairs = torch.arange(0, 10).reshape(-1, 2) node_pairs = torch.arange(0, 10).reshape(-1, 2)
neg_dsts = torch.arange(10, 25).reshape(-1, 3) neg_dsts = torch.arange(10, 25).reshape(-1, 3)
item_set = gb.ItemSet( item_set = gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "neg_dsts") (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
) )
assert item_set.names == ("node_pairs", "neg_dsts") assert item_set.names == ("node_pairs", "negative_dsts")
for i, (node_pair, neg_dst) in enumerate(item_set): for i, (node_pair, neg_dst) in enumerate(item_set):
assert torch.equal(node_pairs[i], node_pair) assert torch.equal(node_pairs[i], node_pair)
assert torch.equal(neg_dsts[i], neg_dst) assert torch.equal(neg_dsts[i], neg_dst)
...@@ -319,17 +319,17 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts(): ...@@ -319,17 +319,17 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts():
neg_dsts = torch.arange(10, 25).reshape(-1, 3) neg_dsts = torch.arange(10, 25).reshape(-1, 3)
node_pairs_neg_dsts = { node_pairs_neg_dsts = {
"user:like:item": gb.ItemSet( "user:like:item": gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "neg_dsts") (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
), ),
"user:follow:user": gb.ItemSet( "user:follow:user": gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "neg_dsts") (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
), ),
} }
expected_data = [] expected_data = []
for key, value in node_pairs_neg_dsts.items(): for key, value in node_pairs_neg_dsts.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_neg_dsts) item_set = gb.ItemSetDict(node_pairs_neg_dsts)
assert item_set.names == ("node_pairs", "neg_dsts") assert item_set.names == ("node_pairs", "negative_dsts")
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment