Unverified Commit 77431529 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Refactor `check_and_update_config` (#33600)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent ab33d2a6
...@@ -19,7 +19,6 @@ else: ...@@ -19,7 +19,6 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
CacheDType = Literal[ CacheDType = Literal[
"auto", "auto",
"bfloat16", "bfloat16",
...@@ -39,13 +38,11 @@ KVOffloadingBackend = Literal["native", "lmcache"] ...@@ -39,13 +38,11 @@ KVOffloadingBackend = Literal["native", "lmcache"]
class CacheConfig: class CacheConfig:
"""Configuration for the KV cache.""" """Configuration for the KV cache."""
block_size: SkipValidation[BlockSize] = None # type: ignore[assignment] block_size: SkipValidation[int] = None # type: ignore[assignment]
"""Size of a contiguous cache block in number of tokens. On CUDA devices, """Size of a contiguous cache block in number of tokens.
only block sizes up to 32 are supported.
This config has no static default. If left unspecified by the user, it will This is None until `Platform.check_and_update_config()` sets it based on
be set in `Platform.check_and_update_config()` based on the current the current platform. Always an int by the time the engine starts."""
platform."""
gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1) gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
"""The fraction of GPU memory to be used for the model executor, which can """The fraction of GPU memory to be used for the model executor, which can
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
......
...@@ -59,7 +59,6 @@ from vllm.config import ( ...@@ -59,7 +59,6 @@ from vllm.config import (
get_attr_docs, get_attr_docs,
) )
from vllm.config.cache import ( from vllm.config.cache import (
BlockSize,
CacheDType, CacheDType,
KVOffloadingBackend, KVOffloadingBackend,
MambaCacheMode, MambaCacheMode,
...@@ -431,7 +430,7 @@ class EngineArgs: ...@@ -431,7 +430,7 @@ class EngineArgs:
max_parallel_loading_workers: int | None = ( max_parallel_loading_workers: int | None = (
ParallelConfig.max_parallel_loading_workers ParallelConfig.max_parallel_loading_workers
) )
block_size: BlockSize = CacheConfig.block_size block_size: int = None # type: ignore[assignment]
enable_prefix_caching: bool | None = None enable_prefix_caching: bool | None = None
prefix_caching_hash_algo: PrefixCachingHashAlgo = ( prefix_caching_hash_algo: PrefixCachingHashAlgo = (
CacheConfig.prefix_caching_hash_algo CacheConfig.prefix_caching_hash_algo
......
...@@ -163,8 +163,6 @@ class CudaPlatformBase(Platform): ...@@ -163,8 +163,6 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
from vllm.v1.attention.backends.registry import AttentionBackendEnum
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
...@@ -172,112 +170,19 @@ class CudaPlatformBase(Platform): ...@@ -172,112 +170,19 @@ class CudaPlatformBase(Platform):
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None: user_specified_block_size = cache_config.block_size is not None
if not user_specified_block_size:
cache_config.block_size = 16 cache_config.block_size = 16
# TODO(lucas): handle this more gracefully # Ensure block_size is compatible with the attention backend.
# Note: model_config may be None during testing # Note: model_config may be None during testing.
# Note: block_size is initialized in # Skip hybrid (attention+mamba) models — their block_size is
# HybridAttentionMambaModelConfig.verify_and_update_config # managed by HybridAttentionMambaModelConfig
# for models with both attention and mamba, if model_config is not None and not model_config.is_hybrid:
# and doesn't need to be reinitialized here cls._update_block_size_for_backend(
if ( vllm_config,
model_config is not None user_specified_block_size,
and model_config.use_mla )
and cache_config.block_size is not None
):
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
# If `--attention-config.backend` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the
# required block_size.
use_flashmla = False
use_cutlass_mla = False
use_flashinfer_mla = False
use_flashmla_sparse = False
use_flashinfer_mla_sparse = False
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
if vllm_config.attention_config.backend is None:
# Default case
hf_text_config = model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
if (
cls.is_device_capability_family(100)
and not use_sparse
and qk_nope_head_dim == 128
):
# Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2)
# and only if qk_nope_head_dim == 128 (kernel constraint)
use_flashinfer_mla = True
# Set the backend in AttentionConfig so it's used during
# backend selection
vllm_config.attention_config.backend = (
AttentionBackendEnum.FLASHINFER_MLA
)
elif cls.is_device_capability_family(100) and not use_sparse:
# Fall back to CUTLASS_MLA as 2nd priority on Blackwell
use_cutlass_mla = True
elif is_flashmla_dense_supported()[0]:
# Non-Blackwell with FlashMLA support
use_flashmla = True
else:
# Fallback: will use Triton MLA or other compatible backend
pass
else:
# Forced case
backend = vllm_config.attention_config.backend
use_flashmla = backend == AttentionBackendEnum.FLASHMLA
use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
use_flashmla_sparse = backend == AttentionBackendEnum.FLASHMLA_SPARSE
use_flashinfer_mla_sparse = (
backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE
)
if (
use_flashmla
and is_flashmla_dense_supported()[0]
and cache_config.block_size % 64 != 0
):
cache_config.block_size = 64
logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
if use_cutlass_mla and cache_config.block_size % 128 != 0:
cache_config.block_size = 128
logger.info(
"Forcing kv cache block size to 128 for CUTLASS_MLA backend."
)
if (
use_flashinfer_mla
and cache_config.block_size != 32
and cache_config.block_size % 64 != 0
):
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLA backend."
)
if use_sparse:
if not (use_flashmla_sparse or use_flashinfer_mla_sparse):
use_flashmla_sparse = True
if use_flashmla_sparse and cache_config.block_size != 64:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashMLASparse backend."
)
elif use_flashinfer_mla_sparse and cache_config.block_size not in (
32,
64,
):
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLASparse "
"backend."
)
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
# Note: model_config may be None during testing # Note: model_config may be None during testing
...@@ -293,6 +198,150 @@ class CudaPlatformBase(Platform): ...@@ -293,6 +198,150 @@ class CudaPlatformBase(Platform):
) )
scheduler_config.disable_chunked_mm_input = True scheduler_config.disable_chunked_mm_input = True
@classmethod
def _update_block_size_for_backend(
cls,
vllm_config: "VllmConfig",
user_specified_block_size: bool,
) -> None:
"""Ensure block_size is compatible with the attention backend.
If the user specified --block-size, the selector validates/filters
backends by that block size (raising on incompatibility). Otherwise,
the backend is selected unconstrained and block_size is set to the
backend's preferred value.
"""
from vllm.config.vllm import set_current_vllm_config
from vllm.v1.attention.selector import AttentionSelectorConfig
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
device_capability = cls.get_device_capability()
if device_capability is None:
return
use_mla = model_config.use_mla
attn_selector_config = AttentionSelectorConfig(
head_size=model_config.get_head_size(),
dtype=model_config.dtype, # type: ignore[arg-type]
kv_cache_dtype=cache_config.cache_dtype,
block_size=cache_config.block_size if user_specified_block_size else None,
use_mla=use_mla,
has_sink=False,
use_sparse=use_mla and hasattr(model_config.hf_config, "index_topk"),
use_mm_prefix=model_config.is_mm_prefix_lm,
)
user_specified_backend = vllm_config.attention_config.backend
num_heads = model_config.get_num_attention_heads(
vllm_config.parallel_config,
)
with set_current_vllm_config(vllm_config):
chosen_backend = cls.select_attention_backend(
selected_backend=user_specified_backend,
attn_selector_config=attn_selector_config,
device_capability=device_capability,
# Don't raise here — we produce better errors below.
raise_on_invalid=False,
num_heads=num_heads,
)
# If the user's --block-size forced a non-optimal backend,
# warn them. Only relevant when the user didn't also specify
# --attention-backend (in which case the choice is explicit).
if (
chosen_backend is not None
and user_specified_block_size
and user_specified_backend is None
):
optimal = cls.select_attention_backend(
selected_backend=None,
attn_selector_config=attn_selector_config._replace(
block_size=None,
),
device_capability=device_capability,
raise_on_invalid=False,
num_heads=num_heads,
)
if optimal is not None and optimal != chosen_backend:
logger.warning(
"--block-size %d is not supported by the preferred "
"%s backend. Using %s instead, which may result "
"in reduced performance. Consider removing "
"--block-size to auto-select the optimal "
"block size.",
cache_config.block_size,
optimal.name,
chosen_backend.name,
)
if chosen_backend is not None:
if user_specified_block_size:
# User's block_size is compatible with the chosen
# backend.
return
# User didn't specify --block-size, so auto-select the
# preferred block size for the chosen backend.
try:
backend_class = chosen_backend.get_class()
except ImportError:
return # Will fail later with a better error
preferred = backend_class.get_preferred_block_size(
cache_config.block_size,
)
if cache_config.block_size != preferred:
logger.info(
"Setting kv cache block size to %d for %s backend.",
preferred,
chosen_backend.name,
)
cache_config.block_size = preferred
return
# No valid backend found. If the user didn't constrain the
# selection, defer the error to get_attn_backend_cls where
# the full config (including per-layer settings) is
# available.
if not user_specified_block_size:
return
if user_specified_backend is not None:
# User specified --block-size and --attention-backend
# and they are incompatible.
try:
backend_class = user_specified_backend.get_class()
supported = backend_class.get_supported_kernel_block_sizes()
except ImportError:
supported = None
raise ValueError(
f"User-specified --block-size "
f"{cache_config.block_size} is incompatible with "
f"the specified --attention-backend "
f"{user_specified_backend.name} (supported kernel "
f"block sizes: {supported}). Either remove "
f"--block-size to auto-select, or choose a "
f"compatible value."
)
else:
# User specified --block-size but no backend supports
# it.
_, invalid_reasons = cls.get_valid_backends(
device_capability=device_capability,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
)
reasons_str = ", ".join(
f"{b.name}: [{', '.join(r)}]" for b, r in invalid_reasons.items()
)
raise ValueError(
f"No valid attention backend found for "
f"--block-size {cache_config.block_size}. "
f"Reasons: {{{reasons_str}}}. Either remove "
f"--block-size to auto-select, or choose a "
f"compatible value."
)
@classmethod @classmethod
def get_current_memory_usage( def get_current_memory_usage(
cls, device: torch.types.Device | None = None cls, device: torch.types.Device | None = None
...@@ -336,77 +385,125 @@ class CudaPlatformBase(Platform): ...@@ -336,77 +385,125 @@ class CudaPlatformBase(Platform):
return valid_backends_priorities, invalid_reasons return valid_backends_priorities, invalid_reasons
@classmethod @classmethod
def get_attn_backend_cls( def select_attention_backend(
cls, cls,
selected_backend: "AttentionBackendEnum", selected_backend: "AttentionBackendEnum | None",
attn_selector_config: "AttentionSelectorConfig", attn_selector_config: "AttentionSelectorConfig",
device_capability: "DeviceCapability",
raise_on_invalid: bool = True,
num_heads: int | None = None, num_heads: int | None = None,
) -> str: ) -> "AttentionBackendEnum | None":
device_capability = cls.get_device_capability() """Select the best attention backend for the given configuration.
assert device_capability is not None
Args:
attn_selector_config = attn_selector_config._replace(block_size=None) selected_backend: User-specified backend, or None for auto-selection
attn_selector_config: Configuration for attention selection
device_capability: Device capability info
raise_on_invalid: If True, raise ValueError when no valid backend
num_heads: Number of attention heads per GPU, used for backend
priority ordering on Blackwell GPUs
Returns:
The selected backend enum, or None if no valid backend found
and raise_on_invalid is False
"""
# First try checking just the selected backend, if there is one. # First try checking just the selected backend, if there is one.
if selected_backend is not None: if selected_backend is not None:
try: try:
backend_class = selected_backend.get_class() backend_class = selected_backend.get_class()
invalid_reasons = backend_class.validate_configuration( validation_errors = backend_class.validate_configuration(
device_capability=device_capability, device_capability=device_capability,
**attn_selector_config._asdict(), **attn_selector_config._asdict(),
) )
except ImportError: except ImportError:
invalid_reasons = ["ImportError"] validation_errors = ["ImportError"]
if invalid_reasons: if validation_errors:
raise ValueError( if raise_on_invalid:
f"Selected backend {selected_backend} is not valid for " raise ValueError(
f"this configuration. Reason: {invalid_reasons}" f"Selected backend {selected_backend} is not valid for "
) f"this configuration. Reason: {validation_errors}"
else: )
logger.info("Using %s backend.", selected_backend) return None
return selected_backend.get_path() return selected_backend
# No selected backend or the selected backend is invalid, # No selected backend, so find the best valid one.
# so we try finding a valid backend.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends( valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
device_capability=device_capability, device_capability=device_capability,
attn_selector_config=attn_selector_config, attn_selector_config=attn_selector_config,
num_heads=num_heads, num_heads=num_heads,
) )
reasons_str = (
"{"
+ ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
)
+ "}"
)
config_str = attn_selector_config.__repr__()
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."
)
if len(valid_backends_priorities) == 0: if len(valid_backends_priorities) == 0:
raise ValueError( if raise_on_invalid:
f"No valid attention backend found for {cls.device_name} " reasons_str = (
f"with {config_str}. Reasons: {reasons_str}." "{"
) + ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
)
+ "}"
)
config_str = attn_selector_config.__repr__()
raise ValueError(
f"No valid attention backend found for {cls.device_name} "
f"with {config_str}. Reasons: {reasons_str}."
)
return None
# We have found some valid backends. Select the one with the # Select the one with the highest priority (lowest index).
# highest priority. sorted_backends = sorted(valid_backends_priorities, key=lambda x: x[1])
sorted_indices = sorted( return sorted_backends[0][0]
range(len(valid_backends_priorities)),
key=lambda i: valid_backends_priorities[i][1], @classmethod
) def get_attn_backend_cls(
selected_index = sorted_indices[0] cls,
selected_backend = valid_backends_priorities[selected_index][0] selected_backend: "AttentionBackendEnum | None",
logger.info_once( attn_selector_config: "AttentionSelectorConfig",
"Using %s attention backend out of potential backends: %s.", num_heads: int | None = None,
selected_backend.name, ) -> str:
"[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]", device_capability = cls.get_device_capability()
scope="local", assert device_capability is not None
chosen_backend = cls.select_attention_backend(
selected_backend=selected_backend,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
device_capability=device_capability,
raise_on_invalid=True,
) )
assert chosen_backend is not None # raise_on_invalid=True guarantees this
# Log the selection
if selected_backend is not None:
logger.info("Using %s backend.", chosen_backend)
else:
# Get all valid backends for logging
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
device_capability=device_capability,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
)
reasons_str = (
"{"
+ ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
)
+ "}"
)
config_str = attn_selector_config.__repr__()
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."
)
logger.info_once(
"Using %s attention backend out of potential backends: %s",
chosen_backend.name,
tuple(b[0].name for b in valid_backends_priorities),
scope="local",
)
return selected_backend.get_path() return chosen_backend.get_path()
@classmethod @classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar
import numpy as np import numpy as np
import torch import torch
...@@ -144,15 +144,9 @@ class AttentionBackend(ABC): ...@@ -144,15 +144,9 @@ class AttentionBackend(ABC):
@classmethod @classmethod
def supports_block_size(cls, block_size: int | None) -> bool: def supports_block_size(cls, block_size: int | None) -> bool:
from vllm.config.cache import BlockSize
if block_size is None: if block_size is None:
return True return True
valid_sizes = get_args(BlockSize)
if block_size not in valid_sizes:
return False
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes() supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
if not supported_kernel_block_sizes: if not supported_kernel_block_sizes:
return True return True
...@@ -167,6 +161,17 @@ class AttentionBackend(ABC): ...@@ -167,6 +161,17 @@ class AttentionBackend(ABC):
return True return True
return False return False
@classmethod
def get_preferred_block_size(cls, default_block_size: int = 16) -> int:
supported_sizes = cls.get_supported_kernel_block_sizes()
if not supported_sizes:
return default_block_size
if cls.supports_block_size(default_block_size):
return default_block_size
return min(s.base if isinstance(s, MultipleOf) else s for s in supported_sizes)
@classmethod @classmethod
def is_mla(cls) -> bool: def is_mla(cls) -> bool:
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment