Commit f84e1ed4 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

disable final layer norm for transformer decoder as it makes things worse

parent 81ba4c4c
...@@ -197,7 +197,7 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -197,7 +197,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
else: else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad()) embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad())
decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True) decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True, final_norm=False)
return TransformerLanguageModel(decoder) return TransformerLanguageModel(decoder)
...@@ -288,7 +288,7 @@ class TransformerEncoder(FairseqEncoder): ...@@ -288,7 +288,7 @@ class TransformerEncoder(FairseqEncoder):
class TransformerDecoder(FairseqIncrementalDecoder): class TransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder.""" """Transformer decoder."""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False): def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False, final_norm=True):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = args.dropout self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed self.share_input_output_embed = args.share_decoder_input_output_embed
...@@ -333,7 +333,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -333,7 +333,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim)) self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), output_embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim ** -0.5) nn.init.normal_(self.embed_out, mean=0, std=output_embed_dim ** -0.5)
self.register_buffer('version', torch.Tensor([2])) self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.decoder_normalize_before self.normalize = args.decoder_normalize_before and final_norm
if self.normalize: if self.normalize:
self.layer_norm = LayerNorm(embed_dim) self.layer_norm = LayerNorm(embed_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment