Commit e5ba94ab authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Make TransformerEncoderLayer layer norm names more descriptive

Summary:
I added an upgrade_state_dict function so that loading old models will still work

layer_norms[0] --> self_attn_layer_norm
layer_norms[1] --> final_layer_norm

Reviewed By: pipibjc

Differential Revision: D14689849

fbshipit-source-id: b2809262c11fe9d083e571fa31044798aefd48ce
parent 309f2511
...@@ -366,6 +366,20 @@ class TransformerEncoder(FairseqEncoder): ...@@ -366,6 +366,20 @@ class TransformerEncoder(FairseqEncoder):
if weights_key in state_dict: if weights_key in state_dict:
del state_dict[weights_key] del state_dict[weights_key]
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
for i in range(len(self.layers)):
# update layer norms
layer_norm_map = {
'0': 'self_attn_layer_norm',
'1': 'final_layer_norm'
}
for old, new in layer_norm_map.items():
for m in ('weight', 'bias'):
k = '{}.layers.{}.layer_norms.{}.{}'.format(name, i, old, m)
if k in state_dict:
state_dict[
'{}.layers.{}.{}.{}'.format(name, i, new, m)
] = state_dict[k]
del state_dict[k]
version_key = '{}.version'.format(name) version_key = '{}.version'.format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers # earlier checkpoints did not normalize after the stack of layers
...@@ -583,12 +597,13 @@ class TransformerEncoderLayer(nn.Module): ...@@ -583,12 +597,13 @@ class TransformerEncoderLayer(nn.Module):
self.embed_dim, args.encoder_attention_heads, self.embed_dim, args.encoder_attention_heads,
dropout=args.attention_dropout, dropout=args.attention_dropout,
) )
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
self.dropout = args.dropout self.dropout = args.dropout
self.relu_dropout = args.relu_dropout self.relu_dropout = args.relu_dropout
self.normalize_before = args.encoder_normalize_before self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim) self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(2)]) self.final_layer_norm = LayerNorm(self.embed_dim)
def forward(self, x, encoder_padding_mask): def forward(self, x, encoder_padding_mask):
""" """
...@@ -601,26 +616,26 @@ class TransformerEncoderLayer(nn.Module): ...@@ -601,26 +616,26 @@ class TransformerEncoderLayer(nn.Module):
encoded output of shape `(batch, src_len, embed_dim)` encoded output of shape `(batch, src_len, embed_dim)`
""" """
residual = x residual = x
x = self.maybe_layer_norm(0, x, before=True) x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask) x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.maybe_layer_norm(0, x, after=True) x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
residual = x residual = x
x = self.maybe_layer_norm(1, x, before=True) x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training) x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x) x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.maybe_layer_norm(1, x, after=True) x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
return x return x
def maybe_layer_norm(self, i, x, before=False, after=False): def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after assert before ^ after
if after ^ self.normalize_before: if after ^ self.normalize_before:
return self.layer_norms[i](x) return layer_norm(x)
else: else:
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment