Unverified Commit 77a73458 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

Reapply [Attention] Refactor `check_and_update_config` (#35122)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 5578f2a4
......@@ -162,7 +162,7 @@ class XPUPlatform(Platform):
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
# in V1(or with chunked prefill) block_size is 64
if cache_config and cache_config.block_size is None:
if cache_config and not cache_config.user_specified_block_size:
cache_config.block_size = 64
# lazy import to avoid circular import
......@@ -227,6 +227,12 @@ class XPUPlatform(Platform):
# ref. https://openucx.readthedocs.io/en/master/faq.html
os.environ["UCX_MEMTYPE_CACHE"] = "n"
@classmethod
def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None:
# TODO: XPU still sets block_size in check_and_update_config.
# Move that logic here so block_size is chosen by the backend.
pass
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True
......
......@@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar
import numpy as np
import torch
......@@ -144,15 +144,9 @@ class AttentionBackend(ABC):
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
from vllm.config.cache import BlockSize
if block_size is None:
return True
valid_sizes = get_args(BlockSize)
if block_size not in valid_sizes:
return False
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
if not supported_kernel_block_sizes:
return True
......@@ -167,6 +161,17 @@ class AttentionBackend(ABC):
return True
return False
@classmethod
def get_preferred_block_size(cls, default_block_size: int) -> int:
supported_sizes = cls.get_supported_kernel_block_sizes()
if not supported_sizes:
return default_block_size
if cls.supports_block_size(default_block_size):
return default_block_size
return min(s.base if isinstance(s, MultipleOf) else s for s in supported_sizes)
@classmethod
def is_mla(cls) -> bool:
return False
......@@ -210,7 +215,7 @@ class AttentionBackend(ABC):
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int,
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
......@@ -224,7 +229,7 @@ class AttentionBackend(ABC):
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int,
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
......
......@@ -75,7 +75,7 @@ class FlashAttnMLABackend(MLACommonBackend):
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
......
......@@ -69,7 +69,7 @@ class FlashInferMLABackend(MLACommonBackend):
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
......
......@@ -106,7 +106,7 @@ class FlashInferMLASparseBackend(AttentionBackend):
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
......
......@@ -80,7 +80,7 @@ class FlashMLABackend(MLACommonBackend):
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
......
......@@ -49,7 +49,6 @@ def get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
......@@ -71,6 +70,12 @@ def get_attn_backend(
vllm_config = get_current_vllm_config()
cache_config = vllm_config.cache_config
if cache_config is not None and cache_config.user_specified_block_size:
block_size = cache_config.block_size
else:
block_size = None
attn_selector_config = AttentionSelectorConfig(
head_size=head_size,
dtype=dtype,
......
......@@ -122,7 +122,11 @@ class EngineCore:
num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
vllm_config
)
if kv_cache_config.kv_cache_groups:
vllm_config.cache_config.block_size = min(
g.kv_cache_spec.block_size for g in kv_cache_config.kv_cache_groups
)
vllm_config.validate_block_size()
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks))
......
......@@ -42,6 +42,7 @@ from vllm.distributed.parallel_state import (
)
from vllm.envs import enable_envs_cache
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.tracing import instrument, maybe_init_worker_tracer
from vllm.utils.network_utils import (
get_distributed_init_method,
......@@ -617,6 +618,9 @@ class WorkerProc:
)
self.worker.load_model()
# Set block size based on the attention backends
current_platform.update_block_size_for_backend(vllm_config)
# Initialize message queues after init_device() since multi-node setups
# (nnodes_within_dp > 1) require distributed groups to be initialized
self._init_message_queues(input_shm_handle, vllm_config)
......
......@@ -387,6 +387,11 @@ class RayDistributedExecutor(Executor):
self.collective_rpc("init_device")
self.collective_rpc("load_model")
def _update_block_size(worker):
current_platform.update_block_size_for_backend(worker.vllm_config)
self.collective_rpc(_update_block_size)
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(self.parallel_config.tensor_parallel_size):
......
......@@ -12,6 +12,7 @@ import torch.distributed as dist
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.executor.abstract import Executor
......@@ -47,6 +48,7 @@ class UniProcExecutor(Executor):
if not is_eep_new_worker:
self.driver_worker.init_device()
self.driver_worker.load_model()
current_platform.update_block_size_for_backend(self.vllm_config)
def _distributed_args(self) -> tuple[str, int, int]:
"""Return (distributed_init_method, rank, local_rank)."""
......
......@@ -32,6 +32,7 @@ from vllm.config import (
set_current_vllm_config,
update_config,
)
from vllm.config.cache import CacheConfig
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
......@@ -586,6 +587,11 @@ class GPUModelRunner(
custom_logitsprocs: Sequence[str | type[LogitsProcessor]] = (
tuple(logits_processors) if logits_processors is not None else ()
)
placeholder_block_size = (
self.cache_config.block_size or CacheConfig.DEFAULT_BLOCK_SIZE
)
self._init_block_sizes = [placeholder_block_size]
self._init_kernel_block_sizes = [placeholder_block_size]
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
# We need to use the encoder length for encoder-decoder
......@@ -595,8 +601,8 @@ class GPUModelRunner(
device=self.device,
pin_memory=self.pin_memory,
vocab_size=self.model_config.get_vocab_size(),
block_sizes=[self.cache_config.block_size],
kernel_block_sizes=[self.cache_config.block_size],
block_sizes=[placeholder_block_size],
kernel_block_sizes=[placeholder_block_size],
is_spec_decode=bool(self.vllm_config.speculative_config),
logitsprocs=build_logitsprocs(
self.vllm_config,
......@@ -6112,8 +6118,10 @@ class GPUModelRunner(
) -> None:
"""
Re-initialize the input batch if the block sizes are different from
`[self.cache_config.block_size]`. This usually happens when there
are multiple KV cache groups.
what it was originally created with. This happens when the final
block size (determined after model loading) differs from the
placeholder used during __init__, or when there are multiple
KV cache groups.
Args:
kv_cache_config: The KV cache configuration.
......@@ -6138,14 +6146,17 @@ class GPUModelRunner(
) + kv_cache_group.kv_cache_spec.num_speculative_blocks
max_num_blocks.append(max_num_blocks_per_req)
if block_sizes != [self.cache_config.block_size] or kernel_block_sizes != [
self.cache_config.block_size
]:
if (
block_sizes != self._init_block_sizes
or kernel_block_sizes != self._init_kernel_block_sizes
):
assert self.offload_config.uva.cpu_offload_gb == 0, (
"Cannot re-initialize the input batch when CPU weight "
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
"for more details."
)
self._init_block_sizes = block_sizes
self._init_kernel_block_sizes = kernel_block_sizes
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=max_model_len,
......@@ -6162,6 +6173,15 @@ class GPUModelRunner(
is_pooling_model=self.is_pooling_model,
)
assert self._init_block_sizes == block_sizes, (
f"InputBatch block_sizes {self._init_block_sizes} != "
f"kv_cache block_sizes {block_sizes}"
)
assert self._init_kernel_block_sizes == kernel_block_sizes, (
f"InputBatch kernel_block_sizes {self._init_kernel_block_sizes} "
f"!= kv_cache kernel_block_sizes {kernel_block_sizes}"
)
def _allocate_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig
) -> dict[str, torch.Tensor]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment