Commit 0da93439 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.1rc0' into v0.18.1rc0-ori

parents 25f2f756 298e5108
...@@ -27,6 +27,7 @@ from openai.types.responses import ( ...@@ -27,6 +27,7 @@ from openai.types.responses import (
ResponseReasoningTextDeltaEvent, ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent, ResponseReasoningTextDoneEvent,
ResponseStatus, ResponseStatus,
ResponseTextConfig,
ResponseWebSearchCallCompletedEvent, ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent, ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent, ResponseWebSearchCallSearchingEvent,
...@@ -38,20 +39,13 @@ from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreated ...@@ -38,20 +39,13 @@ from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreated
from openai.types.responses import ( from openai.types.responses import (
ResponseInProgressEvent as OpenAIResponseInProgressEvent, ResponseInProgressEvent as OpenAIResponseInProgressEvent,
) )
from openai.types.responses.tool import Tool
from openai_harmony import Message as OpenAIHarmonyMessage
# Backward compatibility for OpenAI client versions
try: # For older openai versions (< 1.100.0)
from openai.types.responses import ResponseTextConfig
except ImportError: # For newer openai versions (>= 1.100.0)
from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig
from openai.types.responses.response import IncompleteDetails, ToolChoice from openai.types.responses.response import IncompleteDetails, ToolChoice
from openai.types.responses.response_reasoning_item import ( from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent, Content as ResponseReasoningTextContent,
) )
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning from openai.types.shared import Metadata, Reasoning
from openai_harmony import Message as OpenAIHarmonyMessage
from pydantic import ( from pydantic import (
Field, Field,
ValidationError, ValidationError,
...@@ -258,6 +252,10 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -258,6 +252,10 @@ class ResponsesRequest(OpenAIBaseModel):
"numeric values, used by custom extensions." "numeric values, used by custom extensions."
), ),
) )
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
# --8<-- [end:responses-extra-params] # --8<-- [end:responses-extra-params]
def build_chat_params( def build_chat_params(
...@@ -357,6 +355,10 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -357,6 +355,10 @@ class ResponsesRequest(OpenAIBaseModel):
if isinstance(stop, str): if isinstance(stop, str):
stop = [stop] stop = [stop]
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
extra_args["kv_transfer_params"] = self.kv_transfer_params
return SamplingParams.from_optional( return SamplingParams.from_optional(
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
...@@ -373,7 +375,7 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -373,7 +375,7 @@ class ResponsesRequest(OpenAIBaseModel):
), ),
structured_outputs=structured_outputs, structured_outputs=structured_outputs,
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
extra_args=self.vllm_xargs or {}, extra_args=extra_args,
skip_clone=True, # Created fresh per request, safe to skip clone skip_clone=True, # Created fresh per request, safe to skip clone
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
...@@ -494,6 +496,11 @@ class ResponsesResponse(OpenAIBaseModel): ...@@ -494,6 +496,11 @@ class ResponsesResponse(OpenAIBaseModel):
usage: ResponseUsage | None = None usage: ResponseUsage | None = None
user: str | None = None user: str | None = None
# vLLM-specific fields that are not in OpenAI spec
kv_transfer_params: dict[str, Any] | None = Field(
default=None, description="KVTransfer parameters."
)
# --8<-- [start:responses-response-extra-params] # --8<-- [start:responses-response-extra-params]
# These are populated when enable_response_messages is set to True # These are populated when enable_response_messages is set to True
# NOTE: custom serialization is needed # NOTE: custom serialization is needed
...@@ -537,6 +544,7 @@ class ResponsesResponse(OpenAIBaseModel): ...@@ -537,6 +544,7 @@ class ResponsesResponse(OpenAIBaseModel):
usage: ResponseUsage | None = None, usage: ResponseUsage | None = None,
input_messages: ResponseInputOutputMessage | None = None, input_messages: ResponseInputOutputMessage | None = None,
output_messages: ResponseInputOutputMessage | None = None, output_messages: ResponseInputOutputMessage | None = None,
kv_transfer_params: dict[str, Any] | None = None,
) -> "ResponsesResponse": ) -> "ResponsesResponse":
incomplete_details: IncompleteDetails | None = None incomplete_details: IncompleteDetails | None = None
if status == "incomplete": if status == "incomplete":
...@@ -572,6 +580,7 @@ class ResponsesResponse(OpenAIBaseModel): ...@@ -572,6 +580,7 @@ class ResponsesResponse(OpenAIBaseModel):
truncation=request.truncation, truncation=request.truncation,
user=request.user, user=request.user,
usage=usage, usage=usage,
kv_transfer_params=kv_transfer_params,
) )
......
...@@ -5,11 +5,11 @@ import asyncio ...@@ -5,11 +5,11 @@ import asyncio
import time import time
import uuid import uuid
from collections import deque from collections import deque
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping, Sequence
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from copy import copy from copy import copy
from http import HTTPStatus from http import HTTPStatus
from typing import Final from typing import Any, Final
from fastapi import Request from fastapi import Request
from openai.types.responses import ( from openai.types.responses import (
...@@ -46,6 +46,7 @@ from vllm.engine.protocol import EngineClient ...@@ -46,6 +46,7 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
get_tool_call_id_type,
) )
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.mcp.tool_server import ToolServer from vllm.entrypoints.mcp.tool_server import ToolServer
...@@ -86,6 +87,7 @@ from vllm.entrypoints.openai.responses.protocol import ( ...@@ -86,6 +87,7 @@ from vllm.entrypoints.openai.responses.protocol import (
ResponseCompletedEvent, ResponseCompletedEvent,
ResponseCreatedEvent, ResponseCreatedEvent,
ResponseInProgressEvent, ResponseInProgressEvent,
ResponseInputOutputItem,
ResponseInputOutputMessage, ResponseInputOutputMessage,
ResponseReasoningPartAddedEvent, ResponseReasoningPartAddedEvent,
ResponseReasoningPartDoneEvent, ResponseReasoningPartDoneEvent,
...@@ -105,16 +107,19 @@ from vllm.entrypoints.openai.responses.utils import ( ...@@ -105,16 +107,19 @@ from vllm.entrypoints.openai.responses.utils import (
construct_tool_dicts, construct_tool_dicts,
extract_tool_types, extract_tool_types,
) )
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.utils import get_max_tokens from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ProcessorInputs, token_inputs from vllm.inputs.data import ProcessorInputs, token_inputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs from vllm.logprobs import SampleLogprobs
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.parser import ParserManager from vllm.parser import ParserManager
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
...@@ -165,6 +170,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -165,6 +170,7 @@ class OpenAIServingResponses(OpenAIServing):
self, self,
engine_client: EngineClient, engine_client: EngineClient,
models: OpenAIServingModels, models: OpenAIServingModels,
openai_serving_render: OpenAIServingRender,
*, *,
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
chat_template: str | None, chat_template: str | None,
...@@ -185,6 +191,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -185,6 +191,7 @@ class OpenAIServingResponses(OpenAIServing):
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
) )
self.openai_serving_render = openai_serving_render
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.enable_log_outputs = enable_log_outputs self.enable_log_outputs = enable_log_outputs
...@@ -235,15 +242,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -235,15 +242,7 @@ class OpenAIServingResponses(OpenAIServing):
get_stop_tokens_for_assistant_actions() get_stop_tokens_for_assistant_actions()
) )
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides) self.tool_call_id_type = get_tool_call_id_type(self.model_config)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.enable_auto_tools = enable_auto_tools self.enable_auto_tools = enable_auto_tools
# HACK(woosuk): This is a hack. We should use a better store. # HACK(woosuk): This is a hack. We should use a better store.
...@@ -587,7 +586,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -587,7 +586,7 @@ class OpenAIServingResponses(OpenAIServing):
prev_response_output=prev_response.output if prev_response else None, prev_response_output=prev_response.output if prev_response else None,
) )
_, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self.openai_serving_render.preprocess_chat(
request, request,
messages, messages,
default_template=self.chat_template, default_template=self.chat_template,
...@@ -598,6 +597,109 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -598,6 +597,109 @@ class OpenAIServingResponses(OpenAIServing):
) )
return messages, engine_prompts return messages, engine_prompts
async def _render_next_turn(
self,
request: ResponsesRequest,
messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
new_messages = construct_input_messages(
request_input=messages,
)
_, engine_prompts = await self.openai_serving_render.preprocess_chat(
request,
new_messages,
default_template=chat_template,
default_template_content_format=chat_template_content_format,
default_template_kwargs=None,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
)
return engine_prompts
async def _generate_with_builtin_tools(
self,
request_id: str,
engine_prompt: ProcessorInputs,
sampling_params: SamplingParams,
context: ConversationContext,
lora_request: LoRARequest | None = None,
priority: int = 0,
trace_headers: Mapping[str, str] | None = None,
):
max_model_len = self.model_config.max_model_len
orig_priority = priority
sub_request = 0
while True:
# Ensure that each sub-request has a unique request id.
sub_request_id = f"{request_id}_{sub_request}"
self._log_inputs(
sub_request_id,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
)
async for res in generator:
context.append_output(res)
# NOTE(woosuk): The stop condition is handled by the engine.
yield context
if not context.need_builtin_tool_call():
# The model did not ask for a tool call, so we're done.
break
# Call the tool and update the context with the result.
tool_output = await context.call_tool()
context.append_tool_output(tool_output)
# TODO: uncomment this and enable tool output streaming
# yield context
# Create inputs for the next turn.
# Render the next prompt token ids and update sampling_params.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
token_ids = context.render_for_completion()
engine_prompt = token_inputs(token_ids)
sampling_params.max_tokens = max_model_len - len(token_ids)
elif isinstance(context, ParsableContext):
(engine_prompt,) = await self._render_next_turn(
context.request,
context.parser.response_messages,
context.tool_dicts,
context.tool_parser_cls,
context.chat_template,
context.chat_template_content_format,
)
sampling_params.max_tokens = get_max_tokens(
max_model_len,
context.request.max_output_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params, # type: ignore
self.override_max_tokens, # type: ignore
)
# OPTIMIZATION
priority = orig_priority - 1
sub_request += 1
def _make_request_with_harmony( def _make_request_with_harmony(
self, self,
request: ResponsesRequest, request: ResponsesRequest,
...@@ -771,6 +873,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -771,6 +873,7 @@ class OpenAIServingResponses(OpenAIServing):
output=output, output=output,
status=status, status=status,
usage=usage, usage=usage,
kv_transfer_params=context.kv_transfer_params,
) )
if request.store: if request.store:
...@@ -903,6 +1006,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -903,6 +1006,7 @@ class OpenAIServingResponses(OpenAIServing):
parser = self.parser(tokenizer) parser = self.parser(tokenizer)
return parser.extract_response_outputs( return parser.extract_response_outputs(
model_output=final_output.text, model_output=final_output.text,
model_output_token_ids=final_output.token_ids,
request=request, request=request,
enable_auto_tools=self.enable_auto_tools, enable_auto_tools=self.enable_auto_tools,
tool_call_id_type=self.tool_call_id_type, tool_call_id_type=self.tool_call_id_type,
......
...@@ -191,13 +191,13 @@ def _construct_single_message_from_response_item( ...@@ -191,13 +191,13 @@ def _construct_single_message_from_response_item(
], ],
) )
elif isinstance(item, ResponseReasoningItem): elif isinstance(item, ResponseReasoningItem):
reasoning_content = "" reasoning = ""
if item.encrypted_content: if item.encrypted_content:
raise ValueError("Encrypted content is not supported.") raise ValueError("Encrypted content is not supported.")
elif item.content and len(item.content) >= 1: elif item.content and len(item.content) >= 1:
reasoning_content = item.content[0].text reasoning = item.content[0].text
elif len(item.summary) >= 1: elif len(item.summary) >= 1:
reasoning_content = item.summary[0].text reasoning = item.summary[0].text
logger.warning( logger.warning(
"Using summary text as reasoning content for item %s. " "Using summary text as reasoning content for item %s. "
"Please use content instead of summary for " "Please use content instead of summary for "
...@@ -206,7 +206,7 @@ def _construct_single_message_from_response_item( ...@@ -206,7 +206,7 @@ def _construct_single_message_from_response_item(
) )
return { return {
"role": "assistant", "role": "assistant",
"reasoning": reasoning_content, "reasoning": reasoning,
} }
elif isinstance(item, ResponseOutputMessage): elif isinstance(item, ResponseOutputMessage):
return { return {
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio import asyncio
import base64
import sys import sys
import tempfile import tempfile
from argparse import Namespace from argparse import Namespace
...@@ -13,6 +12,7 @@ from typing import Any, TypeAlias ...@@ -13,6 +12,7 @@ from typing import Any, TypeAlias
from urllib.parse import urlparse from urllib.parse import urlparse
import aiohttp import aiohttp
import pybase64 as base64
import torch import torch
from fastapi import UploadFile from fastapi import UploadFile
from prometheus_client import start_http_server from prometheus_client import start_http_server
...@@ -54,6 +54,7 @@ from vllm.entrypoints.pooling.score.protocol import ( ...@@ -54,6 +54,7 @@ from vllm.entrypoints.pooling.score.protocol import (
ScoreResponse, ScoreResponse,
) )
from vllm.entrypoints.utils import create_error_response from vllm.entrypoints.utils import create_error_response
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager from vllm.reasoning import ReasoningParserManager
from vllm.utils import random_uuid from vllm.utils import random_uuid
...@@ -86,9 +87,10 @@ class BatchTranscriptionRequest(TranscriptionRequest): ...@@ -86,9 +87,10 @@ class BatchTranscriptionRequest(TranscriptionRequest):
def validate_no_file(cls, data: Any): def validate_no_file(cls, data: Any):
"""Ensure file field is not provided in batch requests.""" """Ensure file field is not provided in batch requests."""
if isinstance(data, dict) and "file" in data: if isinstance(data, dict) and "file" in data:
raise ValueError( raise VLLMValidationError(
"The 'file' field is not supported in batch requests. " "The 'file' field is not supported in batch requests. "
"Use 'file_url' instead." "Use 'file_url' instead.",
parameter="file",
) )
return data return data
...@@ -116,9 +118,10 @@ class BatchTranslationRequest(TranslationRequest): ...@@ -116,9 +118,10 @@ class BatchTranslationRequest(TranslationRequest):
def validate_no_file(cls, data: Any): def validate_no_file(cls, data: Any):
"""Ensure file field is not provided in batch requests.""" """Ensure file field is not provided in batch requests."""
if isinstance(data, dict) and "file" in data: if isinstance(data, dict) and "file" in data:
raise ValueError( raise VLLMValidationError(
"The 'file' field is not supported in batch requests. " "The 'file' field is not supported in batch requests. "
"Use 'file_url' instead." "Use 'file_url' instead.",
parameter="file",
) )
return data return data
...@@ -820,7 +823,6 @@ async def main(args: Namespace): ...@@ -820,7 +823,6 @@ async def main(args: Namespace):
async with build_async_engine_client( async with build_async_engine_client(
args, args,
usage_context=UsageContext.OPENAI_BATCH_RUNNER, usage_context=UsageContext.OPENAI_BATCH_RUNNER,
disable_frontend_multiprocessing=False,
) as engine_client: ) as engine_client:
await run_batch(engine_client, args) await run_batch(engine_client, args)
......
...@@ -371,7 +371,7 @@ async def generation_error_handler(req: Request, exc: GenerationError): ...@@ -371,7 +371,7 @@ async def generation_error_handler(req: Request, exc: GenerationError):
async def exception_handler(req: Request, exc: Exception): async def exception_handler(req: Request, exc: Exception):
if req.app.state.args.log_error_stack: if req.app.state.args.log_error_stack:
logger.exception( logger.error(
"Exception caught. Request id: %s", "Exception caught. Request id: %s",
req.state.request_metadata.request_id req.state.request_metadata.request_id
if hasattr(req.state, "request_metadata") if hasattr(req.state, "request_metadata")
......
...@@ -107,7 +107,7 @@ class TranscriptionRequest(OpenAIBaseModel): ...@@ -107,7 +107,7 @@ class TranscriptionRequest(OpenAIBaseModel):
stream_include_usage: bool | None = False stream_include_usage: bool | None = False
stream_continuous_usage_stats: bool | None = False stream_continuous_usage_stats: bool | None = False
vllm_xargs: dict[str, str | int | float] | None = Field( vllm_xargs: dict[str, str | int | float | bool] | None = Field(
default=None, default=None,
description=( description=(
"Additional request parameters with string or " "Additional request parameters with string or "
......
...@@ -42,32 +42,13 @@ from vllm.inputs import EncoderDecoderInputs, ProcessorInputs ...@@ -42,32 +42,13 @@ from vllm.inputs import EncoderDecoderInputs, ProcessorInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription from vllm.model_executor.models import SupportsTranscription
from vllm.multimodal.audio import split_audio from vllm.multimodal.audio import get_audio_duration, split_audio
from vllm.multimodal.media.audio import extract_audio_from_video_bytes from vllm.multimodal.media.audio import load_audio
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import get_tokenizer from vllm.tokenizers import get_tokenizer
from vllm.utils.import_utils import PlaceholderModule
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile as sf
except ImportError:
sf = PlaceholderModule("soundfile") # type: ignore[assignment]
# Public libsndfile error codes exposed via `soundfile.LibsndfileError.code`, soundfile
# being librosa's main backend. Used to validate if an audio loading error is due to a
# server error vs a client error (invalid audio file).
# 1 = unrecognised format (file is not a supported audio container)
# 3 = malformed file (corrupt or structurally invalid audio)
# 4 = unsupported encoding (codec not supported by this libsndfile build)
_BAD_SF_CODES = {1, 3, 4}
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
SpeechToTextResponseVerbose: TypeAlias = ( SpeechToTextResponseVerbose: TypeAlias = (
...@@ -214,32 +195,13 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -214,32 +195,13 @@ class OpenAISpeechToText(OpenAIServing):
# pre-requisite for chunking, as it assumes Whisper SR. # pre-requisite for chunking, as it assumes Whisper SR.
try: try:
with io.BytesIO(audio_data) as buf: with io.BytesIO(audio_data) as buf:
y, sr = librosa.load(buf, sr=self.asr_config.sample_rate) # type: ignore[return-value] y, sr = load_audio(buf, sr=self.asr_config.sample_rate)
except sf.LibsndfileError as exc: except Exception as exc:
# Only fall back for known format-detection failures. raise ValueError("Invalid or unsupported audio file.") from exc
# Re-raise anything else (e.g. corrupt but recognised format).
if exc.code not in _BAD_SF_CODES:
raise
logger.debug(
"librosa/soundfile could not decode audio from BytesIO "
"(code=%s: %s); falling back to pyav in-process decode",
exc.code,
exc,
)
try:
native_y, native_sr = extract_audio_from_video_bytes(audio_data)
sr = self.asr_config.sample_rate
y = librosa.resample(native_y, orig_sr=native_sr, target_sr=sr)
except Exception as pyav_exc:
logger.debug(
"pyAV fallback also failed: %s",
pyav_exc,
)
raise ValueError("Invalid or unsupported audio file.") from pyav_exc
duration = librosa.get_duration(y=y, sr=sr) duration = get_audio_duration(y=y, sr=sr)
do_split_audio = ( do_split_audio = self.asr_config.allow_audio_chunking and (
self.asr_config.allow_audio_chunking self.asr_config.max_audio_clip_s is not None
and duration > self.asr_config.max_audio_clip_s and duration > self.asr_config.max_audio_clip_s
) )
......
...@@ -5,6 +5,9 @@ from typing import TYPE_CHECKING ...@@ -5,6 +5,9 @@ from typing import TYPE_CHECKING
from fastapi import FastAPI from fastapi import FastAPI
from vllm.config import ModelConfig
from vllm.logger import init_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from argparse import Namespace from argparse import Namespace
...@@ -17,9 +20,30 @@ else: ...@@ -17,9 +20,30 @@ else:
RequestLogger = object RequestLogger = object
SupportedTask = object SupportedTask = object
logger = init_logger(__name__)
def enable_scoring_api(
supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None,
) -> bool:
if any(t in supported_tasks for t in ("embed", "token_embed")):
return True
if model_config is not None and "classify" in supported_tasks:
num_labels = getattr(model_config.hf_config, "num_labels", 0)
if num_labels != 1:
logger.debug_once("Score API is only enabled for num_labels == 1.")
return False
return True
return False
def register_pooling_api_routers( def register_pooling_api_routers(
app: FastAPI, supported_tasks: tuple["SupportedTask", ...] app: FastAPI,
supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None,
): ):
from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router
...@@ -37,11 +61,7 @@ def register_pooling_api_routers( ...@@ -37,11 +61,7 @@ def register_pooling_api_routers(
app.include_router(embed_router) app.include_router(embed_router)
# Score API handles score/rerank for: if enable_scoring_api(supported_tasks, model_config):
# - "score" task (score_type: cross-encoder models)
# - "embed" task (score_type: bi-encoder models)
# - "token_embed" task (score_type: late interaction models)
if any(t in supported_tasks for t in ("score", "embed", "token_embed")):
from vllm.entrypoints.pooling.score.api_router import router as score_router from vllm.entrypoints.pooling.score.api_router import router as score_router
app.include_router(score_router) app.include_router(score_router)
...@@ -61,6 +81,8 @@ def init_pooling_state( ...@@ -61,6 +81,8 @@ def init_pooling_state(
from vllm.entrypoints.pooling.score.serving import ServingScores from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.tasks import POOLING_TASKS from vllm.tasks import POOLING_TASKS
model_config = engine_client.model_config
resolved_chat_template = load_chat_template(args.chat_template) resolved_chat_template = load_chat_template(args.chat_template)
state.serving_pooling = ( state.serving_pooling = (
...@@ -68,6 +90,7 @@ def init_pooling_state( ...@@ -68,6 +90,7 @@ def init_pooling_state(
OpenAIServingPooling( OpenAIServingPooling(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
state.openai_serving_render,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
...@@ -101,10 +124,6 @@ def init_pooling_state( ...@@ -101,10 +124,6 @@ def init_pooling_state(
if "classify" in supported_tasks if "classify" in supported_tasks
else None else None
) )
# Score API handles score/rerank for:
# - "score" task (score_type: cross-encoder models)
# - "embed" task (score_type: bi-encoder models)
# - "token_embed" task (score_type: late interaction models)
state.serving_scores = ( state.serving_scores = (
ServingScores( ServingScores(
engine_client, engine_client,
...@@ -113,6 +132,6 @@ def init_pooling_state( ...@@ -113,6 +132,6 @@ def init_pooling_state(
score_template=resolved_chat_template, score_template=resolved_chat_template,
log_error_stack=args.log_error_stack, log_error_stack=args.log_error_stack,
) )
if any(t in supported_tasks for t in ("embed", "score", "token_embed")) if enable_scoring_api(supported_tasks, model_config)
else None else None
) )
...@@ -11,6 +11,7 @@ from vllm.entrypoints.chat_utils import ( ...@@ -11,6 +11,7 @@ from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption, ChatTemplateContentFormatOption,
) )
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.exceptions import VLLMValidationError
from vllm.renderers import ChatParams, merge_kwargs from vllm.renderers import ChatParams, merge_kwargs
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
...@@ -147,9 +148,9 @@ class ChatRequestMixin(OpenAIBaseModel): ...@@ -147,9 +148,9 @@ class ChatRequestMixin(OpenAIBaseModel):
@classmethod @classmethod
def check_generation_prompt(cls, data): def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get("add_generation_prompt"): if data.get("continue_final_message") and data.get("add_generation_prompt"):
raise ValueError( raise VLLMValidationError(
"Cannot set both `continue_final_message` and " "Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True." "`add_generation_prompt` to True.",
) )
return data return data
......
...@@ -6,13 +6,13 @@ OpenAI: https://platform.openai.com/docs/api-reference/embeddings ...@@ -6,13 +6,13 @@ OpenAI: https://platform.openai.com/docs/api-reference/embeddings
Cohere: https://docs.cohere.com/reference/embed Cohere: https://docs.cohere.com/reference/embed
""" """
import base64
import builtins import builtins
import struct import struct
import time import time
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal, TypeAlias from typing import Literal, TypeAlias
import pybase64 as base64
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from vllm import PoolingParams from vllm import PoolingParams
......
...@@ -23,7 +23,7 @@ def init_pooling_io_processors( ...@@ -23,7 +23,7 @@ def init_pooling_io_processors(
if "embed" in supported_tasks: if "embed" in supported_tasks:
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
processors.append(("classify", EmbedIOProcessor)) processors.append(("embed", EmbedIOProcessor))
return { return {
task: processor_cls( task: processor_cls(
......
...@@ -32,6 +32,7 @@ from vllm.entrypoints.pooling.utils import ( ...@@ -32,6 +32,7 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64, encode_pooling_output_base64,
encode_pooling_output_float, encode_pooling_output_float,
) )
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.inputs import ProcessorInputs from vllm.inputs import ProcessorInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
...@@ -47,6 +48,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -47,6 +48,7 @@ class OpenAIServingPooling(OpenAIServing):
self, self,
engine_client: EngineClient, engine_client: EngineClient,
models: OpenAIServingModels, models: OpenAIServingModels,
openai_serving_render: OpenAIServingRender,
*, *,
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
chat_template: str | None, chat_template: str | None,
...@@ -59,6 +61,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -59,6 +61,7 @@ class OpenAIServingPooling(OpenAIServing):
request_logger=request_logger, request_logger=request_logger,
) )
self.openai_serving_render = openai_serving_render
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template self.trust_request_chat_template = trust_request_chat_template
...@@ -101,12 +104,12 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -101,12 +104,12 @@ class OpenAIServingPooling(OpenAIServing):
raw_prompts = await self.io_processor.pre_process_async( raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id prompt=validated_prompt, request_id=request_id
) )
engine_prompts = await self._preprocess_cmpl( engine_prompts = await self.openai_serving_render.preprocess_cmpl(
request, request,
prompt_to_seq(raw_prompts), prompt_to_seq(raw_prompts),
) )
elif isinstance(request, PoolingChatRequest): elif isinstance(request, PoolingChatRequest):
error_check_ret = self._validate_chat_template( error_check_ret = self.openai_serving_render.validate_chat_template(
request_chat_template=request.chat_template, request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs, chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template, trust_request_chat_template=self.trust_request_chat_template,
...@@ -114,7 +117,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -114,7 +117,7 @@ class OpenAIServingPooling(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
_, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self.openai_serving_render.preprocess_chat(
request, request,
request.messages, request.messages,
default_template=self.chat_template, default_template=self.chat_template,
...@@ -122,7 +125,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -122,7 +125,7 @@ class OpenAIServingPooling(OpenAIServing):
default_template_kwargs=None, default_template_kwargs=None,
) )
elif isinstance(request, PoolingCompletionRequest): elif isinstance(request, PoolingCompletionRequest):
engine_prompts = await self._preprocess_completion( engine_prompts = await self.openai_serving_render.preprocess_completion(
request, request,
prompt_input=request.input, prompt_input=request.input,
prompt_embeds=None, prompt_embeds=None,
......
...@@ -35,7 +35,7 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin): ...@@ -35,7 +35,7 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
max_total_tokens_param="max_model_len", max_total_tokens_param="max_model_len",
) )
def to_pooling_params(self, task: PoolingTask = "score"): def to_pooling_params(self, task: PoolingTask = "classify"):
return PoolingParams( return PoolingParams(
task=task, task=task,
use_activation=self.use_activation, use_activation=self.use_activation,
...@@ -111,7 +111,7 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin): ...@@ -111,7 +111,7 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
max_total_tokens_param="max_model_len", max_total_tokens_param="max_model_len",
) )
def to_pooling_params(self, task: PoolingTask = "score"): def to_pooling_params(self, task: PoolingTask = "classify"):
return PoolingParams( return PoolingParams(
task=task, task=task,
use_activation=self.use_activation, use_activation=self.use_activation,
......
...@@ -413,7 +413,7 @@ class ServingScores(OpenAIServing): ...@@ -413,7 +413,7 @@ class ServingScores(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
default_pooling_params = request.to_pooling_params("score") default_pooling_params = request.to_pooling_params("classify")
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
......
...@@ -60,14 +60,6 @@ def encode_pooling_output_float(output: PoolingRequestOutput) -> list[float]: ...@@ -60,14 +60,6 @@ def encode_pooling_output_float(output: PoolingRequestOutput) -> list[float]:
return output.outputs.data.tolist() return output.outputs.data.tolist()
def encode_pooling_output_binary(
output: PoolingRequestOutput,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> bytes:
return tensor2binary(output.outputs.data, embed_dtype, endianness)
def encode_pooling_output_base64( def encode_pooling_output_base64(
output: PoolingRequestOutput, output: PoolingRequestOutput,
embed_dtype: EmbedDType, embed_dtype: EmbedDType,
......
...@@ -10,9 +10,11 @@ import pydantic ...@@ -10,9 +10,11 @@ import pydantic
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response from fastapi.responses import JSONResponse, Response
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling import enable_scoring_api
from vllm.entrypoints.pooling.base.serving import PoolingServing from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.serve.instrumentator.basic import base from vllm.entrypoints.serve.instrumentator.basic import base
from vllm.entrypoints.serve.instrumentator.health import health from vllm.entrypoints.serve.instrumentator.health import health
...@@ -25,7 +27,10 @@ GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None] ...@@ -25,7 +27,10 @@ GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]] EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
def get_invocation_types(supported_tasks: tuple["SupportedTask", ...]): def get_invocation_types(
supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None,
):
# NOTE: Items defined earlier take higher priority # NOTE: Items defined earlier take higher priority
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [] INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = []
...@@ -70,7 +75,7 @@ def get_invocation_types(supported_tasks: tuple["SupportedTask", ...]): ...@@ -70,7 +75,7 @@ def get_invocation_types(supported_tasks: tuple["SupportedTask", ...]):
(ClassificationRequest, (classify, create_classify)), (ClassificationRequest, (classify, create_classify)),
] ]
if "score" in supported_tasks: if enable_scoring_api(supported_tasks, model_config):
from vllm.entrypoints.pooling.score.api_router import do_rerank, rerank from vllm.entrypoints.pooling.score.api_router import do_rerank, rerank
from vllm.entrypoints.pooling.score.protocol import RerankRequest from vllm.entrypoints.pooling.score.protocol import RerankRequest
...@@ -78,7 +83,6 @@ def get_invocation_types(supported_tasks: tuple["SupportedTask", ...]): ...@@ -78,7 +83,6 @@ def get_invocation_types(supported_tasks: tuple["SupportedTask", ...]):
(RerankRequest, (rerank, do_rerank)), (RerankRequest, (rerank, do_rerank)),
] ]
if "score" in supported_tasks or "embed" in supported_tasks:
from vllm.entrypoints.pooling.score.api_router import create_score, score from vllm.entrypoints.pooling.score.api_router import create_score, score
from vllm.entrypoints.pooling.score.protocol import ScoreRequest from vllm.entrypoints.pooling.score.protocol import ScoreRequest
...@@ -97,11 +101,15 @@ def get_invocation_types(supported_tasks: tuple["SupportedTask", ...]): ...@@ -97,11 +101,15 @@ def get_invocation_types(supported_tasks: tuple["SupportedTask", ...]):
return INVOCATION_TYPES return INVOCATION_TYPES
def attach_router(app: FastAPI, supported_tasks: tuple["SupportedTask", ...]): def attach_router(
app: FastAPI,
supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None,
):
router = APIRouter() router = APIRouter()
# NOTE: Construct the TypeAdapters only once # NOTE: Construct the TypeAdapters only once
INVOCATION_TYPES = get_invocation_types(supported_tasks) INVOCATION_TYPES = get_invocation_types(supported_tasks, model_config)
INVOCATION_VALIDATORS = [ INVOCATION_VALIDATORS = [
(pydantic.TypeAdapter(request_type), (get_handler, endpoint)) (pydantic.TypeAdapter(request_type), (get_handler, endpoint))
for request_type, (get_handler, endpoint) in INVOCATION_TYPES for request_type, (get_handler, endpoint) in INVOCATION_TYPES
......
...@@ -29,6 +29,7 @@ from vllm.entrypoints.serve.disagg.protocol import ( ...@@ -29,6 +29,7 @@ from vllm.entrypoints.serve.disagg.protocol import (
GenerateResponse, GenerateResponse,
GenerateResponseChoice, GenerateResponseChoice,
) )
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -45,6 +46,7 @@ class ServingTokens(OpenAIServing): ...@@ -45,6 +46,7 @@ class ServingTokens(OpenAIServing):
self, self,
engine_client: EngineClient, engine_client: EngineClient,
models: OpenAIServingModels, models: OpenAIServingModels,
openai_serving_render: OpenAIServingRender,
*, *,
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
force_no_detokenize: bool = False, force_no_detokenize: bool = False,
...@@ -58,6 +60,7 @@ class ServingTokens(OpenAIServing): ...@@ -58,6 +60,7 @@ class ServingTokens(OpenAIServing):
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids, return_tokens_as_token_ids=return_tokens_as_token_ids,
) )
self.openai_serving_render = openai_serving_render
self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_log_outputs = enable_log_outputs self.enable_log_outputs = enable_log_outputs
self.force_no_detokenize = force_no_detokenize self.force_no_detokenize = force_no_detokenize
...@@ -96,7 +99,7 @@ class ServingTokens(OpenAIServing): ...@@ -96,7 +99,7 @@ class ServingTokens(OpenAIServing):
if raw_request: if raw_request:
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
engine_prompts = await self._preprocess_completion( engine_prompts = await self.openai_serving_render.preprocess_completion(
request, request,
prompt_input=request.token_ids, prompt_input=request.token_ids,
prompt_embeds=None, prompt_embeds=None,
......
...@@ -24,6 +24,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -24,6 +24,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
parse_chat_inputs_to_harmony_messages, parse_chat_inputs_to_harmony_messages,
render_for_completion, render_for_completion,
) )
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.entrypoints.serve.disagg.protocol import ( from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest, GenerateRequest,
MultiModalFeatures, MultiModalFeatures,
...@@ -226,7 +227,7 @@ class OpenAIServingRender: ...@@ -226,7 +227,7 @@ class OpenAIServingRender:
if not self.use_harmony: if not self.use_harmony:
# Common case. # Common case.
error_check_ret = self._validate_chat_template( error_check_ret = self.validate_chat_template(
request_chat_template=request.chat_template, request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs, chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template, trust_request_chat_template=self.trust_request_chat_template,
...@@ -234,7 +235,7 @@ class OpenAIServingRender: ...@@ -234,7 +235,7 @@ class OpenAIServingRender:
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
conversation, engine_prompts = await self._preprocess_chat( conversation, engine_prompts = await self.preprocess_chat(
request, request,
request.messages, request.messages,
default_template=self.chat_template, default_template=self.chat_template,
...@@ -328,7 +329,7 @@ class OpenAIServingRender: ...@@ -328,7 +329,7 @@ class OpenAIServingRender:
"prompt_logprobs is not compatible with prompt embeds." "prompt_logprobs is not compatible with prompt embeds."
) )
engine_prompts = await self._preprocess_completion( engine_prompts = await self.preprocess_completion(
request, request,
prompt_input=request.prompt, prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds, prompt_embeds=request.prompt_embeds,
...@@ -426,7 +427,7 @@ class OpenAIServingRender: ...@@ -426,7 +427,7 @@ class OpenAIServingRender:
) -> ErrorResponse | None: ) -> ErrorResponse | None:
return await self.model_registry.check_model(request.model) return await self.model_registry.check_model(request.model)
def _validate_chat_template( def validate_chat_template(
self, self,
request_chat_template: str | None, request_chat_template: str | None,
chat_template_kwargs: dict[str, Any] | None, chat_template_kwargs: dict[str, Any] | None,
...@@ -447,7 +448,7 @@ class OpenAIServingRender: ...@@ -447,7 +448,7 @@ class OpenAIServingRender:
) )
return None return None
async def _preprocess_completion( async def preprocess_completion(
self, self,
request: Any, request: Any,
prompt_input: str | list[str] | list[int] | list[list[int]] | None, prompt_input: str | list[str] | list[int] | list[list[int]] | None,
...@@ -459,9 +460,9 @@ class OpenAIServingRender: ...@@ -459,9 +460,9 @@ class OpenAIServingRender:
prompts.extend(prompt_to_seq(prompt_embeds)) prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None: if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input)) prompts.extend(prompt_to_seq(prompt_input))
return await self._preprocess_cmpl(request, prompts) return await self.preprocess_cmpl(request, prompts)
async def _preprocess_cmpl( async def preprocess_cmpl(
self, self,
request: Any, request: Any,
prompts: Sequence[PromptType | bytes], prompts: Sequence[PromptType | bytes],
...@@ -490,7 +491,7 @@ class OpenAIServingRender: ...@@ -490,7 +491,7 @@ class OpenAIServingRender:
}, },
) )
async def _preprocess_chat( async def preprocess_chat(
self, self,
request: Any, request: Any,
messages: list[Any], messages: list[Any],
...@@ -500,11 +501,7 @@ class OpenAIServingRender: ...@@ -500,11 +501,7 @@ class OpenAIServingRender:
tool_dicts: list[dict[str, Any]] | None = None, tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]]: ) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
"""Copied from OpenAIServing._preprocess_chat. """Copied from OpenAIServing._preprocess_chat."""
Differences: isinstance check is ChatCompletionRequest-only
(ResponsesRequest not supported here); TODO comment dropped accordingly.
"""
renderer = self.renderer renderer = self.renderer
mm_config = self.model_config.multimodal_config mm_config = self.model_config.multimodal_config
...@@ -542,11 +539,11 @@ class OpenAIServingRender: ...@@ -542,11 +539,11 @@ class OpenAIServingRender:
if tool_parser is not None: if tool_parser is not None:
tool_choice = getattr(request, "tool_choice", "none") tool_choice = getattr(request, "tool_choice", "none")
if tool_choice != "none": if tool_choice != "none":
if not isinstance(request, ChatCompletionRequest): if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
msg = ( msg = (
"Tool usage is only supported " "Tool usage is only supported "
" for ChatCompletionRequest, but got " "for Chat Completions API or Responses API requests, "
f"{type(request).__name__}" f"but got {type(request).__name__}"
) )
raise NotImplementedError(msg) raise NotImplementedError(msg)
tokenizer = renderer.get_tokenizer() tokenizer = renderer.get_tokenizer()
......
...@@ -17,6 +17,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import ( ...@@ -17,6 +17,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import (
OpenAIBaseModel, OpenAIBaseModel,
) )
from vllm.exceptions import VLLMValidationError
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
...@@ -120,9 +121,9 @@ class TokenizeChatRequest(OpenAIBaseModel): ...@@ -120,9 +121,9 @@ class TokenizeChatRequest(OpenAIBaseModel):
@classmethod @classmethod
def check_generation_prompt(cls, data): def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get("add_generation_prompt"): if data.get("continue_final_message") and data.get("add_generation_prompt"):
raise ValueError( raise VLLMValidationError(
"Cannot set both `continue_final_message` and " "Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True." "`add_generation_prompt` to True.",
) )
return data return data
......
...@@ -11,6 +11,7 @@ from vllm.entrypoints.logger import RequestLogger ...@@ -11,6 +11,7 @@ from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.serve.tokenize.protocol import ( from vllm.entrypoints.serve.tokenize.protocol import (
DetokenizeRequest, DetokenizeRequest,
DetokenizeResponse, DetokenizeResponse,
...@@ -31,6 +32,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -31,6 +32,7 @@ class OpenAIServingTokenization(OpenAIServing):
self, self,
engine_client: EngineClient, engine_client: EngineClient,
models: OpenAIServingModels, models: OpenAIServingModels,
openai_serving_render: OpenAIServingRender,
*, *,
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
chat_template: str | None, chat_template: str | None,
...@@ -44,6 +46,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -44,6 +46,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_logger=request_logger, request_logger=request_logger,
) )
self.openai_serving_render = openai_serving_render
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.default_chat_template_kwargs = default_chat_template_kwargs or {} self.default_chat_template_kwargs = default_chat_template_kwargs or {}
...@@ -68,7 +71,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -68,7 +71,7 @@ class OpenAIServingTokenization(OpenAIServing):
if request.tools is None if request.tools is None
else [tool.model_dump() for tool in request.tools] else [tool.model_dump() for tool in request.tools]
) )
error_check_ret = self._validate_chat_template( error_check_ret = self.openai_serving_render.validate_chat_template(
request_chat_template=request.chat_template, request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs, chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template, trust_request_chat_template=self.trust_request_chat_template,
...@@ -76,7 +79,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -76,7 +79,7 @@ class OpenAIServingTokenization(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
_, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self.openai_serving_render.preprocess_chat(
request, request,
request.messages, request.messages,
default_template=self.chat_template, default_template=self.chat_template,
...@@ -85,7 +88,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -85,7 +88,7 @@ class OpenAIServingTokenization(OpenAIServing):
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
) )
else: else:
engine_prompts = await self._preprocess_completion( engine_prompts = await self.openai_serving_render.preprocess_completion(
request, request,
prompt_input=request.prompt, prompt_input=request.prompt,
prompt_embeds=None, prompt_embeds=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment