Unverified Commit d88275ca authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] support DictItemSet in MinibatchSampler (#5803)

parent 945b0e54
"""Minibatch Sampler"""
from typing import Mapping, Optional
from collections.abc import Mapping
from functools import partial
from typing import Optional
from torch.utils.data import default_collate
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from ..batch import batch as dgl_batch
from ..heterograph import DGLGraph
from .itemset import ItemSet
from .itemset import DictItemSet, ItemSet
__all__ = ["MinibatchSampler"]
def _collate(batch):
"""Collate batch."""
data = next(iter(batch))
if isinstance(data, DGLGraph):
return dgl_batch(batch)
elif isinstance(data, Mapping):
raise NotImplementedError
return default_collate(batch)
class MinibatchSampler(IterDataPipe):
"""Minibatch Sampler.
......@@ -36,7 +28,7 @@ class MinibatchSampler(IterDataPipe):
Parameters
----------
item_set : ItemSet
item_set : ItemSet or DictItemSet
Data to be sampled for mini-batches.
batch_size : int
The size of each batch.
......@@ -47,7 +39,7 @@ class MinibatchSampler(IterDataPipe):
Examples
--------
1. Node/edge IDs.
1. Node IDs.
>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10))
......@@ -108,11 +100,77 @@ class MinibatchSampler(IterDataPipe):
>>> data_pipe = data_pipe.map(add_one)
>>> list(data_pipe)
[tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8]), tensor([ 9, 10])]
7. Heterogeneous node IDs.
>>> ids = {
... "user": gb.ItemSet(torch.arange(0, 5)),
... "item": gb.ItemSet(torch.arange(0, 6)),
... }
>>> item_set = gb.DictItemSet(ids)
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
[{'user': tensor([0, 1, 2, 3])},
{'item': tensor([0, 1, 2]), 'user': tensor([4])},
{'item': tensor([3, 4, 5])}]
8. Heterogeneous node pairs.
>>> node_pairs_like = (torch.arange(0, 5), torch.arange(0, 5))
>>> node_pairs_follow = (torch.arange(0, 6), torch.arange(6, 12))
>>> item_set = gb.DictItemSet({
... ("user", "like", "item"): gb.ItemSet(node_pairs_like),
... ("user", "follow", "user"): gb.ItemSet(node_pairs_follow),
... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
[{('user', 'like', 'item'): [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{('user', 'like', 'item'): [tensor([4]), tensor([4])],
('user', 'follow', 'user'): [tensor([0, 1, 2]), tensor([6, 7, 8])]},
{('user', 'follow', 'user'): [tensor([3, 4, 5]), tensor([ 9, 10, 11])]}]
9. Heterogeneous node pairs and labels.
>>> like = (
... torch.arange(0, 5), torch.arange(0, 5), torch.arange(0, 5))
>>> follow = (
... torch.arange(0, 6), torch.arange(6, 12), torch.arange(0, 6))
>>> item_set = gb.DictItemSet({
... ("user", "like", "item"): gb.ItemSet(like),
... ("user", "follow", "user"): gb.ItemSet(follow),
... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
[{('user', 'like', 'item'):
[tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{('user', 'like', 'item'): [tensor([4]), tensor([4]), tensor([4])],
('user', 'follow', 'user'):
[tensor([0, 1, 2]), tensor([6, 7, 8]), tensor([0, 1, 2])]},
{('user', 'follow', 'user'):
[tensor([3, 4, 5]), tensor([ 9, 10, 11]), tensor([3, 4, 5])]}]
10. Heterogeneous head, tail and negative tails.
>>> like = (
... torch.arange(0, 5), torch.arange(0, 5),
... torch.arange(5, 15).reshape(-1, 2))
>>> follow = (
... torch.arange(0, 6), torch.arange(6, 12),
... torch.arange(12, 24).reshape(-1, 2))
>>> item_set = gb.DictItemSet({
... ("user", "like", "item"): gb.ItemSet(like),
... ("user", "follow", "user"): gb.ItemSet(follow),
... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
[{('user', 'like', 'item'): [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]),
tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]},
{('user', 'like', 'item'): [tensor([4]), tensor([4]), tensor([[13, 14]])],
('user', 'follow', 'user'): [tensor([0, 1, 2]), tensor([6, 7, 8]),
tensor([[12, 13], [14, 15], [16, 17]])]},
{('user', 'follow', 'user'): [tensor([3, 4, 5]), tensor([ 9, 10, 11]),
tensor([[18, 19], [20, 21], [22, 23]])]}]
"""
def __init__(
self,
item_set: ItemSet,
item_set: ItemSet or DictItemSet,
batch_size: int,
drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False,
......@@ -125,11 +183,35 @@ class MinibatchSampler(IterDataPipe):
def __iter__(self):
data_pipe = IterableWrapper(self._item_set)
# Shuffle before batch.
if self._shuffle:
# `torchdata.datapipes.iter.Shuffler` works with stream too.
data_pipe = data_pipe.shuffle()
# Batch.
data_pipe = data_pipe.batch(
batch_size=self._batch_size,
drop_last=self._drop_last,
).collate(collate_fn=_collate)
)
# Collate.
def _collate(batch):
data = next(iter(batch))
if isinstance(data, DGLGraph):
return dgl_batch(batch)
elif isinstance(data, Mapping):
assert len(data) == 1, "Only one type of data is allowed."
# Collect all the keys.
keys = {key for item in batch for key in item.keys()}
# Collate each key.
return {
key: default_collate(
[item[key] for item in batch if key in item]
)
for key in keys
}
return default_collate(batch)
data_pipe = data_pipe.collate(collate_fn=partial(_collate))
return iter(data_pipe)
......@@ -8,8 +8,8 @@ from torch.testing import assert_close
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_node_edge_ids(batch_size, shuffle, drop_last):
# Node or edge IDs.
def test_ItemSet_node_ids(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 103
item_set = gb.ItemSet(torch.arange(0, num_ids))
minibatch_sampler = gb.MinibatchSampler(
......@@ -210,3 +210,220 @@ def test_append_with_other_datapipes():
for i, (idx, data) in enumerate(data_pipe):
assert i == idx
assert len(data) == batch_size
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_DictItemSet_node_ids(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 205
ids = {
"user": gb.ItemSet(torch.arange(0, 99)),
"item": gb.ItemSet(torch.arange(99, num_ids)),
}
chained_ids = []
for key, value in ids.items():
chained_ids += [(key, v) for v in value]
item_set = gb.DictItemSet(ids)
minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_ids = []
for i, batch in enumerate(minibatch_sampler):
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = num_ids % batch_size
else:
assert False
assert isinstance(batch, dict)
ids = []
for _, v in batch.items():
ids.append(v)
ids = torch.cat(ids)
assert len(ids) == expected_batch_size
minibatch_ids.append(ids)
minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_DictItemSet_node_pairs(batch_size, shuffle, drop_last):
# Node pairs.
num_ids = 103
total_ids = 2 * num_ids
node_pairs_0 = (
torch.arange(0, num_ids),
torch.arange(num_ids, num_ids * 2),
)
node_pairs_1 = (
torch.arange(num_ids * 2, num_ids * 3),
torch.arange(num_ids * 3, num_ids * 4),
)
node_pairs_dict = {
("user", "like", "item"): gb.ItemSet(node_pairs_0),
("user", "follow", "user"): gb.ItemSet(node_pairs_1),
}
item_set = gb.DictItemSet(node_pairs_dict)
minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
src_ids = []
dst_ids = []
for i, batch in enumerate(minibatch_sampler):
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = total_ids % batch_size
else:
assert False
src = []
dst = []
for _, (v_src, v_dst) in batch.items():
src.append(v_src)
dst.append(v_dst)
src = torch.cat(src)
dst = torch.cat(dst)
assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size
src_ids.append(src)
dst_ids.append(dst)
assert torch.equal(src + num_ids, dst)
src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids)
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_DictItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
# Node pairs and labels
num_ids = 103
total_ids = 2 * num_ids
node_pairs_0 = (
torch.arange(0, num_ids),
torch.arange(num_ids, num_ids * 2),
)
node_pairs_1 = (
torch.arange(num_ids * 2, num_ids * 3),
torch.arange(num_ids * 3, num_ids * 4),
)
labels = torch.arange(0, num_ids)
node_pairs_dict = {
("user", "like", "item"): gb.ItemSet(
(node_pairs_0[0], node_pairs_0[1], labels)
),
("user", "follow", "user"): gb.ItemSet(
(node_pairs_1[0], node_pairs_1[1], labels + num_ids * 2)
),
}
item_set = gb.DictItemSet(node_pairs_dict)
minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
src_ids = []
dst_ids = []
labels = []
for i, batch in enumerate(minibatch_sampler):
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = total_ids % batch_size
else:
assert False
src = []
dst = []
label = []
for _, (v_src, v_dst, v_label) in batch.items():
src.append(v_src)
dst.append(v_dst)
label.append(v_label)
src = torch.cat(src)
dst = torch.cat(dst)
label = torch.cat(label)
assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size
assert len(label) == expected_batch_size
src_ids.append(src)
dst_ids.append(dst)
labels.append(label)
assert torch.equal(src + num_ids, dst)
assert torch.equal(src, label)
src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids)
labels = torch.cat(labels)
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(labels[:-1] <= labels[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_DictItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last):
# Head, tail and negative tails.
num_ids = 103
total_ids = 2 * num_ids
num_negs = 2
heads = torch.arange(0, num_ids)
tails = torch.arange(num_ids, num_ids * 2)
neg_tails = torch.stack((heads + 1, heads + 2), dim=-1)
data_dict = {
("user", "like", "item"): gb.ItemSet((heads, tails, neg_tails)),
("user", "follow", "user"): gb.ItemSet((heads, tails, neg_tails)),
}
item_set = gb.DictItemSet(data_dict)
minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
head_ids = []
tail_ids = []
negs_ids = []
for i, batch in enumerate(minibatch_sampler):
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = total_ids % batch_size
else:
assert False
head = []
tail = []
negs = []
for _, (v_head, v_tail, v_negs) in batch.items():
head.append(v_head)
tail.append(v_tail)
negs.append(v_negs)
head = torch.cat(head)
tail = torch.cat(tail)
negs = torch.cat(negs)
assert len(head) == expected_batch_size
assert len(tail) == expected_batch_size
assert len(negs) == expected_batch_size
head_ids.append(head)
tail_ids.append(tail)
negs_ids.append(negs)
assert negs.dim() == 2
assert negs.shape[0] == expected_batch_size
assert negs.shape[1] == num_negs
assert torch.equal(head + num_ids, tail)
assert torch.equal(head + 1, negs[:, 0])
assert torch.equal(head + 2, negs[:, 1])
head_ids = torch.cat(head_ids)
tail_ids = torch.cat(tail_ids)
negs_ids = torch.cat(negs_ids)
assert torch.all(head_ids[:-1] <= head_ids[1:]) is not shuffle
assert torch.all(tail_ids[:-1] <= tail_ids[1:]) is not shuffle
assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment