Unverified Commit 0210024a authored by Bugen Zhao's avatar Bugen Zhao Committed by GitHub
Browse files

[Bugfix] Pass effective chat template kwargs to reasoning parsers (#40460)


Signed-off-by: default avatarBugen Zhao <i@bugenzhao.com>
parent 4eafc729
......@@ -114,12 +114,15 @@ class OpenAIServingChatBatch(OpenAIServingChat):
"""
tokenizer = self.renderer.tokenizer
assert tokenizer is not None
single_requests = [
request.to_chat_completion_request(messages)
for messages in request.messages
]
reasoning_parser: ReasoningParser | None = None
if self.reasoning_parser_cls:
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
request.chat_template_kwargs,
self.default_chat_template_kwargs,
chat_template_kwargs = self._effective_chat_template_kwargs(
single_requests[0]
)
reasoning_parser = self.reasoning_parser_cls(
tokenizer,
......@@ -155,7 +158,7 @@ class OpenAIServingChatBatch(OpenAIServingChat):
self.default_sampling_params,
self.override_max_tokens,
)
single_request = request.to_chat_completion_request(request.messages[i])
single_request = single_requests[i]
sampling_params = single_request.to_sampling_params(
max_tokens, self.default_sampling_params
)
......
......@@ -189,6 +189,18 @@ class OpenAIServingChat(OpenAIServing):
)
)
def _effective_chat_template_kwargs(
self, request: ChatCompletionRequest
) -> dict[str, Any]:
return (
request.build_chat_params(
self.chat_template,
self.chat_template_content_format,
)
.with_defaults(self.default_chat_template_kwargs)
.chat_template_kwargs
)
async def render_chat_request(
self,
request: ChatCompletionRequest,
......@@ -231,10 +243,7 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response
tokenizer = self.renderer.tokenizer
assert tokenizer is not None
chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
request.chat_template_kwargs,
self.default_chat_template_kwargs,
)
chat_template_kwargs = self._effective_chat_template_kwargs(request)
reasoning_parser: ReasoningParser | None = None
if self.reasoning_parser_cls:
reasoning_parser = self.reasoning_parser_cls(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from collections.abc import Callable
from typing import Any
from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
from openai.types.responses.response_function_tool_call_output_item import (
......@@ -15,6 +15,7 @@ from openai.types.responses.response_reasoning_item import (
ResponseReasoningItem,
)
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.openai.responses.protocol import (
ResponseInputOutputItem,
......@@ -36,10 +37,12 @@ class ResponsesParser:
self,
*,
tokenizer: TokenizerLike,
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser],
reasoning_parser_cls: type[ReasoningParser],
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls: type[ToolParser] | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
self.response_messages: list[ResponseInputOutputItem] = (
# TODO: initial messages may not be properly typed
......@@ -49,7 +52,14 @@ class ResponsesParser:
self.tokenizer = tokenizer
self.request = request
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
self.reasoning_parser_instance = reasoning_parser_cls(
tokenizer,
chat_template_kwargs=_effective_chat_template_kwargs(
request,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
),
)
self.tool_parser_instance = None
if tool_parser_cls is not None:
self.tool_parser_instance = tool_parser_cls(tokenizer, request.tools)
......@@ -159,10 +169,12 @@ class ResponsesParser:
def get_responses_parser_for_simple_context(
*,
tokenizer: TokenizerLike,
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser],
reasoning_parser_cls: type[ReasoningParser],
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
) -> ResponsesParser:
"""Factory function to create a ResponsesParser with
optional reasoning parser.
......@@ -176,4 +188,17 @@ def get_responses_parser_for_simple_context(
response_messages=response_messages,
request=request,
tool_parser_cls=tool_parser_cls,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
)
def _effective_chat_template_kwargs(
request: ResponsesRequest,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
) -> dict[str, Any]:
return request.build_chat_params(
default_template=chat_template,
default_template_content_format=chat_template_content_format,
).chat_template_kwargs
......@@ -6,7 +6,6 @@ import copy
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from contextlib import AsyncExitStack
from dataclasses import replace
from typing import TYPE_CHECKING, Any, Final, Union
......@@ -273,7 +272,7 @@ class ParsableContext(ConversationContext):
*,
response_messages: list[ResponseInputOutputItem],
tokenizer: TokenizerLike,
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None,
reasoning_parser_cls: type[ReasoningParser] | None,
request: ResponsesRequest,
available_tools: list[str] | None,
tool_parser_cls: type[ToolParser] | None,
......@@ -296,6 +295,8 @@ class ParsableContext(ConversationContext):
response_messages=response_messages,
request=request,
tool_parser_cls=tool_parser_cls,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
)
self.tool_parser_cls = tool_parser_cls
self.request = request
......
......@@ -267,6 +267,14 @@ class OpenAIServingResponses(OpenAIServing):
self.tool_server = tool_server
def _effective_chat_template_kwargs(
self, request: ResponsesRequest
) -> dict[str, Any]:
return request.build_chat_params(
self.chat_template,
self.chat_template_content_format,
).chat_template_kwargs
def _validate_generator_input(
self,
engine_input: EngineInput,
......@@ -464,7 +472,10 @@ class OpenAIServingResponses(OpenAIServing):
context = SimpleContext()
if self.parser and self.parser.reasoning_parser_cls is not None:
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer)
reasoning_parser = self.parser.reasoning_parser_cls(
tokenizer,
chat_template_kwargs=self._effective_chat_template_kwargs(request),
)
if (
isinstance(
struct_out := sampling_params.structured_outputs,
......@@ -835,7 +846,10 @@ class OpenAIServingResponses(OpenAIServing):
and self.parser.reasoning_parser_cls is not None
and isinstance(context, (SimpleContext, ParsableContext))
):
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer)
reasoning_parser = self.parser.reasoning_parser_cls(
tokenizer,
chat_template_kwargs=self._effective_chat_template_kwargs(request),
)
accumulated = getattr(context, "_accumulated_token_ids", []) or []
num_reasoning_tokens = reasoning_parser.count_reasoning_tokens(accumulated)
......
......@@ -566,7 +566,9 @@ class OpenAIServingRender:
if reasoning_parser is not None:
tokenizer = renderer.get_tokenizer()
request = reasoning_parser(
tokenizer, model_config=self.model_config
tokenizer,
model_config=self.model_config,
chat_template_kwargs=chat_params.chat_template_kwargs,
).adjust_request(request=request)
# tool parsing is done only if a tool_parser has been set and if
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment