Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
64f570ab
Unverified
Commit
64f570ab
authored
Feb 12, 2026
by
kliuae
Committed by
GitHub
Feb 11, 2026
Browse files
[ROCm] [aiter] Split KV cache update for AiterFlashAttention (#33681)
Signed-off-by:
kliuae
<
kuanfu.liu@embeddedllm.com
>
parent
fd618871
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
40 deletions
+68
-40
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+68
-40
No files found.
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
64f570ab
...
@@ -11,6 +11,7 @@ from vllm._aiter_ops import rocm_aiter_ops
...
@@ -11,6 +11,7 @@ from vllm._aiter_ops import rocm_aiter_ops
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.attention.attention
import
get_attention_context
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.platform_utils
import
get_cu_count
from
vllm.utils.platform_utils
import
get_cu_count
...
@@ -687,6 +688,8 @@ class AiterFlashAttentionBackend(AttentionBackend):
...
@@ -687,6 +688,8 @@ class AiterFlashAttentionBackend(AttentionBackend):
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
def
get_supported_head_sizes
(
cls
)
->
list
[
int
]:
return
[
64
,
128
,
256
]
return
[
64
,
128
,
256
]
forward_includes_kv_cache_update
:
bool
=
False
@
staticmethod
@
staticmethod
def
get_name
()
->
str
:
def
get_name
()
->
str
:
return
"FLASH_ATTN"
return
"FLASH_ATTN"
...
@@ -982,49 +985,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
...
@@ -982,49 +985,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
# performance to make sure it does not introduce any overhead.
# performance to make sure it does not introduce any overhead.
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
key_cache
=
key_cache
.
view
(
current_platform
.
fp8_dtype
())
key_cache
=
key_cache
.
view
(
current_platform
.
fp8_dtype
())
value_cache
=
value_cache
.
view
(
current_platform
.
fp8_dtype
())
value_cache
=
value_cache
.
view
(
current_platform
.
fp8_dtype
())
if
(
self
.
kv_sharing_target_layer_name
is
None
and
key
is
not
None
and
value
is
not
None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
# is not padded. However, we don't need to do
# key[:num_actual_tokens] and value[:num_actual_tokens] because
# the reshape_and_cache_flash op uses the slot_mapping's shape
# to determine the number of actual tokens.
if
rocm_aiter_ops
.
is_shuffle_kv_cache_enabled
():
# We may calculate per token quant scale in
# reshape_and_cache_shuffle_triton which might differ from
# vllm's style when shuffle layout is used.
reshape_and_cache_shuffle_triton
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
attn_metadata
.
k_scale
,
attn_metadata
.
v_scale
,
)
else
:
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
# decode:extend:prefill
# decode:extend:prefill
query
=
query
[:
num_actual_tokens
]
query
=
query
[:
num_actual_tokens
]
...
@@ -1215,3 +1179,67 @@ class AiterFlashAttentionImpl(AttentionImpl):
...
@@ -1215,3 +1179,67 @@ class AiterFlashAttentionImpl(AttentionImpl):
)
)
return
output
return
output
def
do_kv_cache_update
(
self
,
layer
:
Attention
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
):
attn_metadata
,
_
,
_
=
get_attention_context
(
layer
.
layer_name
)
if
attn_metadata
is
None
:
# Profiling run.
return
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
key_cache
=
key_cache
.
view
(
current_platform
.
fp8_dtype
())
value_cache
=
value_cache
.
view
(
current_platform
.
fp8_dtype
())
if
(
self
.
kv_sharing_target_layer_name
is
None
and
key
is
not
None
and
value
is
not
None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
# is not padded. However, we don't need to do
# key[:num_actual_tokens] and value[:num_actual_tokens] because
# the reshape_and_cache_flash op uses the slot_mapping's shape
# to determine the number of actual tokens.
if
rocm_aiter_ops
.
is_shuffle_kv_cache_enabled
():
# We may calculate per token quant scale in
# reshape_and_cache_shuffle_triton which might differ from
# vllm's style when shuffle layout is used.
k_scale
=
attn_metadata
.
k_scale
v_scale
=
attn_metadata
.
v_scale
assert
k_scale
is
not
None
and
v_scale
is
not
None
,
(
"k_scale and v_scale are required for shuffled update"
)
reshape_and_cache_shuffle_triton
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment