Commit 0a7f9e64 authored by Myle Ott's avatar Myle Ott
Browse files

Further generalize EpochBatchIterator and move iterators into new file

parent 75f6ba05
......@@ -12,4 +12,4 @@ from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset
from .data_utils import EpochBatchIterator
from .iterators import EpochBatchIterator
......@@ -6,11 +6,9 @@
# can be found in the PATENTS file in the same directory.
import contextlib
import itertools
import os
import numpy as np
import torch
def infer_language_pair(path):
......@@ -23,60 +21,6 @@ def infer_language_pair(path):
return src, dst
class ShardedIterator(object):
"""A sharded wrapper around an iterable (padded to length)."""
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
if shard_id < 0 or shard_id >= num_shards:
raise ValueError('shard_id must be between 0 and num_shards')
self._sharded_len = len(iterable) // num_shards
if len(iterable) % num_shards > 0:
self._sharded_len += 1
self.itr = itertools.zip_longest(
range(self._sharded_len),
itertools.islice(iterable, shard_id, len(iterable), num_shards),
fillvalue=fill_value,
)
def __len__(self):
return self._sharded_len
def __iter__(self):
return self
def __next__(self):
return next(self.itr)[1]
class CountingIterator(object):
"""Wrapper around an iterable that maintains the iteration count."""
def __init__(self, iterable):
self.iterable = iterable
self.count = 0
self.itr = iter(self)
def __len__(self):
return len(self.iterable)
def __iter__(self):
for x in self.iterable:
self.count += 1
yield x
def __next__(self):
return next(self.itr)
def has_next(self):
return self.count < len(self)
def skip(self, num_to_skip):
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
return self
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values)
......@@ -96,103 +40,6 @@ def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=Fal
return res
class EpochBatchIterator(object):
"""A multi-epoch iterator over a :class:`~torch.utils.data.Dataset`.
Compared to :class:`~torch.utils.data.DataLoader`, this iterator:
- can be reused across multiple epochs with the :func:`next_epoch_itr`
method (optionally shuffled between epochs)
- can be serialized/deserialized with the :func:`state_dict` and
:func:`load_state_dict` methods
- supports sharding with the ``num_shards`` and ``shard_id`` arguments
Args:
dataset (Dataset): dataset from which to load the data
batch_sampler (Sampler): an iterator over batches of indices
seed (int, optional): seed for random number generator for
reproducibility. Default: ``1``
num_shards (int, optional): shard the data iterator into N
shards. Default: ``1``
shard_id (int, optional): which shard of the data iterator to
return. Default: ``0``
"""
def __init__(self, dataset, batch_sampler, seed=1, num_shards=1, shard_id=0):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
self.frozen_batches = tuple(batch_sampler)
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
self.epoch = 0
self._cur_epoch_itr = None
self._next_epoch_itr = None
def __len__(self):
return len(self.frozen_batches)
def next_epoch_itr(self, shuffle=True):
"""
Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True``
"""
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
else:
self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(self.epoch, shuffle)
return self._cur_epoch_itr
def end_of_epoch(self):
"""Returns whether the most recent epoch iterator has been exhausted"""
return not self._cur_epoch_itr.has_next()
@property
def iterations_in_epoch(self):
"""The number of consumed batches in the current epoch."""
if self._cur_epoch_itr is not None:
return self._cur_epoch_itr.count
elif self._next_epoch_itr is not None:
return self._next_epoch_itr.count
return 0
def state_dict(self):
return {
'epoch': self.epoch,
'iterations_in_epoch': self.iterations_in_epoch,
}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
itr_pos = state_dict.get('iterations_in_epoch', 0)
if itr_pos > 0:
# fast-forward epoch iterator
itr = self._get_iterator_for_epoch(self.epoch, state_dict.get('shuffle', True))
if itr_pos < len(itr):
self._next_epoch_itr = itr.skip(itr_pos)
def _get_iterator_for_epoch(self, epoch, shuffle):
if shuffle:
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with numpy_seed(self.seed + epoch):
batches = list(self.frozen_batches) # copy
np.random.shuffle(batches)
else:
batches = self.frozen_batches
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.dataset.collater,
batch_sampler=ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]),
))
@contextlib.contextmanager
def numpy_seed(seed):
"""Context manager which seeds the NumPy PRNG with the specified seed and
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import torch
from . import data_utils
class CountingIterator(object):
"""Wrapper around an iterable that maintains the iteration count.
Args:
iterable (iterable): iterable to wrap
Attributes:
count (int): number of elements consumed from this iterator
"""
def __init__(self, iterable):
self.iterable = iterable
self.count = 0
self.itr = iter(self)
def __len__(self):
return len(self.iterable)
def __iter__(self):
for x in self.iterable:
self.count += 1
yield x
def __next__(self):
return next(self.itr)
def has_next(self):
"""Whether the iterator has been exhausted."""
return self.count < len(self)
def skip(self, num_to_skip):
"""Fast-forward the iterator by skipping *num_to_skip* elements."""
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
return self
class EpochBatchIterator(object):
"""A multi-epoch iterator over a :class:`torch.utils.data.Dataset`.
Compared to :class:`torch.utils.data.DataLoader`, this iterator:
- can be reused across multiple epochs with the :func:`next_epoch_itr`
method (optionally shuffled between epochs)
- can be serialized/deserialized with the :func:`state_dict` and
:func:`load_state_dict` methods
- supports sharding with the *num_shards* and *shard_id* arguments
Args:
dataset (~torch.utils.data.Dataset): dataset from which to load the data
collate_fn (callable): merges a list of samples to form a mini-batch
batch_sampler (~torch.utils.data.Sampler): an iterator over batches of
indices
seed (int, optional): seed for random number generator for
reproducibility. Default: ``1``
num_shards (int, optional): shard the data iterator into N
shards. Default: ``1``
shard_id (int, optional): which shard of the data iterator to
return. Default: ``0``
"""
def __init__(self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
self.collate_fn = collate_fn
self.frozen_batches = tuple(batch_sampler)
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
self.epoch = 0
self._cur_epoch_itr = None
self._next_epoch_itr = None
def __len__(self):
return len(self.frozen_batches)
def next_epoch_itr(self, shuffle=True):
"""Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True``
"""
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
else:
self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(self.epoch, shuffle)
return self._cur_epoch_itr
def end_of_epoch(self):
"""Returns whether the most recent epoch iterator has been exhausted"""
return not self._cur_epoch_itr.has_next()
@property
def iterations_in_epoch(self):
"""The number of consumed batches in the current epoch."""
if self._cur_epoch_itr is not None:
return self._cur_epoch_itr.count
elif self._next_epoch_itr is not None:
return self._next_epoch_itr.count
return 0
def state_dict(self):
"""Returns a dictionary containing a whole state of the iterator."""
return {
'epoch': self.epoch,
'iterations_in_epoch': self.iterations_in_epoch,
}
def load_state_dict(self, state_dict):
"""Copies the state of the iterator from the given *state_dict*."""
self.epoch = state_dict['epoch']
itr_pos = state_dict.get('iterations_in_epoch', 0)
if itr_pos > 0:
# fast-forward epoch iterator
itr = self._get_iterator_for_epoch(self.epoch, state_dict.get('shuffle', True))
if itr_pos < len(itr):
self._next_epoch_itr = itr.skip(itr_pos)
def _get_iterator_for_epoch(self, epoch, shuffle):
if shuffle:
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with data_utils.numpy_seed(self.seed + epoch):
batches = list(self.frozen_batches) # copy
np.random.shuffle(batches)
else:
batches = self.frozen_batches
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]),
))
class ShardedIterator(object):
"""A sharded wrapper around an iterable, padded to length.
Args:
iterable (iterable): iterable to wrap
num_shards (int): number of shards to split the iterable into
shard_id (int): which shard to iterator over
fill_value (Any, optional): padding value when the iterable doesn't
evenly divide *num_shards*. Default: ``None``
"""
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
if shard_id < 0 or shard_id >= num_shards:
raise ValueError('shard_id must be between 0 and num_shards')
self._sharded_len = len(iterable) // num_shards
if len(iterable) % num_shards > 0:
self._sharded_len += 1
self.itr = itertools.zip_longest(
range(self._sharded_len),
itertools.islice(iterable, shard_id, len(iterable), num_shards),
fillvalue=fill_value,
)
def __len__(self):
return self._sharded_len
def __iter__(self):
return self
def __next__(self):
return next(self.itr)[1]
......@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from fairseq.data import data_utils, FairseqDataset
from fairseq.data import data_utils, FairseqDataset, iterators
class FairseqTask(object):
......@@ -87,8 +87,9 @@ class FairseqTask(object):
)
# return a reusable, sharded iterator
return data_utils.EpochBatchIterator(
return iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
seed=seed,
num_shards=num_shards,
......
......@@ -7,14 +7,14 @@
import unittest
from fairseq.data import data_utils
from fairseq.data import iterators
class TestDataUtils(unittest.TestCase):
class TestIterators(unittest.TestCase):
def test_counting_iterator(self):
x = list(range(10))
itr = data_utils.CountingIterator(x)
itr = iterators.CountingIterator(x)
self.assertTrue(itr.has_next())
self.assertEqual(next(itr), 0)
self.assertEqual(next(itr), 1)
......
......@@ -42,8 +42,10 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc
tokens = torch.LongTensor(list(range(epoch_size)))
tokens_ds = data.TokenBlockDataset(tokens, [len(tokens)], 1, include_targets=False)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
epoch_itr = data.EpochBatchIterator(
dataset=data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False),
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=[[i] for i in range(epoch_size)],
)
return trainer, epoch_itr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment