Commit 9c4ecf15 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.4' into v0.8.4-ori

parents bfc2d6f7 dc1b4a6f
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
import msgspec
if TYPE_CHECKING:
from vllm.config import ModelConfig
class PoolingParams(
msgspec.Struct,
......@@ -12,14 +15,30 @@ class PoolingParams(
"""API parameters for pooling models. This is currently a placeholder.
Attributes:
dimensions: Reduce the dimensions of embeddings
if model support matryoshka representation.
additional_data: Any additional data needed for pooling.
"""
dimensions: Optional[int] = None
additional_data: Optional[Any] = None
def clone(self) -> "PoolingParams":
"""Returns a deep copy of the PoolingParams instance."""
return PoolingParams(additional_data=self.additional_data)
return PoolingParams(dimensions=self.dimensions,
additional_data=self.additional_data)
def verify(self, model_config: "ModelConfig") -> None:
if self.dimensions is not None:
if not model_config.is_matryoshka:
raise ValueError(
f'Model "{model_config.served_model_name}" does not '
f'support matryoshka representation, '
f'changing output dimensions will lead to poor results.')
if self.dimensions < 1:
raise ValueError("Dimensions must be greater than 0")
def __repr__(self) -> str:
return (f"PoolingParams("
f"dimensions={self.dimensions}, "
f"additional_metadata={self.additional_data})")
......@@ -60,7 +60,7 @@ class GraniteReasoningParser(ReasoningParser):
Args:
model_output (str): Output of the model to be parsed.
request (ChatCompletionReqest): Request being processed.
request (ChatCompletionRequest): Request being processed.
Returns:
tuple[Optional[str], Optional[str]]: Tuple pair containing the
......
......@@ -101,7 +101,7 @@ class RequestOutputKind(Enum):
CUMULATIVE = 0
# Return only deltas in each RequestOutput
DELTA = 1
# Do not return intermediate RequestOuputs
# Do not return intermediate RequestOutput
FINAL_ONLY = 2
......@@ -385,9 +385,10 @@ class SamplingParams(
if not -2.0 <= self.frequency_penalty <= 2.0:
raise ValueError("frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}.")
if not 0.0 < self.repetition_penalty <= 2.0:
raise ValueError("repetition_penalty must be in (0, 2], got "
f"{self.repetition_penalty}.")
if self.repetition_penalty <= 0.0:
raise ValueError(
"repetition_penalty must be greater than zero, got "
f"{self.repetition_penalty}.")
if self.temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}.")
......
......@@ -1119,7 +1119,7 @@ class _PrintableStructure(Structure):
e.g. class that has _field_ 'hex_value', c_uint could be formatted with
_fmt_ = {"hex_value" : "%08X"}
to produce nicer output.
Default fomratting string for all fields can be set with key "<default>" like:
Default formatting string for all fields can be set with key "<default>" like:
_fmt_ = {"<default>" : "%d MHz"} # e.g all values are numbers in MHz.
If not set it's assumed to be just "%s"
......
......@@ -712,6 +712,7 @@ def load_params_config(model: Union[str, Path], revision: Optional[str],
def get_hf_image_processor_config(
model: Union[str, Path],
hf_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
**kwargs,
) -> Dict[str, Any]:
......@@ -721,7 +722,10 @@ def get_hf_image_processor_config(
# Separate model folder from file path for GGUF models
if check_gguf_file(model):
model = Path(model).parent
return get_image_processor_config(model, revision=revision, **kwargs)
return get_image_processor_config(model,
token=hf_token,
revision=revision,
**kwargs)
def get_hf_text_config(config: PretrainedConfig):
......
......@@ -5,6 +5,7 @@ from typing import Optional, Union
from transformers import AutoConfig, PretrainedConfig
import vllm.envs as envs
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
......@@ -41,8 +42,10 @@ class EAGLEConfig(PretrainedConfig):
self.truncated_vocab_size = self.model.vocab_size if \
truncated_vocab_size is None else truncated_vocab_size
if "architectures" not in kwargs:
if not envs.VLLM_USE_V1:
kwargs["architectures"] = ["EAGLEModel"]
else:
kwargs["architectures"] = ["EagleLlamaForCausalLM"]
super().__init__(**kwargs)
......
# SPDX-License-Identifier: Apache-2.0
from .mistral import (MistralTokenizer, maybe_serialize_tool_calls,
truncate_tool_call_ids)
truncate_tool_call_ids, validate_request_params)
__all__ = [
"MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids"
"MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids",
"validate_request_params"
]
......@@ -98,6 +98,13 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"):
request.messages[i]["tool_call_id"] = tool_call_id
def validate_request_params(request: "ChatCompletionRequest"):
if (request.skip_special_tokens is not None
and not request.skip_special_tokens):
raise ValueError("skip_special_tokens=False is not supported "
"for Mistral tokenizers.")
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,
......
# SPDX-License-Identifier: Apache-2.0
import json
from functools import cache
from os import PathLike
from pathlib import Path
......@@ -51,6 +52,26 @@ def modelscope_list_repo_files(
return files
def _maybe_json_dict(path: Union[str, PathLike]) -> dict[str, str]:
with open(path) as f:
try:
return json.loads(f.read())
except Exception:
return dict[str, str]()
def _maybe_space_split_dict(path: Union[str, PathLike]) -> dict[str, str]:
parsed_dict = dict[str, str]()
with open(path) as f:
for line in f.readlines():
try:
model_name, redirect_name = line.strip().split()
parsed_dict[model_name] = redirect_name
except Exception:
pass
return parsed_dict
@cache
def maybe_model_redirect(model: str) -> str:
"""
......@@ -68,16 +89,10 @@ def maybe_model_redirect(model: str) -> str:
if not Path(model_redirect_path).exists():
return model
with open(model_redirect_path) as f:
for line in f.readlines():
try:
model_name, redirect_name = line.split("\t")
if model == model_name:
redirect_name = redirect_name.strip()
logger.info("model redirect: [ %s ] -> [ %s ]", model,
redirect_name)
return redirect_name
except Exception:
pass
redirect_dict = (_maybe_json_dict(model_redirect_path)
or _maybe_space_split_dict(model_redirect_path))
if (redirect_model := redirect_dict.get(model)):
logger.info("model redirect: [ %s ] -> [ %s ]", model, redirect_model)
return redirect_model
return model
......@@ -2,7 +2,6 @@
from __future__ import annotations
import argparse
import asyncio
import concurrent
import contextlib
......@@ -25,6 +24,7 @@ import socket
import subprocess
import sys
import tempfile
import textwrap
import threading
import time
import traceback
......@@ -32,6 +32,8 @@ import types
import uuid
import warnings
import weakref
from argparse import (Action, ArgumentDefaultsHelpFormatter, ArgumentParser,
ArgumentTypeError)
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import UserDict, defaultdict
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
......@@ -40,7 +42,7 @@ from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Type, TypeVar, Union, cast, overload)
Optional, Tuple, Type, TypeVar, Union, cast, overload)
from uuid import uuid4
import cachetools
......@@ -53,6 +55,7 @@ import torch.types
import yaml
import zmq
import zmq.asyncio
from packaging import version
from packaging.version import Version
from torch.library import Library
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
......@@ -1208,7 +1211,7 @@ def run_once(f: Callable[P, None]) -> Callable[P, None]:
return wrapper
class StoreBoolean(argparse.Action):
class StoreBoolean(Action):
def __call__(self, parser, namespace, values, option_string=None):
if values.lower() == "true":
......@@ -1220,15 +1223,28 @@ class StoreBoolean(argparse.Action):
"Expected 'true' or 'false'.")
class SortedHelpFormatter(argparse.ArgumentDefaultsHelpFormatter):
class SortedHelpFormatter(ArgumentDefaultsHelpFormatter):
"""SortedHelpFormatter that sorts arguments by their option strings."""
def _split_lines(self, text, width):
"""
1. Sentences split across lines have their single newlines removed.
2. Paragraphs and explicit newlines are split into separate lines.
3. Each line is wrapped to the specified width (width of terminal).
"""
# The patterns also include whitespace after the newline
single_newline = re.compile(r"(?<!\n)\n(?!\n)\s*")
multiple_newlines = re.compile(r"\n{2,}\s*")
text = single_newline.sub(' ', text)
lines = re.split(multiple_newlines, text)
return sum([textwrap.wrap(line, width) for line in lines], [])
def add_arguments(self, actions):
actions = sorted(actions, key=lambda x: x.option_strings)
super().add_arguments(actions)
class FlexibleArgumentParser(argparse.ArgumentParser):
class FlexibleArgumentParser(ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""
def __init__(self, *args, **kwargs):
......@@ -1279,11 +1295,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
value = int(value)
except ValueError:
msg = "Port must be an integer"
raise argparse.ArgumentTypeError(msg) from None
raise ArgumentTypeError(msg) from None
if not (1024 <= value <= 65535):
raise argparse.ArgumentTypeError(
"Port must be between 1024 and 65535")
raise ArgumentTypeError("Port must be between 1024 and 65535")
return value
......@@ -1935,12 +1950,13 @@ vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: list[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
op_name: str,
op_func: Callable,
mutates_args: list[str],
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
tags: Tuple[torch.Tag, ...] = (),
):
"""
`torch.library.custom_op` can have significant overhead because it
......@@ -1979,7 +1995,7 @@ def direct_register_custom_op(
import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str)
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
......@@ -2564,3 +2580,20 @@ def sha256(input) -> int:
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
byteorder="big")
def is_torch_equal_or_newer(target: str) -> bool:
"""Check if the installed torch version is >= the target version.
Args:
target: a version string, like "2.6.0".
Returns:
Whether the condition meets.
"""
try:
torch_version = version.parse(str(torch.__version__))
return torch_version >= version.parse(target)
except Exception:
# Fallback to PKG-INFO to load the package info, needed by the doc gen.
return Version(importlib.metadata.version('torch')) >= Version(target)
......@@ -10,7 +10,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
......@@ -164,9 +164,9 @@ def make_local_attention_virtual_batches(
attn_chunk_size: int,
query_start_loc_np: np.ndarray,
seq_lens_np: np.ndarray,
block_table: torch.tensor,
block_table: torch.Tensor,
page_size: int = 0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.tensor]:
) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]:
q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1]
actual_batch_size = seq_lens_np.shape[0]
......@@ -264,7 +264,7 @@ def make_local_attention_virtual_batches(
np.arange(pages_per_local_batch, dtype=np.int32),
(virtual_batches, pages_per_local_batch)) \
+ np.expand_dims(block_starts, axis=1)
block_indices = block_indices.flatten()
block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1)
batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32),
local_blocks * pages_per_local_batch)
block_table_local = block_table[batch_indices, block_indices]\
......
......@@ -83,8 +83,8 @@ spda_o = scaled_dot_product_attention(
return spda_o @ W_O
NOTE: in the actual code,
`kv_b_proj` is [W_UK; W_UV] concatnated per head
`q_b_proj` is [W_UQ; W_QR] concatnated per head
`kv_b_proj` is [W_UK; W_UV] concatenated per head
`q_b_proj` is [W_UQ; W_QR] concatenated per head
`out_proj` is W_O
......@@ -195,7 +195,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
......
......@@ -10,6 +10,9 @@ import torch_xla.experimental.custom_kernel # noqa: F401
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.logger import init_logger
logger = init_logger(__name__)
class PallasAttentionBackend(AttentionBackend):
......@@ -80,7 +83,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
use_irope: bool = False,
) -> None:
if use_irope:
logger.warning_once(
"Using irope in Pallas is not supported yet, it will fall back "
"to global attention for long context.")
if blocksparse_params is not None:
raise ValueError("Paged attention Pallas kernel does "
"not support block-sparse attention.")
......
......@@ -67,11 +67,11 @@ class BlockPool:
Returns:
The cached block if it exists, or None.
"""
if block_hash in self.cached_block_hash_to_block:
first_block_id = list(
self.cached_block_hash_to_block[block_hash].keys())[0]
return self.cached_block_hash_to_block[block_hash][first_block_id]
return None
cached_blocks = self.cached_block_hash_to_block.get(block_hash)
if not cached_blocks:
return None
first_block_id = next(iter(cached_blocks))
return cached_blocks[first_block_id]
def cache_full_blocks(
self,
......
......@@ -133,6 +133,14 @@ def _compute_encoder_budget_multimodal(
_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
key=lambda item: item[1])
if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item
> scheduler_config.max_num_batched_tokens):
raise ValueError(
"Chunked MM input disabled but max_tokens_per_mm_item "
f"({max_tokens_per_mm_item}) is larger than max_num_batched_tokens"
f" ({scheduler_config.max_num_batched_tokens}). Please increase "
"max_num_batched_tokens.")
encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
max_tokens_per_mm_item)
encoder_cache_size = max(scheduler_config.encoder_cache_size,
......
......@@ -126,44 +126,46 @@ class KVCacheManager:
self.req_to_block_hashes[request.request_id] = block_hashes
self.prefix_cache_stats.requests += 1
if request.sampling_params.prompt_logprobs is None:
if len(block_hashes) * self.block_size == request.num_tokens:
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash = block_hashes.pop()
else:
last_block_hash = None
# When the request requires prompt logprobs, we skip prefix caching.
if request.sampling_params.prompt_logprobs is not None:
return [], 0
computed_blocks = (
self.specialized_manager.find_longest_cache_hit(block_hashes))
if len(block_hashes) * self.block_size == request.num_tokens:
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token. This
# have to be achieved by re-computing an entire block because
# allocate_slots() assumes num_computed_tokens is always a
# multiple of the block size. To achieve this, remove the last
# block hash from the block_hashes for find_longest_cache_hit
# This limitation can potentially be removed in the future to
# slightly improve the performance.
last_block_hash = block_hashes.pop()
else:
last_block_hash = None
if last_block_hash is not None:
# Add back the last block hash if it was removed.
block_hashes.append(last_block_hash)
computed_blocks = (
self.specialized_manager.find_longest_cache_hit(block_hashes))
self.prefix_cache_stats.queries += len(block_hashes)
self.prefix_cache_stats.hits += len(computed_blocks)
self.prefix_cache_stats.queries += len(block_hashes)
self.prefix_cache_stats.hits += len(computed_blocks)
if last_block_hash is not None:
# Add back the last block hash if it was removed.
# NOTE: Because block_hashes is cached in req_to_block_hashes,
# we shouldn't modify it directly.
block_hashes.append(last_block_hash)
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens
else:
# Skip cache hits for prompt logprobs
return [], 0
# NOTE(woosuk): Since incomplete blocks are not eligible for
# sharing, `num_computed_tokens` is always a multiple of
# `block_size`.
num_computed_tokens = len(computed_blocks) * self.block_size
return computed_blocks, num_computed_tokens
def allocate_slots(
self,
request: Request,
num_tokens: int,
new_computed_blocks: Optional[list[KVCacheBlock]] = None
new_computed_blocks: Optional[list[KVCacheBlock]] = None,
num_lookahead_tokens: int = 0,
) -> Optional[list[KVCacheBlock]]:
"""Add slots for a request with new tokens to append.
......@@ -173,6 +175,9 @@ class KVCacheManager:
not include the tokens that have already been computed.
new_computed_blocks: A list of new computed blocks just hitting the
prefix caching.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
Blocks layout:
-----------------------------------------------------------------------
......@@ -210,8 +215,9 @@ class KVCacheManager:
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
len(new_computed_blocks) * self.block_size)
num_required_blocks = cdiv(num_computed_tokens + num_tokens,
self.block_size)
num_required_blocks = cdiv(
num_computed_tokens + num_tokens + num_lookahead_tokens,
self.block_size)
num_new_blocks = (num_required_blocks - len(req_blocks) -
len(new_computed_blocks))
......@@ -245,8 +251,11 @@ class KVCacheManager:
else:
# Get new blocks from the free block pool considering
# preallocated blocks.
num_preallocate_blocks = max(
0, self.num_preallocate_blocks -
num_lookahead_tokens // self.block_size)
num_new_blocks = min(
num_new_blocks + self.num_preallocate_blocks,
num_new_blocks + num_preallocate_blocks,
self.block_pool.get_num_free_blocks(),
# Should not exceed the maximum number of blocks per request.
# This is especially because the block table has the shape
......
......@@ -8,7 +8,7 @@ from typing import Any, Callable, NamedTuple, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import sha256
from vllm.utils import GiB_bytes, sha256
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
......@@ -310,8 +310,7 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if mm_positions[-1]["offset"] + mm_positions[-1][
"length"] < start_token_idx:
if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx:
return extra_keys, start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
......@@ -322,8 +321,8 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
curr_mm_idx = start_mm_idx
while mm_positions and curr_mm_idx < len(mm_positions):
assert mm_hashes[curr_mm_idx] is not None
offset = mm_positions[curr_mm_idx]["offset"]
length = mm_positions[curr_mm_idx]["length"]
offset = mm_positions[curr_mm_idx].offset
length = mm_positions[curr_mm_idx].length
if end_token_idx > offset:
if start_token_idx > offset + length:
# This block has passed the current mm input.
......@@ -460,6 +459,54 @@ def hash_request_tokens(hash_function: Any, block_size: int,
return ret
def estimate_max_model_len(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> int:
"""
Estimates the maximum model length that can fit in the available memory
using binary search.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The estimated maximum model length that can fit in the available memory.
"""
# Define a function to check if a given model length fits in memory
def fits_in_memory(model_len: int) -> bool:
# Modify the max_model_len for this calculation
vllm_config.model_config.max_model_len = model_len
# Calculate memory needed for the given model length
memory_needed = sum(
(layer_spec.max_memory_usage_bytes(vllm_config)
for layer_spec in kv_cache_spec.values()),
start=0,
)
return memory_needed <= available_memory
# Binary search for the maximum model length
current_max = vllm_config.model_config.max_model_len
left, right = 1, current_max
# If even the smallest model length doesn't fit, return 0
if not fits_in_memory(left):
return 0
# Binary search for the maximum model length that fits
result = 1
while left <= right:
mid = (left + right) // 2
if fits_in_memory(mid):
result = mid
left = mid + 1
else:
right = mid - 1
return result
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int):
......@@ -487,12 +534,21 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
needed_memory += layer_spec.max_memory_usage_bytes(vllm_config)
if needed_memory > available_memory:
# Estimate the maximum model length that can fit in the available memory
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
available_memory)
estimated_msg = ""
if estimated_max_len > 0:
estimated_msg = " Based on the available memory,"
f" the estimated maximum model length is {estimated_max_len}."
raise ValueError(
f"To serve at least one request with the models's max seq len "
f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GiB KV "
f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV "
f"cache is needed, which is larger than the available KV cache "
f"memory ({available_memory/1024/1024/1024:.2f} GiB). Try "
f"increasing `gpu_memory_utilization` or decreasing "
f"memory ({available_memory/GiB_bytes:.2f} GiB)."
f"{estimated_msg} "
f" Try increasing `gpu_memory_utilization` or decreasing "
f"`max_model_len` when initializing the engine.")
......
......@@ -7,7 +7,8 @@ from collections import deque
from collections.abc import Iterable
from typing import Optional, Union
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
......@@ -39,6 +40,7 @@ class Scheduler(SchedulerInterface):
lora_config: Optional[LoRAConfig],
kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager,
speculative_config: SpeculativeConfig = None,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
......@@ -112,6 +114,11 @@ class Scheduler(SchedulerInterface):
self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size)
self.num_lookahead_tokens = 0
if speculative_config and speculative_config.method == "eagle":
self.num_lookahead_tokens = \
speculative_config.num_speculative_tokens
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
......@@ -188,7 +195,9 @@ class Scheduler(SchedulerInterface):
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens)
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
......@@ -505,8 +514,8 @@ class Scheduler(SchedulerInterface):
assert mm_positions is not None
assert len(mm_positions) > 0
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"]
num_encoder_tokens = pos_info["length"]
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
......@@ -522,6 +531,17 @@ class Scheduler(SchedulerInterface):
if self.encoder_cache_manager.has_cache(request, i):
# The encoder input is already computed and cached.
continue
# If no encoder input chunking is allowed, we do not want to
# partially schedule a multimodal item. If the scheduled range would
# only cover part of the mm input, roll back to before the mm item.
if (self.scheduler_config.disable_chunked_mm_input
and num_computed_tokens < start_pos
and (num_computed_tokens + num_new_tokens)
< (start_pos + num_encoder_tokens)):
num_new_tokens = start_pos - num_computed_tokens
break
if (not self.encoder_cache_manager.can_allocate(request, i)
or num_encoder_tokens > encoder_budget):
# The encoder cache is full or the encoder budget is exhausted.
......@@ -596,8 +616,8 @@ class Scheduler(SchedulerInterface):
if cached_encoder_input_ids:
for input_id in list(cached_encoder_input_ids):
mm_positions = request.mm_positions[input_id]
start_pos = mm_positions["offset"]
num_tokens = mm_positions["length"]
start_pos = mm_positions.offset
num_tokens = mm_positions.length
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
......
......@@ -2,6 +2,7 @@
import enum
import time
from collections.abc import Sequence
from typing import Any, Optional, Union
import msgspec
......@@ -52,7 +53,7 @@ class EngineCoreRequest(
# Detokenizer, but set to None when it is added to EngineCoreClient.
prompt: Optional[str]
prompt_token_ids: list[int]
mm_inputs: Optional[list[MultiModalKwargs]]
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: SamplingParams
......
......@@ -31,7 +31,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
......@@ -98,6 +98,7 @@ class EngineCore:
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
kv_cache_config=kv_cache_config,
speculative_config=vllm_config.speculative_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
......@@ -105,7 +106,7 @@ class EngineCore:
)
# Setup MM Input Mapper.
self.mm_input_cache_server = MMInputCacheServer(
self.mm_input_cache_server = MirroredProcessingCache(
vllm_config.model_config)
# Setup batch queue for pipeline parallelism.
......@@ -173,7 +174,7 @@ class EngineCore:
# anything that has a hash must have a HIT cache entry here
# as well.
assert request.mm_inputs is not None
request.mm_inputs = self.mm_input_cache_server.get_and_update(
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)
......@@ -486,14 +487,14 @@ class EngineCoreProc(EngineCore):
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
type_frame, *data_frames = socket.recv_multipart(copy=False)
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frame.buffer)
request = decoder.decode(data_frames)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
......@@ -510,8 +511,8 @@ class EngineCoreProc(EngineCore):
while True:
outputs = self.output_queue.get()
outputs.engine_index = engine_index
encoder.encode_into(outputs, buffer)
socket.send(buffer, copy=False)
buffers = encoder.encode_into(outputs, buffer)
socket.send_multipart(buffers, copy=False)
ENGINE_PAUSED_OUTPUTS = EngineCoreOutputs(engine_paused=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment