Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
b08c446d
You need to sign in or sign up before continuing.
Unverified
Commit
b08c446d
authored
Oct 18, 2023
by
Rhett Ying
Committed by
GitHub
Oct 18, 2023
Browse files
[GraphBolt] Improve ItemSampler via indexing (#6453)
parent
5f327ff4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
139 additions
and
6 deletions
+139
-6
python/dgl/graphbolt/item_sampler.py
python/dgl/graphbolt/item_sampler.py
+106
-6
tests/python/pytorch/graphbolt/test_item_sampler.py
tests/python/pytorch/graphbolt/test_item_sampler.py
+33
-0
No files found.
python/dgl/graphbolt/item_sampler.py
View file @
b08c446d
...
@@ -4,6 +4,8 @@ from collections.abc import Mapping
...
@@ -4,6 +4,8 @@ from collections.abc import Mapping
from
functools
import
partial
from
functools
import
partial
from
typing
import
Callable
,
Iterator
,
Optional
from
typing
import
Callable
,
Iterator
,
Optional
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.utils.data
import
default_collate
from
torch.utils.data
import
default_collate
from
torchdata.datapipes.iter
import
IterableWrapper
,
IterDataPipe
from
torchdata.datapipes.iter
import
IterableWrapper
,
IterDataPipe
...
@@ -77,6 +79,72 @@ def minibatcher_default(batch, names):
...
@@ -77,6 +79,72 @@ def minibatcher_default(batch, names):
return
minibatch
return
minibatch
class
ItemShufflerAndBatcher
:
"""A shuffler to shuffle items and create batches.
This class is used internally by :class:`ItemSampler` to shuffle items and
create batches. It is not supposed to be used directly. The intention of
this class is to avoid time-consuming iteration over :class:`ItemSet`. As
an optimization, it slices from the :class:`ItemSet` via indexing first,
then shuffle and create batches.
Parameters
----------
item_set : ItemSet
Data to be iterated.
shuffle : bool
Option to shuffle before batching.
batch_size : int
The size of each batch.
drop_last : bool
Option to drop the last batch if it's not full.
buffer_size : int
The size of the buffer to store items sliced from the :class:`ItemSet`.
"""
def
__init__
(
self
,
item_set
:
ItemSet
,
shuffle
:
bool
,
batch_size
:
int
,
drop_last
:
bool
,
buffer_size
:
Optional
[
int
]
=
10
*
1000
,
):
self
.
_item_set
=
item_set
self
.
_shuffle
=
shuffle
self
.
_batch_size
=
batch_size
self
.
_drop_last
=
drop_last
self
.
_buffer_size
=
max
(
buffer_size
,
20
*
batch_size
)
# Round up the buffer size to the nearest multiple of batch size.
self
.
_buffer_size
=
(
(
self
.
_buffer_size
+
batch_size
-
1
)
//
batch_size
*
batch_size
)
def
__iter__
(
self
):
buffer
=
None
num_items
=
len
(
self
.
_item_set
)
start
=
0
while
start
<
num_items
:
end
=
min
(
start
+
self
.
_buffer_size
,
num_items
)
buffer
=
self
.
_item_set
[
start
:
end
]
indices
=
torch
.
arange
(
end
-
start
)
if
self
.
_shuffle
:
np
.
random
.
shuffle
(
indices
.
numpy
())
for
i
in
range
(
0
,
len
(
indices
),
self
.
_batch_size
):
if
self
.
_drop_last
and
i
+
self
.
_batch_size
>
len
(
indices
):
break
batch_indices
=
indices
[
i
:
i
+
self
.
_batch_size
]
if
len
(
self
.
_item_set
.
_items
)
==
1
:
if
isinstance
(
buffer
[
0
],
DGLGraph
):
yield
dgl_batch
([
buffer
[
idx
]
for
idx
in
batch_indices
])
else
:
yield
buffer
[
batch_indices
]
else
:
yield
tuple
(
item
[
batch_indices
]
for
item
in
buffer
)
buffer
=
None
start
=
end
class
ItemSampler
(
IterDataPipe
):
class
ItemSampler
(
IterDataPipe
):
"""A sampler to iterate over input items and create subsets.
"""A sampler to iterate over input items and create subsets.
...
@@ -287,14 +355,28 @@ class ItemSampler(IterDataPipe):
...
@@ -287,14 +355,28 @@ class ItemSampler(IterDataPipe):
minibatcher
:
Optional
[
Callable
]
=
minibatcher_default
,
minibatcher
:
Optional
[
Callable
]
=
minibatcher_default
,
drop_last
:
Optional
[
bool
]
=
False
,
drop_last
:
Optional
[
bool
]
=
False
,
shuffle
:
Optional
[
bool
]
=
False
,
shuffle
:
Optional
[
bool
]
=
False
,
# [TODO][Rui] For now, it's a temporary knob to disable indexing. In
# the future, we will enable indexing for all the item sets.
use_indexing
:
Optional
[
bool
]
=
True
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
_names
=
item_set
.
names
self
.
_names
=
item_set
.
names
self
.
_item_set
=
IterableWrapper
(
item_set
)
# Check if the item set supports indexing.
try
:
item_set
[
0
]
except
TypeError
:
use_indexing
=
False
# [TODO][Rui] For now, we disable indexing for ItemSetDict.
use_indexing
=
(
not
isinstance
(
item_set
,
ItemSetDict
))
and
use_indexing
self
.
_use_indexing
=
use_indexing
self
.
_item_set
=
(
item_set
if
self
.
_use_indexing
else
IterableWrapper
(
item_set
)
)
self
.
_batch_size
=
batch_size
self
.
_batch_size
=
batch_size
self
.
_minibatcher
=
minibatcher
self
.
_minibatcher
=
minibatcher
self
.
_drop_last
=
drop_last
self
.
_drop_last
=
drop_last
self
.
_shuffle
=
shuffle
self
.
_shuffle
=
shuffle
self
.
_use_indexing
=
use_indexing
def
_organize_items
(
self
,
data_pipe
)
->
None
:
def
_organize_items
(
self
,
data_pipe
)
->
None
:
# Shuffle before batch.
# Shuffle before batch.
...
@@ -333,11 +415,21 @@ class ItemSampler(IterDataPipe):
...
@@ -333,11 +415,21 @@ class ItemSampler(IterDataPipe):
return
default_collate
(
batch
)
return
default_collate
(
batch
)
def
__iter__
(
self
)
->
Iterator
:
def
__iter__
(
self
)
->
Iterator
:
# Organize items.
if
self
.
_use_indexing
:
data_pipe
=
self
.
_organize_items
(
self
.
_item_set
)
data_pipe
=
IterableWrapper
(
ItemShufflerAndBatcher
(
self
.
_item_set
,
self
.
_shuffle
,
self
.
_batch_size
,
self
.
_drop_last
,
)
)
else
:
# Organize items.
data_pipe
=
self
.
_organize_items
(
self
.
_item_set
)
# Collate.
# Collate.
data_pipe
=
data_pipe
.
collate
(
collate_fn
=
self
.
_collate
)
data_pipe
=
data_pipe
.
collate
(
collate_fn
=
self
.
_collate
)
# Map to minibatch.
# Map to minibatch.
data_pipe
=
data_pipe
.
map
(
partial
(
self
.
_minibatcher
,
names
=
self
.
_names
))
data_pipe
=
data_pipe
.
map
(
partial
(
self
.
_minibatcher
,
names
=
self
.
_names
))
...
@@ -504,7 +596,15 @@ class DistributedItemSampler(ItemSampler):
...
@@ -504,7 +596,15 @@ class DistributedItemSampler(ItemSampler):
num_replicas
:
Optional
[
int
]
=
None
,
num_replicas
:
Optional
[
int
]
=
None
,
drop_uneven_inputs
:
Optional
[
bool
]
=
False
,
drop_uneven_inputs
:
Optional
[
bool
]
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
(
item_set
,
batch_size
,
minibatcher
,
drop_last
,
shuffle
)
# [TODO][Rui] For now, always set use_indexing to False.
super
().
__init__
(
item_set
,
batch_size
,
minibatcher
,
drop_last
,
shuffle
,
use_indexing
=
False
,
)
self
.
_drop_uneven_inputs
=
drop_uneven_inputs
self
.
_drop_uneven_inputs
=
drop_uneven_inputs
# Apply a sharding filter to distribute the items.
# Apply a sharding filter to distribute the items.
self
.
_item_set
=
self
.
_item_set
.
sharding_filter
()
self
.
_item_set
=
self
.
_item_set
.
sharding_filter
()
...
...
tests/python/pytorch/graphbolt/test_item_sampler.py
View file @
b08c446d
...
@@ -65,6 +65,39 @@ def test_ItemSampler_minibatcher():
...
@@ -65,6 +65,39 @@ def test_ItemSampler_minibatcher():
assert
len
(
minibatch
.
seed_nodes
)
==
4
assert
len
(
minibatch
.
seed_nodes
)
==
4
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
def
test_ItemSet_Iterable_Only
(
batch_size
,
shuffle
,
drop_last
):
num_ids
=
103
class
InvalidLength
:
def
__iter__
(
self
):
return
iter
(
torch
.
arange
(
0
,
num_ids
))
seed_nodes
=
gb
.
ItemSet
(
InvalidLength
())
item_set
=
gb
.
ItemSet
(
seed_nodes
,
names
=
"seed_nodes"
)
item_sampler
=
gb
.
ItemSampler
(
item_set
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
minibatch_ids
=
[]
for
i
,
minibatch
in
enumerate
(
item_sampler
):
assert
isinstance
(
minibatch
,
gb
.
MiniBatch
)
assert
minibatch
.
seed_nodes
is
not
None
assert
minibatch
.
labels
is
None
is_last
=
(
i
+
1
)
*
batch_size
>=
num_ids
if
not
is_last
or
num_ids
%
batch_size
==
0
:
assert
len
(
minibatch
.
seed_nodes
)
==
batch_size
else
:
if
not
drop_last
:
assert
len
(
minibatch
.
seed_nodes
)
==
num_ids
%
batch_size
else
:
assert
False
minibatch_ids
.
append
(
minibatch
.
seed_nodes
)
minibatch_ids
=
torch
.
cat
(
minibatch_ids
)
assert
torch
.
all
(
minibatch_ids
[:
-
1
]
<=
minibatch_ids
[
1
:])
is
not
shuffle
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"shuffle"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"drop_last"
,
[
True
,
False
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment