Commit 7633129b authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#283)

Summary:
Pull Request resolved: https://github.com/pytorch/translate/pull/283

Pull Request resolved: https://github.com/pytorch/fairseq/pull/428

Differential Revision: D13564190

Pulled By: myleott

fbshipit-source-id: 3b62282d7069c288f5bdd1dd2c120788cee4abb5
parent 0cb87130
......@@ -19,10 +19,13 @@ of various sequence-to-sequence models, including:
Fairseq features:
- multi-GPU (distributed) training on one machine or across multiple machines
- fast beam search generation on both CPU and GPU
- fast generation on both CPU and GPU with multiple search algorithms implemented:
- beam search
- Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
- sampling (unconstrained and top-k)
- large mini-batch training even on a single GPU via delayed updates
- fast half-precision floating point (FP16) training
- extensible: easily register new models, criterions, and tasks
- extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers
We also provide [pre-trained models](#pre-trained-models) for several benchmark
translation and language modeling datasets.
......@@ -34,7 +37,7 @@ translation and language modeling datasets.
* For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
* Python version 3.6
Currently fairseq requires PyTorch version >= 0.4.0.
Currently fairseq requires PyTorch version >= 1.0.0.
Please follow the instructions here: https://github.com/pytorch/pytorch#installation.
If you use Docker make sure to increase the shared memory size either with
......
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import socket
import subprocess
from train import main as single_process_main
from fairseq import distributed_utils, options
def main(args):
if args.distributed_init_method is None and args.distributed_port > 0:
# We can determine the init method automatically for Slurm.
node_list = os.environ.get('SLURM_JOB_NODELIST')
if node_list is not None:
try:
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
args.distributed_init_method = 'tcp://{host}:{port}'.format(
host=hostnames.split()[0].decode('utf-8'),
port=args.distributed_port)
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError as e: # Slurm is not installed
pass
if args.distributed_init_method is None and args.distributed_port is None:
raise ValueError('--distributed-init-method or --distributed-port '
'must be specified for distributed training')
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
single_process_main(args)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)
......@@ -6,8 +6,26 @@
Criterions
==========
Criterions compute the loss function given the model and batch, roughly::
loss = criterion(model, batch)
.. automodule:: fairseq.criterions
:members:
.. autoclass:: fairseq.criterions.FairseqCriterion
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
:members:
:undoc-members:
.. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
:members:
:undoc-members:
......@@ -21,6 +21,20 @@ mini-batches.
.. autoclass:: fairseq.data.MonolingualDataset
:members:
**Helper Datasets**
These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
provide additional functionality:
.. autoclass:: fairseq.data.BacktranslationDataset
:members:
.. autoclass:: fairseq.data.ConcatDataset
:members:
.. autoclass:: fairseq.data.RoundRobinZipDatasets
:members:
.. autoclass:: fairseq.data.TransformEosDataset
:members:
Dictionary
----------
......@@ -32,6 +46,8 @@ Dictionary
Iterators
---------
.. autoclass:: fairseq.data.BufferedIterator
:members:
.. autoclass:: fairseq.data.CountingIterator
:members:
.. autoclass:: fairseq.data.EpochBatchIterator
......
......@@ -27,21 +27,20 @@ interactively. Here, we use a beam size of 5:
> MODEL_DIR=wmt14.en-fr.fconv-py
> python interactive.py \
--path $MODEL_DIR/model.pt $MODEL_DIR \
--beam 5
--beam 5 --source-lang en --target-lang fr
| loading model(s) from wmt14.en-fr.fconv-py/model.pt
| [en] dictionary: 44206 types
| [fr] dictionary: 44463 types
| Type the input sentence and press return:
> Why is it rare to discover new marine mam@@ mal species ?
O Why is it rare to discover new marine mam@@ mal species ?
H -0.06429661810398102 Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins ?
A 0 1 3 3 5 6 6 8 8 8 7 11 12
This generation script produces four types of outputs: a line prefixed
with *S* shows the supplied source sentence after applying the
vocabulary; *O* is a copy of the original source sentence; *H* is the
hypothesis along with an average log-likelihood; and *A* is the
attention maxima for each word in the hypothesis, including the
H -0.1525060087442398 Pourquoi est @-@ il rare de découvrir de nouvelles espèces de mammifères marins ?
P -0.2221 -0.3122 -0.1289 -0.2673 -0.1711 -0.1930 -0.1101 -0.1660 -0.1003 -0.0740 -0.1101 -0.0814 -0.1238 -0.0985 -0.1288
This generation script produces three types of outputs: a line prefixed
with *O* is a copy of the original source sentence; *H* is the
hypothesis along with an average log-likelihood; and *P* is the
positional score per token position, including the
end-of-sentence marker which is omitted from the text.
See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a
......
......@@ -6,7 +6,29 @@
Learning Rate Schedulers
========================
TODO
Learning Rate Schedulers update the learning rate over the course of training.
Learning rates can be updated after each update via :func:`step_update` or at
epoch boundaries via :func:`step`.
.. automodule:: fairseq.optim.lr_scheduler
:members:
.. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
:members:
:undoc-members:
.. autoclass:: fairseq.optim.lr_scheduler.reduce_angular_lr_scheduler.TriangularSchedule
:members:
:undoc-members:
Modules
=======
Fairseq provides several stand-alone :class:`torch.nn.Module` s that may be
helpful when implementing a new :class:`FairseqModel`.
Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
be helpful when implementing a new :class:`~fairseq.models.FairseqModel`.
.. automodule:: fairseq.modules
:members:
......
......@@ -6,5 +6,27 @@
Optimizers
==========
Optimizers update the Model parameters based on the gradients.
.. automodule:: fairseq.optim
:members:
.. autoclass:: fairseq.optim.FairseqOptimizer
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adagrad.Adagrad
:members:
:undoc-members:
.. autoclass:: fairseq.optim.adam.FairseqAdam
:members:
:undoc-members:
.. autoclass:: fairseq.optim.fp16_optimizer.FP16Optimizer
:members:
:undoc-members:
.. autoclass:: fairseq.optim.nag.FairseqNAG
:members:
:undoc-members:
.. autoclass:: fairseq.optim.sgd.SGD
:members:
:undoc-members:
......@@ -22,12 +22,18 @@ fairseq implements the following high-level training flow::
for epoch in range(num_epochs):
itr = task.get_batch_iterator(task.dataset('train'))
for num_updates, batch in enumerate(itr):
loss = criterion(model, batch)
optimizer.backward(loss)
task.train_step(batch, model, criterion, optimizer)
average_and_clip_gradients()
optimizer.step()
lr_scheduler.step_update(num_updates)
lr_scheduler.step(epoch)
where the default implementation for ``train.train_step`` is roughly::
def train_step(self, batch, model, criterion, optimizer):
loss = criterion(model, batch)
optimizer.backward(loss)
**Registering new plug-ins**
New plug-ins are *registered* through a set of ``@register`` function
......
......@@ -353,17 +353,16 @@ The model files should appear in the :file:`checkpoints/` directory.
-------------------------------
Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classify.py` with the following contents::
a new file named :file:`eval_classifier.py` with the following contents::
from fairseq import data, options, tasks, utils
from fairseq.tokenizer import Tokenizer
# Parse command-line arguments for generation
parser = options.get_generation_parser()
parser = options.get_generation_parser(default_task='simple_classification')
args = options.parse_args_and_arch(parser)
# Setup task
args.task = 'simple_classification'
task = tasks.setup_task(args)
# Load model
......
......@@ -55,7 +55,9 @@ def main(parsed_args):
# Load ensemble
print('| loading model(s) from {}'.format(parsed_args.path))
models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides))
models, args = utils.load_ensemble_for_inference(
parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
)
for arg in vars(parsed_args).keys():
if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
......@@ -83,9 +85,10 @@ def main(parsed_args):
max_positions=utils.resolve_max_positions(*[
model.max_positions() for model in models
]),
ignore_invalid_inputs=True,
num_shards=args.num_shards,
shard_id=args.shard_id,
ignore_invalid_inputs=True,
num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter()
......
......@@ -9,7 +9,7 @@ from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset
from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .round_robin_zip_datasets import RoundRobinZipDatasets
......@@ -33,7 +33,6 @@ __all__ = [
'GroupedIterator',
'IndexedCachedDataset',
'IndexedDataset',
'IndexedInMemoryDataset',
'IndexedRawTextDataset',
'LanguagePairDataset',
'MonolingualDataset',
......
......@@ -56,16 +56,6 @@ def backtranslate_samples(samples, collate_fn, generate_fn, cuda=True):
class BacktranslationDataset(FairseqDataset):
def __init__(
self,
tgt_dataset,
backtranslation_fn,
max_len_a,
max_len_b,
output_collater=None,
cuda=True,
**kwargs
):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*),
......@@ -73,20 +63,31 @@ class BacktranslationDataset(FairseqDataset):
Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be
used. After backtranslation, the source sentences in this
dataset will be returned as the targets.
backtranslated. Only the source side of this dataset will be used.
After backtranslation, the source sentences in this dataset will be
returned as the targets.
backtranslation_fn (callable): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
max_len_a, max_len_b (int, int): will be used to compute
`maxlen = max_len_a * src_len + max_len_b`, which will be
passed into *backtranslation_fn*.
`maxlen = max_len_a * src_len + max_len_b`, which will be passed
into *backtranslation_fn*.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch (default:
``tgt_dataset.collater``)
backtranslated samples to create the final batch
(default: ``tgt_dataset.collater``).
cuda: use GPU for generation
"""
def __init__(
self,
tgt_dataset,
backtranslation_fn,
max_len_a,
max_len_b,
output_collater=None,
cuda=True,
**kwargs
):
self.tgt_dataset = tgt_dataset
self.backtranslation_fn = backtranslation_fn
self.max_len_a = max_len_a
......@@ -169,8 +170,7 @@ class BacktranslationDataset(FairseqDataset):
@property
def supports_prefetch(self):
return self.tgt_dataset.supports_prefetch()
return getattr(self.tgt_dataset, 'supports_prefetch', False)
def prefetch(self, indices):
return self.tgt_dataset.prefetch(indices)
......@@ -29,18 +29,18 @@ class ConcatDataset(FairseqDataset):
if isinstance(sample_ratios, int):
sample_ratios = [sample_ratios] * len(self.datasets)
self.sample_ratios = sample_ratios
self.cummulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
self.real_sizes = [len(d) for d in self.datasets]
def __len__(self):
return self.cummulative_sizes[-1]
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx = bisect.bisect_right(self.cummulative_sizes, idx)
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cummulative_sizes[dataset_idx - 1]
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
sample_idx = sample_idx % self.real_sizes[dataset_idx]
return self.datasets[dataset_idx][sample_idx]
......@@ -54,7 +54,7 @@ class ConcatDataset(FairseqDataset):
def prefetch(self, indices):
frm = 0
for to, ds in zip(self.cummulative_sizes, self.datasets):
for to, ds in zip(self.cumulative_sizes, self.datasets):
real_size = len(ds)
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
frm = to
......@@ -81,8 +81,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
size_fn (callable): function that returns the size of a given index
max_positions (tuple): filter elements larger than this size.
Comparisons are done component-wise.
raise_exception (bool, optional): if ``True``, raise an exception
if any elements are filtered. Default: ``False``
raise_exception (bool, optional): if ``True``, raise an exception if
any elements are filtered (default: False).
"""
def check_size(idx):
if isinstance(max_positions, float) or isinstance(max_positions, int):
......@@ -128,12 +128,12 @@ def batch_by_size(
indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at
a given index
max_tokens (int, optional): max number of tokens in each batch.
Default: ``None``
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch. Default: ``None``
batch (default: None).
required_batch_size_multiple (int, optional): require batch size to
be a multiple of N. Default: ``1``
be a multiple of N (default: 1).
"""
max_tokens = max_tokens if max_tokens is not None else float('Inf')
max_sentences = max_sentences if max_sentences is not None else float('Inf')
......
......@@ -200,11 +200,15 @@ class Dictionary(object):
t[-1] = self.eos()
return t
class TruncatedDictionary(object):
def __init__(self, wrapped_dict, length):
self.__class__ = type(wrapped_dict.__class__.__name__,
(self.__class__, wrapped_dict.__class__), {})
self.__class__ = type(
wrapped_dict.__class__.__name__,
(self.__class__, wrapped_dict.__class__),
{}
)
self.__dict__ = wrapped_dict.__dict__
self.wrapped_dict = wrapped_dict
self.length = min(len(self.wrapped_dict), length)
......
......@@ -7,8 +7,6 @@
import torch.utils.data
from fairseq.data import data_utils
class FairseqDataset(torch.utils.data.Dataset):
"""A dataset that provides helpers for batching."""
......@@ -51,7 +49,9 @@ class FairseqDataset(torch.utils.data.Dataset):
@property
def supports_prefetch(self):
"""Whether this dataset supports prefetching."""
return False
def prefetch(self, indices):
"""Prefetch the data required for this epoch."""
raise NotImplementedError
......@@ -52,13 +52,12 @@ def data_file_path(prefix_path):
class IndexedDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset"""
def __init__(self, path, fix_lua_indexing=False, read_data=True):
def __init__(self, path, fix_lua_indexing=False):
super().__init__()
self.fix_lua_indexing = fix_lua_indexing
self.read_index(path)
self.data_file = None
if read_data:
self.read_data(path)
self.path = path
def read_index(self, path):
with open(index_file_path(path), 'rb') as f:
......@@ -85,8 +84,10 @@ class IndexedDataset(torch.utils.data.Dataset):
self.data_file.close()
def __getitem__(self, i):
if not self.data_file:
self.read_data(self.path)
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
tensor_size = int(self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]])
a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
......@@ -98,12 +99,6 @@ class IndexedDataset(torch.utils.data.Dataset):
def __len__(self):
return self.size
def read_into(self, start, dst):
self.data_file.seek(start * self.element_size)
self.data_file.readinto(dst)
if self.fix_lua_indexing:
dst -= 1 # subtract 1 for 0-based indexing
@staticmethod
def exists(path):
return (
......@@ -111,11 +106,15 @@ class IndexedDataset(torch.utils.data.Dataset):
os.path.exists(data_file_path(path))
)
@property
def supports_prefetch(self):
return False # avoid prefetching to save memory
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path, fix_lua_indexing=False):
super().__init__(path, fix_lua_indexing, True)
super().__init__(path, fix_lua_indexing=fix_lua_indexing)
self.cache = None
self.cache_index = {}
......@@ -126,6 +125,8 @@ class IndexedCachedDataset(IndexedDataset):
def prefetch(self, indices):
if all(i in self.cache_index for i in indices):
return
if not self.data_file:
self.read_data(self.path)
indices = sorted(set(indices))
total_size = 0
for i in indices:
......@@ -153,34 +154,7 @@ class IndexedCachedDataset(IndexedDataset):
return item
class IndexedInMemoryDataset(IndexedDataset):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
def read_data(self, path):
self.data_file = open(data_file_path(path), 'rb')
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
self.data_file.readinto(self.buffer)
self.data_file.close()
if self.fix_lua_indexing:
self.buffer -= 1 # subtract 1 for 0-based indexing
def read_into(self, start, dst):
if self.token_blob is None:
self.token_blob = [t for l in self.tokens_list for t in l]
np.copyto(dst, self.token_blob[start:])
def __del__(self):
pass
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
return torch.from_numpy(a).long()
class IndexedRawTextDataset(IndexedDataset):
class IndexedRawTextDataset(torch.utils.data.Dataset):
"""Takes a text file as input and binarizes it in memory at instantiation.
Original lines are also kept in memory"""
......@@ -205,6 +179,10 @@ class IndexedRawTextDataset(IndexedDataset):
self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes)
def check_index(self, i):
if i < 0 or i >= self.size:
raise IndexError('index out of range')
def __getitem__(self, i):
self.check_index(i)
return self.tokens_list[i]
......@@ -252,7 +230,7 @@ class IndexedDatasetBuilder(object):
self.dim_offsets.append(self.dim_offsets[-1] + len(tensor.size()))
def merge_file_(self, another_file):
index = IndexedDataset(another_file, read_data=False)
index = IndexedDataset(another_file)
assert index.dtype == self.dtype
begin = self.data_offsets[-1]
......
......@@ -69,17 +69,19 @@ class EpochBatchIterator(object):
batch_sampler (~torch.utils.data.Sampler): an iterator over batches of
indices
seed (int, optional): seed for random number generator for
reproducibility. Default: 1
reproducibility (default: 1).
num_shards (int, optional): shard the data iterator into N
shards. Default: 1
shards (default: 1).
shard_id (int, optional): which shard of the data iterator to
return. Default: 0
buffer_size (int, optional): number of batches to buffer. Default: 5
return (default: 0).
num_workers (int, optional): how many subprocesses to use for data
loading. 0 means the data will be loaded in the main process
(default: 0).
"""
def __init__(
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
buffer_size=5,
num_workers=0,
):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
......@@ -88,14 +90,12 @@ class EpochBatchIterator(object):
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
self.buffer_size = buffer_size
self.num_workers = num_workers
self.epoch = 0
self._cur_epoch_itr = None
self._next_epoch_itr = None
self._supports_prefetch = (
hasattr(dataset, 'supports_prefetch') and dataset.supports_prefetch
)
self._supports_prefetch = getattr(dataset, 'supports_prefetch', False)
def __len__(self):
return len(self.frozen_batches)
......@@ -105,11 +105,10 @@ class EpochBatchIterator(object):
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True``
iterator (default: True).
fix_batches_to_gpus: ensure that batches are always
allocated to the same shards across epochs. Requires
that :attr:`dataset` supports prefetching. Default:
``False``
that :attr:`dataset` supports prefetching (default: False).
"""
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
......@@ -117,7 +116,8 @@ class EpochBatchIterator(object):
else:
self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus)
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus,
)
return self._cur_epoch_itr
def end_of_epoch(self):
......@@ -179,50 +179,14 @@ class EpochBatchIterator(object):
batches = self.frozen_batches
batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
return CountingIterator(BufferedIterator(
torch.utils.data.DataLoader(
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches,
),
buffer_size=self.buffer_size,
num_workers=self.num_workers,
))
class BufferedIterator(object):
"""Wrapper around an iterable that prefetches items into a buffer.
Args:
iterable (iterable): iterable to wrap
buffer_size (int): number of items to prefetch and buffer
"""
def __init__(self, iterable, buffer_size):
self.iterable = iterable
self.q = queue.Queue(maxsize=buffer_size)
self.thread = threading.Thread(target=self._load_q, daemon=True)
self.thread.start()
def __len__(self):
return len(self.iterable)
def __iter__(self):
return self
def __next__(self):
x = self.q.get()
if x is None:
self.thread.join()
raise StopIteration
return x[0]
def _load_q(self):
for x in self.iterable:
self.q.put([x]) # wrap in list so that it's never None
self.q.put(None)
class GroupedIterator(object):
"""Wrapper around an iterable that returns groups (chunks) of items.
......@@ -261,7 +225,7 @@ class ShardedIterator(object):
num_shards (int): number of shards to split the iterable into
shard_id (int): which shard to iterator over
fill_value (Any, optional): padding value when the iterable doesn't
evenly divide *num_shards*. Default: ``None``
evenly divide *num_shards* (default: None).
"""
def __init__(self, iterable, num_shards, shard_id, fill_value=None):
......
......@@ -79,23 +79,23 @@ class LanguagePairDataset(FairseqDataset):
tgt (torch.utils.data.Dataset, optional): target dataset to wrap
tgt_sizes (List[int], optional): target sentence lengths
tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
left_pad_source (bool, optional): pad source tensors on the left side.
Default: ``True``
left_pad_target (bool, optional): pad target tensors on the left side.
Default: ``False``
max_source_positions (int, optional): max number of tokens in the source
sentence. Default: ``1024``
max_target_positions (int, optional): max number of tokens in the target
sentence. Default: ``1024``
shuffle (bool, optional): shuffle dataset elements before batching.
Default: ``True``
left_pad_source (bool, optional): pad source tensors on the left side
(default: True).
left_pad_target (bool, optional): pad target tensors on the left side
(default: False).
max_source_positions (int, optional): max number of tokens in the
source sentence (default: 1024).
max_target_positions (int, optional): max number of tokens in the
target sentence (default: 1024).
shuffle (bool, optional): shuffle dataset elements before batching
(default: True).
input_feeding (bool, optional): create a shifted version of the targets
to be passed into the model for input feeding/teacher forcing.
Default: ``True``
remove_eos_from_source (bool, optional): if set, removes eos from end of
source if it's present. Default: ``False``
to be passed into the model for input feeding/teacher forcing
(default: True).
remove_eos_from_source (bool, optional): if set, removes eos from end
of source if it's present (default: False).
append_eos_to_target (bool, optional): if set, appends eos to end of
target if it's absent. Default: ``False``
target if it's absent (default: False).
"""
def __init__(
......@@ -223,15 +223,13 @@ class LanguagePairDataset(FairseqDataset):
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
@property
def supports_prefetch(self):
return (
hasattr(self.src, 'supports_prefetch')
and self.src.supports_prefetch
and hasattr(self.tgt, 'supports_prefetch')
and self.tgt.supports_prefetch
getattr(self.src, 'supports_prefetch', False)
and getattr(self.tgt, 'supports_prefetch', False)
)
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment