Unverified Commit 51c3f42d authored by Yueming Hao's avatar Yueming Hao Committed by GitHub
Browse files

Replace inefficient torch.sqrt taking scalar input with numpy.sqrt (#21496)

* fix rsqrt

* fix typo
parent b0d539cc
...@@ -519,7 +519,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -519,7 +519,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
) )
# scale key vectors # scale key vectors
key_vectors = self._len_and_dim_norm(query_key_vectors) sqrt_num = np.sqrt(self.attention_head_size)
key_vectors = self._len_and_dim_norm(query_key_vectors, sqrt_num)
# set query_vectors to query key vectors if LSH self attention # set query_vectors to query key vectors if LSH self attention
query_vectors = query_vectors if query_vectors is not None else query_key_vectors query_vectors = query_vectors if query_vectors is not None else query_key_vectors
...@@ -969,14 +970,12 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -969,14 +970,12 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
return indices return indices
def _len_and_dim_norm(self, vectors): def _len_and_dim_norm(self, vectors, sqrt_num):
""" """
length and attention head size dim normalization length and attention head size dim normalization
""" """
vectors = self._len_norm(vectors) vectors = self._len_norm(vectors)
vectors = vectors * torch.rsqrt( vectors = vectors / sqrt_num
torch.tensor(self.attention_head_size, device=vectors.device, dtype=vectors.dtype)
)
return vectors return vectors
def _len_norm(self, x, epsilon=1e-6): def _len_norm(self, x, epsilon=1e-6):
...@@ -1114,9 +1113,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -1114,9 +1113,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
) )
# normalize key vectors # normalize key vectors
key_vectors = key_vectors / torch.sqrt( key_vectors = key_vectors / np.sqrt(self.attention_head_size)
torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=key_vectors.dtype)
)
# get sequence length indices # get sequence length indices
indices = torch.arange(sequence_length, device=query_vectors.device).repeat( indices = torch.arange(sequence_length, device=query_vectors.device).repeat(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment