Commit 5aa4a627 authored by Myle Ott's avatar Myle Ott
Browse files

Don't use 0-dimensional buffers in sinusoidal positional embeddings

parent c2794070
...@@ -244,8 +244,7 @@ class TransformerEncoder(FairseqEncoder): ...@@ -244,8 +244,7 @@ class TransformerEncoder(FairseqEncoder):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'encoder.embed_positions.weights' in state_dict: if 'encoder.embed_positions.weights' in state_dict:
del state_dict['encoder.embed_positions.weights'] del state_dict['encoder.embed_positions.weights']
if 'encoder.embed_positions._float_tensor' not in state_dict: state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor(1)
state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict return state_dict
...@@ -340,8 +339,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -340,8 +339,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'decoder.embed_positions.weights' in state_dict: if 'decoder.embed_positions.weights' in state_dict:
del state_dict['decoder.embed_positions.weights'] del state_dict['decoder.embed_positions.weights']
if 'decoder.embed_positions._float_tensor' not in state_dict: state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor(1)
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
for i in range(len(self.layers)): for i in range(len(self.layers)):
# update layer norms # update layer norms
......
...@@ -30,7 +30,7 @@ class SinusoidalPositionalEmbedding(nn.Module): ...@@ -30,7 +30,7 @@ class SinusoidalPositionalEmbedding(nn.Module):
embedding_dim, embedding_dim,
padding_idx, padding_idx,
) )
self.register_buffer('_float_tensor', torch.FloatTensor()) self.register_buffer('_float_tensor', torch.FloatTensor(1))
@staticmethod @staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None): def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment