Commit 7633129b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#283)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/283

Pull Request resolved: https://github.com/pytorch/fairseq/pull/428

Differential Revision: D13564190

Pulled By: myleott

fbshipit-source-id: 3b62282d7069c288f5bdd1dd2c120788cee4abb5
parent 0cb87130
......@@ -19,10 +19,13 @@ of various sequence-to-sequence models, including:
Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines
- fast beam search generation on both CPU and GPU
- fast generation on both CPU and GPU with multiple search algorithms implemented:
- beam search
- Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
- sampling (unconstrained and top-k)
- large mini-batch training even on a single GPU via delayed updates
- fast half-precision floating point (FP16) training
- extensible: easily register new models, criterions, and tasks
- extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers
We also provide [pre-trained models](#pre-trained-models) for several benchmark
translation and language modeling datasets.
......@@ -34,7 +37,7 @@ translation and language modeling datasets.
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* Python version 3.6
Currently fairseq requires PyTorch version >= 0.4.0.
Currently fairseq requires PyTorch version >= 1.0.0.
Please follow the instructions here: https://github.com/pytorch/pytorch#installation.
If you use Docker make sure to increase the shared memory size either with
......
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import socket
import subprocess
from train import main as single_process_main
from fairseq import distributed_utils, options
def main(args):
if args.distributed_init_method is None and args.distributed_port > 0:
# We can determine the init method automatically for Slurm.
node_list = os.environ.get('SLURM_JOB_NODELIST')
if node_list is not None:
try:
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
args.distributed_init_method = 'tcp://{host}:{port}'.format(
host=hostnames.split()[0].decode('utf-8'),
port=args.distributed_port)
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError as e: # Slurm is not installed
pass
if args.distributed_init_method is None and args.distributed_port is None:
raise ValueError('--distributed-init-method or --distributed-port '
'must be specified for distributed training')
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
single_process_main(args)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)
......@@ -6,8 +6,26 @@
Criterions
==========
Criterions compute the loss function given the model and batch, roughly::
loss = criterion(model, batch)
.. automodule:: fairseq.criterions
:members:
.. autoclass:: fairseq.criterions.FairseqCriterion
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
:members:
:undoc-members:
......@@ -21,6 +21,20 @@ mini-batches.
.. autoclass:: fairseq.data.MonolingualDataset
:members:
**Helper Datasets**
These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
provide additional functionality:
.. autoclass:: fairseq.data.BacktranslationDataset
:members:
.. autoclass:: fairseq.data.ConcatDataset
:members:
.. autoclass:: fairseq.data.RoundRobinZipDatasets
:members:
.. autoclass:: fairseq.data.TransformEosDataset
:members:
Dictionary
----------
......@@ -32,6 +46,8 @@ Dictionary
Iterators
---------
.. autoclass:: fairseq.data.BufferedIterator
:members:
.. autoclass:: fairseq.data.CountingIterator
:members:
.. autoclass:: fairseq.data.EpochBatchIterator
......
......@@ -27,21 +27,20 @@ interactively. Here, we use a beam size of 5:
> MODEL_DIR=wmt14.en-fr.fconv-py
> python interactive.py \
--path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5
--beam 5 --source-lang en --target-lang fr
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types
| [fr] dictionary: 44463 types
| Type the input sentence and press return:
> Why is it rare to discover new marine mam@@ mal species ?
O Why is it rare to discover new marine mam@@ mal species ?
H -0.06429661810398102 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ?
A 0 1 3 3 5 6 6 8 8 8 7 11 12
This generation script produces four types of outputs: a line prefixed
with *S* shows the supplied source sentence after applying the
vocabulary; *O* is a copy of the original source sentence; *H* is the
hypothesis along with an average log-likelihood; and *A* is the
attention maxima for each word in the hypothesis, including the
H -0.1525060087442398 Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ?
P -0.2221 -0.3122 -0.1289 -0.2673 -0.1711 -0.1930 -0.1101 -0.1660 -0.1003 -0.0740 -0.1101 -0.0814 -0.1238 -0.0985 -0.1288
This generation script produces three types of outputs: a line prefixed
with *O* is a copy of the original source sentence; *H* is the
hypothesis along with an average log-likelihood; and *P* is the
positional score per token position, including the
end-of-sentence marker which is omitted from the text.
See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a
......
......@@ -6,7 +6,29 @@
Learning Rate Schedulers
========================
TODO
Learning Rate Schedulers update the learning rate over the course of training.
Learning rates can be updated after each update via :func:`step_update` or at
epoch boundaries via :func:`step`.
.. automodule:: fairseq.optim.lr_scheduler
:members:
.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.reduce_angular_lr_scheduler.TriangularSchedule
:members:
:undoc-members:
Modules
=======
Fairseq provides several stand-alone :class:`torch.nn.Module` s that may be
helpful when implementing a new :class:`FairseqModel`.
Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
be helpful when implementing a new :class:`~fairseq.models.FairseqModel`.
.. automodule:: fairseq.modules
:members:
......
......@@ -6,5 +6,27 @@
Optimizers
==========
Optimizers update the Model parameters based on the gradients.
.. automodule:: fairseq.optim
:members:
.. autoclass:: fairseq.optim.FairseqOptimizer
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adagrad.Adagrad
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adam.FairseqAdam
:members:
:undoc-members:
.. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
:members:
:undoc-members:
.. autoclass:: fairseq.optim.nag.FairseqNAG
:members:
:undoc-members:
.. autoclass:: fairseq.optim.sgd.SGD
:members:
:undoc-members:
......@@ -22,12 +22,18 @@ fairseq implements the following high-level training flow::
for epoch in range(num_epochs):
itr = task.get_batch_iterator(task.dataset('train'))
for num_updates, batch in enumerate(itr):
loss = criterion(model, batch)
optimizer.backward(loss)
task.train_step(batch, model, criterion, optimizer)
average_and_clip_gradients()
optimizer.step()
lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch)
where the default implementation for ``train.train_step`` is roughly::
def train_step(self, batch, model, criterion, optimizer):
loss = criterion(model, batch)
optimizer.backward(loss)
**Registering new plug-ins**
New plug-ins are *registered* through a set of ``@register`` function
......
......@@ -353,17 +353,16 @@ The model files should appear in the :file:`checkpoints/` directory.
-------------------------------
Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classify.py` with the following contents::
a new file named :file:`eval_classifier.py` with the following contents::
from fairseq import data, options, tasks, utils
from fairseq.tokenizer import Tokenizer
# Parse command-line arguments for generation
parser = options.get_generation_parser()
parser = options.get_generation_parser(default_task='simple_classification')
args = options.parse_args_and_arch(parser)
# Setup task
args.task = 'simple_classification'
task = tasks.setup_task(args)
# Load model
......
......@@ -55,7 +55,9 @@ def main(parsed_args):
# Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides))
models, args = utils.load_ensemble_for_inference(
parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
)
for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
......@@ -83,9 +85,10 @@ def main(parsed_args):
max_positions=utils.resolve_max_positions(*[
model.max_positions() for model in models
]),
ignore_invalid_inputs=True,
num_shards=args.num_shards,
shard_id=args.shard_id,
ignore_invalid_inputs=True,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter()
......
......@@ -9,7 +9,7 @@ from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset
from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
......@@ -33,7 +33,6 @@ __all__ = [
'GroupedIterator',
'IndexedCachedDataset',
'IndexedDataset',
'IndexedInMemoryDataset',
'IndexedRawTextDataset',
'LanguagePairDataset',
'MonolingualDataset',
......
......@@ -56,6 +56,28 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
class BacktranslationDataset(FairseqDataset):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*),
and returns the corresponding `{generated src, input tgt}` batch.
Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be used.
After backtranslation, the source sentences in this dataset will be
returned as the targets.
backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be passed
into *backtranslation_fn*.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch
(default: ``tgt_dataset.collater``).
cuda: use GPU for generation
"""
def __init__(
self,
tgt_dataset,
......@@ -66,27 +88,6 @@ class BacktranslationDataset(FairseqDataset):
cuda=True,
**kwargs
):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*),
and returns the corresponding `{generated src, input tgt}` batch.
Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be
used. After backtranslation, the source sentences in this
dataset will be returned as the targets.
backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be
passed into *backtranslation_fn*.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch (default:
``tgt_dataset.collater``)
cuda: use GPU for generation
"""
self.tgt_dataset = tgt_dataset
self.backtranslation_fn = backtranslation_fn
self.max_len_a = max_len_a
......@@ -166,11 +167,10 @@ class BacktranslationDataset(FairseqDataset):
"""
tgt_size = self.tgt_dataset.size(index)[0]
return (tgt_size, tgt_size)
@property
def supports_prefetch(self):
return self.tgt_dataset.supports_prefetch()
return getattr(self.tgt_dataset, 'supports_prefetch', False)
def prefetch(self, indices):
return self.tgt_dataset.prefetch(indices)
......@@ -29,18 +29,18 @@ class ConcatDataset(FairseqDataset):
if isinstance(sample_ratios, int):
sample_ratios = [sample_ratios] * len(self.datasets)
self.sample_ratios = sample_ratios
self.cummulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.real_sizes = [len(d) for d in self.datasets]
def __len__(self):
return self.cummulative_sizes[-1]
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx = bisect.bisect_right(self.cummulative_sizes, idx)
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cummulative_sizes[dataset_idx - 1]
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % self.real_sizes[dataset_idx]
return self.datasets[dataset_idx][sample_idx]
......@@ -54,7 +54,7 @@ class ConcatDataset(FairseqDataset):
def prefetch(self, indices):
frm = 0
for to, ds in zip(self.cummulative_sizes, self.datasets):
for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds)
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to
......@@ -81,8 +81,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
size_fn (callable): function that returns the size of a given index
max_positions (tuple): filter elements larger than this size.
Comparisons are done component-wise.
raise_exception (bool, optional): if ``True``, raise an exception
if any elements are filtered. Default: ``False``
raise_exception (bool, optional): if ``True``, raise an exception if
any elements are filtered (default: False).
"""
def check_size(idx):
if isinstance(max_positions, float) or isinstance(max_positions, int):
......@@ -128,12 +128,12 @@ def batch_by_size(
indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at
a given index
max_tokens (int, optional): max number of tokens in each batch.
Default: ``None``
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch. Default: ``None``
batch (default: None).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1``
be a multiple of N (default: 1).
"""
max_tokens = max_tokens if max_tokens is not None else float('Inf')
max_sentences = max_sentences if max_sentences is not None else float('Inf')
......
......@@ -200,11 +200,15 @@ class Dictionary(object):
t[-1] = self.eos()
return t
class TruncatedDictionary(object):
def __init__(self, wrapped_dict, length):
self.__class__ = type(wrapped_dict.__class__.__name__,
(self.__class__, wrapped_dict.__class__), {})
self.__class__ = type(
wrapped_dict.__class__.__name__,
(self.__class__, wrapped_dict.__class__),
{}
)
self.__dict__ = wrapped_dict.__dict__
self.wrapped_dict = wrapped_dict
self.length = min(len(self.wrapped_dict), length)
......
......@@ -7,8 +7,6 @@
import torch.utils.data
from fairseq.data import data_utils
class FairseqDataset(torch.utils.data.Dataset):
"""A dataset that provides helpers for batching."""
......@@ -51,7 +49,9 @@ class FairseqDataset(torch.utils.data.Dataset):
@property
def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return False
def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
raise NotImplementedError
......@@ -52,13 +52,12 @@ def data_file_path(prefix_path):
class IndexedDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset"""
def __init__(self, path, fix_lua_indexing=False, read_data=True):
def __init__(self, path, fix_lua_indexing=False):
super().__init__()
self.fix_lua_indexing = fix_lua_indexing
self.read_index(path)
self.data_file = None
if read_data:
self.read_data(path)
self.path = path
def read_index(self, path):
with open(index_file_path(path), 'rb') as f:
......@@ -85,8 +84,10 @@ class IndexedDataset(torch.utils.data.Dataset):
self.data_file.close()
def __getitem__(self, i):
if not self.data_file:
self.read_data(self.path)
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
tensor_size = int(self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]])
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
......@@ -98,12 +99,6 @@ class IndexedDataset(torch.utils.data.Dataset):
def __len__(self):
return self.size
def read_into(self, start, dst):
self.data_file.seek(start * self.element_size)
self.data_file.readinto(dst)
if self.fix_lua_indexing:
dst -= 1 # subtract 1 for 0-based indexing
@staticmethod
def exists(path):
return (
......@@ -111,11 +106,15 @@ class IndexedDataset(torch.utils.data.Dataset):
os.path.exists(data_file_path(path))
)
@property
def supports_prefetch(self):
return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path, fix_lua_indexing=False):
super().__init__(path, fix_lua_indexing, True)
super().__init__(path, fix_lua_indexing=fix_lua_indexing)
self.cache = None
self.cache_index = {}
......@@ -126,6 +125,8 @@ class IndexedCachedDataset(IndexedDataset):
def prefetch(self, indices):
if all(i in self.cache_index for i in indices):
return
if not self.data_file:
self.read_data(self.path)
indices = sorted(set(indices))
total_size = 0
for i in indices:
......@@ -153,34 +154,7 @@ class IndexedCachedDataset(IndexedDataset):
return item
class IndexedInMemoryDataset(IndexedDataset):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb')
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
self.data_file.readinto(self.buffer)
self.data_file.close()
if self.fix_lua_indexing:
self.buffer -= 1 # subtract 1 for 0-based indexing
def read_into(self, start, dst):
if self.token_blob is None:
self.token_blob = [t for l in self.tokens_list for t in l]
np.copyto(dst, self.token_blob[start:])
def __del__(self):
pass
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
return torch.from_numpy(a).long()
class IndexedRawTextDataset(IndexedDataset):
class IndexedRawTextDataset(torch.utils.data.Dataset):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
......@@ -205,6 +179,10 @@ class IndexedRawTextDataset(IndexedDataset):
self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes)
def check_index(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
def __getitem__(self, i):
self.check_index(i)
return self.tokens_list[i]
......@@ -252,7 +230,7 @@ class IndexedDatasetBuilder(object):
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def merge_file_(self, another_file):
index = IndexedDataset(another_file, read_data=False)
index = IndexedDataset(another_file)
assert index.dtype == self.dtype
begin = self.data_offsets[-1]
......
......@@ -69,17 +69,19 @@ class EpochBatchIterator(object):
batch_sampler (~torch.utils.data.Sampler): an iterator over batches of
indices
seed (int, optional): seed for random number generator for
reproducibility. Default: 1
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards. Default: 1
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return. Default: 0
buffer_size (int, optional): number of batches to buffer. Default: 5
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
"""
def __init__(
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
buffer_size=5,
num_workers=0,
):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
......@@ -88,14 +90,12 @@ class EpochBatchIterator(object):
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
self.buffer_size = buffer_size
self.num_workers = num_workers
self.epoch = 0
self._cur_epoch_itr = None
self._next_epoch_itr = None
self._supports_prefetch = (
hasattr(dataset, 'supports_prefetch') and dataset.supports_prefetch
)
self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)
def __len__(self):
return len(self.frozen_batches)
......@@ -105,11 +105,10 @@ class EpochBatchIterator(object):
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True``
iterator (default: True).
fix_batches_to_gpus: ensure that batches are always
allocated to the same shards across epochs. Requires
that :attr:`dataset` supports prefetching. Default:
``False``
that :attr:`dataset` supports prefetching (default: False).
"""
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
......@@ -117,7 +116,8 @@ class EpochBatchIterator(object):
else:
self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus)
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
)
return self._cur_epoch_itr
def end_of_epoch(self):
......@@ -179,50 +179,14 @@ class EpochBatchIterator(object):
batches = self.frozen_batches
batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
return CountingIterator(BufferedIterator(
torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches,
),
buffer_size=self.buffer_size,
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches,
num_workers=self.num_workers,
))
class BufferedIterator(object):
"""Wrapper around an iterable that prefetches items into a buffer.
Args:
iterable (iterable): iterable to wrap
buffer_size (int): number of items to prefetch and buffer
"""
def __init__(self, iterable, buffer_size):
self.iterable = iterable
self.q = queue.Queue(maxsize=buffer_size)
self.thread = threading.Thread(target=self._load_q, daemon=True)
self.thread.start()
def __len__(self):
return len(self.iterable)
def __iter__(self):
return self
def __next__(self):
x = self.q.get()
if x is None:
self.thread.join()
raise StopIteration
return x[0]
def _load_q(self):
for x in self.iterable:
self.q.put([x]) # wrap in list so that it's never None
self.q.put(None)
class GroupedIterator(object):
"""Wrapper around an iterable that returns groups (chunks) of items.
......@@ -261,7 +225,7 @@ class ShardedIterator(object):
num_shards (int): number of shards to split the iterable into
shard_id (int): which shard to iterator over
fill_value (Any, optional): padding value when the iterable doesn't
evenly divide *num_shards*. Default: ``None``
evenly divide *num_shards* (default: None).
"""
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
......
......@@ -79,23 +79,23 @@ class LanguagePairDataset(FairseqDataset):
tgt (torch.utils.data.Dataset, optional): target dataset to wrap
tgt_sizes (List[int], optional): target sentence lengths
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
left_pad_source (bool, optional): pad source tensors on the left side.
Default: ``True``
left_pad_target (bool, optional): pad target tensors on the left side.
Default: ``False``
max_source_positions (int, optional): max number of tokens in the source
sentence. Default: ``1024``
max_target_positions (int, optional): max number of tokens in the target
sentence. Default: ``1024``
shuffle (bool, optional): shuffle dataset elements before batching.
Default: ``True``
left_pad_source (bool, optional): pad source tensors on the left side
(default: True).
left_pad_target (bool, optional): pad target tensors on the left side
(default: False).
max_source_positions (int, optional): max number of tokens in the
source sentence (default: 1024).
max_target_positions (int, optional): max number of tokens in the
target sentence (default: 1024).
shuffle (bool, optional): shuffle dataset elements before batching
(default: True).
input_feeding (bool, optional): create a shifted version of the targets
to be passed into the model for input feeding/teacher forcing.
Default: ``True``
remove_eos_from_source (bool, optional): if set, removes eos from end of
source if it's present. Default: ``False``
to be passed into the model for input feeding/teacher forcing
(default: True).
remove_eos_from_source (bool, optional): if set, removes eos from end
of source if it's present (default: False).
append_eos_to_target (bool, optional): if set, appends eos to end of
target if it's absent. Default: ``False``
target if it's absent (default: False).
"""
def __init__(
......@@ -223,15 +223,13 @@ class LanguagePairDataset(FairseqDataset):
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
@property
def supports_prefetch(self):
return (
hasattr(self.src, 'supports_prefetch')
and self.src.supports_prefetch
and hasattr(self.tgt, 'supports_prefetch')
and self.tgt.supports_prefetch
getattr(self.src, 'supports_prefetch', False)
and getattr(self.tgt, 'supports_prefetch', False)
)
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment