Commit b9956a6a authored by Myle Ott's avatar Myle Ott
Browse files

Fix FP16 version comparison

parent 753935ef
......@@ -497,7 +497,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
return self.embed_positions.max_positions() if self.embed_positions is not None else float('inf')
def upgrade_state_dict(self, state_dict):
if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2:
if utils.item(state_dict.get('decoder.version', torch.Tensor([1]))[0]) < 2:
# old models use incorrect weight norm dimension
for i, conv in enumerate(self.convolutions):
# reconfigure weight norm
......
......@@ -277,7 +277,7 @@ class TransformerEncoder(FairseqEncoder):
if 'encoder.embed_positions.weights' in state_dict:
del state_dict['encoder.embed_positions.weights']
state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor(1)
if state_dict.get('encoder.version', torch.Tensor([1]))[0] < 2:
if utils.item(state_dict.get('encoder.version', torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
......@@ -415,7 +415,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if k in state_dict:
state_dict['decoder.layers.{}.{}.{}'.format(i, new, m)] = state_dict[k]
del state_dict[k]
if state_dict.get('decoder.version', torch.Tensor([1]))[0] < 2:
if utils.item(state_dict.get('decoder.version', torch.Tensor([1]))[0]) < 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment