Commit 38d80967 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori

parents 33650733 880c741b
...@@ -20,7 +20,7 @@ from typing import Callable, Optional ...@@ -20,7 +20,7 @@ from typing import Callable, Optional
import torch import torch
from vllm.config import KVTransferConfig from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import StatelessProcessGroup
......
...@@ -64,3 +64,10 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: ...@@ -64,3 +64,10 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
config=vllm_config, role=KVConnectorRole.WORKER) config=vllm_config, role=KVConnectorRole.WORKER)
else: else:
raise ValueError("V0 is no longer supported") raise ValueError("V0 is no longer supported")
def ensure_kv_transfer_shutdown() -> None:
global _KV_CONNECTOR_AGENT
if _KV_CONNECTOR_AGENT is not None:
_KV_CONNECTOR_AGENT.shutdown()
_KV_CONNECTOR_AGENT = None
...@@ -29,6 +29,7 @@ import weakref ...@@ -29,6 +29,7 @@ import weakref
from collections import namedtuple from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
from unittest.mock import patch from unittest.mock import patch
...@@ -904,6 +905,18 @@ def get_tensor_model_parallel_group(): ...@@ -904,6 +905,18 @@ def get_tensor_model_parallel_group():
return get_tp_group() return get_tp_group()
_DCP: Optional[GroupCoordinator] = None
def get_dcp_group() -> GroupCoordinator:
assert _DCP is not None, (
"decode context model parallel group is not initialized")
return _DCP
# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group
_PP: Optional[GroupCoordinator] = None _PP: Optional[GroupCoordinator] = None
_DP: Optional[GroupCoordinator] = None _DP: Optional[GroupCoordinator] = None
...@@ -939,8 +952,8 @@ def get_pipeline_model_parallel_group(): ...@@ -939,8 +952,8 @@ def get_pipeline_model_parallel_group():
def graph_capture(device: torch.device): def graph_capture(device: torch.device):
""" """
`graph_capture` is a context manager which should surround the code that `graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the is capturing the CUDA graph. Its main purpose is to ensure that some
some operations will be run after the graph is captured, before the graph operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the stream that the graph capture is running on. This stream is set to the
...@@ -966,13 +979,12 @@ def set_custom_all_reduce(enable: bool): ...@@ -966,13 +979,12 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable _ENABLE_CUSTOM_ALL_REDUCE = enable
def init_distributed_environment( def init_distributed_environment(world_size: int = -1,
world_size: int = -1, rank: int = -1,
rank: int = -1, distributed_init_method: str = "env://",
distributed_init_method: str = "env://", local_rank: int = -1,
local_rank: int = -1, backend: str = "nccl",
backend: str = "nccl", timeout: Optional[timedelta] = None):
):
logger.debug( logger.debug(
"world_size=%d rank=%d local_rank=%d " "world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s", world_size, rank, local_rank, "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
...@@ -1008,7 +1020,8 @@ def init_distributed_environment( ...@@ -1008,7 +1020,8 @@ def init_distributed_environment(
backend=backend, backend=backend,
init_method=distributed_init_method, init_method=distributed_init_method,
world_size=world_size, world_size=world_size,
rank=rank) rank=rank,
timeout=timeout)
# set the local rank # set the local rank
# local_rank is not available in torch ProcessGroup, # local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816 # see https://github.com/pytorch/pytorch/issues/122816
...@@ -1034,6 +1047,7 @@ def init_distributed_environment( ...@@ -1034,6 +1047,7 @@ def init_distributed_environment(
def initialize_model_parallel( def initialize_model_parallel(
tensor_model_parallel_size: int = 1, tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1,
decode_context_model_parallel_size: Optional[int] = 1,
backend: Optional[str] = None, backend: Optional[str] = None,
) -> None: ) -> None:
""" """
...@@ -1098,6 +1112,23 @@ def initialize_model_parallel( ...@@ -1098,6 +1112,23 @@ def initialize_model_parallel(
use_message_queue_broadcaster=True, use_message_queue_broadcaster=True,
group_name="tp") group_name="tp")
# Build the DCP model-parallel groups.
global _DCP
assert _DCP is None, (
"decode context model parallel group is already initialized")
# Note(hc): In the current implementation of decode context parallel,
# dcp_size must not exceed tp_size, because the world size does not
# change by DCP, it simply reuses the GPUs of TP group, and split one
# TP group into tp_size//dcp_size DCP groups.
group_ranks = all_ranks.reshape(
-1, decode_context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_DCP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
use_message_queue_broadcaster=True,
group_name="dcp")
# Build the pipeline model-parallel groups. # Build the pipeline model-parallel groups.
global _PP global _PP
assert _PP is None, ( assert _PP is None, (
...@@ -1141,6 +1172,7 @@ def initialize_model_parallel( ...@@ -1141,6 +1172,7 @@ def initialize_model_parallel(
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
tensor_model_parallel_size: int, tensor_model_parallel_size: int,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
decode_context_model_parallel_size: Optional[int] = 1,
backend: Optional[str] = None, backend: Optional[str] = None,
) -> None: ) -> None:
"""Helper to initialize model parallel groups if they are not initialized, """Helper to initialize model parallel groups if they are not initialized,
...@@ -1151,7 +1183,8 @@ def ensure_model_parallel_initialized( ...@@ -1151,7 +1183,8 @@ def ensure_model_parallel_initialized(
get_world_group().device_group) get_world_group().device_group)
if not model_parallel_is_initialized(): if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size, initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size, backend) pipeline_model_parallel_size,
decode_context_model_parallel_size, backend)
return return
assert ( assert (
...@@ -1226,6 +1259,16 @@ def get_tensor_model_parallel_rank(): ...@@ -1226,6 +1259,16 @@ def get_tensor_model_parallel_rank():
return get_tp_group().rank_in_group return get_tp_group().rank_in_group
def get_decode_context_model_parallel_world_size():
"""Return world size for the decode context model parallel group."""
return get_dcp_group().world_size
def get_decode_context_model_parallel_rank():
"""Return my rank for the decode context model parallel group."""
return get_dcp_group().rank_in_group
def get_node_count() -> int: def get_node_count() -> int:
"""Return the total number of nodes in the distributed environment. """ """Return the total number of nodes in the distributed environment. """
assert _NODE_COUNT is not None, ( assert _NODE_COUNT is not None, (
...@@ -1246,6 +1289,11 @@ def destroy_model_parallel(): ...@@ -1246,6 +1289,11 @@ def destroy_model_parallel():
_PP.destroy() _PP.destroy()
_PP = None _PP = None
global _DCP
if _DCP:
_DCP.destroy()
_DCP = None
global _DP global _DP
if _DP: if _DP:
_DP.destroy() _DP.destroy()
......
...@@ -22,9 +22,9 @@ from typing_extensions import TypeIs, deprecated ...@@ -22,9 +22,9 @@ from typing_extensions import TypeIs, deprecated
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
ConfigFormat, ConfigType, ConvertOption, ConfigType, ConvertOption, DecodingConfig,
DecodingConfig, DetailedTraceModules, Device, DetailedTraceModules, Device, DeviceConfig,
DeviceConfig, DistributedExecutorBackend, EPLBConfig, DistributedExecutorBackend, EPLBConfig,
GuidedDecodingBackend, HfOverrides, KVEventsConfig, GuidedDecodingBackend, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode, KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
...@@ -227,8 +227,14 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: ...@@ -227,8 +227,14 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
elif contains_type(type_hints, int): elif contains_type(type_hints, int):
kwargs[name]["type"] = int kwargs[name]["type"] = int
# Special case for large integers # Special case for large integers
if name in {"max_model_len", "max_num_batched_tokens"}: human_readable_ints = {
"max_model_len",
"max_num_batched_tokens",
"kv_cache_memory_bytes",
}
if name in human_readable_ints:
kwargs[name]["type"] = human_readable_int kwargs[name]["type"] = human_readable_int
kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}"
elif contains_type(type_hints, float): elif contains_type(type_hints, float):
kwargs[name]["type"] = float kwargs[name]["type"] = float
elif (contains_type(type_hints, dict) elif (contains_type(type_hints, dict)
...@@ -289,6 +295,7 @@ class EngineArgs: ...@@ -289,6 +295,7 @@ class EngineArgs:
trust_remote_code: bool = ModelConfig.trust_remote_code trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path allowed_local_media_path: str = ModelConfig.allowed_local_media_path
download_dir: Optional[str] = LoadConfig.download_dir download_dir: Optional[str] = LoadConfig.download_dir
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
load_format: Union[str, LoadFormats] = LoadConfig.load_format load_format: Union[str, LoadFormats] = LoadConfig.load_format
config_format: str = ModelConfig.config_format config_format: str = ModelConfig.config_format
dtype: ModelDType = ModelConfig.dtype dtype: ModelDType = ModelConfig.dtype
...@@ -306,6 +313,8 @@ class EngineArgs: ...@@ -306,6 +313,8 @@ class EngineArgs:
# number of P/D disaggregation (or other disaggregation) workers # number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
decode_context_parallel_size: int = \
ParallelConfig.decode_context_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
data_parallel_start_rank: Optional[int] = None data_parallel_start_rank: Optional[int] = None
...@@ -332,6 +341,7 @@ class EngineArgs: ...@@ -332,6 +341,7 @@ class EngineArgs:
swap_space: float = CacheConfig.swap_space swap_space: float = CacheConfig.swap_space
cpu_offload_gb: float = CacheConfig.cpu_offload_gb cpu_offload_gb: float = CacheConfig.cpu_offload_gb
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes
max_num_batched_tokens: Optional[ max_num_batched_tokens: Optional[
int] = SchedulerConfig.max_num_batched_tokens int] = SchedulerConfig.max_num_batched_tokens
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
...@@ -417,8 +427,6 @@ class EngineArgs: ...@@ -417,8 +427,6 @@ class EngineArgs:
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
override_neuron_config: dict[str, Any] = \
get_field(ModelConfig, "override_neuron_config")
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
ModelConfig.override_pooler_config ModelConfig.override_pooler_config
compilation_config: CompilationConfig = \ compilation_config: CompilationConfig = \
...@@ -547,7 +555,6 @@ class EngineArgs: ...@@ -547,7 +555,6 @@ class EngineArgs:
help="Disable async output processing. This may result in " help="Disable async output processing. This may result in "
"lower performance.") "lower performance.")
model_group.add_argument("--config-format", model_group.add_argument("--config-format",
choices=[f.value for f in ConfigFormat],
**model_kwargs["config_format"]) **model_kwargs["config_format"])
# This one is a special case because it can bool # This one is a special case because it can bool
# or str. TODO: Handle this in get_kwargs # or str. TODO: Handle this in get_kwargs
...@@ -559,8 +566,6 @@ class EngineArgs: ...@@ -559,8 +566,6 @@ class EngineArgs:
help=model_kwargs["hf_token"]["help"]) help=model_kwargs["hf_token"]["help"])
model_group.add_argument("--hf-overrides", model_group.add_argument("--hf-overrides",
**model_kwargs["hf_overrides"]) **model_kwargs["hf_overrides"])
model_group.add_argument("--override-neuron-config",
**model_kwargs["override_neuron_config"])
model_group.add_argument("--override-pooler-config", model_group.add_argument("--override-pooler-config",
**model_kwargs["override_pooler_config"]) **model_kwargs["override_pooler_config"])
model_group.add_argument("--logits-processor-pattern", model_group.add_argument("--logits-processor-pattern",
...@@ -590,6 +595,8 @@ class EngineArgs: ...@@ -590,6 +595,8 @@ class EngineArgs:
load_group.add_argument("--load-format", **load_kwargs["load_format"]) load_group.add_argument("--load-format", **load_kwargs["load_format"])
load_group.add_argument("--download-dir", load_group.add_argument("--download-dir",
**load_kwargs["download_dir"]) **load_kwargs["download_dir"])
load_group.add_argument("--safetensors-load-strategy",
**load_kwargs["safetensors_load_strategy"])
load_group.add_argument("--model-loader-extra-config", load_group.add_argument("--model-loader-extra-config",
**load_kwargs["model_loader_extra_config"]) **load_kwargs["model_loader_extra_config"])
load_group.add_argument("--ignore-patterns", load_group.add_argument("--ignore-patterns",
...@@ -636,6 +643,9 @@ class EngineArgs: ...@@ -636,6 +643,9 @@ class EngineArgs:
**parallel_kwargs["pipeline_parallel_size"]) **parallel_kwargs["pipeline_parallel_size"])
parallel_group.add_argument("--tensor-parallel-size", "-tp", parallel_group.add_argument("--tensor-parallel-size", "-tp",
**parallel_kwargs["tensor_parallel_size"]) **parallel_kwargs["tensor_parallel_size"])
parallel_group.add_argument(
"--decode-context-parallel-size", "-dcp",
**parallel_kwargs["decode_context_parallel_size"])
parallel_group.add_argument("--data-parallel-size", "-dp", parallel_group.add_argument("--data-parallel-size", "-dp",
**parallel_kwargs["data_parallel_size"]) **parallel_kwargs["data_parallel_size"])
parallel_group.add_argument( parallel_group.add_argument(
...@@ -731,6 +741,8 @@ class EngineArgs: ...@@ -731,6 +741,8 @@ class EngineArgs:
cache_group.add_argument("--block-size", **cache_kwargs["block_size"]) cache_group.add_argument("--block-size", **cache_kwargs["block_size"])
cache_group.add_argument("--gpu-memory-utilization", cache_group.add_argument("--gpu-memory-utilization",
**cache_kwargs["gpu_memory_utilization"]) **cache_kwargs["gpu_memory_utilization"])
cache_group.add_argument("--kv-cache-memory-bytes",
**cache_kwargs["kv_cache_memory_bytes"])
cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"]) cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"])
cache_group.add_argument("--kv-cache-dtype", cache_group.add_argument("--kv-cache-dtype",
**cache_kwargs["cache_dtype"]) **cache_kwargs["cache_dtype"])
...@@ -987,7 +999,6 @@ class EngineArgs: ...@@ -987,7 +999,6 @@ class EngineArgs:
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
mm_processor_cache_gb=self.mm_processor_cache_gb, mm_processor_cache_gb=self.mm_processor_cache_gb,
mm_encoder_tp_mode=self.mm_encoder_tp_mode, mm_encoder_tp_mode=self.mm_encoder_tp_mode,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config, override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern, logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config, generation_config=self.generation_config,
...@@ -1024,6 +1035,7 @@ class EngineArgs: ...@@ -1024,6 +1035,7 @@ class EngineArgs:
return LoadConfig( return LoadConfig(
load_format=self.load_format, load_format=self.load_format,
download_dir=self.download_dir, download_dir=self.download_dir,
safetensors_load_strategy=self.safetensors_load_strategy,
device="cpu" device="cpu"
if is_online_quantization(self.quantization) else None, if is_online_quantization(self.quantization) else None,
model_loader_extra_config=self.model_loader_extra_config, model_loader_extra_config=self.model_loader_extra_config,
...@@ -1053,9 +1065,10 @@ class EngineArgs: ...@@ -1053,9 +1065,10 @@ class EngineArgs:
SpeculatorsConfig) SpeculatorsConfig)
if self.speculative_config is None: if self.speculative_config is None:
hf_config = get_config(self.hf_config_path or self.model, hf_config = get_config(
self.trust_remote_code, self.revision, self.hf_config_path or target_model_config.model,
self.code_revision, self.config_format) self.trust_remote_code, self.revision, self.code_revision,
self.config_format)
# if loading a SpeculatorsConfig, load the speculative_config # if loading a SpeculatorsConfig, load the speculative_config
# details from the config directly # details from the config directly
...@@ -1065,7 +1078,7 @@ class EngineArgs: ...@@ -1065,7 +1078,7 @@ class EngineArgs:
self.speculative_config = {} self.speculative_config = {}
self.speculative_config[ self.speculative_config[
"num_speculative_tokens"] = hf_config.num_lookahead_tokens "num_speculative_tokens"] = hf_config.num_lookahead_tokens
self.speculative_config["model"] = self.model self.speculative_config["model"] = target_model_config.model
self.speculative_config["method"] = hf_config.method self.speculative_config["method"] = hf_config.method
else: else:
return None return None
...@@ -1156,9 +1169,21 @@ class EngineArgs: ...@@ -1156,9 +1169,21 @@ class EngineArgs:
# global layers in interleaved sliding window models. # global layers in interleaved sliding window models.
sliding_window = model_config.get_sliding_window() sliding_window = model_config.get_sliding_window()
# Note(hc): In the current implementation of decode context
# parallel(DCP), tp_size needs to be divisible by dcp_size,
# because the world size does not change by dcp, it simply
# reuses the GPUs of TP group, and split one TP group into
# tp_size//dcp_size DCP groups.
assert self.tensor_parallel_size % self.decode_context_parallel_size \
== 0, (
f"tp_size={self.tensor_parallel_size} must be divisible by"
f"dcp_size={self.decode_context_parallel_size}."
)
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size, block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization, gpu_memory_utilization=self.gpu_memory_utilization,
kv_cache_memory_bytes=self.kv_cache_memory_bytes,
swap_space=self.swap_space, swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype, cache_dtype=self.kv_cache_dtype,
is_attention_free=model_config.is_attention_free, is_attention_free=model_config.is_attention_free,
...@@ -1306,6 +1331,7 @@ class EngineArgs: ...@@ -1306,6 +1331,7 @@ class EngineArgs:
distributed_executor_backend=self.distributed_executor_backend, distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls, worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls, worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size,
) )
speculative_config = self.create_speculative_config( speculative_config = self.create_speculative_config(
...@@ -1436,17 +1462,6 @@ class EngineArgs: ...@@ -1436,17 +1462,6 @@ class EngineArgs:
recommend_to_remove=True) recommend_to_remove=True)
return False return False
# Triton v3.3 has f16 conversion regression issue on Turing and Volta,
# which broke fp16 inference
# see: https://github.com/triton-lang/triton/issues/6698
if (current_platform.is_cuda()
and not current_platform.has_device_capability(80)
and model_config.dtype == torch.float16):
_raise_or_fallback(
feature_name="Compute Capability < 8.0 with FP16",
recommend_to_remove=False)
return False
if self.kv_cache_dtype != "auto": if self.kv_cache_dtype != "auto":
supported = current_platform.is_kv_cache_dtype_supported( supported = current_platform.is_kv_cache_dtype_supported(
self.kv_cache_dtype, model_config) self.kv_cache_dtype, model_config)
...@@ -1476,12 +1491,6 @@ class EngineArgs: ...@@ -1476,12 +1491,6 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# No OTLP observability so far.
if (self.otlp_traces_endpoint or self.collect_detailed_traces):
_raise_or_fallback(feature_name="--otlp-traces-endpoint",
recommend_to_remove=False)
return False
# V1 supports N-gram, Medusa, and Eagle speculative decoding. # V1 supports N-gram, Medusa, and Eagle speculative decoding.
if (self.speculative_config is not None if (self.speculative_config is not None
and self.speculative_config.get("method") == "draft_model"): and self.speculative_config.get("method") == "draft_model"):
...@@ -1499,8 +1508,11 @@ class EngineArgs: ...@@ -1499,8 +1508,11 @@ class EngineArgs:
"TRITON_MLA", "TRITON_MLA",
"CUTLASS_MLA", "CUTLASS_MLA",
"FLASHMLA", "FLASHMLA",
"FLASHMLA_VLLM_V1",
"FLASH_ATTN_MLA",
"FLASHINFER", "FLASHINFER",
"FLASHINFER_VLLM_V1", "FLASHINFER_VLLM_V1",
"FLASHINFER_MLA",
"ROCM_AITER_MLA", "ROCM_AITER_MLA",
"TORCH_SDPA_VLLM_V1", "TORCH_SDPA_VLLM_V1",
"FLEX_ATTENTION", "FLEX_ATTENTION",
...@@ -1589,20 +1601,12 @@ class EngineArgs: ...@@ -1589,20 +1601,12 @@ class EngineArgs:
"in low performance due to small KV cache size. Consider " "in low performance due to small KV cache size. Consider "
"setting --max-model-len to a smaller value.", max_model_len) "setting --max-model-len to a smaller value.", max_model_len)
# if using prefix caching, we must set a hash algo # Disable prefix caching for multimodal models for VLLM_V0.
if self.enable_prefix_caching: if self.enable_prefix_caching and model_config.is_multimodal_model:
# Disable prefix caching for multimodal models for VLLM_V0. logger.warning(
if model_config.is_multimodal_model: "--enable-prefix-caching is not supported for multimodal "
logger.warning( "models in V0 and has been disabled.")
"--enable-prefix-caching is not supported for multimodal " self.enable_prefix_caching = False
"models in V0 and has been disabled.")
self.enable_prefix_caching = False
# VLLM_V0 only supports builtin hash algo for prefix caching.
if self.prefix_caching_hash_algo == "sha256":
raise ValueError(
"sha256 is not supported for prefix caching in V0 engine. "
"Please use 'builtin'.")
# Set max_num_seqs to 256 for VLLM_V0. # Set max_num_seqs to 256 for VLLM_V0.
if self.max_num_seqs is None: if self.max_num_seqs is None:
......
...@@ -10,8 +10,9 @@ from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, ...@@ -10,8 +10,9 @@ from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
from weakref import ReferenceType from weakref import ReferenceType
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, ModelConfig, ParallelConfig,
ParallelConfig, SchedulerConfig, VllmConfig) SchedulerConfig, VllmConfig)
from vllm.config.lora import LoRAConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.async_timeout import asyncio_timeout
...@@ -717,7 +718,7 @@ class AsyncLLMEngine(EngineClient): ...@@ -717,7 +718,7 @@ class AsyncLLMEngine(EngineClient):
# Stop the execute model loop in parallel workers until there # Stop the execute model loop in parallel workers until there
# are more requests to process. This avoids waiting # are more requests to process. This avoids waiting
# indefinitely in torch.distributed ops which may otherwise # indefinitely in torch.distributed ops which may otherwise
# timeout, and unblocks the RPC thread in the workers so that # time out, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages, # they can process any other queued control plane messages,
# such as add/remove lora adapters. # such as add/remove lora adapters.
await engine.engine.stop_remote_worker_execution_loop_async() await engine.engine.stop_remote_worker_execution_loop_async()
......
...@@ -16,9 +16,9 @@ import torch ...@@ -16,9 +16,9 @@ import torch
from typing_extensions import TypeVar from typing_extensions import TypeVar
import vllm.envs as envs import vllm.envs as envs
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, from vllm.config import (DecodingConfig, ModelConfig, ObservabilityConfig,
ObservabilityConfig, ParallelConfig, SchedulerConfig, ParallelConfig, SchedulerConfig, VllmConfig)
VllmConfig) from vllm.config.lora import LoRAConfig
from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs from vllm.core.scheduler import ScheduledSequenceGroup, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics_types import StatLoggerBase, Stats from vllm.engine.metrics_types import StatLoggerBase, Stats
...@@ -278,7 +278,8 @@ class LLMEngine: ...@@ -278,7 +278,8 @@ class LLMEngine:
self.cache_config.block_size, self.cache_config.block_size,
"gpu_memory_utilization": "gpu_memory_utilization":
self.cache_config.gpu_memory_utilization, self.cache_config.gpu_memory_utilization,
"kv_cache_memory_bytes":
self.cache_config.kv_cache_memory_bytes,
# Quantization # Quantization
"quantization": "quantization":
self.model_config.quantization, self.model_config.quantization,
...@@ -1414,7 +1415,7 @@ class LLMEngine: ...@@ -1414,7 +1415,7 @@ class LLMEngine:
num_generation_tokens_iter = 0 num_generation_tokens_iter = 0
num_tokens_iter = 0 num_tokens_iter = 0
time_to_first_tokens_iter: List[float] = [] time_to_first_tokens_iter: List[float] = []
time_per_output_tokens_iter: List[float] = [] inter_token_latencies_iter: List[float] = []
num_preemption_iter = (0 if scheduler_outputs is None else num_preemption_iter = (0 if scheduler_outputs is None else
scheduler_outputs.preempted) scheduler_outputs.preempted)
...@@ -1498,9 +1499,9 @@ class LLMEngine: ...@@ -1498,9 +1499,9 @@ class LLMEngine:
num_generation_tokens_from_prefill_groups += ( num_generation_tokens_from_prefill_groups += (
seq_group.num_seqs()) seq_group.num_seqs())
else: else:
# TPOTs. # ITLs
latency = seq_group.get_last_token_latency() latency = seq_group.get_last_token_latency()
time_per_output_tokens_iter.append(latency) inter_token_latencies_iter.append(latency)
if seq_group.state.current_step == 0: if seq_group.state.current_step == 0:
# For async_output_proc, the do_log_stats() # For async_output_proc, the do_log_stats()
# is called following init_multi_step(), which # is called following init_multi_step(), which
...@@ -1582,7 +1583,7 @@ class LLMEngine: ...@@ -1582,7 +1583,7 @@ class LLMEngine:
num_generation_tokens_iter=num_generation_tokens_iter, num_generation_tokens_iter=num_generation_tokens_iter,
num_tokens_iter=num_tokens_iter, num_tokens_iter=num_tokens_iter,
time_to_first_tokens_iter=time_to_first_tokens_iter, time_to_first_tokens_iter=time_to_first_tokens_iter,
time_per_output_tokens_iter=time_per_output_tokens_iter, inter_token_latencies_iter=inter_token_latencies_iter,
num_preemption_iter=num_preemption_iter, num_preemption_iter=num_preemption_iter,
# Request stats # Request stats
......
...@@ -113,9 +113,21 @@ class Metrics: ...@@ -113,9 +113,21 @@ class Metrics:
0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0,
2560.0 2560.0
]) ])
# Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds
# TODO: in 0.12, only enable if show_hidden_metrics=True
self.histogram_time_per_output_token = self._histogram_cls( self.histogram_time_per_output_token = self._histogram_cls(
name="vllm:time_per_output_token_seconds", name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.", documentation=(
"Histogram of time per output token in seconds."
"DEPRECATED: Use vllm:inter_token_latency_seconds instead."),
labelnames=labelnames,
buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0
])
self.histogram_inter_token_latency = self._histogram_cls(
name="vllm:inter_token_latency_seconds",
documentation="Histogram of inter token latency in seconds.",
labelnames=labelnames, labelnames=labelnames,
buckets=[ buckets=[
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
...@@ -491,7 +503,9 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -491,7 +503,9 @@ class PrometheusStatLogger(StatLoggerBase):
self._log_histogram(self.metrics.histogram_time_to_first_token, self._log_histogram(self.metrics.histogram_time_to_first_token,
stats.time_to_first_tokens_iter) stats.time_to_first_tokens_iter)
self._log_histogram(self.metrics.histogram_time_per_output_token, self._log_histogram(self.metrics.histogram_time_per_output_token,
stats.time_per_output_tokens_iter) stats.inter_token_latencies_iter)
self._log_histogram(self.metrics.histogram_inter_token_latency,
stats.inter_token_latencies_iter)
# Request level data # Request level data
# Latency # Latency
......
...@@ -43,7 +43,7 @@ class Stats: ...@@ -43,7 +43,7 @@ class Stats:
num_generation_tokens_iter: int num_generation_tokens_iter: int
num_tokens_iter: int num_tokens_iter: int
time_to_first_tokens_iter: List[float] time_to_first_tokens_iter: List[float]
time_per_output_tokens_iter: List[float] inter_token_latencies_iter: List[float]
num_preemption_iter: int num_preemption_iter: int
# Request stats (should have _requests suffix) # Request stats (should have _requests suffix)
......
...@@ -235,7 +235,7 @@ class MQLLMEngineClient(EngineClient): ...@@ -235,7 +235,7 @@ class MQLLMEngineClient(EngineClient):
# therefore we have to inform that the current # therefore we have to inform that the current
# processed requests failed as well. Send back a dead # processed requests failed as well. Send back a dead
# engine error give this feedback and also give a # engine error give this feedback and also give a
# 'hint' to the server to shutdown next. # 'hint' to the server to shut down next.
exception = self.dead_error exception = self.dead_error
if request_id is None: if request_id is None:
...@@ -270,7 +270,7 @@ class MQLLMEngineClient(EngineClient): ...@@ -270,7 +270,7 @@ class MQLLMEngineClient(EngineClient):
queue.put_nowait(request_output) queue.put_nowait(request_output)
async def setup(self): async def setup(self):
"""Setup the client before it starts sending server requests.""" """Set up the client before it starts sending server requests."""
# Start output_loop # Start output_loop
if self.output_loop is None: if self.output_loop is None:
......
...@@ -49,7 +49,7 @@ class MQLLMEngine: ...@@ -49,7 +49,7 @@ class MQLLMEngine:
This class is used to wrap the This class is used to wrap the
[`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use
in concurrnet manner. It runs a background loop and uses zeromq to in concurrent manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc. receive new requests and stream outputs incrementally via ipc.
The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode
......
...@@ -78,6 +78,7 @@ class EngineClient(ABC): ...@@ -78,6 +78,7 @@ class EngineClient(ABC):
preprocessor = await self.get_input_preprocessor() preprocessor = await self.get_input_preprocessor()
tokenizer_group = preprocessor.get_tokenizer_group() tokenizer_group = preprocessor.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async() tokenizer = await tokenizer_group.get_lora_tokenizer_async()
eos_token_id = tokenizer.eos_token_id
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError raise NotImplementedError
...@@ -104,7 +105,7 @@ class EngineClient(ABC): ...@@ -104,7 +105,7 @@ class EngineClient(ABC):
tokenized_length = len(prompt_token_ids) tokenized_length = len(prompt_token_ids)
sort_beams_key = create_sort_beams_key_function( sort_beams_key = create_sort_beams_key_function(
tokenizer.eos_token_id, length_penalty) eos_token_id, length_penalty)
beam_search_params = SamplingParams( beam_search_params = SamplingParams(
logprobs=2 * beam_width, logprobs=2 * beam_width,
...@@ -154,7 +155,7 @@ class EngineClient(ABC): ...@@ -154,7 +155,7 @@ class EngineClient(ABC):
if result.outputs[0].logprobs is not None: if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0] logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items(): for token_id, logprob_obj in logprobs.items():
if token_id == tokenizer.eos_token_id and \ if token_id == eos_token_id and \
not ignore_eos: not ignore_eos:
completed.append( completed.append(
BeamSearchSequence( BeamSearchSequence(
...@@ -166,7 +167,7 @@ class EngineClient(ABC): ...@@ -166,7 +167,7 @@ class EngineClient(ABC):
cum_logprob=current_beam.cum_logprob + cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob, logprob_obj.logprob,
finish_reason="stop", finish_reason="stop",
stop_reason=tokenizer.eos_token_id)) stop_reason=eos_token_id))
else: else:
new_beams.append( new_beams.append(
BeamSearchSequence( BeamSearchSequence(
...@@ -189,14 +190,14 @@ class EngineClient(ABC): ...@@ -189,14 +190,14 @@ class EngineClient(ABC):
best_beams = sorted_completed[:beam_width] best_beams = sorted_completed[:beam_width]
for beam in best_beams: for beam in best_beams:
if (beam.tokens[-1] == tokenizer.eos_token_id and not ignore_eos): if (beam.tokens[-1] == eos_token_id and not ignore_eos):
# Skip the eos token in the text. # Skip the eos token in the text.
tokens = beam.tokens[tokenized_length:-1] tokens = beam.tokens[tokenized_length:-1]
else: else:
tokens = beam.tokens[tokenized_length:] tokens = beam.tokens[tokenized_length:]
beam.text = tokenizer.decode(tokens) beam.text = tokenizer.decode(tokens)
beam_search_output = RequestOutput( yield RequestOutput(
request_id=request_id, request_id=request_id,
prompt=prompt_text, prompt=prompt_text,
outputs=[ outputs=[
...@@ -214,8 +215,6 @@ class EngineClient(ABC): ...@@ -214,8 +215,6 @@ class EngineClient(ABC):
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
prompt_logprobs=None) prompt_logprobs=None)
yield beam_search_output
@abstractmethod @abstractmethod
def encode( def encode(
self, self,
......
...@@ -41,7 +41,8 @@ from typing_extensions import Required, TypeAlias, TypedDict ...@@ -41,7 +41,8 @@ from typing_extensions import Required, TypeAlias, TypedDict
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalUUIDDict)
from vllm.multimodal.utils import MediaConnector from vllm.multimodal.utils import MediaConnector
# yapf: disable # yapf: disable
from vllm.transformers_utils.chat_templates import ( from vllm.transformers_utils.chat_templates import (
...@@ -72,6 +73,11 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False): ...@@ -72,6 +73,11 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
type: Required[Literal["audio_url"]] type: Required[Literal["audio_url"]]
"""The type of the content part.""" """The type of the content part."""
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
...@@ -83,6 +89,11 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False): ...@@ -83,6 +89,11 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
""" """
type: Required[Literal["image_embeds"]] type: Required[Literal["image_embeds"]]
"""The type of the content part.""" """The type of the content part."""
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class VideoURL(TypedDict, total=False): class VideoURL(TypedDict, total=False):
...@@ -97,12 +108,18 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False): ...@@ -97,12 +108,18 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
type: Required[Literal["video_url"]] type: Required[Literal["video_url"]]
"""The type of the content part.""" """The type of the content part."""
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class PILImage(BaseModel): class PILImage(BaseModel):
""" """
A PIL.Image.Image object. A PIL.Image.Image object.
""" """
image_pil: Image.Image image_pil: Image.Image
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
...@@ -115,7 +132,13 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False): ...@@ -115,7 +132,13 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
"image_pil": ImageAsset('cherry_blossom').pil_image "image_pil": ImageAsset('cherry_blossom').pil_image
} }
""" """
image_pil: Required[PILImage] image_pil: Required[PILImage]
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
...@@ -127,7 +150,13 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): ...@@ -127,7 +150,13 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"image_url": "https://example.com/image.jpg" "image_url": "https://example.com/image.jpg"
} }
""" """
image_url: Required[str] image_url: Required[str]
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
...@@ -138,6 +167,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): ...@@ -138,6 +167,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
"audio_url": "https://example.com/audio.mp3" "audio_url": "https://example.com/audio.mp3"
} }
""" """
audio_url: Required[str] audio_url: Required[str]
...@@ -149,7 +179,13 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): ...@@ -149,7 +179,13 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
"video_url": "https://example.com/video.mp4" "video_url": "https://example.com/video.mp4"
} }
""" """
video_url: Required[str] video_url: Required[str]
uuid: Optional[str]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class CustomThinkCompletionContentParam(TypedDict, total=False): class CustomThinkCompletionContentParam(TypedDict, total=False):
...@@ -174,19 +210,24 @@ class CustomThinkCompletionContentParam(TypedDict, total=False): ...@@ -174,19 +210,24 @@ class CustomThinkCompletionContentParam(TypedDict, total=False):
ChatCompletionContentPartParam: TypeAlias = Union[ ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, OpenAIChatCompletionContentPartParam,
ChatCompletionContentPartAudioParam,
ChatCompletionContentPartInputAudioParam, ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam, ChatCompletionContentPartVideoParam,
ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentPILImageParam, CustomChatCompletionContentPILImageParam,
CustomChatCompletionContentSimpleImageParam, CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam, ChatCompletionContentPartImageEmbedsParam,
CustomChatCompletionContentSimpleAudioParam, CustomChatCompletionContentSimpleAudioParam,
CustomChatCompletionContentSimpleVideoParam, str, CustomChatCompletionContentSimpleVideoParam,
CustomThinkCompletionContentParam] str,
CustomThinkCompletionContentParam,
]
class CustomChatCompletionMessageParam(TypedDict, total=False): class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API.""" """Enables custom roles in the Chat Completion API."""
role: Required[str] role: Required[str]
"""The role of the message's author.""" """The role of the message's author."""
...@@ -207,9 +248,11 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): ...@@ -207,9 +248,11 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
"""The tool calls generated by the model, such as function calls.""" """The tool calls generated by the model, such as function calls."""
ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, ChatCompletionMessageParam = Union[
CustomChatCompletionMessageParam, OpenAIChatCompletionMessageParam,
OpenAIHarmonyMessage] CustomChatCompletionMessageParam,
OpenAIHarmonyMessage,
]
# TODO: Make fields ReadOnly once mypy supports it # TODO: Make fields ReadOnly once mypy supports it
...@@ -262,13 +305,13 @@ def _is_var_or_elems_access( ...@@ -262,13 +305,13 @@ def _is_var_or_elems_access(
key: Optional[str] = None, key: Optional[str] = None,
) -> bool: ) -> bool:
if isinstance(node, jinja2.nodes.Filter): if isinstance(node, jinja2.nodes.Filter):
return (node.node is not None return node.node is not None and _is_var_or_elems_access(
and _is_var_or_elems_access(node.node, varname, key)) node.node, varname, key)
if isinstance(node, jinja2.nodes.Test): if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key) return _is_var_or_elems_access(node.node, varname, key)
if (isinstance(node, jinja2.nodes.Getitem) if isinstance(node, jinja2.nodes.Getitem) and isinstance(
and isinstance(node.arg, jinja2.nodes.Slice)): node.arg, jinja2.nodes.Slice):
return _is_var_or_elems_access(node.node, varname, key) return _is_var_or_elems_access(node.node, varname, key)
# yapf: disable # yapf: disable
...@@ -373,15 +416,18 @@ def resolve_mistral_chat_template( ...@@ -373,15 +416,18 @@ def resolve_mistral_chat_template(
) -> Optional[str]: ) -> Optional[str]:
if chat_template is not None: if chat_template is not None:
logger.warning_once( logger.warning_once(
"'chat_template' cannot be overridden for mistral tokenizer.") "'chat_template' cannot be overridden for mistral tokenizer."
)
if "add_generation_prompt" in kwargs: if "add_generation_prompt" in kwargs:
logger.warning_once( logger.warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, " "'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored.") "so it will be ignored."
)
if "continue_final_message" in kwargs: if "continue_final_message" in kwargs:
logger.warning_once( logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, " "'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored.") "so it will be ignored."
)
return None return None
...@@ -401,23 +447,35 @@ def resolve_hf_chat_template( ...@@ -401,23 +447,35 @@ def resolve_hf_chat_template(
try: try:
processor = cached_get_processor( processor = cached_get_processor(
tokenizer.name_or_path, tokenizer.name_or_path,
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast, processor_cls=(
ProcessorMixin), PreTrainedTokenizer,
PreTrainedTokenizerFast,
ProcessorMixin,
),
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
) )
if isinstance(processor, ProcessorMixin) and \ if (
hasattr(processor, 'chat_template') and \ isinstance(processor, ProcessorMixin)
processor.chat_template is not None: and hasattr(processor, "chat_template")
and processor.chat_template is not None
):
return processor.chat_template return processor.chat_template
except Exception: except Exception:
logger.debug("Failed to load AutoProcessor chat template for %s", tokenizer.name_or_path, exc_info=True) # noqa: E501 logger.debug(
"Failed to load AutoProcessor chat template for %s",
tokenizer.name_or_path,
exc_info=True,
) # noqa: E501
# 3rd priority: AutoTokenizer chat template # 3rd priority: AutoTokenizer chat template
try: try:
return tokenizer.get_chat_template(chat_template, tools=tools) return tokenizer.get_chat_template(chat_template, tools=tools)
except Exception: except Exception:
logger.debug("Failed to load AutoTokenizer chat template for %s", logger.debug(
tokenizer.name_or_path, exc_info=True) "Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path,
exc_info=True,
)
# 4th priority: Predefined fallbacks # 4th priority: Predefined fallbacks
path = get_chat_template_fallback_path( path = get_chat_template_fallback_path(
...@@ -425,12 +483,16 @@ def resolve_hf_chat_template( ...@@ -425,12 +483,16 @@ def resolve_hf_chat_template(
tokenizer_name_or_path=model_config.tokenizer, tokenizer_name_or_path=model_config.tokenizer,
) )
if path is not None: if path is not None:
logger.info("Loading chat template fallback for %s as there isn't one " logger.info(
"defined on HF Hub.", tokenizer.name_or_path) "Loading chat template fallback for %s as there isn't one "
"defined on HF Hub.",
tokenizer.name_or_path,
)
chat_template = load_chat_template(path) chat_template = load_chat_template(path)
else: else:
logger.debug("There is no chat template fallback for %s", logger.debug(
tokenizer.name_or_path) "There is no chat template fallback for %s", tokenizer.name_or_path
)
return chat_template return chat_template
...@@ -452,11 +514,17 @@ def _resolve_chat_template_content_format( ...@@ -452,11 +514,17 @@ def _resolve_chat_template_content_format(
else: else:
hf_chat_template = None hf_chat_template = None
jinja_text = (hf_chat_template if isinstance(hf_chat_template, str) jinja_text = (
else load_chat_template(chat_template, is_literal=True)) hf_chat_template
if isinstance(hf_chat_template, str)
else load_chat_template(chat_template, is_literal=True)
)
detected_format = ("string" if jinja_text is None else detected_format = (
_detect_content_format(jinja_text, default="string")) "string"
if jinja_text is None
else _detect_content_format(jinja_text, default="string")
)
return detected_format return detected_format
...@@ -512,7 +580,6 @@ def resolve_chat_template_content_format( ...@@ -512,7 +580,6 @@ def resolve_chat_template_content_format(
return detected_format return detected_format
ModalityStr = Literal["image", "audio", "video", "image_embeds"] ModalityStr = Literal["image", "audio", "video", "image_embeds"]
_T = TypeVar("_T") _T = TypeVar("_T")
...@@ -531,6 +598,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -531,6 +598,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._tokenizer = tokenizer self._tokenizer = tokenizer
self._items_by_modality = defaultdict[str, list[_T]](list) self._items_by_modality = defaultdict[str, list[_T]](list)
self._uuids_by_modality = defaultdict[str, list[Optional[str]]](list)
@property @property
def model_config(self) -> ModelConfig: def model_config(self) -> ModelConfig:
...@@ -539,6 +607,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -539,6 +607,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
@cached_property @cached_property
def model_cls(self) -> type[SupportsMultiModal]: def model_cls(self) -> type[SupportsMultiModal]:
from vllm.model_executor.model_loader import get_model_cls from vllm.model_executor.model_loader import get_model_cls
model_cls = get_model_cls(self.model_config) model_cls = get_model_cls(self.model_config)
return cast(type[SupportsMultiModal], model_cls) return cast(type[SupportsMultiModal], model_cls)
...@@ -554,10 +623,15 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -554,10 +623,15 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def mm_processor(self): def mm_processor(self):
return self.mm_registry.create_processor(self.model_config) return self.mm_registry.create_processor(self.model_config)
def add(self, modality: ModalityStr, item: _T) -> Optional[str]: def add(
self, modality: ModalityStr, item: _T, uuid: Optional[str] = None
) -> Optional[str]:
""" """
Add a multi-modal item to the current prompt and returns the Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any. placeholder string to use, if any.
An optional uuid can be added which serves as a unique identifier of the
media.
""" """
input_modality = modality.replace("_embeds", "") input_modality = modality.replace("_embeds", "")
num_items = len(self._items_by_modality[modality]) + 1 num_items = len(self._items_by_modality[modality]) + 1
...@@ -565,37 +639,64 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]): ...@@ -565,37 +639,64 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self.mm_processor.validate_num_items(input_modality, num_items) self.mm_processor.validate_num_items(input_modality, num_items)
self._items_by_modality[modality].append(item) self._items_by_modality[modality].append(item)
self._uuids_by_modality[modality].append(uuid)
return self.model_cls.get_placeholder_str(modality, num_items) return self.model_cls.get_placeholder_str(modality, num_items)
def all_mm_uuids(self) -> Optional[MultiModalUUIDDict]:
if not self._items_by_modality:
return None
mm_uuids = {}
uuids_by_modality = dict(self._uuids_by_modality)
if "image" in uuids_by_modality and "image_embeds" in uuids_by_modality:
raise ValueError(
"Mixing raw image and embedding inputs is not allowed"
)
if "image_embeds" in uuids_by_modality:
image_embeds_uuids = uuids_by_modality["image_embeds"]
if len(image_embeds_uuids) > 1:
raise ValueError(
"Only one message can have {'type': 'image_embeds'}"
)
mm_uuids["image"] = uuids_by_modality["image_embeds"]
if "image" in uuids_by_modality:
mm_uuids["image"] = uuids_by_modality["image"] # UUIDs of images
if "audio" in uuids_by_modality:
mm_uuids["audio"] = uuids_by_modality["audio"] # UUIDs of audios
if "video" in uuids_by_modality:
mm_uuids["video"] = uuids_by_modality["video"] # UUIDs of videos
return mm_uuids
@abstractmethod @abstractmethod
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
raise NotImplementedError raise NotImplementedError
class MultiModalItemTracker(BaseMultiModalItemTracker[object]): class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def all_mm_data(self) -> Optional[MultiModalDataDict]: def all_mm_data(self) -> Optional[MultiModalDataDict]:
if not self._items_by_modality: if not self._items_by_modality:
return None return None
mm_inputs = {} mm_inputs = {}
items_by_modality = dict(self._items_by_modality) items_by_modality = dict(self._items_by_modality)
if "image" in items_by_modality and "image_embeds" in items_by_modality: if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError(\ raise ValueError(
"Mixing raw image and embedding inputs is not allowed") "Mixing raw image and embedding inputs is not allowed"
)
if "image_embeds" in items_by_modality: if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"] image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1: if len(image_embeds_lst) > 1:
raise ValueError(\ raise ValueError(
"Only one message can have {'type': 'image_embeds'}") "Only one message can have {'type': 'image_embeds'}"
)
mm_inputs["image"] = image_embeds_lst[0] mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality: if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio" in items_by_modality: if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality: if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
...@@ -603,32 +704,33 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]): ...@@ -603,32 +704,33 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async def all_mm_data(self) -> Optional[MultiModalDataDict]: async def all_mm_data(self) -> Optional[MultiModalDataDict]:
if not self._items_by_modality: if not self._items_by_modality:
return None return None
mm_inputs = {} mm_inputs = {}
items_by_modality = { items_by_modality = {
modality: await asyncio.gather(*items) modality: await asyncio.gather(*items)
for modality, items in self._items_by_modality.items() for modality, items in self._items_by_modality.items()
} }
if "image" in items_by_modality and "image_embeds" in items_by_modality: if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError( raise ValueError(
"Mixing raw image and embedding inputs is not allowed") "Mixing raw image and embedding inputs is not allowed"
)
if "image_embeds" in items_by_modality: if "image_embeds" in items_by_modality:
image_embeds_lst = items_by_modality["image_embeds"] image_embeds_lst = items_by_modality["image_embeds"]
if len(image_embeds_lst) > 1: if len(image_embeds_lst) > 1:
raise ValueError( raise ValueError(
"Only one message can have {'type': 'image_embeds'}") "Only one message can have {'type': 'image_embeds'}"
)
mm_inputs["image"] = image_embeds_lst[0] mm_inputs["image"] = image_embeds_lst[0]
if "image" in items_by_modality: if "image" in items_by_modality:
mm_inputs["image"] = items_by_modality["image"] # A list of images mm_inputs["image"] = items_by_modality["image"] # A list of images
if "audio" in items_by_modality: if "audio" in items_by_modality:
mm_inputs["audio"] = items_by_modality["audio"] # A list of audios mm_inputs["audio"] = items_by_modality["audio"] # A list of audios
if "video" in items_by_modality: if "video" in items_by_modality:
mm_inputs["video"] = items_by_modality["video"] # A list of videos mm_inputs["video"] = items_by_modality["video"] # A list of videos
return mm_inputs return mm_inputs
def create_parser(self) -> "BaseMultiModalContentParser": def create_parser(self) -> "BaseMultiModalContentParser":
...@@ -636,7 +738,6 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]): ...@@ -636,7 +738,6 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
class BaseMultiModalContentParser(ABC): class BaseMultiModalContentParser(ABC):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -648,8 +749,9 @@ class BaseMultiModalContentParser(ABC): ...@@ -648,8 +749,9 @@ class BaseMultiModalContentParser(ABC):
# } # }
self._placeholder_storage: dict[str, list] = defaultdict(list) self._placeholder_storage: dict[str, list] = defaultdict(list)
def _add_placeholder(self, modality: ModalityStr, def _add_placeholder(
placeholder: Optional[str]): self, modality: ModalityStr, placeholder: Optional[str]
):
mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality] mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality]
if placeholder: if placeholder:
self._placeholder_storage[mod_placeholder].append(placeholder) self._placeholder_storage[mod_placeholder].append(placeholder)
...@@ -658,33 +760,39 @@ class BaseMultiModalContentParser(ABC): ...@@ -658,33 +760,39 @@ class BaseMultiModalContentParser(ABC):
return dict(self._placeholder_storage) return dict(self._placeholder_storage)
@abstractmethod @abstractmethod
def parse_image(self, image_url: str) -> None: def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def parse_image_embeds(self, def parse_image_embeds(
image_embeds: Union[str, dict[str, str]]) -> None: self,
image_embeds: Union[str, dict[str, str]],
uuid: Optional[str] = None,
) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def parse_image_pil(self, image_pil: Image.Image) -> None: def parse_image_pil(
self, image_pil: Image.Image, uuid: Optional[str] = None
) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def parse_input_audio(self, input_audio: InputAudio) -> None: def parse_input_audio(
self, input_audio: InputAudio, uuid: Optional[str] = None
) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def parse_video(self, video_url: str) -> None: def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
raise NotImplementedError raise NotImplementedError
class MultiModalContentParser(BaseMultiModalContentParser): class MultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: MultiModalItemTracker) -> None: def __init__(self, tracker: MultiModalItemTracker) -> None:
super().__init__() super().__init__()
...@@ -695,70 +803,79 @@ class MultiModalContentParser(BaseMultiModalContentParser): ...@@ -695,70 +803,79 @@ class MultiModalContentParser(BaseMultiModalContentParser):
allowed_local_media_path=tracker.allowed_local_media_path, allowed_local_media_path=tracker.allowed_local_media_path,
) )
def parse_image(self, image_url: str) -> None: def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
image = self._connector.fetch_image(image_url) image = self._connector.fetch_image(image_url)
placeholder = self._tracker.add("image", image) placeholder = self._tracker.add("image", image, uuid)
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_image_embeds(self, def parse_image_embeds(
image_embeds: Union[str, dict[str, str]]) -> None: self,
image_embeds: Union[str, dict[str, str]],
uuid: Optional[str] = None,
) -> None:
if isinstance(image_embeds, dict): if isinstance(image_embeds, dict):
embeds = { embeds = {
k: self._connector.fetch_image_embedding(v) k: self._connector.fetch_image_embedding(v)
for k, v in image_embeds.items() for k, v in image_embeds.items()
} }
placeholder = self._tracker.add("image_embeds", embeds) placeholder = self._tracker.add("image_embeds", embeds, uuid)
if isinstance(image_embeds, str): if isinstance(image_embeds, str):
embedding = self._connector.fetch_image_embedding(image_embeds) embedding = self._connector.fetch_image_embedding(image_embeds)
placeholder = self._tracker.add("image_embeds", embedding) placeholder = self._tracker.add("image_embeds", embedding, uuid)
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_image_pil(self, image_pil: Image.Image) -> None: def parse_image_pil(
placeholder = self._tracker.add("image", image_pil) self, image_pil: Image.Image, uuid: Optional[str] = None
) -> None:
placeholder = self._tracker.add("image", image_pil, uuid)
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
audio = self._connector.fetch_audio(audio_url) audio = self._connector.fetch_audio(audio_url)
placeholder = self._tracker.add("audio", audio) placeholder = self._tracker.add("audio", audio, uuid)
self._add_placeholder("audio", placeholder) self._add_placeholder("audio", placeholder)
def parse_input_audio(self, input_audio: InputAudio) -> None: def parse_input_audio(
self, input_audio: InputAudio, uuid: Optional[str] = None
) -> None:
audio_data = input_audio.get("data", "") audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "") audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}" audio_url = f"data:audio/{audio_format};base64,{audio_data}"
return self.parse_audio(audio_url) return self.parse_audio(audio_url, uuid)
def parse_video(self, video_url: str) -> None: def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
video = self._connector.fetch_video(video_url=video_url) video = self._connector.fetch_video(video_url=video_url)
placeholder = self._tracker.add("video", video) placeholder = self._tracker.add("video", video, uuid)
self._add_placeholder("video", placeholder) self._add_placeholder("video", placeholder)
class AsyncMultiModalContentParser(BaseMultiModalContentParser): class AsyncMultiModalContentParser(BaseMultiModalContentParser):
def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
super().__init__() super().__init__()
self._tracker = tracker self._tracker = tracker
self._connector = MediaConnector( self._connector = MediaConnector(
media_io_kwargs=self._tracker._model_config.media_io_kwargs, media_io_kwargs=self._tracker._model_config.media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path allowed_local_media_path=tracker.allowed_local_media_path,
) )
def parse_image(self, image_url: str) -> None: def parse_image(self, image_url: str, uuid: Optional[str] = None) -> None:
image_coro = self._connector.fetch_image_async(image_url) image_coro = self._connector.fetch_image_async(image_url)
placeholder = self._tracker.add("image", image_coro) placeholder = self._tracker.add("image", image_coro, uuid)
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_image_embeds(self, def parse_image_embeds(
image_embeds: Union[str, dict[str, str]]) -> None: self,
image_embeds: Union[str, dict[str, str]],
uuid: Optional[str] = None,
) -> None:
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future() future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
if isinstance(image_embeds, dict): if isinstance(image_embeds, dict):
...@@ -769,37 +886,40 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): ...@@ -769,37 +886,40 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
future.set_result(embeds) future.set_result(embeds)
if isinstance(image_embeds, str): if isinstance(image_embeds, str):
embedding = self._connector.\ embedding = self._connector.fetch_image_embedding(image_embeds)
fetch_image_embedding(image_embeds)
future.set_result(embedding) future.set_result(embedding)
placeholder = self._tracker.add("image_embeds", future) placeholder = self._tracker.add("image_embeds", future, uuid)
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_image_pil(self, image_pil: Image.Image) -> None: def parse_image_pil(
self, image_pil: Image.Image, uuid: Optional[str] = None
) -> None:
future: asyncio.Future[Image.Image] = asyncio.Future() future: asyncio.Future[Image.Image] = asyncio.Future()
future.set_result(image_pil) future.set_result(image_pil)
placeholder = self._tracker.add("image", future) placeholder = self._tracker.add("image", future, uuid)
self._add_placeholder("image", placeholder) self._add_placeholder("image", placeholder)
def parse_audio(self, audio_url: str) -> None: def parse_audio(self, audio_url: str, uuid: Optional[str] = None) -> None:
audio_coro = self._connector.fetch_audio_async(audio_url) audio_coro = self._connector.fetch_audio_async(audio_url)
placeholder = self._tracker.add("audio", audio_coro) placeholder = self._tracker.add("audio", audio_coro, uuid)
self._add_placeholder("audio", placeholder) self._add_placeholder("audio", placeholder)
def parse_input_audio(self, input_audio: InputAudio) -> None: def parse_input_audio(
self, input_audio: InputAudio, uuid: Optional[str] = None
) -> None:
audio_data = input_audio.get("data", "") audio_data = input_audio.get("data", "")
audio_format = input_audio.get("format", "") audio_format = input_audio.get("format", "")
audio_url = f"data:audio/{audio_format};base64,{audio_data}" audio_url = f"data:audio/{audio_format};base64,{audio_data}"
return self.parse_audio(audio_url) return self.parse_audio(audio_url, uuid)
def parse_video(self, video_url: str) -> None: def parse_video(self, video_url: str, uuid: Optional[str] = None) -> None:
video = self._connector.fetch_video_async(video_url=video_url) video = self._connector.fetch_video_async(video_url=video_url)
placeholder = self._tracker.add("video", video) placeholder = self._tracker.add("video", video, uuid)
self._add_placeholder("video", placeholder) self._add_placeholder("video", placeholder)
...@@ -809,20 +929,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]): ...@@ -809,20 +929,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
return return
elif isinstance(chat_template, Path) and not chat_template.exists(): elif isinstance(chat_template, Path) and not chat_template.exists():
raise FileNotFoundError( raise FileNotFoundError("the supplied chat template path doesn't exist")
"the supplied chat template path doesn't exist")
elif isinstance(chat_template, str): elif isinstance(chat_template, str):
JINJA_CHARS = "{}\n" JINJA_CHARS = "{}\n"
if not any(c in chat_template if (
for c in JINJA_CHARS) and not Path(chat_template).exists(): not any(c in chat_template for c in JINJA_CHARS)
and not Path(chat_template).exists()
):
raise ValueError( raise ValueError(
f"The supplied chat template string ({chat_template}) " f"The supplied chat template string ({chat_template}) "
f"appears path-like, but doesn't exist!") f"appears path-like, but doesn't exist!"
)
else: else:
raise TypeError( raise TypeError(
f"{type(chat_template)} is not a valid chat template type") f"{type(chat_template)} is not a valid chat template type"
)
def _load_chat_template( def _load_chat_template(
...@@ -835,8 +958,9 @@ def _load_chat_template( ...@@ -835,8 +958,9 @@ def _load_chat_template(
if is_literal: if is_literal:
if isinstance(chat_template, Path): if isinstance(chat_template, Path):
raise TypeError("chat_template is expected to be read directly " raise TypeError(
"from its value") "chat_template is expected to be read directly from its value"
)
return chat_template return chat_template
...@@ -849,9 +973,11 @@ def _load_chat_template( ...@@ -849,9 +973,11 @@ def _load_chat_template(
JINJA_CHARS = "{}\n" JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS): if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) " msg = (
f"looks like a file path, but it failed to be " f"The supplied chat template ({chat_template}) "
f"opened. Reason: {e}") f"looks like a file path, but it failed to be "
f"opened. Reason: {e}"
)
raise ValueError(msg) from e raise ValueError(msg) from e
# If opening a file fails, set chat template to be args to # If opening a file fails, set chat template to be args to
...@@ -870,8 +996,9 @@ def load_chat_template( ...@@ -870,8 +996,9 @@ def load_chat_template(
return _cached_load_chat_template(chat_template, is_literal=is_literal) return _cached_load_chat_template(chat_template, is_literal=is_literal)
def _get_interleaved_text_prompt(placeholder_storage: dict[str, list], def _get_interleaved_text_prompt(
texts: list[str]) -> str: placeholder_storage: dict[str, list], texts: list[str]
) -> str:
for idx, elem in enumerate(texts): for idx, elem in enumerate(texts):
if elem in placeholder_storage: if elem in placeholder_storage:
texts[idx] = placeholder_storage[elem].pop(0) texts[idx] = placeholder_storage[elem].pop(0)
...@@ -881,10 +1008,11 @@ def _get_interleaved_text_prompt(placeholder_storage: dict[str, list], ...@@ -881,10 +1008,11 @@ def _get_interleaved_text_prompt(placeholder_storage: dict[str, list],
# TODO: Let user specify how to insert multimodal tokens into prompt # TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template) # (similar to chat template)
def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], def _get_full_multimodal_text_prompt(
texts: list[str], placeholder_storage: dict[str, list],
interleave_strings: bool texts: list[str],
) -> str: interleave_strings: bool,
) -> str:
"""Combine multimodal prompts for a multimodal language model.""" """Combine multimodal prompts for a multimodal language model."""
# flatten storage to make it looks like # flatten storage to make it looks like
...@@ -907,7 +1035,6 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], ...@@ -907,7 +1035,6 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
# Look through the text prompt to check for missing placeholders # Look through the text prompt to check for missing placeholders
missing_placeholders: list[str] = [] missing_placeholders: list[str] = []
for placeholder in placeholder_counts: for placeholder in placeholder_counts:
# For any existing placeholder in the text prompt, we leave it as is # For any existing placeholder in the text prompt, we leave it as is
placeholder_counts[placeholder] -= text_prompt.count(placeholder) placeholder_counts[placeholder] -= text_prompt.count(placeholder)
...@@ -916,15 +1043,18 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], ...@@ -916,15 +1043,18 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
"Placeholder count is negative! " "Placeholder count is negative! "
"Ensure that the 'interleave_strings' flag is disabled " "Ensure that the 'interleave_strings' flag is disabled "
"(current value: %s) " "(current value: %s) "
"when manually placing image placeholders.", interleave_strings "when manually placing image placeholders.",
interleave_strings,
) )
logger.debug("Input prompt: %s", text_prompt) logger.debug("Input prompt: %s", text_prompt)
raise ValueError( raise ValueError(
f"Found more '{placeholder}' placeholders in input prompt than " f"Found more '{placeholder}' placeholders in input prompt than "
"actual multimodal data items.") "actual multimodal data items."
)
missing_placeholders.extend([placeholder] * missing_placeholders.extend(
placeholder_counts[placeholder]) [placeholder] * placeholder_counts[placeholder]
)
# NOTE: Default behaviour: we always add missing placeholders # NOTE: Default behaviour: we always add missing placeholders
# at the front of the prompt, if interleave_strings=False # at the front of the prompt, if interleave_strings=False
...@@ -944,7 +1074,8 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python ...@@ -944,7 +1074,8 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python _VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
_ResponsesInputImageParser = TypeAdapter( _ResponsesInputImageParser = TypeAdapter(
ResponseInputImageParam).validate_python ResponseInputImageParam
).validate_python
_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage] _ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage]
# Define a mapping from part types to their corresponding parsing functions. # Define a mapping from part types to their corresponding parsing functions.
...@@ -952,32 +1083,35 @@ MM_PARSER_MAP: dict[ ...@@ -952,32 +1083,35 @@ MM_PARSER_MAP: dict[
str, str,
Callable[[ChatCompletionContentPartParam], _ContentPart], Callable[[ChatCompletionContentPartParam], _ContentPart],
] = { ] = {
"text": "text": lambda part: _TextParser(part).get("text", None),
lambda part: _TextParser(part).get("text", None), "thinking": lambda part: _ThinkParser(part).get("thinking", None),
"thinking": "input_text": lambda part: _TextParser(part).get("text", None),
lambda part: _ThinkParser(part).get("thinking", None), "input_image": lambda part: _ResponsesInputImageParser(part).get(
"input_text": "image_url", None
lambda part: _TextParser(part).get("text", None), ),
"input_image": "image_url": lambda part: _ImageParser(part)
lambda part: _ResponsesInputImageParser(part).get("image_url", None), .get("image_url", {})
"image_url": .get("url", None),
lambda part: _ImageParser(part).get("image_url", {}).get("url", None), "image_embeds": lambda part: _ImageEmbedsParser(part).get(
"image_embeds": "image_embeds", None
lambda part: _ImageEmbedsParser(part).get("image_embeds", None), ),
"image_pil": lambda part: _PILImageParser(part).get("image_pil", None), "image_pil": lambda part: _PILImageParser(part).get("image_pil", None),
"audio_url": "audio_url": lambda part: _AudioParser(part)
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None), .get("audio_url", {})
"input_audio": .get("url", None),
lambda part: _InputAudioParser(part).get("input_audio", None), "input_audio": lambda part: _InputAudioParser(part).get(
"refusal": "input_audio", None
lambda part: _RefusalParser(part).get("refusal", None), ),
"video_url": "refusal": lambda part: _RefusalParser(part).get("refusal", None),
lambda part: _VideoParser(part).get("video_url", {}).get("url", None), "video_url": lambda part: _VideoParser(part)
.get("video_url", {})
.get("url", None),
} }
def _parse_chat_message_content_mm_part( def _parse_chat_message_content_mm_part(
part: ChatCompletionContentPartParam) -> tuple[str, _ContentPart]: part: ChatCompletionContentPartParam,
) -> tuple[str, _ContentPart]:
""" """
Parses a given multi-modal content part based on its type. Parses a given multi-modal content part based on its type.
...@@ -993,7 +1127,8 @@ def _parse_chat_message_content_mm_part( ...@@ -993,7 +1127,8 @@ def _parse_chat_message_content_mm_part(
ValueError: If the 'type' field is missing and no direct URL is found. ValueError: If the 'type' field is missing and no direct URL is found.
""" """
assert isinstance( assert isinstance(
part, dict) # This is needed to avoid mypy errors: part.get() from str part, dict
) # This is needed to avoid mypy errors: part.get() from str
part_type = part.get("type", None) part_type = part.get("type", None)
if isinstance(part_type, str) and part_type in MM_PARSER_MAP: if isinstance(part_type, str) and part_type in MM_PARSER_MAP:
...@@ -1002,8 +1137,10 @@ def _parse_chat_message_content_mm_part( ...@@ -1002,8 +1137,10 @@ def _parse_chat_message_content_mm_part(
# Special case for 'image_url.detail' # Special case for 'image_url.detail'
# We only support 'auto', which is the default # We only support 'auto', which is the default
if part_type == "image_url" and part.get("detail", "auto") != "auto": if part_type == "image_url" and part.get("detail", "auto") != "auto":
logger.warning("'image_url.detail' is currently not supported " logger.warning(
"and will be ignored.") "'image_url.detail' is currently not supported "
"and will be ignored."
)
return part_type, content return part_type, content
...@@ -1011,19 +1148,22 @@ def _parse_chat_message_content_mm_part( ...@@ -1011,19 +1148,22 @@ def _parse_chat_message_content_mm_part(
# 'type' is required field by pydantic # 'type' is required field by pydantic
if part_type is None: if part_type is None:
if part.get("image_url") is not None: if part.get("image_url") is not None:
image_params = cast(CustomChatCompletionContentSimpleImageParam, image_params = cast(
part) CustomChatCompletionContentSimpleImageParam, part
)
return "image_url", image_params.get("image_url", "") return "image_url", image_params.get("image_url", "")
if part.get("audio_url") is not None: if part.get("audio_url") is not None:
audio_params = cast(CustomChatCompletionContentSimpleAudioParam, audio_params = cast(
part) CustomChatCompletionContentSimpleAudioParam, part
)
return "audio_url", audio_params.get("audio_url", "") return "audio_url", audio_params.get("audio_url", "")
if part.get("input_audio") is not None: if part.get("input_audio") is not None:
input_audio_params = cast(dict[str, str], part) input_audio_params = cast(dict[str, str], part)
return "input_audio", input_audio_params return "input_audio", input_audio_params
if part.get("video_url") is not None: if part.get("video_url") is not None:
video_params = cast(CustomChatCompletionContentSimpleVideoParam, video_params = cast(
part) CustomChatCompletionContentSimpleVideoParam, part
)
return "video_url", video_params.get("video_url", "") return "video_url", video_params.get("video_url", "")
# Raise an error if no 'type' or direct URL is found. # Raise an error if no 'type' or direct URL is found.
raise ValueError("Missing 'type' field in multimodal part.") raise ValueError("Missing 'type' field in multimodal part.")
...@@ -1033,9 +1173,16 @@ def _parse_chat_message_content_mm_part( ...@@ -1033,9 +1173,16 @@ def _parse_chat_message_content_mm_part(
return part_type, "unknown part_type content" return part_type, "unknown part_type content"
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", VALID_MESSAGE_CONTENT_MM_PART_TYPES = (
"image_embeds", "image_pil", "text",
"audio_url", "input_audio", "video_url") "refusal",
"image_url",
"image_embeds",
"image_pil",
"audio_url",
"input_audio",
"video_url",
)
def _parse_chat_message_content_parts( def _parse_chat_message_content_parts(
...@@ -1055,21 +1202,20 @@ def _parse_chat_message_content_parts( ...@@ -1055,21 +1202,20 @@ def _parse_chat_message_content_parts(
part, part,
mm_parser, mm_parser,
wrap_dicts=wrap_dicts, wrap_dicts=wrap_dicts,
interleave_strings=interleave_strings interleave_strings=interleave_strings,
) )
if parse_res: if parse_res:
content.append(parse_res) content.append(parse_res)
if wrap_dicts: if wrap_dicts:
# Parsing wraps images and texts as interleaved dictionaries # Parsing wraps images and texts as interleaved dictionaries
return [ConversationMessage(role=role, return [ConversationMessage(role=role, content=content)] # type: ignore
content=content)] # type: ignore
texts = cast(list[str], content) texts = cast(list[str], content)
mm_placeholder_storage = mm_parser.mm_placeholder_storage() mm_placeholder_storage = mm_parser.mm_placeholder_storage()
if mm_placeholder_storage: if mm_placeholder_storage:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage, text_prompt = _get_full_multimodal_text_prompt(
texts, mm_placeholder_storage, texts, interleave_strings
interleave_strings) )
else: else:
text_prompt = "\n".join(texts) text_prompt = "\n".join(texts)
...@@ -1099,46 +1245,59 @@ def _parse_chat_message_content_part( ...@@ -1099,46 +1245,59 @@ def _parse_chat_message_content_part(
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None: if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
logger.warning( logger.warning(
"Skipping multimodal part '%s' (type: '%s') " "Skipping multimodal part '%s' (type: '%s') "
"with empty / unparsable content.", part, part_type) "with empty / unparsable content.",
part,
part_type,
)
return None return None
if part_type in ("text", "input_text", "refusal", "thinking"): if part_type in ("text", "input_text", "refusal", "thinking"):
str_content = cast(str, content) str_content = cast(str, content)
if wrap_dicts: if wrap_dicts:
return {'type': 'text', 'text': str_content} return {"type": "text", "text": str_content}
else: else:
return str_content return str_content
# For media items, if a user has provided one, use it. Otherwise, insert
# a placeholder empty uuid.
uuid = part.get("uuid", None)
if uuid is not None:
uuid = str(uuid)
modality = None modality = None
if part_type == "image_pil": if part_type == "image_pil":
image_content = cast(Image.Image, content) image_content = cast(Image.Image, content)
mm_parser.parse_image_pil(image_content) mm_parser.parse_image_pil(image_content, uuid)
modality = "image" modality = "image"
elif part_type in ("image_url", "input_image"): elif part_type in ("image_url", "input_image"):
str_content = cast(str, content) str_content = cast(str, content)
mm_parser.parse_image(str_content) mm_parser.parse_image(str_content, uuid)
modality = "image" modality = "image"
elif part_type == "image_embeds": elif part_type == "image_embeds":
content = cast(Union[str, dict[str, str]], content) content = cast(Union[str, dict[str, str]], content)
mm_parser.parse_image_embeds(content) mm_parser.parse_image_embeds(content, uuid)
modality = "image" modality = "image"
elif part_type == "audio_url": elif part_type == "audio_url":
str_content = cast(str, content) str_content = cast(str, content)
mm_parser.parse_audio(str_content) mm_parser.parse_audio(str_content, uuid)
modality = "audio" modality = "audio"
elif part_type == "input_audio": elif part_type == "input_audio":
dict_content = cast(InputAudio, content) dict_content = cast(InputAudio, content)
mm_parser.parse_input_audio(dict_content) mm_parser.parse_input_audio(dict_content, uuid)
modality = "audio" modality = "audio"
elif part_type == "video_url": elif part_type == "video_url":
str_content = cast(str, content) str_content = cast(str, content)
mm_parser.parse_video(str_content) mm_parser.parse_video(str_content, uuid)
modality = "video" modality = "video"
else: else:
raise NotImplementedError(f"Unknown part type: {part_type}") raise NotImplementedError(f"Unknown part type: {part_type}")
return {'type': modality} if wrap_dicts else ( return (
MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None {"type": modality}
if wrap_dicts
else (
MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None
)
) )
...@@ -1171,14 +1330,16 @@ def _parse_chat_message_content( ...@@ -1171,14 +1330,16 @@ def _parse_chat_message_content(
) )
for result_msg in result: for result_msg in result:
if role == 'assistant': if role == "assistant":
parsed_msg = _AssistantParser(message) parsed_msg = _AssistantParser(message)
# The 'tool_calls' is not None check ensures compatibility. # The 'tool_calls' is not None check ensures compatibility.
# It's needed only if downstream code doesn't strictly # It's needed only if downstream code doesn't strictly
# follow the OpenAI spec. # follow the OpenAI spec.
if ("tool_calls" in parsed_msg if (
and parsed_msg["tool_calls"] is not None): "tool_calls" in parsed_msg
and parsed_msg["tool_calls"] is not None
):
result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) result_msg["tool_calls"] = list(parsed_msg["tool_calls"])
elif role == "tool": elif role == "tool":
parsed_msg = _ToolParser(message) parsed_msg = _ToolParser(message)
...@@ -1198,12 +1359,15 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None: ...@@ -1198,12 +1359,15 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
# so, for messages that have tool_calls, parse the string (which we get # so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict # from openAI format) to dict
for message in messages: for message in messages:
if (message["role"] == "assistant" and "tool_calls" in message if (
and isinstance(message["tool_calls"], list)): message["role"] == "assistant"
and "tool_calls" in message
and isinstance(message["tool_calls"], list)
):
for item in message["tool_calls"]: for item in message["tool_calls"]:
item["function"]["arguments"] = json.loads( item["function"]["arguments"] = json.loads(
item["function"]["arguments"]) item["function"]["arguments"]
)
def parse_chat_messages( def parse_chat_messages(
...@@ -1211,7 +1375,11 @@ def parse_chat_messages( ...@@ -1211,7 +1375,11 @@ def parse_chat_messages(
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
content_format: _ChatTemplateContentFormat, content_format: _ChatTemplateContentFormat,
) -> tuple[list[ConversationMessage], Optional[MultiModalDataDict]]: ) -> tuple[
list[ConversationMessage],
Optional[MultiModalDataDict],
Optional[MultiModalUUIDDict],
]:
conversation: list[ConversationMessage] = [] conversation: list[ConversationMessage] = []
mm_tracker = MultiModalItemTracker(model_config, tokenizer) mm_tracker = MultiModalItemTracker(model_config, tokenizer)
...@@ -1224,14 +1392,14 @@ def parse_chat_messages( ...@@ -1224,14 +1392,14 @@ def parse_chat_messages(
content_format == "string" content_format == "string"
and model_config.multimodal_config is not None and model_config.multimodal_config is not None
and model_config.multimodal_config.interleave_mm_strings and model_config.multimodal_config.interleave_mm_strings
) ),
) )
conversation.extend(sub_messages) conversation.extend(sub_messages)
_postprocess_messages(conversation) _postprocess_messages(conversation)
return conversation, mm_tracker.all_mm_data() return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
def parse_chat_messages_futures( def parse_chat_messages_futures(
...@@ -1239,7 +1407,11 @@ def parse_chat_messages_futures( ...@@ -1239,7 +1407,11 @@ def parse_chat_messages_futures(
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
content_format: _ChatTemplateContentFormat, content_format: _ChatTemplateContentFormat,
) -> tuple[list[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: ) -> tuple[
list[ConversationMessage],
Awaitable[Optional[MultiModalDataDict]],
Optional[MultiModalUUIDDict],
]:
conversation: list[ConversationMessage] = [] conversation: list[ConversationMessage] = []
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
...@@ -1252,14 +1424,14 @@ def parse_chat_messages_futures( ...@@ -1252,14 +1424,14 @@ def parse_chat_messages_futures(
content_format == "string" content_format == "string"
and model_config.multimodal_config is not None and model_config.multimodal_config is not None
and model_config.multimodal_config.interleave_mm_strings and model_config.multimodal_config.interleave_mm_strings
) ),
) )
conversation.extend(sub_messages) conversation.extend(sub_messages)
_postprocess_messages(conversation) _postprocess_messages(conversation)
return conversation, mm_tracker.all_mm_data() return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
def apply_hf_chat_template( def apply_hf_chat_template(
...@@ -1283,10 +1455,10 @@ def apply_hf_chat_template( ...@@ -1283,10 +1455,10 @@ def apply_hf_chat_template(
raise ValueError( raise ValueError(
"As of transformers v4.44, default chat template is no longer " "As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer " "allowed, so you must provide a chat template if the tokenizer "
"does not define one.") "does not define one."
)
try: try:
return tokenizer.apply_chat_template( return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type] conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type] tools=tools, # type: ignore[arg-type]
...@@ -1298,13 +1470,14 @@ def apply_hf_chat_template( ...@@ -1298,13 +1470,14 @@ def apply_hf_chat_template(
# External library exceptions can sometimes occur despite the framework's # External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities. # internal exception management capabilities.
except Exception as e: except Exception as e:
# Log and report any library-related exceptions for further # Log and report any library-related exceptions for further
# investigation. # investigation.
logger.exception( logger.exception(
"An error occurred in `transformers` while applying chat template") "An error occurred in `transformers` while applying chat template"
)
raise ValueError(str(e)) from e raise ValueError(str(e)) from e
def apply_mistral_chat_template( def apply_mistral_chat_template(
tokenizer: MistralTokenizer, tokenizer: MistralTokenizer,
messages: list[ChatCompletionMessageParam], messages: list[ChatCompletionMessageParam],
...@@ -1337,26 +1510,26 @@ def apply_mistral_chat_template( ...@@ -1337,26 +1510,26 @@ def apply_mistral_chat_template(
# External library exceptions can sometimes occur despite the framework's # External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities. # internal exception management capabilities.
except Exception as e: except Exception as e:
# Log and report any library-related exceptions for further # Log and report any library-related exceptions for further
# investigation. # investigation.
logger.exception( logger.exception(
"An error occurred in `mistral_common` while applying chat " "An error occurred in `mistral_common` while applying chat template"
"template") )
raise ValueError(str(e)) from e raise ValueError(str(e)) from e
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]): def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
idx = 0 idx = 0
for msg in conversation: for msg in conversation:
if msg['role'] == 'assistant': if msg["role"] == "assistant":
tool_calls = msg.get('tool_calls') tool_calls = msg.get("tool_calls")
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
return idx return idx
def make_tool_call_id(id_type:str='random', func_name=None, idx=None):
if id_type=='kimi_k2': def make_tool_call_id(id_type: str = "random", func_name=None, idx=None):
return f'functions.{func_name}:{idx}' if id_type == "kimi_k2":
return f"functions.{func_name}:{idx}"
else: else:
# by default return random # by default return random
return f"chatcmpl-tool-{random_uuid()}" return f"chatcmpl-tool-{random_uuid()}"
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import contextlib
import json import json
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
...@@ -21,6 +22,23 @@ if TYPE_CHECKING: ...@@ -21,6 +22,23 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TurnTokens:
"""Tracks token counts for a single conversation turn."""
def __init__(self, input_tokens=0, output_tokens=0):
self.input_tokens = input_tokens
self.output_tokens = output_tokens
def reset(self):
"""Reset counters for a new turn."""
self.input_tokens = 0
self.output_tokens = 0
def copy(self):
"""Create a copy of this turn's token counts."""
return TurnTokens(self.input_tokens, self.output_tokens)
class ConversationContext(ABC): class ConversationContext(ABC):
@abstractmethod @abstractmethod
...@@ -41,17 +59,32 @@ class ConversationContext(ABC): ...@@ -41,17 +59,32 @@ class ConversationContext(ABC):
@abstractmethod @abstractmethod
async def init_tool_sessions(self, tool_server: Optional[ToolServer], async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack) -> None: exit_stack: AsyncExitStack,
request_id: str) -> None:
pass pass
@abstractmethod
async def cleanup_session(self) -> None:
raise NotImplementedError("Should not be called.")
class SimpleContext(ConversationContext): class SimpleContext(ConversationContext):
def __init__(self): def __init__(self):
self.last_output = None self.last_output = None
self.num_prompt_tokens = 0
self.num_output_tokens = 0
self.num_cached_tokens = 0
# todo num_reasoning_tokens is not implemented yet.
self.num_reasoning_tokens = 0
def append_output(self, output) -> None: def append_output(self, output) -> None:
self.last_output = output self.last_output = output
if not isinstance(output, RequestOutput):
raise ValueError("SimpleContext only supports RequestOutput.")
self.num_prompt_tokens = len(output.prompt_token_ids or [])
self.num_cached_tokens = output.num_cached_tokens or 0
self.num_output_tokens += len(output.outputs[0].token_ids or [])
def need_builtin_tool_call(self) -> bool: def need_builtin_tool_call(self) -> bool:
return False return False
...@@ -63,9 +96,13 @@ class SimpleContext(ConversationContext): ...@@ -63,9 +96,13 @@ class SimpleContext(ConversationContext):
raise NotImplementedError("Should not be called.") raise NotImplementedError("Should not be called.")
async def init_tool_sessions(self, tool_server: Optional[ToolServer], async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack) -> None: exit_stack: AsyncExitStack,
request_id: str) -> None:
pass pass
async def cleanup_session(self) -> None:
raise NotImplementedError("Should not be called.")
class HarmonyContext(ConversationContext): class HarmonyContext(ConversationContext):
...@@ -77,39 +114,130 @@ class HarmonyContext(ConversationContext): ...@@ -77,39 +114,130 @@ class HarmonyContext(ConversationContext):
self._messages = messages self._messages = messages
self.available_tools = available_tools self.available_tools = available_tools
self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {} self._tool_sessions: dict[str, Union[ClientSession, Tool]] = {}
self.called_tools: set[str] = set()
self.parser = get_streamable_parser_for_assistant() self.parser = get_streamable_parser_for_assistant()
self.num_init_messages = len(messages) self.num_init_messages = len(messages)
self.num_prompt_tokens = 0 self.num_prompt_tokens = 0
self.num_output_tokens = 0 self.num_output_tokens = 0
# TODO(woosuk): Implement the following fields.
self.num_cached_tokens = 0 self.num_cached_tokens = 0
self.num_reasoning_tokens = 0 self.num_reasoning_tokens = 0
self.num_tool_output_tokens = 0
def _update_num_prompt_tokens(self, output: RequestOutput): # Turn tracking - replaces multiple individual tracking variables
if output.prompt_token_ids and len(output.prompt_token_ids) > 0: self.current_turn = TurnTokens()
# NOTE: with built-in tools, there might be multiple rounds in self.previous_turn = TurnTokens()
# the conversation, with the full conversation being resent self.is_first_turn = True
# as new prompt each time. Hence the sum. self.first_tok_of_message = True # For streaming support
self.num_prompt_tokens += len(output.prompt_token_ids)
def _update_num_output_tokens(self, token_ids: Sequence[int]): def _update_num_reasoning_tokens(self):
self.num_output_tokens += len(token_ids) # Count all analysis and commentary channels as reasoning tokens
if self.parser.current_channel in {"analysis", "commentary"}:
self.num_reasoning_tokens += 1
def append_output(self, output) -> None: def append_output(self, output) -> None:
if isinstance(output, RequestOutput): if isinstance(output, RequestOutput):
self._update_num_prompt_tokens(output)
output_token_ids = output.outputs[0].token_ids output_token_ids = output.outputs[0].token_ids
self._update_num_output_tokens(output_token_ids)
self.parser = get_streamable_parser_for_assistant() self.parser = get_streamable_parser_for_assistant()
for token_id in output_token_ids: for token_id in output_token_ids:
self.parser.process(token_id) self.parser.process(token_id)
# Check if the current token is part of reasoning content
self._update_num_reasoning_tokens()
self._update_prefill_token_usage(output)
# Reset current turn output tokens for this turn
self.current_turn.output_tokens = 0
self._update_decode_token_usage(output)
# Move current turn to previous turn for next turn's calculations
self.previous_turn = self.current_turn.copy()
output_msgs = self.parser.messages output_msgs = self.parser.messages
else: else:
# Tool output. # Tool output.
output_msgs = output output_msgs = output
self._messages.extend(output_msgs) self._messages.extend(output_msgs)
def _update_prefill_token_usage(self, output: RequestOutput) -> None:
"""Update token usage statistics for the prefill phase of generation.
The prefill phase processes the input prompt tokens. This method:
1. Counts the prompt tokens for this turn
2. Calculates tool output tokens for multi-turn conversations
3. Updates cached token counts
4. Tracks state for next turn calculations
Tool output tokens are calculated as:
current_prompt_tokens - last_turn_prompt_tokens -
last_turn_output_tokens
This represents tokens added between turns (typically tool responses).
Args:
output: The RequestOutput containing prompt token information
"""
if output.prompt_token_ids is not None:
this_turn_input_tokens = len(output.prompt_token_ids)
else:
this_turn_input_tokens = 0
logger.error(
"RequestOutput appended contains no prompt_token_ids.")
# Update current turn input tokens
self.current_turn.input_tokens = this_turn_input_tokens
self.num_prompt_tokens += this_turn_input_tokens
# Calculate tool tokens (except on first turn)
if self.is_first_turn:
self.is_first_turn = False
else:
# start counting tool after first turn
# tool tokens = this turn prefill - last turn prefill -
# last turn decode
this_turn_tool_tokens = (self.current_turn.input_tokens -
self.previous_turn.input_tokens -
self.previous_turn.output_tokens)
# Handle negative tool token counts (shouldn't happen in normal
# cases)
if this_turn_tool_tokens < 0:
logger.error(
"Negative tool output tokens calculated: %d "
"(current_input=%d, previous_input=%d, "
"previous_output=%d). Setting to 0.",
this_turn_tool_tokens, self.current_turn.input_tokens,
self.previous_turn.input_tokens,
self.previous_turn.output_tokens)
this_turn_tool_tokens = 0
self.num_tool_output_tokens += this_turn_tool_tokens
# Update cached tokens
if output.num_cached_tokens is not None:
self.num_cached_tokens += output.num_cached_tokens
def _update_decode_token_usage(self, output: RequestOutput) -> int:
"""Update token usage statistics for the decode phase of generation.
The decode phase processes the generated output tokens. This method:
1. Counts output tokens from all completion outputs
2. Updates the total output token count
3. Tracks tokens generated in the current turn
In streaming mode, this is called for each token generated.
In non-streaming mode, this is called once with all output tokens.
Args:
output: The RequestOutput containing generated token information
Returns:
int: Number of output tokens processed in this call
"""
updated_output_token_count = 0
if output.outputs:
for completion_output in output.outputs:
# only keep last round
updated_output_token_count += len(completion_output.token_ids)
self.num_output_tokens += updated_output_token_count
self.current_turn.output_tokens += updated_output_token_count
return updated_output_token_count
@property @property
def messages(self) -> list: def messages(self) -> list:
return self._messages return self._messages
...@@ -118,7 +246,8 @@ class HarmonyContext(ConversationContext): ...@@ -118,7 +246,8 @@ class HarmonyContext(ConversationContext):
last_msg = self.messages[-1] last_msg = self.messages[-1]
recipient = last_msg.recipient recipient = last_msg.recipient
return recipient is not None and (recipient.startswith("browser.") return recipient is not None and (recipient.startswith("browser.")
or recipient.startswith("python")) or recipient.startswith("python") or
recipient.startswith("container."))
async def call_tool(self) -> list[Message]: async def call_tool(self) -> list[Message]:
if not self.messages: if not self.messages:
...@@ -132,6 +261,9 @@ class HarmonyContext(ConversationContext): ...@@ -132,6 +261,9 @@ class HarmonyContext(ConversationContext):
elif recipient.startswith("python"): elif recipient.startswith("python"):
return await self.call_python_tool( return await self.call_python_tool(
self._tool_sessions["python"], last_msg) self._tool_sessions["python"], last_msg)
elif recipient.startswith("container."):
return await self.call_container_tool(
self._tool_sessions["container"], last_msg)
raise ValueError("No tool call found") raise ValueError("No tool call found")
def render_for_completion(self) -> list[int]: def render_for_completion(self) -> list[int]:
...@@ -140,6 +272,7 @@ class HarmonyContext(ConversationContext): ...@@ -140,6 +272,7 @@ class HarmonyContext(ConversationContext):
async def call_search_tool(self, tool_session: Union["ClientSession", async def call_search_tool(self, tool_session: Union["ClientSession",
Tool], Tool],
last_msg: Message) -> list[Message]: last_msg: Message) -> list[Message]:
self.called_tools.add("browser")
if isinstance(tool_session, Tool): if isinstance(tool_session, Tool):
return await tool_session.get_result(self) return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1] tool_name = last_msg.recipient.split(".")[1]
...@@ -149,12 +282,16 @@ class HarmonyContext(ConversationContext): ...@@ -149,12 +282,16 @@ class HarmonyContext(ConversationContext):
content = TextContent(text=result_str) content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name=last_msg.recipient) author = Author(role=Role.TOOL, name=last_msg.recipient)
return [ return [
Message(author=author, content=[content], recipient=Role.ASSISTANT) Message(author=author,
content=[content],
recipient=Role.ASSISTANT,
channel=last_msg.channel)
] ]
async def call_python_tool(self, tool_session: Union["ClientSession", async def call_python_tool(self, tool_session: Union["ClientSession",
Tool], Tool],
last_msg: Message) -> list[Message]: last_msg: Message) -> list[Message]:
self.called_tools.add("python")
if isinstance(tool_session, Tool): if isinstance(tool_session, Tool):
return await tool_session.get_result(self) return await tool_session.get_result(self)
param = { param = {
...@@ -174,13 +311,63 @@ class HarmonyContext(ConversationContext): ...@@ -174,13 +311,63 @@ class HarmonyContext(ConversationContext):
] ]
async def init_tool_sessions(self, tool_server: Optional[ToolServer], async def init_tool_sessions(self, tool_server: Optional[ToolServer],
exit_stack: AsyncExitStack) -> None: exit_stack: AsyncExitStack,
request_id: str) -> None:
if tool_server: if tool_server:
for tool_name in self.available_tools: for tool_name in self.available_tools:
if tool_name not in self._tool_sessions: if tool_name not in self._tool_sessions:
self._tool_sessions[ tool_session = await exit_stack.enter_async_context(
tool_name] = await exit_stack.enter_async_context( tool_server.new_session(tool_name, request_id))
tool_server.new_session(tool_name)) self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session)
async def call_container_tool(self, tool_session: Union["ClientSession",
Tool],
last_msg: Message) -> list[Message]:
"""
Call container tool. Expect this to be run in a stateful docker
with command line terminal.
The official container tool would at least
expect the following format:
- for tool name: exec
- args:
{
"cmd":List[str] "command to execute",
"workdir":optional[str] "current working directory",
"env":optional[object/dict] "environment variables",
"session_name":optional[str] "session name",
"timeout":optional[int] "timeout in seconds",
"user":optional[str] "user name",
}
"""
self.called_tools.add("container")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
args = json.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text
content = TextContent(text=result_str)
author = Author(role=Role.TOOL, name=last_msg.recipient)
return [
Message(author=author,
content=[content],
recipient=Role.ASSISTANT,
channel=last_msg.channel)
]
async def cleanup_session(self, *args, **kwargs) -> None:
"""Can be used as coro to used in __aexit__"""
async def cleanup_tool_session(tool_session):
if not isinstance(tool_session, Tool):
logger.info("Cleaning up tool session for %s",
tool_session._client_info)
with contextlib.suppress(Exception):
await tool_session.call_tool("cleanup_session", {})
await asyncio.gather(*(cleanup_tool_session(self._tool_sessions[tool])
for tool in self.called_tools))
class StreamingHarmonyContext(HarmonyContext): class StreamingHarmonyContext(HarmonyContext):
...@@ -203,15 +390,22 @@ class StreamingHarmonyContext(HarmonyContext): ...@@ -203,15 +390,22 @@ class StreamingHarmonyContext(HarmonyContext):
# append_output is called for each output token in streaming case, # append_output is called for each output token in streaming case,
# so we only want to add the prompt tokens once for each message. # so we only want to add the prompt tokens once for each message.
if self.first_tok_of_message: if self.first_tok_of_message:
self._update_num_prompt_tokens(output) self._update_prefill_token_usage(output)
self.current_turn.output_tokens = 0
# Reset self.first_tok_of_message if needed: # Reset self.first_tok_of_message if needed:
# if the current token is the last one of the current message # if the current token is the last one of the current message
# (finished=True), then the next token processed will mark the # (finished=True), then the next token processed will mark the
# beginning of a new message # beginning of a new message
self.first_tok_of_message = output.finished self.first_tok_of_message = output.finished
tok = output.outputs[0].token_ids[0] for tok in output.outputs[0].token_ids:
self.parser.process(tok) self.parser.process(tok)
self._update_num_output_tokens(output.outputs[0].token_ids) self._update_decode_token_usage(output)
# For streaming, update previous turn when message is complete
if output.finished:
self.previous_turn = self.current_turn.copy()
# Check if the current token is part of reasoning content
self._update_num_reasoning_tokens()
self.last_tok = tok self.last_tok = tok
else: else:
# Handle the case of tool output in direct message format # Handle the case of tool output in direct message format
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import datetime import datetime
import json import json
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
...@@ -13,12 +16,15 @@ from openai.types.responses.response_function_web_search import ( ...@@ -13,12 +16,15 @@ from openai.types.responses.response_function_web_search import (
from openai.types.responses.response_reasoning_item import ( from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent) Content as ResponseReasoningTextContent)
from openai.types.responses.tool import Tool from openai.types.responses.tool import Tool
from openai_harmony import (Author, Conversation, DeveloperContent, from openai_harmony import (Author, ChannelConfig, Conversation,
HarmonyEncodingName, Message, ReasoningEffort, DeveloperContent, HarmonyEncodingName, Message,
Role, StreamableParser, SystemContent, TextContent, ReasoningEffort, Role, StreamableParser,
ToolDescription, load_harmony_encoding) SystemContent, TextContent, ToolDescription,
load_harmony_encoding)
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem
from vllm import envs
from vllm.entrypoints.openai.protocol import (ChatCompletionToolsParam,
ResponseInputOutputItem)
from vllm.utils import random_uuid from vllm.utils import random_uuid
REASONING_EFFORT = { REASONING_EFFORT = {
...@@ -29,6 +35,20 @@ REASONING_EFFORT = { ...@@ -29,6 +35,20 @@ REASONING_EFFORT = {
_harmony_encoding = None _harmony_encoding = None
# Builtin tools that should be included in the system message when
# they are available and requested by the user.
# Tool args are provided by MCP tool descriptions. Output
# of the tools are stringified.
BUILTIN_TOOLS = {
"web_search_preview",
"code_interpreter",
"container",
}
def has_custom_tools(tool_types: list[str]) -> bool:
return not set(tool_types).issubset(BUILTIN_TOOLS)
def get_encoding(): def get_encoding():
global _harmony_encoding global _harmony_encoding
...@@ -44,10 +64,19 @@ def get_system_message( ...@@ -44,10 +64,19 @@ def get_system_message(
start_date: Optional[str] = None, start_date: Optional[str] = None,
browser_description: Optional[str] = None, browser_description: Optional[str] = None,
python_description: Optional[str] = None, python_description: Optional[str] = None,
container_description: Optional[str] = None,
instructions: Optional[str] = None,
with_custom_tools: bool = False,
) -> Message: ) -> Message:
sys_msg_content = SystemContent.new() sys_msg_content = SystemContent.new()
if model_identity is not None: if model_identity is not None:
sys_msg_content = sys_msg_content.with_model_identity(model_identity) sys_msg_content = sys_msg_content.with_model_identity(model_identity)
if (instructions is not None
and envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
current_identity = sys_msg_content.model_identity
new_identity = (f'{current_identity}\n{instructions}'
if current_identity else instructions)
sys_msg_content = sys_msg_content.with_model_identity(new_identity)
if reasoning_effort is not None: if reasoning_effort is not None:
sys_msg_content = sys_msg_content.with_reasoning_effort( sys_msg_content = sys_msg_content.with_reasoning_effort(
REASONING_EFFORT[reasoning_effort]) REASONING_EFFORT[reasoning_effort])
...@@ -59,32 +88,55 @@ def get_system_message( ...@@ -59,32 +88,55 @@ def get_system_message(
sys_msg_content = sys_msg_content.with_tools(browser_description) sys_msg_content = sys_msg_content.with_tools(browser_description)
if python_description is not None: if python_description is not None:
sys_msg_content = sys_msg_content.with_tools(python_description) sys_msg_content = sys_msg_content.with_tools(python_description)
if container_description is not None:
sys_msg_content = sys_msg_content.with_tools(container_description)
if not with_custom_tools:
channel_config = sys_msg_content.channel_config
invalid_channel = "commentary"
new_config = ChannelConfig.require_channels(
[c for c in channel_config.valid_channels if c != invalid_channel])
sys_msg_content = sys_msg_content.with_channel_config(new_config)
sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content) sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content)
return sys_msg return sys_msg
def get_developer_message(instructions: Optional[str] = None, def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]):
tools: Optional[list[Tool]] = None) -> Message: if isinstance(tool, ChatCompletionToolsParam):
return ToolDescription.new(
name=tool.function.name,
description=tool.function.description,
parameters=tool.function.parameters,
)
return ToolDescription.new(
name=tool.name,
description=tool.description,
parameters=tool.parameters,
)
def get_developer_message(
instructions: Optional[str] = None,
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None,
) -> Message:
dev_msg_content = DeveloperContent.new() dev_msg_content = DeveloperContent.new()
if instructions is not None: if (instructions is not None
and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS):
dev_msg_content = dev_msg_content.with_instructions(instructions) dev_msg_content = dev_msg_content.with_instructions(instructions)
if tools is not None: if tools is not None:
function_tools = [] function_tools: list[Union[Tool, ChatCompletionToolsParam]] = []
for tool in tools: for tool in tools:
if tool.type in ("web_search_preview", "code_interpreter"): if tool.type in ("web_search_preview", "code_interpreter",
"container"):
# These are built-in tools that are added to the system message. # These are built-in tools that are added to the system message.
pass pass
elif tool.type == "function": elif tool.type == "function":
function_tools.append(tool) function_tools.append(tool)
else: else:
raise ValueError(f"tool type {tool.type} not supported") raise ValueError(f"tool type {tool.type} not supported")
if function_tools: if function_tools:
function_tool_descriptions = [ function_tool_descriptions = [
ToolDescription.new( create_tool_definition(tool) for tool in function_tools
name=tool.name,
description=tool.description,
parameters=tool.parameters,
) for tool in function_tools
] ]
dev_msg_content = dev_msg_content.with_function_tools( dev_msg_content = dev_msg_content.with_function_tools(
function_tool_descriptions) function_tool_descriptions)
...@@ -120,6 +172,8 @@ def parse_response_input( ...@@ -120,6 +172,8 @@ def parse_response_input(
TextContent(text=text_prefix + c["text"]) for c in content TextContent(text=text_prefix + c["text"]) for c in content
] ]
msg = Message.from_role_and_contents(role, contents) msg = Message.from_role_and_contents(role, contents)
if role == "assistant":
msg = msg.with_channel("final")
elif response_msg["type"] == "function_call_output": elif response_msg["type"] == "function_call_output":
call_id = response_msg["call_id"] call_id = response_msg["call_id"]
call_response: Optional[ResponseFunctionToolCall] = None call_response: Optional[ResponseFunctionToolCall] = None
...@@ -148,16 +202,46 @@ def parse_response_input( ...@@ -148,16 +202,46 @@ def parse_response_input(
return msg return msg
def parse_chat_input(chat_msg) -> Message: def parse_chat_input(chat_msg) -> list[Message]:
role = chat_msg["role"] if not isinstance(chat_msg, dict):
content = chat_msg["content"] # Handle Pydantic models
chat_msg = chat_msg.model_dump(exclude_none=True)
role = chat_msg.get("role")
# Assistant message with tool calls
tool_calls = chat_msg.get("tool_calls")
if role == "assistant" and tool_calls:
msgs: list[Message] = []
for call in tool_calls:
func = call.get("function", {})
name = func.get("name", "")
arguments = func.get("arguments", "") or ""
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{name}")
msg = msg.with_content_type("json")
msgs.append(msg)
return msgs
# Tool role message (tool output)
if role == "tool":
name = chat_msg.get("name", "")
content = chat_msg.get("content", "") or ""
msg = Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{name}"),
content).with_channel("commentary")
return [msg]
# Default: user/assistant/system messages with content
content = chat_msg.get("content", "")
if isinstance(content, str): if isinstance(content, str):
contents = [TextContent(text=content)] contents = [TextContent(text=content)]
else: else:
# TODO: Support refusal. # TODO: Support refusal.
contents = [TextContent(text=c.get("text", "")) for c in content] contents = [TextContent(text=c.get("text", "")) for c in content]
msg = Message.from_role_and_contents(role, contents) msg = Message.from_role_and_contents(role, contents)
return msg return [msg]
def render_for_completion(messages: list[Message]) -> list[int]: def render_for_completion(messages: list[Message]) -> list[int]:
...@@ -227,7 +311,7 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]: ...@@ -227,7 +311,7 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]:
call_id=f"call_{random_id}", call_id=f"call_{random_id}",
type="function_call", type="function_call",
name=function_name, name=function_name,
id=f"ft_{random_id}", id=f"fc_{random_id}",
) )
output_items.append(response_item) output_items.append(response_item)
elif recipient is not None and (recipient.startswith("python") elif recipient is not None and (recipient.startswith("python")
......
...@@ -95,7 +95,7 @@ async def serve_http(app: FastAPI, ...@@ -95,7 +95,7 @@ async def serve_http(app: FastAPI,
port = uvicorn_kwargs["port"] port = uvicorn_kwargs["port"]
process = find_process_using_port(port) process = find_process_using_port(port)
if process is not None: if process is not None:
logger.debug( logger.warning(
"port %s is used by process %s launched with command:\n%s", "port %s is used by process %s launched with command:\n%s",
port, process, " ".join(process.cmdline())) port, process, " ".join(process.cmdline()))
logger.info("Shutting down FastAPI HTTP server.") logger.info("Shutting down FastAPI HTTP server.")
......
...@@ -110,6 +110,14 @@ class LLM: ...@@ -110,6 +110,14 @@ class LLM:
values will increase the KV cache size and thus improve the model's values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of- throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors. memory (OOM) errors.
kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
this is set to None and vllm can automatically infer the kv cache
size based on gpu_memory_utilization. However, users may want to
manually specify the kv cache memory size. kv_cache_memory_bytes
allows more fine-grain control of how much memory gets used when
compared with using gpu_memory_memory_utilization. Note that
kv_cache_memory_bytes (when not-None) ignores
gpu_memory_utilization
swap_space: The size (GiB) of CPU memory per GPU to use as swap space. swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all when their `best_of` sampling parameters are larger than 1. If all
...@@ -184,6 +192,7 @@ class LLM: ...@@ -184,6 +192,7 @@ class LLM:
hf_overrides: Optional[HfOverrides] = None, hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None, mm_processor_kwargs: Optional[dict[str, Any]] = None,
override_pooler_config: Optional[PoolerConfig] = None, override_pooler_config: Optional[PoolerConfig] = None,
kv_cache_memory_bytes: Optional[int] = None,
compilation_config: Optional[Union[int, dict[str, Any], compilation_config: Optional[Union[int, dict[str, Any],
CompilationConfig]] = None, CompilationConfig]] = None,
logits_processors: Optional[list[Union[str, logits_processors: Optional[list[Union[str,
...@@ -204,7 +213,7 @@ class LLM: ...@@ -204,7 +213,7 @@ class LLM:
if "kv_transfer_config" in kwargs and isinstance( if "kv_transfer_config" in kwargs and isinstance(
kwargs["kv_transfer_config"], dict): kwargs["kv_transfer_config"], dict):
from vllm.config import KVTransferConfig from vllm.config.kv_transfer import KVTransferConfig
raw_config_dict = kwargs["kv_transfer_config"] raw_config_dict = kwargs["kv_transfer_config"]
try: try:
kwargs["kv_transfer_config"] = KVTransferConfig( kwargs["kv_transfer_config"] = KVTransferConfig(
...@@ -251,6 +260,7 @@ class LLM: ...@@ -251,6 +260,7 @@ class LLM:
tokenizer_revision=tokenizer_revision, tokenizer_revision=tokenizer_revision,
seed=seed, seed=seed,
gpu_memory_utilization=gpu_memory_utilization, gpu_memory_utilization=gpu_memory_utilization,
kv_cache_memory_bytes=kv_cache_memory_bytes,
swap_space=swap_space, swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb, cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
...@@ -796,7 +806,7 @@ class LLM: ...@@ -796,7 +806,7 @@ class LLM:
# NOTE: _parse_chat_message_content_parts() currently doesn't # NOTE: _parse_chat_message_content_parts() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in # handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it. # the chat message parsing for it.
conversation, mm_data = parse_chat_messages( conversation, mm_data, mm_uuids = parse_chat_messages(
msgs, msgs,
model_config, model_config,
tokenizer, tokenizer,
...@@ -826,6 +836,9 @@ class LLM: ...@@ -826,6 +836,9 @@ class LLM:
if mm_data is not None: if mm_data is not None:
prompt["multi_modal_data"] = mm_data prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
if mm_processor_kwargs is not None: if mm_processor_kwargs is not None:
prompt["mm_processor_kwargs"] = mm_processor_kwargs prompt["mm_processor_kwargs"] = mm_processor_kwargs
......
...@@ -616,14 +616,23 @@ async def create_responses(request: ResponsesRequest, raw_request: Request): ...@@ -616,14 +616,23 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
@router.get("/v1/responses/{response_id}") @router.get("/v1/responses/{response_id}")
async def retrieve_responses(response_id: str, raw_request: Request): async def retrieve_responses(
response_id: str,
raw_request: Request,
starting_after: Optional[int] = None,
stream: Optional[bool] = False,
):
handler = responses(raw_request) handler = responses(raw_request)
if handler is None: if handler is None:
return base(raw_request).create_error_response( return base(raw_request).create_error_response(
message="The model does not support Responses API") message="The model does not support Responses API")
try: try:
response = await handler.retrieve_responses(response_id) response = await handler.retrieve_responses(
response_id,
starting_after=starting_after,
stream=stream,
)
except Exception as e: except Exception as e:
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
detail=str(e)) from e detail=str(e)) from e
...@@ -631,6 +640,9 @@ async def retrieve_responses(response_id: str, raw_request: Request): ...@@ -631,6 +640,9 @@ async def retrieve_responses(response_id: str, raw_request: Request):
if isinstance(response, ErrorResponse): if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(), return JSONResponse(content=response.model_dump(),
status_code=response.error.code) status_code=response.error.code)
elif stream:
return StreamingResponse(content=response,
media_type="text/event-stream")
return JSONResponse(content=response.model_dump()) return JSONResponse(content=response.model_dump())
...@@ -1705,6 +1717,8 @@ async def init_app_state( ...@@ -1705,6 +1717,8 @@ async def init_app_state(
if args.tool_server == "demo": if args.tool_server == "demo":
tool_server: Optional[ToolServer] = DemoToolServer() tool_server: Optional[ToolServer] = DemoToolServer()
assert isinstance(tool_server, DemoToolServer)
await tool_server.init_and_validate()
elif args.tool_server: elif args.tool_server:
tool_server = MCPToolServer() tool_server = MCPToolServer()
await tool_server.add_tool_server(args.tool_server) await tool_server.add_tool_server(args.tool_server)
......
...@@ -134,14 +134,13 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" ...@@ -134,14 +134,13 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
"""If specified, will run the OpenAI frontend server in the same process as """If specified, will run the OpenAI frontend server in the same process as
the model serving engine.""" the model serving engine."""
enable_request_id_headers: bool = False enable_request_id_headers: bool = False
"""If specified, API server will add X-Request-Id header to responses. """If specified, API server will add X-Request-Id header to responses."""
Caution: this hurts performance at high QPS."""
enable_auto_tool_choice: bool = False enable_auto_tool_choice: bool = False
"""If specified, exclude tool definitions in prompts when
tool_choice='none'."""
exclude_tools_when_tool_choice_none: bool = False
"""Enable auto tool choice for supported models. Use `--tool-call-parser` """Enable auto tool choice for supported models. Use `--tool-call-parser`
to specify which parser to use.""" to specify which parser to use."""
exclude_tools_when_tool_choice_none: bool = False
"""If specified, exclude tool definitions in prompts when
tool_choice='none'."""
tool_call_parser: Optional[str] = None tool_call_parser: Optional[str] = None
"""Select the tool call parser depending on the model that you're using. """Select the tool call parser depending on the model that you're using.
This is used to parse the model-generated tool call into OpenAI API format. This is used to parse the model-generated tool call into OpenAI API format.
...@@ -204,7 +203,7 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" ...@@ -204,7 +203,7 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
frontend_kwargs["lora_modules"]["type"] = optional_type(str) frontend_kwargs["lora_modules"]["type"] = optional_type(str)
frontend_kwargs["lora_modules"]["action"] = LoRAParserAction frontend_kwargs["lora_modules"]["action"] = LoRAParserAction
# Special case: Middleware needs append action # Special case: Middleware needs to append action
frontend_kwargs["middleware"]["action"] = "append" frontend_kwargs["middleware"]["action"] = "append"
frontend_kwargs["middleware"]["type"] = str frontend_kwargs["middleware"]["type"] = str
if "nargs" in frontend_kwargs["middleware"]: if "nargs" in frontend_kwargs["middleware"]:
......
...@@ -43,10 +43,10 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, ...@@ -43,10 +43,10 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
from vllm.entrypoints.score_utils import (ScoreContentPartParam, from vllm.entrypoints.score_utils import (ScoreContentPartParam,
ScoreMultiModalParam) ScoreMultiModalParam)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams) RequestOutputKind, SamplingParams)
from vllm.sequence import Logprob
from vllm.utils import random_uuid, resolve_obj_by_qualname from vllm.utils import random_uuid, resolve_obj_by_qualname
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -1270,9 +1270,20 @@ class CompletionRequest(OpenAIBaseModel): ...@@ -1270,9 +1270,20 @@ class CompletionRequest(OpenAIBaseModel):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_prompt_and_prompt_embeds(cls, data): def validate_prompt_and_prompt_embeds(cls, data):
if data.get("prompt") is None and data.get("prompt_embeds") is None: prompt = data.get("prompt")
prompt_embeds = data.get("prompt_embeds")
prompt_is_empty = (prompt is None
or (isinstance(prompt, str) and prompt == ""))
embeds_is_empty = (prompt_embeds is None
or (isinstance(prompt_embeds, list)
and len(prompt_embeds) == 0))
if prompt_is_empty and embeds_is_empty:
raise ValueError( raise ValueError(
"At least one of `prompt` or `prompt_embeds` must be set.") "Either prompt or prompt_embeds must be provided and non-empty."
)
return data return data
@model_validator(mode="before") @model_validator(mode="before")
...@@ -1342,6 +1353,14 @@ class EmbeddingChatRequest(OpenAIBaseModel): ...@@ -1342,6 +1353,14 @@ class EmbeddingChatRequest(OpenAIBaseModel):
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
# --8<-- [start:chat-embedding-extra-params] # --8<-- [start:chat-embedding-extra-params]
add_generation_prompt: bool = Field(
default=False,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
add_special_tokens: bool = Field( add_special_tokens: bool = Field(
default=False, default=False,
description=( description=(
...@@ -1424,9 +1443,10 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]): ...@@ -1424,9 +1443,10 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
When using plugins IOProcessor plugins, the actual input is processed When using plugins IOProcessor plugins, the actual input is processed
by the plugin itself. Hence, we use a generic type for the request data by the plugin itself. Hence, we use a generic type for the request data
""" """
softmax: bool = True
def to_pooling_params(self): def to_pooling_params(self):
return PoolingParams(task="encode") return PoolingParams(task="encode", softmax=self.softmax)
class IOProcessorResponse(OpenAIBaseModel, Generic[T]): class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
...@@ -1832,7 +1852,8 @@ class InputTokensDetails(OpenAIBaseModel): ...@@ -1832,7 +1852,8 @@ class InputTokensDetails(OpenAIBaseModel):
class OutputTokensDetails(OpenAIBaseModel): class OutputTokensDetails(OpenAIBaseModel):
reasoning_tokens: int reasoning_tokens: int = 0
tool_output_tokens: int = 0
class ResponseUsage(OpenAIBaseModel): class ResponseUsage(OpenAIBaseModel):
...@@ -2175,6 +2196,13 @@ class TranscriptionRequest(OpenAIBaseModel): ...@@ -2175,6 +2196,13 @@ class TranscriptionRequest(OpenAIBaseModel):
) )
# --8<-- [end:transcription-extra-params] # --8<-- [end:transcription-extra-params]
to_language: Optional[str] = None
"""The language of the output audio we transcribe to.
Please note that this is not currently used by supported models at this
time, but it is a placeholder for future use, matching translation api.
"""
# --8<-- [start:transcription-sampling-params] # --8<-- [start:transcription-sampling-params]
temperature: float = Field(default=0.0) temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1. """The sampling temperature, between 0 and 1.
...@@ -2408,6 +2436,9 @@ class TranslationRequest(OpenAIBaseModel): ...@@ -2408,6 +2436,9 @@ class TranslationRequest(OpenAIBaseModel):
# TODO support additional sampling parameters # TODO support additional sampling parameters
# --8<-- [start:translation-sampling-params] # --8<-- [start:translation-sampling-params]
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
"""The seed to use for sampling."""
temperature: float = Field(default=0.0) temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1. """The sampling temperature, between 0 and 1.
...@@ -2427,6 +2458,14 @@ class TranslationRequest(OpenAIBaseModel): ...@@ -2427,6 +2458,14 @@ class TranslationRequest(OpenAIBaseModel):
will improve accuracy. will improve accuracy.
""" """
to_language: Optional[str] = None
"""The language of the input audio we translate to.
Please note that this is not supported by all models, refer to the specific
model documentation for more details.
For instance, Whisper only supports `to_language=en`.
"""
stream: Optional[bool] = False stream: Optional[bool] = False
"""Custom field not present in the original OpenAI definition. When set, """Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat it will enable output to be streamed in a similar fashion as the Chat
...@@ -2458,6 +2497,7 @@ class TranslationRequest(OpenAIBaseModel): ...@@ -2458,6 +2497,7 @@ class TranslationRequest(OpenAIBaseModel):
return SamplingParams.from_optional(temperature=temperature, return SamplingParams.from_optional(temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
seed=self.seed,
output_kind=RequestOutputKind.DELTA output_kind=RequestOutputKind.DELTA
if self.stream \ if self.stream \
else RequestOutputKind.FINAL_ONLY) else RequestOutputKind.FINAL_ONLY)
......
...@@ -161,7 +161,7 @@ async def write_local_file(output_path: str, ...@@ -161,7 +161,7 @@ async def write_local_file(output_path: str,
batch_outputs: The list of batch outputs to write. batch_outputs: The list of batch outputs to write.
""" """
# We should make this async, but as long as run_batch runs as a # We should make this async, but as long as run_batch runs as a
# standalone program, blocking the event loop won't effect performance. # standalone program, blocking the event loop won't affect performance.
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
for o in batch_outputs: for o in batch_outputs:
print(o.model_dump_json(), file=f) print(o.model_dump_json(), file=f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment