Unverified Commit 4a718e77 authored by Sergey Zinchenko's avatar Sergey Zinchenko Committed by GitHub
Browse files

[Bug] Fix Failure in /v1/chat/completions/render for Multimodal Requests...

[Bug] Fix Failure in /v1/chat/completions/render for Multimodal Requests (https://github.com/vllm-project/vllm/issues/35665) (#35684)
parent 600a039f
...@@ -167,6 +167,7 @@ fo = "fo" ...@@ -167,6 +167,7 @@ fo = "fo"
nd = "nd" nd = "nd"
eles = "eles" eles = "eles"
datas = "datas" datas = "datas"
ser = "ser"
ure = "ure" ure = "ure"
[tool.uv] [tool.uv]
......
...@@ -7,7 +7,7 @@ import httpx ...@@ -7,7 +7,7 @@ import httpx
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteLaunchRenderServer
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
...@@ -16,7 +16,7 @@ MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" ...@@ -16,7 +16,7 @@ MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
def server(): def server():
args: list[str] = [] args: list[str] = []
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteLaunchRenderServer(MODEL_NAME, args) as remote_server:
yield remote_server yield remote_server
...@@ -43,23 +43,20 @@ async def test_completion_render_basic(client): ...@@ -43,23 +43,20 @@ async def test_completion_render_basic(client):
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
# Verify response structure # Verify response structure - list of GenerateRequest
assert isinstance(data, list) assert isinstance(data, list)
assert len(data) > 0 assert len(data) > 0
# Verify first prompt # Verify first prompt is a GenerateRequest
first_prompt = data[0] first_prompt = data[0]
assert "prompt_token_ids" in first_prompt assert "token_ids" in first_prompt
assert "prompt" in first_prompt assert "sampling_params" in first_prompt
assert isinstance(first_prompt["prompt_token_ids"], list) assert "model" in first_prompt
assert len(first_prompt["prompt_token_ids"]) > 0 assert "request_id" in first_prompt
assert isinstance(first_prompt["prompt"], str) assert isinstance(first_prompt["token_ids"], list)
assert len(first_prompt["token_ids"]) > 0
# Verify prompt text is preserved assert first_prompt["model"] == MODEL_NAME
assert ( assert first_prompt["request_id"].startswith("cmpl-")
"When should a chat-completions handler return an empty string?"
in first_prompt["prompt"]
)
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -84,36 +81,15 @@ async def test_chat_completion_render_basic(client): ...@@ -84,36 +81,15 @@ async def test_chat_completion_render_basic(client):
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
# Verify response structure - should be [conversation, engine_prompts] # Verify response structure - should be a GenerateRequest
assert isinstance(data, list) assert isinstance(data, dict)
assert len(data) == 2 assert "token_ids" in data
assert isinstance(data["token_ids"], list)
conversation, engine_prompts = data assert len(data["token_ids"]) > 0
# Verify conversation
assert isinstance(conversation, list)
assert len(conversation) > 0
assert conversation[0]["role"] == "user"
assert "empty string" in conversation[0]["content"]
# Verify engine_prompts
assert isinstance(engine_prompts, list)
assert len(engine_prompts) > 0
first_prompt = engine_prompts[0] # Verify token IDs are integers and BOS token is present
assert "prompt_token_ids" in first_prompt token_ids = data["token_ids"]
assert "prompt" in first_prompt
assert isinstance(first_prompt["prompt_token_ids"], list)
assert len(first_prompt["prompt_token_ids"]) > 0
# Verify chat template was applied (should have instruction markers)
assert "[INST]" in first_prompt["prompt"]
assert "[/INST]" in first_prompt["prompt"]
# Verify token IDs are correctly preserved as integers
token_ids = first_prompt["prompt_token_ids"]
assert all(isinstance(tid, int) for tid in token_ids) assert all(isinstance(tid, int) for tid in token_ids)
# Verify BOS token (usually 1 for LLaMA models)
assert token_ids[0] == 1 assert token_ids[0] == 1
...@@ -131,15 +107,18 @@ async def test_completion_render_multiple_prompts(client): ...@@ -131,15 +107,18 @@ async def test_completion_render_multiple_prompts(client):
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
# Should return two prompts # Should return two GenerateRequest items
assert isinstance(data, list) assert isinstance(data, list)
assert len(data) == 2 assert len(data) == 2
# Verify both prompts have required fields # Verify both prompts have GenerateRequest fields
for prompt in data: for prompt in data:
assert "prompt_token_ids" in prompt assert "token_ids" in prompt
assert "prompt" in prompt assert "sampling_params" in prompt
assert len(prompt["prompt_token_ids"]) > 0 assert "model" in prompt
assert "request_id" in prompt
assert len(prompt["token_ids"]) > 0
assert prompt["request_id"].startswith("cmpl-")
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -160,17 +139,49 @@ async def test_chat_completion_render_multi_turn(client): ...@@ -160,17 +139,49 @@ async def test_chat_completion_render_multi_turn(client):
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
conversation, engine_prompts = data # Verify tokenization occurred
assert isinstance(data, dict)
assert "token_ids" in data
assert isinstance(data["token_ids"], list)
assert len(data["token_ids"]) > 0
# Verify all messages preserved
assert len(conversation) == 3
assert conversation[0]["role"] == "user"
assert conversation[1]["role"] == "assistant"
assert conversation[2]["role"] == "user"
# Verify tokenization occurred @pytest.mark.asyncio
assert len(engine_prompts) > 0 async def test_chat_completion_render_with_stream_true(client):
assert len(engine_prompts[0]["prompt_token_ids"]) > 0 """Render accepts stream params but still returns JSON (non-streamed)."""
response = await client.post(
"/v1/chat/completions/render",
json={
"model": MODEL_NAME,
"stream": True,
"stream_options": {
"include_usage": True,
"continuous_usage_stats": True,
},
"messages": [
{
"role": "user",
"content": "Stream options should be accepted by /render.",
}
],
},
)
assert response.status_code == 200
assert response.headers.get("content-type", "").startswith("application/json")
data = response.json()
assert isinstance(data, dict)
assert "token_ids" in data
assert isinstance(data["token_ids"], list)
assert len(data["token_ids"]) > 0
# /render should preserve stream fields on the returned token-in request.
assert data.get("stream") is True
assert isinstance(data.get("stream_options"), dict)
assert data["stream_options"].get("include_usage") is True
assert data["stream_options"].get("continuous_usage_stats") is True
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -224,3 +235,31 @@ async def test_completion_render_no_generation(client): ...@@ -224,3 +235,31 @@ async def test_completion_render_no_generation(client):
assert response.status_code == 200 assert response.status_code == 200
# Render should be fast (< 1 second) since no generation # Render should be fast (< 1 second) since no generation
assert elapsed < 1.0 assert elapsed < 1.0
@pytest.mark.asyncio
async def test_chat_completion_render_with_sampling_params(client):
"""Verify sampling params are correctly returned by /render."""
response = await client.post(
"/v1/chat/completions/render",
json={
"model": MODEL_NAME,
"messages": [{"role": "user", "content": "Test sampling params"}],
"temperature": 0.123,
"top_p": 0.456,
"frequency_penalty": 1.1,
},
)
assert response.status_code == 200
data = response.json()
assert "sampling_params" in data
sampling_params = data["sampling_params"]
assert sampling_params.get("temperature") == 0.123
assert sampling_params.get("top_p") == 0.456
assert sampling_params.get("frequency_penalty") == 1.1
# Check that internal fields are not present
assert "_all_stop_token_ids" not in sampling_params
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Multimodal tests for the /render endpoints that expose prompt preprocessing."""
import httpx
import pytest
import pytest_asyncio
from tests.utils import RemoteOpenAIServer
from vllm.multimodal.utils import encode_image_url
VISION_MODEL_NAME = "Qwen/Qwen3-VL-2B-Instruct"
@pytest.fixture(scope="module")
def vision_server():
"""Vision-capable server used for multimodal /render tests."""
args = [
"--enforce-eager",
"--max-model-len",
"100",
"--max-num-seqs",
"1",
"--limit-mm-per-prompt.image",
"1",
"--limit-mm-per-prompt.video",
"0",
]
env_overrides: dict[str, str] = {}
with RemoteOpenAIServer(
VISION_MODEL_NAME,
args,
env_dict=env_overrides,
) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def vision_client(vision_server):
async with httpx.AsyncClient(
base_url=vision_server.url_for(""), timeout=60.0
) as http_client:
yield http_client
@pytest.mark.asyncio
async def test_chat_completion_render_with_base64_image_url(
vision_client,
local_asset_server,
):
"""Render a multimodal chat request and verify tokens are returned."""
image = local_asset_server.get_image_asset("RGBA_comp.png")
data_url = encode_image_url(image, format="PNG")
assert data_url.startswith("data:image/")
assert ";base64," in data_url
response = await vision_client.post(
"/v1/chat/completions/render",
json={
"model": VISION_MODEL_NAME,
"messages": [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": "What's in this image?"},
],
}
],
},
)
assert response.status_code == 200
data = response.json()
assert isinstance(data, dict)
assert "token_ids" in data
assert isinstance(data["token_ids"], list)
assert len(data["token_ids"]) > 0
# Verify multimodal features are populated
assert "features" in data
features = data["features"]
assert features is not None
# mm_hashes: should have an "image" key with a list of hash strings
assert "mm_hashes" in features
assert "image" in features["mm_hashes"]
image_hashes = features["mm_hashes"]["image"]
assert isinstance(image_hashes, list)
assert len(image_hashes) > 0
assert all(isinstance(h, str) for h in image_hashes)
# mm_placeholders: should have an "image" key with offset/length dicts
assert "mm_placeholders" in features
assert "image" in features["mm_placeholders"]
image_placeholders = features["mm_placeholders"]["image"]
assert isinstance(image_placeholders, list)
assert len(image_placeholders) > 0
for p in image_placeholders:
assert "offset" in p
assert "length" in p
assert isinstance(p["offset"], int)
assert isinstance(p["length"], int)
assert p["length"] > 0
@pytest.mark.asyncio
async def test_tokenize_matches_render_for_multimodal_input(
vision_client,
local_asset_server,
):
"""`/tokenize` should match `/v1/chat/completions/render` token output."""
image = local_asset_server.get_image_asset("RGBA_comp.png")
data_url = encode_image_url(image, format="PNG")
messages = [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_url}},
{"type": "text", "text": "What's in this image?"},
],
}
]
render_response = await vision_client.post(
"/v1/chat/completions/render",
json={
"model": VISION_MODEL_NAME,
"messages": messages,
},
)
assert render_response.status_code == 200
render_data = render_response.json()
tokenize_response = await vision_client.post(
"/tokenize",
json={
"model": VISION_MODEL_NAME,
"messages": messages,
},
)
assert tokenize_response.status_code == 200
tokenize_data = tokenize_response.json()
assert tokenize_data["tokens"] == render_data["token_ids"]
assert tokenize_data["count"] == len(render_data["token_ids"])
...@@ -42,21 +42,12 @@ async def test_chat_render_basic(client): ...@@ -42,21 +42,12 @@ async def test_chat_render_basic(client):
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert isinstance(data, list) # Response should be a GenerateRequest dict
assert len(data) == 2 assert isinstance(data, dict)
assert "token_ids" in data
conversation, engine_prompts = data assert isinstance(data["token_ids"], list)
assert len(data["token_ids"]) > 0
assert isinstance(conversation, list) assert all(isinstance(t, int) for t in data["token_ids"])
assert conversation[0]["role"] == "user"
assert isinstance(engine_prompts, list)
assert len(engine_prompts) > 0
first_prompt = engine_prompts[0]
assert "prompt_token_ids" in first_prompt
assert "prompt" in first_prompt
assert isinstance(first_prompt["prompt_token_ids"], list)
assert all(isinstance(t, int) for t in first_prompt["prompt_token_ids"])
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -74,14 +65,12 @@ async def test_chat_render_multi_turn(client): ...@@ -74,14 +65,12 @@ async def test_chat_render_multi_turn(client):
) )
assert response.status_code == 200 assert response.status_code == 200
conversation, engine_prompts = response.json() data = response.json()
assert len(conversation) == 3 assert isinstance(data, dict)
assert conversation[0]["role"] == "user" assert "token_ids" in data
assert conversation[1]["role"] == "assistant" assert isinstance(data["token_ids"], list)
assert conversation[2]["role"] == "user" assert len(data["token_ids"]) > 0
assert len(engine_prompts) > 0
assert len(engine_prompts[0]["prompt_token_ids"]) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -118,11 +107,13 @@ async def test_completion_render_basic(client): ...@@ -118,11 +107,13 @@ async def test_completion_render_basic(client):
assert len(data) > 0 assert len(data) > 0
first_prompt = data[0] first_prompt = data[0]
assert "prompt_token_ids" in first_prompt assert "token_ids" in first_prompt
assert "prompt" in first_prompt assert "sampling_params" in first_prompt
assert isinstance(first_prompt["prompt_token_ids"], list) assert "model" in first_prompt
assert len(first_prompt["prompt_token_ids"]) > 0 assert "request_id" in first_prompt
assert "Once upon a time" in first_prompt["prompt"] assert isinstance(first_prompt["token_ids"], list)
assert len(first_prompt["token_ids"]) > 0
assert first_prompt["request_id"].startswith("cmpl-")
@pytest.mark.asyncio @pytest.mark.asyncio
...@@ -142,9 +133,12 @@ async def test_completion_render_multiple_prompts(client): ...@@ -142,9 +133,12 @@ async def test_completion_render_multiple_prompts(client):
assert len(data) == 2 assert len(data) == 2
for prompt in data: for prompt in data:
assert "prompt_token_ids" in prompt assert "token_ids" in prompt
assert "prompt" in prompt assert "sampling_params" in prompt
assert len(prompt["prompt_token_ids"]) > 0 assert "model" in prompt
assert "request_id" in prompt
assert len(prompt["token_ids"]) > 0
assert prompt["request_id"].startswith("cmpl-")
@pytest.mark.asyncio @pytest.mark.asyncio
......
...@@ -368,6 +368,7 @@ async def init_app_state( ...@@ -368,6 +368,7 @@ async def init_app_state(
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
default_chat_template_kwargs=args.default_chat_template_kwargs,
trust_request_chat_template=args.trust_request_chat_template, trust_request_chat_template=args.trust_request_chat_template,
) )
...@@ -457,6 +458,9 @@ async def init_render_app_state( ...@@ -457,6 +458,9 @@ async def init_render_app_state(
state.openai_serving_models = model_registry state.openai_serving_models = model_registry
# Expose tokenization via the render handler (no engine required).
state.openai_serving_tokenization = state.openai_serving_render
state.vllm_config = vllm_config state.vllm_config = vllm_config
# Disable stats logging — there is no engine to poll. # Disable stats logging — there is no engine to poll.
state.log_stats = False state.log_stats = False
......
...@@ -17,7 +17,6 @@ from pydantic import ( ...@@ -17,7 +17,6 @@ from pydantic import (
from vllm.entrypoints.chat_utils import make_tool_call_id from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
...@@ -269,53 +268,3 @@ class GenerationError(Exception): ...@@ -269,53 +268,3 @@ class GenerationError(Exception):
def __init__(self, message: str = "Internal server error"): def __init__(self, message: str = "Internal server error"):
super().__init__(message) super().__init__(message)
self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR self.status_code = HTTPStatus.INTERNAL_SERVER_ERROR
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
request_id: str = Field(
default_factory=random_uuid,
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
token_ids: list[int]
"""The token ids to generate text from."""
# features: MultiModalFeatureSpec
# TODO (NickLucche): implement once Renderer work is completed
features: str | None = None
"""The processed MM inputs for the model."""
sampling_params: SamplingParams
"""The sampling parameters for the model."""
model: str | None = None
stream: bool | None = False
stream_options: StreamOptions | None = None
cache_salt: str | None = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
...@@ -11,7 +11,7 @@ from contextlib import asynccontextmanager ...@@ -11,7 +11,7 @@ from contextlib import asynccontextmanager
from http import HTTPStatus from http import HTTPStatus
import pydantic import pydantic
from fastapi import FastAPI, HTTPException, Request, Response from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from starlette.concurrency import iterate_in_threadpool from starlette.concurrency import iterate_in_threadpool
...@@ -350,7 +350,8 @@ async def engine_error_handler( ...@@ -350,7 +350,8 @@ async def engine_error_handler(
server=req.app.state.server, server=req.app.state.server,
engine=req.app.state.engine_client, engine=req.app.state.engine_client,
) )
return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) err = create_error_response(exc)
return JSONResponse(err.model_dump(), status_code=err.error.code)
async def exception_handler(req: Request, exc: Exception): async def exception_handler(req: Request, exc: Exception):
......
...@@ -2,20 +2,55 @@ ...@@ -2,20 +2,55 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionLogProbs from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionLogProbs
from vllm.entrypoints.openai.engine.protocol import ( from vllm.entrypoints.openai.engine.protocol import StreamOptions
SamplingParams,
StreamOptions,
)
from vllm.logprobs import Logprob from vllm.logprobs import Logprob
from vllm.renderers import TokenizeParams from vllm.renderers import TokenizeParams
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
####### Tokens IN <> Tokens OUT ####### ####### Tokens IN <> Tokens OUT #######
class PlaceholderRangeInfo(BaseModel):
"""Serializable placeholder location for a single multi-modal item."""
offset: int
"""Start index of the placeholder tokens in the prompt."""
length: int
"""Number of placeholder tokens."""
# TODO: add ``is_embed: list[bool] | None`` once the /generate side
# consumes features — some models (e.g. Qwen-VL) use sparse
# placeholder masks that cannot be recomputed from offset+length alone.
class MultiModalFeatures(BaseModel):
"""Lightweight multimodal metadata produced by the render step.
Carries hashes (for cache lookup / identification) and placeholder
positions so the downstream ``/generate`` service knows *where* in
the token sequence each multimodal item lives.
.. note:: Phase 1 — metadata only.
Phase 2 should add ``mm_kwargs`` (processed tensor data) using a
binary transport so the ``/generate`` side can skip re-processing.
The ``/generate`` endpoint must also be updated to inject these
features into ``ProcessorInputs`` before passing to
``InputProcessor.process_inputs``.
"""
mm_hashes: dict[str, list[str]]
"""Per-modality item hashes, e.g. ``{"image": ["abc", "def"]}``."""
mm_placeholders: dict[str, list[PlaceholderRangeInfo]]
"""Per-modality placeholder ranges in the token sequence."""
class GenerateRequest(BaseModel): class GenerateRequest(BaseModel):
request_id: str = Field( request_id: str = Field(
default_factory=lambda: f"{random_uuid()}", default_factory=lambda: f"{random_uuid()}",
...@@ -28,10 +63,15 @@ class GenerateRequest(BaseModel): ...@@ -28,10 +63,15 @@ class GenerateRequest(BaseModel):
token_ids: list[int] token_ids: list[int]
"""The token ids to generate text from.""" """The token ids to generate text from."""
# features: MultiModalFeatureSpec @field_validator("token_ids")
# TODO (NickLucche): implement once Renderer work is completed @classmethod
features: str | None = None def validate_token_ids(cls, v: list[int]) -> list[int]:
"""The processed MM inputs for the model.""" if any(t < 0 for t in v):
raise ValueError("token_ids must not contain negative values")
return v
features: MultiModalFeatures | None = None
"""Multimodal hashes and placeholder positions (populated for MM inputs)."""
sampling_params: SamplingParams sampling_params: SamplingParams
"""The sampling parameters for the model.""" """The sampling parameters for the model."""
......
...@@ -9,6 +9,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionReque ...@@ -9,6 +9,7 @@ from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionReque
from vllm.entrypoints.openai.completion.protocol import CompletionRequest from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest
from vllm.entrypoints.serve.render.serving import OpenAIServingRender from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -24,7 +25,7 @@ def render(request: Request) -> OpenAIServingRender | None: ...@@ -24,7 +25,7 @@ def render(request: Request) -> OpenAIServingRender | None:
@router.post( @router.post(
"/v1/chat/completions/render", "/v1/chat/completions/render",
dependencies=[Depends(validate_json_request)], dependencies=[Depends(validate_json_request)],
response_model=list, response_model=GenerateRequest,
responses={ responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
...@@ -44,13 +45,13 @@ async def render_chat_completion(request: ChatCompletionRequest, raw_request: Re ...@@ -44,13 +45,13 @@ async def render_chat_completion(request: ChatCompletionRequest, raw_request: Re
if isinstance(result, ErrorResponse): if isinstance(result, ErrorResponse):
return JSONResponse(content=result.model_dump(), status_code=result.error.code) return JSONResponse(content=result.model_dump(), status_code=result.error.code)
return JSONResponse(content=result) return JSONResponse(content=result.model_dump())
@router.post( @router.post(
"/v1/completions/render", "/v1/completions/render",
dependencies=[Depends(validate_json_request)], dependencies=[Depends(validate_json_request)],
response_model=list, response_model=list[GenerateRequest],
responses={ responses={
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
...@@ -67,7 +68,7 @@ async def render_completion(request: CompletionRequest, raw_request: Request): ...@@ -67,7 +68,7 @@ async def render_completion(request: CompletionRequest, raw_request: Request):
if isinstance(result, ErrorResponse): if isinstance(result, ErrorResponse):
return JSONResponse(content=result.model_dump(), status_code=result.error.code) return JSONResponse(content=result.model_dump(), status_code=result.error.code)
return JSONResponse(content=result) return JSONResponse(content=[item.model_dump() for item in result])
def attach_router(app: FastAPI) -> None: def attach_router(app: FastAPI) -> None:
......
...@@ -24,14 +24,29 @@ from vllm.entrypoints.openai.parser.harmony_utils import ( ...@@ -24,14 +24,29 @@ from vllm.entrypoints.openai.parser.harmony_utils import (
parse_chat_inputs_to_harmony_messages, parse_chat_inputs_to_harmony_messages,
render_for_completion, render_for_completion,
) )
from vllm.entrypoints.utils import create_error_response from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
MultiModalFeatures,
PlaceholderRangeInfo,
)
from vllm.entrypoints.utils import (
create_error_response,
get_max_tokens,
)
from vllm.inputs.data import ProcessorInputs, PromptType, SingletonPrompt, TokensPrompt from vllm.inputs.data import ProcessorInputs, PromptType, SingletonPrompt, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal.inputs import MultiModalHashes, MultiModalPlaceholderDict
from vllm.parser import ParserManager from vllm.parser import ParserManager
from vllm.renderers import BaseRenderer, merge_kwargs from vllm.renderers import BaseRenderer, merge_kwargs
from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq from vllm.renderers.inputs.preprocess import (
extract_prompt_components,
extract_prompt_len,
parse_model_prompt,
prompt_to_seq,
)
from vllm.tokenizers import TokenizerLike from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers import ToolParser from vllm.tool_parsers import ToolParser
from vllm.utils import random_uuid
from vllm.utils.mistral import is_mistral_tokenizer from vllm.utils.mistral import is_mistral_tokenizer
from vllm.utils.mistral import mt as _mt from vllm.utils.mistral import mt as _mt
...@@ -83,10 +98,18 @@ class OpenAIServingRender: ...@@ -83,10 +98,18 @@ class OpenAIServingRender:
self.supports_browsing = False self.supports_browsing = False
self.supports_code_interpreter = False self.supports_code_interpreter = False
self.default_sampling_params = model_config.get_diff_sampling_param()
mc = model_config
self.override_max_tokens = (
self.default_sampling_params.get("max_tokens")
if mc.generation_config not in ("auto", "vllm")
else getattr(mc, "override_generation_config", {}).get("max_new_tokens")
)
async def render_chat_request( async def render_chat_request(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> tuple[list[ConversationMessage], list[ProcessorInputs]] | ErrorResponse: ) -> GenerateRequest | ErrorResponse:
"""Validate the model and preprocess a chat completion request. """Validate the model and preprocess a chat completion request.
This is the authoritative implementation used directly by the This is the authoritative implementation used directly by the
...@@ -96,7 +119,56 @@ class OpenAIServingRender: ...@@ -96,7 +119,56 @@ class OpenAIServingRender:
if error_check_ret is not None: if error_check_ret is not None:
logger.error("Error with model %s", error_check_ret) logger.error("Error with model %s", error_check_ret)
return error_check_ret return error_check_ret
return await self.render_chat(request)
if request.use_beam_search:
return self.create_error_response(
"Beam search is not supported by the render endpoint"
)
result = await self.render_chat(request)
if isinstance(result, ErrorResponse):
return result
_, engine_prompts = result
if len(engine_prompts) != 1:
return self.create_error_response(
f"Expected exactly 1 engine prompt, got {len(engine_prompts)}"
)
engine_prompt = engine_prompts[0]
prompt_components = extract_prompt_components(self.model_config, engine_prompt)
token_ids = prompt_components.token_ids
if not token_ids:
return self.create_error_response("No token_ids rendered")
token_ids = list(token_ids)
input_length = extract_prompt_len(self.model_config, engine_prompt)
max_tokens = get_max_tokens(
self.model_config.max_model_len,
request.max_completion_tokens
if request.max_completion_tokens is not None
else request.max_tokens,
input_length,
self.default_sampling_params,
self.override_max_tokens,
)
params = request.to_sampling_params(max_tokens, self.default_sampling_params)
request_id = f"chatcmpl-{random_uuid()}"
return GenerateRequest(
request_id=request_id,
token_ids=token_ids,
features=self._extract_mm_features(engine_prompt),
sampling_params=params,
model=request.model,
stream=bool(request.stream),
stream_options=(request.stream_options if request.stream else None),
cache_salt=request.cache_salt,
priority=request.priority,
)
async def render_chat( async def render_chat(
self, self,
...@@ -183,7 +255,7 @@ class OpenAIServingRender: ...@@ -183,7 +255,7 @@ class OpenAIServingRender:
async def render_completion_request( async def render_completion_request(
self, self,
request: CompletionRequest, request: CompletionRequest,
) -> list[ProcessorInputs] | ErrorResponse: ) -> list[GenerateRequest] | ErrorResponse:
"""Validate the model and preprocess a completion request. """Validate the model and preprocess a completion request.
This is the authoritative implementation used directly by the This is the authoritative implementation used directly by the
...@@ -192,7 +264,48 @@ class OpenAIServingRender: ...@@ -192,7 +264,48 @@ class OpenAIServingRender:
error_check_ret = await self._check_model(request) error_check_ret = await self._check_model(request)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
return await self.render_completion(request) result = await self.render_completion(request)
if isinstance(result, ErrorResponse):
return result
generate_requests: list[GenerateRequest] = []
for engine_prompt in result:
prompt_components = extract_prompt_components(
self.model_config, engine_prompt
)
token_ids = prompt_components.token_ids
if not token_ids:
return self.create_error_response("No token_ids rendered")
token_ids = list(token_ids)
input_length = extract_prompt_len(self.model_config, engine_prompt)
max_tokens = get_max_tokens(
self.model_config.max_model_len,
request.max_tokens,
input_length,
self.default_sampling_params,
self.override_max_tokens,
)
params = request.to_sampling_params(
max_tokens, self.default_sampling_params
)
request_id = f"cmpl-{random_uuid()}"
generate_requests.append(
GenerateRequest(
request_id=request_id,
token_ids=token_ids,
features=self._extract_mm_features(engine_prompt),
sampling_params=params,
model=request.model,
stream=bool(request.stream),
stream_options=(request.stream_options if request.stream else None),
cache_salt=request.cache_salt,
priority=request.priority,
)
)
return generate_requests
async def render_completion( async def render_completion(
self, self,
...@@ -223,6 +336,33 @@ class OpenAIServingRender: ...@@ -223,6 +336,33 @@ class OpenAIServingRender:
return engine_prompts return engine_prompts
@staticmethod
def _extract_mm_features(
engine_prompt: ProcessorInputs,
) -> MultiModalFeatures | None:
"""Extract multimodal metadata from a rendered engine prompt.
Returns ``None`` for text-only prompts.
"""
if engine_prompt.get("type") != "multimodal":
return None
# At this point engine_prompt is a MultiModalInputs TypedDict.
mm_hashes: MultiModalHashes = engine_prompt["mm_hashes"] # type: ignore[typeddict-item]
raw_placeholders: MultiModalPlaceholderDict = engine_prompt["mm_placeholders"] # type: ignore[typeddict-item]
mm_placeholders = {
modality: [
PlaceholderRangeInfo(offset=p.offset, length=p.length) for p in ranges
]
for modality, ranges in raw_placeholders.items()
}
return MultiModalFeatures(
mm_hashes=mm_hashes,
mm_placeholders=mm_placeholders,
)
def _make_request_with_harmony( def _make_request_with_harmony(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
......
...@@ -35,6 +35,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -35,6 +35,7 @@ class OpenAIServingTokenization(OpenAIServing):
request_logger: RequestLogger | None, request_logger: RequestLogger | None,
chat_template: str | None, chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption, chat_template_content_format: ChatTemplateContentFormatOption,
default_chat_template_kwargs: dict[str, Any] | None = None,
trust_request_chat_template: bool = False, trust_request_chat_template: bool = False,
) -> None: ) -> None:
super().__init__( super().__init__(
...@@ -45,6 +46,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -45,6 +46,7 @@ class OpenAIServingTokenization(OpenAIServing):
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.default_chat_template_kwargs = default_chat_template_kwargs or {}
self.trust_request_chat_template = trust_request_chat_template self.trust_request_chat_template = trust_request_chat_template
async def create_tokenize( async def create_tokenize(
...@@ -79,7 +81,7 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -79,7 +81,7 @@ class OpenAIServingTokenization(OpenAIServing):
request.messages, request.messages,
default_template=self.chat_template, default_template=self.chat_template,
default_template_content_format=self.chat_template_content_format, default_template_content_format=self.chat_template_content_format,
default_template_kwargs=None, default_template_kwargs=self.default_chat_template_kwargs,
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
) )
else: else:
...@@ -98,8 +100,9 @@ class OpenAIServingTokenization(OpenAIServing): ...@@ -98,8 +100,9 @@ class OpenAIServingTokenization(OpenAIServing):
lora_request=lora_request, lora_request=lora_request,
) )
if "prompt_token_ids" in engine_prompt: prompt_components = self._extract_prompt_components(engine_prompt)
input_ids.extend(engine_prompt["prompt_token_ids"]) # type: ignore[typeddict-item] if prompt_components.token_ids is not None:
input_ids.extend(prompt_components.token_ids)
token_strs = None token_strs = None
if request.return_token_strs: if request.return_token_strs:
......
...@@ -8,7 +8,7 @@ from collections.abc import Callable, Sequence ...@@ -8,7 +8,7 @@ from collections.abc import Callable, Sequence
from functools import partial from functools import partial
from inspect import isclass from inspect import isclass
from types import FunctionType from types import FunctionType
from typing import Any, TypeAlias, get_type_hints from typing import Any, ClassVar, TypeAlias, cast, get_type_hints
import cloudpickle import cloudpickle
import msgspec import msgspec
...@@ -460,6 +460,19 @@ def run_method( ...@@ -460,6 +460,19 @@ def run_method(
class PydanticMsgspecMixin: class PydanticMsgspecMixin:
"""Make a ``msgspec.Struct`` compatible with Pydantic for both
**validation** (JSON/dict -> Struct) and **serialization**
(Struct -> JSON-safe dict).
Subclasses may set ``__pydantic_msgspec_exclude__`` (a ``set[str]``)
to list non-underscore field names that should also be stripped from
serialized output. Fields whose names start with ``_`` are always
excluded automatically.
"""
# Subclasses can override to exclude additional public-but-internal keys.
__pydantic_msgspec_exclude__: ClassVar[set[str]] = set()
@classmethod @classmethod
def __get_pydantic_core_schema__( def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler cls, source_type: Any, handler: GetCoreSchemaHandler
...@@ -476,32 +489,62 @@ class PydanticMsgspecMixin: ...@@ -476,32 +489,62 @@ class PydanticMsgspecMixin:
# Build the Pydantic typed_dict_field for each msgspec field # Build the Pydantic typed_dict_field for each msgspec field
fields = {} fields = {}
for name, hint in type_hints.items(): for name, hint in type_hints.items():
if name not in msgspec_fields:
# Skip ClassVar and other non-struct annotations.
continue
# Skip private fields — they are excluded from serialization
# and should not appear in the generated JSON/OpenAPI schema.
if name.startswith("_"):
continue
msgspec_field = msgspec_fields[name] msgspec_field = msgspec_fields[name]
# typed_dict_field using the handler to get the schema # typed_dict_field using the handler to get the schema
field_schema = handler(hint) field_schema = handler(hint)
# Add default value to the schema. # Add default value to the schema.
# Mark fields with defaults as not required so the generated
# JSON Schema stays consistent with ``omit_defaults=True``
# serialization (fields at their default value may be absent).
if msgspec_field.default_factory is not msgspec.NODEFAULT: if msgspec_field.default_factory is not msgspec.NODEFAULT:
wrapped_schema = core_schema.with_default_schema( wrapped_schema = core_schema.with_default_schema(
schema=field_schema, schema=field_schema,
default_factory=msgspec_field.default_factory, default_factory=msgspec_field.default_factory,
) )
fields[name] = core_schema.typed_dict_field(wrapped_schema) fields[name] = core_schema.typed_dict_field(
wrapped_schema, required=False
)
elif msgspec_field.default is not msgspec.NODEFAULT: elif msgspec_field.default is not msgspec.NODEFAULT:
wrapped_schema = core_schema.with_default_schema( wrapped_schema = core_schema.with_default_schema(
schema=field_schema, schema=field_schema,
default=msgspec_field.default, default=msgspec_field.default,
) )
fields[name] = core_schema.typed_dict_field(wrapped_schema) fields[name] = core_schema.typed_dict_field(
wrapped_schema, required=False
)
else: else:
# No default, so Pydantic will treat it as required # No default, so Pydantic will treat it as required
fields[name] = core_schema.typed_dict_field(field_schema) fields[name] = core_schema.typed_dict_field(field_schema)
return core_schema.no_info_after_validator_function( typed_dict_then_convert = core_schema.no_info_after_validator_function(
cls._validate_msgspec, cls._validate_msgspec,
core_schema.typed_dict_schema(fields), core_schema.typed_dict_schema(fields),
) )
# Build a serializer that strips private / excluded fields.
serializer = core_schema.plain_serializer_function_ser_schema(
cls._serialize_msgspec,
info_arg=False,
)
# Accept either an already-constructed msgspec.Struct instance or a
# JSON/dict-like payload.
return core_schema.union_schema(
[
core_schema.is_instance_schema(source_type),
typed_dict_then_convert,
],
serialization=serializer,
)
@classmethod @classmethod
def _validate_msgspec(cls, value: Any) -> Any: def _validate_msgspec(cls, value: Any) -> Any:
"""Validate and convert input to msgspec.Struct instance.""" """Validate and convert input to msgspec.Struct instance."""
...@@ -510,3 +553,25 @@ class PydanticMsgspecMixin: ...@@ -510,3 +553,25 @@ class PydanticMsgspecMixin:
if isinstance(value, dict): if isinstance(value, dict):
return cls(**value) return cls(**value)
return msgspec.convert(value, type=cls) return msgspec.convert(value, type=cls)
@staticmethod
def _serialize_msgspec(value: Any) -> Any:
"""Serialize a msgspec.Struct to a JSON-compatible dict, stripping
private (``_``-prefixed) and explicitly excluded fields.
Uses ``msgspec.to_builtins`` which respects ``omit_defaults=True``,
so only fields that differ from their declared defaults are included.
"""
raw = msgspec.to_builtins(value)
if not isinstance(raw, dict):
return raw
exclude: set[str] = cast(
set[str],
getattr(type(value), "__pydantic_msgspec_exclude__", set()),
)
for key in list(raw):
if key.startswith("_") or key in exclude:
del raw[key]
return raw
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment