Unverified Commit 5ada3fc9 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] rename DictItemSet as ItemSetDict (#5806)

parent d88275ca
"""GraphBolt Itemset."""
__all__ = ["ItemSet", "DictItemSet"]
__all__ = ["ItemSet", "ItemSetDict"]
class ItemSet:
......@@ -36,8 +36,8 @@ class ItemSet:
raise NotImplementedError
class DictItemSet:
r"""Itemset wrapping multiple itemsets with keys.
class ItemSetDict:
r"""An iterable ItemsetDict.
Each item is retrieved by iterating over each itemset and returned with
corresponding key as a dict.
......
......@@ -9,7 +9,7 @@ from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from ..batch import batch as dgl_batch
from ..heterograph import DGLGraph
from .itemset import DictItemSet, ItemSet
from .itemset import ItemSet, ItemSetDict
__all__ = ["MinibatchSampler"]
......@@ -28,7 +28,7 @@ class MinibatchSampler(IterDataPipe):
Parameters
----------
item_set : ItemSet or DictItemSet
item_set : ItemSet or ItemSetDict
Data to be sampled for mini-batches.
batch_size : int
The size of each batch.
......@@ -106,7 +106,7 @@ class MinibatchSampler(IterDataPipe):
... "user": gb.ItemSet(torch.arange(0, 5)),
... "item": gb.ItemSet(torch.arange(0, 6)),
... }
>>> item_set = gb.DictItemSet(ids)
>>> item_set = gb.ItemSetDict(ids)
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
[{'user': tensor([0, 1, 2, 3])},
......@@ -116,7 +116,7 @@ class MinibatchSampler(IterDataPipe):
8. Heterogeneous node pairs.
>>> node_pairs_like = (torch.arange(0, 5), torch.arange(0, 5))
>>> node_pairs_follow = (torch.arange(0, 6), torch.arange(6, 12))
>>> item_set = gb.DictItemSet({
>>> item_set = gb.ItemSetDict({
... ("user", "like", "item"): gb.ItemSet(node_pairs_like),
... ("user", "follow", "user"): gb.ItemSet(node_pairs_follow),
... })
......@@ -132,7 +132,7 @@ class MinibatchSampler(IterDataPipe):
... torch.arange(0, 5), torch.arange(0, 5), torch.arange(0, 5))
>>> follow = (
... torch.arange(0, 6), torch.arange(6, 12), torch.arange(0, 6))
>>> item_set = gb.DictItemSet({
>>> item_set = gb.ItemSetDict({
... ("user", "like", "item"): gb.ItemSet(like),
... ("user", "follow", "user"): gb.ItemSet(follow),
... })
......@@ -153,7 +153,7 @@ class MinibatchSampler(IterDataPipe):
>>> follow = (
... torch.arange(0, 6), torch.arange(6, 12),
... torch.arange(12, 24).reshape(-1, 2))
>>> item_set = gb.DictItemSet({
>>> item_set = gb.ItemSetDict({
... ("user", "like", "item"): gb.ItemSet(like),
... ("user", "follow", "user"): gb.ItemSet(follow),
... })
......@@ -170,7 +170,7 @@ class MinibatchSampler(IterDataPipe):
def __init__(
self,
item_set: ItemSet or DictItemSet,
item_set: ItemSet or ItemSetDict,
batch_size: int,
drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False,
......
......@@ -51,7 +51,7 @@ def test_ItemSet_head_tail_neg_tails():
assert_close(neg_tails[i], negs)
def test_DictItemSet_node_edge_ids():
def test_ItemSetDict_node_edge_ids():
# Node or edge IDs
ids = {
("user", "like", "item"): ItemSet(torch.arange(0, 5)),
......@@ -60,7 +60,7 @@ def test_DictItemSet_node_edge_ids():
chained_ids = []
for key, value in ids.items():
chained_ids += [(key, v) for v in value]
item_set = DictItemSet(ids)
item_set = ItemSetDict(ids)
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
......@@ -68,7 +68,7 @@ def test_DictItemSet_node_edge_ids():
assert item[chained_ids[i][0]] == chained_ids[i][1]
def test_DictItemSet_node_pairs():
def test_ItemSetDict_node_pairs():
# Node pairs.
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
node_pairs_dict = {
......@@ -78,7 +78,7 @@ def test_DictItemSet_node_pairs():
expected_data = []
for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value]
item_set = DictItemSet(node_pairs_dict)
item_set = ItemSetDict(node_pairs_dict)
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
......@@ -86,7 +86,7 @@ def test_DictItemSet_node_pairs():
assert item[expected_data[i][0]] == expected_data[i][1]
def test_DictItemSet_node_pairs_labels():
def test_ItemSetDict_node_pairs_labels():
# Node pairs and labels
node_pairs = (torch.arange(0, 5), torch.arange(5, 10))
labels = torch.randint(0, 3, (5,))
......@@ -101,7 +101,7 @@ def test_DictItemSet_node_pairs_labels():
expected_data = []
for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value]
item_set = DictItemSet(node_pairs_dict)
item_set = ItemSetDict(node_pairs_dict)
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
......@@ -109,7 +109,7 @@ def test_DictItemSet_node_pairs_labels():
assert item[expected_data[i][0]] == expected_data[i][1]
def test_DictItemSet_head_tail_neg_tails():
def test_ItemSetDict_head_tail_neg_tails():
# Head, tail and negative tails.
heads = torch.arange(0, 5)
tails = torch.arange(5, 10)
......@@ -122,7 +122,7 @@ def test_DictItemSet_head_tail_neg_tails():
expected_data = []
for key, value in data_dict.items():
expected_data += [(key, v) for v in value]
item_set = DictItemSet(data_dict)
item_set = ItemSetDict(data_dict)
for i, item in enumerate(item_set):
assert len(item) == 1
assert isinstance(item, dict)
......
......@@ -215,7 +215,7 @@ def test_append_with_other_datapipes():
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_DictItemSet_node_ids(batch_size, shuffle, drop_last):
def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 205
ids = {
......@@ -225,7 +225,7 @@ def test_DictItemSet_node_ids(batch_size, shuffle, drop_last):
chained_ids = []
for key, value in ids.items():
chained_ids += [(key, v) for v in value]
item_set = gb.DictItemSet(ids)
item_set = gb.ItemSetDict(ids)
minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......@@ -253,7 +253,7 @@ def test_DictItemSet_node_ids(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_DictItemSet_node_pairs(batch_size, shuffle, drop_last):
def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
# Node pairs.
num_ids = 103
total_ids = 2 * num_ids
......@@ -269,7 +269,7 @@ def test_DictItemSet_node_pairs(batch_size, shuffle, drop_last):
("user", "like", "item"): gb.ItemSet(node_pairs_0),
("user", "follow", "user"): gb.ItemSet(node_pairs_1),
}
item_set = gb.DictItemSet(node_pairs_dict)
item_set = gb.ItemSetDict(node_pairs_dict)
minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......@@ -305,7 +305,7 @@ def test_DictItemSet_node_pairs(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_DictItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
# Node pairs and labels
num_ids = 103
total_ids = 2 * num_ids
......@@ -326,7 +326,7 @@ def test_DictItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
(node_pairs_1[0], node_pairs_1[1], labels + num_ids * 2)
),
}
item_set = gb.DictItemSet(node_pairs_dict)
item_set = gb.ItemSetDict(node_pairs_dict)
minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......@@ -371,7 +371,7 @@ def test_DictItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_DictItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last):
def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last):
# Head, tail and negative tails.
num_ids = 103
total_ids = 2 * num_ids
......@@ -383,7 +383,7 @@ def test_DictItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last):
("user", "like", "item"): gb.ItemSet((heads, tails, neg_tails)),
("user", "follow", "user"): gb.ItemSet((heads, tails, neg_tails)),
}
item_set = gb.DictItemSet(data_dict)
item_set = gb.ItemSetDict(data_dict)
minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment