Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
116f4be4
Unverified
Commit
116f4be4
authored
Apr 01, 2026
by
Matthew Bonanni
Committed by
GitHub
Apr 01, 2026
Browse files
[1/N][Cleanup] Standardize on use of `is_quantized_kv_cache` (#38659)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
7b01d97a
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
53 additions
and
47 deletions
+53
-47
vllm/config/cache.py
vllm/config/cache.py
+2
-1
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+5
-4
vllm/model_executor/layers/quantization/kv_cache.py
vllm/model_executor/layers/quantization/kv_cache.py
+1
-1
vllm/model_executor/models/extract_hidden_states.py
vllm/model_executor/models/extract_hidden_states.py
+1
-2
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+2
-2
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+2
-1
vllm/utils/torch_utils.py
vllm/utils/torch_utils.py
+4
-0
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+0
-4
vllm/v1/attention/backends/cpu_attn.py
vllm/v1/attention/backends/cpu_attn.py
+1
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+5
-5
vllm/v1/attention/backends/flash_attn_diffkv.py
vllm/v1/attention/backends/flash_attn_diffkv.py
+2
-1
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+9
-10
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+1
-2
vllm/v1/attention/backends/mla/cutlass_mla.py
vllm/v1/attention/backends/mla/cutlass_mla.py
+1
-1
vllm/v1/attention/backends/mla/flashattn_mla.py
vllm/v1/attention/backends/mla/flashattn_mla.py
+2
-2
vllm/v1/attention/backends/mla/flashinfer_mla.py
vllm/v1/attention/backends/mla/flashinfer_mla.py
+3
-3
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
+3
-2
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+6
-3
vllm/v1/attention/backends/mla/flashmla_sparse.py
vllm/v1/attention/backends/mla/flashmla_sparse.py
+2
-1
vllm/v1/attention/backends/mla/triton_mla.py
vllm/v1/attention/backends/mla/triton_mla.py
+1
-1
No files found.
vllm/config/cache.py
View file @
116f4be4
...
@@ -8,6 +8,7 @@ from pydantic import Field, SkipValidation, field_validator, model_validator
...
@@ -8,6 +8,7 @@ from pydantic import Field, SkipValidation, field_validator, model_validator
from
vllm.config.utils
import
config
from
vllm.config.utils
import
config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -236,7 +237,7 @@ class CacheConfig:
...
@@ -236,7 +237,7 @@ class CacheConfig:
@
field_validator
(
"cache_dtype"
,
mode
=
"after"
)
@
field_validator
(
"cache_dtype"
,
mode
=
"after"
)
@
classmethod
@
classmethod
def
_validate_cache_dtype
(
cls
,
cache_dtype
:
CacheDType
)
->
CacheDType
:
def
_validate_cache_dtype
(
cls
,
cache_dtype
:
CacheDType
)
->
CacheDType
:
if
cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
cache_dtype
):
logger
.
info
(
logger
.
info
(
"Using fp8 data type to store kv cache. It reduces the GPU "
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"memory footprint and boosts the performance. "
...
...
vllm/model_executor/layers/attention/mla_attention.py
View file @
116f4be4
...
@@ -241,6 +241,7 @@ from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory
...
@@ -241,6 +241,7 @@ from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory
from
vllm.utils.math_utils
import
cdiv
,
round_down
from
vllm.utils.math_utils
import
cdiv
,
round_down
from
vllm.utils.torch_utils
import
(
from
vllm.utils.torch_utils
import
(
direct_register_custom_op
,
direct_register_custom_op
,
is_quantized_kv_cache
,
kv_cache_dtype_str_to_dtype
,
kv_cache_dtype_str_to_dtype
,
)
)
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
...
@@ -342,7 +343,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -342,7 +343,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
if
(
if
(
self
.
attn_backend
.
get_name
()
==
"FLASHMLA_SPARSE"
self
.
attn_backend
.
get_name
()
==
"FLASHMLA_SPARSE"
and
kv_cache_dtype
.
startswith
(
"fp8"
)
and
is_quantized_kv_cache
(
kv_cache_dtype
)
and
kv_cache_dtype
!=
"fp8_ds_mla"
and
kv_cache_dtype
!=
"fp8_ds_mla"
):
):
assert
cache_config
is
not
None
assert
cache_config
is
not
None
...
@@ -356,7 +357,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -356,7 +357,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if
(
if
(
self
.
attn_backend
.
get_name
()
==
"FLASHINFER_MLA_SPARSE"
self
.
attn_backend
.
get_name
()
==
"FLASHINFER_MLA_SPARSE"
and
kv_cache_dtype
.
startswith
(
"fp8"
)
and
is_quantized_kv_cache
(
kv_cache_dtype
)
):
):
logger
.
info_once
(
logger
.
info_once
(
"Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla "
"Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla "
...
@@ -571,7 +572,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
...
@@ -571,7 +572,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if
self
.
impl
.
dcp_world_size
==
-
1
:
if
self
.
impl
.
dcp_world_size
==
-
1
:
self
.
impl
.
dcp_world_size
=
get_dcp_group
().
world_size
self
.
impl
.
dcp_world_size
=
get_dcp_group
().
world_size
fp8_attention
=
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
fp8_attention
=
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
num_actual_toks
=
attn_metadata
.
num_actual_tokens
num_actual_toks
=
attn_metadata
.
num_actual_tokens
...
@@ -1434,7 +1435,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -1434,7 +1435,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
is enabled, else model dtype.
is enabled, else model dtype.
"""
"""
use_fp8
=
(
use_fp8
=
(
vllm_config
.
cache_config
.
cache_dtype
.
startswith
(
"fp8"
)
is_quantized_kv_cache
(
vllm_config
.
cache_config
.
cache_dtype
)
and
vllm_config
.
attention_config
.
use_prefill_query_quantization
and
vllm_config
.
attention_config
.
use_prefill_query_quantization
and
backend_supports_prefill_query_quantization
()
and
backend_supports_prefill_query_quantization
()
)
)
...
...
vllm/model_executor/layers/quantization/kv_cache.py
View file @
116f4be4
...
@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.
v1.attention.backend
import
is_quantized_kv_cache
from
vllm.
utils.torch_utils
import
is_quantized_kv_cache
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/models/extract_hidden_states.py
View file @
116f4be4
...
@@ -23,14 +23,13 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import (
...
@@ -23,14 +23,13 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import (
)
)
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.models.utils
import
maybe_prefix
from
vllm.model_executor.models.utils
import
maybe_prefix
from
vllm.utils.torch_utils
import
kv_cache_dtype_str_to_dtype
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
,
kv_cache_dtype_str_to_dtype
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionBackend
,
AttentionImpl
,
AttentionImpl
,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
AttentionType
,
AttentionType
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
is_quantized_kv_cache
,
)
)
from
vllm.v1.kv_cache_interface
import
(
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
AttentionSpec
,
...
...
vllm/platforms/cpu.py
View file @
116f4be4
...
@@ -16,7 +16,7 @@ import torch
...
@@ -16,7 +16,7 @@ import torch
from
vllm
import
envs
from
vllm
import
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.
v1.attention.backend
import
is_quantized_kv_cache
from
vllm.
utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
.interface
import
CpuArchEnum
,
Platform
,
PlatformEnum
from
.interface
import
CpuArchEnum
,
Platform
,
PlatformEnum
...
@@ -183,7 +183,7 @@ class CpuPlatform(Platform):
...
@@ -183,7 +183,7 @@ class CpuPlatform(Platform):
"backend is not compatible with FP8 KV cache."
"backend is not compatible with FP8 KV cache."
)
)
if
cache_config
.
cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
cache_config
.
cache_dtype
):
logger
.
warning
(
logger
.
warning
(
"CPU backend doesn't support KV cache quantization fallback to auto."
"CPU backend doesn't support KV cache quantization fallback to auto."
)
)
...
...
vllm/platforms/cuda.py
View file @
116f4be4
...
@@ -23,6 +23,7 @@ import vllm._C_stable_libtorch # noqa
...
@@ -23,6 +23,7 @@ import vllm._C_stable_libtorch # noqa
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils.import_utils
import
import_pynvml
from
vllm.utils.import_utils
import
import_pynvml
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
...
@@ -87,7 +88,7 @@ def _get_backend_priorities(
...
@@ -87,7 +88,7 @@ def _get_backend_priorities(
# Sparse MLA backend priorities
# Sparse MLA backend priorities
# See https://github.com/vllm-project/vllm/issues/35807 for
# See https://github.com/vllm-project/vllm/issues/35807 for
# benchmark results
# benchmark results
if
kv_cache_dtype
is
not
None
and
kv_cache_dtype
.
startswith
(
"fp8"
):
if
kv_cache_dtype
is
not
None
and
is_quantized_kv_cache
(
kv_cache_dtype
):
# Prefer FlashInfer for fp8 kv cache
# Prefer FlashInfer for fp8 kv cache
sparse_backends
=
[
sparse_backends
=
[
AttentionBackendEnum
.
FLASHINFER_MLA_SPARSE
,
AttentionBackendEnum
.
FLASHINFER_MLA_SPARSE
,
...
...
vllm/utils/torch_utils.py
View file @
116f4be4
...
@@ -61,6 +61,10 @@ MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = {
...
@@ -61,6 +61,10 @@ MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = {
T
=
TypeVar
(
"T"
)
T
=
TypeVar
(
"T"
)
def
is_quantized_kv_cache
(
kv_cache_dtype
:
str
)
->
bool
:
return
kv_cache_dtype
.
startswith
(
"fp8"
)
def
is_strictly_contiguous
(
t
:
torch
.
Tensor
)
->
bool
:
def
is_strictly_contiguous
(
t
:
torch
.
Tensor
)
->
bool
:
"""
"""
Check if tensor is contiguous AND has no degenerate strides.
Check if tensor is contiguous AND has no degenerate strides.
...
...
vllm/v1/attention/backend.py
View file @
116f4be4
...
@@ -954,10 +954,6 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
...
@@ -954,10 +954,6 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
)
)
def
is_quantized_kv_cache
(
kv_cache_dtype
:
str
)
->
bool
:
return
kv_cache_dtype
.
startswith
(
"fp8"
)
def
subclass_attention_backend
(
def
subclass_attention_backend
(
name_prefix
:
str
,
name_prefix
:
str
,
attention_backend_cls
:
type
[
AttentionBackend
],
attention_backend_cls
:
type
[
AttentionBackend
],
...
...
vllm/v1/attention/backends/cpu_attn.py
View file @
116f4be4
...
@@ -9,6 +9,7 @@ from vllm import _custom_ops as ops
...
@@ -9,6 +9,7 @@ from vllm import _custom_ops as ops
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionBackend
,
AttentionImpl
,
AttentionImpl
,
...
@@ -16,7 +17,6 @@ from vllm.v1.attention.backend import (
...
@@ -16,7 +17,6 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
AttentionType
,
AttentionType
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
is_quantized_kv_cache
,
)
)
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
split_decodes_and_prefills
,
split_decodes_and_prefills
,
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
116f4be4
...
@@ -10,12 +10,12 @@ import numpy as np
...
@@ -10,12 +10,12 @@ import numpy as np
import
torch
import
torch
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionBackend
,
AttentionImpl
,
AttentionImpl
,
AttentionType
,
AttentionType
,
MultipleOf
,
MultipleOf
,
is_quantized_kv_cache
,
)
)
from
vllm.v1.attention.backends.fa_utils
import
(
from
vllm.v1.attention.backends.fa_utils
import
(
flash_attn_supports_fp8
,
flash_attn_supports_fp8
,
...
@@ -177,7 +177,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -177,7 +177,7 @@ class FlashAttentionBackend(AttentionBackend):
def
supports_kv_cache_dtype
(
cls
,
kv_cache_dtype
:
CacheDType
|
None
)
->
bool
:
def
supports_kv_cache_dtype
(
cls
,
kv_cache_dtype
:
CacheDType
|
None
)
->
bool
:
if
kv_cache_dtype
is
None
:
if
kv_cache_dtype
is
None
:
return
True
return
True
if
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
kv_cache_dtype
):
return
flash_attn_supports_fp8
()
return
flash_attn_supports_fp8
()
return
kv_cache_dtype
in
[
"auto"
,
"float16"
,
"bfloat16"
]
return
kv_cache_dtype
in
[
"auto"
,
"float16"
,
"bfloat16"
]
...
@@ -430,7 +430,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
...
@@ -430,7 +430,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
batch_size
,
cu_query_lens
,
max_query_len
,
seqlens
,
max_seq_len
,
causal
batch_size
,
cu_query_lens
,
max_query_len
,
seqlens
,
max_seq_len
,
causal
):
):
cache_dtype
=
self
.
cache_config
.
cache_dtype
cache_dtype
=
self
.
cache_config
.
cache_dtype
if
cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
cache_dtype
):
qkv_dtype
=
FlashAttentionBackend
.
get_fp8_dtype_for_flashattn
(
qkv_dtype
=
FlashAttentionBackend
.
get_fp8_dtype_for_flashattn
(
cache_dtype
cache_dtype
)
)
...
@@ -726,7 +726,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -726,7 +726,7 @@ class FlashAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before
# For decoder and cross-attention, use KV cache as before
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
# queries are quantized in the attention layer
# queries are quantized in the attention layer
dtype
=
FlashAttentionBackend
.
get_fp8_dtype_for_flashattn
(
dtype
=
FlashAttentionBackend
.
get_fp8_dtype_for_flashattn
(
self
.
kv_cache_dtype
self
.
kv_cache_dtype
...
@@ -978,7 +978,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -978,7 +978,7 @@ class FlashAttentionImpl(AttentionImpl):
)
)
# For encoder attention, process FP8 quantization if needed
# For encoder attention, process FP8 quantization if needed
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"quantization is not supported for encoder attention"
"quantization is not supported for encoder attention"
)
)
...
...
vllm/v1/attention/backends/flash_attn_diffkv.py
View file @
116f4be4
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
torch
import
torch
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
AttentionType
from
vllm.v1.attention.backend
import
AttentionType
from
vllm.v1.attention.backends.fa_utils
import
is_flash_attn_varlen_func_available
from
vllm.v1.attention.backends.fa_utils
import
is_flash_attn_varlen_func_available
from
vllm.v1.attention.ops.triton_reshape_and_cache_flash
import
(
from
vllm.v1.attention.ops.triton_reshape_and_cache_flash
import
(
...
@@ -191,7 +192,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
...
@@ -191,7 +192,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
key_cache
=
kv_cache
[...,
:
self
.
head_size
]
key_cache
=
kv_cache
[...,
:
self
.
head_size
]
value_cache
=
kv_cache
[...,
self
.
head_size
:]
value_cache
=
kv_cache
[...,
self
.
head_size
:]
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
# queries are quantized in the attention layer
# queries are quantized in the attention layer
dtype
=
FlashAttentionBackend
.
get_fp8_dtype_for_flashattn
(
dtype
=
FlashAttentionBackend
.
get_fp8_dtype_for_flashattn
(
self
.
kv_cache_dtype
self
.
kv_cache_dtype
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
116f4be4
...
@@ -42,7 +42,7 @@ from vllm.utils.flashinfer import (
...
@@ -42,7 +42,7 @@ from vllm.utils.flashinfer import (
)
)
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.utils.platform_utils
import
is_pin_memory_available
from
vllm.utils.torch_utils
import
is_strictly_contiguous
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
,
is_strictly_contiguous
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionBackend
,
AttentionCGSupport
,
AttentionCGSupport
,
...
@@ -602,7 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
...
@@ -602,7 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self
.
page_size
=
self
.
kv_cache_spec
.
block_size
self
.
page_size
=
self
.
kv_cache_spec
.
block_size
self
.
cache_dtype
=
self
.
cache_config
.
cache_dtype
self
.
cache_dtype
=
self
.
cache_config
.
cache_dtype
if
self
.
cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
cache_dtype
):
self
.
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
kv_cache_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
cache_dtype
self
.
cache_dtype
)
)
...
@@ -1269,7 +1269,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1269,7 +1269,7 @@ class FlashInferImpl(AttentionImpl):
def
fused_output_quant_supported
(
self
,
quant_key
:
QuantKey
):
def
fused_output_quant_supported
(
self
,
quant_key
:
QuantKey
):
return
(
return
(
self
.
support_trtllm_attn
self
.
support_trtllm_attn
and
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
and
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
and
quant_key
in
(
kFp8StaticTensorSym
,
kNvfp4Dynamic
)
and
quant_key
in
(
kFp8StaticTensorSym
,
kNvfp4Dynamic
)
)
)
...
@@ -1317,12 +1317,12 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1317,12 +1317,12 @@ class FlashInferImpl(AttentionImpl):
if
self
.
bmm1_scale
is
None
:
if
self
.
bmm1_scale
is
None
:
self
.
bmm1_scale
=
self
.
scale
self
.
bmm1_scale
=
self
.
scale
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
self
.
bmm1_scale
*=
layer
.
_q_scale_float
*
layer
.
_k_scale_float
self
.
bmm1_scale
*=
layer
.
_q_scale_float
*
layer
.
_k_scale_float
if
self
.
bmm2_scale
is
None
:
if
self
.
bmm2_scale
is
None
:
self
.
bmm2_scale
=
1.0
self
.
bmm2_scale
=
1.0
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
self
.
bmm2_scale
*=
layer
.
_v_scale_float
self
.
bmm2_scale
*=
layer
.
_v_scale_float
prefill_use_trtllm
=
isinstance
(
attn_metadata
.
prefill
,
TRTLLMPrefill
)
prefill_use_trtllm
=
isinstance
(
attn_metadata
.
prefill
,
TRTLLMPrefill
)
...
@@ -1375,8 +1375,8 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1375,8 +1375,8 @@ class FlashInferImpl(AttentionImpl):
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
# to process the cache when the kv_cache_dtype is fp8
if
self
.
kv_sharing_target_layer_name
is
None
and
self
.
kv_cache_dtype
.
startswith
(
if
self
.
kv_sharing_target_layer_name
is
None
and
is_quantized_kv_cache
(
"fp8"
self
.
kv_cache_dtype
):
):
torch_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
torch_dtype
=
FlashInferBackend
.
get_fp8_dtype_for_flashinfer
(
self
.
kv_cache_dtype
self
.
kv_cache_dtype
...
@@ -1486,9 +1486,8 @@ class FlashInferImpl(AttentionImpl):
...
@@ -1486,9 +1486,8 @@ class FlashInferImpl(AttentionImpl):
assert
self
.
o_sf_scale
is
None
assert
self
.
o_sf_scale
is
None
out
=
output
[
num_decode_tokens
:]
out
=
output
[
num_decode_tokens
:]
if
(
if
attn_metadata
.
q_data_type
!=
FP8_DTYPE
and
is_quantized_kv_cache
(
attn_metadata
.
q_data_type
!=
FP8_DTYPE
self
.
kv_cache_dtype
and
self
.
kv_cache_dtype
.
startswith
(
"fp8"
)
):
):
# TRTLLM prefill attention does not support BF16 Q
# TRTLLM prefill attention does not support BF16 Q
# and fp8 kv cache. So to enable prefill attention
# and fp8 kv cache. So to enable prefill attention
...
...
vllm/v1/attention/backends/flex_attention.py
View file @
116f4be4
...
@@ -27,14 +27,13 @@ from vllm.config.cache import CacheDType
...
@@ -27,14 +27,13 @@ from vllm.config.cache import CacheDType
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
,
is_torch_equal_or_newer
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionBackend
,
AttentionImpl
,
AttentionImpl
,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
AttentionType
,
AttentionType
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
is_quantized_kv_cache
,
)
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
...
vllm/v1/attention/backends/mla/cutlass_mla.py
View file @
116f4be4
...
@@ -17,12 +17,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
...
@@ -17,12 +17,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
)
)
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.utils.platform_utils
import
num_compute_units
from
vllm.utils.platform_utils
import
num_compute_units
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionCGSupport
,
AttentionCGSupport
,
AttentionLayer
,
AttentionLayer
,
AttentionType
,
AttentionType
,
MultipleOf
,
MultipleOf
,
is_quantized_kv_cache
,
)
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/v1/attention/backends/mla/flashattn_mla.py
View file @
116f4be4
...
@@ -20,12 +20,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
...
@@ -20,12 +20,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
)
)
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionCGSupport
,
AttentionCGSupport
,
AttentionLayer
,
AttentionLayer
,
AttentionType
,
AttentionType
,
MultipleOf
,
MultipleOf
,
is_quantized_kv_cache
,
)
)
from
vllm.v1.attention.backends.fa_utils
import
(
from
vllm.v1.attention.backends.fa_utils
import
(
flash_attn_supports_mla
,
flash_attn_supports_mla
,
...
@@ -319,7 +319,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
...
@@ -319,7 +319,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
q
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
q
,
[
self
.
kv_lora_rank
,
self
.
qk_rope_head_dim
],
dim
=-
1
)
)
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"FP8 FlashAttention MLA not yet supported"
)
raise
NotImplementedError
(
"FP8 FlashAttention MLA not yet supported"
)
kv_c_cache
=
kv_c_and_k_pe_cache
[...,
:
self
.
kv_lora_rank
]
kv_c_cache
=
kv_c_and_k_pe_cache
[...,
:
self
.
kv_lora_rank
]
...
...
vllm/v1/attention/backends/mla/flashinfer_mla.py
View file @
116f4be4
...
@@ -16,12 +16,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
...
@@ -16,12 +16,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
QueryLenSupport
,
QueryLenSupport
,
)
)
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionCGSupport
,
AttentionCGSupport
,
AttentionLayer
,
AttentionLayer
,
AttentionType
,
AttentionType
,
MultipleOf
,
MultipleOf
,
is_quantized_kv_cache
,
)
)
from
vllm.v1.attention.backends.utils
import
KVCacheLayoutType
from
vllm.v1.attention.backends.utils
import
KVCacheLayoutType
...
@@ -184,12 +184,12 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
...
@@ -184,12 +184,12 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
if
self
.
bmm1_scale
is
None
:
if
self
.
bmm1_scale
is
None
:
self
.
bmm1_scale
=
self
.
scale
self
.
bmm1_scale
=
self
.
scale
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
self
.
bmm1_scale
*=
layer
.
_q_scale_float
*
layer
.
_k_scale_float
self
.
bmm1_scale
*=
layer
.
_q_scale_float
*
layer
.
_k_scale_float
if
self
.
bmm2_scale
is
None
:
if
self
.
bmm2_scale
is
None
:
self
.
bmm2_scale
=
1.0
self
.
bmm2_scale
=
1.0
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
self
.
bmm2_scale
*=
layer
.
_k_scale_float
self
.
bmm2_scale
*=
layer
.
_k_scale_float
# Reuse pre-allocated zero-init output buffer to avoid a memset
# Reuse pre-allocated zero-init output buffer to avoid a memset
...
...
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
View file @
116f4be4
...
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
...
@@ -26,6 +26,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims
,
get_mla_dims
,
)
)
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionBackend
,
AttentionCGSupport
,
AttentionCGSupport
,
...
@@ -341,11 +342,11 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata
...
@@ -341,11 +342,11 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata
if
self
.
bmm1_scale
is
None
:
if
self
.
bmm1_scale
is
None
:
self
.
bmm1_scale
=
self
.
scale
self
.
bmm1_scale
=
self
.
scale
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
self
.
bmm1_scale
*=
layer
.
_q_scale_float
*
layer
.
_k_scale_float
self
.
bmm1_scale
*=
layer
.
_q_scale_float
*
layer
.
_k_scale_float
if
self
.
bmm2_scale
is
None
:
if
self
.
bmm2_scale
is
None
:
self
.
bmm2_scale
=
1.0
self
.
bmm2_scale
=
1.0
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
self
.
bmm2_scale
*=
layer
.
_k_scale_float
self
.
bmm2_scale
*=
layer
.
_k_scale_float
o
=
trtllm_batch_decode_with_kv_cache_mla
(
o
=
trtllm_batch_decode_with_kv_cache_mla
(
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
116f4be4
...
@@ -20,6 +20,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
...
@@ -20,6 +20,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
)
)
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.utils.platform_utils
import
num_compute_units
from
vllm.utils.platform_utils
import
num_compute_units
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionCGSupport
,
AttentionCGSupport
,
AttentionLayer
,
AttentionLayer
,
...
@@ -128,7 +129,9 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
...
@@ -128,7 +129,9 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self
.
cg_buf_tile_scheduler_metadata
=
None
self
.
cg_buf_tile_scheduler_metadata
=
None
self
.
cg_buf_num_splits
=
None
self
.
cg_buf_num_splits
=
None
self
.
is_fp8_kvcache
=
vllm_config
.
cache_config
.
cache_dtype
.
startswith
(
"fp8"
)
self
.
is_fp8_kvcache
=
is_quantized_kv_cache
(
vllm_config
.
cache_config
.
cache_dtype
)
num_sms
=
num_compute_units
(
self
.
device
.
index
)
num_sms
=
num_compute_units
(
self
.
device
.
index
)
...
@@ -269,7 +272,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -269,7 +272,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q
=
reshape_query_for_spec_decode
(
q
,
num_decodes
)
q
=
reshape_query_for_spec_decode
(
q
,
num_decodes
)
scheduler_metadata
=
attn_metadata
.
decode
.
scheduler_metadata
scheduler_metadata
=
attn_metadata
.
decode
.
scheduler_metadata
if
envs
.
VLLM_BATCH_INVARIANT
and
not
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
envs
.
VLLM_BATCH_INVARIANT
and
not
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
device
=
q
.
device
device
=
q
.
device
dtype
=
torch
.
int32
dtype
=
torch
.
int32
...
@@ -299,7 +302,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -299,7 +302,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
scheduler_metadata
.
tile_scheduler_metadata
=
tile_scheduler_metadata
scheduler_metadata
.
tile_scheduler_metadata
=
tile_scheduler_metadata
scheduler_metadata
.
num_splits
=
num_splits
scheduler_metadata
.
num_splits
=
num_splits
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
o
,
lse
=
flash_mla_with_kvcache_fp8
(
o
,
lse
=
flash_mla_with_kvcache_fp8
(
q
=
q
,
q
=
q
,
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
k_cache
=
kv_c_and_k_pe_cache
.
unsqueeze
(
-
2
),
# Add head dim of 1
...
...
vllm/v1/attention/backends/mla/flashmla_sparse.py
View file @
116f4be4
...
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
...
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.utils.platform_utils
import
num_compute_units
from
vllm.utils.platform_utils
import
num_compute_units
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionBackend
,
AttentionBackend
,
AttentionCGSupport
,
AttentionCGSupport
,
...
@@ -571,7 +572,7 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
...
@@ -571,7 +572,7 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
max_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
max_tokens
=
vllm_config
.
scheduler_config
.
max_num_batched_tokens
q_concat_shape
=
(
max_tokens
,
num_heads
,
head_size
)
q_concat_shape
=
(
max_tokens
,
num_heads
,
head_size
)
if
kv_cache_dtype
.
startswith
(
"fp8"
):
if
is_quantized_kv_cache
(
kv_cache_dtype
):
assert
kv_cache_dtype
==
"fp8_ds_mla"
,
(
assert
kv_cache_dtype
==
"fp8_ds_mla"
,
(
"FlashMLA Sparse Attention backend fp8 only supports "
"FlashMLA Sparse Attention backend fp8 only supports "
"fp8_ds_mla kv-cache dtype"
"fp8_ds_mla kv-cache dtype"
...
...
vllm/v1/attention/backends/mla/triton_mla.py
View file @
116f4be4
...
@@ -14,11 +14,11 @@ from vllm.model_executor.layers.attention.mla_attention import (
...
@@ -14,11 +14,11 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonMetadata
,
MLACommonMetadata
,
)
)
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.platforms.interface
import
DeviceCapability
from
vllm.utils.torch_utils
import
is_quantized_kv_cache
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionLayer
,
AttentionLayer
,
AttentionType
,
AttentionType
,
MultipleOf
,
MultipleOf
,
is_quantized_kv_cache
,
)
)
from
vllm.v1.attention.ops.triton_decode_attention
import
decode_attention_fwd
from
vllm.v1.attention.ops.triton_decode_attention
import
decode_attention_fwd
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment