Commit 2c16c7a4 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix v0 eager fa-pa acc error

parent 6ca3d790
......@@ -943,12 +943,13 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len, self.sliding_window,
self.kv_cache_dtype, self.alibi_slopes)
if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else
decode_meta.max_encoder_seq_len)
assert max_seq_len is not None
assert max_seq_len is not None
if use_custom:
max_num_partitions = (
(max_seq_len + _PARTITION_SIZE_ROCM - 1) //
_PARTITION_SIZE_ROCM)
......@@ -1002,6 +1003,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor
if envs.VLLM_USE_FLASH_ATTN_PA:
from flash_attn import vllm_flash_attn_with_kvcache
if decode_meta.use_cuda_graph:
max_seq_len = 0
if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA SIZE:")
print(f"q.shape = {decode_query.unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}, kv_cache_dtype = {self.kv_cache_dtype}")
......@@ -1024,6 +1027,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
k_scale=layer._k_scale,
v_scale=layer._v_scale,
kv_cache_dtype=self.kv_cache_dtype,
max_seqlen_k=max_seq_len,
).squeeze(1)
else:
out_pa[:] = paged_attn.forward_decode(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment