Unverified Commit 116f4be4 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[1/N][Cleanup] Standardize on use of `is_quantized_kv_cache` (#38659)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 7b01d97a
...@@ -8,6 +8,7 @@ from pydantic import Field, SkipValidation, field_validator, model_validator ...@@ -8,6 +8,7 @@ from pydantic import Field, SkipValidation, field_validator, model_validator
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.torch_utils import is_quantized_kv_cache
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -236,7 +237,7 @@ class CacheConfig: ...@@ -236,7 +237,7 @@ class CacheConfig:
@field_validator("cache_dtype", mode="after") @field_validator("cache_dtype", mode="after")
@classmethod @classmethod
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType: def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
if cache_dtype.startswith("fp8"): if is_quantized_kv_cache(cache_dtype):
logger.info( logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU " "Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. " "memory footprint and boosts the performance. "
......
...@@ -241,6 +241,7 @@ from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory ...@@ -241,6 +241,7 @@ from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory
from vllm.utils.math_utils import cdiv, round_down from vllm.utils.math_utils import cdiv, round_down
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
direct_register_custom_op, direct_register_custom_op,
is_quantized_kv_cache,
kv_cache_dtype_str_to_dtype, kv_cache_dtype_str_to_dtype,
) )
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
...@@ -342,7 +343,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -342,7 +343,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
# Automatically convert fp8 kv-cache format to "fp8_ds_mla" # Automatically convert fp8 kv-cache format to "fp8_ds_mla"
if ( if (
self.attn_backend.get_name() == "FLASHMLA_SPARSE" self.attn_backend.get_name() == "FLASHMLA_SPARSE"
and kv_cache_dtype.startswith("fp8") and is_quantized_kv_cache(kv_cache_dtype)
and kv_cache_dtype != "fp8_ds_mla" and kv_cache_dtype != "fp8_ds_mla"
): ):
assert cache_config is not None assert cache_config is not None
...@@ -356,7 +357,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -356,7 +357,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if ( if (
self.attn_backend.get_name() == "FLASHINFER_MLA_SPARSE" self.attn_backend.get_name() == "FLASHINFER_MLA_SPARSE"
and kv_cache_dtype.startswith("fp8") and is_quantized_kv_cache(kv_cache_dtype)
): ):
logger.info_once( logger.info_once(
"Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla " "Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla "
...@@ -571,7 +572,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -571,7 +572,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if self.impl.dcp_world_size == -1: if self.impl.dcp_world_size == -1:
self.impl.dcp_world_size = get_dcp_group().world_size self.impl.dcp_world_size = get_dcp_group().world_size
fp8_attention = self.kv_cache_dtype.startswith("fp8") fp8_attention = is_quantized_kv_cache(self.kv_cache_dtype)
num_actual_toks = attn_metadata.num_actual_tokens num_actual_toks = attn_metadata.num_actual_tokens
...@@ -1434,7 +1435,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -1434,7 +1435,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
is enabled, else model dtype. is enabled, else model dtype.
""" """
use_fp8 = ( use_fp8 = (
vllm_config.cache_config.cache_dtype.startswith("fp8") is_quantized_kv_cache(vllm_config.cache_config.cache_dtype)
and vllm_config.attention_config.use_prefill_query_quantization and vllm_config.attention_config.use_prefill_query_quantization
and backend_supports_prefill_query_quantization() and backend_supports_prefill_query_quantization()
) )
......
...@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backend import is_quantized_kv_cache from vllm.utils.torch_utils import is_quantized_kv_cache
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -23,14 +23,13 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import ( ...@@ -23,14 +23,13 @@ from vllm.model_executor.layers.attention.kv_transfer_utils import (
) )
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.models.utils import maybe_prefix
from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype from vllm.utils.torch_utils import is_quantized_kv_cache, kv_cache_dtype_str_to_dtype
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
is_quantized_kv_cache,
) )
from vllm.v1.kv_cache_interface import ( from vllm.v1.kv_cache_interface import (
AttentionSpec, AttentionSpec,
......
...@@ -16,7 +16,7 @@ import torch ...@@ -16,7 +16,7 @@ import torch
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backend import is_quantized_kv_cache from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import CpuArchEnum, Platform, PlatformEnum from .interface import CpuArchEnum, Platform, PlatformEnum
...@@ -183,7 +183,7 @@ class CpuPlatform(Platform): ...@@ -183,7 +183,7 @@ class CpuPlatform(Platform):
"backend is not compatible with FP8 KV cache." "backend is not compatible with FP8 KV cache."
) )
if cache_config.cache_dtype.startswith("fp8"): if is_quantized_kv_cache(cache_config.cache_dtype):
logger.warning( logger.warning(
"CPU backend doesn't support KV cache quantization fallback to auto." "CPU backend doesn't support KV cache quantization fallback to auto."
) )
......
...@@ -23,6 +23,7 @@ import vllm._C_stable_libtorch # noqa ...@@ -23,6 +23,7 @@ import vllm._C_stable_libtorch # noqa
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml from vllm.utils.import_utils import import_pynvml
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
...@@ -87,7 +88,7 @@ def _get_backend_priorities( ...@@ -87,7 +88,7 @@ def _get_backend_priorities(
# Sparse MLA backend priorities # Sparse MLA backend priorities
# See https://github.com/vllm-project/vllm/issues/35807 for # See https://github.com/vllm-project/vllm/issues/35807 for
# benchmark results # benchmark results
if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): if kv_cache_dtype is not None and is_quantized_kv_cache(kv_cache_dtype):
# Prefer FlashInfer for fp8 kv cache # Prefer FlashInfer for fp8 kv cache
sparse_backends = [ sparse_backends = [
AttentionBackendEnum.FLASHINFER_MLA_SPARSE, AttentionBackendEnum.FLASHINFER_MLA_SPARSE,
......
...@@ -61,6 +61,10 @@ MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = { ...@@ -61,6 +61,10 @@ MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = {
T = TypeVar("T") T = TypeVar("T")
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype.startswith("fp8")
def is_strictly_contiguous(t: torch.Tensor) -> bool: def is_strictly_contiguous(t: torch.Tensor) -> bool:
""" """
Check if tensor is contiguous AND has no degenerate strides. Check if tensor is contiguous AND has no degenerate strides.
......
...@@ -954,10 +954,6 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]): ...@@ -954,10 +954,6 @@ class SparseMLAAttentionImpl(AttentionImplBase[T], Generic[T]):
) )
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype.startswith("fp8")
def subclass_attention_backend( def subclass_attention_backend(
name_prefix: str, name_prefix: str,
attention_backend_cls: type[AttentionBackend], attention_backend_cls: type[AttentionBackend],
......
...@@ -9,6 +9,7 @@ from vllm import _custom_ops as ops ...@@ -9,6 +9,7 @@ from vllm import _custom_ops as ops
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
...@@ -16,7 +17,6 @@ from vllm.v1.attention.backend import ( ...@@ -16,7 +17,6 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
is_quantized_kv_cache,
) )
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
split_decodes_and_prefills, split_decodes_and_prefills,
......
...@@ -10,12 +10,12 @@ import numpy as np ...@@ -10,12 +10,12 @@ import numpy as np
import torch import torch
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
is_quantized_kv_cache,
) )
from vllm.v1.attention.backends.fa_utils import ( from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_fp8, flash_attn_supports_fp8,
...@@ -177,7 +177,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -177,7 +177,7 @@ class FlashAttentionBackend(AttentionBackend):
def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool: def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
if kv_cache_dtype is None: if kv_cache_dtype is None:
return True return True
if kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(kv_cache_dtype):
return flash_attn_supports_fp8() return flash_attn_supports_fp8()
return kv_cache_dtype in ["auto", "float16", "bfloat16"] return kv_cache_dtype in ["auto", "float16", "bfloat16"]
...@@ -430,7 +430,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ...@@ -430,7 +430,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
): ):
cache_dtype = self.cache_config.cache_dtype cache_dtype = self.cache_config.cache_dtype
if cache_dtype.startswith("fp8"): if is_quantized_kv_cache(cache_dtype):
qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
cache_dtype cache_dtype
) )
...@@ -726,7 +726,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -726,7 +726,7 @@ class FlashAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before # For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
# queries are quantized in the attention layer # queries are quantized in the attention layer
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype self.kv_cache_dtype
...@@ -978,7 +978,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -978,7 +978,7 @@ class FlashAttentionImpl(AttentionImpl):
) )
# For encoder attention, process FP8 quantization if needed # For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError( raise NotImplementedError(
"quantization is not supported for encoder attention" "quantization is not supported for encoder attention"
) )
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import torch import torch
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import AttentionType from vllm.v1.attention.backend import AttentionType
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import ( from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
...@@ -191,7 +192,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl): ...@@ -191,7 +192,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
key_cache = kv_cache[..., : self.head_size] key_cache = kv_cache[..., : self.head_size]
value_cache = kv_cache[..., self.head_size :] value_cache = kv_cache[..., self.head_size :]
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
# queries are quantized in the attention layer # queries are quantized in the attention layer
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
self.kv_cache_dtype self.kv_cache_dtype
......
...@@ -42,7 +42,7 @@ from vllm.utils.flashinfer import ( ...@@ -42,7 +42,7 @@ from vllm.utils.flashinfer import (
) )
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import is_strictly_contiguous from vllm.utils.torch_utils import is_quantized_kv_cache, is_strictly_contiguous
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
...@@ -602,7 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -602,7 +602,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self.page_size = self.kv_cache_spec.block_size self.page_size = self.kv_cache_spec.block_size
self.cache_dtype = self.cache_config.cache_dtype self.cache_dtype = self.cache_config.cache_dtype
if self.cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.cache_dtype):
self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( self.kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.cache_dtype self.cache_dtype
) )
...@@ -1269,7 +1269,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -1269,7 +1269,7 @@ class FlashInferImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey): def fused_output_quant_supported(self, quant_key: QuantKey):
return ( return (
self.support_trtllm_attn self.support_trtllm_attn
and self.kv_cache_dtype.startswith("fp8") and is_quantized_kv_cache(self.kv_cache_dtype)
and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic) and quant_key in (kFp8StaticTensorSym, kNvfp4Dynamic)
) )
...@@ -1317,12 +1317,12 @@ class FlashInferImpl(AttentionImpl): ...@@ -1317,12 +1317,12 @@ class FlashInferImpl(AttentionImpl):
if self.bmm1_scale is None: if self.bmm1_scale is None:
self.bmm1_scale = self.scale self.bmm1_scale = self.scale
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
if self.bmm2_scale is None: if self.bmm2_scale is None:
self.bmm2_scale = 1.0 self.bmm2_scale = 1.0
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm2_scale *= layer._v_scale_float self.bmm2_scale *= layer._v_scale_float
prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill) prefill_use_trtllm = isinstance(attn_metadata.prefill, TRTLLMPrefill)
...@@ -1375,8 +1375,8 @@ class FlashInferImpl(AttentionImpl): ...@@ -1375,8 +1375,8 @@ class FlashInferImpl(AttentionImpl):
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8 # to process the cache when the kv_cache_dtype is fp8
if self.kv_sharing_target_layer_name is None and self.kv_cache_dtype.startswith( if self.kv_sharing_target_layer_name is None and is_quantized_kv_cache(
"fp8" self.kv_cache_dtype
): ):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype self.kv_cache_dtype
...@@ -1486,9 +1486,8 @@ class FlashInferImpl(AttentionImpl): ...@@ -1486,9 +1486,8 @@ class FlashInferImpl(AttentionImpl):
assert self.o_sf_scale is None assert self.o_sf_scale is None
out = output[num_decode_tokens:] out = output[num_decode_tokens:]
if ( if attn_metadata.q_data_type != FP8_DTYPE and is_quantized_kv_cache(
attn_metadata.q_data_type != FP8_DTYPE self.kv_cache_dtype
and self.kv_cache_dtype.startswith("fp8")
): ):
# TRTLLM prefill attention does not support BF16 Q # TRTLLM prefill attention does not support BF16 Q
# and fp8 kv cache. So to enable prefill attention # and fp8 kv cache. So to enable prefill attention
......
...@@ -27,14 +27,13 @@ from vllm.config.cache import CacheDType ...@@ -27,14 +27,13 @@ from vllm.config.cache import CacheDType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_quantized_kv_cache, is_torch_equal_or_newer
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionImpl, AttentionImpl,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
is_quantized_kv_cache,
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
......
...@@ -17,12 +17,12 @@ from vllm.model_executor.layers.attention.mla_attention import ( ...@@ -17,12 +17,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionCGSupport, AttentionCGSupport,
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
is_quantized_kv_cache,
) )
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -20,12 +20,12 @@ from vllm.model_executor.layers.attention.mla_attention import ( ...@@ -20,12 +20,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionCGSupport, AttentionCGSupport,
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
is_quantized_kv_cache,
) )
from vllm.v1.attention.backends.fa_utils import ( from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_mla, flash_attn_supports_mla,
...@@ -319,7 +319,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): ...@@ -319,7 +319,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
) )
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError("FP8 FlashAttention MLA not yet supported") raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
......
...@@ -16,12 +16,12 @@ from vllm.model_executor.layers.attention.mla_attention import ( ...@@ -16,12 +16,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
QueryLenSupport, QueryLenSupport,
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionCGSupport, AttentionCGSupport,
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
is_quantized_kv_cache,
) )
from vllm.v1.attention.backends.utils import KVCacheLayoutType from vllm.v1.attention.backends.utils import KVCacheLayoutType
...@@ -184,12 +184,12 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -184,12 +184,12 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
if self.bmm1_scale is None: if self.bmm1_scale is None:
self.bmm1_scale = self.scale self.bmm1_scale = self.scale
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
if self.bmm2_scale is None: if self.bmm2_scale is None:
self.bmm2_scale = 1.0 self.bmm2_scale = 1.0
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm2_scale *= layer._k_scale_float self.bmm2_scale *= layer._k_scale_float
# Reuse pre-allocated zero-init output buffer to avoid a memset # Reuse pre-allocated zero-init output buffer to avoid a memset
......
...@@ -26,6 +26,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( ...@@ -26,6 +26,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims, get_mla_dims,
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
...@@ -341,11 +342,11 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata ...@@ -341,11 +342,11 @@ class FlashInferMLASparseImpl(SparseMLAAttentionImpl[FlashInferMLASparseMetadata
if self.bmm1_scale is None: if self.bmm1_scale is None:
self.bmm1_scale = self.scale self.bmm1_scale = self.scale
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float self.bmm1_scale *= layer._q_scale_float * layer._k_scale_float
if self.bmm2_scale is None: if self.bmm2_scale is None:
self.bmm2_scale = 1.0 self.bmm2_scale = 1.0
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
self.bmm2_scale *= layer._k_scale_float self.bmm2_scale *= layer._k_scale_float
o = trtllm_batch_decode_with_kv_cache_mla( o = trtllm_batch_decode_with_kv_cache_mla(
......
...@@ -20,6 +20,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( ...@@ -20,6 +20,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionCGSupport, AttentionCGSupport,
AttentionLayer, AttentionLayer,
...@@ -128,7 +129,9 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ...@@ -128,7 +129,9 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.cg_buf_tile_scheduler_metadata = None self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None self.cg_buf_num_splits = None
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8") self.is_fp8_kvcache = is_quantized_kv_cache(
vllm_config.cache_config.cache_dtype
)
num_sms = num_compute_units(self.device.index) num_sms = num_compute_units(self.device.index)
...@@ -269,7 +272,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -269,7 +272,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q = reshape_query_for_spec_decode(q, num_decodes) q = reshape_query_for_spec_decode(q, num_decodes)
scheduler_metadata = attn_metadata.decode.scheduler_metadata scheduler_metadata = attn_metadata.decode.scheduler_metadata
if envs.VLLM_BATCH_INVARIANT and not self.kv_cache_dtype.startswith("fp8"): if envs.VLLM_BATCH_INVARIANT and not is_quantized_kv_cache(self.kv_cache_dtype):
device = q.device device = q.device
dtype = torch.int32 dtype = torch.int32
...@@ -299,7 +302,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -299,7 +302,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
scheduler_metadata.num_splits = num_splits scheduler_metadata.num_splits = num_splits
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
o, lse = flash_mla_with_kvcache_fp8( o, lse = flash_mla_with_kvcache_fp8(
q=q, q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
......
...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( ...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.platform_utils import num_compute_units from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
...@@ -571,7 +572,7 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]): ...@@ -571,7 +572,7 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
max_tokens = vllm_config.scheduler_config.max_num_batched_tokens max_tokens = vllm_config.scheduler_config.max_num_batched_tokens
q_concat_shape = (max_tokens, num_heads, head_size) q_concat_shape = (max_tokens, num_heads, head_size)
if kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(kv_cache_dtype):
assert kv_cache_dtype == "fp8_ds_mla", ( assert kv_cache_dtype == "fp8_ds_mla", (
"FlashMLA Sparse Attention backend fp8 only supports " "FlashMLA Sparse Attention backend fp8 only supports "
"fp8_ds_mla kv-cache dtype" "fp8_ds_mla kv-cache dtype"
......
...@@ -14,11 +14,11 @@ from vllm.model_executor.layers.attention.mla_attention import ( ...@@ -14,11 +14,11 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonMetadata, MLACommonMetadata,
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionLayer, AttentionLayer,
AttentionType, AttentionType,
MultipleOf, MultipleOf,
is_quantized_kv_cache,
) )
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment