Commit dffb1674 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Updates to model API (#561)

Summary:
- `FairseqModel` -> `FairseqEncoderDecoderModel`
- add `FairseqDecoder.extract_features` and `FairseqDecoder.output_layer`
- `encoder_out_dict` -> `encoder_out`
- rm unused `remove_head` functions
- update docs
Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/561

Differential Revision: D15271142

Pulled By: myleott

fbshipit-source-id: 8e8864e399336020f0271c780598e968ff51a264
parent a0c5f9b8
...@@ -74,12 +74,18 @@ Adding new models ...@@ -74,12 +74,18 @@ Adding new models
.. autoclass:: fairseq.models.BaseFairseqModel .. autoclass:: fairseq.models.BaseFairseqModel
:members: :members:
:undoc-members: :undoc-members:
.. autoclass:: fairseq.models.FairseqModel .. autoclass:: fairseq.models.FairseqEncoderDecoderModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqEncoderModel
:members: :members:
:undoc-members: :undoc-members:
.. autoclass:: fairseq.models.FairseqLanguageModel .. autoclass:: fairseq.models.FairseqLanguageModel
:members: :members:
:undoc-members: :undoc-members:
.. autoclass:: fairseq.models.FairseqMultiModel
:members:
:undoc-members:
.. autoclass:: fairseq.models.FairseqEncoder .. autoclass:: fairseq.models.FairseqEncoder
:members: :members:
.. autoclass:: fairseq.models.CompositeEncoder .. autoclass:: fairseq.models.CompositeEncoder
......
...@@ -2,7 +2,7 @@ Modules ...@@ -2,7 +2,7 @@ Modules
======= =======
Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may Fairseq provides several stand-alone :class:`torch.nn.Module` classes that may
be helpful when implementing a new :class:`~fairseq.models.FairseqModel`. be helpful when implementing a new :class:`~fairseq.models.BaseFairseqModel`.
.. automodule:: fairseq.modules .. automodule:: fairseq.modules
:members: :members:
......
...@@ -41,7 +41,7 @@ New plug-ins are *registered* through a set of ``@register`` function ...@@ -41,7 +41,7 @@ New plug-ins are *registered* through a set of ``@register`` function
decorators, for example:: decorators, for example::
@register_model('my_lstm') @register_model('my_lstm')
class MyLSTM(FairseqModel): class MyLSTM(FairseqEncoderDecoderModel):
(...) (...)
Once registered, new plug-ins can be used with the existing :ref:`Command-line Once registered, new plug-ins can be used with the existing :ref:`Command-line
......
...@@ -2,9 +2,9 @@ Tutorial: Simple LSTM ...@@ -2,9 +2,9 @@ Tutorial: Simple LSTM
===================== =====================
In this tutorial we will extend fairseq by adding a new In this tutorial we will extend fairseq by adding a new
:class:`~fairseq.models.FairseqModel` that encodes a source sentence with an :class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source
LSTM and then passes the final hidden state to a second LSTM that decodes the sentence with an LSTM and then passes the final hidden state to a second LSTM
target sentence (without attention). that decodes the target sentence (without attention).
This tutorial covers: This tutorial covers:
...@@ -233,18 +233,18 @@ Once the model is registered we'll be able to use it with the existing ...@@ -233,18 +233,18 @@ Once the model is registered we'll be able to use it with the existing
All registered models must implement the All registered models must implement the
:class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence :class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence
models (i.e., any model with a single Encoder and Decoder), we can instead models (i.e., any model with a single Encoder and Decoder), we can instead
implement the :class:`~fairseq.models.FairseqModel` interface. implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface.
Create a small wrapper class in the same file and register it in fairseq with Create a small wrapper class in the same file and register it in fairseq with
the name ``'simple_lstm'``:: the name ``'simple_lstm'``::
from fairseq.models import FairseqModel, register_model from fairseq.models import FairseqEncoderDecoderModel, register_model
# Note: the register_model "decorator" should immediately precede the # Note: the register_model "decorator" should immediately precede the
# definition of the Model class. # definition of the Model class.
@register_model('simple_lstm') @register_model('simple_lstm')
class SimpleLSTMModel(FairseqModel): class SimpleLSTMModel(FairseqEncoderDecoderModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
...@@ -308,7 +308,7 @@ the name ``'simple_lstm'``:: ...@@ -308,7 +308,7 @@ the name ``'simple_lstm'``::
# We could override the ``forward()`` if we wanted more control over how # We could override the ``forward()`` if we wanted more control over how
# the encoder and decoder interact, but it's not necessary for this # the encoder and decoder interact, but it's not necessary for this
# tutorial since we can inherit the default implementation provided by # tutorial since we can inherit the default implementation provided by
# the FairseqModel base class, which looks like: # the FairseqEncoderDecoderModel base class, which looks like:
# #
# def forward(self, src_tokens, src_lengths, prev_output_tokens): # def forward(self, src_tokens, src_lengths, prev_output_tokens):
# encoder_out = self.encoder(src_tokens, src_lengths) # encoder_out = self.encoder(src_tokens, src_lengths)
......
...@@ -14,10 +14,11 @@ from .fairseq_encoder import FairseqEncoder ...@@ -14,10 +14,11 @@ from .fairseq_encoder import FairseqEncoder
from .fairseq_incremental_decoder import FairseqIncrementalDecoder from .fairseq_incremental_decoder import FairseqIncrementalDecoder
from .fairseq_model import ( from .fairseq_model import (
BaseFairseqModel, BaseFairseqModel,
FairseqEncoderModel,
FairseqEncoderDecoderModel,
FairseqLanguageModel,
FairseqModel, FairseqModel,
FairseqMultiModel, FairseqMultiModel,
FairseqLanguageModel,
FairseqEncoderModel,
) )
from .composite_encoder import CompositeEncoder from .composite_encoder import CompositeEncoder
...@@ -30,6 +31,7 @@ __all__ = [ ...@@ -30,6 +31,7 @@ __all__ = [
'DistributedFairseqModel', 'DistributedFairseqModel',
'FairseqDecoder', 'FairseqDecoder',
'FairseqEncoder', 'FairseqEncoder',
'FairseqEncoderDecoderModel',
'FairseqEncoderModel', 'FairseqEncoderModel',
'FairseqIncrementalDecoder', 'FairseqIncrementalDecoder',
'FairseqLanguageModel', 'FairseqLanguageModel',
...@@ -56,12 +58,13 @@ def register_model(name): ...@@ -56,12 +58,13 @@ def register_model(name):
For example:: For example::
@register_model('lstm') @register_model('lstm')
class LSTM(FairseqModel): class LSTM(FairseqEncoderDecoderModel):
(...) (...)
.. note:: All models must implement the :class:`BaseFairseqModel` interface. .. note:: All models must implement the :class:`BaseFairseqModel` interface.
Typically you will extend :class:`FairseqModel` for sequence-to-sequence Typically you will extend :class:`FairseqEncoderDecoderModel` for
tasks or :class:`FairseqLanguageModel` for language modeling tasks. sequence-to-sequence tasks or :class:`FairseqLanguageModel` for
language modeling tasks.
Args: Args:
name (str): the name of the model name (str): the name of the model
......
...@@ -18,25 +18,40 @@ class FairseqDecoder(nn.Module): ...@@ -18,25 +18,40 @@ class FairseqDecoder(nn.Module):
self.dictionary = dictionary self.dictionary = dictionary
self.onnx_trace = False self.onnx_trace = False
def forward(self, prev_output_tokens, encoder_out): def forward(self, prev_output_tokens, encoder_out=None, **kwargs):
""" """
Args: Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape prev_output_tokens (LongTensor): shifted output tokens of shape
`(batch, tgt_len)`, for input feeding/teacher forcing `(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for encoder_out (dict, optional): output from the encoder, used for
encoder-side attention encoder-side attention
Returns: Returns:
tuple: tuple:
- the last decoder layer's output of shape - the decoder's output of shape `(batch, tgt_len, vocab)`
`(batch, tgt_len, vocab)` - a dictionary with any model-specific outputs
- the last decoder layer's attention weights of shape """
`(batch, tgt_len, src_len)` x, extra = self.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs)
x = self.output_layer(x)
return x, extra
def extract_features(self, prev_output_tokens, encoder_out=None, **kwargs):
"""
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
""" """
raise NotImplementedError raise NotImplementedError
def prepare_for_onnx_export_(self): def output_layer(self, features, **kwargs):
self.onnx_trace = True """
Project features to the default output size, e.g., vocabulary size.
Args:
features (Tensor): features returned by *extract_features*.
"""
raise NotImplementedError
def get_normalized_probs(self, net_output, log_probs, sample): def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output.""" """Get normalized probabilities (or log probs) from a net's output."""
...@@ -63,3 +78,6 @@ class FairseqDecoder(nn.Module): ...@@ -63,3 +78,6 @@ class FairseqDecoder(nn.Module):
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
"""Upgrade a (possibly old) state dict for new versions of fairseq.""" """Upgrade a (possibly old) state dict for new versions of fairseq."""
return state_dict return state_dict
def prepare_for_onnx_export_(self):
self.onnx_trace = True
...@@ -15,7 +15,7 @@ class FairseqEncoder(nn.Module): ...@@ -15,7 +15,7 @@ class FairseqEncoder(nn.Module):
super().__init__() super().__init__()
self.dictionary = dictionary self.dictionary = dictionary
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths=None, **kwargs):
""" """
Args: Args:
src_tokens (LongTensor): tokens in the source language of shape src_tokens (LongTensor): tokens in the source language of shape
......
...@@ -12,8 +12,8 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -12,8 +12,8 @@ class FairseqIncrementalDecoder(FairseqDecoder):
"""Base class for incremental decoders. """Base class for incremental decoders.
Incremental decoding is a special mode at inference time where the Model Incremental decoding is a special mode at inference time where the Model
only receives a single timestep of input corresponding to the immediately only receives a single timestep of input corresponding to the previous
previous output token (for input feeding) and must produce the next output output token (for input feeding) and must produce the next output
*incrementally*. Thus the model must cache any long-term state that is *incrementally*. Thus the model must cache any long-term state that is
needed about the sequence, e.g., hidden states, convolutional states, etc. needed about the sequence, e.g., hidden states, convolutional states, etc.
...@@ -33,22 +33,29 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -33,22 +33,29 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def __init__(self, dictionary): def __init__(self, dictionary):
super().__init__(dictionary) super().__init__(dictionary)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
""" """
Args: Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape prev_output_tokens (LongTensor): shifted output tokens of shape
`(batch, tgt_len)`, for input feeding/teacher forcing `(batch, tgt_len)`, for input feeding/teacher forcing
encoder_out (Tensor, optional): output from the encoder, used for encoder_out (dict, optional): output from the encoder, used for
encoder-side attention encoder-side attention
incremental_state (dict): dictionary used for storing state during incremental_state (dict, optional): dictionary used for storing
:ref:`Incremental decoding` state during :ref:`Incremental decoding`
Returns: Returns:
tuple: tuple:
- the last decoder layer's output of shape `(batch, tgt_len, - the decoder's output of shape `(batch, tgt_len, vocab)`
vocab)` - a dictionary with any model-specific outputs
- the last decoder layer's attention weights of shape `(batch, """
tgt_len, src_len)` raise NotImplementedError
def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **kwargs):
"""
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
""" """
raise NotImplementedError raise NotImplementedError
......
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
"""
Base classes for various fairseq models.
"""
from typing import Dict, List, Optional from typing import Dict, List, Optional
...@@ -11,6 +14,7 @@ import torch ...@@ -11,6 +14,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils
from fairseq.data import Dictionary from fairseq.data import Dictionary
from fairseq.models import FairseqDecoder, FairseqEncoder from fairseq.models import FairseqDecoder, FairseqEncoder
...@@ -30,7 +34,7 @@ class BaseFairseqModel(nn.Module): ...@@ -30,7 +34,7 @@ class BaseFairseqModel(nn.Module):
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
"""Build a new model instance.""" """Build a new model instance."""
raise NotImplementedError('FairseqModels must implement the build_model method') raise NotImplementedError('Model must implement the build_model method')
def get_targets(self, sample, net_output): def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output.""" """Get targets from either the sample or the net's output."""
...@@ -48,14 +52,14 @@ class BaseFairseqModel(nn.Module): ...@@ -48,14 +52,14 @@ class BaseFairseqModel(nn.Module):
return F.softmax(logits, dim=-1) return F.softmax(logits, dim=-1)
raise NotImplementedError raise NotImplementedError
def extract_features(self, *args, **kwargs):
"""Similar to *forward* but only return features."""
return self(*args, **kwargs)
def max_positions(self): def max_positions(self):
"""Maximum length supported by the model.""" """Maximum length supported by the model."""
return None return None
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder.max_positions()
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Copies parameters and buffers from *state_dict* into this module and """Copies parameters and buffers from *state_dict* into this module and
its descendants. its descendants.
...@@ -139,7 +143,7 @@ class BaseFairseqModel(nn.Module): ...@@ -139,7 +143,7 @@ class BaseFairseqModel(nn.Module):
self.apply(apply_prepare_for_onnx_export_) self.apply(apply_prepare_for_onnx_export_)
class FairseqModel(BaseFairseqModel): class FairseqEncoderDecoderModel(BaseFairseqModel):
"""Base class for encoder-decoder models. """Base class for encoder-decoder models.
Args: Args:
...@@ -155,7 +159,7 @@ class FairseqModel(BaseFairseqModel): ...@@ -155,7 +159,7 @@ class FairseqModel(BaseFairseqModel):
assert isinstance(self.encoder, FairseqEncoder) assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder) assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens): def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
""" """
Run the forward pass for an encoder-decoder model. Run the forward pass for an encoder-decoder model.
...@@ -174,19 +178,54 @@ class FairseqModel(BaseFairseqModel): ...@@ -174,19 +178,54 @@ class FairseqModel(BaseFairseqModel):
`(batch, tgt_len)`, for input feeding/teacher forcing `(batch, tgt_len)`, for input feeding/teacher forcing
Returns: Returns:
the decoder's output, typically of shape `(batch, tgt_len, vocab)` tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
""" """
encoder_out = self.encoder(src_tokens, src_lengths) encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(prev_output_tokens, encoder_out) decoder_out = self.decoder(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return decoder_out return decoder_out
def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
Similar to *forward* but only return features.
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
features = self.decoder.extract_features(prev_output_tokens, encoder_out=encoder_out, **kwargs)
return features
def output_layer(self, features, **kwargs):
"""Project features to the default output size (typically vocabulary size)."""
return self.decoder.output_layer(features, **kwargs)
def max_positions(self): def max_positions(self):
"""Maximum length supported by the model.""" """Maximum length supported by the model."""
return (self.encoder.max_positions(), self.decoder.max_positions()) return (self.encoder.max_positions(), self.decoder.max_positions())
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder.max_positions()
class FairseqModel(FairseqEncoderDecoderModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
utils.deprecation_warning(
'FairseqModel is deprecated, please use FairseqEncoderDecoderModel '
'or BaseFairseqModel instead',
stacklevel=4,
)
class FairseqMultiModel(BaseFairseqModel): class FairseqMultiModel(BaseFairseqModel):
"""Base class for combining multiple encoder-decoder models.""" """Base class for combining multiple encoder-decoder models."""
def __init__(self, encoders, decoders): def __init__(self, encoders, decoders):
super().__init__() super().__init__()
assert encoders.keys() == decoders.keys() assert encoders.keys() == decoders.keys()
...@@ -232,11 +271,13 @@ class FairseqMultiModel(BaseFairseqModel): ...@@ -232,11 +271,13 @@ class FairseqMultiModel(BaseFairseqModel):
shared_dict, embed_dim, pretrained_embed_path shared_dict, embed_dim, pretrained_embed_path
) )
def forward(self, src_tokens, src_lengths, prev_output_tokens): def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
decoder_outs = {} decoder_outs = {}
for key in self.keys: for key in self.keys:
encoder_out = self.models[key].encoder(src_tokens, src_lengths) encoder_out = self.models[key].encoder(src_tokens, src_lengths, **kwargs)
decoder_outs[key] = self.models[key].decoder(prev_output_tokens, encoder_out) decoder_outs[key] = self.models[key].decoder(
prev_output_tokens, encoder_out, **kwargs,
)
return decoder_outs return decoder_outs
def max_positions(self): def max_positions(self):
...@@ -271,7 +312,7 @@ class FairseqLanguageModel(BaseFairseqModel): ...@@ -271,7 +312,7 @@ class FairseqLanguageModel(BaseFairseqModel):
self.decoder = decoder self.decoder = decoder
assert isinstance(self.decoder, FairseqDecoder) assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, **kwargs):
""" """
Run the forward pass for a decoder-only model. Run the forward pass for a decoder-only model.
...@@ -283,22 +324,39 @@ class FairseqLanguageModel(BaseFairseqModel): ...@@ -283,22 +324,39 @@ class FairseqLanguageModel(BaseFairseqModel):
src_lengths (LongTensor): source sentence lengths of shape `(batch)` src_lengths (LongTensor): source sentence lengths of shape `(batch)`
Returns: Returns:
the decoder's output, typically of shape `(batch, seq_len, vocab)` tuple:
- the decoder's output of shape `(batch, seq_len, vocab)`
- a dictionary with any model-specific outputs
"""
return self.decoder(src_tokens, **kwargs)
def extract_features(self, src_tokens, **kwargs):
"""
Similar to *forward* but only return features.
Returns:
tuple:
- the decoder's features of shape `(batch, seq_len, embed_dim)`
- a dictionary with any model-specific outputs
""" """
return self.decoder(src_tokens) return self.decoder.extract_features(src_tokens, **kwargs)
def output_layer(self, features, **kwargs):
"""Project features to the default output size (typically vocabulary size)."""
return self.decoder.output_layer(features, **kwargs)
def max_positions(self): def max_positions(self):
"""Maximum length supported by the model.""" """Maximum length supported by the model."""
return self.decoder.max_positions() return self.decoder.max_positions()
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder.max_positions()
@property @property
def supported_targets(self): def supported_targets(self):
return {'future'} return {'future'}
def remove_head(self):
"""Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed"""
raise NotImplementedError()
class FairseqEncoderModel(BaseFairseqModel): class FairseqEncoderModel(BaseFairseqModel):
"""Base class for encoder-only models. """Base class for encoder-only models.
...@@ -316,14 +374,14 @@ class FairseqEncoderModel(BaseFairseqModel): ...@@ -316,14 +374,14 @@ class FairseqEncoderModel(BaseFairseqModel):
""" """
Run the forward pass for a encoder-only model. Run the forward pass for a encoder-only model.
Feeds a batch of tokens through the encoder to generate logits. Feeds a batch of tokens through the encoder to generate features.
Args: Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)` src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)` src_lengths (LongTensor): source sentence lengths of shape `(batch)`
Returns: Returns:
the encoder's output, typically of shape `(batch, seq_len, vocab)` the encoder's output, typically of shape `(batch, src_len, features)`
""" """
return self.encoder(src_tokens, src_lengths, **kwargs) return self.encoder(src_tokens, src_lengths, **kwargs)
...@@ -341,11 +399,3 @@ class FairseqEncoderModel(BaseFairseqModel): ...@@ -341,11 +399,3 @@ class FairseqEncoderModel(BaseFairseqModel):
def max_positions(self): def max_positions(self):
"""Maximum length supported by the model.""" """Maximum length supported by the model."""
return self.encoder.max_positions() return self.encoder.max_positions()
@property
def supported_targets(self):
return {'future'}
def remove_head(self):
"""Removes the head of the model (e.g. the softmax layer) to conserve space when it is not needed"""
raise NotImplementedError()
...@@ -14,7 +14,7 @@ from fairseq import utils ...@@ -14,7 +14,7 @@ from fairseq import utils
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqIncrementalDecoder, FairseqIncrementalDecoder,
FairseqModel, FairseqEncoderDecoderModel,
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
...@@ -25,7 +25,7 @@ from fairseq.modules import ( ...@@ -25,7 +25,7 @@ from fairseq.modules import (
@register_model('fconv') @register_model('fconv')
class FConvModel(FairseqModel): class FConvModel(FairseqEncoderDecoderModel):
""" """
A fully convolutional model, i.e. a convolutional encoder and a A fully convolutional model, i.e. a convolutional encoder and a
convolutional decoder, as described in `"Convolutional Sequence to Sequence convolutional decoder, as described in `"Convolutional Sequence to Sequence
...@@ -406,10 +406,10 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -406,10 +406,10 @@ class FConvDecoder(FairseqIncrementalDecoder):
else: else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, prev_output_tokens, encoder_out_dict=None, incremental_state=None): def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused):
if encoder_out_dict is not None: if encoder_out is not None:
encoder_out = encoder_out_dict['encoder_out'] encoder_padding_mask = encoder_out['encoder_padding_mask']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] encoder_out = encoder_out['encoder_out']
# split and transpose encoder outputs # split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(encoder_out, incremental_state) encoder_a, encoder_b = self._split_encoder_out(encoder_out, incremental_state)
......
...@@ -16,7 +16,7 @@ from fairseq.models import ( ...@@ -16,7 +16,7 @@ from fairseq.models import (
CompositeEncoder, CompositeEncoder,
FairseqDecoder, FairseqDecoder,
FairseqEncoder, FairseqEncoder,
FairseqModel, FairseqEncoderDecoderModel,
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
...@@ -30,7 +30,7 @@ from fairseq.modules import ( ...@@ -30,7 +30,7 @@ from fairseq.modules import (
@register_model('fconv_self_att') @register_model('fconv_self_att')
class FConvModelSelfAtt(FairseqModel): class FConvModelSelfAtt(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder, pretrained_encoder=None): def __init__(self, encoder, decoder, pretrained_encoder=None):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention) self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention)
...@@ -371,9 +371,9 @@ class FConvDecoder(FairseqDecoder): ...@@ -371,9 +371,9 @@ class FConvDecoder(FairseqDecoder):
self.pretrained_decoder.fc2.register_forward_hook(save_output()) self.pretrained_decoder.fc2.register_forward_hook(save_output())
def forward(self, prev_output_tokens, encoder_out_dict): def forward(self, prev_output_tokens, encoder_out):
encoder_out = encoder_out_dict['encoder']['encoder_out'] trained_encoder_out = encoder_out['pretrained'] if self.pretrained else None
trained_encoder_out = encoder_out_dict['pretrained'] if self.pretrained else None encoder_out = encoder_out['encoder']['encoder_out']
encoder_a, encoder_b = self._split_encoder_out(encoder_out) encoder_a, encoder_b = self._split_encoder_out(encoder_out)
......
...@@ -15,7 +15,7 @@ from fairseq import options, utils ...@@ -15,7 +15,7 @@ from fairseq import options, utils
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqIncrementalDecoder, FairseqIncrementalDecoder,
FairseqModel, FairseqEncoderDecoderModel,
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
...@@ -31,7 +31,7 @@ from fairseq.modules import ( ...@@ -31,7 +31,7 @@ from fairseq.modules import (
@register_model('lightconv') @register_model('lightconv')
class LightConvModel(FairseqModel): class LightConvModel(FairseqEncoderDecoderModel):
""" """
LightConv and DynamicConv model from `"Pay Less Attention with Lightweight and Dynamic Convolutions" (Wu, et al, 2019) LightConv and DynamicConv model from `"Pay Less Attention with Lightweight and Dynamic Convolutions" (Wu, et al, 2019)
<https://openreview.net/pdf?id=SkVhlh09tX>`_. <https://openreview.net/pdf?id=SkVhlh09tX>`_.
...@@ -213,13 +213,11 @@ class LightConvEncoder(FairseqEncoder): ...@@ -213,13 +213,11 @@ class LightConvEncoder(FairseqEncoder):
if self.normalize: if self.normalize:
self.layer_norm = LayerNorm(embed_dim) self.layer_norm = LayerNorm(embed_dim)
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, **unused):
""" """
Args: Args:
src_tokens (LongTensor): tokens in the source language of shape src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)` `(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
Returns: Returns:
dict: dict:
......
...@@ -13,7 +13,7 @@ from fairseq import options, utils ...@@ -13,7 +13,7 @@ from fairseq import options, utils
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqIncrementalDecoder, FairseqIncrementalDecoder,
FairseqModel, FairseqEncoderDecoderModel,
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
...@@ -21,7 +21,7 @@ from fairseq.modules import AdaptiveSoftmax ...@@ -21,7 +21,7 @@ from fairseq.modules import AdaptiveSoftmax
@register_model('lstm') @register_model('lstm')
class LSTMModel(FairseqModel): class LSTMModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
...@@ -356,9 +356,9 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -356,9 +356,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
elif not self.share_input_output_embed: elif not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None): def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out'] encoder_padding_mask = encoder_out['encoder_padding_mask']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask'] encoder_out = encoder_out['encoder_out']
if incremental_state is not None: if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:] prev_output_tokens = prev_output_tokens[:, -1:]
......
...@@ -15,7 +15,7 @@ from fairseq import options, utils ...@@ -15,7 +15,7 @@ from fairseq import options, utils
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqIncrementalDecoder, FairseqIncrementalDecoder,
FairseqModel, FairseqEncoderDecoderModel,
register_model, register_model,
register_model_architecture, register_model_architecture,
) )
...@@ -29,7 +29,7 @@ from fairseq.modules import ( ...@@ -29,7 +29,7 @@ from fairseq.modules import (
@register_model('transformer') @register_model('transformer')
class TransformerModel(FairseqModel): class TransformerModel(FairseqEncoderDecoderModel):
""" """
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017)
<https://arxiv.org/abs/1706.03762>`_. <https://arxiv.org/abs/1706.03762>`_.
...@@ -298,7 +298,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -298,7 +298,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
input_embed_dim = embed_tokens.embedding_dim input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim embed_dim = args.decoder_embed_dim
output_embed_dim = args.decoder_output_dim self.output_embed_dim = args.decoder_output_dim
padding_idx = embed_tokens.padding_idx padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions self.max_target_positions = args.max_target_positions
...@@ -321,13 +321,13 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -321,13 +321,13 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.adaptive_softmax = None self.adaptive_softmax = None
self.project_out_dim = Linear(embed_dim, output_embed_dim, bias=False) \ self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \
if embed_dim != output_embed_dim and not args.tie_adaptive_weights else None if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None
if args.adaptive_softmax_cutoff is not None: if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax( self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary), len(dictionary),
output_embed_dim, self.output_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int), options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout, dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
...@@ -335,14 +335,14 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -335,14 +335,14 @@ class TransformerDecoder(FairseqIncrementalDecoder):
tie_proj=args.tie_adaptive_proj, tie_proj=args.tie_adaptive_proj,
) )
elif not self.share_input_output_embed: elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim)) self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), self.output_embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim ** -0.5) nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5)
self.register_buffer('version', torch.Tensor([2])) self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.decoder_normalize_before and final_norm self.normalize = args.decoder_normalize_before and final_norm
if self.normalize: if self.normalize:
self.layer_norm = LayerNorm(embed_dim) self.layer_norm = LayerNorm(embed_dim)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None): def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused):
""" """
Args: Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape prev_output_tokens (LongTensor): previous decoder outputs of shape
...@@ -354,10 +354,21 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -354,10 +354,21 @@ class TransformerDecoder(FairseqIncrementalDecoder):
Returns: Returns:
tuple: tuple:
- the last decoder layer's output of shape `(batch, tgt_len, - the decoder's output of shape `(batch, tgt_len, vocab)`
vocab)` - a dictionary with any model-specific outputs
- the last decoder layer's attention weights of shape `(batch, """
tgt_len, src_len)` x, extra = self.extract_features(prev_output_tokens, encoder_out, incremental_state)
x = self.output_layer(x)
return x, extra
def extract_features(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused):
"""
Similar to *forward* but only return features.
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
""" """
# embed positions # embed positions
positions = self.embed_positions( positions = self.embed_positions(
...@@ -406,14 +417,18 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -406,14 +417,18 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if self.project_out_dim is not None: if self.project_out_dim is not None:
x = self.project_out_dim(x) x = self.project_out_dim(x)
return x, {'attn': attn, 'inner_states': inner_states}
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None: if self.adaptive_softmax is None:
# project back to size of vocabulary # project back to size of vocabulary
if self.share_input_output_embed: if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight) return F.linear(features, self.embed_tokens.weight)
else: else:
x = F.linear(x, self.embed_out) return F.linear(features, self.embed_out)
else:
return x, {'attn': attn, 'inner_states': inner_states} return features
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
......
...@@ -83,10 +83,10 @@ class LanguageModelingTask(FairseqTask): ...@@ -83,10 +83,10 @@ class LanguageModelingTask(FairseqTask):
help='prepend beginning of sentence token (<s>)') help='prepend beginning of sentence token (<s>)')
# fmt: on # fmt: on
def __init__(self, args, dictionary, output_dictionary, targets=None): def __init__(self, args, dictionary, output_dictionary=None, targets=None):
super().__init__(args) super().__init__(args)
self.dictionary = dictionary self.dictionary = dictionary
self.output_dictionary = output_dictionary self.output_dictionary = output_dictionary or dictionary
if targets is None: if targets is None:
targets = ['future'] targets = ['future']
......
...@@ -13,8 +13,8 @@ from fairseq.data import Dictionary ...@@ -13,8 +13,8 @@ from fairseq.data import Dictionary
from fairseq.data.language_pair_dataset import collate from fairseq.data.language_pair_dataset import collate
from fairseq.models import ( from fairseq.models import (
FairseqEncoder, FairseqEncoder,
FairseqEncoderDecoderModel,
FairseqIncrementalDecoder, FairseqIncrementalDecoder,
FairseqModel,
) )
from fairseq.tasks import FairseqTask from fairseq.tasks import FairseqTask
...@@ -154,7 +154,7 @@ class TestTranslationTask(FairseqTask): ...@@ -154,7 +154,7 @@ class TestTranslationTask(FairseqTask):
return self.tgt_dict return self.tgt_dict
class TestModel(FairseqModel): class TestModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
...@@ -170,7 +170,7 @@ class TestEncoder(FairseqEncoder): ...@@ -170,7 +170,7 @@ class TestEncoder(FairseqEncoder):
super().__init__(dictionary) super().__init__(dictionary)
self.args = args self.args = args
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths=None, **kwargs):
return src_tokens return src_tokens
def reorder_encoder_out(self, encoder_out, new_order): def reorder_encoder_out(self, encoder_out, new_order):
...@@ -184,7 +184,7 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder): ...@@ -184,7 +184,7 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100) args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100)
self.args = args self.args = args
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
if incremental_state is not None: if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:] prev_output_tokens = prev_output_tokens[:, -1:]
bbsz = prev_output_tokens.size(0) bbsz = prev_output_tokens.size(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment