Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a1051f00
Unverified
Commit
a1051f00
authored
Jun 07, 2023
by
Rhett Ying
Committed by
GitHub
Jun 07, 2023
Browse files
[GraphBolt] add MinibatchSampler which supports ItemSet (#5793)
parent
b56552ff
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
348 additions
and
0 deletions
+348
-0
python/dgl/graphbolt/__init__.py
python/dgl/graphbolt/__init__.py
+1
-0
python/dgl/graphbolt/minibatch_sampler.py
python/dgl/graphbolt/minibatch_sampler.py
+135
-0
tests/python/pytorch/graphbolt/test_minibatch_sampler.py
tests/python/pytorch/graphbolt/test_minibatch_sampler.py
+212
-0
No files found.
python/dgl/graphbolt/__init__.py
View file @
a1051f00
...
...
@@ -7,6 +7,7 @@ import torch
from
.._ffi
import
libinfo
from
.graph_storage
import
*
from
.itemset
import
*
from
.minibatch_sampler
import
*
def
load_graphbolt
():
...
...
python/dgl/graphbolt/minibatch_sampler.py
0 → 100644
View file @
a1051f00
"""Minibatch Sampler"""
from
typing
import
Mapping
,
Optional
from
torch.utils.data
import
default_collate
from
torchdata.datapipes.iter
import
IterableWrapper
,
IterDataPipe
from
..batch
import
batch
as
dgl_batch
from
..heterograph
import
DGLGraph
from
.itemset
import
ItemSet
__all__
=
[
"MinibatchSampler"
]
def
_collate
(
batch
):
"""Collate batch."""
data
=
next
(
iter
(
batch
))
if
isinstance
(
data
,
DGLGraph
):
return
dgl_batch
(
batch
)
elif
isinstance
(
data
,
Mapping
):
raise
NotImplementedError
return
default_collate
(
batch
)
class
MinibatchSampler
(
IterDataPipe
):
"""Minibatch Sampler.
Creates mini-batches of data which could be node/edge IDs, node pairs with
or without labels, head/tail/negative_tails, DGLGraphs and heterogeneous
counterparts.
Note: This class `MinibatchSampler` is not decorated with
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it
does not support function-like call. But any iterable datapipes from
`torchdata` can be further appended.
Parameters
----------
item_set : ItemSet
Data to be sampled for mini-batches.
batch_size : int
The size of each batch.
drop_last : bool
Option to drop the last batch if it's not full.
shuffle : bool
Option to shuffle before sample.
Examples
--------
1. Node/edge IDs.
>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> minibatch_sampler = gb.MinibatchSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> list(minibatch_sampler)
[tensor([1, 2, 5, 7]), tensor([3, 0, 9, 4]), tensor([6, 8])]
2. Node pairs.
>>> item_set = gb.ItemSet((torch.arange(0, 10), torch.arange(10, 20)))
>>> minibatch_sampler = gb.MinibatchSampler(
... item_set, batch_size=4, shuffle=True, drop_last=False
... )
>>> list(minibatch_sampler)
[[tensor([9, 8, 3, 1]), tensor([19, 18, 13, 11])], [tensor([2, 5, 7, 4]),
tensor([12, 15, 17, 14])], [tensor([0, 6]), tensor([10, 16])]
3. Node pairs and labels.
>>> item_set = gb.ItemSet(
... (torch.arange(0, 5), torch.arange(5, 10), torch.arange(10, 15))
... )
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 3)
>>> list(minibatch_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]), tensor([10, 11, 12])],
[tensor([3, 4]), tensor([8, 9]), tensor([13, 14])]]
4. Head, tail and negative tails
>>> heads = torch.arange(0, 5)
>>> tails = torch.arange(5, 10)
>>> negative_tails = torch.stack((heads + 1, heads + 2), dim=-1)
>>> item_set = gb.ItemSet((heads, tails, negative_tails))
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 3)
>>> list(minibatch_sampler)
[[tensor([0, 1, 2]), tensor([5, 6, 7]),
tensor([[1, 2], [2, 3], [3, 4]])],
[tensor([3, 4]), tensor([8, 9]), tensor([[4, 5], [5, 6]])]]
5. DGLGraphs.
>>> import dgl
>>> graphs = [ dgl.rand_graph(10, 20) for _ in range(5) ]
>>> item_set = gb.ItemSet(graphs)
>>> minibatch_sampler = gb.MinibatchSampler(item_set, 3)
>>> list(minibatch_sampler)
[Graph(num_nodes=30, num_edges=60,
ndata_schemes={}
edata_schemes={}),
Graph(num_nodes=20, num_edges=40,
ndata_schemes={}
edata_schemes={})]
6. Further process batches with other datapipes such as
`torchdata.datapipes.iter.Mapper`.
>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> data_pipe = gb.MinibatchSampler(item_set, 4)
>>> def add_one(batch):
... return batch + 1
>>> data_pipe = data_pipe.map(add_one)
>>> list(data_pipe)
[tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8]), tensor([ 9, 10])]
"""
def
__init__
(
self
,
item_set
:
ItemSet
,
batch_size
:
int
,
drop_last
:
Optional
[
bool
]
=
False
,
shuffle
:
Optional
[
bool
]
=
False
,
):
super
().
__init__
()
self
.
_item_set
=
item_set
self
.
_batch_size
=
batch_size
self
.
_drop_last
=
drop_last
self
.
_shuffle
=
shuffle
def
__iter__
(
self
):
data_pipe
=
IterableWrapper
(
self
.
_item_set
)
if
self
.
_shuffle
:
# `torchdata.datapipes.iter.Shuffler` works with stream too.
data_pipe
=
data_pipe
.
shuffle
()
data_pipe
=
data_pipe
.
batch
(
batch_size
=
self
.
_batch_size
,
drop_last
=
self
.
_drop_last
,
).
collate
(
collate_fn
=
_collate
)
return
iter
(
data_pipe
)
tests/python/pytorch/graphbolt/test_minibatch_sampler.py
0 → 100644
View file @
a1051f00
import
dgl
import
pytest
import
torch
from
dgl
import
graphbolt
as
gb
from
torch.testing
import
assert_close
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSet_node_edge_ids
(
batch_size
,
shuffle
,
drop_last
):
# Node or edge IDs.
num_ids
=
103
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
num_ids
))
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
minibatch_ids
=
[]
for
i
,
minibatch
in
enumerate
(
minibatch_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
assert
len
(
minibatch
)
==
batch_size
else
:
if
not
drop_last
:
assert
len
(
minibatch
)
==
num_ids
%
batch_size
else
:
assert
False
minibatch_ids
.
append
(
minibatch
)
minibatch_ids
=
torch
.
cat
(
minibatch_ids
)
assert
torch
.
all
(
minibatch_ids
[:
-
1
]
<=
minibatch_ids
[
1
:])
is
not
shuffle
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSet_graphs
(
batch_size
,
shuffle
,
drop_last
):
# Graphs.
num_graphs
=
103
num_nodes
=
10
num_edges
=
20
graphs
=
[
dgl
.
rand_graph
(
num_nodes
*
(
i
+
1
),
num_edges
*
(
i
+
1
))
for
i
in
range
(
num_graphs
)
]
item_set
=
gb
.
ItemSet
(
graphs
)
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
minibatch_num_nodes
=
[]
minibatch_num_edges
=
[]
for
i
,
minibatch
in
enumerate
(
minibatch_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
num_graphs
if
not
is_last
or
num_graphs
%
batch_size
==
0
:
assert
minibatch
.
batch_size
==
batch_size
else
:
if
not
drop_last
:
assert
minibatch
.
batch_size
==
num_graphs
%
batch_size
else
:
assert
False
minibatch_num_nodes
.
append
(
minibatch
.
batch_num_nodes
())
minibatch_num_edges
.
append
(
minibatch
.
batch_num_edges
())
minibatch_num_nodes
=
torch
.
cat
(
minibatch_num_nodes
)
minibatch_num_edges
=
torch
.
cat
(
minibatch_num_edges
)
assert
(
torch
.
all
(
minibatch_num_nodes
[:
-
1
]
<=
minibatch_num_nodes
[
1
:])
is
not
shuffle
)
assert
(
torch
.
all
(
minibatch_num_edges
[:
-
1
]
<=
minibatch_num_edges
[
1
:])
is
not
shuffle
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSet_node_pairs
(
batch_size
,
shuffle
,
drop_last
):
# Node pairs.
num_ids
=
103
node_pairs
=
(
torch
.
arange
(
0
,
num_ids
),
torch
.
arange
(
num_ids
,
num_ids
*
2
))
item_set
=
gb
.
ItemSet
(
node_pairs
)
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
src_ids
=
[]
dst_ids
=
[]
for
i
,
(
src
,
dst
)
in
enumerate
(
minibatch_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
else
:
if
not
drop_last
:
expected_batch_size
=
num_ids
%
batch_size
else
:
assert
False
assert
len
(
src
)
==
expected_batch_size
assert
len
(
dst
)
==
expected_batch_size
# Verify src and dst IDs match.
assert
torch
.
equal
(
src
+
num_ids
,
dst
)
# Archive batch.
src_ids
.
append
(
src
)
dst_ids
.
append
(
dst
)
src_ids
=
torch
.
cat
(
src_ids
)
dst_ids
=
torch
.
cat
(
dst_ids
)
assert
torch
.
all
(
src_ids
[:
-
1
]
<=
src_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
dst_ids
[:
-
1
]
<=
dst_ids
[
1
:])
is
not
shuffle
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSet_node_pairs_labels
(
batch_size
,
shuffle
,
drop_last
):
# Node pairs and labels
num_ids
=
103
node_pairs
=
(
torch
.
arange
(
0
,
num_ids
),
torch
.
arange
(
num_ids
,
num_ids
*
2
))
labels
=
torch
.
arange
(
0
,
num_ids
)
item_set
=
gb
.
ItemSet
((
node_pairs
[
0
],
node_pairs
[
1
],
labels
))
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
src_ids
=
[]
dst_ids
=
[]
labels
=
[]
for
i
,
(
src
,
dst
,
label
)
in
enumerate
(
minibatch_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
else
:
if
not
drop_last
:
expected_batch_size
=
num_ids
%
batch_size
else
:
assert
False
assert
len
(
src
)
==
expected_batch_size
assert
len
(
dst
)
==
expected_batch_size
assert
len
(
label
)
==
expected_batch_size
# Verify src/dst IDs and labels match.
assert
torch
.
equal
(
src
+
num_ids
,
dst
)
assert
torch
.
equal
(
src
,
label
)
# Archive batch.
src_ids
.
append
(
src
)
dst_ids
.
append
(
dst
)
labels
.
append
(
label
)
src_ids
=
torch
.
cat
(
src_ids
)
dst_ids
=
torch
.
cat
(
dst_ids
)
labels
=
torch
.
cat
(
labels
)
assert
torch
.
all
(
src_ids
[:
-
1
]
<=
src_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
dst_ids
[:
-
1
]
<=
dst_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
labels
[:
-
1
]
<=
labels
[
1
:])
is
not
shuffle
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSet_head_tail_neg_tails
(
batch_size
,
shuffle
,
drop_last
):
# Head, tail and negative tails.
num_ids
=
103
num_negs
=
2
heads
=
torch
.
arange
(
0
,
num_ids
)
tails
=
torch
.
arange
(
num_ids
,
num_ids
*
2
)
neg_tails
=
torch
.
stack
((
heads
+
1
,
heads
+
2
),
dim
=-
1
)
item_set
=
gb
.
ItemSet
((
heads
,
tails
,
neg_tails
))
for
i
,
(
head
,
tail
,
negs
)
in
enumerate
(
item_set
):
assert
heads
[
i
]
==
head
assert
tails
[
i
]
==
tail
assert
torch
.
equal
(
neg_tails
[
i
],
negs
)
minibatch_sampler
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
head_ids
=
[]
tail_ids
=
[]
negs_ids
=
[]
for
i
,
(
head
,
tail
,
negs
)
in
enumerate
(
minibatch_sampler
):
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
expected_batch_size
=
batch_size
else
:
if
not
drop_last
:
expected_batch_size
=
num_ids
%
batch_size
else
:
assert
False
assert
len
(
head
)
==
expected_batch_size
assert
len
(
tail
)
==
expected_batch_size
assert
negs
.
dim
()
==
2
assert
negs
.
shape
[
0
]
==
expected_batch_size
assert
negs
.
shape
[
1
]
==
num_negs
# Verify head/tail and negatie tails match.
assert
torch
.
equal
(
head
+
num_ids
,
tail
)
assert
torch
.
equal
(
head
+
1
,
negs
[:,
0
])
assert
torch
.
equal
(
head
+
2
,
negs
[:,
1
])
# Archive batch.
head_ids
.
append
(
head
)
tail_ids
.
append
(
tail
)
negs_ids
.
append
(
negs
)
head_ids
=
torch
.
cat
(
head_ids
)
tail_ids
=
torch
.
cat
(
tail_ids
)
negs_ids
=
torch
.
cat
(
negs_ids
)
assert
torch
.
all
(
head_ids
[:
-
1
]
<=
head_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
tail_ids
[:
-
1
]
<=
tail_ids
[
1
:])
is
not
shuffle
assert
torch
.
all
(
negs_ids
[:
-
1
,
0
]
<=
negs_ids
[
1
:,
0
])
is
not
shuffle
assert
torch
.
all
(
negs_ids
[:
-
1
,
1
]
<=
negs_ids
[
1
:,
1
])
is
not
shuffle
def
test_append_with_other_datapipes
():
num_ids
=
100
batch_size
=
4
item_set
=
gb
.
ItemSet
(
torch
.
arange
(
0
,
num_ids
))
data_pipe
=
gb
.
MinibatchSampler
(
item_set
,
batch_size
)
# torchdata.datapipes.iter.Enumerator
data_pipe
=
data_pipe
.
enumerate
()
for
i
,
(
idx
,
data
)
in
enumerate
(
data_pipe
):
assert
i
==
idx
assert
len
(
data
)
==
batch_size
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment