Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b08c446d
Unverified
Commit
b08c446d
authored
Oct 18, 2023
by
Rhett Ying
Committed by
GitHub
Oct 18, 2023
Browse files
[GraphBolt] Improve ItemSampler via indexing (#6453)
parent
5f327ff4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
139 additions
and
6 deletions
+139
-6
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+106
-6
tests/python/pytorch/graphbolt/test_item_sampler.py
tests/python/pytorch/graphbolt/test_item_sampler.py
+33
-0
No files found.
python/dgl/graphbolt/item_sampler.py
View file @
b08c446d
...
...
@@ -4,6 +4,8 @@ from collections.abc import Mapping
from
functools
import
partial
from
typing
import
Callable
,
Iterator
,
Optional
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
torch.utils.data
import
default_collate
from
torchdata.datapipes.iter
import
IterableWrapper
,
IterDataPipe
...
...
@@ -77,6 +79,72 @@ def minibatcher_default(batch, names):
return
minibatch
class
ItemShufflerAndBatcher
:
"""A shuffler to shuffle items and create batches.
This class is used internally by :class:`ItemSampler` to shuffle items and
create batches. It is not supposed to be used directly. The intention of
this class is to avoid time-consuming iteration over :class:`ItemSet`. As
an optimization, it slices from the :class:`ItemSet` via indexing first,
then shuffle and create batches.
Parameters
----------
item_set : ItemSet
Data to be iterated.
shuffle : bool
Option to shuffle before batching.
batch_size : int
The size of each batch.
drop_last : bool
Option to drop the last batch if it's not full.
buffer_size : int
The size of the buffer to store items sliced from the :class:`ItemSet`.
"""
def
__init__
(
self
,
item_set
:
ItemSet
,
shuffle
:
bool
,
batch_size
:
int
,
drop_last
:
bool
,
buffer_size
:
Optional
[
int
]
=
10
*
1000
,
):
self
.
_item_set
=
item_set
self
.
_shuffle
=
shuffle
self
.
_batch_size
=
batch_size
self
.
_drop_last
=
drop_last
self
.
_buffer_size
=
max
(
buffer_size
,
20
*
batch_size
)
# Round up the buffer size to the nearest multiple of batch size.
self
.
_buffer_size
=
(
(
self
.
_buffer_size
+
batch_size
-
1
)
//
batch_size
*
batch_size
)
def
__iter__
(
self
):
buffer
=
None
num_items
=
len
(
self
.
_item_set
)
start
=
0
while
start
<
num_items
:
end
=
min
(
start
+
self
.
_buffer_size
,
num_items
)
buffer
=
self
.
_item_set
[
start
:
end
]
indices
=
torch
.
arange
(
end
-
start
)
if
self
.
_shuffle
:
np
.
random
.
shuffle
(
indices
.
numpy
())
for
i
in
range
(
0
,
len
(
indices
),
self
.
_batch_size
):
if
self
.
_drop_last
and
i
+
self
.
_batch_size
>
len
(
indices
):
break
batch_indices
=
indices
[
i
:
i
+
self
.
_batch_size
]
if
len
(
self
.
_item_set
.
_items
)
==
1
:
if
isinstance
(
buffer
[
0
],
DGLGraph
):
yield
dgl_batch
([
buffer
[
idx
]
for
idx
in
batch_indices
])
else
:
yield
buffer
[
batch_indices
]
else
:
yield
tuple
(
item
[
batch_indices
]
for
item
in
buffer
)
buffer
=
None
start
=
end
class
ItemSampler
(
IterDataPipe
):
"""A sampler to iterate over input items and create subsets.
...
...
@@ -287,14 +355,28 @@ class ItemSampler(IterDataPipe):
minibatcher
:
Optional
[
Callable
]
=
minibatcher_default
,
drop_last
:
Optional
[
bool
]
=
False
,
shuffle
:
Optional
[
bool
]
=
False
,
# [TODO][Rui] For now, it's a temporary knob to disable indexing. In
# the future, we will enable indexing for all the item sets.
use_indexing
:
Optional
[
bool
]
=
True
,
)
->
None
:
super
().
__init__
()
self
.
_names
=
item_set
.
names
self
.
_item_set
=
IterableWrapper
(
item_set
)
# Check if the item set supports indexing.
try
:
item_set
[
0
]
except
TypeError
:
use_indexing
=
False
# [TODO][Rui] For now, we disable indexing for ItemSetDict.
use_indexing
=
(
not
isinstance
(
item_set
,
ItemSetDict
))
and
use_indexing
self
.
_use_indexing
=
use_indexing
self
.
_item_set
=
(
item_set
if
self
.
_use_indexing
else
IterableWrapper
(
item_set
)
)
self
.
_batch_size
=
batch_size
self
.
_minibatcher
=
minibatcher
self
.
_drop_last
=
drop_last
self
.
_shuffle
=
shuffle
self
.
_use_indexing
=
use_indexing
def
_organize_items
(
self
,
data_pipe
)
->
None
:
# Shuffle before batch.
...
...
@@ -333,6 +415,16 @@ class ItemSampler(IterDataPipe):
return
default_collate
(
batch
)
def
__iter__
(
self
)
->
Iterator
:
if
self
.
_use_indexing
:
data_pipe
=
IterableWrapper
(
ItemShufflerAndBatcher
(
self
.
_item_set
,
self
.
_shuffle
,
self
.
_batch_size
,
self
.
_drop_last
,
)
)
else
:
# Organize items.
data_pipe
=
self
.
_organize_items
(
self
.
_item_set
)
...
...
@@ -504,7 +596,15 @@ class DistributedItemSampler(ItemSampler):
num_replicas
:
Optional
[
int
]
=
None
,
drop_uneven_inputs
:
Optional
[
bool
]
=
False
,
)
->
None
:
super
().
__init__
(
item_set
,
batch_size
,
minibatcher
,
drop_last
,
shuffle
)
# [TODO][Rui] For now, always set use_indexing to False.
super
().
__init__
(
item_set
,
batch_size
,
minibatcher
,
drop_last
,
shuffle
,
use_indexing
=
False
,
)
self
.
_drop_uneven_inputs
=
drop_uneven_inputs
# Apply a sharding filter to distribute the items.
self
.
_item_set
=
self
.
_item_set
.
sharding_filter
()
...
...
tests/python/pytorch/graphbolt/test_item_sampler.py
View file @
b08c446d
...
...
@@ -65,6 +65,39 @@ def test_ItemSampler_minibatcher():
assert
len
(
minibatch
.
seed_nodes
)
==
4
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSet_Iterable_Only
(
batch_size
,
shuffle
,
drop_last
):
num_ids
=
103
class
InvalidLength
:
def
__iter__
(
self
):
return
iter
(
torch
.
arange
(
0
,
num_ids
))
seed_nodes
=
gb
.
ItemSet
(
InvalidLength
())
item_set
=
gb
.
ItemSet
(
seed_nodes
,
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
minibatch_ids
=
[]
for
i
,
minibatch
in
enumerate
(
item_sampler
):
assert
isinstance
(
minibatch
,
gb
.
MiniBatch
)
assert
minibatch
.
seed_nodes
is
not
None
assert
minibatch
.
labels
is
None
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
assert
len
(
minibatch
.
seed_nodes
)
==
batch_size
else
:
if
not
drop_last
:
assert
len
(
minibatch
.
seed_nodes
)
==
num_ids
%
batch_size
else
:
assert
False
minibatch_ids
.
append
(
minibatch
.
seed_nodes
)
minibatch_ids
=
torch
.
cat
(
minibatch_ids
)
assert
torch
.
all
(
minibatch_ids
[:
-
1
]
<=
minibatch_ids
[
1
:])
is
not
shuffle
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment