Unverified Commit 77a73458 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

Reapply [Attention] Refactor `check_and_update_config` (#35122)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 5578f2a4
...@@ -6,7 +6,12 @@ from unittest.mock import patch ...@@ -6,7 +6,12 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config from vllm.config import (
AttentionConfig,
CacheConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform from vllm.platforms.cuda import CudaPlatform
...@@ -84,12 +89,15 @@ def test_backend_selection( ...@@ -84,12 +89,15 @@ def test_backend_selection(
"""Test attention backend selection with valid device-backend pairs.""" """Test attention backend selection with valid device-backend pairs."""
# Create AttentionConfig with the specified backend # Create AttentionConfig with the specified backend
attention_config = AttentionConfig(backend=AttentionBackendEnum[name]) attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
vllm_config = VllmConfig(attention_config=attention_config) cache_config = CacheConfig(block_size=block_size)
vllm_config = VllmConfig(
attention_config=attention_config, cache_config=cache_config
)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
if device == "cpu": if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()): with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size) backend = get_attn_backend(16, torch.float16, None)
assert backend.get_name() == "CPU_ATTN" assert backend.get_name() == "CPU_ATTN"
elif device == "hip": elif device == "hip":
...@@ -104,20 +112,16 @@ def test_backend_selection( ...@@ -104,20 +112,16 @@ def test_backend_selection(
if name == "TRITON_MLA" and block_size == 1: if name == "TRITON_MLA" and block_size == 1:
# TRITON_MLA doesn't support block_size == 1 # TRITON_MLA doesn't support block_size == 1
with pytest.raises(ValueError): with pytest.raises(ValueError):
get_attn_backend( get_attn_backend(576, torch.float16, None, use_mla=use_mla)
576, torch.float16, None, block_size, use_mla=use_mla
)
else: else:
# Valid backend-block_size combination # Valid backend-block_size combination
backend = get_attn_backend( backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla 576, torch.float16, None, use_mla=use_mla
) )
expected = name expected = name
assert backend.get_name() == expected assert backend.get_name() == expected
else: else:
backend = get_attn_backend( backend = get_attn_backend(32, torch.float16, None, use_mla=use_mla)
32, torch.float16, None, block_size, use_mla=use_mla
)
expected = "ROCM_ATTN" expected = "ROCM_ATTN"
assert backend.get_name() == expected assert backend.get_name() == expected
...@@ -141,7 +145,7 @@ def test_backend_selection( ...@@ -141,7 +145,7 @@ def test_backend_selection(
if capability[0] != 10: if capability[0] != 10:
pytest.skip("CUTLASS MLA is not supported on this platform") pytest.skip("CUTLASS MLA is not supported on this platform")
backend = get_attn_backend( backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla 576, torch.float16, None, use_mla=use_mla
) )
expected = "CUTLASS_MLA" expected = "CUTLASS_MLA"
assert backend.get_name() == expected assert backend.get_name() == expected
...@@ -156,7 +160,7 @@ def test_backend_selection( ...@@ -156,7 +160,7 @@ def test_backend_selection(
"FlashInfer MLA only supports block_size 32 or 64" "FlashInfer MLA only supports block_size 32 or 64"
) )
backend = get_attn_backend( backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla 576, torch.float16, None, use_mla=use_mla
) )
expected = "FLASHINFER_MLA" expected = "FLASHINFER_MLA"
assert backend.get_name() == expected assert backend.get_name() == expected
...@@ -175,7 +179,6 @@ def test_backend_selection( ...@@ -175,7 +179,6 @@ def test_backend_selection(
576, 576,
torch.float16, torch.float16,
None, None,
block_size,
use_mla=use_mla, use_mla=use_mla,
) )
expected = name expected = name
...@@ -190,27 +193,23 @@ def test_backend_selection( ...@@ -190,27 +193,23 @@ def test_backend_selection(
"FlashAttention MLA not supported on this platform" "FlashAttention MLA not supported on this platform"
) )
backend = get_attn_backend( backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla 576, torch.float16, None, use_mla=use_mla
) )
expected = "FLASH_ATTN_MLA" expected = "FLASH_ATTN_MLA"
assert backend.get_name() == expected assert backend.get_name() == expected
else: else:
# TRITON_MLA or other fallback # TRITON_MLA or other fallback
backend = get_attn_backend( backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla 576, torch.float16, None, use_mla=use_mla
) )
expected = "TRITON_MLA" expected = "TRITON_MLA"
assert backend.get_name() == expected assert backend.get_name() == expected
elif name == "FLASHINFER": elif name == "FLASHINFER":
backend = get_attn_backend( backend = get_attn_backend(64, torch.float16, None, use_mla=use_mla)
64, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASHINFER" expected = "FLASHINFER"
assert backend.get_name() == expected assert backend.get_name() == expected
elif name == "FLASH_ATTN": elif name == "FLASH_ATTN":
backend = get_attn_backend( backend = get_attn_backend(32, torch.float16, None, use_mla=use_mla)
32, torch.float16, None, block_size, use_mla=use_mla
)
expected = "FLASH_ATTN" expected = "FLASH_ATTN"
assert backend.get_name() == expected assert backend.get_name() == expected
...@@ -224,12 +223,12 @@ def test_fp32_fallback(device: str): ...@@ -224,12 +223,12 @@ def test_fp32_fallback(device: str):
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
if device == "cpu": if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()): with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16) backend = get_attn_backend(16, torch.float32, None)
assert backend.get_name() == "CPU_ATTN" assert backend.get_name() == "CPU_ATTN"
elif device == "cuda": elif device == "cuda":
with patch("vllm.platforms.current_platform", CudaPlatform()): with patch("vllm.platforms.current_platform", CudaPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16) backend = get_attn_backend(16, torch.float32, None)
assert backend.get_name() == "FLEX_ATTENTION" assert backend.get_name() == "FLEX_ATTENTION"
...@@ -241,35 +240,40 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): ...@@ -241,35 +240,40 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
) )
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN) attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
vllm_config = VllmConfig(attention_config=attention_config) cache_config = CacheConfig(block_size=16)
vllm_config = VllmConfig(
attention_config=attention_config, cache_config=cache_config
)
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
# Unsupported CUDA arch # Unsupported CUDA arch
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5)) monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
backend = get_attn_backend(16, torch.float16, None, 16) backend = get_attn_backend(16, torch.float16, None)
assert backend.get_name() != "FLASH_ATTN" assert backend.get_name() != "FLASH_ATTN"
# Reset the monkeypatch for subsequent tests # Reset the monkeypatch for subsequent tests
monkeypatch.undo() monkeypatch.undo()
# Unsupported data type # Unsupported data type
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16) backend = get_attn_backend(16, torch.float8_e4m3fn, None)
assert backend.get_name() != "FLASH_ATTN" assert backend.get_name() != "FLASH_ATTN"
# Unsupported kv cache data type # Unsupported kv cache data type
backend = get_attn_backend(16, torch.float16, "fp8", 16) backend = get_attn_backend(16, torch.float16, "fp8")
assert backend.get_name() != "FLASH_ATTN" assert backend.get_name() != "FLASH_ATTN"
# Unsupported block size # Unsupported block size
backend = get_attn_backend(16, torch.float16, None, 8) vllm_config.cache_config.block_size = 8
backend = get_attn_backend(16, torch.float16, None)
assert backend.get_name() != "FLASH_ATTN" assert backend.get_name() != "FLASH_ATTN"
# flash-attn is not installed # flash-attn is not installed
import sys import sys
vllm_config.cache_config.block_size = 16
original_module = sys.modules.get("vllm_flash_attn") original_module = sys.modules.get("vllm_flash_attn")
monkeypatch.setitem(sys.modules, "vllm_flash_attn", None) monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
backend = get_attn_backend(16, torch.float16, None, 16) backend = get_attn_backend(16, torch.float16, None)
assert backend.get_name() != "FLASH_ATTN" assert backend.get_name() != "FLASH_ATTN"
# Restore the original module if it existed # Restore the original module if it existed
...@@ -279,7 +283,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): ...@@ -279,7 +283,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False) monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
# Unsupported head size # Unsupported head size
backend = get_attn_backend(17, torch.float16, None, 16) backend = get_attn_backend(17, torch.float16, None)
assert backend.get_name() != "FLASH_ATTN" assert backend.get_name() != "FLASH_ATTN"
...@@ -320,7 +324,7 @@ def test_auto_backend_selection_behavior(): ...@@ -320,7 +324,7 @@ def test_auto_backend_selection_behavior():
set_current_vllm_config(vllm_config_auto), set_current_vllm_config(vllm_config_auto),
patch("vllm.platforms.current_platform", CpuPlatform()), patch("vllm.platforms.current_platform", CpuPlatform()),
): ):
backend_auto = get_attn_backend(16, torch.float16, None, 16) backend_auto = get_attn_backend(16, torch.float16, None)
_cached_get_attn_backend.cache_clear() _cached_get_attn_backend.cache_clear()
...@@ -328,7 +332,7 @@ def test_auto_backend_selection_behavior(): ...@@ -328,7 +332,7 @@ def test_auto_backend_selection_behavior():
set_current_vllm_config(vllm_config_none), set_current_vllm_config(vllm_config_none),
patch("vllm.platforms.current_platform", CpuPlatform()), patch("vllm.platforms.current_platform", CpuPlatform()),
): ):
backend_none = get_attn_backend(16, torch.float16, None, 16) backend_none = get_attn_backend(16, torch.float16, None)
# Both should select the same backend # Both should select the same backend
assert backend_auto.get_name() == backend_none.get_name() assert backend_auto.get_name() == backend_none.get_name()
...@@ -358,7 +362,10 @@ def test_per_head_quant_scales_backend_selection( ...@@ -358,7 +362,10 @@ def test_per_head_quant_scales_backend_selection(
backend=AttentionBackendEnum[backend_name], backend=AttentionBackendEnum[backend_name],
flash_attn_version=flash_attn_version, flash_attn_version=flash_attn_version,
) )
vllm_config = VllmConfig(attention_config=attention_config) cache_config = CacheConfig(block_size=64)
vllm_config = VllmConfig(
attention_config=attention_config, cache_config=cache_config
)
with ( with (
set_current_vllm_config(vllm_config), set_current_vllm_config(vllm_config),
...@@ -376,7 +383,6 @@ def test_per_head_quant_scales_backend_selection( ...@@ -376,7 +383,6 @@ def test_per_head_quant_scales_backend_selection(
head_size=128, head_size=128,
dtype=torch.float16, dtype=torch.float16,
kv_cache_dtype="fp8", kv_cache_dtype="fp8",
block_size=64,
use_per_head_quant_scales=True, use_per_head_quant_scales=True,
) )
assert backend.get_name() == backend_name assert backend.get_name() == backend_name
...@@ -386,7 +392,6 @@ def test_per_head_quant_scales_backend_selection( ...@@ -386,7 +392,6 @@ def test_per_head_quant_scales_backend_selection(
head_size=128, head_size=128,
dtype=torch.float16, dtype=torch.float16,
kv_cache_dtype="fp8", kv_cache_dtype="fp8",
block_size=64,
use_per_head_quant_scales=True, use_per_head_quant_scales=True,
) )
assert backend_name in str(exc_info.value) assert backend_name in str(exc_info.value)
...@@ -13,6 +13,7 @@ import torch.nn as nn ...@@ -13,6 +13,7 @@ import torch.nn as nn
from PIL import Image from PIL import Image
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.config.cache import CacheConfig
from vllm.config.multimodal import ( from vllm.config.multimodal import (
AudioDummyOptions, AudioDummyOptions,
BaseDummyOptions, BaseDummyOptions,
...@@ -131,7 +132,9 @@ def initialize_dummy_model( ...@@ -131,7 +132,9 @@ def initialize_dummy_model(
): ):
temp_file = tempfile.mkstemp()[1] temp_file = tempfile.mkstemp()[1]
current_device = torch.get_default_device() current_device = torch.get_default_device()
vllm_config = VllmConfig(model_config=model_config) vllm_config = VllmConfig(
model_config=model_config, cache_config=CacheConfig(block_size=16)
)
with set_current_vllm_config(vllm_config=vllm_config): with set_current_vllm_config(vllm_config=vllm_config):
init_distributed_environment( init_distributed_environment(
world_size=1, world_size=1,
......
...@@ -80,7 +80,7 @@ def _create_proposer( ...@@ -80,7 +80,7 @@ def _create_proposer(
device = current_platform.device_type device = current_platform.device_type
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=model_config, model_config=model_config,
cache_config=CacheConfig(), cache_config=CacheConfig(block_size=16),
speculative_config=speculative_config, speculative_config=speculative_config,
device_config=DeviceConfig(device=device), device_config=DeviceConfig(device=device),
parallel_config=ParallelConfig(), parallel_config=ParallelConfig(),
......
...@@ -2,16 +2,15 @@ ...@@ -2,16 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import field from dataclasses import field
from typing import Literal from typing import ClassVar, Literal
from pydantic import Field, SkipValidation, field_validator from pydantic import Field, SkipValidation, field_validator, model_validator
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
CacheDType = Literal[ CacheDType = Literal[
"auto", "auto",
"bfloat16", "bfloat16",
...@@ -31,12 +30,13 @@ KVOffloadingBackend = Literal["native", "lmcache"] ...@@ -31,12 +30,13 @@ KVOffloadingBackend = Literal["native", "lmcache"]
class CacheConfig: class CacheConfig:
"""Configuration for the KV cache.""" """Configuration for the KV cache."""
block_size: SkipValidation[BlockSize] = None # type: ignore[assignment] DEFAULT_BLOCK_SIZE: ClassVar[int] = 16
"""Size of a contiguous cache block in number of tokens.
This config has no static default. If left unspecified by the user, it will block_size: SkipValidation[int] = None # type: ignore[assignment]
be set in `Platform.check_and_update_config()` based on the current """Size of a contiguous cache block in number of tokens.
platform.""" Accepts None (meaning "use default"). After construction, always int."""
user_specified_block_size: bool = field(default=False, init=False)
"""Whether block_size was explicitly provided. Derived automatically."""
gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1) gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
"""The fraction of GPU memory to be used for the model executor, which can """The fraction of GPU memory to be used for the model executor, which can
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
...@@ -169,6 +169,8 @@ class CacheConfig: ...@@ -169,6 +169,8 @@ class CacheConfig:
"prefix_caching_hash_algo", "prefix_caching_hash_algo",
"cpu_kvcache_space_bytes", "cpu_kvcache_space_bytes",
"mamba_page_size_padded", "mamba_page_size_padded",
"user_specified_block_size",
"_block_size_resolved",
# Post-init/derived counters # Post-init/derived counters
"num_gpu_blocks", "num_gpu_blocks",
"num_cpu_blocks", "num_cpu_blocks",
...@@ -186,6 +188,22 @@ class CacheConfig: ...@@ -186,6 +188,22 @@ class CacheConfig:
# metrics info # metrics info
return {key: str(value) for key, value in self.__dict__.items()} return {key: str(value) for key, value in self.__dict__.items()}
_block_size_resolved: bool = field(default=False, init=False)
"""Guard against pydantic re-running _apply_block_size_default."""
@model_validator(mode="after")
def _apply_block_size_default(self) -> "CacheConfig":
# Pydantic re-runs validators when CacheConfig is nested inside
# another pydantic model (e.g. VllmConfig). Guard against that.
if self._block_size_resolved:
return self
object.__setattr__(self, "_block_size_resolved", True)
if self.block_size is None:
object.__setattr__(self, "block_size", self.DEFAULT_BLOCK_SIZE)
else:
object.__setattr__(self, "user_specified_block_size", True)
return self
@field_validator("cache_dtype", mode="after") @field_validator("cache_dtype", mode="after")
@classmethod @classmethod
def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType: def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType:
......
...@@ -1026,32 +1026,6 @@ class VllmConfig: ...@@ -1026,32 +1026,6 @@ class VllmConfig:
) )
current_platform.check_and_update_config(self) current_platform.check_and_update_config(self)
# If DCP, ensure the block size is right.
if self.parallel_config.decode_context_parallel_size > 1:
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
self.parallel_config.cp_kv_cache_interleave_size
!= self.parallel_config.dcp_kv_cache_interleave_size
):
self.parallel_config.cp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
)
logger.warning_once(
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
"_interleave_size. And dcp-kv-cache-interleave-size will be "
"deprecated when PCP is fully supported."
)
assert (
self.parallel_config.cp_kv_cache_interleave_size
<= self.cache_config.block_size
and self.cache_config.block_size
% self.parallel_config.cp_kv_cache_interleave_size
== 0
), (
f"Block_size({self.cache_config.block_size}) should be greater "
"than or equal to and divisible by cp_kv_cache_interleave_size "
f"({self.parallel_config.cp_kv_cache_interleave_size})."
)
# Do this after all the updates to compilation_config.mode # Do this after all the updates to compilation_config.mode
effective_dp_size = ( effective_dp_size = (
self.parallel_config.data_parallel_size self.parallel_config.data_parallel_size
...@@ -1219,26 +1193,6 @@ class VllmConfig: ...@@ -1219,26 +1193,6 @@ class VllmConfig:
# Default to enable HMA if not explicitly disabled by user or logic above. # Default to enable HMA if not explicitly disabled by user or logic above.
self.scheduler_config.disable_hybrid_kv_cache_manager = False self.scheduler_config.disable_hybrid_kv_cache_manager = False
if self.cache_config.mamba_cache_mode == "align":
assert (
self.cache_config.block_size
<= self.scheduler_config.max_num_batched_tokens
), (
"In Mamba cache align mode, block_size "
f"({self.cache_config.block_size}) must be <= "
"max_num_batched_tokens "
f"({self.scheduler_config.max_num_batched_tokens})."
)
if self.scheduler_config.long_prefill_token_threshold > 0:
assert (
self.scheduler_config.long_prefill_token_threshold
>= self.cache_config.block_size
)
assert not self.scheduler_config.disable_chunked_mm_input, (
"Chunked MM input is required because we need the flexibility to "
"schedule a multiple of block_size tokens even if they are in the "
"middle of a mm input"
)
if self.compilation_config.debug_dump_path: if self.compilation_config.debug_dump_path:
self.compilation_config.debug_dump_path = ( self.compilation_config.debug_dump_path = (
self.compilation_config.debug_dump_path.absolute().expanduser() self.compilation_config.debug_dump_path.absolute().expanduser()
...@@ -1673,6 +1627,53 @@ class VllmConfig: ...@@ -1673,6 +1627,53 @@ class VllmConfig:
f"compilation_config={self.compilation_config!r}" f"compilation_config={self.compilation_config!r}"
) )
def validate_block_size(self) -> None:
"""Validate block_size against DCP and mamba constraints.
Called after Platform.update_block_size_for_backend() has
finalised block_size.
"""
block_size = self.cache_config.block_size
# DCP interleave-size compatibility
if self.parallel_config.decode_context_parallel_size > 1:
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
self.parallel_config.cp_kv_cache_interleave_size
!= self.parallel_config.dcp_kv_cache_interleave_size
):
self.parallel_config.cp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
)
logger.warning_once(
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
"_interleave_size. And dcp-kv-cache-interleave-size will be "
"deprecated when PCP is fully supported."
)
assert (
self.parallel_config.cp_kv_cache_interleave_size <= block_size
and block_size % self.parallel_config.cp_kv_cache_interleave_size == 0
), (
f"Block_size({block_size}) should be greater "
"than or equal to and divisible by cp_kv_cache_interleave_size "
f"({self.parallel_config.cp_kv_cache_interleave_size})."
)
# Mamba cache align-mode constraints
if self.cache_config.mamba_cache_mode == "align":
assert block_size <= self.scheduler_config.max_num_batched_tokens, (
"In Mamba cache align mode, block_size "
f"({block_size}) must be <= "
"max_num_batched_tokens "
f"({self.scheduler_config.max_num_batched_tokens})."
)
if self.scheduler_config.long_prefill_token_threshold > 0:
assert self.scheduler_config.long_prefill_token_threshold >= block_size
assert not self.scheduler_config.disable_chunked_mm_input, (
"Chunked MM input is required because we need the flexibility "
"to schedule a multiple of block_size tokens even if they are "
"in the middle of a mm input"
)
@model_validator(mode="after") @model_validator(mode="after")
def validate_mamba_block_size(self) -> "VllmConfig": def validate_mamba_block_size(self) -> "VllmConfig":
if self.model_config is None: if self.model_config is None:
......
...@@ -500,7 +500,6 @@ def get_current_attn_backend(vllm_config: VllmConfig): ...@@ -500,7 +500,6 @@ def get_current_attn_backend(vllm_config: VllmConfig):
head_size=vllm_config.model_config.get_head_size(), head_size=vllm_config.model_config.get_head_size(),
dtype=vllm_config.model_config.dtype, dtype=vllm_config.model_config.dtype,
kv_cache_dtype=vllm_config.cache_config.cache_dtype, kv_cache_dtype=vllm_config.cache_config.cache_dtype,
block_size=vllm_config.cache_config.block_size,
use_mla=vllm_config.model_config.use_mla, use_mla=vllm_config.model_config.use_mla,
) )
return backend return backend
...@@ -726,7 +726,6 @@ class MoRIIOConnectorWorker: ...@@ -726,7 +726,6 @@ class MoRIIOConnectorWorker:
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.dtype, self.model_config.dtype,
self.cache_config.cache_dtype, self.cache_config.cache_dtype,
self.block_size,
use_mla=self.use_mla, use_mla=self.use_mla,
) )
......
...@@ -62,7 +62,6 @@ from vllm.config import ( ...@@ -62,7 +62,6 @@ from vllm.config import (
get_attr_docs, get_attr_docs,
) )
from vllm.config.cache import ( from vllm.config.cache import (
BlockSize,
CacheDType, CacheDType,
KVOffloadingBackend, KVOffloadingBackend,
MambaCacheMode, MambaCacheMode,
...@@ -440,7 +439,7 @@ class EngineArgs: ...@@ -440,7 +439,7 @@ class EngineArgs:
max_parallel_loading_workers: int | None = ( max_parallel_loading_workers: int | None = (
ParallelConfig.max_parallel_loading_workers ParallelConfig.max_parallel_loading_workers
) )
block_size: BlockSize = CacheConfig.block_size block_size: int | None = None
enable_prefix_caching: bool | None = None enable_prefix_caching: bool | None = None
prefix_caching_hash_algo: PrefixCachingHashAlgo = ( prefix_caching_hash_algo: PrefixCachingHashAlgo = (
CacheConfig.prefix_caching_hash_algo CacheConfig.prefix_caching_hash_algo
...@@ -1521,7 +1520,7 @@ class EngineArgs: ...@@ -1521,7 +1520,7 @@ class EngineArgs:
) )
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size, block_size=self.block_size, # type: ignore[arg-type]
gpu_memory_utilization=self.gpu_memory_utilization, gpu_memory_utilization=self.gpu_memory_utilization,
kv_cache_memory_bytes=self.kv_cache_memory_bytes, kv_cache_memory_bytes=self.kv_cache_memory_bytes,
cache_dtype=resolved_cache_dtype, # type: ignore[arg-type] cache_dtype=resolved_cache_dtype, # type: ignore[arg-type]
......
...@@ -221,11 +221,9 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -221,11 +221,9 @@ class Attention(nn.Module, AttentionLayerBase):
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
if cache_config is not None: if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales calculate_kv_scales = cache_config.calculate_kv_scales
else: else:
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False calculate_kv_scales = False
# llm-compressor mdls need to set cache_dtype to "fp8" manually. # llm-compressor mdls need to set cache_dtype to "fp8" manually.
...@@ -275,7 +273,6 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -275,7 +273,6 @@ class Attention(nn.Module, AttentionLayerBase):
head_size, head_size,
dtype, dtype,
kv_cache_dtype, kv_cache_dtype,
block_size,
use_mla=False, use_mla=False,
has_sink=self.has_sink, has_sink=self.has_sink,
use_mm_prefix=self.use_mm_prefix, use_mm_prefix=self.use_mm_prefix,
......
...@@ -30,9 +30,8 @@ from vllm.v1.kv_cache_interface import ( ...@@ -30,9 +30,8 @@ from vllm.v1.kv_cache_interface import (
def create_chunked_local_attention_backend( def create_chunked_local_attention_backend(
underlying_attn_backend: AttentionBackend, underlying_attn_backend: AttentionBackend,
attention_chunk_size: int, attention_chunk_size: int,
block_size: int,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_" prefix = f"ChunkedLocalAttention_{attention_chunk_size}_"
underlying_builder = underlying_attn_backend.get_builder_cls() underlying_builder = underlying_attn_backend.get_builder_cls()
assert issubclass(underlying_builder, AttentionMetadataBuilder) assert issubclass(underlying_builder, AttentionMetadataBuilder)
...@@ -55,7 +54,9 @@ def create_chunked_local_attention_backend( ...@@ -55,7 +54,9 @@ def create_chunked_local_attention_backend(
fast_build: bool = False, fast_build: bool = False,
): ):
cm, make_virtual_batches_block_table = make_local_attention_virtual_batches( cm, make_virtual_batches_block_table = make_local_attention_virtual_batches(
attention_chunk_size, common_attn_metadata, block_size attention_chunk_size,
common_attn_metadata,
self.kv_cache_spec.block_size,
) )
metadata = super().build(common_prefix_len, cm, fast_build) metadata = super().build(common_prefix_len, cm, fast_build)
metadata.make_virtual_batches_block_table = make_virtual_batches_block_table metadata.make_virtual_batches_block_table = make_virtual_batches_block_table
...@@ -94,16 +95,12 @@ class ChunkedLocalAttention(Attention): ...@@ -94,16 +95,12 @@ class ChunkedLocalAttention(Attention):
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
if cache_config is not None: if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else: else:
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend( underlying_attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype)
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_chunked_local_attention_backend( attn_backend = create_chunked_local_attention_backend(
underlying_attn_backend, attention_chunk_size, block_size underlying_attn_backend, attention_chunk_size
) )
super().__init__( super().__init__(
......
...@@ -188,10 +188,8 @@ class CrossAttention(Attention): ...@@ -188,10 +188,8 @@ class CrossAttention(Attention):
if cache_config is not None: if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else: else:
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = 16
if attn_type is not None: if attn_type is not None:
assert attn_type == AttentionType.ENCODER_DECODER, ( assert attn_type == AttentionType.ENCODER_DECODER, (
...@@ -202,7 +200,6 @@ class CrossAttention(Attention): ...@@ -202,7 +200,6 @@ class CrossAttention(Attention):
head_size, head_size,
dtype, dtype,
kv_cache_dtype, kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_DECODER, attn_type=AttentionType.ENCODER_DECODER,
) )
attn_backend = create_cross_attention_backend(underlying_attn_backend) attn_backend = create_cross_attention_backend(underlying_attn_backend)
......
...@@ -66,16 +66,13 @@ class EncoderOnlyAttention(Attention): ...@@ -66,16 +66,13 @@ class EncoderOnlyAttention(Attention):
if cache_config is not None: if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else: else:
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend( underlying_attn_backend = get_attn_backend(
head_size, head_size,
dtype, dtype,
kv_cache_dtype, kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_ONLY, attn_type=AttentionType.ENCODER_ONLY,
) )
......
...@@ -323,11 +323,9 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -323,11 +323,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if cache_config is not None: if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales calculate_kv_scales = cache_config.calculate_kv_scales
else: else:
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False calculate_kv_scales = False
self.quant_config = quant_config self.quant_config = quant_config
...@@ -336,7 +334,6 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -336,7 +334,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.head_size, self.head_size,
dtype, dtype,
kv_cache_dtype, kv_cache_dtype,
block_size,
use_mla=True, use_mla=True,
use_sparse=use_sparse, use_sparse=use_sparse,
num_heads=self.num_heads, num_heads=self.num_heads,
...@@ -449,17 +446,24 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -449,17 +446,24 @@ class MLAAttention(nn.Module, AttentionLayerBase):
) )
# Attributes for forward_impl method # Attributes for forward_impl method
self.chunked_prefill_workspace_size = ( self._vllm_config = get_current_vllm_config()
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( self._chunked_prefill_workspace_size: int | None = None
get_current_vllm_config()
)
)
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
static=True, static=True,
group_shape=GroupShape.PER_TENSOR, group_shape=GroupShape.PER_TENSOR,
compile_native=True, compile_native=True,
) )
@property
def chunked_prefill_workspace_size(self) -> int:
if self._chunked_prefill_workspace_size is None:
self._chunked_prefill_workspace_size = (
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
self._vllm_config
)
)
return self._chunked_prefill_workspace_size
def forward( def forward(
self, self,
q: torch.Tensor, q: torch.Tensor,
......
...@@ -126,17 +126,13 @@ class StaticSinkAttention(Attention, CustomOp): ...@@ -126,17 +126,13 @@ class StaticSinkAttention(Attention, CustomOp):
if cache_config is not None: if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else: else:
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = 16
if attn_backend is not None: if attn_backend is not None:
underlying_attn_backend = attn_backend underlying_attn_backend = attn_backend
else: else:
underlying_attn_backend = get_attn_backend( underlying_attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype)
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_static_sink_attention_backend( attn_backend = create_static_sink_attention_backend(
underlying_attn_backend, # type: ignore[arg-type] underlying_attn_backend, # type: ignore[arg-type]
sink_len=sink_len, sink_len=sink_len,
...@@ -153,7 +149,6 @@ class StaticSinkAttention(Attention, CustomOp): ...@@ -153,7 +149,6 @@ class StaticSinkAttention(Attention, CustomOp):
CustomOp.__init__(self) CustomOp.__init__(self)
self.sink_len = sink_len self.sink_len = sink_len
self.block_size = block_size
self.sink_populated = False self.sink_populated = False
self.sink_key = None self.sink_key = None
self.sink_value = None self.sink_value = None
...@@ -212,12 +207,12 @@ class StaticSinkAttention(Attention, CustomOp): ...@@ -212,12 +207,12 @@ class StaticSinkAttention(Attention, CustomOp):
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it # Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention. # Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER assert self.attn_type == AttentionType.DECODER
return SinkFullAttentionSpec( return SinkFullAttentionSpec(
block_size=block_size, block_size=self.block_size,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
head_size=self.head_size, head_size=self.head_size,
head_size_v=self.head_size_v, head_size_v=self.head_size_v,
......
...@@ -217,10 +217,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): ...@@ -217,10 +217,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
) )
# override attention block size if either (a) the # override attention block size if it is too small,
# user has not set it or (b) the user has set it # even if the user has explicitly set it
# too small. if cache_config.block_size < attn_block_size:
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
cache_config.block_size = attn_block_size cache_config.block_size = attn_block_size
logger.info( logger.info(
"Setting attention block size to %d tokens " "Setting attention block size to %d tokens "
......
...@@ -290,16 +290,13 @@ class WhisperCausalAttentionWithBlockPooling(Attention): ...@@ -290,16 +290,13 @@ class WhisperCausalAttentionWithBlockPooling(Attention):
if cache_config is not None: if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else: else:
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend( underlying_attn_backend = get_attn_backend(
head_size, head_size,
dtype, dtype,
kv_cache_dtype, kv_cache_dtype,
block_size,
attn_type=attn_type, attn_type=attn_type,
) )
attn_backend = create_whisper_attention_backend_with_block_pooling( attn_backend = create_whisper_attention_backend_with_block_pooling(
......
...@@ -185,7 +185,7 @@ class CpuPlatform(Platform): ...@@ -185,7 +185,7 @@ class CpuPlatform(Platform):
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
if cache_config.block_size is None: if not cache_config.user_specified_block_size:
cache_config.block_size = 128 cache_config.block_size = 128
if cache_config.block_size % 32 != 0: if cache_config.block_size % 32 != 0:
...@@ -361,6 +361,12 @@ class CpuPlatform(Platform): ...@@ -361,6 +361,12 @@ class CpuPlatform(Platform):
vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS,
) )
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
# TODO: CPU still sets block_size in check_and_update_config.
# Move that logic here so block_size is chosen by the backend.
pass
@classmethod @classmethod
def get_allowed_cpu_core_node_list(cls) -> tuple[list[int], list[LogicalCPUInfo]]: def get_allowed_cpu_core_node_list(cls) -> tuple[list[int], list[LogicalCPUInfo]]:
assert platform.system() == "Linux" assert platform.system() == "Linux"
......
...@@ -166,122 +166,12 @@ class CudaPlatformBase(Platform): ...@@ -166,122 +166,12 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
from vllm.v1.attention.backends.registry import AttentionBackendEnum
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
# TODO(lucas): handle this more gracefully
# Note: model_config may be None during testing
# Note: block_size is initialized in
# HybridAttentionMambaModelConfig.verify_and_update_config
# for models with both attention and mamba,
# and doesn't need to be reinitialized here
if (
model_config is not None
and model_config.use_mla
and cache_config.block_size is not None
):
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
# If `--attention-config.backend` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the
# required block_size.
use_flashmla = False
use_cutlass_mla = False
use_flashinfer_mla = False
use_flashmla_sparse = False
use_flashinfer_mla_sparse = False
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
if vllm_config.attention_config.backend is None:
# Default case
hf_text_config = model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
if (
cls.is_device_capability_family(100)
and not use_sparse
and qk_nope_head_dim == 128
):
# Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2)
# and only if qk_nope_head_dim == 128 (kernel constraint)
use_flashinfer_mla = True
# Set the backend in AttentionConfig so it's used during
# backend selection
vllm_config.attention_config.backend = (
AttentionBackendEnum.FLASHINFER_MLA
)
elif cls.is_device_capability_family(100) and not use_sparse:
# Fall back to CUTLASS_MLA as 2nd priority on Blackwell
use_cutlass_mla = True
elif is_flashmla_dense_supported()[0]:
# Non-Blackwell with FlashMLA support
use_flashmla = True
else:
# Fallback: will use Triton MLA or other compatible backend
pass
else:
# Forced case
backend = vllm_config.attention_config.backend
use_flashmla = backend == AttentionBackendEnum.FLASHMLA
use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
use_flashmla_sparse = backend == AttentionBackendEnum.FLASHMLA_SPARSE
use_flashinfer_mla_sparse = (
backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE
)
if (
use_flashmla
and is_flashmla_dense_supported()[0]
and cache_config.block_size % 64 != 0
):
cache_config.block_size = 64
logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
if use_cutlass_mla and cache_config.block_size % 128 != 0:
cache_config.block_size = 128
logger.info(
"Forcing kv cache block size to 128 for CUTLASS_MLA backend."
)
if (
use_flashinfer_mla
and cache_config.block_size != 32
and cache_config.block_size % 64 != 0
):
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLA backend."
)
if use_sparse:
if not (use_flashmla_sparse or use_flashinfer_mla_sparse):
use_flashmla_sparse = True
if use_flashmla_sparse and cache_config.block_size != 64:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashMLASparse backend."
)
elif use_flashinfer_mla_sparse and cache_config.block_size not in (
32,
64,
):
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLASparse "
"backend."
)
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
# Note: model_config may be None during testing # Note: model_config may be None during testing
if ( if (
...@@ -312,10 +202,10 @@ class CudaPlatformBase(Platform): ...@@ -312,10 +202,10 @@ class CudaPlatformBase(Platform):
num_heads: int | None = None, num_heads: int | None = None,
) -> tuple[ ) -> tuple[
list[tuple["AttentionBackendEnum", int]], list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]], dict["AttentionBackendEnum", tuple[int, list[str]]],
]: ]:
valid_backends_priorities = [] valid_backends_priorities = []
invalid_reasons = {} invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {}
backend_priorities = _get_backend_priorities( backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla, attn_selector_config.use_mla,
...@@ -332,7 +222,7 @@ class CudaPlatformBase(Platform): ...@@ -332,7 +222,7 @@ class CudaPlatformBase(Platform):
except ImportError: except ImportError:
invalid_reasons_i = ["ImportError"] invalid_reasons_i = ["ImportError"]
if invalid_reasons_i: if invalid_reasons_i:
invalid_reasons[backend] = invalid_reasons_i invalid_reasons[backend] = (priority, invalid_reasons_i)
else: else:
valid_backends_priorities.append((backend, priority)) valid_backends_priorities.append((backend, priority))
...@@ -341,14 +231,13 @@ class CudaPlatformBase(Platform): ...@@ -341,14 +231,13 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def get_attn_backend_cls( def get_attn_backend_cls(
cls, cls,
selected_backend: "AttentionBackendEnum", selected_backend: "AttentionBackendEnum | None",
attn_selector_config: "AttentionSelectorConfig", attn_selector_config: "AttentionSelectorConfig",
num_heads: int | None = None, num_heads: int | None = None,
) -> str: ) -> str:
device_capability = cls.get_device_capability() device_capability = cls.get_device_capability()
assert device_capability is not None assert device_capability is not None
attn_selector_config = attn_selector_config._replace(block_size=None)
# First try checking just the selected backend, if there is one. # First try checking just the selected backend, if there is one.
if selected_backend is not None: if selected_backend is not None:
try: try:
...@@ -370,7 +259,7 @@ class CudaPlatformBase(Platform): ...@@ -370,7 +259,7 @@ class CudaPlatformBase(Platform):
# No selected backend or the selected backend is invalid, # No selected backend or the selected backend is invalid,
# so we try finding a valid backend. # so we try finding a valid backend.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends( valid_backends_priorities, all_invalid_reasons = cls.get_valid_backends(
device_capability=device_capability, device_capability=device_capability,
attn_selector_config=attn_selector_config, attn_selector_config=attn_selector_config,
num_heads=num_heads, num_heads=num_heads,
...@@ -379,7 +268,7 @@ class CudaPlatformBase(Platform): ...@@ -379,7 +268,7 @@ class CudaPlatformBase(Platform):
"{" "{"
+ ", ".join( + ", ".join(
f"{backend.name}: [{', '.join(reasons)}]" f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items() for backend, (_, reasons) in all_invalid_reasons.items()
) )
+ "}" + "}"
) )
...@@ -402,6 +291,29 @@ class CudaPlatformBase(Platform): ...@@ -402,6 +291,29 @@ class CudaPlatformBase(Platform):
) )
selected_index = sorted_indices[0] selected_index = sorted_indices[0]
selected_backend = valid_backends_priorities[selected_index][0] selected_backend = valid_backends_priorities[selected_index][0]
selected_priority = valid_backends_priorities[selected_index][1]
# If the user specified --block-size (but not --attention-backend),
# check whether that constraint precluded any higher-priority backends.
if attn_selector_config.block_size is not None:
excluded = [
backend
for backend, (priority, reasons) in all_invalid_reasons.items()
if priority < selected_priority
and reasons == ["block_size not supported"]
]
if excluded:
names = ", ".join(b.name for b in excluded)
logger.warning(
"--block-size %d precluded higher-priority backend(s) "
"%s. Using %s instead, which may result in reduced "
"performance. Consider removing --block-size to "
"auto-select the optimal block size.",
attn_selector_config.block_size,
names,
selected_backend.name,
)
logger.info_once( logger.info_once(
"Using %s attention backend out of potential backends: %s.", "Using %s attention backend out of potential backends: %s.",
selected_backend.name, selected_backend.name,
......
...@@ -420,6 +420,56 @@ class Platform: ...@@ -420,6 +420,56 @@ class Platform:
""" """
pass pass
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
"""
Ensure block_size is compatible with the attention backend.
"""
from vllm.config.cache import CacheConfig
cache_config = vllm_config.cache_config
if cache_config.user_specified_block_size:
# User specified --block-size; keep it.
return
model_config = vllm_config.model_config
# model_config may be None during testing.
# Skip hybrid models — their block_size is managed by
# HybridAttentionMambaModelConfig.
if model_config is None or model_config.is_hybrid:
cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE
return
from vllm.config.vllm import (
get_layers_from_vllm_config,
set_current_vllm_config,
)
from vllm.model_executor.layers.attention_layer_base import (
AttentionLayerBase,
)
attn_layers = get_layers_from_vllm_config(
vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
)
if not attn_layers:
cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE
return
first_layer = next(iter(attn_layers.values()))
backend_cls = first_layer.get_attn_backend()
with set_current_vllm_config(vllm_config):
preferred = backend_cls.get_preferred_block_size(
CacheConfig.DEFAULT_BLOCK_SIZE
)
if preferred != CacheConfig.DEFAULT_BLOCK_SIZE:
logger.info(
"Setting kv cache block size to %d for %s backend.",
preferred,
backend_cls.get_name(),
)
cache_config.block_size = preferred
@classmethod @classmethod
def verify_model_arch(cls, model_arch: str) -> None: def verify_model_arch(cls, model_arch: str) -> None:
""" """
......
...@@ -687,7 +687,7 @@ class RocmPlatform(Platform): ...@@ -687,7 +687,7 @@ class RocmPlatform(Platform):
) )
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
if cache_config and cache_config.block_size is None: if cache_config and not cache_config.user_specified_block_size:
if ( if (
envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION and envs.VLLM_ROCM_USE_AITER
# NOTE: This block has been deprecated # NOTE: This block has been deprecated
...@@ -707,6 +707,12 @@ class RocmPlatform(Platform): ...@@ -707,6 +707,12 @@ class RocmPlatform(Platform):
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
# TODO: ROCm still sets block_size in check_and_update_config.
# Move that logic here so block_size is chosen by the backend.
pass
@classmethod @classmethod
def verify_model_arch(cls, model_arch: str) -> None: def verify_model_arch(cls, model_arch: str) -> None:
if model_arch in _ROCM_UNSUPPORTED_MODELS: if model_arch in _ROCM_UNSUPPORTED_MODELS:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment