"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "1ebda734e2a9edaccd89095c4bfdacc20d693a3c"
Commit 0a7f9e64 authored by Myle Ott's avatar Myle Ott
Browse files

Further generalize EpochBatchIterator and move iterators into new file

parent 75f6ba05
...@@ -12,4 +12,4 @@ from .language_pair_dataset import LanguagePairDataset ...@@ -12,4 +12,4 @@ from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset from .token_block_dataset import TokenBlockDataset
from .data_utils import EpochBatchIterator from .iterators import EpochBatchIterator
...@@ -6,11 +6,9 @@ ...@@ -6,11 +6,9 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import contextlib import contextlib
import itertools
import os import os
import numpy as np import numpy as np
import torch
def infer_language_pair(path): def infer_language_pair(path):
...@@ -23,60 +21,6 @@ def infer_language_pair(path): ...@@ -23,60 +21,6 @@ def infer_language_pair(path):
return src, dst return src, dst
class ShardedIterator(object):
"""A sharded wrapper around an iterable (padded to length)."""
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
if shard_id < 0 or shard_id >= num_shards:
raise ValueError('shard_id must be between 0 and num_shards')
self._sharded_len = len(iterable) // num_shards
if len(iterable) % num_shards > 0:
self._sharded_len += 1
self.itr = itertools.zip_longest(
range(self._sharded_len),
itertools.islice(iterable, shard_id, len(iterable), num_shards),
fillvalue=fill_value,
)
def __len__(self):
return self._sharded_len
def __iter__(self):
return self
def __next__(self):
return next(self.itr)[1]
class CountingIterator(object):
"""Wrapper around an iterable that maintains the iteration count."""
def __init__(self, iterable):
self.iterable = iterable
self.count = 0
self.itr = iter(self)
def __len__(self):
return len(self.iterable)
def __iter__(self):
for x in self.iterable:
self.count += 1
yield x
def __next__(self):
return next(self.itr)
def has_next(self):
return self.count < len(self)
def skip(self, num_to_skip):
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
return self
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False): def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
"""Convert a list of 1d tensors into a padded 2d tensor.""" """Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values) size = max(v.size(0) for v in values)
...@@ -96,103 +40,6 @@ def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=Fal ...@@ -96,103 +40,6 @@ def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=Fal
return res return res
class EpochBatchIterator(object):
"""A multi-epoch iterator over a :class:`~torch.utils.data.Dataset`.
Compared to :class:`~torch.utils.data.DataLoader`, this iterator:
- can be reused across multiple epochs with the :func:`next_epoch_itr`
method (optionally shuffled between epochs)
- can be serialized/deserialized with the :func:`state_dict` and
:func:`load_state_dict` methods
- supports sharding with the ``num_shards`` and ``shard_id`` arguments
Args:
dataset (Dataset): dataset from which to load the data
batch_sampler (Sampler): an iterator over batches of indices
seed (int, optional): seed for random number generator for
reproducibility. Default: ``1``
num_shards (int, optional): shard the data iterator into N
shards. Default: ``1``
shard_id (int, optional): which shard of the data iterator to
return. Default: ``0``
"""
def __init__(self, dataset, batch_sampler, seed=1, num_shards=1, shard_id=0):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
self.frozen_batches = tuple(batch_sampler)
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
self.epoch = 0
self._cur_epoch_itr = None
self._next_epoch_itr = None
def __len__(self):
return len(self.frozen_batches)
def next_epoch_itr(self, shuffle=True):
"""
Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True``
"""
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
else:
self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(self.epoch, shuffle)
return self._cur_epoch_itr
def end_of_epoch(self):
"""Returns whether the most recent epoch iterator has been exhausted"""
return not self._cur_epoch_itr.has_next()
@property
def iterations_in_epoch(self):
"""The number of consumed batches in the current epoch."""
if self._cur_epoch_itr is not None:
return self._cur_epoch_itr.count
elif self._next_epoch_itr is not None:
return self._next_epoch_itr.count
return 0
def state_dict(self):
return {
'epoch': self.epoch,
'iterations_in_epoch': self.iterations_in_epoch,
}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
itr_pos = state_dict.get('iterations_in_epoch', 0)
if itr_pos > 0:
# fast-forward epoch iterator
itr = self._get_iterator_for_epoch(self.epoch, state_dict.get('shuffle', True))
if itr_pos < len(itr):
self._next_epoch_itr = itr.skip(itr_pos)
def _get_iterator_for_epoch(self, epoch, shuffle):
if shuffle:
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with numpy_seed(self.seed + epoch):
batches = list(self.frozen_batches) # copy
np.random.shuffle(batches)
else:
batches = self.frozen_batches
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.dataset.collater,
batch_sampler=ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]),
))
@contextlib.contextmanager @contextlib.contextmanager
def numpy_seed(seed): def numpy_seed(seed):
"""Context manager which seeds the NumPy PRNG with the specified seed and """Context manager which seeds the NumPy PRNG with the specified seed and
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import torch
from . import data_utils
class CountingIterator(object):
"""Wrapper around an iterable that maintains the iteration count.
Args:
iterable (iterable): iterable to wrap
Attributes:
count (int): number of elements consumed from this iterator
"""
def __init__(self, iterable):
self.iterable = iterable
self.count = 0
self.itr = iter(self)
def __len__(self):
return len(self.iterable)
def __iter__(self):
for x in self.iterable:
self.count += 1
yield x
def __next__(self):
return next(self.itr)
def has_next(self):
"""Whether the iterator has been exhausted."""
return self.count < len(self)
def skip(self, num_to_skip):
"""Fast-forward the iterator by skipping *num_to_skip* elements."""
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
return self
class EpochBatchIterator(object):
"""A multi-epoch iterator over a :class:`torch.utils.data.Dataset`.
Compared to :class:`torch.utils.data.DataLoader`, this iterator:
- can be reused across multiple epochs with the :func:`next_epoch_itr`
method (optionally shuffled between epochs)
- can be serialized/deserialized with the :func:`state_dict` and
:func:`load_state_dict` methods
- supports sharding with the *num_shards* and *shard_id* arguments
Args:
dataset (~torch.utils.data.Dataset): dataset from which to load the data
collate_fn (callable): merges a list of samples to form a mini-batch
batch_sampler (~torch.utils.data.Sampler): an iterator over batches of
indices
seed (int, optional): seed for random number generator for
reproducibility. Default: ``1``
num_shards (int, optional): shard the data iterator into N
shards. Default: ``1``
shard_id (int, optional): which shard of the data iterator to
return. Default: ``0``
"""
def __init__(self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
self.collate_fn = collate_fn
self.frozen_batches = tuple(batch_sampler)
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
self.epoch = 0
self._cur_epoch_itr = None
self._next_epoch_itr = None
def __len__(self):
return len(self.frozen_batches)
def next_epoch_itr(self, shuffle=True):
"""Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True``
"""
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
else:
self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(self.epoch, shuffle)
return self._cur_epoch_itr
def end_of_epoch(self):
"""Returns whether the most recent epoch iterator has been exhausted"""
return not self._cur_epoch_itr.has_next()
@property
def iterations_in_epoch(self):
"""The number of consumed batches in the current epoch."""
if self._cur_epoch_itr is not None:
return self._cur_epoch_itr.count
elif self._next_epoch_itr is not None:
return self._next_epoch_itr.count
return 0
def state_dict(self):
"""Returns a dictionary containing a whole state of the iterator."""
return {
'epoch': self.epoch,
'iterations_in_epoch': self.iterations_in_epoch,
}
def load_state_dict(self, state_dict):
"""Copies the state of the iterator from the given *state_dict*."""
self.epoch = state_dict['epoch']
itr_pos = state_dict.get('iterations_in_epoch', 0)
if itr_pos > 0:
# fast-forward epoch iterator
itr = self._get_iterator_for_epoch(self.epoch, state_dict.get('shuffle', True))
if itr_pos < len(itr):
self._next_epoch_itr = itr.skip(itr_pos)
def _get_iterator_for_epoch(self, epoch, shuffle):
if shuffle:
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with data_utils.numpy_seed(self.seed + epoch):
batches = list(self.frozen_batches) # copy
np.random.shuffle(batches)
else:
batches = self.frozen_batches
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]),
))
class ShardedIterator(object):
"""A sharded wrapper around an iterable, padded to length.
Args:
iterable (iterable): iterable to wrap
num_shards (int): number of shards to split the iterable into
shard_id (int): which shard to iterator over
fill_value (Any, optional): padding value when the iterable doesn't
evenly divide *num_shards*. Default: ``None``
"""
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
if shard_id < 0 or shard_id >= num_shards:
raise ValueError('shard_id must be between 0 and num_shards')
self._sharded_len = len(iterable) // num_shards
if len(iterable) % num_shards > 0:
self._sharded_len += 1
self.itr = itertools.zip_longest(
range(self._sharded_len),
itertools.islice(iterable, shard_id, len(iterable), num_shards),
fillvalue=fill_value,
)
def __len__(self):
return self._sharded_len
def __iter__(self):
return self
def __next__(self):
return next(self.itr)[1]
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from fairseq.data import data_utils, FairseqDataset from fairseq.data import data_utils, FairseqDataset, iterators
class FairseqTask(object): class FairseqTask(object):
...@@ -87,8 +87,9 @@ class FairseqTask(object): ...@@ -87,8 +87,9 @@ class FairseqTask(object):
) )
# return a reusable, sharded iterator # return a reusable, sharded iterator
return data_utils.EpochBatchIterator( return iterators.EpochBatchIterator(
dataset=dataset, dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
seed=seed, seed=seed,
num_shards=num_shards, num_shards=num_shards,
......
...@@ -7,14 +7,14 @@ ...@@ -7,14 +7,14 @@
import unittest import unittest
from fairseq.data import data_utils from fairseq.data import iterators
class TestDataUtils(unittest.TestCase): class TestIterators(unittest.TestCase):
def test_counting_iterator(self): def test_counting_iterator(self):
x = list(range(10)) x = list(range(10))
itr = data_utils.CountingIterator(x) itr = iterators.CountingIterator(x)
self.assertTrue(itr.has_next()) self.assertTrue(itr.has_next())
self.assertEqual(next(itr), 0) self.assertEqual(next(itr), 0)
self.assertEqual(next(itr), 1) self.assertEqual(next(itr), 1)
......
...@@ -42,8 +42,10 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc ...@@ -42,8 +42,10 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc
tokens = torch.LongTensor(list(range(epoch_size))) tokens = torch.LongTensor(list(range(epoch_size)))
tokens_ds = data.TokenBlockDataset(tokens, [len(tokens)], 1, include_targets=False) tokens_ds = data.TokenBlockDataset(tokens, [len(tokens)], 1, include_targets=False)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch) trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
epoch_itr = data.EpochBatchIterator( epoch_itr = data.EpochBatchIterator(
dataset=data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False), dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=[[i] for i in range(epoch_size)], batch_sampler=[[i] for i in range(epoch_size)],
) )
return trainer, epoch_itr return trainer, epoch_itr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment