Unverified Commit 29949322 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] convert item list to MiniBatch (#6281)

parent dadce86a
......@@ -2,16 +2,73 @@
from collections.abc import Mapping
from functools import partial
from typing import Iterator, Optional
from typing import Callable, Iterator, Optional
from torch.utils.data import default_collate
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from ..base import dgl_warning
from ..batch import batch as dgl_batch
from ..heterograph import DGLGraph
from .itemset import ItemSet, ItemSetDict
from .minibatch import MiniBatch
__all__ = ["ItemSampler", "minibatcher_default"]
__all__ = ["ItemSampler"]
def minibatcher_default(batch, names):
"""Default minibatcher.
The default minibatcher maps a list of items to a `MiniBatch` with the
same names as the items. The names of items are supposed to be provided
and align with the data attributes of `MiniBatch`. If any unknown item name
is provided, exception will be raised. If the names of items are not
provided, the item list is returned as is and a warning will be raised.
Parameters
----------
batch : list
List of items.
names : Tuple[str] or None
Names of items in `batch` with same length. The order should align
with `batch`.
Returns
-------
MiniBatch
A minibatch.
"""
if names is None:
dgl_warning(
"Failed to map item list to `MiniBatch` as the names of items are "
"not provided. Please provide a customized `MiniBatcher`. "
"The item list is returned as is."
)
return batch
if len(names) == 1:
# Handle the case of single item: batch = tensor([0, 1, 2, 3]), names =
# ("seed_nodes",) as `zip(batch, names)` will iterate over the tensor
# instead of the batch.
init_data = {names[0]: batch}
else:
if isinstance(batch, Mapping):
init_data = {
name: {k: v[i] for k, v in batch.items()}
for i, name in enumerate(names)
}
else:
init_data = {name: item for item, name in zip(batch, names)}
minibatch = MiniBatch()
for name, item in init_data.items():
if not hasattr(minibatch, name):
dgl_warning(
f"Unknown item name '{name}' is detected and added into "
"`MiniBatch`. You probably need to provide a customized "
"`MiniBatcher`."
)
setattr(minibatch, name, item)
return minibatch
class ItemSampler(IterDataPipe):
......@@ -32,6 +89,8 @@ class ItemSampler(IterDataPipe):
Data to be sampled.
batch_size : int
The size of each batch.
minibatcher : Optional[Callable]
A callable that takes in a list of items and returns a `MiniBatch`.
drop_last : bool
Option to drop the last batch if it's not full.
shuffle : bool
......@@ -42,41 +101,68 @@ class ItemSampler(IterDataPipe):
1. Node IDs.
>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> list(item_sampler)
[tensor([1, 2, 5, 7]), tensor([3, 0, 9, 4]), tensor([6, 8])]
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=tensor([9, 0, 7, 2]), node_pairs=None, labels=None,
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
2. Node pairs.
>>> item_set = gb.ItemSet((torch.arange(0, 10), torch.arange(10, 20)))
>>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2),
... names="node_pairs")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> list(item_sampler)
[[tensor([9, 8, 3, 1]), tensor([19, 18, 13, 11])], [tensor([2, 5, 7, 4]),
tensor([12, 15, 17, 14])], [tensor([0, 6]), tensor([10, 16])]
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[16, 17],
[ 4, 5],
[ 6, 7],
[10, 11]]), labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
3. Node pairs and labels.
>>> item_set = gb.ItemSet(
... (torch.arange(0, 5), torch.arange(5, 10), torch.arange(10, 15))
... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 15)),
... names=("node_pairs", "labels")
... )
>>> item_sampler = gb.ItemSampler(item_set, 3)
>>> list(item_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]), tensor([10, 11, 12])],
[tensor([3, 4]), tensor([8, 9]), tensor([13, 14])]]
4. Head, tail and negative tails
>>> heads = torch.arange(0, 5)
>>> tails = torch.arange(5, 10)
>>> negative_tails = torch.stack((heads + 1, heads + 2), dim=-1)
>>> item_set = gb.ItemSet((heads, tails, negative_tails))
>>> item_sampler = gb.ItemSampler(item_set, 3)
>>> list(item_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]),
tensor([[1, 2], [2, 3], [3, 4]])],
[tensor([3, 4]), tensor([8, 9]), tensor([[4, 5], [5, 6]])]]
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[8, 9],
[4, 5],
[0, 1],
[6, 7]]), labels=tensor([14, 12, 10, 13]), negative_srcs=None,
negative_dsts=None, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
4. Node pairs and negative destinations.
>>> node_pairs = torch.arange(0, 20).reshape(-1, 2)
>>> negative_dsts = torch.arange(10, 30).reshape(-1, 2)
>>> item_set = gb.ItemSet((node_pairs, negative_dsts), names=("node_pairs",
... "negative_dsts"))
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[10, 11],
[ 6, 7],
[ 2, 3],
[ 8, 9]]), labels=None, negative_srcs=None,
negative_dsts=tensor([[20, 21],
[16, 17],
[12, 13],
[18, 19]]), sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
5. DGLGraphs.
>>> import dgl
......@@ -103,81 +189,96 @@ class ItemSampler(IterDataPipe):
7. Heterogeneous node IDs.
>>> ids = {
... "user": gb.ItemSet(torch.arange(0, 5)),
... "item": gb.ItemSet(torch.arange(0, 6)),
... "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
... "item": gb.ItemSet(torch.arange(0, 6), names="seed_nodes"),
... }
>>> item_set = gb.ItemSetDict(ids)
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{'user': tensor([0, 1, 2, 3])},
{'item': tensor([0, 1, 2]), 'user': tensor([4])},
{'item': tensor([3, 4, 5])}]
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes={'user': tensor([0, 1, 2, 3])}, node_pairs=None,
labels=None, negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
8. Heterogeneous node pairs.
>>> node_pairs_like = (torch.arange(0, 5), torch.arange(0, 5))
>>> node_pairs_follow = (torch.arange(0, 6), torch.arange(6, 12))
>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(node_pairs_like),
... "user:follow:user": gb.ItemSet(node_pairs_follow),
... "user:like:item": gb.ItemSet(
... node_pairs_like, names="node_pairs"),
... "user:follow:user": gb.ItemSet(
... node_pairs_follow, names="node_pairs"),
... })
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{"user:like:item": [tensor([4]), tensor([4])],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8])]},
{"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11])]}]
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])}, labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
9. Heterogeneous node pairs and labels.
>>> like = (
... torch.arange(0, 5), torch.arange(0, 5), torch.arange(0, 5))
>>> follow = (
... torch.arange(0, 6), torch.arange(6, 12), torch.arange(0, 6))
>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> labels_like = torch.arange(0, 10)
>>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
>>> labels_follow = torch.arange(10, 20)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow),
... "user:like:item": gb.ItemSet((node_pairs_like, labels_like),
... names=("node_pairs", "labels")),
... "user:follow:user": gb.ItemSet((node_pairs_follow, labels_follow),
... names=("node_pairs", "labels")),
... })
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{"user:like:item":
[tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{"user:like:item": [tensor([4]), tensor([4]), tensor([4])],
"user:follow:user":
[tensor([0, 1, 2]), tensor([6, 7, 8]), tensor([0, 1, 2])]},
{"user:follow:user":
[tensor([3, 4, 5]), tensor([ 9, 10, 11]), tensor([3, 4, 5])]}]
10. Heterogeneous head, tail and negative tails.
>>> like = (
... torch.arange(0, 5), torch.arange(0, 5),
... torch.arange(5, 15).reshape(-1, 2))
>>> follow = (
... torch.arange(0, 6), torch.arange(6, 12),
... torch.arange(12, 24).reshape(-1, 2))
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])}, labels={'user:like:item': tensor([0, 1, 2, 3])},
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
10. Heterogeneous node pairs and negative destinations.
>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> negative_dsts_like = torch.arange(10, 20).reshape(-1, 2)
>>> node_pairs_follow = torch.arange(20, 30).reshape(-1, 2)
>>> negative_dsts_follow = torch.arange(30, 40).reshape(-1, 2)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow),
... "user:like:item": gb.ItemSet((node_pairs_like, negative_dsts_like),
... names=("node_pairs", "negative_dsts")),
... "user:follow:user": gb.ItemSet((node_pairs_follow,
... negative_dsts_follow), names=("node_pairs", "negative_dsts")),
... })
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]),
tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]},
{"user:like:item": [tensor([4]), tensor([4]), tensor([[13, 14]])],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8]),
tensor([[12, 13], [14, 15], [16, 17]])]},
{"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11]),
tensor([[18, 19], [20, 21], [22, 23]])]}]
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])}, labels=None, negative_srcs=None,
negative_dsts={'user:like:item': tensor([[10, 11],
[12, 13],
[14, 15],
[16, 17]])}, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
"""
def __init__(
self,
item_set: ItemSet or ItemSetDict,
batch_size: int,
minibatcher: Optional[Callable] = minibatcher_default,
drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False,
) -> None:
super().__init__()
self._item_set = item_set
self._batch_size = batch_size
self._minibatcher = minibatcher
self._drop_last = drop_last
self._shuffle = shuffle
......@@ -217,4 +318,9 @@ class ItemSampler(IterDataPipe):
data_pipe = data_pipe.collate(collate_fn=partial(_collate))
# Map to minibatch.
data_pipe = data_pipe.map(
partial(self._minibatcher, names=self._item_set.names)
)
return iter(data_pipe)
......@@ -10,7 +10,7 @@ import torch
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo():
dp = gb.ItemSampler(torch.randn(20), 4)
dp = gb.ItemSampler(gb.ItemSet(torch.randn(20)), 4)
dp = gb.CopyTo(dp, "cuda")
for data in dp:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment