"graphbolt/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "4546e54f2145b07275065aada9bf1a5f79a8e238"
Commit fd28c880 authored by Myle Ott's avatar Myle Ott
Browse files

Fix LearnedPositionalEmbedding

parent 4db6579a
...@@ -22,7 +22,6 @@ class LearnedPositionalEmbedding(nn.Embedding): ...@@ -22,7 +22,6 @@ class LearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad): def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad):
super().__init__(num_embeddings, embedding_dim, padding_idx) super().__init__(num_embeddings, embedding_dim, padding_idx)
self.left_pad = left_pad self.left_pad = left_pad
self.register_buffer('range_buf', None)
self._is_incremental_eval = False self._is_incremental_eval = False
def incremental_eval(self, mode=True): def incremental_eval(self, mode=True):
...@@ -44,7 +43,7 @@ class LearnedPositionalEmbedding(nn.Embedding): ...@@ -44,7 +43,7 @@ class LearnedPositionalEmbedding(nn.Embedding):
def make_positions(self, input): def make_positions(self, input):
"""Replace non-padding symbols with their position numbers.""" """Replace non-padding symbols with their position numbers."""
if self.range_buf is None: if not hasattr(self, 'range_buf'):
self.range_buf = input.new() self.range_buf = input.new()
seqlen = input.size(1) seqlen = input.size(1)
if self.range_buf.numel() < seqlen: if self.range_buf.numel() < seqlen:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment