Unverified Commit 9c756a5e authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] refine ItemSet/ItemSetDict and add examples (#5834)

parent c9778b55
...@@ -13,10 +13,43 @@ class ItemSet: ...@@ -13,10 +13,43 @@ class ItemSet:
Parameters Parameters
---------- ----------
items: Iterable or Tuple[Iterable] items: Iterable or Tuple[Iterable]
Examples
--------
>>> import torch
>>> from dgl import graphbolt as gb
1. Single iterable.
>>> node_ids = torch.arange(0, 5)
>>> item_set = gb.ItemSet(node_ids)
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)]
2. Tuple of iterables with same shape.
>>> node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
>>> item_set = gb.ItemSet(node_pairs)
>>> list(item_set)
[(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)),
(tensor(3), tensor(8)), (tensor(4), tensor(9))]
3. Tuple of iterables with different shape.
>>> heads = torch.arange(0, 5)
>>> tails = torch.arange(5, 10)
>>> neg_tails = torch.arange(10, 20).reshape(5, 2)
>>> item_set = gb.ItemSet((heads, tails, neg_tails))
>>> list(item_set)
[(tensor(0), tensor(5), tensor([10, 11])),
(tensor(1), tensor(6), tensor([12, 13])),
(tensor(2), tensor(7), tensor([14, 15])),
(tensor(3), tensor(8), tensor([16, 17])),
(tensor(4), tensor(9), tensor([18, 19]))]
""" """
def __init__(self, items): def __init__(self, items):
if isinstance(items, tuple): if isinstance(items, tuple):
assert all(
items[0].size(0) == item.size(0) for item in items
), "Size mismatch between items in tuple."
self._items = items self._items = items
else: else:
self._items = (items,) self._items = (items,)
...@@ -29,12 +62,6 @@ class ItemSet: ...@@ -29,12 +62,6 @@ class ItemSet:
for item in zip_items: for item in zip_items:
yield tuple(item) yield tuple(item)
def __getitem__(self, _):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class ItemSetDict: class ItemSetDict:
r"""An iterable ItemsetDict. r"""An iterable ItemsetDict.
...@@ -45,6 +72,51 @@ class ItemSetDict: ...@@ -45,6 +72,51 @@ class ItemSetDict:
Parameters Parameters
---------- ----------
itemsets: Dict[str, ItemSet] itemsets: Dict[str, ItemSet]
Examples
--------
>>> import torch
>>> from dgl import graphbolt as gb
1. Single iterable.
>>> node_ids_user = torch.arange(0, 5)
>>> node_ids_item = torch.arange(5, 10)
>>> item_set = gb.ItemSetDict({
... 'user': gb.ItemSet(node_ids_user),
... 'item': gb.ItemSet(node_ids_item)})
>>> list(item_set)
[{'user': tensor(0)}, {'user': tensor(1)}, {'user': tensor(2)},
{'user': tensor(3)}, {'user': tensor(4)}, {'item': tensor(5)},
{'item': tensor(6)}, {'item': tensor(7)}, {'item': tensor(8)},
{'item': tensor(9)}]
2. Tuple of iterables with same shape.
>>> node_pairs_like = (torch.arange(0, 2), torch.arange(0, 2))
>>> node_pairs_follow = (torch.arange(0, 3), torch.arange(3, 6))
>>> item_set = gb.ItemSetDict({
... ('user', 'like', 'item'): gb.ItemSet(node_pairs_like),
... ('user', 'follow', 'user'): gb.ItemSet(node_pairs_follow)})
>>> list(item_set)
[{('user', 'like', 'item'): (tensor(0), tensor(0))},
{('user', 'like', 'item'): (tensor(1), tensor(1))},
{('user', 'follow', 'user'): (tensor(0), tensor(3))},
{('user', 'follow', 'user'): (tensor(1), tensor(4))},
{('user', 'follow', 'user'): (tensor(2), tensor(5))}]
3. Tuple of iterables with different shape.
>>> like = (torch.arange(0, 2), torch.arange(0, 2),
... torch.arange(0, 4).reshape(-1, 2))
>>> follow = (torch.arange(0, 3), torch.arange(3, 6),
... torch.arange(0, 6).reshape(-1, 2))
>>> item_set = gb.ItemSetDict({
... ('user', 'like', 'item'): gb.ItemSet(like),
... ('user', 'follow', 'user'): gb.ItemSet(follow)})
>>> list(item_set)
[{('user', 'like', 'item'): (tensor(0), tensor(0), tensor([0, 1]))},
{('user', 'like', 'item'): (tensor(1), tensor(1), tensor([2, 3]))},
{('user', 'follow', 'user'): (tensor(0), tensor(3), tensor([0, 1]))},
{('user', 'follow', 'user'): (tensor(1), tensor(4), tensor([2, 3]))},
{('user', 'follow', 'user'): (tensor(2), tensor(5), tensor([4, 5]))}]
""" """
def __init__(self, itemsets): def __init__(self, itemsets):
...@@ -54,9 +126,3 @@ class ItemSetDict: ...@@ -54,9 +126,3 @@ class ItemSetDict:
for key, itemset in self._itemsets.items(): for key, itemset in self._itemsets.items():
for item in itemset: for item in itemset:
yield {key: item} yield {key: item}
def __getitem__(self, _):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
import dgl import dgl
import pytest
import torch import torch
from dgl import graphbolt as gb
from torch.testing import assert_close from torch.testing import assert_close
from dgl.graphbolt import *
def test_mismatch_size_in_tuple():
# Size mismatch.
node_pairs = (torch.arange(0, 5), torch.arange(5, 11))
with pytest.raises(AssertionError):
_ = gb.ItemSet(node_pairs)
def test_ItemSet_node_edge_ids(): def test_ItemSet_node_edge_ids():
# Node or edge IDs. # Node or edge IDs.
item_set = ItemSet(torch.arange(0, 5)) item_set = gb.ItemSet(torch.arange(0, 5))
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert i == item.item() assert i == item.item()
...@@ -14,7 +22,7 @@ def test_ItemSet_node_edge_ids(): ...@@ -14,7 +22,7 @@ def test_ItemSet_node_edge_ids():
def test_ItemSet_graphs(): def test_ItemSet_graphs():
# Graphs. # Graphs.
graphs = [dgl.rand_graph(10, 20) for _ in range(5)] graphs = [dgl.rand_graph(10, 20) for _ in range(5)]
item_set = ItemSet(graphs) item_set = gb.ItemSet(graphs)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert graphs[i] == item assert graphs[i] == item
...@@ -22,7 +30,7 @@ def test_ItemSet_graphs(): ...@@ -22,7 +30,7 @@ def test_ItemSet_graphs():
def test_ItemSet_node_pairs(): def test_ItemSet_node_pairs():
# Node pairs. # Node pairs.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10)) node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
item_set = ItemSet(node_pairs) item_set = gb.ItemSet(node_pairs)
for i, (src, dst) in enumerate(item_set): for i, (src, dst) in enumerate(item_set):
assert node_pairs[0][i] == src assert node_pairs[0][i] == src
assert node_pairs[1][i] == dst assert node_pairs[1][i] == dst
...@@ -32,7 +40,7 @@ def test_ItemSet_node_pairs_labels(): ...@@ -32,7 +40,7 @@ def test_ItemSet_node_pairs_labels():
# Node pairs and labels # Node pairs and labels
node_pairs = (torch.arange(0, 5), torch.arange(5, 10)) node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
labels = torch.randint(0, 3, (5,)) labels = torch.randint(0, 3, (5,))
item_set = ItemSet((node_pairs[0], node_pairs[1], labels)) item_set = gb.ItemSet((node_pairs[0], node_pairs[1], labels))
for i, (src, dst, label) in enumerate(item_set): for i, (src, dst, label) in enumerate(item_set):
assert node_pairs[0][i] == src assert node_pairs[0][i] == src
assert node_pairs[1][i] == dst assert node_pairs[1][i] == dst
...@@ -44,7 +52,7 @@ def test_ItemSet_head_tail_neg_tails(): ...@@ -44,7 +52,7 @@ def test_ItemSet_head_tail_neg_tails():
heads = torch.arange(0, 5) heads = torch.arange(0, 5)
tails = torch.arange(5, 10) tails = torch.arange(5, 10)
neg_tails = torch.arange(10, 20).reshape(5, 2) neg_tails = torch.arange(10, 20).reshape(5, 2)
item_set = ItemSet((heads, tails, neg_tails)) item_set = gb.ItemSet((heads, tails, neg_tails))
for i, (head, tail, negs) in enumerate(item_set): for i, (head, tail, negs) in enumerate(item_set):
assert heads[i] == head assert heads[i] == head
assert tails[i] == tail assert tails[i] == tail
...@@ -54,13 +62,13 @@ def test_ItemSet_head_tail_neg_tails(): ...@@ -54,13 +62,13 @@ def test_ItemSet_head_tail_neg_tails():
def test_ItemSetDict_node_edge_ids(): def test_ItemSetDict_node_edge_ids():
# Node or edge IDs # Node or edge IDs
ids = { ids = {
("user", "like", "item"): ItemSet(torch.arange(0, 5)), ("user", "like", "item"): gb.ItemSet(torch.arange(0, 5)),
("user", "follow", "user"): ItemSet(torch.arange(0, 5)), ("user", "follow", "user"): gb.ItemSet(torch.arange(0, 5)),
} }
chained_ids = [] chained_ids = []
for key, value in ids.items(): for key, value in ids.items():
chained_ids += [(key, v) for v in value] chained_ids += [(key, v) for v in value]
item_set = ItemSetDict(ids) item_set = gb.ItemSetDict(ids)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
...@@ -72,13 +80,13 @@ def test_ItemSetDict_node_pairs(): ...@@ -72,13 +80,13 @@ def test_ItemSetDict_node_pairs():
# Node pairs. # Node pairs.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10)) node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
node_pairs_dict = { node_pairs_dict = {
("user", "like", "item"): ItemSet(node_pairs), ("user", "like", "item"): gb.ItemSet(node_pairs),
("user", "follow", "user"): ItemSet(node_pairs), ("user", "follow", "user"): gb.ItemSet(node_pairs),
} }
expected_data = [] expected_data = []
for key, value in node_pairs_dict.items(): for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = ItemSetDict(node_pairs_dict) item_set = gb.ItemSetDict(node_pairs_dict)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
...@@ -91,17 +99,17 @@ def test_ItemSetDict_node_pairs_labels(): ...@@ -91,17 +99,17 @@ def test_ItemSetDict_node_pairs_labels():
node_pairs = (torch.arange(0, 5), torch.arange(5, 10)) node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
labels = torch.randint(0, 3, (5,)) labels = torch.randint(0, 3, (5,))
node_pairs_dict = { node_pairs_dict = {
("user", "like", "item"): ItemSet( ("user", "like", "item"): gb.ItemSet(
(node_pairs[0], node_pairs[1], labels) (node_pairs[0], node_pairs[1], labels)
), ),
("user", "follow", "user"): ItemSet( ("user", "follow", "user"): gb.ItemSet(
(node_pairs[0], node_pairs[1], labels) (node_pairs[0], node_pairs[1], labels)
), ),
} }
expected_data = [] expected_data = []
for key, value in node_pairs_dict.items(): for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = ItemSetDict(node_pairs_dict) item_set = gb.ItemSetDict(node_pairs_dict)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
...@@ -114,15 +122,15 @@ def test_ItemSetDict_head_tail_neg_tails(): ...@@ -114,15 +122,15 @@ def test_ItemSetDict_head_tail_neg_tails():
heads = torch.arange(0, 5) heads = torch.arange(0, 5)
tails = torch.arange(5, 10) tails = torch.arange(5, 10)
neg_tails = torch.arange(10, 20).reshape(5, 2) neg_tails = torch.arange(10, 20).reshape(5, 2)
item_set = ItemSet((heads, tails, neg_tails)) item_set = gb.ItemSet((heads, tails, neg_tails))
data_dict = { data_dict = {
("user", "like", "item"): ItemSet((heads, tails, neg_tails)), ("user", "like", "item"): gb.ItemSet((heads, tails, neg_tails)),
("user", "follow", "user"): ItemSet((heads, tails, neg_tails)), ("user", "follow", "user"): gb.ItemSet((heads, tails, neg_tails)),
} }
expected_data = [] expected_data = []
for key, value in data_dict.items(): for key, value in data_dict.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = ItemSetDict(data_dict) item_set = gb.ItemSetDict(data_dict)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment