Commit b9956a6a authored by Myle Ott's avatar Myle Ott
Browse files

Fix FP16 version comparison

parent 753935ef
...@@ -497,7 +497,7 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -497,7 +497,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
return self.embed_positions.max_positions() if self.embed_positions is not None else float('inf') return self.embed_positions.max_positions() if self.embed_positions is not None else float('inf')
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2: if utils.item(state_dict.get('decoder.version', torch.Tensor([1]))[0]) < 2:
# old models use incorrect weight norm dimension # old models use incorrect weight norm dimension
for i, conv in enumerate(self.convolutions): for i, conv in enumerate(self.convolutions):
# reconfigure weight norm # reconfigure weight norm
......
...@@ -277,7 +277,7 @@ class TransformerEncoder(FairseqEncoder): ...@@ -277,7 +277,7 @@ class TransformerEncoder(FairseqEncoder):
if 'encoder.embed_positions.weights' in state_dict: if 'encoder.embed_positions.weights' in state_dict:
del state_dict['encoder.embed_positions.weights'] del state_dict['encoder.embed_positions.weights']
state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor(1) state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor(1)
if state_dict.get('encoder.version', torch.Tensor([1]))[0] < 2: if utils.item(state_dict.get('encoder.version', torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers # earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None self.layer_norm = None
self.normalize = False self.normalize = False
...@@ -415,7 +415,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -415,7 +415,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if k in state_dict: if k in state_dict:
state_dict['decoder.layers.{}.{}.{}'.format(i, new, m)] = state_dict[k] state_dict['decoder.layers.{}.{}.{}'.format(i, new, m)] = state_dict[k]
del state_dict[k] del state_dict[k]
if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2: if utils.item(state_dict.get('decoder.version', torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers # earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None self.layer_norm = None
self.normalize = False self.normalize = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment