Commit 7633129b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#283)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/283

Pull Request resolved: https://github.com/pytorch/fairseq/pull/428

Differential Revision: D13564190

Pulled By: myleott

fbshipit-source-id: 3b62282d7069c288f5bdd1dd2c120788cee4abb5
parent 0cb87130
...@@ -19,10 +19,13 @@ of various sequence-to-sequence models, including: ...@@ -19,10 +19,13 @@ of various sequence-to-sequence models, including:
Fairseq features: Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines - multi-GPU (distributed) training on one machine or across multiple machines
- fast beam search generation on both CPU and GPU - fast generation on both CPU and GPU with multiple search algorithms implemented:
- beam search
- Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
- sampling (unconstrained and top-k)
- large mini-batch training even on a single GPU via delayed updates - large mini-batch training even on a single GPU via delayed updates
- fast half-precision floating point (FP16) training - fast half-precision floating point (FP16) training
- extensible: easily register new models, criterions, and tasks - extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers
We also provide [pre-trained models](#pre-trained-models) for several benchmark We also provide [pre-trained models](#pre-trained-models) for several benchmark
translation and language modeling datasets. translation and language modeling datasets.
...@@ -34,7 +37,7 @@ translation and language modeling datasets. ...@@ -34,7 +37,7 @@ translation and language modeling datasets.
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl) * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* Python version 3.6 * Python version 3.6
Currently fairseq requires PyTorch version >= 0.4.0. Currently fairseq requires PyTorch version >= 1.0.0.
Please follow the instructions here: https://github.com/pytorch/pytorch#installation. Please follow the instructions here: https://github.com/pytorch/pytorch#installation.
If you use Docker make sure to increase the shared memory size either with If you use Docker make sure to increase the shared memory size either with
......
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import socket
import subprocess
from train import main as single_process_main
from fairseq import distributed_utils, options
def main(args):
if args.distributed_init_method is None and args.distributed_port > 0:
# We can determine the init method automatically for Slurm.
node_list = os.environ.get('SLURM_JOB_NODELIST')
if node_list is not None:
try:
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
args.distributed_init_method = 'tcp://{host}:{port}'.format(
host=hostnames.split()[0].decode('utf-8'),
port=args.distributed_port)
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError as e: # Slurm is not installed
pass
if args.distributed_init_method is None and args.distributed_port is None:
raise ValueError('--distributed-init-method or --distributed-port '
'must be specified for distributed training')
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
single_process_main(args)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)
...@@ -6,8 +6,26 @@ ...@@ -6,8 +6,26 @@
Criterions Criterions
========== ==========
Criterions compute the loss function given the model and batch, roughly::
loss = criterion(model, batch)
.. automodule:: fairseq.criterions .. automodule:: fairseq.criterions
:members: :members:
.. autoclass:: fairseq.criterions.FairseqCriterion .. autoclass:: fairseq.criterions.FairseqCriterion
:members: :members:
:undoc-members: :undoc-members:
.. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
:members:
:undoc-members:
...@@ -21,6 +21,20 @@ mini-batches. ...@@ -21,6 +21,20 @@ mini-batches.
.. autoclass:: fairseq.data.MonolingualDataset .. autoclass:: fairseq.data.MonolingualDataset
:members: :members:
**Helper Datasets**
These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
provide additional functionality:
.. autoclass:: fairseq.data.BacktranslationDataset
:members:
.. autoclass:: fairseq.data.ConcatDataset
:members:
.. autoclass:: fairseq.data.RoundRobinZipDatasets
:members:
.. autoclass:: fairseq.data.TransformEosDataset
:members:
Dictionary Dictionary
---------- ----------
...@@ -32,6 +46,8 @@ Dictionary ...@@ -32,6 +46,8 @@ Dictionary
Iterators Iterators
--------- ---------
.. autoclass:: fairseq.data.BufferedIterator
:members:
.. autoclass:: fairseq.data.CountingIterator .. autoclass:: fairseq.data.CountingIterator
:members: :members:
.. autoclass:: fairseq.data.EpochBatchIterator .. autoclass:: fairseq.data.EpochBatchIterator
......
...@@ -27,21 +27,20 @@ interactively. Here, we use a beam size of 5: ...@@ -27,21 +27,20 @@ interactively. Here, we use a beam size of 5:
> MODEL_DIR=wmt14.en-fr.fconv-py > MODEL_DIR=wmt14.en-fr.fconv-py
> python interactive.py \ > python interactive.py \
--path $MODEL_DIR/model.pt $MODEL_DIR \ --path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5 --beam 5 --source-lang en --target-lang fr
| loading model(s) from wmt14.en-fr.fconv-py/model.pt | loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types | [en] dictionary: 44206 types
| [fr] dictionary: 44463 types | [fr] dictionary: 44463 types
| Type the input sentence and press return: | Type the input sentence and press return:
> Why is it rare to discover new marine mam@@ mal species ? > Why is it rare to discover new marine mam@@ mal species ?
O Why is it rare to discover new marine mam@@ mal species ? O Why is it rare to discover new marine mam@@ mal species ?
H -0.06429661810398102 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ? H -0.1525060087442398 Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ?
A 0 1 3 3 5 6 6 8 8 8 7 11 12 P -0.2221 -0.3122 -0.1289 -0.2673 -0.1711 -0.1930 -0.1101 -0.1660 -0.1003 -0.0740 -0.1101 -0.0814 -0.1238 -0.0985 -0.1288
This generation script produces four types of outputs: a line prefixed This generation script produces three types of outputs: a line prefixed
with *S* shows the supplied source sentence after applying the with *O* is a copy of the original source sentence; *H* is the
vocabulary; *O* is a copy of the original source sentence; *H* is the hypothesis along with an average log-likelihood; and *P* is the
hypothesis along with an average log-likelihood; and *A* is the positional score per token position, including the
attention maxima for each word in the hypothesis, including the
end-of-sentence marker which is omitted from the text. end-of-sentence marker which is omitted from the text.
See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a
......
...@@ -6,7 +6,29 @@ ...@@ -6,7 +6,29 @@
Learning Rate Schedulers Learning Rate Schedulers
======================== ========================
TODO Learning Rate Schedulers update the learning rate over the course of training.
Learning rates can be updated after each update via :func:`step_update` or at
epoch boundaries via :func:`step`.
.. automodule:: fairseq.optim.lr_scheduler .. automodule:: fairseq.optim.lr_scheduler
:members: :members:
.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.reduce_angular_lr_scheduler.TriangularSchedule
:members:
:undoc-members:
Modules Modules
======= =======
Fairseq provides several stand-alone :class:`torch.nn.Module` s that may be Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
helpful when implementing a new :class:`FairseqModel`. be helpful when implementing a new :class:`~fairseq.models.FairseqModel`.
.. automodule:: fairseq.modules .. automodule:: fairseq.modules
:members: :members:
......
...@@ -6,5 +6,27 @@ ...@@ -6,5 +6,27 @@
Optimizers Optimizers
========== ==========
Optimizers update the Model parameters based on the gradients.
.. automodule:: fairseq.optim .. automodule:: fairseq.optim
:members: :members:
.. autoclass:: fairseq.optim.FairseqOptimizer
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adagrad.Adagrad
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adam.FairseqAdam
:members:
:undoc-members:
.. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
:members:
:undoc-members:
.. autoclass:: fairseq.optim.nag.FairseqNAG
:members:
:undoc-members:
.. autoclass:: fairseq.optim.sgd.SGD
:members:
:undoc-members:
...@@ -22,12 +22,18 @@ fairseq implements the following high-level training flow:: ...@@ -22,12 +22,18 @@ fairseq implements the following high-level training flow::
for epoch in range(num_epochs): for epoch in range(num_epochs):
itr = task.get_batch_iterator(task.dataset('train')) itr = task.get_batch_iterator(task.dataset('train'))
for num_updates, batch in enumerate(itr): for num_updates, batch in enumerate(itr):
loss = criterion(model, batch) task.train_step(batch, model, criterion, optimizer)
optimizer.backward(loss) average_and_clip_gradients()
optimizer.step() optimizer.step()
lr_scheduler.step_update(num_updates) lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch) lr_scheduler.step(epoch)
where the default implementation for ``train.train_step`` is roughly::
def train_step(self, batch, model, criterion, optimizer):
loss = criterion(model, batch)
optimizer.backward(loss)
**Registering new plug-ins** **Registering new plug-ins**
New plug-ins are *registered* through a set of ``@register`` function New plug-ins are *registered* through a set of ``@register`` function
......
...@@ -353,17 +353,16 @@ The model files should appear in the :file:`checkpoints/` directory. ...@@ -353,17 +353,16 @@ The model files should appear in the :file:`checkpoints/` directory.
------------------------------- -------------------------------
Finally we can write a short script to evaluate our model on new inputs. Create Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classify.py` with the following contents:: a new file named :file:`eval_classifier.py` with the following contents::
from fairseq import data, options, tasks, utils from fairseq import data, options, tasks, utils
from fairseq.tokenizer import Tokenizer from fairseq.tokenizer import Tokenizer
# Parse command-line arguments for generation # Parse command-line arguments for generation
parser = options.get_generation_parser() parser = options.get_generation_parser(default_task='simple_classification')
args = options.parse_args_and_arch(parser) args = options.parse_args_and_arch(parser)
# Setup task # Setup task
args.task = 'simple_classification'
task = tasks.setup_task(args) task = tasks.setup_task(args)
# Load model # Load model
......
...@@ -55,7 +55,9 @@ def main(parsed_args): ...@@ -55,7 +55,9 @@ def main(parsed_args):
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path)) print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides)) models, args = utils.load_ensemble_for_inference(
parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
)
for arg in vars(parsed_args).keys(): for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}: if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
...@@ -83,9 +85,10 @@ def main(parsed_args): ...@@ -83,9 +85,10 @@ def main(parsed_args):
max_positions=utils.resolve_max_positions(*[ max_positions=utils.resolve_max_positions(*[
model.max_positions() for model in models model.max_positions() for model in models
]), ]),
ignore_invalid_inputs=True,
num_shards=args.num_shards, num_shards=args.num_shards,
shard_id=args.shard_id, shard_id=args.shard_id,
ignore_invalid_inputs=True, num_workers=args.num_workers,
).next_epoch_itr(shuffle=False) ).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
......
...@@ -9,7 +9,7 @@ from .dictionary import Dictionary, TruncatedDictionary ...@@ -9,7 +9,7 @@ from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset from .fairseq_dataset import FairseqDataset
from .backtranslation_dataset import BacktranslationDataset from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets from .round_robin_zip_datasets import RoundRobinZipDatasets
...@@ -33,7 +33,6 @@ __all__ = [ ...@@ -33,7 +33,6 @@ __all__ = [
'GroupedIterator', 'GroupedIterator',
'IndexedCachedDataset', 'IndexedCachedDataset',
'IndexedDataset', 'IndexedDataset',
'IndexedInMemoryDataset',
'IndexedRawTextDataset', 'IndexedRawTextDataset',
'LanguagePairDataset', 'LanguagePairDataset',
'MonolingualDataset', 'MonolingualDataset',
......
...@@ -56,16 +56,6 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True): ...@@ -56,16 +56,6 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
class BacktranslationDataset(FairseqDataset): class BacktranslationDataset(FairseqDataset):
def __init__(
self,
tgt_dataset,
backtranslation_fn,
max_len_a,
max_len_b,
output_collater=None,
cuda=True,
**kwargs
):
""" """
Sets up a backtranslation dataset which takes a tgt batch, generates Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*), a src using a tgt-src backtranslation function (*backtranslation_fn*),
...@@ -73,20 +63,31 @@ class BacktranslationDataset(FairseqDataset): ...@@ -73,20 +63,31 @@ class BacktranslationDataset(FairseqDataset):
Args: Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be backtranslated. Only the source side of this dataset will be used.
used. After backtranslation, the source sentences in this After backtranslation, the source sentences in this dataset will be
dataset will be returned as the targets. returned as the targets.
backtranslation_fn (callable): function to call to generate backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object. :class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be `maxlen = max_len_a * src_len + max_len_b`, which will be passed
passed into *backtranslation_fn*. into *backtranslation_fn*.
output_collater (callable, optional): function to call on the output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch (default: backtranslated samples to create the final batch
``tgt_dataset.collater``) (default: ``tgt_dataset.collater``).
cuda: use GPU for generation cuda: use GPU for generation
""" """
def __init__(
self,
tgt_dataset,
backtranslation_fn,
max_len_a,
max_len_b,
output_collater=None,
cuda=True,
**kwargs
):
self.tgt_dataset = tgt_dataset self.tgt_dataset = tgt_dataset
self.backtranslation_fn = backtranslation_fn self.backtranslation_fn = backtranslation_fn
self.max_len_a = max_len_a self.max_len_a = max_len_a
...@@ -169,8 +170,7 @@ class BacktranslationDataset(FairseqDataset): ...@@ -169,8 +170,7 @@ class BacktranslationDataset(FairseqDataset):
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return self.tgt_dataset.supports_prefetch() return getattr(self.tgt_dataset, 'supports_prefetch', False)
def prefetch(self, indices): def prefetch(self, indices):
return self.tgt_dataset.prefetch(indices) return self.tgt_dataset.prefetch(indices)
...@@ -29,18 +29,18 @@ class ConcatDataset(FairseqDataset): ...@@ -29,18 +29,18 @@ class ConcatDataset(FairseqDataset):
if isinstance(sample_ratios, int): if isinstance(sample_ratios, int):
sample_ratios = [sample_ratios] * len(self.datasets) sample_ratios = [sample_ratios] * len(self.datasets)
self.sample_ratios = sample_ratios self.sample_ratios = sample_ratios
self.cummulative_sizes = self.cumsum(self.datasets, sample_ratios) self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.real_sizes = [len(d) for d in self.datasets] self.real_sizes = [len(d) for d in self.datasets]
def __len__(self): def __len__(self):
return self.cummulative_sizes[-1] return self.cumulative_sizes[-1]
def __getitem__(self, idx): def __getitem__(self, idx):
dataset_idx = bisect.bisect_right(self.cummulative_sizes, idx) dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0: if dataset_idx == 0:
sample_idx = idx sample_idx = idx
else: else:
sample_idx = idx - self.cummulative_sizes[dataset_idx - 1] sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % self.real_sizes[dataset_idx] sample_idx = sample_idx % self.real_sizes[dataset_idx]
return self.datasets[dataset_idx][sample_idx] return self.datasets[dataset_idx][sample_idx]
...@@ -54,7 +54,7 @@ class ConcatDataset(FairseqDataset): ...@@ -54,7 +54,7 @@ class ConcatDataset(FairseqDataset):
def prefetch(self, indices): def prefetch(self, indices):
frm = 0 frm = 0
for to, ds in zip(self.cummulative_sizes, self.datasets): for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds) real_size = len(ds)
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to]) ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to frm = to
...@@ -81,8 +81,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False): ...@@ -81,8 +81,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
size_fn (callable): function that returns the size of a given index size_fn (callable): function that returns the size of a given index
max_positions (tuple): filter elements larger than this size. max_positions (tuple): filter elements larger than this size.
Comparisons are done component-wise. Comparisons are done component-wise.
raise_exception (bool, optional): if ``True``, raise an exception raise_exception (bool, optional): if ``True``, raise an exception if
if any elements are filtered. Default: ``False`` any elements are filtered (default: False).
""" """
def check_size(idx): def check_size(idx):
if isinstance(max_positions, float) or isinstance(max_positions, int): if isinstance(max_positions, float) or isinstance(max_positions, int):
...@@ -128,12 +128,12 @@ def batch_by_size( ...@@ -128,12 +128,12 @@ def batch_by_size(
indices (List[int]): ordered list of dataset indices indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at num_tokens_fn (callable): function that returns the number of tokens at
a given index a given index
max_tokens (int, optional): max number of tokens in each batch. max_tokens (int, optional): max number of tokens in each batch
Default: ``None`` (default: None).
max_sentences (int, optional): max number of sentences in each max_sentences (int, optional): max number of sentences in each
batch. Default: ``None`` batch (default: None).
required_batch_size_multiple (int, optional): require batch size to required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1`` be a multiple of N (default: 1).
""" """
max_tokens = max_tokens if max_tokens is not None else float('Inf') max_tokens = max_tokens if max_tokens is not None else float('Inf')
max_sentences = max_sentences if max_sentences is not None else float('Inf') max_sentences = max_sentences if max_sentences is not None else float('Inf')
......
...@@ -200,11 +200,15 @@ class Dictionary(object): ...@@ -200,11 +200,15 @@ class Dictionary(object):
t[-1] = self.eos() t[-1] = self.eos()
return t return t
class TruncatedDictionary(object): class TruncatedDictionary(object):
def __init__(self, wrapped_dict, length): def __init__(self, wrapped_dict, length):
self.__class__ = type(wrapped_dict.__class__.__name__, self.__class__ = type(
(self.__class__, wrapped_dict.__class__), {}) wrapped_dict.__class__.__name__,
(self.__class__, wrapped_dict.__class__),
{}
)
self.__dict__ = wrapped_dict.__dict__ self.__dict__ = wrapped_dict.__dict__
self.wrapped_dict = wrapped_dict self.wrapped_dict = wrapped_dict
self.length = min(len(self.wrapped_dict), length) self.length = min(len(self.wrapped_dict), length)
......
...@@ -7,8 +7,6 @@ ...@@ -7,8 +7,6 @@
import torch.utils.data import torch.utils.data
from fairseq.data import data_utils
class FairseqDataset(torch.utils.data.Dataset): class FairseqDataset(torch.utils.data.Dataset):
"""A dataset that provides helpers for batching.""" """A dataset that provides helpers for batching."""
...@@ -51,7 +49,9 @@ class FairseqDataset(torch.utils.data.Dataset): ...@@ -51,7 +49,9 @@ class FairseqDataset(torch.utils.data.Dataset):
@property @property
def supports_prefetch(self): def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return False return False
def prefetch(self, indices): def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
raise NotImplementedError raise NotImplementedError
...@@ -52,13 +52,12 @@ def data_file_path(prefix_path): ...@@ -52,13 +52,12 @@ def data_file_path(prefix_path):
class IndexedDataset(torch.utils.data.Dataset): class IndexedDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset""" """Loader for TorchNet IndexedDataset"""
def __init__(self, path, fix_lua_indexing=False, read_data=True): def __init__(self, path, fix_lua_indexing=False):
super().__init__() super().__init__()
self.fix_lua_indexing = fix_lua_indexing self.fix_lua_indexing = fix_lua_indexing
self.read_index(path) self.read_index(path)
self.data_file = None self.data_file = None
if read_data: self.path = path
self.read_data(path)
def read_index(self, path): def read_index(self, path):
with open(index_file_path(path), 'rb') as f: with open(index_file_path(path), 'rb') as f:
...@@ -85,8 +84,10 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -85,8 +84,10 @@ class IndexedDataset(torch.utils.data.Dataset):
self.data_file.close() self.data_file.close()
def __getitem__(self, i): def __getitem__(self, i):
if not self.data_file:
self.read_data(self.path)
self.check_index(i) self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] tensor_size = int(self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]])
a = np.empty(tensor_size, dtype=self.dtype) a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a) self.data_file.readinto(a)
...@@ -98,12 +99,6 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -98,12 +99,6 @@ class IndexedDataset(torch.utils.data.Dataset):
def __len__(self): def __len__(self):
return self.size return self.size
def read_into(self, start, dst):
self.data_file.seek(start * self.element_size)
self.data_file.readinto(dst)
if self.fix_lua_indexing:
dst -= 1 # subtract 1 for 0-based indexing
@staticmethod @staticmethod
def exists(path): def exists(path):
return ( return (
...@@ -111,11 +106,15 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -111,11 +106,15 @@ class IndexedDataset(torch.utils.data.Dataset):
os.path.exists(data_file_path(path)) os.path.exists(data_file_path(path))
) )
@property
def supports_prefetch(self):
return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset): class IndexedCachedDataset(IndexedDataset):
def __init__(self, path, fix_lua_indexing=False): def __init__(self, path, fix_lua_indexing=False):
super().__init__(path, fix_lua_indexing, True) super().__init__(path, fix_lua_indexing=fix_lua_indexing)
self.cache = None self.cache = None
self.cache_index = {} self.cache_index = {}
...@@ -126,6 +125,8 @@ class IndexedCachedDataset(IndexedDataset): ...@@ -126,6 +125,8 @@ class IndexedCachedDataset(IndexedDataset):
def prefetch(self, indices): def prefetch(self, indices):
if all(i in self.cache_index for i in indices): if all(i in self.cache_index for i in indices):
return return
if not self.data_file:
self.read_data(self.path)
indices = sorted(set(indices)) indices = sorted(set(indices))
total_size = 0 total_size = 0
for i in indices: for i in indices:
...@@ -153,34 +154,7 @@ class IndexedCachedDataset(IndexedDataset): ...@@ -153,34 +154,7 @@ class IndexedCachedDataset(IndexedDataset):
return item return item
class IndexedInMemoryDataset(IndexedDataset): class IndexedRawTextDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb')
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
self.data_file.readinto(self.buffer)
self.data_file.close()
if self.fix_lua_indexing:
self.buffer -= 1 # subtract 1 for 0-based indexing
def read_into(self, start, dst):
if self.token_blob is None:
self.token_blob = [t for l in self.tokens_list for t in l]
np.copyto(dst, self.token_blob[start:])
def __del__(self):
pass
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
return torch.from_numpy(a).long()
class IndexedRawTextDataset(IndexedDataset):
"""Takes a text file as input and binarizes it in memory at instantiation. """Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory""" Original lines are also kept in memory"""
...@@ -205,6 +179,10 @@ class IndexedRawTextDataset(IndexedDataset): ...@@ -205,6 +179,10 @@ class IndexedRawTextDataset(IndexedDataset):
self.sizes.append(len(tokens)) self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes) self.sizes = np.array(self.sizes)
def check_index(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
def __getitem__(self, i): def __getitem__(self, i):
self.check_index(i) self.check_index(i)
return self.tokens_list[i] return self.tokens_list[i]
...@@ -252,7 +230,7 @@ class IndexedDatasetBuilder(object): ...@@ -252,7 +230,7 @@ class IndexedDatasetBuilder(object):
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size())) self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def merge_file_(self, another_file): def merge_file_(self, another_file):
index = IndexedDataset(another_file, read_data=False) index = IndexedDataset(another_file)
assert index.dtype == self.dtype assert index.dtype == self.dtype
begin = self.data_offsets[-1] begin = self.data_offsets[-1]
......
...@@ -69,17 +69,19 @@ class EpochBatchIterator(object): ...@@ -69,17 +69,19 @@ class EpochBatchIterator(object):
batch_sampler (~torch.utils.data.Sampler): an iterator over batches of batch_sampler (~torch.utils.data.Sampler): an iterator over batches of
indices indices
seed (int, optional): seed for random number generator for seed (int, optional): seed for random number generator for
reproducibility. Default: 1 reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N num_shards (int, optional): shard the data iterator into N
shards. Default: 1 shards (default: 1).
shard_id (int, optional): which shard of the data iterator to shard_id (int, optional): which shard of the data iterator to
return. Default: 0 return (default: 0).
buffer_size (int, optional): number of batches to buffer. Default: 5 num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
""" """
def __init__( def __init__(
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0, self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
buffer_size=5, num_workers=0,
): ):
assert isinstance(dataset, torch.utils.data.Dataset) assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset self.dataset = dataset
...@@ -88,14 +90,12 @@ class EpochBatchIterator(object): ...@@ -88,14 +90,12 @@ class EpochBatchIterator(object):
self.seed = seed self.seed = seed
self.num_shards = num_shards self.num_shards = num_shards
self.shard_id = shard_id self.shard_id = shard_id
self.buffer_size = buffer_size self.num_workers = num_workers
self.epoch = 0 self.epoch = 0
self._cur_epoch_itr = None self._cur_epoch_itr = None
self._next_epoch_itr = None self._next_epoch_itr = None
self._supports_prefetch = ( self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)
hasattr(dataset, 'supports_prefetch') and dataset.supports_prefetch
)
def __len__(self): def __len__(self):
return len(self.frozen_batches) return len(self.frozen_batches)
...@@ -105,11 +105,10 @@ class EpochBatchIterator(object): ...@@ -105,11 +105,10 @@ class EpochBatchIterator(object):
Args: Args:
shuffle (bool, optional): shuffle batches before returning the shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True`` iterator (default: True).
fix_batches_to_gpus: ensure that batches are always fix_batches_to_gpus: ensure that batches are always
allocated to the same shards across epochs. Requires allocated to the same shards across epochs. Requires
that :attr:`dataset` supports prefetching. Default: that :attr:`dataset` supports prefetching (default: False).
``False``
""" """
if self._next_epoch_itr is not None: if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr self._cur_epoch_itr = self._next_epoch_itr
...@@ -117,7 +116,8 @@ class EpochBatchIterator(object): ...@@ -117,7 +116,8 @@ class EpochBatchIterator(object):
else: else:
self.epoch += 1 self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch( self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus) self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
)
return self._cur_epoch_itr return self._cur_epoch_itr
def end_of_epoch(self): def end_of_epoch(self):
...@@ -179,50 +179,14 @@ class EpochBatchIterator(object): ...@@ -179,50 +179,14 @@ class EpochBatchIterator(object):
batches = self.frozen_batches batches = self.frozen_batches
batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
return CountingIterator(BufferedIterator( return CountingIterator(torch.utils.data.DataLoader(
torch.utils.data.DataLoader(
self.dataset, self.dataset,
collate_fn=self.collate_fn, collate_fn=self.collate_fn,
batch_sampler=batches, batch_sampler=batches,
), num_workers=self.num_workers,
buffer_size=self.buffer_size,
)) ))
class BufferedIterator(object):
"""Wrapper around an iterable that prefetches items into a buffer.
Args:
iterable (iterable): iterable to wrap
buffer_size (int): number of items to prefetch and buffer
"""
def __init__(self, iterable, buffer_size):
self.iterable = iterable
self.q = queue.Queue(maxsize=buffer_size)
self.thread = threading.Thread(target=self._load_q, daemon=True)
self.thread.start()
def __len__(self):
return len(self.iterable)
def __iter__(self):
return self
def __next__(self):
x = self.q.get()
if x is None:
self.thread.join()
raise StopIteration
return x[0]
def _load_q(self):
for x in self.iterable:
self.q.put([x]) # wrap in list so that it's never None
self.q.put(None)
class GroupedIterator(object): class GroupedIterator(object):
"""Wrapper around an iterable that returns groups (chunks) of items. """Wrapper around an iterable that returns groups (chunks) of items.
...@@ -261,7 +225,7 @@ class ShardedIterator(object): ...@@ -261,7 +225,7 @@ class ShardedIterator(object):
num_shards (int): number of shards to split the iterable into num_shards (int): number of shards to split the iterable into
shard_id (int): which shard to iterator over shard_id (int): which shard to iterator over
fill_value (Any, optional): padding value when the iterable doesn't fill_value (Any, optional): padding value when the iterable doesn't
evenly divide *num_shards*. Default: ``None`` evenly divide *num_shards* (default: None).
""" """
def __init__(self, iterable, num_shards, shard_id, fill_value=None): def __init__(self, iterable, num_shards, shard_id, fill_value=None):
......
...@@ -79,23 +79,23 @@ class LanguagePairDataset(FairseqDataset): ...@@ -79,23 +79,23 @@ class LanguagePairDataset(FairseqDataset):
tgt (torch.utils.data.Dataset, optional): target dataset to wrap tgt (torch.utils.data.Dataset, optional): target dataset to wrap
tgt_sizes (List[int], optional): target sentence lengths tgt_sizes (List[int], optional): target sentence lengths
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
left_pad_source (bool, optional): pad source tensors on the left side. left_pad_source (bool, optional): pad source tensors on the left side
Default: ``True`` (default: True).
left_pad_target (bool, optional): pad target tensors on the left side. left_pad_target (bool, optional): pad target tensors on the left side
Default: ``False`` (default: False).
max_source_positions (int, optional): max number of tokens in the source max_source_positions (int, optional): max number of tokens in the
sentence. Default: ``1024`` source sentence (default: 1024).
max_target_positions (int, optional): max number of tokens in the target max_target_positions (int, optional): max number of tokens in the
sentence. Default: ``1024`` target sentence (default: 1024).
shuffle (bool, optional): shuffle dataset elements before batching. shuffle (bool, optional): shuffle dataset elements before batching
Default: ``True`` (default: True).
input_feeding (bool, optional): create a shifted version of the targets input_feeding (bool, optional): create a shifted version of the targets
to be passed into the model for input feeding/teacher forcing. to be passed into the model for input feeding/teacher forcing
Default: ``True`` (default: True).
remove_eos_from_source (bool, optional): if set, removes eos from end of remove_eos_from_source (bool, optional): if set, removes eos from end
source if it's present. Default: ``False`` of source if it's present (default: False).
append_eos_to_target (bool, optional): if set, appends eos to end of append_eos_to_target (bool, optional): if set, appends eos to end of
target if it's absent. Default: ``False`` target if it's absent (default: False).
""" """
def __init__( def __init__(
...@@ -223,15 +223,13 @@ class LanguagePairDataset(FairseqDataset): ...@@ -223,15 +223,13 @@ class LanguagePairDataset(FairseqDataset):
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
@property @property
def supports_prefetch(self): def supports_prefetch(self):
return ( return (
hasattr(self.src, 'supports_prefetch') getattr(self.src, 'supports_prefetch', False)
and self.src.supports_prefetch and getattr(self.tgt, 'supports_prefetch', False)
and hasattr(self.tgt, 'supports_prefetch')
and self.tgt.supports_prefetch
) )
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment