Unverified Commit 0210024a authored by Bugen Zhao's avatar Bugen Zhao Committed by GitHub
Browse files

[Bugfix] Pass effective chat template kwargs to reasoning parsers (#40460)


Signed-off-by: default avatarBugen Zhao <i@bugenzhao.com>
parent 4eafc729
...@@ -114,12 +114,15 @@ class OpenAIServingChatBatch(OpenAIServingChat): ...@@ -114,12 +114,15 @@ class OpenAIServingChatBatch(OpenAIServingChat):
""" """
tokenizer = self.renderer.tokenizer tokenizer = self.renderer.tokenizer
assert tokenizer is not None assert tokenizer is not None
single_requests = [
request.to_chat_completion_request(messages)
for messages in request.messages
]
reasoning_parser: ReasoningParser | None = None reasoning_parser: ReasoningParser | None = None
if self.reasoning_parser_cls: if self.reasoning_parser_cls:
chat_template_kwargs = self._prepare_extra_chat_template_kwargs( chat_template_kwargs = self._effective_chat_template_kwargs(
request.chat_template_kwargs, single_requests[0]
self.default_chat_template_kwargs,
) )
reasoning_parser = self.reasoning_parser_cls( reasoning_parser = self.reasoning_parser_cls(
tokenizer, tokenizer,
...@@ -155,7 +158,7 @@ class OpenAIServingChatBatch(OpenAIServingChat): ...@@ -155,7 +158,7 @@ class OpenAIServingChatBatch(OpenAIServingChat):
self.default_sampling_params, self.default_sampling_params,
self.override_max_tokens, self.override_max_tokens,
) )
single_request = request.to_chat_completion_request(request.messages[i]) single_request = single_requests[i]
sampling_params = single_request.to_sampling_params( sampling_params = single_request.to_sampling_params(
max_tokens, self.default_sampling_params max_tokens, self.default_sampling_params
) )
......
...@@ -189,6 +189,18 @@ class OpenAIServingChat(OpenAIServing): ...@@ -189,6 +189,18 @@ class OpenAIServingChat(OpenAIServing):
) )
) )
def _effective_chat_template_kwargs(
self, request: ChatCompletionRequest
) -> dict[str, Any]:
return (
request.build_chat_params(
self.chat_template,
self.chat_template_content_format,
)
.with_defaults(self.default_chat_template_kwargs)
.chat_template_kwargs
)
async def render_chat_request( async def render_chat_request(
self, self,
request: ChatCompletionRequest, request: ChatCompletionRequest,
...@@ -231,10 +243,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -231,10 +243,7 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response # Streaming response
tokenizer = self.renderer.tokenizer tokenizer = self.renderer.tokenizer
assert tokenizer is not None assert tokenizer is not None
chat_template_kwargs = self._prepare_extra_chat_template_kwargs( chat_template_kwargs = self._effective_chat_template_kwargs(request)
request.chat_template_kwargs,
self.default_chat_template_kwargs,
)
reasoning_parser: ReasoningParser | None = None reasoning_parser: ReasoningParser | None = None
if self.reasoning_parser_cls: if self.reasoning_parser_cls:
reasoning_parser = self.reasoning_parser_cls( reasoning_parser = self.reasoning_parser_cls(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging import logging
from collections.abc import Callable from typing import Any
from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
from openai.types.responses.response_function_tool_call_output_item import ( from openai.types.responses.response_function_tool_call_output_item import (
...@@ -15,6 +15,7 @@ from openai.types.responses.response_reasoning_item import ( ...@@ -15,6 +15,7 @@ from openai.types.responses.response_reasoning_item import (
ResponseReasoningItem, ResponseReasoningItem,
) )
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.constants import MCP_PREFIX from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.openai.responses.protocol import ( from vllm.entrypoints.openai.responses.protocol import (
ResponseInputOutputItem, ResponseInputOutputItem,
...@@ -36,10 +37,12 @@ class ResponsesParser: ...@@ -36,10 +37,12 @@ class ResponsesParser:
self, self,
*, *,
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser], reasoning_parser_cls: type[ReasoningParser],
response_messages: list[ResponseInputOutputItem], response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest, request: ResponsesRequest,
tool_parser_cls: type[ToolParser] | None, tool_parser_cls: type[ToolParser] | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
): ):
self.response_messages: list[ResponseInputOutputItem] = ( self.response_messages: list[ResponseInputOutputItem] = (
# TODO: initial messages may not be properly typed # TODO: initial messages may not be properly typed
...@@ -49,7 +52,14 @@ class ResponsesParser: ...@@ -49,7 +52,14 @@ class ResponsesParser:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.request = request self.request = request
self.reasoning_parser_instance = reasoning_parser_cls(tokenizer) self.reasoning_parser_instance = reasoning_parser_cls(
tokenizer,
chat_template_kwargs=_effective_chat_template_kwargs(
request,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
),
)
self.tool_parser_instance = None self.tool_parser_instance = None
if tool_parser_cls is not None: if tool_parser_cls is not None:
self.tool_parser_instance = tool_parser_cls(tokenizer, request.tools) self.tool_parser_instance = tool_parser_cls(tokenizer, request.tools)
...@@ -159,10 +169,12 @@ class ResponsesParser: ...@@ -159,10 +169,12 @@ class ResponsesParser:
def get_responses_parser_for_simple_context( def get_responses_parser_for_simple_context(
*, *,
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser], reasoning_parser_cls: type[ReasoningParser],
response_messages: list[ResponseInputOutputItem], response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest, request: ResponsesRequest,
tool_parser_cls, tool_parser_cls,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
) -> ResponsesParser: ) -> ResponsesParser:
"""Factory function to create a ResponsesParser with """Factory function to create a ResponsesParser with
optional reasoning parser. optional reasoning parser.
...@@ -176,4 +188,17 @@ def get_responses_parser_for_simple_context( ...@@ -176,4 +188,17 @@ def get_responses_parser_for_simple_context(
response_messages=response_messages, response_messages=response_messages,
request=request, request=request,
tool_parser_cls=tool_parser_cls, tool_parser_cls=tool_parser_cls,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
) )
def _effective_chat_template_kwargs(
request: ResponsesRequest,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
) -> dict[str, Any]:
return request.build_chat_params(
default_template=chat_template,
default_template_content_format=chat_template_content_format,
).chat_template_kwargs
...@@ -6,7 +6,6 @@ import copy ...@@ -6,7 +6,6 @@ import copy
import json import json
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
from dataclasses import replace from dataclasses import replace
from typing import TYPE_CHECKING, Any, Final, Union from typing import TYPE_CHECKING, Any, Final, Union
...@@ -273,7 +272,7 @@ class ParsableContext(ConversationContext): ...@@ -273,7 +272,7 @@ class ParsableContext(ConversationContext):
*, *,
response_messages: list[ResponseInputOutputItem], response_messages: list[ResponseInputOutputItem],
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
reasoning_parser_cls: Callable[[TokenizerLike], ReasoningParser] | None, reasoning_parser_cls: type[ReasoningParser] | None,
request: ResponsesRequest, request: ResponsesRequest,
available_tools: list[str] | None, available_tools: list[str] | None,
tool_parser_cls: type[ToolParser] | None, tool_parser_cls: type[ToolParser] | None,
...@@ -296,6 +295,8 @@ class ParsableContext(ConversationContext): ...@@ -296,6 +295,8 @@ class ParsableContext(ConversationContext):
response_messages=response_messages, response_messages=response_messages,
request=request, request=request,
tool_parser_cls=tool_parser_cls, tool_parser_cls=tool_parser_cls,
chat_template=chat_template,
chat_template_content_format=chat_template_content_format,
) )
self.tool_parser_cls = tool_parser_cls self.tool_parser_cls = tool_parser_cls
self.request = request self.request = request
......
...@@ -267,6 +267,14 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -267,6 +267,14 @@ class OpenAIServingResponses(OpenAIServing):
self.tool_server = tool_server self.tool_server = tool_server
def _effective_chat_template_kwargs(
self, request: ResponsesRequest
) -> dict[str, Any]:
return request.build_chat_params(
self.chat_template,
self.chat_template_content_format,
).chat_template_kwargs
def _validate_generator_input( def _validate_generator_input(
self, self,
engine_input: EngineInput, engine_input: EngineInput,
...@@ -464,7 +472,10 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -464,7 +472,10 @@ class OpenAIServingResponses(OpenAIServing):
context = SimpleContext() context = SimpleContext()
if self.parser and self.parser.reasoning_parser_cls is not None: if self.parser and self.parser.reasoning_parser_cls is not None:
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer) reasoning_parser = self.parser.reasoning_parser_cls(
tokenizer,
chat_template_kwargs=self._effective_chat_template_kwargs(request),
)
if ( if (
isinstance( isinstance(
struct_out := sampling_params.structured_outputs, struct_out := sampling_params.structured_outputs,
...@@ -835,7 +846,10 @@ class OpenAIServingResponses(OpenAIServing): ...@@ -835,7 +846,10 @@ class OpenAIServingResponses(OpenAIServing):
and self.parser.reasoning_parser_cls is not None and self.parser.reasoning_parser_cls is not None
and isinstance(context, (SimpleContext, ParsableContext)) and isinstance(context, (SimpleContext, ParsableContext))
): ):
reasoning_parser = self.parser.reasoning_parser_cls(tokenizer) reasoning_parser = self.parser.reasoning_parser_cls(
tokenizer,
chat_template_kwargs=self._effective_chat_template_kwargs(request),
)
accumulated = getattr(context, "_accumulated_token_ids", []) or [] accumulated = getattr(context, "_accumulated_token_ids", []) or []
num_reasoning_tokens = reasoning_parser.count_reasoning_tokens(accumulated) num_reasoning_tokens = reasoning_parser.count_reasoning_tokens(accumulated)
......
...@@ -566,7 +566,9 @@ class OpenAIServingRender: ...@@ -566,7 +566,9 @@ class OpenAIServingRender:
if reasoning_parser is not None: if reasoning_parser is not None:
tokenizer = renderer.get_tokenizer() tokenizer = renderer.get_tokenizer()
request = reasoning_parser( request = reasoning_parser(
tokenizer, model_config=self.model_config tokenizer,
model_config=self.model_config,
chat_template_kwargs=chat_params.chat_template_kwargs,
).adjust_request(request=request) ).adjust_request(request=request)
# tool parsing is done only if a tool_parser has been set and if # tool parsing is done only if a tool_parser has been set and if
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment