Unverified Commit a1051f00 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] add MinibatchSampler which supports ItemSet (#5793)

parent b56552ff
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from .._ffi import libinfo from .._ffi import libinfo
from .graph_storage import * from .graph_storage import *
from .itemset import * from .itemset import *
from .minibatch_sampler import *
def load_graphbolt(): def load_graphbolt():
......
"""Minibatch Sampler"""
from typing import Mapping, Optional
from torch.utils.data import default_collate
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from ..batch import batch as dgl_batch
from ..heterograph import DGLGraph
from .itemset import ItemSet
__all__ = ["MinibatchSampler"]
def _collate(batch):
"""Collate batch."""
data = next(iter(batch))
if isinstance(data, DGLGraph):
return dgl_batch(batch)
elif isinstance(data, Mapping):
raise NotImplementedError
return default_collate(batch)
class MinibatchSampler(IterDataPipe):
"""Minibatch Sampler.
Creates mini-batches of data which could be node/edge IDs, node pairs with
or without labels, head/tail/negative_tails, DGLGraphs and heterogeneous
counterparts.
Note: This class `MinibatchSampler` is not decorated with
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it
does not support function-like call. But any iterable datapipes from
`torchdata` can be further appended.
Parameters
----------
item_set : ItemSet
Data to be sampled for mini-batches.
batch_size : int
The size of each batch.
drop_last : bool
Option to drop the last batch if it's not full.
shuffle : bool
Option to shuffle before sample.
Examples
--------
1. Node/edge IDs.
>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> minibatch_sampler = gb.MinibatchSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> list(minibatch_sampler)
[tensor([1, 2, 5, 7]), tensor([3, 0, 9, 4]), tensor([6, 8])]
2. Node pairs.
>>> item_set = gb.ItemSet((torch.arange(0, 10), torch.arange(10, 20)))
>>> minibatch_sampler = gb.MinibatchSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> list(minibatch_sampler)
[[tensor([9, 8, 3, 1]), tensor([19, 18, 13, 11])], [tensor([2, 5, 7, 4]),
tensor([12, 15, 17, 14])], [tensor([0, 6]), tensor([10, 16])]
3. Node pairs and labels.
>>> item_set = gb.ItemSet(
... (torch.arange(0, 5), torch.arange(5, 10), torch.arange(10, 15))
... )
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 3)
>>> list(minibatch_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]), tensor([10, 11, 12])],
[tensor([3, 4]), tensor([8, 9]), tensor([13, 14])]]
4. Head, tail and negative tails
>>> heads = torch.arange(0, 5)
>>> tails = torch.arange(5, 10)
>>> negative_tails = torch.stack((heads + 1, heads + 2), dim=-1)
>>> item_set = gb.ItemSet((heads, tails, negative_tails))
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 3)
>>> list(minibatch_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]),
tensor([[1, 2], [2, 3], [3, 4]])],
[tensor([3, 4]), tensor([8, 9]), tensor([[4, 5], [5, 6]])]]
5. DGLGraphs.
>>> import dgl
>>> graphs = [ dgl.rand_graph(10, 20) for _ in range(5) ]
>>> item_set = gb.ItemSet(graphs)
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 3)
>>> list(minibatch_sampler)
[Graph(num_nodes=30, num_edges=60,
ndata_schemes={}
edata_schemes={}),
Graph(num_nodes=20, num_edges=40,
ndata_schemes={}
edata_schemes={})]
6. Further process batches with other datapipes such as
`torchdata.datapipes.iter.Mapper`.
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> data_pipe = gb.MinibatchSampler(item_set, 4)
>>> def add_one(batch):
... return batch + 1
>>> data_pipe = data_pipe.map(add_one)
>>> list(data_pipe)
[tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8]), tensor([ 9, 10])]
"""
def __init__(
self,
item_set: ItemSet,
batch_size: int,
drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False,
):
super().__init__()
self._item_set = item_set
self._batch_size = batch_size
self._drop_last = drop_last
self._shuffle = shuffle
def __iter__(self):
data_pipe = IterableWrapper(self._item_set)
if self._shuffle:
# `torchdata.datapipes.iter.Shuffler` works with stream too.
data_pipe = data_pipe.shuffle()
data_pipe = data_pipe.batch(
batch_size=self._batch_size,
drop_last=self._drop_last,
).collate(collate_fn=_collate)
return iter(data_pipe)
import dgl
import pytest
import torch
from dgl import graphbolt as gb
from torch.testing import assert_close
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_node_edge_ids(batch_size, shuffle, drop_last):
# Node or edge IDs.
num_ids = 103
item_set = gb.ItemSet(torch.arange(0, num_ids))
minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_ids = []
for i, minibatch in enumerate(minibatch_sampler):
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
assert len(minibatch) == batch_size
else:
if not drop_last:
assert len(minibatch) == num_ids % batch_size
else:
assert False
minibatch_ids.append(minibatch)
minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_graphs(batch_size, shuffle, drop_last):
# Graphs.
num_graphs = 103
num_nodes = 10
num_edges = 20
graphs = [
dgl.rand_graph(num_nodes * (i + 1), num_edges * (i + 1))
for i in range(num_graphs)
]
item_set = gb.ItemSet(graphs)
minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_num_nodes = []
minibatch_num_edges = []
for i, minibatch in enumerate(minibatch_sampler):
is_last = (i + 1) * batch_size >= num_graphs
if not is_last or num_graphs % batch_size == 0:
assert minibatch.batch_size == batch_size
else:
if not drop_last:
assert minibatch.batch_size == num_graphs % batch_size
else:
assert False
minibatch_num_nodes.append(minibatch.batch_num_nodes())
minibatch_num_edges.append(minibatch.batch_num_edges())
minibatch_num_nodes = torch.cat(minibatch_num_nodes)
minibatch_num_edges = torch.cat(minibatch_num_edges)
assert (
torch.all(minibatch_num_nodes[:-1] <= minibatch_num_nodes[1:])
is not shuffle
)
assert (
torch.all(minibatch_num_edges[:-1] <= minibatch_num_edges[1:])
is not shuffle
)
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
# Node pairs.
num_ids = 103
node_pairs = (torch.arange(0, num_ids), torch.arange(num_ids, num_ids * 2))
item_set = gb.ItemSet(node_pairs)
minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
src_ids = []
dst_ids = []
for i, (src, dst) in enumerate(minibatch_sampler):
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = num_ids % batch_size
else:
assert False
assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size
# Verify src and dst IDs match.
assert torch.equal(src + num_ids, dst)
# Archive batch.
src_ids.append(src)
dst_ids.append(dst)
src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids)
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
# Node pairs and labels
num_ids = 103
node_pairs = (torch.arange(0, num_ids), torch.arange(num_ids, num_ids * 2))
labels = torch.arange(0, num_ids)
item_set = gb.ItemSet((node_pairs[0], node_pairs[1], labels))
minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
src_ids = []
dst_ids = []
labels = []
for i, (src, dst, label) in enumerate(minibatch_sampler):
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = num_ids % batch_size
else:
assert False
assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size
assert len(label) == expected_batch_size
# Verify src/dst IDs and labels match.
assert torch.equal(src + num_ids, dst)
assert torch.equal(src, label)
# Archive batch.
src_ids.append(src)
dst_ids.append(dst)
labels.append(label)
src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids)
labels = torch.cat(labels)
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(labels[:-1] <= labels[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last):
# Head, tail and negative tails.
num_ids = 103
num_negs = 2
heads = torch.arange(0, num_ids)
tails = torch.arange(num_ids, num_ids * 2)
neg_tails = torch.stack((heads + 1, heads + 2), dim=-1)
item_set = gb.ItemSet((heads, tails, neg_tails))
for i, (head, tail, negs) in enumerate(item_set):
assert heads[i] == head
assert tails[i] == tail
assert torch.equal(neg_tails[i], negs)
minibatch_sampler = gb.MinibatchSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
head_ids = []
tail_ids = []
negs_ids = []
for i, (head, tail, negs) in enumerate(minibatch_sampler):
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = num_ids % batch_size
else:
assert False
assert len(head) == expected_batch_size
assert len(tail) == expected_batch_size
assert negs.dim() == 2
assert negs.shape[0] == expected_batch_size
assert negs.shape[1] == num_negs
# Verify head/tail and negatie tails match.
assert torch.equal(head + num_ids, tail)
assert torch.equal(head + 1, negs[:, 0])
assert torch.equal(head + 2, negs[:, 1])
# Archive batch.
head_ids.append(head)
tail_ids.append(tail)
negs_ids.append(negs)
head_ids = torch.cat(head_ids)
tail_ids = torch.cat(tail_ids)
negs_ids = torch.cat(negs_ids)
assert torch.all(head_ids[:-1] <= head_ids[1:]) is not shuffle
assert torch.all(tail_ids[:-1] <= tail_ids[1:]) is not shuffle
assert torch.all(negs_ids[:-1, 0] <= negs_ids[1:, 0]) is not shuffle
assert torch.all(negs_ids[:-1, 1] <= negs_ids[1:, 1]) is not shuffle
def test_append_with_other_datapipes():
num_ids = 100
batch_size = 4
item_set = gb.ItemSet(torch.arange(0, num_ids))
data_pipe = gb.MinibatchSampler(item_set, batch_size)
# torchdata.datapipes.iter.Enumerator
data_pipe = data_pipe.enumerate()
for i, (idx, data) in enumerate(data_pipe):
assert i == idx
assert len(data) == batch_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment