Commit 5dcc855a authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add missing options to TransformerDecoderLayer

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/560

Differential Revision: D15260838

Pulled By: myleott

fbshipit-source-id: 5f80dd82775c10ce46a3e1c451ccaf0ef55bfa31
parent 8a2e6e81
...@@ -18,6 +18,7 @@ from torch.serialization import default_restore_location ...@@ -18,6 +18,7 @@ from torch.serialization import default_restore_location
from fairseq import tasks from fairseq import tasks
from fairseq.models import FairseqEncoder, FairseqDecoder from fairseq.models import FairseqEncoder, FairseqDecoder
def load_checkpoint_to_cpu(path): def load_checkpoint_to_cpu(path):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).""" """Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
state = torch.load( state = torch.load(
......
...@@ -53,6 +53,8 @@ class TransformerModel(FairseqModel): ...@@ -53,6 +53,8 @@ class TransformerModel(FairseqModel):
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# fmt: off # fmt: off
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_fast'],
help='Which activation function to use')
parser.add_argument('--dropout', type=float, metavar='D', parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability') help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D', parser.add_argument('--attention-dropout', type=float, metavar='D',
...@@ -99,8 +101,6 @@ class TransformerModel(FairseqModel): ...@@ -99,8 +101,6 @@ class TransformerModel(FairseqModel):
'Must be used with adaptive_loss criterion'), 'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections') help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_fast'],
help='Which activation function to use')
# fmt: on # fmt: on
@classmethod @classmethod
...@@ -566,12 +566,15 @@ class TransformerDecoderLayer(nn.Module): ...@@ -566,12 +566,15 @@ class TransformerDecoderLayer(nn.Module):
(default: False). (default: False).
""" """
def __init__(self, args, no_encoder_attn=False): def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
super().__init__() super().__init__()
self.embed_dim = args.decoder_embed_dim self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention( self.self_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads, embed_dim=self.embed_dim,
num_heads=args.decoder_attention_heads,
dropout=args.attention_dropout, dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
) )
self.dropout = args.dropout self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn( self.activation_fn = utils.get_activation_fn(
...@@ -606,9 +609,17 @@ class TransformerDecoderLayer(nn.Module): ...@@ -606,9 +609,17 @@ class TransformerDecoderLayer(nn.Module):
def prepare_for_onnx_export_(self): def prepare_for_onnx_export_(self):
self.onnx_trace = True self.onnx_trace = True
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, def forward(
prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None, self,
self_attn_padding_mask=None): x,
encoder_out=None,
encoder_padding_mask=None,
incremental_state=None,
prev_self_attn_state=None,
prev_attn_state=None,
self_attn_mask=None,
self_attn_padding_mask=None,
):
""" """
Args: Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
...@@ -626,7 +637,7 @@ class TransformerDecoderLayer(nn.Module): ...@@ -626,7 +637,7 @@ class TransformerDecoderLayer(nn.Module):
prev_key, prev_value = prev_self_attn_state prev_key, prev_value = prev_self_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value} saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.self_attn._set_input_buffer(incremental_state, saved_state) self.self_attn._set_input_buffer(incremental_state, saved_state)
x, _ = self.self_attn( x, attn = self.self_attn(
query=x, query=x,
key=x, key=x,
value=x, value=x,
...@@ -639,7 +650,6 @@ class TransformerDecoderLayer(nn.Module): ...@@ -639,7 +650,6 @@ class TransformerDecoderLayer(nn.Module):
x = residual + x x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True) x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
attn = None
if self.encoder_attn is not None: if self.encoder_attn is not None:
residual = x residual = x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True) x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
...@@ -670,7 +680,7 @@ class TransformerDecoderLayer(nn.Module): ...@@ -670,7 +680,7 @@ class TransformerDecoderLayer(nn.Module):
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True) x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
if self.onnx_trace: if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state) saved_state = self.self_attn._get_input_buffer(incremental_state)
self_attn_state = saved_state["prev_key"], saved_state["prev_value"] self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
return x, attn, self_attn_state return x, attn, self_attn_state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment