Unverified Commit 116f4be4 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[1/N][Cleanup] Standardize on use of `is_quantized_kv_cache` (#38659)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 7b01d97a
......@@ -8,6 +8,7 @@ from pydantic import Field, SkipValidation, field_validator, model_validator
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.torch_utils import is_quantized_kv_cache
logger = init_logger(__name__)
......@@ -236,7 +237,7 @@ class CacheConfig:
@field_validator("cache_dtype", mode="after")
@classmethod
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
if cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(cache_dtype):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
......
......@@ -241,6 +241,7 @@ from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory
from vllm.utils.math_utils import cdiv, round_down
from vllm.utils.torch_utils import (
direct_register_custom_op,
is_quantized_kv_cache,
kv_cache_dtype_str_to_dtype,
)
from vllm.v1.attention.backend import (
......@@ -342,7 +343,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
if (
self.attn_backend.get_name() == "FLASHMLA_SPARSE"
and kv_cache_dtype.startswith("fp8")
and is_quantized_kv_cache(kv_cache_dtype)
and kv_cache_dtype != "fp8_ds_mla"
):
assert cache_config is not None
......@@ -356,7 +357,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if (
self.attn_backend.get_name() == "FLASHINFER_MLA_SPARSE"
and kv_cache_dtype.startswith("fp8")
and is_quantized_kv_cache(kv_cache_dtype)
):
logger.info_once(
"Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla "
......@@ -571,7 +572,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if self.impl.dcp_world_size == -1:
self.impl.dcp_world_size = get_dcp_group().world_size
fp8_attention = self.kv_cache_dtype.startswith("fp8")
fp8_attention = is_quantized_kv_cache(self.kv_cache_dtype)
num_actual_toks = attn_metadata.num_actual_tokens
......@@ -1434,7 +1435,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
is enabled, else model dtype.
"""
use_fp8 = (
vllm_config.cache_config.cache_dtype.startswith("fp8")
is_quantized_kv_cache(vllm_config.cache_config.cache_dtype)
and vllm_config.attention_config.use_prefill_query_quantization
and backend_supports_prefill_query_quantization()
)
......
......@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
from vllm.platforms import current_platform
from vllm.v1.attention.backend import is_quantized_kv_cache
from vllm.utils.torch_utils import is_quantized_kv_cache
logger = init_logger(__name__)
......
......@@ -23,14 +23,13 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import (
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.models.utils import maybe_prefix
from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype
from vllm.utils.torch_utils import is_quantized_kv_cache, kv_cache_dtype_str_to_dtype
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
is_quantized_kv_cache,
)
from vllm.v1.kv_cache_interface import (
AttentionSpec,
......
......@@ -16,7 +16,7 @@ import torch
from vllm import envs
from vllm.logger import init_logger
from vllm.v1.attention.backend import is_quantized_kv_cache
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import CpuArchEnum, Platform, PlatformEnum
......@@ -183,7 +183,7 @@ class CpuPlatform(Platform):
"backend is not compatible with FP8 KV cache."
)
if cache_config.cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(cache_config.cache_dtype):
logger.warning(
"CPU backend doesn't support KV cache quantization fallback to auto."
)
......
......@@ -23,6 +23,7 @@ import vllm._C_stable_libtorch # noqa
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import DeviceCapability, Platform, PlatformEnum
......@@ -87,7 +88,7 @@ def _get_backend_priorities(
# Sparse MLA backend priorities
# See https://github.com/vllm-project/vllm/issues/35807 for
# benchmark results
if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
if kv_cache_dtype is not None and is_quantized_kv_cache(kv_cache_dtype):
# Prefer FlashInfer for fp8 kv cache
sparse_backends = [
AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
......
......@@ -61,6 +61,10 @@ MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = {
T = TypeVar("T")
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype.startswith("fp8")
def is_strictly_contiguous(t: torch.Tensor) -> bool:
"""
Check if tensor is contiguous AND has no degenerate strides.
......
......@@ -954,10 +954,6 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
)
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype.startswith("fp8")
def subclass_attention_backend(
name_prefix: str,
attention_backend_cls: type[AttentionBackend],
......
......@@ -9,6 +9,7 @@ from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
......@@ -16,7 +17,6 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills,
......
......@@ -10,12 +10,12 @@ import numpy as np
import torch
from vllm.model_executor.layers.attention import Attention
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_fp8,
......@@ -177,7 +177,7 @@ class FlashAttentionBackend(AttentionBackend):
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
if kv_cache_dtype is None:
return True
if kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(kv_cache_dtype):
return flash_attn_supports_fp8()
return kv_cache_dtype in ["auto", "float16", "bfloat16"]
......@@ -430,7 +430,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
):
cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(cache_dtype):
qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
cache_dtype
)
......@@ -726,7 +726,7 @@ class FlashAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
# queries are quantized in the attention layer
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype
......@@ -978,7 +978,7 @@ class FlashAttentionImpl(AttentionImpl):
)
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"quantization is not supported for encoder attention"
)
......
......@@ -4,6 +4,7 @@
import torch
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import AttentionType
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
......@@ -191,7 +192,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
key_cache = kv_cache[..., : self.head_size]
value_cache = kv_cache[..., self.head_size :]
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
# queries are quantized in the attention layer
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype
......
......@@ -42,7 +42,7 @@ from vllm.utils.flashinfer import (
)
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import is_strictly_contiguous
from vllm.utils.torch_utils import is_quantized_kv_cache, is_strictly_contiguous
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
......@@ -602,7 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.page_size = self.kv_cache_spec.block_size
self.cache_dtype = self.cache_config.cache_dtype
if self.cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.cache_dtype):
self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.cache_dtype
)
......@@ -1269,7 +1269,7 @@ class FlashInferImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return (
self.support_trtllm_attn
and self.kv_cache_dtype.startswith("fp8")
and is_quantized_kv_cache(self.kv_cache_dtype)
and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
)
......@@ -1317,12 +1317,12 @@ class FlashInferImpl(AttentionImpl):
if self.bmm1_scale is None:
self.bmm1_scale = self.scale
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
if self.bmm2_scale is None:
self.bmm2_scale = 1.0
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm2_scale *= layer._v_scale_float
prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill)
......@@ -1375,8 +1375,8 @@ class FlashInferImpl(AttentionImpl):
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if self.kv_sharing_target_layer_name is None and self.kv_cache_dtype.startswith(
"fp8"
if self.kv_sharing_target_layer_name is None and is_quantized_kv_cache(
self.kv_cache_dtype
):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype
......@@ -1486,9 +1486,8 @@ class FlashInferImpl(AttentionImpl):
assert self.o_sf_scale is None
out = output[num_decode_tokens:]
if (
attn_metadata.q_data_type != FP8_DTYPE
and self.kv_cache_dtype.startswith("fp8")
if attn_metadata.q_data_type != FP8_DTYPE and is_quantized_kv_cache(
self.kv_cache_dtype
):
# TRTLLM prefill attention does not support BF16 Q
# and fp8 kv cache. So to enable prefill attention
......
......@@ -27,14 +27,13 @@ from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer
from vllm.utils.torch_utils import is_quantized_kv_cache, is_torch_equal_or_newer
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionImpl,
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
is_quantized_kv_cache,
)
from vllm.v1.kv_cache_interface import AttentionSpec
......
......@@ -17,12 +17,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
logger = init_logger(__name__)
......
......@@ -20,12 +20,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_mla,
......@@ -319,7 +319,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
......
......@@ -16,12 +16,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
QueryLenSupport,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.utils import KVCacheLayoutType
......@@ -184,12 +184,12 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
if self.bmm1_scale is None:
self.bmm1_scale = self.scale
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
if self.bmm2_scale is None:
self.bmm2_scale = 1.0
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm2_scale *= layer._k_scale_float
# Reuse pre-allocated zero-init output buffer to avoid a memset
......
......@@ -26,6 +26,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
......@@ -341,11 +342,11 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata
if self.bmm1_scale is None:
self.bmm1_scale = self.scale
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
if self.bmm2_scale is None:
self.bmm2_scale = 1.0
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm2_scale *= layer._k_scale_float
o = trtllm_batch_decode_with_kv_cache_mla(
......
......@@ -20,6 +20,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
......@@ -128,7 +129,9 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
self.is_fp8_kvcache = is_quantized_kv_cache(
vllm_config.cache_config.cache_dtype
)
num_sms = num_compute_units(self.device.index)
......@@ -269,7 +272,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q = reshape_query_for_spec_decode(q, num_decodes)
scheduler_metadata = attn_metadata.decode.scheduler_metadata
if envs.VLLM_BATCH_INVARIANT and not self.kv_cache_dtype.startswith("fp8"):
if envs.VLLM_BATCH_INVARIANT and not is_quantized_kv_cache(self.kv_cache_dtype):
device = q.device
dtype = torch.int32
......@@ -299,7 +302,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
scheduler_metadata.num_splits = num_splits
if self.kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(self.kv_cache_dtype):
o, lse = flash_mla_with_kvcache_fp8(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
......
......@@ -16,6 +16,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
......@@ -571,7 +572,7 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
vllm_config = get_current_vllm_config()
max_tokens = vllm_config.scheduler_config.max_num_batched_tokens
q_concat_shape = (max_tokens, num_heads, head_size)
if kv_cache_dtype.startswith("fp8"):
if is_quantized_kv_cache(kv_cache_dtype):
assert kv_cache_dtype == "fp8_ds_mla", (
"FlashMLA Sparse Attention backend fp8 only supports "
"fp8_ds_mla kv-cache dtype"
......
......@@ -14,11 +14,11 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonMetadata,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import (
AttentionLayer,
AttentionType,
MultipleOf,
is_quantized_kv_cache,
)
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment