"git@developer.sourcefind.cn:OpenDAS/deepspeed.git" did not exist on "be1147c08acde34d7a13b73f8d33e55c3f719de4"
Commit 2e507d3c authored by Myle Ott's avatar Myle Ott
Browse files

Clean up FairseqTask so that it's easier to extend/add new tasks

parent 6296de82
...@@ -59,11 +59,13 @@ def main(parsed_args): ...@@ -59,11 +59,13 @@ def main(parsed_args):
assert len(models) > 0 assert len(models) > 0
itr = data.EpochBatchIterator( itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset), dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens or 36000, max_tokens=args.max_tokens or 36000,
max_sentences=args.max_sentences, max_sentences=args.max_sentences,
max_positions=models[0].max_positions(), max_positions=utils.resolve_max_positions(*[
model.max_positions() for model in models
]),
num_shards=args.num_shards, num_shards=args.num_shards,
shard_id=args.shard_id, shard_id=args.shard_id,
ignore_invalid_inputs=True, ignore_invalid_inputs=True,
......
...@@ -12,8 +12,6 @@ import os ...@@ -12,8 +12,6 @@ import os
import numpy as np import numpy as np
import torch import torch
from . import FairseqDataset
def infer_language_pair(path): def infer_language_pair(path):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx""" """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
...@@ -99,42 +97,35 @@ def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=Fal ...@@ -99,42 +97,35 @@ def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=Fal
class EpochBatchIterator(object): class EpochBatchIterator(object):
"""Iterate over a FairseqDataset and yield batches bucketed by size. """A multi-epoch iterator over a :class:`~torch.utils.data.Dataset`.
Compared to :class:`~torch.utils.data.DataLoader`, this iterator:
Batches may contain sequences of different lengths. This iterator can be - can be reused across multiple epochs with the :func:`next_epoch_itr`
reused across multiple epochs with the next_epoch_itr() method. method (optionally shuffled between epochs)
- can be serialized/deserialized with the :func:`state_dict` and
:func:`load_state_dict` methods
- supports sharding with the ``num_shards`` and ``shard_id`` arguments
Args: Args:
dataset: a FairseqDataset dataset (Dataset): dataset from which to load the data
max_tokens: max number of tokens in each batch batch_sampler (Sampler): an iterator over batches of indices
max_sentences: max number of sentences in each batch seed (int, optional): seed for random number generator for
max_positions: max sentence length supported by the model reproducibility. Default: ``1``
ignore_invalid_inputs: don't raise Exception for sentences that are too long num_shards (int, optional): shard the data iterator into N
required_batch_size_multiple: require batch size to be a multiple of N shards. Default: ``1``
seed: seed for random number generator for reproducibility shard_id (int, optional): which shard of the data iterator to
num_shards: shard the data iterator into N shards return. Default: ``0``
shard_id: which shard of the data iterator to return
""" """
def __init__( def __init__(self, dataset, batch_sampler, seed=1, num_shards=1, shard_id=0):
self, dataset, max_tokens=None, max_sentences=None, max_positions=None, assert isinstance(dataset, torch.utils.data.Dataset)
ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1,
num_shards=1, shard_id=0,
):
assert isinstance(dataset, FairseqDataset)
self.dataset = dataset self.dataset = dataset
self.max_tokens = max_tokens if max_tokens is not None else float('Inf') self.frozen_batches = tuple(batch_sampler)
self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
self.max_positions = max_positions
self.ignore_invalid_inputs = ignore_invalid_inputs
self.bsz_mult = required_batch_size_multiple
self.seed = seed self.seed = seed
self.num_shards = num_shards self.num_shards = num_shards
self.shard_id = shard_id self.shard_id = shard_id
with numpy_seed(self.seed):
self.frozen_batches = tuple(self._batch_generator())
self.epoch = 0 self.epoch = 0
self._cur_epoch_itr = None self._cur_epoch_itr = None
self._next_epoch_itr = None self._next_epoch_itr = None
...@@ -143,7 +134,13 @@ class EpochBatchIterator(object): ...@@ -143,7 +134,13 @@ class EpochBatchIterator(object):
return len(self.frozen_batches) return len(self.frozen_batches)
def next_epoch_itr(self, shuffle=True): def next_epoch_itr(self, shuffle=True):
"""Shuffle batches and return a new iterator over the dataset.""" """
Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True``
"""
if self._next_epoch_itr is not None: if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None self._next_epoch_itr = None
...@@ -153,10 +150,12 @@ class EpochBatchIterator(object): ...@@ -153,10 +150,12 @@ class EpochBatchIterator(object):
return self._cur_epoch_itr return self._cur_epoch_itr
def end_of_epoch(self): def end_of_epoch(self):
"""Returns whether the most recent epoch iterator has been exhausted"""
return not self._cur_epoch_itr.has_next() return not self._cur_epoch_itr.has_next()
@property @property
def iterations_in_epoch(self): def iterations_in_epoch(self):
"""The number of consumed batches in the current epoch."""
if self._cur_epoch_itr is not None: if self._cur_epoch_itr is not None:
return self._cur_epoch_itr.count return self._cur_epoch_itr.count
elif self._next_epoch_itr is not None: elif self._next_epoch_itr is not None:
...@@ -193,55 +192,6 @@ class EpochBatchIterator(object): ...@@ -193,55 +192,6 @@ class EpochBatchIterator(object):
batch_sampler=ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]), batch_sampler=ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]),
)) ))
def _batch_generator(self):
batch = []
def is_batch_full(num_tokens):
if len(batch) == 0:
return False
if len(batch) == self.max_sentences:
return True
if num_tokens > self.max_tokens:
return True
return False
sample_len = 0
sample_lens = []
ignored = []
for idx in self.dataset.ordered_indices():
if not self.dataset.valid_size(idx, self.max_positions):
if self.ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception((
'Size of sample #{} is invalid, max_positions={}, skip this '
'example with --skip-invalid-size-inputs-valid-test'
).format(idx, self.max_positions))
sample_lens.append(self.dataset.num_tokens(idx))
sample_len = max(sample_len, sample_lens[-1])
num_tokens = (len(batch) + 1) * sample_len
if is_batch_full(num_tokens):
mod_len = max(
self.bsz_mult * (len(batch) // self.bsz_mult),
len(batch) % self.bsz_mult,
)
yield batch[:mod_len]
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
yield batch
if len(ignored) > 0:
print((
'| WARNING: {} samples have invalid sizes and will be skipped, '
'max_positions={}, first few sample ids={}'
).format(len(ignored), self.max_positions, ignored[:10]))
@contextlib.contextmanager @contextlib.contextmanager
def numpy_seed(seed): def numpy_seed(seed):
...@@ -256,3 +206,112 @@ def numpy_seed(seed): ...@@ -256,3 +206,112 @@ def numpy_seed(seed):
yield yield
finally: finally:
np.random.set_state(state) np.random.set_state(state)
def collect_filtered(function, iterable, filtered):
"""
Similar to :func:`filter` but collects filtered elements in ``filtered``.
Args:
function (callable): function that returns ``False`` for elements that
should be filtered
iterable (iterable): iterable to filter
filtered (list): list to store filtered elements
"""
for el in iterable:
if function(el):
yield el
else:
filtered.append(el)
def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
"""
Filter indices based on their size.
Args:
indices (List[int]): ordered list of dataset indices
size_fn (callable): function that returns the size of a given index
max_positions (tuple): filter elements larger than this size.
Comparisons are done component-wise.
raise_exception (bool, optional): if ``True``, raise an exception
if any elements are filtered. Default: ``False``
"""
def check_size(idx):
if isinstance(max_positions, float) or isinstance(max_positions, int):
return size_fn(idx) < max_positions
else:
return all(a <= b for a, b in zip(size_fn(idx), max_positions))
ignored = []
itr = collect_filtered(check_size, indices, ignored)
for idx in itr:
if len(ignored) > 0 and raise_exception:
raise Exception((
'Size of sample #{} is invalid (={}) since max_positions={}, '
'skip this example with --skip-invalid-size-inputs-valid-test'
).format(idx, self.size(idx), max_positions))
yield idx
if len(ignored) > 0:
print((
'| WARNING: {} samples have invalid sizes and will be skipped, '
'max_positions={}, first few sample ids={}'
).format(len(ignored), max_positions, ignored[:10]))
def batch_by_size(
indices, num_tokens_fn, max_tokens=None, max_sentences=None,
required_batch_size_multiple=1,
):
"""
Yield mini-batches of indices bucketed by size. Batches may contain
sequences of different lengths.
Args:
indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at
a given index
max_tokens (int, optional): max number of tokens in each batch.
Default: ``None``
max_sentences (int, optional): max number of sentences in each
batch. Default: ``None``
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1``
"""
max_tokens = max_tokens if max_tokens is not None else float('Inf')
max_sentences = max_sentences if max_sentences is not None else float('Inf')
bsz_mult = required_batch_size_multiple
batch = []
def is_batch_full(num_tokens):
if len(batch) == 0:
return False
if len(batch) == max_sentences:
return True
if num_tokens > max_tokens:
return True
return False
sample_len = 0
sample_lens = []
ignored = []
for idx in indices:
sample_lens.append(num_tokens_fn(idx))
sample_len = max(sample_len, sample_lens[-1])
num_tokens = (len(batch) + 1) * sample_len
if is_batch_full(num_tokens):
mod_len = max(
bsz_mult * (len(batch) // bsz_mult),
len(batch) % bsz_mult,
)
yield batch[:mod_len]
batch = batch[mod_len:]
sample_lens = sample_lens[mod_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
batch.append(idx)
if len(batch) > 0:
yield batch
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
import torch.utils.data import torch.utils.data
from fairseq.data import data_utils
class FairseqDataset(torch.utils.data.Dataset): class FairseqDataset(torch.utils.data.Dataset):
"""A dataset that provides helpers for batching.""" """A dataset that provides helpers for batching."""
...@@ -18,7 +20,14 @@ class FairseqDataset(torch.utils.data.Dataset): ...@@ -18,7 +20,14 @@ class FairseqDataset(torch.utils.data.Dataset):
raise NotImplementedError raise NotImplementedError
def collater(self, samples): def collater(self, samples):
"""Merge a list of samples to form a mini-batch.""" """Merge a list of samples to form a mini-batch.
Args:
samples (List[int]): sample indices to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
raise NotImplementedError raise NotImplementedError
def get_dummy_batch(self, num_tokens, max_positions): def get_dummy_batch(self, num_tokens, max_positions):
...@@ -26,13 +35,16 @@ class FairseqDataset(torch.utils.data.Dataset): ...@@ -26,13 +35,16 @@ class FairseqDataset(torch.utils.data.Dataset):
raise NotImplementedError raise NotImplementedError
def num_tokens(self, index): def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching.""" """Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
raise NotImplementedError raise NotImplementedError
def ordered_indices(self): def size(self, index):
"""Ordered indices for batching.""" """Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
raise NotImplementedError raise NotImplementedError
def valid_size(self, index, max_positions): def ordered_indices(self):
"""Check if an example's size is valid according to max_positions.""" """Return an ordered list of indices. Batches will be constructed based
on this order."""
raise NotImplementedError raise NotImplementedError
...@@ -8,6 +8,8 @@ ...@@ -8,6 +8,8 @@
import numpy as np import numpy as np
import torch import torch
from fairseq import utils
from . import data_utils, FairseqDataset from . import data_utils, FairseqDataset
...@@ -59,7 +61,27 @@ def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=Fal ...@@ -59,7 +61,27 @@ def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=Fal
class LanguagePairDataset(FairseqDataset): class LanguagePairDataset(FairseqDataset):
"""A pair of torch.utils.data.Datasets.""" """
A pair of torch.utils.data.Datasets.
Args:
src (torch.utils.data.Dataset): source dataset to wrap
src_sizes (List[int]): source sentence lengths
src_dict (fairseq.data.Dictionary): source vocabulary
tgt (torch.utils.data.Dataset, optional): target dataset to wrap
tgt_sizes (List[int], optional): target sentence lengths
tgt_dict (fairseq.data.Dictionary, optional): target vocabulary
left_pad_source (bool, optional): pad source tensors on the left side.
Default: ``True``
left_pad_target (bool, optional): pad target tensors on the left side.
Default: ``False``
max_source_positions (int, optional): max number of tokens in the source
sentence. Default: ``1024``
max_target_positions (int, optional): max number of tokens in the target
sentence. Default: ``1024``
shuffle (bool, optional): shuffle dataset elements before batching.
Default: ``True``
"""
def __init__( def __init__(
self, src, src_sizes, src_dict, self, src, src_sizes, src_dict,
...@@ -95,15 +117,43 @@ class LanguagePairDataset(FairseqDataset): ...@@ -95,15 +117,43 @@ class LanguagePairDataset(FairseqDataset):
return len(self.src) return len(self.src)
def collater(self, samples): def collater(self, samples):
"""Merge a list of samples to form a mini-batch.""" """Merge a list of samples to form a mini-batch.
Returned mini-batches contain the following keys:
- `id` (torch.LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (torch.LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will appear
on the left if ``left_pad_source`` is True.
- `src_lengths` (torch.LongTensor): 1D Tensor of the unpadded lengths
of each source sentence of shape `(bsz)`
- `prev_output_tokens` (torch.LongTensor): a padded 2D Tensor of
tokens in the target sentence, shifted right by one position for
input feeding/teacher forcing, of shape `(bsz, tgt_len)`. Padding
will appear on the left if ``left_pad_target`` is True.
- `target` (torch.LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear on the
left if ``left_pad_target`` is True.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
return collate( return collate(
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(), samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
) )
def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128): def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128):
max_source_positions, max_target_positions = self._get_max_positions(max_positions) """Return a dummy batch with a given number of tokens."""
src_len, tgt_len = min(src_len, max_source_positions), min(tgt_len, max_target_positions) src_len, tgt_len = utils.resolve_max_positions(
(src_len, tgt_len),
max_positions,
(self.max_source_positions, self.max_target_positions),
)
bsz = num_tokens // max(src_len, tgt_len) bsz = num_tokens // max(src_len, tgt_len)
return self.collater([ return self.collater([
{ {
...@@ -115,11 +165,18 @@ class LanguagePairDataset(FairseqDataset): ...@@ -115,11 +165,18 @@ class LanguagePairDataset(FairseqDataset):
]) ])
def num_tokens(self, index): def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching.""" """Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0) return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return (self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
def ordered_indices(self): def ordered_indices(self):
"""Ordered indices for batching.""" """Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle: if self.shuffle:
indices = np.random.permutation(len(self)) indices = np.random.permutation(len(self))
else: else:
...@@ -127,18 +184,3 @@ class LanguagePairDataset(FairseqDataset): ...@@ -127,18 +184,3 @@ class LanguagePairDataset(FairseqDataset):
if self.tgt_sizes is not None: if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
max_source_positions, max_target_positions = self._get_max_positions(max_positions)
return (
self.src_sizes[index] <= max_source_positions
and (self.tgt_sizes is None or self.tgt_sizes[index] <= max_target_positions)
)
def _get_max_positions(self, max_positions):
if max_positions is None:
return self.max_source_positions, self.max_target_positions
assert len(max_positions) == 2
max_src_pos, max_tgt_pos = max_positions
return min(self.max_source_positions, max_src_pos), min(self.max_target_positions, max_tgt_pos)
...@@ -31,7 +31,16 @@ def collate(samples, pad_idx, eos_idx): ...@@ -31,7 +31,16 @@ def collate(samples, pad_idx, eos_idx):
class MonolingualDataset(FairseqDataset): class MonolingualDataset(FairseqDataset):
"""A wrapper around torch.utils.data.Dataset for monolingual data.""" """
A wrapper around torch.utils.data.Dataset for monolingual data.
Args:
dataset (torch.utils.data.Dataset): dataset to wrap
sizes (List[int]): sentence lengths
vocab (fairseq.data.Dictionary): vocabulary
shuffle (bool, optional): shuffle the elements before batching.
Default: ``True``
"""
def __init__(self, dataset, sizes, vocab, shuffle): def __init__(self, dataset, sizes, vocab, shuffle):
self.dataset = dataset self.dataset = dataset
...@@ -47,12 +56,31 @@ class MonolingualDataset(FairseqDataset): ...@@ -47,12 +56,31 @@ class MonolingualDataset(FairseqDataset):
return len(self.dataset) return len(self.dataset)
def collater(self, samples): def collater(self, samples):
"""Merge a list of samples to form a mini-batch.""" """Merge a list of samples to form a mini-batch.
Returned mini-batches contain the following keys:
- `id` (torch.LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (torch.LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will appear
on the right.
- `target` (torch.LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear on the
right.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch suitable for forwarding with a Model
"""
return collate(samples, self.vocab.pad(), self.vocab.eos()) return collate(samples, self.vocab.pad(), self.vocab.eos())
def get_dummy_batch(self, num_tokens, max_positions, tgt_len=128): def get_dummy_batch(self, num_tokens, max_positions, tgt_len=128):
assert isinstance(max_positions, float) or isinstance(max_positions, int) """Return a dummy batch with a given number of tokens."""
tgt_len = min(tgt_len, max_positions) if isinstance(max_positions, float) or isinstance(max_positions, int):
tgt_len = min(tgt_len, max_positions)
bsz = num_tokens // tgt_len bsz = num_tokens // tgt_len
target = self.vocab.dummy_sentence(tgt_len + 1) target = self.vocab.dummy_sentence(tgt_len + 1)
source, target = target[:-1], target[1:] source, target = target[:-1], target[1:]
...@@ -62,19 +90,21 @@ class MonolingualDataset(FairseqDataset): ...@@ -62,19 +90,21 @@ class MonolingualDataset(FairseqDataset):
]) ])
def num_tokens(self, index): def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching.""" """Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return self.sizes[index]
def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return self.sizes[index] return self.sizes[index]
def ordered_indices(self): def ordered_indices(self):
"""Ordered indices for batching.""" """Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle: if self.shuffle:
order = [np.random.permutation(len(self))] order = [np.random.permutation(len(self))]
else: else:
order = [np.arange(len(self))] order = [np.arange(len(self))]
order.append(np.flip(self.sizes, 0)) order.append(np.flip(self.sizes, 0))
return np.lexsort(order) return np.lexsort(order)
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
assert isinstance(max_positions, float) or isinstance(max_positions, int)
return self.sizes[index] <= max_positions
...@@ -5,11 +5,13 @@ ...@@ -5,11 +5,13 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from fairseq.data import data_utils, FairseqDataset
class FairseqTask(object): class FairseqTask(object):
""" """
A Task defines the data format, stores shared state (e.g., dictionaries) and Tasks store dictionaries and provide helpers for loading/iterating over
provides helpers for building the model/criterion and calculating the loss. Datasets, initializing the Model/Criterion and calculating the loss.
""" """
@staticmethod @staticmethod
...@@ -37,6 +39,62 @@ class FairseqTask(object): ...@@ -37,6 +39,62 @@ class FairseqTask(object):
raise TypeError('Datasets are expected to be of type FairseqDataset') raise TypeError('Datasets are expected to be of type FairseqDataset')
return self.datasets[split] return self.datasets[split]
def get_batch_iterator(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1,
seed=1, num_shards=1, shard_id=0,
):
"""
Generate batches of indices.
Args:
dataset (FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch.
Default: ``None``
max_sentences (int, optional): max number of sentences in each
batch. Default: ``None``
max_positions (optional): max sentence length supported by the
model. Default: ``None``
ignore_invalid_inputs (bool, optional): don't raise Exception for
sentences that are too long. Default: ``False``
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1``
seed (int, optional): seed for random number generator for
reproducibility. Default: ``1``
num_shards (int, optional): shard the data iterator into N
shards. Default: ``1``
shard_id (int, optional): which shard of the data iterator to
return. Default: ``0``
Returns:
EpochBatchIterator: a batched iterator over the given dataset split
"""
assert isinstance(dataset, FairseqDataset)
# get indices ordered by example size
with data_utils.numpy_seed(seed):
indices = dataset.ordered_indices()
# filter examples that are too large
indices = data_utils.filter_by_size(
indices, dataset.size, max_positions, raise_exception=(not ignore_invalid_inputs),
)
# create mini-batches with given size constraints
batch_sampler = data_utils.batch_by_size(
indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
)
# return a reusable, sharded iterator
return data_utils.EpochBatchIterator(
dataset=dataset,
batch_sampler=batch_sampler,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
)
def build_model(self, args): def build_model(self, args):
from fairseq import models from fairseq import models
return models.build_model(args, self) return models.build_model(args, self)
...@@ -48,6 +106,9 @@ class FairseqTask(object): ...@@ -48,6 +106,9 @@ class FairseqTask(object):
def get_loss(self, model, criterion, sample): def get_loss(self, model, criterion, sample):
return criterion(model, sample) return criterion(model, sample)
def max_positions(self):
return None
@property @property
def source_dictionary(self): def source_dictionary(self):
raise NotImplementedError raise NotImplementedError
......
...@@ -139,6 +139,9 @@ class TranslationTask(FairseqTask): ...@@ -139,6 +139,9 @@ class TranslationTask(FairseqTask):
max_target_positions=self.args.max_target_positions, max_target_positions=self.args.max_target_positions,
) )
def max_positions(self):
return (self.args.max_source_positions, self.args.max_target_positions)
@property @property
def source_dictionary(self): def source_dictionary(self):
return self.src_dict return self.src_dict
......
...@@ -150,7 +150,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): ...@@ -150,7 +150,7 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
ensemble = [] ensemble = []
for state in states: for state in states:
args = state['args'] args = state['args']
if model_arg_overrides is not None: if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides) args = _override_model_args(args, model_arg_overrides)
...@@ -399,3 +399,17 @@ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): ...@@ -399,3 +399,17 @@ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
idx = int(m.group(1)) if len(m.groups()) > 0 else i idx = int(m.group(1)) if len(m.groups()) > 0 else i
entries.append((idx, m.group(0))) entries.append((idx, m.group(0)))
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
def resolve_max_positions(*args):
"""Resolve max position constraints from multiple sources."""
max_positions = None
for arg in args:
if max_positions is None:
max_positions = arg
elif arg is not None:
if isinstance(arg, float) or isinstance(arg, int):
max_positions = min(max_positions, arg)
else:
max_positions = tuple(map(min, zip(max_positions, arg)))
return max_positions
...@@ -54,11 +54,14 @@ def main(args): ...@@ -54,11 +54,14 @@ def main(args):
align_dict = utils.load_align_dict(args.replace_unk) align_dict = utils.load_align_dict(args.replace_unk)
# Load dataset (possibly sharded) # Load dataset (possibly sharded)
itr = data.EpochBatchIterator( itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset), dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
max_sentences=args.max_sentences, max_sentences=args.max_sentences,
max_positions=models[0].max_positions(), max_positions=utils.resolve_max_positions(
task.max_positions(),
*[model.max_positions() for model in models]
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8, required_batch_size_multiple=8,
num_shards=args.num_shards, num_shards=args.num_shards,
......
...@@ -32,14 +32,14 @@ def buffered_read(buffer_size): ...@@ -32,14 +32,14 @@ def buffered_read(buffer_size):
yield buffer yield buffer
def make_batches(lines, args, src_dict, max_positions): def make_batches(lines, args, task, max_positions):
tokens = [ tokens = [
tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long() tokenizer.Tokenizer.tokenize(src_str, task.source_dictionary, add_if_not_exist=False).long()
for src_str in lines for src_str in lines
] ]
lengths = np.array([t.numel() for t in tokens]) lengths = np.array([t.numel() for t in tokens])
itr = data.EpochBatchIterator( itr = task.get_batch_iterator(
dataset=data.LanguagePairDataset(tokens, lengths, src_dict), dataset=data.LanguagePairDataset(tokens, lengths, task.source_dictionary),
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
max_sentences=args.max_sentences, max_sentences=args.max_sentences,
max_positions=max_positions, max_positions=max_positions,
...@@ -76,7 +76,6 @@ def main(args): ...@@ -76,7 +76,6 @@ def main(args):
models, model_args = utils.load_ensemble_for_inference(model_paths, task, model_arg_overrides=eval(args.model_overrides)) models, model_args = utils.load_ensemble_for_inference(model_paths, task, model_arg_overrides=eval(args.model_overrides))
# Set dictionaries # Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary tgt_dict = task.target_dictionary
# Optimize ensemble for generation # Optimize ensemble for generation
...@@ -151,13 +150,18 @@ def main(args): ...@@ -151,13 +150,18 @@ def main(args):
return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)] return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]
max_positions = utils.resolve_max_positions(
task.max_positions(),
*[model.max_positions() for model in models]
)
if args.buffer_size > 1: if args.buffer_size > 1:
print('| Sentence buffer size:', args.buffer_size) print('| Sentence buffer size:', args.buffer_size)
print('| Type the input sentence and press return:') print('| Type the input sentence and press return:')
for inputs in buffered_read(args.buffer_size): for inputs in buffered_read(args.buffer_size):
indices = [] indices = []
results = [] results = []
for batch, batch_indices in make_batches(inputs, args, src_dict, models[0].max_positions()): for batch, batch_indices in make_batches(inputs, args, task, max_positions):
indices.extend(batch_indices) indices.extend(batch_indices)
results += process_batch(batch) results += process_batch(batch)
......
...@@ -44,7 +44,7 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc ...@@ -44,7 +44,7 @@ def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoc
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch) trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
epoch_itr = data.EpochBatchIterator( epoch_itr = data.EpochBatchIterator(
dataset=data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False), dataset=data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False),
max_tokens=1, batch_sampler=[[i] for i in range(epoch_size)],
) )
return trainer, epoch_itr return trainer, epoch_itr
......
...@@ -12,7 +12,7 @@ import os ...@@ -12,7 +12,7 @@ import os
import math import math
import torch import torch
from fairseq import data, distributed_utils, options, progress_bar, tasks, utils from fairseq import distributed_utils, options, progress_bar, tasks, utils
from fairseq.fp16_trainer import FP16Trainer from fairseq.fp16_trainer import FP16Trainer
from fairseq.trainer import Trainer from fairseq.trainer import Trainer
from fairseq.meters import AverageMeter, StopwatchMeter from fairseq.meters import AverageMeter, StopwatchMeter
...@@ -57,11 +57,14 @@ def main(args): ...@@ -57,11 +57,14 @@ def main(args):
)) ))
# Initialize dataloader # Initialize dataloader
max_positions = trainer.get_model().max_positions() max_positions = utils.resolve_max_positions(
epoch_itr = data.EpochBatchIterator( task.max_positions(),
trainer.get_model().max_positions(),
)
epoch_itr = task.get_batch_iterator(
dataset=task.dataset(args.train_subset), dataset=task.dataset(args.train_subset),
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
max_sentences=args.max_sentences_valid, max_sentences=args.max_sentences,
max_positions=max_positions, max_positions=max_positions,
ignore_invalid_inputs=True, ignore_invalid_inputs=True,
required_batch_size_multiple=8, required_batch_size_multiple=8,
...@@ -193,11 +196,14 @@ def validate(args, trainer, task, epoch_itr, subsets): ...@@ -193,11 +196,14 @@ def validate(args, trainer, task, epoch_itr, subsets):
valid_losses = [] valid_losses = []
for subset in subsets: for subset in subsets:
# Initialize data iterator # Initialize data iterator
itr = data.EpochBatchIterator( itr = task.get_batch_iterator(
dataset=task.dataset(subset), dataset=task.dataset(subset),
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
max_sentences=args.max_sentences_valid, max_sentences=args.max_sentences_valid,
max_positions=trainer.get_model().max_positions(), max_positions=utils.resolve_max_positions(
task.max_positions(),
trainer.get_model().max_positions(),
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=8, required_batch_size_multiple=8,
seed=args.seed, seed=args.seed,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment