Commit 5aa4a627 authored by Myle Ott's avatar Myle Ott
Browse files

Don't use 0-dimensional buffers in sinusoidal positional embeddings

parent c2794070
......@@ -244,8 +244,7 @@ class TransformerEncoder(FairseqEncoder):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'encoder.embed_positions.weights' in state_dict:
del state_dict['encoder.embed_positions.weights']
if 'encoder.embed_positions._float_tensor' not in state_dict:
state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor()
state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor(1)
return state_dict
......@@ -340,8 +339,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'decoder.embed_positions.weights' in state_dict:
del state_dict['decoder.embed_positions.weights']
if 'decoder.embed_positions._float_tensor' not in state_dict:
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor(1)
for i in range(len(self.layers)):
# update layer norms
......
......@@ -30,7 +30,7 @@ class SinusoidalPositionalEmbedding(nn.Module):
embedding_dim,
padding_idx,
)
self.register_buffer('_float_tensor', torch.FloatTensor())
self.register_buffer('_float_tensor', torch.FloatTensor(1))
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment