Unverified Commit 77431529 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention] Refactor `check_and_update_config` (#33600)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent ab33d2a6
......@@ -19,7 +19,6 @@ else:
logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
CacheDType = Literal[
"auto",
"bfloat16",
......@@ -39,13 +38,11 @@ KVOffloadingBackend = Literal["native", "lmcache"]
class CacheConfig:
"""Configuration for the KV cache."""
block_size: SkipValidation[BlockSize] = None # type: ignore[assignment]
"""Size of a contiguous cache block in number of tokens. On CUDA devices,
only block sizes up to 32 are supported.
block_size: SkipValidation[int] = None # type: ignore[assignment]
"""Size of a contiguous cache block in number of tokens.
This config has no static default. If left unspecified by the user, it will
be set in `Platform.check_and_update_config()` based on the current
platform."""
This is None until `Platform.check_and_update_config()` sets it based on
the current platform. Always an int by the time the engine starts."""
gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
"""The fraction of GPU memory to be used for the model executor, which can
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
......
......@@ -59,7 +59,6 @@ from vllm.config import (
get_attr_docs,
)
from vllm.config.cache import (
BlockSize,
CacheDType,
KVOffloadingBackend,
MambaCacheMode,
......@@ -431,7 +430,7 @@ class EngineArgs:
max_parallel_loading_workers: int | None = (
ParallelConfig.max_parallel_loading_workers
)
block_size: BlockSize = CacheConfig.block_size
block_size: int = None # type: ignore[assignment]
enable_prefix_caching: bool | None = None
prefix_caching_hash_algo: PrefixCachingHashAlgo = (
CacheConfig.prefix_caching_hash_algo
......
......@@ -163,8 +163,6 @@ class CudaPlatformBase(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
from vllm.v1.attention.backends.registry import AttentionBackendEnum
parallel_config = vllm_config.parallel_config
model_config = vllm_config.model_config
......@@ -172,112 +170,19 @@ class CudaPlatformBase(Platform):
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
user_specified_block_size = cache_config.block_size is not None
if not user_specified_block_size:
cache_config.block_size = 16
# TODO(lucas): handle this more gracefully
# Note: model_config may be None during testing
# Note: block_size is initialized in
# HybridAttentionMambaModelConfig.verify_and_update_config
# for models with both attention and mamba,
# and doesn't need to be reinitialized here
if (
model_config is not None
and model_config.use_mla
and cache_config.block_size is not None
):
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
# If `--attention-config.backend` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the
# required block_size.
use_flashmla = False
use_cutlass_mla = False
use_flashinfer_mla = False
use_flashmla_sparse = False
use_flashinfer_mla_sparse = False
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
if vllm_config.attention_config.backend is None:
# Default case
hf_text_config = model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
if (
cls.is_device_capability_family(100)
and not use_sparse
and qk_nope_head_dim == 128
):
# Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2)
# and only if qk_nope_head_dim == 128 (kernel constraint)
use_flashinfer_mla = True
# Set the backend in AttentionConfig so it's used during
# backend selection
vllm_config.attention_config.backend = (
AttentionBackendEnum.FLASHINFER_MLA
)
elif cls.is_device_capability_family(100) and not use_sparse:
# Fall back to CUTLASS_MLA as 2nd priority on Blackwell
use_cutlass_mla = True
elif is_flashmla_dense_supported()[0]:
# Non-Blackwell with FlashMLA support
use_flashmla = True
else:
# Fallback: will use Triton MLA or other compatible backend
pass
else:
# Forced case
backend = vllm_config.attention_config.backend
use_flashmla = backend == AttentionBackendEnum.FLASHMLA
use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
use_flashmla_sparse = backend == AttentionBackendEnum.FLASHMLA_SPARSE
use_flashinfer_mla_sparse = (
backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE
)
if (
use_flashmla
and is_flashmla_dense_supported()[0]
and cache_config.block_size % 64 != 0
):
cache_config.block_size = 64
logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
if use_cutlass_mla and cache_config.block_size % 128 != 0:
cache_config.block_size = 128
logger.info(
"Forcing kv cache block size to 128 for CUTLASS_MLA backend."
)
if (
use_flashinfer_mla
and cache_config.block_size != 32
and cache_config.block_size % 64 != 0
):
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLA backend."
)
if use_sparse:
if not (use_flashmla_sparse or use_flashinfer_mla_sparse):
use_flashmla_sparse = True
if use_flashmla_sparse and cache_config.block_size != 64:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashMLASparse backend."
)
elif use_flashinfer_mla_sparse and cache_config.block_size not in (
32,
64,
):
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLASparse "
"backend."
)
# Ensure block_size is compatible with the attention backend.
# Note: model_config may be None during testing.
# Skip hybrid (attention+mamba) models — their block_size is
# managed by HybridAttentionMambaModelConfig
if model_config is not None and not model_config.is_hybrid:
cls._update_block_size_for_backend(
vllm_config,
user_specified_block_size,
)
scheduler_config = vllm_config.scheduler_config
# Note: model_config may be None during testing
......@@ -293,6 +198,150 @@ class CudaPlatformBase(Platform):
)
scheduler_config.disable_chunked_mm_input = True
@classmethod
def _update_block_size_for_backend(
cls,
vllm_config: "VllmConfig",
user_specified_block_size: bool,
) -> None:
"""Ensure block_size is compatible with the attention backend.
If the user specified --block-size, the selector validates/filters
backends by that block size (raising on incompatibility). Otherwise,
the backend is selected unconstrained and block_size is set to the
backend's preferred value.
"""
from vllm.config.vllm import set_current_vllm_config
from vllm.v1.attention.selector import AttentionSelectorConfig
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
device_capability = cls.get_device_capability()
if device_capability is None:
return
use_mla = model_config.use_mla
attn_selector_config = AttentionSelectorConfig(
head_size=model_config.get_head_size(),
dtype=model_config.dtype, # type: ignore[arg-type]
kv_cache_dtype=cache_config.cache_dtype,
block_size=cache_config.block_size if user_specified_block_size else None,
use_mla=use_mla,
has_sink=False,
use_sparse=use_mla and hasattr(model_config.hf_config, "index_topk"),
use_mm_prefix=model_config.is_mm_prefix_lm,
)
user_specified_backend = vllm_config.attention_config.backend
num_heads = model_config.get_num_attention_heads(
vllm_config.parallel_config,
)
with set_current_vllm_config(vllm_config):
chosen_backend = cls.select_attention_backend(
selected_backend=user_specified_backend,
attn_selector_config=attn_selector_config,
device_capability=device_capability,
# Don't raise here — we produce better errors below.
raise_on_invalid=False,
num_heads=num_heads,
)
# If the user's --block-size forced a non-optimal backend,
# warn them. Only relevant when the user didn't also specify
# --attention-backend (in which case the choice is explicit).
if (
chosen_backend is not None
and user_specified_block_size
and user_specified_backend is None
):
optimal = cls.select_attention_backend(
selected_backend=None,
attn_selector_config=attn_selector_config._replace(
block_size=None,
),
device_capability=device_capability,
raise_on_invalid=False,
num_heads=num_heads,
)
if optimal is not None and optimal != chosen_backend:
logger.warning(
"--block-size %d is not supported by the preferred "
"%s backend. Using %s instead, which may result "
"in reduced performance. Consider removing "
"--block-size to auto-select the optimal "
"block size.",
cache_config.block_size,
optimal.name,
chosen_backend.name,
)
if chosen_backend is not None:
if user_specified_block_size:
# User's block_size is compatible with the chosen
# backend.
return
# User didn't specify --block-size, so auto-select the
# preferred block size for the chosen backend.
try:
backend_class = chosen_backend.get_class()
except ImportError:
return # Will fail later with a better error
preferred = backend_class.get_preferred_block_size(
cache_config.block_size,
)
if cache_config.block_size != preferred:
logger.info(
"Setting kv cache block size to %d for %s backend.",
preferred,
chosen_backend.name,
)
cache_config.block_size = preferred
return
# No valid backend found. If the user didn't constrain the
# selection, defer the error to get_attn_backend_cls where
# the full config (including per-layer settings) is
# available.
if not user_specified_block_size:
return
if user_specified_backend is not None:
# User specified --block-size and --attention-backend
# and they are incompatible.
try:
backend_class = user_specified_backend.get_class()
supported = backend_class.get_supported_kernel_block_sizes()
except ImportError:
supported = None
raise ValueError(
f"User-specified --block-size "
f"{cache_config.block_size} is incompatible with "
f"the specified --attention-backend "
f"{user_specified_backend.name} (supported kernel "
f"block sizes: {supported}). Either remove "
f"--block-size to auto-select, or choose a "
f"compatible value."
)
else:
# User specified --block-size but no backend supports
# it.
_, invalid_reasons = cls.get_valid_backends(
device_capability=device_capability,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
)
reasons_str = ", ".join(
f"{b.name}: [{', '.join(r)}]" for b, r in invalid_reasons.items()
)
raise ValueError(
f"No valid attention backend found for "
f"--block-size {cache_config.block_size}. "
f"Reasons: {{{reasons_str}}}. Either remove "
f"--block-size to auto-select, or choose a "
f"compatible value."
)
@classmethod
def get_current_memory_usage(
cls, device: torch.types.Device | None = None
......@@ -336,77 +385,125 @@ class CudaPlatformBase(Platform):
return valid_backends_priorities, invalid_reasons
@classmethod
def get_attn_backend_cls(
def select_attention_backend(
cls,
selected_backend: "AttentionBackendEnum",
selected_backend: "AttentionBackendEnum | None",
attn_selector_config: "AttentionSelectorConfig",
device_capability: "DeviceCapability",
raise_on_invalid: bool = True,
num_heads: int | None = None,
) -> str:
device_capability = cls.get_device_capability()
assert device_capability is not None
attn_selector_config = attn_selector_config._replace(block_size=None)
) -> "AttentionBackendEnum | None":
"""Select the best attention backend for the given configuration.
Args:
selected_backend: User-specified backend, or None for auto-selection
attn_selector_config: Configuration for attention selection
device_capability: Device capability info
raise_on_invalid: If True, raise ValueError when no valid backend
num_heads: Number of attention heads per GPU, used for backend
priority ordering on Blackwell GPUs
Returns:
The selected backend enum, or None if no valid backend found
and raise_on_invalid is False
"""
# First try checking just the selected backend, if there is one.
if selected_backend is not None:
try:
backend_class = selected_backend.get_class()
invalid_reasons = backend_class.validate_configuration(
validation_errors = backend_class.validate_configuration(
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
invalid_reasons = ["ImportError"]
if invalid_reasons:
raise ValueError(
f"Selected backend {selected_backend} is not valid for "
f"this configuration. Reason: {invalid_reasons}"
)
else:
logger.info("Using %s backend.", selected_backend)
return selected_backend.get_path()
validation_errors = ["ImportError"]
if validation_errors:
if raise_on_invalid:
raise ValueError(
f"Selected backend {selected_backend} is not valid for "
f"this configuration. Reason: {validation_errors}"
)
return None
return selected_backend
# No selected backend or the selected backend is invalid,
# so we try finding a valid backend.
# No selected backend, so find the best valid one.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
device_capability=device_capability,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
)
reasons_str = (
"{"
+ ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
)
+ "}"
)
config_str = attn_selector_config.__repr__()
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."
)
if len(valid_backends_priorities) == 0:
raise ValueError(
f"No valid attention backend found for {cls.device_name} "
f"with {config_str}. Reasons: {reasons_str}."
)
if raise_on_invalid:
reasons_str = (
"{"
+ ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
)
+ "}"
)
config_str = attn_selector_config.__repr__()
raise ValueError(
f"No valid attention backend found for {cls.device_name} "
f"with {config_str}. Reasons: {reasons_str}."
)
return None
# We have found some valid backends. Select the one with the
# highest priority.
sorted_indices = sorted(
range(len(valid_backends_priorities)),
key=lambda i: valid_backends_priorities[i][1],
)
selected_index = sorted_indices[0]
selected_backend = valid_backends_priorities[selected_index][0]
logger.info_once(
"Using %s attention backend out of potential backends: %s.",
selected_backend.name,
"[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]",
scope="local",
# Select the one with the highest priority (lowest index).
sorted_backends = sorted(valid_backends_priorities, key=lambda x: x[1])
return sorted_backends[0][0]
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum | None",
attn_selector_config: "AttentionSelectorConfig",
num_heads: int | None = None,
) -> str:
device_capability = cls.get_device_capability()
assert device_capability is not None
chosen_backend = cls.select_attention_backend(
selected_backend=selected_backend,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
device_capability=device_capability,
raise_on_invalid=True,
)
assert chosen_backend is not None # raise_on_invalid=True guarantees this
# Log the selection
if selected_backend is not None:
logger.info("Using %s backend.", chosen_backend)
else:
# Get all valid backends for logging
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
device_capability=device_capability,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
)
reasons_str = (
"{"
+ ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
)
+ "}"
)
config_str = attn_selector_config.__repr__()
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."
)
logger.info_once(
"Using %s attention backend out of potential backends: %s",
chosen_backend.name,
tuple(b[0].name for b in valid_backends_priorities),
scope="local",
)
return selected_backend.get_path()
return chosen_backend.get_path()
@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
......
......@@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar
import numpy as np
import torch
......@@ -144,15 +144,9 @@ class AttentionBackend(ABC):
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
from vllm.config.cache import BlockSize
if block_size is None:
return True
valid_sizes = get_args(BlockSize)
if block_size not in valid_sizes:
return False
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
if not supported_kernel_block_sizes:
return True
......@@ -167,6 +161,17 @@ class AttentionBackend(ABC):
return True
return False
@classmethod
def get_preferred_block_size(cls, default_block_size: int = 16) -> int:
supported_sizes = cls.get_supported_kernel_block_sizes()
if not supported_sizes:
return default_block_size
if cls.supports_block_size(default_block_size):
return default_block_size
return min(s.base if isinstance(s, MultipleOf) else s for s in supported_sizes)
@classmethod
def is_mla(cls) -> bool:
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment