"vscode:/vscode.git/clone" did not exist on "a2908d88d093a05bcdebe9011ba69b6dd48156ba"
Commit 8ca05802 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Replace --decoder-final-norm with --no-decoder-final-norm

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/620

Differential Revision: D15569440

Pulled By: myleott

fbshipit-source-id: c4681f1c72467c04cd2654e87bc724c94b76e3fb
parent 38e82904
...@@ -74,8 +74,6 @@ class TransformerModel(FairseqEncoderDecoderModel): ...@@ -74,8 +74,6 @@ class TransformerModel(FairseqEncoderDecoderModel):
help='num encoder attention heads') help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', action='store_true', parser.add_argument('--encoder-normalize-before', action='store_true',
help='apply layernorm before each encoder block') help='apply layernorm before each encoder block')
parser.add_argument('--decoder-final-norm', default=False, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--encoder-learned-pos', action='store_true', parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder') help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR', parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
...@@ -170,6 +168,8 @@ class TransformerEncoder(FairseqEncoder): ...@@ -170,6 +168,8 @@ class TransformerEncoder(FairseqEncoder):
def __init__(self, args, dictionary, embed_tokens): def __init__(self, args, dictionary, embed_tokens):
super().__init__(dictionary) super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([3]))
self.dropout = args.dropout self.dropout = args.dropout
embed_dim = embed_tokens.embedding_dim embed_dim = embed_tokens.embedding_dim
...@@ -188,10 +188,11 @@ class TransformerEncoder(FairseqEncoder): ...@@ -188,10 +188,11 @@ class TransformerEncoder(FairseqEncoder):
TransformerEncoderLayer(args) TransformerEncoderLayer(args)
for i in range(args.encoder_layers) for i in range(args.encoder_layers)
]) ])
self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.encoder_normalize_before if args.encoder_normalize_before:
if self.normalize:
self.layer_norm = LayerNorm(embed_dim) self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
""" """
...@@ -226,7 +227,7 @@ class TransformerEncoder(FairseqEncoder): ...@@ -226,7 +227,7 @@ class TransformerEncoder(FairseqEncoder):
for layer in self.layers: for layer in self.layers:
x = layer(x, encoder_padding_mask) x = layer(x, encoder_padding_mask)
if self.normalize: if self.layer_norm:
x = self.layer_norm(x) x = self.layer_norm(x)
return { return {
...@@ -290,12 +291,12 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -290,12 +291,12 @@ class TransformerDecoder(FairseqIncrementalDecoder):
embed_tokens (torch.nn.Embedding): output embedding embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False). (default: False).
final_norm (bool, optional): apply layer norm to the output of the
final decoder layer (default: True).
""" """
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True): def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
super().__init__(dictionary) super().__init__(dictionary)
self.register_buffer('version', torch.Tensor([3]))
self.dropout = args.dropout self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed self.share_input_output_embed = args.share_decoder_input_output_embed
...@@ -340,10 +341,11 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -340,10 +341,11 @@ class TransformerDecoder(FairseqIncrementalDecoder):
elif not self.share_input_output_embed: elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), self.output_embed_dim)) self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), self.output_embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5) nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5)
self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.decoder_normalize_before and final_norm if args.decoder_normalize_before and not getattr(args, 'no_decoder_final_norm', False):
if self.normalize:
self.layer_norm = LayerNorm(embed_dim) self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused): def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused):
""" """
...@@ -411,7 +413,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -411,7 +413,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
) )
inner_states.append(x) inner_states.append(x)
if self.normalize: if self.layer_norm:
x = self.layer_norm(x) x = self.layer_norm(x)
# T x B x C -> B x T x C # T x B x C -> B x T x C
......
...@@ -21,6 +21,7 @@ from fairseq.models.transformer import ( ...@@ -21,6 +21,7 @@ from fairseq.models.transformer import (
@register_model("transformer_from_pretrained_xlm") @register_model("transformer_from_pretrained_xlm")
class TransformerFromPretrainedXLMModel(TransformerModel): class TransformerFromPretrainedXLMModel(TransformerModel):
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
...@@ -110,6 +111,7 @@ def upgrade_state_dict_with_xlm_weights( ...@@ -110,6 +111,7 @@ def upgrade_state_dict_with_xlm_weights(
class TransformerEncoderFromPretrainedXLM(TransformerEncoder): class TransformerEncoderFromPretrainedXLM(TransformerEncoder):
def __init__(self, args, dictionary, embed_tokens): def __init__(self, args, dictionary, embed_tokens):
super().__init__(args, dictionary, embed_tokens) super().__init__(args, dictionary, embed_tokens)
if getattr(args, 'init_decoder_only', False): if getattr(args, 'init_decoder_only', False):
...@@ -128,12 +130,9 @@ class TransformerEncoderFromPretrainedXLM(TransformerEncoder): ...@@ -128,12 +130,9 @@ class TransformerEncoderFromPretrainedXLM(TransformerEncoder):
class TransformerDecoderFromPretrainedXLM(TransformerDecoder): class TransformerDecoderFromPretrainedXLM(TransformerDecoder):
def __init__(
self, args, dictionary, embed_tokens, no_encoder_attn=False, final_norm=True def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
): super().__init__(args, dictionary, embed_tokens, no_encoder_attn)
super().__init__(
args, dictionary, embed_tokens, no_encoder_attn, final_norm
)
if getattr(args, 'init_encoder_only', False): if getattr(args, 'init_encoder_only', False):
# Don't load XLM weights for decoder if --init-encoder-only # Don't load XLM weights for decoder if --init-encoder-only
return return
......
...@@ -40,8 +40,8 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -40,8 +40,8 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='dropout probability') help='dropout probability')
parser.add_argument('--attention-dropout', default=0., type=float, metavar='D', parser.add_argument('--attention-dropout', default=0., type=float, metavar='D',
help='dropout probability for attention weights') help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', default=0., type=float, metavar='D', parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D',
help='dropout probability after ReLU in FFN') help='dropout probability after activation in FFN.')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N', parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension') help='decoder embedding dimension')
parser.add_argument('--decoder-output-dim', type=int, metavar='N', parser.add_argument('--decoder-output-dim', type=int, metavar='N',
...@@ -56,8 +56,8 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -56,8 +56,8 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='num decoder attention heads') help='num decoder attention heads')
parser.add_argument('--decoder-normalize-before', default=False, action='store_true', parser.add_argument('--decoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each decoder block') help='apply layernorm before each decoder block')
parser.add_argument('--decoder-final-norm', default=False, action='store_true', parser.add_argument('--no-decoder-final-norm', action='store_true',
help='apply layernorm before each decoder block') help='don\'t add an extra layernorm after the last decoder block')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. ' help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion') 'Must be used with adaptive_loss criterion')
...@@ -67,7 +67,7 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -67,7 +67,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
help='adaptive input factor') help='adaptive input factor')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)') help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true', parser.add_argument('--share-decoder-input-output-embed', action='store_true',
help='share decoder input and output embeddings') help='share decoder input and output embeddings')
parser.add_argument('--character-embeddings', default=False, action='store_true', parser.add_argument('--character-embeddings', default=False, action='store_true',
help='if set, uses character embedding convolutions to produce token embeddings') help='if set, uses character embedding convolutions to produce token embeddings')
...@@ -132,7 +132,6 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -132,7 +132,6 @@ class TransformerLanguageModel(FairseqLanguageModel):
decoder = TransformerDecoder( decoder = TransformerDecoder(
args, task.target_dictionary, embed_tokens, no_encoder_attn=True, args, task.target_dictionary, embed_tokens, no_encoder_attn=True,
final_norm=args.decoder_final_norm,
) )
return TransformerLanguageModel(decoder) return TransformerLanguageModel(decoder)
...@@ -150,12 +149,13 @@ def base_lm_architecture(args): ...@@ -150,12 +149,13 @@ def base_lm_architecture(args):
args.activation_fn = getattr(args, 'activation_fn', 'relu') args.activation_fn = getattr(args, 'activation_fn', 'relu')
args.add_bos_token = getattr(args, 'add_bos_token', False) args.add_bos_token = getattr(args, 'add_bos_token', False)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.character_embeddings = getattr(args, 'character_embeddings', False) args.character_embeddings = getattr(args, 'character_embeddings', False)
args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim) args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim) args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
# The model training is not stable without this # Model training is not stable without this
args.decoder_normalize_before = True args.decoder_normalize_before = True
args.adaptive_input = getattr(args, 'adaptive_input', False) args.adaptive_input = getattr(args, 'adaptive_input', False)
...@@ -176,7 +176,8 @@ def transformer_lm_big(args): ...@@ -176,7 +176,8 @@ def transformer_lm_big(args):
@register_model_architecture('transformer_lm', 'transformer_lm_wiki103') @register_model_architecture('transformer_lm', 'transformer_lm_wiki103')
def transformer_lm_wiki103(args): @register_model_architecture('transformer_lm', 'transformer_lm_baevski_wiki103')
def transformer_lm_baevski_wiki103(args):
args.decoder_layers = getattr(args, 'decoder_layers', 16) args.decoder_layers = getattr(args, 'decoder_layers', 16)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8) args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.dropout = getattr(args, 'dropout', 0.3) args.dropout = getattr(args, 'dropout', 0.3)
...@@ -187,14 +188,18 @@ def transformer_lm_wiki103(args): ...@@ -187,14 +188,18 @@ def transformer_lm_wiki103(args):
args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0.2) args.adaptive_softmax_dropout = getattr(args, 'adaptive_softmax_dropout', 0.2)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.activation_dropout = getattr(args, 'activation_dropout', 0.1) args.activation_dropout = getattr(args, 'activation_dropout', 0.1)
args.no_decoder_final_norm = getattr(args, 'no_decoder_final_norm', True)
args.tie_adaptive_proj = getattr(args, 'tie_adaptive_proj', True)
transformer_lm_big(args) transformer_lm_big(args)
@register_model_architecture('transformer_lm', 'transformer_lm_gbw') @register_model_architecture('transformer_lm', 'transformer_lm_gbw')
def transformer_lm_gbw(args): @register_model_architecture('transformer_lm', 'transformer_lm_baevski_gbw')
def transformer_lm_baevski_gbw(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.no_decoder_final_norm = getattr(args, 'no_decoder_final_norm', True)
transformer_lm_big(args) transformer_lm_big(args)
...@@ -206,7 +211,6 @@ def transformer_lm_gpt(args): ...@@ -206,7 +211,6 @@ def transformer_lm_gpt(args):
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 12) args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 12)
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.decoder_final_norm = getattr(args, 'decoder_final_norm', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu') args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args) base_lm_architecture(args)
...@@ -219,7 +223,6 @@ def transformer_lm_gpt2_small(args): ...@@ -219,7 +223,6 @@ def transformer_lm_gpt2_small(args):
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16) args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.decoder_final_norm = getattr(args, 'decoder_final_norm', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu') args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args) base_lm_architecture(args)
...@@ -232,7 +235,6 @@ def transformer_lm_gpt2_medium(args): ...@@ -232,7 +235,6 @@ def transformer_lm_gpt2_medium(args):
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 20) args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 20)
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.decoder_final_norm = getattr(args, 'decoder_final_norm', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu') args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args) base_lm_architecture(args)
...@@ -245,6 +247,5 @@ def transformer_lm_gpt2_big(args): ...@@ -245,6 +247,5 @@ def transformer_lm_gpt2_big(args):
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 25) args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 25)
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1) args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.decoder_final_norm = getattr(args, 'decoder_final_norm', True)
args.activation_fn = getattr(args, 'activation_fn', 'gelu') args.activation_fn = getattr(args, 'activation_fn', 'gelu')
base_lm_architecture(args) base_lm_architecture(args)
...@@ -81,6 +81,8 @@ class LanguageModelingTask(FairseqTask): ...@@ -81,6 +81,8 @@ class LanguageModelingTask(FairseqTask):
help='include past target') help='include past target')
parser.add_argument('--add-bos-token', action='store_true', parser.add_argument('--add-bos-token', action='store_true',
help='prepend beginning of sentence token (<s>)') help='prepend beginning of sentence token (<s>)')
parser.add_argument('--max-target-positions', default=1024, type=int, metavar='N',
help='max number of tokens in the target sequence')
# fmt: on # fmt: on
def __init__(self, args, dictionary, output_dictionary=None, targets=None): def __init__(self, args, dictionary, output_dictionary=None, targets=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment