Unverified Commit 77a73458 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

Reapply [Attention] Refactor `check_and_update_config` (#35122)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 5578f2a4
...@@ -162,7 +162,7 @@ class XPUPlatform(Platform): ...@@ -162,7 +162,7 @@ class XPUPlatform(Platform):
model_config = vllm_config.model_config model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
# in V1(or with chunked prefill) block_size is 64 # in V1(or with chunked prefill) block_size is 64
if cache_config and cache_config.block_size is None: if cache_config and not cache_config.user_specified_block_size:
cache_config.block_size = 64 cache_config.block_size = 64
# lazy import to avoid circular import # lazy import to avoid circular import
...@@ -227,6 +227,12 @@ class XPUPlatform(Platform): ...@@ -227,6 +227,12 @@ class XPUPlatform(Platform):
# ref. https://openucx.readthedocs.io/en/master/faq.html # ref. https://openucx.readthedocs.io/en/master/faq.html
os.environ["UCX_MEMTYPE_CACHE"] = "n" os.environ["UCX_MEMTYPE_CACHE"] = "n"
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
# TODO: XPU still sets block_size in check_and_update_config.
# Move that logic here so block_size is chosen by the backend.
pass
@classmethod @classmethod
def support_hybrid_kv_cache(cls) -> bool: def support_hybrid_kv_cache(cls) -> bool:
return True return True
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar
import numpy as np import numpy as np
import torch import torch
...@@ -144,15 +144,9 @@ class AttentionBackend(ABC): ...@@ -144,15 +144,9 @@ class AttentionBackend(ABC):
@classmethod @classmethod
def supports_block_size(cls, block_size: int | None) -> bool: def supports_block_size(cls, block_size: int | None) -> bool:
from vllm.config.cache import BlockSize
if block_size is None: if block_size is None:
return True return True
valid_sizes = get_args(BlockSize)
if block_size not in valid_sizes:
return False
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes() supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
if not supported_kernel_block_sizes: if not supported_kernel_block_sizes:
return True return True
...@@ -167,6 +161,17 @@ class AttentionBackend(ABC): ...@@ -167,6 +161,17 @@ class AttentionBackend(ABC):
return True return True
return False return False
@classmethod
def get_preferred_block_size(cls, default_block_size: int) -> int:
supported_sizes = cls.get_supported_kernel_block_sizes()
if not supported_sizes:
return default_block_size
if cls.supports_block_size(default_block_size):
return default_block_size
return min(s.base if isinstance(s, MultipleOf) else s for s in supported_sizes)
@classmethod @classmethod
def is_mla(cls) -> bool: def is_mla(cls) -> bool:
return False return False
...@@ -210,7 +215,7 @@ class AttentionBackend(ABC): ...@@ -210,7 +215,7 @@ class AttentionBackend(ABC):
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None", kv_cache_dtype: "CacheDType | None",
block_size: int, block_size: int | None,
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
...@@ -224,7 +229,7 @@ class AttentionBackend(ABC): ...@@ -224,7 +229,7 @@ class AttentionBackend(ABC):
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None", kv_cache_dtype: "CacheDType | None",
block_size: int, block_size: int | None,
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
......
...@@ -75,7 +75,7 @@ class FlashAttnMLABackend(MLACommonBackend): ...@@ -75,7 +75,7 @@ class FlashAttnMLABackend(MLACommonBackend):
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: CacheDType | None, kv_cache_dtype: CacheDType | None,
block_size: int, block_size: int | None,
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
......
...@@ -69,7 +69,7 @@ class FlashInferMLABackend(MLACommonBackend): ...@@ -69,7 +69,7 @@ class FlashInferMLABackend(MLACommonBackend):
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: CacheDType | None, kv_cache_dtype: CacheDType | None,
block_size: int, block_size: int | None,
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
......
...@@ -106,7 +106,7 @@ class FlashInferMLASparseBackend(AttentionBackend): ...@@ -106,7 +106,7 @@ class FlashInferMLASparseBackend(AttentionBackend):
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: CacheDType | None, kv_cache_dtype: CacheDType | None,
block_size: int, block_size: int | None,
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
......
...@@ -80,7 +80,7 @@ class FlashMLABackend(MLACommonBackend): ...@@ -80,7 +80,7 @@ class FlashMLABackend(MLACommonBackend):
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: CacheDType | None, kv_cache_dtype: CacheDType | None,
block_size: int, block_size: int | None,
use_mla: bool, use_mla: bool,
has_sink: bool, has_sink: bool,
use_sparse: bool, use_sparse: bool,
......
...@@ -49,7 +49,6 @@ def get_attn_backend( ...@@ -49,7 +49,6 @@ def get_attn_backend(
head_size: int, head_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_cache_dtype: str | None, kv_cache_dtype: str | None,
block_size: int | None,
use_mla: bool = False, use_mla: bool = False,
has_sink: bool = False, has_sink: bool = False,
use_sparse: bool = False, use_sparse: bool = False,
...@@ -71,6 +70,12 @@ def get_attn_backend( ...@@ -71,6 +70,12 @@ def get_attn_backend(
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
cache_config = vllm_config.cache_config
if cache_config is not None and cache_config.user_specified_block_size:
block_size = cache_config.block_size
else:
block_size = None
attn_selector_config = AttentionSelectorConfig( attn_selector_config = AttentionSelectorConfig(
head_size=head_size, head_size=head_size,
dtype=dtype, dtype=dtype,
......
...@@ -122,7 +122,11 @@ class EngineCore: ...@@ -122,7 +122,11 @@ class EngineCore:
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches( num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
vllm_config vllm_config
) )
if kv_cache_config.kv_cache_groups:
vllm_config.cache_config.block_size = min(
g.kv_cache_spec.block_size for g in kv_cache_config.kv_cache_groups
)
vllm_config.validate_block_size()
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
......
...@@ -42,6 +42,7 @@ from vllm.distributed.parallel_state import ( ...@@ -42,6 +42,7 @@ from vllm.distributed.parallel_state import (
) )
from vllm.envs import enable_envs_cache from vllm.envs import enable_envs_cache
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.tracing import instrument, maybe_init_worker_tracer from vllm.tracing import instrument, maybe_init_worker_tracer
from vllm.utils.network_utils import ( from vllm.utils.network_utils import (
get_distributed_init_method, get_distributed_init_method,
...@@ -617,6 +618,9 @@ class WorkerProc: ...@@ -617,6 +618,9 @@ class WorkerProc:
) )
self.worker.load_model() self.worker.load_model()
# Set block size based on the attention backends
current_platform.update_block_size_for_backend(vllm_config)
# Initialize message queues after init_device() since multi-node setups # Initialize message queues after init_device() since multi-node setups
# (nnodes_within_dp > 1) require distributed groups to be initialized # (nnodes_within_dp > 1) require distributed groups to be initialized
self._init_message_queues(input_shm_handle, vllm_config) self._init_message_queues(input_shm_handle, vllm_config)
......
...@@ -387,6 +387,11 @@ class RayDistributedExecutor(Executor): ...@@ -387,6 +387,11 @@ class RayDistributedExecutor(Executor):
self.collective_rpc("init_device") self.collective_rpc("init_device")
self.collective_rpc("load_model") self.collective_rpc("load_model")
def _update_block_size(worker):
current_platform.update_block_size_for_backend(worker.vllm_config)
self.collective_rpc(_update_block_size)
for pp_rank in range(self.parallel_config.pipeline_parallel_size): for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([]) self.pp_tp_workers.append([])
for tp_rank in range(self.parallel_config.tensor_parallel_size): for tp_rank in range(self.parallel_config.tensor_parallel_size):
......
...@@ -12,6 +12,7 @@ import torch.distributed as dist ...@@ -12,6 +12,7 @@ import torch.distributed as dist
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
...@@ -47,6 +48,7 @@ class UniProcExecutor(Executor): ...@@ -47,6 +48,7 @@ class UniProcExecutor(Executor):
if not is_eep_new_worker: if not is_eep_new_worker:
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
current_platform.update_block_size_for_backend(self.vllm_config)
def _distributed_args(self) -> tuple[str, int, int]: def _distributed_args(self) -> tuple[str, int, int]:
"""Return (distributed_init_method, rank, local_rank).""" """Return (distributed_init_method, rank, local_rank)."""
......
...@@ -32,6 +32,7 @@ from vllm.config import ( ...@@ -32,6 +32,7 @@ from vllm.config import (
set_current_vllm_config, set_current_vllm_config,
update_config, update_config,
) )
from vllm.config.cache import CacheConfig
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
...@@ -586,6 +587,11 @@ class GPUModelRunner( ...@@ -586,6 +587,11 @@ class GPUModelRunner(
custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = ( custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (
tuple(logits_processors) if logits_processors is not None else () tuple(logits_processors) if logits_processors is not None else ()
) )
placeholder_block_size = (
self.cache_config.block_size or CacheConfig.DEFAULT_BLOCK_SIZE
)
self._init_block_sizes = [placeholder_block_size]
self._init_kernel_block_sizes = [placeholder_block_size]
self.input_batch = InputBatch( self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
# We need to use the encoder length for encoder-decoder # We need to use the encoder length for encoder-decoder
...@@ -595,8 +601,8 @@ class GPUModelRunner( ...@@ -595,8 +601,8 @@ class GPUModelRunner(
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(), vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.cache_config.block_size], block_sizes=[placeholder_block_size],
kernel_block_sizes=[self.cache_config.block_size], kernel_block_sizes=[placeholder_block_size],
is_spec_decode=bool(self.vllm_config.speculative_config), is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs( logitsprocs=build_logitsprocs(
self.vllm_config, self.vllm_config,
...@@ -6112,8 +6118,10 @@ class GPUModelRunner( ...@@ -6112,8 +6118,10 @@ class GPUModelRunner(
) -> None: ) -> None:
""" """
Re-initialize the input batch if the block sizes are different from Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there what it was originally created with. This happens when the final
are multiple KV cache groups. block size (determined after model loading) differs from the
placeholder used during __init__, or when there are multiple
KV cache groups.
Args: Args:
kv_cache_config: The KV cache configuration. kv_cache_config: The KV cache configuration.
...@@ -6138,14 +6146,17 @@ class GPUModelRunner( ...@@ -6138,14 +6146,17 @@ class GPUModelRunner(
) + kv_cache_group.kv_cache_spec.num_speculative_blocks ) + kv_cache_group.kv_cache_spec.num_speculative_blocks
max_num_blocks.append(max_num_blocks_per_req) max_num_blocks.append(max_num_blocks_per_req)
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [ if (
self.cache_config.block_size block_sizes != self._init_block_sizes
]: or kernel_block_sizes != self._init_kernel_block_sizes
):
assert self.offload_config.uva.cpu_offload_gb == 0, ( assert self.offload_config.uva.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight " "Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501 "offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
"for more details." "for more details."
) )
self._init_block_sizes = block_sizes
self._init_kernel_block_sizes = kernel_block_sizes
self.input_batch = InputBatch( self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
max_model_len=max_model_len, max_model_len=max_model_len,
...@@ -6162,6 +6173,15 @@ class GPUModelRunner( ...@@ -6162,6 +6173,15 @@ class GPUModelRunner(
is_pooling_model=self.is_pooling_model, is_pooling_model=self.is_pooling_model,
) )
assert self._init_block_sizes == block_sizes, (
f"InputBatch block_sizes {self._init_block_sizes} != "
f"kv_cache block_sizes {block_sizes}"
)
assert self._init_kernel_block_sizes == kernel_block_sizes, (
f"InputBatch kernel_block_sizes {self._init_kernel_block_sizes} "
f"!= kv_cache kernel_block_sizes {kernel_block_sizes}"
)
def _allocate_kv_cache_tensors( def _allocate_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig self, kv_cache_config: KVCacheConfig
) -> dict[str, torch.Tensor]: ) -> dict[str, torch.Tensor]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment