Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a4ad9db5
Unverified
Commit
a4ad9db5
authored
Mar 13, 2026
by
Rohan Potdar
Committed by
GitHub
Mar 13, 2026
Browse files
Enable RoPE+KV cache fusion for ROCm AITER FA (non-shuffle layout) (#35786)
Signed-off-by:
Rohan138
<
rohanpotdar138@gmail.com
>
parent
b373b510
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
1 deletion
+47
-1
tests/compile/passes/test_rope_kvcache_fusion.py
tests/compile/passes/test_rope_kvcache_fusion.py
+1
-0
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+46
-1
No files found.
tests/compile/passes/test_rope_kvcache_fusion.py
View file @
a4ad9db5
...
@@ -196,6 +196,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
...
@@ -196,6 +196,7 @@ class QKRoPEKVCacheTestModel(torch.nn.Module):
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_UNIFIED_ATTN
,
AttentionBackendEnum
.
TRITON_ATTN
,
AttentionBackendEnum
.
TRITON_ATTN
,
AttentionBackendEnum
.
ROCM_ATTN
,
AttentionBackendEnum
.
ROCM_ATTN
,
AttentionBackendEnum
.
ROCM_AITER_FA
,
],
],
)
)
@
pytest
.
mark
.
parametrize
(
"enable_rope_custom_op"
,
[
True
])
# [True, False])
@
pytest
.
mark
.
parametrize
(
"enable_rope_custom_op"
,
[
True
])
# [True, False])
...
...
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
a4ad9db5
...
@@ -20,6 +20,7 @@ from vllm.v1.attention.backend import (
...
@@ -20,6 +20,7 @@ from vllm.v1.attention.backend import (
AttentionBackend
,
AttentionBackend
,
AttentionCGSupport
,
AttentionCGSupport
,
AttentionImpl
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
AttentionType
,
AttentionType
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
...
@@ -1308,7 +1309,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
...
@@ -1308,7 +1309,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
def
do_kv_cache_update
(
def
do_kv_cache_update
(
self
,
self
,
layer
:
Attention
,
layer
:
Attention
Layer
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
...
@@ -1359,3 +1360,47 @@ class AiterFlashAttentionImpl(AttentionImpl):
...
@@ -1359,3 +1360,47 @@ class AiterFlashAttentionImpl(AttentionImpl):
layer
.
_k_scale
,
layer
.
_k_scale
,
layer
.
_v_scale
,
layer
.
_v_scale
,
)
)
def
fused_rope_kvcache_supported
(
self
):
# Only support fusion when shuffle KV cache layout is not used;
# shuffle layout uses a different cache update path.
return
(
rocm_aiter_ops
.
is_enabled
()
and
not
rocm_aiter_ops
.
is_shuffle_kv_cache_enabled
()
)
def
do_rope_and_kv_cache_update
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
kv_cache
:
torch
.
Tensor
,
layer_slot_mapping
:
torch
.
Tensor
,
):
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
flash_layout
=
True
is_fp8_kv_cache
=
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
if
is_fp8_kv_cache
:
key_cache
=
key_cache
.
view
(
current_platform
.
fp8_dtype
())
value_cache
=
value_cache
.
view
(
current_platform
.
fp8_dtype
())
rocm_aiter_ops
.
triton_rope_and_cache
(
query
,
key
,
value
,
positions
,
cos_sin_cache
,
is_neox
,
key_cache
,
value_cache
,
layer_slot_mapping
,
layer
.
_k_scale
,
layer
.
_v_scale
,
flash_layout
,
is_fp8_kv_cache
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment