Commit 6381cc97 authored by Myle Ott's avatar Myle Ott
Browse files

Add documentation

parent 0e101e9c
...@@ -8,3 +8,10 @@ ...@@ -8,3 +8,10 @@
from .multiprocessing_pdb import pdb from .multiprocessing_pdb import pdb
__all__ = ['pdb'] __all__ = ['pdb']
import fairseq.criterions
import fairseq.models
import fairseq.modules
import fairseq.optim
import fairseq.optim.lr_scheduler
import fairseq.tasks
...@@ -7,9 +7,23 @@ ...@@ -7,9 +7,23 @@
from .dictionary import Dictionary from .dictionary import Dictionary
from .fairseq_dataset import FairseqDataset from .fairseq_dataset import FairseqDataset
from .indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset # noqa: F401 from .indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset from .token_block_dataset import TokenBlockDataset
from .iterators import EpochBatchIterator from .iterators import CountingIterator, EpochBatchIterator, ShardedIterator
__all__ = [
'CountingIterator',
'Dictionary',
'EpochBatchIterator',
'FairseqDataset',
'IndexedDataset',
'IndexedInMemoryDataset',
'IndexedRawTextDataset',
'LanguagePairDataset',
'MonolingualDataset',
'TokenBlockDataset',
'ShardedIterator',
]
...@@ -74,10 +74,10 @@ class LanguagePairDataset(FairseqDataset): ...@@ -74,10 +74,10 @@ class LanguagePairDataset(FairseqDataset):
Args: Args:
src (torch.utils.data.Dataset): source dataset to wrap src (torch.utils.data.Dataset): source dataset to wrap
src_sizes (List[int]): source sentence lengths src_sizes (List[int]): source sentence lengths
src_dict (fairseq.data.Dictionary): source vocabulary src_dict (~fairseq.data.Dictionary): source vocabulary
tgt (torch.utils.data.Dataset, optional): target dataset to wrap tgt (torch.utils.data.Dataset, optional): target dataset to wrap
tgt_sizes (List[int], optional): target sentence lengths tgt_sizes (List[int], optional): target sentence lengths
tgt_dict (fairseq.data.Dictionary, optional): target vocabulary tgt_dict (~fairseq.data.Dictionary, optional): target vocabulary
left_pad_source (bool, optional): pad source tensors on the left side. left_pad_source (bool, optional): pad source tensors on the left side.
Default: ``True`` Default: ``True``
left_pad_target (bool, optional): pad target tensors on the left side. left_pad_target (bool, optional): pad target tensors on the left side.
...@@ -130,29 +130,31 @@ class LanguagePairDataset(FairseqDataset): ...@@ -130,29 +130,31 @@ class LanguagePairDataset(FairseqDataset):
def collater(self, samples): def collater(self, samples):
"""Merge a list of samples to form a mini-batch. """Merge a list of samples to form a mini-batch.
Returns mini-batches with the following keys:
- `id` (torch.LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (torch.LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will appear
on the left if *left_pad_source* is True.
- `src_lengths` (torch.LongTensor): 1D Tensor of the unpadded lengths
of each source sentence of shape `(bsz)`
- `prev_output_tokens` (torch.LongTensor): a padded 2D Tensor of
tokens in the target sentence, shifted right by one position for
input feeding/teacher forcing, of shape `(bsz, tgt_len)`. This key
will only be present if *input_feeding* is ``True``. Padding will
appear on the left if *left_pad_target* is ``True``.
- `target` (torch.LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear on the
left if *left_pad_target* is ``True``.
Args: Args:
samples (List[dict]): samples to collate samples (List[dict]): samples to collate
Returns: Returns:
dict: a mini-batch suitable for forwarding with a Model dict: a mini-batch with the following keys:
- `id` (LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will
appear on the left if *left_pad_source* is ``True``.
- `src_lengths` (LongTensor): 1D Tensor of the unpadded
lengths of each source sentence of shape `(bsz)`
- `prev_output_tokens` (LongTensor): a padded 2D Tensor of
tokens in the target sentence, shifted right by one position
for input feeding/teacher forcing, of shape `(bsz,
tgt_len)`. This key will not be present if *input_feeding*
is ``False``. Padding will appear on the left if
*left_pad_target* is ``True``.
- `target` (LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear
on the left if *left_pad_target* is ``True``.
""" """
return collate( return collate(
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(), samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
......
...@@ -40,7 +40,7 @@ class MonolingualDataset(FairseqDataset): ...@@ -40,7 +40,7 @@ class MonolingualDataset(FairseqDataset):
Args: Args:
dataset (torch.utils.data.Dataset): dataset to wrap dataset (torch.utils.data.Dataset): dataset to wrap
sizes (List[int]): sentence lengths sizes (List[int]): sentence lengths
vocab (fairseq.data.Dictionary): vocabulary vocab (~fairseq.data.Dictionary): vocabulary
shuffle (bool, optional): shuffle the elements before batching. shuffle (bool, optional): shuffle the elements before batching.
Default: ``True`` Default: ``True``
""" """
...@@ -61,22 +61,23 @@ class MonolingualDataset(FairseqDataset): ...@@ -61,22 +61,23 @@ class MonolingualDataset(FairseqDataset):
def collater(self, samples): def collater(self, samples):
"""Merge a list of samples to form a mini-batch. """Merge a list of samples to form a mini-batch.
Returns mini-batches with the following keys:
- `id` (torch.LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (torch.LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will appear
on the right.
- `target` (torch.LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear on the
right.
Args: Args:
samples (List[dict]): samples to collate samples (List[dict]): samples to collate
Returns: Returns:
dict: a mini-batch suitable for forwarding with a Model dict: a mini-batch with the following keys:
- `id` (LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will
appear on the right.
- `target` (LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear
on the right.
""" """
return collate(samples, self.vocab.pad(), self.vocab.eos()) return collate(samples, self.vocab.pad(), self.vocab.eos())
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import argparse
import importlib import importlib
import os import os
...@@ -18,6 +19,7 @@ from .composite_encoder import CompositeEncoder # noqa: F401 ...@@ -18,6 +19,7 @@ from .composite_encoder import CompositeEncoder # noqa: F401
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {} ARCH_MODEL_REGISTRY = {}
ARCH_MODEL_INV_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {} ARCH_CONFIG_REGISTRY = {}
...@@ -26,7 +28,23 @@ def build_model(args, task): ...@@ -26,7 +28,23 @@ def build_model(args, task):
def register_model(name): def register_model(name):
"""Decorator to register a new model (e.g., LSTM).""" """
New model types can be added to fairseq with the :func:`register_model`
function decorator.
For example::
@register_model('lstm')
class LSTM(FairseqModel):
(...)
.. note:: All models must implement the :class:`BaseFairseqModel` interface.
Typically you will extend :class:`FairseqModel` for sequence-to-sequence
tasks or :class:`FairseqLanguageModel` for language modeling tasks.
Args:
name (str): the name of the model
"""
def register_model_cls(cls): def register_model_cls(cls):
if name in MODEL_REGISTRY: if name in MODEL_REGISTRY:
...@@ -40,7 +58,29 @@ def register_model(name): ...@@ -40,7 +58,29 @@ def register_model(name):
def register_model_architecture(model_name, arch_name): def register_model_architecture(model_name, arch_name):
"""Decorator to register a new model architecture (e.g., lstm_luong_wmt_en_de).""" """
New model architectures can be added to fairseq with the
:func:`register_model_architecture` function decorator. After registration,
model architectures can be selected with the ``--arch`` command-line
argument.
For example::
@register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
def lstm_luong_wmt_en_de(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000)
(...)
The decorated function should take a single argument *args*, which is a
:class:`argparse.Namespace` of arguments parsed from the command-line. The
decorated function should modify these arguments in-place to match the
desired architecture.
Args:
model_name (str): the name of the Model (Model must already be
registered)
arch_name (str): the name of the model architecture (``--arch``)
"""
def register_model_arch_fn(fn): def register_model_arch_fn(fn):
if model_name not in MODEL_REGISTRY: if model_name not in MODEL_REGISTRY:
...@@ -50,6 +90,7 @@ def register_model_architecture(model_name, arch_name): ...@@ -50,6 +90,7 @@ def register_model_architecture(model_name, arch_name):
if not callable(fn): if not callable(fn):
raise ValueError('Model architecture must be callable ({})'.format(arch_name)) raise ValueError('Model architecture must be callable ({})'.format(arch_name))
ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name] ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name]
ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name)
ARCH_CONFIG_REGISTRY[arch_name] = fn ARCH_CONFIG_REGISTRY[arch_name] = fn
return fn return fn
...@@ -59,5 +100,14 @@ def register_model_architecture(model_name, arch_name): ...@@ -59,5 +100,14 @@ def register_model_architecture(model_name, arch_name):
# automatically import any Python files in the models/ directory # automatically import any Python files in the models/ directory
for file in os.listdir(os.path.dirname(__file__)): for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'): if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')] model_name = file[:file.find('.py')]
importlib.import_module('fairseq.models.' + module) module = importlib.import_module('fairseq.models.' + model_name)
# extra `model_parser` for sphinx
if model_name in MODEL_REGISTRY:
parser = argparse.ArgumentParser(add_help=False)
group_archs = parser.add_argument_group('Named architectures')
group_archs.add_argument('--arch', choices=ARCH_MODEL_INV_REGISTRY[model_name])
group_args = parser.add_argument_group('Additional command-line arguments')
MODEL_REGISTRY[model_name].add_args(group_args)
globals()[model_name + '_parser'] = parser
...@@ -10,8 +10,13 @@ from . import FairseqEncoder ...@@ -10,8 +10,13 @@ from . import FairseqEncoder
class CompositeEncoder(FairseqEncoder): class CompositeEncoder(FairseqEncoder):
""" """
Encoder class that forwards on multiple encoders, for example for a fusion model or question-answering A wrapper around a dictionary of :class:`FairseqEncoder` objects.
Accepts a dictionary of encoder, the first encoder's dictionary is used for initialization
We run forward on each encoder and return a dictionary of outputs. The first
encoder's dictionary is used for initialization.
Args:
encoders (dict): a dictionary of :class:`FairseqEncoder` objects.
""" """
def __init__(self, encoders): def __init__(self, encoders):
...@@ -21,6 +26,17 @@ class CompositeEncoder(FairseqEncoder): ...@@ -21,6 +26,17 @@ class CompositeEncoder(FairseqEncoder):
self.add_module(key, self.encoders[key]) self.add_module(key, self.encoders[key])
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (LongTensor): lengths of each source sentence of shape
`(batch)`
Returns:
dict:
the outputs from each Encoder
"""
encoder_out = {} encoder_out = {}
for key in self.encoders: for key in self.encoders:
encoder_out[key] = self.encoders[key](src_tokens, src_lengths) encoder_out[key] = self.encoders[key](src_tokens, src_lengths)
......
...@@ -17,6 +17,20 @@ class FairseqDecoder(nn.Module): ...@@ -17,6 +17,20 @@ class FairseqDecoder(nn.Module):
self.dictionary = dictionary self.dictionary = dictionary
def forward(self, prev_output_tokens, encoder_out): def forward(self, prev_output_tokens, encoder_out):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
Returns:
tuple:
- the last decoder layer's output of shape
`(batch, tgt_len, vocab)`
- the last decoder layer's attention weights of shape
`(batch, tgt_len, src_len)`
"""
raise NotImplementedError raise NotImplementedError
def get_normalized_probs(self, net_output, log_probs, sample): def get_normalized_probs(self, net_output, log_probs, sample):
...@@ -35,7 +49,8 @@ class FairseqDecoder(nn.Module): ...@@ -35,7 +49,8 @@ class FairseqDecoder(nn.Module):
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the decoder.""" """Maximum input length supported by the decoder."""
raise NotImplementedError return 1e6 # an arbitrary large number
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
return state_dict return state_dict
...@@ -16,15 +16,32 @@ class FairseqEncoder(nn.Module): ...@@ -16,15 +16,32 @@ class FairseqEncoder(nn.Module):
self.dictionary = dictionary self.dictionary = dictionary
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (LongTensor): lengths of each source sentence of shape
`(batch)`
"""
raise NotImplementedError raise NotImplementedError
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
"""Reorder encoder output according to new_order.""" """
Reorder encoder output according to `new_order`.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
`encoder_out` rearranged according to `new_order`
"""
raise NotImplementedError raise NotImplementedError
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
raise NotImplementedError return 1e6 # an arbitrary large number
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
return state_dict return state_dict
...@@ -9,12 +9,44 @@ from . import FairseqDecoder ...@@ -9,12 +9,44 @@ from . import FairseqDecoder
class FairseqIncrementalDecoder(FairseqDecoder): class FairseqIncrementalDecoder(FairseqDecoder):
"""Base class for incremental decoders.""" """Base class for incremental decoders.
Incremental decoding is a special mode at inference time where the Model
only receives a single timestep of input corresponding to the immediately
previous output token (for input feeding) and must produce the next output
*incrementally*. Thus the model must cache any long-term state that is
needed about the sequence, e.g., hidden states, convolutional states, etc.
Compared to the standard :class:`FairseqDecoder` interface, the incremental
decoder interface allows :func:`forward` functions to take an extra keyword
argument (*incremental_state*) that can be used to cache state across
time-steps.
The :class:`FairseqIncrementalDecoder` interface also defines the
:func:`reorder_incremental_state` method, which is used during beam search
to select and reorder the incremental state based on the selection of beams.
"""
def __init__(self, dictionary): def __init__(self, dictionary):
super().__init__(dictionary) super().__init__(dictionary)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len,
vocab)`
- the last decoder layer's attention weights of shape `(batch,
tgt_len, src_len)`
"""
raise NotImplementedError raise NotImplementedError
def reorder_incremental_state(self, incremental_state, new_order): def reorder_incremental_state(self, incremental_state, new_order):
......
...@@ -54,16 +54,17 @@ class BaseFairseqModel(nn.Module): ...@@ -54,16 +54,17 @@ class BaseFairseqModel(nn.Module):
return self.decoder.max_positions() return self.decoder.max_positions()
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Copies parameters and buffers from state_dict into this module and """Copies parameters and buffers from *state_dict* into this module and
its descendants. its descendants.
Overrides the method in nn.Module; compared with that method this Overrides the method in :class:`nn.Module`. Compared with that method
additionally "upgrades" state_dicts from old checkpoints. this additionally "upgrades" *state_dicts* from old checkpoints.
""" """
self.upgrade_state_dict(state_dict) self.upgrade_state_dict(state_dict)
super().load_state_dict(state_dict, strict) super().load_state_dict(state_dict, strict)
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
"""Upgrade old state dicts to work with newer code."""
assert state_dict is not None assert state_dict is not None
def do_upgrade(m, prefix): def do_upgrade(m, prefix):
...@@ -119,7 +120,12 @@ class BaseFairseqModel(nn.Module): ...@@ -119,7 +120,12 @@ class BaseFairseqModel(nn.Module):
class FairseqModel(BaseFairseqModel): class FairseqModel(BaseFairseqModel):
"""Base class for encoder-decoder models.""" """Base class for encoder-decoder models.
Args:
encoder (FairseqEncoder): the encoder
decoder (FairseqDecoder): the decoder
"""
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__() super().__init__()
...@@ -130,6 +136,26 @@ class FairseqModel(BaseFairseqModel): ...@@ -130,6 +136,26 @@ class FairseqModel(BaseFairseqModel):
assert isinstance(self.decoder, FairseqDecoder) assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens): def forward(self, src_tokens, src_lengths, prev_output_tokens):
"""
Run the forward pass for an encoder-decoder model.
First feed a batch of source tokens through the encoder. Then, feed the
encoder output and previous decoder outputs (i.e., input feeding/teacher
forcing) to the decoder to produce the next outputs::
encoder_out = self.encoder(src_tokens, src_lengths)
return self.decoder(prev_output_tokens, encoder_out)
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
Returns:
the decoder's output, typically of shape `(batch, tgt_len, vocab)`
"""
encoder_out = self.encoder(src_tokens, src_lengths) encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out = self.decoder(prev_output_tokens, encoder_out) decoder_out = self.decoder(prev_output_tokens, encoder_out)
return decoder_out return decoder_out
...@@ -140,7 +166,11 @@ class FairseqModel(BaseFairseqModel): ...@@ -140,7 +166,11 @@ class FairseqModel(BaseFairseqModel):
class FairseqLanguageModel(BaseFairseqModel): class FairseqLanguageModel(BaseFairseqModel):
"""Base class for decoder-only models.""" """Base class for decoder-only models.
Args:
decoder (FairseqDecoder): the decoder
"""
def __init__(self, decoder): def __init__(self, decoder):
super().__init__() super().__init__()
...@@ -148,6 +178,19 @@ class FairseqLanguageModel(BaseFairseqModel): ...@@ -148,6 +178,19 @@ class FairseqLanguageModel(BaseFairseqModel):
assert isinstance(self.decoder, FairseqDecoder) assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
"""
Run the forward pass for a decoder-only model.
Feeds a batch of tokens through the decoder to predict the next tokens.
Args:
src_tokens (LongTensor): tokens on which to condition the decoder,
of shape `(batch, tgt_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
Returns:
the decoder's output, typically of shape `(batch, seq_len, vocab)`
"""
return self.decoder(src_tokens) return self.decoder(src_tokens)
def max_positions(self): def max_positions(self):
......
...@@ -24,6 +24,23 @@ from . import ( ...@@ -24,6 +24,23 @@ from . import (
@register_model('fconv') @register_model('fconv')
class FConvModel(FairseqModel): class FConvModel(FairseqModel):
"""
A fully convolutional model, i.e. a convolutional encoder and a
convolutional decoder, as described in `"Convolutional Sequence to Sequence
Learning" (Gehring et al., 2017) <https://arxiv.org/abs/1705.03122>`_.
Args:
encoder (FConvEncoder): the encoder
decoder (FConvDecoder): the decoder
The Convolutional model provides the following named architectures and
command-line arguments:
.. argparse::
:ref: fairseq.models.fconv_parser
:prog:
"""
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention) self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention)
...@@ -145,7 +162,26 @@ class FConvLanguageModel(FairseqLanguageModel): ...@@ -145,7 +162,26 @@ class FConvLanguageModel(FairseqLanguageModel):
class FConvEncoder(FairseqEncoder): class FConvEncoder(FairseqEncoder):
"""Convolutional encoder""" """
Convolutional encoder consisting of `len(convolutions)` layers.
Args:
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_dim (int, optional): embedding dimension
embed_dict (str, optional): filename from which to load pre-trained
embeddings
max_positions (int, optional): maximum supported input sequence length
convolutions (list, optional): the convolutional layer structure. Each
list item `i` corresponds to convolutional layer `i`. Layers are
given as ``(out_channels, kernel_width, [residual])``. Residual
connections are added between layers when ``residual=1`` (which is
the default behavior).
dropout (float, optional): dropout to be applied before each conv layer
normalization_constant (float, optional): multiplies the result of the
residual block by sqrt(value)
left_pad (bool, optional): whether the input is left-padded. Default:
``True``
"""
def __init__( def __init__(
self, dictionary, embed_dim=512, embed_dict=None, max_positions=1024, self, dictionary, embed_dim=512, embed_dict=None, max_positions=1024,
...@@ -198,6 +234,23 @@ class FConvEncoder(FairseqEncoder): ...@@ -198,6 +234,23 @@ class FConvEncoder(FairseqEncoder):
self.fc2 = Linear(in_channels, embed_dim) self.fc2 = Linear(in_channels, embed_dim)
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (LongTensor): lengths of each source sentence of shape
`(batch)`
Returns:
dict:
- **encoder_out** (tuple): a tuple with two elements, where the
first element is the last encoder layer's output and the
second element is the same quantity summed with the input
embedding (used for attention). The shape of both tensors is
`(batch, src_len, embed_dim)`.
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
"""
# embed tokens and positions # embed tokens and positions
x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens) x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
......
...@@ -27,6 +27,22 @@ from . import ( ...@@ -27,6 +27,22 @@ from . import (
@register_model('transformer') @register_model('transformer')
class TransformerModel(FairseqModel): class TransformerModel(FairseqModel):
"""
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
<https://arxiv.org/abs/1706.03762>`_.
Args:
encoder (TransformerEncoder): the encoder
decoder (TransformerDecoder): the decoder
The Transformer model provides the following named architectures and
command-line arguments:
.. argparse::
:ref: fairseq.models.transformer_parser
:prog:
"""
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
...@@ -202,7 +218,17 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -202,7 +218,17 @@ class TransformerLanguageModel(FairseqLanguageModel):
class TransformerEncoder(FairseqEncoder): class TransformerEncoder(FairseqEncoder):
"""Transformer encoder.""" """
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
left_pad (bool, optional): whether the input is left-padded. Default:
``True``
"""
def __init__(self, args, dictionary, embed_tokens, left_pad=True): def __init__(self, args, dictionary, embed_tokens, left_pad=True):
super().__init__(dictionary) super().__init__(dictionary)
...@@ -231,6 +257,20 @@ class TransformerEncoder(FairseqEncoder): ...@@ -231,6 +257,20 @@ class TransformerEncoder(FairseqEncoder):
self.layer_norm = LayerNorm(embed_dim) self.layer_norm = LayerNorm(embed_dim)
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
"""
# embed tokens and positions # embed tokens and positions
x = self.embed_scale * self.embed_tokens(src_tokens) x = self.embed_scale * self.embed_tokens(src_tokens)
if self.embed_positions is not None: if self.embed_positions is not None:
...@@ -258,6 +298,16 @@ class TransformerEncoder(FairseqEncoder): ...@@ -258,6 +298,16 @@ class TransformerEncoder(FairseqEncoder):
} }
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if encoder_out['encoder_out'] is not None: if encoder_out['encoder_out'] is not None:
encoder_out['encoder_out'] = \ encoder_out['encoder_out'] = \
encoder_out['encoder_out'].index_select(1, new_order) encoder_out['encoder_out'].index_select(1, new_order)
...@@ -273,6 +323,7 @@ class TransformerEncoder(FairseqEncoder): ...@@ -273,6 +323,7 @@ class TransformerEncoder(FairseqEncoder):
return min(self.max_source_positions, self.embed_positions.max_positions()) return min(self.max_source_positions, self.embed_positions.max_positions())
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'encoder.embed_positions.weights' in state_dict: if 'encoder.embed_positions.weights' in state_dict:
del state_dict['encoder.embed_positions.weights'] del state_dict['encoder.embed_positions.weights']
...@@ -286,7 +337,19 @@ class TransformerEncoder(FairseqEncoder): ...@@ -286,7 +337,19 @@ class TransformerEncoder(FairseqEncoder):
class TransformerDecoder(FairseqIncrementalDecoder): class TransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder.""" """
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
left_pad (bool, optional): whether the input is left-padded. Default:
``False``
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True): def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True):
super().__init__(dictionary) super().__init__(dictionary)
...@@ -338,6 +401,22 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -338,6 +401,22 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.layer_norm = LayerNorm(embed_dim) self.layer_norm = LayerNorm(embed_dim)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
Returns:
tuple:
- the last decoder layer's output of shape `(batch, tgt_len,
vocab)`
- the last decoder layer's attention weights of shape `(batch,
tgt_len, src_len)`
"""
# embed positions # embed positions
positions = self.embed_positions( positions = self.embed_positions(
prev_output_tokens, prev_output_tokens,
...@@ -397,6 +476,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -397,6 +476,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
return min(self.max_target_positions, self.embed_positions.max_positions()) return min(self.max_target_positions, self.embed_positions.max_positions())
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'decoder.embed_positions.weights' in state_dict: if 'decoder.embed_positions.weights' in state_dict:
del state_dict['decoder.embed_positions.weights'] del state_dict['decoder.embed_positions.weights']
...@@ -429,12 +509,15 @@ class TransformerEncoderLayer(nn.Module): ...@@ -429,12 +509,15 @@ class TransformerEncoderLayer(nn.Module):
"""Encoder layer block. """Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is In the original paper each operation (multi-head attention or FFN) is
postprocessed with: dropout -> add residual -> layernorm. postprocessed with: `dropout -> add residual -> layernorm`. In the
In the tensor2tensor code they suggest that learning is more robust when tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with: preprocessing each layer with layernorm and postprocessing with:
dropout -> add residual. `dropout -> add residual`. We default to the approach in the paper, but the
We default to the approach in the paper, but the tensor2tensor approach can tensor2tensor approach can be enabled by setting
be enabled by setting `normalize_before=True`. *args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
""" """
def __init__(self, args): def __init__(self, args):
...@@ -452,6 +535,15 @@ class TransformerEncoderLayer(nn.Module): ...@@ -452,6 +535,15 @@ class TransformerEncoderLayer(nn.Module):
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(2)]) self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(2)])
def forward(self, x, encoder_padding_mask): def forward(self, x, encoder_padding_mask):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
Returns:
encoded output of shape `(batch, src_len, embed_dim)`
"""
residual = x residual = x
x = self.maybe_layer_norm(0, x, before=True) x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask) x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask)
...@@ -478,7 +570,21 @@ class TransformerEncoderLayer(nn.Module): ...@@ -478,7 +570,21 @@ class TransformerEncoderLayer(nn.Module):
class TransformerDecoderLayer(nn.Module): class TransformerDecoderLayer(nn.Module):
"""Decoder layer block.""" """Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.decoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs.
Default: ``False``
"""
def __init__(self, args, no_encoder_attn=False): def __init__(self, args, no_encoder_attn=False):
super().__init__() super().__init__()
...@@ -510,6 +616,15 @@ class TransformerDecoderLayer(nn.Module): ...@@ -510,6 +616,15 @@ class TransformerDecoderLayer(nn.Module):
self.need_attn = True self.need_attn = True
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state): def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
Returns:
encoded output of shape `(batch, src_len, embed_dim)`
"""
residual = x residual = x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True) x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x, _ = self.self_attn( x, _ = self.self_attn(
......
...@@ -13,8 +13,8 @@ from torch import nn ...@@ -13,8 +13,8 @@ from torch import nn
class Highway(torch.nn.Module): class Highway(torch.nn.Module):
""" """
A `Highway layer <https://arxiv.org/abs/1505.00387> A `Highway layer <https://arxiv.org/abs/1505.00387>`_.
Adopted from the AllenNLP implementation Adopted from the AllenNLP implementation.
""" """
def __init__( def __init__(
......
...@@ -29,13 +29,13 @@ class LinearizedConvolution(ConvTBC): ...@@ -29,13 +29,13 @@ class LinearizedConvolution(ConvTBC):
def forward(self, input, incremental_state=None): def forward(self, input, incremental_state=None):
""" """
Input:
Time x Batch x Channel during training
Batch x Time x Channel during inference
Args: Args:
incremental_state: Used to buffer signal; if not None, then input is incremental_state: Used to buffer signal; if not None, then input is
expected to contain a single frame. If the input order changes expected to contain a single frame. If the input order changes
between time steps, call reorder_incremental_state. between time steps, call reorder_incremental_state.
Input:
Time x Batch x Channel during training
Batch x Time x Channel during inference
""" """
if incremental_state is None: if incremental_state is None:
output = super().forward(input) output = super().forward(input)
......
...@@ -35,6 +35,10 @@ def get_generation_parser(interactive=False, default_task='translation'): ...@@ -35,6 +35,10 @@ def get_generation_parser(interactive=False, default_task='translation'):
return parser return parser
def get_interactive_generation_parser(default_task='translation'):
return get_generation_parser(interactive=True, default_task=default_task)
def get_eval_lm_parser(default_task='language_modeling'): def get_eval_lm_parser(default_task='language_modeling'):
parser = get_parser('Evaluate Language Model', default_task) parser = get_parser('Evaluate Language Model', default_task)
add_dataset_args(parser, gen=True) add_dataset_args(parser, gen=True)
...@@ -115,8 +119,7 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False): ...@@ -115,8 +119,7 @@ def parse_args_and_arch(parser, input_args=None, parse_known=False):
def get_parser(desc, default_task='translation'): def get_parser(desc, default_task='translation'):
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser()
description='Facebook AI Research Sequence-to-Sequence Toolkit -- ' + desc)
parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar')
parser.add_argument('--log-interval', type=int, default=1000, metavar='N', parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
help='log progress every N batches (when progress bar is disabled)') help='log progress every N batches (when progress bar is disabled)')
...@@ -128,8 +131,9 @@ def get_parser(desc, default_task='translation'): ...@@ -128,8 +131,9 @@ def get_parser(desc, default_task='translation'):
# Task definitions can be found under fairseq/tasks/ # Task definitions can be found under fairseq/tasks/
parser.add_argument( parser.add_argument(
'--task', metavar='TASK', default=default_task, choices=TASK_REGISTRY.keys(), '--task', metavar='TASK', default=default_task,
help='task: {} (default: {})'.format(', '.join(TASK_REGISTRY.keys()), default_task) choices=TASK_REGISTRY.keys(),
help='task',
) )
return parser return parser
...@@ -199,7 +203,7 @@ def add_optimization_args(parser): ...@@ -199,7 +203,7 @@ def add_optimization_args(parser):
# Optimizer definitions can be found under fairseq/optim/ # Optimizer definitions can be found under fairseq/optim/
group.add_argument('--optimizer', default='nag', metavar='OPT', group.add_argument('--optimizer', default='nag', metavar='OPT',
choices=OPTIMIZER_REGISTRY.keys(), choices=OPTIMIZER_REGISTRY.keys(),
help='optimizer: {} (default: nag)'.format(', '.join(OPTIMIZER_REGISTRY.keys()))) help='Optimizer')
group.add_argument('--lr', '--learning-rate', default='0.25', metavar='LR_1,LR_2,...,LR_N', group.add_argument('--lr', '--learning-rate', default='0.25', metavar='LR_1,LR_2,...,LR_N',
help='learning rate for the first N epochs; all epochs >N using LR_N' help='learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)') ' (note: this may be interpreted differently depending on --lr-scheduler)')
...@@ -210,8 +214,8 @@ def add_optimization_args(parser): ...@@ -210,8 +214,8 @@ def add_optimization_args(parser):
# Learning rate schedulers can be found under fairseq/optim/lr_scheduler/ # Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
group.add_argument('--lr-scheduler', default='reduce_lr_on_plateau', group.add_argument('--lr-scheduler', default='reduce_lr_on_plateau',
help='learning rate scheduler: {} (default: reduce_lr_on_plateau)'.format( choices=LR_SCHEDULER_REGISTRY.keys(),
', '.join(LR_SCHEDULER_REGISTRY.keys()))) help='Learning Rate Scheduler')
group.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', group.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS',
help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)') help='learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)')
group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR', group.add_argument('--min-lr', default=1e-5, type=float, metavar='LR',
...@@ -337,16 +341,14 @@ def add_model_args(parser): ...@@ -337,16 +341,14 @@ def add_model_args(parser):
group.add_argument( group.add_argument(
'--arch', '-a', default='fconv', metavar='ARCH', required=True, '--arch', '-a', default='fconv', metavar='ARCH', required=True,
choices=ARCH_MODEL_REGISTRY.keys(), choices=ARCH_MODEL_REGISTRY.keys(),
help='model architecture: {} (default: fconv)'.format( help='Model Architecture',
', '.join(ARCH_MODEL_REGISTRY.keys())),
) )
# Criterion definitions can be found under fairseq/criterions/ # Criterion definitions can be found under fairseq/criterions/
group.add_argument( group.add_argument(
'--criterion', default='cross_entropy', metavar='CRIT', '--criterion', default='cross_entropy', metavar='CRIT',
choices=CRITERION_REGISTRY.keys(), choices=CRITERION_REGISTRY.keys(),
help='training criterion: {} (default: cross_entropy)'.format( help='Training Criterion',
', '.join(CRITERION_REGISTRY.keys())),
) )
return group return group
# Copyright (c) 2017-present, Facebook, Inc. # Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
# #
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import argparse
import importlib import importlib
import os import os
...@@ -20,7 +20,26 @@ def setup_task(args): ...@@ -20,7 +20,26 @@ def setup_task(args):
def register_task(name): def register_task(name):
"""Decorator to register a new task.""" """
New tasks can be added to fairseq with the
:func:`~fairseq.tasks.register_task` function decorator.
For example::
@register_task('classification')
class ClassificationTask(FairseqTask):
(...)
.. note::
All Tasks must implement the :class:`~fairseq.tasks.FairseqTask`
interface.
Please see the
Args:
name (str): the name of the task
"""
def register_task_cls(cls): def register_task_cls(cls):
if name in TASK_REGISTRY: if name in TASK_REGISTRY:
...@@ -39,5 +58,17 @@ def register_task(name): ...@@ -39,5 +58,17 @@ def register_task(name):
# automatically import any Python files in the tasks/ directory # automatically import any Python files in the tasks/ directory
for file in os.listdir(os.path.dirname(__file__)): for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'): if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')] task_name = file[:file.find('.py')]
importlib.import_module('fairseq.tasks.' + module) importlib.import_module('fairseq.tasks.' + task_name)
# expose `task_parser` for sphinx
if task_name in TASK_REGISTRY:
parser = argparse.ArgumentParser(add_help=False)
group_task = parser.add_argument_group('Task name')
group_task.add_argument(
'--task', metavar=task_name,
help='Enable this task with: ``--task=' + task_name + '``'
)
group_args = parser.add_argument_group('Additional command-line arguments')
TASK_REGISTRY[task_name].add_args(group_args)
globals()[task_name + '_parser'] = parser
...@@ -25,13 +25,31 @@ class FairseqTask(object): ...@@ -25,13 +25,31 @@ class FairseqTask(object):
@classmethod @classmethod
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
"""
return cls(args) return cls(args)
def load_dataset(self, split, combine=False): def load_dataset(self, split, combine=False):
"""Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
raise NotImplementedError raise NotImplementedError
def dataset(self, split): def dataset(self, split):
"""Return a dataset split.""" """
Return a loaded dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
Returns:
a :class:`~fairseq.data.FairseqDataset` corresponding to *split*
"""
from fairseq.data import FairseqDataset from fairseq.data import FairseqDataset
if split not in self.datasets: if split not in self.datasets:
raise KeyError('Dataset not loaded: ' + split) raise KeyError('Dataset not loaded: ' + split)
...@@ -48,7 +66,7 @@ class FairseqTask(object): ...@@ -48,7 +66,7 @@ class FairseqTask(object):
Get an iterator that yields batches of data from the given dataset. Get an iterator that yields batches of data from the given dataset.
Args: Args:
dataset (FairseqDataset): dataset to batch dataset (~fairseq.data.FairseqDataset): dataset to batch
max_tokens (int, optional): max number of tokens in each batch. max_tokens (int, optional): max number of tokens in each batch.
Default: ``None`` Default: ``None``
max_sentences (int, optional): max number of sentences in each max_sentences (int, optional): max number of sentences in each
...@@ -67,7 +85,8 @@ class FairseqTask(object): ...@@ -67,7 +85,8 @@ class FairseqTask(object):
return. Default: ``0`` return. Default: ``0``
Returns: Returns:
EpochBatchIterator: a batched iterator over the given dataset split ~fairseq.iterators.EpochBatchIterator: a batched iterator over the
given dataset split
""" """
assert isinstance(dataset, FairseqDataset) assert isinstance(dataset, FairseqDataset)
...@@ -97,23 +116,58 @@ class FairseqTask(object): ...@@ -97,23 +116,58 @@ class FairseqTask(object):
) )
def build_model(self, args): def build_model(self, args):
"""
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this
task.
Args:
args (argparse.Namespace): parsed command-line arguments
Returns:
a :class:`~fairseq.models.BaseFairseqModel` instance
"""
from fairseq import models from fairseq import models
return models.build_model(args, self) return models.build_model(args, self)
def build_criterion(self, args): def build_criterion(self, args):
"""
Build the :class:`~fairseq.criterions.FairseqCriterion` instance for
this task.
Args:
args (argparse.Namespace): parsed command-line arguments
Returns:
a :class:`~fairseq.criterions.FairseqCriterion` instance
"""
from fairseq import criterions from fairseq import criterions
return criterions.build_criterion(args, self) return criterions.build_criterion(args, self)
def get_loss(self, model, criterion, sample): def get_loss(self, model, criterion, sample):
"""
Return the loss as computed by *criterion* for the given *model* and
*sample*.
Args:
model (~fairseq.models.BaseFairseqModel): the model
criterion (~fairseq.criterions.FairseqCriterion): the criterion
sample (dict): the mini-batch. The format is defined by the
:class:`~fairseq.data.FairseqDataset`.
"""
return criterion(model, sample) return criterion(model, sample)
def max_positions(self): def max_positions(self):
"""Return the max input length allowed by the task."""
return None return None
@property @property
def source_dictionary(self): def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable
for this task)."""
raise NotImplementedError raise NotImplementedError
@property @property
def target_dictionary(self): def target_dictionary(self):
"""Return the target :class:`~fairseq.data.Dictionary` (if applicable
for this task)."""
raise NotImplementedError raise NotImplementedError
...@@ -21,18 +21,37 @@ from . import FairseqTask, register_task ...@@ -21,18 +21,37 @@ from . import FairseqTask, register_task
@register_task('language_modeling') @register_task('language_modeling')
class LanguageModelingTask(FairseqTask): class LanguageModelingTask(FairseqTask):
"""
Train a language model.
Args:
dictionary (Dictionary): the dictionary for the language model
.. note::
The language modeling task is compatible with :mod:`train.py <train>`,
:mod:`generate.py <generate>`, :mod:`interactive.py <interactive>` and
:mod:`eval_lm.py <eval_lm>`.
The language modeling task provides the following additional command-line
arguments:
.. argparse::
:ref: fairseq.tasks.language_modeling_parser
:prog:
"""
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', help='path to data directory') parser.add_argument('data', help='path to data directory')
parser.add_argument('--sample-break-mode', metavar='VAL', parser.add_argument('--sample-break-mode',
choices=['none', 'complete', 'eos'], choices=['none', 'complete', 'eos'],
help='If omitted or "none", fills each sample with tokens-per-sample ' help='If omitted or "none", fills each sample with tokens-per-sample '
'tokens. If set to "complete", splits samples only at the end ' 'tokens. If set to "complete", splits samples only at the end '
'of sentence, but may include multiple sentences per sample. ' 'of sentence, but may include multiple sentences per sample. '
'If set to "eos", includes only one sentence per sample.') 'If set to "eos", includes only one sentence per sample.')
parser.add_argument('--tokens-per-sample', default=1024, type=int, metavar='N', parser.add_argument('--tokens-per-sample', default=1024, type=int,
help='max number of tokens per sample for LM dataset') help='max number of tokens per sample for LM dataset')
parser.add_argument('--raw-text', default=False, action='store_true', parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset') help='load raw text dataset')
...@@ -43,12 +62,21 @@ class LanguageModelingTask(FairseqTask): ...@@ -43,12 +62,21 @@ class LanguageModelingTask(FairseqTask):
@classmethod @classmethod
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
"""
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt')) dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary))) print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary) return cls(args, dictionary)
def load_dataset(self, split, combine=False): def load_dataset(self, split, combine=False):
"""Load a dataset split.""" """Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
loaded_datasets = [] loaded_datasets = []
...@@ -90,4 +118,6 @@ class LanguageModelingTask(FairseqTask): ...@@ -90,4 +118,6 @@ class LanguageModelingTask(FairseqTask):
@property @property
def target_dictionary(self): def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.dictionary return self.dictionary
...@@ -22,11 +22,30 @@ from . import FairseqTask, register_task ...@@ -22,11 +22,30 @@ from . import FairseqTask, register_task
@register_task('translation') @register_task('translation')
class TranslationTask(FairseqTask): class TranslationTask(FairseqTask):
"""
Translate from one (source) language to another (target) language.
Args:
src_dict (Dictionary): dictionary for the source language
tgt_dict (Dictionary): dictionary for the target language
.. note::
The translation task is compatible with :mod:`train.py <train>`,
:mod:`generate.py <generate>` and :mod:`interactive.py <interactive>`.
The translation task provides the following additional command-line
arguments:
.. argparse::
:ref: fairseq.tasks.translation_parser
:prog:
"""
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add task-specific arguments to the parser.""" """Add task-specific arguments to the parser."""
parser.add_argument('data', metavar='DIR', help='path to data directory') parser.add_argument('data', help='path to data directory')
parser.add_argument('-s', '--source-lang', default=None, metavar='SRC', parser.add_argument('-s', '--source-lang', default=None, metavar='SRC',
help='source language') help='source language')
parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET', parser.add_argument('-t', '--target-lang', default=None, metavar='TARGET',
...@@ -34,9 +53,9 @@ class TranslationTask(FairseqTask): ...@@ -34,9 +53,9 @@ class TranslationTask(FairseqTask):
parser.add_argument('--raw-text', action='store_true', parser.add_argument('--raw-text', action='store_true',
help='load raw text dataset') help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left (default: True)') help='pad the source on the left')
parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL', parser.add_argument('--left-pad-target', default='False', type=str, metavar='BOOL',
help='pad the target on the left (default: False)') help='pad the target on the left')
parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N', parser.add_argument('--max-source-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the source sequence') help='max number of tokens in the source sequence')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N', parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
...@@ -51,6 +70,11 @@ class TranslationTask(FairseqTask): ...@@ -51,6 +70,11 @@ class TranslationTask(FairseqTask):
@classmethod @classmethod
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
"""Setup the task (e.g., load dictionaries).
Args:
args (argparse.Namespace): parsed command-line arguments
"""
args.left_pad_source = options.eval_bool(args.left_pad_source) args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target) args.left_pad_target = options.eval_bool(args.left_pad_target)
...@@ -72,7 +96,11 @@ class TranslationTask(FairseqTask): ...@@ -72,7 +96,11 @@ class TranslationTask(FairseqTask):
return cls(args, src_dict, tgt_dict) return cls(args, src_dict, tgt_dict)
def load_dataset(self, split, combine=False): def load_dataset(self, split, combine=False):
"""Load a dataset split.""" """Load a given dataset split.
Args:
split (str): name of the split (e.g., train, valid, test)
"""
def split_exists(split, src, tgt, lang): def split_exists(split, src, tgt, lang):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang)) filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
...@@ -140,12 +168,15 @@ class TranslationTask(FairseqTask): ...@@ -140,12 +168,15 @@ class TranslationTask(FairseqTask):
) )
def max_positions(self): def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.args.max_source_positions, self.args.max_target_positions) return (self.args.max_source_positions, self.args.max_target_positions)
@property @property
def source_dictionary(self): def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary`."""
return self.src_dict return self.src_dict
@property @property
def target_dictionary(self): def target_dictionary(self):
"""Return the target :class:`~fairseq.data.Dictionary`."""
return self.tgt_dict return self.tgt_dict
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
"""
Translate pre-processed data with a trained model.
"""
import torch import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment