Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
0a7f9e64
Commit
0a7f9e64
authored
Aug 31, 2018
by
Myle Ott
Browse files
Further generalize EpochBatchIterator and move iterators into new file
parent
75f6ba05
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
195 additions
and
160 deletions
+195
-160
fairseq/data/__init__.py
fairseq/data/__init__.py
+1
-1
fairseq/data/data_utils.py
fairseq/data/data_utils.py
+0
-153
fairseq/data/iterators.py
fairseq/data/iterators.py
+185
-0
fairseq/tasks/fairseq_task.py
fairseq/tasks/fairseq_task.py
+3
-2
tests/test_iterators.py
tests/test_iterators.py
+3
-3
tests/test_train.py
tests/test_train.py
+3
-1
No files found.
fairseq/data/__init__.py
View file @
0a7f9e64
...
...
@@ -12,4 +12,4 @@ from .language_pair_dataset import LanguagePairDataset
from
.monolingual_dataset
import
MonolingualDataset
from
.token_block_dataset
import
TokenBlockDataset
from
.
data_util
s
import
EpochBatchIterator
from
.
iterator
s
import
EpochBatchIterator
fairseq/data/data_utils.py
View file @
0a7f9e64
...
...
@@ -6,11 +6,9 @@
# can be found in the PATENTS file in the same directory.
import
contextlib
import
itertools
import
os
import
numpy
as
np
import
torch
def
infer_language_pair
(
path
):
...
...
@@ -23,60 +21,6 @@ def infer_language_pair(path):
return
src
,
dst
class
ShardedIterator
(
object
):
"""A sharded wrapper around an iterable (padded to length)."""
def
__init__
(
self
,
iterable
,
num_shards
,
shard_id
,
fill_value
=
None
):
if
shard_id
<
0
or
shard_id
>=
num_shards
:
raise
ValueError
(
'shard_id must be between 0 and num_shards'
)
self
.
_sharded_len
=
len
(
iterable
)
//
num_shards
if
len
(
iterable
)
%
num_shards
>
0
:
self
.
_sharded_len
+=
1
self
.
itr
=
itertools
.
zip_longest
(
range
(
self
.
_sharded_len
),
itertools
.
islice
(
iterable
,
shard_id
,
len
(
iterable
),
num_shards
),
fillvalue
=
fill_value
,
)
def
__len__
(
self
):
return
self
.
_sharded_len
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
return
next
(
self
.
itr
)[
1
]
class
CountingIterator
(
object
):
"""Wrapper around an iterable that maintains the iteration count."""
def
__init__
(
self
,
iterable
):
self
.
iterable
=
iterable
self
.
count
=
0
self
.
itr
=
iter
(
self
)
def
__len__
(
self
):
return
len
(
self
.
iterable
)
def
__iter__
(
self
):
for
x
in
self
.
iterable
:
self
.
count
+=
1
yield
x
def
__next__
(
self
):
return
next
(
self
.
itr
)
def
has_next
(
self
):
return
self
.
count
<
len
(
self
)
def
skip
(
self
,
num_to_skip
):
next
(
itertools
.
islice
(
self
.
itr
,
num_to_skip
,
num_to_skip
),
None
)
return
self
def
collate_tokens
(
values
,
pad_idx
,
eos_idx
,
left_pad
,
move_eos_to_beginning
=
False
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size
=
max
(
v
.
size
(
0
)
for
v
in
values
)
...
...
@@ -96,103 +40,6 @@ def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=Fal
return
res
class
EpochBatchIterator
(
object
):
"""A multi-epoch iterator over a :class:`~torch.utils.data.Dataset`.
Compared to :class:`~torch.utils.data.DataLoader`, this iterator:
- can be reused across multiple epochs with the :func:`next_epoch_itr`
method (optionally shuffled between epochs)
- can be serialized/deserialized with the :func:`state_dict` and
:func:`load_state_dict` methods
- supports sharding with the ``num_shards`` and ``shard_id`` arguments
Args:
dataset (Dataset): dataset from which to load the data
batch_sampler (Sampler): an iterator over batches of indices
seed (int, optional): seed for random number generator for
reproducibility. Default: ``1``
num_shards (int, optional): shard the data iterator into N
shards. Default: ``1``
shard_id (int, optional): which shard of the data iterator to
return. Default: ``0``
"""
def
__init__
(
self
,
dataset
,
batch_sampler
,
seed
=
1
,
num_shards
=
1
,
shard_id
=
0
):
assert
isinstance
(
dataset
,
torch
.
utils
.
data
.
Dataset
)
self
.
dataset
=
dataset
self
.
frozen_batches
=
tuple
(
batch_sampler
)
self
.
seed
=
seed
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
self
.
epoch
=
0
self
.
_cur_epoch_itr
=
None
self
.
_next_epoch_itr
=
None
def
__len__
(
self
):
return
len
(
self
.
frozen_batches
)
def
next_epoch_itr
(
self
,
shuffle
=
True
):
"""
Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True``
"""
if
self
.
_next_epoch_itr
is
not
None
:
self
.
_cur_epoch_itr
=
self
.
_next_epoch_itr
self
.
_next_epoch_itr
=
None
else
:
self
.
epoch
+=
1
self
.
_cur_epoch_itr
=
self
.
_get_iterator_for_epoch
(
self
.
epoch
,
shuffle
)
return
self
.
_cur_epoch_itr
def
end_of_epoch
(
self
):
"""Returns whether the most recent epoch iterator has been exhausted"""
return
not
self
.
_cur_epoch_itr
.
has_next
()
@
property
def
iterations_in_epoch
(
self
):
"""The number of consumed batches in the current epoch."""
if
self
.
_cur_epoch_itr
is
not
None
:
return
self
.
_cur_epoch_itr
.
count
elif
self
.
_next_epoch_itr
is
not
None
:
return
self
.
_next_epoch_itr
.
count
return
0
def
state_dict
(
self
):
return
{
'epoch'
:
self
.
epoch
,
'iterations_in_epoch'
:
self
.
iterations_in_epoch
,
}
def
load_state_dict
(
self
,
state_dict
):
self
.
epoch
=
state_dict
[
'epoch'
]
itr_pos
=
state_dict
.
get
(
'iterations_in_epoch'
,
0
)
if
itr_pos
>
0
:
# fast-forward epoch iterator
itr
=
self
.
_get_iterator_for_epoch
(
self
.
epoch
,
state_dict
.
get
(
'shuffle'
,
True
))
if
itr_pos
<
len
(
itr
):
self
.
_next_epoch_itr
=
itr
.
skip
(
itr_pos
)
def
_get_iterator_for_epoch
(
self
,
epoch
,
shuffle
):
if
shuffle
:
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with
numpy_seed
(
self
.
seed
+
epoch
):
batches
=
list
(
self
.
frozen_batches
)
# copy
np
.
random
.
shuffle
(
batches
)
else
:
batches
=
self
.
frozen_batches
return
CountingIterator
(
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset
,
collate_fn
=
self
.
dataset
.
collater
,
batch_sampler
=
ShardedIterator
(
batches
,
self
.
num_shards
,
self
.
shard_id
,
fill_value
=
[]),
))
@
contextlib
.
contextmanager
def
numpy_seed
(
seed
):
"""Context manager which seeds the NumPy PRNG with the specified seed and
...
...
fairseq/data/iterators.py
0 → 100644
View file @
0a7f9e64
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
itertools
import
numpy
as
np
import
torch
from
.
import
data_utils
class
CountingIterator
(
object
):
"""Wrapper around an iterable that maintains the iteration count.
Args:
iterable (iterable): iterable to wrap
Attributes:
count (int): number of elements consumed from this iterator
"""
def
__init__
(
self
,
iterable
):
self
.
iterable
=
iterable
self
.
count
=
0
self
.
itr
=
iter
(
self
)
def
__len__
(
self
):
return
len
(
self
.
iterable
)
def
__iter__
(
self
):
for
x
in
self
.
iterable
:
self
.
count
+=
1
yield
x
def
__next__
(
self
):
return
next
(
self
.
itr
)
def
has_next
(
self
):
"""Whether the iterator has been exhausted."""
return
self
.
count
<
len
(
self
)
def
skip
(
self
,
num_to_skip
):
"""Fast-forward the iterator by skipping *num_to_skip* elements."""
next
(
itertools
.
islice
(
self
.
itr
,
num_to_skip
,
num_to_skip
),
None
)
return
self
class
EpochBatchIterator
(
object
):
"""A multi-epoch iterator over a :class:`torch.utils.data.Dataset`.
Compared to :class:`torch.utils.data.DataLoader`, this iterator:
- can be reused across multiple epochs with the :func:`next_epoch_itr`
method (optionally shuffled between epochs)
- can be serialized/deserialized with the :func:`state_dict` and
:func:`load_state_dict` methods
- supports sharding with the *num_shards* and *shard_id* arguments
Args:
dataset (~torch.utils.data.Dataset): dataset from which to load the data
collate_fn (callable): merges a list of samples to form a mini-batch
batch_sampler (~torch.utils.data.Sampler): an iterator over batches of
indices
seed (int, optional): seed for random number generator for
reproducibility. Default: ``1``
num_shards (int, optional): shard the data iterator into N
shards. Default: ``1``
shard_id (int, optional): which shard of the data iterator to
return. Default: ``0``
"""
def
__init__
(
self
,
dataset
,
collate_fn
,
batch_sampler
,
seed
=
1
,
num_shards
=
1
,
shard_id
=
0
):
assert
isinstance
(
dataset
,
torch
.
utils
.
data
.
Dataset
)
self
.
dataset
=
dataset
self
.
collate_fn
=
collate_fn
self
.
frozen_batches
=
tuple
(
batch_sampler
)
self
.
seed
=
seed
self
.
num_shards
=
num_shards
self
.
shard_id
=
shard_id
self
.
epoch
=
0
self
.
_cur_epoch_itr
=
None
self
.
_next_epoch_itr
=
None
def
__len__
(
self
):
return
len
(
self
.
frozen_batches
)
def
next_epoch_itr
(
self
,
shuffle
=
True
):
"""Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True``
"""
if
self
.
_next_epoch_itr
is
not
None
:
self
.
_cur_epoch_itr
=
self
.
_next_epoch_itr
self
.
_next_epoch_itr
=
None
else
:
self
.
epoch
+=
1
self
.
_cur_epoch_itr
=
self
.
_get_iterator_for_epoch
(
self
.
epoch
,
shuffle
)
return
self
.
_cur_epoch_itr
def
end_of_epoch
(
self
):
"""Returns whether the most recent epoch iterator has been exhausted"""
return
not
self
.
_cur_epoch_itr
.
has_next
()
@
property
def
iterations_in_epoch
(
self
):
"""The number of consumed batches in the current epoch."""
if
self
.
_cur_epoch_itr
is
not
None
:
return
self
.
_cur_epoch_itr
.
count
elif
self
.
_next_epoch_itr
is
not
None
:
return
self
.
_next_epoch_itr
.
count
return
0
def
state_dict
(
self
):
"""Returns a dictionary containing a whole state of the iterator."""
return
{
'epoch'
:
self
.
epoch
,
'iterations_in_epoch'
:
self
.
iterations_in_epoch
,
}
def
load_state_dict
(
self
,
state_dict
):
"""Copies the state of the iterator from the given *state_dict*."""
self
.
epoch
=
state_dict
[
'epoch'
]
itr_pos
=
state_dict
.
get
(
'iterations_in_epoch'
,
0
)
if
itr_pos
>
0
:
# fast-forward epoch iterator
itr
=
self
.
_get_iterator_for_epoch
(
self
.
epoch
,
state_dict
.
get
(
'shuffle'
,
True
))
if
itr_pos
<
len
(
itr
):
self
.
_next_epoch_itr
=
itr
.
skip
(
itr_pos
)
def
_get_iterator_for_epoch
(
self
,
epoch
,
shuffle
):
if
shuffle
:
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with
data_utils
.
numpy_seed
(
self
.
seed
+
epoch
):
batches
=
list
(
self
.
frozen_batches
)
# copy
np
.
random
.
shuffle
(
batches
)
else
:
batches
=
self
.
frozen_batches
return
CountingIterator
(
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset
,
collate_fn
=
self
.
collate_fn
,
batch_sampler
=
ShardedIterator
(
batches
,
self
.
num_shards
,
self
.
shard_id
,
fill_value
=
[]),
))
class
ShardedIterator
(
object
):
"""A sharded wrapper around an iterable, padded to length.
Args:
iterable (iterable): iterable to wrap
num_shards (int): number of shards to split the iterable into
shard_id (int): which shard to iterator over
fill_value (Any, optional): padding value when the iterable doesn't
evenly divide *num_shards*. Default: ``None``
"""
def
__init__
(
self
,
iterable
,
num_shards
,
shard_id
,
fill_value
=
None
):
if
shard_id
<
0
or
shard_id
>=
num_shards
:
raise
ValueError
(
'shard_id must be between 0 and num_shards'
)
self
.
_sharded_len
=
len
(
iterable
)
//
num_shards
if
len
(
iterable
)
%
num_shards
>
0
:
self
.
_sharded_len
+=
1
self
.
itr
=
itertools
.
zip_longest
(
range
(
self
.
_sharded_len
),
itertools
.
islice
(
iterable
,
shard_id
,
len
(
iterable
),
num_shards
),
fillvalue
=
fill_value
,
)
def
__len__
(
self
):
return
self
.
_sharded_len
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
return
next
(
self
.
itr
)[
1
]
fairseq/tasks/fairseq_task.py
View file @
0a7f9e64
...
...
@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
fairseq.data
import
data_utils
,
FairseqDataset
from
fairseq.data
import
data_utils
,
FairseqDataset
,
iterators
class
FairseqTask
(
object
):
...
...
@@ -87,8 +87,9 @@ class FairseqTask(object):
)
# return a reusable, sharded iterator
return
data_util
s
.
EpochBatchIterator
(
return
iterator
s
.
EpochBatchIterator
(
dataset
=
dataset
,
collate_fn
=
dataset
.
collater
,
batch_sampler
=
batch_sampler
,
seed
=
seed
,
num_shards
=
num_shards
,
...
...
tests/test_
data_util
s.py
→
tests/test_
iterator
s.py
View file @
0a7f9e64
...
...
@@ -7,14 +7,14 @@
import
unittest
from
fairseq.data
import
data_util
s
from
fairseq.data
import
iterator
s
class
Test
DataUtil
s
(
unittest
.
TestCase
):
class
Test
Iterator
s
(
unittest
.
TestCase
):
def
test_counting_iterator
(
self
):
x
=
list
(
range
(
10
))
itr
=
data_util
s
.
CountingIterator
(
x
)
itr
=
iterator
s
.
CountingIterator
(
x
)
self
.
assertTrue
(
itr
.
has_next
())
self
.
assertEqual
(
next
(
itr
),
0
)
self
.
assertEqual
(
next
(
itr
),
1
)
...
...
tests/test_train.py
View file @
0a7f9e64
...
...
@@ -42,8 +42,10 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc
tokens
=
torch
.
LongTensor
(
list
(
range
(
epoch_size
)))
tokens_ds
=
data
.
TokenBlockDataset
(
tokens
,
[
len
(
tokens
)],
1
,
include_targets
=
False
)
trainer
=
mock_trainer
(
epoch
,
num_updates
,
iterations_in_epoch
)
dataset
=
data
.
LanguagePairDataset
(
tokens_ds
,
tokens_ds
.
sizes
,
mock_dict
(),
shuffle
=
False
)
epoch_itr
=
data
.
EpochBatchIterator
(
dataset
=
data
.
LanguagePairDataset
(
tokens_ds
,
tokens_ds
.
sizes
,
mock_dict
(),
shuffle
=
False
),
dataset
=
dataset
,
collate_fn
=
dataset
.
collater
,
batch_sampler
=
[[
i
]
for
i
in
range
(
epoch_size
)],
)
return
trainer
,
epoch_itr
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment