Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d88275ca
Unverified
Commit
d88275ca
authored
Jun 08, 2023
by
Rhett Ying
Committed by
GitHub
Jun 08, 2023
Browse files
[GraphBolt] support DictItemSet in MinibatchSampler (#5803)
parent
945b0e54
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
317 additions
and
18 deletions
+317
-18
python/dgl/graphbolt/minibatch_sampler.py
python/dgl/graphbolt/minibatch_sampler.py
+98
-16
tests/python/pytorch/graphbolt/test_minibatch_sampler.py
tests/python/pytorch/graphbolt/test_minibatch_sampler.py
+219
-2
No files found.
python/dgl/graphbolt/minibatch_sampler.py
View file @
d88275ca
"""Minibatch Sampler"""
"""Minibatch Sampler"""
from
typing
import
Mapping
,
Optional
from
collections.abc
import
Mapping
from
functools
import
partial
from
typing
import
Optional
from
torch.utils.data
import
default_collate
from
torch.utils.data
import
default_collate
from
torchdata.datapipes.iter
import
IterableWrapper
,
IterDataPipe
from
torchdata.datapipes.iter
import
IterableWrapper
,
IterDataPipe
from
..batch
import
batch
as
dgl_batch
from
..batch
import
batch
as
dgl_batch
from
..heterograph
import
DGLGraph
from
..heterograph
import
DGLGraph
from
.itemset
import
ItemSet
from
.itemset
import
DictItemSet
,
ItemSet
__all__
=
[
"MinibatchSampler"
]
__all__
=
[
"MinibatchSampler"
]
def
_collate
(
batch
):
"""Collate batch."""
data
=
next
(
iter
(
batch
))
if
isinstance
(
data
,
DGLGraph
):
return
dgl_batch
(
batch
)
elif
isinstance
(
data
,
Mapping
):
raise
NotImplementedError
return
default_collate
(
batch
)
class
MinibatchSampler
(
IterDataPipe
):
class
MinibatchSampler
(
IterDataPipe
):
"""Minibatch Sampler.
"""Minibatch Sampler.
...
@@ -36,7 +28,7 @@ class MinibatchSampler(IterDataPipe):
...
@@ -36,7 +28,7 @@ class MinibatchSampler(IterDataPipe):
Parameters
Parameters
----------
----------
item_set : ItemSet
item_set : ItemSet
or DictItemSet
Data to be sampled for mini-batches.
Data to be sampled for mini-batches.
batch_size : int
batch_size : int
The size of each batch.
The size of each batch.
...
@@ -47,7 +39,7 @@ class MinibatchSampler(IterDataPipe):
...
@@ -47,7 +39,7 @@ class MinibatchSampler(IterDataPipe):
Examples
Examples
--------
--------
1. Node
/edge
IDs.
1. Node IDs.
>>> import torch
>>> import torch
>>> from dgl import graphbolt as gb
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> item_set = gb.ItemSet(torch.arange(0, 10))
...
@@ -108,11 +100,77 @@ class MinibatchSampler(IterDataPipe):
...
@@ -108,11 +100,77 @@ class MinibatchSampler(IterDataPipe):
>>> data_pipe = data_pipe.map(add_one)
>>> data_pipe = data_pipe.map(add_one)
>>> list(data_pipe)
>>> list(data_pipe)
[tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8]), tensor([ 9, 10])]
[tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8]), tensor([ 9, 10])]
7. Heterogeneous node IDs.
>>> ids = {
... "user": gb.ItemSet(torch.arange(0, 5)),
... "item": gb.ItemSet(torch.arange(0, 6)),
... }
>>> item_set = gb.DictItemSet(ids)
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
[{'user': tensor([0, 1, 2, 3])},
{'item': tensor([0, 1, 2]), 'user': tensor([4])},
{'item': tensor([3, 4, 5])}]
8. Heterogeneous node pairs.
>>> node_pairs_like = (torch.arange(0, 5), torch.arange(0, 5))
>>> node_pairs_follow = (torch.arange(0, 6), torch.arange(6, 12))
>>> item_set = gb.DictItemSet({
... ("user", "like", "item"): gb.ItemSet(node_pairs_like),
... ("user", "follow", "user"): gb.ItemSet(node_pairs_follow),
... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
[{('user', 'like', 'item'): [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{('user', 'like', 'item'): [tensor([4]), tensor([4])],
('user', 'follow', 'user'): [tensor([0, 1, 2]), tensor([6, 7, 8])]},
{('user', 'follow', 'user'): [tensor([3, 4, 5]), tensor([ 9, 10, 11])]}]
9. Heterogeneous node pairs and labels.
>>> like = (
... torch.arange(0, 5), torch.arange(0, 5), torch.arange(0, 5))
>>> follow = (
... torch.arange(0, 6), torch.arange(6, 12), torch.arange(0, 6))
>>> item_set = gb.DictItemSet({
... ("user", "like", "item"): gb.ItemSet(like),
... ("user", "follow", "user"): gb.ItemSet(follow),
... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
[{('user', 'like', 'item'):
[tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])]},
{('user', 'like', 'item'): [tensor([4]), tensor([4]), tensor([4])],
('user', 'follow', 'user'):
[tensor([0, 1, 2]), tensor([6, 7, 8]), tensor([0, 1, 2])]},
{('user', 'follow', 'user'):
[tensor([3, 4, 5]), tensor([ 9, 10, 11]), tensor([3, 4, 5])]}]
10. Heterogeneous head, tail and negative tails.
>>> like = (
... torch.arange(0, 5), torch.arange(0, 5),
... torch.arange(5, 15).reshape(-1, 2))
>>> follow = (
... torch.arange(0, 6), torch.arange(6, 12),
... torch.arange(12, 24).reshape(-1, 2))
>>> item_set = gb.DictItemSet({
... ("user", "like", "item"): gb.ItemSet(like),
... ("user", "follow", "user"): gb.ItemSet(follow),
... })
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 4)
>>> list(minibatch_sampler)
[{('user', 'like', 'item'): [tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3]),
tensor([[ 5, 6], [ 7, 8], [ 9, 10], [11, 12]])]},
{('user', 'like', 'item'): [tensor([4]), tensor([4]), tensor([[13, 14]])],
('user', 'follow', 'user'): [tensor([0, 1, 2]), tensor([6, 7, 8]),
tensor([[12, 13], [14, 15], [16, 17]])]},
{('user', 'follow', 'user'): [tensor([3, 4, 5]), tensor([ 9, 10, 11]),
tensor([[18, 19], [20, 21], [22, 23]])]}]
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
item_set
:
ItemSet
,
item_set
:
ItemSet
or
DictItemSet
,
batch_size
:
int
,
batch_size
:
int
,
drop_last
:
Optional
[
bool
]
=
False
,
drop_last
:
Optional
[
bool
]
=
False
,
shuffle
:
Optional
[
bool
]
=
False
,
shuffle
:
Optional
[
bool
]
=
False
,
...
@@ -125,11 +183,35 @@ class MinibatchSampler(IterDataPipe):
...
@@ -125,11 +183,35 @@ class MinibatchSampler(IterDataPipe):
def
__iter__
(
self
):
def
__iter__
(
self
):
data_pipe
=
IterableWrapper
(
self
.
_item_set
)
data_pipe
=
IterableWrapper
(
self
.
_item_set
)
# Shuffle before batch.
if
self
.
_shuffle
:
if
self
.
_shuffle
:
# `torchdata.datapipes.iter.Shuffler` works with stream too.
# `torchdata.datapipes.iter.Shuffler` works with stream too.
data_pipe
=
data_pipe
.
shuffle
()
data_pipe
=
data_pipe
.
shuffle
()
# Batch.
data_pipe
=
data_pipe
.
batch
(
data_pipe
=
data_pipe
.
batch
(
batch_size
=
self
.
_batch_size
,
batch_size
=
self
.
_batch_size
,
drop_last
=
self
.
_drop_last
,
drop_last
=
self
.
_drop_last
,
).
collate
(
collate_fn
=
_collate
)
)
# Collate.
def
_collate
(
batch
):
data
=
next
(
iter
(
batch
))
if
isinstance
(
data
,
DGLGraph
):
return
dgl_batch
(
batch
)
elif
isinstance
(
data
,
Mapping
):
assert
len
(
data
)
==
1
,
"Only one type of data is allowed."
# Collect all the keys.
keys
=
{
key
for
item
in
batch
for
key
in
item
.
keys
()}
# Collate each key.
return
{
key
:
default_collate
(
[
item
[
key
]
for
item
in
batch
if
key
in
item
]
)
for
key
in
keys
}
return
default_collate
(
batch
)
data_pipe
=
data_pipe
.
collate
(
collate_fn
=
partial
(
_collate
))
return
iter
(
data_pipe
)
return
iter
(
data_pipe
)
tests/python/pytorch/graphbolt/test_minibatch_sampler.py
View file @
d88275ca
...
@@ -8,8 +8,8 @@ from torch.testing import assert_close
...
@@ -8,8 +8,8 @@ from torch.testing import assert_close
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSet_node_
edge_
ids
(
batch_size
,
shuffle
,
drop_last
):
def
test_ItemSet_node_ids
(
batch_size
,
shuffle
,
drop_last
):
# Node
or edge
IDs.
# Node IDs.
num_ids
=
103
num_ids
=
103
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
num_ids
))
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
num_ids
))
minibatch_sampler
=
gb
.
MinibatchSampler
(
minibatch_sampler
=
gb
.
MinibatchSampler
(
...
@@ -210,3 +210,220 @@ def test_append_with_other_datapipes():
...
@@ -210,3 +210,220 @@ def test_append_with_other_datapipes():
for
i
,
(
idx
,
data
)
in
enumerate
(
data_pipe
):
for
i
,
(
idx
,
data
)
in
enumerate
(
data_pipe
):
assert
i
==
idx
assert
i
==
idx
assert
len
(
data
)
==
batch_size
assert
len
(
data
)
==
batch_size
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_DictItemSet_node_ids
(
batch_size
,
shuffle
,
drop_last
):
# Node IDs.
num_ids
=
205
ids
=
{
"user"
:
gb
.
ItemSet
(
torch
.
arange
(
0
,
99
)),
"item"
:
gb
.
ItemSet
(
torch
.
arange
(
99
,
num_ids
)),
}
chained_ids
=
[]
for
key
,
value
in
ids
.
items
():
chained_ids
+=
[(
key
,
v
)
for
v
in
value
]
item_set
=
gb
.
DictItemSet
(
ids
)
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
minibatch_ids
=
[]
for
i
,
batch
in
enumerate
(
minibatch_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
else
:
if
not
drop_last
:
expected_batch_size
=
num_ids
%
batch_size
else
:
assert
False
assert
isinstance
(
batch
,
dict
)
ids
=
[]
for
_
,
v
in
batch
.
items
():
ids
.
append
(
v
)
ids
=
torch
.
cat
(
ids
)
assert
len
(
ids
)
==
expected_batch_size
minibatch_ids
.
append
(
ids
)
minibatch_ids
=
torch
.
cat
(
minibatch_ids
)
assert
torch
.
all
(
minibatch_ids
[:
-
1
]
<=
minibatch_ids
[
1
:])
is
not
shuffle
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_DictItemSet_node_pairs
(
batch_size
,
shuffle
,
drop_last
):
# Node pairs.
num_ids
=
103
total_ids
=
2
*
num_ids
node_pairs_0
=
(
torch
.
arange
(
0
,
num_ids
),
torch
.
arange
(
num_ids
,
num_ids
*
2
),
)
node_pairs_1
=
(
torch
.
arange
(
num_ids
*
2
,
num_ids
*
3
),
torch
.
arange
(
num_ids
*
3
,
num_ids
*
4
),
)
node_pairs_dict
=
{
(
"user"
,
"like"
,
"item"
):
gb
.
ItemSet
(
node_pairs_0
),
(
"user"
,
"follow"
,
"user"
):
gb
.
ItemSet
(
node_pairs_1
),
}
item_set
=
gb
.
DictItemSet
(
node_pairs_dict
)
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
src_ids
=
[]
dst_ids
=
[]
for
i
,
batch
in
enumerate
(
minibatch_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
total_ids
if
not
is_last
or
total_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
else
:
if
not
drop_last
:
expected_batch_size
=
total_ids
%
batch_size
else
:
assert
False
src
=
[]
dst
=
[]
for
_
,
(
v_src
,
v_dst
)
in
batch
.
items
():
src
.
append
(
v_src
)
dst
.
append
(
v_dst
)
src
=
torch
.
cat
(
src
)
dst
=
torch
.
cat
(
dst
)
assert
len
(
src
)
==
expected_batch_size
assert
len
(
dst
)
==
expected_batch_size
src_ids
.
append
(
src
)
dst_ids
.
append
(
dst
)
assert
torch
.
equal
(
src
+
num_ids
,
dst
)
src_ids
=
torch
.
cat
(
src_ids
)
dst_ids
=
torch
.
cat
(
dst_ids
)
assert
torch
.
all
(
src_ids
[:
-
1
]
<=
src_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
dst_ids
[:
-
1
]
<=
dst_ids
[
1
:])
is
not
shuffle
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_DictItemSet_node_pairs_labels
(
batch_size
,
shuffle
,
drop_last
):
# Node pairs and labels
num_ids
=
103
total_ids
=
2
*
num_ids
node_pairs_0
=
(
torch
.
arange
(
0
,
num_ids
),
torch
.
arange
(
num_ids
,
num_ids
*
2
),
)
node_pairs_1
=
(
torch
.
arange
(
num_ids
*
2
,
num_ids
*
3
),
torch
.
arange
(
num_ids
*
3
,
num_ids
*
4
),
)
labels
=
torch
.
arange
(
0
,
num_ids
)
node_pairs_dict
=
{
(
"user"
,
"like"
,
"item"
):
gb
.
ItemSet
(
(
node_pairs_0
[
0
],
node_pairs_0
[
1
],
labels
)
),
(
"user"
,
"follow"
,
"user"
):
gb
.
ItemSet
(
(
node_pairs_1
[
0
],
node_pairs_1
[
1
],
labels
+
num_ids
*
2
)
),
}
item_set
=
gb
.
DictItemSet
(
node_pairs_dict
)
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
src_ids
=
[]
dst_ids
=
[]
labels
=
[]
for
i
,
batch
in
enumerate
(
minibatch_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
total_ids
if
not
is_last
or
total_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
else
:
if
not
drop_last
:
expected_batch_size
=
total_ids
%
batch_size
else
:
assert
False
src
=
[]
dst
=
[]
label
=
[]
for
_
,
(
v_src
,
v_dst
,
v_label
)
in
batch
.
items
():
src
.
append
(
v_src
)
dst
.
append
(
v_dst
)
label
.
append
(
v_label
)
src
=
torch
.
cat
(
src
)
dst
=
torch
.
cat
(
dst
)
label
=
torch
.
cat
(
label
)
assert
len
(
src
)
==
expected_batch_size
assert
len
(
dst
)
==
expected_batch_size
assert
len
(
label
)
==
expected_batch_size
src_ids
.
append
(
src
)
dst_ids
.
append
(
dst
)
labels
.
append
(
label
)
assert
torch
.
equal
(
src
+
num_ids
,
dst
)
assert
torch
.
equal
(
src
,
label
)
src_ids
=
torch
.
cat
(
src_ids
)
dst_ids
=
torch
.
cat
(
dst_ids
)
labels
=
torch
.
cat
(
labels
)
assert
torch
.
all
(
src_ids
[:
-
1
]
<=
src_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
dst_ids
[:
-
1
]
<=
dst_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
labels
[:
-
1
]
<=
labels
[
1
:])
is
not
shuffle
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_DictItemSet_head_tail_neg_tails
(
batch_size
,
shuffle
,
drop_last
):
# Head, tail and negative tails.
num_ids
=
103
total_ids
=
2
*
num_ids
num_negs
=
2
heads
=
torch
.
arange
(
0
,
num_ids
)
tails
=
torch
.
arange
(
num_ids
,
num_ids
*
2
)
neg_tails
=
torch
.
stack
((
heads
+
1
,
heads
+
2
),
dim
=-
1
)
data_dict
=
{
(
"user"
,
"like"
,
"item"
):
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
)),
(
"user"
,
"follow"
,
"user"
):
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
)),
}
item_set
=
gb
.
DictItemSet
(
data_dict
)
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
head_ids
=
[]
tail_ids
=
[]
negs_ids
=
[]
for
i
,
batch
in
enumerate
(
minibatch_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
total_ids
if
not
is_last
or
total_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
else
:
if
not
drop_last
:
expected_batch_size
=
total_ids
%
batch_size
else
:
assert
False
head
=
[]
tail
=
[]
negs
=
[]
for
_
,
(
v_head
,
v_tail
,
v_negs
)
in
batch
.
items
():
head
.
append
(
v_head
)
tail
.
append
(
v_tail
)
negs
.
append
(
v_negs
)
head
=
torch
.
cat
(
head
)
tail
=
torch
.
cat
(
tail
)
negs
=
torch
.
cat
(
negs
)
assert
len
(
head
)
==
expected_batch_size
assert
len
(
tail
)
==
expected_batch_size
assert
len
(
negs
)
==
expected_batch_size
head_ids
.
append
(
head
)
tail_ids
.
append
(
tail
)
negs_ids
.
append
(
negs
)
assert
negs
.
dim
()
==
2
assert
negs
.
shape
[
0
]
==
expected_batch_size
assert
negs
.
shape
[
1
]
==
num_negs
assert
torch
.
equal
(
head
+
num_ids
,
tail
)
assert
torch
.
equal
(
head
+
1
,
negs
[:,
0
])
assert
torch
.
equal
(
head
+
2
,
negs
[:,
1
])
head_ids
=
torch
.
cat
(
head_ids
)
tail_ids
=
torch
.
cat
(
tail_ids
)
negs_ids
=
torch
.
cat
(
negs_ids
)
assert
torch
.
all
(
head_ids
[:
-
1
]
<=
head_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
tail_ids
[:
-
1
]
<=
tail_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
negs_ids
[:
-
1
]
<=
negs_ids
[
1
:])
is
not
shuffle
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment