Commit 081057de authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-ori

parents 7cf5d5c4 ba41cc90
......@@ -46,7 +46,7 @@ class KVTransferAgent:
assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
"TransferAgent should only be used when kv_connector is set."
self.connector = KVConnectorFactory.create_connector(
self.connector = KVConnectorFactory.create_connector_v0(
rank, local_rank, config)
def send_kv_caches_and_hidden_states(
......
......@@ -70,7 +70,7 @@ class MooncakeStore(KVStoreBufferBase):
):
try:
from mooncake_vllm_adaptor import MooncakeDistributedStore
from mooncake.store import MooncakeDistributedStore
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
......
......@@ -2,6 +2,7 @@
import json
import os
import struct
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Optional, Union
......@@ -57,14 +58,14 @@ class MooncakeTransferEngine:
def __init__(self, kv_rank: int, local_rank: int):
try:
import mooncake_vllm_adaptor as mva
from mooncake.engine import TransferEngine
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector.") from e
self.engine = mva.mooncake_vllm_adaptor()
self.engine = TransferEngine()
self.local_rank = local_rank
try:
......@@ -115,14 +116,14 @@ class MooncakeTransferEngine:
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
if kv_rank == 0:
self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}")
self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}")
self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
else:
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}")
self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}")
self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
def initialize(self, local_hostname: str, metadata_server: str,
......@@ -140,12 +141,12 @@ class MooncakeTransferEngine:
"Mooncake Configuration error. `metadata_backend`"
f" should be one of {supported_backend}.")
self.engine.initializeExt(local_hostname, metadata_server,
protocol, device_name, metadata_backend)
self.engine.initialize_ext(local_hostname, metadata_server,
protocol, device_name, metadata_backend)
def allocate_managed_buffer(self, length: int) -> int:
"""Allocate a managed buffer of the specified length."""
ret = self.engine.allocateManagedBuffer(length)
ret = self.engine.allocate_managed_buffer(length)
if ret <= 0:
logger.error("Allocation Return Error")
raise Exception("Allocation Return Error")
......@@ -153,13 +154,13 @@ class MooncakeTransferEngine:
def free_managed_buffer(self, buffer: int, length: int) -> int:
"""Free a previously allocated managed buffer."""
return self.engine.freeManagedBuffer(buffer, length)
return self.engine.free_managed_buffer(buffer, length)
def transfer_sync(self, buffer: int, peer_buffer_address: int,
length: int) -> int:
"""Synchronously transfer data to the specified address."""
ret = self.engine.transferSync(self.remote_url, buffer,
peer_buffer_address, length)
ret = self.engine.transfer_sync_read(self.remote_url, buffer,
peer_buffer_address, length)
if ret < 0:
logger.error("Transfer Return Error")
raise Exception("Transfer Return Error")
......@@ -168,15 +169,15 @@ class MooncakeTransferEngine:
def write_bytes_to_buffer(self, buffer: int, user_data: bytes,
length: int) -> int:
"""Write bytes to the allocated buffer."""
return self.engine.writeBytesToBuffer(buffer, user_data, length)
return self.engine.write_bytes_to_buffer(buffer, user_data, length)
def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
"""Read bytes from the allocated buffer."""
return self.engine.readBytesFromBuffer(buffer, length)
return self.engine.read_bytes_from_buffer(buffer, length)
def wait_for_ack(self, src_ptr: int, length: int) -> None:
"""Asynchronously wait for ACK from the receiver."""
ack = self.sender_ack.recv_pyobj()
ack = self.sender_ack.recv()
if ack != b'ACK':
logger.error("Failed to receive ACK from the receiver")
......@@ -187,18 +188,22 @@ class MooncakeTransferEngine:
length = len(user_data)
src_ptr = self.allocate_managed_buffer(length)
self.write_bytes_to_buffer(src_ptr, user_data, length)
self.sender_socket.send_pyobj((src_ptr, length))
self.sender_socket.send_multipart(
[struct.pack("!Q", src_ptr),
struct.pack("!Q", length)])
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
def recv_bytes(self) -> bytes:
"""Receive bytes from the remote process."""
src_ptr, length = self.receiver_socket.recv_pyobj()
data = self.receiver_socket.recv_multipart()
src_ptr = struct.unpack("!Q", data[0])[0]
length = struct.unpack("!Q", data[1])[0]
dst_ptr = self.allocate_managed_buffer(length)
self.transfer_sync(dst_ptr, src_ptr, length)
ret = self.read_bytes_from_buffer(dst_ptr, length)
# Buffer cleanup
self.receiver_ack.send_pyobj(b'ACK')
self.receiver_ack.send(b'ACK')
self.free_managed_buffer(dst_ptr, length)
return ret
......
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional
from vllm import envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
from vllm.distributed.parallel_state import get_world_group
if TYPE_CHECKING:
from vllm.config import VllmConfig
_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
def get_kv_transfer_group() -> KVConnectorBaseType:
assert _KV_CONNECTOR_AGENT is not None, (
"disaggregated KV cache transfer parallel group is not initialized")
return _KV_CONNECTOR_AGENT
def has_kv_transfer_group() -> bool:
return _KV_CONNECTOR_AGENT is not None
def is_v1_kv_transfer_group(
connector: Optional[KVConnectorBaseType] = None) -> bool:
"""Check if the KV connector is the v1 connector.
If the argument is None, it will check the global KV connector
Args:
connector: The KV connector to check. If None, it will check the
global KV connector.
Note:
This function will no-longer be needed after the v1 KV connector
becomes the default.
"""
if connector is None:
connector = _KV_CONNECTOR_AGENT
if connector is None:
return False
return isinstance(connector, KVConnectorBase_V1)
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
Initialize KV cache transfer parallel group.
"""
global _KV_CONNECTOR_AGENT
if vllm_config.kv_transfer_config is None:
return
if (vllm_config.kv_transfer_config.is_kv_transfer_instance
and _KV_CONNECTOR_AGENT is None):
if envs.VLLM_USE_V1:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
config=vllm_config, role=KVConnectorRole.WORKER)
else:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
rank=get_world_group().rank,
local_rank=get_world_group().local_rank,
config=vllm_config,
)
......@@ -29,15 +29,13 @@ from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from multiprocessing import shared_memory
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch
import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs
from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase)
......@@ -46,9 +44,6 @@ from vllm.logger import init_logger
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
supports_custom_op)
if TYPE_CHECKING:
from vllm.config import VllmConfig
@dataclass
class GraphCaptureContext:
......@@ -118,6 +113,38 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor)
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group.reduce_scatter(tensor, dim)
def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
new_shape = list(tensor.shape)
new_shape[dim] = tensor.shape[dim] // world_size
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
def all_gather(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group.all_gather(tensor, dim)
def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
new_shape = list(tensor.shape)
new_shape[dim] = tensor.shape[dim] * world_size
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
if supports_custom_op():
from vllm.platforms import current_platform
direct_register_custom_op(
......@@ -128,6 +155,20 @@ if supports_custom_op():
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="reduce_scatter",
op_func=reduce_scatter,
mutates_args=[],
fake_impl=reduce_scatter_fake,
)
direct_register_custom_op(
op_name="all_gather",
op_func=all_gather,
mutates_args=[],
fake_impl=all_gather_fake,
)
class GroupCoordinator:
"""
......@@ -327,6 +368,18 @@ class GroupCoordinator:
return self.device_communicator.all_gather(input_, dim)
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
return self.device_communicator.reduce_scatter(input_, dim)
def gather(self,
input_: torch.Tensor,
dst: int = 0,
......@@ -772,14 +825,6 @@ def get_pp_group() -> GroupCoordinator:
# kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group
_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None
def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
assert _KV_TRANSFER is not None, (
"disaggregated KV cache transfer parallel group is not initialized")
return _KV_TRANSFER
@contextmanager
def graph_capture(device: torch.device):
......@@ -962,26 +1007,6 @@ def initialize_model_parallel(
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
Initialize KV cache transfer parallel group.
"""
global _KV_TRANSFER
if vllm_config.kv_transfer_config is None:
return
if all([
vllm_config.kv_transfer_config.is_kv_transfer_instance,
_KV_TRANSFER is None
]):
_KV_TRANSFER = kv_transfer.KVTransferAgent(
rank=get_world_group().rank,
local_rank=get_world_group().local_rank,
config=vllm_config)
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
......
......@@ -7,6 +7,7 @@
import dataclasses
import datetime
import pickle
import socket
import time
from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
......@@ -123,6 +124,10 @@ class StatelessProcessGroup:
rank: int
world_size: int
store: torch._C._distributed_c10d.Store
# stores a reference to the socket so that the file descriptor stays alive
socket: Optional[socket.socket]
data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter
......@@ -234,18 +239,33 @@ class StatelessProcessGroup:
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa
launch_server = rank == 0
if launch_server:
# listen on the specified interface (instead of 0.0.0.0)
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
listen_socket.bind((host, port))
listen_socket.listen()
listen_fd = listen_socket.fileno()
else:
listen_socket = None
listen_fd = None
store = TCPStore(
host_name=host,
port=port,
world_size=world_size,
is_master=(rank == 0),
is_master=launch_server,
timeout=datetime.timedelta(seconds=store_timeout),
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
master_listen_fd=listen_fd,
)
return StatelessProcessGroup(
rank=rank,
world_size=world_size,
store=store,
socket=listen_socket,
data_expiration_seconds=data_expiration_seconds)
......
# SPDX-License-Identifier: Apache-2.0
# yapf: disable
import argparse
import dataclasses
import json
import re
import threading
from dataclasses import MISSING, dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
Tuple, Type, Union, cast, get_args, get_origin)
from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
TypeVar, Union, cast, get_args, get_origin)
import torch
from typing_extensions import TypeIs, deprecated
import vllm.envs as envs
from vllm import version
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
DecodingConfig, DeviceConfig, HfOverrides,
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
ConfigFormat, ConfigType, DecodingConfig, Device,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackendV1, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, ObservabilityConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig, TaskOption,
TokenizerPoolConfig, VllmConfig, get_attr_docs)
ModelConfig, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PrefixCachingHashAlgo, PromptAdapterConfig,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerPoolConfig, VllmConfig,
get_attr_docs, get_field)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
......@@ -28,33 +34,42 @@ from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor
from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor
if TYPE_CHECKING:
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
# yapf: enable
logger = init_logger(__name__)
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
DEVICE_OPTIONS = [
"auto",
"cuda",
"neuron",
"cpu",
"tpu",
"xpu",
"hpu",
]
# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]
def nullable_str(val: str):
if not val or val == "None":
return None
return val
def optional_type(
return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
def _optional_type(val: str) -> Optional[T]:
if val == "" or val == "None":
return None
try:
if return_type is json.loads and not re.match("^{.*}$", val):
return cast(T, nullable_kvs(val))
return return_type(val)
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
return _optional_type
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
@deprecated(
"Passing a JSON argument as a string containing comma separated key=value "
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
"string instead.")
def nullable_kvs(val: str) -> dict[str, int]:
"""Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary.
......@@ -64,10 +79,7 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
Returns:
Dictionary with parsed values.
"""
if len(val) == 0:
return None
out_dict: Dict[str, int] = {}
out_dict: dict[str, int] = {}
for item in val.split(","):
kv_parts = [part.lower().strip() for part in item.split("=")]
if len(kv_parts) != 2:
......@@ -89,6 +101,105 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
return out_dict
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
"""Check if the type hint is a specific type."""
return type_hint is type or get_origin(type_hint) is type
def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
"""Check if the type hints contain a specific type."""
return any(is_type(type_hint, type) for type_hint in type_hints)
def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
"""Get the specific type from the type hints."""
return next((th for th in type_hints if is_type(th, type)), None)
def is_not_builtin(type_hint: TypeHint) -> bool:
"""Check if the class is not a built-in type."""
return type_hint.__module__ != "builtins"
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
# Get the default value of the field
default = field.default
if field.default_factory is not MISSING:
default = field.default_factory()
# Get the help text for the field
name = field.name
help = cls_docs[name]
# Escape % for argparse
help = help.replace("%", "%%")
# Initialise the kwargs dictionary for the field
kwargs[name] = {"default": default, "help": help}
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) is Union:
type_hints.update(get_args(field.type))
else:
type_hints.add(field.type)
# Set other kwargs based on the type hints
if contains_type(type_hints, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
elif contains_type(type_hints, Literal):
# Creates choices from Literal arguments
type_hint = get_type(type_hints, Literal)
choices = sorted(get_args(type_hint))
kwargs[name]["choices"] = choices
choice_type = type(choices[0])
assert all(type(c) is choice_type for c in choices), (
"All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}")
kwargs[name]["type"] = choice_type
elif contains_type(type_hints, tuple):
type_hint = get_type(type_hints, tuple)
types = get_args(type_hint)
tuple_type = types[0]
assert all(t is tuple_type for t in types if t is not Ellipsis), (
"All non-Ellipsis tuple elements must be of the same "
f"type. Got {types}.")
kwargs[name]["type"] = tuple_type
kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
elif contains_type(type_hints, list):
type_hint = get_type(type_hints, list)
types = get_args(type_hint)
assert len(types) == 1, (
"List type must have exactly one type. Got "
f"{type_hint} with types {types}")
kwargs[name]["type"] = types[0]
kwargs[name]["nargs"] = "+"
elif contains_type(type_hints, int):
kwargs[name]["type"] = int
elif contains_type(type_hints, float):
kwargs[name]["type"] = float
elif contains_type(type_hints, dict):
# Dict arguments will always be optional
kwargs[name]["type"] = optional_type(json.loads)
elif (contains_type(type_hints, str)
or any(is_not_builtin(th) for th in type_hints)):
kwargs[name]["type"] = str
else:
raise ValueError(
f"Unsupported type {type_hints} for argument {name}.")
# If None is in type_hints, make the argument optional.
# But not if it's a bool, argparse will handle this better.
if type(None) in type_hints and not contains_type(type_hints, bool):
kwargs[name]["type"] = optional_type(kwargs[name]["type"])
if kwargs[name].get("choices"):
kwargs[name]["choices"].append("None")
return kwargs
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
......@@ -105,14 +216,15 @@ class EngineArgs:
load_format: str = LoadConfig.load_format
config_format: ConfigFormat = ConfigFormat.AUTO
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
seed: Optional[int] = None
max_model_len: Optional[int] = None
# Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without
# notice.
distributed_executor_backend: Optional[Union[
str, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
DistributedExecutorBackend,
Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
......@@ -120,20 +232,23 @@ class EngineArgs:
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
block_size: Optional[int] = None
enable_prefix_caching: Optional[bool] = None
prefix_caching_hash_algo: str = "builtin"
block_size: Optional[BlockSize] = CacheConfig.block_size
enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
prefix_caching_hash_algo: PrefixCachingHashAlgo = \
CacheConfig.prefix_caching_hash_algo
disable_sliding_window: bool = False
disable_cascade_attn: bool = False
use_v2_block_manager: bool = True
swap_space: float = 4 # GiB
cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_partial_prefills: Optional[int] = 1
max_long_partial_prefills: Optional[int] = 1
long_prefill_token_threshold: Optional[int] = 0
max_num_seqs: Optional[int] = None
swap_space: float = CacheConfig.swap_space
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
max_num_batched_tokens: Optional[
int] = SchedulerConfig.max_num_batched_tokens
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
long_prefill_token_threshold: int = \
SchedulerConfig.long_prefill_token_threshold
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
disable_log_stats: bool = False
revision: Optional[str] = None
......@@ -147,42 +262,51 @@ class EngineArgs:
enforce_eager: Optional[bool] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
tokenizer_pool_size: int = 0
# Note: Specifying a tokenizer pool by passing a class
# is intended for expert use only. The API may change without
# notice.
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
# The following three fields are deprecated and will be removed in a future
# release. Setting them will have no effect. Please remove them from your
# configurations.
tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
tokenizer_pool_type: str = TokenizerPoolConfig.pool_type
tokenizer_pool_extra_config: dict = \
get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: dict[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt")
mm_processor_kwargs: Optional[Dict[str, Any]] = None
disable_mm_preprocessor_cache: bool = False
# LoRA fields
enable_lora: bool = False
enable_lora_bias: bool = False
max_loras: int = 1
max_lora_rank: int = 16
enable_lora_bias: bool = LoRAConfig.bias_enabled
max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = LoRAConfig.max_lora_rank
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
long_lora_scaling_factors: Optional[tuple[float, ...]] = \
LoRAConfig.long_lora_scaling_factors
# PromptAdapter fields
enable_prompt_adapter: bool = False
max_prompt_adapters: int = 1
max_prompt_adapter_token: int = 0
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'auto'
num_scheduler_steps: int = 1
multi_step_stream_outputs: bool = True
max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
max_prompt_adapter_token: int = \
PromptAdapterConfig.max_prompt_adapter_token
device: Device = DeviceConfig.device
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
model_loader_extra_config: Optional[
dict] = LoadConfig.model_loader_extra_config
num_gpu_blocks_override: Optional[
int] = CacheConfig.num_gpu_blocks_override
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
model_loader_extra_config: dict = \
get_field(LoadConfig, "model_loader_extra_config")
ignore_patterns: Optional[Union[str,
List[str]]] = LoadConfig.ignore_patterns
preemption_mode: Optional[str] = None
preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
scheduler_delay_factor: float = 0.0
enable_chunked_prefill: Optional[bool] = None
disable_chunked_mm_input: bool = False
scheduler_delay_factor: float = SchedulerConfig.delay_factor
enable_chunked_prefill: Optional[
bool] = SchedulerConfig.enable_chunked_prefill
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
guided_decoding_backend: str = DecodingConfig.guided_decoding_backend
logits_processor_pattern: Optional[str] = None
......@@ -194,8 +318,8 @@ class EngineArgs:
otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler"
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None
......@@ -210,11 +334,11 @@ class EngineArgs:
enable_sleep_mode: bool = False
model_impl: str = "auto"
calculate_kv_scales: Optional[bool] = None
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
additional_config: Optional[Dict[str, Any]] = None
enable_reasoning: Optional[bool] = None
reasoning_parser: Optional[str] = None
reasoning_parser: Optional[str] = DecodingConfig.reasoning_backend
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
def __post_init__(self):
......@@ -236,38 +360,6 @@ class EngineArgs:
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine."""
def is_type_in_union(cls: type[Any], type: type[Any]) -> bool:
"""Check if the class is a type in a union type."""
return get_origin(cls) is Union and type in get_args(cls)
def is_optional(cls: type[Any]) -> bool:
"""Check if the class is an optional type."""
return is_type_in_union(cls, type(None))
def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
name = field.name
# One of these will always be present
default = (field.default_factory
if field.default is MISSING else field.default)
kwargs[name] = {"default": default, "help": cls_docs[name]}
# When using action="store_true"
# add_argument doesn't accept type
if field.type is bool:
continue
# Handle optional fields
if is_optional(field.type):
kwargs[name]["type"] = nullable_str
continue
# Handle str in union fields
if is_type_in_union(field.type, str):
kwargs[name]["type"] = str
continue
kwargs[name]["type"] = field.type
return kwargs
# Model arguments
parser.add_argument(
'--model',
......@@ -285,13 +377,13 @@ class EngineArgs:
'which task to use.')
parser.add_argument(
'--tokenizer',
type=nullable_str,
type=optional_type(str),
default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
"--hf-config-path",
type=nullable_str,
type=optional_type(str),
default=EngineArgs.hf_config_path,
help='Name or path of the huggingface config to use. '
'If unspecified, model name or path will be used.')
......@@ -303,21 +395,21 @@ class EngineArgs:
'the input. The generated output will contain token ids.')
parser.add_argument(
'--revision',
type=nullable_str,
type=optional_type(str),
default=None,
help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument(
'--code-revision',
type=nullable_str,
type=optional_type(str),
default=None,
help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.')
parser.add_argument(
'--tokenizer-revision',
type=nullable_str,
type=optional_type(str),
default=None,
help='Revision of the huggingface tokenizer to use. '
'It can be a branch name, a tag name, or a commit id. '
......@@ -357,7 +449,6 @@ class EngineArgs:
load_group.add_argument('--model-loader-extra-config',
**load_kwargs["model_loader_extra_config"])
load_group.add_argument('--use-tqdm-on-load',
action=argparse.BooleanOptionalAction,
**load_kwargs["use_tqdm_on_load"])
parser.add_argument(
......@@ -382,14 +473,6 @@ class EngineArgs:
'* "bfloat16" for a balance between precision and range.\n'
'* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.')
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument('--max-model-len',
type=human_readable_int,
default=EngineArgs.max_model_len,
......@@ -399,21 +482,25 @@ class EngineArgs:
'Examples:\n'
'- 1k → 1000\n'
'- 1K → 1024\n')
parser.add_argument(
# Guided decoding arguments
guided_decoding_kwargs = get_kwargs(DecodingConfig)
guided_decoding_group = parser.add_argument_group(
title="DecodingConfig",
description=DecodingConfig.__doc__,
)
guided_decoding_group.add_argument(
'--guided-decoding-backend',
type=str,
default=DecodingConfig.guided_decoding_backend,
help='Which engine will be used for guided decoding'
' (JSON schema / regex etc) by default. Currently support '
'https://github.com/mlc-ai/xgrammar and '
'https://github.com/guidance-ai/llguidance.'
'Valid backend values are "xgrammar", "guidance", and "auto". '
'With "auto", we will make opinionated choices based on request '
'contents and what the backend libraries currently support, so '
'the behavior is subject to change in each release.')
**guided_decoding_kwargs["guided_decoding_backend"])
guided_decoding_group.add_argument(
"--reasoning-parser",
# This choices is a special case because it's not static
choices=list(ReasoningParserManager.reasoning_parsers),
**guided_decoding_kwargs["reasoning_backend"])
parser.add_argument(
'--logits-processor-pattern',
type=nullable_str,
type=optional_type(str),
default=None,
help='Optional regex pattern specifying valid logits processor '
'qualified names that can be passed with the `logits_processors` '
......@@ -439,7 +526,6 @@ class EngineArgs:
)
parallel_group.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp', 'uni', 'external_launcher'],
**parallel_kwargs["distributed_executor_backend"])
parallel_group.add_argument(
'--pipeline-parallel-size', '-pp',
......@@ -450,46 +536,40 @@ class EngineArgs:
**parallel_kwargs["data_parallel_size"])
parallel_group.add_argument(
'--enable-expert-parallel',
action='store_true',
**parallel_kwargs["enable_expert_parallel"])
parallel_group.add_argument(
'--max-parallel-loading-workers',
**parallel_kwargs["max_parallel_loading_workers"])
parallel_group.add_argument(
'--ray-workers-use-nsight',
action='store_true',
**parallel_kwargs["ray_workers_use_nsight"])
parallel_group.add_argument(
'--disable-custom-all-reduce',
action='store_true',
**parallel_kwargs["disable_custom_all_reduce"])
# KV cache arguments
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32, 64, 128],
help='Token block size for contiguous chunks of '
'tokens. This is ignored on neuron devices and '
'set to ``--max-model-len``. On CUDA devices, '
'only block sizes up to 32 are supported. '
'On HPU devices, block size defaults to 128.')
parser.add_argument(
"--enable-prefix-caching",
action=argparse.BooleanOptionalAction,
default=EngineArgs.enable_prefix_caching,
help="Enables automatic prefix caching. "
"Use ``--no-enable-prefix-caching`` to disable explicitly.",
)
parser.add_argument(
"--prefix-caching-hash-algo",
type=str,
choices=["builtin", "sha256"],
default=EngineArgs.prefix_caching_hash_algo,
help="Set the hash algorithm for prefix caching. "
"Options are 'builtin' (Python's built-in hash) or 'sha256' "
"(collision resistant but with certain overheads).",
# KV cache arguments
cache_kwargs = get_kwargs(CacheConfig)
cache_group = parser.add_argument_group(
title="CacheConfig",
description=CacheConfig.__doc__,
)
cache_group.add_argument('--block-size', **cache_kwargs["block_size"])
cache_group.add_argument('--gpu-memory-utilization',
**cache_kwargs["gpu_memory_utilization"])
cache_group.add_argument('--swap-space', **cache_kwargs["swap_space"])
cache_group.add_argument('--kv-cache-dtype',
**cache_kwargs["cache_dtype"])
cache_group.add_argument('--num-gpu-blocks-override',
**cache_kwargs["num_gpu_blocks_override"])
cache_group.add_argument("--enable-prefix-caching",
**cache_kwargs["enable_prefix_caching"])
cache_group.add_argument("--prefix-caching-hash-algo",
**cache_kwargs["prefix_caching_hash_algo"])
cache_group.add_argument('--cpu-offload-gb',
**cache_kwargs["cpu_offload_gb"])
cache_group.add_argument('--calculate-kv-scales',
**cache_kwargs["calculate_kv_scales"])
parser.add_argument('--disable-sliding-window',
action='store_true',
help='Disables sliding window, '
......@@ -502,86 +582,11 @@ class EngineArgs:
'block manager v2) is now the default. '
'Setting this flag to True or False'
' has no effect on vLLM behavior.')
parser.add_argument(
'--num-lookahead-slots',
type=int,
default=EngineArgs.num_lookahead_slots,
help='Experimental scheduling config necessary for '
'speculative decoding. This will be replaced by '
'speculative config in the future; it is present '
'to enable correctness tests until then.')
parser.add_argument('--seed',
type=int,
default=EngineArgs.seed,
help='Random seed for operations.')
parser.add_argument('--swap-space',
type=float,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU.')
parser.add_argument(
'--cpu-offload-gb',
type=float,
default=0,
help='The space in GiB to offload to CPU, per GPU. '
'Default is 0, which means no offloading. Intuitively, '
'this argument can be seen as a virtual way to increase '
'the GPU memory size. For example, if you have one 24 GB '
'GPU and set this to 10, virtually you can think of it as '
'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
'which requires at least 26GB GPU memory. Note that this '
'requires fast CPU-GPU interconnect, as part of the model is '
'loaded from CPU memory to GPU memory on the fly in each '
'model forward pass.')
parser.add_argument(
'--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization,
help='The fraction of GPU memory to be used for the model '
'executor, which can range from 0 to 1. For example, a value of '
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
'will use the default value of 0.9. This is a per-instance '
'limit, and only applies to the current vLLM instance.'
'It does not matter if you have another vLLM instance running '
'on the same GPU. For example, if you have two vLLM instances '
'running on the same GPU, you can set the GPU memory utilization '
'to 0.5 for each instance.')
parser.add_argument(
'--num-gpu-blocks-override',
type=int,
default=None,
help='If specified, ignore GPU profiling result and use this number'
' of GPU blocks. Used for testing preemption.')
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
help='Maximum number of batched tokens per '
'iteration.')
parser.add_argument(
"--max-num-partial-prefills",
type=int,
default=EngineArgs.max_num_partial_prefills,
help="For chunked prefill, the max number of concurrent \
partial prefills.")
parser.add_argument(
"--max-long-partial-prefills",
type=int,
default=EngineArgs.max_long_partial_prefills,
help="For chunked prefill, the maximum number of prompts longer "
"than --long-prefill-token-threshold that will be prefilled "
"concurrently. Setting this less than --max-num-partial-prefills "
"will allow shorter prompts to jump the queue in front of longer "
"prompts in some cases, improving latency.")
parser.add_argument(
"--long-prefill-token-threshold",
type=float,
default=EngineArgs.long_prefill_token_threshold,
help="For chunked prefill, a request is considered long if the "
"prompt is longer than this number of tokens.")
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
help='Maximum number of sequences per iteration.')
parser.add_argument(
'--max-logprobs',
type=int,
......@@ -594,7 +599,7 @@ class EngineArgs:
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=nullable_str,
type=optional_type(str),
choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization,
help='Method used to quantize the weights. If '
......@@ -645,154 +650,108 @@ class EngineArgs:
'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.')
parser.add_argument('--tokenizer-pool-size',
type=int,
default=EngineArgs.tokenizer_pool_size,
help='Size of tokenizer pool to use for '
'asynchronous tokenization. If 0, will '
'use synchronous tokenization.')
parser.add_argument('--tokenizer-pool-type',
type=str,
default=EngineArgs.tokenizer_pool_type,
help='Type of tokenizer pool to use for '
'asynchronous tokenization. Ignored '
'if tokenizer_pool_size is 0.')
parser.add_argument('--tokenizer-pool-extra-config',
type=nullable_str,
default=EngineArgs.tokenizer_pool_extra_config,
help='Extra config for tokenizer pool. '
'This should be a JSON string that will be '
'parsed into a dictionary. Ignored if '
'tokenizer_pool_size is 0.')
# Tokenizer arguments
tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
tokenizer_group = parser.add_argument_group(
title="TokenizerPoolConfig",
description=TokenizerPoolConfig.__doc__,
)
tokenizer_group.add_argument('--tokenizer-pool-size',
**tokenizer_kwargs["pool_size"])
tokenizer_group.add_argument('--tokenizer-pool-type',
**tokenizer_kwargs["pool_type"])
tokenizer_group.add_argument('--tokenizer-pool-extra-config',
**tokenizer_kwargs["extra_config"])
# Multimodal related configs
parser.add_argument(
'--limit-mm-per-prompt',
type=nullable_kvs,
default=EngineArgs.limit_mm_per_prompt,
# The default value is given in
# MultiModalConfig.get_default_limit_per_prompt
help=('For each multimodal plugin, limit how many '
'input instances to allow for each prompt. '
'Expects a comma-separated list of items, '
'e.g.: `image=16,video=2` allows a maximum of 16 '
'images and 2 videos per prompt. Defaults to '
'1 (V0) or 999 (V1) for each modality.'))
multimodal_kwargs = get_kwargs(MultiModalConfig)
multimodal_group = parser.add_argument_group(
title="MultiModalConfig",
description=MultiModalConfig.__doc__,
)
multimodal_group.add_argument('--limit-mm-per-prompt',
**multimodal_kwargs["limit_per_prompt"])
parser.add_argument(
'--mm-processor-kwargs',
default=None,
type=json.loads,
help=('Overrides for the multimodal input mapping/processing, '
'e.g., image processor. For example: ``{"num_crops": 4}``.'))
help=('Overrides for the multi-modal processor obtained from '
'``AutoProcessor.from_pretrained``. The available overrides '
'depend on the model that is being run.'
'For example, for Phi-3-Vision: ``{"num_crops": 4}``.'))
parser.add_argument(
'--disable-mm-preprocessor-cache',
action='store_true',
help='If true, then disables caching of the multi-modal '
'preprocessor/mapper. (not recommended)')
help='If True, disable caching of the processed multi-modal '
'inputs.')
# LoRA related configs
parser.add_argument('--enable-lora',
action='store_true',
help='If True, enable handling of LoRA adapters.')
parser.add_argument('--enable-lora-bias',
action='store_true',
help='If True, enable bias for LoRA adapters.')
parser.add_argument('--max-loras',
type=int,
default=EngineArgs.max_loras,
help='Max number of LoRAs in a single batch.')
parser.add_argument('--max-lora-rank',
type=int,
default=EngineArgs.max_lora_rank,
help='Max LoRA rank.')
parser.add_argument(
'--lora-extra-vocab-size',
type=int,
default=EngineArgs.lora_extra_vocab_size,
help=('Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'))
parser.add_argument(
lora_kwargs = get_kwargs(LoRAConfig)
lora_group = parser.add_argument_group(
title="LoRAConfig",
description=LoRAConfig.__doc__,
)
lora_group.add_argument(
'--enable-lora',
action=argparse.BooleanOptionalAction,
help='If True, enable handling of LoRA adapters.')
lora_group.add_argument('--enable-lora-bias',
**lora_kwargs["bias_enabled"])
lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"])
lora_group.add_argument('--max-lora-rank',
**lora_kwargs["max_lora_rank"])
lora_group.add_argument('--lora-extra-vocab-size',
**lora_kwargs["lora_extra_vocab_size"])
lora_group.add_argument(
'--lora-dtype',
type=str,
default=EngineArgs.lora_dtype,
choices=['auto', 'float16', 'bfloat16'],
help=('Data type for LoRA. If auto, will default to '
'base model dtype.'))
parser.add_argument(
'--long-lora-scaling-factors',
type=nullable_str,
default=EngineArgs.long_lora_scaling_factors,
help=('Specify multiple scaling factors (which can '
'be different from base model scaling factor '
'- see eg. Long LoRA) to allow for multiple '
'LoRA adapters trained with those scaling '
'factors to be used at the same time. If not '
'specified, only adapters trained with the '
'base model scaling factor are allowed.'))
parser.add_argument(
'--max-cpu-loras',
type=int,
default=EngineArgs.max_cpu_loras,
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_loras.'))
parser.add_argument(
'--fully-sharded-loras',
action='store_true',
help=('By default, only half of the LoRA computation is '
'sharded with tensor parallelism. '
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'))
parser.add_argument('--enable-prompt-adapter',
action='store_true',
help='If True, enable handling of PromptAdapters.')
parser.add_argument('--max-prompt-adapters',
type=int,
default=EngineArgs.max_prompt_adapters,
help='Max number of PromptAdapters in a batch.')
parser.add_argument('--max-prompt-adapter-token',
type=int,
default=EngineArgs.max_prompt_adapter_token,
help='Max number of PromptAdapters tokens')
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
choices=DEVICE_OPTIONS,
help='Device type for vLLM execution.')
parser.add_argument('--num-scheduler-steps',
type=int,
default=1,
help=('Maximum number of forward steps per '
'scheduler call.'))
**lora_kwargs["lora_dtype"],
)
lora_group.add_argument('--long-lora-scaling-factors',
**lora_kwargs["long_lora_scaling_factors"])
lora_group.add_argument('--max-cpu-loras',
**lora_kwargs["max_cpu_loras"])
lora_group.add_argument('--fully-sharded-loras',
**lora_kwargs["fully_sharded_loras"])
# PromptAdapter related configs
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
prompt_adapter_group = parser.add_argument_group(
title="PromptAdapterConfig",
description=PromptAdapterConfig.__doc__,
)
prompt_adapter_group.add_argument(
'--enable-prompt-adapter',
action=argparse.BooleanOptionalAction,
help='If True, enable handling of PromptAdapters.')
prompt_adapter_group.add_argument(
'--max-prompt-adapters',
**prompt_adapter_kwargs["max_prompt_adapters"])
prompt_adapter_group.add_argument(
'--max-prompt-adapter-token',
**prompt_adapter_kwargs["max_prompt_adapter_token"])
# Device arguments
device_kwargs = get_kwargs(DeviceConfig)
device_group = parser.add_argument_group(
title="DeviceConfig",
description=DeviceConfig.__doc__,
)
device_group.add_argument("--device", **device_kwargs["device"])
# Speculative arguments
speculative_group = parser.add_argument_group(
title="SpeculativeConfig",
description=SpeculativeConfig.__doc__,
)
speculative_group.add_argument(
'--speculative-config',
type=json.loads,
default=None,
help='The configurations for speculative decoding.'
' Should be a JSON string.')
parser.add_argument(
'--multi-step-stream-outputs',
action=StoreBoolean,
default=EngineArgs.multi_step_stream_outputs,
nargs="?",
const="True",
help='If False, then multi-step will stream outputs at the end '
'of all steps')
parser.add_argument(
'--scheduler-delay-factor',
type=float,
default=EngineArgs.scheduler_delay_factor,
help='Apply a delay (of delay factor multiplied by previous '
'prompt latency) before scheduling next prompt.')
parser.add_argument(
'--enable-chunked-prefill',
action=StoreBoolean,
default=EngineArgs.enable_chunked_prefill,
nargs="?",
const="True",
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.')
parser.add_argument('--speculative-config',
type=json.loads,
default=None,
help='The configurations for speculative decoding.'
' Should be a JSON string.')
parser.add_argument(
'--ignore-patterns',
action="append",
......@@ -801,13 +760,6 @@ class EngineArgs:
help="The pattern(s) to ignore when loading the model."
"Default to `original/**/*` to avoid repeated loading of llama's "
"checkpoints.")
parser.add_argument(
'--preemption-mode',
type=str,
default=None,
help='If \'recompute\', the engine performs preemption by '
'recomputing; If \'swap\', the engine performs preemption by '
'block swapping.')
parser.add_argument(
"--served-model-name",
......@@ -863,22 +815,47 @@ class EngineArgs:
help="Disable async output processing. This may result in "
"lower performance.")
parser.add_argument(
'--scheduling-policy',
choices=['fcfs', 'priority'],
default="fcfs",
help='The scheduling policy to use. "fcfs" (first come first served'
', i.e. requests are handled in order of arrival; default) '
'or "priority" (requests are handled based on given '
'priority (lower value means earlier handling) and time of '
'arrival deciding any ties).')
parser.add_argument(
'--scheduler-cls',
default=EngineArgs.scheduler_cls,
help='The scheduler class to use. "vllm.core.scheduler.Scheduler" '
'is the default scheduler. Can be a class directly or the path to '
'a class of form "mod.custom_class".')
# Scheduler arguments
scheduler_kwargs = get_kwargs(SchedulerConfig)
scheduler_group = parser.add_argument_group(
title="SchedulerConfig",
description=SchedulerConfig.__doc__,
)
scheduler_group.add_argument(
'--max-num-batched-tokens',
**scheduler_kwargs["max_num_batched_tokens"])
scheduler_group.add_argument('--max-num-seqs',
**scheduler_kwargs["max_num_seqs"])
scheduler_group.add_argument(
"--max-num-partial-prefills",
**scheduler_kwargs["max_num_partial_prefills"])
scheduler_group.add_argument(
"--max-long-partial-prefills",
**scheduler_kwargs["max_long_partial_prefills"])
scheduler_group.add_argument(
"--long-prefill-token-threshold",
**scheduler_kwargs["long_prefill_token_threshold"])
scheduler_group.add_argument('--num-lookahead-slots',
**scheduler_kwargs["num_lookahead_slots"])
scheduler_group.add_argument('--scheduler-delay-factor',
**scheduler_kwargs["delay_factor"])
scheduler_group.add_argument('--preemption-mode',
**scheduler_kwargs["preemption_mode"])
scheduler_group.add_argument('--num-scheduler-steps',
**scheduler_kwargs["num_scheduler_steps"])
scheduler_group.add_argument(
'--multi-step-stream-outputs',
**scheduler_kwargs["multi_step_stream_outputs"])
scheduler_group.add_argument('--scheduling-policy',
**scheduler_kwargs["policy"])
scheduler_group.add_argument(
'--enable-chunked-prefill',
**scheduler_kwargs["enable_chunked_prefill"])
scheduler_group.add_argument(
"--disable-chunked-mm-input",
**scheduler_kwargs["disable_chunked_mm_input"])
parser.add_argument('--scheduler-cls',
**scheduler_kwargs["scheduler_cls"])
parser.add_argument(
'--override-neuron-config',
......@@ -905,10 +882,11 @@ class EngineArgs:
'testing only. level 3 is the recommended level '
'for production.\n'
'To specify the full compilation config, '
'use a JSON string.\n'
'use a JSON string, e.g. ``{"level": 3, '
'"cudagraph_capture_sizes": [1, 2, 4, 8]}``\n'
'Following the convention of traditional '
'compilers, using -O without space is also '
'supported. -O3 is equivalent to -O 3.')
'compilers, using ``-O`` without space is also '
'supported. ``-O3`` is equivalent to ``-O 3``.')
parser.add_argument('--kv-transfer-config',
type=KVTransferConfig.from_cli,
......@@ -930,7 +908,7 @@ class EngineArgs:
'class without changing the existing functions.')
parser.add_argument(
"--generation-config",
type=nullable_str,
type=optional_type(str),
default="auto",
help="The folder path to the generation config. "
"Defaults to 'auto', the generation config will be loaded from "
......@@ -957,15 +935,6 @@ class EngineArgs:
help="Enable sleep mode for the engine. "
"(only cuda platform is supported)")
parser.add_argument(
'--calculate-kv-scales',
action='store_true',
help='This enables dynamic calculation of '
'k_scale and v_scale when kv-cache-dtype is fp8. '
'If calculate-kv-scales is false, the scales will '
'be loaded from the model checkpoint if available. '
'Otherwise, the scales will default to 1.0.')
parser.add_argument(
"--additional-config",
type=json.loads,
......@@ -983,16 +952,6 @@ class EngineArgs:
"If enabled, the model will be able to generate reasoning content."
)
parser.add_argument(
"--reasoning-parser",
type=str,
choices=list(ReasoningParserManager.reasoning_parsers),
default=None,
help=
"Select the reasoning parser depending on the model that you're "
"using. This is used to parse the reasoning content into OpenAI "
"API format. Required for ``--enable-reasoning``.")
parser.add_argument(
"--disable-cascade-attn",
action="store_true",
......@@ -1003,20 +962,6 @@ class EngineArgs:
"Note that even if this is set to False, cascade attention will be "
"only used when the heuristic tells that it's beneficial.")
parser.add_argument(
"--disable-chunked-mm-input",
action=StoreBoolean,
default=EngineArgs.disable_chunked_mm_input,
nargs="?",
const="True",
help="Disable multimodal input chunking attention for V1. "
"If set to true and chunked prefill is enabled, we do not want to"
" partially schedule a multimodal item. This ensures that if a "
"request has a mixed prompt (like text tokens TTTT followed by "
"image tokens IIIIIIIIII) where only some image tokens can be "
"scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled "
"as TTTT in one step and IIIIIIIIII in the next.")
return parser
@classmethod
......@@ -1210,11 +1155,6 @@ class EngineArgs:
enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
tokenizer_pool_config=TokenizerPoolConfig.create_config(
self.tokenizer_pool_size,
self.tokenizer_pool_type,
self.tokenizer_pool_extra_config,
),
ray_workers_use_nsight=self.ray_workers_use_nsight,
placement_group=placement_group,
distributed_executor_backend=self.distributed_executor_backend,
......@@ -1288,8 +1228,6 @@ class EngineArgs:
if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "":
if self.model_loader_extra_config is None:
self.model_loader_extra_config = {}
self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
......@@ -1370,7 +1308,7 @@ class EngineArgs:
recommend_to_remove=False)
return False
if self.preemption_mode != EngineArgs.preemption_mode:
if self.preemption_mode != SchedulerConfig.preemption_mode:
_raise_or_fallback(feature_name="--preemption-mode",
recommend_to_remove=True)
return False
......@@ -1381,34 +1319,28 @@ class EngineArgs:
recommend_to_remove=True)
return False
if self.scheduling_policy != EngineArgs.scheduling_policy:
if self.scheduling_policy != SchedulerConfig.policy:
_raise_or_fallback(feature_name="--scheduling-policy",
recommend_to_remove=False)
return False
if self.num_scheduler_steps != EngineArgs.num_scheduler_steps:
if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps:
_raise_or_fallback(feature_name="--num-scheduler-steps",
recommend_to_remove=True)
return False
if self.scheduler_delay_factor != EngineArgs.scheduler_delay_factor:
if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
_raise_or_fallback(feature_name="--scheduler-delay-factor",
recommend_to_remove=True)
return False
if self.additional_config != EngineArgs.additional_config:
_raise_or_fallback(feature_name="--additional-config",
recommend_to_remove=False)
return False
# Xgrammar and Guidance are supported.
SUPPORTED_GUIDED_DECODING = [
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
"guidance:disable-any-whitespace", "auto"
]
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
_raise_or_fallback(feature_name="--guided-decoding-backend",
recommend_to_remove=False)
# remove backend options when doing this check
if self.guided_decoding_backend.split(':')[0] \
not in get_args(GuidedDecodingBackendV1):
_raise_or_fallback(
feature_name=
f"--guided-decoding-backend={self.guided_decoding_backend}",
recommend_to_remove=False)
return False
# Need at least Ampere for now (FA support required).
......@@ -1432,7 +1364,7 @@ class EngineArgs:
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if fp8_attention and will_use_fa:
from vllm.vllm_flash_attn.fa_utils import (
from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8)
supported = flash_attn_supports_fp8()
if not supported:
......@@ -1475,9 +1407,9 @@ class EngineArgs:
# No Concurrent Partial Prefills so far.
if (self.max_num_partial_prefills
!= EngineArgs.max_num_partial_prefills
!= SchedulerConfig.max_num_partial_prefills
or self.max_long_partial_prefills
!= EngineArgs.max_long_partial_prefills):
!= SchedulerConfig.max_long_partial_prefills):
_raise_or_fallback(feature_name="Concurrent Partial Prefill",
recommend_to_remove=False)
return False
......@@ -1497,7 +1429,7 @@ class EngineArgs:
if speculative_method:
if speculative_method in ("ngram", "[ngram]"):
is_ngram_enabled = True
elif speculative_method == "eagle":
elif speculative_method in ("eagle", "eagle3"):
is_eagle_enabled = True
else:
speculative_model = self.speculative_config.get("model")
......@@ -1509,16 +1441,17 @@ class EngineArgs:
recommend_to_remove=False)
return False
# No Disaggregated Prefill so far.
if self.kv_transfer_config != EngineArgs.kv_transfer_config:
_raise_or_fallback(feature_name="--kv-transfer-config",
recommend_to_remove=False)
return False
# No FlashInfer or XFormers so far.
# No XFormers so far.
V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1",
"TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA"
"FLASH_ATTN_VLLM_V1",
"FLASH_ATTN",
"PALLAS",
"PALLAS_VLLM_V1",
"TRITON_ATTN_VLLM_V1",
"TRITON_MLA",
"FLASHMLA",
"FLASHINFER",
"FLASHINFER_VLLM_V1",
]
if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
......@@ -1620,9 +1553,7 @@ class EngineArgs:
self.enable_prefix_caching = False
# VLLM_V0 only supports builtin hash algo for prefix caching.
if self.prefix_caching_hash_algo is None:
self.prefix_caching_hash_algo = "builtin"
elif self.prefix_caching_hash_algo == "sha256":
if self.prefix_caching_hash_algo == "sha256":
raise ValueError(
"sha256 is not supported for prefix caching in V0 engine. "
"Please use 'builtin'.")
......@@ -1641,10 +1572,6 @@ class EngineArgs:
if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
# if using prefix caching, we must set a hash algo
if self.enable_prefix_caching and self.prefix_caching_hash_algo is None:
self.prefix_caching_hash_algo = "builtin"
# V1 should use the new scheduler by default.
# Swap it only if this arg is set to the original V0 default
if self.scheduler_cls == EngineArgs.scheduler_cls:
......@@ -1661,13 +1588,13 @@ class EngineArgs:
# values for non-H100/H200 GPUs.
try:
from vllm.platforms import current_platform
device_name = current_platform.get_device_name().lower()
device_memory = current_platform.get_device_total_memory()
except Exception:
# This is only used to set default_max_num_batched_tokens
device_name = "no-device"
device_memory = 0
if "h100" in device_name or "h200" in device_name:
# For H100 and H200, we use larger default values.
if device_memory >= 70 * GiB_bytes:
# For GPUs like H100 and MI300x, use larger default values.
default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 16384,
UsageContext.OPENAI_API_SERVER: 8192,
......
......@@ -493,12 +493,11 @@ class _AsyncLLMEngine(LLMEngine):
tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer)
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
processed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
processed_inputs = self.input_processor(preprocessed_inputs)
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
......@@ -526,10 +525,15 @@ class _AsyncLLMEngine(LLMEngine):
)
async def check_health_async(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health()
async def collective_rpc_async(self,
method: str,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None):
raise NotImplementedError
async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
......@@ -1167,6 +1171,10 @@ class AsyncLLMEngine(EngineClient):
exception=asyncio.CancelledError,
verbose=self.log_requests)
async def get_vllm_config(self) -> VllmConfig:
"""Get the vllm configuration of the vLLM engine."""
return self.engine.get_vllm_config()
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
return self.engine.get_model_config()
......@@ -1234,6 +1242,17 @@ class AsyncLLMEngine(EngineClient):
async def add_lora(self, lora_request: LoRARequest) -> None:
self.engine.add_lora(lora_request)
async def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None):
"""
Perform a collective RPC call to the given path.
"""
return await self.engine.collective_rpc_async(method, timeout, args,
kwargs)
# TODO(v1): Remove this class proxy when V1 goes default.
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
......
......@@ -29,8 +29,7 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputs)
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
......@@ -55,7 +54,7 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs)
TokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import (Counter, Device, deprecate_kwargs,
......@@ -66,7 +65,6 @@ from vllm.worker.model_runner_base import InputProcessingError
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
_R = TypeVar("_R", default=Any)
......@@ -205,7 +203,7 @@ class LLMEngine:
return outputs_
tokenizer: Optional[BaseTokenizerGroup]
tokenizer: Optional[TokenizerGroup]
def __init__(
self,
......@@ -214,7 +212,6 @@ class LLMEngine:
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
......@@ -275,11 +272,7 @@ class LLMEngine:
self.tokenizer,
mm_registry)
self.input_registry = input_registry
self.input_processor = input_registry.create_input_processor(
self.model_config)
self.model_executor = executor_class(vllm_config=vllm_config, )
self.model_executor = executor_class(vllm_config=vllm_config)
if self.model_config.runner_type != "pooling":
self._initialize_kv_caches()
......@@ -321,11 +314,6 @@ class LLMEngine:
self.parallel_config.disable_custom_all_reduce,
})
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
self.cached_scheduler_outputs = [
SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size)
......@@ -537,21 +525,12 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown()
def get_tokenizer_group(
self,
group_type: Type[_G] = BaseTokenizerGroup,
) -> _G:
tokenizer_group = self.tokenizer
if tokenizer_group is None:
def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None:
raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True")
if not isinstance(tokenizer_group, group_type):
raise TypeError("Invalid type of tokenizer group. "
f"Expected type: {group_type}, but "
f"found type: {type(tokenizer_group)}")
return tokenizer_group
return self.tokenizer
def get_tokenizer(
self,
......@@ -559,11 +538,10 @@ class LLMEngine:
) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def _init_tokenizer(self) -> BaseTokenizerGroup:
def _init_tokenizer(self) -> TokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
parallel_config=self.parallel_config,
lora_config=self.lora_config)
def _verify_args(self) -> None:
......@@ -778,12 +756,11 @@ class LLMEngine:
prompt,
tokenizer=self.get_tokenizer(lora_request=lora_request))
preprocessed_inputs = self.input_preprocessor.preprocess(
processed_inputs = self.input_preprocessor.preprocess(
prompt,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
processed_inputs = self.input_processor(preprocessed_inputs)
self._add_processed_request(
request_id=request_id,
......@@ -914,6 +891,10 @@ class LLMEngine:
scheduler.abort_seq_group(
request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
def get_vllm_config(self) -> VllmConfig:
"""Gets the vllm configuration."""
return self.vllm_config
def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config
......@@ -1948,8 +1929,6 @@ class LLMEngine:
return self.model_executor.is_sleeping
def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health()
def is_tracing_enabled(self) -> bool:
......@@ -2058,7 +2037,7 @@ class LLMEngine:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) >= max_prompt_len:
if len(prompt_ids) > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
......
......@@ -140,16 +140,13 @@ class Metrics:
name="vllm:generation_tokens_total",
documentation="Number of generation tokens processed.",
labelnames=labelnames)
buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
if not vllm_config.model_config.enforce_eager:
buckets = vllm_config.compilation_config.\
cudagraph_capture_sizes.copy()
buckets.sort()
self.histogram_iteration_tokens = self._histogram_cls(
name="vllm:iteration_tokens_total",
documentation="Histogram of number of tokens per engine_step.",
labelnames=labelnames,
buckets=buckets)
buckets=[
1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384
])
self.histogram_time_to_first_token = self._histogram_cls(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
......
......@@ -93,6 +93,7 @@ class MQLLMEngineClient(EngineClient):
self._errored_with: Optional[BaseException] = None
# Get the configs.
self.vllm_config = engine_config
self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config
......@@ -100,7 +101,6 @@ class MQLLMEngineClient(EngineClient):
self.tokenizer = init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=engine_config.scheduler_config,
parallel_config=engine_config.parallel_config,
lora_config=engine_config.lora_config)
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer)
......@@ -377,6 +377,9 @@ class MQLLMEngineClient(EngineClient):
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)
async def get_vllm_config(self) -> VllmConfig:
return self.vllm_config
async def get_decoding_config(self) -> DecodingConfig:
return self.decoding_config
......
......@@ -178,7 +178,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# generates a fixed number of tokens without evaluating stopping
# conditions within the block. This can cause an eos token to be
# unintentionally ignored.
if not sampling_params.ignore_eos:
if not sampling_params.ignore_eos and self.detokenizer:
eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
# Avoiding .index calls as exception throwing in the happy path
# is expensive.
......
......@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
......@@ -220,6 +220,11 @@ class EngineClient(ABC):
"""
...
@abstractmethod
async def get_vllm_config(self) -> VllmConfig:
"""Get the vllm configuration of the vLLM engine."""
...
@abstractmethod
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
......
......@@ -27,10 +27,11 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio)
from pydantic import TypeAdapter
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin)
# pydantic needs the TypedDict from typing_extensions
from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig
......@@ -482,11 +483,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if modality in ("image", "image_embeds"):
if model_type == "chatglm":
return "<|begin_of_image|><|endoftext|><|end_of_image|>"
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
if model_type in ("phi3_v", "phi4mm"):
return f"<|image_{current_count}|>"
if model_type == "phi4mm":
return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
if model_type in ("minicpmo", "minicpmv"):
return "(<image>./</image>)"
if model_type in ("blip-2", "florence2", "fuyu", "paligemma",
......@@ -506,20 +504,24 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<|image|>"
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
return "<|vision_start|><|IMAGE|><|vision_end|>"
if model_type == "molmo":
return ""
if model_type == "aria":
return "<|fim_prefix|><|img|><|fim_suffix|>"
if model_type == "gemma3":
return "<start_of_image>"
if model_type == "kimi_vl":
return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>" # noqa: E501
raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
if model_type in ("ultravox", "granite_speech"):
return "<|audio|>"
if model_type == "phi4mm":
return "<|endoftext11|>" # 200011 (see vocab.json in hf model)
if model_type == "qwen2_audio":
return f"<|audio_{current_count}|>"
if model_type in ("qwen2_audio", "qwen2_5_omni"):
return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
if model_type == "minicpmo":
......@@ -528,6 +530,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
elif modality == "video":
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
return "<|vision_start|><|VIDEO|><|vision_end|>"
if model_type in ("minicpmo", "minicpmv"):
return "(<video>./</video>)"
if model_type.startswith("llava"):
......@@ -876,12 +880,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
# No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
# Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio]
......@@ -1092,7 +1097,11 @@ def _parse_chat_message_content(
if role == 'assistant':
parsed_msg = _AssistantParser(message)
if "tool_calls" in parsed_msg:
# The 'tool_calls' is not None check ensures compatibility.
# It's needed only if downstream code doesn't strictly
# follow the OpenAI spec.
if ("tool_calls" in parsed_msg
and parsed_msg["tool_calls"] is not None):
result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
elif role == "tool":
parsed_msg = _ToolParser(message)
......@@ -1189,14 +1198,25 @@ def apply_hf_chat_template(
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")
return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template,
tokenize=tokenize,
**kwargs,
)
try:
return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template,
tokenize=tokenize,
**kwargs,
)
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `transformers` while applying chat template")
raise ValueError from e
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
......@@ -1205,6 +1225,8 @@ def apply_mistral_chat_template(
tools: Optional[list[dict[str, Any]]],
**kwargs: Any,
) -> list[int]:
from mistral_common.exceptions import MistralCommonException
# The return value of resolve_mistral_chat_template is always None,
# and we won't use it.
resolve_mistral_chat_template(
......@@ -1222,5 +1244,16 @@ def apply_mistral_chat_template(
# if input does not comply with the expected format.
# We convert those assertion errors to ValueErrors so they can be
# are properly caught in the preprocessing_input step
except AssertionError as e:
except (AssertionError, MistralCommonException) as e:
raise ValueError from e
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `mistral_common` while applying chat "
"template")
raise ValueError from e
# SPDX-License-Identifier: Apache-2.0
import argparse
from vllm.benchmarks.latency import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
from vllm.entrypoints.cli.types import CLISubcommand
class BenchmarkLatencySubcommand(BenchmarkSubcommandBase):
""" The `latency` subcommand for vllm bench. """
def __init__(self):
self.name = "latency"
super().__init__()
@property
def help(self) -> str:
return "Benchmark the latency of a single batch of requests."
def add_cli_args(self, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)
def cmd_init() -> list[CLISubcommand]:
return [BenchmarkLatencySubcommand()]
# SPDX-License-Identifier: Apache-2.0
import argparse
import vllm.entrypoints.cli.benchmark.latency
import vllm.entrypoints.cli.benchmark.serve
import vllm.entrypoints.cli.benchmark.throughput
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.utils import FlexibleArgumentParser
# TODO: Add the rest of the benchmark subcommands here,
# e.g., throughput, latency, etc.
BENCHMARK_CMD_MODULES = [
vllm.entrypoints.cli.benchmark.latency,
vllm.entrypoints.cli.benchmark.serve,
vllm.entrypoints.cli.benchmark.throughput,
]
......
# SPDX-License-Identifier: Apache-2.0
import argparse
from vllm.benchmarks.throughput import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
from vllm.entrypoints.cli.types import CLISubcommand
class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase):
""" The `throughput` subcommand for vllm bench. """
def __init__(self):
self.name = "throughput"
super().__init__()
@property
def help(self) -> str:
return "Benchmark offline inference throughput."
def add_cli_args(self, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)
def cmd_init() -> list[CLISubcommand]:
return [BenchmarkThroughputSubcommand()]
# SPDX-License-Identifier: Apache-2.0
import argparse
from vllm.collect_env import main as collect_env_main
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import FlexibleArgumentParser
class CollectEnvSubcommand(CLISubcommand):
"""The `serve` subcommand for the vLLM CLI. """
def __init__(self):
self.name = "collect-env"
super().__init__()
@staticmethod
def cmd(args: argparse.Namespace) -> None:
"""Collect information about the environment."""
collect_env_main()
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
serve_parser = subparsers.add_parser(
"collect-env",
help="Start collecting environment information.",
description="Start collecting environment information.",
usage="vllm collect-env")
return make_arg_parser(serve_parser)
def cmd_init() -> list[CLISubcommand]:
return [CollectEnvSubcommand()]
......@@ -5,6 +5,7 @@ import signal
import sys
import vllm.entrypoints.cli.benchmark.main
import vllm.entrypoints.cli.collect_env
import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.serve
import vllm.version
......@@ -15,6 +16,7 @@ CMD_MODULES = [
vllm.entrypoints.cli.openai,
vllm.entrypoints.cli.serve,
vllm.entrypoints.cli.benchmark.main,
vllm.entrypoints.cli.collect_env,
]
......
......@@ -12,9 +12,11 @@ from fastapi import FastAPI, Request, Response
from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.ssl import SSLCertRefresher
from vllm.logger import init_logger
from vllm.utils import find_process_using_port
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
logger = init_logger(__name__)
......@@ -40,6 +42,8 @@ async def serve_http(app: FastAPI,
loop = asyncio.get_running_loop()
watchdog_task = loop.create_task(
watchdog_loop(server, app.state.engine_client))
server_task = loop.create_task(
server.serve(sockets=[sock] if sock else None))
......@@ -52,6 +56,7 @@ async def serve_http(app: FastAPI,
def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early
server_task.cancel()
watchdog_task.cancel()
if ssl_cert_refresher:
ssl_cert_refresher.stop()
......@@ -73,48 +78,69 @@ async def serve_http(app: FastAPI,
port, process, " ".join(process.cmdline()))
logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown()
finally:
watchdog_task.cancel()
async def watchdog_loop(server: uvicorn.Server, engine: EngineClient):
"""
# Watchdog task that runs in the background, checking
# for error state in the engine. Needed to trigger shutdown
# if an exception arises is StreamingResponse() generator.
"""
VLLM_WATCHDOG_TIME_S = 5.0
while True:
await asyncio.sleep(VLLM_WATCHDOG_TIME_S)
terminate_if_errored(server, engine)
def terminate_if_errored(server: uvicorn.Server, engine: EngineClient):
"""
See discussions here on shutting down a uvicorn server
https://github.com/encode/uvicorn/discussions/1103
In this case we cannot await the server shutdown here
because handler must first return to close the connection
for this request.
"""
engine_errored = engine.errored and not engine.is_running
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored:
server.should_exit = True
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""Adds handlers for fatal errors that should crash the server"""
"""
VLLM V1 AsyncLLM catches exceptions and returns
only two types: EngineGenerateError and EngineDeadError.
EngineGenerateError is raised by the per request generate()
method. This error could be request specific (and therefore
recoverable - e.g. if there is an error in input processing).
EngineDeadError is raised by the background output_handler
method. This error is global and therefore not recoverable.
We register these @app.exception_handlers to return nice
responses to the end user if they occur and shut down if needed.
See https://fastapi.tiangolo.com/tutorial/handling-errors/
for more details on how exception handlers work.
If an exception is encountered in a StreamingResponse
generator, the exception is not raised, since we already sent
a 200 status. Rather, we send an error message as the next chunk.
Since the exception is not raised, this means that the server
will not automatically shut down. Instead, we use the watchdog
background task for check for errored state.
"""
@app.exception_handler(RuntimeError)
async def runtime_error_handler(request: Request, __):
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
engine = request.app.state.engine_client
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
and not engine.is_running):
logger.fatal("AsyncLLMEngine has failed, terminating server "
"process")
# See discussions here on shutting down a uvicorn server
# https://github.com/encode/uvicorn/discussions/1103
# In this case we cannot await the server shutdown here because
# this handler must first return to close the connection for
# this request.
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(AsyncEngineDeadError)
async def async_engine_dead_handler(_, __):
"""Kill the server if the async engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("AsyncLLMEngine is already dead, terminating server "
"process")
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(MQEngineDeadError)
async def mq_engine_dead_handler(_, __):
"""Kill the server if the mq engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("MQLLMEngine is already dead, terminating server "
"process")
server.should_exit = True
@app.exception_handler(EngineDeadError)
@app.exception_handler(EngineGenerateError)
async def runtime_exception_handler(request: Request, __):
terminate_if_errored(
server=server,
engine=request.app.state.engine_client,
)
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment