Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
29949322
Unverified
Commit
29949322
authored
Sep 05, 2023
by
Rhett Ying
Committed by
GitHub
Sep 05, 2023
Browse files
[GraphBolt] convert item list to MiniBatch (#6281)
parent
dadce86a
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
468 additions
and
183 deletions
+468
-183
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+182
-76
tests/python/pytorch/graphbolt/test_base.py
tests/python/pytorch/graphbolt/test_base.py
+1
-1
tests/python/pytorch/graphbolt/test_item_sampler.py
tests/python/pytorch/graphbolt/test_item_sampler.py
+285
-106
No files found.
python/dgl/graphbolt/item_sampler.py
View file @
29949322
...
...
@@ -2,16 +2,73 @@
from
collections.abc
import
Mapping
from
functools
import
partial
from
typing
import
Iterator
,
Optional
from
typing
import
Callable
,
Iterator
,
Optional
from
torch.utils.data
import
default_collate
from
torchdata.datapipes.iter
import
IterableWrapper
,
IterDataPipe
from
..base
import
dgl_warning
from
..batch
import
batch
as
dgl_batch
from
..heterograph
import
DGLGraph
from
.itemset
import
ItemSet
,
ItemSetDict
from
.minibatch
import
MiniBatch
__all__
=
[
"ItemSampler"
,
"minibatcher_default"
]
def
minibatcher_default
(
batch
,
names
):
"""Default minibatcher.
The default minibatcher maps a list of items to a `MiniBatch` with the
same names as the items. The names of items are supposed to be provided
and align with the data attributes of `MiniBatch`. If any unknown item name
is provided, exception will be raised. If the names of items are not
provided, the item list is returned as is and a warning will be raised.
Parameters
----------
batch : list
List of items.
names : Tuple[str] or None
Names of items in `batch` with same length. The order should align
with `batch`.
__all__
=
[
"ItemSampler"
]
Returns
-------
MiniBatch
A minibatch.
"""
if
names
is
None
:
dgl_warning
(
"Failed to map item list to `MiniBatch` as the names of items are "
"not provided. Please provide a customized `MiniBatcher`. "
"The item list is returned as is."
)
return
batch
if
len
(
names
)
==
1
:
# Handle the case of single item: batch = tensor([0, 1, 2, 3]), names =
# ("seed_nodes",) as `zip(batch, names)` will iterate over the tensor
# instead of the batch.
init_data
=
{
names
[
0
]:
batch
}
else
:
if
isinstance
(
batch
,
Mapping
):
init_data
=
{
name
:
{
k
:
v
[
i
]
for
k
,
v
in
batch
.
items
()}
for
i
,
name
in
enumerate
(
names
)
}
else
:
init_data
=
{
name
:
item
for
item
,
name
in
zip
(
batch
,
names
)}
minibatch
=
MiniBatch
()
for
name
,
item
in
init_data
.
items
():
if
not
hasattr
(
minibatch
,
name
):
dgl_warning
(
f
"Unknown item name '
{
name
}
' is detected and added into "
"`MiniBatch`. You probably need to provide a customized "
"`MiniBatcher`."
)
setattr
(
minibatch
,
name
,
item
)
return
minibatch
class
ItemSampler
(
IterDataPipe
):
...
...
@@ -32,6 +89,8 @@ class ItemSampler(IterDataPipe):
Data to be sampled.
batch_size : int
The size of each batch.
minibatcher : Optional[Callable]
A callable that takes in a list of items and returns a `MiniBatch`.
drop_last : bool
Option to drop the last batch if it's not full.
shuffle : bool
...
...
@@ -42,41 +101,68 @@ class ItemSampler(IterDataPipe):
1. Node IDs.
>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> item_set = gb.ItemSet(torch.arange(0, 10)
, names="seed_nodes"
)
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> list(item_sampler)
[tensor([1, 2, 5, 7]), tensor([3, 0, 9, 4]), tensor([6, 8])]
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=tensor([9, 0, 7, 2]), node_pairs=None, labels=None,
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
2. Node pairs.
>>> item_set = gb.ItemSet((torch.arange(0, 10), torch.arange(10, 20)))
>>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2),
... names="node_pairs")
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> list(item_sampler)
[[tensor([9, 8, 3, 1]), tensor([19, 18, 13, 11])], [tensor([2, 5, 7, 4]),
tensor([12, 15, 17, 14])], [tensor([0, 6]), tensor([10, 16])]
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[16, 17],
[ 4, 5],
[ 6, 7],
[10, 11]]), labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
3. Node pairs and labels.
>>> item_set = gb.ItemSet(
... (torch.arange(0, 5), torch.arange(5, 10), torch.arange(10, 15))
... (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 15)),
... names=("node_pairs", "labels")
... )
>>> item_sampler = gb.ItemSampler(item_set, 3)
>>> list(item_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]), tensor([10, 11, 12])],
[tensor([3, 4]), tensor([8, 9]), tensor([13, 14])]]
4. Head, tail and negative tails
>>> heads = torch.arange(0, 5)
>>> tails = torch.arange(5, 10)
>>> negative_tails = torch.stack((heads + 1, heads + 2), dim=-1)
>>> item_set = gb.ItemSet((heads, tails, negative_tails))
>>> item_sampler = gb.ItemSampler(item_set, 3)
>>> list(item_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]),
tensor([[1, 2], [2, 3], [3, 4]])],
[tensor([3, 4]), tensor([8, 9]), tensor([[4, 5], [5, 6]])]]
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[8, 9],
[4, 5],
[0, 1],
[6, 7]]), labels=tensor([14, 12, 10, 13]), negative_srcs=None,
negative_dsts=None, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
4. Node pairs and negative destinations.
>>> node_pairs = torch.arange(0, 20).reshape(-1, 2)
>>> negative_dsts = torch.arange(10, 30).reshape(-1, 2)
>>> item_set = gb.ItemSet((node_pairs, negative_dsts), names=("node_pairs",
... "negative_dsts"))
>>> item_sampler = gb.ItemSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs=tensor([[10, 11],
[ 6, 7],
[ 2, 3],
[ 8, 9]]), labels=None, negative_srcs=None,
negative_dsts=tensor([[20, 21],
[16, 17],
[12, 13],
[18, 19]]), sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
5. DGLGraphs.
>>> import dgl
...
...
@@ -103,81 +189,96 @@ class ItemSampler(IterDataPipe):
7. Heterogeneous node IDs.
>>> ids = {
... "user": gb.ItemSet(torch.arange(0, 5)),
... "item": gb.ItemSet(torch.arange(0, 6)),
... "user": gb.ItemSet(torch.arange(0, 5)
, names="seed_nodes"
),
... "item": gb.ItemSet(torch.arange(0, 6)
, names="seed_nodes"
),
... }
>>> item_set = gb.ItemSetDict(ids)
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{'user': tensor([0, 1, 2, 3])},
{'item': tensor([0, 1, 2]), 'user': tensor([4])},
{'item': tensor([3, 4, 5])}]
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes={'user': tensor([0, 1, 2, 3])}, node_pairs=None,
labels=None, negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
8. Heterogeneous node pairs.
>>> node_pairs_like =
(
torch.arange(0,
5), torch.arange(0, 5)
)
>>> node_pairs_follow =
(
torch.arange(0,
6), torch.arange(6
,
1
2)
)
>>> node_pairs_like = torch.arange(0,
10).reshape(-1, 2
)
>>> node_pairs_follow = torch.arange(
1
0,
20).reshape(-1
, 2)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(node_pairs_like),
... "user:follow:user": gb.ItemSet(node_pairs_follow),
... "user:like:item": gb.ItemSet(
... node_pairs_like, names="node_pairs"),
... "user:follow:user": gb.ItemSet(
... node_pairs_follow, names="node_pairs"),
... })
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{"user:like:item": [tensor([4]), tensor([4])],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8])]},
{"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11])]}]
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])}, labels=None, negative_srcs=None, negative_dsts=None,
sampled_subgraphs=None, input_nodes=None, node_features=None,
edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
9. Heterogeneous node pairs and labels.
>>>
like = (
... torch.arange(0, 5), torch.arange(0, 5),
torch.arange(0,
5)
)
>>>
follow = (
... torch.arange(0, 6), torch.arange(6, 12),
torch.arange(0,
6)
)
>>>
node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> labels_like =
torch.arange(0,
10
)
>>>
node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
>>> labels_follow =
torch.arange(
1
0,
20
)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow),
... "user:like:item": gb.ItemSet((node_pairs_like, labels_like),
... names=("node_pairs", "labels")),
... "user:follow:user": gb.ItemSet((node_pairs_follow, labels_follow),
... names=("node_pairs", "labels")),
... })
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{"user:like:item":
[tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{"user:like:item": [tensor([4]), tensor([4]), tensor([4])],
"user:follow:user":
[tensor([0, 1, 2]), tensor([6, 7, 8]), tensor([0, 1, 2])]},
{"user:follow:user":
[tensor([3, 4, 5]), tensor([ 9, 10, 11]), tensor([3, 4, 5])]}]
10. Heterogeneous head, tail and negative tails.
>>> like = (
... torch.arange(0, 5), torch.arange(0, 5),
... torch.arange(5, 15).reshape(-1, 2))
>>> follow = (
... torch.arange(0, 6), torch.arange(6, 12),
... torch.arange(12, 24).reshape(-1, 2))
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])}, labels={'user:like:item': tensor([0, 1, 2, 3])},
negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
10. Heterogeneous node pairs and negative destinations.
>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> negative_dsts_like = torch.arange(10, 20).reshape(-1, 2)
>>> node_pairs_follow = torch.arange(20, 30).reshape(-1, 2)
>>> negative_dsts_follow = torch.arange(30, 40).reshape(-1, 2)
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(like),
... "user:follow:user": gb.ItemSet(follow),
... "user:like:item": gb.ItemSet((node_pairs_like, negative_dsts_like),
... names=("node_pairs", "negative_dsts")),
... "user:follow:user": gb.ItemSet((node_pairs_follow,
... negative_dsts_follow), names=("node_pairs", "negative_dsts")),
... })
>>> item_sampler = gb.ItemSampler(item_set, 4)
>>> list(item_sampler)
[{"user:like:item": [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]),
tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]},
{"user:like:item": [tensor([4]), tensor([4]), tensor([[13, 14]])],
"user:follow:user": [tensor([0, 1, 2]), tensor([6, 7, 8]),
tensor([[12, 13], [14, 15], [16, 17]])]},
{"user:follow:user": [tensor([3, 4, 5]), tensor([ 9, 10, 11]),
tensor([[18, 19], [20, 21], [22, 23]])]}]
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None, node_pairs={'user:like:item': tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7]])}, labels=None, negative_srcs=None,
negative_dsts={'user:like:item': tensor([[10, 11],
[12, 13],
[14, 15],
[16, 17]])}, sampled_subgraphs=None, input_nodes=None,
node_features=None, edge_features=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_dsts=None)
"""
def
__init__
(
self
,
item_set
:
ItemSet
or
ItemSetDict
,
batch_size
:
int
,
minibatcher
:
Optional
[
Callable
]
=
minibatcher_default
,
drop_last
:
Optional
[
bool
]
=
False
,
shuffle
:
Optional
[
bool
]
=
False
,
)
->
None
:
super
().
__init__
()
self
.
_item_set
=
item_set
self
.
_batch_size
=
batch_size
self
.
_minibatcher
=
minibatcher
self
.
_drop_last
=
drop_last
self
.
_shuffle
=
shuffle
...
...
@@ -217,4 +318,9 @@ class ItemSampler(IterDataPipe):
data_pipe
=
data_pipe
.
collate
(
collate_fn
=
partial
(
_collate
))
# Map to minibatch.
data_pipe
=
data_pipe
.
map
(
partial
(
self
.
_minibatcher
,
names
=
self
.
_item_set
.
names
)
)
return
iter
(
data_pipe
)
tests/python/pytorch/graphbolt/test_base.py
View file @
29949322
...
...
@@ -10,7 +10,7 @@ import torch
@
unittest
.
skipIf
(
F
.
_default_context_str
==
"cpu"
,
"CopyTo needs GPU to test"
)
def
test_CopyTo
():
dp
=
gb
.
ItemSampler
(
torch
.
randn
(
20
),
4
)
dp
=
gb
.
ItemSampler
(
gb
.
ItemSet
(
torch
.
randn
(
20
)
)
,
4
)
dp
=
gb
.
CopyTo
(
dp
,
"cuda"
)
for
data
in
dp
:
...
...
tests/python/pytorch/graphbolt/test_item_sampler.py
View file @
29949322
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment