"docs/source/en/using-diffusers/freeu.md" did not exist on "e46ec5f88fec23870538df782258c59271b010fd"
Commit 0a628401 authored by Peng-Jen Chen's avatar Peng-Jen Chen Committed by Facebook Github Bot
Browse files

Update upgrade_state_dict in transformer.py to upgrade_state_dict_named (#317)

Summary:
Pull Request resolved: https://github.com/pytorch/fairseq/pull/317

When upgrading `state_dict` variable, `upgrade_state_dict` function in TransformerEncoder/TransformerDecoder doesn't handle multiple encoders/decoders, however, D10052908 will be the case.

Before the change, we will hit error message [1] when loading checkpoint for multilingual_transformer model in D10052908. This diff will fix it.

Reviewed By: myleott, liezl200

Differential Revision: D10375418

fbshipit-source-id: 7104c1a463e78f3fa33d8479a37c51608be50610
parent 0eea6923
......@@ -322,17 +322,19 @@ class TransformerEncoder(FairseqEncoder):
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions())
def upgrade_state_dict(self, state_dict):
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'encoder.embed_positions.weights' in state_dict:
del state_dict['encoder.embed_positions.weights']
state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor(1)
if utils.item(state_dict.get('encoder.version', torch.Tensor([1]))[0]) < 2:
weights_key = '{}.embed_positions.weights'.format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
version_key = '{}.version'.format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict['encoder.version'] = torch.Tensor([1])
state_dict[version_key] = torch.Tensor([1])
return state_dict
......@@ -487,12 +489,13 @@ class TransformerDecoder(FairseqIncrementalDecoder):
self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
return self._future_mask[:dim, :dim]
def upgrade_state_dict(self, state_dict):
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'decoder.embed_positions.weights' in state_dict:
del state_dict['decoder.embed_positions.weights']
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor(1)
weights_key = '{}.embed_positions.weights'.format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1)
for i in range(len(self.layers)):
# update layer norms
......@@ -503,15 +506,15 @@ class TransformerDecoder(FairseqIncrementalDecoder):
}
for old, new in layer_norm_map.items():
for m in ('weight', 'bias'):
k = 'decoder.layers.{}.layer_norms.{}.{}'.format(i, old, m)
k = '{}.layers.{}.layer_norms.{}.{}'.format(name, i, old, m)
if k in state_dict:
state_dict['decoder.layers.{}.{}.{}'.format(i, new, m)] = state_dict[k]
state_dict['{}.layers.{}.{}.{}'.format(name, i, new, m)] = state_dict[k]
del state_dict[k]
if utils.item(state_dict.get('decoder.version', torch.Tensor([1]))[0]) < 2:
if utils.item(state_dict.get('{}.version'.format(name), torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict['decoder.version'] = torch.Tensor([1])
state_dict['{}.version'.format(name)] = torch.Tensor([1])
return state_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment