Unverified Commit ba2f0acc authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Misc] Reorganize inputs (#35182)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 678b3c99
...@@ -797,12 +797,12 @@ class AnthropicServingMessages(OpenAIServingChat): ...@@ -797,12 +797,12 @@ class AnthropicServingMessages(OpenAIServingChat):
if isinstance(result, ErrorResponse): if isinstance(result, ErrorResponse):
return result return result
_, engine_prompts = result _, engine_inputs = result
input_tokens = sum( # type: ignore input_tokens = sum( # type: ignore
len(prompt["prompt_token_ids"]) # type: ignore[typeddict-item, misc] len(engine_input["prompt_token_ids"]) # type: ignore[typeddict-item, misc]
for prompt in engine_prompts for engine_input in engine_inputs
if "prompt_token_ids" in prompt if "prompt_token_ids" in engine_input
) )
response = AnthropicCountTokensResponse( response = AnthropicCountTokensResponse(
......
...@@ -40,9 +40,10 @@ from typing_extensions import Required, TypedDict ...@@ -40,9 +40,10 @@ from typing_extensions import Required, TypedDict
from vllm import envs from vllm import envs
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalBatchedField, MultiModalBatchedField,
MultiModalFlatField, MultiModalFlatField,
......
...@@ -57,9 +57,9 @@ from vllm.entrypoints.pooling.score.utils import ( ...@@ -57,9 +57,9 @@ from vllm.entrypoints.pooling.score.utils import (
validate_score_input, validate_score_input,
) )
from vllm.entrypoints.utils import log_non_default_args from vllm.entrypoints.utils import log_non_default_args
from vllm.inputs.data import ( from vllm.inputs import (
DataPrompt, DataPrompt,
ProcessorInputs, EngineInput,
PromptType, PromptType,
SingletonPrompt, SingletonPrompt,
TextPrompt, TextPrompt,
...@@ -589,7 +589,7 @@ class LLM: ...@@ -589,7 +589,7 @@ class LLM:
def _resolve_mm_lora( def _resolve_mm_lora(
self, self,
prompt: ProcessorInputs, prompt: EngineInput,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
) -> LoRARequest | None: ) -> LoRARequest | None:
if prompt["type"] != "multimodal": if prompt["type"] != "multimodal":
...@@ -716,8 +716,8 @@ class LLM: ...@@ -716,8 +716,8 @@ class LLM:
eos_token_id = tokenizer.eos_token_id eos_token_id = tokenizer.eos_token_id
sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty) sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
engine_prompts = self._preprocess_cmpl(prompts) engine_inputs = self._preprocess_cmpl(prompts)
lora_requests = self._lora_request_to_seq(lora_request, len(engine_prompts)) lora_requests = self._lora_request_to_seq(lora_request, len(engine_inputs))
if use_tqdm and concurrency_limit is not None: if use_tqdm and concurrency_limit is not None:
logger.warning( logger.warning(
...@@ -727,7 +727,7 @@ class LLM: ...@@ -727,7 +727,7 @@ class LLM:
use_tqdm = False use_tqdm = False
if concurrency_limit is None: if concurrency_limit is None:
concurrency_limit = len(engine_prompts) concurrency_limit = len(engine_inputs)
# generate 2 * beam_width candidates at each step # generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation # following the huggingface transformers implementation
...@@ -740,7 +740,7 @@ class LLM: ...@@ -740,7 +740,7 @@ class LLM:
) )
instances: list[BeamSearchInstance] = [] instances: list[BeamSearchInstance] = []
for lora_req, prompt in zip(lora_requests, engine_prompts): for lora_req, prompt in zip(lora_requests, engine_inputs):
if prompt["type"] == "embeds": if prompt["type"] == "embeds":
raise NotImplementedError( raise NotImplementedError(
"Embedding prompt not supported for beam search" "Embedding prompt not supported for beam search"
...@@ -845,7 +845,7 @@ class LLM: ...@@ -845,7 +845,7 @@ class LLM:
self, self,
prompts: Sequence[PromptType], prompts: Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> Sequence[ProcessorInputs]: ) -> Sequence[EngineInput]:
""" """
Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
a format that can be passed to `_add_request`. a format that can be passed to `_add_request`.
...@@ -853,7 +853,7 @@ class LLM: ...@@ -853,7 +853,7 @@ class LLM:
Refer to [LLM.generate][] for a complete description of the arguments. Refer to [LLM.generate][] for a complete description of the arguments.
Returns: Returns:
A list of `ProcessorInputs` objects ready to be passed into LLMEngine. A list of `EngineInput` objects ready to be passed into LLMEngine.
""" """
renderer = self.renderer renderer = self.renderer
model_config = self.model_config model_config = self.model_config
...@@ -871,9 +871,9 @@ class LLM: ...@@ -871,9 +871,9 @@ class LLM:
self, self,
prompt: PromptType, prompt: PromptType,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> ProcessorInputs: ) -> EngineInput:
(engine_prompt,) = self._preprocess_cmpl([prompt], tokenization_kwargs) (engine_input,) = self._preprocess_cmpl([prompt], tokenization_kwargs)
return engine_prompt return engine_input
def _preprocess_chat( def _preprocess_chat(
self, self,
...@@ -886,7 +886,7 @@ class LLM: ...@@ -886,7 +886,7 @@ class LLM:
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None,
) -> Sequence[ProcessorInputs]: ) -> Sequence[EngineInput]:
""" """
Convert a list of conversations into prompts so that they can then Convert a list of conversations into prompts so that they can then
be used as input for other LLM APIs. be used as input for other LLM APIs.
...@@ -894,7 +894,7 @@ class LLM: ...@@ -894,7 +894,7 @@ class LLM:
Refer to [LLM.chat][] for a complete description of the arguments. Refer to [LLM.chat][] for a complete description of the arguments.
Returns: Returns:
A list of `ProcessorInputs` objects ready to be passed into LLMEngine. A list of `EngineInput` objects ready to be passed into LLMEngine.
""" """
renderer = self.renderer renderer = self.renderer
...@@ -915,14 +915,14 @@ class LLM: ...@@ -915,14 +915,14 @@ class LLM:
**(tokenization_kwargs or {}) **(tokenization_kwargs or {})
) )
_, engine_prompts = renderer.render_chat( _, engine_inputs = renderer.render_chat(
conversations, conversations,
chat_params, chat_params,
tok_params, tok_params,
prompt_extras={"mm_processor_kwargs": mm_processor_kwargs}, prompt_extras={"mm_processor_kwargs": mm_processor_kwargs},
) )
return engine_prompts return engine_inputs
def _preprocess_chat_one( def _preprocess_chat_one(
self, self,
...@@ -935,8 +935,8 @@ class LLM: ...@@ -935,8 +935,8 @@ class LLM:
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
mm_processor_kwargs: dict[str, Any] | None = None, mm_processor_kwargs: dict[str, Any] | None = None,
) -> ProcessorInputs: ) -> EngineInput:
(engine_prompt,) = self._preprocess_chat( (engine_input,) = self._preprocess_chat(
[conversation], [conversation],
chat_template=chat_template, chat_template=chat_template,
chat_template_content_format=chat_template_content_format, chat_template_content_format=chat_template_content_format,
...@@ -948,7 +948,7 @@ class LLM: ...@@ -948,7 +948,7 @@ class LLM:
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
) )
return engine_prompt return engine_input
def chat( def chat(
self, self,
...@@ -1909,7 +1909,7 @@ class LLM: ...@@ -1909,7 +1909,7 @@ class LLM:
def _render_and_run_requests( def _render_and_run_requests(
self, self,
prompts: Iterable[ProcessorInputs], prompts: Iterable[EngineInput],
params: Sequence[SamplingParams | PoolingParams], params: Sequence[SamplingParams | PoolingParams],
output_type: type[_O], output_type: type[_O],
*, *,
...@@ -1938,7 +1938,7 @@ class LLM: ...@@ -1938,7 +1938,7 @@ class LLM:
def _render_and_add_requests( def _render_and_add_requests(
self, self,
prompts: Iterable[ProcessorInputs], prompts: Iterable[EngineInput],
params: Sequence[SamplingParams | PoolingParams], params: Sequence[SamplingParams | PoolingParams],
*, *,
lora_requests: Sequence[LoRARequest | None] | None = None, lora_requests: Sequence[LoRARequest | None] | None = None,
...@@ -1967,7 +1967,7 @@ class LLM: ...@@ -1967,7 +1967,7 @@ class LLM:
def _add_request( def _add_request(
self, self,
prompt: ProcessorInputs, prompt: EngineInput,
params: SamplingParams | PoolingParams, params: SamplingParams | PoolingParams,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
priority: int = 0, priority: int = 0,
......
...@@ -63,7 +63,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -63,7 +63,7 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
) )
from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls from vllm.entrypoints.openai.utils import maybe_filter_parallel_tool_calls
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.inputs.data import ProcessorInputs from vllm.inputs import EngineInput
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
...@@ -177,7 +177,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -177,7 +177,7 @@ class OpenAIServingChat(OpenAIServing):
async def render_chat_request( async def render_chat_request(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse: ) -> tuple[list[ConversationMessage], list[EngineInput]] | ErrorResponse:
""" """
Validate the model and preprocess a chat completion request. Validate the model and preprocess a chat completion request.
...@@ -185,7 +185,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -185,7 +185,7 @@ class OpenAIServingChat(OpenAIServing):
engine-aware checks (LoRA model validation, engine health). engine-aware checks (LoRA model validation, engine health).
Returns: Returns:
A tuple of (conversation, engine_prompts) on success, A tuple of (conversation, engine_inputs) on success,
or an ErrorResponse on failure. or an ErrorResponse on failure.
""" """
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
...@@ -231,7 +231,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -231,7 +231,7 @@ class OpenAIServingChat(OpenAIServing):
if isinstance(result, ErrorResponse): if isinstance(result, ErrorResponse):
return result return result
conversation, engine_prompts = result conversation, engine_inputs = result
request_id = ( request_id = (
f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}" f"chatcmpl-{self._base_request_id(raw_request, request.request_id)}"
...@@ -251,13 +251,13 @@ class OpenAIServingChat(OpenAIServing): ...@@ -251,13 +251,13 @@ class OpenAIServingChat(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
for i, engine_prompt in enumerate(engine_prompts): for i, engine_input in enumerate(engine_inputs):
prompt_token_ids = self._extract_prompt_components(engine_prompt).token_ids prompt_token_ids = self._extract_prompt_components(engine_input).token_ids
# If we are creating sub requests for multiple prompts, ensure that they # If we are creating sub requests for multiple prompts, ensure that they
# have unique request ids. # have unique request ids.
sub_request_id = ( sub_request_id = (
request_id if len(engine_prompts) == 1 else f"{request_id}_{i}" request_id if len(engine_inputs) == 1 else f"{request_id}_{i}"
) )
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
...@@ -265,7 +265,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -265,7 +265,7 @@ class OpenAIServingChat(OpenAIServing):
request.max_completion_tokens request.max_completion_tokens
if request.max_completion_tokens is not None if request.max_completion_tokens is not None
else request.max_tokens, else request.max_tokens,
self._extract_prompt_len(engine_prompt), self._extract_prompt_len(engine_input),
self.default_sampling_params, self.default_sampling_params,
self.override_max_tokens, self.override_max_tokens,
) )
...@@ -283,7 +283,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -283,7 +283,7 @@ class OpenAIServingChat(OpenAIServing):
self._log_inputs( self._log_inputs(
sub_request_id, sub_request_id,
engine_prompt, engine_input,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -296,7 +296,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -296,7 +296,7 @@ class OpenAIServingChat(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search( generator = self.beam_search(
prompt=engine_prompt, prompt=engine_input,
request_id=sub_request_id, request_id=sub_request_id,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
...@@ -313,7 +313,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -313,7 +313,7 @@ class OpenAIServingChat(OpenAIServing):
reasoning_ended = None reasoning_ended = None
generator = self.engine_client.generate( generator = self.engine_client.generate(
engine_prompt, engine_input,
sampling_params, sampling_params,
sub_request_id, sub_request_id,
lora_request=lora_request, lora_request=lora_request,
......
...@@ -33,7 +33,7 @@ from vllm.entrypoints.openai.engine.serving import ( ...@@ -33,7 +33,7 @@ from vllm.entrypoints.openai.engine.serving import (
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens, should_include_usage from vllm.entrypoints.utils import get_max_tokens, should_include_usage
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ProcessorInputs from vllm.inputs import EngineInput
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -82,7 +82,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -82,7 +82,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def render_completion_request( async def render_completion_request(
self, self,
request: CompletionRequest, request: CompletionRequest,
) -> list[ProcessorInputs] | ErrorResponse: ) -> list[EngineInput] | ErrorResponse:
""" """
Validate the model and preprocess a completion request. Validate the model and preprocess a completion request.
...@@ -90,8 +90,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -90,8 +90,7 @@ class OpenAIServingCompletion(OpenAIServing):
engine-aware checks (LoRA model validation, engine health). engine-aware checks (LoRA model validation, engine health).
Returns: Returns:
A list of engine_prompts on success, A list of engine_inputs on success, or an ErrorResponse on failure.
or an ErrorResponse on failure.
""" """
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
...@@ -128,7 +127,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -128,7 +127,7 @@ class OpenAIServingCompletion(OpenAIServing):
if isinstance(result, ErrorResponse): if isinstance(result, ErrorResponse):
return result return result
engine_prompts = result engine_inputs = result
request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}" request_id = f"cmpl-{self._base_request_id(raw_request, request.request_id)}"
created_time = int(time.time()) created_time = int(time.time())
...@@ -145,11 +144,11 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -145,11 +144,11 @@ class OpenAIServingCompletion(OpenAIServing):
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
max_model_len = self.model_config.max_model_len max_model_len = self.model_config.max_model_len
generators: list[AsyncGenerator[RequestOutput, None]] = [] generators: list[AsyncGenerator[RequestOutput, None]] = []
for i, engine_prompt in enumerate(engine_prompts): for i, engine_input in enumerate(engine_inputs):
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
max_model_len, max_model_len,
request.max_tokens, request.max_tokens,
self._extract_prompt_len(engine_prompt), self._extract_prompt_len(engine_input),
self.default_sampling_params, self.default_sampling_params,
self.override_max_tokens, self.override_max_tokens,
) )
...@@ -169,7 +168,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -169,7 +168,7 @@ class OpenAIServingCompletion(OpenAIServing):
self._log_inputs( self._log_inputs(
request_id_item, request_id_item,
engine_prompt, engine_input,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -182,7 +181,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -182,7 +181,7 @@ class OpenAIServingCompletion(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search( generator = self.beam_search(
prompt=engine_prompt, prompt=engine_input,
request_id=request_id, request_id=request_id,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
...@@ -190,7 +189,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -190,7 +189,7 @@ class OpenAIServingCompletion(OpenAIServing):
) )
else: else:
generator = self.engine_client.generate( generator = self.engine_client.generate(
engine_prompt, engine_input,
sampling_params, sampling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
...@@ -204,7 +203,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -204,7 +203,7 @@ class OpenAIServingCompletion(OpenAIServing):
result_generator = merge_async_iterators(*generators) result_generator = merge_async_iterators(*generators)
model_name = self.models.model_name(lora_request) model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts) num_prompts = len(engine_inputs)
# Streaming response # Streaming response
tokenizer = self.renderer.tokenizer tokenizer = self.renderer.tokenizer
...@@ -212,7 +211,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -212,7 +211,7 @@ class OpenAIServingCompletion(OpenAIServing):
if request.stream: if request.stream:
return self.completion_stream_generator( return self.completion_stream_generator(
request, request,
engine_prompts, engine_inputs,
result_generator, result_generator,
request_id, request_id,
created_time, created_time,
...@@ -235,8 +234,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -235,8 +234,7 @@ class OpenAIServingCompletion(OpenAIServing):
# We did not pass it into vLLM engine to avoid being redundant # We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs # with the inputs token IDs
if final_res.prompt is None: if final_res.prompt is None:
engine_prompt = engine_prompts[i] final_res.prompt = self._extract_prompt_text(engine_inputs[i])
final_res.prompt = self._extract_prompt_text(engine_prompt)
final_res_batch_checked = cast(list[RequestOutput], final_res_batch) final_res_batch_checked = cast(list[RequestOutput], final_res_batch)
...@@ -268,7 +266,7 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -268,7 +266,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator( async def completion_stream_generator(
self, self,
request: CompletionRequest, request: CompletionRequest,
engine_prompts: list[ProcessorInputs], engine_inputs: list[EngineInput],
result_generator: AsyncIterator[tuple[int, RequestOutput]], result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str, request_id: str,
created_time: int, created_time: int,
...@@ -301,8 +299,8 @@ class OpenAIServingCompletion(OpenAIServing): ...@@ -301,8 +299,8 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_text = res.prompt prompt_text = res.prompt
if prompt_text is None: if prompt_text is None:
engine_prompt = engine_prompts[prompt_idx] engine_input = engine_inputs[prompt_idx]
prompt_text = self._extract_prompt_text(engine_prompt) prompt_text = self._extract_prompt_text(engine_input)
# Prompt details are excluded from later streamed outputs # Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None: if prompt_token_ids is not None:
......
...@@ -72,11 +72,7 @@ from vllm.entrypoints.serve.tokenize.protocol import ( ...@@ -72,11 +72,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
) )
from vllm.entrypoints.utils import create_error_response from vllm.entrypoints.utils import create_error_response
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ( from vllm.inputs import EngineInput, PromptType, TokensPrompt
ProcessorInputs,
PromptType,
TokensPrompt,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
...@@ -163,7 +159,7 @@ class ServeContext(Generic[RequestT]): ...@@ -163,7 +159,7 @@ class ServeContext(Generic[RequestT]):
request_id: str request_id: str
created_time: int = field(default_factory=lambda: int(time.time())) created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
engine_prompts: list[ProcessorInputs] | None = None engine_inputs: list[EngineInput] | None = None
result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = ( result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
None None
...@@ -202,7 +198,7 @@ class OpenAIServing: ...@@ -202,7 +198,7 @@ class OpenAIServing:
async def beam_search( async def beam_search(
self, self,
prompt: ProcessorInputs, prompt: EngineInput,
request_id: str, request_id: str,
params: BeamSearchParams, params: BeamSearchParams,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
...@@ -493,21 +489,21 @@ class OpenAIServing: ...@@ -493,21 +489,21 @@ class OpenAIServing:
if isinstance(pooling_params, ErrorResponse): if isinstance(pooling_params, ErrorResponse):
return pooling_params return pooling_params
if ctx.engine_prompts is None: if ctx.engine_inputs is None:
return self.create_error_response("Engine prompts not available") return self.create_error_response("Engine prompts not available")
for i, engine_prompt in enumerate(ctx.engine_prompts): for i, engine_input in enumerate(ctx.engine_inputs):
request_id_item = f"{ctx.request_id}-{i}" request_id_item = f"{ctx.request_id}-{i}"
self._log_inputs( self._log_inputs(
request_id_item, request_id_item,
engine_prompt, engine_input,
params=pooling_params, params=pooling_params,
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
) )
generator = self.engine_client.encode( generator = self.engine_client.encode(
engine_prompt, engine_input,
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
...@@ -526,10 +522,10 @@ class OpenAIServing: ...@@ -526,10 +522,10 @@ class OpenAIServing:
ctx: ServeContext, ctx: ServeContext,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Collect batch results from the result generator.""" """Collect batch results from the result generator."""
if ctx.engine_prompts is None: if ctx.engine_inputs is None:
return self.create_error_response("Engine prompts not available") return self.create_error_response("Engine prompts not available")
num_prompts = len(ctx.engine_prompts) num_prompts = len(ctx.engine_inputs)
final_res_batch: list[PoolingRequestOutput | None] final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_prompts final_res_batch = [None] * num_prompts
...@@ -806,19 +802,19 @@ class OpenAIServing: ...@@ -806,19 +802,19 @@ class OpenAIServing:
# Apply server defaults first, then request kwargs override. # Apply server defaults first, then request kwargs override.
return default_chat_template_kwargs | request_chat_template_kwargs return default_chat_template_kwargs | request_chat_template_kwargs
def _extract_prompt_components(self, prompt: PromptType | ProcessorInputs): def _extract_prompt_components(self, prompt: PromptType | EngineInput):
return extract_prompt_components(self.model_config, prompt) return extract_prompt_components(self.model_config, prompt)
def _extract_prompt_text(self, prompt: ProcessorInputs): def _extract_prompt_text(self, prompt: PromptType | EngineInput):
return self._extract_prompt_components(prompt).text return self._extract_prompt_components(prompt).text
def _extract_prompt_len(self, prompt: ProcessorInputs): def _extract_prompt_len(self, prompt: EngineInput):
return extract_prompt_len(self.model_config, prompt) return extract_prompt_len(self.model_config, prompt)
def _log_inputs( def _log_inputs(
self, self,
request_id: str, request_id: str,
inputs: PromptType | ProcessorInputs, inputs: PromptType | EngineInput,
params: SamplingParams | PoolingParams | BeamSearchParams | None, params: SamplingParams | PoolingParams | BeamSearchParams | None,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
) -> None: ) -> None:
......
...@@ -12,7 +12,7 @@ from vllm.engine.protocol import EngineClient, StreamingInput ...@@ -12,7 +12,7 @@ from vllm.engine.protocol import EngineClient, StreamingInput
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.inputs.data import PromptType from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsRealtime from vllm.model_executor.models.interfaces import SupportsRealtime
from vllm.renderers.inputs.preprocess import parse_model_prompt from vllm.renderers.inputs.preprocess import parse_model_prompt
...@@ -83,6 +83,6 @@ class OpenAIServingRealtime(OpenAIServing): ...@@ -83,6 +83,6 @@ class OpenAIServingRealtime(OpenAIServing):
async for prompt in stream_input_iter: async for prompt in stream_input_iter:
parsed_prompt = parse_model_prompt(model_config, prompt) parsed_prompt = parse_model_prompt(model_config, prompt)
(engine_prompt,) = await renderer.render_cmpl_async([parsed_prompt]) (engine_input,) = await renderer.render_cmpl_async([parsed_prompt])
yield StreamingInput(prompt=engine_prompt) yield StreamingInput(prompt=engine_input)
...@@ -110,7 +110,7 @@ from vllm.entrypoints.openai.responses.utils import ( ...@@ -110,7 +110,7 @@ from vllm.entrypoints.openai.responses.utils import (
from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.utils import get_max_tokens from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs.data import ProcessorInputs, token_inputs from vllm.inputs import EngineInput, tokens_input
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob as SampleLogprob from vllm.logprobs import Logprob as SampleLogprob
from vllm.logprobs import SampleLogprobs from vllm.logprobs import SampleLogprobs
...@@ -269,10 +269,10 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -269,10 +269,10 @@ class OpenAIServingResponses(OpenAIServing):
def _validate_generator_input( def _validate_generator_input(
self, self,
engine_prompt: ProcessorInputs, engine_input: EngineInput,
) -> ErrorResponse | None: ) -> ErrorResponse | None:
"""Add validations to the input to the generator here.""" """Add validations to the input to the generator here."""
prompt_len = self._extract_prompt_len(engine_prompt) prompt_len = self._extract_prompt_len(engine_input)
max_model_len = self.model_config.max_model_len max_model_len = self.model_config.max_model_len
if prompt_len >= max_model_len: if prompt_len >= max_model_len:
...@@ -369,11 +369,11 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -369,11 +369,11 @@ class OpenAIServingResponses(OpenAIServing):
model_name = self.models.model_name(lora_request) model_name = self.models.model_name(lora_request)
if self.use_harmony: if self.use_harmony:
messages, engine_prompts = self._make_request_with_harmony( messages, engine_inputs = self._make_request_with_harmony(
request, prev_response request, prev_response
) )
else: else:
messages, engine_prompts = await self._make_request(request, prev_response) messages, engine_inputs = await self._make_request(request, prev_response)
request_metadata = RequestResponseMetadata(request_id=request.request_id) request_metadata = RequestResponseMetadata(request_id=request.request_id)
if raw_request: if raw_request:
...@@ -413,15 +413,15 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -413,15 +413,15 @@ class OpenAIServingResponses(OpenAIServing):
available_tools = [] available_tools = []
tokenizer = self.renderer.get_tokenizer() tokenizer = self.renderer.get_tokenizer()
for engine_prompt in engine_prompts: for engine_input in engine_inputs:
maybe_error = self._validate_generator_input(engine_prompt) maybe_error = self._validate_generator_input(engine_input)
if maybe_error is not None: if maybe_error is not None:
return maybe_error return maybe_error
default_max_tokens = get_max_tokens( default_max_tokens = get_max_tokens(
max_model_len, max_model_len,
request.max_output_tokens, request.max_output_tokens,
self._extract_prompt_len(engine_prompt), self._extract_prompt_len(engine_input),
self.default_sampling_params, self.default_sampling_params,
self.override_max_tokens, self.override_max_tokens,
) )
...@@ -480,7 +480,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -480,7 +480,7 @@ class OpenAIServingResponses(OpenAIServing):
) )
generator = self._generate_with_builtin_tools( generator = self._generate_with_builtin_tools(
request_id=request.request_id, request_id=request.request_id,
engine_prompt=engine_prompt, engine_input=engine_input,
sampling_params=sampling_params, sampling_params=sampling_params,
context=context, context=context,
lora_request=lora_request, lora_request=lora_request,
...@@ -586,7 +586,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -586,7 +586,7 @@ class OpenAIServingResponses(OpenAIServing):
prev_response_output=prev_response.output if prev_response else None, prev_response_output=prev_response.output if prev_response else None,
) )
_, engine_prompts = await self.openai_serving_render.preprocess_chat( _, engine_inputs = await self.openai_serving_render.preprocess_chat(
request, request,
messages, messages,
default_template=self.chat_template, default_template=self.chat_template,
...@@ -595,7 +595,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -595,7 +595,7 @@ class OpenAIServingResponses(OpenAIServing):
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
tool_parser=self.parser.tool_parser_cls if self.parser else None, tool_parser=self.parser.tool_parser_cls if self.parser else None,
) )
return messages, engine_prompts return messages, engine_inputs
async def _render_next_turn( async def _render_next_turn(
self, self,
...@@ -610,7 +610,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -610,7 +610,7 @@ class OpenAIServingResponses(OpenAIServing):
request_input=messages, request_input=messages,
) )
_, engine_prompts = await self.openai_serving_render.preprocess_chat( _, engine_inputs = await self.openai_serving_render.preprocess_chat(
request, request,
new_messages, new_messages,
default_template=chat_template, default_template=chat_template,
...@@ -619,12 +619,12 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -619,12 +619,12 @@ class OpenAIServingResponses(OpenAIServing):
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
tool_parser=tool_parser, tool_parser=tool_parser,
) )
return engine_prompts return engine_inputs
async def _generate_with_builtin_tools( async def _generate_with_builtin_tools(
self, self,
request_id: str, request_id: str,
engine_prompt: ProcessorInputs, engine_input: EngineInput,
sampling_params: SamplingParams, sampling_params: SamplingParams,
context: ConversationContext, context: ConversationContext,
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
...@@ -641,13 +641,13 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -641,13 +641,13 @@ class OpenAIServingResponses(OpenAIServing):
self._log_inputs( self._log_inputs(
sub_request_id, sub_request_id,
engine_prompt, engine_input,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
generator = self.engine_client.generate( generator = self.engine_client.generate(
engine_prompt, engine_input,
sampling_params, sampling_params,
sub_request_id, sub_request_id,
lora_request=lora_request, lora_request=lora_request,
...@@ -675,11 +675,11 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -675,11 +675,11 @@ class OpenAIServingResponses(OpenAIServing):
# Render the next prompt token ids and update sampling_params. # Render the next prompt token ids and update sampling_params.
if isinstance(context, (HarmonyContext, StreamingHarmonyContext)): if isinstance(context, (HarmonyContext, StreamingHarmonyContext)):
token_ids = context.render_for_completion() token_ids = context.render_for_completion()
engine_prompt = token_inputs(token_ids) engine_input = tokens_input(token_ids)
sampling_params.max_tokens = max_model_len - len(token_ids) sampling_params.max_tokens = max_model_len - len(token_ids)
elif isinstance(context, ParsableContext): elif isinstance(context, ParsableContext):
(engine_prompt,) = await self._render_next_turn( (engine_input,) = await self._render_next_turn(
context.request, context.request,
context.parser.response_messages, context.parser.response_messages,
context.tool_dicts, context.tool_dicts,
...@@ -691,7 +691,7 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -691,7 +691,7 @@ class OpenAIServingResponses(OpenAIServing):
sampling_params.max_tokens = get_max_tokens( sampling_params.max_tokens = get_max_tokens(
max_model_len, max_model_len,
context.request.max_output_tokens, context.request.max_output_tokens,
self._extract_prompt_len(engine_prompt), self._extract_prompt_len(engine_input),
self.default_sampling_params, # type: ignore self.default_sampling_params, # type: ignore
self.override_max_tokens, # type: ignore self.override_max_tokens, # type: ignore
) )
...@@ -713,14 +713,10 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -713,14 +713,10 @@ class OpenAIServingResponses(OpenAIServing):
arrival_time = time.time() arrival_time = time.time()
messages = self._construct_input_messages_with_harmony(request, prev_response) messages = self._construct_input_messages_with_harmony(request, prev_response)
prompt_token_ids = render_for_completion(messages) prompt_token_ids = render_for_completion(messages)
engine_prompt = token_inputs(prompt_token_ids) engine_input = tokens_input(prompt_token_ids, cache_salt=request.cache_salt)
engine_prompt["arrival_time"] = arrival_time engine_input["arrival_time"] = arrival_time
# Add cache_salt if provided in the request return messages, [engine_input]
if request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
return messages, [engine_prompt]
async def _initialize_tool_sessions( async def _initialize_tool_sessions(
self, self,
......
...@@ -38,7 +38,7 @@ from vllm.entrypoints.openai.speech_to_text.protocol import ( ...@@ -38,7 +38,7 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
) )
from vllm.entrypoints.utils import get_max_tokens from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs import EncoderDecoderInputs, ProcessorInputs from vllm.inputs import EncoderDecoderInput, EngineInput
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription from vllm.model_executor.models import SupportsTranscription
...@@ -171,7 +171,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -171,7 +171,7 @@ class OpenAISpeechToText(OpenAIServing):
request: SpeechToTextRequest, request: SpeechToTextRequest,
audio_data: bytes, audio_data: bytes,
request_id: str, request_id: str,
) -> tuple[list[ProcessorInputs], float]: ) -> tuple[list[EngineInput], float]:
# Validate request # Validate request
language = self.model_cls.validate_language(request.language) language = self.model_cls.validate_language(request.language)
# Skip to_language validation to avoid extra logging for Whisper. # Skip to_language validation to avoid extra logging for Whisper.
...@@ -250,9 +250,9 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -250,9 +250,9 @@ class OpenAISpeechToText(OpenAIServing):
parsed_prompts.append(parsed_prompt) parsed_prompts.append(parsed_prompt)
engine_prompts = await self.renderer.render_cmpl_async(parsed_prompts) engine_inputs = await self.renderer.render_cmpl_async(parsed_prompts)
return engine_prompts, duration return engine_inputs, duration
def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt): def _preprocess_verbose_prompt(self, prompt: EncoderDecoderDictPrompt):
dec_prompt = prompt["decoder_prompt"] dec_prompt = prompt["decoder_prompt"]
...@@ -271,7 +271,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -271,7 +271,7 @@ class OpenAISpeechToText(OpenAIServing):
return prompt return prompt
@staticmethod @staticmethod
def _get_decoder_prompt_len(engine_prompts: list[ProcessorInputs]) -> int: def _get_decoder_prompt_len(engine_inputs: list[EngineInput]) -> int:
"""Get the length of the decoder prompt. Currently we need to offset """Get the length of the decoder prompt. Currently we need to offset
by the decoder prompt length when running beam search because the mm by the decoder prompt length when running beam search because the mm
encoder is not currently cached and runs on decode calls; because of encoder is not currently cached and runs on decode calls; because of
...@@ -282,12 +282,13 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -282,12 +282,13 @@ class OpenAISpeechToText(OpenAIServing):
encoder/decoder caching is implemented. encoder/decoder caching is implemented.
""" """
input_len = 0 input_len = 0
assert len(engine_prompts) > 0 assert len(engine_inputs) > 0
first_eng_prompt = engine_prompts[0] first_input = engine_inputs[0]
if first_input.get("type") == "enc_dec":
first_input = cast(EncoderDecoderInput, first_input)
input_len = len(first_input["decoder_prompt"]["prompt_token_ids"])
if first_eng_prompt.get("type") == "enc_dec":
first_eng_prompt = cast(EncoderDecoderInputs, first_eng_prompt)
input_len = len(first_eng_prompt["decoder_prompt"]["prompt_token_ids"])
return input_len return input_len
def _get_verbose_segments( def _get_verbose_segments(
...@@ -409,7 +410,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -409,7 +410,7 @@ class OpenAISpeechToText(OpenAIServing):
lora_request = self._maybe_get_adapters(request) lora_request = self._maybe_get_adapters(request)
engine_prompts, duration_s = await self._preprocess_speech_to_text( engine_inputs, duration_s = await self._preprocess_speech_to_text(
request=request, request=request,
audio_data=audio_data, audio_data=audio_data,
request_id=request_id, request_id=request_id,
...@@ -420,7 +421,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -420,7 +421,7 @@ class OpenAISpeechToText(OpenAIServing):
list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None list_result_generator: list[AsyncGenerator[RequestOutput, None]] | None = None
input_len = ( input_len = (
OpenAISpeechToText._get_decoder_prompt_len(engine_prompts) OpenAISpeechToText._get_decoder_prompt_len(engine_inputs)
if request.use_beam_search if request.use_beam_search
else 0 else 0
) )
...@@ -450,12 +451,12 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -450,12 +451,12 @@ class OpenAISpeechToText(OpenAIServing):
sampling_params.logprobs = 1 sampling_params.logprobs = 1
list_result_generator = [] list_result_generator = []
for i, engine_prompt in enumerate(engine_prompts): for i, engine_input in enumerate(engine_inputs):
request_id_item = f"{request_id}_{i}" request_id_item = f"{request_id}_{i}"
self._log_inputs( self._log_inputs(
request_id_item, request_id_item,
engine_prompt, engine_input,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -468,7 +469,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -468,7 +469,7 @@ class OpenAISpeechToText(OpenAIServing):
if isinstance(sampling_params, BeamSearchParams): if isinstance(sampling_params, BeamSearchParams):
generator = self.beam_search( generator = self.beam_search(
prompt=engine_prompt, prompt=engine_input,
params=sampling_params, params=sampling_params,
request_id=request_id_item, request_id=request_id_item,
lora_request=lora_request, lora_request=lora_request,
...@@ -476,7 +477,7 @@ class OpenAISpeechToText(OpenAIServing): ...@@ -476,7 +477,7 @@ class OpenAISpeechToText(OpenAIServing):
) )
else: else:
generator = self.engine_client.generate( generator = self.engine_client.generate(
engine_prompt, engine_input,
sampling_params, sampling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
......
...@@ -18,7 +18,7 @@ from vllm.entrypoints.pooling.typing import ( ...@@ -18,7 +18,7 @@ from vllm.entrypoints.pooling.typing import (
PoolingCompletionLikeRequest, PoolingCompletionLikeRequest,
PoolingServeContext, PoolingServeContext,
) )
from vllm.inputs.data import ProcessorInputs, SingletonPrompt from vllm.inputs import EngineInput, SingletonPrompt
from vllm.renderers import BaseRenderer, merge_kwargs from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.tool_parsers import ToolParser from vllm.tool_parsers import ToolParser
...@@ -60,7 +60,7 @@ class PoolingIOProcessor: ...@@ -60,7 +60,7 @@ class PoolingIOProcessor:
chat_template_kwargs=request.chat_template_kwargs, chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template, trust_request_chat_template=self.trust_request_chat_template,
) )
_, engine_prompts = self._preprocess_chat_online( _, engine_inputs = self._preprocess_chat_online(
request, request,
request.messages, request.messages,
default_template=self.chat_template, default_template=self.chat_template,
...@@ -68,7 +68,7 @@ class PoolingIOProcessor: ...@@ -68,7 +68,7 @@ class PoolingIOProcessor:
default_template_kwargs=None, default_template_kwargs=None,
) )
elif isinstance(request, PoolingCompletionLikeRequest): elif isinstance(request, PoolingCompletionLikeRequest):
engine_prompts = self._preprocess_completion_online( engine_inputs = self._preprocess_completion_online(
request, request,
prompt_input=request.input, prompt_input=request.input,
prompt_embeds=None, prompt_embeds=None,
...@@ -76,7 +76,7 @@ class PoolingIOProcessor: ...@@ -76,7 +76,7 @@ class PoolingIOProcessor:
else: else:
raise ValueError(f"Invalid {self.name} request type") raise ValueError(f"Invalid {self.name} request type")
ctx.engine_prompts = engine_prompts ctx.engine_inputs = engine_inputs
async def pre_process_online_async(self, ctx: PoolingServeContext): async def pre_process_online_async(self, ctx: PoolingServeContext):
self.pre_process_online(ctx) self.pre_process_online(ctx)
...@@ -100,7 +100,7 @@ class PoolingIOProcessor: ...@@ -100,7 +100,7 @@ class PoolingIOProcessor:
self, self,
prompts: PromptType | Sequence[PromptType], prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> Sequence[ProcessorInputs]: ) -> Sequence[EngineInput]:
return self._preprocess_completion_offline( return self._preprocess_completion_offline(
prompts=prompts, tokenization_kwargs=tokenization_kwargs prompts=prompts, tokenization_kwargs=tokenization_kwargs
) )
...@@ -128,7 +128,7 @@ class PoolingIOProcessor: ...@@ -128,7 +128,7 @@ class PoolingIOProcessor:
request: RendererRequest, request: RendererRequest,
prompt_input: str | list[str] | list[int] | list[list[int]] | None, prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None, prompt_embeds: bytes | list[bytes] | None,
) -> list[ProcessorInputs]: ) -> list[EngineInput]:
renderer = self.renderer renderer = self.renderer
model_config = self.model_config model_config = self.model_config
...@@ -167,7 +167,7 @@ class PoolingIOProcessor: ...@@ -167,7 +167,7 @@ class PoolingIOProcessor:
default_template_kwargs: dict[str, Any] | None, default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None, tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: type[ToolParser] | None = None, tool_parser: type[ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]]: ) -> tuple[list[ConversationMessage], list[EngineInput]]:
renderer = self.renderer renderer = self.renderer
default_template_kwargs = merge_kwargs( default_template_kwargs = merge_kwargs(
...@@ -188,7 +188,7 @@ class PoolingIOProcessor: ...@@ -188,7 +188,7 @@ class PoolingIOProcessor:
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None), default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
) )
(conversation,), (engine_prompt,) = renderer.render_chat( (conversation,), (engine_input,) = renderer.render_chat(
[messages], [messages],
chat_params, chat_params,
tok_params, tok_params,
...@@ -199,13 +199,13 @@ class PoolingIOProcessor: ...@@ -199,13 +199,13 @@ class PoolingIOProcessor:
}, },
) )
return conversation, [engine_prompt] return conversation, [engine_input]
def _preprocess_completion_offline( def _preprocess_completion_offline(
self, self,
prompts: PromptType | Sequence[PromptType], prompts: PromptType | Sequence[PromptType],
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> Sequence[ProcessorInputs]: ) -> Sequence[EngineInput]:
renderer = self.renderer renderer = self.renderer
model_config = self.model_config model_config = self.model_config
......
...@@ -20,7 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse ...@@ -20,7 +20,7 @@ from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.pooling.typing import AnyPoolingRequest, PoolingServeContext from vllm.entrypoints.pooling.typing import AnyPoolingRequest, PoolingServeContext
from vllm.exceptions import VLLMNotFoundError from vllm.exceptions import VLLMNotFoundError
from vllm.inputs.data import ProcessorInputs from vllm.inputs import EngineInput
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.renderers.base import BaseRenderer from vllm.renderers.base import BaseRenderer
from vllm.renderers.inputs.preprocess import extract_prompt_components from vllm.renderers.inputs.preprocess import extract_prompt_components
...@@ -106,7 +106,7 @@ class PoolingServing: ...@@ -106,7 +106,7 @@ class PoolingServing:
self, self,
ctx: PoolingServeContext, ctx: PoolingServeContext,
): ):
if ctx.engine_prompts is None: if ctx.engine_inputs is None:
raise ValueError("Engine prompts not available") raise ValueError("Engine prompts not available")
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
...@@ -120,7 +120,7 @@ class PoolingServing: ...@@ -120,7 +120,7 @@ class PoolingServing:
pooling_params = self.io_processor.create_pooling_params(ctx.request) pooling_params = self.io_processor.create_pooling_params(ctx.request)
pooling_params.verify(self.model_config) pooling_params.verify(self.model_config)
for i, engine_prompt in enumerate(ctx.engine_prompts): for i, engine_input in enumerate(ctx.engine_inputs):
prompt_request_id = ( prompt_request_id = (
f"{ctx.request_id}-{i}" f"{ctx.request_id}-{i}"
if ctx.prompt_request_ids is None if ctx.prompt_request_ids is None
...@@ -129,13 +129,13 @@ class PoolingServing: ...@@ -129,13 +129,13 @@ class PoolingServing:
self._log_inputs( self._log_inputs(
prompt_request_id, prompt_request_id,
engine_prompt, engine_input,
params=pooling_params, params=pooling_params,
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
) )
generator = self.engine_client.encode( generator = self.engine_client.encode(
engine_prompt, engine_input,
pooling_params, pooling_params,
prompt_request_id, prompt_request_id,
lora_request=ctx.lora_request, lora_request=ctx.lora_request,
...@@ -151,13 +151,13 @@ class PoolingServing: ...@@ -151,13 +151,13 @@ class PoolingServing:
self, self,
ctx: PoolingServeContext, ctx: PoolingServeContext,
): ):
if ctx.engine_prompts is None: if ctx.engine_inputs is None:
raise ValueError("Engine prompts not available") raise ValueError("Engine prompts not available")
if ctx.result_generator is None: if ctx.result_generator is None:
raise ValueError("Result generator not available") raise ValueError("Result generator not available")
num_inputs = len(ctx.engine_prompts) num_inputs = len(ctx.engine_inputs)
final_res_batch: list[PoolingRequestOutput | None] final_res_batch: list[PoolingRequestOutput | None]
final_res_batch = [None] * num_inputs final_res_batch = [None] * num_inputs
...@@ -317,7 +317,7 @@ class PoolingServing: ...@@ -317,7 +317,7 @@ class PoolingServing:
def _log_inputs( def _log_inputs(
self, self,
request_id: str, request_id: str,
inputs: ProcessorInputs, inputs: EngineInput,
params: PoolingParams, params: PoolingParams,
lora_request: LoRARequest | None, lora_request: LoRARequest | None,
) -> None: ) -> None:
......
...@@ -24,7 +24,7 @@ from vllm.entrypoints.pooling.embed.protocol import ( ...@@ -24,7 +24,7 @@ from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingCompletionRequest, EmbeddingCompletionRequest,
) )
from vllm.entrypoints.pooling.typing import PoolingServeContext from vllm.entrypoints.pooling.typing import PoolingServeContext
from vllm.inputs.data import ProcessorInputs, token_inputs from vllm.inputs import EngineInput, tokens_input
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.renderers import merge_kwargs from vllm.renderers import merge_kwargs
...@@ -83,20 +83,20 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -83,20 +83,20 @@ class EmbedIOProcessor(PoolingIOProcessor):
################################################################# #################################################################
def _pre_process_chunked(self, ctx: PoolingServeContext) -> None: def _pre_process_chunked(self, ctx: PoolingServeContext) -> None:
if ctx.engine_prompts is None: if ctx.engine_inputs is None:
raise ValueError("Engine prompts not available") raise ValueError("Engine prompts not available")
ctx.intermediates = ctx.engine_prompts ctx.intermediates = ctx.engine_inputs
request_id = ctx.request_id request_id = ctx.request_id
max_model_len = self.model_config.max_model_len max_model_len = self.model_config.max_model_len
chunked_engine_prompts: list[ProcessorInputs] = [] chunked_engine_inputs: list[EngineInput] = []
prompt_request_ids: list[str] = [] prompt_request_ids: list[str] = []
for prompt_idx, engine_prompt in enumerate(ctx.engine_prompts): for prompt_idx, engine_input in enumerate(ctx.engine_inputs):
token_ids = engine_prompt.get("prompt_token_ids", None) token_ids = engine_input.get("prompt_token_ids", None)
if token_ids is None: if token_ids is None:
raise NotImplementedError( raise NotImplementedError(
"Long Text Embedding with Chunked Processing does " "Long Text Embedding with Chunked Processing does "
"not support EmbedsPrompt and EncoderDecoderInputs." "not support EmbedsPrompt and EncoderDecoderInput."
) )
prompt_token_ids = cast(list[int], token_ids) prompt_token_ids = cast(list[int], token_ids)
...@@ -104,14 +104,14 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -104,14 +104,14 @@ class EmbedIOProcessor(PoolingIOProcessor):
for chunk_idx, chunk_tokens in enumerate( for chunk_idx, chunk_tokens in enumerate(
chunk_list(prompt_token_ids, max_model_len) chunk_list(prompt_token_ids, max_model_len)
): ):
chunked_engine_prompts.append( chunked_engine_inputs.append(
token_inputs(prompt_token_ids=chunk_tokens) tokens_input(prompt_token_ids=chunk_tokens)
) )
prompt_request_ids.append( prompt_request_ids.append(
f"{request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}" f"{request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"
) )
ctx.engine_prompts = chunked_engine_prompts ctx.engine_inputs = chunked_engine_inputs
ctx.prompt_request_ids = prompt_request_ids ctx.prompt_request_ids = prompt_request_ids
return None return None
...@@ -184,8 +184,8 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -184,8 +184,8 @@ class EmbedIOProcessor(PoolingIOProcessor):
if ctx.intermediates is None: if ctx.intermediates is None:
raise ValueError("Original prompts inputs not available") raise ValueError("Original prompts inputs not available")
original_engine_prompts = cast(list[ProcessorInputs], ctx.intermediates) original_engine_inputs = cast(list[EngineInput], ctx.intermediates)
num_prompts = len(original_engine_prompts) num_prompts = len(original_engine_inputs)
# Finalize aggregated results # Finalize aggregated results
final_res_batch: list[PoolingRequestOutput] = [] final_res_batch: list[PoolingRequestOutput] = []
...@@ -211,12 +211,12 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -211,12 +211,12 @@ class EmbedIOProcessor(PoolingIOProcessor):
pooling_output_data = PoolingOutput(data=final_embedding) pooling_output_data = PoolingOutput(data=final_embedding)
# Get original prompt token IDs for this prompt # Get original prompt token IDs for this prompt
original_prompt = original_engine_prompts[prompt_idx] original_prompt = original_engine_inputs[prompt_idx]
token_ids = original_prompt.get("prompt_token_ids", None) token_ids = original_prompt.get("prompt_token_ids", None)
if token_ids is None: if token_ids is None:
raise NotImplementedError( raise NotImplementedError(
"Long Text Embedding with Chunked Processing does " "Long Text Embedding with Chunked Processing does "
"not support EmbedsPrompt and EncoderDecoderInputs." "not support EmbedsPrompt and EncoderDecoderInput."
) )
original_token_ids = cast(list[int], token_ids) original_token_ids = cast(list[int], token_ids)
...@@ -372,7 +372,7 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -372,7 +372,7 @@ class EmbedIOProcessor(PoolingIOProcessor):
] ]
for uri in request.images for uri in request.images
] ]
ctx.engine_prompts = self._batch_render_chat( ctx.engine_inputs = self._batch_render_chat(
request, all_messages, truncate_prompt_tokens, truncation_side request, all_messages, truncate_prompt_tokens, truncation_side
) )
...@@ -382,7 +382,7 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -382,7 +382,7 @@ class EmbedIOProcessor(PoolingIOProcessor):
self._mixed_input_to_messages(inp, task_prefix=task_prefix) self._mixed_input_to_messages(inp, task_prefix=task_prefix)
for inp in request.inputs for inp in request.inputs
] ]
ctx.engine_prompts = self._batch_render_chat( ctx.engine_inputs = self._batch_render_chat(
request, all_messages, truncate_prompt_tokens, truncation_side request, all_messages, truncate_prompt_tokens, truncation_side
) )
...@@ -396,7 +396,7 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -396,7 +396,7 @@ class EmbedIOProcessor(PoolingIOProcessor):
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
truncation_side=truncation_side, truncation_side=truncation_side,
) )
ctx.engine_prompts = self._preprocess_completion_online( ctx.engine_inputs = self._preprocess_completion_online(
proxy, prompt_input=proxy.input, prompt_embeds=None proxy, prompt_input=proxy.input, prompt_embeds=None
) )
...@@ -406,7 +406,7 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -406,7 +406,7 @@ class EmbedIOProcessor(PoolingIOProcessor):
all_messages: Sequence[list[ChatCompletionMessageParam]], all_messages: Sequence[list[ChatCompletionMessageParam]],
truncate_prompt_tokens: int | None, truncate_prompt_tokens: int | None,
truncation_side: Literal["left", "right"] | None, truncation_side: Literal["left", "right"] | None,
) -> list[ProcessorInputs]: ) -> list[EngineInput]:
"""Batch-render multiple conversations through the chat template.""" """Batch-render multiple conversations through the chat template."""
if not all_messages: if not all_messages:
return [] return []
...@@ -438,8 +438,8 @@ class EmbedIOProcessor(PoolingIOProcessor): ...@@ -438,8 +438,8 @@ class EmbedIOProcessor(PoolingIOProcessor):
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None), default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
) )
_, engine_prompts = renderer.render_chat(all_messages, chat_params, tok_params) _, engine_inputs = renderer.render_chat(all_messages, chat_params, tok_params)
return engine_prompts return engine_inputs
def _validate_input_type(self, input_type: str | None) -> None: def _validate_input_type(self, input_type: str | None) -> None:
"""Raise if *input_type* is not supported by this model.""" """Raise if *input_type* is not supported by this model."""
......
...@@ -33,7 +33,7 @@ from vllm.entrypoints.pooling.utils import ( ...@@ -33,7 +33,7 @@ from vllm.entrypoints.pooling.utils import (
encode_pooling_output_float, encode_pooling_output_float,
) )
from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.inputs import ProcessorInputs from vllm.inputs import EngineInput
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.renderers.inputs.preprocess import prompt_to_seq from vllm.renderers.inputs.preprocess import prompt_to_seq
...@@ -110,7 +110,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -110,7 +110,7 @@ class OpenAIServingPooling(OpenAIServing):
request.task, request.task,
) )
engine_prompts: Sequence[ProcessorInputs] engine_inputs: Sequence[EngineInput]
if use_io_processor := isinstance(request, IOProcessorRequest): if use_io_processor := isinstance(request, IOProcessorRequest):
if self.io_processor is None: if self.io_processor is None:
raise ValueError( raise ValueError(
...@@ -125,7 +125,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -125,7 +125,7 @@ class OpenAIServingPooling(OpenAIServing):
raw_prompts = await self.io_processor.pre_process_async( raw_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id prompt=validated_prompt, request_id=request_id
) )
engine_prompts = await self.openai_serving_render.preprocess_cmpl( engine_inputs = await self.openai_serving_render.preprocess_cmpl(
request, request,
prompt_to_seq(raw_prompts), prompt_to_seq(raw_prompts),
) )
...@@ -138,7 +138,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -138,7 +138,7 @@ class OpenAIServingPooling(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
_, engine_prompts = await self.openai_serving_render.preprocess_chat( _, engine_inputs = await self.openai_serving_render.preprocess_chat(
request, request,
request.messages, request.messages,
default_template=self.chat_template, default_template=self.chat_template,
...@@ -146,7 +146,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -146,7 +146,7 @@ class OpenAIServingPooling(OpenAIServing):
default_template_kwargs=None, default_template_kwargs=None,
) )
elif isinstance(request, PoolingCompletionRequest): elif isinstance(request, PoolingCompletionRequest):
engine_prompts = await self.openai_serving_render.preprocess_completion( engine_inputs = await self.openai_serving_render.preprocess_completion(
request, request,
prompt_input=request.input, prompt_input=request.input,
prompt_embeds=None, prompt_embeds=None,
...@@ -165,12 +165,12 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -165,12 +165,12 @@ class OpenAIServingPooling(OpenAIServing):
else: else:
pooling_params = request.to_pooling_params() # type: ignore pooling_params = request.to_pooling_params() # type: ignore
for i, engine_prompt in enumerate(engine_prompts): for i, engine_input in enumerate(engine_inputs):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
self._log_inputs( self._log_inputs(
request_id_item, request_id_item,
engine_prompt, engine_input,
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -182,7 +182,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -182,7 +182,7 @@ class OpenAIServingPooling(OpenAIServing):
) )
generator = self.engine_client.encode( generator = self.engine_client.encode(
engine_prompt, engine_input,
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
...@@ -221,7 +221,7 @@ class OpenAIServingPooling(OpenAIServing): ...@@ -221,7 +221,7 @@ class OpenAIServingPooling(OpenAIServing):
return IOProcessorResponse(request_id=request_id, data=output) return IOProcessorResponse(request_id=request_id, data=output)
assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest)) assert isinstance(request, (PoolingCompletionRequest, PoolingChatRequest))
num_prompts = len(engine_prompts) num_prompts = len(engine_inputs)
# Non-streaming response # Non-streaming response
final_res_batch: list[PoolingRequestOutput | None] final_res_batch: list[PoolingRequestOutput | None]
......
...@@ -35,7 +35,7 @@ from vllm.entrypoints.pooling.score.utils import ( ...@@ -35,7 +35,7 @@ from vllm.entrypoints.pooling.score.utils import (
parse_score_data_single, parse_score_data_single,
validate_score_input, validate_score_input,
) )
from vllm.inputs.data import ProcessorInputs, TokensPrompt, token_inputs from vllm.inputs import EngineInput, TokensPrompt, tokens_input
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
...@@ -110,12 +110,12 @@ class ServingScores(OpenAIServing): ...@@ -110,12 +110,12 @@ class ServingScores(OpenAIServing):
*(encode_async(t, **tokenization_kwargs) for t in input_texts) *(encode_async(t, **tokenization_kwargs) for t in input_texts)
) )
engine_prompts: list[ProcessorInputs] = [] engine_inputs: list[EngineInput] = []
for tok_result, input_text in zip(tokenized_prompts, input_texts): for tok_result, input_text in zip(tokenized_prompts, input_texts):
text_token_prompt = self._validate_input(request, tok_result, input_text) text_token_prompt = self._validate_input(request, tok_result, input_text)
engine_prompts.append( engine_inputs.append(
token_inputs( tokens_input(
text_token_prompt["prompt_token_ids"], text_token_prompt["prompt_token_ids"],
prompt=input_text, prompt=input_text,
) )
...@@ -125,19 +125,19 @@ class ServingScores(OpenAIServing): ...@@ -125,19 +125,19 @@ class ServingScores(OpenAIServing):
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
pooling_params = request.to_pooling_params("embed") pooling_params = request.to_pooling_params("embed")
for i, engine_prompt in enumerate(engine_prompts): for i, engine_input in enumerate(engine_inputs):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
self._log_inputs( self._log_inputs(
request_id_item, request_id_item,
engine_prompt, engine_input,
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
) )
generators.append( generators.append(
self.engine_client.encode( self.engine_client.encode(
engine_prompt, engine_input,
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
...@@ -151,7 +151,7 @@ class ServingScores(OpenAIServing): ...@@ -151,7 +151,7 @@ class ServingScores(OpenAIServing):
# Non-streaming response # Non-streaming response
final_res_batch: list[PoolingRequestOutput] = [] final_res_batch: list[PoolingRequestOutput] = []
embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_prompts) embeddings: list[PoolingRequestOutput | None] = [None] * len(engine_inputs)
async for i, res in result_generator: async for i, res in result_generator:
embeddings[i] = res embeddings[i] = res
...@@ -183,7 +183,7 @@ class ServingScores(OpenAIServing): ...@@ -183,7 +183,7 @@ class ServingScores(OpenAIServing):
request: RerankRequest | ScoreRequest, request: RerankRequest | ScoreRequest,
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
tokenization_kwargs: dict[str, Any], tokenization_kwargs: dict[str, Any],
) -> tuple[str, TokensPrompt]: ) -> TokensPrompt:
"""Parse a single ScoreData into a text + optional multimodal """Parse a single ScoreData into a text + optional multimodal
TokensPrompt for late-interaction encoding. TokensPrompt for late-interaction encoding.
...@@ -197,21 +197,22 @@ class ServingScores(OpenAIServing): ...@@ -197,21 +197,22 @@ class ServingScores(OpenAIServing):
else: else:
text, mm_data, mm_uuids = parse_score_data_single(data, role, model_config) text, mm_data, mm_uuids = parse_score_data_single(data, role, model_config)
prompt_inputs = tokenizer(text, **tokenization_kwargs) prompt_ids = tokenizer.encode(text, **tokenization_kwargs)
self._validate_input(request, prompt_inputs["input_ids"], text) self._validate_input(request, prompt_ids, text)
engine_prompt = TokensPrompt( tok_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"], prompt_token_ids=prompt_ids,
prompt=text,
) )
if mm_data is not None: if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data tok_prompt["multi_modal_data"] = mm_data
if mm_uuids is not None: if mm_uuids is not None:
engine_prompt["multi_modal_uuids"] = mm_uuids tok_prompt["multi_modal_uuids"] = mm_uuids
if request.mm_processor_kwargs is not None: if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs tok_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
return text, engine_prompt return tok_prompt
async def _late_interaction_score( async def _late_interaction_score(
self, self,
...@@ -240,7 +241,7 @@ class ServingScores(OpenAIServing): ...@@ -240,7 +241,7 @@ class ServingScores(OpenAIServing):
executor=self._tokenizer_executor, executor=self._tokenizer_executor,
) )
preprocessed = await asyncio.gather( tok_prompts = await asyncio.gather(
*( *(
preprocess_async( preprocess_async(
data=d, data=d,
...@@ -253,12 +254,8 @@ class ServingScores(OpenAIServing): ...@@ -253,12 +254,8 @@ class ServingScores(OpenAIServing):
) )
) )
query_prompts: list[TokensPrompt] = [ query_prompts = tok_prompts[: len(data_1)]
prompt for _, prompt in preprocessed[: len(data_1)] doc_prompts = tok_prompts[len(data_1) :]
]
doc_prompts: list[TokensPrompt] = [
prompt for _, prompt in preprocessed[len(data_1) :]
]
default_pooling_params = request.to_pooling_params("token_embed") default_pooling_params = request.to_pooling_params("token_embed")
...@@ -268,7 +265,7 @@ class ServingScores(OpenAIServing): ...@@ -268,7 +265,7 @@ class ServingScores(OpenAIServing):
query_prompts query_prompts
) )
query_generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] query_generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
for i, engine_prompt in enumerate(query_prompts): for i, tok_prompt in enumerate(query_prompts):
request_id_item = f"{request_id}-query-{i}" request_id_item = f"{request_id}-query-{i}"
pooling_params = default_pooling_params.clone() pooling_params = default_pooling_params.clone()
pooling_params.late_interaction_params = ( pooling_params.late_interaction_params = (
...@@ -280,14 +277,14 @@ class ServingScores(OpenAIServing): ...@@ -280,14 +277,14 @@ class ServingScores(OpenAIServing):
self._log_inputs( self._log_inputs(
request_id_item, request_id_item,
engine_prompt, tok_prompt,
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
) )
query_generators.append( query_generators.append(
self.engine_client.encode( self.engine_client.encode(
engine_prompt, tok_prompt,
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
...@@ -306,7 +303,7 @@ class ServingScores(OpenAIServing): ...@@ -306,7 +303,7 @@ class ServingScores(OpenAIServing):
# stage 2: encode docs and return scalar scores from workers. # stage 2: encode docs and return scalar scores from workers.
doc_generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] doc_generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
for i, engine_prompt in enumerate(doc_prompts): for i, tok_prompt in enumerate(doc_prompts):
request_id_item = f"{request_id}-doc-{i}" request_id_item = f"{request_id}-doc-{i}"
query_idx = 0 if len(query_prompts) == 1 else i query_idx = 0 if len(query_prompts) == 1 else i
pooling_params = default_pooling_params.clone() pooling_params = default_pooling_params.clone()
...@@ -316,14 +313,14 @@ class ServingScores(OpenAIServing): ...@@ -316,14 +313,14 @@ class ServingScores(OpenAIServing):
self._log_inputs( self._log_inputs(
request_id_item, request_id_item,
engine_prompt, tok_prompt,
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
) )
doc_generators.append( doc_generators.append(
self.engine_client.encode( self.engine_client.encode(
engine_prompt, tok_prompt,
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
...@@ -404,28 +401,22 @@ class ServingScores(OpenAIServing): ...@@ -404,28 +401,22 @@ class ServingScores(OpenAIServing):
) )
) )
request_prompts: list[str] = []
engine_prompts: list[TokensPrompt] = []
for full_prompt, engine_prompt in preprocessed_prompts:
request_prompts.append(full_prompt)
engine_prompts.append(engine_prompt)
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
default_pooling_params = request.to_pooling_params("classify") default_pooling_params = request.to_pooling_params("classify")
for i, engine_prompt in enumerate(engine_prompts): for i, (full_prompt, tok_prompt) in enumerate(preprocessed_prompts):
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
self._log_inputs( self._log_inputs(
request_id_item, request_id_item,
request_prompts[i], full_prompt,
params=default_pooling_params, params=default_pooling_params,
lora_request=lora_request, lora_request=lora_request,
) )
if token_type_ids := engine_prompt.pop("token_type_ids", None): if token_type_ids := tok_prompt.pop("token_type_ids", None):
pooling_params = default_pooling_params.clone() pooling_params = default_pooling_params.clone()
compressed = compress_token_type_ids(token_type_ids) compressed = compress_token_type_ids(token_type_ids)
pooling_params.extra_kwargs = {"compressed_token_type_ids": compressed} pooling_params.extra_kwargs = {"compressed_token_type_ids": compressed}
...@@ -433,7 +424,7 @@ class ServingScores(OpenAIServing): ...@@ -433,7 +424,7 @@ class ServingScores(OpenAIServing):
pooling_params = default_pooling_params pooling_params = default_pooling_params
generator = self.engine_client.encode( generator = self.engine_client.encode(
engine_prompt, tok_prompt,
pooling_params, pooling_params,
request_id_item, request_id_item,
lora_request=lora_request, lora_request=lora_request,
...@@ -447,7 +438,7 @@ class ServingScores(OpenAIServing): ...@@ -447,7 +438,7 @@ class ServingScores(OpenAIServing):
# Non-streaming response # Non-streaming response
final_res_batch: list[PoolingRequestOutput | None] = [None] * len( final_res_batch: list[PoolingRequestOutput | None] = [None] * len(
engine_prompts preprocessed_prompts
) )
async for i, res in result_generator: async for i, res in result_generator:
...@@ -464,7 +455,7 @@ class ServingScores(OpenAIServing): ...@@ -464,7 +455,7 @@ class ServingScores(OpenAIServing):
data_2: ScoreData, data_2: ScoreData,
) -> tuple[str, TokensPrompt]: ) -> tuple[str, TokensPrompt]:
model_config = self.model_config model_config = self.model_config
full_prompt, engine_prompt = get_score_prompt( full_prompt, engine_input = get_score_prompt(
model_config=model_config, model_config=model_config,
data_1=data_1, data_1=data_1,
data_2=data_2, data_2=data_2,
...@@ -472,11 +463,11 @@ class ServingScores(OpenAIServing): ...@@ -472,11 +463,11 @@ class ServingScores(OpenAIServing):
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
score_template=self.score_template, score_template=self.score_template,
) )
self._validate_input(request, engine_prompt["prompt_token_ids"], full_prompt) self._validate_input(request, engine_input["prompt_token_ids"], full_prompt)
if request.mm_processor_kwargs is not None: if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs engine_input["mm_processor_kwargs"] = request.mm_processor_kwargs
return full_prompt, engine_prompt return full_prompt, engine_input
async def _run_scoring( async def _run_scoring(
self, self,
......
...@@ -20,10 +20,14 @@ from vllm.entrypoints.chat_utils import ( ...@@ -20,10 +20,14 @@ from vllm.entrypoints.chat_utils import (
MultiModalItemTracker, MultiModalItemTracker,
_parse_chat_message_content_parts, _parse_chat_message_content_parts,
) )
from vllm.inputs import TokensPrompt from vllm.inputs import (
from vllm.inputs.data import PromptType, TextPrompt MultiModalDataDict,
MultiModalUUIDDict,
PromptType,
TextPrompt,
TokensPrompt,
)
from vllm.model_executor.models.interfaces import supports_score_template from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.outputs import PoolingRequestOutput from vllm.outputs import PoolingRequestOutput
from vllm.renderers.hf import safe_apply_chat_template from vllm.renderers.hf import safe_apply_chat_template
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
......
...@@ -32,7 +32,7 @@ from vllm.entrypoints.pooling.score.protocol import ( ...@@ -32,7 +32,7 @@ from vllm.entrypoints.pooling.score.protocol import (
ScoreRequest, ScoreRequest,
ScoreResponse, ScoreResponse,
) )
from vllm.inputs import ProcessorInputs from vllm.inputs import EngineInput
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
PoolingCompletionLikeRequest: TypeAlias = ( PoolingCompletionLikeRequest: TypeAlias = (
...@@ -74,7 +74,7 @@ class PoolingServeContext(Generic[PoolingRequestT]): ...@@ -74,7 +74,7 @@ class PoolingServeContext(Generic[PoolingRequestT]):
created_time: int = field(default_factory=lambda: int(time.time())) created_time: int = field(default_factory=lambda: int(time.time()))
lora_request: LoRARequest | None = None lora_request: LoRARequest | None = None
engine_prompts: list[ProcessorInputs] | None = None engine_inputs: list[EngineInput] | None = None
prompt_request_ids: list[str] | None = None prompt_request_ids: list[str] | None = None
intermediates: Any | None = None intermediates: Any | None = None
......
...@@ -33,19 +33,20 @@ class MultiModalFeatures(BaseModel): ...@@ -33,19 +33,20 @@ class MultiModalFeatures(BaseModel):
"""Lightweight multimodal metadata produced by the render step. """Lightweight multimodal metadata produced by the render step.
Carries hashes (for cache lookup / identification) and placeholder Carries hashes (for cache lookup / identification) and placeholder
positions so the downstream ``/generate`` service knows *where* in positions so the downstream `/generate` service knows *where* in
the token sequence each multimodal item lives. the token sequence each multimodal item lives.
.. note:: Phase 1 — metadata only. Note:
Phase 2 should add ``mm_kwargs`` (processed tensor data) using a Phase 1 — metadata only.
binary transport so the ``/generate`` side can skip re-processing. Phase 2 should add `mm_kwargs` (processed tensor data) using a
The ``/generate`` endpoint must also be updated to inject these binary transport so the ``/generate` side can skip re-processing.
features into ``ProcessorInputs`` before passing to The `/generate` endpoint must also be updated to inject these
``InputProcessor.process_inputs``. features into `EngineInput` before passing to
`InputProcessor.process_inputs`.
""" """
mm_hashes: dict[str, list[str]] mm_hashes: dict[str, list[str]]
"""Per-modality item hashes, e.g. ``{"image": ["abc", "def"]}``.""" """Per-modality item hashes, e.g. `{"image": ["abc", "def"]}`."""
mm_placeholders: dict[str, list[PlaceholderRangeInfo]] mm_placeholders: dict[str, list[PlaceholderRangeInfo]]
"""Per-modality placeholder ranges in the token sequence.""" """Per-modality placeholder ranges in the token sequence."""
......
...@@ -99,13 +99,11 @@ class ServingTokens(OpenAIServing): ...@@ -99,13 +99,11 @@ class ServingTokens(OpenAIServing):
if raw_request: if raw_request:
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
engine_prompts = await self.openai_serving_render.preprocess_completion( (engine_input,) = await self.openai_serving_render.preprocess_completion(
request, request,
prompt_input=request.token_ids, prompt_input=request.token_ids,
prompt_embeds=None, prompt_embeds=None,
) )
assert len(engine_prompts) == 1
engine_prompt = engine_prompts[0]
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
result_generator: AsyncGenerator[RequestOutput, None] | None = None result_generator: AsyncGenerator[RequestOutput, None] | None = None
...@@ -115,7 +113,7 @@ class ServingTokens(OpenAIServing): ...@@ -115,7 +113,7 @@ class ServingTokens(OpenAIServing):
self._log_inputs( self._log_inputs(
request_id, request_id,
engine_prompt, engine_input,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
) )
...@@ -127,7 +125,7 @@ class ServingTokens(OpenAIServing): ...@@ -127,7 +125,7 @@ class ServingTokens(OpenAIServing):
) )
result_generator = self.engine_client.generate( result_generator = self.engine_client.generate(
engine_prompt, engine_input,
sampling_params, sampling_params,
request_id, request_id,
lora_request=lora_request, lora_request=lora_request,
......
...@@ -34,9 +34,15 @@ from vllm.entrypoints.utils import ( ...@@ -34,9 +34,15 @@ from vllm.entrypoints.utils import (
create_error_response, create_error_response,
get_max_tokens, get_max_tokens,
) )
from vllm.inputs.data import ProcessorInputs, PromptType, SingletonPrompt, TokensPrompt from vllm.inputs import (
EngineInput,
MultiModalHashes,
MultiModalPlaceholders,
PromptType,
SingletonPrompt,
tokens_input,
)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal.inputs import MultiModalHashes, MultiModalPlaceholderDict
from vllm.parser import ParserManager from vllm.parser import ParserManager
from vllm.renderers import BaseRenderer, merge_kwargs from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.renderers.inputs.preprocess import ( from vllm.renderers.inputs.preprocess import (
...@@ -127,22 +133,22 @@ class OpenAIServingRender: ...@@ -127,22 +133,22 @@ class OpenAIServingRender:
if isinstance(result, ErrorResponse): if isinstance(result, ErrorResponse):
return result return result
_, engine_prompts = result _, engine_inputs = result
if len(engine_prompts) != 1: if len(engine_inputs) != 1:
return self.create_error_response( return self.create_error_response(
f"Expected exactly 1 engine prompt, got {len(engine_prompts)}" f"Expected exactly 1 engine prompt, got {len(engine_inputs)}"
) )
engine_prompt = engine_prompts[0] engine_input = engine_inputs[0]
prompt_components = extract_prompt_components(self.model_config, engine_prompt) prompt_components = extract_prompt_components(self.model_config, engine_input)
token_ids = prompt_components.token_ids token_ids = prompt_components.token_ids
if not token_ids: if not token_ids:
return self.create_error_response("No token_ids rendered") return self.create_error_response("No token_ids rendered")
token_ids = list(token_ids) token_ids = list(token_ids)
input_length = extract_prompt_len(self.model_config, engine_prompt) input_length = extract_prompt_len(self.model_config, engine_input)
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
self.model_config.max_model_len, self.model_config.max_model_len,
request.max_completion_tokens request.max_completion_tokens
...@@ -159,7 +165,7 @@ class OpenAIServingRender: ...@@ -159,7 +165,7 @@ class OpenAIServingRender:
return GenerateRequest( return GenerateRequest(
request_id=request_id, request_id=request_id,
token_ids=token_ids, token_ids=token_ids,
features=self._extract_mm_features(engine_prompt), features=self._extract_mm_features(engine_input),
sampling_params=params, sampling_params=params,
model=request.model, model=request.model,
stream=bool(request.stream), stream=bool(request.stream),
...@@ -171,7 +177,7 @@ class OpenAIServingRender: ...@@ -171,7 +177,7 @@ class OpenAIServingRender:
async def render_chat( async def render_chat(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse: ) -> tuple[list[ConversationMessage], list[EngineInput]] | ErrorResponse:
"""Core preprocessing logic for chat requests (no model/engine check). """Core preprocessing logic for chat requests (no model/engine check).
Called directly by render_chat_request and delegated to by Called directly by render_chat_request and delegated to by
...@@ -184,7 +190,6 @@ class OpenAIServingRender: ...@@ -184,7 +190,6 @@ class OpenAIServingRender:
if is_mistral_tokenizer(tokenizer): if is_mistral_tokenizer(tokenizer):
# because of issues with pydantic we need to potentially # because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request # re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
_mt.maybe_serialize_tool_calls(request) # type: ignore[arg-type] _mt.maybe_serialize_tool_calls(request) # type: ignore[arg-type]
_mt.truncate_tool_call_ids(request) # type: ignore[arg-type] _mt.truncate_tool_call_ids(request) # type: ignore[arg-type]
_mt.validate_request_params(request) _mt.validate_request_params(request)
...@@ -232,7 +237,7 @@ class OpenAIServingRender: ...@@ -232,7 +237,7 @@ class OpenAIServingRender:
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
conversation, engine_prompts = await self.preprocess_chat( conversation, engine_inputs = await self.preprocess_chat(
request, request,
request.messages, request.messages,
default_template=self.chat_template, default_template=self.chat_template,
...@@ -244,11 +249,11 @@ class OpenAIServingRender: ...@@ -244,11 +249,11 @@ class OpenAIServingRender:
else: else:
# For GPT-OSS. # For GPT-OSS.
should_include_tools = tool_dicts is not None should_include_tools = tool_dicts is not None
conversation, engine_prompts = self._make_request_with_harmony( conversation, engine_inputs = self._make_request_with_harmony(
request, should_include_tools request, should_include_tools
) )
return conversation, engine_prompts return conversation, engine_inputs
async def render_completion_request( async def render_completion_request(
self, self,
...@@ -266,16 +271,16 @@ class OpenAIServingRender: ...@@ -266,16 +271,16 @@ class OpenAIServingRender:
if isinstance(result, ErrorResponse): if isinstance(result, ErrorResponse):
return result return result
generate_requests: list[GenerateRequest] = [] generate_requests: list[GenerateRequest] = []
for engine_prompt in result: for engine_input in result:
prompt_components = extract_prompt_components( prompt_components = extract_prompt_components(
self.model_config, engine_prompt self.model_config, engine_input
) )
token_ids = prompt_components.token_ids token_ids = prompt_components.token_ids
if not token_ids: if not token_ids:
return self.create_error_response("No token_ids rendered") return self.create_error_response("No token_ids rendered")
token_ids = list(token_ids) token_ids = list(token_ids)
input_length = extract_prompt_len(self.model_config, engine_prompt) input_length = extract_prompt_len(self.model_config, engine_input)
max_tokens = get_max_tokens( max_tokens = get_max_tokens(
self.model_config.max_model_len, self.model_config.max_model_len,
request.max_tokens, request.max_tokens,
...@@ -293,7 +298,7 @@ class OpenAIServingRender: ...@@ -293,7 +298,7 @@ class OpenAIServingRender:
GenerateRequest( GenerateRequest(
request_id=request_id, request_id=request_id,
token_ids=token_ids, token_ids=token_ids,
features=self._extract_mm_features(engine_prompt), features=self._extract_mm_features(engine_input),
sampling_params=params, sampling_params=params,
model=request.model, model=request.model,
stream=bool(request.stream), stream=bool(request.stream),
...@@ -308,7 +313,7 @@ class OpenAIServingRender: ...@@ -308,7 +313,7 @@ class OpenAIServingRender:
async def render_completion( async def render_completion(
self, self,
request: CompletionRequest, request: CompletionRequest,
) -> list[ProcessorInputs] | ErrorResponse: ) -> list[EngineInput] | ErrorResponse:
"""Core preprocessing logic for completion requests (no model/engine check). """Core preprocessing logic for completion requests (no model/engine check).
Called directly by render_completion_request and delegated to by Called directly by render_completion_request and delegated to by
...@@ -326,28 +331,28 @@ class OpenAIServingRender: ...@@ -326,28 +331,28 @@ class OpenAIServingRender:
"prompt_logprobs is not compatible with prompt embeds." "prompt_logprobs is not compatible with prompt embeds."
) )
engine_prompts = await self.preprocess_completion( engine_inputs = await self.preprocess_completion(
request, request,
prompt_input=request.prompt, prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds, prompt_embeds=request.prompt_embeds,
) )
return engine_prompts return engine_inputs
@staticmethod @staticmethod
def _extract_mm_features( def _extract_mm_features(
engine_prompt: ProcessorInputs, engine_input: EngineInput,
) -> MultiModalFeatures | None: ) -> MultiModalFeatures | None:
"""Extract multimodal metadata from a rendered engine prompt. """Extract multimodal metadata from a rendered engine prompt.
Returns ``None`` for text-only prompts. Returns ``None`` for text-only prompts.
""" """
if engine_prompt.get("type") != "multimodal": if engine_input.get("type") != "multimodal":
return None return None
# At this point engine_prompt is a MultiModalInputs TypedDict. # At this point engine_input is a MultiModalInputs TypedDict.
mm_hashes: MultiModalHashes = engine_prompt["mm_hashes"] # type: ignore[typeddict-item] mm_hashes: MultiModalHashes = engine_input["mm_hashes"] # type: ignore[typeddict-item]
raw_placeholders: MultiModalPlaceholderDict = engine_prompt["mm_placeholders"] # type: ignore[typeddict-item] raw_placeholders: MultiModalPlaceholders = engine_input["mm_placeholders"] # type: ignore[typeddict-item]
mm_placeholders = { mm_placeholders = {
modality: [ modality: [
...@@ -401,13 +406,9 @@ class OpenAIServingRender: ...@@ -401,13 +406,9 @@ class OpenAIServingRender:
# Render prompt token ids. # Render prompt token ids.
prompt_token_ids = render_for_completion(messages) prompt_token_ids = render_for_completion(messages)
engine_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) engine_input = tokens_input(prompt_token_ids, cache_salt=request.cache_salt)
# Add cache_salt if provided in the request
if request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
return messages, [engine_prompt] return messages, [engine_input]
def create_error_response( def create_error_response(
self, self,
...@@ -450,7 +451,7 @@ class OpenAIServingRender: ...@@ -450,7 +451,7 @@ class OpenAIServingRender:
request: Any, request: Any,
prompt_input: str | list[str] | list[int] | list[list[int]] | None, prompt_input: str | list[str] | list[int] | list[list[int]] | None,
prompt_embeds: bytes | list[bytes] | None, prompt_embeds: bytes | list[bytes] | None,
) -> list[ProcessorInputs]: ) -> list[EngineInput]:
"""Copied from OpenAIServing._preprocess_completion.""" """Copied from OpenAIServing._preprocess_completion."""
prompts = list[SingletonPrompt | bytes]() prompts = list[SingletonPrompt | bytes]()
if prompt_embeds is not None: # embeds take higher priority if prompt_embeds is not None: # embeds take higher priority
...@@ -463,7 +464,7 @@ class OpenAIServingRender: ...@@ -463,7 +464,7 @@ class OpenAIServingRender:
self, self,
request: Any, request: Any,
prompts: Sequence[PromptType | bytes], prompts: Sequence[PromptType | bytes],
) -> list[ProcessorInputs]: ) -> list[EngineInput]:
"""Copied from OpenAIServing._preprocess_cmpl.""" """Copied from OpenAIServing._preprocess_cmpl."""
renderer = self.renderer renderer = self.renderer
model_config = self.model_config model_config = self.model_config
...@@ -497,7 +498,7 @@ class OpenAIServingRender: ...@@ -497,7 +498,7 @@ class OpenAIServingRender:
default_template_kwargs: dict[str, Any] | None, default_template_kwargs: dict[str, Any] | None,
tool_dicts: list[dict[str, Any]] | None = None, tool_dicts: list[dict[str, Any]] | None = None,
tool_parser: type[ToolParser] | None = None, tool_parser: type[ToolParser] | None = None,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]]: ) -> tuple[list[ConversationMessage], list[EngineInput]]:
"""Copied from OpenAIServing._preprocess_chat.""" """Copied from OpenAIServing._preprocess_chat."""
renderer = self.renderer renderer = self.renderer
mm_config = self.model_config.multimodal_config mm_config = self.model_config.multimodal_config
...@@ -519,7 +520,7 @@ class OpenAIServingRender: ...@@ -519,7 +520,7 @@ class OpenAIServingRender:
default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None), default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None),
) )
(conversation,), (engine_prompt,) = await renderer.render_chat_async( (conversation,), (engine_input,) = await renderer.render_chat_async(
[messages], [messages],
chat_params, chat_params,
tok_params, tok_params,
...@@ -546,4 +547,4 @@ class OpenAIServingRender: ...@@ -546,4 +547,4 @@ class OpenAIServingRender:
tokenizer = renderer.get_tokenizer() tokenizer = renderer.get_tokenizer()
request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type] request = tool_parser(tokenizer).adjust_request(request=request) # type: ignore[arg-type]
return conversation, [engine_prompt] return conversation, [engine_input]
...@@ -20,7 +20,7 @@ from vllm.entrypoints.serve.tokenize.protocol import ( ...@@ -20,7 +20,7 @@ from vllm.entrypoints.serve.tokenize.protocol import (
TokenizeResponse, TokenizeResponse,
TokenizerInfoResponse, TokenizerInfoResponse,
) )
from vllm.inputs import TokensPrompt, token_inputs from vllm.inputs import TokensPrompt, tokens_input
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
...@@ -79,7 +79,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -79,7 +79,7 @@ class OpenAIServingTokenization(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
_, engine_prompts = await self.openai_serving_render.preprocess_chat( _, engine_inputs = await self.openai_serving_render.preprocess_chat(
request, request,
request.messages, request.messages,
default_template=self.chat_template, default_template=self.chat_template,
...@@ -88,22 +88,22 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -88,22 +88,22 @@ class OpenAIServingTokenization(OpenAIServing):
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
) )
else: else:
engine_prompts = await self.openai_serving_render.preprocess_completion( engine_inputs = await self.openai_serving_render.preprocess_completion(
request, request,
prompt_input=request.prompt, prompt_input=request.prompt,
prompt_embeds=None, prompt_embeds=None,
) )
input_ids: list[int] = [] input_ids: list[int] = []
for engine_prompt in engine_prompts: for engine_input in engine_inputs:
self._log_inputs( self._log_inputs(
request_id, request_id,
engine_prompt, engine_input,
params=None, params=None,
lora_request=lora_request, lora_request=lora_request,
) )
prompt_components = self._extract_prompt_components(engine_prompt) prompt_components = self._extract_prompt_components(engine_input)
if prompt_components.token_ids is not None: if prompt_components.token_ids is not None:
input_ids.extend(prompt_components.token_ids) input_ids.extend(prompt_components.token_ids)
...@@ -134,16 +134,16 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -134,16 +134,16 @@ class OpenAIServingTokenization(OpenAIServing):
self._log_inputs( self._log_inputs(
request_id, request_id,
token_inputs(request.tokens), tokens_input(request.tokens),
params=None, params=None,
lora_request=lora_request, lora_request=lora_request,
) )
engine_prompt = await self.renderer.tokenize_prompt_async( tok_prompt = await self.renderer.tokenize_prompt_async(
TokensPrompt(prompt_token_ids=request.tokens), TokensPrompt(prompt_token_ids=request.tokens),
request.build_tok_params(self.model_config), request.build_tok_params(self.model_config),
) )
prompt_text = engine_prompt["prompt"] # type: ignore[typeddict-item] prompt_text = tok_prompt["prompt"] # type: ignore[typeddict-item]
return DetokenizeResponse(prompt=prompt_text) return DetokenizeResponse(prompt=prompt_text)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment