Unverified Commit 29949322 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] convert item list to MiniBatch (#6281)

parent dadce86a
......@@ -2,16 +2,73 @@
from collections.abc import Mapping
from functools import partial
from typing import Iterator, Optional
from typing import Callable, Iterator, Optional
from torch.utils.data import default_collate
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from ..base import dgl_warning
from ..batch import batch as dgl_batch
from ..heterograph import DGLGraph
from .itemset import ItemSet, ItemSetDict
from .minibatch import MiniBatch
__all__ = ["ItemSampler", "minibatcher_default"]
def minibatcher_default(batch, names):
"""Default minibatcher.
The default minibatcher maps a list of items to a `MiniBatch` with the
same names as the items. The names of items are supposed to be provided
and align with the data attributes of `MiniBatch`. If any unknown item name
is provided, exception will be raised. If the names of items are not
provided, the item list is returned as is and a warning will be raised.
Parameters
----------
batch : list
List of items.
names : Tuple[str] or None
Names of items in `batch` with same length. The order should align
with `batch`.
__all__ = ["ItemSampler"]
Returns
-------
MiniBatch
A minibatch.
"""
if names is None:
dgl_warning(
"Failed to map item list to `MiniBatch` as the names of items are "
"not provided. Please provide a customized `MiniBatcher`. "
"The item list is returned as is."
)
return batch
if len(names) == 1:
# Handle the case of single item: batch = tensor([0, 1, 2, 3]), names =
# ("seed_nodes",) as `zip(batch, names)` will iterate over the tensor
# instead of the batch.
init_data = {names[0]: batch}
else:
if isinstance(batch, Mapping):
init_data = {
name: {k: v[i] for k, v in batch.items()}
for i, name in enumerate(names)
}
else:
init_data = {name: item for item, name in zip(batch, names)}
minibatch = MiniBatch()
for name, item in init_data.items():
if not hasattr(minibatch, name):
dgl_warning(
f"Unknown item name '{name}' is detected and added into "
"`MiniBatch`. You probably need to provide a customized "
"`MiniBatcher`."
)
setattr(minibatch, name, item)
return minibatch
class ItemSampler(IterDataPipe):
......@@ -32,6 +89,8 @@ class ItemSampler(IterDataPipe):
Data to be sampled.
batch_size : int
The size of each batch.
minibatcher : Optional[Callable]
A callable that takes in a list of items and returns a `MiniBatch`.
drop_last : bool
Option to drop the last batch if it's not full.
shuffle : bool
......@@ -42,41 +101,68 @@ class ItemSampler(IterDataPipe):
1. Node IDs.
>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> list(item_sampler)
[tensor([1, 2, 5, 7]), tensor([3, 0, 9, 4]), tensor([6, 8])]
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=tensor([9, 0, 7, 2]), node_pairs=None, labels=None,
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
2. Node pairs.
>>> item_set = gb.ItemSet((torch.arange(0, 10), torch.arange(10, 20)))
>>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2),
... names="node_pairs")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> list(item_sampler)
[[tensor([9, 8, 3, 1]), tensor([19, 18, 13, 11])], [tensor([2, 5, 7, 4]),
tensor([12, 15, 17, 14])], [tensor([0, 6]), tensor([10, 16])]
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[16, 17],
[ 4, 5],
[ 6, 7],
[10, 11]]), labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
3. Node pairs and labels.
>>> item_set = gb.ItemSet(
... (torch.arange(0, 5), torch.arange(5, 10), torch.arange(10, 15))
... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 15)),
... names=("node_pairs", "labels")
... )
>>> item_sampler = gb.ItemSampler(item_set, 3)
>>> list(item_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]), tensor([10, 11, 12])],
[tensor([3, 4]), tensor([8, 9]), tensor([13, 14])]]
4. Head, tail and negative tails
>>> heads = torch.arange(0, 5)
>>> tails = torch.arange(5, 10)
>>> negative_tails = torch.stack((heads + 1, heads + 2), dim=-1)
>>> item_set = gb.ItemSet((heads, tails, negative_tails))
>>> item_sampler = gb.ItemSampler(item_set, 3)
>>> list(item_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]),
tensor([[1, 2], [2, 3], [3, 4]])],
[tensor([3, 4]), tensor([8, 9]), tensor([[4, 5], [5, 6]])]]
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[8, 9],
[4, 5],
[0, 1],
[6, 7]]), labels=tensor([14, 12, 10, 13]), negative_srcs=None,
negative_dsts=None, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
4. Node pairs and negative destinations.
>>> node_pairs = torch.arange(0, 20).reshape(-1, 2)
>>> negative_dsts = torch.arange(10, 30).reshape(-1, 2)
>>> item_set = gb.ItemSet((node_pairs, negative_dsts), names=("node_pairs",
... "negative_dsts"))
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[10, 11],
[ 6, 7],
[ 2, 3],
[ 8, 9]]), labels=None, negative_srcs=None,
negative_dsts=tensor([[20, 21],
[16, 17],
[12, 13],
[18, 19]]), sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
5. DGLGraphs.
>>> import dgl
......@@ -103,81 +189,96 @@ class ItemSampler(IterDataPipe):
7. Heterogeneous node IDs.
>>> ids = {
... "user": gb.ItemSet(torch.arange(0, 5)),
... "item": gb.ItemSet(torch.arange(0, 6)),
... "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
... "item": gb.ItemSet(torch.arange(0, 6), names="seed_nodes"),
... }
>>> item_set = gb.ItemSetDict(ids)
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{'user': tensor([0, 1, 2, 3])},
{'item': tensor([0, 1, 2]), 'user': tensor([4])},
{'item': tensor([3, 4, 5])}]
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes={'user': tensor([0, 1, 2, 3])}, node_pairs=None,
labels=None, negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
8. Heterogeneous node pairs.
>>> node_pairs_like = (torch.arange(0, 5), torch.arange(0, 5))
>>> node_pairs_follow = (torch.arange(0, 6), torch.arange(6, 12))
>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(node_pairs_like),
... "user:follow:user": gb.ItemSet(node_pairs_follow),
... "user:like:item": gb.ItemSet(
... node_pairs_like, names="node_pairs"),
... "user:follow:user": gb.ItemSet(
... node_pairs_follow, names="node_pairs"),
... })
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{"user:like:item": [tensor([4]), tensor([4])],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8])]},
{"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11])]}]
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])}, labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
9. Heterogeneous node pairs and labels.
>>> like = (
... torch.arange(0, 5), torch.arange(0, 5), torch.arange(0, 5))
>>> follow = (
... torch.arange(0, 6), torch.arange(6, 12), torch.arange(0, 6))
>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> labels_like = torch.arange(0, 10)
>>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
>>> labels_follow = torch.arange(10, 20)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow),
... "user:like:item": gb.ItemSet((node_pairs_like, labels_like),
... names=("node_pairs", "labels")),
... "user:follow:user": gb.ItemSet((node_pairs_follow, labels_follow),
... names=("node_pairs", "labels")),
... })
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{"user:like:item":
[tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{"user:like:item": [tensor([4]), tensor([4]), tensor([4])],
"user:follow:user":
[tensor([0, 1, 2]), tensor([6, 7, 8]), tensor([0, 1, 2])]},
{"user:follow:user":
[tensor([3, 4, 5]), tensor([ 9, 10, 11]), tensor([3, 4, 5])]}]
10. Heterogeneous head, tail and negative tails.
>>> like = (
... torch.arange(0, 5), torch.arange(0, 5),
... torch.arange(5, 15).reshape(-1, 2))
>>> follow = (
... torch.arange(0, 6), torch.arange(6, 12),
... torch.arange(12, 24).reshape(-1, 2))
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])}, labels={'user:like:item': tensor([0, 1, 2, 3])},
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
10. Heterogeneous node pairs and negative destinations.
>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> negative_dsts_like = torch.arange(10, 20).reshape(-1, 2)
>>> node_pairs_follow = torch.arange(20, 30).reshape(-1, 2)
>>> negative_dsts_follow = torch.arange(30, 40).reshape(-1, 2)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow),
... "user:like:item": gb.ItemSet((node_pairs_like, negative_dsts_like),
... names=("node_pairs", "negative_dsts")),
... "user:follow:user": gb.ItemSet((node_pairs_follow,
... negative_dsts_follow), names=("node_pairs", "negative_dsts")),
... })
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]),
tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]},
{"user:like:item": [tensor([4]), tensor([4]), tensor([[13, 14]])],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8]),
tensor([[12, 13], [14, 15], [16, 17]])]},
{"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11]),
tensor([[18, 19], [20, 21], [22, 23]])]}]
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])}, labels=None, negative_srcs=None,
negative_dsts={'user:like:item': tensor([[10, 11],
[12, 13],
[14, 15],
[16, 17]])}, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
"""
def __init__(
self,
item_set: ItemSet or ItemSetDict,
batch_size: int,
minibatcher: Optional[Callable] = minibatcher_default,
drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False,
) -> None:
super().__init__()
self._item_set = item_set
self._batch_size = batch_size
self._minibatcher = minibatcher
self._drop_last = drop_last
self._shuffle = shuffle
......@@ -217,4 +318,9 @@ class ItemSampler(IterDataPipe):
data_pipe = data_pipe.collate(collate_fn=partial(_collate))
# Map to minibatch.
data_pipe = data_pipe.map(
partial(self._minibatcher, names=self._item_set.names)
)
return iter(data_pipe)
......@@ -10,7 +10,7 @@ import torch
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo():
dp = gb.ItemSampler(torch.randn(20), 4)
dp = gb.ItemSampler(gb.ItemSet(torch.randn(20)), 4)
dp = gb.CopyTo(dp, "cuda")
for data in dp:
......
import re
import dgl
import pytest
import torch
......@@ -5,31 +7,126 @@ from dgl import graphbolt as gb
from torch.testing import assert_close
def test_ItemSampler_minibatcher():
# Default minibatcher is used if not specified.
# Warning message is raised if names are not specified.
item_set = gb.ItemSet(torch.arange(0, 10))
item_sampler = gb.ItemSampler(item_set, batch_size=4)
with pytest.warns(
UserWarning,
match=re.escape(
"Failed to map item list to `MiniBatch` as the names of items are "
"not provided. Please provide a customized `MiniBatcher`. The "
"item list is returned as is."
),
):
minibatch = next(iter(item_sampler))
assert not isinstance(minibatch, gb.MiniBatch)
# Default minibatcher is used if not specified.
# Warning message is raised if unrecognized names are specified.
item_set = gb.ItemSet(torch.arange(0, 10), names="unknown_name")
item_sampler = gb.ItemSampler(item_set, batch_size=4)
with pytest.warns(
UserWarning,
match=re.escape(
"Unknown item name 'unknown_name' is detected and added into "
"`MiniBatch`. You probably need to provide a customized "
"`MiniBatcher`."
),
):
minibatch = next(iter(item_sampler))
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.unknown_name is not None
# Default minibatcher is used if not specified.
# `MiniBatch` is returned if expected names are specified.
item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes")
item_sampler = gb.ItemSampler(item_set, batch_size=4)
minibatch = next(iter(item_sampler))
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert len(minibatch.seed_nodes) == 4
# Customized minibatcher is used if specified.
def minibatcher(batch, names):
return gb.MiniBatch(seed_nodes=batch)
item_sampler = gb.ItemSampler(
item_set, batch_size=4, minibatcher=minibatcher
)
minibatch = next(iter(item_sampler))
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert len(minibatch.seed_nodes) == 4
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_node_ids(batch_size, shuffle, drop_last):
def test_ItemSet_seed_nodes(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 103
item_set = gb.ItemSet(torch.arange(0, num_ids))
seed_nodes = torch.arange(0, num_ids)
item_set = gb.ItemSet(seed_nodes, names="seed_nodes")
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_ids = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.labels is None
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
assert len(minibatch) == batch_size
assert len(minibatch.seed_nodes) == batch_size
else:
if not drop_last:
assert len(minibatch) == num_ids % batch_size
assert len(minibatch.seed_nodes) == num_ids % batch_size
else:
assert False
minibatch_ids.append(minibatch)
minibatch_ids.append(minibatch.seed_nodes)
minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_seed_nodes_labels(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 103
seed_nodes = torch.arange(0, num_ids)
labels = torch.arange(0, num_ids)
item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels"))
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_ids = []
minibatch_labels = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.labels is not None
assert len(minibatch.seed_nodes) == len(minibatch.labels)
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
assert len(minibatch.seed_nodes) == batch_size
else:
if not drop_last:
assert len(minibatch.seed_nodes) == num_ids % batch_size
else:
assert False
minibatch_ids.append(minibatch.seed_nodes)
minibatch_labels.append(minibatch.labels)
minibatch_ids = torch.cat(minibatch_ids)
minibatch_labels = torch.cat(minibatch_labels)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
assert (
torch.all(minibatch_labels[:-1] <= minibatch_labels[1:]) is not shuffle
)
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
......@@ -77,14 +174,18 @@ def test_ItemSet_graphs(batch_size, shuffle, drop_last):
def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
# Node pairs.
num_ids = 103
node_pairs = (torch.arange(0, num_ids), torch.arange(num_ids, num_ids * 2))
item_set = gb.ItemSet(node_pairs)
node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
item_set = gb.ItemSet(node_pairs, names="node_pairs")
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
src_ids = []
dst_ids = []
for i, (src, dst) in enumerate(item_sampler):
for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None
assert minibatch.labels is None
src = minibatch.node_pairs[:, 0]
dst = minibatch.node_pairs[:, 1]
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -96,7 +197,7 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size
# Verify src and dst IDs match.
assert torch.equal(src + num_ids, dst)
assert torch.equal(src + 1, dst)
# Archive batch.
src_ids.append(src)
dst_ids.append(dst)
......@@ -112,16 +213,22 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
# Node pairs and labels
num_ids = 103
node_pairs = (torch.arange(0, num_ids), torch.arange(num_ids, num_ids * 2))
labels = torch.arange(0, num_ids)
item_set = gb.ItemSet((node_pairs[0], node_pairs[1], labels))
node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
labels = node_pairs[:, 0]
item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
src_ids = []
dst_ids = []
labels = []
for i, (src, dst, label) in enumerate(item_sampler):
for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None
assert minibatch.labels is not None
assert len(minibatch.node_pairs) == len(minibatch.labels)
src = minibatch.node_pairs[:, 0]
dst = minibatch.node_pairs[:, 1]
label = minibatch.labels
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -134,7 +241,7 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
assert len(dst) == expected_batch_size
assert len(label) == expected_batch_size
# Verify src/dst IDs and labels match.
assert torch.equal(src + num_ids, dst)
assert torch.equal(src + 1, dst)
assert torch.equal(src, label)
# Archive batch.
src_ids.append(src)
......@@ -151,25 +258,29 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last):
# Head, tail and negative tails.
def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
# Node pairs and negative destinations.
num_ids = 103
num_negs = 2
heads = torch.arange(0, num_ids)
tails = torch.arange(num_ids, num_ids * 2)
neg_tails = torch.stack((heads + 1, heads + 2), dim=-1)
item_set = gb.ItemSet((heads, tails, neg_tails))
for i, (head, tail, negs) in enumerate(item_set):
assert heads[i] == head
assert tails[i] == tail
assert torch.equal(neg_tails[i], negs)
node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
neg_dsts = torch.arange(
2 * num_ids, 2 * num_ids + num_ids * num_negs
).reshape(-1, num_negs)
item_set = gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
)
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
head_ids = []
tail_ids = []
src_ids = []
dst_ids = []
negs_ids = []
for i, (head, tail, negs) in enumerate(item_sampler):
for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None
assert minibatch.negative_dsts is not None
src = minibatch.node_pairs[:, 0]
dst = minibatch.node_pairs[:, 1]
negs = minibatch.negative_dsts
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -178,24 +289,23 @@ def test_ItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last):
expected_batch_size = num_ids % batch_size
else:
assert False
assert len(head) == expected_batch_size
assert len(tail) == expected_batch_size
assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size
assert negs.dim() == 2
assert negs.shape[0] == expected_batch_size
assert negs.shape[1] == num_negs
# Verify head/tail and negatie tails match.
assert torch.equal(head + num_ids, tail)
assert torch.equal(head + 1, negs[:, 0])
assert torch.equal(head + 2, negs[:, 1])
# Verify node pairs and negative destinations.
assert torch.equal(src + 1, dst)
assert torch.equal(negs[:, 0] + 1, negs[:, 1])
# Archive batch.
head_ids.append(head)
tail_ids.append(tail)
src_ids.append(src)
dst_ids.append(dst)
negs_ids.append(negs)
head_ids = torch.cat(head_ids)
tail_ids = torch.cat(tail_ids)
src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids)
negs_ids = torch.cat(negs_ids)
assert torch.all(head_ids[:-1] <= head_ids[1:]) is not shuffle
assert torch.all(tail_ids[:-1] <= tail_ids[1:]) is not shuffle
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(negs_ids[:-1, 0] <= negs_ids[1:, 0]) is not shuffle
assert torch.all(negs_ids[:-1, 1] <= negs_ids[1:, 1]) is not shuffle
......@@ -215,12 +325,57 @@ def test_append_with_other_datapipes():
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
def test_ItemSetDict_seed_nodes(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 205
ids = {
"user": gb.ItemSet(torch.arange(0, 99), names="seed_nodes"),
"item": gb.ItemSet(torch.arange(99, num_ids), names="seed_nodes"),
}
chained_ids = []
for key, value in ids.items():
chained_ids += [(key, v) for v in value]
item_set = gb.ItemSetDict(ids)
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_ids = []
for i, minibatch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = num_ids % batch_size
else:
assert False
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
ids = []
for _, v in minibatch.seed_nodes.items():
ids.append(v)
ids = torch.cat(ids)
assert len(ids) == expected_batch_size
minibatch_ids.append(ids)
minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_seed_nodes_labels(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 205
ids = {
"user": gb.ItemSet(torch.arange(0, 99)),
"item": gb.ItemSet(torch.arange(99, num_ids)),
"user": gb.ItemSet(
(torch.arange(0, 99), torch.arange(0, 99)),
names=("seed_nodes", "labels"),
),
"item": gb.ItemSet(
(torch.arange(99, num_ids), torch.arange(99, num_ids)),
names=("seed_nodes", "labels"),
),
}
chained_ids = []
for key, value in ids.items():
......@@ -230,7 +385,11 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_ids = []
for i, batch in enumerate(item_sampler):
minibatch_labels = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.labels is not None
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -239,15 +398,24 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
expected_batch_size = num_ids % batch_size
else:
assert False
assert isinstance(batch, dict)
ids = []
for _, v in batch.items():
for _, v in minibatch.seed_nodes.items():
ids.append(v)
ids = torch.cat(ids)
assert len(ids) == expected_batch_size
minibatch_ids.append(ids)
labels = []
for _, v in minibatch.labels.items():
labels.append(v)
labels = torch.cat(labels)
assert len(labels) == expected_batch_size
minibatch_labels.append(labels)
minibatch_ids = torch.cat(minibatch_ids)
minibatch_labels = torch.cat(minibatch_labels)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
assert (
torch.all(minibatch_labels[:-1] <= minibatch_labels[1:]) is not shuffle
)
@pytest.mark.parametrize("batch_size", [1, 4])
......@@ -256,18 +424,12 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
# Node pairs.
num_ids = 103
total_ids = 2 * num_ids
node_pairs_0 = (
torch.arange(0, num_ids),
torch.arange(num_ids, num_ids * 2),
)
node_pairs_1 = (
torch.arange(num_ids * 2, num_ids * 3),
torch.arange(num_ids * 3, num_ids * 4),
)
total_pairs = 2 * num_ids
node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
node_pairs_dict = {
"user:like:item": gb.ItemSet(node_pairs_0),
"user:follow:user": gb.ItemSet(node_pairs_1),
"user:like:item": gb.ItemSet(node_pairs_like, names="node_pairs"),
"user:follow:user": gb.ItemSet(node_pairs_follow, names="node_pairs"),
}
item_set = gb.ItemSetDict(node_pairs_dict)
item_sampler = gb.ItemSampler(
......@@ -275,27 +437,30 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
)
src_ids = []
dst_ids = []
for i, batch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.node_pairs is not None
assert minibatch.labels is None
is_last = (i + 1) * batch_size >= total_pairs
if not is_last or total_pairs % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = total_ids % batch_size
expected_batch_size = total_pairs % batch_size
else:
assert False
src = []
dst = []
for _, (v_src, v_dst) in batch.items():
src.append(v_src)
dst.append(v_dst)
for _, node_pairs in minibatch.node_pairs.items():
src.append(node_pairs[:, 0])
dst.append(node_pairs[:, 1])
src = torch.cat(src)
dst = torch.cat(dst)
assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size
src_ids.append(src)
dst_ids.append(dst)
assert torch.equal(src + num_ids, dst)
assert torch.equal(src + 1, dst)
src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids)
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
......@@ -309,21 +474,17 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
# Node pairs and labels
num_ids = 103
total_ids = 2 * num_ids
node_pairs_0 = (
torch.arange(0, num_ids),
torch.arange(num_ids, num_ids * 2),
)
node_pairs_1 = (
torch.arange(num_ids * 2, num_ids * 3),
torch.arange(num_ids * 3, num_ids * 4),
)
node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
labels = torch.arange(0, num_ids)
node_pairs_dict = {
"user:like:item": gb.ItemSet(
(node_pairs_0[0], node_pairs_0[1], labels)
(node_pairs_like, node_pairs_like[:, 0]),
names=("node_pairs", "labels"),
),
"user:follow:user": gb.ItemSet(
(node_pairs_1[0], node_pairs_1[1], labels + num_ids * 2)
(node_pairs_follow, node_pairs_follow[:, 0]),
names=("node_pairs", "labels"),
),
}
item_set = gb.ItemSetDict(node_pairs_dict)
......@@ -333,7 +494,10 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
src_ids = []
dst_ids = []
labels = []
for i, batch in enumerate(item_sampler):
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.node_pairs is not None
assert minibatch.labels is not None
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -345,9 +509,10 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
src = []
dst = []
label = []
for _, (v_src, v_dst, v_label) in batch.items():
src.append(v_src)
dst.append(v_dst)
for _, node_pairs in minibatch.node_pairs.items():
src.append(node_pairs[:, 0])
dst.append(node_pairs[:, 1])
for _, v_label in minibatch.labels.items():
label.append(v_label)
src = torch.cat(src)
dst = torch.cat(dst)
......@@ -358,7 +523,7 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
src_ids.append(src)
dst_ids.append(dst)
labels.append(label)
assert torch.equal(src + num_ids, dst)
assert torch.equal(src + 1, dst)
assert torch.equal(src, label)
src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids)
......@@ -371,26 +536,40 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last):
def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
# Head, tail and negative tails.
num_ids = 103
total_ids = 2 * num_ids
num_negs = 2
heads = torch.arange(0, num_ids)
tails = torch.arange(num_ids, num_ids * 2)
neg_tails = torch.stack((heads + 1, heads + 2), dim=-1)
node_paris_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
neg_dsts_like = torch.arange(
num_ids * 4, num_ids * 4 + num_ids * num_negs
).reshape(-1, num_negs)
neg_dsts_follow = torch.arange(
num_ids * 4 + num_ids * num_negs, num_ids * 4 + num_ids * num_negs * 2
).reshape(-1, num_negs)
data_dict = {
"user:like:item": gb.ItemSet((heads, tails, neg_tails)),
"user:follow:user": gb.ItemSet((heads, tails, neg_tails)),
"user:like:item": gb.ItemSet(
(node_paris_like, neg_dsts_like),
names=("node_pairs", "negative_dsts"),
),
"user:follow:user": gb.ItemSet(
(node_pairs_follow, neg_dsts_follow),
names=("node_pairs", "negative_dsts"),
),
}
item_set = gb.ItemSetDict(data_dict)
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
head_ids = []
tail_ids = []
src_ids = []
dst_ids = []
negs_ids = []
for i, batch in enumerate(item_sampler):
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.node_pairs is not None
assert minibatch.negative_dsts is not None
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -399,31 +578,31 @@ def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last):
expected_batch_size = total_ids % batch_size
else:
assert False
head = []
tail = []
src = []
dst = []
negs = []
for _, (v_head, v_tail, v_negs) in batch.items():
head.append(v_head)
tail.append(v_tail)
for _, node_pairs in minibatch.node_pairs.items():
src.append(node_pairs[:, 0])
dst.append(node_pairs[:, 1])
for _, v_negs in minibatch.negative_dsts.items():
negs.append(v_negs)
head = torch.cat(head)
tail = torch.cat(tail)
src = torch.cat(src)
dst = torch.cat(dst)
negs = torch.cat(negs)
assert len(head) == expected_batch_size
assert len(tail) == expected_batch_size
assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size
assert len(negs) == expected_batch_size
head_ids.append(head)
tail_ids.append(tail)
src_ids.append(src)
dst_ids.append(dst)
negs_ids.append(negs)
assert negs.dim() == 2
assert negs.shape[0] == expected_batch_size
assert negs.shape[1] == num_negs
assert torch.equal(head + num_ids, tail)
assert torch.equal(head + 1, negs[:, 0])
assert torch.equal(head + 2, negs[:, 1])
head_ids = torch.cat(head_ids)
tail_ids = torch.cat(tail_ids)
assert torch.equal(src + 1, dst)
assert torch.equal(negs[:, 0] + 1, negs[:, 1])
src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids)
negs_ids = torch.cat(negs_ids)
assert torch.all(head_ids[:-1] <= head_ids[1:]) is not shuffle
assert torch.all(tail_ids[:-1] <= tail_ids[1:]) is not shuffle
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment