Commit 081057de authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.5' into v0.8.5-ori

parents 7cf5d5c4 ba41cc90
...@@ -46,7 +46,7 @@ class KVTransferAgent: ...@@ -46,7 +46,7 @@ class KVTransferAgent:
assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\
"TransferAgent should only be used when kv_connector is set." "TransferAgent should only be used when kv_connector is set."
self.connector = KVConnectorFactory.create_connector( self.connector = KVConnectorFactory.create_connector_v0(
rank, local_rank, config) rank, local_rank, config)
def send_kv_caches_and_hidden_states( def send_kv_caches_and_hidden_states(
......
...@@ -70,7 +70,7 @@ class MooncakeStore(KVStoreBufferBase): ...@@ -70,7 +70,7 @@ class MooncakeStore(KVStoreBufferBase):
): ):
try: try:
from mooncake_vllm_adaptor import MooncakeDistributedStore from mooncake.store import MooncakeDistributedStore
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"Please install mooncake by following the instructions at " "Please install mooncake by following the instructions at "
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import json import json
import os import os
import struct
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union from typing import Optional, Union
...@@ -57,14 +58,14 @@ class MooncakeTransferEngine: ...@@ -57,14 +58,14 @@ class MooncakeTransferEngine:
def __init__(self, kv_rank: int, local_rank: int): def __init__(self, kv_rank: int, local_rank: int):
try: try:
import mooncake_vllm_adaptor as mva from mooncake.engine import TransferEngine
except ImportError as e: except ImportError as e:
raise ImportError( raise ImportError(
"Please install mooncake by following the instructions at " "Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector.") from e "to run vLLM with MooncakeConnector.") from e
self.engine = mva.mooncake_vllm_adaptor() self.engine = TransferEngine()
self.local_rank = local_rank self.local_rank = local_rank
try: try:
...@@ -115,14 +116,14 @@ class MooncakeTransferEngine: ...@@ -115,14 +116,14 @@ class MooncakeTransferEngine:
p_rank_offset = int(p_port) + 8 + self.local_rank * 2 p_rank_offset = int(p_port) + 8 + self.local_rank * 2
d_rank_offset = int(d_port) + 8 + self.local_rank * 2 d_rank_offset = int(d_port) + 8 + self.local_rank * 2
if kv_rank == 0: if kv_rank == 0:
self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}") self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}") self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}") self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}") self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
else: else:
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}") self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}") self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}") self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}") self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
def initialize(self, local_hostname: str, metadata_server: str, def initialize(self, local_hostname: str, metadata_server: str,
...@@ -140,12 +141,12 @@ class MooncakeTransferEngine: ...@@ -140,12 +141,12 @@ class MooncakeTransferEngine:
"Mooncake Configuration error. `metadata_backend`" "Mooncake Configuration error. `metadata_backend`"
f" should be one of {supported_backend}.") f" should be one of {supported_backend}.")
self.engine.initializeExt(local_hostname, metadata_server, self.engine.initialize_ext(local_hostname, metadata_server,
protocol, device_name, metadata_backend) protocol, device_name, metadata_backend)
def allocate_managed_buffer(self, length: int) -> int: def allocate_managed_buffer(self, length: int) -> int:
"""Allocate a managed buffer of the specified length.""" """Allocate a managed buffer of the specified length."""
ret = self.engine.allocateManagedBuffer(length) ret = self.engine.allocate_managed_buffer(length)
if ret <= 0: if ret <= 0:
logger.error("Allocation Return Error") logger.error("Allocation Return Error")
raise Exception("Allocation Return Error") raise Exception("Allocation Return Error")
...@@ -153,13 +154,13 @@ class MooncakeTransferEngine: ...@@ -153,13 +154,13 @@ class MooncakeTransferEngine:
def free_managed_buffer(self, buffer: int, length: int) -> int: def free_managed_buffer(self, buffer: int, length: int) -> int:
"""Free a previously allocated managed buffer.""" """Free a previously allocated managed buffer."""
return self.engine.freeManagedBuffer(buffer, length) return self.engine.free_managed_buffer(buffer, length)
def transfer_sync(self, buffer: int, peer_buffer_address: int, def transfer_sync(self, buffer: int, peer_buffer_address: int,
length: int) -> int: length: int) -> int:
"""Synchronously transfer data to the specified address.""" """Synchronously transfer data to the specified address."""
ret = self.engine.transferSync(self.remote_url, buffer, ret = self.engine.transfer_sync_read(self.remote_url, buffer,
peer_buffer_address, length) peer_buffer_address, length)
if ret < 0: if ret < 0:
logger.error("Transfer Return Error") logger.error("Transfer Return Error")
raise Exception("Transfer Return Error") raise Exception("Transfer Return Error")
...@@ -168,15 +169,15 @@ class MooncakeTransferEngine: ...@@ -168,15 +169,15 @@ class MooncakeTransferEngine:
def write_bytes_to_buffer(self, buffer: int, user_data: bytes, def write_bytes_to_buffer(self, buffer: int, user_data: bytes,
length: int) -> int: length: int) -> int:
"""Write bytes to the allocated buffer.""" """Write bytes to the allocated buffer."""
return self.engine.writeBytesToBuffer(buffer, user_data, length) return self.engine.write_bytes_to_buffer(buffer, user_data, length)
def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes: def read_bytes_from_buffer(self, buffer: int, length: int) -> bytes:
"""Read bytes from the allocated buffer.""" """Read bytes from the allocated buffer."""
return self.engine.readBytesFromBuffer(buffer, length) return self.engine.read_bytes_from_buffer(buffer, length)
def wait_for_ack(self, src_ptr: int, length: int) -> None: def wait_for_ack(self, src_ptr: int, length: int) -> None:
"""Asynchronously wait for ACK from the receiver.""" """Asynchronously wait for ACK from the receiver."""
ack = self.sender_ack.recv_pyobj() ack = self.sender_ack.recv()
if ack != b'ACK': if ack != b'ACK':
logger.error("Failed to receive ACK from the receiver") logger.error("Failed to receive ACK from the receiver")
...@@ -187,18 +188,22 @@ class MooncakeTransferEngine: ...@@ -187,18 +188,22 @@ class MooncakeTransferEngine:
length = len(user_data) length = len(user_data)
src_ptr = self.allocate_managed_buffer(length) src_ptr = self.allocate_managed_buffer(length)
self.write_bytes_to_buffer(src_ptr, user_data, length) self.write_bytes_to_buffer(src_ptr, user_data, length)
self.sender_socket.send_pyobj((src_ptr, length)) self.sender_socket.send_multipart(
[struct.pack("!Q", src_ptr),
struct.pack("!Q", length)])
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length) self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
def recv_bytes(self) -> bytes: def recv_bytes(self) -> bytes:
"""Receive bytes from the remote process.""" """Receive bytes from the remote process."""
src_ptr, length = self.receiver_socket.recv_pyobj() data = self.receiver_socket.recv_multipart()
src_ptr = struct.unpack("!Q", data[0])[0]
length = struct.unpack("!Q", data[1])[0]
dst_ptr = self.allocate_managed_buffer(length) dst_ptr = self.allocate_managed_buffer(length)
self.transfer_sync(dst_ptr, src_ptr, length) self.transfer_sync(dst_ptr, src_ptr, length)
ret = self.read_bytes_from_buffer(dst_ptr, length) ret = self.read_bytes_from_buffer(dst_ptr, length)
# Buffer cleanup # Buffer cleanup
self.receiver_ack.send_pyobj(b'ACK') self.receiver_ack.send(b'ACK')
self.free_managed_buffer(dst_ptr, length) self.free_managed_buffer(dst_ptr, length)
return ret return ret
......
# SPDX-License-Identifier: Apache-2.0
from typing import TYPE_CHECKING, Optional
from vllm import envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
from vllm.distributed.parallel_state import get_world_group
if TYPE_CHECKING:
from vllm.config import VllmConfig
_KV_CONNECTOR_AGENT: Optional[KVConnectorBaseType] = None
def get_kv_transfer_group() -> KVConnectorBaseType:
assert _KV_CONNECTOR_AGENT is not None, (
"disaggregated KV cache transfer parallel group is not initialized")
return _KV_CONNECTOR_AGENT
def has_kv_transfer_group() -> bool:
return _KV_CONNECTOR_AGENT is not None
def is_v1_kv_transfer_group(
connector: Optional[KVConnectorBaseType] = None) -> bool:
"""Check if the KV connector is the v1 connector.
If the argument is None, it will check the global KV connector
Args:
connector: The KV connector to check. If None, it will check the
global KV connector.
Note:
This function will no-longer be needed after the v1 KV connector
becomes the default.
"""
if connector is None:
connector = _KV_CONNECTOR_AGENT
if connector is None:
return False
return isinstance(connector, KVConnectorBase_V1)
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
Initialize KV cache transfer parallel group.
"""
global _KV_CONNECTOR_AGENT
if vllm_config.kv_transfer_config is None:
return
if (vllm_config.kv_transfer_config.is_kv_transfer_instance
and _KV_CONNECTOR_AGENT is None):
if envs.VLLM_USE_V1:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1(
config=vllm_config, role=KVConnectorRole.WORKER)
else:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0(
rank=get_world_group().rank,
local_rank=get_world_group().local_rank,
config=vllm_config,
)
...@@ -29,15 +29,13 @@ from collections import namedtuple ...@@ -29,15 +29,13 @@ from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Union)
from unittest.mock import patch from unittest.mock import patch
import torch import torch
import torch.distributed import torch.distributed
from torch.distributed import Backend, ProcessGroup from torch.distributed import Backend, ProcessGroup
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.device_communicators.base_device_communicator import ( from vllm.distributed.device_communicators.base_device_communicator import (
DeviceCommunicatorBase) DeviceCommunicatorBase)
...@@ -46,9 +44,6 @@ from vllm.logger import init_logger ...@@ -46,9 +44,6 @@ from vllm.logger import init_logger
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname, from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
supports_custom_op) supports_custom_op)
if TYPE_CHECKING:
from vllm.config import VllmConfig
@dataclass @dataclass
class GraphCaptureContext: class GraphCaptureContext:
...@@ -118,6 +113,38 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: ...@@ -118,6 +113,38 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
return torch.empty_like(tensor) return torch.empty_like(tensor)
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group.reduce_scatter(tensor, dim)
def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
new_shape = list(tensor.shape)
new_shape[dim] = tensor.shape[dim] // world_size
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
def all_gather(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
assert group_name in _groups, f"Group {group_name} is not found."
group = _groups[group_name]()
if group is None:
raise ValueError(f"Group {group_name} is destroyed.")
return group.all_gather(tensor, dim)
def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
group_name: str) -> torch.Tensor:
new_shape = list(tensor.shape)
new_shape[dim] = tensor.shape[dim] * world_size
return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)
if supports_custom_op(): if supports_custom_op():
from vllm.platforms import current_platform from vllm.platforms import current_platform
direct_register_custom_op( direct_register_custom_op(
...@@ -128,6 +155,20 @@ if supports_custom_op(): ...@@ -128,6 +155,20 @@ if supports_custom_op():
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
direct_register_custom_op(
op_name="reduce_scatter",
op_func=reduce_scatter,
mutates_args=[],
fake_impl=reduce_scatter_fake,
)
direct_register_custom_op(
op_name="all_gather",
op_func=all_gather,
mutates_args=[],
fake_impl=all_gather_fake,
)
class GroupCoordinator: class GroupCoordinator:
""" """
...@@ -327,6 +368,18 @@ class GroupCoordinator: ...@@ -327,6 +368,18 @@ class GroupCoordinator:
return self.device_communicator.all_gather(input_, dim) return self.device_communicator.all_gather(input_, dim)
def reduce_scatter(self,
input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
return self.device_communicator.reduce_scatter(input_, dim)
def gather(self, def gather(self,
input_: torch.Tensor, input_: torch.Tensor,
dst: int = 0, dst: int = 0,
...@@ -772,14 +825,6 @@ def get_pp_group() -> GroupCoordinator: ...@@ -772,14 +825,6 @@ def get_pp_group() -> GroupCoordinator:
# kept for backward compatibility # kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group get_pipeline_model_parallel_group = get_pp_group
_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None
def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
assert _KV_TRANSFER is not None, (
"disaggregated KV cache transfer parallel group is not initialized")
return _KV_TRANSFER
@contextmanager @contextmanager
def graph_capture(device: torch.device): def graph_capture(device: torch.device):
...@@ -962,26 +1007,6 @@ def initialize_model_parallel( ...@@ -962,26 +1007,6 @@ def initialize_model_parallel(
_DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group) _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
Initialize KV cache transfer parallel group.
"""
global _KV_TRANSFER
if vllm_config.kv_transfer_config is None:
return
if all([
vllm_config.kv_transfer_config.is_kv_transfer_instance,
_KV_TRANSFER is None
]):
_KV_TRANSFER = kv_transfer.KVTransferAgent(
rank=get_world_group().rank,
local_rank=get_world_group().local_rank,
config=vllm_config)
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
tensor_model_parallel_size: int, tensor_model_parallel_size: int,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
import dataclasses import dataclasses
import datetime import datetime
import pickle import pickle
import socket
import time import time
from collections import deque from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple from typing import Any, Deque, Dict, Optional, Sequence, Tuple
...@@ -123,6 +124,10 @@ class StatelessProcessGroup: ...@@ -123,6 +124,10 @@ class StatelessProcessGroup:
rank: int rank: int
world_size: int world_size: int
store: torch._C._distributed_c10d.Store store: torch._C._distributed_c10d.Store
# stores a reference to the socket so that the file descriptor stays alive
socket: Optional[socket.socket]
data_expiration_seconds: int = 3600 # 1 hour data_expiration_seconds: int = 3600 # 1 hour
# dst rank -> counter # dst rank -> counter
...@@ -234,18 +239,33 @@ class StatelessProcessGroup: ...@@ -234,18 +239,33 @@ class StatelessProcessGroup:
can call `StatelessProcessGroup.create` to form a group, and then process A, B, can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group. C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa """ # noqa
launch_server = rank == 0
if launch_server:
# listen on the specified interface (instead of 0.0.0.0)
listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
listen_socket.bind((host, port))
listen_socket.listen()
listen_fd = listen_socket.fileno()
else:
listen_socket = None
listen_fd = None
store = TCPStore( store = TCPStore(
host_name=host, host_name=host,
port=port, port=port,
world_size=world_size, world_size=world_size,
is_master=(rank == 0), is_master=launch_server,
timeout=datetime.timedelta(seconds=store_timeout), timeout=datetime.timedelta(seconds=store_timeout),
use_libuv=False, # for now: github.com/pytorch/pytorch/pull/150215
master_listen_fd=listen_fd,
) )
return StatelessProcessGroup( return StatelessProcessGroup(
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
store=store, store=store,
socket=listen_socket,
data_expiration_seconds=data_expiration_seconds) data_expiration_seconds=data_expiration_seconds)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# yapf: disable
import argparse import argparse
import dataclasses import dataclasses
import json import json
import re import re
import threading import threading
from dataclasses import MISSING, dataclass, fields from dataclasses import MISSING, dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, from typing import (Any, Callable, Dict, List, Literal, Optional, Type,
Tuple, Type, Union, cast, get_args, get_origin) TypeVar, Union, cast, get_args, get_origin)
import torch import torch
from typing_extensions import TypeIs, deprecated
import vllm.envs as envs import vllm.envs as envs
from vllm import version from vllm import version
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DecodingConfig, DeviceConfig, HfOverrides, ConfigFormat, ConfigType, DecodingConfig, Device,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackendV1, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, ObservabilityConfig, ModelConfig, ModelImpl, MultiModalConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig, ObservabilityConfig, ParallelConfig, PoolerConfig,
SchedulerConfig, SpeculativeConfig, TaskOption, PrefixCachingHashAlgo, PromptAdapterConfig,
TokenizerPoolConfig, VllmConfig, get_attr_docs) SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerPoolConfig, VllmConfig,
get_attr_docs, get_field)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
...@@ -28,33 +34,42 @@ from vllm.reasoning import ReasoningParserManager ...@@ -28,33 +34,42 @@ from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, StoreBoolean, is_in_ray_actor from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor
if TYPE_CHECKING: # yapf: enable
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
logger = init_logger(__name__) logger = init_logger(__name__)
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"] ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
DEVICE_OPTIONS = [ # object is used to allow for special typing forms
"auto", T = TypeVar("T")
"cuda", TypeHint = Union[type[Any], object]
"neuron", TypeHintT = Union[type[T], object]
"cpu",
"tpu",
"xpu",
"hpu",
]
def nullable_str(val: str): def optional_type(
if not val or val == "None": return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]:
return None
return val def _optional_type(val: str) -> Optional[T]:
if val == "" or val == "None":
return None
try:
if return_type is json.loads and not re.match("^{.*}$", val):
return cast(T, nullable_kvs(val))
return return_type(val)
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
return _optional_type
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: @deprecated(
"Passing a JSON argument as a string containing comma separated key=value "
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
"string instead.")
def nullable_kvs(val: str) -> dict[str, int]:
"""Parses a string containing comma separate key [str] to value [int] """Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary. pairs into a dictionary.
...@@ -64,10 +79,7 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: ...@@ -64,10 +79,7 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
Returns: Returns:
Dictionary with parsed values. Dictionary with parsed values.
""" """
if len(val) == 0: out_dict: dict[str, int] = {}
return None
out_dict: Dict[str, int] = {}
for item in val.split(","): for item in val.split(","):
kv_parts = [part.lower().strip() for part in item.split("=")] kv_parts = [part.lower().strip() for part in item.split("=")]
if len(kv_parts) != 2: if len(kv_parts) != 2:
...@@ -89,6 +101,105 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: ...@@ -89,6 +101,105 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
return out_dict return out_dict
def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]:
"""Check if the type hint is a specific type."""
return type_hint is type or get_origin(type_hint) is type
def contains_type(type_hints: set[TypeHint], type: TypeHintT) -> bool:
"""Check if the type hints contain a specific type."""
return any(is_type(type_hint, type) for type_hint in type_hints)
def get_type(type_hints: set[TypeHint], type: TypeHintT) -> TypeHintT:
"""Get the specific type from the type hints."""
return next((th for th in type_hints if is_type(th, type)), None)
def is_not_builtin(type_hint: TypeHint) -> bool:
"""Check if the class is not a built-in type."""
return type_hint.__module__ != "builtins"
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
# Get the default value of the field
default = field.default
if field.default_factory is not MISSING:
default = field.default_factory()
# Get the help text for the field
name = field.name
help = cls_docs[name]
# Escape % for argparse
help = help.replace("%", "%%")
# Initialise the kwargs dictionary for the field
kwargs[name] = {"default": default, "help": help}
# Get the set of possible types for the field
type_hints: set[TypeHint] = set()
if get_origin(field.type) is Union:
type_hints.update(get_args(field.type))
else:
type_hints.add(field.type)
# Set other kwargs based on the type hints
if contains_type(type_hints, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
elif contains_type(type_hints, Literal):
# Creates choices from Literal arguments
type_hint = get_type(type_hints, Literal)
choices = sorted(get_args(type_hint))
kwargs[name]["choices"] = choices
choice_type = type(choices[0])
assert all(type(c) is choice_type for c in choices), (
"All choices must be of the same type. "
f"Got {choices} with types {[type(c) for c in choices]}")
kwargs[name]["type"] = choice_type
elif contains_type(type_hints, tuple):
type_hint = get_type(type_hints, tuple)
types = get_args(type_hint)
tuple_type = types[0]
assert all(t is tuple_type for t in types if t is not Ellipsis), (
"All non-Ellipsis tuple elements must be of the same "
f"type. Got {types}.")
kwargs[name]["type"] = tuple_type
kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types)
elif contains_type(type_hints, list):
type_hint = get_type(type_hints, list)
types = get_args(type_hint)
assert len(types) == 1, (
"List type must have exactly one type. Got "
f"{type_hint} with types {types}")
kwargs[name]["type"] = types[0]
kwargs[name]["nargs"] = "+"
elif contains_type(type_hints, int):
kwargs[name]["type"] = int
elif contains_type(type_hints, float):
kwargs[name]["type"] = float
elif contains_type(type_hints, dict):
# Dict arguments will always be optional
kwargs[name]["type"] = optional_type(json.loads)
elif (contains_type(type_hints, str)
or any(is_not_builtin(th) for th in type_hints)):
kwargs[name]["type"] = str
else:
raise ValueError(
f"Unsupported type {type_hints} for argument {name}.")
# If None is in type_hints, make the argument optional.
# But not if it's a bool, argparse will handle this better.
if type(None) in type_hints and not contains_type(type_hints, bool):
kwargs[name]["type"] = optional_type(kwargs[name]["type"])
if kwargs[name].get("choices"):
kwargs[name]["choices"].append("None")
return kwargs
@dataclass @dataclass
class EngineArgs: class EngineArgs:
"""Arguments for vLLM engine.""" """Arguments for vLLM engine."""
...@@ -105,14 +216,15 @@ class EngineArgs: ...@@ -105,14 +216,15 @@ class EngineArgs:
load_format: str = LoadConfig.load_format load_format: str = LoadConfig.load_format
config_format: ConfigFormat = ConfigFormat.AUTO config_format: ConfigFormat = ConfigFormat.AUTO
dtype: str = 'auto' dtype: str = 'auto'
kv_cache_dtype: str = 'auto' kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
seed: Optional[int] = None seed: Optional[int] = None
max_model_len: Optional[int] = None max_model_len: Optional[int] = None
# Note: Specifying a custom executor backend by passing a class # Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without # is intended for expert use only. The API may change without
# notice. # notice.
distributed_executor_backend: Optional[Union[ distributed_executor_backend: Optional[Union[
str, Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend DistributedExecutorBackend,
Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend
# number of P/D disaggregation (or other disaggregation) workers # number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
...@@ -120,20 +232,23 @@ class EngineArgs: ...@@ -120,20 +232,23 @@ class EngineArgs:
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
max_parallel_loading_workers: Optional[ max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers int] = ParallelConfig.max_parallel_loading_workers
block_size: Optional[int] = None block_size: Optional[BlockSize] = CacheConfig.block_size
enable_prefix_caching: Optional[bool] = None enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
prefix_caching_hash_algo: str = "builtin" prefix_caching_hash_algo: PrefixCachingHashAlgo = \
CacheConfig.prefix_caching_hash_algo
disable_sliding_window: bool = False disable_sliding_window: bool = False
disable_cascade_attn: bool = False disable_cascade_attn: bool = False
use_v2_block_manager: bool = True use_v2_block_manager: bool = True
swap_space: float = 4 # GiB swap_space: float = CacheConfig.swap_space
cpu_offload_gb: float = 0 # GiB cpu_offload_gb: float = CacheConfig.cpu_offload_gb
gpu_memory_utilization: float = 0.90 gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
max_num_batched_tokens: Optional[int] = None max_num_batched_tokens: Optional[
max_num_partial_prefills: Optional[int] = 1 int] = SchedulerConfig.max_num_batched_tokens
max_long_partial_prefills: Optional[int] = 1 max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
long_prefill_token_threshold: Optional[int] = 0 max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
max_num_seqs: Optional[int] = None long_prefill_token_threshold: int = \
SchedulerConfig.long_prefill_token_threshold
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
...@@ -147,42 +262,51 @@ class EngineArgs: ...@@ -147,42 +262,51 @@ class EngineArgs:
enforce_eager: Optional[bool] = None enforce_eager: Optional[bool] = None
max_seq_len_to_capture: int = 8192 max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
tokenizer_pool_size: int = 0 # The following three fields are deprecated and will be removed in a future
# Note: Specifying a tokenizer pool by passing a class # release. Setting them will have no effect. Please remove them from your
# is intended for expert use only. The API may change without # configurations.
# notice. tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray" tokenizer_pool_type: str = TokenizerPoolConfig.pool_type
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None tokenizer_pool_extra_config: dict = \
limit_mm_per_prompt: Optional[Mapping[str, int]] = None get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: dict[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt")
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None
disable_mm_preprocessor_cache: bool = False disable_mm_preprocessor_cache: bool = False
# LoRA fields
enable_lora: bool = False enable_lora: bool = False
enable_lora_bias: bool = False enable_lora_bias: bool = LoRAConfig.bias_enabled
max_loras: int = 1 max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = 16 max_lora_rank: int = LoRAConfig.max_lora_rank
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size
long_lora_scaling_factors: Optional[tuple[float, ...]] = \
LoRAConfig.long_lora_scaling_factors
# PromptAdapter fields
enable_prompt_adapter: bool = False enable_prompt_adapter: bool = False
max_prompt_adapters: int = 1 max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters
max_prompt_adapter_token: int = 0 max_prompt_adapter_token: int = \
fully_sharded_loras: bool = False PromptAdapterConfig.max_prompt_adapter_token
lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None device: Device = DeviceConfig.device
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto' num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
max_cpu_loras: Optional[int] = None multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
device: str = 'auto'
num_scheduler_steps: int = 1
multi_step_stream_outputs: bool = True
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: Optional[int] = None num_gpu_blocks_override: Optional[
num_lookahead_slots: int = 0 int] = CacheConfig.num_gpu_blocks_override
model_loader_extra_config: Optional[ num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
dict] = LoadConfig.model_loader_extra_config model_loader_extra_config: dict = \
get_field(LoadConfig, "model_loader_extra_config")
ignore_patterns: Optional[Union[str, ignore_patterns: Optional[Union[str,
List[str]]] = LoadConfig.ignore_patterns List[str]]] = LoadConfig.ignore_patterns
preemption_mode: Optional[str] = None preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
scheduler_delay_factor: float = 0.0 scheduler_delay_factor: float = SchedulerConfig.delay_factor
enable_chunked_prefill: Optional[bool] = None enable_chunked_prefill: Optional[
disable_chunked_mm_input: bool = False bool] = SchedulerConfig.enable_chunked_prefill
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
guided_decoding_backend: str = DecodingConfig.guided_decoding_backend guided_decoding_backend: str = DecodingConfig.guided_decoding_backend
logits_processor_pattern: Optional[str] = None logits_processor_pattern: Optional[str] = None
...@@ -194,8 +318,8 @@ class EngineArgs: ...@@ -194,8 +318,8 @@ class EngineArgs:
otlp_traces_endpoint: Optional[str] = None otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False disable_async_output_proc: bool = False
scheduling_policy: Literal["fcfs", "priority"] = "fcfs" scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: Union[str, Type[object]] = "vllm.core.scheduler.Scheduler" scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
override_neuron_config: Optional[Dict[str, Any]] = None override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None override_pooler_config: Optional[PoolerConfig] = None
...@@ -210,11 +334,11 @@ class EngineArgs: ...@@ -210,11 +334,11 @@ class EngineArgs:
enable_sleep_mode: bool = False enable_sleep_mode: bool = False
model_impl: str = "auto" model_impl: str = "auto"
calculate_kv_scales: Optional[bool] = None calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
additional_config: Optional[Dict[str, Any]] = None additional_config: Optional[Dict[str, Any]] = None
enable_reasoning: Optional[bool] = None enable_reasoning: Optional[bool] = None
reasoning_parser: Optional[str] = None reasoning_parser: Optional[str] = DecodingConfig.reasoning_backend
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
def __post_init__(self): def __post_init__(self):
...@@ -236,38 +360,6 @@ class EngineArgs: ...@@ -236,38 +360,6 @@ class EngineArgs:
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine.""" """Shared CLI arguments for vLLM engine."""
def is_type_in_union(cls: type[Any], type: type[Any]) -> bool:
"""Check if the class is a type in a union type."""
return get_origin(cls) is Union and type in get_args(cls)
def is_optional(cls: type[Any]) -> bool:
"""Check if the class is an optional type."""
return is_type_in_union(cls, type(None))
def get_kwargs(cls: type[Any]) -> Dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
name = field.name
# One of these will always be present
default = (field.default_factory
if field.default is MISSING else field.default)
kwargs[name] = {"default": default, "help": cls_docs[name]}
# When using action="store_true"
# add_argument doesn't accept type
if field.type is bool:
continue
# Handle optional fields
if is_optional(field.type):
kwargs[name]["type"] = nullable_str
continue
# Handle str in union fields
if is_type_in_union(field.type, str):
kwargs[name]["type"] = str
continue
kwargs[name]["type"] = field.type
return kwargs
# Model arguments # Model arguments
parser.add_argument( parser.add_argument(
'--model', '--model',
...@@ -285,13 +377,13 @@ class EngineArgs: ...@@ -285,13 +377,13 @@ class EngineArgs:
'which task to use.') 'which task to use.')
parser.add_argument( parser.add_argument(
'--tokenizer', '--tokenizer',
type=nullable_str, type=optional_type(str),
default=EngineArgs.tokenizer, default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use. ' help='Name or path of the huggingface tokenizer to use. '
'If unspecified, model name or path will be used.') 'If unspecified, model name or path will be used.')
parser.add_argument( parser.add_argument(
"--hf-config-path", "--hf-config-path",
type=nullable_str, type=optional_type(str),
default=EngineArgs.hf_config_path, default=EngineArgs.hf_config_path,
help='Name or path of the huggingface config to use. ' help='Name or path of the huggingface config to use. '
'If unspecified, model name or path will be used.') 'If unspecified, model name or path will be used.')
...@@ -303,21 +395,21 @@ class EngineArgs: ...@@ -303,21 +395,21 @@ class EngineArgs:
'the input. The generated output will contain token ids.') 'the input. The generated output will contain token ids.')
parser.add_argument( parser.add_argument(
'--revision', '--revision',
type=nullable_str, type=optional_type(str),
default=None, default=None,
help='The specific model version to use. It can be a branch ' help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
'the default version.') 'the default version.')
parser.add_argument( parser.add_argument(
'--code-revision', '--code-revision',
type=nullable_str, type=optional_type(str),
default=None, default=None,
help='The specific revision to use for the model code on ' help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.') 'commit id. If unspecified, will use the default version.')
parser.add_argument( parser.add_argument(
'--tokenizer-revision', '--tokenizer-revision',
type=nullable_str, type=optional_type(str),
default=None, default=None,
help='Revision of the huggingface tokenizer to use. ' help='Revision of the huggingface tokenizer to use. '
'It can be a branch name, a tag name, or a commit id. ' 'It can be a branch name, a tag name, or a commit id. '
...@@ -357,7 +449,6 @@ class EngineArgs: ...@@ -357,7 +449,6 @@ class EngineArgs:
load_group.add_argument('--model-loader-extra-config', load_group.add_argument('--model-loader-extra-config',
**load_kwargs["model_loader_extra_config"]) **load_kwargs["model_loader_extra_config"])
load_group.add_argument('--use-tqdm-on-load', load_group.add_argument('--use-tqdm-on-load',
action=argparse.BooleanOptionalAction,
**load_kwargs["use_tqdm_on_load"]) **load_kwargs["use_tqdm_on_load"])
parser.add_argument( parser.add_argument(
...@@ -382,14 +473,6 @@ class EngineArgs: ...@@ -382,14 +473,6 @@ class EngineArgs:
'* "bfloat16" for a balance between precision and range.\n' '* "bfloat16" for a balance between precision and range.\n'
'* "float" is shorthand for FP32 precision.\n' '* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.') '* "float32" for FP32 precision.')
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument('--max-model-len', parser.add_argument('--max-model-len',
type=human_readable_int, type=human_readable_int,
default=EngineArgs.max_model_len, default=EngineArgs.max_model_len,
...@@ -399,21 +482,25 @@ class EngineArgs: ...@@ -399,21 +482,25 @@ class EngineArgs:
'Examples:\n' 'Examples:\n'
'- 1k → 1000\n' '- 1k → 1000\n'
'- 1K → 1024\n') '- 1K → 1024\n')
parser.add_argument(
# Guided decoding arguments
guided_decoding_kwargs = get_kwargs(DecodingConfig)
guided_decoding_group = parser.add_argument_group(
title="DecodingConfig",
description=DecodingConfig.__doc__,
)
guided_decoding_group.add_argument(
'--guided-decoding-backend', '--guided-decoding-backend',
type=str, **guided_decoding_kwargs["guided_decoding_backend"])
default=DecodingConfig.guided_decoding_backend, guided_decoding_group.add_argument(
help='Which engine will be used for guided decoding' "--reasoning-parser",
' (JSON schema / regex etc) by default. Currently support ' # This choices is a special case because it's not static
'https://github.com/mlc-ai/xgrammar and ' choices=list(ReasoningParserManager.reasoning_parsers),
'https://github.com/guidance-ai/llguidance.' **guided_decoding_kwargs["reasoning_backend"])
'Valid backend values are "xgrammar", "guidance", and "auto". '
'With "auto", we will make opinionated choices based on request '
'contents and what the backend libraries currently support, so '
'the behavior is subject to change in each release.')
parser.add_argument( parser.add_argument(
'--logits-processor-pattern', '--logits-processor-pattern',
type=nullable_str, type=optional_type(str),
default=None, default=None,
help='Optional regex pattern specifying valid logits processor ' help='Optional regex pattern specifying valid logits processor '
'qualified names that can be passed with the `logits_processors` ' 'qualified names that can be passed with the `logits_processors` '
...@@ -439,7 +526,6 @@ class EngineArgs: ...@@ -439,7 +526,6 @@ class EngineArgs:
) )
parallel_group.add_argument( parallel_group.add_argument(
'--distributed-executor-backend', '--distributed-executor-backend',
choices=['ray', 'mp', 'uni', 'external_launcher'],
**parallel_kwargs["distributed_executor_backend"]) **parallel_kwargs["distributed_executor_backend"])
parallel_group.add_argument( parallel_group.add_argument(
'--pipeline-parallel-size', '-pp', '--pipeline-parallel-size', '-pp',
...@@ -450,46 +536,40 @@ class EngineArgs: ...@@ -450,46 +536,40 @@ class EngineArgs:
**parallel_kwargs["data_parallel_size"]) **parallel_kwargs["data_parallel_size"])
parallel_group.add_argument( parallel_group.add_argument(
'--enable-expert-parallel', '--enable-expert-parallel',
action='store_true',
**parallel_kwargs["enable_expert_parallel"]) **parallel_kwargs["enable_expert_parallel"])
parallel_group.add_argument( parallel_group.add_argument(
'--max-parallel-loading-workers', '--max-parallel-loading-workers',
**parallel_kwargs["max_parallel_loading_workers"]) **parallel_kwargs["max_parallel_loading_workers"])
parallel_group.add_argument( parallel_group.add_argument(
'--ray-workers-use-nsight', '--ray-workers-use-nsight',
action='store_true',
**parallel_kwargs["ray_workers_use_nsight"]) **parallel_kwargs["ray_workers_use_nsight"])
parallel_group.add_argument( parallel_group.add_argument(
'--disable-custom-all-reduce', '--disable-custom-all-reduce',
action='store_true',
**parallel_kwargs["disable_custom_all_reduce"]) **parallel_kwargs["disable_custom_all_reduce"])
# KV cache arguments
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32, 64, 128],
help='Token block size for contiguous chunks of '
'tokens. This is ignored on neuron devices and '
'set to ``--max-model-len``. On CUDA devices, '
'only block sizes up to 32 are supported. '
'On HPU devices, block size defaults to 128.')
parser.add_argument( # KV cache arguments
"--enable-prefix-caching", cache_kwargs = get_kwargs(CacheConfig)
action=argparse.BooleanOptionalAction, cache_group = parser.add_argument_group(
default=EngineArgs.enable_prefix_caching, title="CacheConfig",
help="Enables automatic prefix caching. " description=CacheConfig.__doc__,
"Use ``--no-enable-prefix-caching`` to disable explicitly.",
)
parser.add_argument(
"--prefix-caching-hash-algo",
type=str,
choices=["builtin", "sha256"],
default=EngineArgs.prefix_caching_hash_algo,
help="Set the hash algorithm for prefix caching. "
"Options are 'builtin' (Python's built-in hash) or 'sha256' "
"(collision resistant but with certain overheads).",
) )
cache_group.add_argument('--block-size', **cache_kwargs["block_size"])
cache_group.add_argument('--gpu-memory-utilization',
**cache_kwargs["gpu_memory_utilization"])
cache_group.add_argument('--swap-space', **cache_kwargs["swap_space"])
cache_group.add_argument('--kv-cache-dtype',
**cache_kwargs["cache_dtype"])
cache_group.add_argument('--num-gpu-blocks-override',
**cache_kwargs["num_gpu_blocks_override"])
cache_group.add_argument("--enable-prefix-caching",
**cache_kwargs["enable_prefix_caching"])
cache_group.add_argument("--prefix-caching-hash-algo",
**cache_kwargs["prefix_caching_hash_algo"])
cache_group.add_argument('--cpu-offload-gb',
**cache_kwargs["cpu_offload_gb"])
cache_group.add_argument('--calculate-kv-scales',
**cache_kwargs["calculate_kv_scales"])
parser.add_argument('--disable-sliding-window', parser.add_argument('--disable-sliding-window',
action='store_true', action='store_true',
help='Disables sliding window, ' help='Disables sliding window, '
...@@ -502,86 +582,11 @@ class EngineArgs: ...@@ -502,86 +582,11 @@ class EngineArgs:
'block manager v2) is now the default. ' 'block manager v2) is now the default. '
'Setting this flag to True or False' 'Setting this flag to True or False'
' has no effect on vLLM behavior.') ' has no effect on vLLM behavior.')
parser.add_argument(
'--num-lookahead-slots',
type=int,
default=EngineArgs.num_lookahead_slots,
help='Experimental scheduling config necessary for '
'speculative decoding. This will be replaced by '
'speculative config in the future; it is present '
'to enable correctness tests until then.')
parser.add_argument('--seed', parser.add_argument('--seed',
type=int, type=int,
default=EngineArgs.seed, default=EngineArgs.seed,
help='Random seed for operations.') help='Random seed for operations.')
parser.add_argument('--swap-space',
type=float,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU.')
parser.add_argument(
'--cpu-offload-gb',
type=float,
default=0,
help='The space in GiB to offload to CPU, per GPU. '
'Default is 0, which means no offloading. Intuitively, '
'this argument can be seen as a virtual way to increase '
'the GPU memory size. For example, if you have one 24 GB '
'GPU and set this to 10, virtually you can think of it as '
'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
'which requires at least 26GB GPU memory. Note that this '
'requires fast CPU-GPU interconnect, as part of the model is '
'loaded from CPU memory to GPU memory on the fly in each '
'model forward pass.')
parser.add_argument(
'--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization,
help='The fraction of GPU memory to be used for the model '
'executor, which can range from 0 to 1. For example, a value of '
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
'will use the default value of 0.9. This is a per-instance '
'limit, and only applies to the current vLLM instance.'
'It does not matter if you have another vLLM instance running '
'on the same GPU. For example, if you have two vLLM instances '
'running on the same GPU, you can set the GPU memory utilization '
'to 0.5 for each instance.')
parser.add_argument(
'--num-gpu-blocks-override',
type=int,
default=None,
help='If specified, ignore GPU profiling result and use this number'
' of GPU blocks. Used for testing preemption.')
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
help='Maximum number of batched tokens per '
'iteration.')
parser.add_argument(
"--max-num-partial-prefills",
type=int,
default=EngineArgs.max_num_partial_prefills,
help="For chunked prefill, the max number of concurrent \
partial prefills.")
parser.add_argument(
"--max-long-partial-prefills",
type=int,
default=EngineArgs.max_long_partial_prefills,
help="For chunked prefill, the maximum number of prompts longer "
"than --long-prefill-token-threshold that will be prefilled "
"concurrently. Setting this less than --max-num-partial-prefills "
"will allow shorter prompts to jump the queue in front of longer "
"prompts in some cases, improving latency.")
parser.add_argument(
"--long-prefill-token-threshold",
type=float,
default=EngineArgs.long_prefill_token_threshold,
help="For chunked prefill, a request is considered long if the "
"prompt is longer than this number of tokens.")
parser.add_argument('--max-num-seqs',
type=int,
default=EngineArgs.max_num_seqs,
help='Maximum number of sequences per iteration.')
parser.add_argument( parser.add_argument(
'--max-logprobs', '--max-logprobs',
type=int, type=int,
...@@ -594,7 +599,7 @@ class EngineArgs: ...@@ -594,7 +599,7 @@ class EngineArgs:
# Quantization settings. # Quantization settings.
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
type=nullable_str, type=optional_type(str),
choices=[*QUANTIZATION_METHODS, None], choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization, default=EngineArgs.quantization,
help='Method used to quantize the weights. If ' help='Method used to quantize the weights. If '
...@@ -645,154 +650,108 @@ class EngineArgs: ...@@ -645,154 +650,108 @@ class EngineArgs:
'Additionally for encoder-decoder models, if the ' 'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger ' 'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.') 'than this, we fall back to the eager mode.')
parser.add_argument('--tokenizer-pool-size',
type=int, # Tokenizer arguments
default=EngineArgs.tokenizer_pool_size, tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
help='Size of tokenizer pool to use for ' tokenizer_group = parser.add_argument_group(
'asynchronous tokenization. If 0, will ' title="TokenizerPoolConfig",
'use synchronous tokenization.') description=TokenizerPoolConfig.__doc__,
parser.add_argument('--tokenizer-pool-type', )
type=str, tokenizer_group.add_argument('--tokenizer-pool-size',
default=EngineArgs.tokenizer_pool_type, **tokenizer_kwargs["pool_size"])
help='Type of tokenizer pool to use for ' tokenizer_group.add_argument('--tokenizer-pool-type',
'asynchronous tokenization. Ignored ' **tokenizer_kwargs["pool_type"])
'if tokenizer_pool_size is 0.') tokenizer_group.add_argument('--tokenizer-pool-extra-config',
parser.add_argument('--tokenizer-pool-extra-config', **tokenizer_kwargs["extra_config"])
type=nullable_str,
default=EngineArgs.tokenizer_pool_extra_config,
help='Extra config for tokenizer pool. '
'This should be a JSON string that will be '
'parsed into a dictionary. Ignored if '
'tokenizer_pool_size is 0.')
# Multimodal related configs # Multimodal related configs
parser.add_argument( multimodal_kwargs = get_kwargs(MultiModalConfig)
'--limit-mm-per-prompt', multimodal_group = parser.add_argument_group(
type=nullable_kvs, title="MultiModalConfig",
default=EngineArgs.limit_mm_per_prompt, description=MultiModalConfig.__doc__,
# The default value is given in )
# MultiModalConfig.get_default_limit_per_prompt multimodal_group.add_argument('--limit-mm-per-prompt',
help=('For each multimodal plugin, limit how many ' **multimodal_kwargs["limit_per_prompt"])
'input instances to allow for each prompt. '
'Expects a comma-separated list of items, '
'e.g.: `image=16,video=2` allows a maximum of 16 '
'images and 2 videos per prompt. Defaults to '
'1 (V0) or 999 (V1) for each modality.'))
parser.add_argument( parser.add_argument(
'--mm-processor-kwargs', '--mm-processor-kwargs',
default=None, default=None,
type=json.loads, type=json.loads,
help=('Overrides for the multimodal input mapping/processing, ' help=('Overrides for the multi-modal processor obtained from '
'e.g., image processor. For example: ``{"num_crops": 4}``.')) '``AutoProcessor.from_pretrained``. The available overrides '
'depend on the model that is being run.'
'For example, for Phi-3-Vision: ``{"num_crops": 4}``.'))
parser.add_argument( parser.add_argument(
'--disable-mm-preprocessor-cache', '--disable-mm-preprocessor-cache',
action='store_true', action='store_true',
help='If true, then disables caching of the multi-modal ' help='If True, disable caching of the processed multi-modal '
'preprocessor/mapper. (not recommended)') 'inputs.')
# LoRA related configs # LoRA related configs
parser.add_argument('--enable-lora', lora_kwargs = get_kwargs(LoRAConfig)
action='store_true', lora_group = parser.add_argument_group(
help='If True, enable handling of LoRA adapters.') title="LoRAConfig",
parser.add_argument('--enable-lora-bias', description=LoRAConfig.__doc__,
action='store_true', )
help='If True, enable bias for LoRA adapters.') lora_group.add_argument(
parser.add_argument('--max-loras', '--enable-lora',
type=int, action=argparse.BooleanOptionalAction,
default=EngineArgs.max_loras, help='If True, enable handling of LoRA adapters.')
help='Max number of LoRAs in a single batch.') lora_group.add_argument('--enable-lora-bias',
parser.add_argument('--max-lora-rank', **lora_kwargs["bias_enabled"])
type=int, lora_group.add_argument('--max-loras', **lora_kwargs["max_loras"])
default=EngineArgs.max_lora_rank, lora_group.add_argument('--max-lora-rank',
help='Max LoRA rank.') **lora_kwargs["max_lora_rank"])
parser.add_argument( lora_group.add_argument('--lora-extra-vocab-size',
'--lora-extra-vocab-size', **lora_kwargs["lora_extra_vocab_size"])
type=int, lora_group.add_argument(
default=EngineArgs.lora_extra_vocab_size,
help=('Maximum size of extra vocabulary that can be '
'present in a LoRA adapter (added to the base '
'model vocabulary).'))
parser.add_argument(
'--lora-dtype', '--lora-dtype',
type=str, **lora_kwargs["lora_dtype"],
default=EngineArgs.lora_dtype, )
choices=['auto', 'float16', 'bfloat16'], lora_group.add_argument('--long-lora-scaling-factors',
help=('Data type for LoRA. If auto, will default to ' **lora_kwargs["long_lora_scaling_factors"])
'base model dtype.')) lora_group.add_argument('--max-cpu-loras',
parser.add_argument( **lora_kwargs["max_cpu_loras"])
'--long-lora-scaling-factors', lora_group.add_argument('--fully-sharded-loras',
type=nullable_str, **lora_kwargs["fully_sharded_loras"])
default=EngineArgs.long_lora_scaling_factors,
help=('Specify multiple scaling factors (which can ' # PromptAdapter related configs
'be different from base model scaling factor ' prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
'- see eg. Long LoRA) to allow for multiple ' prompt_adapter_group = parser.add_argument_group(
'LoRA adapters trained with those scaling ' title="PromptAdapterConfig",
'factors to be used at the same time. If not ' description=PromptAdapterConfig.__doc__,
'specified, only adapters trained with the ' )
'base model scaling factor are allowed.')) prompt_adapter_group.add_argument(
parser.add_argument( '--enable-prompt-adapter',
'--max-cpu-loras', action=argparse.BooleanOptionalAction,
type=int, help='If True, enable handling of PromptAdapters.')
default=EngineArgs.max_cpu_loras, prompt_adapter_group.add_argument(
help=('Maximum number of LoRAs to store in CPU memory. ' '--max-prompt-adapters',
'Must be >= than max_loras.')) **prompt_adapter_kwargs["max_prompt_adapters"])
parser.add_argument( prompt_adapter_group.add_argument(
'--fully-sharded-loras', '--max-prompt-adapter-token',
action='store_true', **prompt_adapter_kwargs["max_prompt_adapter_token"])
help=('By default, only half of the LoRA computation is '
'sharded with tensor parallelism. ' # Device arguments
'Enabling this will use the fully sharded layers. ' device_kwargs = get_kwargs(DeviceConfig)
'At high sequence length, max rank or ' device_group = parser.add_argument_group(
'tensor parallel size, this is likely faster.')) title="DeviceConfig",
parser.add_argument('--enable-prompt-adapter', description=DeviceConfig.__doc__,
action='store_true', )
help='If True, enable handling of PromptAdapters.') device_group.add_argument("--device", **device_kwargs["device"])
parser.add_argument('--max-prompt-adapters',
type=int, # Speculative arguments
default=EngineArgs.max_prompt_adapters, speculative_group = parser.add_argument_group(
help='Max number of PromptAdapters in a batch.') title="SpeculativeConfig",
parser.add_argument('--max-prompt-adapter-token', description=SpeculativeConfig.__doc__,
type=int, )
default=EngineArgs.max_prompt_adapter_token, speculative_group.add_argument(
help='Max number of PromptAdapters tokens') '--speculative-config',
parser.add_argument("--device", type=json.loads,
type=str, default=None,
default=EngineArgs.device, help='The configurations for speculative decoding.'
choices=DEVICE_OPTIONS, ' Should be a JSON string.')
help='Device type for vLLM execution.')
parser.add_argument('--num-scheduler-steps',
type=int,
default=1,
help=('Maximum number of forward steps per '
'scheduler call.'))
parser.add_argument(
'--multi-step-stream-outputs',
action=StoreBoolean,
default=EngineArgs.multi_step_stream_outputs,
nargs="?",
const="True",
help='If False, then multi-step will stream outputs at the end '
'of all steps')
parser.add_argument(
'--scheduler-delay-factor',
type=float,
default=EngineArgs.scheduler_delay_factor,
help='Apply a delay (of delay factor multiplied by previous '
'prompt latency) before scheduling next prompt.')
parser.add_argument(
'--enable-chunked-prefill',
action=StoreBoolean,
default=EngineArgs.enable_chunked_prefill,
nargs="?",
const="True",
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.')
parser.add_argument('--speculative-config',
type=json.loads,
default=None,
help='The configurations for speculative decoding.'
' Should be a JSON string.')
parser.add_argument( parser.add_argument(
'--ignore-patterns', '--ignore-patterns',
action="append", action="append",
...@@ -801,13 +760,6 @@ class EngineArgs: ...@@ -801,13 +760,6 @@ class EngineArgs:
help="The pattern(s) to ignore when loading the model." help="The pattern(s) to ignore when loading the model."
"Default to `original/**/*` to avoid repeated loading of llama's " "Default to `original/**/*` to avoid repeated loading of llama's "
"checkpoints.") "checkpoints.")
parser.add_argument(
'--preemption-mode',
type=str,
default=None,
help='If \'recompute\', the engine performs preemption by '
'recomputing; If \'swap\', the engine performs preemption by '
'block swapping.')
parser.add_argument( parser.add_argument(
"--served-model-name", "--served-model-name",
...@@ -863,22 +815,47 @@ class EngineArgs: ...@@ -863,22 +815,47 @@ class EngineArgs:
help="Disable async output processing. This may result in " help="Disable async output processing. This may result in "
"lower performance.") "lower performance.")
parser.add_argument( # Scheduler arguments
'--scheduling-policy', scheduler_kwargs = get_kwargs(SchedulerConfig)
choices=['fcfs', 'priority'], scheduler_group = parser.add_argument_group(
default="fcfs", title="SchedulerConfig",
help='The scheduling policy to use. "fcfs" (first come first served' description=SchedulerConfig.__doc__,
', i.e. requests are handled in order of arrival; default) ' )
'or "priority" (requests are handled based on given ' scheduler_group.add_argument(
'priority (lower value means earlier handling) and time of ' '--max-num-batched-tokens',
'arrival deciding any ties).') **scheduler_kwargs["max_num_batched_tokens"])
scheduler_group.add_argument('--max-num-seqs',
parser.add_argument( **scheduler_kwargs["max_num_seqs"])
'--scheduler-cls', scheduler_group.add_argument(
default=EngineArgs.scheduler_cls, "--max-num-partial-prefills",
help='The scheduler class to use. "vllm.core.scheduler.Scheduler" ' **scheduler_kwargs["max_num_partial_prefills"])
'is the default scheduler. Can be a class directly or the path to ' scheduler_group.add_argument(
'a class of form "mod.custom_class".') "--max-long-partial-prefills",
**scheduler_kwargs["max_long_partial_prefills"])
scheduler_group.add_argument(
"--long-prefill-token-threshold",
**scheduler_kwargs["long_prefill_token_threshold"])
scheduler_group.add_argument('--num-lookahead-slots',
**scheduler_kwargs["num_lookahead_slots"])
scheduler_group.add_argument('--scheduler-delay-factor',
**scheduler_kwargs["delay_factor"])
scheduler_group.add_argument('--preemption-mode',
**scheduler_kwargs["preemption_mode"])
scheduler_group.add_argument('--num-scheduler-steps',
**scheduler_kwargs["num_scheduler_steps"])
scheduler_group.add_argument(
'--multi-step-stream-outputs',
**scheduler_kwargs["multi_step_stream_outputs"])
scheduler_group.add_argument('--scheduling-policy',
**scheduler_kwargs["policy"])
scheduler_group.add_argument(
'--enable-chunked-prefill',
**scheduler_kwargs["enable_chunked_prefill"])
scheduler_group.add_argument(
"--disable-chunked-mm-input",
**scheduler_kwargs["disable_chunked_mm_input"])
parser.add_argument('--scheduler-cls',
**scheduler_kwargs["scheduler_cls"])
parser.add_argument( parser.add_argument(
'--override-neuron-config', '--override-neuron-config',
...@@ -905,10 +882,11 @@ class EngineArgs: ...@@ -905,10 +882,11 @@ class EngineArgs:
'testing only. level 3 is the recommended level ' 'testing only. level 3 is the recommended level '
'for production.\n' 'for production.\n'
'To specify the full compilation config, ' 'To specify the full compilation config, '
'use a JSON string.\n' 'use a JSON string, e.g. ``{"level": 3, '
'"cudagraph_capture_sizes": [1, 2, 4, 8]}``\n'
'Following the convention of traditional ' 'Following the convention of traditional '
'compilers, using -O without space is also ' 'compilers, using ``-O`` without space is also '
'supported. -O3 is equivalent to -O 3.') 'supported. ``-O3`` is equivalent to ``-O 3``.')
parser.add_argument('--kv-transfer-config', parser.add_argument('--kv-transfer-config',
type=KVTransferConfig.from_cli, type=KVTransferConfig.from_cli,
...@@ -930,7 +908,7 @@ class EngineArgs: ...@@ -930,7 +908,7 @@ class EngineArgs:
'class without changing the existing functions.') 'class without changing the existing functions.')
parser.add_argument( parser.add_argument(
"--generation-config", "--generation-config",
type=nullable_str, type=optional_type(str),
default="auto", default="auto",
help="The folder path to the generation config. " help="The folder path to the generation config. "
"Defaults to 'auto', the generation config will be loaded from " "Defaults to 'auto', the generation config will be loaded from "
...@@ -957,15 +935,6 @@ class EngineArgs: ...@@ -957,15 +935,6 @@ class EngineArgs:
help="Enable sleep mode for the engine. " help="Enable sleep mode for the engine. "
"(only cuda platform is supported)") "(only cuda platform is supported)")
parser.add_argument(
'--calculate-kv-scales',
action='store_true',
help='This enables dynamic calculation of '
'k_scale and v_scale when kv-cache-dtype is fp8. '
'If calculate-kv-scales is false, the scales will '
'be loaded from the model checkpoint if available. '
'Otherwise, the scales will default to 1.0.')
parser.add_argument( parser.add_argument(
"--additional-config", "--additional-config",
type=json.loads, type=json.loads,
...@@ -983,16 +952,6 @@ class EngineArgs: ...@@ -983,16 +952,6 @@ class EngineArgs:
"If enabled, the model will be able to generate reasoning content." "If enabled, the model will be able to generate reasoning content."
) )
parser.add_argument(
"--reasoning-parser",
type=str,
choices=list(ReasoningParserManager.reasoning_parsers),
default=None,
help=
"Select the reasoning parser depending on the model that you're "
"using. This is used to parse the reasoning content into OpenAI "
"API format. Required for ``--enable-reasoning``.")
parser.add_argument( parser.add_argument(
"--disable-cascade-attn", "--disable-cascade-attn",
action="store_true", action="store_true",
...@@ -1003,20 +962,6 @@ class EngineArgs: ...@@ -1003,20 +962,6 @@ class EngineArgs:
"Note that even if this is set to False, cascade attention will be " "Note that even if this is set to False, cascade attention will be "
"only used when the heuristic tells that it's beneficial.") "only used when the heuristic tells that it's beneficial.")
parser.add_argument(
"--disable-chunked-mm-input",
action=StoreBoolean,
default=EngineArgs.disable_chunked_mm_input,
nargs="?",
const="True",
help="Disable multimodal input chunking attention for V1. "
"If set to true and chunked prefill is enabled, we do not want to"
" partially schedule a multimodal item. This ensures that if a "
"request has a mixed prompt (like text tokens TTTT followed by "
"image tokens IIIIIIIIII) where only some image tokens can be "
"scheduled (like TTTTIIIII, leaving IIIII), it will be scheduled "
"as TTTT in one step and IIIIIIIIII in the next.")
return parser return parser
@classmethod @classmethod
...@@ -1210,11 +1155,6 @@ class EngineArgs: ...@@ -1210,11 +1155,6 @@ class EngineArgs:
enable_expert_parallel=self.enable_expert_parallel, enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers, max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce, disable_custom_all_reduce=self.disable_custom_all_reduce,
tokenizer_pool_config=TokenizerPoolConfig.create_config(
self.tokenizer_pool_size,
self.tokenizer_pool_type,
self.tokenizer_pool_extra_config,
),
ray_workers_use_nsight=self.ray_workers_use_nsight, ray_workers_use_nsight=self.ray_workers_use_nsight,
placement_group=placement_group, placement_group=placement_group,
distributed_executor_backend=self.distributed_executor_backend, distributed_executor_backend=self.distributed_executor_backend,
...@@ -1288,8 +1228,6 @@ class EngineArgs: ...@@ -1288,8 +1228,6 @@ class EngineArgs:
if self.qlora_adapter_name_or_path is not None and \ if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "": self.qlora_adapter_name_or_path != "":
if self.model_loader_extra_config is None:
self.model_loader_extra_config = {}
self.model_loader_extra_config[ self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
...@@ -1370,7 +1308,7 @@ class EngineArgs: ...@@ -1370,7 +1308,7 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
if self.preemption_mode != EngineArgs.preemption_mode: if self.preemption_mode != SchedulerConfig.preemption_mode:
_raise_or_fallback(feature_name="--preemption-mode", _raise_or_fallback(feature_name="--preemption-mode",
recommend_to_remove=True) recommend_to_remove=True)
return False return False
...@@ -1381,34 +1319,28 @@ class EngineArgs: ...@@ -1381,34 +1319,28 @@ class EngineArgs:
recommend_to_remove=True) recommend_to_remove=True)
return False return False
if self.scheduling_policy != EngineArgs.scheduling_policy: if self.scheduling_policy != SchedulerConfig.policy:
_raise_or_fallback(feature_name="--scheduling-policy", _raise_or_fallback(feature_name="--scheduling-policy",
recommend_to_remove=False) recommend_to_remove=False)
return False return False
if self.num_scheduler_steps != EngineArgs.num_scheduler_steps: if self.num_scheduler_steps != SchedulerConfig.num_scheduler_steps:
_raise_or_fallback(feature_name="--num-scheduler-steps", _raise_or_fallback(feature_name="--num-scheduler-steps",
recommend_to_remove=True) recommend_to_remove=True)
return False return False
if self.scheduler_delay_factor != EngineArgs.scheduler_delay_factor: if self.scheduler_delay_factor != SchedulerConfig.delay_factor:
_raise_or_fallback(feature_name="--scheduler-delay-factor", _raise_or_fallback(feature_name="--scheduler-delay-factor",
recommend_to_remove=True) recommend_to_remove=True)
return False return False
if self.additional_config != EngineArgs.additional_config: # remove backend options when doing this check
_raise_or_fallback(feature_name="--additional-config", if self.guided_decoding_backend.split(':')[0] \
recommend_to_remove=False) not in get_args(GuidedDecodingBackendV1):
return False _raise_or_fallback(
feature_name=
# Xgrammar and Guidance are supported. f"--guided-decoding-backend={self.guided_decoding_backend}",
SUPPORTED_GUIDED_DECODING = [ recommend_to_remove=False)
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
"guidance:disable-any-whitespace", "auto"
]
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
_raise_or_fallback(feature_name="--guided-decoding-backend",
recommend_to_remove=False)
return False return False
# Need at least Ampere for now (FA support required). # Need at least Ampere for now (FA support required).
...@@ -1432,7 +1364,7 @@ class EngineArgs: ...@@ -1432,7 +1364,7 @@ class EngineArgs:
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False supported = False
if fp8_attention and will_use_fa: if fp8_attention and will_use_fa:
from vllm.vllm_flash_attn.fa_utils import ( from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8) flash_attn_supports_fp8)
supported = flash_attn_supports_fp8() supported = flash_attn_supports_fp8()
if not supported: if not supported:
...@@ -1475,9 +1407,9 @@ class EngineArgs: ...@@ -1475,9 +1407,9 @@ class EngineArgs:
# No Concurrent Partial Prefills so far. # No Concurrent Partial Prefills so far.
if (self.max_num_partial_prefills if (self.max_num_partial_prefills
!= EngineArgs.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills
or self.max_long_partial_prefills or self.max_long_partial_prefills
!= EngineArgs.max_long_partial_prefills): != SchedulerConfig.max_long_partial_prefills):
_raise_or_fallback(feature_name="Concurrent Partial Prefill", _raise_or_fallback(feature_name="Concurrent Partial Prefill",
recommend_to_remove=False) recommend_to_remove=False)
return False return False
...@@ -1497,7 +1429,7 @@ class EngineArgs: ...@@ -1497,7 +1429,7 @@ class EngineArgs:
if speculative_method: if speculative_method:
if speculative_method in ("ngram", "[ngram]"): if speculative_method in ("ngram", "[ngram]"):
is_ngram_enabled = True is_ngram_enabled = True
elif speculative_method == "eagle": elif speculative_method in ("eagle", "eagle3"):
is_eagle_enabled = True is_eagle_enabled = True
else: else:
speculative_model = self.speculative_config.get("model") speculative_model = self.speculative_config.get("model")
...@@ -1509,16 +1441,17 @@ class EngineArgs: ...@@ -1509,16 +1441,17 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# No Disaggregated Prefill so far. # No XFormers so far.
if self.kv_transfer_config != EngineArgs.kv_transfer_config:
_raise_or_fallback(feature_name="--kv-transfer-config",
recommend_to_remove=False)
return False
# No FlashInfer or XFormers so far.
V1_BACKENDS = [ V1_BACKENDS = [
"FLASH_ATTN_VLLM_V1", "FLASH_ATTN", "PALLAS", "PALLAS_VLLM_V1", "FLASH_ATTN_VLLM_V1",
"TRITON_ATTN_VLLM_V1", "TRITON_MLA", "FLASHMLA" "FLASH_ATTN",
"PALLAS",
"PALLAS_VLLM_V1",
"TRITON_ATTN_VLLM_V1",
"TRITON_MLA",
"FLASHMLA",
"FLASHINFER",
"FLASHINFER_VLLM_V1",
] ]
if (envs.is_set("VLLM_ATTENTION_BACKEND") if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
...@@ -1620,9 +1553,7 @@ class EngineArgs: ...@@ -1620,9 +1553,7 @@ class EngineArgs:
self.enable_prefix_caching = False self.enable_prefix_caching = False
# VLLM_V0 only supports builtin hash algo for prefix caching. # VLLM_V0 only supports builtin hash algo for prefix caching.
if self.prefix_caching_hash_algo is None: if self.prefix_caching_hash_algo == "sha256":
self.prefix_caching_hash_algo = "builtin"
elif self.prefix_caching_hash_algo == "sha256":
raise ValueError( raise ValueError(
"sha256 is not supported for prefix caching in V0 engine. " "sha256 is not supported for prefix caching in V0 engine. "
"Please use 'builtin'.") "Please use 'builtin'.")
...@@ -1641,10 +1572,6 @@ class EngineArgs: ...@@ -1641,10 +1572,6 @@ class EngineArgs:
if self.enable_prefix_caching is None: if self.enable_prefix_caching is None:
self.enable_prefix_caching = True self.enable_prefix_caching = True
# if using prefix caching, we must set a hash algo
if self.enable_prefix_caching and self.prefix_caching_hash_algo is None:
self.prefix_caching_hash_algo = "builtin"
# V1 should use the new scheduler by default. # V1 should use the new scheduler by default.
# Swap it only if this arg is set to the original V0 default # Swap it only if this arg is set to the original V0 default
if self.scheduler_cls == EngineArgs.scheduler_cls: if self.scheduler_cls == EngineArgs.scheduler_cls:
...@@ -1661,13 +1588,13 @@ class EngineArgs: ...@@ -1661,13 +1588,13 @@ class EngineArgs:
# values for non-H100/H200 GPUs. # values for non-H100/H200 GPUs.
try: try:
from vllm.platforms import current_platform from vllm.platforms import current_platform
device_name = current_platform.get_device_name().lower() device_memory = current_platform.get_device_total_memory()
except Exception: except Exception:
# This is only used to set default_max_num_batched_tokens # This is only used to set default_max_num_batched_tokens
device_name = "no-device" device_memory = 0
if "h100" in device_name or "h200" in device_name: if device_memory >= 70 * GiB_bytes:
# For H100 and H200, we use larger default values. # For GPUs like H100 and MI300x, use larger default values.
default_max_num_batched_tokens = { default_max_num_batched_tokens = {
UsageContext.LLM_CLASS: 16384, UsageContext.LLM_CLASS: 16384,
UsageContext.OPENAI_API_SERVER: 8192, UsageContext.OPENAI_API_SERVER: 8192,
......
...@@ -493,12 +493,11 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -493,12 +493,11 @@ class _AsyncLLMEngine(LLMEngine):
tokenizer = await self.get_tokenizer_async(lora_request) tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer) self._validate_token_prompt(prompt, tokenizer=tokenizer)
preprocessed_inputs = await self.input_preprocessor.preprocess_async( processed_inputs = await self.input_preprocessor.preprocess_async(
prompt, prompt,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
processed_inputs = self.input_processor(preprocessed_inputs)
if isinstance(params, SamplingParams) and \ if isinstance(params, SamplingParams) and \
params.guided_decoding is not None: params.guided_decoding is not None:
...@@ -526,10 +525,15 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -526,10 +525,15 @@ class _AsyncLLMEngine(LLMEngine):
) )
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health() self.model_executor.check_health()
async def collective_rpc_async(self,
method: str,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None):
raise NotImplementedError
async def build_guided_decoding_logits_processor_async( async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer, sampling_params: SamplingParams, tokenizer: AnyTokenizer,
...@@ -1167,6 +1171,10 @@ class AsyncLLMEngine(EngineClient): ...@@ -1167,6 +1171,10 @@ class AsyncLLMEngine(EngineClient):
exception=asyncio.CancelledError, exception=asyncio.CancelledError,
verbose=self.log_requests) verbose=self.log_requests)
async def get_vllm_config(self) -> VllmConfig:
"""Get the vllm configuration of the vLLM engine."""
return self.engine.get_vllm_config()
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""
return self.engine.get_model_config() return self.engine.get_model_config()
...@@ -1234,6 +1242,17 @@ class AsyncLLMEngine(EngineClient): ...@@ -1234,6 +1242,17 @@ class AsyncLLMEngine(EngineClient):
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> None:
self.engine.add_lora(lora_request) self.engine.add_lora(lora_request)
async def collective_rpc(self,
method: str,
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict] = None):
"""
Perform a collective RPC call to the given path.
"""
return await self.engine.collective_rpc_async(method, timeout, args,
kwargs)
# TODO(v1): Remove this class proxy when V1 goes default. # TODO(v1): Remove this class proxy when V1 goes default.
if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1:
......
...@@ -29,8 +29,7 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group ...@@ -29,8 +29,7 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import ( from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors) get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
PromptType, SingletonInputs)
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -55,7 +54,7 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, ...@@ -55,7 +54,7 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import ( from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs) TokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import (Counter, Device, deprecate_kwargs, from vllm.utils import (Counter, Device, deprecate_kwargs,
...@@ -66,7 +65,6 @@ from vllm.worker.model_runner_base import InputProcessingError ...@@ -66,7 +65,6 @@ from vllm.worker.model_runner_base import InputProcessingError
logger = init_logger(__name__) logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, PoolingRequestOutput) _O = TypeVar("_O", RequestOutput, PoolingRequestOutput)
_R = TypeVar("_R", default=Any) _R = TypeVar("_R", default=Any)
...@@ -205,7 +203,7 @@ class LLMEngine: ...@@ -205,7 +203,7 @@ class LLMEngine:
return outputs_ return outputs_
tokenizer: Optional[BaseTokenizerGroup] tokenizer: Optional[TokenizerGroup]
def __init__( def __init__(
self, self,
...@@ -214,7 +212,6 @@ class LLMEngine: ...@@ -214,7 +212,6 @@ class LLMEngine:
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
use_cached_outputs: bool = False, use_cached_outputs: bool = False,
) -> None: ) -> None:
...@@ -275,11 +272,7 @@ class LLMEngine: ...@@ -275,11 +272,7 @@ class LLMEngine:
self.tokenizer, self.tokenizer,
mm_registry) mm_registry)
self.input_registry = input_registry self.model_executor = executor_class(vllm_config=vllm_config)
self.input_processor = input_registry.create_input_processor(
self.model_config)
self.model_executor = executor_class(vllm_config=vllm_config, )
if self.model_config.runner_type != "pooling": if self.model_config.runner_type != "pooling":
self._initialize_kv_caches() self._initialize_kv_caches()
...@@ -321,11 +314,6 @@ class LLMEngine: ...@@ -321,11 +314,6 @@ class LLMEngine:
self.parallel_config.disable_custom_all_reduce, self.parallel_config.disable_custom_all_reduce,
}) })
if self.tokenizer:
# Ping the tokenizer to ensure liveness if it runs in a
# different process.
self.tokenizer.ping()
self.cached_scheduler_outputs = [ self.cached_scheduler_outputs = [
SchedulerOutputState() SchedulerOutputState()
for _ in range(self.parallel_config.pipeline_parallel_size) for _ in range(self.parallel_config.pipeline_parallel_size)
...@@ -537,21 +525,12 @@ class LLMEngine: ...@@ -537,21 +525,12 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None): if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown() model_executor.shutdown()
def get_tokenizer_group( def get_tokenizer_group(self) -> TokenizerGroup:
self, if self.tokenizer is None:
group_type: Type[_G] = BaseTokenizerGroup,
) -> _G:
tokenizer_group = self.tokenizer
if tokenizer_group is None:
raise ValueError("Unable to get tokenizer because " raise ValueError("Unable to get tokenizer because "
"skip_tokenizer_init is True") "skip_tokenizer_init is True")
if not isinstance(tokenizer_group, group_type):
raise TypeError("Invalid type of tokenizer group. "
f"Expected type: {group_type}, but "
f"found type: {type(tokenizer_group)}")
return tokenizer_group return self.tokenizer
def get_tokenizer( def get_tokenizer(
self, self,
...@@ -559,11 +538,10 @@ class LLMEngine: ...@@ -559,11 +538,10 @@ class LLMEngine:
) -> AnyTokenizer: ) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request) return self.get_tokenizer_group().get_lora_tokenizer(lora_request)
def _init_tokenizer(self) -> BaseTokenizerGroup: def _init_tokenizer(self) -> TokenizerGroup:
return init_tokenizer_from_configs( return init_tokenizer_from_configs(
model_config=self.model_config, model_config=self.model_config,
scheduler_config=self.scheduler_config, scheduler_config=self.scheduler_config,
parallel_config=self.parallel_config,
lora_config=self.lora_config) lora_config=self.lora_config)
def _verify_args(self) -> None: def _verify_args(self) -> None:
...@@ -778,12 +756,11 @@ class LLMEngine: ...@@ -778,12 +756,11 @@ class LLMEngine:
prompt, prompt,
tokenizer=self.get_tokenizer(lora_request=lora_request)) tokenizer=self.get_tokenizer(lora_request=lora_request))
preprocessed_inputs = self.input_preprocessor.preprocess( processed_inputs = self.input_preprocessor.preprocess(
prompt, prompt,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
processed_inputs = self.input_processor(preprocessed_inputs)
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
...@@ -914,6 +891,10 @@ class LLMEngine: ...@@ -914,6 +891,10 @@ class LLMEngine:
scheduler.abort_seq_group( scheduler.abort_seq_group(
request_id, seq_id_to_seq_group=self.seq_id_to_seq_group) request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
def get_vllm_config(self) -> VllmConfig:
"""Gets the vllm configuration."""
return self.vllm_config
def get_model_config(self) -> ModelConfig: def get_model_config(self) -> ModelConfig:
"""Gets the model configuration.""" """Gets the model configuration."""
return self.model_config return self.model_config
...@@ -1948,8 +1929,6 @@ class LLMEngine: ...@@ -1948,8 +1929,6 @@ class LLMEngine:
return self.model_executor.is_sleeping return self.model_executor.is_sleeping
def check_health(self) -> None: def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health() self.model_executor.check_health()
def is_tracing_enabled(self) -> bool: def is_tracing_enabled(self) -> bool:
...@@ -2058,7 +2037,7 @@ class LLMEngine: ...@@ -2058,7 +2037,7 @@ class LLMEngine:
raise ValueError(f"The {prompt_type} prompt cannot be empty") raise ValueError(f"The {prompt_type} prompt cannot be empty")
max_prompt_len = self.model_config.max_model_len max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) >= max_prompt_len: if len(prompt_ids) > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model: if prompt_type == "encoder" and model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor( mm_processor = mm_registry.create_processor(
......
...@@ -140,16 +140,13 @@ class Metrics: ...@@ -140,16 +140,13 @@ class Metrics:
name="vllm:generation_tokens_total", name="vllm:generation_tokens_total",
documentation="Number of generation tokens processed.", documentation="Number of generation tokens processed.",
labelnames=labelnames) labelnames=labelnames)
buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096]
if not vllm_config.model_config.enforce_eager:
buckets = vllm_config.compilation_config.\
cudagraph_capture_sizes.copy()
buckets.sort()
self.histogram_iteration_tokens = self._histogram_cls( self.histogram_iteration_tokens = self._histogram_cls(
name="vllm:iteration_tokens_total", name="vllm:iteration_tokens_total",
documentation="Histogram of number of tokens per engine_step.", documentation="Histogram of number of tokens per engine_step.",
labelnames=labelnames, labelnames=labelnames,
buckets=buckets) buckets=[
1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384
])
self.histogram_time_to_first_token = self._histogram_cls( self.histogram_time_to_first_token = self._histogram_cls(
name="vllm:time_to_first_token_seconds", name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.", documentation="Histogram of time to first token in seconds.",
......
...@@ -93,6 +93,7 @@ class MQLLMEngineClient(EngineClient): ...@@ -93,6 +93,7 @@ class MQLLMEngineClient(EngineClient):
self._errored_with: Optional[BaseException] = None self._errored_with: Optional[BaseException] = None
# Get the configs. # Get the configs.
self.vllm_config = engine_config
self.model_config = engine_config.model_config self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config self.decoding_config = engine_config.decoding_config
...@@ -100,7 +101,6 @@ class MQLLMEngineClient(EngineClient): ...@@ -100,7 +101,6 @@ class MQLLMEngineClient(EngineClient):
self.tokenizer = init_tokenizer_from_configs( self.tokenizer = init_tokenizer_from_configs(
model_config=self.model_config, model_config=self.model_config,
scheduler_config=engine_config.scheduler_config, scheduler_config=engine_config.scheduler_config,
parallel_config=engine_config.parallel_config,
lora_config=engine_config.lora_config) lora_config=engine_config.lora_config)
self.input_preprocessor = InputPreprocessor(self.model_config, self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer) self.tokenizer)
...@@ -377,6 +377,9 @@ class MQLLMEngineClient(EngineClient): ...@@ -377,6 +377,9 @@ class MQLLMEngineClient(EngineClient):
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
return await self.tokenizer.get_lora_tokenizer_async(lora_request) return await self.tokenizer.get_lora_tokenizer_async(lora_request)
async def get_vllm_config(self) -> VllmConfig:
return self.vllm_config
async def get_decoding_config(self) -> DecodingConfig: async def get_decoding_config(self) -> DecodingConfig:
return self.decoding_config return self.decoding_config
......
...@@ -178,7 +178,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): ...@@ -178,7 +178,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# generates a fixed number of tokens without evaluating stopping # generates a fixed number of tokens without evaluating stopping
# conditions within the block. This can cause an eos token to be # conditions within the block. This can cause an eos token to be
# unintentionally ignored. # unintentionally ignored.
if not sampling_params.ignore_eos: if not sampling_params.ignore_eos and self.detokenizer:
eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id
# Avoiding .index calls as exception throwing in the happy path # Avoiding .index calls as exception throwing in the happy path
# is expensive. # is expensive.
......
...@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod ...@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Mapping, Optional from typing import AsyncGenerator, List, Mapping, Optional
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType, TokensPrompt from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
...@@ -220,6 +220,11 @@ class EngineClient(ABC): ...@@ -220,6 +220,11 @@ class EngineClient(ABC):
""" """
... ...
@abstractmethod
async def get_vllm_config(self) -> VllmConfig:
"""Get the vllm configuration of the vLLM engine."""
...
@abstractmethod @abstractmethod
async def get_model_config(self) -> ModelConfig: async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine.""" """Get the model configuration of the vLLM engine."""
......
...@@ -27,10 +27,11 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam, ...@@ -27,10 +27,11 @@ from openai.types.chat import (ChatCompletionMessageToolCallParam,
ChatCompletionToolMessageParam) ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import ( from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio) InputAudio)
from pydantic import TypeAdapter
# yapf: enable # yapf: enable
# pydantic needs the TypedDict from typing_extensions
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast, from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin) ProcessorMixin)
# pydantic needs the TypedDict from typing_extensions
from typing_extensions import Required, TypeAlias, TypedDict from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -482,11 +483,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -482,11 +483,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if modality in ("image", "image_embeds"): if modality in ("image", "image_embeds"):
if model_type == "chatglm": if model_type == "chatglm":
return "<|begin_of_image|><|endoftext|><|end_of_image|>" return "<|begin_of_image|><|endoftext|><|end_of_image|>"
if model_type == "phi3_v": if model_type in ("phi3_v", "phi4mm"):
# Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>" return f"<|image_{current_count}|>"
if model_type == "phi4mm":
return "<|endoftext10|>" # 200010 (see vocab.json in hf model)
if model_type in ("minicpmo", "minicpmv"): if model_type in ("minicpmo", "minicpmv"):
return "(<image>./</image>)" return "(<image>./</image>)"
if model_type in ("blip-2", "florence2", "fuyu", "paligemma", if model_type in ("blip-2", "florence2", "fuyu", "paligemma",
...@@ -506,20 +504,24 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -506,20 +504,24 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
return "<|image|>" return "<|image|>"
if model_type in ("qwen2_vl", "qwen2_5_vl"): if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|image_pad|><|vision_end|>" return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
return "<|vision_start|><|IMAGE|><|vision_end|>"
if model_type == "molmo": if model_type == "molmo":
return "" return ""
if model_type == "aria": if model_type == "aria":
return "<|fim_prefix|><|img|><|fim_suffix|>" return "<|fim_prefix|><|img|><|fim_suffix|>"
if model_type == "gemma3": if model_type == "gemma3":
return "<start_of_image>" return "<start_of_image>"
if model_type == "kimi_vl":
return "<|media_start|>image<|media_content|><|media_pad|><|media_end|>" # noqa: E501
raise TypeError(f"Unknown {modality} model type: {model_type}") raise TypeError(f"Unknown {modality} model type: {model_type}")
elif modality == "audio": elif modality == "audio":
if model_type == "ultravox": if model_type in ("ultravox", "granite_speech"):
return "<|audio|>" return "<|audio|>"
if model_type == "phi4mm": if model_type == "phi4mm":
return "<|endoftext11|>" # 200011 (see vocab.json in hf model) return f"<|audio_{current_count}|>"
if model_type == "qwen2_audio": if model_type in ("qwen2_audio", "qwen2_5_omni"):
return (f"Audio {current_count}: " return (f"Audio {current_count}: "
f"<|audio_bos|><|AUDIO|><|audio_eos|>") f"<|audio_bos|><|AUDIO|><|audio_eos|>")
if model_type == "minicpmo": if model_type == "minicpmo":
...@@ -528,6 +530,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -528,6 +530,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
elif modality == "video": elif modality == "video":
if model_type in ("qwen2_vl", "qwen2_5_vl"): if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|video_pad|><|vision_end|>" return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type == "qwen2_5_omni":
return "<|vision_start|><|VIDEO|><|vision_end|>"
if model_type in ("minicpmo", "minicpmv"): if model_type in ("minicpmo", "minicpmv"):
return "(<video>./</video>)" return "(<video>./</video>)"
if model_type.startswith("llava"): if model_type.startswith("llava"):
...@@ -876,12 +880,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int], ...@@ -876,12 +880,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
# No need to validate using Pydantic again # No need to validate using Pydantic again
_TextParser = partial(cast, ChatCompletionContentPartTextParam) _TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam) _ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam) # Need to validate url objects
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio] _ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio]
...@@ -1092,7 +1097,11 @@ def _parse_chat_message_content( ...@@ -1092,7 +1097,11 @@ def _parse_chat_message_content(
if role == 'assistant': if role == 'assistant':
parsed_msg = _AssistantParser(message) parsed_msg = _AssistantParser(message)
if "tool_calls" in parsed_msg: # The 'tool_calls' is not None check ensures compatibility.
# It's needed only if downstream code doesn't strictly
# follow the OpenAI spec.
if ("tool_calls" in parsed_msg
and parsed_msg["tool_calls"] is not None):
result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
elif role == "tool": elif role == "tool":
parsed_msg = _ToolParser(message) parsed_msg = _ToolParser(message)
...@@ -1189,14 +1198,25 @@ def apply_hf_chat_template( ...@@ -1189,14 +1198,25 @@ def apply_hf_chat_template(
"allowed, so you must provide a chat template if the tokenizer " "allowed, so you must provide a chat template if the tokenizer "
"does not define one.") "does not define one.")
return tokenizer.apply_chat_template( try:
conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type] return tokenizer.apply_chat_template(
chat_template=hf_chat_template, conversation=conversation, # type: ignore[arg-type]
tokenize=tokenize, tools=tools, # type: ignore[arg-type]
**kwargs, chat_template=hf_chat_template,
) tokenize=tokenize,
**kwargs,
)
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `transformers` while applying chat template")
raise ValueError from e
def apply_mistral_chat_template( def apply_mistral_chat_template(
tokenizer: MistralTokenizer, tokenizer: MistralTokenizer,
...@@ -1205,6 +1225,8 @@ def apply_mistral_chat_template( ...@@ -1205,6 +1225,8 @@ def apply_mistral_chat_template(
tools: Optional[list[dict[str, Any]]], tools: Optional[list[dict[str, Any]]],
**kwargs: Any, **kwargs: Any,
) -> list[int]: ) -> list[int]:
from mistral_common.exceptions import MistralCommonException
# The return value of resolve_mistral_chat_template is always None, # The return value of resolve_mistral_chat_template is always None,
# and we won't use it. # and we won't use it.
resolve_mistral_chat_template( resolve_mistral_chat_template(
...@@ -1222,5 +1244,16 @@ def apply_mistral_chat_template( ...@@ -1222,5 +1244,16 @@ def apply_mistral_chat_template(
# if input does not comply with the expected format. # if input does not comply with the expected format.
# We convert those assertion errors to ValueErrors so they can be # We convert those assertion errors to ValueErrors so they can be
# are properly caught in the preprocessing_input step # are properly caught in the preprocessing_input step
except AssertionError as e: except (AssertionError, MistralCommonException) as e:
raise ValueError from e
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
except Exception as e:
# Log and report any library-related exceptions for further
# investigation.
logger.exception(
"An error occurred in `mistral_common` while applying chat "
"template")
raise ValueError from e raise ValueError from e
# SPDX-License-Identifier: Apache-2.0
import argparse
from vllm.benchmarks.latency import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
from vllm.entrypoints.cli.types import CLISubcommand
class BenchmarkLatencySubcommand(BenchmarkSubcommandBase):
""" The `latency` subcommand for vllm bench. """
def __init__(self):
self.name = "latency"
super().__init__()
@property
def help(self) -> str:
return "Benchmark the latency of a single batch of requests."
def add_cli_args(self, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)
def cmd_init() -> list[CLISubcommand]:
return [BenchmarkLatencySubcommand()]
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import vllm.entrypoints.cli.benchmark.latency
import vllm.entrypoints.cli.benchmark.serve import vllm.entrypoints.cli.benchmark.serve
import vllm.entrypoints.cli.benchmark.throughput
from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.cli.types import CLISubcommand
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
# TODO: Add the rest of the benchmark subcommands here,
# e.g., throughput, latency, etc.
BENCHMARK_CMD_MODULES = [ BENCHMARK_CMD_MODULES = [
vllm.entrypoints.cli.benchmark.latency,
vllm.entrypoints.cli.benchmark.serve, vllm.entrypoints.cli.benchmark.serve,
vllm.entrypoints.cli.benchmark.throughput,
] ]
......
# SPDX-License-Identifier: Apache-2.0
import argparse
from vllm.benchmarks.throughput import add_cli_args, main
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
from vllm.entrypoints.cli.types import CLISubcommand
class BenchmarkThroughputSubcommand(BenchmarkSubcommandBase):
""" The `throughput` subcommand for vllm bench. """
def __init__(self):
self.name = "throughput"
super().__init__()
@property
def help(self) -> str:
return "Benchmark offline inference throughput."
def add_cli_args(self, parser: argparse.ArgumentParser) -> None:
add_cli_args(parser)
@staticmethod
def cmd(args: argparse.Namespace) -> None:
main(args)
def cmd_init() -> list[CLISubcommand]:
return [BenchmarkThroughputSubcommand()]
# SPDX-License-Identifier: Apache-2.0
import argparse
from vllm.collect_env import main as collect_env_main
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.utils import FlexibleArgumentParser
class CollectEnvSubcommand(CLISubcommand):
"""The `serve` subcommand for the vLLM CLI. """
def __init__(self):
self.name = "collect-env"
super().__init__()
@staticmethod
def cmd(args: argparse.Namespace) -> None:
"""Collect information about the environment."""
collect_env_main()
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
serve_parser = subparsers.add_parser(
"collect-env",
help="Start collecting environment information.",
description="Start collecting environment information.",
usage="vllm collect-env")
return make_arg_parser(serve_parser)
def cmd_init() -> list[CLISubcommand]:
return [CollectEnvSubcommand()]
...@@ -5,6 +5,7 @@ import signal ...@@ -5,6 +5,7 @@ import signal
import sys import sys
import vllm.entrypoints.cli.benchmark.main import vllm.entrypoints.cli.benchmark.main
import vllm.entrypoints.cli.collect_env
import vllm.entrypoints.cli.openai import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.serve import vllm.entrypoints.cli.serve
import vllm.version import vllm.version
...@@ -15,6 +16,7 @@ CMD_MODULES = [ ...@@ -15,6 +16,7 @@ CMD_MODULES = [
vllm.entrypoints.cli.openai, vllm.entrypoints.cli.openai,
vllm.entrypoints.cli.serve, vllm.entrypoints.cli.serve,
vllm.entrypoints.cli.benchmark.main, vllm.entrypoints.cli.benchmark.main,
vllm.entrypoints.cli.collect_env,
] ]
......
...@@ -12,9 +12,11 @@ from fastapi import FastAPI, Request, Response ...@@ -12,9 +12,11 @@ from fastapi import FastAPI, Request, Response
from vllm import envs from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.ssl import SSLCertRefresher from vllm.entrypoints.ssl import SSLCertRefresher
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import find_process_using_port from vllm.utils import find_process_using_port
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -40,6 +42,8 @@ async def serve_http(app: FastAPI, ...@@ -40,6 +42,8 @@ async def serve_http(app: FastAPI,
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
watchdog_task = loop.create_task(
watchdog_loop(server, app.state.engine_client))
server_task = loop.create_task( server_task = loop.create_task(
server.serve(sockets=[sock] if sock else None)) server.serve(sockets=[sock] if sock else None))
...@@ -52,6 +56,7 @@ async def serve_http(app: FastAPI, ...@@ -52,6 +56,7 @@ async def serve_http(app: FastAPI,
def signal_handler() -> None: def signal_handler() -> None:
# prevents the uvicorn signal handler to exit early # prevents the uvicorn signal handler to exit early
server_task.cancel() server_task.cancel()
watchdog_task.cancel()
if ssl_cert_refresher: if ssl_cert_refresher:
ssl_cert_refresher.stop() ssl_cert_refresher.stop()
...@@ -73,48 +78,69 @@ async def serve_http(app: FastAPI, ...@@ -73,48 +78,69 @@ async def serve_http(app: FastAPI,
port, process, " ".join(process.cmdline())) port, process, " ".join(process.cmdline()))
logger.info("Shutting down FastAPI HTTP server.") logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown() return server.shutdown()
finally:
watchdog_task.cancel()
async def watchdog_loop(server: uvicorn.Server, engine: EngineClient):
"""
# Watchdog task that runs in the background, checking
# for error state in the engine. Needed to trigger shutdown
# if an exception arises is StreamingResponse() generator.
"""
VLLM_WATCHDOG_TIME_S = 5.0
while True:
await asyncio.sleep(VLLM_WATCHDOG_TIME_S)
terminate_if_errored(server, engine)
def terminate_if_errored(server: uvicorn.Server, engine: EngineClient):
"""
See discussions here on shutting down a uvicorn server
https://github.com/encode/uvicorn/discussions/1103
In this case we cannot await the server shutdown here
because handler must first return to close the connection
for this request.
"""
engine_errored = engine.errored and not engine.is_running
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine_errored:
server.should_exit = True
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""Adds handlers for fatal errors that should crash the server""" """
VLLM V1 AsyncLLM catches exceptions and returns
only two types: EngineGenerateError and EngineDeadError.
EngineGenerateError is raised by the per request generate()
method. This error could be request specific (and therefore
recoverable - e.g. if there is an error in input processing).
EngineDeadError is raised by the background output_handler
method. This error is global and therefore not recoverable.
We register these @app.exception_handlers to return nice
responses to the end user if they occur and shut down if needed.
See https://fastapi.tiangolo.com/tutorial/handling-errors/
for more details on how exception handlers work.
If an exception is encountered in a StreamingResponse
generator, the exception is not raised, since we already sent
a 200 status. Rather, we send an error message as the next chunk.
Since the exception is not raised, this means that the server
will not automatically shut down. Instead, we use the watchdog
background task for check for errored state.
"""
@app.exception_handler(RuntimeError) @app.exception_handler(RuntimeError)
async def runtime_error_handler(request: Request, __):
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
engine = request.app.state.engine_client
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
and not engine.is_running):
logger.fatal("AsyncLLMEngine has failed, terminating server "
"process")
# See discussions here on shutting down a uvicorn server
# https://github.com/encode/uvicorn/discussions/1103
# In this case we cannot await the server shutdown here because
# this handler must first return to close the connection for
# this request.
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(AsyncEngineDeadError) @app.exception_handler(AsyncEngineDeadError)
async def async_engine_dead_handler(_, __):
"""Kill the server if the async engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("AsyncLLMEngine is already dead, terminating server "
"process")
server.should_exit = True
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
@app.exception_handler(MQEngineDeadError) @app.exception_handler(MQEngineDeadError)
async def mq_engine_dead_handler(_, __): @app.exception_handler(EngineDeadError)
"""Kill the server if the mq engine is already dead. It will @app.exception_handler(EngineGenerateError)
not handle any further requests.""" async def runtime_exception_handler(request: Request, __):
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: terminate_if_errored(
logger.fatal("MQLLMEngine is already dead, terminating server " server=server,
"process") engine=request.app.state.engine_client,
server.should_exit = True )
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment