Unverified Commit c568581f authored by Strahinja Stamenkovic's avatar Strahinja Stamenkovic Committed by GitHub
Browse files

Fix IndexError with encoder-decoder models when using Custom Paged Attention (#33112)


Signed-off-by: default avatarsstamenk <strahinja.stamenkovic@amd.com>
parent 2d705343
...@@ -330,7 +330,14 @@ class RocmAttentionImpl(AttentionImpl): ...@@ -330,7 +330,14 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache, self.num_kv_heads, self.head_size kv_cache, self.num_kv_heads, self.head_size
) )
if self.kv_sharing_target_layer_name is None: # key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer. # Skip this if sharing KV cache with an earlier attention layer.
...@@ -382,8 +389,8 @@ class RocmAttentionImpl(AttentionImpl): ...@@ -382,8 +389,8 @@ class RocmAttentionImpl(AttentionImpl):
# Compute attention and update output up to `num_actual_tokens`. # Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode( chunked_prefill_paged_decode(
query=query[:num_actual_tokens], query=query[:num_actual_tokens],
key=key[:num_actual_tokens], key=key[:num_actual_tokens] if key is not None else None,
value=value[:num_actual_tokens], value=value[:num_actual_tokens] if value is not None else None,
output=output[:num_actual_tokens], output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache, key_cache=key_cache,
......
...@@ -302,8 +302,9 @@ def chunked_prefill_paged_decode( ...@@ -302,8 +302,9 @@ def chunked_prefill_paged_decode(
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs = len(seq_lens) num_seqs = len(seq_lens)
num_query_heads = query.shape[1] num_query_heads = query.shape[1]
num_kv_heads = key.shape[1] # key may be None in cross-attention decode (already cached from encoder)
num_queries_per_kv = query.shape[1] // key.shape[1] num_kv_heads = key.shape[1] if key is not None else key_cache.shape[1]
num_queries_per_kv = num_query_heads // num_kv_heads
head_size = query.shape[2] head_size = query.shape[2]
# Conversion of FP8 Tensor from uint8 storage to # Conversion of FP8 Tensor from uint8 storage to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment