Unverified Commit 504ff7bb authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

2 SinusoidalPositionalEmbedding fixes (#8226)

parent f744b815
...@@ -1328,8 +1328,6 @@ class SinusoidalPositionalEmbedding(nn.Embedding): ...@@ -1328,8 +1328,6 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions, embedding_dim, padding_idx=None): def __init__(self, num_positions, embedding_dim, padding_idx=None):
super().__init__(num_positions, embedding_dim) super().__init__(num_positions, embedding_dim)
if embedding_dim % 2 != 0:
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
self.weight = self._init_weight(self.weight) self.weight = self._init_weight(self.weight)
@staticmethod @staticmethod
...@@ -1342,10 +1340,11 @@ class SinusoidalPositionalEmbedding(nn.Embedding): ...@@ -1342,10 +1340,11 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
position_enc = np.array( position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)] [[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
) )
out[:, 0 : dim // 2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) # This line breaks for odd n_pos out.requires_grad = False # set early to avoid an error in pytorch-1.8+
out[:, dim // 2 :] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_() out.detach_()
out.requires_grad = False
return out return out
@torch.no_grad() @torch.no_grad()
......
...@@ -620,8 +620,8 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase): ...@@ -620,8 +620,8 @@ class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
self.assertListEqual(no_cache[-1].tolist(), yes_cache[0][0].tolist()) self.assertListEqual(no_cache[-1].tolist(), yes_cache[0][0].tolist())
def test_odd_embed_dim(self): def test_odd_embed_dim(self):
with self.assertRaises(NotImplementedError): # odd embedding_dim is allowed
SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device) SinusoidalPositionalEmbedding(num_positions=4, embedding_dim=5, padding_idx=0).to(torch_device)
# odd num_positions is allowed # odd num_positions is allowed
SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=0).to(torch_device) SinusoidalPositionalEmbedding(num_positions=5, embedding_dim=4, padding_idx=0).to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment