Unverified Commit 5ada3fc9 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] rename DictItemSet as ItemSetDict (#5806)

parent d88275ca
"""GraphBolt Itemset.""" """GraphBolt Itemset."""
__all__ = ["ItemSet", "DictItemSet"] __all__ = ["ItemSet", "ItemSetDict"]
class ItemSet: class ItemSet:
...@@ -36,8 +36,8 @@ class ItemSet: ...@@ -36,8 +36,8 @@ class ItemSet:
raise NotImplementedError raise NotImplementedError
class DictItemSet: class ItemSetDict:
r"""Itemset wrapping multiple itemsets with keys. r"""An iterable ItemsetDict.
Each item is retrieved by iterating over each itemset and returned with Each item is retrieved by iterating over each itemset and returned with
corresponding key as a dict. corresponding key as a dict.
......
...@@ -9,7 +9,7 @@ from torchdata.datapipes.iter import IterableWrapper, IterDataPipe ...@@ -9,7 +9,7 @@ from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from ..batch import batch as dgl_batch from ..batch import batch as dgl_batch
from ..heterograph import DGLGraph from ..heterograph import DGLGraph
from .itemset import DictItemSet, ItemSet from .itemset import ItemSet, ItemSetDict
__all__ = ["MinibatchSampler"] __all__ = ["MinibatchSampler"]
...@@ -28,7 +28,7 @@ class MinibatchSampler(IterDataPipe): ...@@ -28,7 +28,7 @@ class MinibatchSampler(IterDataPipe):
Parameters Parameters
---------- ----------
item_set : ItemSet or DictItemSet item_set : ItemSet or ItemSetDict
Data to be sampled for mini-batches. Data to be sampled for mini-batches.
batch_size : int batch_size : int
The size of each batch. The size of each batch.
...@@ -106,7 +106,7 @@ class MinibatchSampler(IterDataPipe): ...@@ -106,7 +106,7 @@ class MinibatchSampler(IterDataPipe):
... "user": gb.ItemSet(torch.arange(0, 5)), ... "user": gb.ItemSet(torch.arange(0, 5)),
... "item": gb.ItemSet(torch.arange(0, 6)), ... "item": gb.ItemSet(torch.arange(0, 6)),
... } ... }
>>> item_set = gb.DictItemSet(ids) >>> item_set = gb.ItemSetDict(ids)
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4) >>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler) >>> list(minibatch_sampler)
[{'user': tensor([0, 1, 2, 3])}, [{'user': tensor([0, 1, 2, 3])},
...@@ -116,7 +116,7 @@ class MinibatchSampler(IterDataPipe): ...@@ -116,7 +116,7 @@ class MinibatchSampler(IterDataPipe):
8. Heterogeneous node pairs. 8. Heterogeneous node pairs.
>>> node_pairs_like = (torch.arange(0, 5), torch.arange(0, 5)) >>> node_pairs_like = (torch.arange(0, 5), torch.arange(0, 5))
>>> node_pairs_follow = (torch.arange(0, 6), torch.arange(6, 12)) >>> node_pairs_follow = (torch.arange(0, 6), torch.arange(6, 12))
>>> item_set = gb.DictItemSet({ >>> item_set = gb.ItemSetDict({
... ("user", "like", "item"): gb.ItemSet(node_pairs_like), ... ("user", "like", "item"): gb.ItemSet(node_pairs_like),
... ("user", "follow", "user"): gb.ItemSet(node_pairs_follow), ... ("user", "follow", "user"): gb.ItemSet(node_pairs_follow),
... }) ... })
...@@ -132,7 +132,7 @@ class MinibatchSampler(IterDataPipe): ...@@ -132,7 +132,7 @@ class MinibatchSampler(IterDataPipe):
... torch.arange(0, 5), torch.arange(0, 5), torch.arange(0, 5)) ... torch.arange(0, 5), torch.arange(0, 5), torch.arange(0, 5))
>>> follow = ( >>> follow = (
... torch.arange(0, 6), torch.arange(6, 12), torch.arange(0, 6)) ... torch.arange(0, 6), torch.arange(6, 12), torch.arange(0, 6))
>>> item_set = gb.DictItemSet({ >>> item_set = gb.ItemSetDict({
... ("user", "like", "item"): gb.ItemSet(like), ... ("user", "like", "item"): gb.ItemSet(like),
... ("user", "follow", "user"): gb.ItemSet(follow), ... ("user", "follow", "user"): gb.ItemSet(follow),
... }) ... })
...@@ -153,7 +153,7 @@ class MinibatchSampler(IterDataPipe): ...@@ -153,7 +153,7 @@ class MinibatchSampler(IterDataPipe):
>>> follow = ( >>> follow = (
... torch.arange(0, 6), torch.arange(6, 12), ... torch.arange(0, 6), torch.arange(6, 12),
... torch.arange(12, 24).reshape(-1, 2)) ... torch.arange(12, 24).reshape(-1, 2))
>>> item_set = gb.DictItemSet({ >>> item_set = gb.ItemSetDict({
... ("user", "like", "item"): gb.ItemSet(like), ... ("user", "like", "item"): gb.ItemSet(like),
... ("user", "follow", "user"): gb.ItemSet(follow), ... ("user", "follow", "user"): gb.ItemSet(follow),
... }) ... })
...@@ -170,7 +170,7 @@ class MinibatchSampler(IterDataPipe): ...@@ -170,7 +170,7 @@ class MinibatchSampler(IterDataPipe):
def __init__( def __init__(
self, self,
item_set: ItemSet or DictItemSet, item_set: ItemSet or ItemSetDict,
batch_size: int, batch_size: int,
drop_last: Optional[bool] = False, drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False, shuffle: Optional[bool] = False,
......
...@@ -51,7 +51,7 @@ def test_ItemSet_head_tail_neg_tails(): ...@@ -51,7 +51,7 @@ def test_ItemSet_head_tail_neg_tails():
assert_close(neg_tails[i], negs) assert_close(neg_tails[i], negs)
def test_DictItemSet_node_edge_ids(): def test_ItemSetDict_node_edge_ids():
# Node or edge IDs # Node or edge IDs
ids = { ids = {
("user", "like", "item"): ItemSet(torch.arange(0, 5)), ("user", "like", "item"): ItemSet(torch.arange(0, 5)),
...@@ -60,7 +60,7 @@ def test_DictItemSet_node_edge_ids(): ...@@ -60,7 +60,7 @@ def test_DictItemSet_node_edge_ids():
chained_ids = [] chained_ids = []
for key, value in ids.items(): for key, value in ids.items():
chained_ids += [(key, v) for v in value] chained_ids += [(key, v) for v in value]
item_set = DictItemSet(ids) item_set = ItemSetDict(ids)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
...@@ -68,7 +68,7 @@ def test_DictItemSet_node_edge_ids(): ...@@ -68,7 +68,7 @@ def test_DictItemSet_node_edge_ids():
assert item[chained_ids[i][0]] == chained_ids[i][1] assert item[chained_ids[i][0]] == chained_ids[i][1]
def test_DictItemSet_node_pairs(): def test_ItemSetDict_node_pairs():
# Node pairs. # Node pairs.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10)) node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
node_pairs_dict = { node_pairs_dict = {
...@@ -78,7 +78,7 @@ def test_DictItemSet_node_pairs(): ...@@ -78,7 +78,7 @@ def test_DictItemSet_node_pairs():
expected_data = [] expected_data = []
for key, value in node_pairs_dict.items(): for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = DictItemSet(node_pairs_dict) item_set = ItemSetDict(node_pairs_dict)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
...@@ -86,7 +86,7 @@ def test_DictItemSet_node_pairs(): ...@@ -86,7 +86,7 @@ def test_DictItemSet_node_pairs():
assert item[expected_data[i][0]] == expected_data[i][1] assert item[expected_data[i][0]] == expected_data[i][1]
def test_DictItemSet_node_pairs_labels(): def test_ItemSetDict_node_pairs_labels():
# Node pairs and labels # Node pairs and labels
node_pairs = (torch.arange(0, 5), torch.arange(5, 10)) node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
labels = torch.randint(0, 3, (5,)) labels = torch.randint(0, 3, (5,))
...@@ -101,7 +101,7 @@ def test_DictItemSet_node_pairs_labels(): ...@@ -101,7 +101,7 @@ def test_DictItemSet_node_pairs_labels():
expected_data = [] expected_data = []
for key, value in node_pairs_dict.items(): for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = DictItemSet(node_pairs_dict) item_set = ItemSetDict(node_pairs_dict)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
...@@ -109,7 +109,7 @@ def test_DictItemSet_node_pairs_labels(): ...@@ -109,7 +109,7 @@ def test_DictItemSet_node_pairs_labels():
assert item[expected_data[i][0]] == expected_data[i][1] assert item[expected_data[i][0]] == expected_data[i][1]
def test_DictItemSet_head_tail_neg_tails(): def test_ItemSetDict_head_tail_neg_tails():
# Head, tail and negative tails. # Head, tail and negative tails.
heads = torch.arange(0, 5) heads = torch.arange(0, 5)
tails = torch.arange(5, 10) tails = torch.arange(5, 10)
...@@ -122,7 +122,7 @@ def test_DictItemSet_head_tail_neg_tails(): ...@@ -122,7 +122,7 @@ def test_DictItemSet_head_tail_neg_tails():
expected_data = [] expected_data = []
for key, value in data_dict.items(): for key, value in data_dict.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = DictItemSet(data_dict) item_set = ItemSetDict(data_dict)
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
......
...@@ -215,7 +215,7 @@ def test_append_with_other_datapipes(): ...@@ -215,7 +215,7 @@ def test_append_with_other_datapipes():
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
def test_DictItemSet_node_ids(batch_size, shuffle, drop_last): def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
# Node IDs. # Node IDs.
num_ids = 205 num_ids = 205
ids = { ids = {
...@@ -225,7 +225,7 @@ def test_DictItemSet_node_ids(batch_size, shuffle, drop_last): ...@@ -225,7 +225,7 @@ def test_DictItemSet_node_ids(batch_size, shuffle, drop_last):
chained_ids = [] chained_ids = []
for key, value in ids.items(): for key, value in ids.items():
chained_ids += [(key, v) for v in value] chained_ids += [(key, v) for v in value]
item_set = gb.DictItemSet(ids) item_set = gb.ItemSetDict(ids)
minibatch_sampler = gb.MinibatchSampler( minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
...@@ -253,7 +253,7 @@ def test_DictItemSet_node_ids(batch_size, shuffle, drop_last): ...@@ -253,7 +253,7 @@ def test_DictItemSet_node_ids(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
def test_DictItemSet_node_pairs(batch_size, shuffle, drop_last): def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
# Node pairs. # Node pairs.
num_ids = 103 num_ids = 103
total_ids = 2 * num_ids total_ids = 2 * num_ids
...@@ -269,7 +269,7 @@ def test_DictItemSet_node_pairs(batch_size, shuffle, drop_last): ...@@ -269,7 +269,7 @@ def test_DictItemSet_node_pairs(batch_size, shuffle, drop_last):
("user", "like", "item"): gb.ItemSet(node_pairs_0), ("user", "like", "item"): gb.ItemSet(node_pairs_0),
("user", "follow", "user"): gb.ItemSet(node_pairs_1), ("user", "follow", "user"): gb.ItemSet(node_pairs_1),
} }
item_set = gb.DictItemSet(node_pairs_dict) item_set = gb.ItemSetDict(node_pairs_dict)
minibatch_sampler = gb.MinibatchSampler( minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
...@@ -305,7 +305,7 @@ def test_DictItemSet_node_pairs(batch_size, shuffle, drop_last): ...@@ -305,7 +305,7 @@ def test_DictItemSet_node_pairs(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
def test_DictItemSet_node_pairs_labels(batch_size, shuffle, drop_last): def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
# Node pairs and labels # Node pairs and labels
num_ids = 103 num_ids = 103
total_ids = 2 * num_ids total_ids = 2 * num_ids
...@@ -326,7 +326,7 @@ def test_DictItemSet_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -326,7 +326,7 @@ def test_DictItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
(node_pairs_1[0], node_pairs_1[1], labels + num_ids * 2) (node_pairs_1[0], node_pairs_1[1], labels + num_ids * 2)
), ),
} }
item_set = gb.DictItemSet(node_pairs_dict) item_set = gb.ItemSetDict(node_pairs_dict)
minibatch_sampler = gb.MinibatchSampler( minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
...@@ -371,7 +371,7 @@ def test_DictItemSet_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -371,7 +371,7 @@ def test_DictItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
def test_DictItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last): def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last):
# Head, tail and negative tails. # Head, tail and negative tails.
num_ids = 103 num_ids = 103
total_ids = 2 * num_ids total_ids = 2 * num_ids
...@@ -383,7 +383,7 @@ def test_DictItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last): ...@@ -383,7 +383,7 @@ def test_DictItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last):
("user", "like", "item"): gb.ItemSet((heads, tails, neg_tails)), ("user", "like", "item"): gb.ItemSet((heads, tails, neg_tails)),
("user", "follow", "user"): gb.ItemSet((heads, tails, neg_tails)), ("user", "follow", "user"): gb.ItemSet((heads, tails, neg_tails)),
} }
item_set = gb.DictItemSet(data_dict) item_set = gb.ItemSetDict(data_dict)
minibatch_sampler = gb.MinibatchSampler( minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment