Commit 81b47e7e authored by Myle Ott's avatar Myle Ott
Browse files

Fix buffers in sinusoidal positional embeddings

parent 5935fe2f
...@@ -150,6 +150,14 @@ class TransformerEncoder(FairseqEncoder): ...@@ -150,6 +150,14 @@ class TransformerEncoder(FairseqEncoder):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'encoder.embed_positions.weights' in state_dict:
del state_dict['encoder.embed_positions.weights']
if 'encoder.embed_positions._float_tensor' not in state_dict:
state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict
class TransformerDecoder(FairseqDecoder): class TransformerDecoder(FairseqDecoder):
"""Transformer decoder.""" """Transformer decoder."""
...@@ -222,6 +230,14 @@ class TransformerDecoder(FairseqDecoder): ...@@ -222,6 +230,14 @@ class TransformerDecoder(FairseqDecoder):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'decoder.embed_positions.weights' in state_dict:
del state_dict['decoder.embed_positions.weights']
if 'decoder.embed_positions._float_tensor' not in state_dict:
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict
class TransformerEncoderLayer(nn.Module): class TransformerEncoderLayer(nn.Module):
"""Encoder layer block. """Encoder layer block.
......
...@@ -26,14 +26,12 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -26,14 +26,12 @@ class SinusoidalPositionalEmbedding(nn.Module):
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.padding_idx = padding_idx self.padding_idx = padding_idx
self.left_pad = left_pad self.left_pad = left_pad
self.register_buffer( self.weights = SinusoidalPositionalEmbedding.get_embedding(
'weights', init_size,
SinusoidalPositionalEmbedding.get_embedding( embedding_dim,
init_size, padding_idx,
embedding_dim,
padding_idx,
),
) )
self.register_buffer('_float_tensor', torch.FloatTensor())
@staticmethod @staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None): def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
...@@ -65,6 +63,7 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -65,6 +63,7 @@ class SinusoidalPositionalEmbedding(nn.Module):
self.embedding_dim, self.embedding_dim,
self.padding_idx, self.padding_idx,
).type_as(self.weights) ).type_as(self.weights)
self.weights = self.weights.type_as(self._float_tensor)
weights = Variable(self.weights) weights = Variable(self.weights)
if incremental_state is not None: if incremental_state is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment