Commit 31330101 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-dev

parents e8933c34 dc1b4a6f
......@@ -69,12 +69,12 @@ class CpuPlatform(Platform):
cache_config = vllm_config.cache_config
ipex_avaliable = find_spec("intel_extension_for_pytorch") is not None
ipex_available = find_spec("intel_extension_for_pytorch") is not None
if cache_config and cache_config.block_size is None:
cache_config.block_size = 128 if ipex_avaliable else 16
cache_config.block_size = 128 if ipex_available else 16
if not ipex_avaliable and cache_config.block_size != 16:
if not ipex_available and cache_config.block_size != 16:
raise RuntimeError(
f"--block-size={cache_config.block_size} requires"
" intel_extension_for_pytorch")
......
......@@ -46,15 +46,15 @@ class HpuPlatform(Platform):
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
scheduler_config = vllm_config.scheduler_config
parallel_config = vllm_config.parallel_config
if scheduler_config.is_multi_step:
raise NotImplementedError(
"Multi-step execution is not implemented for HPU")
parallel_config.worker_cls = \
"vllm.worker.multi_step_hpu_worker.MultiStepHPUWorker"
if vllm_config.speculative_config is not None:
raise NotImplementedError(
"Speculative decoding is not implemented for HPU")
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker"
......
# SPDX-License-Identifier: Apache-2.0
import enum
import platform
import random
......@@ -9,14 +8,21 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
import numpy as np
import torch
from vllm.inputs import PromptType
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.utils import FlexibleArgumentParser
else:
ModelConfig = None
VllmConfig = None
LoRARequest = None
PoolingParams = None
SamplingParams = None
FlexibleArgumentParser = None
logger = init_logger(__name__)
......@@ -231,7 +237,7 @@ class Platform:
parser: Optional[FlexibleArgumentParser] = None
) -> None:
"""
Do some pre-registeration or update action for the current platform.
Do some pre-registration or update action for the current platform.
This function is called before global VllmConfig is initialized or cli
arguments are parsed. It's used for out-of-tree platforms to register or
......@@ -386,6 +392,14 @@ class Platform:
"""
return False
@classmethod
def validate_request(
cls,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
) -> None:
"""Raises if this request is unsupported on this platform"""
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED
......
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union
import torch
import vllm.envs as envs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams, SamplingType
from .interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from vllm.pooling_params import PoolingParams
else:
ModelConfig = None
VllmConfig = None
PoolingParams = None
logger = init_logger(__name__)
......@@ -116,6 +120,13 @@ class TpuPlatform(Platform):
assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")
if scheduler_config.is_multimodal_model and not \
scheduler_config.disable_chunked_mm_input:
logger.warning("TPU does not support running Multimodal models"\
" without setting `--disable_chunked_mm_input`. " \
"Forcing --disable_chunked_mm_input.")
scheduler_config.disable_chunked_mm_input = True
@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on TPU.")
......@@ -133,3 +144,18 @@ class TpuPlatform(Platform):
def supports_v1(cls, model_config: ModelConfig) -> bool:
# V1 support on TPU is experimental
return True
@classmethod
def validate_request(
cls,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
) -> None:
"""Raises if this request is unsupported on this platform"""
if isinstance(params, SamplingParams):
if params.guided_decoding is not None:
raise ValueError("Structured output is not supported on "
f"{cls.device_name}.")
if params.sampling_type == SamplingType.RANDOM_SEED:
raise ValueError(
"Torch XLA does not support per-request seed.")
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
import msgspec
if TYPE_CHECKING:
from vllm.config import ModelConfig
class PoolingParams(
msgspec.Struct,
......@@ -12,14 +15,30 @@ class PoolingParams(
"""API parameters for pooling models. This is currently a placeholder.
Attributes:
dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation.
additional_data: Any additional data needed for pooling.
"""
dimensions: Optional[int] = None
additional_data: Optional[Any] = None
def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance."""
return PoolingParams(additional_data=self.additional_data)
return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
def verify(self, model_config: "ModelConfig") -> None:
if self.dimensions is not None:
if not model_config.is_matryoshka:
raise ValueError(
f'Model "{model_config.served_model_name}" does not '
f'support matryoshka representation, '
f'changing output dimensions will lead to poor results.')
if self.dimensions < 1:
raise ValueError("Dimensions must be greater than 0")
def __repr__(self) -> str:
return (f"PoolingParams("
f"dimensions={self.dimensions}, "
f"additional_metadata={self.additional_data})")
......@@ -60,7 +60,7 @@ class GraniteReasoningParser(ReasoningParser):
Args:
model_output (str): Output of the model to be parsed.
request (ChatCompletionReqest): Request being processed.
request (ChatCompletionRequest): Request being processed.
Returns:
tuple[Optional[str], Optional[str]]: Tuple pair containing the
......
......@@ -101,7 +101,7 @@ class RequestOutputKind(Enum):
CUMULATIVE = 0
# Return only deltas in each RequestOutput
DELTA = 1
# Do not return intermediate RequestOuputs
# Do not return intermediate RequestOutput
FINAL_ONLY = 2
......@@ -385,9 +385,10 @@ class SamplingParams(
if not -2.0 <= self.frequency_penalty <= 2.0:
raise ValueError("frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}.")
if not 0.0 < self.repetition_penalty <= 2.0:
raise ValueError("repetition_penalty must be in (0, 2], got "
f"{self.repetition_penalty}.")
if self.repetition_penalty <= 0.0:
raise ValueError(
"repetition_penalty must be greater than zero, got "
f"{self.repetition_penalty}.")
if self.temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}.")
......
......@@ -1119,7 +1119,7 @@ class _PrintableStructure(Structure):
e.g. class that has _field_ 'hex_value', c_uint could be formatted with
_fmt_ = {"hex_value" : "%08X"}
to produce nicer output.
Default fomratting string for all fields can be set with key "<default>" like:
Default formatting string for all fields can be set with key "<default>" like:
_fmt_ = {"<default>" : "%d MHz"} # e.g all values are numbers in MHz.
If not set it's assumed to be just "%s"
......
......@@ -712,6 +712,7 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
def get_hf_image_processor_config(
model: Union[str, Path],
hf_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
**kwargs,
) -> Dict[str, Any]:
......@@ -721,7 +722,10 @@ def get_hf_image_processor_config(
# Separate model folder from file path for GGUF models
if check_gguf_file(model):
model = Path(model).parent
return get_image_processor_config(model, revision=revision, **kwargs)
return get_image_processor_config(model,
token=hf_token,
revision=revision,
**kwargs)
def get_hf_text_config(config: PretrainedConfig):
......
......@@ -5,6 +5,7 @@ from typing import Optional, Union
from transformers import AutoConfig, PretrainedConfig
import vllm.envs as envs
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
......@@ -41,8 +42,10 @@ class EAGLEConfig(PretrainedConfig):
self.truncated_vocab_size = self.model.vocab_size if \
truncated_vocab_size is None else truncated_vocab_size
if "architectures" not in kwargs:
if not envs.VLLM_USE_V1:
kwargs["architectures"] = ["EAGLEModel"]
else:
kwargs["architectures"] = ["EagleLlamaForCausalLM"]
super().__init__(**kwargs)
......
# SPDX-License-Identifier: Apache-2.0
from .mistral import (MistralTokenizer, maybe_serialize_tool_calls,
truncate_tool_call_ids)
truncate_tool_call_ids, validate_request_params)
__all__ = [
"MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids"
"MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids",
"validate_request_params"
]
......@@ -98,6 +98,13 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"):
request.messages[i]["tool_call_id"] = tool_call_id
def validate_request_params(request: "ChatCompletionRequest"):
if (request.skip_special_tokens is not None
and not request.skip_special_tokens):
raise ValueError("skip_special_tokens=False is not supported "
"for Mistral tokenizers.")
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,
......
# SPDX-License-Identifier: Apache-2.0
import json
from functools import cache
from os import PathLike
from pathlib import Path
......@@ -51,6 +52,26 @@ def modelscope_list_repo_files(
return files
def _maybe_json_dict(path: Union[str, PathLike]) -> dict[str, str]:
with open(path) as f:
try:
return json.loads(f.read())
except Exception:
return dict[str, str]()
def _maybe_space_split_dict(path: Union[str, PathLike]) -> dict[str, str]:
parsed_dict = dict[str, str]()
with open(path) as f:
for line in f.readlines():
try:
model_name, redirect_name = line.strip().split()
parsed_dict[model_name] = redirect_name
except Exception:
pass
return parsed_dict
@cache
def maybe_model_redirect(model: str) -> str:
"""
......@@ -68,16 +89,10 @@ def maybe_model_redirect(model: str) -> str:
if not Path(model_redirect_path).exists():
return model
with open(model_redirect_path) as f:
for line in f.readlines():
try:
model_name, redirect_name = line.split("\t")
if model == model_name:
redirect_name = redirect_name.strip()
logger.info("model redirect: [ %s ] -> [ %s ]", model,
redirect_name)
return redirect_name
except Exception:
pass
redirect_dict = (_maybe_json_dict(model_redirect_path)
or _maybe_space_split_dict(model_redirect_path))
if (redirect_model := redirect_dict.get(model)):
logger.info("model redirect: [ %s ] -> [ %s ]", model, redirect_model)
return redirect_model
return model
......@@ -2,7 +2,6 @@
from __future__ import annotations
import argparse
import asyncio
import concurrent
import contextlib
......@@ -25,6 +24,7 @@ import socket
import subprocess
import sys
import tempfile
import textwrap
import threading
import time
import traceback
......@@ -32,6 +32,8 @@ import types
import uuid
import warnings
import weakref
from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser,
ArgumentTypeError)
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
......@@ -40,7 +42,7 @@ from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Type, TypeVar, Union, cast, overload)
Optional, Tuple, Type, TypeVar, Union, cast, overload)
from uuid import uuid4
import cachetools
......@@ -53,6 +55,7 @@ import torch.types
import yaml
import zmq
import zmq.asyncio
from packaging import version
from packaging.version import Version
from torch.library import Library
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
......@@ -1209,7 +1212,7 @@ def run_once(f: Callable[P, None]) -> Callable[P, None]:
return wrapper
class StoreBoolean(argparse.Action):
class StoreBoolean(Action):
def __call__(self, parser, namespace, values, option_string=None):
if values.lower() == "true":
......@@ -1221,15 +1224,28 @@ class StoreBoolean(argparse.Action):
"Expected 'true' or 'false'.")
class SortedHelpFormatter(argparse.ArgumentDefaultsHelpFormatter):
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter):
"""SortedHelpFormatter that sorts arguments by their option strings."""
def _split_lines(self, text, width):
"""
1. Sentences split across lines have their single newlines removed.
2. Paragraphs and explicit newlines are split into separate lines.
3. Each line is wrapped to the specified width (width of terminal).
"""
# The patterns also include whitespace after the newline
single_newline = re.compile(r"(?<!\n)\n(?!\n)\s*")
multiple_newlines = re.compile(r"\n{2,}\s*")
text = single_newline.sub(' ', text)
lines = re.split(multiple_newlines, text)
return sum([textwrap.wrap(line, width) for line in lines], [])
def add_arguments(self, actions):
actions = sorted(actions, key=lambda x: x.option_strings)
super().add_arguments(actions)
class FlexibleArgumentParser(argparse.ArgumentParser):
class FlexibleArgumentParser(ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
def __init__(self, *args, **kwargs):
......@@ -1280,11 +1296,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
value = int(value)
except ValueError:
msg = "Port must be an integer"
raise argparse.ArgumentTypeError(msg) from None
raise ArgumentTypeError(msg) from None
if not (1024 <= value <= 65535):
raise argparse.ArgumentTypeError(
"Port must be between 1024 and 65535")
raise ArgumentTypeError("Port must be between 1024 and 65535")
return value
......@@ -2060,12 +2075,13 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: list[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
op_name: str,
op_func: Callable,
mutates_args: list[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
tags: Tuple[torch.Tag, ...] = (),
):
"""
`torch.library.custom_op` can have significant overhead because it
......@@ -2104,7 +2120,7 @@ def direct_register_custom_op(
import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str)
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
......@@ -2689,3 +2705,20 @@ def sha256(input) -> int:
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
byteorder="big")
def is_torch_equal_or_newer(target: str) -> bool:
"""Check if the installed torch version is >= the target version.
Args:
target: a version string, like "2.6.0".
Returns:
Whether the condition meets.
"""
try:
torch_version = version.parse(str(torch.__version__))
return torch_version >= version.parse(target)
except Exception:
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
return Version(importlib.metadata.version('torch')) >= Version(target)
......@@ -10,7 +10,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
......@@ -164,9 +164,9 @@ def make_local_attention_virtual_batches(
attn_chunk_size: int,
query_start_loc_np: np.ndarray,
seq_lens_np: np.ndarray,
block_table: torch.tensor,
block_table: torch.Tensor,
page_size: int = 0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.tensor]:
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
actual_batch_size = seq_lens_np.shape[0]
......@@ -264,7 +264,7 @@ def make_local_attention_virtual_batches(
np.arange(pages_per_local_batch, dtype=np.int32),
(virtual_batches, pages_per_local_batch)) \
+ np.expand_dims(block_starts, axis=1)
block_indices = block_indices.flatten()
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
local_blocks * pages_per_local_batch)
block_table_local = block_table[batch_indices, block_indices]\
......
......@@ -83,8 +83,8 @@ spda_o = scaled_dot_product_attention(
return spda_o @ W_O
NOTE: in the actual code,
`kv_b_proj` is [W_UK; W_UV] concatnated per head
`q_b_proj` is [W_UQ; W_QR] concatnated per head
`kv_b_proj` is [W_UK; W_UV] concatenated per head
`q_b_proj` is [W_UQ; W_QR] concatenated per head
`out_proj` is W_O
......@@ -195,7 +195,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
......
......@@ -10,6 +10,9 @@ import torch_xla.experimental.custom_kernel # noqa: F401
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.logger import init_logger
logger = init_logger(__name__)
class PallasAttentionBackend(AttentionBackend):
......@@ -80,7 +83,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
use_irope: bool = False,
) -> None:
if use_irope:
logger.warning_once(
"Using irope in Pallas is not supported yet, it will fall back "
"to global attention for long context.")
if blocksparse_params is not None:
raise ValueError("Paged attention Pallas kernel does "
"not support block-sparse attention.")
......
......@@ -67,11 +67,11 @@ class BlockPool:
Returns:
The cached block if it exists, or None.
"""
if block_hash in self.cached_block_hash_to_block:
first_block_id = list(
self.cached_block_hash_to_block[block_hash].keys())[0]
return self.cached_block_hash_to_block[block_hash][first_block_id]
return None
cached_blocks = self.cached_block_hash_to_block.get(block_hash)
if not cached_blocks:
return None
first_block_id = next(iter(cached_blocks))
return cached_blocks[first_block_id]
def cache_full_blocks(
self,
......
......@@ -133,6 +133,14 @@ def _compute_encoder_budget_multimodal(
_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
key=lambda item: item[1])
if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item
> scheduler_config.max_num_batched_tokens):
raise ValueError(
"Chunked MM input disabled but max_tokens_per_mm_item "
f"({max_tokens_per_mm_item}) is larger than max_num_batched_tokens"
f" ({scheduler_config.max_num_batched_tokens}). Please increase "
"max_num_batched_tokens.")
encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
max_tokens_per_mm_item)
encoder_cache_size = max(scheduler_config.encoder_cache_size,
......
......@@ -126,44 +126,46 @@ class KVCacheManager:
self.req_to_block_hashes[request.request_id] = block_hashes
self.prefix_cache_stats.requests += 1
if request.sampling_params.prompt_logprobs is None:
if len(block_hashes) * self.block_size == request.num_tokens:
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash = block_hashes.pop()
else:
last_block_hash = None
# When the request requires prompt logprobs, we skip prefix caching.
if request.sampling_params.prompt_logprobs is not None:
return [], 0
computed_blocks = (
self.specialized_manager.find_longest_cache_hit(block_hashes))
if len(block_hashes) * self.block_size == request.num_tokens:
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash = block_hashes.pop()
else:
last_block_hash = None
if last_block_hash is not None:
# Add back the last block hash if it was removed.
block_hashes.append(last_block_hash)
computed_blocks = (
self.specialized_manager.find_longest_cache_hit(block_hashes))
self.prefix_cache_stats.queries += len(block_hashes)
self.prefix_cache_stats.hits += len(computed_blocks)
self.prefix_cache_stats.queries += len(block_hashes)
self.prefix_cache_stats.hits += len(computed_blocks)
if last_block_hash is not None:
# Add back the last block hash if it was removed.
# NOTE: Because block_hashes is cached in req_to_block_hashes,
# we shouldn't modify it directly.
block_hashes.append(last_block_hash)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens
else:
# Skip cache hits for prompt logprobs
return [], 0
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens
def allocate_slots(
self,
request: Request,
num_tokens: int,
new_computed_blocks: Optional[list[KVCacheBlock]] = None
new_computed_blocks: Optional[list[KVCacheBlock]] = None,
num_lookahead_tokens: int = 0,
) -> Optional[list[KVCacheBlock]]:
"""Add slots for a request with new tokens to append.
......@@ -173,6 +175,9 @@ class KVCacheManager:
not include the tokens that have already been computed.
new_computed_blocks: A list of new computed blocks just hitting the
prefix caching.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
Blocks layout:
-----------------------------------------------------------------------
......@@ -210,8 +215,9 @@ class KVCacheManager:
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
self.block_size)
num_required_blocks = cdiv(
num_computed_tokens + num_tokens + num_lookahead_tokens,
self.block_size)
num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks))
......@@ -245,8 +251,11 @@ class KVCacheManager:
else:
# Get new blocks from the free block pool considering
# preallocated blocks.
num_preallocate_blocks = max(
0, self.num_preallocate_blocks -
num_lookahead_tokens // self.block_size)
num_new_blocks = min(
num_new_blocks + self.num_preallocate_blocks,
num_new_blocks + num_preallocate_blocks,
self.block_pool.get_num_free_blocks(),
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment