Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
......@@ -12,10 +12,10 @@ from openai.types.responses.response_reasoning_item import (
)
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
from vllm.outputs import CompletionOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers.protocol import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
......
......@@ -320,6 +320,7 @@ class ResponsesRequest(OpenAIBaseModel):
max_tool_calls: int | None = None
metadata: Metadata | None = None
model: str | None = None
logit_bias: dict[str, float] | None = None
parallel_tool_calls: bool | None = True
previous_response_id: str | None = None
prompt: ResponsePrompt | None = None
......@@ -333,6 +334,7 @@ class ResponsesRequest(OpenAIBaseModel):
tools: list[Tool] = Field(default_factory=list)
top_logprobs: int | None = 0
top_p: float | None = None
top_k: int | None = None
truncation: Literal["auto", "disabled"] | None = "disabled"
user: str | None = None
......@@ -387,6 +389,7 @@ class ResponsesRequest(OpenAIBaseModel):
_DEFAULT_SAMPLING_PARAMS = {
"temperature": 1.0,
"top_p": 1.0,
"top_k": 0,
}
def to_sampling_params(
......@@ -408,6 +411,10 @@ class ResponsesRequest(OpenAIBaseModel):
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
)
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
)
stop_token_ids = default_sampling_params.get("stop_token_ids")
# Structured output
......@@ -428,6 +435,7 @@ class ResponsesRequest(OpenAIBaseModel):
return SamplingParams.from_optional(
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_tokens=max_tokens,
logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
stop_token_ids=stop_token_ids,
......@@ -435,6 +443,7 @@ class ResponsesRequest(OpenAIBaseModel):
RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
),
structured_outputs=structured_outputs,
logit_bias=self.logit_bias,
)
def is_include_output_logprobs(self) -> bool:
......
......@@ -27,8 +27,8 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
get_stop_tokens_for_assistant_actions,
get_streamable_parser_for_assistant,
get_system_message,
parse_chat_inputs_to_harmony_messages,
parse_chat_output,
parse_input_to_harmony_message,
render_for_completion,
)
from vllm.entrypoints.openai.protocol import (
......@@ -51,13 +51,15 @@ from vllm.entrypoints.openai.protocol import (
ToolCall,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_engine import (
GenerationError,
OpenAIServing,
clamp_prompt_logprobs,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput
......@@ -69,6 +71,8 @@ from vllm.tokenizers.mistral import (
truncate_tool_call_ids,
validate_request_params,
)
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
......@@ -230,11 +234,7 @@ class OpenAIServingChat(OpenAIServing):
)
if error_check_ret is not None:
return error_check_ret
(
conversation,
request_prompts,
engine_prompts,
) = await self._preprocess_chat(
conversation, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
request.messages,
......@@ -250,11 +250,7 @@ class OpenAIServingChat(OpenAIServing):
)
else:
# For GPT-OSS.
(
conversation,
request_prompts,
engine_prompts,
) = self._make_request_with_harmony(request)
conversation, engine_prompts = self._make_request_with_harmony(request)
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}")
......@@ -274,7 +270,7 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = []
try:
for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = self._get_prompt_components(request_prompts[i])
prompt_text, _, _ = self._get_prompt_components(engine_prompt)
# If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids.
sub_request_id = (
......@@ -309,7 +305,7 @@ class OpenAIServingChat(OpenAIServing):
self._log_inputs(
sub_request_id,
request_prompts[i],
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
......@@ -380,6 +376,8 @@ class OpenAIServingChat(OpenAIServing):
tokenizer,
request_metadata,
)
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
......@@ -531,7 +529,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str,
model_name: str,
conversation: list[ConversationMessage],
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]:
created_time = int(time.time())
......@@ -585,6 +583,11 @@ class OpenAIServingChat(OpenAIServing):
try:
if self.reasoning_parser:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
reasoning_parser = self.reasoning_parser(
tokenizer,
chat_template_kwargs=request.chat_template_kwargs, # type: ignore
......@@ -598,6 +601,11 @@ class OpenAIServingChat(OpenAIServing):
# Prepare the tool parser if it's needed
try:
if tool_choice_auto and self.tool_parser:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
tool_parsers: list[ToolParser | None] = [
self.tool_parser(tokenizer)
] * num_choices
......@@ -816,6 +824,9 @@ class OpenAIServingChat(OpenAIServing):
if delta_message is not None:
harmony_tools_streamed[i] = True
elif cur_channel == "commentary":
# Tool call preambles meant to be shown to the user
delta_message = DeltaMessage(content=delta_text)
else:
delta_message = None
# handle streaming deltas for tools with named tool_choice
......@@ -953,21 +964,9 @@ class OpenAIServingChat(OpenAIServing):
assert reasoning_end_arr is not None
output_token_ids = as_list(output.token_ids)
if not reasoning_end_arr[i]:
delta_message = (
reasoning_parser.extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output_token_ids,
)
)
# When encountering think end id in prompt_token_ids
# i.e {"enable_thinking": False},
# set reasoning status to end.
# Remove the text and token ids related
# to 'reasoning'.
if (
res.prompt_token_ids
and reasoning_parser.is_reasoning_end(
......@@ -976,30 +975,38 @@ class OpenAIServingChat(OpenAIServing):
):
reasoning_end_arr[i] = True
current_token_ids = output_token_ids
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
# When encountering think end id in delta_token_ids,
# set reasoning status to end.
# Remove the text and token ids related
# to 'reasoning'.
if reasoning_parser.is_reasoning_end(output_token_ids):
reasoning_end_arr[i] = True
current_token_ids = (
reasoning_parser.extract_content_ids(
output_token_ids
# Don't update current_text, keep it as is from delta
else:
delta_message = (
reasoning_parser.extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output_token_ids,
)
)
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
# When encountering think end id in delta_token_ids,
# set reasoning status to end.
# Remove the text and token ids related
# to 'reasoning'.
if reasoning_parser.is_reasoning_end(output_token_ids):
reasoning_end_arr[i] = True
current_token_ids = (
reasoning_parser.extract_content_ids(
output_token_ids
)
)
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
# handle tool calls only after reasoning is done,
else:
if reasoning_end_arr[i]:
delta_token_ids = output_token_ids
# First time to tool call,
# add the remaining text and token ids
......@@ -1120,6 +1127,10 @@ class OpenAIServingChat(OpenAIServing):
# if the model is finished generating
else:
# check for error finish reason and abort streaming
# finish_reason='error' indicates a retryable error
self._raise_if_error(output.finish_reason, request_id)
# check to make sure we haven't "forgotten" to stream
# any tokens that were generated but previously
# matched by partial json parsing
......@@ -1287,6 +1298,8 @@ class OpenAIServingChat(OpenAIServing):
delta=False,
)
except GenerationError as e:
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in chat completion stream generator.")
......@@ -1302,7 +1315,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str,
model_name: str,
conversation: list[ConversationMessage],
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata,
) -> ErrorResponse | ChatCompletionResponse:
created_time = int(time.time())
......@@ -1327,6 +1340,9 @@ class OpenAIServingChat(OpenAIServing):
role = self.get_chat_request_role(request)
for output in final_res.outputs:
# check for error finish reason and raise GenerationError
# finish_reason='error' indicates a retryable request-level internal error
self._raise_if_error(output.finish_reason, request_id)
token_ids = output.token_ids
out_logprobs = output.logprobs
tool_call_info = None
......@@ -1349,6 +1365,11 @@ class OpenAIServingChat(OpenAIServing):
reasoning = None
if self.tool_parser is not None:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
tool_parser = self.tool_parser(tokenizer)
# NOTE: We use token_ids for openai tool parser
tool_call_info = tool_parser.extract_tool_calls(
......@@ -1391,6 +1412,11 @@ class OpenAIServingChat(OpenAIServing):
if self.reasoning_parser:
try:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
reasoning_parser = self.reasoning_parser(
tokenizer,
chat_template_kwargs=request.chat_template_kwargs, # type: ignore
......@@ -1630,7 +1656,7 @@ class OpenAIServingChat(OpenAIServing):
self,
logprobs: dict[int, Logprob],
top_logprobs: int | None,
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
should_return_as_token_id: bool,
) -> list[ChatCompletionLogProb]:
return [
......@@ -1654,7 +1680,7 @@ class OpenAIServingChat(OpenAIServing):
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None],
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
num_output_top_logprobs: int | None = None,
return_as_token_id: bool | None = None,
) -> ChatCompletionLogProbs:
......@@ -1672,6 +1698,11 @@ class OpenAIServingChat(OpenAIServing):
if should_return_as_token_id:
token = f"token_id:{token_id}"
else:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
token = tokenizer.decode(token_id)
logprobs_content.append(
......@@ -1755,6 +1786,11 @@ class OpenAIServingChat(OpenAIServing):
):
messages: list[OpenAIMessage] = []
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request)
# Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default
# if the model supports it. TODO: Support browsing.
......@@ -1773,15 +1809,14 @@ class OpenAIServingChat(OpenAIServing):
messages.append(dev_msg)
# Add user message.
for chat_msg in request.messages:
messages.extend(parse_input_to_harmony_message(chat_msg))
messages.extend(parse_chat_inputs_to_harmony_messages(request.messages))
# Render prompt token ids.
prompt_token_ids = render_for_completion(messages)
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
# Add cache_salt if provided in the request
if request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
return messages, [prompt_token_ids], [engine_prompt]
return messages, [engine_prompt]
......@@ -24,7 +24,11 @@ from vllm.entrypoints.openai.protocol import (
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_engine import (
GenerationError,
OpenAIServing,
clamp_prompt_logprobs,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage
......@@ -300,6 +304,8 @@ class OpenAIServingCompletion(OpenAIServing):
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
......@@ -437,6 +443,8 @@ class OpenAIServingCompletion(OpenAIServing):
finish_reason = output.finish_reason
stop_reason = output.stop_reason
self._raise_if_error(finish_reason, request_id)
chunk = CompletionStreamResponse(
id=request_id,
created=created_time,
......@@ -498,8 +506,11 @@ class OpenAIServingCompletion(OpenAIServing):
# report to FastAPI middleware aggregate usage across all choices
request_metadata.final_usage_info = final_usage_info
except GenerationError as e:
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in completion stream generator.")
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
......@@ -530,6 +541,8 @@ class OpenAIServingCompletion(OpenAIServing):
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
for output in final_res.outputs:
self._raise_if_error(output.finish_reason, request_id)
assert request.max_tokens is not None
if request.echo:
if request.return_token_ids:
......
......@@ -5,60 +5,19 @@ import json
import sys
import time
import traceback
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
import numpy as np
import torch
from fastapi import Request
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
from typing_extensions import TypeIs
from vllm.entrypoints.context import (
HarmonyContext,
ParsableContext,
StreamingHarmonyContext,
)
from vllm.entrypoints.openai.protocol import (
FunctionCall,
ResponseInputOutputItem,
ResponsesRequest,
)
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
ScoreRequest,
ScoreResponse,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
from openai.types.responses import (
ToolChoiceFunction,
)
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
......@@ -72,7 +31,12 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages_futures,
resolve_chat_template_content_format,
)
from vllm.entrypoints.context import ConversationContext
from vllm.entrypoints.context import (
ConversationContext,
HarmonyContext,
ParsableContext,
StreamingHarmonyContext,
)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam,
......@@ -83,7 +47,10 @@ from vllm.entrypoints.openai.protocol import (
DetokenizeRequest,
ErrorInfo,
ErrorResponse,
FunctionCall,
FunctionDefinition,
ResponseInputOutputItem,
ResponsesRequest,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeResponse,
......@@ -92,15 +59,34 @@ from vllm.entrypoints.openai.protocol import (
TranslationRequest,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
ScoreRequest,
ScoreResponse,
)
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
from vllm.entrypoints.responses_utils import (
construct_input_messages,
)
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import PromptType
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.parse import (
PromptComponents,
get_prompt_components,
......@@ -109,15 +95,15 @@ from vllm.inputs.parse import (
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
MultiModalDataDict,
MultiModalUUIDDict,
)
from vllm.multimodal import MultiModalDataDict
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import DeepseekV32Tokenizer, MistralTokenizer, TokenizerLike
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers import ToolParser, ToolParserManager
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
......@@ -133,6 +119,15 @@ from vllm.utils.async_utils import (
from vllm.utils.collection_utils import is_list_of
from vllm.v1.engine import EngineCoreRequest
class GenerationError(Exception):
"""raised when finish_reason indicates internal server error (500)"""
def __init__(self, message: str = "Internal server error"):
super().__init__(message)
self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
logger = init_logger(__name__)
CompletionLikeRequest: TypeAlias = (
......@@ -174,34 +169,6 @@ AnyResponse: TypeAlias = (
)
class TextTokensPrompt(TypedDict):
prompt: str
prompt_token_ids: list[int]
class EmbedsPrompt(TypedDict):
prompt_embeds: torch.Tensor
RequestPrompt: TypeAlias = list[int] | str | TextTokensPrompt | EmbedsPrompt
def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
return (
isinstance(prompt, dict)
and "prompt_token_ids" in prompt
and "prompt_embeds" not in prompt
)
def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
return (
isinstance(prompt, dict)
and "prompt_token_ids" not in prompt
and "prompt_embeds" in prompt
)
RequestT = TypeVar("RequestT", bound=AnyRequest)
......@@ -212,8 +179,7 @@ class RequestProcessingMixin:
handling prompt preparation and engine input.
"""
request_prompts: Sequence[RequestPrompt] | None = field(default_factory=list)
engine_prompts: list[EngineTokensPrompt] | None = field(default_factory=list)
engine_prompts: list[TokensPrompt] | None = field(default_factory=list)
@dataclass(kw_only=True)
......@@ -414,7 +380,7 @@ class OpenAIServing:
prompts_batch, lora_req_batch = zip(
*[
(
EngineTokensPrompt(
TokensPrompt(
prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs,
......@@ -456,6 +422,29 @@ class OpenAIServing:
# Iterate through all beam inference results
for i, result in enumerate(output):
current_beam = all_beams[i]
# check for error finish reason and abort beam search
if result.outputs[0].finish_reason == "error":
# yield error output and terminate beam search
yield RequestOutput(
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
],
finished=True,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None,
)
return
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
all_beams_token_id.extend(list(logprobs.keys()))
......@@ -780,6 +769,35 @@ class OpenAIServing:
)
return json_str
def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
"""Raise GenerationError if finish_reason indicates an error."""
if finish_reason == "error":
logger.error(
"Request %s failed with an internal error during generation",
request_id,
)
raise GenerationError("Internal server error")
def _convert_generation_error_to_response(
self, e: GenerationError
) -> ErrorResponse:
"""Convert GenerationError to ErrorResponse."""
return self.create_error_response(
str(e),
err_type="InternalServerError",
status_code=e.status_code,
)
def _convert_generation_error_to_streaming_response(
self, e: GenerationError
) -> str:
"""Convert GenerationError to streaming error response."""
return self.create_streaming_error_response(
str(e),
err_type="InternalServerError",
status_code=e.status_code,
)
async def _check_model(
self,
request: AnyRequest,
......@@ -884,7 +902,7 @@ class OpenAIServing:
prompt: str,
tokenizer: TokenizerLike,
add_special_tokens: bool,
) -> TextTokensPrompt:
) -> TokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer)
if (
......@@ -925,7 +943,7 @@ class OpenAIServing:
request: AnyRequest,
prompt_ids: list[int],
tokenizer: TokenizerLike | None,
) -> TextTokensPrompt:
) -> TokensPrompt:
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
if truncate_prompt_tokens is None:
......@@ -948,7 +966,7 @@ class OpenAIServing:
request: AnyRequest,
input_ids: list[int],
input_text: str,
) -> TextTokensPrompt:
) -> TokensPrompt:
token_num = len(input_ids)
# Note: EmbeddingRequest, ClassificationRequest,
......@@ -979,7 +997,7 @@ class OpenAIServing:
f"{token_num} tokens in the input for {operation}. "
f"Please reduce the length of the input."
)
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
......@@ -987,7 +1005,7 @@ class OpenAIServing:
request,
(TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
):
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# chat completion endpoint supports max_completion_tokens
if isinstance(request, ChatCompletionRequest):
......@@ -1015,7 +1033,7 @@ class OpenAIServing:
f" - {token_num})."
)
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
async def _tokenize_prompt_input_async(
self,
......@@ -1023,7 +1041,7 @@ class OpenAIServing:
tokenizer: TokenizerLike,
prompt_input: str | list[int],
add_special_tokens: bool = True,
) -> TextTokensPrompt:
) -> TokensPrompt:
"""
A simpler implementation that tokenizes a single prompt input.
"""
......@@ -1042,7 +1060,7 @@ class OpenAIServing:
tokenizer: TokenizerLike,
prompt_inputs: Iterable[str | list[int]],
add_special_tokens: bool = True,
) -> AsyncGenerator[TextTokensPrompt, None]:
) -> AsyncGenerator[TokensPrompt, None]:
"""
A simpler implementation that tokenizes multiple prompt inputs.
"""
......@@ -1095,11 +1113,7 @@ class OpenAIServing:
chat_template_kwargs: dict[str, Any] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
add_special_tokens: bool = False,
) -> tuple[
list[ConversationMessage],
Sequence[RequestPrompt],
list[EngineTokensPrompt],
]:
) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format(
......@@ -1172,9 +1186,7 @@ class OpenAIServing:
"Prompt has to be a string",
"when the tokenizer is not initialised",
)
prompt_inputs = TextTokensPrompt(
prompt=request_prompt, prompt_token_ids=[1]
)
prompt_inputs = TokensPrompt(prompt=request_prompt, prompt_token_ids=[1])
elif isinstance(request_prompt, str):
prompt_inputs = await self._tokenize_prompt_input_async(
request,
......@@ -1187,14 +1199,15 @@ class OpenAIServing:
assert is_list_of(request_prompt, int), (
"Prompt has to be either a string or a list of token ids"
)
prompt_inputs = TextTokensPrompt(
prompt_inputs = TokensPrompt(
prompt=tokenizer.decode(request_prompt),
prompt_token_ids=request_prompt,
)
engine_prompt = EngineTokensPrompt(
prompt_token_ids=prompt_inputs["prompt_token_ids"]
)
engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"])
if "prompt" in prompt_inputs:
engine_prompt["prompt"] = prompt_inputs["prompt"]
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
......@@ -1207,7 +1220,7 @@ class OpenAIServing:
if hasattr(request, "cache_salt") and request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
return conversation, [request_prompt], [engine_prompt]
return conversation, [engine_prompt]
async def _process_inputs(
self,
......@@ -1239,7 +1252,7 @@ class OpenAIServing:
async def _render_next_turn(
self,
request: ResponsesRequest,
tokenizer: AnyTokenizer,
tokenizer: TokenizerLike | None,
messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None,
tool_parser,
......@@ -1250,7 +1263,7 @@ class OpenAIServing:
request_input=messages,
)
_, request_prompts, engine_prompts = await self._preprocess_chat(
_, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
new_messages,
......@@ -1259,20 +1272,20 @@ class OpenAIServing:
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
)
return request_prompts, engine_prompts
return engine_prompts
async def _generate_with_builtin_tools(
self,
request_id: str,
request_prompt: RequestPrompt,
engine_prompt: EngineTokensPrompt,
engine_prompt: TokensPrompt,
sampling_params: SamplingParams,
context: ConversationContext,
lora_request: LoRARequest | None = None,
priority: int = 0,
**kwargs,
):
prompt_text, _, _ = self._get_prompt_components(request_prompt)
prompt_text, _, _ = self._get_prompt_components(engine_prompt)
orig_priority = priority
sub_request = 0
while True:
......@@ -1280,7 +1293,7 @@ class OpenAIServing:
sub_request_id = f"{request_id}_{sub_request}"
self._log_inputs(
sub_request_id,
request_prompt,
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
......@@ -1325,10 +1338,9 @@ class OpenAIServing:
# Render the next prompt token ids.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
prompt_token_ids = context.render_for_completion()
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
request_prompt = prompt_token_ids
engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
elif isinstance(context, ParsableContext):
request_prompts, engine_prompts = await self._render_next_turn(
engine_prompts = await self._render_next_turn(
context.request,
context.tokenizer,
context.parser.response_messages,
......@@ -1338,8 +1350,7 @@ class OpenAIServing:
context.chat_template_content_format,
)
engine_prompt = engine_prompts[0]
request_prompt = request_prompts[0]
prompt_text, _, _ = self._get_prompt_components(request_prompt)
prompt_text, _, _ = self._get_prompt_components(engine_prompt)
# Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len(
......@@ -1349,19 +1360,13 @@ class OpenAIServing:
priority = orig_priority - 1
sub_request += 1
def _get_prompt_components(
self,
prompt: RequestPrompt | PromptType,
) -> PromptComponents:
if isinstance(prompt, list):
return PromptComponents(token_ids=prompt)
return get_prompt_components(prompt) # type: ignore[arg-type]
def _get_prompt_components(self, prompt: PromptType) -> PromptComponents:
return get_prompt_components(prompt)
def _log_inputs(
self,
request_id: str,
inputs: RequestPrompt | PromptType,
inputs: PromptType,
params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None,
) -> None:
......@@ -1423,7 +1428,7 @@ class OpenAIServing:
@staticmethod
def _parse_tool_calls_from_content(
request: ResponsesRequest | ChatCompletionRequest,
tokenizer: TokenizerLike,
tokenizer: TokenizerLike | None,
enable_auto_tools: bool,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
content: str | None = None,
......@@ -1463,6 +1468,11 @@ class OpenAIServing:
and enable_auto_tools
and (request.tool_choice == "auto" or request.tool_choice is None)
):
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
# Automatic Tool Call Parsing
try:
tool_parser = tool_parser_cls(tokenizer)
......
......@@ -50,6 +50,7 @@ from openai.types.responses.response_reasoning_item import (
)
from openai.types.responses.tool import Mcp, Tool
from openai_harmony import Message as OpenAIHarmonyMessage
from pydantic import TypeAdapter
from vllm import envs
from vllm.engine.protocol import EngineClient
......@@ -94,7 +95,10 @@ from vllm.entrypoints.openai.protocol import (
ResponseUsage,
StreamingResponsesResponse,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import (
GenerationError,
OpenAIServing,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.responses_utils import (
construct_input_messages,
......@@ -103,7 +107,7 @@ from vllm.entrypoints.responses_utils import (
make_response_output_items_from_parsable_context,
)
from vllm.entrypoints.tool_server import ToolServer
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs
......@@ -254,7 +258,7 @@ class OpenAIServingResponses(OpenAIServing):
self.tool_server = tool_server
def _validate_generator_input(
self, engine_prompt: EngineTokensPrompt
self, engine_prompt: TokensPrompt
) -> ErrorResponse | None:
"""Add validations to the input to the generator here."""
if self.max_model_len <= len(engine_prompt["prompt_token_ids"]):
......@@ -349,11 +353,11 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer()
if self.use_harmony:
messages, request_prompts, engine_prompts = (
self._make_request_with_harmony(request, prev_response)
messages, engine_prompts = self._make_request_with_harmony(
request, prev_response
)
else:
messages, request_prompts, engine_prompts = await self._make_request(
messages, engine_prompts = await self._make_request(
request, prev_response, tokenizer
)
......@@ -389,7 +393,7 @@ class OpenAIServingResponses(OpenAIServing):
assert len(builtin_tool_list) == 0
available_tools = []
try:
for i, engine_prompt in enumerate(engine_prompts):
for engine_prompt in engine_prompts:
maybe_error = self._validate_generator_input(engine_prompt)
if maybe_error is not None:
return maybe_error
......@@ -416,7 +420,7 @@ class OpenAIServingResponses(OpenAIServing):
context = HarmonyContext(messages, available_tools)
else:
if envs.VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT:
# This is an feature in development for parsing
# This is a feature in development for parsing
# tokens during generation instead of at the end
context = ParsableContext(
response_messages=messages,
......@@ -445,7 +449,6 @@ class OpenAIServingResponses(OpenAIServing):
)
generator = self._generate_with_builtin_tools(
request_id=request.request_id,
request_prompt=request_prompts[i],
engine_prompt=engine_prompt,
sampling_params=sampling_params,
context=context,
......@@ -541,6 +544,8 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer,
request_metadata,
)
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except Exception as e:
return self.create_error_response(str(e))
......@@ -558,7 +563,7 @@ class OpenAIServingResponses(OpenAIServing):
prev_msg=self.msg_store.get(prev_response.id) if prev_response else None,
prev_response_output=prev_response.output if prev_response else None,
)
_, request_prompts, engine_prompts = await self._preprocess_chat(
_, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
messages,
......@@ -567,7 +572,7 @@ class OpenAIServingResponses(OpenAIServing):
chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format,
)
return messages, request_prompts, engine_prompts
return messages, engine_prompts
def _make_request_with_harmony(
self,
......@@ -580,13 +585,13 @@ class OpenAIServingResponses(OpenAIServing):
)
messages = self._construct_input_messages_with_harmony(request, prev_response)
prompt_token_ids = render_for_completion(messages)
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
# Add cache_salt if provided in the request
if request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
return messages, [prompt_token_ids], [engine_prompt]
return messages, [engine_prompt]
async def _initialize_tool_sessions(
self,
......@@ -648,6 +653,8 @@ class OpenAIServingResponses(OpenAIServing):
status = "incomplete"
elif context.finish_reason == "abort":
status = "cancelled"
else:
self._raise_if_error(context.finish_reason, request.request_id)
else:
status = "incomplete"
elif isinstance(context, ParsableContext):
......@@ -673,6 +680,9 @@ class OpenAIServingResponses(OpenAIServing):
assert len(final_res.outputs) == 1
final_output = final_res.outputs[0]
# finish_reason='error' indicates retryable internal error
self._raise_if_error(final_output.finish_reason, request.request_id)
output = self._make_response_output_items(request, final_output, tokenizer)
if request.enable_response_messages:
......@@ -1066,6 +1076,8 @@ class OpenAIServingResponses(OpenAIServing):
async for event in generator:
event_deque.append(event)
new_event_signal.set() # Signal new event available
except GenerationError as e:
response = self._convert_generation_error_to_response(e)
except Exception as e:
logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(str(e))
......@@ -1089,6 +1101,8 @@ class OpenAIServingResponses(OpenAIServing):
):
try:
response = await self.responses_full_generator(request, *args, **kwargs)
except GenerationError as e:
response = self._convert_generation_error_to_response(e)
except Exception as e:
logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(str(e))
......@@ -1227,6 +1241,8 @@ class OpenAIServingResponses(OpenAIServing):
continue
if ctx.last_output.outputs:
output = ctx.last_output.outputs[0]
# finish_reason='error' indicates a retryable error
self._raise_if_error(output.finish_reason, request.request_id)
if reasoning_parser:
delta_message = reasoning_parser.extract_reasoning_streaming(
previous_text=previous_text,
......@@ -1522,6 +1538,9 @@ class OpenAIServingResponses(OpenAIServing):
async for ctx in result_generator:
assert isinstance(ctx, StreamingHarmonyContext)
# finish_reason='error' indicates a retryable error
self._raise_if_error(ctx.finish_reason, request.request_id)
if ctx.is_expecting_start():
current_output_index += 1
sent_output_item_added = False
......@@ -2016,18 +2035,25 @@ class OpenAIServingResponses(OpenAIServing):
)
)
async for event_data in processer(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
_increment_sequence_number_and_return,
):
yield event_data
try:
async for event_data in processer(
request,
sampling_params,
result_generator,
context,
model_name,
tokenizer,
request_metadata,
created_time,
_increment_sequence_number_and_return,
):
yield event_data
except GenerationError as e:
error_json = self._convert_generation_error_to_streaming_response(e)
yield _increment_sequence_number_and_return(
TypeAdapter(StreamingResponsesResponse).validate_json(error_json)
)
return
async def empty_async_generator():
# A hack to trick Python to think this is a generator but
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
ToolParserManager,
)
import warnings
__all__ = ["ToolParser", "ToolParserManager"]
def __getattr__(name: str):
if name == "ToolParser":
from vllm.tool_parsers import ToolParser
"""
Register a lazy module mapping.
warnings.warn(
"`vllm.entrypoints.openai.tool_parsers.ToolParser` has been moved to "
"`vllm.tool_parsers.ToolParser`. "
"The old name will be removed in v0.14.",
DeprecationWarning,
stacklevel=2,
)
Example:
ToolParserManager.register_lazy_module(
name="kimi_k2",
module_path="vllm.entrypoints.openai.tool_parsers.kimi_k2_parser",
class_name="KimiK2ToolParser",
)
"""
return ToolParser
if name == "ToolParserManager":
from vllm.tool_parsers import ToolParserManager
warnings.warn(
"`vllm.entrypoints.openai.tool_parsers.ToolParserManager` "
"has been moved to `vllm.tool_parsers.ToolParserManager`. "
"The old name will be removed in v0.14.",
DeprecationWarning,
stacklevel=2,
)
_TOOL_PARSERS_TO_REGISTER = {
"deepseek_v3": ( # name
"deepseekv3_tool_parser", # filename
"DeepSeekV3ToolParser", # class_name
),
"deepseek_v31": (
"deepseekv31_tool_parser",
"DeepSeekV31ToolParser",
),
"deepseek_v32": (
"deepseekv32_tool_parser",
"DeepSeekV32ToolParser",
),
"ernie45": (
"ernie45_tool_parser",
"Ernie45ToolParser",
),
"glm45": (
"glm4_moe_tool_parser",
"Glm4MoeModelToolParser",
),
"granite-20b-fc": (
"granite_20b_fc_tool_parser",
"Granite20bFCToolParser",
),
"granite": (
"granite_tool_parser",
"GraniteToolParser",
),
"hermes": (
"hermes_tool_parser",
"Hermes2ProToolParser",
),
"hunyuan_a13b": (
"hunyuan_a13b_tool_parser",
"HunyuanA13BToolParser",
),
"internlm": (
"internlm2_tool_parser",
"Internlm2ToolParser",
),
"jamba": (
"jamba_tool_parser",
"JambaToolParser",
),
"kimi_k2": (
"kimi_k2_tool_parser",
"KimiK2ToolParser",
),
"llama3_json": (
"llama_tool_parser",
"Llama3JsonToolParser",
),
"llama4_json": (
"llama_tool_parser",
"Llama3JsonToolParser",
),
"llama4_pythonic": (
"llama4_pythonic_tool_parser",
"Llama4PythonicToolParser",
),
"longcat": (
"longcat_tool_parser",
"LongcatFlashToolParser",
),
"minimax_m2": (
"minimax_m2_tool_parser",
"MinimaxM2ToolParser",
),
"minimax": (
"minimax_tool_parser",
"MinimaxToolParser",
),
"mistral": (
"mistral_tool_parser",
"MistralToolParser",
),
"olmo3": (
"olmo3_tool_parser",
"Olmo3PythonicToolParser",
),
"openai": (
"openai_tool_parser",
"OpenAIToolParser",
),
"phi4_mini_json": (
"phi4mini_tool_parser",
"Phi4MiniJsonToolParser",
),
"pythonic": (
"pythonic_tool_parser",
"PythonicToolParser",
),
"qwen3_coder": (
"qwen3coder_tool_parser",
"Qwen3CoderToolParser",
),
"qwen3_xml": (
"qwen3xml_tool_parser",
"Qwen3XMLToolParser",
),
"seed_oss": (
"seed_oss_tool_parser",
"SeedOssToolParser",
),
"step3": (
"step3_tool_parser",
"Step3ToolParser",
),
"xlam": (
"xlam_tool_parser",
"xLAMToolParser",
),
"gigachat3": (
"gigachat3_tool_parser",
"GigaChat3ToolParser",
),
}
return ToolParserManager
def register_lazy_tool_parsers():
for name, (file_name, class_name) in _TOOL_PARSERS_TO_REGISTER.items():
module_path = f"vllm.entrypoints.openai.tool_parsers.{file_name}"
ToolParserManager.register_lazy_module(name, module_path, class_name)
register_lazy_tool_parsers()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
......@@ -72,11 +72,7 @@ class ClassificationMixin(OpenAIServing):
if ret:
return ret
(
_,
_,
engine_prompts,
) = await self._preprocess_chat(
_, engine_prompts = await self._preprocess_chat(
cast(ChatCompletionRequest, chat_request),
ctx.tokenizer,
messages,
......
......@@ -20,7 +20,6 @@ from vllm.entrypoints.openai.serving_engine import (
EmbeddingServeContext,
OpenAIServing,
ServeContext,
TextTokensPrompt,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import (
......@@ -32,7 +31,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingResponseData,
)
from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import (
EmbeddingRequestOutput,
......@@ -83,11 +82,7 @@ class EmbeddingMixin(OpenAIServing):
renderer = self._get_renderer(tokenizer)
if isinstance(ctx.request, EmbeddingChatRequest):
(
_,
_,
ctx.engine_prompts,
) = await self._preprocess_chat(
_, ctx.engine_prompts = await self._preprocess_chat(
ctx.request,
tokenizer,
ctx.request.messages,
......@@ -209,14 +204,13 @@ class EmbeddingMixin(OpenAIServing):
async def _process_chunked_request(
self,
ctx: EmbeddingServeContext,
original_prompt: TextTokensPrompt,
token_ids: list[int],
pooling_params,
trace_headers,
prompt_idx: int,
) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
"""Process a single prompt using chunked processing."""
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
token_ids = original_prompt["prompt_token_ids"]
# Split into chunks using max_position_embeddings
max_pos_embeddings = self._get_max_position_embeddings()
......@@ -228,18 +222,12 @@ class EmbeddingMixin(OpenAIServing):
chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
# Create engine prompt for this chunk
chunk_engine_prompt = EngineTokensPrompt(prompt_token_ids=chunk_tokens)
# Create chunk request prompt for logging
chunk_text = ""
chunk_request_prompt = TextTokensPrompt(
prompt=chunk_text, prompt_token_ids=chunk_tokens
)
chunk_engine_prompt = TokensPrompt(prompt_token_ids=chunk_tokens)
# Log the chunk
self._log_inputs(
chunk_request_id,
chunk_request_prompt,
chunk_engine_prompt,
params=pooling_params,
lora_request=ctx.lora_request,
)
......@@ -263,7 +251,7 @@ class EmbeddingMixin(OpenAIServing):
request,
input_ids: list[int],
input_text: str,
) -> TextTokensPrompt:
) -> TokensPrompt:
"""Override to support chunked processing for embedding requests."""
token_num = len(input_ids)
......@@ -328,23 +316,15 @@ class EmbeddingMixin(OpenAIServing):
)
)
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# For other request types, use the parent's implementation
return super()._validate_input(request, input_ids, input_text)
def _is_text_tokens_prompt(self, prompt) -> bool:
"""Check if a prompt is a TextTokensPrompt (has prompt_token_ids)."""
return (
isinstance(prompt, dict)
and "prompt_token_ids" in prompt
and "prompt_embeds" not in prompt
)
async def _create_single_prompt_generator(
self,
ctx: EmbeddingServeContext,
engine_prompt: EngineTokensPrompt,
engine_prompt: TokensPrompt,
pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None,
prompt_index: int,
......@@ -413,14 +393,16 @@ class EmbeddingMixin(OpenAIServing):
for i, engine_prompt in enumerate(ctx.engine_prompts):
# Check if this specific prompt needs chunked processing
if self._is_text_tokens_prompt(engine_prompt):
# Cast to TextTokensPrompt since we've verified
# prompt_token_ids
text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
if len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings:
if "prompt_token_ids" in engine_prompt:
prompt_token_ids = engine_prompt["prompt_token_ids"]
if len(prompt_token_ids) > max_pos_embeddings:
# Use chunked processing for this prompt
chunk_generators = await self._process_chunked_request(
ctx, text_tokens_prompt, pooling_params, trace_headers, i
ctx,
prompt_token_ids,
pooling_params,
trace_headers,
i,
)
generators.extend(chunk_generators)
continue
......@@ -578,14 +560,13 @@ class EmbeddingMixin(OpenAIServing):
# Get original prompt token IDs for this prompt
original_prompt = ctx.engine_prompts[prompt_idx]
if not self._is_text_tokens_prompt(original_prompt):
if "prompt_token_ids" not in original_prompt:
return self.create_error_response(
f"Chunked prompt {prompt_idx} is not a TextTokensPrompt"
f"Chunked prompt {prompt_idx} does not contain "
"token IDs"
)
original_token_ids = cast(TextTokensPrompt, original_prompt)[
"prompt_token_ids"
]
original_token_ids = original_prompt["prompt_token_ids"]
pooling_request_output = PoolingRequestOutput(
request_id=aggregator["request_id"],
......
......@@ -137,11 +137,8 @@ class OpenAIServingPooling(OpenAIServing):
)
if error_check_ret is not None:
return error_check_ret
(
_,
_,
engine_prompts,
) = await self._preprocess_chat(
_, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
request.messages,
......
......@@ -120,6 +120,7 @@ class RerankResult(BaseModel):
class RerankUsage(BaseModel):
prompt_tokens: int
total_tokens: int
......
......@@ -38,7 +38,8 @@ from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async, merge_async_iterators
logger = init_logger(__name__)
......@@ -501,5 +502,7 @@ class ServingScores(OpenAIServing):
id=request_id,
model=model_name,
results=results,
usage=RerankUsage(total_tokens=num_prompt_tokens),
usage=RerankUsage(
total_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens
),
)
......@@ -12,9 +12,7 @@ import torch
from pydantic import Field
from vllm.config import ModelConfig
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TextPrompt as EngineTextPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer
......@@ -97,7 +95,7 @@ class BaseRenderer(ABC):
*,
prompt_or_prompts: str | list[str] | list[int] | list[list[int]],
config: RenderConfig,
) -> list[EngineTokensPrompt]:
) -> list[TokensPrompt]:
"""
Convert text or token inputs into engine-ready TokensPrompt objects.
......@@ -115,7 +113,7 @@ class BaseRenderer(ABC):
(e.g., tokenization and length handling).
Returns:
list[EngineTokensPrompt]: Engine-ready token prompts.
list[TokensPrompt]: Engine-ready token prompts.
Raises:
ValueError: If input formats are invalid or length limits exceeded.
......@@ -129,7 +127,7 @@ class BaseRenderer(ABC):
prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None,
config: RenderConfig,
) -> list[EngineTokensPrompt | EngineEmbedsPrompt]:
) -> list[TokensPrompt | EmbedsPrompt]:
"""
Convert text/token and/or base64-encoded embeddings inputs into
engine-ready prompt objects using a unified RenderConfig.
......@@ -146,7 +144,7 @@ class BaseRenderer(ABC):
(e.g., tokenization and length handling).
Returns:
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
list[Union[TokensPrompt, EmbedsPrompt]]:
Engine-ready prompt objects.
Raises:
......@@ -161,31 +159,34 @@ class BaseRenderer(ABC):
prompt_embeds: bytes | list[bytes],
truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None,
cache_salt: str | None = None,
) -> list[EngineEmbedsPrompt]:
) -> list[EmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects."""
if not self.model_config.enable_prompt_embeds:
raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`."
)
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
tensor = torch.load(
io.BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"),
)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
# Enable sparse tensor integrity checks to prevent out-of-bounds
# writes from maliciously crafted tensors
with torch.sparse.check_sparse_tensor_invariants():
tensor = torch.load(
io.BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"),
)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
if truncate_prompt_tokens is not None:
tensor = tensor[-truncate_prompt_tokens:]
embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor)
embeds_prompt = EmbedsPrompt(prompt_embeds=tensor)
if cache_salt is not None:
embeds_prompt["cache_salt"] = cache_salt
return embeds_prompt
......@@ -213,7 +214,7 @@ class CompletionRenderer(BaseRenderer):
*,
prompt_or_prompts: str | list[str] | list[int] | list[list[int]],
config: RenderConfig,
) -> list[EngineTokensPrompt]:
) -> list[TokensPrompt]:
"""Implementation of prompt rendering for completion-style requests.
Uses async tokenizer pooling for improved performance. See base class
......@@ -240,7 +241,7 @@ class CompletionRenderer(BaseRenderer):
prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None,
config: RenderConfig,
) -> list[EngineTokensPrompt | EngineEmbedsPrompt]:
) -> list[TokensPrompt | EmbedsPrompt]:
"""
Render text/token prompts and/or precomputed embedding prompts. At
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
......@@ -249,7 +250,7 @@ class CompletionRenderer(BaseRenderer):
if truncate_prompt_tokens == 0:
return []
rendered: list[EngineTokensPrompt | EngineEmbedsPrompt] = []
rendered: list[TokensPrompt | EmbedsPrompt] = []
if prompt_embeds is not None:
rendered.extend(
......@@ -281,10 +282,10 @@ class CompletionRenderer(BaseRenderer):
async def _create_prompt(
self,
prompt_input: EngineTextPrompt | EngineTokensPrompt,
prompt_input: TextPrompt | TokensPrompt,
config: RenderConfig,
truncate_prompt_tokens: int | None,
) -> EngineTokensPrompt:
) -> TokensPrompt:
prompt, prompt_token_ids, _ = get_prompt_components(prompt_input)
if prompt_token_ids is not None:
......@@ -317,7 +318,7 @@ class CompletionRenderer(BaseRenderer):
truncate_prompt_tokens: int | None,
add_special_tokens: bool,
cache_salt: str | None,
) -> EngineTokensPrompt:
) -> TokensPrompt:
"""Tokenize text input asynchronously."""
async_tokenizer = self._get_async_tokenizer()
......@@ -350,7 +351,7 @@ class CompletionRenderer(BaseRenderer):
truncate_prompt_tokens: int | None,
cache_salt: str | None,
needs_detokenization: bool | None = False,
) -> EngineTokensPrompt:
) -> TokensPrompt:
"""Optionally detokenize token IDs and build a tokens prompt."""
token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens)
......@@ -392,8 +393,8 @@ class CompletionRenderer(BaseRenderer):
max_length: int | None = None,
cache_salt: str | None = None,
prompt: str | None = None,
) -> EngineTokensPrompt:
"""Create validated EngineTokensPrompt."""
) -> TokensPrompt:
"""Create validated TokensPrompt."""
if max_length is not None and len(token_ids) > max_length:
raise ValueError(
f"This model's maximum context length is {max_length} tokens. "
......@@ -401,7 +402,7 @@ class CompletionRenderer(BaseRenderer):
"Please reduce the length of the input messages."
)
tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids)
tokens_prompt = TokensPrompt(prompt_token_ids=token_ids)
if cache_salt is not None:
tokens_prompt["cache_salt"] = cache_salt
if prompt is not None:
......
......@@ -27,7 +27,7 @@ from vllm.entrypoints.serve.disagg.protocol import (
GenerateResponse,
GenerateResponseChoice,
)
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
......@@ -99,7 +99,7 @@ class ServingTokens(OpenAIServing):
# TODO(NickLucche): Change to EngineCoreRequest once Renderer work is
# completed
engine_prompt = EngineTokensPrompt(prompt_token_ids=request.token_ids)
engine_prompt = TokensPrompt(prompt_token_ids=request.token_ids)
if request.features is not None:
engine_prompt["multi_modal_data"] = None
......@@ -115,7 +115,7 @@ class ServingTokens(OpenAIServing):
self._log_inputs(
request_id,
request.token_ids,
TokensPrompt(prompt_token_ids=request.token_ids),
params=sampling_params,
lora_request=lora_request,
)
......
......@@ -21,6 +21,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
......@@ -80,11 +81,8 @@ class OpenAIServingTokenization(OpenAIServing):
)
if error_check_ret is not None:
return error_check_ret
(
_,
_,
engine_prompts,
) = await self._preprocess_chat(
_, engine_prompts = await self._preprocess_chat(
request,
tokenizer,
request.messages,
......@@ -141,7 +139,10 @@ class OpenAIServingTokenization(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer()
self._log_inputs(
request_id, request.tokens, params=None, lora_request=lora_request
request_id,
TokensPrompt(prompt_token_ids=request.tokens),
params=None,
lora_request=lora_request,
)
prompt_input = await self._tokenize_prompt_input_async(
......
......@@ -30,7 +30,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.tokenizers import MistralTokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.argparse_utils import FlexibleArgumentParser
logger = init_logger(__name__)
......
......@@ -72,10 +72,9 @@ if TYPE_CHECKING:
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
VLLM_MEDIA_CONNECTOR: str = "http"
VLLM_MM_INPUT_CACHE_GIB: int = 4
VLLM_TARGET_DEVICE: str = "cuda"
VLLM_MAIN_CUDA_VERSION: str = "12.9"
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest"
VLLM_FLOAT32_MATMUL_PRECISION: Literal["ieee", "tf32"] = "ieee"
MAX_JOBS: str | None = None
NVCC_THREADS: str | None = None
VLLM_USE_PRECOMPILED: bool = False
......@@ -240,6 +239,7 @@ if TYPE_CHECKING:
VLLM_NCCL_INCLUDE_PATH: str | None = None
VLLM_USE_FBGEMM: bool = False
VLLM_GC_DEBUG: str = ""
VLLM_DEBUG_WORKSPACE: bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
......@@ -458,11 +458,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower()
or "12.9",
# Controls PyTorch float32 matmul precision mode within vLLM workers.
# Valid options mirror torch.set_float32_matmul_precision
# Accepted values:
# - "ieee" (default): force full IEEE FP32 matmul precision.
# - "tf32": enable TensorFloat32-based fast matmul.
"VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices(
"VLLM_FLOAT32_MATMUL_PRECISION",
"highest",
["highest", "high", "medium"],
"ieee",
["ieee", "tf32"],
case_sensitive=False,
),
# Maximum number of compilation jobs to run in parallel.
......@@ -787,9 +789,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
# imported at runtime.
# If a non-existing backend is used, an AssertionError will be thrown.
"VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"),
# [DEPRECATED] Cache size (in GiB per process) for multimodal input cache
# Default is 4 GiB per API process + 4 GiB per engine core process
"VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH": lambda: os.path.expanduser(
......@@ -1540,6 +1539,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with
# top 5 collected objects
"VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""),
# Debug workspace allocations.
# logging of workspace resize operations.
"VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))),
# Disables parallel execution of shared_experts via separate cuda stream
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))
......@@ -1584,6 +1586,12 @@ def __getattr__(name: str):
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def _is_envs_cache_enabled() -> bool:
"""Checked if __getattr__ is wrapped with functools.cache"""
global __getattr__
return hasattr(__getattr__, "cache_clear")
def enable_envs_cache() -> None:
"""
Enables caching of environment variables. This is useful for performance
......@@ -1594,6 +1602,9 @@ def enable_envs_cache() -> None:
runtime overhead. This also means that environment variables should NOT
be updated after the service is initialized.
"""
if _is_envs_cache_enabled():
# Avoid wrapping functools.cache multiple times
return
# Tag __getattr__ with functools.cache
global __getattr__
__getattr__ = functools.cache(__getattr__)
......@@ -1603,6 +1614,17 @@ def enable_envs_cache() -> None:
__getattr__(key)
def disable_envs_cache() -> None:
"""
Resets the environment variables cache. It could be used to isolate environments
between unit tests.
"""
global __getattr__
# If __getattr__ is wrapped by functions.cache, unwrap the caching layer.
if _is_envs_cache_enabled():
__getattr__ = __getattr__.__wrapped__
def __dir__():
return list(environment_variables.keys())
......@@ -1665,7 +1687,6 @@ def compile_factors() -> dict[str, object]:
"VLLM_MEDIA_CONNECTOR",
"VLLM_ASSETS_CACHE",
"VLLM_ASSETS_CACHE_MODEL_CLEAN",
"VLLM_MM_INPUT_CACHE_GIB",
"VLLM_WORKER_MULTIPROC_METHOD",
"VLLM_ENABLE_V1_MULTIPROCESSING",
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
......
......@@ -33,22 +33,31 @@ def parse_raw_prompts(
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
# case 2: array of strings
if is_list_of(prompt, str):
# case 2: array of strings
prompt = cast(list[str], prompt)
return [TextPrompt(prompt=elem) for elem in prompt]
# case 3: array of tokens
if is_list_of(prompt, int):
# case 3: array of tokens
prompt = cast(list[int], prompt)
return [TokensPrompt(prompt_token_ids=prompt)]
# case 4: array of token arrays
if is_list_of(prompt, list):
prompt = cast(list[list[int]], prompt)
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
first = prompt[0]
if not isinstance(first, list):
raise ValueError("prompt expected to be a list of lists")
if is_list_of(prompt[0], int):
# case 4: array of token arrays
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
if len(first) == 0:
raise ValueError("Please provide at least one prompt")
# strict validation: every nested list must be list[int]
if not all(is_list_of(elem, int) for elem in prompt):
raise TypeError("Nested lists must contain only integers")
prompt = cast(list[list[int]], prompt)
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
raise TypeError(
"prompt must be a string, array of strings, "
......
......@@ -229,6 +229,11 @@ def suppress_logging(level: int = logging.INFO) -> Generator[None, Any, None]:
# guaranteed by the Python GIL.
_configure_vllm_root_logger()
# Transformers uses httpx to access the Hugging Face Hub. httpx is quite verbose,
# so we set its logging level to WARNING when vLLM's logging level is INFO.
if envs.VLLM_LOGGING_LEVEL == "INFO":
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = init_logger(__name__)
......
......@@ -38,8 +38,9 @@ class CustomOp(nn.Module):
)
return super().__new__(op_cls_to_instantiate)
def __init__(self):
def __init__(self, enforce_enable: bool = False):
super().__init__()
self._enforce_enable = enforce_enable
self._forward_method = self.dispatch_forward()
def forward(self, *args, **kwargs):
......@@ -84,7 +85,11 @@ class CustomOp(nn.Module):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
compilation_config = get_cached_compilation_config()
enabled = self.enabled()
# CustomOp object can be enforce enabled, e.g., enable device-specific
# kernels in ViT models when enabling graph mode. By default, it will
# follow the compilation_config to determine whether enable itself.
enabled = self._enforce_enable or self.enabled()
if enabled:
compilation_config.enabled_custom_ops.update([self.__class__.name])
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment