Commit 7633129b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#283)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/283

Pull Request resolved: https://github.com/pytorch/fairseq/pull/428

Differential Revision: D13564190

Pulled By: myleott

fbshipit-source-id: 3b62282d7069c288f5bdd1dd2c120788cee4abb5
parent 0cb87130
...@@ -19,10 +19,13 @@ of various sequence-to-sequence models, including: ...@@ -19,10 +19,13 @@ of various sequence-to-sequence models, including:
Fairseq features: Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines - multi-GPU (distributed) training on one machine or across multiple machines
- fast beam search generation on both CPU and GPU - fast generation on both CPU and GPU with multiple search algorithms implemented:
- beam search
- Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
- sampling (unconstrained and top-k)
- large mini-batch training even on a single GPU via delayed updates - large mini-batch training even on a single GPU via delayed updates
- fast half-precision floating point (FP16) training - fast half-precision floating point (FP16) training
- extensible: easily register new models, criterions, and tasks - extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers
We also provide [pre-trained models](#pre-trained-models) for several benchmark We also provide [pre-trained models](#pre-trained-models) for several benchmark
translation and language modeling datasets. translation and language modeling datasets.
...@@ -34,7 +37,7 @@ translation and language modeling datasets. ...@@ -34,7 +37,7 @@ translation and language modeling datasets.
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* Python version 3.6 * Python version 3.6
Currently fairseq requires PyTorch version >= 0.4.0. Currently fairseq requires PyTorch version >= 1.0.0.
Please follow the instructions here: https://github.com/pytorch/pytorch#installation. Please follow the instructions here: https://github.com/pytorch/pytorch#installation.
If you use Docker make sure to increase the shared memory size either with If you use Docker make sure to increase the shared memory size either with
......
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import socket
import subprocess
from train import main as single_process_main
from fairseq import distributed_utils, options
def main(args):
if args.distributed_init_method is None and args.distributed_port > 0:
# We can determine the init method automatically for Slurm.
node_list = os.environ.get('SLURM_JOB_NODELIST')
if node_list is not None:
try:
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
args.distributed_init_method = 'tcp://{host}:{port}'.format(
host=hostnames.split()[0].decode('utf-8'),
port=args.distributed_port)
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError as e: # Slurm is not installed
pass
if args.distributed_init_method is None and args.distributed_port is None:
raise ValueError('--distributed-init-method or --distributed-port '
'must be specified for distributed training')
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
single_process_main(args)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)
...@@ -6,8 +6,26 @@ ...@@ -6,8 +6,26 @@
Criterions Criterions
========== ==========
Criterions compute the loss function given the model and batch, roughly::
loss = criterion(model, batch)
.. automodule:: fairseq.criterions .. automodule:: fairseq.criterions
:members: :members:
.. autoclass:: fairseq.criterions.FairseqCriterion .. autoclass:: fairseq.criterions.FairseqCriterion
:members: :members:
:undoc-members: :undoc-members:
.. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
:members:
:undoc-members:
...@@ -21,6 +21,20 @@ mini-batches. ...@@ -21,6 +21,20 @@ mini-batches.
.. autoclass:: fairseq.data.MonolingualDataset .. autoclass:: fairseq.data.MonolingualDataset
:members: :members:
**Helper Datasets**
These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
provide additional functionality:
.. autoclass:: fairseq.data.BacktranslationDataset
:members:
.. autoclass:: fairseq.data.ConcatDataset
:members:
.. autoclass:: fairseq.data.RoundRobinZipDatasets
:members:
.. autoclass:: fairseq.data.TransformEosDataset
:members:
Dictionary Dictionary
---------- ----------
...@@ -32,6 +46,8 @@ Dictionary ...@@ -32,6 +46,8 @@ Dictionary
Iterators Iterators
--------- ---------
.. autoclass:: fairseq.data.BufferedIterator
:members:
.. autoclass:: fairseq.data.CountingIterator .. autoclass:: fairseq.data.CountingIterator
:members: :members:
.. autoclass:: fairseq.data.EpochBatchIterator .. autoclass:: fairseq.data.EpochBatchIterator
......
...@@ -27,21 +27,20 @@ interactively. Here, we use a beam size of 5: ...@@ -27,21 +27,20 @@ interactively. Here, we use a beam size of 5:
> MODEL_DIR=wmt14.en-fr.fconv-py > MODEL_DIR=wmt14.en-fr.fconv-py
> python interactive.py \ > python interactive.py \
--path $MODEL_DIR/model.pt $MODEL_DIR \ --path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5 --beam 5 --source-lang en --target-lang fr
| loading model(s) from wmt14.en-fr.fconv-py/model.pt | loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types | [en] dictionary: 44206 types
| [fr] dictionary: 44463 types | [fr] dictionary: 44463 types
| Type the input sentence and press return: | Type the input sentence and press return:
> Why is it rare to discover new marine mam@@ mal species ? > Why is it rare to discover new marine mam@@ mal species ?
O Why is it rare to discover new marine mam@@ mal species ? O Why is it rare to discover new marine mam@@ mal species ?
H -0.06429661810398102 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ? H -0.1525060087442398 Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ?
A 0 1 3 3 5 6 6 8 8 8 7 11 12 P -0.2221 -0.3122 -0.1289 -0.2673 -0.1711 -0.1930 -0.1101 -0.1660 -0.1003 -0.0740 -0.1101 -0.0814 -0.1238 -0.0985 -0.1288
This generation script produces four types of outputs: a line prefixed This generation script produces three types of outputs: a line prefixed
with *S* shows the supplied source sentence after applying the with *O* is a copy of the original source sentence; *H* is the
vocabulary; *O* is a copy of the original source sentence; *H* is the hypothesis along with an average log-likelihood; and *P* is the
hypothesis along with an average log-likelihood; and *A* is the positional score per token position, including the
attention maxima for each word in the hypothesis, including the
end-of-sentence marker which is omitted from the text. end-of-sentence marker which is omitted from the text.
See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a
......
...@@ -6,7 +6,29 @@ ...@@ -6,7 +6,29 @@
Learning Rate Schedulers Learning Rate Schedulers
======================== ========================
TODO Learning Rate Schedulers update the learning rate over the course of training.
Learning rates can be updated after each update via :func:`step_update` or at
epoch boundaries via :func:`step`.
.. automodule:: fairseq.optim.lr_scheduler .. automodule:: fairseq.optim.lr_scheduler
:members: :members:
.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.reduce_angular_lr_scheduler.TriangularSchedule
:members:
:undoc-members:
Modules Modules
======= =======
Fairseq provides several stand-alone :class:`torch.nn.Module` s that may be Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
helpful when implementing a new :class:`FairseqModel`. be helpful when implementing a new :class:`~fairseq.models.FairseqModel`.
.. automodule:: fairseq.modules .. automodule:: fairseq.modules
:members: :members:
......
...@@ -6,5 +6,27 @@ ...@@ -6,5 +6,27 @@
Optimizers Optimizers
========== ==========
Optimizers update the Model parameters based on the gradients.
.. automodule:: fairseq.optim .. automodule:: fairseq.optim
:members: :members:
.. autoclass:: fairseq.optim.FairseqOptimizer
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adagrad.Adagrad
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adam.FairseqAdam
:members:
:undoc-members:
.. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
:members:
:undoc-members:
.. autoclass:: fairseq.optim.nag.FairseqNAG
:members:
:undoc-members:
.. autoclass:: fairseq.optim.sgd.SGD
:members:
:undoc-members:
...@@ -22,12 +22,18 @@ fairseq implements the following high-level training flow:: ...@@ -22,12 +22,18 @@ fairseq implements the following high-level training flow::
for epoch in range(num_epochs): for epoch in range(num_epochs):
itr = task.get_batch_iterator(task.dataset('train')) itr = task.get_batch_iterator(task.dataset('train'))
for num_updates, batch in enumerate(itr): for num_updates, batch in enumerate(itr):
loss = criterion(model, batch) task.train_step(batch, model, criterion, optimizer)
optimizer.backward(loss) average_and_clip_gradients()
optimizer.step() optimizer.step()
lr_scheduler.step_update(num_updates) lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch) lr_scheduler.step(epoch)
where the default implementation for ``train.train_step`` is roughly::
def train_step(self, batch, model, criterion, optimizer):
loss = criterion(model, batch)
optimizer.backward(loss)
**Registering new plug-ins** **Registering new plug-ins**
New plug-ins are *registered* through a set of ``@register`` function New plug-ins are *registered* through a set of ``@register`` function
......
...@@ -353,17 +353,16 @@ The model files should appear in the :file:`checkpoints/` directory. ...@@ -353,17 +353,16 @@ The model files should appear in the :file:`checkpoints/` directory.
------------------------------- -------------------------------
Finally we can write a short script to evaluate our model on new inputs. Create Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classify.py` with the following contents:: a new file named :file:`eval_classifier.py` with the following contents::
from fairseq import data, options, tasks, utils from fairseq import data, options, tasks, utils
from fairseq.tokenizer import Tokenizer from fairseq.tokenizer import Tokenizer
# Parse command-line arguments for generation # Parse command-line arguments for generation
parser = options.get_generation_parser() parser = options.get_generation_parser(default_task='simple_classification')
args = options.parse_args_and_arch(parser) args = options.parse_args_and_arch(parser)
# Setup task # Setup task
args.task = 'simple_classification'
task = tasks.setup_task(args) task = tasks.setup_task(args)
# Load model # Load model
......
...@@ -55,7 +55,9 @@ def main(parsed_args): ...@@ -55,7 +55,9 @@ def main(parsed_args):
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path)) print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides)) models, args = utils.load_ensemble_for_inference(
parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
)
for arg in vars(parsed_args).keys(): for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}: if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
...@@ -83,9 +85,10 @@ def main(parsed_args): ...@@ -83,9 +85,10 @@ def main(parsed_args):
max_positions=utils.resolve_max_positions(*[ max_positions=utils.resolve_max_positions(*[
model.max_positions() for model in models model.max_positions() for model in models
]), ]),
ignore_invalid_inputs=True,
num_shards=args.num_shards, num_shards=args.num_shards,
shard_id=args.shard_id, shard_id=args.shard_id,
ignore_invalid_inputs=True, num_workers=args.num_workers,
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
......
...@@ -9,7 +9,7 @@ from .dictionary import Dictionary, TruncatedDictionary ...@@ -9,7 +9,7 @@ from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset from .fairseq_dataset import FairseqDataset
from .backtranslation_dataset import BacktranslationDataset from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets from .round_robin_zip_datasets import RoundRobinZipDatasets
...@@ -33,7 +33,6 @@ __all__ = [ ...@@ -33,7 +33,6 @@ __all__ = [
'GroupedIterator', 'GroupedIterator',
'IndexedCachedDataset', 'IndexedCachedDataset',
'IndexedDataset', 'IndexedDataset',
'IndexedInMemoryDataset',
'IndexedRawTextDataset', 'IndexedRawTextDataset',
'LanguagePairDataset', 'LanguagePairDataset',
'MonolingualDataset', 'MonolingualDataset',
......
...@@ -56,6 +56,28 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True): ...@@ -56,6 +56,28 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
class BacktranslationDataset(FairseqDataset): class BacktranslationDataset(FairseqDataset):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*),
and returns the corresponding `{generated src, input tgt}` batch.
Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be used.
After backtranslation, the source sentences in this dataset will be
returned as the targets.
backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be passed
into *backtranslation_fn*.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch
(default: ``tgt_dataset.collater``).
cuda: use GPU for generation
"""
def __init__( def __init__(
self, self,
tgt_dataset, tgt_dataset,
...@@ -66,27 +88,6 @@ class BacktranslationDataset(FairseqDataset): ...@@ -66,27 +88,6 @@ class BacktranslationDataset(FairseqDataset):
cuda=True, cuda=True,
**kwargs **kwargs
): ):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*),
and returns the corresponding `{generated src, input tgt}` batch.
Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be
used. After backtranslation, the source sentences in this
dataset will be returned as the targets.
backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be
passed into *backtranslation_fn*.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch (default:
``tgt_dataset.collater``)
cuda: use GPU for generation
"""
self.tgt_dataset = tgt_dataset self.tgt_dataset = tgt_dataset
self.backtranslation_fn = backtranslation_fn self.backtranslation_fn = backtranslation_fn
self.max_len_a = max_len_a self.max_len_a = max_len_a
...@@ -166,11 +167,10 @@ class BacktranslationDataset(FairseqDataset): ...@@ -166,11 +167,10 @@ class BacktranslationDataset(FairseqDataset):
""" """
tgt_size = self.tgt_dataset.size(index)[0] tgt_size = self.tgt_dataset.size(index)[0]
return (tgt_size, tgt_size) return (tgt_size, tgt_size)
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return self.tgt_dataset.supports_prefetch() return getattr(self.tgt_dataset, 'supports_prefetch', False)
def prefetch(self, indices): def prefetch(self, indices):
return self.tgt_dataset.prefetch(indices) return self.tgt_dataset.prefetch(indices)
...@@ -29,18 +29,18 @@ class ConcatDataset(FairseqDataset): ...@@ -29,18 +29,18 @@ class ConcatDataset(FairseqDataset):
if isinstance(sample_ratios, int): if isinstance(sample_ratios, int):
sample_ratios = [sample_ratios] * len(self.datasets) sample_ratios = [sample_ratios] * len(self.datasets)
self.sample_ratios = sample_ratios self.sample_ratios = sample_ratios
self.cummulative_sizes = self.cumsum(self.datasets, sample_ratios) self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.real_sizes = [len(d) for d in self.datasets] self.real_sizes = [len(d) for d in self.datasets]
def __len__(self): def __len__(self):
return self.cummulative_sizes[-1] return self.cumulative_sizes[-1]
def __getitem__(self, idx): def __getitem__(self, idx):
dataset_idx = bisect.bisect_right(self.cummulative_sizes, idx) dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0: if dataset_idx == 0:
sample_idx = idx sample_idx = idx
else: else:
sample_idx = idx - self.cummulative_sizes[dataset_idx - 1] sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % self.real_sizes[dataset_idx] sample_idx = sample_idx % self.real_sizes[dataset_idx]
return self.datasets[dataset_idx][sample_idx] return self.datasets[dataset_idx][sample_idx]
...@@ -54,7 +54,7 @@ class ConcatDataset(FairseqDataset): ...@@ -54,7 +54,7 @@ class ConcatDataset(FairseqDataset):
def prefetch(self, indices): def prefetch(self, indices):
frm = 0 frm = 0
for to, ds in zip(self.cummulative_sizes, self.datasets): for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds) real_size = len(ds)
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to frm = to
...@@ -81,8 +81,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False): ...@@ -81,8 +81,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
size_fn (callable): function that returns the size of a given index size_fn (callable): function that returns the size of a given index
max_positions (tuple): filter elements larger than this size. max_positions (tuple): filter elements larger than this size.
Comparisons are done component-wise. Comparisons are done component-wise.
raise_exception (bool, optional): if ``True``, raise an exception raise_exception (bool, optional): if ``True``, raise an exception if
if any elements are filtered. Default: ``False`` any elements are filtered (default: False).
""" """
def check_size(idx): def check_size(idx):
if isinstance(max_positions, float) or isinstance(max_positions, int): if isinstance(max_positions, float) or isinstance(max_positions, int):
...@@ -128,12 +128,12 @@ def batch_by_size( ...@@ -128,12 +128,12 @@ def batch_by_size(
indices (List[int]): ordered list of dataset indices indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at num_tokens_fn (callable): function that returns the number of tokens at
a given index a given index
max_tokens (int, optional): max number of tokens in each batch. max_tokens (int, optional): max number of tokens in each batch
Default: ``None`` (default: None).
max_sentences (int, optional): max number of sentences in each max_sentences (int, optional): max number of sentences in each
batch. Default: ``None`` batch (default: None).
required_batch_size_multiple (int, optional): require batch size to required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1`` be a multiple of N (default: 1).
""" """
max_tokens = max_tokens if max_tokens is not None else float('Inf') max_tokens = max_tokens if max_tokens is not None else float('Inf')
max_sentences = max_sentences if max_sentences is not None else float('Inf') max_sentences = max_sentences if max_sentences is not None else float('Inf')
......
...@@ -200,11 +200,15 @@ class Dictionary(object): ...@@ -200,11 +200,15 @@ class Dictionary(object):
t[-1] = self.eos() t[-1] = self.eos()
return t return t
class TruncatedDictionary(object): class TruncatedDictionary(object):
def __init__(self, wrapped_dict, length): def __init__(self, wrapped_dict, length):
self.__class__ = type(wrapped_dict.__class__.__name__, self.__class__ = type(
(self.__class__, wrapped_dict.__class__), {}) wrapped_dict.__class__.__name__,
(self.__class__, wrapped_dict.__class__),
{}
)
self.__dict__ = wrapped_dict.__dict__ self.__dict__ = wrapped_dict.__dict__
self.wrapped_dict = wrapped_dict self.wrapped_dict = wrapped_dict
self.length = min(len(self.wrapped_dict), length) self.length = min(len(self.wrapped_dict), length)
......
...@@ -7,8 +7,6 @@ ...@@ -7,8 +7,6 @@
import torch.utils.data import torch.utils.data
from fairseq.data import data_utils
class FairseqDataset(torch.utils.data.Dataset): class FairseqDataset(torch.utils.data.Dataset):
"""A dataset that provides helpers for batching.""" """A dataset that provides helpers for batching."""
...@@ -51,7 +49,9 @@ class FairseqDataset(torch.utils.data.Dataset): ...@@ -51,7 +49,9 @@ class FairseqDataset(torch.utils.data.Dataset):
@property @property
def supports_prefetch(self): def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return False return False
def prefetch(self, indices): def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
raise NotImplementedError raise NotImplementedError
...@@ -52,13 +52,12 @@ def data_file_path(prefix_path): ...@@ -52,13 +52,12 @@ def data_file_path(prefix_path):
class IndexedDataset(torch.utils.data.Dataset): class IndexedDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset""" """Loader for TorchNet IndexedDataset"""
def __init__(self, path, fix_lua_indexing=False, read_data=True): def __init__(self, path, fix_lua_indexing=False):
super().__init__() super().__init__()
self.fix_lua_indexing = fix_lua_indexing self.fix_lua_indexing = fix_lua_indexing
self.read_index(path) self.read_index(path)
self.data_file = None self.data_file = None
if read_data: self.path = path
self.read_data(path)
def read_index(self, path): def read_index(self, path):
with open(index_file_path(path), 'rb') as f: with open(index_file_path(path), 'rb') as f:
...@@ -85,8 +84,10 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -85,8 +84,10 @@ class IndexedDataset(torch.utils.data.Dataset):
self.data_file.close() self.data_file.close()
def __getitem__(self, i): def __getitem__(self, i):
if not self.data_file:
self.read_data(self.path)
self.check_index(i) self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] tensor_size = int(self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]])
a = np.empty(tensor_size, dtype=self.dtype) a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a) self.data_file.readinto(a)
...@@ -98,12 +99,6 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -98,12 +99,6 @@ class IndexedDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return self.size return self.size
def read_into(self, start, dst):
self.data_file.seek(start * self.element_size)
self.data_file.readinto(dst)
if self.fix_lua_indexing:
dst -= 1 # subtract 1 for 0-based indexing
@staticmethod @staticmethod
def exists(path): def exists(path):
return ( return (
...@@ -111,11 +106,15 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -111,11 +106,15 @@ class IndexedDataset(torch.utils.data.Dataset):
os.path.exists(data_file_path(path)) os.path.exists(data_file_path(path))
) )
@property
def supports_prefetch(self):
return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset): class IndexedCachedDataset(IndexedDataset):
def __init__(self, path, fix_lua_indexing=False): def __init__(self, path, fix_lua_indexing=False):
super().__init__(path, fix_lua_indexing, True) super().__init__(path, fix_lua_indexing=fix_lua_indexing)
self.cache = None self.cache = None
self.cache_index = {} self.cache_index = {}
...@@ -126,6 +125,8 @@ class IndexedCachedDataset(IndexedDataset): ...@@ -126,6 +125,8 @@ class IndexedCachedDataset(IndexedDataset):
def prefetch(self, indices): def prefetch(self, indices):
if all(i in self.cache_index for i in indices): if all(i in self.cache_index for i in indices):
return return
if not self.data_file:
self.read_data(self.path)
indices = sorted(set(indices)) indices = sorted(set(indices))
total_size = 0 total_size = 0
for i in indices: for i in indices:
...@@ -153,34 +154,7 @@ class IndexedCachedDataset(IndexedDataset): ...@@ -153,34 +154,7 @@ class IndexedCachedDataset(IndexedDataset):
return item return item
class IndexedInMemoryDataset(IndexedDataset): class IndexedRawTextDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb')
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
self.data_file.readinto(self.buffer)
self.data_file.close()
if self.fix_lua_indexing:
self.buffer -= 1 # subtract 1 for 0-based indexing
def read_into(self, start, dst):
if self.token_blob is None:
self.token_blob = [t for l in self.tokens_list for t in l]
np.copyto(dst, self.token_blob[start:])
def __del__(self):
pass
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
return torch.from_numpy(a).long()
class IndexedRawTextDataset(IndexedDataset):
"""Takes a text file as input and binarizes it in memory at instantiation. """Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory""" Original lines are also kept in memory"""
...@@ -205,6 +179,10 @@ class IndexedRawTextDataset(IndexedDataset): ...@@ -205,6 +179,10 @@ class IndexedRawTextDataset(IndexedDataset):
self.sizes.append(len(tokens)) self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes) self.sizes = np.array(self.sizes)
def check_index(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
def __getitem__(self, i): def __getitem__(self, i):
self.check_index(i) self.check_index(i)
return self.tokens_list[i] return self.tokens_list[i]
...@@ -252,7 +230,7 @@ class IndexedDatasetBuilder(object): ...@@ -252,7 +230,7 @@ class IndexedDatasetBuilder(object):
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def merge_file_(self, another_file): def merge_file_(self, another_file):
index = IndexedDataset(another_file, read_data=False) index = IndexedDataset(another_file)
assert index.dtype == self.dtype assert index.dtype == self.dtype
begin = self.data_offsets[-1] begin = self.data_offsets[-1]
......
...@@ -69,17 +69,19 @@ class EpochBatchIterator(object): ...@@ -69,17 +69,19 @@ class EpochBatchIterator(object):
batch_sampler (~torch.utils.data.Sampler): an iterator over batches of batch_sampler (~torch.utils.data.Sampler): an iterator over batches of
indices indices
seed (int, optional): seed for random number generator for seed (int, optional): seed for random number generator for
reproducibility. Default: 1 reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N num_shards (int, optional): shard the data iterator into N
shards. Default: 1 shards (default: 1).
shard_id (int, optional): which shard of the data iterator to shard_id (int, optional): which shard of the data iterator to
return. Default: 0 return (default: 0).
buffer_size (int, optional): number of batches to buffer. Default: 5 num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
""" """
def __init__( def __init__(
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0, self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
buffer_size=5, num_workers=0,
): ):
assert isinstance(dataset, torch.utils.data.Dataset) assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset self.dataset = dataset
...@@ -88,14 +90,12 @@ class EpochBatchIterator(object): ...@@ -88,14 +90,12 @@ class EpochBatchIterator(object):
self.seed = seed self.seed = seed
self.num_shards = num_shards self.num_shards = num_shards
self.shard_id = shard_id self.shard_id = shard_id
self.buffer_size = buffer_size self.num_workers = num_workers
self.epoch = 0 self.epoch = 0
self._cur_epoch_itr = None self._cur_epoch_itr = None
self._next_epoch_itr = None self._next_epoch_itr = None
self._supports_prefetch = ( self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)
hasattr(dataset, 'supports_prefetch') and dataset.supports_prefetch
)
def __len__(self): def __len__(self):
return len(self.frozen_batches) return len(self.frozen_batches)
...@@ -105,11 +105,10 @@ class EpochBatchIterator(object): ...@@ -105,11 +105,10 @@ class EpochBatchIterator(object):
Args: Args:
shuffle (bool, optional): shuffle batches before returning the shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True`` iterator (default: True).
fix_batches_to_gpus: ensure that batches are always fix_batches_to_gpus: ensure that batches are always
allocated to the same shards across epochs. Requires allocated to the same shards across epochs. Requires
that :attr:`dataset` supports prefetching. Default: that :attr:`dataset` supports prefetching (default: False).
``False``
""" """
if self._next_epoch_itr is not None: if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr self._cur_epoch_itr = self._next_epoch_itr
...@@ -117,7 +116,8 @@ class EpochBatchIterator(object): ...@@ -117,7 +116,8 @@ class EpochBatchIterator(object):
else: else:
self.epoch += 1 self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch( self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus) self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
)
return self._cur_epoch_itr return self._cur_epoch_itr
def end_of_epoch(self): def end_of_epoch(self):
...@@ -179,50 +179,14 @@ class EpochBatchIterator(object): ...@@ -179,50 +179,14 @@ class EpochBatchIterator(object):
batches = self.frozen_batches batches = self.frozen_batches
batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
return CountingIterator(BufferedIterator( return CountingIterator(torch.utils.data.DataLoader(
torch.utils.data.DataLoader( self.dataset,
self.dataset, collate_fn=self.collate_fn,
collate_fn=self.collate_fn, batch_sampler=batches,
batch_sampler=batches, num_workers=self.num_workers,
),
buffer_size=self.buffer_size,
)) ))
class BufferedIterator(object):
"""Wrapper around an iterable that prefetches items into a buffer.
Args:
iterable (iterable): iterable to wrap
buffer_size (int): number of items to prefetch and buffer
"""
def __init__(self, iterable, buffer_size):
self.iterable = iterable
self.q = queue.Queue(maxsize=buffer_size)
self.thread = threading.Thread(target=self._load_q, daemon=True)
self.thread.start()
def __len__(self):
return len(self.iterable)
def __iter__(self):
return self
def __next__(self):
x = self.q.get()
if x is None:
self.thread.join()
raise StopIteration
return x[0]
def _load_q(self):
for x in self.iterable:
self.q.put([x]) # wrap in list so that it's never None
self.q.put(None)
class GroupedIterator(object): class GroupedIterator(object):
"""Wrapper around an iterable that returns groups (chunks) of items. """Wrapper around an iterable that returns groups (chunks) of items.
...@@ -261,7 +225,7 @@ class ShardedIterator(object): ...@@ -261,7 +225,7 @@ class ShardedIterator(object):
num_shards (int): number of shards to split the iterable into num_shards (int): number of shards to split the iterable into
shard_id (int): which shard to iterator over shard_id (int): which shard to iterator over
fill_value (Any, optional): padding value when the iterable doesn't fill_value (Any, optional): padding value when the iterable doesn't
evenly divide *num_shards*. Default: ``None`` evenly divide *num_shards* (default: None).
""" """
def __init__(self, iterable, num_shards, shard_id, fill_value=None): def __init__(self, iterable, num_shards, shard_id, fill_value=None):
......
...@@ -79,23 +79,23 @@ class LanguagePairDataset(FairseqDataset): ...@@ -79,23 +79,23 @@ class LanguagePairDataset(FairseqDataset):
tgt (torch.utils.data.Dataset, optional): target dataset to wrap tgt (torch.utils.data.Dataset, optional): target dataset to wrap
tgt_sizes (List[int], optional): target sentence lengths tgt_sizes (List[int], optional): target sentence lengths
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
left_pad_source (bool, optional): pad source tensors on the left side. left_pad_source (bool, optional): pad source tensors on the left side
Default: ``True`` (default: True).
left_pad_target (bool, optional): pad target tensors on the left side. left_pad_target (bool, optional): pad target tensors on the left side
Default: ``False`` (default: False).
max_source_positions (int, optional): max number of tokens in the source max_source_positions (int, optional): max number of tokens in the
sentence. Default: ``1024`` source sentence (default: 1024).
max_target_positions (int, optional): max number of tokens in the target max_target_positions (int, optional): max number of tokens in the
sentence. Default: ``1024`` target sentence (default: 1024).
shuffle (bool, optional): shuffle dataset elements before batching. shuffle (bool, optional): shuffle dataset elements before batching
Default: ``True`` (default: True).
input_feeding (bool, optional): create a shifted version of the targets input_feeding (bool, optional): create a shifted version of the targets
to be passed into the model for input feeding/teacher forcing. to be passed into the model for input feeding/teacher forcing
Default: ``True`` (default: True).
remove_eos_from_source (bool, optional): if set, removes eos from end of remove_eos_from_source (bool, optional): if set, removes eos from end
source if it's present. Default: ``False`` of source if it's present (default: False).
append_eos_to_target (bool, optional): if set, appends eos to end of append_eos_to_target (bool, optional): if set, appends eos to end of
target if it's absent. Default: ``False`` target if it's absent (default: False).
""" """
def __init__( def __init__(
...@@ -223,15 +223,13 @@ class LanguagePairDataset(FairseqDataset): ...@@ -223,15 +223,13 @@ class LanguagePairDataset(FairseqDataset):
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return ( return (
hasattr(self.src, 'supports_prefetch') getattr(self.src, 'supports_prefetch', False)
and self.src.supports_prefetch and getattr(self.tgt, 'supports_prefetch', False)
and hasattr(self.tgt, 'supports_prefetch')
and self.tgt.supports_prefetch
) )
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment