Unverified Commit dc837bc2 authored by Hojin Yang's avatar Hojin Yang Committed by GitHub
Browse files

feat(frontend): add --default-chat-template-kwargs CLI argument (#31343)


Signed-off-by: default avatareffortprogrammer <yhjhoward7@gmail.com>
parent e54ee3ea
...@@ -204,6 +204,42 @@ The reasoning content is also available when both tool calling and the reasoning ...@@ -204,6 +204,42 @@ The reasoning content is also available when both tool calling and the reasoning
For more examples, please refer to [examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py](../../examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py). For more examples, please refer to [examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py](../../examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py).
## Server-Level Default Chat Template Kwargs
You can set default `chat_template_kwargs` at the server level using the `--default-chat-template-kwargs` CLI argument. This is useful for configuring reasoning behavior across all requests without requiring clients to specify it in each request.
### Disabling Thinking Mode by Default
For models like Qwen3 where thinking is enabled by default, you can disable it server-wide:
```bash
vllm serve Qwen/Qwen3-8B \
--reasoning-parser qwen3 \
--default-chat-template-kwargs '{"enable_thinking": false}'
```
### Enabling Thinking Mode by Default
For models like IBM Granite 3.2 or DeepSeek-V3.1 where thinking is disabled by default, you can enable it server-wide:
```bash
vllm serve ibm-granite/granite-3.2-2b-instruct \
--reasoning-parser granite \
--default-chat-template-kwargs '{"thinking": true}'
```
### Request-Level Override
Request-level `chat_template_kwargs` always take priority over server defaults. For example, if the server is started with `enable_thinking=false`, a client can still enable it for a specific request:
```python
response = client.chat.completions.create(
model=model,
messages=messages,
extra_body={"chat_template_kwargs": {"enable_thinking": True}} # Overrides server default
)
```
## Limitations ## Limitations
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). - The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
......
...@@ -208,3 +208,36 @@ def test_middleware(serve_parser, cli_args, expected_middleware): ...@@ -208,3 +208,36 @@ def test_middleware(serve_parser, cli_args, expected_middleware):
"""Ensure multiple middleware args are parsed properly""" """Ensure multiple middleware args are parsed properly"""
args = serve_parser.parse_args(args=cli_args) args = serve_parser.parse_args(args=cli_args)
assert args.middleware == expected_middleware assert args.middleware == expected_middleware
def test_default_chat_template_kwargs_parsing(serve_parser):
"""Ensure default_chat_template_kwargs JSON is parsed correctly"""
args = serve_parser.parse_args(
args=["--default-chat-template-kwargs", '{"enable_thinking": false}']
)
assert args.default_chat_template_kwargs == {"enable_thinking": False}
def test_default_chat_template_kwargs_complex(serve_parser):
"""Ensure complex default_chat_template_kwargs JSON is parsed correctly"""
kwargs_json = '{"enable_thinking": false, "custom_param": "value", "num": 42}'
args = serve_parser.parse_args(args=["--default-chat-template-kwargs", kwargs_json])
assert args.default_chat_template_kwargs == {
"enable_thinking": False,
"custom_param": "value",
"num": 42,
}
def test_default_chat_template_kwargs_default_none(serve_parser):
"""Ensure default_chat_template_kwargs defaults to None"""
args = serve_parser.parse_args(args=[])
assert args.default_chat_template_kwargs is None
def test_default_chat_template_kwargs_invalid_json(serve_parser):
"""Ensure invalid JSON raises an error"""
with pytest.raises(SystemExit):
serve_parser.parse_args(
args=["--default-chat-template-kwargs", "not valid json"]
)
...@@ -1081,6 +1081,7 @@ async def init_app_state( ...@@ -1081,6 +1081,7 @@ async def init_app_state(
request_logger=request_logger, request_logger=request_logger,
chat_template=resolved_chat_template, chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format, chat_template_content_format=args.chat_template_content_format,
default_chat_template_kwargs=args.default_chat_template_kwargs,
trust_request_chat_template=args.trust_request_chat_template, trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids, return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice, enable_auto_tools=args.enable_auto_tool_choice,
......
...@@ -11,7 +11,7 @@ import json ...@@ -11,7 +11,7 @@ import json
import ssl import ssl
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import field from dataclasses import field
from typing import Literal from typing import Any, Literal
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
...@@ -114,6 +114,12 @@ class FrontendArgs: ...@@ -114,6 +114,12 @@ class FrontendArgs:
"""Whether to trust the chat template provided in the request. If False, """Whether to trust the chat template provided in the request. If False,
the server will always use the chat template specified by `--chat-template` the server will always use the chat template specified by `--chat-template`
or the ones from tokenizer.""" or the ones from tokenizer."""
default_chat_template_kwargs: dict[str, Any] | None = None
"""Default keyword arguments to pass to the chat template renderer.
These will be merged with request-level chat_template_kwargs,
with request values taking precedence. Useful for setting default
behavior for reasoning models. Example: '{"enable_thinking": false}'
to disable thinking mode by default for Qwen3/DeepSeek models."""
response_role: str = "assistant" response_role: str = "assistant"
"""The role name to return if `request.add_generation_prompt=true`.""" """The role name to return if `request.add_generation_prompt=true`."""
ssl_keyfile: str | None = None ssl_keyfile: str | None = None
...@@ -216,6 +222,9 @@ class FrontendArgs: ...@@ -216,6 +222,9 @@ class FrontendArgs:
del frontend_kwargs["allowed_methods"]["nargs"] del frontend_kwargs["allowed_methods"]["nargs"]
del frontend_kwargs["allowed_headers"]["nargs"] del frontend_kwargs["allowed_headers"]["nargs"]
# Special case: default_chat_template_kwargs needs json.loads type
frontend_kwargs["default_chat_template_kwargs"]["type"] = json.loads
# Special case: LoRA modules need custom parser action and # Special case: LoRA modules need custom parser action and
# optional_type(str) # optional_type(str)
frontend_kwargs["lora_modules"]["type"] = optional_type(str) frontend_kwargs["lora_modules"]["type"] = optional_type(str)
......
...@@ -468,6 +468,9 @@ async def run_batch( ...@@ -468,6 +468,9 @@ async def run_batch(
reasoning_parser=args.structured_outputs_config.reasoning_parser, reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage, enable_force_include_usage=args.enable_force_include_usage,
default_chat_template_kwargs=getattr(
args, "default_chat_template_kwargs", None
),
) )
if "generate" in supported_tasks if "generate" in supported_tasks
else None else None
......
...@@ -6,7 +6,7 @@ import json ...@@ -6,7 +6,7 @@ import json
import time import time
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from collections.abc import Sequence as GenericSequence from collections.abc import Sequence as GenericSequence
from typing import Final from typing import Any, Final
import jinja2 import jinja2
import partial_json_parser import partial_json_parser
...@@ -102,6 +102,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -102,6 +102,7 @@ class OpenAIServingChat(OpenAIServing):
enable_force_include_usage: bool = False, enable_force_include_usage: bool = False,
enable_log_outputs: bool = False, enable_log_outputs: bool = False,
log_error_stack: bool = False, log_error_stack: bool = False,
default_chat_template_kwargs: dict[str, Any] | None = None,
) -> None: ) -> None:
super().__init__( super().__init__(
engine_client=engine_client, engine_client=engine_client,
...@@ -115,6 +116,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -115,6 +116,7 @@ class OpenAIServingChat(OpenAIServing):
self.chat_template = chat_template self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template self.trust_request_chat_template = trust_request_chat_template
self.default_chat_template_kwargs = default_chat_template_kwargs or {}
self.enable_log_outputs = enable_log_outputs self.enable_log_outputs = enable_log_outputs
# set up logits processors # set up logits processors
...@@ -203,6 +205,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -203,6 +205,7 @@ class OpenAIServingChat(OpenAIServing):
tool_dicts=None, tool_dicts=None,
documents=None, documents=None,
chat_template_kwargs=None, chat_template_kwargs=None,
default_chat_template_kwargs=self.default_chat_template_kwargs,
tool_parser=None, tool_parser=None,
add_special_tokens=False, add_special_tokens=False,
) )
...@@ -310,6 +313,7 @@ class OpenAIServingChat(OpenAIServing): ...@@ -310,6 +313,7 @@ class OpenAIServingChat(OpenAIServing):
tool_dicts=tool_dicts, tool_dicts=tool_dicts,
documents=request.documents, documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs, chat_template_kwargs=request.chat_template_kwargs,
default_chat_template_kwargs=self.default_chat_template_kwargs,
tool_parser=tool_parser, tool_parser=tool_parser,
add_special_tokens=request.add_special_tokens, add_special_tokens=request.add_special_tokens,
) )
......
...@@ -1158,6 +1158,7 @@ class OpenAIServing: ...@@ -1158,6 +1158,7 @@ class OpenAIServing:
tool_dicts: list[dict[str, Any]] | None = None, tool_dicts: list[dict[str, Any]] | None = None,
documents: list[dict[str, str]] | None = None, documents: list[dict[str, str]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None, chat_template_kwargs: dict[str, Any] | None = None,
default_chat_template_kwargs: dict[str, Any] | None = None,
tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
add_special_tokens: bool = False, add_special_tokens: bool = False,
) -> tuple[list[ConversationMessage], list[TokensPrompt]]: ) -> tuple[list[ConversationMessage], list[TokensPrompt]]:
...@@ -1183,6 +1184,8 @@ class OpenAIServing: ...@@ -1183,6 +1184,8 @@ class OpenAIServing:
tools=tool_dicts, tools=tool_dicts,
documents=documents, documents=documents,
) )
if default_chat_template_kwargs:
_chat_template_kwargs.update(default_chat_template_kwargs)
_chat_template_kwargs.update(chat_template_kwargs or {}) _chat_template_kwargs.update(chat_template_kwargs or {})
request_prompt: str | list[int] request_prompt: str | list[int]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment