Commit a47630e1 authored by Liezl Puzon's avatar Liezl Puzon Committed by Facebook Github Bot
Browse files

Fix hybrid transformer state dict update after encoder layernorm rename (#633)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/633

Pull Request resolved: https://github.com/pytorch/translate/pull/456

This diff makes it easier to upgrade the state dict for components that use TransformerEncoderLayer

Reviewed By: jhcross

Differential Revision: D14916941

fbshipit-source-id: 6d0258c8a9492a720684dadce59c90fc87cbf5cf
parent 58b912f6
...@@ -368,18 +368,8 @@ class TransformerEncoder(FairseqEncoder): ...@@ -368,18 +368,8 @@ class TransformerEncoder(FairseqEncoder):
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
for i in range(len(self.layers)): for i in range(len(self.layers)):
# update layer norms # update layer norms
layer_norm_map = { self.layers[i].upgrade_state_dict_named(state_dict, f"{name}.layers.{i}")
'0': 'self_attn_layer_norm',
'1': 'final_layer_norm'
}
for old, new in layer_norm_map.items():
for m in ('weight', 'bias'):
k = '{}.layers.{}.layer_norms.{}.{}'.format(name, i, old, m)
if k in state_dict:
state_dict[
'{}.layers.{}.{}.{}'.format(name, i, new, m)
] = state_dict[k]
del state_dict[k]
version_key = '{}.version'.format(name) version_key = '{}.version'.format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers # earlier checkpoints did not normalize after the stack of layers
...@@ -605,6 +595,25 @@ class TransformerEncoderLayer(nn.Module): ...@@ -605,6 +595,25 @@ class TransformerEncoderLayer(nn.Module):
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim) self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.final_layer_norm = LayerNorm(self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim)
def upgrade_state_dict_named(self, state_dict, name):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map = {
'0': 'self_attn_layer_norm',
'1': 'final_layer_norm'
}
for old, new in layer_norm_map.items():
for m in ('weight', 'bias'):
k = f'{name}.layer_norms.{old}.{m}'
if k in state_dict:
state_dict[
f'{name}.{new}.{m}'
] = state_dict[k]
del state_dict[k]
def forward(self, x, encoder_padding_mask): def forward(self, x, encoder_padding_mask):
""" """
Args: Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment