Unverified Commit 00f8e0d2 authored by Sage's avatar Sage Committed by GitHub
Browse files

[Frontend] Delegate tokenization serving preprocessing to OpenAIServingRender (#37266)


Signed-off-by: default avatarSage Ahrac <sagiahrak@gmail.com>
parent 4af9ed21
...@@ -111,7 +111,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat: ...@@ -111,7 +111,7 @@ def _build_serving_chat(engine: AsyncLLM) -> OpenAIServingChat:
[{"prompt_token_ids": [1, 2, 3]}], [{"prompt_token_ids": [1, 2, 3]}],
) )
serving_chat.openai_serving_render._preprocess_chat = AsyncMock( serving_chat.openai_serving_render.preprocess_chat = AsyncMock(
side_effect=_fake_preprocess_chat side_effect=_fake_preprocess_chat
) )
return serving_chat return serving_chat
......
...@@ -46,6 +46,7 @@ from vllm.entrypoints.sagemaker.api_router import sagemaker_standards_bootstrap ...@@ -46,6 +46,7 @@ from vllm.entrypoints.sagemaker.api_router import sagemaker_standards_bootstrap
from vllm.entrypoints.serve.elastic_ep.middleware import ( from vllm.entrypoints.serve.elastic_ep.middleware import (
ScalingMiddleware, ScalingMiddleware,
) )
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.entrypoints.utils import ( from vllm.entrypoints.utils import (
cli_env_setup, cli_env_setup,
...@@ -365,9 +366,27 @@ async def init_app_state( ...@@ -365,9 +366,27 @@ async def init_app_state(
lora_modules=lora_modules, lora_modules=lora_modules,
) )
await state.openai_serving_models.init_static_loras() await state.openai_serving_models.init_static_loras()
state.openai_serving_render = OpenAIServingRender(
model_config=engine_client.model_config,
renderer=engine_client.renderer,
io_processor=engine_client.io_processor,
model_registry=state.openai_serving_models.registry,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser,
default_chat_template_kwargs=args.default_chat_template_kwargs,
log_error_stack=args.log_error_stack,
)
state.openai_serving_tokenization = OpenAIServingTokenization( state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client, engine_client,
state.openai_serving_models, state.openai_serving_models,
state.openai_serving_render,
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
......
...@@ -74,26 +74,7 @@ async def init_generate_state( ...@@ -74,26 +74,7 @@ async def init_generate_state(
# Render endpoints are always backed by OpenAIServingRender so that # Render endpoints are always backed by OpenAIServingRender so that
# /v1/chat/completions/render and /v1/completions/render work on both # /v1/chat/completions/render and /v1/completions/render work on both
# generate-mode and render-only servers. # generate-mode and render-only servers. Created in init_app_state.
# It is created first so that OpenAIServingChat and OpenAIServingCompletion
# can delegate their preprocessing logic to it.
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
state.openai_serving_render = OpenAIServingRender(
model_config=engine_client.model_config,
renderer=engine_client.renderer,
io_processor=engine_client.io_processor,
model_registry=state.openai_serving_models.registry,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
tool_parser=args.tool_call_parser,
default_chat_template_kwargs=args.default_chat_template_kwargs,
log_error_stack=args.log_error_stack,
)
state.openai_serving_responses = ( state.openai_serving_responses = (
OpenAIServingResponses( OpenAIServingResponses(
......
...@@ -226,7 +226,7 @@ class OpenAIServingRender: ...@@ -226,7 +226,7 @@ class OpenAIServingRender:
if not self.use_harmony: if not self.use_harmony:
# Common case. # Common case.
error_check_ret = self._validate_chat_template( error_check_ret = self.validate_chat_template(
request_chat_template=request.chat_template, request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs, chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template, trust_request_chat_template=self.trust_request_chat_template,
...@@ -234,7 +234,7 @@ class OpenAIServingRender: ...@@ -234,7 +234,7 @@ class OpenAIServingRender:
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
conversation, engine_prompts = await self._preprocess_chat( conversation, engine_prompts = await self.preprocess_chat(
request, request,
request.messages, request.messages,
default_template=self.chat_template, default_template=self.chat_template,
...@@ -328,7 +328,7 @@ class OpenAIServingRender: ...@@ -328,7 +328,7 @@ class OpenAIServingRender:
"prompt_logprobs is not compatible with prompt embeds." "prompt_logprobs is not compatible with prompt embeds."
) )
engine_prompts = await self._preprocess_completion( engine_prompts = await self.preprocess_completion(
request, request,
prompt_input=request.prompt, prompt_input=request.prompt,
prompt_embeds=request.prompt_embeds, prompt_embeds=request.prompt_embeds,
...@@ -426,7 +426,7 @@ class OpenAIServingRender: ...@@ -426,7 +426,7 @@ class OpenAIServingRender:
) -> ErrorResponse | None: ) -> ErrorResponse | None:
return await self.model_registry.check_model(request.model) return await self.model_registry.check_model(request.model)
def _validate_chat_template( def validate_chat_template(
self, self,
request_chat_template: str | None, request_chat_template: str | None,
chat_template_kwargs: dict[str, Any] | None, chat_template_kwargs: dict[str, Any] | None,
...@@ -447,7 +447,7 @@ class OpenAIServingRender: ...@@ -447,7 +447,7 @@ class OpenAIServingRender:
) )
return None return None
async def _preprocess_completion( async def preprocess_completion(
self, self,
request: Any, request: Any,
prompt_input: str | list[str] | list[int] | list[list[int]] | None, prompt_input: str | list[str] | list[int] | list[list[int]] | None,
...@@ -490,7 +490,7 @@ class OpenAIServingRender: ...@@ -490,7 +490,7 @@ class OpenAIServingRender:
}, },
) )
async def _preprocess_chat( async def preprocess_chat(
self, self,
request: Any, request: Any,
messages: list[Any], messages: list[Any],
......
...@@ -11,6 +11,7 @@ from vllm.entrypoints.logger import RequestLogger ...@@ -11,6 +11,7 @@ from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.models.serving import OpenAIServingModels from vllm.entrypoints.openai.models.serving import OpenAIServingModels
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.entrypoints.serve.tokenize.protocol import ( from vllm.entrypoints.serve.tokenize.protocol import (
DetokenizeRequest, DetokenizeRequest,
DetokenizeResponse, DetokenizeResponse,
...@@ -31,6 +32,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -31,6 +32,7 @@ class OpenAIServingTokenization(OpenAIServing):
self, self,
engine_client: EngineClient, engine_client: EngineClient,
models: OpenAIServingModels, models: OpenAIServingModels,
openai_serving_render: OpenAIServingRender,
*, *,
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
chat_template: str | None, chat_template: str | None,
...@@ -44,6 +46,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -44,6 +46,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_logger=request_logger, request_logger=request_logger,
) )
self.openai_serving_render = openai_serving_render
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.default_chat_template_kwargs = default_chat_template_kwargs or {} self.default_chat_template_kwargs = default_chat_template_kwargs or {}
...@@ -68,7 +71,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -68,7 +71,7 @@ class OpenAIServingTokenization(OpenAIServing):
if request.tools is None if request.tools is None
else [tool.model_dump() for tool in request.tools] else [tool.model_dump() for tool in request.tools]
) )
error_check_ret = self._validate_chat_template( error_check_ret = self.openai_serving_render.validate_chat_template(
request_chat_template=request.chat_template, request_chat_template=request.chat_template,
chat_template_kwargs=request.chat_template_kwargs, chat_template_kwargs=request.chat_template_kwargs,
trust_request_chat_template=self.trust_request_chat_template, trust_request_chat_template=self.trust_request_chat_template,
...@@ -76,7 +79,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -76,7 +79,7 @@ class OpenAIServingTokenization(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
_, engine_prompts = await self._preprocess_chat( _, engine_prompts = await self.openai_serving_render.preprocess_chat(
request, request,
request.messages, request.messages,
default_template=self.chat_template, default_template=self.chat_template,
...@@ -85,7 +88,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -85,7 +88,7 @@ class OpenAIServingTokenization(OpenAIServing):
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
) )
else: else:
engine_prompts = await self._preprocess_completion( engine_prompts = await self.openai_serving_render.preprocess_completion(
request, request,
prompt_input=request.prompt, prompt_input=request.prompt,
prompt_embeds=None, prompt_embeds=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment