"vllm/vscode:/vscode.git/clone" did not exist on "20e4497be23f8e74882bfb0bd0db3d30dd821afc"
Unverified Commit 77a73458 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

Reapply [Attention] Refactor `check_and_update_config` (#35122)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 5578f2a4
......@@ -6,7 +6,12 @@ from unittest.mock import patch
import pytest
import torch
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
from vllm.config import (
AttentionConfig,
CacheConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
......@@ -84,12 +89,15 @@ def test_backend_selection(
"""Test attention backend selection with valid device-backend pairs."""
# Create AttentionConfig with the specified backend
attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
vllm_config = VllmConfig(attention_config=attention_config)
cache_config = CacheConfig(block_size=block_size)
vllm_config = VllmConfig(
attention_config=attention_config, cache_config=cache_config
)
with set_current_vllm_config(vllm_config):
if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size)
backend = get_attn_backend(16, torch.float16, None)
assert backend.get_name() == "CPU_ATTN"
elif device == "hip":
......@@ -104,20 +112,16 @@ def test_backend_selection(
if name == "TRITON_MLA" and block_size == 1:
# TRITON_MLA doesn't support block_size == 1
with pytest.raises(ValueError):
get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
)
get_attn_backend(576, torch.float16, None, use_mla=use_mla)
else:
# Valid backend-block_size combination
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, use_mla=use_mla
)
expected = name
assert backend.get_name() == expected
else:
backend = get_attn_backend(
32, torch.float16, None, block_size, use_mla=use_mla
)
backend = get_attn_backend(32, torch.float16, None, use_mla=use_mla)
expected = "ROCM_ATTN"
assert backend.get_name() == expected
......@@ -141,7 +145,7 @@ def test_backend_selection(
if capability[0] != 10:
pytest.skip("CUTLASS MLA is not supported on this platform")
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, use_mla=use_mla
)
expected = "CUTLASS_MLA"
assert backend.get_name() == expected
......@@ -156,7 +160,7 @@ def test_backend_selection(
"FlashInfer MLA only supports block_size 32 or 64"
)
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, use_mla=use_mla
)
expected = "FLASHINFER_MLA"
assert backend.get_name() == expected
......@@ -175,7 +179,6 @@ def test_backend_selection(
576,
torch.float16,
None,
block_size,
use_mla=use_mla,
)
expected = name
......@@ -190,27 +193,23 @@ def test_backend_selection(
"FlashAttention MLA not supported on this platform"
)
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, use_mla=use_mla
)
expected = "FLASH_ATTN_MLA"
assert backend.get_name() == expected
else:
# TRITON_MLA or other fallback
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, use_mla=use_mla
)
expected = "TRITON_MLA"
assert backend.get_name() == expected
elif name == "FLASHINFER":
backend = get_attn_backend(
64, torch.float16, None, block_size, use_mla=use_mla
)
backend = get_attn_backend(64, torch.float16, None, use_mla=use_mla)
expected = "FLASHINFER"
assert backend.get_name() == expected
elif name == "FLASH_ATTN":
backend = get_attn_backend(
32, torch.float16, None, block_size, use_mla=use_mla
)
backend = get_attn_backend(32, torch.float16, None, use_mla=use_mla)
expected = "FLASH_ATTN"
assert backend.get_name() == expected
......@@ -224,12 +223,12 @@ def test_fp32_fallback(device: str):
with set_current_vllm_config(vllm_config):
if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16)
backend = get_attn_backend(16, torch.float32, None)
assert backend.get_name() == "CPU_ATTN"
elif device == "cuda":
with patch("vllm.platforms.current_platform", CudaPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16)
backend = get_attn_backend(16, torch.float32, None)
assert backend.get_name() == "FLEX_ATTENTION"
......@@ -241,35 +240,40 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
)
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
vllm_config = VllmConfig(attention_config=attention_config)
cache_config = CacheConfig(block_size=16)
vllm_config = VllmConfig(
attention_config=attention_config, cache_config=cache_config
)
with set_current_vllm_config(vllm_config):
# Unsupported CUDA arch
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
backend = get_attn_backend(16, torch.float16, None, 16)
backend = get_attn_backend(16, torch.float16, None)
assert backend.get_name() != "FLASH_ATTN"
# Reset the monkeypatch for subsequent tests
monkeypatch.undo()
# Unsupported data type
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
backend = get_attn_backend(16, torch.float8_e4m3fn, None)
assert backend.get_name() != "FLASH_ATTN"
# Unsupported kv cache data type
backend = get_attn_backend(16, torch.float16, "fp8", 16)
backend = get_attn_backend(16, torch.float16, "fp8")
assert backend.get_name() != "FLASH_ATTN"
# Unsupported block size
backend = get_attn_backend(16, torch.float16, None, 8)
vllm_config.cache_config.block_size = 8
backend = get_attn_backend(16, torch.float16, None)
assert backend.get_name() != "FLASH_ATTN"
# flash-attn is not installed
import sys
vllm_config.cache_config.block_size = 16
original_module = sys.modules.get("vllm_flash_attn")
monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
backend = get_attn_backend(16, torch.float16, None, 16)
backend = get_attn_backend(16, torch.float16, None)
assert backend.get_name() != "FLASH_ATTN"
# Restore the original module if it existed
......@@ -279,7 +283,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
# Unsupported head size
backend = get_attn_backend(17, torch.float16, None, 16)
backend = get_attn_backend(17, torch.float16, None)
assert backend.get_name() != "FLASH_ATTN"
......@@ -320,7 +324,7 @@ def test_auto_backend_selection_behavior():
set_current_vllm_config(vllm_config_auto),
patch("vllm.platforms.current_platform", CpuPlatform()),
):
backend_auto = get_attn_backend(16, torch.float16, None, 16)
backend_auto = get_attn_backend(16, torch.float16, None)
_cached_get_attn_backend.cache_clear()
......@@ -328,7 +332,7 @@ def test_auto_backend_selection_behavior():
set_current_vllm_config(vllm_config_none),
patch("vllm.platforms.current_platform", CpuPlatform()),
):
backend_none = get_attn_backend(16, torch.float16, None, 16)
backend_none = get_attn_backend(16, torch.float16, None)
# Both should select the same backend
assert backend_auto.get_name() == backend_none.get_name()
......@@ -358,7 +362,10 @@ def test_per_head_quant_scales_backend_selection(
backend=AttentionBackendEnum[backend_name],
flash_attn_version=flash_attn_version,
)
vllm_config = VllmConfig(attention_config=attention_config)
cache_config = CacheConfig(block_size=64)
vllm_config = VllmConfig(
attention_config=attention_config, cache_config=cache_config
)
with (
set_current_vllm_config(vllm_config),
......@@ -376,7 +383,6 @@ def test_per_head_quant_scales_backend_selection(
head_size=128,
dtype=torch.float16,
kv_cache_dtype="fp8",
block_size=64,
use_per_head_quant_scales=True,
)
assert backend.get_name() == backend_name
......@@ -386,7 +392,6 @@ def test_per_head_quant_scales_backend_selection(
head_size=128,
dtype=torch.float16,
kv_cache_dtype="fp8",
block_size=64,
use_per_head_quant_scales=True,
)
assert backend_name in str(exc_info.value)
......@@ -13,6 +13,7 @@ import torch.nn as nn
from PIL import Image
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.config.cache import CacheConfig
from vllm.config.multimodal import (
AudioDummyOptions,
BaseDummyOptions,
......@@ -131,7 +132,9 @@ def initialize_dummy_model(
):
temp_file = tempfile.mkstemp()[1]
current_device = torch.get_default_device()
vllm_config = VllmConfig(model_config=model_config)
vllm_config = VllmConfig(
model_config=model_config, cache_config=CacheConfig(block_size=16)
)
with set_current_vllm_config(vllm_config=vllm_config):
init_distributed_environment(
world_size=1,
......
......@@ -80,7 +80,7 @@ def _create_proposer(
device = current_platform.device_type
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(),
cache_config=CacheConfig(block_size=16),
speculative_config=speculative_config,
device_config=DeviceConfig(device=device),
parallel_config=ParallelConfig(),
......
......@@ -2,16 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import field
from typing import Literal
from typing import ClassVar, Literal
from pydantic import Field, SkipValidation, field_validator
from pydantic import Field, SkipValidation, field_validator, model_validator
from vllm.config.utils import config
from vllm.logger import init_logger
logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
CacheDType = Literal[
"auto",
"bfloat16",
......@@ -31,12 +30,13 @@ KVOffloadingBackend = Literal["native", "lmcache"]
class CacheConfig:
"""Configuration for the KV cache."""
block_size: SkipValidation[BlockSize] = None # type: ignore[assignment]
"""Size of a contiguous cache block in number of tokens.
DEFAULT_BLOCK_SIZE: ClassVar[int] = 16
This config has no static default. If left unspecified by the user, it will
be set in `Platform.check_and_update_config()` based on the current
platform."""
block_size: SkipValidation[int] = None # type: ignore[assignment]
"""Size of a contiguous cache block in number of tokens.
Accepts None (meaning "use default"). After construction, always int."""
user_specified_block_size: bool = field(default=False, init=False)
"""Whether block_size was explicitly provided. Derived automatically."""
gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
"""The fraction of GPU memory to be used for the model executor, which can
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
......@@ -169,6 +169,8 @@ class CacheConfig:
"prefix_caching_hash_algo",
"cpu_kvcache_space_bytes",
"mamba_page_size_padded",
"user_specified_block_size",
"_block_size_resolved",
# Post-init/derived counters
"num_gpu_blocks",
"num_cpu_blocks",
......@@ -186,6 +188,22 @@ class CacheConfig:
# metrics info
return {key: str(value) for key, value in self.__dict__.items()}
_block_size_resolved: bool = field(default=False, init=False)
"""Guard against pydantic re-running _apply_block_size_default."""
@model_validator(mode="after")
def _apply_block_size_default(self) -> "CacheConfig":
# Pydantic re-runs validators when CacheConfig is nested inside
# another pydantic model (e.g. VllmConfig). Guard against that.
if self._block_size_resolved:
return self
object.__setattr__(self, "_block_size_resolved", True)
if self.block_size is None:
object.__setattr__(self, "block_size", self.DEFAULT_BLOCK_SIZE)
else:
object.__setattr__(self, "user_specified_block_size", True)
return self
@field_validator("cache_dtype", mode="after")
@classmethod
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
......
......@@ -1026,32 +1026,6 @@ class VllmConfig:
)
current_platform.check_and_update_config(self)
# If DCP, ensure the block size is right.
if self.parallel_config.decode_context_parallel_size > 1:
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
self.parallel_config.cp_kv_cache_interleave_size
!= self.parallel_config.dcp_kv_cache_interleave_size
):
self.parallel_config.cp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
)
logger.warning_once(
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
"_interleave_size. And dcp-kv-cache-interleave-size will be "
"deprecated when PCP is fully supported."
)
assert (
self.parallel_config.cp_kv_cache_interleave_size
<= self.cache_config.block_size
and self.cache_config.block_size
% self.parallel_config.cp_kv_cache_interleave_size
== 0
), (
f"Block_size({self.cache_config.block_size}) should be greater "
"than or equal to and divisible by cp_kv_cache_interleave_size "
f"({self.parallel_config.cp_kv_cache_interleave_size})."
)
# Do this after all the updates to compilation_config.mode
effective_dp_size = (
self.parallel_config.data_parallel_size
......@@ -1219,26 +1193,6 @@ class VllmConfig:
# Default to enable HMA if not explicitly disabled by user or logic above.
self.scheduler_config.disable_hybrid_kv_cache_manager = False
if self.cache_config.mamba_cache_mode == "align":
assert (
self.cache_config.block_size
<= self.scheduler_config.max_num_batched_tokens
), (
"In Mamba cache align mode, block_size "
f"({self.cache_config.block_size}) must be <= "
"max_num_batched_tokens "
f"({self.scheduler_config.max_num_batched_tokens})."
)
if self.scheduler_config.long_prefill_token_threshold > 0:
assert (
self.scheduler_config.long_prefill_token_threshold
>= self.cache_config.block_size
)
assert not self.scheduler_config.disable_chunked_mm_input, (
"Chunked MM input is required because we need the flexibility to "
"schedule a multiple of block_size tokens even if they are in the "
"middle of a mm input"
)
if self.compilation_config.debug_dump_path:
self.compilation_config.debug_dump_path = (
self.compilation_config.debug_dump_path.absolute().expanduser()
......@@ -1673,6 +1627,53 @@ class VllmConfig:
f"compilation_config={self.compilation_config!r}"
)
def validate_block_size(self) -> None:
"""Validate block_size against DCP and mamba constraints.
Called after Platform.update_block_size_for_backend() has
finalised block_size.
"""
block_size = self.cache_config.block_size
# DCP interleave-size compatibility
if self.parallel_config.decode_context_parallel_size > 1:
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
self.parallel_config.cp_kv_cache_interleave_size
!= self.parallel_config.dcp_kv_cache_interleave_size
):
self.parallel_config.cp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
)
logger.warning_once(
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
"_interleave_size. And dcp-kv-cache-interleave-size will be "
"deprecated when PCP is fully supported."
)
assert (
self.parallel_config.cp_kv_cache_interleave_size <= block_size
and block_size % self.parallel_config.cp_kv_cache_interleave_size == 0
), (
f"Block_size({block_size}) should be greater "
"than or equal to and divisible by cp_kv_cache_interleave_size "
f"({self.parallel_config.cp_kv_cache_interleave_size})."
)
# Mamba cache align-mode constraints
if self.cache_config.mamba_cache_mode == "align":
assert block_size <= self.scheduler_config.max_num_batched_tokens, (
"In Mamba cache align mode, block_size "
f"({block_size}) must be <= "
"max_num_batched_tokens "
f"({self.scheduler_config.max_num_batched_tokens})."
)
if self.scheduler_config.long_prefill_token_threshold > 0:
assert self.scheduler_config.long_prefill_token_threshold >= block_size
assert not self.scheduler_config.disable_chunked_mm_input, (
"Chunked MM input is required because we need the flexibility "
"to schedule a multiple of block_size tokens even if they are "
"in the middle of a mm input"
)
@model_validator(mode="after")
def validate_mamba_block_size(self) -> "VllmConfig":
if self.model_config is None:
......
......@@ -500,7 +500,6 @@ def get_current_attn_backend(vllm_config: VllmConfig):
head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype,
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
block_size=vllm_config.cache_config.block_size,
use_mla=vllm_config.model_config.use_mla,
)
return backend
......@@ -726,7 +726,6 @@ class MoRIIOConnectorWorker:
self.model_config.get_head_size(),
self.model_config.dtype,
self.cache_config.cache_dtype,
self.block_size,
use_mla=self.use_mla,
)
......
......@@ -62,7 +62,6 @@ from vllm.config import (
get_attr_docs,
)
from vllm.config.cache import (
BlockSize,
CacheDType,
KVOffloadingBackend,
MambaCacheMode,
......@@ -440,7 +439,7 @@ class EngineArgs:
max_parallel_loading_workers: int | None = (
ParallelConfig.max_parallel_loading_workers
)
block_size: BlockSize = CacheConfig.block_size
block_size: int | None = None
enable_prefix_caching: bool | None = None
prefix_caching_hash_algo: PrefixCachingHashAlgo = (
CacheConfig.prefix_caching_hash_algo
......@@ -1521,7 +1520,7 @@ class EngineArgs:
)
cache_config = CacheConfig(
block_size=self.block_size,
block_size=self.block_size, # type: ignore[arg-type]
gpu_memory_utilization=self.gpu_memory_utilization,
kv_cache_memory_bytes=self.kv_cache_memory_bytes,
cache_dtype=resolved_cache_dtype, # type: ignore[arg-type]
......
......@@ -221,11 +221,9 @@ class Attention(nn.Module, AttentionLayerBase):
vllm_config = get_current_vllm_config()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
# llm-compressor mdls need to set cache_dtype to "fp8" manually.
......@@ -275,7 +273,6 @@ class Attention(nn.Module, AttentionLayerBase):
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=False,
has_sink=self.has_sink,
use_mm_prefix=self.use_mm_prefix,
......
......@@ -30,9 +30,8 @@ from vllm.v1.kv_cache_interface import (
def create_chunked_local_attention_backend(
underlying_attn_backend: AttentionBackend,
attention_chunk_size: int,
block_size: int,
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_"
underlying_builder = underlying_attn_backend.get_builder_cls()
assert issubclass(underlying_builder, AttentionMetadataBuilder)
......@@ -55,7 +54,9 @@ def create_chunked_local_attention_backend(
fast_build: bool = False,
):
cm, make_virtual_batches_block_table = make_local_attention_virtual_batches(
attention_chunk_size, common_attn_metadata, block_size
attention_chunk_size,
common_attn_metadata,
self.kv_cache_spec.block_size,
)
metadata = super().build(common_prefix_len, cm, fast_build)
metadata.make_virtual_batches_block_table = make_virtual_batches_block_table
......@@ -94,16 +95,12 @@ class ChunkedLocalAttention(Attention):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
underlying_attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype)
attn_backend = create_chunked_local_attention_backend(
underlying_attn_backend, attention_chunk_size, block_size
underlying_attn_backend, attention_chunk_size
)
super().__init__(
......
......@@ -188,10 +188,8 @@ class CrossAttention(Attention):
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_DECODER, (
......@@ -202,7 +200,6 @@ class CrossAttention(Attention):
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_DECODER,
)
attn_backend = create_cross_attention_backend(underlying_attn_backend)
......
......@@ -66,16 +66,13 @@ class EncoderOnlyAttention(Attention):
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_ONLY,
)
......
......@@ -323,11 +323,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
self.quant_config = quant_config
......@@ -336,7 +334,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=True,
use_sparse=use_sparse,
num_heads=self.num_heads,
......@@ -449,17 +446,24 @@ class MLAAttention(nn.Module, AttentionLayerBase):
)
# Attributes for forward_impl method
self.chunked_prefill_workspace_size = (
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
get_current_vllm_config()
)
)
self._vllm_config = get_current_vllm_config()
self._chunked_prefill_workspace_size: int | None = None
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
compile_native=True,
)
@property
def chunked_prefill_workspace_size(self) -> int:
if self._chunked_prefill_workspace_size is None:
self._chunked_prefill_workspace_size = (
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
self._vllm_config
)
)
return self._chunked_prefill_workspace_size
def forward(
self,
q: torch.Tensor,
......
......@@ -126,17 +126,13 @@ class StaticSinkAttention(Attention, CustomOp):
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if attn_backend is not None:
underlying_attn_backend = attn_backend
else:
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
underlying_attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype)
attn_backend = create_static_sink_attention_backend(
underlying_attn_backend, # type: ignore[arg-type]
sink_len=sink_len,
......@@ -153,7 +149,6 @@ class StaticSinkAttention(Attention, CustomOp):
CustomOp.__init__(self)
self.sink_len = sink_len
self.block_size = block_size
self.sink_populated = False
self.sink_key = None
self.sink_value = None
......@@ -212,12 +207,12 @@ class StaticSinkAttention(Attention, CustomOp):
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
self.block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
return SinkFullAttentionSpec(
block_size=block_size,
block_size=self.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
head_size_v=self.head_size_v,
......
......@@ -217,10 +217,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
)
# override attention block size if either (a) the
# user has not set it or (b) the user has set it
# too small.
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
# override attention block size if it is too small,
# even if the user has explicitly set it
if cache_config.block_size < attn_block_size:
cache_config.block_size = attn_block_size
logger.info(
"Setting attention block size to %d tokens "
......
......@@ -290,16 +290,13 @@ class WhisperCausalAttentionWithBlockPooling(Attention):
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=attn_type,
)
attn_backend = create_whisper_attention_backend_with_block_pooling(
......
......@@ -185,7 +185,7 @@ class CpuPlatform(Platform):
cache_config = vllm_config.cache_config
if cache_config.block_size is None:
if not cache_config.user_specified_block_size:
cache_config.block_size = 128
if cache_config.block_size % 32 != 0:
......@@ -361,6 +361,12 @@ class CpuPlatform(Platform):
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
# TODO: CPU still sets block_size in check_and_update_config.
# Move that logic here so block_size is chosen by the backend.
pass
@classmethod
def get_allowed_cpu_core_node_list(cls) -> tuple[list[int], list[LogicalCPUInfo]]:
assert platform.system() == "Linux"
......
......@@ -166,122 +166,12 @@ class CudaPlatformBase(Platform):
@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
from vllm.v1.attention.backends.registry import AttentionBackendEnum
parallel_config = vllm_config.parallel_config
model_config = vllm_config.model_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
# TODO(lucas): handle this more gracefully
# Note: model_config may be None during testing
# Note: block_size is initialized in
# HybridAttentionMambaModelConfig.verify_and_update_config
# for models with both attention and mamba,
# and doesn't need to be reinitialized here
if (
model_config is not None
and model_config.use_mla
and cache_config.block_size is not None
):
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
# If `--attention-config.backend` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the
# required block_size.
use_flashmla = False
use_cutlass_mla = False
use_flashinfer_mla = False
use_flashmla_sparse = False
use_flashinfer_mla_sparse = False
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
if vllm_config.attention_config.backend is None:
# Default case
hf_text_config = model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
if (
cls.is_device_capability_family(100)
and not use_sparse
and qk_nope_head_dim == 128
):
# Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2)
# and only if qk_nope_head_dim == 128 (kernel constraint)
use_flashinfer_mla = True
# Set the backend in AttentionConfig so it's used during
# backend selection
vllm_config.attention_config.backend = (
AttentionBackendEnum.FLASHINFER_MLA
)
elif cls.is_device_capability_family(100) and not use_sparse:
# Fall back to CUTLASS_MLA as 2nd priority on Blackwell
use_cutlass_mla = True
elif is_flashmla_dense_supported()[0]:
# Non-Blackwell with FlashMLA support
use_flashmla = True
else:
# Fallback: will use Triton MLA or other compatible backend
pass
else:
# Forced case
backend = vllm_config.attention_config.backend
use_flashmla = backend == AttentionBackendEnum.FLASHMLA
use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
use_flashmla_sparse = backend == AttentionBackendEnum.FLASHMLA_SPARSE
use_flashinfer_mla_sparse = (
backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE
)
if (
use_flashmla
and is_flashmla_dense_supported()[0]
and cache_config.block_size % 64 != 0
):
cache_config.block_size = 64
logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
if use_cutlass_mla and cache_config.block_size % 128 != 0:
cache_config.block_size = 128
logger.info(
"Forcing kv cache block size to 128 for CUTLASS_MLA backend."
)
if (
use_flashinfer_mla
and cache_config.block_size != 32
and cache_config.block_size % 64 != 0
):
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLA backend."
)
if use_sparse:
if not (use_flashmla_sparse or use_flashinfer_mla_sparse):
use_flashmla_sparse = True
if use_flashmla_sparse and cache_config.block_size != 64:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashMLASparse backend."
)
elif use_flashinfer_mla_sparse and cache_config.block_size not in (
32,
64,
):
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLASparse "
"backend."
)
scheduler_config = vllm_config.scheduler_config
# Note: model_config may be None during testing
if (
......@@ -312,10 +202,10 @@ class CudaPlatformBase(Platform):
num_heads: int | None = None,
) -> tuple[
list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]],
dict["AttentionBackendEnum", tuple[int, list[str]]],
]:
valid_backends_priorities = []
invalid_reasons = {}
invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {}
backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla,
......@@ -332,7 +222,7 @@ class CudaPlatformBase(Platform):
except ImportError:
invalid_reasons_i = ["ImportError"]
if invalid_reasons_i:
invalid_reasons[backend] = invalid_reasons_i
invalid_reasons[backend] = (priority, invalid_reasons_i)
else:
valid_backends_priorities.append((backend, priority))
......@@ -341,14 +231,13 @@ class CudaPlatformBase(Platform):
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
selected_backend: "AttentionBackendEnum | None",
attn_selector_config: "AttentionSelectorConfig",
num_heads: int | None = None,
) -> str:
device_capability = cls.get_device_capability()
assert device_capability is not None
attn_selector_config = attn_selector_config._replace(block_size=None)
# First try checking just the selected backend, if there is one.
if selected_backend is not None:
try:
......@@ -370,7 +259,7 @@ class CudaPlatformBase(Platform):
# No selected backend or the selected backend is invalid,
# so we try finding a valid backend.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
valid_backends_priorities, all_invalid_reasons = cls.get_valid_backends(
device_capability=device_capability,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
......@@ -379,7 +268,7 @@ class CudaPlatformBase(Platform):
"{"
+ ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
for backend, (_, reasons) in all_invalid_reasons.items()
)
+ "}"
)
......@@ -402,6 +291,29 @@ class CudaPlatformBase(Platform):
)
selected_index = sorted_indices[0]
selected_backend = valid_backends_priorities[selected_index][0]
selected_priority = valid_backends_priorities[selected_index][1]
# If the user specified --block-size (but not --attention-backend),
# check whether that constraint precluded any higher-priority backends.
if attn_selector_config.block_size is not None:
excluded = [
backend
for backend, (priority, reasons) in all_invalid_reasons.items()
if priority < selected_priority
and reasons == ["block_size not supported"]
]
if excluded:
names = ", ".join(b.name for b in excluded)
logger.warning(
"--block-size %d precluded higher-priority backend(s) "
"%s. Using %s instead, which may result in reduced "
"performance. Consider removing --block-size to "
"auto-select the optimal block size.",
attn_selector_config.block_size,
names,
selected_backend.name,
)
logger.info_once(
"Using %s attention backend out of potential backends: %s.",
selected_backend.name,
......
......@@ -420,6 +420,56 @@ class Platform:
"""
pass
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
"""
Ensure block_size is compatible with the attention backend.
"""
from vllm.config.cache import CacheConfig
cache_config = vllm_config.cache_config
if cache_config.user_specified_block_size:
# User specified --block-size; keep it.
return
model_config = vllm_config.model_config
# model_config may be None during testing.
# Skip hybrid models — their block_size is managed by
# HybridAttentionMambaModelConfig.
if model_config is None or model_config.is_hybrid:
cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE
return
from vllm.config.vllm import (
get_layers_from_vllm_config,
set_current_vllm_config,
)
from vllm.model_executor.layers.attention_layer_base import (
AttentionLayerBase,
)
attn_layers = get_layers_from_vllm_config(
vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
if not attn_layers:
cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE
return
first_layer = next(iter(attn_layers.values()))
backend_cls = first_layer.get_attn_backend()
with set_current_vllm_config(vllm_config):
preferred = backend_cls.get_preferred_block_size(
CacheConfig.DEFAULT_BLOCK_SIZE
)
if preferred != CacheConfig.DEFAULT_BLOCK_SIZE:
logger.info(
"Setting kv cache block size to %d for %s backend.",
preferred,
backend_cls.get_name(),
)
cache_config.block_size = preferred
@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
"""
......
......@@ -687,7 +687,7 @@ class RocmPlatform(Platform):
)
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
if cache_config and cache_config.block_size is None:
if cache_config and not cache_config.user_specified_block_size:
if (
envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER
# NOTE: This block has been deprecated
......@@ -707,6 +707,12 @@ class RocmPlatform(Platform):
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
# TODO: ROCm still sets block_size in check_and_update_config.
# Move that logic here so block_size is chosen by the backend.
pass
@classmethod
def verify_model_arch(cls, model_arch: str) -> None:
if model_arch in _ROCM_UNSUPPORTED_MODELS:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment