Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
116f4be4
Unverified
Commit
116f4be4
authored
Apr 01, 2026
by
Matthew Bonanni
Committed by
GitHub
Apr 01, 2026
Browse files
[1/N][Cleanup] Standardize on use of `is_quantized_kv_cache` (#38659)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
7b01d97a
Changes
28
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
37 additions
and
28 deletions
+37
-28
vllm/v1/attention/backends/mla/xpu_mla_sparse.py
vllm/v1/attention/backends/mla/xpu_mla_sparse.py
+2
-1
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+8
-7
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
+3
-2
vllm/v1/attention/backends/rocm_attn.py
vllm/v1/attention/backends/rocm_attn.py
+4
-3
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+5
-4
vllm/v1/attention/ops/triton_reshape_and_cache_flash.py
vllm/v1/attention/ops/triton_reshape_and_cache_flash.py
+11
-8
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-1
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+2
-2
No files found.
vllm/v1/attention/backends/mla/xpu_mla_sparse.py
View file @
116f4be4
...
...
@@ -13,6 +13,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.attention.mla_attention
import
(
get_mla_dims
,
)
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionCGSupport
,
...
...
@@ -231,7 +232,7 @@ class XPUMLASparseImpl(SparseMLAAttentionImpl[XPUMLASparseMetadata]):
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"FP8 kv is not supported with XPU MLA Sparse yet"
)
# Concatenate q if it's a tuple (ql_nope, q_pe)
...
...
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
116f4be4
...
...
@@ -16,6 +16,7 @@ from vllm.platforms import current_platform
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.platform_utils
import
num_compute_units
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionCGSupport
,
...
...
@@ -291,7 +292,7 @@ if current_platform.is_rocm():
new_key_cache
=
key_cache
.
view_as
(
k_cache_template
)
new_value_cache
=
value_cache
.
view_as
(
v_cache_template
)
QUANT
=
False
if
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
kv_cache_dtype
):
QUANT
=
True
grid
=
(
num_tokens
,
...
...
@@ -494,7 +495,7 @@ class AiterFlashAttentionMetadataBuilder(
if
(
rocm_aiter_ops
.
is_shuffle_kv_cache_enabled
()
and
self
.
scale
.
numel
()
==
1
and
self
.
vllm_config
.
cache_config
.
cache_dtype
.
startswith
(
"fp8"
)
and
is_quantized_kv_cache
(
self
.
vllm_config
.
cache_config
.
cache_dtype
)
):
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
first_layer_name
=
[
k
for
k
in
layers
][
0
]
...
...
@@ -887,7 +888,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
cu_seqlens_kv
=
swa_cu_seqlens
,
token_to_batch
=
swa_token_to_batch
,
seq_starts
=
swa_seq_starts
,
dequant
=
self
.
kv_cache_dtype
.
startswith
(
"fp8"
),
dequant
=
is_quantized_kv_cache
(
self
.
kv_cache_dtype
),
kv_cache_layout
=
"NHD"
,
total_tokens
=
swa_total_tokens
,
)
...
...
@@ -982,7 +983,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
cu_seqlens_kv
=
cu_seqlens_kv
[
chunk_idx
],
token_to_batch
=
token_to_batch
[
chunk_idx
],
seq_starts
=
chunk_starts
[
chunk_idx
],
dequant
=
self
.
kv_cache_dtype
.
startswith
(
"fp8"
),
dequant
=
is_quantized_kv_cache
(
self
.
kv_cache_dtype
),
kv_cache_layout
=
"SHUFFLE"
if
rocm_aiter_ops
.
is_shuffle_kv_cache_enabled
()
else
"NHD"
,
...
...
@@ -1081,7 +1082,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
key_cache
=
key_cache
.
view
(
current_platform
.
fp8_dtype
())
value_cache
=
value_cache
.
view
(
current_platform
.
fp8_dtype
())
...
...
@@ -1370,7 +1371,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
key_cache
=
key_cache
.
view
(
current_platform
.
fp8_dtype
())
value_cache
=
value_cache
.
view
(
current_platform
.
fp8_dtype
())
# Reshape the input keys and values and store them in the cache.
...
...
@@ -1436,7 +1437,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
flash_layout
=
True
is_fp8_kv_cache
=
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
is_fp8_kv_cache
=
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
if
is_fp8_kv_cache
:
key_cache
=
key_cache
.
view
(
current_platform
.
fp8_dtype
())
value_cache
=
value_cache
.
view
(
current_platform
.
fp8_dtype
())
...
...
vllm/v1/attention/backends/rocm_aiter_unified_attn.py
View file @
116f4be4
...
...
@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey
,
kFp8StaticTensorSym
,
)
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
AttentionLayer
,
AttentionType
,
MultipleOf
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.rocm_attn
import
(
...
...
@@ -200,7 +201,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
softmax_scale
=
self
.
scale
fp8_post_attn_v_rescale
=
False
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
key_cache
=
key_cache
.
view
(
self
.
fp8_dtype
)
value_cache
=
value_cache
.
view
(
self
.
fp8_dtype
)
# When Q is FP8, triton kernel skips K/V dequant (for fp8xfp8 matmul).
...
...
@@ -299,7 +300,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
flash_layout
=
True
is_fp8_kv_cache
=
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
is_fp8_kv_cache
=
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
if
is_fp8_kv_cache
:
key_cache
=
key_cache
.
view
(
self
.
fp8_dtype
)
value_cache
=
value_cache
.
view
(
self
.
fp8_dtype
)
...
...
vllm/v1/attention/backends/rocm_attn.py
View file @
116f4be4
...
...
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionCGSupport
,
...
...
@@ -315,7 +316,7 @@ class RocmAttentionImpl(AttentionImpl):
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"quantization is not supported for encoder attention"
)
...
...
@@ -406,7 +407,7 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache
,
self
.
num_kv_heads
,
self
.
head_size
)
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
key_cache
=
key_cache
.
view
(
self
.
fp8_dtype
)
value_cache
=
value_cache
.
view
(
self
.
fp8_dtype
)
assert
layer
.
_q_scale_float
==
1.0
,
(
...
...
@@ -513,7 +514,7 @@ class RocmAttentionImpl(AttentionImpl):
)
flash_layout
=
False
is_fp8_kv_cache
=
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
is_fp8_kv_cache
=
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
if
is_fp8_kv_cache
:
key_cache
=
key_cache
.
view
(
self
.
fp8_dtype
)
value_cache
=
value_cache
.
view
(
self
.
fp8_dtype
)
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
116f4be4
...
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.utils.math_utils
import
next_power_of_2
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionCGSupport
,
...
...
@@ -472,7 +473,7 @@ class TritonAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before
key_cache
,
value_cache
=
kv_cache
.
unbind
(
1
)
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
if
key_cache
.
dtype
!=
self
.
fp8_dtype
:
key_cache
=
key_cache
.
view
(
self
.
fp8_dtype
)
value_cache
=
value_cache
.
view
(
self
.
fp8_dtype
)
...
...
@@ -546,7 +547,7 @@ class TritonAttentionImpl(AttentionImpl):
layer: The attention layer
"""
# For encoder attention, process FP8 quantization if needed
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"quantization is not supported for encoder attention"
)
...
...
@@ -588,7 +589,7 @@ class TritonAttentionImpl(AttentionImpl):
key_cache
,
value_cache
=
kv_cache
.
unbind
(
1
)
# Reshape the input keys and values and store them in the cache.
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
key_cache
=
key_cache
.
view
(
self
.
fp8_dtype
)
value_cache
=
value_cache
.
view
(
self
.
fp8_dtype
)
# triton kernel does not support uint8 kv_cache
...
...
@@ -623,7 +624,7 @@ class TritonAttentionImpl(AttentionImpl):
key_cache
,
value_cache
=
kv_cache
.
unbind
(
1
)
flash_layout
=
True
is_fp8_kv_cache
=
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
is_fp8_kv_cache
=
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
if
is_fp8_kv_cache
:
key_cache
=
key_cache
.
view
(
self
.
fp8_dtype
)
value_cache
=
value_cache
.
view
(
self
.
fp8_dtype
)
...
...
vllm/v1/attention/ops/triton_reshape_and_cache_flash.py
View file @
116f4be4
...
...
@@ -5,6 +5,7 @@ import torch
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
@
triton
.
jit
...
...
@@ -145,16 +146,18 @@ def triton_reshape_and_cache_flash(
block_stride
=
key_cache
.
stride
()[
0
]
page_stride
=
key_cache
.
stride
()[
1
]
assert
kv_cache_dtype
==
"auto"
or
kv_cache_dtype
.
startswith
(
"fp8"
),
(
assert
kv_cache_dtype
==
"auto"
or
is_quantized_kv_cache
(
kv_cache_dtype
),
(
f
"unsupported kv_cache_dtype (str), got
{
kv_cache_dtype
}
."
)
kv_cache_torch_dtype
=
(
current_platform
.
fp8_dtype
()
if
kv_cache_dtype
.
startswith
(
"fp8"
)
if
is_quantized_kv_cache
(
kv_cache_dtype
)
else
key_cache
.
dtype
)
if
key_cache
.
dtype
!=
kv_cache_torch_dtype
and
kv_cache_dtype
.
startswith
(
"fp8"
):
if
key_cache
.
dtype
!=
kv_cache_torch_dtype
and
is_quantized_kv_cache
(
kv_cache_dtype
):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
key_cache
=
key_cache
.
view
(
kv_cache_torch_dtype
)
...
...
@@ -164,7 +167,7 @@ def triton_reshape_and_cache_flash(
"uint8 is not supported by triton reshape_and_cache_flash"
)
FP8_KV_CACHE
=
kv_cache_dtype
.
startswith
(
"fp8"
)
FP8_KV_CACHE
=
is_quantized_kv_cache
(
kv_cache_dtype
)
assert
(
not
FP8_KV_CACHE
)
or
kv_cache_torch_dtype
in
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
,
...
...
@@ -323,16 +326,16 @@ def triton_reshape_and_cache_flash_diffkv(
block_stride
=
kv_cache
.
stride
()[
0
]
page_stride
=
kv_cache
.
stride
()[
1
]
assert
kv_cache_dtype
==
"auto"
or
kv_cache_dtype
.
startswith
(
"fp8"
),
(
assert
kv_cache_dtype
==
"auto"
or
is_quantized_kv_cache
(
kv_cache_dtype
),
(
f
"unsupported kv_cache_dtype (str), got
{
kv_cache_dtype
}
."
)
kv_cache_torch_dtype
=
(
current_platform
.
fp8_dtype
()
if
kv_cache_dtype
.
startswith
(
"fp8"
)
if
is_quantized_kv_cache
(
kv_cache_dtype
)
else
kv_cache
.
dtype
)
if
kv_cache
.
dtype
!=
kv_cache_torch_dtype
and
kv_cache_dtype
.
startswith
(
"fp8"
):
if
kv_cache
.
dtype
!=
kv_cache_torch_dtype
and
is_quantized_kv_cache
(
kv_cache_dtype
):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
kv_cache
=
kv_cache
.
view
(
kv_cache_torch_dtype
)
...
...
@@ -341,7 +344,7 @@ def triton_reshape_and_cache_flash_diffkv(
"uint8 is not supported by triton reshape_and_cache_flash_diffkv"
)
FP8_KV_CACHE
=
kv_cache_dtype
.
startswith
(
"fp8"
)
FP8_KV_CACHE
=
is_quantized_kv_cache
(
kv_cache_dtype
)
assert
(
not
FP8_KV_CACHE
)
or
kv_cache_torch_dtype
in
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
116f4be4
...
...
@@ -109,6 +109,7 @@ from vllm.utils.nvtx_pytorch_hooks import PytHooks
from
vllm.utils.platform_utils
import
is_pin_memory_available
,
num_compute_units
from
vllm.utils.torch_utils
import
(
get_dtype_size
,
is_quantized_kv_cache
,
kv_cache_dtype_str_to_dtype
,
)
from
vllm.v1.attention.backend
import
(
...
...
@@ -896,7 +897,7 @@ class GPUModelRunner(
If these are left at 0.0 (default after wake_up), all KV cache values
become effectively zero, causing gibberish output.
"""
if
not
self
.
cache_config
.
cache_dtype
.
startswith
(
"fp8"
):
if
not
is_quantized_kv_cache
(
self
.
cache_config
.
cache_dtype
):
return
kv_caches
=
getattr
(
self
,
"kv_caches"
,
[])
...
...
vllm/v1/worker/gpu_worker.py
View file @
116f4be4
...
...
@@ -46,7 +46,7 @@ from vllm.tasks import SupportedTask
from
vllm.tracing
import
instrument
from
vllm.utils.mem_constants
import
GiB_bytes
from
vllm.utils.mem_utils
import
MemorySnapshot
,
format_gib
,
memory_profiling
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
,
set_random_seed
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
,
KVCacheSpec
from
vllm.v1.outputs
import
(
...
...
@@ -197,7 +197,7 @@ class Worker(WorkerBase):
# especially the FP8 scaling factor.
if
(
(
tags
is
None
or
"kv_cache"
in
tags
)
and
self
.
cache_config
.
cache_dtype
.
startswith
(
"fp8"
)
and
is_quantized_kv_cache
(
self
.
cache_config
.
cache_dtype
)
and
hasattr
(
self
.
model_runner
,
"init_fp8_kv_scales"
)
):
self
.
model_runner
.
init_fp8_kv_scales
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment