Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
de3869bb
Unverified
Commit
de3869bb
authored
Feb 07, 2026
by
Rohan Potdar
Committed by
GitHub
Feb 07, 2026
Browse files
move checks out of `unified_kv_cache_update` custom op (#33943)
Signed-off-by:
Rohan138
<
rohanpotdar138@gmail.com
>
parent
ce9b3cd3
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
80 additions
and
101 deletions
+80
-101
vllm/model_executor/layers/attention/attention.py
vllm/model_executor/layers/attention/attention.py
+14
-6
vllm/model_executor/layers/attention/cross_attention.py
vllm/model_executor/layers/attention/cross_attention.py
+3
-0
vllm/model_executor/models/whisper_causal.py
vllm/model_executor/models/whisper_causal.py
+3
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+0
-10
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
+11
-20
vllm/v1/attention/backends/rocm_attn.py
vllm/v1/attention/backends/rocm_attn.py
+32
-42
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+17
-23
No files found.
vllm/model_executor/layers/attention/attention.py
View file @
de3869bb
...
@@ -422,9 +422,15 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -422,9 +422,15 @@ class Attention(nn.Module, AttentionLayerBase):
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
value
is
not
None
:
if
value
is
not
None
:
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size_v
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size_v
)
kv_cache_dummy_dep
=
None
if
self
.
use_direct_call
:
if
self
.
use_direct_call
:
kv_cache_dummy_dep
=
None
# Skip this if sharing KV cache with an earlier attention layer.
if
not
self
.
attn_backend
.
forward_includes_kv_cache_update
:
if
(
not
self
.
attn_backend
.
forward_includes_kv_cache_update
and
self
.
kv_sharing_target_layer_name
is
None
and
key
is
not
None
and
value
is
not
None
):
kv_cache_dummy_dep
=
unified_kv_cache_update
(
kv_cache_dummy_dep
=
unified_kv_cache_update
(
key
,
value
,
self
.
layer_name
key
,
value
,
self
.
layer_name
)
)
...
@@ -437,10 +443,12 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -437,10 +443,12 @@ class Attention(nn.Module, AttentionLayerBase):
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
kv_cache_dummy_dep
=
kv_cache_dummy_dep
,
)
)
else
:
else
:
kv_cache_dummy_dep
=
None
# Skip this if sharing KV cache with an earlier attention layer.
if
not
self
.
attn_backend
.
forward_includes_kv_cache_update
and
(
if
(
# torch can only dispatch custom op if a tensor is passed
not
self
.
attn_backend
.
forward_includes_kv_cache_update
key
is
not
None
or
value
is
not
None
and
self
.
kv_sharing_target_layer_name
is
None
and
key
is
not
None
and
value
is
not
None
):
):
kv_cache_dummy_dep
=
torch
.
ops
.
vllm
.
unified_kv_cache_update
(
kv_cache_dummy_dep
=
torch
.
ops
.
vllm
.
unified_kv_cache_update
(
key
,
value
,
self
.
layer_name
key
,
value
,
self
.
layer_name
...
...
vllm/model_executor/layers/attention/cross_attention.py
View file @
de3869bb
...
@@ -136,6 +136,9 @@ def create_cross_attention_backend(
...
@@ -136,6 +136,9 @@ def create_cross_attention_backend(
if
(
if
(
not
underlying_attn_backend
.
forward_includes_kv_cache_update
not
underlying_attn_backend
.
forward_includes_kv_cache_update
and
attn_metadata
is
not
None
and
attn_metadata
is
not
None
and
layer
.
kv_sharing_target_layer_name
is
None
and
key
is
not
None
and
value
is
not
None
):
):
self
.
do_kv_cache_update
(
self
.
do_kv_cache_update
(
layer
,
key
,
value
,
kv_cache
,
attn_metadata
.
slot_mapping
layer
,
key
,
value
,
kv_cache
,
attn_metadata
.
slot_mapping
...
...
vllm/model_executor/models/whisper_causal.py
View file @
de3869bb
...
@@ -172,6 +172,9 @@ def create_whisper_attention_backend_with_block_pooling(
...
@@ -172,6 +172,9 @@ def create_whisper_attention_backend_with_block_pooling(
if
(
if
(
not
underlying_attn_backend
.
forward_includes_kv_cache_update
not
underlying_attn_backend
.
forward_includes_kv_cache_update
and
attn_metadata
is
not
None
and
attn_metadata
is
not
None
and
layer
.
kv_sharing_target_layer_name
is
None
and
key
is
not
None
and
value
is
not
None
):
):
self
.
do_kv_cache_update
(
self
.
do_kv_cache_update
(
layer
,
key
,
value
,
kv_cache
,
attn_metadata
.
slot_mapping
layer
,
key
,
value
,
kv_cache
,
attn_metadata
.
slot_mapping
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
de3869bb
...
@@ -771,16 +771,6 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -771,16 +771,6 @@ class FlashAttentionImpl(AttentionImpl):
# we use direct Q, K, V tensors without caching
# we use direct Q, K, V tensors without caching
return
return
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if
(
self
.
kv_sharing_target_layer_name
is
not
None
or
key
is
None
or
value
is
None
):
return
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
# Reshape the input keys and values and store them in the cache.
# Reshape the input keys and values and store them in the cache.
...
...
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
View file @
de3869bb
...
@@ -196,23 +196,14 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
...
@@ -196,23 +196,14 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
):
):
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
# key and value may be None in the case of cross attention. They are
# Reshape the input keys and values and store them in the cache.
# calculated once based on the output from the encoder and then cached
ops
.
reshape_and_cache_flash
(
# in KV cache.
key
,
if
(
value
,
self
.
kv_sharing_target_layer_name
is
None
key_cache
,
and
key
is
not
None
value_cache
,
and
value
is
not
None
slot_mapping
,
):
self
.
kv_cache_dtype
,
# Reshape the input keys and values and store them in the cache.
layer
.
_k_scale
,
# Skip this if sharing KV cache with an earlier attention layer.
layer
.
_v_scale
,
ops
.
reshape_and_cache_flash
(
)
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
vllm/v1/attention/backends/rocm_attn.py
View file @
de3869bb
...
@@ -383,45 +383,35 @@ class RocmAttentionImpl(AttentionImpl):
...
@@ -383,45 +383,35 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
)
# key and value may be None in the case of cross attention. They are
# Reshape the input keys and values and store them in the cache.
# calculated once based on the output from the encoder and then cached
# Get the actual block_size from value_cache
# in KV cache.
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
if
(
block_size
=
value_cache
.
shape
[
3
]
self
.
kv_sharing_target_layer_name
is
None
# Determine if it is a power of 2
and
key
is
not
None
is_pow2
=
block_size
>
0
and
(
block_size
&
(
block_size
-
1
)
==
0
)
and
value
is
not
None
):
if
is_pow2
:
# Reshape the input keys and values and store them in the cache.
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
# Skip this if sharing KV cache with an earlier attention layer.
PagedAttention
.
write_to_paged_cache
(
key
,
# Get the actual block_size from value_cache
value
,
# value_cache shape: [num_blocks, num_heads, head_size, block_size]
key_cache
,
block_size
=
value_cache
.
shape
[
3
]
value_cache
,
# Determine if it is a power of 2
slot_mapping
,
is_pow2
=
block_size
>
0
and
(
block_size
&
(
block_size
-
1
)
==
0
)
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
if
is_pow2
:
layer
.
_v_scale
,
# Normal 16, 32, 64, etc., use vLLM native HIP C++ logic
)
PagedAttention
.
write_to_paged_cache
(
else
:
key
,
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
value
,
# force using our modified Triton logic
key_cache
,
triton_reshape_and_cache_flash
(
value_cache
,
key
,
slot_mapping
,
value
,
self
.
kv_cache_dtype
,
key_cache
,
layer
.
_k_scale
,
value_cache
,
layer
.
_v_scale
,
slot_mapping
,
)
self
.
kv_cache_dtype
,
else
:
layer
.
_k_scale
,
# Case B: Non-standard blocks (e.g., 544 in Qwen3),
layer
.
_v_scale
,
# force using our modified Triton logic
)
triton_reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
vllm/v1/attention/backends/triton_attn.py
View file @
de3869bb
...
@@ -579,26 +579,20 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -579,26 +579,20 @@ class TritonAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before
# For decoder and cross-attention, use KV cache as before
key_cache
,
value_cache
=
kv_cache
.
unbind
(
1
)
key_cache
,
value_cache
=
kv_cache
.
unbind
(
1
)
if
(
# Reshape the input keys and values and store them in the cache.
self
.
kv_sharing_target_layer_name
is
None
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
and
key
is
not
None
key_cache
=
key_cache
.
view
(
self
.
fp8_dtype
)
and
value
is
not
None
value_cache
=
value_cache
.
view
(
self
.
fp8_dtype
)
):
# triton kernel does not support uint8 kv_cache
# Reshape the input keys and values and store them in the cache.
# (because some explicit casts (e.g. float8_e4m3fnuz)
# Skip this if sharing KV cache with an earlier attention layer.
# are not supported)
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
triton_reshape_and_cache_flash
(
key_cache
=
key_cache
.
view
(
self
.
fp8_dtype
)
key
,
value_cache
=
value_cache
.
view
(
self
.
fp8_dtype
)
value
,
# triton kernel does not support uint8 kv_cache
key_cache
,
# (because some explicit casts (e.g. float8_e4m3fnuz)
value_cache
,
# are not supported)
slot_mapping
,
triton_reshape_and_cache_flash
(
self
.
kv_cache_dtype
,
key
,
layer
.
_k_scale
,
value
,
layer
.
_v_scale
,
key_cache
,
)
value_cache
,
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment