Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
29949322
Unverified
Commit
29949322
authored
Sep 05, 2023
by
Rhett Ying
Committed by
GitHub
Sep 05, 2023
Browse files
[GraphBolt] convert item list to MiniBatch (#6281)
parent
dadce86a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
468 additions
and
183 deletions
+468
-183
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+182
-76
tests/python/pytorch/graphbolt/test_base.py
tests/python/pytorch/graphbolt/test_base.py
+1
-1
tests/python/pytorch/graphbolt/test_item_sampler.py
tests/python/pytorch/graphbolt/test_item_sampler.py
+285
-106
No files found.
python/dgl/graphbolt/item_sampler.py
View file @
29949322
...
...
@@ -2,16 +2,73 @@
from
collections.abc
import
Mapping
from
functools
import
partial
from
typing
import
Iterator
,
Optional
from
typing
import
Callable
,
Iterator
,
Optional
from
torch.utils.data
import
default_collate
from
torchdata.datapipes.iter
import
IterableWrapper
,
IterDataPipe
from
..base
import
dgl_warning
from
..batch
import
batch
as
dgl_batch
from
..heterograph
import
DGLGraph
from
.itemset
import
ItemSet
,
ItemSetDict
from
.minibatch
import
MiniBatch
__all__
=
[
"ItemSampler"
,
"minibatcher_default"
]
def
minibatcher_default
(
batch
,
names
):
"""Default minibatcher.
The default minibatcher maps a list of items to a `MiniBatch` with the
same names as the items. The names of items are supposed to be provided
and align with the data attributes of `MiniBatch`. If any unknown item name
is provided, exception will be raised. If the names of items are not
provided, the item list is returned as is and a warning will be raised.
Parameters
----------
batch : list
List of items.
names : Tuple[str] or None
Names of items in `batch` with same length. The order should align
with `batch`.
__all__
=
[
"ItemSampler"
]
Returns
-------
MiniBatch
A minibatch.
"""
if
names
is
None
:
dgl_warning
(
"Failed to map item list to `MiniBatch` as the names of items are "
"not provided. Please provide a customized `MiniBatcher`. "
"The item list is returned as is."
)
return
batch
if
len
(
names
)
==
1
:
# Handle the case of single item: batch = tensor([0, 1, 2, 3]), names =
# ("seed_nodes",) as `zip(batch, names)` will iterate over the tensor
# instead of the batch.
init_data
=
{
names
[
0
]:
batch
}
else
:
if
isinstance
(
batch
,
Mapping
):
init_data
=
{
name
:
{
k
:
v
[
i
]
for
k
,
v
in
batch
.
items
()}
for
i
,
name
in
enumerate
(
names
)
}
else
:
init_data
=
{
name
:
item
for
item
,
name
in
zip
(
batch
,
names
)}
minibatch
=
MiniBatch
()
for
name
,
item
in
init_data
.
items
():
if
not
hasattr
(
minibatch
,
name
):
dgl_warning
(
f
"Unknown item name '
{
name
}
' is detected and added into "
"`MiniBatch`. You probably need to provide a customized "
"`MiniBatcher`."
)
setattr
(
minibatch
,
name
,
item
)
return
minibatch
class
ItemSampler
(
IterDataPipe
):
...
...
@@ -32,6 +89,8 @@ class ItemSampler(IterDataPipe):
Data to be sampled.
batch_size : int
The size of each batch.
minibatcher : Optional[Callable]
A callable that takes in a list of items and returns a `MiniBatch`.
drop_last : bool
Option to drop the last batch if it's not full.
shuffle : bool
...
...
@@ -42,41 +101,68 @@ class ItemSampler(IterDataPipe):
1. Node IDs.
>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> item_set = gb.ItemSet(torch.arange(0, 10)
, names="seed_nodes"
)
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> list(item_sampler)
[tensor([1, 2, 5, 7]), tensor([3, 0, 9, 4]), tensor([6, 8])]
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=tensor([9, 0, 7, 2]), node_pairs=None, labels=None,
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
2. Node pairs.
>>> item_set = gb.ItemSet((torch.arange(0, 10), torch.arange(10, 20)))
>>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2),
... names="node_pairs")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> list(item_sampler)
[[tensor([9, 8, 3, 1]), tensor([19, 18, 13, 11])], [tensor([2, 5, 7, 4]),
tensor([12, 15, 17, 14])], [tensor([0, 6]), tensor([10, 16])]
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[16, 17],
[ 4, 5],
[ 6, 7],
[10, 11]]), labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
3. Node pairs and labels.
>>> item_set = gb.ItemSet(
... (torch.arange(0, 5), torch.arange(5, 10), torch.arange(10, 15))
... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 15)),
... names=("node_pairs", "labels")
... )
>>> item_sampler = gb.ItemSampler(item_set, 3)
>>> list(item_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]), tensor([10, 11, 12])],
[tensor([3, 4]), tensor([8, 9]), tensor([13, 14])]]
4. Head, tail and negative tails
>>> heads = torch.arange(0, 5)
>>> tails = torch.arange(5, 10)
>>> negative_tails = torch.stack((heads + 1, heads + 2), dim=-1)
>>> item_set = gb.ItemSet((heads, tails, negative_tails))
>>> item_sampler = gb.ItemSampler(item_set, 3)
>>> list(item_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]),
tensor([[1, 2], [2, 3], [3, 4]])],
[tensor([3, 4]), tensor([8, 9]), tensor([[4, 5], [5, 6]])]]
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[8, 9],
[4, 5],
[0, 1],
[6, 7]]), labels=tensor([14, 12, 10, 13]), negative_srcs=None,
negative_dsts=None, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
4. Node pairs and negative destinations.
>>> node_pairs = torch.arange(0, 20).reshape(-1, 2)
>>> negative_dsts = torch.arange(10, 30).reshape(-1, 2)
>>> item_set = gb.ItemSet((node_pairs, negative_dsts), names=("node_pairs",
... "negative_dsts"))
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[10, 11],
[ 6, 7],
[ 2, 3],
[ 8, 9]]), labels=None, negative_srcs=None,
negative_dsts=tensor([[20, 21],
[16, 17],
[12, 13],
[18, 19]]), sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
5. DGLGraphs.
>>> import dgl
...
...
@@ -103,81 +189,96 @@ class ItemSampler(IterDataPipe):
7. Heterogeneous node IDs.
>>> ids = {
... "user": gb.ItemSet(torch.arange(0, 5)),
... "item": gb.ItemSet(torch.arange(0, 6)),
... "user": gb.ItemSet(torch.arange(0, 5)
, names="seed_nodes"
),
... "item": gb.ItemSet(torch.arange(0, 6)
, names="seed_nodes"
),
... }
>>> item_set = gb.ItemSetDict(ids)
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{'user': tensor([0, 1, 2, 3])},
{'item': tensor([0, 1, 2]), 'user': tensor([4])},
{'item': tensor([3, 4, 5])}]
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes={'user': tensor([0, 1, 2, 3])}, node_pairs=None,
labels=None, negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
8. Heterogeneous node pairs.
>>> node_pairs_like =
(
torch.arange(0,
5), torch.arange(0, 5)
)
>>> node_pairs_follow =
(
torch.arange(0,
6), torch.arange(6
,
1
2)
)
>>> node_pairs_like = torch.arange(0,
10).reshape(-1, 2
)
>>> node_pairs_follow = torch.arange(
1
0,
20).reshape(-1
, 2)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(node_pairs_like),
... "user:follow:user": gb.ItemSet(node_pairs_follow),
... "user:like:item": gb.ItemSet(
... node_pairs_like, names="node_pairs"),
... "user:follow:user": gb.ItemSet(
... node_pairs_follow, names="node_pairs"),
... })
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{"user:like:item": [tensor([4]), tensor([4])],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8])]},
{"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11])]}]
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])}, labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
9. Heterogeneous node pairs and labels.
>>>
like = (
... torch.arange(0, 5), torch.arange(0, 5),
torch.arange(0,
5)
)
>>>
follow = (
... torch.arange(0, 6), torch.arange(6, 12),
torch.arange(0,
6)
)
>>>
node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> labels_like =
torch.arange(0,
10
)
>>>
node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
>>> labels_follow =
torch.arange(
1
0,
20
)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow),
... "user:like:item": gb.ItemSet((node_pairs_like, labels_like),
... names=("node_pairs", "labels")),
... "user:follow:user": gb.ItemSet((node_pairs_follow, labels_follow),
... names=("node_pairs", "labels")),
... })
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{"user:like:item":
[tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{"user:like:item": [tensor([4]), tensor([4]), tensor([4])],
"user:follow:user":
[tensor([0, 1, 2]), tensor([6, 7, 8]), tensor([0, 1, 2])]},
{"user:follow:user":
[tensor([3, 4, 5]), tensor([ 9, 10, 11]), tensor([3, 4, 5])]}]
10. Heterogeneous head, tail and negative tails.
>>> like = (
... torch.arange(0, 5), torch.arange(0, 5),
... torch.arange(5, 15).reshape(-1, 2))
>>> follow = (
... torch.arange(0, 6), torch.arange(6, 12),
... torch.arange(12, 24).reshape(-1, 2))
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])}, labels={'user:like:item': tensor([0, 1, 2, 3])},
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
10. Heterogeneous node pairs and negative destinations.
>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> negative_dsts_like = torch.arange(10, 20).reshape(-1, 2)
>>> node_pairs_follow = torch.arange(20, 30).reshape(-1, 2)
>>> negative_dsts_follow = torch.arange(30, 40).reshape(-1, 2)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow),
... "user:like:item": gb.ItemSet((node_pairs_like, negative_dsts_like),
... names=("node_pairs", "negative_dsts")),
... "user:follow:user": gb.ItemSet((node_pairs_follow,
... negative_dsts_follow), names=("node_pairs", "negative_dsts")),
... })
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]),
tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]},
{"user:like:item": [tensor([4]), tensor([4]), tensor([[13, 14]])],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8]),
tensor([[12, 13], [14, 15], [16, 17]])]},
{"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11]),
tensor([[18, 19], [20, 21], [22, 23]])]}]
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])}, labels=None, negative_srcs=None,
negative_dsts={'user:like:item': tensor([[10, 11],
[12, 13],
[14, 15],
[16, 17]])}, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
"""
def
__init__
(
self
,
item_set
:
ItemSet
or
ItemSetDict
,
batch_size
:
int
,
minibatcher
:
Optional
[
Callable
]
=
minibatcher_default
,
drop_last
:
Optional
[
bool
]
=
False
,
shuffle
:
Optional
[
bool
]
=
False
,
)
->
None
:
super
().
__init__
()
self
.
_item_set
=
item_set
self
.
_batch_size
=
batch_size
self
.
_minibatcher
=
minibatcher
self
.
_drop_last
=
drop_last
self
.
_shuffle
=
shuffle
...
...
@@ -217,4 +318,9 @@ class ItemSampler(IterDataPipe):
data_pipe
=
data_pipe
.
collate
(
collate_fn
=
partial
(
_collate
))
# Map to minibatch.
data_pipe
=
data_pipe
.
map
(
partial
(
self
.
_minibatcher
,
names
=
self
.
_item_set
.
names
)
)
return
iter
(
data_pipe
)
tests/python/pytorch/graphbolt/test_base.py
View file @
29949322
...
...
@@ -10,7 +10,7 @@ import torch
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"cpu"
,
"CopyTo needs GPU to test"
)
def
test_CopyTo
():
dp
=
gb
.
ItemSampler
(
torch
.
randn
(
20
),
4
)
dp
=
gb
.
ItemSampler
(
gb
.
ItemSet
(
torch
.
randn
(
20
)
)
,
4
)
dp
=
gb
.
CopyTo
(
dp
,
"cuda"
)
for
data
in
dp
:
...
...
tests/python/pytorch/graphbolt/test_item_sampler.py
View file @
29949322
import
re
import
dgl
import
pytest
import
torch
...
...
@@ -5,31 +7,126 @@ from dgl import graphbolt as gb
from
torch.testing
import
assert_close
def
test_ItemSampler_minibatcher
():
# Default minibatcher is used if not specified.
# Warning message is raised if names are not specified.
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
10
))
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
4
)
with
pytest
.
warns
(
UserWarning
,
match
=
re
.
escape
(
"Failed to map item list to `MiniBatch` as the names of items are "
"not provided. Please provide a customized `MiniBatcher`. The "
"item list is returned as is."
),
):
minibatch
=
next
(
iter
(
item_sampler
))
assert
not
isinstance
(
minibatch
,
gb
.
MiniBatch
)
# Default minibatcher is used if not specified.
# Warning message is raised if unrecognized names are specified.
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
10
),
names
=
"unknown_name"
)
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
4
)
with
pytest
.
warns
(
UserWarning
,
match
=
re
.
escape
(
"Unknown item name 'unknown_name' is detected and added into "
"`MiniBatch`. You probably need to provide a customized "
"`MiniBatcher`."
),
):
minibatch
=
next
(
iter
(
item_sampler
))
assert
isinstance
(
minibatch
,
gb
.
MiniBatch
)
assert
minibatch
.
unknown_name
is
not
None
# Default minibatcher is used if not specified.
# `MiniBatch` is returned if expected names are specified.
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
10
),
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
4
)
minibatch
=
next
(
iter
(
item_sampler
))
assert
isinstance
(
minibatch
,
gb
.
MiniBatch
)
assert
minibatch
.
seed_nodes
is
not
None
assert
len
(
minibatch
.
seed_nodes
)
==
4
# Customized minibatcher is used if specified.
def
minibatcher
(
batch
,
names
):
return
gb
.
MiniBatch
(
seed_nodes
=
batch
)
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
4
,
minibatcher
=
minibatcher
)
minibatch
=
next
(
iter
(
item_sampler
))
assert
isinstance
(
minibatch
,
gb
.
MiniBatch
)
assert
minibatch
.
seed_nodes
is
not
None
assert
len
(
minibatch
.
seed_nodes
)
==
4
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSet_node
_id
s
(
batch_size
,
shuffle
,
drop_last
):
def
test_ItemSet_
seed_
nodes
(
batch_size
,
shuffle
,
drop_last
):
# Node IDs.
num_ids
=
103
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
num_ids
))
seed_nodes
=
torch
.
arange
(
0
,
num_ids
)
item_set
=
gb
.
ItemSet
(
seed_nodes
,
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
minibatch_ids
=
[]
for
i
,
minibatch
in
enumerate
(
item_sampler
):
assert
isinstance
(
minibatch
,
gb
.
MiniBatch
)
assert
minibatch
.
seed_nodes
is
not
None
assert
minibatch
.
labels
is
None
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
assert
len
(
minibatch
)
==
batch_size
assert
len
(
minibatch
.
seed_nodes
)
==
batch_size
else
:
if
not
drop_last
:
assert
len
(
minibatch
)
==
num_ids
%
batch_size
assert
len
(
minibatch
.
seed_nodes
)
==
num_ids
%
batch_size
else
:
assert
False
minibatch_ids
.
append
(
minibatch
)
minibatch_ids
.
append
(
minibatch
.
seed_nodes
)
minibatch_ids
=
torch
.
cat
(
minibatch_ids
)
assert
torch
.
all
(
minibatch_ids
[:
-
1
]
<=
minibatch_ids
[
1
:])
is
not
shuffle
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSet_seed_nodes_labels
(
batch_size
,
shuffle
,
drop_last
):
# Node IDs.
num_ids
=
103
seed_nodes
=
torch
.
arange
(
0
,
num_ids
)
labels
=
torch
.
arange
(
0
,
num_ids
)
item_set
=
gb
.
ItemSet
((
seed_nodes
,
labels
),
names
=
(
"seed_nodes"
,
"labels"
))
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
minibatch_ids
=
[]
minibatch_labels
=
[]
for
i
,
minibatch
in
enumerate
(
item_sampler
):
assert
isinstance
(
minibatch
,
gb
.
MiniBatch
)
assert
minibatch
.
seed_nodes
is
not
None
assert
minibatch
.
labels
is
not
None
assert
len
(
minibatch
.
seed_nodes
)
==
len
(
minibatch
.
labels
)
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
assert
len
(
minibatch
.
seed_nodes
)
==
batch_size
else
:
if
not
drop_last
:
assert
len
(
minibatch
.
seed_nodes
)
==
num_ids
%
batch_size
else
:
assert
False
minibatch_ids
.
append
(
minibatch
.
seed_nodes
)
minibatch_labels
.
append
(
minibatch
.
labels
)
minibatch_ids
=
torch
.
cat
(
minibatch_ids
)
minibatch_labels
=
torch
.
cat
(
minibatch_labels
)
assert
torch
.
all
(
minibatch_ids
[:
-
1
]
<=
minibatch_ids
[
1
:])
is
not
shuffle
assert
(
torch
.
all
(
minibatch_labels
[:
-
1
]
<=
minibatch_labels
[
1
:])
is
not
shuffle
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
...
...
@@ -77,14 +174,18 @@ def test_ItemSet_graphs(batch_size, shuffle, drop_last):
def
test_ItemSet_node_pairs
(
batch_size
,
shuffle
,
drop_last
):
# Node pairs.
num_ids
=
103
node_pairs
=
(
torch
.
arange
(
0
,
num_ids
)
,
torch
.
arange
(
num_ids
,
num_ids
*
2
)
)
item_set
=
gb
.
ItemSet
(
node_pairs
)
node_pairs
=
torch
.
arange
(
0
,
2
*
num_ids
)
.
reshape
(
-
1
,
2
)
item_set
=
gb
.
ItemSet
(
node_pairs
,
names
=
"node_pairs"
)
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
src_ids
=
[]
dst_ids
=
[]
for
i
,
(
src
,
dst
)
in
enumerate
(
item_sampler
):
for
i
,
minibatch
in
enumerate
(
item_sampler
):
assert
minibatch
.
node_pairs
is
not
None
assert
minibatch
.
labels
is
None
src
=
minibatch
.
node_pairs
[:,
0
]
dst
=
minibatch
.
node_pairs
[:,
1
]
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
...
...
@@ -96,7 +197,7 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
assert
len
(
src
)
==
expected_batch_size
assert
len
(
dst
)
==
expected_batch_size
# Verify src and dst IDs match.
assert
torch
.
equal
(
src
+
num_ids
,
dst
)
assert
torch
.
equal
(
src
+
1
,
dst
)
# Archive batch.
src_ids
.
append
(
src
)
dst_ids
.
append
(
dst
)
...
...
@@ -112,16 +213,22 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
def
test_ItemSet_node_pairs_labels
(
batch_size
,
shuffle
,
drop_last
):
# Node pairs and labels
num_ids
=
103
node_pairs
=
(
torch
.
arange
(
0
,
num_ids
)
,
torch
.
arange
(
num_ids
,
num_ids
*
2
)
)
labels
=
torch
.
arange
(
0
,
num_ids
)
item_set
=
gb
.
ItemSet
((
node_pairs
[
0
],
node_pairs
[
1
]
,
labels
))
node_pairs
=
torch
.
arange
(
0
,
2
*
num_ids
)
.
reshape
(
-
1
,
2
)
labels
=
node_pairs
[:,
0
]
item_set
=
gb
.
ItemSet
((
node_pairs
,
labels
),
names
=
(
"
node_pairs
"
,
"
labels
"
))
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
src_ids
=
[]
dst_ids
=
[]
labels
=
[]
for
i
,
(
src
,
dst
,
label
)
in
enumerate
(
item_sampler
):
for
i
,
minibatch
in
enumerate
(
item_sampler
):
assert
minibatch
.
node_pairs
is
not
None
assert
minibatch
.
labels
is
not
None
assert
len
(
minibatch
.
node_pairs
)
==
len
(
minibatch
.
labels
)
src
=
minibatch
.
node_pairs
[:,
0
]
dst
=
minibatch
.
node_pairs
[:,
1
]
label
=
minibatch
.
labels
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
...
...
@@ -134,7 +241,7 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
assert
len
(
dst
)
==
expected_batch_size
assert
len
(
label
)
==
expected_batch_size
# Verify src/dst IDs and labels match.
assert
torch
.
equal
(
src
+
num_ids
,
dst
)
assert
torch
.
equal
(
src
+
1
,
dst
)
assert
torch
.
equal
(
src
,
label
)
# Archive batch.
src_ids
.
append
(
src
)
...
...
@@ -151,25 +258,29 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSet_
head_tail_neg_tail
s
(
batch_size
,
shuffle
,
drop_last
):
#
Head, tail
and negative
tail
s.
def
test_ItemSet_
node_pairs_negative_dst
s
(
batch_size
,
shuffle
,
drop_last
):
#
Node pairs
and negative
destination
s.
num_ids
=
103
num_negs
=
2
heads
=
torch
.
arange
(
0
,
num_ids
)
tails
=
torch
.
arange
(
num_ids
,
num_ids
*
2
)
neg_tails
=
torch
.
stack
((
heads
+
1
,
heads
+
2
),
dim
=-
1
)
item_set
=
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
))
for
i
,
(
head
,
tail
,
negs
)
in
enumerate
(
item_set
):
assert
heads
[
i
]
==
head
assert
tails
[
i
]
==
tail
assert
torch
.
equal
(
neg_tails
[
i
],
negs
)
node_pairs
=
torch
.
arange
(
0
,
2
*
num_ids
).
reshape
(
-
1
,
2
)
neg_dsts
=
torch
.
arange
(
2
*
num_ids
,
2
*
num_ids
+
num_ids
*
num_negs
).
reshape
(
-
1
,
num_negs
)
item_set
=
gb
.
ItemSet
(
(
node_pairs
,
neg_dsts
),
names
=
(
"node_pairs"
,
"negative_dsts"
)
)
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
head
_ids
=
[]
tail
_ids
=
[]
src
_ids
=
[]
dst
_ids
=
[]
negs_ids
=
[]
for
i
,
(
head
,
tail
,
negs
)
in
enumerate
(
item_sampler
):
for
i
,
minibatch
in
enumerate
(
item_sampler
):
assert
minibatch
.
node_pairs
is
not
None
assert
minibatch
.
negative_dsts
is
not
None
src
=
minibatch
.
node_pairs
[:,
0
]
dst
=
minibatch
.
node_pairs
[:,
1
]
negs
=
minibatch
.
negative_dsts
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
...
...
@@ -178,24 +289,23 @@ def test_ItemSet_head_tail_neg_tails(batch_size, shuffle, drop_last):
expected_batch_size
=
num_ids
%
batch_size
else
:
assert
False
assert
len
(
head
)
==
expected_batch_size
assert
len
(
tail
)
==
expected_batch_size
assert
len
(
src
)
==
expected_batch_size
assert
len
(
dst
)
==
expected_batch_size
assert
negs
.
dim
()
==
2
assert
negs
.
shape
[
0
]
==
expected_batch_size
assert
negs
.
shape
[
1
]
==
num_negs
# Verify head/tail and negatie tails match.
assert
torch
.
equal
(
head
+
num_ids
,
tail
)
assert
torch
.
equal
(
head
+
1
,
negs
[:,
0
])
assert
torch
.
equal
(
head
+
2
,
negs
[:,
1
])
# Verify node pairs and negative destinations.
assert
torch
.
equal
(
src
+
1
,
dst
)
assert
torch
.
equal
(
negs
[:,
0
]
+
1
,
negs
[:,
1
])
# Archive batch.
head
_ids
.
append
(
head
)
tail
_ids
.
append
(
tail
)
src
_ids
.
append
(
src
)
dst
_ids
.
append
(
dst
)
negs_ids
.
append
(
negs
)
head
_ids
=
torch
.
cat
(
head
_ids
)
tail
_ids
=
torch
.
cat
(
tail
_ids
)
src
_ids
=
torch
.
cat
(
src
_ids
)
dst
_ids
=
torch
.
cat
(
dst
_ids
)
negs_ids
=
torch
.
cat
(
negs_ids
)
assert
torch
.
all
(
head
_ids
[:
-
1
]
<=
head
_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
tail
_ids
[:
-
1
]
<=
tail
_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
src
_ids
[:
-
1
]
<=
src
_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
dst
_ids
[:
-
1
]
<=
dst
_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
negs_ids
[:
-
1
,
0
]
<=
negs_ids
[
1
:,
0
])
is
not
shuffle
assert
torch
.
all
(
negs_ids
[:
-
1
,
1
]
<=
negs_ids
[
1
:,
1
])
is
not
shuffle
...
...
@@ -215,12 +325,57 @@ def test_append_with_other_datapipes():
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSetDict_node_ids
(
batch_size
,
shuffle
,
drop_last
):
def
test_ItemSetDict_seed_nodes
(
batch_size
,
shuffle
,
drop_last
):
# Node IDs.
num_ids
=
205
ids
=
{
"user"
:
gb
.
ItemSet
(
torch
.
arange
(
0
,
99
),
names
=
"seed_nodes"
),
"item"
:
gb
.
ItemSet
(
torch
.
arange
(
99
,
num_ids
),
names
=
"seed_nodes"
),
}
chained_ids
=
[]
for
key
,
value
in
ids
.
items
():
chained_ids
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
ItemSetDict
(
ids
)
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
minibatch_ids
=
[]
for
i
,
minibatch
in
enumerate
(
item_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
else
:
if
not
drop_last
:
expected_batch_size
=
num_ids
%
batch_size
else
:
assert
False
assert
isinstance
(
minibatch
,
gb
.
MiniBatch
)
assert
minibatch
.
seed_nodes
is
not
None
ids
=
[]
for
_
,
v
in
minibatch
.
seed_nodes
.
items
():
ids
.
append
(
v
)
ids
=
torch
.
cat
(
ids
)
assert
len
(
ids
)
==
expected_batch_size
minibatch_ids
.
append
(
ids
)
minibatch_ids
=
torch
.
cat
(
minibatch_ids
)
assert
torch
.
all
(
minibatch_ids
[:
-
1
]
<=
minibatch_ids
[
1
:])
is
not
shuffle
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSetDict_seed_nodes_labels
(
batch_size
,
shuffle
,
drop_last
):
# Node IDs.
num_ids
=
205
ids
=
{
"user"
:
gb
.
ItemSet
(
torch
.
arange
(
0
,
99
)),
"item"
:
gb
.
ItemSet
(
torch
.
arange
(
99
,
num_ids
)),
"user"
:
gb
.
ItemSet
(
(
torch
.
arange
(
0
,
99
),
torch
.
arange
(
0
,
99
)),
names
=
(
"seed_nodes"
,
"labels"
),
),
"item"
:
gb
.
ItemSet
(
(
torch
.
arange
(
99
,
num_ids
),
torch
.
arange
(
99
,
num_ids
)),
names
=
(
"seed_nodes"
,
"labels"
),
),
}
chained_ids
=
[]
for
key
,
value
in
ids
.
items
():
...
...
@@ -230,7 +385,11 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
minibatch_ids
=
[]
for
i
,
batch
in
enumerate
(
item_sampler
):
minibatch_labels
=
[]
for
i
,
minibatch
in
enumerate
(
item_sampler
):
assert
isinstance
(
minibatch
,
gb
.
MiniBatch
)
assert
minibatch
.
seed_nodes
is
not
None
assert
minibatch
.
labels
is
not
None
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
...
...
@@ -239,15 +398,24 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
expected_batch_size
=
num_ids
%
batch_size
else
:
assert
False
assert
isinstance
(
batch
,
dict
)
ids
=
[]
for
_
,
v
in
batch
.
items
():
for
_
,
v
in
mini
batch
.
seed_nodes
.
items
():
ids
.
append
(
v
)
ids
=
torch
.
cat
(
ids
)
assert
len
(
ids
)
==
expected_batch_size
minibatch_ids
.
append
(
ids
)
labels
=
[]
for
_
,
v
in
minibatch
.
labels
.
items
():
labels
.
append
(
v
)
labels
=
torch
.
cat
(
labels
)
assert
len
(
labels
)
==
expected_batch_size
minibatch_labels
.
append
(
labels
)
minibatch_ids
=
torch
.
cat
(
minibatch_ids
)
minibatch_labels
=
torch
.
cat
(
minibatch_labels
)
assert
torch
.
all
(
minibatch_ids
[:
-
1
]
<=
minibatch_ids
[
1
:])
is
not
shuffle
assert
(
torch
.
all
(
minibatch_labels
[:
-
1
]
<=
minibatch_labels
[
1
:])
is
not
shuffle
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
...
...
@@ -256,18 +424,12 @@ def test_ItemSetDict_node_ids(batch_size, shuffle, drop_last):
def
test_ItemSetDict_node_pairs
(
batch_size
,
shuffle
,
drop_last
):
# Node pairs.
num_ids
=
103
total_ids
=
2
*
num_ids
node_pairs_0
=
(
torch
.
arange
(
0
,
num_ids
),
torch
.
arange
(
num_ids
,
num_ids
*
2
),
)
node_pairs_1
=
(
torch
.
arange
(
num_ids
*
2
,
num_ids
*
3
),
torch
.
arange
(
num_ids
*
3
,
num_ids
*
4
),
)
total_pairs
=
2
*
num_ids
node_pairs_like
=
torch
.
arange
(
0
,
num_ids
*
2
).
reshape
(
-
1
,
2
)
node_pairs_follow
=
torch
.
arange
(
num_ids
*
2
,
num_ids
*
4
).
reshape
(
-
1
,
2
)
node_pairs_dict
=
{
"user:like:item"
:
gb
.
ItemSet
(
node_pairs_
0
),
"user:follow:user"
:
gb
.
ItemSet
(
node_pairs_
1
),
"user:like:item"
:
gb
.
ItemSet
(
node_pairs_
like
,
names
=
"node_pairs"
),
"user:follow:user"
:
gb
.
ItemSet
(
node_pairs_
follow
,
names
=
"node_pairs"
),
}
item_set
=
gb
.
ItemSetDict
(
node_pairs_dict
)
item_sampler
=
gb
.
ItemSampler
(
...
...
@@ -275,27 +437,30 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
)
src_ids
=
[]
dst_ids
=
[]
for
i
,
batch
in
enumerate
(
item_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
total_ids
if
not
is_last
or
total_ids
%
batch_size
==
0
:
for
i
,
minibatch
in
enumerate
(
item_sampler
):
assert
isinstance
(
minibatch
,
gb
.
MiniBatch
)
assert
minibatch
.
node_pairs
is
not
None
assert
minibatch
.
labels
is
None
is_last
=
(
i
+
1
)
*
batch_size
>=
total_pairs
if
not
is_last
or
total_pairs
%
batch_size
==
0
:
expected_batch_size
=
batch_size
else
:
if
not
drop_last
:
expected_batch_size
=
total_
id
s
%
batch_size
expected_batch_size
=
total_
pair
s
%
batch_size
else
:
assert
False
src
=
[]
dst
=
[]
for
_
,
(
v_src
,
v_dst
)
in
batch
.
items
():
src
.
append
(
v_src
)
dst
.
append
(
v_dst
)
for
_
,
node_pairs
in
minibatch
.
node_pairs
.
items
():
src
.
append
(
node_pairs
[:,
0
]
)
dst
.
append
(
node_pairs
[:,
1
]
)
src
=
torch
.
cat
(
src
)
dst
=
torch
.
cat
(
dst
)
assert
len
(
src
)
==
expected_batch_size
assert
len
(
dst
)
==
expected_batch_size
src_ids
.
append
(
src
)
dst_ids
.
append
(
dst
)
assert
torch
.
equal
(
src
+
num_ids
,
dst
)
assert
torch
.
equal
(
src
+
1
,
dst
)
src_ids
=
torch
.
cat
(
src_ids
)
dst_ids
=
torch
.
cat
(
dst_ids
)
assert
torch
.
all
(
src_ids
[:
-
1
]
<=
src_ids
[
1
:])
is
not
shuffle
...
...
@@ -309,21 +474,17 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
# Node pairs and labels
num_ids
=
103
total_ids
=
2
*
num_ids
node_pairs_0
=
(
torch
.
arange
(
0
,
num_ids
),
torch
.
arange
(
num_ids
,
num_ids
*
2
),
)
node_pairs_1
=
(
torch
.
arange
(
num_ids
*
2
,
num_ids
*
3
),
torch
.
arange
(
num_ids
*
3
,
num_ids
*
4
),
)
node_pairs_like
=
torch
.
arange
(
0
,
num_ids
*
2
).
reshape
(
-
1
,
2
)
node_pairs_follow
=
torch
.
arange
(
num_ids
*
2
,
num_ids
*
4
).
reshape
(
-
1
,
2
)
labels
=
torch
.
arange
(
0
,
num_ids
)
node_pairs_dict
=
{
"user:like:item"
:
gb
.
ItemSet
(
(
node_pairs_0
[
0
],
node_pairs_0
[
1
],
labels
)
(
node_pairs_like
,
node_pairs_like
[:,
0
]),
names
=
(
"node_pairs"
,
"labels"
),
),
"user:follow:user"
:
gb
.
ItemSet
(
(
node_pairs_1
[
0
],
node_pairs_1
[
1
],
labels
+
num_ids
*
2
)
(
node_pairs_follow
,
node_pairs_follow
[:,
0
]),
names
=
(
"node_pairs"
,
"labels"
),
),
}
item_set
=
gb
.
ItemSetDict
(
node_pairs_dict
)
...
...
@@ -333,7 +494,10 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
src_ids
=
[]
dst_ids
=
[]
labels
=
[]
for
i
,
batch
in
enumerate
(
item_sampler
):
for
i
,
minibatch
in
enumerate
(
item_sampler
):
assert
isinstance
(
minibatch
,
gb
.
MiniBatch
)
assert
minibatch
.
node_pairs
is
not
None
assert
minibatch
.
labels
is
not
None
is_last
=
(
i
+
1
)
*
batch_size
>=
total_ids
if
not
is_last
or
total_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
...
...
@@ -345,9 +509,10 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
src
=
[]
dst
=
[]
label
=
[]
for
_
,
(
v_src
,
v_dst
,
v_label
)
in
batch
.
items
():
src
.
append
(
v_src
)
dst
.
append
(
v_dst
)
for
_
,
node_pairs
in
minibatch
.
node_pairs
.
items
():
src
.
append
(
node_pairs
[:,
0
])
dst
.
append
(
node_pairs
[:,
1
])
for
_
,
v_label
in
minibatch
.
labels
.
items
():
label
.
append
(
v_label
)
src
=
torch
.
cat
(
src
)
dst
=
torch
.
cat
(
dst
)
...
...
@@ -358,7 +523,7 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
src_ids
.
append
(
src
)
dst_ids
.
append
(
dst
)
labels
.
append
(
label
)
assert
torch
.
equal
(
src
+
num_ids
,
dst
)
assert
torch
.
equal
(
src
+
1
,
dst
)
assert
torch
.
equal
(
src
,
label
)
src_ids
=
torch
.
cat
(
src_ids
)
dst_ids
=
torch
.
cat
(
dst_ids
)
...
...
@@ -371,26 +536,40 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSetDict_
head_tail_neg_tail
s
(
batch_size
,
shuffle
,
drop_last
):
def
test_ItemSetDict_
node_pairs_negative_dst
s
(
batch_size
,
shuffle
,
drop_last
):
# Head, tail and negative tails.
num_ids
=
103
total_ids
=
2
*
num_ids
num_negs
=
2
heads
=
torch
.
arange
(
0
,
num_ids
)
tails
=
torch
.
arange
(
num_ids
,
num_ids
*
2
)
neg_tails
=
torch
.
stack
((
heads
+
1
,
heads
+
2
),
dim
=-
1
)
node_paris_like
=
torch
.
arange
(
0
,
num_ids
*
2
).
reshape
(
-
1
,
2
)
node_pairs_follow
=
torch
.
arange
(
num_ids
*
2
,
num_ids
*
4
).
reshape
(
-
1
,
2
)
neg_dsts_like
=
torch
.
arange
(
num_ids
*
4
,
num_ids
*
4
+
num_ids
*
num_negs
).
reshape
(
-
1
,
num_negs
)
neg_dsts_follow
=
torch
.
arange
(
num_ids
*
4
+
num_ids
*
num_negs
,
num_ids
*
4
+
num_ids
*
num_negs
*
2
).
reshape
(
-
1
,
num_negs
)
data_dict
=
{
"user:like:item"
:
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
)),
"user:follow:user"
:
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
)),
"user:like:item"
:
gb
.
ItemSet
(
(
node_paris_like
,
neg_dsts_like
),
names
=
(
"node_pairs"
,
"negative_dsts"
),
),
"user:follow:user"
:
gb
.
ItemSet
(
(
node_pairs_follow
,
neg_dsts_follow
),
names
=
(
"node_pairs"
,
"negative_dsts"
),
),
}
item_set
=
gb
.
ItemSetDict
(
data_dict
)
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
head
_ids
=
[]
tail
_ids
=
[]
src
_ids
=
[]
dst
_ids
=
[]
negs_ids
=
[]
for
i
,
batch
in
enumerate
(
item_sampler
):
for
i
,
minibatch
in
enumerate
(
item_sampler
):
assert
isinstance
(
minibatch
,
gb
.
MiniBatch
)
assert
minibatch
.
node_pairs
is
not
None
assert
minibatch
.
negative_dsts
is
not
None
is_last
=
(
i
+
1
)
*
batch_size
>=
total_ids
if
not
is_last
or
total_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
...
...
@@ -399,31 +578,31 @@ def test_ItemSetDict_head_tail_neg_tails(batch_size, shuffle, drop_last):
expected_batch_size
=
total_ids
%
batch_size
else
:
assert
False
head
=
[]
tail
=
[]
src
=
[]
dst
=
[]
negs
=
[]
for
_
,
(
v_head
,
v_tail
,
v_negs
)
in
batch
.
items
():
head
.
append
(
v_head
)
tail
.
append
(
v_tail
)
for
_
,
node_pairs
in
minibatch
.
node_pairs
.
items
():
src
.
append
(
node_pairs
[:,
0
])
dst
.
append
(
node_pairs
[:,
1
])
for
_
,
v_negs
in
minibatch
.
negative_dsts
.
items
():
negs
.
append
(
v_negs
)
head
=
torch
.
cat
(
head
)
tail
=
torch
.
cat
(
tail
)
src
=
torch
.
cat
(
src
)
dst
=
torch
.
cat
(
dst
)
negs
=
torch
.
cat
(
negs
)
assert
len
(
head
)
==
expected_batch_size
assert
len
(
tail
)
==
expected_batch_size
assert
len
(
src
)
==
expected_batch_size
assert
len
(
dst
)
==
expected_batch_size
assert
len
(
negs
)
==
expected_batch_size
head
_ids
.
append
(
head
)
tail
_ids
.
append
(
tail
)
src
_ids
.
append
(
src
)
dst
_ids
.
append
(
dst
)
negs_ids
.
append
(
negs
)
assert
negs
.
dim
()
==
2
assert
negs
.
shape
[
0
]
==
expected_batch_size
assert
negs
.
shape
[
1
]
==
num_negs
assert
torch
.
equal
(
head
+
num_ids
,
tail
)
assert
torch
.
equal
(
head
+
1
,
negs
[:,
0
])
assert
torch
.
equal
(
head
+
2
,
negs
[:,
1
])
head_ids
=
torch
.
cat
(
head_ids
)
tail_ids
=
torch
.
cat
(
tail_ids
)
assert
torch
.
equal
(
src
+
1
,
dst
)
assert
torch
.
equal
(
negs
[:,
0
]
+
1
,
negs
[:,
1
])
src_ids
=
torch
.
cat
(
src_ids
)
dst_ids
=
torch
.
cat
(
dst_ids
)
negs_ids
=
torch
.
cat
(
negs_ids
)
assert
torch
.
all
(
head
_ids
[:
-
1
]
<=
head
_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
tail
_ids
[:
-
1
]
<=
tail
_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
src
_ids
[:
-
1
]
<=
src
_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
dst
_ids
[:
-
1
]
<=
dst
_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
negs_ids
[:
-
1
]
<=
negs_ids
[
1
:])
is
not
shuffle
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment