Commit 81b47e7e authored by Myle Ott's avatar Myle Ott
Browse files

Fix buffers in sinusoidal positional embeddings

parent 5935fe2f
......@@ -150,6 +150,14 @@ class TransformerEncoder(FairseqEncoder):
"""Maximum input length supported by the encoder."""
return self.embed_positions.max_positions()
def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'encoder.embed_positions.weights' in state_dict:
del state_dict['encoder.embed_positions.weights']
if 'encoder.embed_positions._float_tensor' not in state_dict:
state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict
class TransformerDecoder(FairseqDecoder):
"""Transformer decoder."""
......@@ -222,6 +230,14 @@ class TransformerDecoder(FairseqDecoder):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions()
def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'decoder.embed_positions.weights' in state_dict:
del state_dict['decoder.embed_positions.weights']
if 'decoder.embed_positions._float_tensor' not in state_dict:
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
......
......@@ -26,14 +26,12 @@ class SinusoidalPositionalEmbedding(nn.Module):
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.left_pad = left_pad
self.register_buffer(
'weights',
SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
padding_idx,
),
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
padding_idx,
)
self.register_buffer('_float_tensor', torch.FloatTensor())
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
......@@ -65,6 +63,7 @@ class SinusoidalPositionalEmbedding(nn.Module):
self.embedding_dim,
self.padding_idx,
).type_as(self.weights)
self.weights = self.weights.type_as(self._float_tensor)
weights = Variable(self.weights)
if incremental_state is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment