Unverified Commit aaefc58e authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[CI] Revert PRs 34818 and 33600 (#34979)

parent f24b2de3
...@@ -13,7 +13,6 @@ import torch.nn as nn ...@@ -13,7 +13,6 @@ import torch.nn as nn
from PIL import Image from PIL import Image
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.config.cache import CacheConfig
from vllm.config.multimodal import ( from vllm.config.multimodal import (
AudioDummyOptions, AudioDummyOptions,
BaseDummyOptions, BaseDummyOptions,
...@@ -132,9 +131,7 @@ def initialize_dummy_model( ...@@ -132,9 +131,7 @@ def initialize_dummy_model(
): ):
temp_file = tempfile.mkstemp()[1] temp_file = tempfile.mkstemp()[1]
current_device = torch.get_default_device() current_device = torch.get_default_device()
vllm_config = VllmConfig( vllm_config = VllmConfig(model_config=model_config)
model_config=model_config, cache_config=CacheConfig(block_size=16)
)
with set_current_vllm_config(vllm_config=vllm_config): with set_current_vllm_config(vllm_config=vllm_config):
init_distributed_environment( init_distributed_environment(
world_size=1, world_size=1,
......
...@@ -457,9 +457,6 @@ def dummy_hf_overrides( ...@@ -457,9 +457,6 @@ def dummy_hf_overrides(
# Kimi uses `num_expert_group` instead of `n_group`. # Kimi uses `num_expert_group` instead of `n_group`.
if n_group is None: if n_group is None:
n_group = getattr(text_config, "num_expert_group", None) n_group = getattr(text_config, "num_expert_group", None)
# InternS1Pro uses `router_n_groups` instead of `n_group`.
if n_group is None:
n_group = getattr(text_config, "router_n_groups", None)
num_experts = n_group * 2 if n_group is not None else 2 num_experts = n_group * 2 if n_group is not None else 2
# we use three layers for Gemma-3n to check # we use three layers for Gemma-3n to check
...@@ -489,14 +486,12 @@ def dummy_hf_overrides( ...@@ -489,14 +486,12 @@ def dummy_hf_overrides(
# Only set MoE related config when the model has MoE layers. # Only set MoE related config when the model has MoE layers.
# Otherwise all models detected as MoE by _get_transformers_backend_cls. # Otherwise all models detected as MoE by _get_transformers_backend_cls.
if model_arch_config.num_experts > 0: if model_arch_config.num_experts > 0:
orig_topk = getattr(text_config, "num_experts_per_tok", 2)
topk = min(orig_topk, 2)
update_dict.update( update_dict.update(
{ {
"num_experts": num_experts, "num_experts": num_experts,
"num_experts_per_tok": topk, "num_experts_per_tok": 2,
# Kimi uses `num_experts_per_token`. # Kimi uses `num_experts_per_token`.
"num_experts_per_token": topk, "num_experts_per_token": 2,
"num_local_experts": num_experts, "num_local_experts": num_experts,
# Otherwise there will not be any expert layers # Otherwise there will not be any expert layers
"first_k_dense_replace": 0, "first_k_dense_replace": 0,
......
...@@ -78,7 +78,7 @@ def _create_proposer( ...@@ -78,7 +78,7 @@ def _create_proposer(
device = current_platform.device_type device = current_platform.device_type
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=model_config, model_config=model_config,
cache_config=CacheConfig(block_size=16), cache_config=CacheConfig(),
speculative_config=speculative_config, speculative_config=speculative_config,
device_config=DeviceConfig(device=device), device_config=DeviceConfig(device=device),
parallel_config=ParallelConfig(), parallel_config=ParallelConfig(),
......
...@@ -19,6 +19,7 @@ else: ...@@ -19,6 +19,7 @@ else:
logger = init_logger(__name__) logger = init_logger(__name__)
BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
CacheDType = Literal[ CacheDType = Literal[
"auto", "auto",
"bfloat16", "bfloat16",
...@@ -38,11 +39,13 @@ KVOffloadingBackend = Literal["native", "lmcache"] ...@@ -38,11 +39,13 @@ KVOffloadingBackend = Literal["native", "lmcache"]
class CacheConfig: class CacheConfig:
"""Configuration for the KV cache.""" """Configuration for the KV cache."""
block_size: SkipValidation[int] = None # type: ignore[assignment] block_size: SkipValidation[BlockSize] = None # type: ignore[assignment]
"""Size of a contiguous cache block in number of tokens. """Size of a contiguous cache block in number of tokens. On CUDA devices,
only block sizes up to 32 are supported.
This is None until the platform sets it. Always an int by the time This config has no static default. If left unspecified by the user, it will
the engine starts.""" be set in `Platform.check_and_update_config()` based on the current
platform."""
gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1) gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1)
"""The fraction of GPU memory to be used for the model executor, which can """The fraction of GPU memory to be used for the model executor, which can
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
......
...@@ -915,6 +915,32 @@ class VllmConfig: ...@@ -915,6 +915,32 @@ class VllmConfig:
) )
current_platform.check_and_update_config(self) current_platform.check_and_update_config(self)
# If DCP, ensure the block size is right.
if self.parallel_config.decode_context_parallel_size > 1:
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
self.parallel_config.cp_kv_cache_interleave_size
!= self.parallel_config.dcp_kv_cache_interleave_size
):
self.parallel_config.cp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
)
logger.warning_once(
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
"_interleave_size. And dcp-kv-cache-interleave-size will be "
"deprecated when PCP is fully supported."
)
assert (
self.parallel_config.cp_kv_cache_interleave_size
<= self.cache_config.block_size
and self.cache_config.block_size
% self.parallel_config.cp_kv_cache_interleave_size
== 0
), (
f"Block_size({self.cache_config.block_size}) should be greater "
"than or equal to and divisible by cp_kv_cache_interleave_size "
f"({self.parallel_config.cp_kv_cache_interleave_size})."
)
# Do this after all the updates to compilation_config.mode # Do this after all the updates to compilation_config.mode
effective_dp_size = ( effective_dp_size = (
self.parallel_config.data_parallel_size self.parallel_config.data_parallel_size
...@@ -1082,6 +1108,26 @@ class VllmConfig: ...@@ -1082,6 +1108,26 @@ class VllmConfig:
# Default to enable HMA if not explicitly disabled by user or logic above. # Default to enable HMA if not explicitly disabled by user or logic above.
self.scheduler_config.disable_hybrid_kv_cache_manager = False self.scheduler_config.disable_hybrid_kv_cache_manager = False
if self.cache_config.mamba_cache_mode == "align":
assert (
self.cache_config.block_size
<= self.scheduler_config.max_num_batched_tokens
), (
"In Mamba cache align mode, block_size "
f"({self.cache_config.block_size}) must be <= "
"max_num_batched_tokens "
f"({self.scheduler_config.max_num_batched_tokens})."
)
if self.scheduler_config.long_prefill_token_threshold > 0:
assert (
self.scheduler_config.long_prefill_token_threshold
>= self.cache_config.block_size
)
assert not self.scheduler_config.disable_chunked_mm_input, (
"Chunked MM input is required because we need the flexibility to "
"schedule a multiple of block_size tokens even if they are in the "
"middle of a mm input"
)
if self.compilation_config.debug_dump_path: if self.compilation_config.debug_dump_path:
self.compilation_config.debug_dump_path = ( self.compilation_config.debug_dump_path = (
self.compilation_config.debug_dump_path.absolute().expanduser() self.compilation_config.debug_dump_path.absolute().expanduser()
...@@ -1442,57 +1488,6 @@ class VllmConfig: ...@@ -1442,57 +1488,6 @@ class VllmConfig:
f"compilation_config={self.compilation_config!r}" f"compilation_config={self.compilation_config!r}"
) )
def validate_block_size(self) -> None:
"""Validate block_size against DCP and mamba constraints.
Called after Platform.update_block_size_for_backend() has
finalised block_size, so that the checks see the real value
rather than the initial None sentinel.
"""
block_size = self.cache_config.block_size
assert block_size is not None, (
"validate_block_size called before block_size was set"
)
# DCP interleave-size compatibility
if self.parallel_config.decode_context_parallel_size > 1:
if self.parallel_config.dcp_kv_cache_interleave_size > 1 and (
self.parallel_config.cp_kv_cache_interleave_size
!= self.parallel_config.dcp_kv_cache_interleave_size
):
self.parallel_config.cp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
)
logger.warning_once(
"cp_kv_cache_interleave_size is overridden by dcp_kv_cache"
"_interleave_size. And dcp-kv-cache-interleave-size will be "
"deprecated when PCP is fully supported."
)
assert (
self.parallel_config.cp_kv_cache_interleave_size <= block_size
and block_size % self.parallel_config.cp_kv_cache_interleave_size == 0
), (
f"Block_size({block_size}) should be greater "
"than or equal to and divisible by cp_kv_cache_interleave_size "
f"({self.parallel_config.cp_kv_cache_interleave_size})."
)
# Mamba cache align-mode constraints
if self.cache_config.mamba_cache_mode == "align":
assert block_size <= self.scheduler_config.max_num_batched_tokens, (
"In Mamba cache align mode, block_size "
f"({block_size}) must be <= "
"max_num_batched_tokens "
f"({self.scheduler_config.max_num_batched_tokens})."
)
if self.scheduler_config.long_prefill_token_threshold > 0:
assert self.scheduler_config.long_prefill_token_threshold >= block_size
assert not self.scheduler_config.disable_chunked_mm_input, (
"Chunked MM input is required because we need the flexibility "
"to schedule a multiple of block_size tokens even if they are "
"in the middle of a mm input"
)
@model_validator(mode="after") @model_validator(mode="after")
def validate_mamba_block_size(self) -> "VllmConfig": def validate_mamba_block_size(self) -> "VllmConfig":
if self.model_config is None: if self.model_config is None:
......
...@@ -59,6 +59,7 @@ from vllm.config import ( ...@@ -59,6 +59,7 @@ from vllm.config import (
get_attr_docs, get_attr_docs,
) )
from vllm.config.cache import ( from vllm.config.cache import (
BlockSize,
CacheDType, CacheDType,
KVOffloadingBackend, KVOffloadingBackend,
MambaCacheMode, MambaCacheMode,
...@@ -430,7 +431,7 @@ class EngineArgs: ...@@ -430,7 +431,7 @@ class EngineArgs:
max_parallel_loading_workers: int | None = ( max_parallel_loading_workers: int | None = (
ParallelConfig.max_parallel_loading_workers ParallelConfig.max_parallel_loading_workers
) )
block_size: int = None # type: ignore[assignment] block_size: BlockSize = CacheConfig.block_size
enable_prefix_caching: bool | None = None enable_prefix_caching: bool | None = None
prefix_caching_hash_algo: PrefixCachingHashAlgo = ( prefix_caching_hash_algo: PrefixCachingHashAlgo = (
CacheConfig.prefix_caching_hash_algo CacheConfig.prefix_caching_hash_algo
......
...@@ -30,8 +30,9 @@ from vllm.v1.kv_cache_interface import ( ...@@ -30,8 +30,9 @@ from vllm.v1.kv_cache_interface import (
def create_chunked_local_attention_backend( def create_chunked_local_attention_backend(
underlying_attn_backend: AttentionBackend, underlying_attn_backend: AttentionBackend,
attention_chunk_size: int, attention_chunk_size: int,
block_size: int,
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_" prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
underlying_builder = underlying_attn_backend.get_builder_cls() underlying_builder = underlying_attn_backend.get_builder_cls()
assert issubclass(underlying_builder, AttentionMetadataBuilder) assert issubclass(underlying_builder, AttentionMetadataBuilder)
...@@ -54,9 +55,7 @@ def create_chunked_local_attention_backend( ...@@ -54,9 +55,7 @@ def create_chunked_local_attention_backend(
fast_build: bool = False, fast_build: bool = False,
): ):
cm, make_virtual_batches_block_table = make_local_attention_virtual_batches( cm, make_virtual_batches_block_table = make_local_attention_virtual_batches(
attention_chunk_size, attention_chunk_size, common_attn_metadata, block_size
common_attn_metadata,
self.kv_cache_spec.block_size,
) )
metadata = super().build(common_prefix_len, cm, fast_build) metadata = super().build(common_prefix_len, cm, fast_build)
metadata.make_virtual_batches_block_table = make_virtual_batches_block_table metadata.make_virtual_batches_block_table = make_virtual_batches_block_table
...@@ -98,13 +97,13 @@ class ChunkedLocalAttention(Attention): ...@@ -98,13 +97,13 @@ class ChunkedLocalAttention(Attention):
block_size = cache_config.block_size block_size = cache_config.block_size
else: else:
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = None block_size = 16
underlying_attn_backend = get_attn_backend( underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size head_size, dtype, kv_cache_dtype, block_size
) )
attn_backend = create_chunked_local_attention_backend( attn_backend = create_chunked_local_attention_backend(
underlying_attn_backend, attention_chunk_size underlying_attn_backend, attention_chunk_size, block_size
) )
super().__init__( super().__init__(
......
...@@ -407,24 +407,17 @@ class MLAAttention(nn.Module, AttentionLayerBase): ...@@ -407,24 +407,17 @@ class MLAAttention(nn.Module, AttentionLayerBase):
) )
# Attributes for forward_impl method # Attributes for forward_impl method
self._vllm_config = get_current_vllm_config() self.chunked_prefill_workspace_size = (
self._chunked_prefill_workspace_size: int | None = None MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
get_current_vllm_config()
)
)
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
static=True, static=True,
group_shape=GroupShape.PER_TENSOR, group_shape=GroupShape.PER_TENSOR,
compile_native=True, compile_native=True,
) )
@property
def chunked_prefill_workspace_size(self) -> int:
if self._chunked_prefill_workspace_size is None:
self._chunked_prefill_workspace_size = (
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
self._vllm_config
)
)
return self._chunked_prefill_workspace_size
def forward( def forward(
self, self,
q: torch.Tensor, q: torch.Tensor,
......
...@@ -163,68 +163,135 @@ class CudaPlatformBase(Platform): ...@@ -163,68 +163,135 @@ class CudaPlatformBase(Platform):
@classmethod @classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
from vllm.v1.attention.backends.registry import AttentionBackendEnum
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"
scheduler_config = vllm_config.scheduler_config cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
# TODO(lucas): handle this more gracefully
# Note: model_config may be None during testing # Note: model_config may be None during testing
# Note: block_size is initialized in
# HybridAttentionMambaModelConfig.verify_and_update_config
# for models with both attention and mamba,
# and doesn't need to be reinitialized here
if ( if (
model_config is not None model_config is not None
and model_config.is_mm_prefix_lm and model_config.use_mla
and scheduler_config.is_multimodal_model and cache_config.block_size is not None
and not scheduler_config.disable_chunked_mm_input
): ):
logger.warning( use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
"Forcing --disable_chunked_mm_input for models " # If `--attention-config.backend` is not set and we are using MLA,
"with multimodal-bidirectional attention." # then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the
# required block_size.
use_flashmla = False
use_cutlass_mla = False
use_flashinfer_mla = False
use_flashmla_sparse = False
use_flashinfer_mla_sparse = False
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
if vllm_config.attention_config.backend is None:
# Default case
hf_text_config = model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
if (
cls.is_device_capability_family(100)
and not use_sparse
and qk_nope_head_dim == 128
):
# Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2)
# and only if qk_nope_head_dim == 128 (kernel constraint)
use_flashinfer_mla = True
# Set the backend in AttentionConfig so it's used during
# backend selection
vllm_config.attention_config.backend = (
AttentionBackendEnum.FLASHINFER_MLA
)
elif cls.is_device_capability_family(100) and not use_sparse:
# Fall back to CUTLASS_MLA as 2nd priority on Blackwell
use_cutlass_mla = True
elif is_flashmla_dense_supported()[0]:
# Non-Blackwell with FlashMLA support
use_flashmla = True
else:
# Fallback: will use Triton MLA or other compatible backend
pass
else:
# Forced case
backend = vllm_config.attention_config.backend
use_flashmla = backend == AttentionBackendEnum.FLASHMLA
use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
use_flashmla_sparse = backend == AttentionBackendEnum.FLASHMLA_SPARSE
use_flashinfer_mla_sparse = (
backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE
) )
scheduler_config.disable_chunked_mm_input = True
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
cache_config = vllm_config.cache_config
if cache_config.block_size is not None:
# User specified --block-size; keep it.
return
model_config = vllm_config.model_config if (
# model_config may be None during testing. use_flashmla
# Skip hybrid models — their block_size is managed by and is_flashmla_dense_supported()[0]
# HybridAttentionMambaModelConfig. and cache_config.block_size % 64 != 0
if model_config is None or model_config.is_hybrid: ):
cache_config.block_size = 16 cache_config.block_size = 64
return logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")
from vllm.config.vllm import ( if use_cutlass_mla and cache_config.block_size % 128 != 0:
get_layers_from_vllm_config, cache_config.block_size = 128
set_current_vllm_config, logger.info(
) "Forcing kv cache block size to 128 for CUTLASS_MLA backend."
from vllm.model_executor.layers.attention_layer_base import (
AttentionLayerBase,
) )
attn_layers = get_layers_from_vllm_config( if (
vllm_config, use_flashinfer_mla
AttentionLayerBase, and cache_config.block_size != 32
and cache_config.block_size % 64 != 0
):
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLA backend."
) )
if not attn_layers:
cache_config.block_size = 16
return
first_layer = next(iter(attn_layers.values())) if use_sparse:
backend_cls = first_layer.get_attn_backend() if not (use_flashmla_sparse or use_flashinfer_mla_sparse):
with set_current_vllm_config(vllm_config): use_flashmla_sparse = True
preferred = backend_cls.get_preferred_block_size(16)
if preferred != 16: if use_flashmla_sparse and cache_config.block_size != 64:
cache_config.block_size = 64
logger.info( logger.info(
"Setting kv cache block size to %d for %s backend.", "Forcing kv cache block size to 64 for FlashMLASparse backend."
preferred, )
backend_cls.get_name(), elif use_flashinfer_mla_sparse and cache_config.block_size not in (
32,
64,
):
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLASparse "
"backend."
) )
cache_config.block_size = preferred
scheduler_config = vllm_config.scheduler_config
# Note: model_config may be None during testing
if (
model_config is not None
and model_config.is_mm_prefix_lm
and scheduler_config.is_multimodal_model
and not scheduler_config.disable_chunked_mm_input
):
logger.warning(
"Forcing --disable_chunked_mm_input for models "
"with multimodal-bidirectional attention."
)
scheduler_config.disable_chunked_mm_input = True
@classmethod @classmethod
def get_current_memory_usage( def get_current_memory_usage(
...@@ -242,10 +309,10 @@ class CudaPlatformBase(Platform): ...@@ -242,10 +309,10 @@ class CudaPlatformBase(Platform):
num_heads: int | None = None, num_heads: int | None = None,
) -> tuple[ ) -> tuple[
list[tuple["AttentionBackendEnum", int]], list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", tuple[int, list[str]]], dict["AttentionBackendEnum", list[str]],
]: ]:
valid_backends_priorities = [] valid_backends_priorities = []
invalid_reasons: dict[AttentionBackendEnum, tuple[int, list[str]]] = {} invalid_reasons = {}
backend_priorities = _get_backend_priorities( backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla, attn_selector_config.use_mla,
...@@ -262,155 +329,84 @@ class CudaPlatformBase(Platform): ...@@ -262,155 +329,84 @@ class CudaPlatformBase(Platform):
except ImportError: except ImportError:
invalid_reasons_i = ["ImportError"] invalid_reasons_i = ["ImportError"]
if invalid_reasons_i: if invalid_reasons_i:
invalid_reasons[backend] = (priority, invalid_reasons_i) invalid_reasons[backend] = invalid_reasons_i
else: else:
valid_backends_priorities.append((backend, priority)) valid_backends_priorities.append((backend, priority))
return valid_backends_priorities, invalid_reasons return valid_backends_priorities, invalid_reasons
@classmethod @classmethod
def select_attention_backend( def get_attn_backend_cls(
cls, cls,
selected_backend: "AttentionBackendEnum | None", selected_backend: "AttentionBackendEnum",
attn_selector_config: "AttentionSelectorConfig", attn_selector_config: "AttentionSelectorConfig",
device_capability: "DeviceCapability",
raise_on_invalid: bool = True,
num_heads: int | None = None, num_heads: int | None = None,
) -> "AttentionBackendEnum | None": ) -> str:
"""Select the best attention backend for the given configuration. device_capability = cls.get_device_capability()
assert device_capability is not None
Args:
selected_backend: User-specified backend, or None for auto-selection attn_selector_config = attn_selector_config._replace(block_size=None)
attn_selector_config: Configuration for attention selection
device_capability: Device capability info
raise_on_invalid: If True, raise ValueError when no valid backend
num_heads: Number of attention heads per GPU, used for backend
priority ordering on Blackwell GPUs
Returns:
The selected backend enum, or None if no valid backend found
and raise_on_invalid is False
"""
# First try checking just the selected backend, if there is one. # First try checking just the selected backend, if there is one.
if selected_backend is not None: if selected_backend is not None:
try: try:
backend_class = selected_backend.get_class() backend_class = selected_backend.get_class()
validation_errors = backend_class.validate_configuration( invalid_reasons = backend_class.validate_configuration(
device_capability=device_capability, device_capability=device_capability,
**attn_selector_config._asdict(), **attn_selector_config._asdict(),
) )
except ImportError: except ImportError:
validation_errors = ["ImportError"] invalid_reasons = ["ImportError"]
if validation_errors: if invalid_reasons:
if raise_on_invalid:
raise ValueError( raise ValueError(
f"Selected backend {selected_backend} is not valid for " f"Selected backend {selected_backend} is not valid for "
f"this configuration. Reason: {validation_errors}" f"this configuration. Reason: {invalid_reasons}"
) )
return None else:
return selected_backend logger.info("Using %s backend.", selected_backend)
return selected_backend.get_path()
# No selected backend, so find the best valid one. # No selected backend or the selected backend is invalid,
# so we try finding a valid backend.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends( valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
device_capability=device_capability, device_capability=device_capability,
attn_selector_config=attn_selector_config, attn_selector_config=attn_selector_config,
num_heads=num_heads, num_heads=num_heads,
) )
if len(valid_backends_priorities) == 0:
if raise_on_invalid:
reasons_str = ( reasons_str = (
"{" "{"
+ ", ".join( + ", ".join(
f"{backend.name}: [{', '.join(reasons)}]" f"{backend.name}: [{', '.join(reasons)}]"
for backend, (_, reasons) in invalid_reasons.items() for backend, reasons in invalid_reasons.items()
) )
+ "}" + "}"
) )
config_str = attn_selector_config.__repr__() config_str = attn_selector_config.__repr__()
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."
)
if len(valid_backends_priorities) == 0:
raise ValueError( raise ValueError(
f"No valid attention backend found for {cls.device_name} " f"No valid attention backend found for {cls.device_name} "
f"with {config_str}. Reasons: {reasons_str}." f"with {config_str}. Reasons: {reasons_str}."
) )
return None
# Select the one with the highest priority (lowest index). # We have found some valid backends. Select the one with the
sorted_backends = sorted(valid_backends_priorities, key=lambda x: x[1]) # highest priority.
chosen_backend, chosen_priority = sorted_backends[0] sorted_indices = sorted(
range(len(valid_backends_priorities)),
# If the user specified --block-size (but not --attention-backend), key=lambda i: valid_backends_priorities[i][1],
# check whether that constraint precluded any higher-priority backends.
if attn_selector_config.block_size is not None:
excluded = [
backend
for backend, (priority, reasons) in invalid_reasons.items()
if priority < chosen_priority
and reasons == ["block_size not supported"]
]
if excluded:
names = ", ".join(b.name for b in excluded)
logger.warning(
"--block-size %d excluded higher-priority backend(s) "
"%s. Using %s instead, which may result in reduced "
"performance. Consider removing --block-size to "
"auto-select the optimal block size.",
attn_selector_config.block_size,
names,
chosen_backend.name,
)
return chosen_backend
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum | None",
attn_selector_config: "AttentionSelectorConfig",
num_heads: int | None = None,
) -> str:
device_capability = cls.get_device_capability()
assert device_capability is not None
chosen_backend = cls.select_attention_backend(
selected_backend=selected_backend,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
device_capability=device_capability,
raise_on_invalid=True,
)
assert chosen_backend is not None # raise_on_invalid=True guarantees this
# Log the selection
if selected_backend is not None:
logger.info("Using %s backend.", chosen_backend)
else:
# Get all valid backends for logging
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
device_capability=device_capability,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
)
reasons_str = (
"{"
+ ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, (_, reasons) in invalid_reasons.items()
)
+ "}"
)
config_str = attn_selector_config.__repr__()
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."
) )
selected_index = sorted_indices[0]
selected_backend = valid_backends_priorities[selected_index][0]
logger.info_once( logger.info_once(
"Using %s attention backend out of potential backends: %s", "Using %s attention backend out of potential backends: %s.",
chosen_backend.name, selected_backend.name,
tuple(backend.name for backend, _ in valid_backends_priorities), "[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]",
scope="local", scope="local",
) )
return chosen_backend.get_path() return selected_backend.get_path()
@classmethod @classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
......
...@@ -406,13 +406,6 @@ class Platform: ...@@ -406,13 +406,6 @@ class Platform:
""" """
pass pass
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
"""
Ensure block_size is compatible with the attention backend.
"""
pass
@classmethod @classmethod
def verify_model_arch(cls, model_arch: str) -> None: def verify_model_arch(cls, model_arch: str) -> None:
""" """
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
import numpy as np import numpy as np
import torch import torch
...@@ -144,9 +144,15 @@ class AttentionBackend(ABC): ...@@ -144,9 +144,15 @@ class AttentionBackend(ABC):
@classmethod @classmethod
def supports_block_size(cls, block_size: int | None) -> bool: def supports_block_size(cls, block_size: int | None) -> bool:
from vllm.config.cache import BlockSize
if block_size is None: if block_size is None:
return True return True
valid_sizes = get_args(BlockSize)
if block_size not in valid_sizes:
return False
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes() supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
if not supported_kernel_block_sizes: if not supported_kernel_block_sizes:
return True return True
...@@ -161,17 +167,6 @@ class AttentionBackend(ABC): ...@@ -161,17 +167,6 @@ class AttentionBackend(ABC):
return True return True
return False return False
@classmethod
def get_preferred_block_size(cls, default_block_size: int = 16) -> int:
supported_sizes = cls.get_supported_kernel_block_sizes()
if not supported_sizes:
return default_block_size
if cls.supports_block_size(default_block_size):
return default_block_size
return min(s.base if isinstance(s, MultipleOf) else s for s in supported_sizes)
@classmethod @classmethod
def is_mla(cls) -> bool: def is_mla(cls) -> bool:
return False return False
......
...@@ -114,14 +114,7 @@ class EngineCore: ...@@ -114,14 +114,7 @@ class EngineCore:
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches( num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
vllm_config vllm_config
) )
if kv_cache_config.kv_cache_groups:
vllm_config.cache_config.block_size = min(
g.kv_cache_spec.block_size for g in kv_cache_config.kv_cache_groups
)
elif vllm_config.cache_config.block_size is None:
# Attention-free models (encoder-only, SSM) — use default.
vllm_config.cache_config.block_size = 16
vllm_config.validate_block_size()
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
......
...@@ -41,7 +41,6 @@ from vllm.distributed.parallel_state import ( ...@@ -41,7 +41,6 @@ from vllm.distributed.parallel_state import (
) )
from vllm.envs import enable_envs_cache from vllm.envs import enable_envs_cache
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.tracing import instrument, maybe_init_worker_tracer from vllm.tracing import instrument, maybe_init_worker_tracer
from vllm.utils.network_utils import ( from vllm.utils.network_utils import (
get_distributed_init_method, get_distributed_init_method,
...@@ -580,9 +579,6 @@ class WorkerProc: ...@@ -580,9 +579,6 @@ class WorkerProc:
self._init_message_queues(input_shm_handle, vllm_config) self._init_message_queues(input_shm_handle, vllm_config)
self.worker.load_model() self.worker.load_model()
# Set block size based on the attention backends
current_platform.update_block_size_for_backend(vllm_config)
# Enable environment variable cache (e.g. assume no more # Enable environment variable cache (e.g. assume no more
# environment variable overrides after this point) # environment variable overrides after this point)
enable_envs_cache() enable_envs_cache()
......
...@@ -385,11 +385,6 @@ class RayDistributedExecutor(Executor): ...@@ -385,11 +385,6 @@ class RayDistributedExecutor(Executor):
self.collective_rpc("init_device") self.collective_rpc("init_device")
self.collective_rpc("load_model") self.collective_rpc("load_model")
def _update_block_size(worker):
current_platform.update_block_size_for_backend(worker.vllm_config)
self.collective_rpc(_update_block_size)
for pp_rank in range(self.parallel_config.pipeline_parallel_size): for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([]) self.pp_tp_workers.append([])
for tp_rank in range(self.parallel_config.tensor_parallel_size): for tp_rank in range(self.parallel_config.tensor_parallel_size):
......
...@@ -12,7 +12,6 @@ import torch.distributed as dist ...@@ -12,7 +12,6 @@ import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
...@@ -47,7 +46,6 @@ class UniProcExecutor(Executor): ...@@ -47,7 +46,6 @@ class UniProcExecutor(Executor):
self.driver_worker.init_worker(all_kwargs=[kwargs]) self.driver_worker.init_worker(all_kwargs=[kwargs])
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
current_platform.update_block_size_for_backend(self.vllm_config)
def _distributed_args(self) -> tuple[str, int, int]: def _distributed_args(self) -> tuple[str, int, int]:
"""Return (distributed_init_method, rank, local_rank).""" """Return (distributed_init_method, rank, local_rank)."""
......
...@@ -513,7 +513,6 @@ class GPUModelRunner( ...@@ -513,7 +513,6 @@ class GPUModelRunner(
custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = ( custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (
tuple(logits_processors) if logits_processors is not None else () tuple(logits_processors) if logits_processors is not None else ()
) )
placeholder_block_size = self.cache_config.block_size or 16
self.input_batch = InputBatch( self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
# We need to use the encoder length for encoder-decoer # We need to use the encoder length for encoder-decoer
...@@ -523,8 +522,8 @@ class GPUModelRunner( ...@@ -523,8 +522,8 @@ class GPUModelRunner(
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(), vocab_size=self.model_config.get_vocab_size(),
block_sizes=[placeholder_block_size], block_sizes=[self.cache_config.block_size],
kernel_block_sizes=[placeholder_block_size], kernel_block_sizes=[self.cache_config.block_size],
is_spec_decode=bool(self.vllm_config.speculative_config), is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs( logitsprocs=build_logitsprocs(
self.vllm_config, self.vllm_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment