Commit a3f8d5dd authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori

parents 8d75f22e f34eca5f
...@@ -12,10 +12,10 @@ from openai.types.responses.response_reasoning_item import ( ...@@ -12,10 +12,10 @@ from openai.types.responses.response_reasoning_item import (
) )
from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers.protocol import TokenizerLike from vllm.tokenizers.protocol import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import ToolParser
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
......
...@@ -320,6 +320,7 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -320,6 +320,7 @@ class ResponsesRequest(OpenAIBaseModel):
max_tool_calls: int | None = None max_tool_calls: int | None = None
metadata: Metadata | None = None metadata: Metadata | None = None
model: str | None = None model: str | None = None
logit_bias: dict[str, float] | None = None
parallel_tool_calls: bool | None = True parallel_tool_calls: bool | None = True
previous_response_id: str | None = None previous_response_id: str | None = None
prompt: ResponsePrompt | None = None prompt: ResponsePrompt | None = None
...@@ -333,6 +334,7 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -333,6 +334,7 @@ class ResponsesRequest(OpenAIBaseModel):
tools: list[Tool] = Field(default_factory=list) tools: list[Tool] = Field(default_factory=list)
top_logprobs: int | None = 0 top_logprobs: int | None = 0
top_p: float | None = None top_p: float | None = None
top_k: int | None = None
truncation: Literal["auto", "disabled"] | None = "disabled" truncation: Literal["auto", "disabled"] | None = "disabled"
user: str | None = None user: str | None = None
...@@ -387,6 +389,7 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -387,6 +389,7 @@ class ResponsesRequest(OpenAIBaseModel):
_DEFAULT_SAMPLING_PARAMS = { _DEFAULT_SAMPLING_PARAMS = {
"temperature": 1.0, "temperature": 1.0,
"top_p": 1.0, "top_p": 1.0,
"top_k": 0,
} }
def to_sampling_params( def to_sampling_params(
...@@ -408,6 +411,10 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -408,6 +411,10 @@ class ResponsesRequest(OpenAIBaseModel):
top_p = default_sampling_params.get( top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]
) )
if (top_k := self.top_k) is None:
top_k = default_sampling_params.get(
"top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]
)
stop_token_ids = default_sampling_params.get("stop_token_ids") stop_token_ids = default_sampling_params.get("stop_token_ids")
# Structured output # Structured output
...@@ -428,6 +435,7 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -428,6 +435,7 @@ class ResponsesRequest(OpenAIBaseModel):
return SamplingParams.from_optional( return SamplingParams.from_optional(
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k,
max_tokens=max_tokens, max_tokens=max_tokens,
logprobs=self.top_logprobs if self.is_include_output_logprobs() else None, logprobs=self.top_logprobs if self.is_include_output_logprobs() else None,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
...@@ -435,6 +443,7 @@ class ResponsesRequest(OpenAIBaseModel): ...@@ -435,6 +443,7 @@ class ResponsesRequest(OpenAIBaseModel):
RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY
), ),
structured_outputs=structured_outputs, structured_outputs=structured_outputs,
logit_bias=self.logit_bias,
) )
def is_include_output_logprobs(self) -> bool: def is_include_output_logprobs(self) -> bool:
......
...@@ -27,8 +27,8 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -27,8 +27,8 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
get_stop_tokens_for_assistant_actions, get_stop_tokens_for_assistant_actions,
get_streamable_parser_for_assistant, get_streamable_parser_for_assistant,
get_system_message, get_system_message,
parse_chat_inputs_to_harmony_messages,
parse_chat_output, parse_chat_output,
parse_input_to_harmony_message,
render_for_completion, render_for_completion,
) )
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
...@@ -51,13 +51,15 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -51,13 +51,15 @@ from vllm.entrypoints.openai.protocol import (
ToolCall, ToolCall,
UsageInfo, UsageInfo,
) )
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.serving_engine import (
GenerationError,
OpenAIServing,
clamp_prompt_logprobs,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
...@@ -69,6 +71,8 @@ from vllm.tokenizers.mistral import ( ...@@ -69,6 +71,8 @@ from vllm.tokenizers.mistral import (
truncate_tool_call_ids, truncate_tool_call_ids,
validate_request_params, validate_request_params,
) )
from vllm.tool_parsers import ToolParser
from vllm.tool_parsers.mistral_tool_parser import MistralToolCall
from vllm.utils.collection_utils import as_list from vllm.utils.collection_utils import as_list
from vllm.v1.sample.logits_processor import validate_logits_processors_parameters from vllm.v1.sample.logits_processor import validate_logits_processors_parameters
...@@ -230,11 +234,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -230,11 +234,7 @@ class OpenAIServingChat(OpenAIServing):
) )
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
( conversation, engine_prompts = await self._preprocess_chat(
conversation,
request_prompts,
engine_prompts,
) = await self._preprocess_chat(
request, request,
tokenizer, tokenizer,
request.messages, request.messages,
...@@ -250,11 +250,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -250,11 +250,7 @@ class OpenAIServingChat(OpenAIServing):
) )
else: else:
# For GPT-OSS. # For GPT-OSS.
( conversation, engine_prompts = self._make_request_with_harmony(request)
conversation,
request_prompts,
engine_prompts,
) = self._make_request_with_harmony(request)
except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e: except (ValueError, TypeError, RuntimeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs") logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}") return self.create_error_response(f"{e} {e.__cause__}")
...@@ -274,7 +270,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -274,7 +270,7 @@ class OpenAIServingChat(OpenAIServing):
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for i, engine_prompt in enumerate(engine_prompts):
prompt_text, _, _ = self._get_prompt_components(request_prompts[i]) prompt_text, _, _ = self._get_prompt_components(engine_prompt)
# If we are creating sub requests for multiple prompts, ensure that they # If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids. # have unique request ids.
sub_request_id = ( sub_request_id = (
...@@ -309,7 +305,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -309,7 +305,7 @@ class OpenAIServingChat(OpenAIServing):
self._log_inputs( self._log_inputs(
sub_request_id, sub_request_id,
request_prompts[i], engine_prompt,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -380,6 +376,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -380,6 +376,8 @@ class OpenAIServingChat(OpenAIServing):
tokenizer, tokenizer,
request_metadata, request_metadata,
) )
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
...@@ -531,7 +529,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -531,7 +529,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
created_time = int(time.time()) created_time = int(time.time())
...@@ -585,6 +583,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -585,6 +583,11 @@ class OpenAIServingChat(OpenAIServing):
try: try:
if self.reasoning_parser: if self.reasoning_parser:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
reasoning_parser = self.reasoning_parser( reasoning_parser = self.reasoning_parser(
tokenizer, tokenizer,
chat_template_kwargs=request.chat_template_kwargs, # type: ignore chat_template_kwargs=request.chat_template_kwargs, # type: ignore
...@@ -598,6 +601,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -598,6 +601,11 @@ class OpenAIServingChat(OpenAIServing):
# Prepare the tool parser if it's needed # Prepare the tool parser if it's needed
try: try:
if tool_choice_auto and self.tool_parser: if tool_choice_auto and self.tool_parser:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
tool_parsers: list[ToolParser | None] = [ tool_parsers: list[ToolParser | None] = [
self.tool_parser(tokenizer) self.tool_parser(tokenizer)
] * num_choices ] * num_choices
...@@ -816,6 +824,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -816,6 +824,9 @@ class OpenAIServingChat(OpenAIServing):
if delta_message is not None: if delta_message is not None:
harmony_tools_streamed[i] = True harmony_tools_streamed[i] = True
elif cur_channel == "commentary":
# Tool call preambles meant to be shown to the user
delta_message = DeltaMessage(content=delta_text)
else: else:
delta_message = None delta_message = None
# handle streaming deltas for tools with named tool_choice # handle streaming deltas for tools with named tool_choice
...@@ -953,21 +964,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -953,21 +964,9 @@ class OpenAIServingChat(OpenAIServing):
assert reasoning_end_arr is not None assert reasoning_end_arr is not None
output_token_ids = as_list(output.token_ids) output_token_ids = as_list(output.token_ids)
if not reasoning_end_arr[i]: if not reasoning_end_arr[i]:
delta_message = (
reasoning_parser.extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
output_token_ids,
)
)
# When encountering think end id in prompt_token_ids # When encountering think end id in prompt_token_ids
# i.e {"enable_thinking": False}, # i.e {"enable_thinking": False},
# set reasoning status to end. # set reasoning status to end.
# Remove the text and token ids related
# to 'reasoning'.
if ( if (
res.prompt_token_ids res.prompt_token_ids
and reasoning_parser.is_reasoning_end( and reasoning_parser.is_reasoning_end(
...@@ -976,30 +975,38 @@ class OpenAIServingChat(OpenAIServing): ...@@ -976,30 +975,38 @@ class OpenAIServingChat(OpenAIServing):
): ):
reasoning_end_arr[i] = True reasoning_end_arr[i] = True
current_token_ids = output_token_ids current_token_ids = output_token_ids
if delta_message and delta_message.content: # Don't update current_text, keep it as is from delta
current_text = delta_message.content else:
delta_message.content = None delta_message = (
else: reasoning_parser.extract_reasoning_streaming(
current_text = "" previous_text,
# When encountering think end id in delta_token_ids, current_text,
# set reasoning status to end. delta_text,
# Remove the text and token ids related previous_token_ids,
# to 'reasoning'. current_token_ids,
if reasoning_parser.is_reasoning_end(output_token_ids): output_token_ids,
reasoning_end_arr[i] = True
current_token_ids = (
reasoning_parser.extract_content_ids(
output_token_ids
) )
) )
if delta_message and delta_message.content:
current_text = delta_message.content # When encountering think end id in delta_token_ids,
delta_message.content = None # set reasoning status to end.
else: # Remove the text and token ids related
current_text = "" # to 'reasoning'.
if reasoning_parser.is_reasoning_end(output_token_ids):
reasoning_end_arr[i] = True
current_token_ids = (
reasoning_parser.extract_content_ids(
output_token_ids
)
)
if delta_message and delta_message.content:
current_text = delta_message.content
delta_message.content = None
else:
current_text = ""
# handle tool calls only after reasoning is done, # handle tool calls only after reasoning is done,
else: if reasoning_end_arr[i]:
delta_token_ids = output_token_ids delta_token_ids = output_token_ids
# First time to tool call, # First time to tool call,
# add the remaining text and token ids # add the remaining text and token ids
...@@ -1120,6 +1127,10 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1120,6 +1127,10 @@ class OpenAIServingChat(OpenAIServing):
# if the model is finished generating # if the model is finished generating
else: else:
# check for error finish reason and abort streaming
# finish_reason='error' indicates a retryable error
self._raise_if_error(output.finish_reason, request_id)
# check to make sure we haven't "forgotten" to stream # check to make sure we haven't "forgotten" to stream
# any tokens that were generated but previously # any tokens that were generated but previously
# matched by partial json parsing # matched by partial json parsing
...@@ -1287,6 +1298,8 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1287,6 +1298,8 @@ class OpenAIServingChat(OpenAIServing):
delta=False, delta=False,
) )
except GenerationError as e:
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
except Exception as e: except Exception as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
logger.exception("Error in chat completion stream generator.") logger.exception("Error in chat completion stream generator.")
...@@ -1302,7 +1315,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1302,7 +1315,7 @@ class OpenAIServingChat(OpenAIServing):
request_id: str, request_id: str,
model_name: str, model_name: str,
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
request_metadata: RequestResponseMetadata, request_metadata: RequestResponseMetadata,
) -> ErrorResponse | ChatCompletionResponse: ) -> ErrorResponse | ChatCompletionResponse:
created_time = int(time.time()) created_time = int(time.time())
...@@ -1327,6 +1340,9 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1327,6 +1340,9 @@ class OpenAIServingChat(OpenAIServing):
role = self.get_chat_request_role(request) role = self.get_chat_request_role(request)
for output in final_res.outputs: for output in final_res.outputs:
# check for error finish reason and raise GenerationError
# finish_reason='error' indicates a retryable request-level internal error
self._raise_if_error(output.finish_reason, request_id)
token_ids = output.token_ids token_ids = output.token_ids
out_logprobs = output.logprobs out_logprobs = output.logprobs
tool_call_info = None tool_call_info = None
...@@ -1349,6 +1365,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1349,6 +1365,11 @@ class OpenAIServingChat(OpenAIServing):
reasoning = None reasoning = None
if self.tool_parser is not None: if self.tool_parser is not None:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
tool_parser = self.tool_parser(tokenizer) tool_parser = self.tool_parser(tokenizer)
# NOTE: We use token_ids for openai tool parser # NOTE: We use token_ids for openai tool parser
tool_call_info = tool_parser.extract_tool_calls( tool_call_info = tool_parser.extract_tool_calls(
...@@ -1391,6 +1412,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1391,6 +1412,11 @@ class OpenAIServingChat(OpenAIServing):
if self.reasoning_parser: if self.reasoning_parser:
try: try:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
reasoning_parser = self.reasoning_parser( reasoning_parser = self.reasoning_parser(
tokenizer, tokenizer,
chat_template_kwargs=request.chat_template_kwargs, # type: ignore chat_template_kwargs=request.chat_template_kwargs, # type: ignore
...@@ -1630,7 +1656,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1630,7 +1656,7 @@ class OpenAIServingChat(OpenAIServing):
self, self,
logprobs: dict[int, Logprob], logprobs: dict[int, Logprob],
top_logprobs: int | None, top_logprobs: int | None,
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
should_return_as_token_id: bool, should_return_as_token_id: bool,
) -> list[ChatCompletionLogProb]: ) -> list[ChatCompletionLogProb]:
return [ return [
...@@ -1654,7 +1680,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1654,7 +1680,7 @@ class OpenAIServingChat(OpenAIServing):
self, self,
token_ids: GenericSequence[int], token_ids: GenericSequence[int],
top_logprobs: GenericSequence[dict[int, Logprob] | None], top_logprobs: GenericSequence[dict[int, Logprob] | None],
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
num_output_top_logprobs: int | None = None, num_output_top_logprobs: int | None = None,
return_as_token_id: bool | None = None, return_as_token_id: bool | None = None,
) -> ChatCompletionLogProbs: ) -> ChatCompletionLogProbs:
...@@ -1672,6 +1698,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1672,6 +1698,11 @@ class OpenAIServingChat(OpenAIServing):
if should_return_as_token_id: if should_return_as_token_id:
token = f"token_id:{token_id}" token = f"token_id:{token_id}"
else: else:
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
token = tokenizer.decode(token_id) token = tokenizer.decode(token_id)
logprobs_content.append( logprobs_content.append(
...@@ -1755,6 +1786,11 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1755,6 +1786,11 @@ class OpenAIServingChat(OpenAIServing):
): ):
messages: list[OpenAIMessage] = [] messages: list[OpenAIMessage] = []
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
maybe_serialize_tool_calls(request)
# Add system message. # Add system message.
# NOTE: In Chat Completion API, browsing is enabled by default # NOTE: In Chat Completion API, browsing is enabled by default
# if the model supports it. TODO: Support browsing. # if the model supports it. TODO: Support browsing.
...@@ -1773,15 +1809,14 @@ class OpenAIServingChat(OpenAIServing): ...@@ -1773,15 +1809,14 @@ class OpenAIServingChat(OpenAIServing):
messages.append(dev_msg) messages.append(dev_msg)
# Add user message. # Add user message.
for chat_msg in request.messages: messages.extend(parse_chat_inputs_to_harmony_messages(request.messages))
messages.extend(parse_input_to_harmony_message(chat_msg))
# Render prompt token ids. # Render prompt token ids.
prompt_token_ids = render_for_completion(messages) prompt_token_ids = render_for_completion(messages)
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
# Add cache_salt if provided in the request # Add cache_salt if provided in the request
if request.cache_salt is not None: if request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt engine_prompt["cache_salt"] = request.cache_salt
return messages, [prompt_token_ids], [engine_prompt] return messages, [engine_prompt]
...@@ -24,7 +24,11 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -24,7 +24,11 @@ from vllm.entrypoints.openai.protocol import (
RequestResponseMetadata, RequestResponseMetadata,
UsageInfo, UsageInfo,
) )
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.serving_engine import (
GenerationError,
OpenAIServing,
clamp_prompt_logprobs,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
...@@ -300,6 +304,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -300,6 +304,8 @@ class OpenAIServingCompletion(OpenAIServing):
) )
except asyncio.CancelledError: except asyncio.CancelledError:
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
...@@ -437,6 +443,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -437,6 +443,8 @@ class OpenAIServingCompletion(OpenAIServing):
finish_reason = output.finish_reason finish_reason = output.finish_reason
stop_reason = output.stop_reason stop_reason = output.stop_reason
self._raise_if_error(finish_reason, request_id)
chunk = CompletionStreamResponse( chunk = CompletionStreamResponse(
id=request_id, id=request_id,
created=created_time, created=created_time,
...@@ -498,8 +506,11 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -498,8 +506,11 @@ class OpenAIServingCompletion(OpenAIServing):
# report to FastAPI middleware aggregate usage across all choices # report to FastAPI middleware aggregate usage across all choices
request_metadata.final_usage_info = final_usage_info request_metadata.final_usage_info = final_usage_info
except GenerationError as e:
yield f"data: {self._convert_generation_error_to_streaming_response(e)}\n\n"
except Exception as e: except Exception as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
logger.exception("Error in completion stream generator.")
data = self.create_streaming_error_response(str(e)) data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
...@@ -530,6 +541,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -530,6 +541,8 @@ class OpenAIServingCompletion(OpenAIServing):
out_logprobs: GenericSequence[dict[int, Logprob] | None] | None out_logprobs: GenericSequence[dict[int, Logprob] | None] | None
for output in final_res.outputs: for output in final_res.outputs:
self._raise_if_error(output.finish_reason, request_id)
assert request.max_tokens is not None assert request.max_tokens is not None
if request.echo: if request.echo:
if request.return_token_ids: if request.return_token_ids:
......
...@@ -5,60 +5,19 @@ import json ...@@ -5,60 +5,19 @@ import json
import sys import sys
import time import time
import traceback import traceback
from collections.abc import AsyncGenerator, Callable, Iterable, Mapping, Sequence from collections.abc import AsyncGenerator, Callable, Iterable, Mapping
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field from dataclasses import dataclass, field
from http import HTTPStatus from http import HTTPStatus
from typing import Any, ClassVar, Generic, TypeAlias, TypeVar from typing import Any, ClassVar, Generic, TypeAlias, TypeVar
import numpy as np import numpy as np
import torch
from fastapi import Request from fastapi import Request
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
from typing_extensions import TypeIs
from vllm.entrypoints.context import (
HarmonyContext,
ParsableContext,
StreamingHarmonyContext,
)
from vllm.entrypoints.openai.protocol import (
FunctionCall,
ResponseInputOutputItem,
ResponsesRequest,
)
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
ScoreRequest,
ScoreResponse,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
from openai.types.responses import ( from openai.types.responses import (
ToolChoiceFunction, ToolChoiceFunction,
) )
from pydantic import ConfigDict, TypeAdapter
from starlette.datastructures import Headers
import vllm.envs as envs import vllm.envs as envs
from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function
...@@ -72,7 +31,12 @@ from vllm.entrypoints.chat_utils import ( ...@@ -72,7 +31,12 @@ from vllm.entrypoints.chat_utils import (
parse_chat_messages_futures, parse_chat_messages_futures,
resolve_chat_template_content_format, resolve_chat_template_content_format,
) )
from vllm.entrypoints.context import ConversationContext from vllm.entrypoints.context import (
ConversationContext,
HarmonyContext,
ParsableContext,
StreamingHarmonyContext,
)
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionNamedToolChoiceParam,
...@@ -83,7 +47,10 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -83,7 +47,10 @@ from vllm.entrypoints.openai.protocol import (
DetokenizeRequest, DetokenizeRequest,
ErrorInfo, ErrorInfo,
ErrorResponse, ErrorResponse,
FunctionCall,
FunctionDefinition, FunctionDefinition,
ResponseInputOutputItem,
ResponsesRequest,
TokenizeChatRequest, TokenizeChatRequest,
TokenizeCompletionRequest, TokenizeCompletionRequest,
TokenizeResponse, TokenizeResponse,
...@@ -92,15 +59,34 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -92,15 +59,34 @@ from vllm.entrypoints.openai.protocol import (
TranslationRequest, TranslationRequest,
) )
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.pooling.classify.protocol import (
ClassificationChatRequest,
ClassificationCompletionRequest,
ClassificationRequest,
ClassificationResponse,
)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse,
)
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
PoolingResponse,
)
from vllm.entrypoints.pooling.score.protocol import (
RerankRequest,
ScoreRequest,
ScoreResponse,
)
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
from vllm.entrypoints.responses_utils import ( from vllm.entrypoints.responses_utils import (
construct_input_messages, construct_input_messages,
) )
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.utils import _validate_truncation_size from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptType, TokensPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import ( from vllm.inputs.parse import (
PromptComponents, PromptComponents,
get_prompt_components, get_prompt_components,
...@@ -109,15 +95,15 @@ from vllm.inputs.parse import ( ...@@ -109,15 +95,15 @@ from vllm.inputs.parse import (
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin from vllm.multimodal import MultiModalDataDict
MultiModalDataDict,
MultiModalUUIDDict,
)
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.tokenizers import DeepseekV32Tokenizer, MistralTokenizer, TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.deepseek_v32 import DeepseekV32Tokenizer
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.tool_parsers import ToolParser, ToolParserManager
from vllm.tracing import ( from vllm.tracing import (
contains_trace_headers, contains_trace_headers,
extract_trace_headers, extract_trace_headers,
...@@ -133,6 +119,15 @@ from vllm.utils.async_utils import ( ...@@ -133,6 +119,15 @@ from vllm.utils.async_utils import (
from vllm.utils.collection_utils import is_list_of from vllm.utils.collection_utils import is_list_of
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
class GenerationError(Exception):
"""raised when finish_reason indicates internal server error (500)"""
def __init__(self, message: str = "Internal server error"):
super().__init__(message)
self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
logger = init_logger(__name__) logger = init_logger(__name__)
CompletionLikeRequest: TypeAlias = ( CompletionLikeRequest: TypeAlias = (
...@@ -174,34 +169,6 @@ AnyResponse: TypeAlias = ( ...@@ -174,34 +169,6 @@ AnyResponse: TypeAlias = (
) )
class TextTokensPrompt(TypedDict):
prompt: str
prompt_token_ids: list[int]
class EmbedsPrompt(TypedDict):
prompt_embeds: torch.Tensor
RequestPrompt: TypeAlias = list[int] | str | TextTokensPrompt | EmbedsPrompt
def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
return (
isinstance(prompt, dict)
and "prompt_token_ids" in prompt
and "prompt_embeds" not in prompt
)
def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
return (
isinstance(prompt, dict)
and "prompt_token_ids" not in prompt
and "prompt_embeds" in prompt
)
RequestT = TypeVar("RequestT", bound=AnyRequest) RequestT = TypeVar("RequestT", bound=AnyRequest)
...@@ -212,8 +179,7 @@ class RequestProcessingMixin: ...@@ -212,8 +179,7 @@ class RequestProcessingMixin:
handling prompt preparation and engine input. handling prompt preparation and engine input.
""" """
request_prompts: Sequence[RequestPrompt] | None = field(default_factory=list) engine_prompts: list[TokensPrompt] | None = field(default_factory=list)
engine_prompts: list[EngineTokensPrompt] | None = field(default_factory=list)
@dataclass(kw_only=True) @dataclass(kw_only=True)
...@@ -414,7 +380,7 @@ class OpenAIServing: ...@@ -414,7 +380,7 @@ class OpenAIServing:
prompts_batch, lora_req_batch = zip( prompts_batch, lora_req_batch = zip(
*[ *[
( (
EngineTokensPrompt( TokensPrompt(
prompt_token_ids=beam.tokens, prompt_token_ids=beam.tokens,
multi_modal_data=beam.multi_modal_data, multi_modal_data=beam.multi_modal_data,
mm_processor_kwargs=beam.mm_processor_kwargs, mm_processor_kwargs=beam.mm_processor_kwargs,
...@@ -456,6 +422,29 @@ class OpenAIServing: ...@@ -456,6 +422,29 @@ class OpenAIServing:
# Iterate through all beam inference results # Iterate through all beam inference results
for i, result in enumerate(output): for i, result in enumerate(output):
current_beam = all_beams[i] current_beam = all_beams[i]
# check for error finish reason and abort beam search
if result.outputs[0].finish_reason == "error":
# yield error output and terminate beam search
yield RequestOutput(
request_id=request_id,
prompt=prompt_text,
outputs=[
CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason="error",
)
],
finished=True,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None,
)
return
if result.outputs[0].logprobs is not None: if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0] logprobs = result.outputs[0].logprobs[0]
all_beams_token_id.extend(list(logprobs.keys())) all_beams_token_id.extend(list(logprobs.keys()))
...@@ -780,6 +769,35 @@ class OpenAIServing: ...@@ -780,6 +769,35 @@ class OpenAIServing:
) )
return json_str return json_str
def _raise_if_error(self, finish_reason: str | None, request_id: str) -> None:
"""Raise GenerationError if finish_reason indicates an error."""
if finish_reason == "error":
logger.error(
"Request %s failed with an internal error during generation",
request_id,
)
raise GenerationError("Internal server error")
def _convert_generation_error_to_response(
self, e: GenerationError
) -> ErrorResponse:
"""Convert GenerationError to ErrorResponse."""
return self.create_error_response(
str(e),
err_type="InternalServerError",
status_code=e.status_code,
)
def _convert_generation_error_to_streaming_response(
self, e: GenerationError
) -> str:
"""Convert GenerationError to streaming error response."""
return self.create_streaming_error_response(
str(e),
err_type="InternalServerError",
status_code=e.status_code,
)
async def _check_model( async def _check_model(
self, self,
request: AnyRequest, request: AnyRequest,
...@@ -884,7 +902,7 @@ class OpenAIServing: ...@@ -884,7 +902,7 @@ class OpenAIServing:
prompt: str, prompt: str,
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
add_special_tokens: bool, add_special_tokens: bool,
) -> TextTokensPrompt: ) -> TokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer) async_tokenizer = self._get_async_tokenizer(tokenizer)
if ( if (
...@@ -925,7 +943,7 @@ class OpenAIServing: ...@@ -925,7 +943,7 @@ class OpenAIServing:
request: AnyRequest, request: AnyRequest,
prompt_ids: list[int], prompt_ids: list[int],
tokenizer: TokenizerLike | None, tokenizer: TokenizerLike | None,
) -> TextTokensPrompt: ) -> TokensPrompt:
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None) truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", None)
if truncate_prompt_tokens is None: if truncate_prompt_tokens is None:
...@@ -948,7 +966,7 @@ class OpenAIServing: ...@@ -948,7 +966,7 @@ class OpenAIServing:
request: AnyRequest, request: AnyRequest,
input_ids: list[int], input_ids: list[int],
input_text: str, input_text: str,
) -> TextTokensPrompt: ) -> TokensPrompt:
token_num = len(input_ids) token_num = len(input_ids)
# Note: EmbeddingRequest, ClassificationRequest, # Note: EmbeddingRequest, ClassificationRequest,
...@@ -979,7 +997,7 @@ class OpenAIServing: ...@@ -979,7 +997,7 @@ class OpenAIServing:
f"{token_num} tokens in the input for {operation}. " f"{token_num} tokens in the input for {operation}. "
f"Please reduce the length of the input." f"Please reduce the length of the input."
) )
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation # and does not require model context length validation
...@@ -987,7 +1005,7 @@ class OpenAIServing: ...@@ -987,7 +1005,7 @@ class OpenAIServing:
request, request,
(TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest), (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest),
): ):
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# chat completion endpoint supports max_completion_tokens # chat completion endpoint supports max_completion_tokens
if isinstance(request, ChatCompletionRequest): if isinstance(request, ChatCompletionRequest):
...@@ -1015,7 +1033,7 @@ class OpenAIServing: ...@@ -1015,7 +1033,7 @@ class OpenAIServing:
f" - {token_num})." f" - {token_num})."
) )
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
async def _tokenize_prompt_input_async( async def _tokenize_prompt_input_async(
self, self,
...@@ -1023,7 +1041,7 @@ class OpenAIServing: ...@@ -1023,7 +1041,7 @@ class OpenAIServing:
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
prompt_input: str | list[int], prompt_input: str | list[int],
add_special_tokens: bool = True, add_special_tokens: bool = True,
) -> TextTokensPrompt: ) -> TokensPrompt:
""" """
A simpler implementation that tokenizes a single prompt input. A simpler implementation that tokenizes a single prompt input.
""" """
...@@ -1042,7 +1060,7 @@ class OpenAIServing: ...@@ -1042,7 +1060,7 @@ class OpenAIServing:
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
prompt_inputs: Iterable[str | list[int]], prompt_inputs: Iterable[str | list[int]],
add_special_tokens: bool = True, add_special_tokens: bool = True,
) -> AsyncGenerator[TextTokensPrompt, None]: ) -> AsyncGenerator[TokensPrompt, None]:
""" """
A simpler implementation that tokenizes multiple prompt inputs. A simpler implementation that tokenizes multiple prompt inputs.
""" """
...@@ -1095,11 +1113,7 @@ class OpenAIServing: ...@@ -1095,11 +1113,7 @@ class OpenAIServing:
chat_template_kwargs: dict[str, Any] | None = None, chat_template_kwargs: dict[str, Any] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
) -> tuple[ ) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
list[ConversationMessage],
Sequence[RequestPrompt],
list[EngineTokensPrompt],
]:
model_config = self.model_config model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
...@@ -1172,9 +1186,7 @@ class OpenAIServing: ...@@ -1172,9 +1186,7 @@ class OpenAIServing:
"Prompt has to be a string", "Prompt has to be a string",
"when the tokenizer is not initialised", "when the tokenizer is not initialised",
) )
prompt_inputs = TextTokensPrompt( prompt_inputs = TokensPrompt(prompt=request_prompt, prompt_token_ids=[1])
prompt=request_prompt, prompt_token_ids=[1]
)
elif isinstance(request_prompt, str): elif isinstance(request_prompt, str):
prompt_inputs = await self._tokenize_prompt_input_async( prompt_inputs = await self._tokenize_prompt_input_async(
request, request,
...@@ -1187,14 +1199,15 @@ class OpenAIServing: ...@@ -1187,14 +1199,15 @@ class OpenAIServing:
assert is_list_of(request_prompt, int), ( assert is_list_of(request_prompt, int), (
"Prompt has to be either a string or a list of token ids" "Prompt has to be either a string or a list of token ids"
) )
prompt_inputs = TextTokensPrompt( prompt_inputs = TokensPrompt(
prompt=tokenizer.decode(request_prompt), prompt=tokenizer.decode(request_prompt),
prompt_token_ids=request_prompt, prompt_token_ids=request_prompt,
) )
engine_prompt = EngineTokensPrompt( engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["prompt_token_ids"])
prompt_token_ids=prompt_inputs["prompt_token_ids"] if "prompt" in prompt_inputs:
) engine_prompt["prompt"] = prompt_inputs["prompt"]
if mm_data is not None: if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data engine_prompt["multi_modal_data"] = mm_data
...@@ -1207,7 +1220,7 @@ class OpenAIServing: ...@@ -1207,7 +1220,7 @@ class OpenAIServing:
if hasattr(request, "cache_salt") and request.cache_salt is not None: if hasattr(request, "cache_salt") and request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt engine_prompt["cache_salt"] = request.cache_salt
return conversation, [request_prompt], [engine_prompt] return conversation, [engine_prompt]
async def _process_inputs( async def _process_inputs(
self, self,
...@@ -1239,7 +1252,7 @@ class OpenAIServing: ...@@ -1239,7 +1252,7 @@ class OpenAIServing:
async def _render_next_turn( async def _render_next_turn(
self, self,
request: ResponsesRequest, request: ResponsesRequest,
tokenizer: AnyTokenizer, tokenizer: TokenizerLike | None,
messages: list[ResponseInputOutputItem], messages: list[ResponseInputOutputItem],
tool_dicts: list[dict[str, Any]] | None, tool_dicts: list[dict[str, Any]] | None,
tool_parser, tool_parser,
...@@ -1250,7 +1263,7 @@ class OpenAIServing: ...@@ -1250,7 +1263,7 @@ class OpenAIServing:
request_input=messages, request_input=messages,
) )
_, request_prompts, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self._preprocess_chat(
request, request,
tokenizer, tokenizer,
new_messages, new_messages,
...@@ -1259,20 +1272,20 @@ class OpenAIServing: ...@@ -1259,20 +1272,20 @@ class OpenAIServing:
chat_template=chat_template, chat_template=chat_template,
chat_template_content_format=chat_template_content_format, chat_template_content_format=chat_template_content_format,
) )
return request_prompts, engine_prompts return engine_prompts
async def _generate_with_builtin_tools( async def _generate_with_builtin_tools(
self, self,
request_id: str, request_id: str,
request_prompt: RequestPrompt, engine_prompt: TokensPrompt,
engine_prompt: EngineTokensPrompt,
sampling_params: SamplingParams, sampling_params: SamplingParams,
context: ConversationContext, context: ConversationContext,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
priority: int = 0, priority: int = 0,
**kwargs, **kwargs,
): ):
prompt_text, _, _ = self._get_prompt_components(request_prompt) prompt_text, _, _ = self._get_prompt_components(engine_prompt)
orig_priority = priority orig_priority = priority
sub_request = 0 sub_request = 0
while True: while True:
...@@ -1280,7 +1293,7 @@ class OpenAIServing: ...@@ -1280,7 +1293,7 @@ class OpenAIServing:
sub_request_id = f"{request_id}_{sub_request}" sub_request_id = f"{request_id}_{sub_request}"
self._log_inputs( self._log_inputs(
sub_request_id, sub_request_id,
request_prompt, engine_prompt,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -1325,10 +1338,9 @@ class OpenAIServing: ...@@ -1325,10 +1338,9 @@ class OpenAIServing:
# Render the next prompt token ids. # Render the next prompt token ids.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)): if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
prompt_token_ids = context.render_for_completion() prompt_token_ids = context.render_for_completion()
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
request_prompt = prompt_token_ids
elif isinstance(context, ParsableContext): elif isinstance(context, ParsableContext):
request_prompts, engine_prompts = await self._render_next_turn( engine_prompts = await self._render_next_turn(
context.request, context.request,
context.tokenizer, context.tokenizer,
context.parser.response_messages, context.parser.response_messages,
...@@ -1338,8 +1350,7 @@ class OpenAIServing: ...@@ -1338,8 +1350,7 @@ class OpenAIServing:
context.chat_template_content_format, context.chat_template_content_format,
) )
engine_prompt = engine_prompts[0] engine_prompt = engine_prompts[0]
request_prompt = request_prompts[0] prompt_text, _, _ = self._get_prompt_components(engine_prompt)
prompt_text, _, _ = self._get_prompt_components(request_prompt)
# Update the sampling params. # Update the sampling params.
sampling_params.max_tokens = self.max_model_len - len( sampling_params.max_tokens = self.max_model_len - len(
...@@ -1349,19 +1360,13 @@ class OpenAIServing: ...@@ -1349,19 +1360,13 @@ class OpenAIServing:
priority = orig_priority - 1 priority = orig_priority - 1
sub_request += 1 sub_request += 1
def _get_prompt_components( def _get_prompt_components(self, prompt: PromptType) -> PromptComponents:
self, return get_prompt_components(prompt)
prompt: RequestPrompt | PromptType,
) -> PromptComponents:
if isinstance(prompt, list):
return PromptComponents(token_ids=prompt)
return get_prompt_components(prompt) # type: ignore[arg-type]
def _log_inputs( def _log_inputs(
self, self,
request_id: str, request_id: str,
inputs: RequestPrompt | PromptType, inputs: PromptType,
params: SamplingParams | PoolingParams | BeamSearchParams | None, params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
) -> None: ) -> None:
...@@ -1423,7 +1428,7 @@ class OpenAIServing: ...@@ -1423,7 +1428,7 @@ class OpenAIServing:
@staticmethod @staticmethod
def _parse_tool_calls_from_content( def _parse_tool_calls_from_content(
request: ResponsesRequest | ChatCompletionRequest, request: ResponsesRequest | ChatCompletionRequest,
tokenizer: TokenizerLike, tokenizer: TokenizerLike | None,
enable_auto_tools: bool, enable_auto_tools: bool,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None, tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
content: str | None = None, content: str | None = None,
...@@ -1463,6 +1468,11 @@ class OpenAIServing: ...@@ -1463,6 +1468,11 @@ class OpenAIServing:
and enable_auto_tools and enable_auto_tools
and (request.tool_choice == "auto" or request.tool_choice is None) and (request.tool_choice == "auto" or request.tool_choice is None)
): ):
if tokenizer is None:
raise ValueError(
"Tokenizer not available when `skip_tokenizer_init=True`"
)
# Automatic Tool Call Parsing # Automatic Tool Call Parsing
try: try:
tool_parser = tool_parser_cls(tokenizer) tool_parser = tool_parser_cls(tokenizer)
......
...@@ -50,6 +50,7 @@ from openai.types.responses.response_reasoning_item import ( ...@@ -50,6 +50,7 @@ from openai.types.responses.response_reasoning_item import (
) )
from openai.types.responses.tool import Mcp, Tool from openai.types.responses.tool import Mcp, Tool
from openai_harmony import Message as OpenAIHarmonyMessage from openai_harmony import Message as OpenAIHarmonyMessage
from pydantic import TypeAdapter
from vllm import envs from vllm import envs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
...@@ -94,7 +95,10 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -94,7 +95,10 @@ from vllm.entrypoints.openai.protocol import (
ResponseUsage, ResponseUsage,
StreamingResponsesResponse, StreamingResponsesResponse,
) )
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import (
GenerationError,
OpenAIServing,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.responses_utils import ( from vllm.entrypoints.responses_utils import (
construct_input_messages, construct_input_messages,
...@@ -103,7 +107,7 @@ from vllm.entrypoints.responses_utils import ( ...@@ -103,7 +107,7 @@ from vllm.entrypoints.responses_utils import (
make_response_output_items_from_parsable_context, make_response_output_items_from_parsable_context,
) )
from vllm.entrypoints.tool_server import ToolServer from vllm.entrypoints.tool_server import ToolServer
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs from vllm.logprobs import SampleLogprobs
...@@ -254,7 +258,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -254,7 +258,7 @@ class OpenAIServingResponses(OpenAIServing):
self.tool_server = tool_server self.tool_server = tool_server
def _validate_generator_input( def _validate_generator_input(
self, engine_prompt: EngineTokensPrompt self, engine_prompt: TokensPrompt
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Add validations to the input to the generator here.""" """Add validations to the input to the generator here."""
if self.max_model_len <= len(engine_prompt["prompt_token_ids"]): if self.max_model_len <= len(engine_prompt["prompt_token_ids"]):
...@@ -349,11 +353,11 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -349,11 +353,11 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer() tokenizer = await self.engine_client.get_tokenizer()
if self.use_harmony: if self.use_harmony:
messages, request_prompts, engine_prompts = ( messages, engine_prompts = self._make_request_with_harmony(
self._make_request_with_harmony(request, prev_response) request, prev_response
) )
else: else:
messages, request_prompts, engine_prompts = await self._make_request( messages, engine_prompts = await self._make_request(
request, prev_response, tokenizer request, prev_response, tokenizer
) )
...@@ -389,7 +393,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -389,7 +393,7 @@ class OpenAIServingResponses(OpenAIServing):
assert len(builtin_tool_list) == 0 assert len(builtin_tool_list) == 0
available_tools = [] available_tools = []
try: try:
for i, engine_prompt in enumerate(engine_prompts): for engine_prompt in engine_prompts:
maybe_error = self._validate_generator_input(engine_prompt) maybe_error = self._validate_generator_input(engine_prompt)
if maybe_error is not None: if maybe_error is not None:
return maybe_error return maybe_error
...@@ -416,7 +420,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -416,7 +420,7 @@ class OpenAIServingResponses(OpenAIServing):
context = HarmonyContext(messages, available_tools) context = HarmonyContext(messages, available_tools)
else: else:
if envs.VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: if envs.VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT:
# This is an feature in development for parsing # This is a feature in development for parsing
# tokens during generation instead of at the end # tokens during generation instead of at the end
context = ParsableContext( context = ParsableContext(
response_messages=messages, response_messages=messages,
...@@ -445,7 +449,6 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -445,7 +449,6 @@ class OpenAIServingResponses(OpenAIServing):
) )
generator = self._generate_with_builtin_tools( generator = self._generate_with_builtin_tools(
request_id=request.request_id, request_id=request.request_id,
request_prompt=request_prompts[i],
engine_prompt=engine_prompt, engine_prompt=engine_prompt,
sampling_params=sampling_params, sampling_params=sampling_params,
context=context, context=context,
...@@ -541,6 +544,8 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -541,6 +544,8 @@ class OpenAIServingResponses(OpenAIServing):
tokenizer, tokenizer,
request_metadata, request_metadata,
) )
except GenerationError as e:
return self._convert_generation_error_to_response(e)
except Exception as e: except Exception as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
...@@ -558,7 +563,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -558,7 +563,7 @@ class OpenAIServingResponses(OpenAIServing):
prev_msg=self.msg_store.get(prev_response.id) if prev_response else None, prev_msg=self.msg_store.get(prev_response.id) if prev_response else None,
prev_response_output=prev_response.output if prev_response else None, prev_response_output=prev_response.output if prev_response else None,
) )
_, request_prompts, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self._preprocess_chat(
request, request,
tokenizer, tokenizer,
messages, messages,
...@@ -567,7 +572,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -567,7 +572,7 @@ class OpenAIServingResponses(OpenAIServing):
chat_template=self.chat_template, chat_template=self.chat_template,
chat_template_content_format=self.chat_template_content_format, chat_template_content_format=self.chat_template_content_format,
) )
return messages, request_prompts, engine_prompts return messages, engine_prompts
def _make_request_with_harmony( def _make_request_with_harmony(
self, self,
...@@ -580,13 +585,13 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -580,13 +585,13 @@ class OpenAIServingResponses(OpenAIServing):
) )
messages = self._construct_input_messages_with_harmony(request, prev_response) messages = self._construct_input_messages_with_harmony(request, prev_response)
prompt_token_ids = render_for_completion(messages) prompt_token_ids = render_for_completion(messages)
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
# Add cache_salt if provided in the request # Add cache_salt if provided in the request
if request.cache_salt is not None: if request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt engine_prompt["cache_salt"] = request.cache_salt
return messages, [prompt_token_ids], [engine_prompt] return messages, [engine_prompt]
async def _initialize_tool_sessions( async def _initialize_tool_sessions(
self, self,
...@@ -648,6 +653,8 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -648,6 +653,8 @@ class OpenAIServingResponses(OpenAIServing):
status = "incomplete" status = "incomplete"
elif context.finish_reason == "abort": elif context.finish_reason == "abort":
status = "cancelled" status = "cancelled"
else:
self._raise_if_error(context.finish_reason, request.request_id)
else: else:
status = "incomplete" status = "incomplete"
elif isinstance(context, ParsableContext): elif isinstance(context, ParsableContext):
...@@ -673,6 +680,9 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -673,6 +680,9 @@ class OpenAIServingResponses(OpenAIServing):
assert len(final_res.outputs) == 1 assert len(final_res.outputs) == 1
final_output = final_res.outputs[0] final_output = final_res.outputs[0]
# finish_reason='error' indicates retryable internal error
self._raise_if_error(final_output.finish_reason, request.request_id)
output = self._make_response_output_items(request, final_output, tokenizer) output = self._make_response_output_items(request, final_output, tokenizer)
if request.enable_response_messages: if request.enable_response_messages:
...@@ -1066,6 +1076,8 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1066,6 +1076,8 @@ class OpenAIServingResponses(OpenAIServing):
async for event in generator: async for event in generator:
event_deque.append(event) event_deque.append(event)
new_event_signal.set() # Signal new event available new_event_signal.set() # Signal new event available
except GenerationError as e:
response = self._convert_generation_error_to_response(e)
except Exception as e: except Exception as e:
logger.exception("Background request failed for %s", request.request_id) logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(str(e)) response = self.create_error_response(str(e))
...@@ -1089,6 +1101,8 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1089,6 +1101,8 @@ class OpenAIServingResponses(OpenAIServing):
): ):
try: try:
response = await self.responses_full_generator(request, *args, **kwargs) response = await self.responses_full_generator(request, *args, **kwargs)
except GenerationError as e:
response = self._convert_generation_error_to_response(e)
except Exception as e: except Exception as e:
logger.exception("Background request failed for %s", request.request_id) logger.exception("Background request failed for %s", request.request_id)
response = self.create_error_response(str(e)) response = self.create_error_response(str(e))
...@@ -1227,6 +1241,8 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1227,6 +1241,8 @@ class OpenAIServingResponses(OpenAIServing):
continue continue
if ctx.last_output.outputs: if ctx.last_output.outputs:
output = ctx.last_output.outputs[0] output = ctx.last_output.outputs[0]
# finish_reason='error' indicates a retryable error
self._raise_if_error(output.finish_reason, request.request_id)
if reasoning_parser: if reasoning_parser:
delta_message = reasoning_parser.extract_reasoning_streaming( delta_message = reasoning_parser.extract_reasoning_streaming(
previous_text=previous_text, previous_text=previous_text,
...@@ -1522,6 +1538,9 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -1522,6 +1538,9 @@ class OpenAIServingResponses(OpenAIServing):
async for ctx in result_generator: async for ctx in result_generator:
assert isinstance(ctx, StreamingHarmonyContext) assert isinstance(ctx, StreamingHarmonyContext)
# finish_reason='error' indicates a retryable error
self._raise_if_error(ctx.finish_reason, request.request_id)
if ctx.is_expecting_start(): if ctx.is_expecting_start():
current_output_index += 1 current_output_index += 1
sent_output_item_added = False sent_output_item_added = False
...@@ -2016,18 +2035,25 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -2016,18 +2035,25 @@ class OpenAIServingResponses(OpenAIServing):
) )
) )
async for event_data in processer( try:
request, async for event_data in processer(
sampling_params, request,
result_generator, sampling_params,
context, result_generator,
model_name, context,
tokenizer, model_name,
request_metadata, tokenizer,
created_time, request_metadata,
_increment_sequence_number_and_return, created_time,
): _increment_sequence_number_and_return,
yield event_data ):
yield event_data
except GenerationError as e:
error_json = self._convert_generation_error_to_streaming_response(e)
yield _increment_sequence_number_and_return(
TypeAdapter(StreamingResponsesResponse).validate_json(error_json)
)
return
async def empty_async_generator(): async def empty_async_generator():
# A hack to trick Python to think this is a generator but # A hack to trick Python to think this is a generator but
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( import warnings
ToolParser,
ToolParserManager,
)
__all__ = ["ToolParser", "ToolParserManager"]
def __getattr__(name: str):
if name == "ToolParser":
from vllm.tool_parsers import ToolParser
""" warnings.warn(
Register a lazy module mapping. "`vllm.entrypoints.openai.tool_parsers.ToolParser` has been moved to "
"`vllm.tool_parsers.ToolParser`. "
"The old name will be removed in v0.14.",
DeprecationWarning,
stacklevel=2,
)
Example: return ToolParser
ToolParserManager.register_lazy_module( if name == "ToolParserManager":
name="kimi_k2", from vllm.tool_parsers import ToolParserManager
module_path="vllm.entrypoints.openai.tool_parsers.kimi_k2_parser",
class_name="KimiK2ToolParser",
)
"""
warnings.warn(
"`vllm.entrypoints.openai.tool_parsers.ToolParserManager` "
"has been moved to `vllm.tool_parsers.ToolParserManager`. "
"The old name will be removed in v0.14.",
DeprecationWarning,
stacklevel=2,
)
_TOOL_PARSERS_TO_REGISTER = { return ToolParserManager
"deepseek_v3": ( # name
"deepseekv3_tool_parser", # filename
"DeepSeekV3ToolParser", # class_name
),
"deepseek_v31": (
"deepseekv31_tool_parser",
"DeepSeekV31ToolParser",
),
"deepseek_v32": (
"deepseekv32_tool_parser",
"DeepSeekV32ToolParser",
),
"ernie45": (
"ernie45_tool_parser",
"Ernie45ToolParser",
),
"glm45": (
"glm4_moe_tool_parser",
"Glm4MoeModelToolParser",
),
"granite-20b-fc": (
"granite_20b_fc_tool_parser",
"Granite20bFCToolParser",
),
"granite": (
"granite_tool_parser",
"GraniteToolParser",
),
"hermes": (
"hermes_tool_parser",
"Hermes2ProToolParser",
),
"hunyuan_a13b": (
"hunyuan_a13b_tool_parser",
"HunyuanA13BToolParser",
),
"internlm": (
"internlm2_tool_parser",
"Internlm2ToolParser",
),
"jamba": (
"jamba_tool_parser",
"JambaToolParser",
),
"kimi_k2": (
"kimi_k2_tool_parser",
"KimiK2ToolParser",
),
"llama3_json": (
"llama_tool_parser",
"Llama3JsonToolParser",
),
"llama4_json": (
"llama_tool_parser",
"Llama3JsonToolParser",
),
"llama4_pythonic": (
"llama4_pythonic_tool_parser",
"Llama4PythonicToolParser",
),
"longcat": (
"longcat_tool_parser",
"LongcatFlashToolParser",
),
"minimax_m2": (
"minimax_m2_tool_parser",
"MinimaxM2ToolParser",
),
"minimax": (
"minimax_tool_parser",
"MinimaxToolParser",
),
"mistral": (
"mistral_tool_parser",
"MistralToolParser",
),
"olmo3": (
"olmo3_tool_parser",
"Olmo3PythonicToolParser",
),
"openai": (
"openai_tool_parser",
"OpenAIToolParser",
),
"phi4_mini_json": (
"phi4mini_tool_parser",
"Phi4MiniJsonToolParser",
),
"pythonic": (
"pythonic_tool_parser",
"PythonicToolParser",
),
"qwen3_coder": (
"qwen3coder_tool_parser",
"Qwen3CoderToolParser",
),
"qwen3_xml": (
"qwen3xml_tool_parser",
"Qwen3XMLToolParser",
),
"seed_oss": (
"seed_oss_tool_parser",
"SeedOssToolParser",
),
"step3": (
"step3_tool_parser",
"Step3ToolParser",
),
"xlam": (
"xlam_tool_parser",
"xLAMToolParser",
),
"gigachat3": (
"gigachat3_tool_parser",
"GigaChat3ToolParser",
),
}
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def register_lazy_tool_parsers():
for name, (file_name, class_name) in _TOOL_PARSERS_TO_REGISTER.items():
module_path = f"vllm.entrypoints.openai.tool_parsers.{file_name}"
ToolParserManager.register_lazy_module(name, module_path, class_name)
register_lazy_tool_parsers()
...@@ -72,11 +72,7 @@ class ClassificationMixin(OpenAIServing): ...@@ -72,11 +72,7 @@ class ClassificationMixin(OpenAIServing):
if ret: if ret:
return ret return ret
( _, engine_prompts = await self._preprocess_chat(
_,
_,
engine_prompts,
) = await self._preprocess_chat(
cast(ChatCompletionRequest, chat_request), cast(ChatCompletionRequest, chat_request),
ctx.tokenizer, ctx.tokenizer,
messages, messages,
......
...@@ -20,7 +20,6 @@ from vllm.entrypoints.openai.serving_engine import ( ...@@ -20,7 +20,6 @@ from vllm.entrypoints.openai.serving_engine import (
EmbeddingServeContext, EmbeddingServeContext,
OpenAIServing, OpenAIServing,
ServeContext, ServeContext,
TextTokensPrompt,
) )
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.pooling.embed.protocol import ( from vllm.entrypoints.pooling.embed.protocol import (
...@@ -32,7 +31,7 @@ from vllm.entrypoints.pooling.embed.protocol import ( ...@@ -32,7 +31,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingResponseData, EmbeddingResponseData,
) )
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import ( from vllm.outputs import (
EmbeddingRequestOutput, EmbeddingRequestOutput,
...@@ -83,11 +82,7 @@ class EmbeddingMixin(OpenAIServing): ...@@ -83,11 +82,7 @@ class EmbeddingMixin(OpenAIServing):
renderer = self._get_renderer(tokenizer) renderer = self._get_renderer(tokenizer)
if isinstance(ctx.request, EmbeddingChatRequest): if isinstance(ctx.request, EmbeddingChatRequest):
( _, ctx.engine_prompts = await self._preprocess_chat(
_,
_,
ctx.engine_prompts,
) = await self._preprocess_chat(
ctx.request, ctx.request,
tokenizer, tokenizer,
ctx.request.messages, ctx.request.messages,
...@@ -209,14 +204,13 @@ class EmbeddingMixin(OpenAIServing): ...@@ -209,14 +204,13 @@ class EmbeddingMixin(OpenAIServing):
async def _process_chunked_request( async def _process_chunked_request(
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
original_prompt: TextTokensPrompt, token_ids: list[int],
pooling_params, pooling_params,
trace_headers, trace_headers,
prompt_idx: int, prompt_idx: int,
) -> list[AsyncGenerator[PoolingRequestOutput, None]]: ) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
"""Process a single prompt using chunked processing.""" """Process a single prompt using chunked processing."""
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
token_ids = original_prompt["prompt_token_ids"]
# Split into chunks using max_position_embeddings # Split into chunks using max_position_embeddings
max_pos_embeddings = self._get_max_position_embeddings() max_pos_embeddings = self._get_max_position_embeddings()
...@@ -228,18 +222,12 @@ class EmbeddingMixin(OpenAIServing): ...@@ -228,18 +222,12 @@ class EmbeddingMixin(OpenAIServing):
chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}" chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
# Create engine prompt for this chunk # Create engine prompt for this chunk
chunk_engine_prompt = EngineTokensPrompt(prompt_token_ids=chunk_tokens) chunk_engine_prompt = TokensPrompt(prompt_token_ids=chunk_tokens)
# Create chunk request prompt for logging
chunk_text = ""
chunk_request_prompt = TextTokensPrompt(
prompt=chunk_text, prompt_token_ids=chunk_tokens
)
# Log the chunk # Log the chunk
self._log_inputs( self._log_inputs(
chunk_request_id, chunk_request_id,
chunk_request_prompt, chunk_engine_prompt,
params=pooling_params, params=pooling_params,
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
) )
...@@ -263,7 +251,7 @@ class EmbeddingMixin(OpenAIServing): ...@@ -263,7 +251,7 @@ class EmbeddingMixin(OpenAIServing):
request, request,
input_ids: list[int], input_ids: list[int],
input_text: str, input_text: str,
) -> TextTokensPrompt: ) -> TokensPrompt:
"""Override to support chunked processing for embedding requests.""" """Override to support chunked processing for embedding requests."""
token_num = len(input_ids) token_num = len(input_ids)
...@@ -328,23 +316,15 @@ class EmbeddingMixin(OpenAIServing): ...@@ -328,23 +316,15 @@ class EmbeddingMixin(OpenAIServing):
) )
) )
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) return TokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
# For other request types, use the parent's implementation # For other request types, use the parent's implementation
return super()._validate_input(request, input_ids, input_text) return super()._validate_input(request, input_ids, input_text)
def _is_text_tokens_prompt(self, prompt) -> bool:
"""Check if a prompt is a TextTokensPrompt (has prompt_token_ids)."""
return (
isinstance(prompt, dict)
and "prompt_token_ids" in prompt
and "prompt_embeds" not in prompt
)
async def _create_single_prompt_generator( async def _create_single_prompt_generator(
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
engine_prompt: EngineTokensPrompt, engine_prompt: TokensPrompt,
pooling_params: PoolingParams, pooling_params: PoolingParams,
trace_headers: Mapping[str, str] | None, trace_headers: Mapping[str, str] | None,
prompt_index: int, prompt_index: int,
...@@ -413,14 +393,16 @@ class EmbeddingMixin(OpenAIServing): ...@@ -413,14 +393,16 @@ class EmbeddingMixin(OpenAIServing):
for i, engine_prompt in enumerate(ctx.engine_prompts): for i, engine_prompt in enumerate(ctx.engine_prompts):
# Check if this specific prompt needs chunked processing # Check if this specific prompt needs chunked processing
if self._is_text_tokens_prompt(engine_prompt): if "prompt_token_ids" in engine_prompt:
# Cast to TextTokensPrompt since we've verified prompt_token_ids = engine_prompt["prompt_token_ids"]
# prompt_token_ids if len(prompt_token_ids) > max_pos_embeddings:
text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
if len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings:
# Use chunked processing for this prompt # Use chunked processing for this prompt
chunk_generators = await self._process_chunked_request( chunk_generators = await self._process_chunked_request(
ctx, text_tokens_prompt, pooling_params, trace_headers, i ctx,
prompt_token_ids,
pooling_params,
trace_headers,
i,
) )
generators.extend(chunk_generators) generators.extend(chunk_generators)
continue continue
...@@ -578,14 +560,13 @@ class EmbeddingMixin(OpenAIServing): ...@@ -578,14 +560,13 @@ class EmbeddingMixin(OpenAIServing):
# Get original prompt token IDs for this prompt # Get original prompt token IDs for this prompt
original_prompt = ctx.engine_prompts[prompt_idx] original_prompt = ctx.engine_prompts[prompt_idx]
if not self._is_text_tokens_prompt(original_prompt): if "prompt_token_ids" not in original_prompt:
return self.create_error_response( return self.create_error_response(
f"Chunked prompt {prompt_idx} is not a TextTokensPrompt" f"Chunked prompt {prompt_idx} does not contain "
"token IDs"
) )
original_token_ids = cast(TextTokensPrompt, original_prompt)[ original_token_ids = original_prompt["prompt_token_ids"]
"prompt_token_ids"
]
pooling_request_output = PoolingRequestOutput( pooling_request_output = PoolingRequestOutput(
request_id=aggregator["request_id"], request_id=aggregator["request_id"],
......
...@@ -137,11 +137,8 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -137,11 +137,8 @@ class OpenAIServingPooling(OpenAIServing):
) )
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
(
_, _, engine_prompts = await self._preprocess_chat(
_,
engine_prompts,
) = await self._preprocess_chat(
request, request,
tokenizer, tokenizer,
request.messages, request.messages,
......
...@@ -120,6 +120,7 @@ class RerankResult(BaseModel): ...@@ -120,6 +120,7 @@ class RerankResult(BaseModel):
class RerankUsage(BaseModel): class RerankUsage(BaseModel):
prompt_tokens: int
total_tokens: int total_tokens: int
......
...@@ -38,7 +38,8 @@ from vllm.inputs.data import TokensPrompt ...@@ -38,7 +38,8 @@ from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
from vllm.tokenizers import MistralTokenizer, TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.async_utils import make_async, merge_async_iterators from vllm.utils.async_utils import make_async, merge_async_iterators
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -501,5 +502,7 @@ class ServingScores(OpenAIServing): ...@@ -501,5 +502,7 @@ class ServingScores(OpenAIServing):
id=request_id, id=request_id,
model=model_name, model=model_name,
results=results, results=results,
usage=RerankUsage(total_tokens=num_prompt_tokens), usage=RerankUsage(
total_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens
),
) )
...@@ -12,9 +12,7 @@ import torch ...@@ -12,9 +12,7 @@ import torch
from pydantic import Field from pydantic import Field
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt from vllm.inputs.data import EmbedsPrompt, TextPrompt, TokensPrompt
from vllm.inputs.data import TextPrompt as EngineTextPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import get_prompt_components, parse_raw_prompts from vllm.inputs.parse import get_prompt_components, parse_raw_prompts
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.utils.async_utils import AsyncMicrobatchTokenizer from vllm.utils.async_utils import AsyncMicrobatchTokenizer
...@@ -97,7 +95,7 @@ class BaseRenderer(ABC): ...@@ -97,7 +95,7 @@ class BaseRenderer(ABC):
*, *,
prompt_or_prompts: str | list[str] | list[int] | list[list[int]], prompt_or_prompts: str | list[str] | list[int] | list[list[int]],
config: RenderConfig, config: RenderConfig,
) -> list[EngineTokensPrompt]: ) -> list[TokensPrompt]:
""" """
Convert text or token inputs into engine-ready TokensPrompt objects. Convert text or token inputs into engine-ready TokensPrompt objects.
...@@ -115,7 +113,7 @@ class BaseRenderer(ABC): ...@@ -115,7 +113,7 @@ class BaseRenderer(ABC):
(e.g., tokenization and length handling). (e.g., tokenization and length handling).
Returns: Returns:
list[EngineTokensPrompt]: Engine-ready token prompts. list[TokensPrompt]: Engine-ready token prompts.
Raises: Raises:
ValueError: If input formats are invalid or length limits exceeded. ValueError: If input formats are invalid or length limits exceeded.
...@@ -129,7 +127,7 @@ class BaseRenderer(ABC): ...@@ -129,7 +127,7 @@ class BaseRenderer(ABC):
prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None, prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None, prompt_embeds: bytes | list[bytes] | None = None,
config: RenderConfig, config: RenderConfig,
) -> list[EngineTokensPrompt | EngineEmbedsPrompt]: ) -> list[TokensPrompt | EmbedsPrompt]:
""" """
Convert text/token and/or base64-encoded embeddings inputs into Convert text/token and/or base64-encoded embeddings inputs into
engine-ready prompt objects using a unified RenderConfig. engine-ready prompt objects using a unified RenderConfig.
...@@ -146,7 +144,7 @@ class BaseRenderer(ABC): ...@@ -146,7 +144,7 @@ class BaseRenderer(ABC):
(e.g., tokenization and length handling). (e.g., tokenization and length handling).
Returns: Returns:
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]: list[Union[TokensPrompt, EmbedsPrompt]]:
Engine-ready prompt objects. Engine-ready prompt objects.
Raises: Raises:
...@@ -161,31 +159,34 @@ class BaseRenderer(ABC): ...@@ -161,31 +159,34 @@ class BaseRenderer(ABC):
prompt_embeds: bytes | list[bytes], prompt_embeds: bytes | list[bytes],
truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None, truncate_prompt_tokens: Annotated[int, Field(ge=0)] | None = None,
cache_salt: str | None = None, cache_salt: str | None = None,
) -> list[EngineEmbedsPrompt]: ) -> list[EmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects.""" """Load and validate base64-encoded embeddings into prompt objects."""
if not self.model_config.enable_prompt_embeds: if not self.model_config.enable_prompt_embeds:
raise ValueError( raise ValueError(
"You must set `--enable-prompt-embeds` to input `prompt_embeds`." "You must set `--enable-prompt-embeds` to input `prompt_embeds`."
) )
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt: def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
tensor = torch.load( # Enable sparse tensor integrity checks to prevent out-of-bounds
io.BytesIO(pybase64.b64decode(embed, validate=True)), # writes from maliciously crafted tensors
weights_only=True, with torch.sparse.check_sparse_tensor_invariants():
map_location=torch.device("cpu"), tensor = torch.load(
) io.BytesIO(pybase64.b64decode(embed, validate=True)),
assert isinstance(tensor, torch.Tensor) and tensor.dtype in ( weights_only=True,
torch.float32, map_location=torch.device("cpu"),
torch.bfloat16, )
torch.float16, assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
) torch.float32,
tensor = tensor.to_dense() torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2: if tensor.dim() > 2:
tensor = tensor.squeeze(0) tensor = tensor.squeeze(0)
assert tensor.dim() == 2 assert tensor.dim() == 2
if truncate_prompt_tokens is not None: if truncate_prompt_tokens is not None:
tensor = tensor[-truncate_prompt_tokens:] tensor = tensor[-truncate_prompt_tokens:]
embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor) embeds_prompt = EmbedsPrompt(prompt_embeds=tensor)
if cache_salt is not None: if cache_salt is not None:
embeds_prompt["cache_salt"] = cache_salt embeds_prompt["cache_salt"] = cache_salt
return embeds_prompt return embeds_prompt
...@@ -213,7 +214,7 @@ class CompletionRenderer(BaseRenderer): ...@@ -213,7 +214,7 @@ class CompletionRenderer(BaseRenderer):
*, *,
prompt_or_prompts: str | list[str] | list[int] | list[list[int]], prompt_or_prompts: str | list[str] | list[int] | list[list[int]],
config: RenderConfig, config: RenderConfig,
) -> list[EngineTokensPrompt]: ) -> list[TokensPrompt]:
"""Implementation of prompt rendering for completion-style requests. """Implementation of prompt rendering for completion-style requests.
Uses async tokenizer pooling for improved performance. See base class Uses async tokenizer pooling for improved performance. See base class
...@@ -240,7 +241,7 @@ class CompletionRenderer(BaseRenderer): ...@@ -240,7 +241,7 @@ class CompletionRenderer(BaseRenderer):
prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None, prompt_or_prompts: str | list[str] | list[int] | list[list[int]] | None = None,
prompt_embeds: bytes | list[bytes] | None = None, prompt_embeds: bytes | list[bytes] | None = None,
config: RenderConfig, config: RenderConfig,
) -> list[EngineTokensPrompt | EngineEmbedsPrompt]: ) -> list[TokensPrompt | EmbedsPrompt]:
""" """
Render text/token prompts and/or precomputed embedding prompts. At Render text/token prompts and/or precomputed embedding prompts. At
least one of `prompt_or_prompts` or `prompt_embeds` must be provided. least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
...@@ -249,7 +250,7 @@ class CompletionRenderer(BaseRenderer): ...@@ -249,7 +250,7 @@ class CompletionRenderer(BaseRenderer):
if truncate_prompt_tokens == 0: if truncate_prompt_tokens == 0:
return [] return []
rendered: list[EngineTokensPrompt | EngineEmbedsPrompt] = [] rendered: list[TokensPrompt | EmbedsPrompt] = []
if prompt_embeds is not None: if prompt_embeds is not None:
rendered.extend( rendered.extend(
...@@ -281,10 +282,10 @@ class CompletionRenderer(BaseRenderer): ...@@ -281,10 +282,10 @@ class CompletionRenderer(BaseRenderer):
async def _create_prompt( async def _create_prompt(
self, self,
prompt_input: EngineTextPrompt | EngineTokensPrompt, prompt_input: TextPrompt | TokensPrompt,
config: RenderConfig, config: RenderConfig,
truncate_prompt_tokens: int | None, truncate_prompt_tokens: int | None,
) -> EngineTokensPrompt: ) -> TokensPrompt:
prompt, prompt_token_ids, _ = get_prompt_components(prompt_input) prompt, prompt_token_ids, _ = get_prompt_components(prompt_input)
if prompt_token_ids is not None: if prompt_token_ids is not None:
...@@ -317,7 +318,7 @@ class CompletionRenderer(BaseRenderer): ...@@ -317,7 +318,7 @@ class CompletionRenderer(BaseRenderer):
truncate_prompt_tokens: int | None, truncate_prompt_tokens: int | None,
add_special_tokens: bool, add_special_tokens: bool,
cache_salt: str | None, cache_salt: str | None,
) -> EngineTokensPrompt: ) -> TokensPrompt:
"""Tokenize text input asynchronously.""" """Tokenize text input asynchronously."""
async_tokenizer = self._get_async_tokenizer() async_tokenizer = self._get_async_tokenizer()
...@@ -350,7 +351,7 @@ class CompletionRenderer(BaseRenderer): ...@@ -350,7 +351,7 @@ class CompletionRenderer(BaseRenderer):
truncate_prompt_tokens: int | None, truncate_prompt_tokens: int | None,
cache_salt: str | None, cache_salt: str | None,
needs_detokenization: bool | None = False, needs_detokenization: bool | None = False,
) -> EngineTokensPrompt: ) -> TokensPrompt:
"""Optionally detokenize token IDs and build a tokens prompt.""" """Optionally detokenize token IDs and build a tokens prompt."""
token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens) token_ids = self._maybe_apply_truncation(token_ids, truncate_prompt_tokens)
...@@ -392,8 +393,8 @@ class CompletionRenderer(BaseRenderer): ...@@ -392,8 +393,8 @@ class CompletionRenderer(BaseRenderer):
max_length: int | None = None, max_length: int | None = None,
cache_salt: str | None = None, cache_salt: str | None = None,
prompt: str | None = None, prompt: str | None = None,
) -> EngineTokensPrompt: ) -> TokensPrompt:
"""Create validated EngineTokensPrompt.""" """Create validated TokensPrompt."""
if max_length is not None and len(token_ids) > max_length: if max_length is not None and len(token_ids) > max_length:
raise ValueError( raise ValueError(
f"This model's maximum context length is {max_length} tokens. " f"This model's maximum context length is {max_length} tokens. "
...@@ -401,7 +402,7 @@ class CompletionRenderer(BaseRenderer): ...@@ -401,7 +402,7 @@ class CompletionRenderer(BaseRenderer):
"Please reduce the length of the input messages." "Please reduce the length of the input messages."
) )
tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids) tokens_prompt = TokensPrompt(prompt_token_ids=token_ids)
if cache_salt is not None: if cache_salt is not None:
tokens_prompt["cache_salt"] = cache_salt tokens_prompt["cache_salt"] = cache_salt
if prompt is not None: if prompt is not None:
......
...@@ -27,7 +27,7 @@ from vllm.entrypoints.serve.disagg.protocol import ( ...@@ -27,7 +27,7 @@ from vllm.entrypoints.serve.disagg.protocol import (
GenerateResponse, GenerateResponse,
GenerateResponseChoice, GenerateResponseChoice,
) )
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -99,7 +99,7 @@ class ServingTokens(OpenAIServing): ...@@ -99,7 +99,7 @@ class ServingTokens(OpenAIServing):
# TODO(NickLucche): Change to EngineCoreRequest once Renderer work is # TODO(NickLucche): Change to EngineCoreRequest once Renderer work is
# completed # completed
engine_prompt = EngineTokensPrompt(prompt_token_ids=request.token_ids) engine_prompt = TokensPrompt(prompt_token_ids=request.token_ids)
if request.features is not None: if request.features is not None:
engine_prompt["multi_modal_data"] = None engine_prompt["multi_modal_data"] = None
...@@ -115,7 +115,7 @@ class ServingTokens(OpenAIServing): ...@@ -115,7 +115,7 @@ class ServingTokens(OpenAIServing):
self._log_inputs( self._log_inputs(
request_id, request_id,
request.token_ids, TokensPrompt(prompt_token_ids=request.token_ids),
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
......
...@@ -21,6 +21,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -21,6 +21,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -80,11 +81,8 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -80,11 +81,8 @@ class OpenAIServingTokenization(OpenAIServing):
) )
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
(
_, _, engine_prompts = await self._preprocess_chat(
_,
engine_prompts,
) = await self._preprocess_chat(
request, request,
tokenizer, tokenizer,
request.messages, request.messages,
...@@ -141,7 +139,10 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -141,7 +139,10 @@ class OpenAIServingTokenization(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer() tokenizer = await self.engine_client.get_tokenizer()
self._log_inputs( self._log_inputs(
request_id, request.tokens, params=None, lora_request=lora_request request_id,
TokensPrompt(prompt_token_ids=request.tokens),
params=None,
lora_request=lora_request,
) )
prompt_input = await self._tokenize_prompt_input_async( prompt_input = await self._tokenize_prompt_input_async(
......
...@@ -30,7 +30,7 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -30,7 +30,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.entrypoints.openai.serving_models import LoRAModulePath
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.tokenizers import MistralTokenizer from vllm.tokenizers.mistral import MistralTokenizer
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -72,10 +72,9 @@ if TYPE_CHECKING: ...@@ -72,10 +72,9 @@ if TYPE_CHECKING:
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
VLLM_MEDIA_CONNECTOR: str = "http" VLLM_MEDIA_CONNECTOR: str = "http"
VLLM_MM_INPUT_CACHE_GIB: int = 4
VLLM_TARGET_DEVICE: str = "cuda" VLLM_TARGET_DEVICE: str = "cuda"
VLLM_MAIN_CUDA_VERSION: str = "12.9" VLLM_MAIN_CUDA_VERSION: str = "12.9"
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest" VLLM_FLOAT32_MATMUL_PRECISION: Literal["ieee", "tf32"] = "ieee"
MAX_JOBS: str | None = None MAX_JOBS: str | None = None
NVCC_THREADS: str | None = None NVCC_THREADS: str | None = None
VLLM_USE_PRECOMPILED: bool = False VLLM_USE_PRECOMPILED: bool = False
...@@ -240,6 +239,7 @@ if TYPE_CHECKING: ...@@ -240,6 +239,7 @@ if TYPE_CHECKING:
VLLM_NCCL_INCLUDE_PATH: str | None = None VLLM_NCCL_INCLUDE_PATH: str | None = None
VLLM_USE_FBGEMM: bool = False VLLM_USE_FBGEMM: bool = False
VLLM_GC_DEBUG: str = "" VLLM_GC_DEBUG: str = ""
VLLM_DEBUG_WORKSPACE: bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
...@@ -458,11 +458,13 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -458,11 +458,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower() "VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower()
or "12.9", or "12.9",
# Controls PyTorch float32 matmul precision mode within vLLM workers. # Controls PyTorch float32 matmul precision mode within vLLM workers.
# Valid options mirror torch.set_float32_matmul_precision # Accepted values:
# - "ieee" (default): force full IEEE FP32 matmul precision.
# - "tf32": enable TensorFloat32-based fast matmul.
"VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices( "VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices(
"VLLM_FLOAT32_MATMUL_PRECISION", "VLLM_FLOAT32_MATMUL_PRECISION",
"highest", "ieee",
["highest", "high", "medium"], ["ieee", "tf32"],
case_sensitive=False, case_sensitive=False,
), ),
# Maximum number of compilation jobs to run in parallel. # Maximum number of compilation jobs to run in parallel.
...@@ -787,9 +789,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -787,9 +789,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
# imported at runtime. # imported at runtime.
# If a non-existing backend is used, an AssertionError will be thrown. # If a non-existing backend is used, an AssertionError will be thrown.
"VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"), "VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"),
# [DEPRECATED] Cache size (in GiB per process) for multimodal input cache
# Default is 4 GiB per API process + 4 GiB per engine core process
"VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),
# Path to the XLA persistent cache directory. # Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs. # Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH": lambda: os.path.expanduser( "VLLM_XLA_CACHE_PATH": lambda: os.path.expanduser(
...@@ -1540,6 +1539,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1540,6 +1539,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with # - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with
# top 5 collected objects # top 5 collected objects
"VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""), "VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""),
# Debug workspace allocations.
# logging of workspace resize operations.
"VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))),
# Disables parallel execution of shared_experts via separate cuda stream # Disables parallel execution of shared_experts via separate cuda stream
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool( "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0")) int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))
...@@ -1584,6 +1586,12 @@ def __getattr__(name: str): ...@@ -1584,6 +1586,12 @@ def __getattr__(name: str):
raise AttributeError(f"module {__name__!r} has no attribute {name!r}") raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def _is_envs_cache_enabled() -> bool:
"""Checked if __getattr__ is wrapped with functools.cache"""
global __getattr__
return hasattr(__getattr__, "cache_clear")
def enable_envs_cache() -> None: def enable_envs_cache() -> None:
""" """
Enables caching of environment variables. This is useful for performance Enables caching of environment variables. This is useful for performance
...@@ -1594,6 +1602,9 @@ def enable_envs_cache() -> None: ...@@ -1594,6 +1602,9 @@ def enable_envs_cache() -> None:
runtime overhead. This also means that environment variables should NOT runtime overhead. This also means that environment variables should NOT
be updated after the service is initialized. be updated after the service is initialized.
""" """
if _is_envs_cache_enabled():
# Avoid wrapping functools.cache multiple times
return
# Tag __getattr__ with functools.cache # Tag __getattr__ with functools.cache
global __getattr__ global __getattr__
__getattr__ = functools.cache(__getattr__) __getattr__ = functools.cache(__getattr__)
...@@ -1603,6 +1614,17 @@ def enable_envs_cache() -> None: ...@@ -1603,6 +1614,17 @@ def enable_envs_cache() -> None:
__getattr__(key) __getattr__(key)
def disable_envs_cache() -> None:
"""
Resets the environment variables cache. It could be used to isolate environments
between unit tests.
"""
global __getattr__
# If __getattr__ is wrapped by functions.cache, unwrap the caching layer.
if _is_envs_cache_enabled():
__getattr__ = __getattr__.__wrapped__
def __dir__(): def __dir__():
return list(environment_variables.keys()) return list(environment_variables.keys())
...@@ -1665,7 +1687,6 @@ def compile_factors() -> dict[str, object]: ...@@ -1665,7 +1687,6 @@ def compile_factors() -> dict[str, object]:
"VLLM_MEDIA_CONNECTOR", "VLLM_MEDIA_CONNECTOR",
"VLLM_ASSETS_CACHE", "VLLM_ASSETS_CACHE",
"VLLM_ASSETS_CACHE_MODEL_CLEAN", "VLLM_ASSETS_CACHE_MODEL_CLEAN",
"VLLM_MM_INPUT_CACHE_GIB",
"VLLM_WORKER_MULTIPROC_METHOD", "VLLM_WORKER_MULTIPROC_METHOD",
"VLLM_ENABLE_V1_MULTIPROCESSING", "VLLM_ENABLE_V1_MULTIPROCESSING",
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
......
...@@ -33,22 +33,31 @@ def parse_raw_prompts( ...@@ -33,22 +33,31 @@ def parse_raw_prompts(
if len(prompt) == 0: if len(prompt) == 0:
raise ValueError("please provide at least one prompt") raise ValueError("please provide at least one prompt")
# case 2: array of strings
if is_list_of(prompt, str): if is_list_of(prompt, str):
# case 2: array of strings
prompt = cast(list[str], prompt) prompt = cast(list[str], prompt)
return [TextPrompt(prompt=elem) for elem in prompt] return [TextPrompt(prompt=elem) for elem in prompt]
# case 3: array of tokens
if is_list_of(prompt, int): if is_list_of(prompt, int):
# case 3: array of tokens
prompt = cast(list[int], prompt) prompt = cast(list[int], prompt)
return [TokensPrompt(prompt_token_ids=prompt)] return [TokensPrompt(prompt_token_ids=prompt)]
# case 4: array of token arrays
if is_list_of(prompt, list): if is_list_of(prompt, list):
prompt = cast(list[list[int]], prompt) first = prompt[0]
if len(prompt[0]) == 0: if not isinstance(first, list):
raise ValueError("please provide at least one prompt") raise ValueError("prompt expected to be a list of lists")
if is_list_of(prompt[0], int): if len(first) == 0:
# case 4: array of token arrays raise ValueError("Please provide at least one prompt")
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
# strict validation: every nested list must be list[int]
if not all(is_list_of(elem, int) for elem in prompt):
raise TypeError("Nested lists must contain only integers")
prompt = cast(list[list[int]], prompt)
return [TokensPrompt(prompt_token_ids=elem) for elem in prompt]
raise TypeError( raise TypeError(
"prompt must be a string, array of strings, " "prompt must be a string, array of strings, "
......
...@@ -229,6 +229,11 @@ def suppress_logging(level: int = logging.INFO) -> Generator[None, Any, None]: ...@@ -229,6 +229,11 @@ def suppress_logging(level: int = logging.INFO) -> Generator[None, Any, None]:
# guaranteed by the Python GIL. # guaranteed by the Python GIL.
_configure_vllm_root_logger() _configure_vllm_root_logger()
# Transformers uses httpx to access the Hugging Face Hub. httpx is quite verbose,
# so we set its logging level to WARNING when vLLM's logging level is INFO.
if envs.VLLM_LOGGING_LEVEL == "INFO":
logging.getLogger("httpx").setLevel(logging.WARNING)
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -38,8 +38,9 @@ class CustomOp(nn.Module): ...@@ -38,8 +38,9 @@ class CustomOp(nn.Module):
) )
return super().__new__(op_cls_to_instantiate) return super().__new__(op_cls_to_instantiate)
def __init__(self): def __init__(self, enforce_enable: bool = False):
super().__init__() super().__init__()
self._enforce_enable = enforce_enable
self._forward_method = self.dispatch_forward() self._forward_method = self.dispatch_forward()
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
...@@ -84,7 +85,11 @@ class CustomOp(nn.Module): ...@@ -84,7 +85,11 @@ class CustomOp(nn.Module):
# NOTE(woosuk): Here we assume that vLLM was built for only one # NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching. # specific backend. Currently, we do not support dynamic dispatching.
compilation_config = get_cached_compilation_config() compilation_config = get_cached_compilation_config()
enabled = self.enabled()
# CustomOp object can be enforce enabled, e.g., enable device-specific
# kernels in ViT models when enabling graph mode. By default, it will
# follow the compilation_config to determine whether enable itself.
enabled = self._enforce_enable or self.enabled()
if enabled: if enabled:
compilation_config.enabled_custom_ops.update([self.__class__.name]) compilation_config.enabled_custom_ops.update([self.__class__.name])
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment