Commit f84e1ed4 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

disable final layer norm for transformer decoder as it makes things worse

parent 81ba4c4c
......@@ -197,7 +197,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad())
decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True)
decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True, final_norm=False)
return TransformerLanguageModel(decoder)
......@@ -288,7 +288,7 @@ class TransformerEncoder(FairseqEncoder):
class TransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder."""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False):
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True):
super().__init__(dictionary)
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
......@@ -333,7 +333,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim ** -0.5)
self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.decoder_normalize_before
self.normalize = args.decoder_normalize_before and final_norm
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment