Unverified Commit ca0379b8 authored by Florian Seiler's avatar Florian Seiler Committed by GitHub
Browse files

Fix num_heads in _upad_input (#26490)



* Fix num_heads in _upad_input

The variable num_key_value_heads has falsely been named num_heads, which led to reshaping the query_layer using the wrong attention head count. (It would have been enough to use the correct variable self.num_heads instead of num_heads, but I renamed num_heads to num_key_value_heads for clarity)

* fixed copies using make fix-copies and ran make fixup

---------
Co-authored-by: default avatarfseiler <f.seiler@jerocom.de>
parent 67239f73
...@@ -692,13 +692,17 @@ class FalconFlashAttention2(FalconAttention): ...@@ -692,13 +692,17 @@ class FalconFlashAttention2(FalconAttention):
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) key_layer = index_first_axis(
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len: if query_length == kv_seq_len:
query_layer = index_first_axis( query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
) )
cu_seqlens_q = cu_seqlens_k cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k max_seqlen_in_batch_q = max_seqlen_in_batch_k
......
...@@ -553,13 +553,17 @@ class LlamaFlashAttention2(LlamaAttention): ...@@ -553,13 +553,17 @@ class LlamaFlashAttention2(LlamaAttention):
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) key_layer = index_first_axis(
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len: if query_length == kv_seq_len:
query_layer = index_first_axis( query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
) )
cu_seqlens_q = cu_seqlens_k cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k max_seqlen_in_batch_q = max_seqlen_in_batch_k
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment