"tests/python/vscode:/vscode.git/clone" did not exist on "b569e4beafb7b7153d1761e1aef807b6e45c291b"
Commit 5dcc855a authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Add missing options to TransformerDecoderLayer

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/560

Differential Revision: D15260838

Pulled By: myleott

fbshipit-source-id: 5f80dd82775c10ce46a3e1c451ccaf0ef55bfa31
parent 8a2e6e81
......@@ -18,6 +18,7 @@ from torch.serialization import default_restore_location
from fairseq import tasks
from fairseq.models import FairseqEncoder, FairseqDecoder
def load_checkpoint_to_cpu(path):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
state = torch.load(
......
......@@ -53,6 +53,8 @@ class TransformerModel(FairseqModel):
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_fast'],
help='Which activation function to use')
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
......@@ -99,8 +101,6 @@ class TransformerModel(FairseqModel):
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
parser.add_argument('--activation-fn', choices=['relu', 'gelu', 'gelu_fast'],
help='Which activation function to use')
# fmt: on
@classmethod
......@@ -566,12 +566,15 @@ class TransformerDecoderLayer(nn.Module):
(default: False).
"""
def __init__(self, args, no_encoder_attn=False):
def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
embed_dim=self.embed_dim,
num_heads=args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
......@@ -606,9 +609,17 @@ class TransformerDecoderLayer(nn.Module):
def prepare_for_onnx_export_(self):
self.onnx_trace = True
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state,
prev_self_attn_state=None, prev_attn_state=None, self_attn_mask=None,
self_attn_padding_mask=None):
def forward(
self,
x,
encoder_out=None,
encoder_padding_mask=None,
incremental_state=None,
prev_self_attn_state=None,
prev_attn_state=None,
self_attn_mask=None,
self_attn_padding_mask=None,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
......@@ -626,7 +637,7 @@ class TransformerDecoderLayer(nn.Module):
prev_key, prev_value = prev_self_attn_state
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
self.self_attn._set_input_buffer(incremental_state, saved_state)
x, _ = self.self_attn(
x, attn = self.self_attn(
query=x,
key=x,
value=x,
......@@ -639,7 +650,6 @@ class TransformerDecoderLayer(nn.Module):
x = residual + x
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
attn = None
if self.encoder_attn is not None:
residual = x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
......@@ -670,7 +680,7 @@ class TransformerDecoderLayer(nn.Module):
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
if self.onnx_trace:
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
return x, attn, self_attn_state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment