"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "592dae31c826cde394f721eac64ce5d4748f4ef0"
Unverified Commit 29949322 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] convert item list to MiniBatch (#6281)

parent dadce86a
...@@ -2,16 +2,73 @@ ...@@ -2,16 +2,73 @@
from collections.abc import Mapping from collections.abc import Mapping
from functools import partial from functools import partial
from typing import Iterator, Optional from typing import Callable, Iterator, Optional
from torch.utils.data import default_collate from torch.utils.data import default_collate
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe from torchdata.datapipes.iter import IterableWrapper, IterDataPipe
from ..base import dgl_warning
from ..batch import batch as dgl_batch from ..batch import batch as dgl_batch
from ..heterograph import DGLGraph from ..heterograph import DGLGraph
from .itemset import ItemSet, ItemSetDict from .itemset import ItemSet, ItemSetDict
from .minibatch import MiniBatch
__all__ = ["ItemSampler", "minibatcher_default"]
__all__ = ["ItemSampler"] def minibatcher_default(batch, names):
"""Default minibatcher.
The default minibatcher maps a list of items to a `MiniBatch` with the
same names as the items. The names of items are supposed to be provided
and align with the data attributes of `MiniBatch`. If any unknown item name
is provided, exception will be raised. If the names of items are not
provided, the item list is returned as is and a warning will be raised.
Parameters
----------
batch : list
List of items.
names : Tuple[str] or None
Names of items in `batch` with same length. The order should align
with `batch`.
Returns
-------
MiniBatch
A minibatch.
"""
if names is None:
dgl_warning(
"Failed to map item list to `MiniBatch` as the names of items are "
"not provided. Please provide a customized `MiniBatcher`. "
"The item list is returned as is."
)
return batch
if len(names) == 1:
# Handle the case of single item: batch = tensor([0, 1, 2, 3]), names =
# ("seed_nodes",) as `zip(batch, names)` will iterate over the tensor
# instead of the batch.
init_data = {names[0]: batch}
else:
if isinstance(batch, Mapping):
init_data = {
name: {k: v[i] for k, v in batch.items()}
for i, name in enumerate(names)
}
else:
init_data = {name: item for item, name in zip(batch, names)}
minibatch = MiniBatch()
for name, item in init_data.items():
if not hasattr(minibatch, name):
dgl_warning(
f"Unknown item name '{name}' is detected and added into "
"`MiniBatch`. You probably need to provide a customized "
"`MiniBatcher`."
)
setattr(minibatch, name, item)
return minibatch
class ItemSampler(IterDataPipe): class ItemSampler(IterDataPipe):
...@@ -32,6 +89,8 @@ class ItemSampler(IterDataPipe): ...@@ -32,6 +89,8 @@ class ItemSampler(IterDataPipe):
Data to be sampled. Data to be sampled.
batch_size : int batch_size : int
The size of each batch. The size of each batch.
minibatcher : Optional[Callable]
A callable that takes in a list of items and returns a `MiniBatch`.
drop_last : bool drop_last : bool
Option to drop the last batch if it's not full. Option to drop the last batch if it's not full.
shuffle : bool shuffle : bool
...@@ -42,41 +101,68 @@ class ItemSampler(IterDataPipe): ...@@ -42,41 +101,68 @@ class ItemSampler(IterDataPipe):
1. Node IDs. 1. Node IDs.
>>> import torch >>> import torch
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10)) >>> item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False ... item_set, batch_size=4, shuffle=True, drop_last=False
... ) ... )
>>> list(item_sampler) >>> next(iter(item_sampler))
[tensor([1, 2, 5, 7]), tensor([3, 0, 9, 4]), tensor([6, 8])] MiniBatch(seed_nodes=tensor([9, 0, 7, 2]), node_pairs=None, labels=None,
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
2. Node pairs. 2. Node pairs.
>>> item_set = gb.ItemSet((torch.arange(0, 10), torch.arange(10, 20))) >>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2),
... names="node_pairs")
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False ... item_set, batch_size=4, shuffle=True, drop_last=False
... ) ... )
>>> list(item_sampler) >>> next(iter(item_sampler))
[[tensor([9, 8, 3, 1]), tensor([19, 18, 13, 11])], [tensor([2, 5, 7, 4]), MiniBatch(seed_nodes=None, node_pairs=tensor([[16, 17],
tensor([12, 15, 17, 14])], [tensor([0, 6]), tensor([10, 16])] [ 4, 5],
[ 6, 7],
[10, 11]]), labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
3. Node pairs and labels. 3. Node pairs and labels.
>>> item_set = gb.ItemSet( >>> item_set = gb.ItemSet(
... (torch.arange(0, 5), torch.arange(5, 10), torch.arange(10, 15)) ... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 15)),
... names=("node_pairs", "labels")
... ) ... )
>>> item_sampler = gb.ItemSampler(item_set, 3) >>> item_sampler = gb.ItemSampler(
>>> list(item_sampler) ... item_set, batch_size=4, shuffle=True, drop_last=False
[[tensor([0, 1, 2]), tensor([5, 6, 7]), tensor([10, 11, 12])], ... )
[tensor([3, 4]), tensor([8, 9]), tensor([13, 14])]] >>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[8, 9],
4. Head, tail and negative tails [4, 5],
>>> heads = torch.arange(0, 5) [0, 1],
>>> tails = torch.arange(5, 10) [6, 7]]), labels=tensor([14, 12, 10, 13]), negative_srcs=None,
>>> negative_tails = torch.stack((heads + 1, heads + 2), dim=-1) negative_dsts=None, sampled_subgraphs=None, input_nodes=None,
>>> item_set = gb.ItemSet((heads, tails, negative_tails)) node_features=None, edge_features=None, compacted_node_pairs=None,
>>> item_sampler = gb.ItemSampler(item_set, 3) compacted_negative_srcs=None, compacted_negative_dsts=None)
>>> list(item_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]), 4. Node pairs and negative destinations.
tensor([[1, 2], [2, 3], [3, 4]])], >>> node_pairs = torch.arange(0, 20).reshape(-1, 2)
[tensor([3, 4]), tensor([8, 9]), tensor([[4, 5], [5, 6]])]] >>> negative_dsts = torch.arange(10, 30).reshape(-1, 2)
>>> item_set = gb.ItemSet((node_pairs, negative_dsts), names=("node_pairs",
... "negative_dsts"))
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[10, 11],
[ 6, 7],
[ 2, 3],
[ 8, 9]]), labels=None, negative_srcs=None,
negative_dsts=tensor([[20, 21],
[16, 17],
[12, 13],
[18, 19]]), sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
5. DGLGraphs. 5. DGLGraphs.
>>> import dgl >>> import dgl
...@@ -103,81 +189,96 @@ class ItemSampler(IterDataPipe): ...@@ -103,81 +189,96 @@ class ItemSampler(IterDataPipe):
7. Heterogeneous node IDs. 7. Heterogeneous node IDs.
>>> ids = { >>> ids = {
... "user": gb.ItemSet(torch.arange(0, 5)), ... "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
... "item": gb.ItemSet(torch.arange(0, 6)), ... "item": gb.ItemSet(torch.arange(0, 6), names="seed_nodes"),
... } ... }
>>> item_set = gb.ItemSetDict(ids) >>> item_set = gb.ItemSetDict(ids)
>>> item_sampler = gb.ItemSampler(item_set, 4) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> list(item_sampler) >>> next(iter(item_sampler))
[{'user': tensor([0, 1, 2, 3])}, MiniBatch(seed_nodes={'user': tensor([0, 1, 2, 3])}, node_pairs=None,
{'item': tensor([0, 1, 2]), 'user': tensor([4])}, labels=None, negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
{'item': tensor([3, 4, 5])}] input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
8. Heterogeneous node pairs. 8. Heterogeneous node pairs.
>>> node_pairs_like = (torch.arange(0, 5), torch.arange(0, 5)) >>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> node_pairs_follow = (torch.arange(0, 6), torch.arange(6, 12)) >>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
>>> item_set = gb.ItemSetDict({ >>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(node_pairs_like), ... "user:like:item": gb.ItemSet(
... "user:follow:user": gb.ItemSet(node_pairs_follow), ... node_pairs_like, names="node_pairs"),
... "user:follow:user": gb.ItemSet(
... node_pairs_follow, names="node_pairs"),
... }) ... })
>>> item_sampler = gb.ItemSampler(item_set, 4) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> list(item_sampler) >>> next(iter(item_sampler))
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]}, MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
{"user:like:item": [tensor([4]), tensor([4])], [2, 3],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8])]}, [4, 5],
{"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11])]}] [6, 7]])}, labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
9. Heterogeneous node pairs and labels. 9. Heterogeneous node pairs and labels.
>>> like = ( >>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
... torch.arange(0, 5), torch.arange(0, 5), torch.arange(0, 5)) >>> labels_like = torch.arange(0, 10)
>>> follow = ( >>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
... torch.arange(0, 6), torch.arange(6, 12), torch.arange(0, 6)) >>> labels_follow = torch.arange(10, 20)
>>> item_set = gb.ItemSetDict({ >>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(like), ... "user:like:item": gb.ItemSet((node_pairs_like, labels_like),
... "user:follow:user": gb.ItemSet(follow), ... names=("node_pairs", "labels")),
... "user:follow:user": gb.ItemSet((node_pairs_follow, labels_follow),
... names=("node_pairs", "labels")),
... }) ... })
>>> item_sampler = gb.ItemSampler(item_set, 4) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> list(item_sampler) >>> next(iter(item_sampler))
[{"user:like:item": MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]}, [2, 3],
{"user:like:item": [tensor([4]), tensor([4]), tensor([4])], [4, 5],
"user:follow:user": [6, 7]])}, labels={'user:like:item': tensor([0, 1, 2, 3])},
[tensor([0, 1, 2]), tensor([6, 7, 8]), tensor([0, 1, 2])]}, negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
{"user:follow:user": input_nodes=None, node_features=None, edge_features=None,
[tensor([3, 4, 5]), tensor([ 9, 10, 11]), tensor([3, 4, 5])]}] compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
10. Heterogeneous head, tail and negative tails.
>>> like = ( 10. Heterogeneous node pairs and negative destinations.
... torch.arange(0, 5), torch.arange(0, 5), >>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
... torch.arange(5, 15).reshape(-1, 2)) >>> negative_dsts_like = torch.arange(10, 20).reshape(-1, 2)
>>> follow = ( >>> node_pairs_follow = torch.arange(20, 30).reshape(-1, 2)
... torch.arange(0, 6), torch.arange(6, 12), >>> negative_dsts_follow = torch.arange(30, 40).reshape(-1, 2)
... torch.arange(12, 24).reshape(-1, 2))
>>> item_set = gb.ItemSetDict({ >>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(like), ... "user:like:item": gb.ItemSet((node_pairs_like, negative_dsts_like),
... "user:follow:user": gb.ItemSet(follow), ... names=("node_pairs", "negative_dsts")),
... "user:follow:user": gb.ItemSet((node_pairs_follow,
... negative_dsts_follow), names=("node_pairs", "negative_dsts")),
... }) ... })
>>> item_sampler = gb.ItemSampler(item_set, 4) >>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> list(item_sampler) >>> next(iter(item_sampler))
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]}, [2, 3],
{"user:like:item": [tensor([4]), tensor([4]), tensor([[13, 14]])], [4, 5],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8]), [6, 7]])}, labels=None, negative_srcs=None,
tensor([[12, 13], [14, 15], [16, 17]])]}, negative_dsts={'user:like:item': tensor([[10, 11],
{"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11]), [12, 13],
tensor([[18, 19], [20, 21], [22, 23]])]}] [14, 15],
[16, 17]])}, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
""" """
def __init__( def __init__(
self, self,
item_set: ItemSet or ItemSetDict, item_set: ItemSet or ItemSetDict,
batch_size: int, batch_size: int,
minibatcher: Optional[Callable] = minibatcher_default,
drop_last: Optional[bool] = False, drop_last: Optional[bool] = False,
shuffle: Optional[bool] = False, shuffle: Optional[bool] = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self._item_set = item_set self._item_set = item_set
self._batch_size = batch_size self._batch_size = batch_size
self._minibatcher = minibatcher
self._drop_last = drop_last self._drop_last = drop_last
self._shuffle = shuffle self._shuffle = shuffle
...@@ -217,4 +318,9 @@ class ItemSampler(IterDataPipe): ...@@ -217,4 +318,9 @@ class ItemSampler(IterDataPipe):
data_pipe = data_pipe.collate(collate_fn=partial(_collate)) data_pipe = data_pipe.collate(collate_fn=partial(_collate))
# Map to minibatch.
data_pipe = data_pipe.map(
partial(self._minibatcher, names=self._item_set.names)
)
return iter(data_pipe) return iter(data_pipe)
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
@unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test") @unittest.skipIf(F._default_context_str == "cpu", "CopyTo needs GPU to test")
def test_CopyTo(): def test_CopyTo():
dp = gb.ItemSampler(torch.randn(20), 4) dp = gb.ItemSampler(gb.ItemSet(torch.randn(20)), 4)
dp = gb.CopyTo(dp, "cuda") dp = gb.CopyTo(dp, "cuda")
for data in dp: for data in dp:
......
import re
import dgl import dgl
import pytest import pytest
import torch import torch
...@@ -5,29 +7,124 @@ from dgl import graphbolt as gb ...@@ -5,29 +7,124 @@ from dgl import graphbolt as gb
from torch.testing import assert_close from torch.testing import assert_close
def test_ItemSampler_minibatcher():
# Default minibatcher is used if not specified.
# Warning message is raised if names are not specified.
item_set = gb.ItemSet(torch.arange(0, 10))
item_sampler = gb.ItemSampler(item_set, batch_size=4)
with pytest.warns(
UserWarning,
match=re.escape(
"Failed to map item list to `MiniBatch` as the names of items are "
"not provided. Please provide a customized `MiniBatcher`. The "
"item list is returned as is."
),
):
minibatch = next(iter(item_sampler))
assert not isinstance(minibatch, gb.MiniBatch)
# Default minibatcher is used if not specified.
# Warning message is raised if unrecognized names are specified.
item_set = gb.ItemSet(torch.arange(0, 10), names="unknown_name")
item_sampler = gb.ItemSampler(item_set, batch_size=4)
with pytest.warns(
UserWarning,
match=re.escape(
"Unknown item name 'unknown_name' is detected and added into "
"`MiniBatch`. You probably need to provide a customized "
"`MiniBatcher`."
),
):
minibatch = next(iter(item_sampler))
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.unknown_name is not None
# Default minibatcher is used if not specified.
# `MiniBatch` is returned if expected names are specified.
item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes")
item_sampler = gb.ItemSampler(item_set, batch_size=4)
minibatch = next(iter(item_sampler))
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert len(minibatch.seed_nodes) == 4
# Customized minibatcher is used if specified.
def minibatcher(batch, names):
return gb.MiniBatch(seed_nodes=batch)
item_sampler = gb.ItemSampler(
item_set, batch_size=4, minibatcher=minibatcher
)
minibatch = next(iter(item_sampler))
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert len(minibatch.seed_nodes) == 4
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_node_ids(batch_size, shuffle, drop_last): def test_ItemSet_seed_nodes(batch_size, shuffle, drop_last):
# Node IDs. # Node IDs.
num_ids = 103 num_ids = 103
item_set = gb.ItemSet(torch.arange(0, num_ids)) seed_nodes = torch.arange(0, num_ids)
item_set = gb.ItemSet(seed_nodes, names="seed_nodes")
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_ids = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.labels is None
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
assert len(minibatch.seed_nodes) == batch_size
else:
if not drop_last:
assert len(minibatch.seed_nodes) == num_ids % batch_size
else:
assert False
minibatch_ids.append(minibatch.seed_nodes)
minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_seed_nodes_labels(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 103
seed_nodes = torch.arange(0, num_ids)
labels = torch.arange(0, num_ids)
item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels"))
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
minibatch_ids = [] minibatch_ids = []
minibatch_labels = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.labels is not None
assert len(minibatch.seed_nodes) == len(minibatch.labels)
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
assert len(minibatch) == batch_size assert len(minibatch.seed_nodes) == batch_size
else: else:
if not drop_last: if not drop_last:
assert len(minibatch) == num_ids % batch_size assert len(minibatch.seed_nodes) == num_ids % batch_size
else: else:
assert False assert False
minibatch_ids.append(minibatch) minibatch_ids.append(minibatch.seed_nodes)
minibatch_labels.append(minibatch.labels)
minibatch_ids = torch.cat(minibatch_ids) minibatch_ids = torch.cat(minibatch_ids)
minibatch_labels = torch.cat(minibatch_labels)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
assert (
torch.all(minibatch_labels[:-1] <= minibatch_labels[1:]) is not shuffle
)
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
...@@ -77,14 +174,18 @@ def test_ItemSet_graphs(batch_size, shuffle, drop_last): ...@@ -77,14 +174,18 @@ def test_ItemSet_graphs(batch_size, shuffle, drop_last):
def test_ItemSet_node_pairs(batch_size, shuffle, drop_last): def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
# Node pairs. # Node pairs.
num_ids = 103 num_ids = 103
node_pairs = (torch.arange(0, num_ids), torch.arange(num_ids, num_ids * 2)) node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
item_set = gb.ItemSet(node_pairs) item_set = gb.ItemSet(node_pairs, names="node_pairs")
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
src_ids = [] src_ids = []
dst_ids = [] dst_ids = []
for i, (src, dst) in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None
assert minibatch.labels is None
src = minibatch.node_pairs[:, 0]
dst = minibatch.node_pairs[:, 1]
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -96,7 +197,7 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last): ...@@ -96,7 +197,7 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
assert len(src) == expected_batch_size assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size assert len(dst) == expected_batch_size
# Verify src and dst IDs match. # Verify src and dst IDs match.
assert torch.equal(src + num_ids, dst) assert torch.equal(src + 1, dst)
# Archive batch. # Archive batch.
src_ids.append(src) src_ids.append(src)
dst_ids.append(dst) dst_ids.append(dst)
...@@ -112,16 +213,22 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last): ...@@ -112,16 +213,22 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last): def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
# Node pairs and labels # Node pairs and labels
num_ids = 103 num_ids = 103
node_pairs = (torch.arange(0, num_ids), torch.arange(num_ids, num_ids * 2)) node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
labels = torch.arange(0, num_ids) labels = node_pairs[:, 0]
item_set = gb.ItemSet((node_pairs[0], node_pairs[1], labels)) item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
src_ids = [] src_ids = []
dst_ids = [] dst_ids = []
labels = [] labels = []
for i, (src, dst, label) in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None
assert minibatch.labels is not None
assert len(minibatch.node_pairs) == len(minibatch.labels)
src = minibatch.node_pairs[:, 0]
dst = minibatch.node_pairs[:, 1]
label = minibatch.labels
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -134,7 +241,7 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -134,7 +241,7 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
assert len(dst) == expected_batch_size assert len(dst) == expected_batch_size
assert len(label) == expected_batch_size assert len(label) == expected_batch_size
# Verify src/dst IDs and labels match. # Verify src/dst IDs and labels match.
assert torch.equal(src + num_ids, dst) assert torch.equal(src + 1, dst)
assert torch.equal(src, label) assert torch.equal(src, label)
# Archive batch. # Archive batch.
src_ids.append(src) src_ids.append(src)
...@@ -151,25 +258,29 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -151,25 +258,29 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last): def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
# Head, tail and negative tails. # Node pairs and negative destinations.
num_ids = 103 num_ids = 103
num_negs = 2 num_negs = 2
heads = torch.arange(0, num_ids) node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
tails = torch.arange(num_ids, num_ids * 2) neg_dsts = torch.arange(
neg_tails = torch.stack((heads + 1, heads + 2), dim=-1) 2 * num_ids, 2 * num_ids + num_ids * num_negs
item_set = gb.ItemSet((heads, tails, neg_tails)) ).reshape(-1, num_negs)
for i, (head, tail, negs) in enumerate(item_set): item_set = gb.ItemSet(
assert heads[i] == head (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
assert tails[i] == tail )
assert torch.equal(neg_tails[i], negs)
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
head_ids = [] src_ids = []
tail_ids = [] dst_ids = []
negs_ids = [] negs_ids = []
for i, (head, tail, negs) in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert minibatch.node_pairs is not None
assert minibatch.negative_dsts is not None
src = minibatch.node_pairs[:, 0]
dst = minibatch.node_pairs[:, 1]
negs = minibatch.negative_dsts
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -178,24 +289,23 @@ def test_ItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last): ...@@ -178,24 +289,23 @@ def test_ItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last):
expected_batch_size = num_ids % batch_size expected_batch_size = num_ids % batch_size
else: else:
assert False assert False
assert len(head) == expected_batch_size assert len(src) == expected_batch_size
assert len(tail) == expected_batch_size assert len(dst) == expected_batch_size
assert negs.dim() == 2 assert negs.dim() == 2
assert negs.shape[0] == expected_batch_size assert negs.shape[0] == expected_batch_size
assert negs.shape[1] == num_negs assert negs.shape[1] == num_negs
# Verify head/tail and negatie tails match. # Verify node pairs and negative destinations.
assert torch.equal(head + num_ids, tail) assert torch.equal(src + 1, dst)
assert torch.equal(head + 1, negs[:, 0]) assert torch.equal(negs[:, 0] + 1, negs[:, 1])
assert torch.equal(head + 2, negs[:, 1])
# Archive batch. # Archive batch.
head_ids.append(head) src_ids.append(src)
tail_ids.append(tail) dst_ids.append(dst)
negs_ids.append(negs) negs_ids.append(negs)
head_ids = torch.cat(head_ids) src_ids = torch.cat(src_ids)
tail_ids = torch.cat(tail_ids) dst_ids = torch.cat(dst_ids)
negs_ids = torch.cat(negs_ids) negs_ids = torch.cat(negs_ids)
assert torch.all(head_ids[:-1] <= head_ids[1:]) is not shuffle assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(tail_ids[:-1] <= tail_ids[1:]) is not shuffle assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(negs_ids[:-1, 0] <= negs_ids[1:, 0]) is not shuffle assert torch.all(negs_ids[:-1, 0] <= negs_ids[1:, 0]) is not shuffle
assert torch.all(negs_ids[:-1, 1] <= negs_ids[1:, 1]) is not shuffle assert torch.all(negs_ids[:-1, 1] <= negs_ids[1:, 1]) is not shuffle
...@@ -215,12 +325,57 @@ def test_append_with_other_datapipes(): ...@@ -215,12 +325,57 @@ def test_append_with_other_datapipes():
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last): def test_ItemSetDict_seed_nodes(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 205
ids = {
"user": gb.ItemSet(torch.arange(0, 99), names="seed_nodes"),
"item": gb.ItemSet(torch.arange(99, num_ids), names="seed_nodes"),
}
chained_ids = []
for key, value in ids.items():
chained_ids += [(key, v) for v in value]
item_set = gb.ItemSetDict(ids)
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
minibatch_ids = []
for i, minibatch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = num_ids % batch_size
else:
assert False
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
ids = []
for _, v in minibatch.seed_nodes.items():
ids.append(v)
ids = torch.cat(ids)
assert len(ids) == expected_batch_size
minibatch_ids.append(ids)
minibatch_ids = torch.cat(minibatch_ids)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_seed_nodes_labels(batch_size, shuffle, drop_last):
# Node IDs. # Node IDs.
num_ids = 205 num_ids = 205
ids = { ids = {
"user": gb.ItemSet(torch.arange(0, 99)), "user": gb.ItemSet(
"item": gb.ItemSet(torch.arange(99, num_ids)), (torch.arange(0, 99), torch.arange(0, 99)),
names=("seed_nodes", "labels"),
),
"item": gb.ItemSet(
(torch.arange(99, num_ids), torch.arange(99, num_ids)),
names=("seed_nodes", "labels"),
),
} }
chained_ids = [] chained_ids = []
for key, value in ids.items(): for key, value in ids.items():
...@@ -230,7 +385,11 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last): ...@@ -230,7 +385,11 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
minibatch_ids = [] minibatch_ids = []
for i, batch in enumerate(item_sampler): minibatch_labels = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seed_nodes is not None
assert minibatch.labels is not None
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -239,15 +398,24 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last): ...@@ -239,15 +398,24 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
expected_batch_size = num_ids % batch_size expected_batch_size = num_ids % batch_size
else: else:
assert False assert False
assert isinstance(batch, dict)
ids = [] ids = []
for _, v in batch.items(): for _, v in minibatch.seed_nodes.items():
ids.append(v) ids.append(v)
ids = torch.cat(ids) ids = torch.cat(ids)
assert len(ids) == expected_batch_size assert len(ids) == expected_batch_size
minibatch_ids.append(ids) minibatch_ids.append(ids)
labels = []
for _, v in minibatch.labels.items():
labels.append(v)
labels = torch.cat(labels)
assert len(labels) == expected_batch_size
minibatch_labels.append(labels)
minibatch_ids = torch.cat(minibatch_ids) minibatch_ids = torch.cat(minibatch_ids)
minibatch_labels = torch.cat(minibatch_labels)
assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle
assert (
torch.all(minibatch_labels[:-1] <= minibatch_labels[1:]) is not shuffle
)
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
...@@ -256,18 +424,12 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last): ...@@ -256,18 +424,12 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last): def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
# Node pairs. # Node pairs.
num_ids = 103 num_ids = 103
total_ids = 2 * num_ids total_pairs = 2 * num_ids
node_pairs_0 = ( node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
torch.arange(0, num_ids), node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
torch.arange(num_ids, num_ids * 2),
)
node_pairs_1 = (
torch.arange(num_ids * 2, num_ids * 3),
torch.arange(num_ids * 3, num_ids * 4),
)
node_pairs_dict = { node_pairs_dict = {
"user:like:item": gb.ItemSet(node_pairs_0), "user:like:item": gb.ItemSet(node_pairs_like, names="node_pairs"),
"user:follow:user": gb.ItemSet(node_pairs_1), "user:follow:user": gb.ItemSet(node_pairs_follow, names="node_pairs"),
} }
item_set = gb.ItemSetDict(node_pairs_dict) item_set = gb.ItemSetDict(node_pairs_dict)
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
...@@ -275,27 +437,30 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last): ...@@ -275,27 +437,30 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
) )
src_ids = [] src_ids = []
dst_ids = [] dst_ids = []
for i, batch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
is_last = (i + 1) * batch_size >= total_ids assert isinstance(minibatch, gb.MiniBatch)
if not is_last or total_ids % batch_size == 0: assert minibatch.node_pairs is not None
assert minibatch.labels is None
is_last = (i + 1) * batch_size >= total_pairs
if not is_last or total_pairs % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
else: else:
if not drop_last: if not drop_last:
expected_batch_size = total_ids % batch_size expected_batch_size = total_pairs % batch_size
else: else:
assert False assert False
src = [] src = []
dst = [] dst = []
for _, (v_src, v_dst) in batch.items(): for _, node_pairs in minibatch.node_pairs.items():
src.append(v_src) src.append(node_pairs[:, 0])
dst.append(v_dst) dst.append(node_pairs[:, 1])
src = torch.cat(src) src = torch.cat(src)
dst = torch.cat(dst) dst = torch.cat(dst)
assert len(src) == expected_batch_size assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size assert len(dst) == expected_batch_size
src_ids.append(src) src_ids.append(src)
dst_ids.append(dst) dst_ids.append(dst)
assert torch.equal(src + num_ids, dst) assert torch.equal(src + 1, dst)
src_ids = torch.cat(src_ids) src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids) dst_ids = torch.cat(dst_ids)
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
...@@ -309,21 +474,17 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -309,21 +474,17 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
# Node pairs and labels # Node pairs and labels
num_ids = 103 num_ids = 103
total_ids = 2 * num_ids total_ids = 2 * num_ids
node_pairs_0 = ( node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
torch.arange(0, num_ids), node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
torch.arange(num_ids, num_ids * 2),
)
node_pairs_1 = (
torch.arange(num_ids * 2, num_ids * 3),
torch.arange(num_ids * 3, num_ids * 4),
)
labels = torch.arange(0, num_ids) labels = torch.arange(0, num_ids)
node_pairs_dict = { node_pairs_dict = {
"user:like:item": gb.ItemSet( "user:like:item": gb.ItemSet(
(node_pairs_0[0], node_pairs_0[1], labels) (node_pairs_like, node_pairs_like[:, 0]),
names=("node_pairs", "labels"),
), ),
"user:follow:user": gb.ItemSet( "user:follow:user": gb.ItemSet(
(node_pairs_1[0], node_pairs_1[1], labels + num_ids * 2) (node_pairs_follow, node_pairs_follow[:, 0]),
names=("node_pairs", "labels"),
), ),
} }
item_set = gb.ItemSetDict(node_pairs_dict) item_set = gb.ItemSetDict(node_pairs_dict)
...@@ -333,7 +494,10 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -333,7 +494,10 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
src_ids = [] src_ids = []
dst_ids = [] dst_ids = []
labels = [] labels = []
for i, batch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.node_pairs is not None
assert minibatch.labels is not None
is_last = (i + 1) * batch_size >= total_ids is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0: if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -345,9 +509,10 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -345,9 +509,10 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
src = [] src = []
dst = [] dst = []
label = [] label = []
for _, (v_src, v_dst, v_label) in batch.items(): for _, node_pairs in minibatch.node_pairs.items():
src.append(v_src) src.append(node_pairs[:, 0])
dst.append(v_dst) dst.append(node_pairs[:, 1])
for _, v_label in minibatch.labels.items():
label.append(v_label) label.append(v_label)
src = torch.cat(src) src = torch.cat(src)
dst = torch.cat(dst) dst = torch.cat(dst)
...@@ -358,7 +523,7 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -358,7 +523,7 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
src_ids.append(src) src_ids.append(src)
dst_ids.append(dst) dst_ids.append(dst)
labels.append(label) labels.append(label)
assert torch.equal(src + num_ids, dst) assert torch.equal(src + 1, dst)
assert torch.equal(src, label) assert torch.equal(src, label)
src_ids = torch.cat(src_ids) src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids) dst_ids = torch.cat(dst_ids)
...@@ -371,26 +536,40 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -371,26 +536,40 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last): def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
# Head, tail and negative tails. # Head, tail and negative tails.
num_ids = 103 num_ids = 103
total_ids = 2 * num_ids total_ids = 2 * num_ids
num_negs = 2 num_negs = 2
heads = torch.arange(0, num_ids) node_paris_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
tails = torch.arange(num_ids, num_ids * 2) node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
neg_tails = torch.stack((heads + 1, heads + 2), dim=-1) neg_dsts_like = torch.arange(
num_ids * 4, num_ids * 4 + num_ids * num_negs
).reshape(-1, num_negs)
neg_dsts_follow = torch.arange(
num_ids * 4 + num_ids * num_negs, num_ids * 4 + num_ids * num_negs * 2
).reshape(-1, num_negs)
data_dict = { data_dict = {
"user:like:item": gb.ItemSet((heads, tails, neg_tails)), "user:like:item": gb.ItemSet(
"user:follow:user": gb.ItemSet((heads, tails, neg_tails)), (node_paris_like, neg_dsts_like),
names=("node_pairs", "negative_dsts"),
),
"user:follow:user": gb.ItemSet(
(node_pairs_follow, neg_dsts_follow),
names=("node_pairs", "negative_dsts"),
),
} }
item_set = gb.ItemSetDict(data_dict) item_set = gb.ItemSetDict(data_dict)
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
head_ids = [] src_ids = []
tail_ids = [] dst_ids = []
negs_ids = [] negs_ids = []
for i, batch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.node_pairs is not None
assert minibatch.negative_dsts is not None
is_last = (i + 1) * batch_size >= total_ids is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0: if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -399,31 +578,31 @@ def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last): ...@@ -399,31 +578,31 @@ def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last):
expected_batch_size = total_ids % batch_size expected_batch_size = total_ids % batch_size
else: else:
assert False assert False
head = [] src = []
tail = [] dst = []
negs = [] negs = []
for _, (v_head, v_tail, v_negs) in batch.items(): for _, node_pairs in minibatch.node_pairs.items():
head.append(v_head) src.append(node_pairs[:, 0])
tail.append(v_tail) dst.append(node_pairs[:, 1])
for _, v_negs in minibatch.negative_dsts.items():
negs.append(v_negs) negs.append(v_negs)
head = torch.cat(head) src = torch.cat(src)
tail = torch.cat(tail) dst = torch.cat(dst)
negs = torch.cat(negs) negs = torch.cat(negs)
assert len(head) == expected_batch_size assert len(src) == expected_batch_size
assert len(tail) == expected_batch_size assert len(dst) == expected_batch_size
assert len(negs) == expected_batch_size assert len(negs) == expected_batch_size
head_ids.append(head) src_ids.append(src)
tail_ids.append(tail) dst_ids.append(dst)
negs_ids.append(negs) negs_ids.append(negs)
assert negs.dim() == 2 assert negs.dim() == 2
assert negs.shape[0] == expected_batch_size assert negs.shape[0] == expected_batch_size
assert negs.shape[1] == num_negs assert negs.shape[1] == num_negs
assert torch.equal(head + num_ids, tail) assert torch.equal(src + 1, dst)
assert torch.equal(head + 1, negs[:, 0]) assert torch.equal(negs[:, 0] + 1, negs[:, 1])
assert torch.equal(head + 2, negs[:, 1]) src_ids = torch.cat(src_ids)
head_ids = torch.cat(head_ids) dst_ids = torch.cat(dst_ids)
tail_ids = torch.cat(tail_ids)
negs_ids = torch.cat(negs_ids) negs_ids = torch.cat(negs_ids)
assert torch.all(head_ids[:-1] <= head_ids[1:]) is not shuffle assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(tail_ids[:-1] <= tail_ids[1:]) is not shuffle assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment