Unverified Commit 2782aada authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[modelling] remove un-necessary transpose for fa2 attention (#31749)

* [whisper] remove un-necessary transpose for fa2 attention

* propagate
parent f83c6f1d
...@@ -301,7 +301,7 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention): ...@@ -301,7 +301,7 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention):
# Flash attention requires the input to have the shape # Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim # batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape # therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
...@@ -311,7 +311,6 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention): ...@@ -311,7 +311,6 @@ class Idefics2VisionFlashAttention2(Idefics2VisionAttention):
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view. # to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
...@@ -817,7 +816,7 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention): ...@@ -817,7 +816,7 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
key_states = self.k_proj(torch.cat([context, latents], dim=-2)) key_states = self.k_proj(torch.cat([context, latents], dim=-2))
value_states = self.v_proj(torch.cat([context, latents], dim=-2)) value_states = self.v_proj(torch.cat([context, latents], dim=-2))
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
...@@ -882,7 +881,6 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention): ...@@ -882,7 +881,6 @@ class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention # Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
......
...@@ -406,7 +406,7 @@ class JambaFlashAttention2(JambaAttention): ...@@ -406,7 +406,7 @@ class JambaFlashAttention2(JambaAttention):
# Flash attention requires the input to have the shape # Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim # batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape # therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
...@@ -469,7 +469,6 @@ class JambaFlashAttention2(JambaAttention): ...@@ -469,7 +469,6 @@ class JambaFlashAttention2(JambaAttention):
value_states = value_states.to(target_dtype) value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention # Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
......
...@@ -387,7 +387,7 @@ class WhisperFlashAttention2(WhisperAttention): ...@@ -387,7 +387,7 @@ class WhisperFlashAttention2(WhisperAttention):
bsz, tgt_len, _ = hidden_states.size() bsz, tgt_len, _ = hidden_states.size()
# get query proj # get query proj
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz) query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
if past_key_value is not None: if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx) is_updated = past_key_value.is_updated.get(self.layer_idx)
...@@ -416,7 +416,6 @@ class WhisperFlashAttention2(WhisperAttention): ...@@ -416,7 +416,6 @@ class WhisperFlashAttention2(WhisperAttention):
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim] # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) value_states = value_states.transpose(1, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment