Unverified Commit 4a718e77 authored by Sergey Zinchenko's avatar Sergey Zinchenko Committed by GitHub
Browse files

[Bug] Fix Failure in /v1/chat/completions/render for Multimodal Requests...

[Bug] Fix Failure in /v1/chat/completions/render for Multimodal Requests (https://github.com/vllm-project/vllm/issues/35665) (#35684)
parent 600a039f
......@@ -167,6 +167,7 @@ fo = "fo"
nd = "nd"
eles = "eles"
datas = "datas"
ser = "ser"
ure = "ure"
[tool.uv]
......
......@@ -7,7 +7,7 @@ import httpx
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
from tests.utils import RemoteLaunchRenderServer
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
......@@ -16,7 +16,7 @@ MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
def server():
args: list[str] = []
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
with RemoteLaunchRenderServer(MODEL_NAME, args) as remote_server:
yield remote_server
......@@ -43,23 +43,20 @@ async def test_completion_render_basic(client):
assert response.status_code == 200
data = response.json()
# Verify response structure
# Verify response structure - list of GenerateRequest
assert isinstance(data, list)
assert len(data) > 0
# Verify first prompt
# Verify first prompt is a GenerateRequest
first_prompt = data[0]
assert "prompt_token_ids" in first_prompt
assert "prompt" in first_prompt
assert isinstance(first_prompt["prompt_token_ids"], list)
assert len(first_prompt["prompt_token_ids"]) > 0
assert isinstance(first_prompt["prompt"], str)
# Verify prompt text is preserved
assert (
"When should a chat-completions handler return an empty string?"
in first_prompt["prompt"]
)
assert "token_ids" in first_prompt
assert "sampling_params" in first_prompt
assert "model" in first_prompt
assert "request_id" in first_prompt
assert isinstance(first_prompt["token_ids"], list)
assert len(first_prompt["token_ids"]) > 0
assert first_prompt["model"] == MODEL_NAME
assert first_prompt["request_id"].startswith("cmpl-")
@pytest.mark.asyncio
......@@ -84,36 +81,15 @@ async def test_chat_completion_render_basic(client):
assert response.status_code == 200
data = response.json()
# Verify response structure - should be [conversation, engine_prompts]
assert isinstance(data, list)
assert len(data) == 2
conversation, engine_prompts = data
# Verify conversation
assert isinstance(conversation, list)
assert len(conversation) > 0
assert conversation[0]["role"] == "user"
assert "empty string" in conversation[0]["content"]
# Verify engine_prompts
assert isinstance(engine_prompts, list)
assert len(engine_prompts) > 0
# Verify response structure - should be a GenerateRequest
assert isinstance(data, dict)
assert "token_ids" in data
assert isinstance(data["token_ids"], list)
assert len(data["token_ids"]) > 0
first_prompt = engine_prompts[0]
assert "prompt_token_ids" in first_prompt
assert "prompt" in first_prompt
assert isinstance(first_prompt["prompt_token_ids"], list)
assert len(first_prompt["prompt_token_ids"]) > 0
# Verify chat template was applied (should have instruction markers)
assert "[INST]" in first_prompt["prompt"]
assert "[/INST]" in first_prompt["prompt"]
# Verify token IDs are correctly preserved as integers
token_ids = first_prompt["prompt_token_ids"]
# Verify token IDs are integers and BOS token is present
token_ids = data["token_ids"]
assert all(isinstance(tid, int) for tid in token_ids)
# Verify BOS token (usually 1 for LLaMA models)
assert token_ids[0] == 1
......@@ -131,15 +107,18 @@ async def test_completion_render_multiple_prompts(client):
assert response.status_code == 200
data = response.json()
# Should return two prompts
# Should return two GenerateRequest items
assert isinstance(data, list)
assert len(data) == 2
# Verify both prompts have required fields
# Verify both prompts have GenerateRequest fields
for prompt in data:
assert "prompt_token_ids" in prompt
assert "prompt" in prompt
assert len(prompt["prompt_token_ids"]) > 0
assert "token_ids" in prompt
assert "sampling_params" in prompt
assert "model" in prompt
assert "request_id" in prompt
assert len(prompt["token_ids"]) > 0
assert prompt["request_id"].startswith("cmpl-")
@pytest.mark.asyncio
......@@ -160,17 +139,49 @@ async def test_chat_completion_render_multi_turn(client):
assert response.status_code == 200
data = response.json()
conversation, engine_prompts = data
# Verify tokenization occurred
assert isinstance(data, dict)
assert "token_ids" in data
assert isinstance(data["token_ids"], list)
assert len(data["token_ids"]) > 0
# Verify all messages preserved
assert len(conversation) == 3
assert conversation[0]["role"] == "user"
assert conversation[1]["role"] == "assistant"
assert conversation[2]["role"] == "user"
# Verify tokenization occurred
assert len(engine_prompts) > 0
assert len(engine_prompts[0]["prompt_token_ids"]) > 0
@pytest.mark.asyncio
async def test_chat_completion_render_with_stream_true(client):
"""Render accepts stream params but still returns JSON (non-streamed)."""
response = await client.post(
"/v1/chat/completions/render",
json={
"model": MODEL_NAME,
"stream": True,
"stream_options": {
"include_usage": True,
"continuous_usage_stats": True,
},
"messages": [
{
"role": "user",
"content": "Stream options should be accepted by /render.",
}
],
},
)
assert response.status_code == 200
assert response.headers.get("content-type", "").startswith("application/json")
data = response.json()
assert isinstance(data, dict)
assert "token_ids" in data
assert isinstance(data["token_ids"], list)
assert len(data["token_ids"]) > 0
# /render should preserve stream fields on the returned token-in request.
assert data.get("stream") is True
assert isinstance(data.get("stream_options"), dict)
assert data["stream_options"].get("include_usage") is True
assert data["stream_options"].get("continuous_usage_stats") is True
@pytest.mark.asyncio
......@@ -224,3 +235,31 @@ async def test_completion_render_no_generation(client):
assert response.status_code == 200
# Render should be fast (< 1 second) since no generation
assert elapsed < 1.0
@pytest.mark.asyncio
async def test_chat_completion_render_with_sampling_params(client):
"""Verify sampling params are correctly returned by /render."""
response = await client.post(
"/v1/chat/completions/render",
json={
"model": MODEL_NAME,
"messages": [{"role": "user", "content": "Test sampling params"}],
"temperature": 0.123,
"top_p": 0.456,
"frequency_penalty": 1.1,
},
)
assert response.status_code == 200
data = response.json()
assert "sampling_params" in data
sampling_params = data["sampling_params"]
assert sampling_params.get("temperature") == 0.123
assert sampling_params.get("top_p") == 0.456
assert sampling_params.get("frequency_penalty") == 1.1
# Check that internal fields are not present
assert "_all_stop_token_ids" not in sampling_params
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Multimodal tests for the /render endpoints that expose prompt preprocessing."""
import httpx
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
from vllm.multimodal.utils import encode_image_url
VISION_MODEL_NAME = "Qwen/Qwen3-VL-2B-Instruct"
@pytest.fixture(scope="module")
def vision_server():
"""Vision-capable server used for multimodal /render tests."""
args = [
"--enforce-eager",
"--max-model-len",
"100",
"--max-num-seqs",
"1",
"--limit-mm-per-prompt.image",
"1",
"--limit-mm-per-prompt.video",
"0",
]
env_overrides: dict[str, str] = {}
with RemoteOpenAIServer(
VISION_MODEL_NAME,
args,
env_dict=env_overrides,
) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def vision_client(vision_server):
async with httpx.AsyncClient(
base_url=vision_server.url_for(""), timeout=60.0
) as http_client:
yield http_client
@pytest.mark.asyncio
async def test_chat_completion_render_with_base64_image_url(
vision_client,
local_asset_server,
):
"""Render a multimodal chat request and verify tokens are returned."""
image = local_asset_server.get_image_asset("RGBA_comp.png")
data_url = encode_image_url(image, format="PNG")
assert data_url.startswith("data:image/")
assert ";base64," in data_url
response = await vision_client.post(
"/v1/chat/completions/render",
json={
"model": VISION_MODEL_NAME,
"messages": [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": "What's in this image?"},
],
}
],
},
)
assert response.status_code == 200
data = response.json()
assert isinstance(data, dict)
assert "token_ids" in data
assert isinstance(data["token_ids"], list)
assert len(data["token_ids"]) > 0
# Verify multimodal features are populated
assert "features" in data
features = data["features"]
assert features is not None
# mm_hashes: should have an "image" key with a list of hash strings
assert "mm_hashes" in features
assert "image" in features["mm_hashes"]
image_hashes = features["mm_hashes"]["image"]
assert isinstance(image_hashes, list)
assert len(image_hashes) > 0
assert all(isinstance(h, str) for h in image_hashes)
# mm_placeholders: should have an "image" key with offset/length dicts
assert "mm_placeholders" in features
assert "image" in features["mm_placeholders"]
image_placeholders = features["mm_placeholders"]["image"]
assert isinstance(image_placeholders, list)
assert len(image_placeholders) > 0
for p in image_placeholders:
assert "offset" in p
assert "length" in p
assert isinstance(p["offset"], int)
assert isinstance(p["length"], int)
assert p["length"] > 0
@pytest.mark.asyncio
async def test_tokenize_matches_render_for_multimodal_input(
vision_client,
local_asset_server,
):
"""`/tokenize` should match `/v1/chat/completions/render` token output."""
image = local_asset_server.get_image_asset("RGBA_comp.png")
data_url = encode_image_url(image, format="PNG")
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": "What's in this image?"},
],
}
]
render_response = await vision_client.post(
"/v1/chat/completions/render",
json={
"model": VISION_MODEL_NAME,
"messages": messages,
},
)
assert render_response.status_code == 200
render_data = render_response.json()
tokenize_response = await vision_client.post(
"/tokenize",
json={
"model": VISION_MODEL_NAME,
"messages": messages,
},
)
assert tokenize_response.status_code == 200
tokenize_data = tokenize_response.json()
assert tokenize_data["tokens"] == render_data["token_ids"]
assert tokenize_data["count"] == len(render_data["token_ids"])
......@@ -42,21 +42,12 @@ async def test_chat_render_basic(client):
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 2
conversation, engine_prompts = data
assert isinstance(conversation, list)
assert conversation[0]["role"] == "user"
assert isinstance(engine_prompts, list)
assert len(engine_prompts) > 0
first_prompt = engine_prompts[0]
assert "prompt_token_ids" in first_prompt
assert "prompt" in first_prompt
assert isinstance(first_prompt["prompt_token_ids"], list)
assert all(isinstance(t, int) for t in first_prompt["prompt_token_ids"])
# Response should be a GenerateRequest dict
assert isinstance(data, dict)
assert "token_ids" in data
assert isinstance(data["token_ids"], list)
assert len(data["token_ids"]) > 0
assert all(isinstance(t, int) for t in data["token_ids"])
@pytest.mark.asyncio
......@@ -74,14 +65,12 @@ async def test_chat_render_multi_turn(client):
)
assert response.status_code == 200
conversation, engine_prompts = response.json()
data = response.json()
assert len(conversation) == 3
assert conversation[0]["role"] == "user"
assert conversation[1]["role"] == "assistant"
assert conversation[2]["role"] == "user"
assert len(engine_prompts) > 0
assert len(engine_prompts[0]["prompt_token_ids"]) > 0
assert isinstance(data, dict)
assert "token_ids" in data
assert isinstance(data["token_ids"], list)
assert len(data["token_ids"]) > 0
@pytest.mark.asyncio
......@@ -118,11 +107,13 @@ async def test_completion_render_basic(client):
assert len(data) > 0
first_prompt = data[0]
assert "prompt_token_ids" in first_prompt
assert "prompt" in first_prompt
assert isinstance(first_prompt["prompt_token_ids"], list)
assert len(first_prompt["prompt_token_ids"]) > 0
assert "Once upon a time" in first_prompt["prompt"]
assert "token_ids" in first_prompt
assert "sampling_params" in first_prompt
assert "model" in first_prompt
assert "request_id" in first_prompt
assert isinstance(first_prompt["token_ids"], list)
assert len(first_prompt["token_ids"]) > 0
assert first_prompt["request_id"].startswith("cmpl-")
@pytest.mark.asyncio
......@@ -142,9 +133,12 @@ async def test_completion_render_multiple_prompts(client):
assert len(data) == 2
for prompt in data:
assert "prompt_token_ids" in prompt
assert "prompt" in prompt
assert len(prompt["prompt_token_ids"]) > 0
assert "token_ids" in prompt
assert "sampling_params" in prompt
assert "model" in prompt
assert "request_id" in prompt
assert len(prompt["token_ids"]) > 0
assert prompt["request_id"].startswith("cmpl-")
@pytest.mark.asyncio
......
......@@ -368,6 +368,7 @@ async def init_app_state(
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
default_chat_template_kwargs=args.default_chat_template_kwargs,
trust_request_chat_template=args.trust_request_chat_template,
)
......@@ -457,6 +458,9 @@ async def init_render_app_state(
state.openai_serving_models = model_registry
# Expose tokenization via the render handler (no engine required).
state.openai_serving_tokenization = state.openai_serving_render
state.vllm_config = vllm_config
# Disable stats logging — there is no engine to poll.
state.log_stats = False
......
......@@ -17,7 +17,6 @@ from pydantic import (
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname
......@@ -269,53 +268,3 @@ class GenerationError(Exception):
def __init__(self, message: str = "Internal server error"):
super().__init__(message)
self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
token_ids: list[int]
"""The token ids to generate text from."""
# features: MultiModalFeatureSpec
# TODO (NickLucche): implement once Renderer work is completed
features: str | None = None
"""The processed MM inputs for the model."""
sampling_params: SamplingParams
"""The sampling parameters for the model."""
model: str | None = None
stream: bool | None = False
stream_options: StreamOptions | None = None
cache_salt: str | None = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
......@@ -11,7 +11,7 @@ from contextlib import asynccontextmanager
from http import HTTPStatus
import pydantic
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from starlette.concurrency import iterate_in_threadpool
......@@ -350,7 +350,8 @@ async def engine_error_handler(
server=req.app.state.server,
engine=req.app.state.engine_client,
)
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
err = create_error_response(exc)
return JSONResponse(err.model_dump(), status_code=err.error.code)
async def exception_handler(req: Request, exc: Exception):
......
......@@ -2,20 +2,55 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from vllm.config import ModelConfig
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionLogProbs
from vllm.entrypoints.openai.engine.protocol import (
SamplingParams,
StreamOptions,
)
from vllm.entrypoints.openai.engine.protocol import StreamOptions
from vllm.logprobs import Logprob
from vllm.renderers import TokenizeParams
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
####### Tokens IN <> Tokens OUT #######
class PlaceholderRangeInfo(BaseModel):
"""Serializable placeholder location for a single multi-modal item."""
offset: int
"""Start index of the placeholder tokens in the prompt."""
length: int
"""Number of placeholder tokens."""
# TODO: add ``is_embed: list[bool] | None`` once the /generate side
# consumes features — some models (e.g. Qwen-VL) use sparse
# placeholder masks that cannot be recomputed from offset+length alone.
class MultiModalFeatures(BaseModel):
"""Lightweight multimodal metadata produced by the render step.
Carries hashes (for cache lookup / identification) and placeholder
positions so the downstream ``/generate`` service knows *where* in
the token sequence each multimodal item lives.
.. note:: Phase 1 — metadata only.
Phase 2 should add ``mm_kwargs`` (processed tensor data) using a
binary transport so the ``/generate`` side can skip re-processing.
The ``/generate`` endpoint must also be updated to inject these
features into ``ProcessorInputs`` before passing to
``InputProcessor.process_inputs``.
"""
mm_hashes: dict[str, list[str]]
"""Per-modality item hashes, e.g. ``{"image": ["abc", "def"]}``."""
mm_placeholders: dict[str, list[PlaceholderRangeInfo]]
"""Per-modality placeholder ranges in the token sequence."""
class GenerateRequest(BaseModel):
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
......@@ -28,10 +63,15 @@ class GenerateRequest(BaseModel):
token_ids: list[int]
"""The token ids to generate text from."""
# features: MultiModalFeatureSpec
# TODO (NickLucche): implement once Renderer work is completed
features: str | None = None
"""The processed MM inputs for the model."""
@field_validator("token_ids")
@classmethod
def validate_token_ids(cls, v: list[int]) -> list[int]:
if any(t < 0 for t in v):
raise ValueError("token_ids must not contain negative values")
return v
features: MultiModalFeatures | None = None
"""Multimodal hashes and placeholder positions (populated for MM inputs)."""
sampling_params: SamplingParams
"""The sampling parameters for the model."""
......
......@@ -9,6 +9,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionReque
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.logger import init_logger
......@@ -24,7 +25,7 @@ def render(request: Request) -> OpenAIServingRender | None:
@router.post(
"/v1/chat/completions/render",
dependencies=[Depends(validate_json_request)],
response_model=list,
response_model=GenerateRequest,
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
......@@ -44,13 +45,13 @@ async def render_chat_completion(request: ChatCompletionRequest, raw_request: Re
if isinstance(result, ErrorResponse):
return JSONResponse(content=result.model_dump(), status_code=result.error.code)
return JSONResponse(content=result)
return JSONResponse(content=result.model_dump())
@router.post(
"/v1/completions/render",
dependencies=[Depends(validate_json_request)],
response_model=list,
response_model=list[GenerateRequest],
responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
......@@ -67,7 +68,7 @@ async def render_completion(request: CompletionRequest, raw_request: Request):
if isinstance(result, ErrorResponse):
return JSONResponse(content=result.model_dump(), status_code=result.error.code)
return JSONResponse(content=result)
return JSONResponse(content=[item.model_dump() for item in result])
def attach_router(app: FastAPI) -> None:
......
......@@ -24,14 +24,29 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
parse_chat_inputs_to_harmony_messages,
render_for_completion,
)
from vllm.entrypoints.utils import create_error_response
from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
MultiModalFeatures,
PlaceholderRangeInfo,
)
from vllm.entrypoints.utils import (
create_error_response,
get_max_tokens,
)
from vllm.inputs.data import ProcessorInputs, PromptType, SingletonPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.multimodal.inputs import MultiModalHashes, MultiModalPlaceholderDict
from vllm.parser import ParserManager
from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
from vllm.renderers.inputs.preprocess import (
extract_prompt_components,
extract_prompt_len,
parse_model_prompt,
prompt_to_seq,
)
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser
from vllm.utils import random_uuid
from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.mistral import mt as _mt
......@@ -83,10 +98,18 @@ class OpenAIServingRender:
self.supports_browsing = False
self.supports_code_interpreter = False
self.default_sampling_params = model_config.get_diff_sampling_param()
mc = model_config
self.override_max_tokens = (
self.default_sampling_params.get("max_tokens")
if mc.generation_config not in ("auto", "vllm")
else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
)
async def render_chat_request(
self,
request: ChatCompletionRequest,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse:
) -> GenerateRequest | ErrorResponse:
"""Validate the model and preprocess a chat completion request.
This is the authoritative implementation used directly by the
......@@ -96,7 +119,56 @@ class OpenAIServingRender:
if error_check_ret is not None:
logger.error("Error with model %s", error_check_ret)
return error_check_ret
return await self.render_chat(request)
if request.use_beam_search:
return self.create_error_response(
"Beam search is not supported by the render endpoint"
)
result = await self.render_chat(request)
if isinstance(result, ErrorResponse):
return result
_, engine_prompts = result
if len(engine_prompts) != 1:
return self.create_error_response(
f"Expected exactly 1 engine prompt, got {len(engine_prompts)}"
)
engine_prompt = engine_prompts[0]
prompt_components = extract_prompt_components(self.model_config, engine_prompt)
token_ids = prompt_components.token_ids
if not token_ids:
return self.create_error_response("No token_ids rendered")
token_ids = list(token_ids)
input_length = extract_prompt_len(self.model_config, engine_prompt)
max_tokens = get_max_tokens(
self.model_config.max_model_len,
request.max_completion_tokens
if request.max_completion_tokens is not None
else request.max_tokens,
input_length,
self.default_sampling_params,
self.override_max_tokens,
)
params = request.to_sampling_params(max_tokens, self.default_sampling_params)
request_id = f"chatcmpl-{random_uuid()}"
return GenerateRequest(
request_id=request_id,
token_ids=token_ids,
features=self._extract_mm_features(engine_prompt),
sampling_params=params,
model=request.model,
stream=bool(request.stream),
stream_options=(request.stream_options if request.stream else None),
cache_salt=request.cache_salt,
priority=request.priority,
)
async def render_chat(
self,
......@@ -183,7 +255,7 @@ class OpenAIServingRender:
async def render_completion_request(
self,
request: CompletionRequest,
) -> list[ProcessorInputs] | ErrorResponse:
) -> list[GenerateRequest] | ErrorResponse:
"""Validate the model and preprocess a completion request.
This is the authoritative implementation used directly by the
......@@ -192,7 +264,48 @@ class OpenAIServingRender:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
return await self.render_completion(request)
result = await self.render_completion(request)
if isinstance(result, ErrorResponse):
return result
generate_requests: list[GenerateRequest] = []
for engine_prompt in result:
prompt_components = extract_prompt_components(
self.model_config, engine_prompt
)
token_ids = prompt_components.token_ids
if not token_ids:
return self.create_error_response("No token_ids rendered")
token_ids = list(token_ids)
input_length = extract_prompt_len(self.model_config, engine_prompt)
max_tokens = get_max_tokens(
self.model_config.max_model_len,
request.max_tokens,
input_length,
self.default_sampling_params,
self.override_max_tokens,
)
params = request.to_sampling_params(
max_tokens, self.default_sampling_params
)
request_id = f"cmpl-{random_uuid()}"
generate_requests.append(
GenerateRequest(
request_id=request_id,
token_ids=token_ids,
features=self._extract_mm_features(engine_prompt),
sampling_params=params,
model=request.model,
stream=bool(request.stream),
stream_options=(request.stream_options if request.stream else None),
cache_salt=request.cache_salt,
priority=request.priority,
)
)
return generate_requests
async def render_completion(
self,
......@@ -223,6 +336,33 @@ class OpenAIServingRender:
return engine_prompts
@staticmethod
def _extract_mm_features(
engine_prompt: ProcessorInputs,
) -> MultiModalFeatures | None:
"""Extract multimodal metadata from a rendered engine prompt.
Returns ``None`` for text-only prompts.
"""
if engine_prompt.get("type") != "multimodal":
return None
# At this point engine_prompt is a MultiModalInputs TypedDict.
mm_hashes: MultiModalHashes = engine_prompt["mm_hashes"] # type: ignore[typeddict-item]
raw_placeholders: MultiModalPlaceholderDict = engine_prompt["mm_placeholders"] # type: ignore[typeddict-item]
mm_placeholders = {
modality: [
PlaceholderRangeInfo(offset=p.offset, length=p.length) for p in ranges
]
for modality, ranges in raw_placeholders.items()
}
return MultiModalFeatures(
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholders,
)
def _make_request_with_harmony(
self,
request: ChatCompletionRequest,
......
......@@ -35,6 +35,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_logger: RequestLogger | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
default_chat_template_kwargs: dict[str, Any] | None = None,
trust_request_chat_template: bool = False,
) -> None:
super().__init__(
......@@ -45,6 +46,7 @@ class OpenAIServingTokenization(OpenAIServing):
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.default_chat_template_kwargs = default_chat_template_kwargs or {}
self.trust_request_chat_template = trust_request_chat_template
async def create_tokenize(
......@@ -79,7 +81,7 @@ class OpenAIServingTokenization(OpenAIServing):
request.messages,
default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None,
default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts,
)
else:
......@@ -98,8 +100,9 @@ class OpenAIServingTokenization(OpenAIServing):
lora_request=lora_request,
)
if "prompt_token_ids" in engine_prompt:
input_ids.extend(engine_prompt["prompt_token_ids"]) # type: ignore[typeddict-item]
prompt_components = self._extract_prompt_components(engine_prompt)
if prompt_components.token_ids is not None:
input_ids.extend(prompt_components.token_ids)
token_strs = None
if request.return_token_strs:
......
......@@ -8,7 +8,7 @@ from collections.abc import Callable, Sequence
from functools import partial
from inspect import isclass
from types import FunctionType
from typing import Any, TypeAlias, get_type_hints
from typing import Any, ClassVar, TypeAlias, cast, get_type_hints
import cloudpickle
import msgspec
......@@ -460,6 +460,19 @@ def run_method(
class PydanticMsgspecMixin:
"""Make a ``msgspec.Struct`` compatible with Pydantic for both
**validation** (JSON/dict -> Struct) and **serialization**
(Struct -> JSON-safe dict).
Subclasses may set ``__pydantic_msgspec_exclude__`` (a ``set[str]``)
to list non-underscore field names that should also be stripped from
serialized output. Fields whose names start with ``_`` are always
excluded automatically.
"""
# Subclasses can override to exclude additional public-but-internal keys.
__pydantic_msgspec_exclude__: ClassVar[set[str]] = set()
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
......@@ -476,32 +489,62 @@ class PydanticMsgspecMixin:
# Build the Pydantic typed_dict_field for each msgspec field
fields = {}
for name, hint in type_hints.items():
if name not in msgspec_fields:
# Skip ClassVar and other non-struct annotations.
continue
# Skip private fields — they are excluded from serialization
# and should not appear in the generated JSON/OpenAPI schema.
if name.startswith("_"):
continue
msgspec_field = msgspec_fields[name]
# typed_dict_field using the handler to get the schema
field_schema = handler(hint)
# Add default value to the schema.
# Mark fields with defaults as not required so the generated
# JSON Schema stays consistent with ``omit_defaults=True``
# serialization (fields at their default value may be absent).
if msgspec_field.default_factory is not msgspec.NODEFAULT:
wrapped_schema = core_schema.with_default_schema(
schema=field_schema,
default_factory=msgspec_field.default_factory,
)
fields[name] = core_schema.typed_dict_field(wrapped_schema)
fields[name] = core_schema.typed_dict_field(
wrapped_schema, required=False
)
elif msgspec_field.default is not msgspec.NODEFAULT:
wrapped_schema = core_schema.with_default_schema(
schema=field_schema,
default=msgspec_field.default,
)
fields[name] = core_schema.typed_dict_field(wrapped_schema)
fields[name] = core_schema.typed_dict_field(
wrapped_schema, required=False
)
else:
# No default, so Pydantic will treat it as required
fields[name] = core_schema.typed_dict_field(field_schema)
return core_schema.no_info_after_validator_function(
typed_dict_then_convert = core_schema.no_info_after_validator_function(
cls._validate_msgspec,
core_schema.typed_dict_schema(fields),
)
# Build a serializer that strips private / excluded fields.
serializer = core_schema.plain_serializer_function_ser_schema(
cls._serialize_msgspec,
info_arg=False,
)
# Accept either an already-constructed msgspec.Struct instance or a
# JSON/dict-like payload.
return core_schema.union_schema(
[
core_schema.is_instance_schema(source_type),
typed_dict_then_convert,
],
serialization=serializer,
)
@classmethod
def _validate_msgspec(cls, value: Any) -> Any:
"""Validate and convert input to msgspec.Struct instance."""
......@@ -510,3 +553,25 @@ class PydanticMsgspecMixin:
if isinstance(value, dict):
return cls(**value)
return msgspec.convert(value, type=cls)
@staticmethod
def _serialize_msgspec(value: Any) -> Any:
"""Serialize a msgspec.Struct to a JSON-compatible dict, stripping
private (``_``-prefixed) and explicitly excluded fields.
Uses ``msgspec.to_builtins`` which respects ``omit_defaults=True``,
so only fields that differ from their declared defaults are included.
"""
raw = msgspec.to_builtins(value)
if not isinstance(raw, dict):
return raw
exclude: set[str] = cast(
set[str],
getattr(type(value), "__pydantic_msgspec_exclude__", set()),
)
for key in list(raw):
if key.startswith("_") or key in exclude:
del raw[key]
return raw
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment