"vllm/vscode:/vscode.git/clone" did not exist on "517b769b5858a8d8d233d277f54461acfc9def63"
Commit 0da93439 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.1rc0' into v0.18.1rc0-ori

parents 25f2f756 298e5108
......@@ -27,6 +27,7 @@ from openai.types.responses import (
ResponseReasoningTextDeltaEvent,
ResponseReasoningTextDoneEvent,
ResponseStatus,
ResponseTextConfig,
ResponseWebSearchCallCompletedEvent,
ResponseWebSearchCallInProgressEvent,
ResponseWebSearchCallSearchingEvent,
......@@ -38,20 +39,13 @@ from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreated
from openai.types.responses import (
ResponseInProgressEvent as OpenAIResponseInProgressEvent,
)
from openai.types.responses.tool import Tool
from openai_harmony import Message as OpenAIHarmonyMessage
# Backward compatibility for OpenAI client versions
try: # For older openai versions (< 1.100.0)
from openai.types.responses import ResponseTextConfig
except ImportError: # For newer openai versions (>= 1.100.0)
from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig
from openai.types.responses.response import IncompleteDetails, ToolChoice
from openai.types.responses.response_reasoning_item import (
Content as ResponseReasoningTextContent,
)
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
from openai_harmony import Message as OpenAIHarmonyMessage
from pydantic import (
Field,
ValidationError,
......@@ -258,6 +252,10 @@ class ResponsesRequest(OpenAIBaseModel):
"numeric values, used by custom extensions."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
# --8<-- [end:responses-extra-params]
def build_chat_params(
......@@ -357,6 +355,10 @@ class ResponsesRequest(OpenAIBaseModel):
if isinstance(stop, str):
stop = [stop]
extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {}
if self.kv_transfer_params:
extra_args["kv_transfer_params"] = self.kv_transfer_params
return SamplingParams.from_optional(
temperature=temperature,
top_p=top_p,
......@@ -373,7 +375,7 @@ class ResponsesRequest(OpenAIBaseModel):
),
structured_outputs=structured_outputs,
logit_bias=self.logit_bias,
extra_args=self.vllm_xargs or {},
extra_args=extra_args,
skip_clone=True, # Created fresh per request, safe to skip clone
skip_special_tokens=self.skip_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
......@@ -494,6 +496,11 @@ class ResponsesResponse(OpenAIBaseModel):
usage: ResponseUsage | None = None
user: str | None = None
# vLLM-specific fields that are not in OpenAI spec
kv_transfer_params: dict[str, Any] | None = Field(
default=None, description="KVTransfer parameters."
)
# --8<-- [start:responses-response-extra-params]
# These are populated when enable_response_messages is set to True
# NOTE: custom serialization is needed
......@@ -537,6 +544,7 @@ class ResponsesResponse(OpenAIBaseModel):
usage: ResponseUsage | None = None,
input_messages: ResponseInputOutputMessage | None = None,
output_messages: ResponseInputOutputMessage | None = None,
kv_transfer_params: dict[str, Any] | None = None,
) -> "ResponsesResponse":
incomplete_details: IncompleteDetails | None = None
if status == "incomplete":
......@@ -572,6 +580,7 @@ class ResponsesResponse(OpenAIBaseModel):
truncation=request.truncation,
user=request.user,
usage=usage,
kv_transfer_params=kv_transfer_params,
)
......
......@@ -5,11 +5,11 @@ import asyncio
import time
import uuid
from collections import deque
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Mapping, Sequence
from contextlib import AsyncExitStack
from copy import copy
from http import HTTPStatus
from typing import Final
from typing import Any, Final
from fastapi import Request
from openai.types.responses import (
......@@ -46,6 +46,7 @@ from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
get_tool_call_id_type,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.mcp.tool_server import ToolServer
......@@ -86,6 +87,7 @@ from vllm.entrypoints.openai.responses.protocol import (
ResponseCompletedEvent,
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseInputOutputItem,
ResponseInputOutputMessage,
ResponseReasoningPartAddedEvent,
ResponseReasoningPartDoneEvent,
......@@ -105,16 +107,19 @@ from vllm.entrypoints.openai.responses.utils import (
construct_tool_dicts,
extract_tool_types,
)
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ProcessorInputs, token_inputs
from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput
from vllm.parser import ParserManager
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
from vllm.utils import random_uuid
from vllm.utils.collection_utils import as_list
......@@ -165,6 +170,7 @@ class OpenAIServingResponses(OpenAIServing):
self,
engine_client: EngineClient,
models: OpenAIServingModels,
openai_serving_render: OpenAIServingRender,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
......@@ -185,6 +191,7 @@ class OpenAIServingResponses(OpenAIServing):
return_tokens_as_token_ids=return_tokens_as_token_ids,
)
self.openai_serving_render = openai_serving_render
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.enable_log_outputs = enable_log_outputs
......@@ -235,15 +242,7 @@ class OpenAIServingResponses(OpenAIServing):
get_stop_tokens_for_assistant_actions()
)
# Handle tool call ID type for Kimi K2 (supporting test mocking via overrides)
hf_overrides = getattr(self.model_config, "hf_overrides", None)
if self.model_config.hf_text_config.model_type == "kimi_k2" or (
isinstance(hf_overrides, dict)
and hf_overrides.get("model_type") == "kimi_k2"
):
self.tool_call_id_type = "kimi_k2"
else:
self.tool_call_id_type = "random"
self.tool_call_id_type = get_tool_call_id_type(self.model_config)
self.enable_auto_tools = enable_auto_tools
# HACK(woosuk): This is a hack. We should use a better store.
......@@ -587,7 +586,7 @@ class OpenAIServingResponses(OpenAIServing):
prev_response_output=prev_response.output if prev_response else None,
)
_, engine_prompts = await self._preprocess_chat(
_, engine_prompts = await self.openai_serving_render.preprocess_chat(
request,
messages,
default_template=self.chat_template,
......@@ -598,6 +597,109 @@ class OpenAIServingResponses(OpenAIServing):
)
return messages, engine_prompts
async def _render_next_turn(
self,
request: ResponsesRequest,
messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
new_messages = construct_input_messages(
request_input=messages,
)
_, engine_prompts = await self.openai_serving_render.preprocess_chat(
request,
new_messages,
default_template=chat_template,
default_template_content_format=chat_template_content_format,
default_template_kwargs=None,
tool_dicts=tool_dicts,
tool_parser=tool_parser,
)
return engine_prompts
async def _generate_with_builtin_tools(
self,
request_id: str,
engine_prompt: ProcessorInputs,
sampling_params: SamplingParams,
context: ConversationContext,
lora_request: LoRARequest | None = None,
priority: int = 0,
trace_headers: Mapping[str, str] | None = None,
):
max_model_len = self.model_config.max_model_len
orig_priority = priority
sub_request = 0
while True:
# Ensure that each sub-request has a unique request id.
sub_request_id = f"{request_id}_{sub_request}"
self._log_inputs(
sub_request_id,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
sub_request_id,
lora_request=lora_request,
trace_headers=trace_headers,
priority=priority,
)
async for res in generator:
context.append_output(res)
# NOTE(woosuk): The stop condition is handled by the engine.
yield context
if not context.need_builtin_tool_call():
# The model did not ask for a tool call, so we're done.
break
# Call the tool and update the context with the result.
tool_output = await context.call_tool()
context.append_tool_output(tool_output)
# TODO: uncomment this and enable tool output streaming
# yield context
# Create inputs for the next turn.
# Render the next prompt token ids and update sampling_params.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
token_ids = context.render_for_completion()
engine_prompt = token_inputs(token_ids)
sampling_params.max_tokens = max_model_len - len(token_ids)
elif isinstance(context, ParsableContext):
(engine_prompt,) = await self._render_next_turn(
context.request,
context.parser.response_messages,
context.tool_dicts,
context.tool_parser_cls,
context.chat_template,
context.chat_template_content_format,
)
sampling_params.max_tokens = get_max_tokens(
max_model_len,
context.request.max_output_tokens,
self._extract_prompt_len(engine_prompt),
self.default_sampling_params, # type: ignore
self.override_max_tokens, # type: ignore
)
# OPTIMIZATION
priority = orig_priority - 1
sub_request += 1
def _make_request_with_harmony(
self,
request: ResponsesRequest,
......@@ -771,6 +873,7 @@ class OpenAIServingResponses(OpenAIServing):
output=output,
status=status,
usage=usage,
kv_transfer_params=context.kv_transfer_params,
)
if request.store:
......@@ -903,6 +1006,7 @@ class OpenAIServingResponses(OpenAIServing):
parser = self.parser(tokenizer)
return parser.extract_response_outputs(
model_output=final_output.text,
model_output_token_ids=final_output.token_ids,
request=request,
enable_auto_tools=self.enable_auto_tools,
tool_call_id_type=self.tool_call_id_type,
......
......@@ -191,13 +191,13 @@ def _construct_single_message_from_response_item(
],
)
elif isinstance(item, ResponseReasoningItem):
reasoning_content = ""
reasoning = ""
if item.encrypted_content:
raise ValueError("Encrypted content is not supported.")
elif item.content and len(item.content) >= 1:
reasoning_content = item.content[0].text
reasoning = item.content[0].text
elif len(item.summary) >= 1:
reasoning_content = item.summary[0].text
reasoning = item.summary[0].text
logger.warning(
"Using summary text as reasoning content for item %s. "
"Please use content instead of summary for "
......@@ -206,7 +206,7 @@ def _construct_single_message_from_response_item(
)
return {
"role": "assistant",
"reasoning": reasoning_content,
"reasoning": reasoning,
}
elif isinstance(item, ResponseOutputMessage):
return {
......
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import base64
import sys
import tempfile
from argparse import Namespace
......@@ -13,6 +12,7 @@ from typing import Any, TypeAlias
from urllib.parse import urlparse
import aiohttp
import pybase64 as base64
import torch
from fastapi import UploadFile
from prometheus_client import start_http_server
......@@ -54,6 +54,7 @@ from vllm.entrypoints.pooling.score.protocol import (
ScoreResponse,
)
from vllm.entrypoints.utils import create_error_response
from vllm.exceptions import VLLMValidationError
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.utils import random_uuid
......@@ -86,9 +87,10 @@ class BatchTranscriptionRequest(TranscriptionRequest):
def validate_no_file(cls, data: Any):
"""Ensure file field is not provided in batch requests."""
if isinstance(data, dict) and "file" in data:
raise ValueError(
raise VLLMValidationError(
"The 'file' field is not supported in batch requests. "
"Use 'file_url' instead."
"Use 'file_url' instead.",
parameter="file",
)
return data
......@@ -116,9 +118,10 @@ class BatchTranslationRequest(TranslationRequest):
def validate_no_file(cls, data: Any):
"""Ensure file field is not provided in batch requests."""
if isinstance(data, dict) and "file" in data:
raise ValueError(
raise VLLMValidationError(
"The 'file' field is not supported in batch requests. "
"Use 'file_url' instead."
"Use 'file_url' instead.",
parameter="file",
)
return data
......@@ -820,7 +823,6 @@ async def main(args: Namespace):
async with build_async_engine_client(
args,
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
disable_frontend_multiprocessing=False,
) as engine_client:
await run_batch(engine_client, args)
......
......@@ -371,7 +371,7 @@ async def generation_error_handler(req: Request, exc: GenerationError):
async def exception_handler(req: Request, exc: Exception):
if req.app.state.args.log_error_stack:
logger.exception(
logger.error(
"Exception caught. Request id: %s",
req.state.request_metadata.request_id
if hasattr(req.state, "request_metadata")
......
......@@ -107,7 +107,7 @@ class TranscriptionRequest(OpenAIBaseModel):
stream_include_usage: bool | None = False
stream_continuous_usage_stats: bool | None = False
vllm_xargs: dict[str, str | int | float] | None = Field(
vllm_xargs: dict[str, str | int | float | bool] | None = Field(
default=None,
description=(
"Additional request parameters with string or "
......
......@@ -42,32 +42,13 @@ from vllm.inputs import EncoderDecoderInputs, ProcessorInputs
from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription
from vllm.multimodal.audio import split_audio
from vllm.multimodal.media.audio import extract_audio_from_video_bytes
from vllm.multimodal.audio import get_audio_duration, split_audio
from vllm.multimodal.media.audio import load_audio
from vllm.outputs import RequestOutput
from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import get_tokenizer
from vllm.utils.import_utils import PlaceholderModule
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile as sf
except ImportError:
sf = PlaceholderModule("soundfile") # type: ignore[assignment]
# Public libsndfile error codes exposed via `soundfile.LibsndfileError.code`, soundfile
# being librosa's main backend. Used to validate if an audio loading error is due to a
# server error vs a client error (invalid audio file).
# 1 = unrecognised format (file is not a supported audio container)
# 3 = malformed file (corrupt or structurally invalid audio)
# 4 = unsupported encoding (codec not supported by this libsndfile build)
_BAD_SF_CODES = {1, 3, 4}
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
SpeechToTextResponseVerbose: TypeAlias = (
......@@ -214,32 +195,13 @@ class OpenAISpeechToText(OpenAIServing):
# pre-requisite for chunking, as it assumes Whisper SR.
try:
with io.BytesIO(audio_data) as buf:
y, sr = librosa.load(buf, sr=self.asr_config.sample_rate) # type: ignore[return-value]
except sf.LibsndfileError as exc:
# Only fall back for known format-detection failures.
# Re-raise anything else (e.g. corrupt but recognised format).
if exc.code not in _BAD_SF_CODES:
raise
logger.debug(
"librosa/soundfile could not decode audio from BytesIO "
"(code=%s: %s); falling back to pyav in-process decode",
exc.code,
exc,
)
try:
native_y, native_sr = extract_audio_from_video_bytes(audio_data)
sr = self.asr_config.sample_rate
y = librosa.resample(native_y, orig_sr=native_sr, target_sr=sr)
except Exception as pyav_exc:
logger.debug(
"pyAV fallback also failed: %s",
pyav_exc,
)
raise ValueError("Invalid or unsupported audio file.") from pyav_exc
y, sr = load_audio(buf, sr=self.asr_config.sample_rate)
except Exception as exc:
raise ValueError("Invalid or unsupported audio file.") from exc
duration = librosa.get_duration(y=y, sr=sr)
do_split_audio = (
self.asr_config.allow_audio_chunking
duration = get_audio_duration(y=y, sr=sr)
do_split_audio = self.asr_config.allow_audio_chunking and (
self.asr_config.max_audio_clip_s is not None
and duration > self.asr_config.max_audio_clip_s
)
......
......@@ -5,6 +5,9 @@ from typing import TYPE_CHECKING
from fastapi import FastAPI
from vllm.config import ModelConfig
from vllm.logger import init_logger
if TYPE_CHECKING:
from argparse import Namespace
......@@ -17,9 +20,30 @@ else:
RequestLogger = object
SupportedTask = object
logger = init_logger(__name__)
def enable_scoring_api(
supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None,
) -> bool:
if any(t in supported_tasks for t in ("embed", "token_embed")):
return True
if model_config is not None and "classify" in supported_tasks:
num_labels = getattr(model_config.hf_config, "num_labels", 0)
if num_labels != 1:
logger.debug_once("Score API is only enabled for num_labels == 1.")
return False
return True
return False
def register_pooling_api_routers(
app: FastAPI, supported_tasks: tuple["SupportedTask", ...]
app: FastAPI,
supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None,
):
from vllm.entrypoints.pooling.pooling.api_router import router as pooling_router
......@@ -37,11 +61,7 @@ def register_pooling_api_routers(
app.include_router(embed_router)
# Score API handles score/rerank for:
# - "score" task (score_type: cross-encoder models)
# - "embed" task (score_type: bi-encoder models)
# - "token_embed" task (score_type: late interaction models)
if any(t in supported_tasks for t in ("score", "embed", "token_embed")):
if enable_scoring_api(supported_tasks, model_config):
from vllm.entrypoints.pooling.score.api_router import router as score_router
app.include_router(score_router)
......@@ -61,6 +81,8 @@ def init_pooling_state(
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.tasks import POOLING_TASKS
model_config = engine_client.model_config
resolved_chat_template = load_chat_template(args.chat_template)
state.serving_pooling = (
......@@ -68,6 +90,7 @@ def init_pooling_state(
OpenAIServingPooling(
engine_client,
state.openai_serving_models,
state.openai_serving_render,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
......@@ -101,10 +124,6 @@ def init_pooling_state(
if "classify" in supported_tasks
else None
)
# Score API handles score/rerank for:
# - "score" task (score_type: cross-encoder models)
# - "embed" task (score_type: bi-encoder models)
# - "token_embed" task (score_type: late interaction models)
state.serving_scores = (
ServingScores(
engine_client,
......@@ -113,6 +132,6 @@ def init_pooling_state(
score_template=resolved_chat_template,
log_error_stack=args.log_error_stack,
)
if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
if enable_scoring_api(supported_tasks, model_config)
else None
)
......@@ -11,6 +11,7 @@ from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.openai.engine.protocol import OpenAIBaseModel
from vllm.exceptions import VLLMValidationError
from vllm.renderers import ChatParams, merge_kwargs
from vllm.utils import random_uuid
from vllm.utils.serial_utils import EmbedDType, EncodingFormat, Endianness
......@@ -147,9 +148,9 @@ class ChatRequestMixin(OpenAIBaseModel):
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get("add_generation_prompt"):
raise ValueError(
raise VLLMValidationError(
"Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True."
"`add_generation_prompt` to True.",
)
return data
......
......@@ -6,13 +6,13 @@ OpenAI: https://platform.openai.com/docs/api-reference/embeddings
Cohere: https://docs.cohere.com/reference/embed
"""
import base64
import builtins
import struct
import time
from collections.abc import Sequence
from typing import Literal, TypeAlias
import pybase64 as base64
from pydantic import BaseModel, Field
from vllm import PoolingParams
......
......@@ -23,7 +23,7 @@ def init_pooling_io_processors(
if "embed" in supported_tasks:
from vllm.entrypoints.pooling.embed.io_processor import EmbedIOProcessor
processors.append(("classify", EmbedIOProcessor))
processors.append(("embed", EmbedIOProcessor))
return {
task: processor_cls(
......
......@@ -32,6 +32,7 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_base64,
encode_pooling_output_float,
)
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.inputs import ProcessorInputs
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
......@@ -47,6 +48,7 @@ class OpenAIServingPooling(OpenAIServing):
self,
engine_client: EngineClient,
models: OpenAIServingModels,
openai_serving_render: OpenAIServingRender,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
......@@ -59,6 +61,7 @@ class OpenAIServingPooling(OpenAIServing):
request_logger=request_logger,
)
self.openai_serving_render = openai_serving_render
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
......@@ -101,12 +104,12 @@ class OpenAIServingPooling(OpenAIServing):
raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id
)
engine_prompts = await self._preprocess_cmpl(
engine_prompts = await self.openai_serving_render.preprocess_cmpl(
request,
prompt_to_seq(raw_prompts),
)
elif isinstance(request, PoolingChatRequest):
error_check_ret = self._validate_chat_template(
error_check_ret = self.openai_serving_render.validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
......@@ -114,7 +117,7 @@ class OpenAIServingPooling(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
_, engine_prompts = await self._preprocess_chat(
_, engine_prompts = await self.openai_serving_render.preprocess_chat(
request,
request.messages,
default_template=self.chat_template,
......@@ -122,7 +125,7 @@ class OpenAIServingPooling(OpenAIServing):
default_template_kwargs=None,
)
elif isinstance(request, PoolingCompletionRequest):
engine_prompts = await self._preprocess_completion(
engine_prompts = await self.openai_serving_render.preprocess_completion(
request,
prompt_input=request.input,
prompt_embeds=None,
......
......@@ -35,7 +35,7 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self, task: PoolingTask = "score"):
def to_pooling_params(self, task: PoolingTask = "classify"):
return PoolingParams(
task=task,
use_activation=self.use_activation,
......@@ -111,7 +111,7 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
max_total_tokens_param="max_model_len",
)
def to_pooling_params(self, task: PoolingTask = "score"):
def to_pooling_params(self, task: PoolingTask = "classify"):
return PoolingParams(
task=task,
use_activation=self.use_activation,
......
......@@ -413,7 +413,7 @@ class ServingScores(OpenAIServing):
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
default_pooling_params = request.to_pooling_params("score")
default_pooling_params = request.to_pooling_params("classify")
for i, engine_prompt in enumerate(engine_prompts):
request_id_item = f"{request_id}-{i}"
......
......@@ -60,14 +60,6 @@ def encode_pooling_output_float(output: PoolingRequestOutput) -> list[float]:
return output.outputs.data.tolist()
def encode_pooling_output_binary(
output: PoolingRequestOutput,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> bytes:
return tensor2binary(output.outputs.data, embed_dtype, endianness)
def encode_pooling_output_base64(
output: PoolingRequestOutput,
embed_dtype: EmbedDType,
......
......@@ -10,9 +10,11 @@ import pydantic
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response
from vllm.config import ModelConfig
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling import enable_scoring_api
from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.serve.instrumentator.basic import base
from vllm.entrypoints.serve.instrumentator.health import health
......@@ -25,7 +27,10 @@ GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
def get_invocation_types(supported_tasks: tuple["SupportedTask", ...]):
def get_invocation_types(
supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None,
):
# NOTE: Items defined earlier take higher priority
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = []
......@@ -70,7 +75,7 @@ def get_invocation_types(supported_tasks: tuple["SupportedTask", ...]):
(ClassificationRequest, (classify, create_classify)),
]
if "score" in supported_tasks:
if enable_scoring_api(supported_tasks, model_config):
from vllm.entrypoints.pooling.score.api_router import do_rerank, rerank
from vllm.entrypoints.pooling.score.protocol import RerankRequest
......@@ -78,7 +83,6 @@ def get_invocation_types(supported_tasks: tuple["SupportedTask", ...]):
(RerankRequest, (rerank, do_rerank)),
]
if "score" in supported_tasks or "embed" in supported_tasks:
from vllm.entrypoints.pooling.score.api_router import create_score, score
from vllm.entrypoints.pooling.score.protocol import ScoreRequest
......@@ -97,11 +101,15 @@ def get_invocation_types(supported_tasks: tuple["SupportedTask", ...]):
return INVOCATION_TYPES
def attach_router(app: FastAPI, supported_tasks: tuple["SupportedTask", ...]):
def attach_router(
app: FastAPI,
supported_tasks: tuple["SupportedTask", ...],
model_config: ModelConfig | None = None,
):
router = APIRouter()
# NOTE: Construct the TypeAdapters only once
INVOCATION_TYPES = get_invocation_types(supported_tasks)
INVOCATION_TYPES = get_invocation_types(supported_tasks, model_config)
INVOCATION_VALIDATORS = [
(pydantic.TypeAdapter(request_type), (get_handler, endpoint))
for request_type, (get_handler, endpoint) in INVOCATION_TYPES
......
......@@ -29,6 +29,7 @@ from vllm.entrypoints.serve.disagg.protocol import (
GenerateResponse,
GenerateResponseChoice,
)
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
......@@ -45,6 +46,7 @@ class ServingTokens(OpenAIServing):
self,
engine_client: EngineClient,
models: OpenAIServingModels,
openai_serving_render: OpenAIServingRender,
*,
request_logger: RequestLogger | None,
force_no_detokenize: bool = False,
......@@ -58,6 +60,7 @@ class ServingTokens(OpenAIServing):
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
)
self.openai_serving_render = openai_serving_render
self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_log_outputs = enable_log_outputs
self.force_no_detokenize = force_no_detokenize
......@@ -96,7 +99,7 @@ class ServingTokens(OpenAIServing):
if raw_request:
raw_request.state.request_metadata = request_metadata
engine_prompts = await self._preprocess_completion(
engine_prompts = await self.openai_serving_render.preprocess_completion(
request,
prompt_input=request.token_ids,
prompt_embeds=None,
......
......@@ -24,6 +24,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
parse_chat_inputs_to_harmony_messages,
render_for_completion,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
MultiModalFeatures,
......@@ -226,7 +227,7 @@ class OpenAIServingRender:
if not self.use_harmony:
# Common case.
error_check_ret = self._validate_chat_template(
error_check_ret = self.validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
......@@ -234,7 +235,7 @@ class OpenAIServingRender:
if error_check_ret is not None:
return error_check_ret
conversation, engine_prompts = await self._preprocess_chat(
conversation, engine_prompts = await self.preprocess_chat(
request,
request.messages,
default_template=self.chat_template,
......@@ -328,7 +329,7 @@ class OpenAIServingRender:
"prompt_logprobs is not compatible with prompt embeds."
)
engine_prompts = await self._preprocess_completion(
engine_prompts = await self.preprocess_completion(
request,
prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds,
......@@ -426,7 +427,7 @@ class OpenAIServingRender:
) -> ErrorResponse | None:
return await self.model_registry.check_model(request.model)
def _validate_chat_template(
def validate_chat_template(
self,
request_chat_template: str | None,
chat_template_kwargs: dict[str, Any] | None,
......@@ -447,7 +448,7 @@ class OpenAIServingRender:
)
return None
async def _preprocess_completion(
async def preprocess_completion(
self,
request: Any,
prompt_input: str | list[str] | list[int] | list[list[int]] | None,
......@@ -459,9 +460,9 @@ class OpenAIServingRender:
prompts.extend(prompt_to_seq(prompt_embeds))
if prompt_input is not None:
prompts.extend(prompt_to_seq(prompt_input))
return await self._preprocess_cmpl(request, prompts)
return await self.preprocess_cmpl(request, prompts)
async def _preprocess_cmpl(
async def preprocess_cmpl(
self,
request: Any,
prompts: Sequence[PromptType | bytes],
......@@ -490,7 +491,7 @@ class OpenAIServingRender:
},
)
async def _preprocess_chat(
async def preprocess_chat(
self,
request: Any,
messages: list[Any],
......@@ -500,11 +501,7 @@ class OpenAIServingRender:
tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]]:
"""Copied from OpenAIServing._preprocess_chat.
Differences: isinstance check is ChatCompletionRequest-only
(ResponsesRequest not supported here); TODO comment dropped accordingly.
"""
"""Copied from OpenAIServing._preprocess_chat."""
renderer = self.renderer
mm_config = self.model_config.multimodal_config
......@@ -542,11 +539,11 @@ class OpenAIServingRender:
if tool_parser is not None:
tool_choice = getattr(request, "tool_choice", "none")
if tool_choice != "none":
if not isinstance(request, ChatCompletionRequest):
if not isinstance(request, ChatCompletionRequest | ResponsesRequest):
msg = (
"Tool usage is only supported "
" for ChatCompletionRequest, but got "
f"{type(request).__name__}"
"for Chat Completions API or Responses API requests, "
f"but got {type(request).__name__}"
)
raise NotImplementedError(msg)
tokenizer = renderer.get_tokenizer()
......
......@@ -17,6 +17,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import (
from vllm.entrypoints.openai.engine.protocol import (
OpenAIBaseModel,
)
from vllm.exceptions import VLLMValidationError
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
......@@ -120,9 +121,9 @@ class TokenizeChatRequest(OpenAIBaseModel):
@classmethod
def check_generation_prompt(cls, data):
if data.get("continue_final_message") and data.get("add_generation_prompt"):
raise ValueError(
raise VLLMValidationError(
"Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True."
"`add_generation_prompt` to True.",
)
return data
......
......@@ -11,6 +11,7 @@ from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.serve.tokenize.protocol import (
DetokenizeRequest,
DetokenizeResponse,
......@@ -31,6 +32,7 @@ class OpenAIServingTokenization(OpenAIServing):
self,
engine_client: EngineClient,
models: OpenAIServingModels,
openai_serving_render: OpenAIServingRender,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
......@@ -44,6 +46,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_logger=request_logger,
)
self.openai_serving_render = openai_serving_render
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.default_chat_template_kwargs = default_chat_template_kwargs or {}
......@@ -68,7 +71,7 @@ class OpenAIServingTokenization(OpenAIServing):
if request.tools is None
else [tool.model_dump() for tool in request.tools]
)
error_check_ret = self._validate_chat_template(
error_check_ret = self.openai_serving_render.validate_chat_template(
request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template,
......@@ -76,7 +79,7 @@ class OpenAIServingTokenization(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
_, engine_prompts = await self._preprocess_chat(
_, engine_prompts = await self.openai_serving_render.preprocess_chat(
request,
request.messages,
default_template=self.chat_template,
......@@ -85,7 +88,7 @@ class OpenAIServingTokenization(OpenAIServing):
tool_dicts=tool_dicts,
)
else:
engine_prompts = await self._preprocess_completion(
engine_prompts = await self.openai_serving_render.preprocess_completion(
request,
prompt_input=request.prompt,
prompt_embeds=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment