Unverified Commit a4ad9db5 authored by Rohan Potdar's avatar Rohan Potdar Committed by GitHub
Browse files

Enable RoPE+KV cache fusion for ROCm AITER FA (non-shuffle layout) (#35786)


Signed-off-by: default avatarRohan138 <rohanpotdar138@gmail.com>
parent b373b510
...@@ -196,6 +196,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module): ...@@ -196,6 +196,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN, AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.ROCM_ATTN, AttentionBackendEnum.ROCM_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
], ],
) )
@pytest.mark.parametrize("enable_rope_custom_op", [True]) # [True, False]) @pytest.mark.parametrize("enable_rope_custom_op", [True]) # [True, False])
......
...@@ -20,6 +20,7 @@ from vllm.v1.attention.backend import ( ...@@ -20,6 +20,7 @@ from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
AttentionImpl, AttentionImpl,
AttentionLayer,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
...@@ -1308,7 +1309,7 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -1308,7 +1309,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
def do_kv_cache_update( def do_kv_cache_update(
self, self,
layer: Attention, layer: AttentionLayer,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
...@@ -1359,3 +1360,47 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -1359,3 +1360,47 @@ class AiterFlashAttentionImpl(AttentionImpl):
layer._k_scale, layer._k_scale,
layer._v_scale, layer._v_scale,
) )
def fused_rope_kvcache_supported(self):
# Only support fusion when shuffle KV cache layout is not used;
# shuffle layout uses a different cache update path.
return (
rocm_aiter_ops.is_enabled()
and not rocm_aiter_ops.is_shuffle_kv_cache_enabled()
)
def do_rope_and_kv_cache_update(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
is_neox: bool,
kv_cache: torch.Tensor,
layer_slot_mapping: torch.Tensor,
):
key_cache, value_cache = kv_cache.unbind(0)
flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
if is_fp8_kv_cache:
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())
rocm_aiter_ops.triton_rope_and_cache(
query,
key,
value,
positions,
cos_sin_cache,
is_neox,
key_cache,
value_cache,
layer_slot_mapping,
layer._k_scale,
layer._v_scale,
flash_layout,
is_fp8_kv_cache,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment