Unverified Commit 9c756a5e authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] refine ItemSet/ItemSetDict and add examples (#5834)

parent c9778b55
......@@ -13,10 +13,43 @@ class ItemSet:
Parameters
----------
items: Iterable or Tuple[Iterable]
Examples
--------
>>> import torch
>>> from dgl import graphbolt as gb
1. Single iterable.
>>> node_ids = torch.arange(0, 5)
>>> item_set = gb.ItemSet(node_ids)
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)]
2. Tuple of iterables with same shape.
>>> node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
>>> item_set = gb.ItemSet(node_pairs)
>>> list(item_set)
[(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)),
(tensor(3), tensor(8)), (tensor(4), tensor(9))]
3. Tuple of iterables with different shape.
>>> heads = torch.arange(0, 5)
>>> tails = torch.arange(5, 10)
>>> neg_tails = torch.arange(10, 20).reshape(5, 2)
>>> item_set = gb.ItemSet((heads, tails, neg_tails))
>>> list(item_set)
[(tensor(0), tensor(5), tensor([10, 11])),
(tensor(1), tensor(6), tensor([12, 13])),
(tensor(2), tensor(7), tensor([14, 15])),
(tensor(3), tensor(8), tensor([16, 17])),
(tensor(4), tensor(9), tensor([18, 19]))]
"""
def __init__(self, items):
if isinstance(items, tuple):
assert all(
items[0].size(0) == item.size(0) for item in items
), "Size mismatch between items in tuple."
self._items = items
else:
self._items = (items,)
......@@ -29,12 +62,6 @@ class ItemSet:
for item in zip_items:
yield tuple(item)
def __getitem__(self, _):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class ItemSetDict:
r"""An iterable ItemsetDict.
......@@ -45,6 +72,51 @@ class ItemSetDict:
Parameters
----------
itemsets: Dict[str, ItemSet]
Examples
--------
>>> import torch
>>> from dgl import graphbolt as gb
1. Single iterable.
>>> node_ids_user = torch.arange(0, 5)
>>> node_ids_item = torch.arange(5, 10)
>>> item_set = gb.ItemSetDict({
... 'user': gb.ItemSet(node_ids_user),
... 'item': gb.ItemSet(node_ids_item)})
>>> list(item_set)
[{'user': tensor(0)}, {'user': tensor(1)}, {'user': tensor(2)},
{'user': tensor(3)}, {'user': tensor(4)}, {'item': tensor(5)},
{'item': tensor(6)}, {'item': tensor(7)}, {'item': tensor(8)},
{'item': tensor(9)}]
2. Tuple of iterables with same shape.
>>> node_pairs_like = (torch.arange(0, 2), torch.arange(0, 2))
>>> node_pairs_follow = (torch.arange(0, 3), torch.arange(3, 6))
>>> item_set = gb.ItemSetDict({
... ('user', 'like', 'item'): gb.ItemSet(node_pairs_like),
... ('user', 'follow', 'user'): gb.ItemSet(node_pairs_follow)})
>>> list(item_set)
[{('user', 'like', 'item'): (tensor(0), tensor(0))},
{('user', 'like', 'item'): (tensor(1), tensor(1))},
{('user', 'follow', 'user'): (tensor(0), tensor(3))},
{('user', 'follow', 'user'): (tensor(1), tensor(4))},
{('user', 'follow', 'user'): (tensor(2), tensor(5))}]
3. Tuple of iterables with different shape.
>>> like = (torch.arange(0, 2), torch.arange(0, 2),
... torch.arange(0, 4).reshape(-1, 2))
>>> follow = (torch.arange(0, 3), torch.arange(3, 6),
... torch.arange(0, 6).reshape(-1, 2))
>>> item_set = gb.ItemSetDict({
... ('user', 'like', 'item'): gb.ItemSet(like),
... ('user', 'follow', 'user'): gb.ItemSet(follow)})
>>> list(item_set)
[{('user', 'like', 'item'): (tensor(0), tensor(0), tensor([0, 1]))},
{('user', 'like', 'item'): (tensor(1), tensor(1), tensor([2, 3]))},
{('user', 'follow', 'user'): (tensor(0), tensor(3), tensor([0, 1]))},
{('user', 'follow', 'user'): (tensor(1), tensor(4), tensor([2, 3]))},
{('user', 'follow', 'user'): (tensor(2), tensor(5), tensor([4, 5]))}]
"""
def __init__(self, itemsets):
......@@ -54,9 +126,3 @@ class ItemSetDict:
for key, itemset in self._itemsets.items():
for item in itemset:
yield {key: item}
def __getitem__(self, _):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
import dgl
import pytest
import torch
from dgl import graphbolt as gb
from torch.testing import assert_close
from dgl.graphbolt import *
def test_mismatch_size_in_tuple():
# Size mismatch.
node_pairs = (torch.arange(0, 5), torch.arange(5, 11))
with pytest.raises(AssertionError):
_ = gb.ItemSet(node_pairs)
def test_ItemSet_node_edge_ids():
# Node or edge IDs.
item_set = ItemSet(torch.arange(0, 5))
item_set = gb.ItemSet(torch.arange(0, 5))
for i, item in enumerate(item_set):
assert i == item.item()
......@@ -14,7 +22,7 @@ def test_ItemSet_node_edge_ids():
def test_ItemSet_graphs():
# Graphs.
graphs = [dgl.rand_graph(10, 20) for _ in range(5)]
item_set = ItemSet(graphs)
item_set = gb.ItemSet(graphs)
for i, item in enumerate(item_set):
assert graphs[i] == item
......@@ -22,7 +30,7 @@ def test_ItemSet_graphs():
def test_ItemSet_node_pairs():
# Node pairs.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
item_set = ItemSet(node_pairs)
item_set = gb.ItemSet(node_pairs)
for i, (src, dst) in enumerate(item_set):
assert node_pairs[0][i] == src
assert node_pairs[1][i] == dst
......@@ -32,7 +40,7 @@ def test_ItemSet_node_pairs_labels():
# Node pairs and labels
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
labels = torch.randint(0, 3, (5,))
item_set = ItemSet((node_pairs[0], node_pairs[1], labels))
item_set = gb.ItemSet((node_pairs[0], node_pairs[1], labels))
for i, (src, dst, label) in enumerate(item_set):
assert node_pairs[0][i] == src
assert node_pairs[1][i] == dst
......@@ -44,7 +52,7 @@ def test_ItemSet_head_tail_neg_tails():
heads = torch.arange(0, 5)
tails = torch.arange(5, 10)
neg_tails = torch.arange(10, 20).reshape(5, 2)
item_set = ItemSet((heads, tails, neg_tails))
item_set = gb.ItemSet((heads, tails, neg_tails))
for i, (head, tail, negs) in enumerate(item_set):
assert heads[i] == head
assert tails[i] == tail
......@@ -54,13 +62,13 @@ def test_ItemSet_head_tail_neg_tails():
def test_ItemSetDict_node_edge_ids():
# Node or edge IDs
ids = {
("user", "like", "item"): ItemSet(torch.arange(0, 5)),
("user", "follow", "user"): ItemSet(torch.arange(0, 5)),
("user", "like", "item"): gb.ItemSet(torch.arange(0, 5)),
("user", "follow", "user"): gb.ItemSet(torch.arange(0, 5)),
}
chained_ids = []
for key, value in ids.items():
chained_ids += [(key, v) for v in value]
item_set = ItemSetDict(ids)
item_set = gb.ItemSetDict(ids)
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
......@@ -72,13 +80,13 @@ def test_ItemSetDict_node_pairs():
# Node pairs.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
node_pairs_dict = {
("user", "like", "item"): ItemSet(node_pairs),
("user", "follow", "user"): ItemSet(node_pairs),
("user", "like", "item"): gb.ItemSet(node_pairs),
("user", "follow", "user"): gb.ItemSet(node_pairs),
}
expected_data = []
for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value]
item_set = ItemSetDict(node_pairs_dict)
item_set = gb.ItemSetDict(node_pairs_dict)
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
......@@ -91,17 +99,17 @@ def test_ItemSetDict_node_pairs_labels():
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
labels = torch.randint(0, 3, (5,))
node_pairs_dict = {
("user", "like", "item"): ItemSet(
("user", "like", "item"): gb.ItemSet(
(node_pairs[0], node_pairs[1], labels)
),
("user", "follow", "user"): ItemSet(
("user", "follow", "user"): gb.ItemSet(
(node_pairs[0], node_pairs[1], labels)
),
}
expected_data = []
for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value]
item_set = ItemSetDict(node_pairs_dict)
item_set = gb.ItemSetDict(node_pairs_dict)
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
......@@ -114,15 +122,15 @@ def test_ItemSetDict_head_tail_neg_tails():
heads = torch.arange(0, 5)
tails = torch.arange(5, 10)
neg_tails = torch.arange(10, 20).reshape(5, 2)
item_set = ItemSet((heads, tails, neg_tails))
item_set = gb.ItemSet((heads, tails, neg_tails))
data_dict = {
("user", "like", "item"): ItemSet((heads, tails, neg_tails)),
("user", "follow", "user"): ItemSet((heads, tails, neg_tails)),
("user", "like", "item"): gb.ItemSet((heads, tails, neg_tails)),
("user", "follow", "user"): gb.ItemSet((heads, tails, neg_tails)),
}
expected_data = []
for key, value in data_dict.items():
expected_data += [(key, v) for v in value]
item_set = ItemSetDict(data_dict)
item_set = gb.ItemSetDict(data_dict)
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment