Unverified Commit 955b43a5 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Bugfix][Attention] Explicitly report support for kv_cache_dtype bfloat16 (#32795)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 744ef304
...@@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backend import is_quantized_kv_cache
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -52,11 +53,14 @@ class BaseKVCacheMethod(QuantizeMethodBase): ...@@ -52,11 +53,14 @@ class BaseKVCacheMethod(QuantizeMethodBase):
assert not hasattr(layer, "prob_scale") assert not hasattr(layer, "prob_scale")
return return
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # If the kv-cache is not quantized, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint. # regardless whether the kv-scale is available in the checkpoint.
# No need to process kv scales after loading if we are going to # No need to process kv scales after loading if we are going to
# calculate them on the fly. # calculate them on the fly.
if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: if (
is_quantized_kv_cache(layer.kv_cache_dtype)
and not layer.calculate_kv_scales
):
if layer.k_scale > 0.0 and layer.v_scale > 0.0: if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present # We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist() k_scale = layer.k_scale.to("cpu").tolist()
......
...@@ -16,6 +16,7 @@ import torch ...@@ -16,6 +16,7 @@ import torch
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backend import is_quantized_kv_cache
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from .interface import CpuArchEnum, Platform, PlatformEnum from .interface import CpuArchEnum, Platform, PlatformEnum
...@@ -198,13 +199,13 @@ class CpuPlatform(Platform): ...@@ -198,13 +199,13 @@ class CpuPlatform(Platform):
if ( if (
scheduler_config.enable_chunked_prefill scheduler_config.enable_chunked_prefill
or cache_config.enable_prefix_caching or cache_config.enable_prefix_caching
) and cache_config.cache_dtype != "auto": ) and is_quantized_kv_cache(cache_config.cache_dtype):
raise RuntimeError( raise RuntimeError(
"Chunked-prefill and prefix-cache on the CPU " "Chunked-prefill and prefix-cache on the CPU "
"backend is not compatible with FP8 KV cache." "backend is not compatible with FP8 KV cache."
) )
if cache_config.cache_dtype != "auto": if cache_config.cache_dtype.startswith("fp8"):
logger.warning( logger.warning(
"CPU backend doesn't support KV cache quantization fallback to auto." "CPU backend doesn't support KV cache quantization fallback to auto."
) )
......
...@@ -51,7 +51,7 @@ class AttentionBackend(ABC): ...@@ -51,7 +51,7 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph. # makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False accept_output_buffer: bool = False
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"] supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto", "bfloat16"]
@staticmethod @staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
...@@ -747,7 +747,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]): ...@@ -747,7 +747,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
def is_quantized_kv_cache(kv_cache_dtype: str) -> bool: def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype != "auto" return kv_cache_dtype.startswith("fp8")
def subclass_attention_backend( def subclass_attention_backend(
......
...@@ -151,7 +151,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -151,7 +151,7 @@ class FlashAttentionBackend(AttentionBackend):
return True return True
if kv_cache_dtype.startswith("fp8"): if kv_cache_dtype.startswith("fp8"):
return flash_attn_supports_fp8() return flash_attn_supports_fp8()
return kv_cache_dtype in ["auto"] return kv_cache_dtype in ["auto", "bfloat16"]
@classmethod @classmethod
def supports_sink(cls) -> bool: def supports_sink(cls) -> bool:
......
...@@ -281,6 +281,7 @@ class FlashInferBackend(AttentionBackend): ...@@ -281,6 +281,7 @@ class FlashInferBackend(AttentionBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
"fp8_e5m2", "fp8_e5m2",
......
...@@ -80,7 +80,7 @@ class FlexAttentionBackend(AttentionBackend): ...@@ -80,7 +80,7 @@ class FlexAttentionBackend(AttentionBackend):
torch.bfloat16, torch.bfloat16,
torch.float32, torch.float32,
] ]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "bfloat16"]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
......
...@@ -38,6 +38,7 @@ class CutlassMLABackend(MLACommonBackend): ...@@ -38,6 +38,7 @@ class CutlassMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
] ]
......
...@@ -43,7 +43,10 @@ logger = init_logger(__name__) ...@@ -43,7 +43,10 @@ logger = init_logger(__name__)
class FlashAttnMLABackend(MLACommonBackend): class FlashAttnMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
]
@staticmethod @staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
......
...@@ -38,6 +38,7 @@ class FlashInferMLABackend(MLACommonBackend): ...@@ -38,6 +38,7 @@ class FlashInferMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
] ]
......
...@@ -48,6 +48,7 @@ class FlashMLABackend(MLACommonBackend): ...@@ -48,6 +48,7 @@ class FlashMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
] ]
......
...@@ -76,7 +76,11 @@ structured as: ...@@ -76,7 +76,11 @@ structured as:
class FlashMLASparseBackend(AttentionBackend): class FlashMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8_ds_mla",
]
@staticmethod @staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
......
...@@ -28,7 +28,10 @@ logger = init_logger(__name__) ...@@ -28,7 +28,10 @@ logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend): class TritonMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"] supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
]
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
......
...@@ -259,6 +259,7 @@ class TritonAttentionBackend(AttentionBackend): ...@@ -259,6 +259,7 @@ class TritonAttentionBackend(AttentionBackend):
] ]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto", "auto",
"bfloat16",
"fp8", "fp8",
"fp8_e4m3", "fp8_e4m3",
"fp8_e5m2", "fp8_e5m2",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment