Unverified Commit 29949322 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] convert item list to MiniBatch (#6281)

parent dadce86a
...@@ -2,16 +2,73 @@ ...@@ -2,16 +2,73 @@
from collections.abc import Mapping from collections.abc import Mapping
from functools import partial from functools import partial
from typing import Iterator, Optional from typing import Callable, Iterator, Optional
from torch.utils.data import default_collate from torch.utils.data import default_collate
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from ..base import dgl_warning
from ..batch import batch as dgl_batch from ..batch import batch as dgl_batch
from ..heterograph import DGLGraph from ..heterograph import DGLGraph
from .itemset import ItemSet, ItemSetDict from .itemset import ItemSet, ItemSetDict
from .minibatch import MiniBatch
__all__ = ["ItemSampler", "minibatcher_default"]
def minibatcher_default(batch, names):
"""Default minibatcher.
The default minibatcher maps a list of items to a `MiniBatch` with the
same names as the items. The names of items are supposed to be provided
and align with the data attributes of `MiniBatch`. If any unknown item name
is provided, exception will be raised. If the names of items are not
provided, the item list is returned as is and a warning will be raised.
Parameters
----------
batch : list
List of items.
names : Tuple[str] or None
Names of items in `batch` with same length. The order should align
with `batch`.
__all__ = ["ItemSampler"] Returns
-------
MiniBatch
A minibatch.
"""
if names is None:
dgl_warning(
"Failed to map item list to `MiniBatch` as the names of items are "
"not provided. Please provide a customized `MiniBatcher`. "
"The item list is returned as is."
)
return batch
if len(names) == 1:
# Handle the case of single item: batch = tensor([0, 1, 2, 3]), names =
# ("seed_nodes",) as `zip(batch, names)` will iterate over the tensor
# instead of the batch.
init_data = {names[0]: batch}
else:
if isinstance(batch, Mapping):
init_data = {
name: {k: v[i] for k, v in batch.items()}
for i, name in enumerate(names)
}
else:
init_data = {name: item for item, name in zip(batch, names)}
minibatch = MiniBatch()
for name, item in init_data.items():
if not hasattr(minibatch, name):
dgl_warning(
f"Unknown item name '{name}' is detected and added into "
"`MiniBatch`. You probably need to provide a customized "
"`MiniBatcher`."
)
setattr(minibatch, name, item)
return minibatch
class ItemSampler(IterDataPipe): class ItemSampler(IterDataPipe):
...@@ -32,6 +89,8 @@ class ItemSampler(IterDataPipe): ...@@ -32,6 +89,8 @@ class ItemSampler(IterDataPipe):
Data to be sampled. Data to be sampled.
batch_size : int batch_size : int
The size of each batch. The size of each batch.
minibatcher : Optional[Callable]
A callable that takes in a list of items and returns a `MiniBatch`.
drop_last : bool drop_last : bool
Option to drop the last batch if it's not full. Option to drop the last batch if it's not full.
shuffle : bool shuffle : bool
...@@ -42,41 +101,68 @@ class ItemSampler(IterDataPipe): ...@@ -42,41 +101,68 @@ class ItemSampler(IterDataPipe):
1. Node IDs. 1. Node IDs.
>>> import torch >>> import torch
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10)) >>> item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False ... item_set, batch_size=4, shuffle=True, drop_last=False
... ) ... )
>>> list(item_sampler) >>> next(iter(item_sampler))
[tensor([1, 2, 5, 7]), tensor([3, 0, 9, 4]), tensor([6, 8])] MiniBatch(seed_nodes=tensor([9, 0, 7, 2]), node_pairs=None, labels=None,
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
2. Node pairs. 2. Node pairs.
>>> item_set = gb.ItemSet((torch.arange(0, 10), torch.arange(10, 20))) >>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2),
... names="node_pairs")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False ... item_set, batch_size=4, shuffle=True, drop_last=False
... ) ... )
>>> list(item_sampler) >>> next(iter(item_sampler))
[[tensor([9, 8, 3, 1]), tensor([19, 18, 13, 11])], [tensor([2, 5, 7, 4]), MiniBatch(seed_nodes=None, node_pairs=tensor([[16, 17],
tensor([12, 15, 17, 14])], [tensor([0, 6]), tensor([10, 16])] [ 4, 5],
[ 6, 7],
[10, 11]]), labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
3. Node pairs and labels. 3. Node pairs and labels.
>>> item_set = gb.ItemSet( >>> item_set = gb.ItemSet(
... (torch.arange(0, 5), torch.arange(5, 10), torch.arange(10, 15)) ... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 15)),
... names=("node_pairs", "labels")
... ) ... )
>>> item_sampler = gb.ItemSampler(item_set, 3) >>> item_sampler = gb.ItemSampler(
>>> list(item_sampler) ... item_set, batch_size=4, shuffle=True, drop_last=False
[[tensor([0, 1, 2]), tensor([5, 6, 7]), tensor([10, 11, 12])], ... )
[tensor([3, 4]), tensor([8, 9]), tensor([13, 14])]] >>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[8, 9],
4. Head, tail and negative tails [4, 5],
>>> heads = torch.arange(0, 5) [0, 1],
>>> tails = torch.arange(5, 10) [6, 7]]), labels=tensor([14, 12, 10, 13]), negative_srcs=None,
>>> negative_tails = torch.stack((heads + 1, heads + 2), dim=-1) negative_dsts=None, sampled_subgraphs=None, input_nodes=None,
>>> item_set = gb.ItemSet((heads, tails, negative_tails)) node_features=None, edge_features=None, compacted_node_pairs=None,
>>> item_sampler = gb.ItemSampler(item_set, 3) compacted_negative_srcs=None, compacted_negative_dsts=None)
>>> list(item_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]), 4. Node pairs and negative destinations.
tensor([[1, 2], [2, 3], [3, 4]])], >>> node_pairs = torch.arange(0, 20).reshape(-1, 2)
[tensor([3, 4]), tensor([8, 9]), tensor([[4, 5], [5, 6]])]] >>> negative_dsts = torch.arange(10, 30).reshape(-1, 2)
>>> item_set = gb.ItemSet((node_pairs, negative_dsts), names=("node_pairs",
... "negative_dsts"))
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[10, 11],
[ 6, 7],
[ 2, 3],
[ 8, 9]]), labels=None, negative_srcs=None,
negative_dsts=tensor([[20, 21],
[16, 17],
[12, 13],
[18, 19]]), sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
5. DGLGraphs. 5. DGLGraphs.
>>> import dgl >>> import dgl
...@@ -103,81 +189,96 @@ class ItemSampler(IterDataPipe): ...@@ -103,81 +189,96 @@ class ItemSampler(IterDataPipe):
7. Heterogeneous node IDs. 7. Heterogeneous node IDs.
>>> ids = { >>> ids = {
... "user": gb.ItemSet(torch.arange(0, 5)), ... "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
... "item": gb.ItemSet(torch.arange(0, 6)), ... "item": gb.ItemSet(torch.arange(0, 6), names="seed_nodes"),
... } ... }
>>> item_set = gb.ItemSetDict(ids) >>> item_set = gb.ItemSetDict(ids)
>>> item_sampler = gb.ItemSampler(item_set, 4) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> list(item_sampler) >>> next(iter(item_sampler))
[{'user': tensor([0, 1, 2, 3])}, MiniBatch(seed_nodes={'user': tensor([0, 1, 2, 3])}, node_pairs=None,
{'item': tensor([0, 1, 2]), 'user': tensor([4])}, labels=None, negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
{'item': tensor([3, 4, 5])}] input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
8. Heterogeneous node pairs. 8. Heterogeneous node pairs.
>>> node_pairs_like = (torch.arange(0, 5), torch.arange(0, 5)) >>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> node_pairs_follow = (torch.arange(0, 6), torch.arange(6, 12)) >>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
>>> item_set = gb.ItemSetDict({ >>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(node_pairs_like), ... "user:like:item": gb.ItemSet(
... "user:follow:user": gb.ItemSet(node_pairs_follow), ... node_pairs_like, names="node_pairs"),
... "user:follow:user": gb.ItemSet(
... node_pairs_follow, names="node_pairs"),
... }) ... })
>>> item_sampler = gb.ItemSampler(item_set, 4) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> list(item_sampler) >>> next(iter(item_sampler))
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]}, MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
{"user:like:item": [tensor([4]), tensor([4])], [2, 3],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8])]}, [4, 5],
{"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11])]}] [6, 7]])}, labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
9. Heterogeneous node pairs and labels. 9. Heterogeneous node pairs and labels.
>>> like = ( >>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
... torch.arange(0, 5), torch.arange(0, 5), torch.arange(0, 5)) >>> labels_like = torch.arange(0, 10)
>>> follow = ( >>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
... torch.arange(0, 6), torch.arange(6, 12), torch.arange(0, 6)) >>> labels_follow = torch.arange(10, 20)
>>> item_set = gb.ItemSetDict({ >>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(like), ... "user:like:item": gb.ItemSet((node_pairs_like, labels_like),
... "user:follow:user": gb.ItemSet(follow), ... names=("node_pairs", "labels")),
... "user:follow:user": gb.ItemSet((node_pairs_follow, labels_follow),
... names=("node_pairs", "labels")),
... }) ... })
>>> item_sampler = gb.ItemSampler(item_set, 4) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> list(item_sampler) >>> next(iter(item_sampler))
[{"user:like:item": MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]}, [2, 3],
{"user:like:item": [tensor([4]), tensor([4]), tensor([4])], [4, 5],
"user:follow:user": [6, 7]])}, labels={'user:like:item': tensor([0, 1, 2, 3])},
[tensor([0, 1, 2]), tensor([6, 7, 8]), tensor([0, 1, 2])]}, negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
{"user:follow:user": input_nodes=None, node_features=None, edge_features=None,
[tensor([3, 4, 5]), tensor([ 9, 10, 11]), tensor([3, 4, 5])]}] compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
10. Heterogeneous head, tail and negative tails.
>>> like = ( 10. Heterogeneous node pairs and negative destinations.
... torch.arange(0, 5), torch.arange(0, 5), >>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
... torch.arange(5, 15).reshape(-1, 2)) >>> negative_dsts_like = torch.arange(10, 20).reshape(-1, 2)
>>> follow = ( >>> node_pairs_follow = torch.arange(20, 30).reshape(-1, 2)
... torch.arange(0, 6), torch.arange(6, 12), >>> negative_dsts_follow = torch.arange(30, 40).reshape(-1, 2)
... torch.arange(12, 24).reshape(-1, 2))
>>> item_set = gb.ItemSetDict({ >>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(like), ... "user:like:item": gb.ItemSet((node_pairs_like, negative_dsts_like),
... "user:follow:user": gb.ItemSet(follow), ... names=("node_pairs", "negative_dsts")),
... "user:follow:user": gb.ItemSet((node_pairs_follow,
... negative_dsts_follow), names=("node_pairs", "negative_dsts")),
... }) ... })
>>> item_sampler = gb.ItemSampler(item_set, 4) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> list(item_sampler) >>> next(iter(item_sampler))
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]}, [2, 3],
{"user:like:item": [tensor([4]), tensor([4]), tensor([[13, 14]])], [4, 5],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8]), [6, 7]])}, labels=None, negative_srcs=None,
tensor([[12, 13], [14, 15], [16, 17]])]}, negative_dsts={'user:like:item': tensor([[10, 11],
{"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11]), [12, 13],
tensor([[18, 19], [20, 21], [22, 23]])]}] [14, 15],
[16, 17]])}, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
""" """
def __init__( def __init__(
self, self,
item_set: ItemSet or ItemSetDict, item_set: ItemSet or ItemSetDict,
batch_size: int, batch_size: int,
minibatcher: Optional[Callable] = minibatcher_default,
drop_last: Optional[bool] = False, drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False, shuffle: Optional[bool] = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self._item_set = item_set self._item_set = item_set
self._batch_size = batch_size self._batch_size = batch_size
self._minibatcher = minibatcher
self._drop_last = drop_last self._drop_last = drop_last
self._shuffle = shuffle self._shuffle = shuffle
...@@ -217,4 +318,9 @@ class ItemSampler(IterDataPipe): ...@@ -217,4 +318,9 @@ class ItemSampler(IterDataPipe):
data_pipe = data_pipe.collate(collate_fn=partial(_collate)) data_pipe = data_pipe.collate(collate_fn=partial(_collate))
# Map to minibatch.
data_pipe = data_pipe.map(
partial(self._minibatcher, names=self._item_set.names)
)
return iter(data_pipe) return iter(data_pipe)
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test") @unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo(): def test_CopyTo():
dp = gb.ItemSampler(torch.randn(20), 4) dp = gb.ItemSampler(gb.ItemSet(torch.randn(20)), 4)
dp = gb.CopyTo(dp, "cuda") dp = gb.CopyTo(dp, "cuda")
for data in dp: for data in dp:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment