You need to sign in or sign up before continuing.
Commit fedc55ec authored by ngimel's avatar ngimel Committed by Myle Ott
Browse files

add end-of-stack normalizations in case normalize_before has been set (#244)

parent f7f2dd01
......@@ -200,6 +200,10 @@ class TransformerEncoder(FairseqEncoder):
TransformerEncoderLayer(args)
for i in range(args.encoder_layers)
])
self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.encoder_normalize_before
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
def forward(self, src_tokens, src_lengths):
# embed tokens and positions
......@@ -220,6 +224,9 @@ class TransformerEncoder(FairseqEncoder):
for layer in self.layers:
x = layer(x, encoder_padding_mask)
if self.normalize:
x = self.layer_norm(x)
return {
'encoder_out': x, # T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T
......@@ -245,6 +252,11 @@ class TransformerEncoder(FairseqEncoder):
if 'encoder.embed_positions.weights' in state_dict:
del state_dict['encoder.embed_positions.weights']
state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor(1)
if state_dict.get('encoder.version', torch.Tensor([1]))[0] < 2:
#earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict['encoder.version'] = torch.Tensor([1])
return state_dict
......@@ -285,6 +297,10 @@ class TransformerDecoder(FairseqIncrementalDecoder):
elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
self.register_buffer('version', torch.Tensor([2]))
self.normalize = args.decoder_normalize_before
if self.normalize:
self.layer_norm = LayerNorm(embed_dim)
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
# embed positions
......@@ -317,6 +333,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
incremental_state,
)
if self.normalize:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
......@@ -354,6 +373,12 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if k in state_dict:
state_dict['decoder.layers.{}.{}.{}'.format(i, new, m)] = state_dict[k]
del state_dict[k]
if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2:
#earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict['decoder.version'] = torch.Tensor([1])
return state_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment