"tests/vscode:/vscode.git/clone" did not exist on "ab79863e6c4f4df652328af6901be2ee208dacec"
Commit 2c16c7a4 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix v0 eager fa-pa acc error

parent 6ca3d790
...@@ -944,11 +944,12 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -944,11 +944,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_meta.max_decode_seq_len, self.sliding_window, decode_meta.max_decode_seq_len, self.sliding_window,
self.kv_cache_dtype, self.alibi_slopes) self.kv_cache_dtype, self.alibi_slopes)
if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else != AttentionType.ENCODER_DECODER else
decode_meta.max_encoder_seq_len) decode_meta.max_encoder_seq_len)
assert max_seq_len is not None assert max_seq_len is not None
if use_custom:
max_num_partitions = ( max_num_partitions = (
(max_seq_len + _PARTITION_SIZE_ROCM - 1) // (max_seq_len + _PARTITION_SIZE_ROCM - 1) //
_PARTITION_SIZE_ROCM) _PARTITION_SIZE_ROCM)
...@@ -1002,6 +1003,8 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -1002,6 +1003,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor
if envs.VLLM_USE_FLASH_ATTN_PA: if envs.VLLM_USE_FLASH_ATTN_PA:
from flash_attn import vllm_flash_attn_with_kvcache from flash_attn import vllm_flash_attn_with_kvcache
if decode_meta.use_cuda_graph:
max_seq_len = 0
if envs.VLLM_USE_PA_PRINT_PARAM: if envs.VLLM_USE_PA_PRINT_PARAM:
print("PA SIZE:") print("PA SIZE:")
print(f"q.shape = {decode_query.unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}, kv_cache_dtype = {self.kv_cache_dtype}") print(f"q.shape = {decode_query.unsqueeze(1).shape}, key_cache.shape = {key_cache.shape}, value_cache.shape = {value_cache.shape}, kv_cache_dtype = {self.kv_cache_dtype}")
...@@ -1024,6 +1027,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -1024,6 +1027,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
k_scale=layer._k_scale, k_scale=layer._k_scale,
v_scale=layer._v_scale, v_scale=layer._v_scale,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
max_seqlen_k=max_seq_len,
).squeeze(1) ).squeeze(1)
else: else:
out_pa[:] = paged_attn.forward_decode( out_pa[:] = paged_attn.forward_decode(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment