"vscode:/vscode.git/clone" did not exist on "7c5fef81e0aecff65c041a9dfb23aff22bf64f4b"
Commit f84e1ed4 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

disable final layer norm for transformer decoder as it makes things worse

parent 81ba4c4c
......@@ -197,7 +197,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad())
decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True)
decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True, final_norm=False)
return TransformerLanguageModel(decoder)
......@@ -288,7 +288,7 @@ class TransformerEncoder(FairseqEncoder):
class TransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder."""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False):
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True):
super().__init__(dictionary)
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
......@@ -333,7 +333,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim ** -0.5)
self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.decoder_normalize_before
self.normalize = args.decoder_normalize_before and final_norm
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment