Commit df877aad authored by zhuwenwen's avatar zhuwenwen
Browse files

update use_rocm_custom_paged_attention

parent 1e302221
...@@ -922,10 +922,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -922,10 +922,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs, num_heads, head_size = decode_query.shape num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads gqa_ratio = num_heads // self.num_kv_heads
# use_custom = use_rocm_custom_paged_attention( use_custom = use_rocm_custom_paged_attention(
# decode_query.dtype, head_size, block_size, gqa_ratio, decode_query.dtype, head_size, block_size, gqa_ratio,
# decode_meta.max_decode_seq_len, self.sliding_window) decode_meta.max_decode_seq_len, self.sliding_window)
use_custom = False
if use_custom: if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else != AttentionType.ENCODER_DECODER else
......
...@@ -110,15 +110,16 @@ def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, ...@@ -110,15 +110,16 @@ def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
# rocm custom page attention not support on gfx1* # rocm custom page attention not support on gfx1*
# custom paged attn always supported on V0. On V1, requires sliding window # custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy. # disabled due to observed numerical discrepancy.
return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0 return False
or sliding_window == (-1, -1)) # return (on_mi250_mi300() and (not envs.VLLM_USE_V1 or sliding_window == 0
and (qtype == torch.half or qtype == torch.bfloat16) # or sliding_window == (-1, -1))
and (head_size == 64 or head_size == 128) # and (qtype == torch.half or qtype == torch.bfloat16)
and (block_size == 16 or block_size == 32) # and (head_size == 64 or head_size == 128)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768 # and (block_size == 16 or block_size == 32)
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) # and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN # and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and envs.VLLM_ROCM_USE_AITER)) # and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
# and envs.VLLM_ROCM_USE_AITER))
class RocmPlatform(Platform): class RocmPlatform(Platform):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment