Commit fd28c880 authored by Myle Ott's avatar Myle Ott
Browse files

Fix LearnedPositionalEmbedding

parent 4db6579a
......@@ -22,7 +22,6 @@ class LearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.left_pad = left_pad
self.register_buffer('range_buf', None)
self._is_incremental_eval = False
def incremental_eval(self, mode=True):
......@@ -44,7 +43,7 @@ class LearnedPositionalEmbedding(nn.Embedding):
def make_positions(self, input):
"""Replace non-padding symbols with their position numbers."""
if self.range_buf is None:
if not hasattr(self, 'range_buf'):
self.range_buf = input.new()
seqlen = input.size(1)
if self.range_buf.numel() < seqlen:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment