Unverified Commit 51c3f42d authored by Yueming Hao's avatar Yueming Hao Committed by GitHub
Browse files

Replace inefficient torch.sqrt taking scalar input with numpy.sqrt (#21496)

* fix rsqrt

* fix typo
parent b0d539cc
......@@ -519,7 +519,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
)
# scale key vectors
key_vectors = self._len_and_dim_norm(query_key_vectors)
sqrt_num = np.sqrt(self.attention_head_size)
key_vectors = self._len_and_dim_norm(query_key_vectors, sqrt_num)
# set query_vectors to query key vectors if LSH self attention
query_vectors = query_vectors if query_vectors is not None else query_key_vectors
......@@ -969,14 +970,12 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
return indices
def _len_and_dim_norm(self, vectors):
def _len_and_dim_norm(self, vectors, sqrt_num):
"""
length and attention head size dim normalization
"""
vectors = self._len_norm(vectors)
vectors = vectors * torch.rsqrt(
torch.tensor(self.attention_head_size, device=vectors.device, dtype=vectors.dtype)
)
vectors = vectors / sqrt_num
return vectors
def _len_norm(self, x, epsilon=1e-6):
......@@ -1114,9 +1113,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
)
# normalize key vectors
key_vectors = key_vectors / torch.sqrt(
torch.tensor(self.attention_head_size, device=key_vectors.device, dtype=key_vectors.dtype)
)
key_vectors = key_vectors / np.sqrt(self.attention_head_size)
# get sequence length indices
indices = torch.arange(sequence_length, device=query_vectors.device).repeat(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment