Unverified Commit 14a6122e authored by Yifan Jiang's avatar Yifan Jiang Committed by GitHub
Browse files

fix: strip tools from chat template when tool_choice=none (#7391)


Signed-off-by: default avatarYifan Jiang <19356972+yifjiang@users.noreply.github.com>
Co-authored-by: default avatarAyush Agarwal <ayushag@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 70e1ec1d
...@@ -28,6 +28,7 @@ class DynamoRuntimeConfig(ConfigBase): ...@@ -28,6 +28,7 @@ class DynamoRuntimeConfig(ConfigBase):
dyn_tool_call_parser: Optional[str] = None dyn_tool_call_parser: Optional[str] = None
dyn_reasoning_parser: Optional[str] = None dyn_reasoning_parser: Optional[str] = None
exclude_tools_when_tool_choice_none: bool = True
custom_jinja_template: Optional[str] = None custom_jinja_template: Optional[str] = None
endpoint_types: str endpoint_types: str
dump_config_to: Optional[str] = None dump_config_to: Optional[str] = None
...@@ -141,6 +142,19 @@ class DynamoRuntimeArgGroup(ArgGroup): ...@@ -141,6 +142,19 @@ class DynamoRuntimeArgGroup(ArgGroup):
help="Reasoning parser name for the model. If not specified, no reasoning parsing is performed.", help="Reasoning parser name for the model. If not specified, no reasoning parsing is performed.",
choices=get_reasoning_parser_names(), choices=get_reasoning_parser_names(),
) )
# NOTE: This flag also exists in FrontendArgGroup (frontend_args.py).
# Both definitions are needed: this one controls the Rust-native chat
# template path (oai.rs), while the frontend copy controls the Python
# processors (vllm_processor / sglang_processor) which parse arguments
# independently via FrontendConfig.
add_negatable_bool_argument(
g,
flag_name="--exclude-tools-when-tool-choice-none",
env_var="DYN_EXCLUDE_TOOLS_WHEN_TOOL_CHOICE_NONE",
default=True,
help="Exclude tool definitions from the chat template when tool_choice='none'. "
"Prevents models from generating raw XML tool calls in the content field.",
)
add_argument( add_argument(
g, g,
flag_name="--custom-jinja-template", flag_name="--custom-jinja-template",
......
...@@ -75,6 +75,7 @@ class FrontendConfig(KvRouterConfigBase): ...@@ -75,6 +75,7 @@ class FrontendConfig(KvRouterConfigBase):
debug_perf: bool debug_perf: bool
enable_streaming_tool_dispatch: bool enable_streaming_tool_dispatch: bool
enable_streaming_reasoning_dispatch: bool enable_streaming_reasoning_dispatch: bool
exclude_tools_when_tool_choice_none: bool
preprocess_workers: int preprocess_workers: int
tokenizer_backend: str tokenizer_backend: str
...@@ -392,6 +393,22 @@ class FrontendArgGroup(ArgGroup): ...@@ -392,6 +393,22 @@ class FrontendArgGroup(ArgGroup):
"Can be combined with --enable-streaming-tool-dispatch." "Can be combined with --enable-streaming-tool-dispatch."
), ),
) )
# NOTE: This flag also exists in DynamoRuntimeArgGroup (runtime_args.py).
# Both definitions are needed: runtime_args controls the Rust-native
# chat template path (oai.rs), while this one controls the Python
# frontend processors (vllm_processor / sglang_processor) which parse
# arguments independently via FrontendConfig.
add_negatable_bool_argument(
g,
flag_name="--exclude-tools-when-tool-choice-none",
env_var="DYN_EXCLUDE_TOOLS_WHEN_TOOL_CHOICE_NONE",
default=True,
help=(
"Exclude tool definitions from the chat template when "
"tool_choice='none'. Prevents models from generating raw XML "
"tool calls in the content field."
),
)
add_argument( add_argument(
g, g,
flag_name="--dyn-chat-processor", flag_name="--dyn-chat-processor",
......
...@@ -79,6 +79,7 @@ def _prepare_request( ...@@ -79,6 +79,7 @@ def _prepare_request(
*, *,
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
tool_parser_class: type[ToolParser] | None, tool_parser_class: type[ToolParser] | None,
exclude_tools_when_tool_choice_none: bool = True,
) -> tuple[ChatCompletionRequest, ToolParser | None, dict[str, Any], Any, ChatParams]: ) -> tuple[ChatCompletionRequest, ToolParser | None, dict[str, Any], Any, ChatParams]:
"""Validate request and build arguments for template rendering. """Validate request and build arguments for template rendering.
...@@ -107,9 +108,15 @@ def _prepare_request( ...@@ -107,9 +108,15 @@ def _prepare_request(
tool_parser = tool_parser_class(tokenizer) tool_parser = tool_parser_class(tokenizer)
request_for_sampling = tool_parser.adjust_request(request_for_sampling) request_for_sampling = tool_parser.adjust_request(request_for_sampling)
# Strip tools from the template when tool_choice=none so the model doesn't
# see them and generate raw XML tool calls in its response.
tool_dicts = ( tool_dicts = (
[tool.model_dump() for tool in request_for_sampling.tools] [tool.model_dump() for tool in request_for_sampling.tools]
if request_for_sampling.tools if request_for_sampling.tools
and not (
exclude_tools_when_tool_choice_none
and request_for_sampling.tool_choice == "none"
)
else None else None
) )
chat_template_kwargs = dict(request_for_sampling.chat_template_kwargs or {}) chat_template_kwargs = dict(request_for_sampling.chat_template_kwargs or {})
...@@ -155,6 +162,7 @@ async def preprocess_chat_request( ...@@ -155,6 +162,7 @@ async def preprocess_chat_request(
tokenizer: TokenizerLike, tokenizer: TokenizerLike,
renderer, renderer,
tool_parser_class: type[ToolParser] | None, tool_parser_class: type[ToolParser] | None,
exclude_tools_when_tool_choice_none: bool = True,
) -> PreprocessResult: ) -> PreprocessResult:
( (
request_for_sampling, request_for_sampling,
...@@ -163,7 +171,10 @@ async def preprocess_chat_request( ...@@ -163,7 +171,10 @@ async def preprocess_chat_request(
messages, messages,
chat_params, chat_params,
) = _prepare_request( ) = _prepare_request(
request, tokenizer=tokenizer, tool_parser_class=tool_parser_class request,
tokenizer=tokenizer,
tool_parser_class=tool_parser_class,
exclude_tools_when_tool_choice_none=exclude_tools_when_tool_choice_none,
) )
_, engine_prompt = await renderer.render_messages_async(messages, chat_params) _, engine_prompt = await renderer.render_messages_async(messages, chat_params)
......
...@@ -100,6 +100,7 @@ def preprocess_chat_request( ...@@ -100,6 +100,7 @@ def preprocess_chat_request(
tokenizer, tokenizer,
tool_call_parser_name: str | None, tool_call_parser_name: str | None,
reasoning_parser_name: str | None, reasoning_parser_name: str | None,
exclude_tools_when_tool_choice_none: bool = True,
) -> SglangPreprocessResult: ) -> SglangPreprocessResult:
"""Preprocess a chat request using SGLang tokenizer and parser APIs. """Preprocess a chat request using SGLang tokenizer and parser APIs.
...@@ -115,7 +116,12 @@ def preprocess_chat_request( ...@@ -115,7 +116,12 @@ def preprocess_chat_request(
"add_generation_prompt": True, "add_generation_prompt": True,
"tokenize": True, "tokenize": True,
} }
if sglang_tools: # Strip tools from template when tool_choice=none so the model doesn't
# see them and generate raw XML tool calls in its response.
tool_choice = request.get("tool_choice", "auto")
if sglang_tools and not (
exclude_tools_when_tool_choice_none and tool_choice == "none"
):
template_kwargs["tools"] = [t.model_dump() for t in sglang_tools] template_kwargs["tools"] = [t.model_dump() for t in sglang_tools]
prompt_token_ids = tokenizer.apply_chat_template(messages, **template_kwargs) prompt_token_ids = tokenizer.apply_chat_template(messages, **template_kwargs)
......
...@@ -89,6 +89,7 @@ def _map_finish_reason(raw: str | None) -> str | None: ...@@ -89,6 +89,7 @@ def _map_finish_reason(raw: str | None) -> str | None:
_w_tokenizer: Any = None _w_tokenizer: Any = None
_w_tool_call_parser_name: str | None = None _w_tool_call_parser_name: str | None = None
_w_reasoning_parser_name: str | None = None _w_reasoning_parser_name: str | None = None
_w_exclude_tools_when_tool_choice_none: bool = True
@dataclass @dataclass
...@@ -104,12 +105,15 @@ def _init_worker( ...@@ -104,12 +105,15 @@ def _init_worker(
model_path: str, model_path: str,
tool_call_parser_name: str | None, tool_call_parser_name: str | None,
reasoning_parser_name: str | None, reasoning_parser_name: str | None,
exclude_tools_when_tool_choice_none: bool = True,
) -> None: ) -> None:
"""Initialize a worker process with its own tokenizer.""" """Initialize a worker process with its own tokenizer."""
global _w_tokenizer, _w_tool_call_parser_name, _w_reasoning_parser_name global _w_tokenizer, _w_tool_call_parser_name, _w_reasoning_parser_name
global _w_exclude_tools_when_tool_choice_none
_w_tokenizer = get_tokenizer(model_path) _w_tokenizer = get_tokenizer(model_path)
_w_tool_call_parser_name = tool_call_parser_name _w_tool_call_parser_name = tool_call_parser_name
_w_reasoning_parser_name = reasoning_parser_name _w_reasoning_parser_name = reasoning_parser_name
_w_exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
def _preprocess_worker( def _preprocess_worker(
...@@ -123,6 +127,7 @@ def _preprocess_worker( ...@@ -123,6 +127,7 @@ def _preprocess_worker(
tokenizer=_w_tokenizer, tokenizer=_w_tokenizer,
tool_call_parser_name=_w_tool_call_parser_name, tool_call_parser_name=_w_tool_call_parser_name,
reasoning_parser_name=_w_reasoning_parser_name, reasoning_parser_name=_w_reasoning_parser_name,
exclude_tools_when_tool_choice_none=_w_exclude_tools_when_tool_choice_none,
) )
n = request.get("n", 1) n = request.get("n", 1)
...@@ -218,6 +223,7 @@ class SglangProcessor: ...@@ -218,6 +223,7 @@ class SglangProcessor:
self.is_kv_router = isinstance(router, KvRouter) self.is_kv_router = isinstance(router, KvRouter)
self.tool_call_parser_name = tool_call_parser_name self.tool_call_parser_name = tool_call_parser_name
self.reasoning_parser_name = reasoning_parser_name self.reasoning_parser_name = reasoning_parser_name
self.exclude_tools_when_tool_choice_none = True
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.debug_perf = debug_perf self.debug_perf = debug_perf
self.stream_interval = stream_interval self.stream_interval = stream_interval
...@@ -275,6 +281,7 @@ class SglangProcessor: ...@@ -275,6 +281,7 @@ class SglangProcessor:
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
tool_call_parser_name=self.tool_call_parser_name, tool_call_parser_name=self.tool_call_parser_name,
reasoning_parser_name=self.reasoning_parser_name, reasoning_parser_name=self.reasoning_parser_name,
exclude_tools_when_tool_choice_none=self.exclude_tools_when_tool_choice_none,
) )
if self.debug_perf: if self.debug_perf:
...@@ -576,6 +583,7 @@ class SglangEngineFactory: ...@@ -576,6 +583,7 @@ class SglangEngineFactory:
source_path, source_path,
tool_call_parser_name, tool_call_parser_name,
reasoning_parser_name, reasoning_parser_name,
self.config.exclude_tools_when_tool_choice_none,
), ),
) )
futures = [ futures = [
...@@ -612,5 +620,8 @@ class SglangEngineFactory: ...@@ -612,5 +620,8 @@ class SglangEngineFactory:
preprocess_workers=preprocess_workers, preprocess_workers=preprocess_workers,
stream_interval=self.stream_interval, stream_interval=self.stream_interval,
) )
gen.exclude_tools_when_tool_choice_none = (
self.config.exclude_tools_when_tool_choice_none
)
return PythonAsyncEngine(gen.generator, loop) return PythonAsyncEngine(gen.generator, loop)
...@@ -13,6 +13,7 @@ Parallels test_vllm_unit.py for the vLLM backend. ...@@ -13,6 +13,7 @@ Parallels test_vllm_unit.py for the vLLM backend.
import pytest import pytest
from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.srt.utils.hf_transformers_utils import get_tokenizer
import dynamo.frontend.sglang_processor as sglang_processor_module
from dynamo.frontend.sglang_prepost import ( from dynamo.frontend.sglang_prepost import (
SglangPreprocessResult, SglangPreprocessResult,
SglangStreamingPostProcessor, SglangStreamingPostProcessor,
...@@ -23,6 +24,7 @@ from dynamo.frontend.sglang_prepost import ( ...@@ -23,6 +24,7 @@ from dynamo.frontend.sglang_prepost import (
from dynamo.frontend.sglang_processor import ( from dynamo.frontend.sglang_processor import (
SglangPreprocessWorkerResult, SglangPreprocessWorkerResult,
_build_dynamo_preproc, _build_dynamo_preproc,
_init_worker,
_map_finish_reason, _map_finish_reason,
) )
from dynamo.frontend.utils import PreprocessError, random_call_id, random_uuid from dynamo.frontend.utils import PreprocessError, random_call_id, random_uuid
...@@ -513,6 +515,94 @@ class TestPreprocessChatRequest: ...@@ -513,6 +515,94 @@ class TestPreprocessChatRequest:
assert len(with_tools.prompt_token_ids) > len(without_tools.prompt_token_ids) assert len(with_tools.prompt_token_ids) > len(without_tools.prompt_token_ids)
assert with_tools.tool_call_parser is not None assert with_tools.tool_call_parser is not None
def test_tool_choice_none_strips_tools_from_template(self, tokenizer):
"""When exclude flag is on and tool_choice=none, tools are excluded from template."""
tool_request = {
"model": MODEL,
"messages": [{"role": "user", "content": "Hello"}],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}
],
}
with_tools_auto = preprocess_chat_request(
{**tool_request, "tool_choice": "auto"},
tokenizer=tokenizer,
tool_call_parser_name=None,
reasoning_parser_name=None,
exclude_tools_when_tool_choice_none=True,
)
with_tools_none = preprocess_chat_request(
{**tool_request, "tool_choice": "none"},
tokenizer=tokenizer,
tool_call_parser_name=None,
reasoning_parser_name=None,
exclude_tools_when_tool_choice_none=True,
)
# tool_choice=none should produce fewer tokens (no tool defs in template)
assert len(with_tools_none.prompt_token_ids) < len(
with_tools_auto.prompt_token_ids
), "tool_choice=none with exclude flag should strip tools from template"
def test_tool_choice_none_keeps_tools_when_flag_off(self, tokenizer):
"""When exclude flag is off, tool_choice=none still includes tools in template."""
tool_request = {
"model": MODEL,
"messages": [{"role": "user", "content": "Hello"}],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}
],
}
with_auto = preprocess_chat_request(
{**tool_request, "tool_choice": "auto"},
tokenizer=tokenizer,
tool_call_parser_name=None,
reasoning_parser_name=None,
exclude_tools_when_tool_choice_none=False,
)
with_none = preprocess_chat_request(
{**tool_request, "tool_choice": "none"},
tokenizer=tokenizer,
tool_call_parser_name=None,
reasoning_parser_name=None,
exclude_tools_when_tool_choice_none=False,
)
# With flag off, both should have similar token counts (tools in template)
assert len(with_none.prompt_token_ids) == len(
with_auto.prompt_token_ids
), "tool_choice=none with flag off should keep tools in template"
def test_init_worker_propagates_exclude_flag_true(self):
"""_init_worker sets the worker-global exclude_tools flag to True."""
_init_worker(MODEL, None, None, exclude_tools_when_tool_choice_none=True)
assert sglang_processor_module._w_exclude_tools_when_tool_choice_none is True
def test_init_worker_propagates_exclude_flag_false(self):
"""_init_worker sets the worker-global exclude_tools flag to False."""
_init_worker(MODEL, None, None, exclude_tools_when_tool_choice_none=False)
assert sglang_processor_module._w_exclude_tools_when_tool_choice_none is False
# Reset to default
sglang_processor_module._w_exclude_tools_when_tool_choice_none = True
def test_with_reasoning_parser(self, tokenizer): def test_with_reasoning_parser(self, tokenizer):
"""Reasoning parser is attached to result.""" """Reasoning parser is attached to result."""
result = preprocess_chat_request( result = preprocess_chat_request(
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for vLLM processor components.
Tests for the tool-stripping behaviour of _prepare_request when
tool_choice='none' and the exclude_tools_when_tool_choice_none flag.
"""
import pytest
from transformers import AutoTokenizer
from dynamo.frontend.prepost import _prepare_request
MODEL = "Qwen/Qwen3-0.6B"
TOOL_REQUEST = {
"model": MODEL,
"messages": [{"role": "user", "content": "Hello"}],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}
],
}
@pytest.fixture(scope="module")
def tokenizer():
return AutoTokenizer.from_pretrained(MODEL)
# ---------------------------------------------------------------------------
# _prepare_request: tool_choice=none tool-stripping
# ---------------------------------------------------------------------------
class TestPrepareRequestToolStripping:
"""Test that _prepare_request strips/keeps tools based on the flag."""
def test_tool_choice_none_strips_tools_from_template(self, tokenizer):
"""When exclude flag is on and tool_choice=none, tools are excluded from template kwargs."""
_, _, _, _, chat_params = _prepare_request(
{**TOOL_REQUEST, "tool_choice": "none"},
tokenizer=tokenizer,
tool_parser_class=None,
exclude_tools_when_tool_choice_none=True,
)
assert (
chat_params.chat_template_kwargs["tools"] is None
), "tool_choice=none with exclude flag should strip tools from template"
def test_tool_choice_none_keeps_tools_when_flag_off(self, tokenizer):
"""When exclude flag is off, tool_choice=none still includes tools in template kwargs."""
_, _, _, _, chat_params = _prepare_request(
{**TOOL_REQUEST, "tool_choice": "none"},
tokenizer=tokenizer,
tool_parser_class=None,
exclude_tools_when_tool_choice_none=False,
)
tools = chat_params.chat_template_kwargs["tools"]
assert (
tools is not None and len(tools) == 1
), "tool_choice=none with flag off should keep tools in template"
def test_tool_choice_auto_keeps_tools(self, tokenizer):
"""tool_choice=auto should always include tools regardless of flag."""
_, _, _, _, chat_params = _prepare_request(
{**TOOL_REQUEST, "tool_choice": "auto"},
tokenizer=tokenizer,
tool_parser_class=None,
exclude_tools_when_tool_choice_none=True,
)
tools = chat_params.chat_template_kwargs["tools"]
assert (
tools is not None and len(tools) == 1
), "tool_choice=auto should keep tools in template"
def test_tool_choice_required_keeps_tools(self, tokenizer):
"""tool_choice=required should always include tools regardless of flag."""
_, _, _, _, chat_params = _prepare_request(
{**TOOL_REQUEST, "tool_choice": "required"},
tokenizer=tokenizer,
tool_parser_class=None,
exclude_tools_when_tool_choice_none=True,
)
tools = chat_params.chat_template_kwargs["tools"]
assert (
tools is not None and len(tools) == 1
), "tool_choice=required should keep tools in template"
def test_no_tools_in_request(self, tokenizer):
"""Request without tools should produce None tools in template kwargs."""
_, _, _, _, chat_params = _prepare_request(
{"model": MODEL, "messages": [{"role": "user", "content": "Hello"}]},
tokenizer=tokenizer,
tool_parser_class=None,
exclude_tools_when_tool_choice_none=True,
)
assert (
chat_params.chat_template_kwargs["tools"] is None
), "No tools in request should produce None tools in template"
...@@ -85,6 +85,7 @@ class VllmProcessor: ...@@ -85,6 +85,7 @@ class VllmProcessor:
self.output_processor = output_processor self.output_processor = output_processor
self.tool_parser_class = tool_parser_class self.tool_parser_class = tool_parser_class
self.reasoning_parser_class = reasoning_parser_class self.reasoning_parser_class = reasoning_parser_class
self.exclude_tools_when_tool_choice_none = True
def _get_eos_token_ids(self) -> list[int]: def _get_eos_token_ids(self) -> list[int]:
"""Return EOS token ids using tokenizer metadata. """Return EOS token ids using tokenizer metadata.
...@@ -125,6 +126,7 @@ class VllmProcessor: ...@@ -125,6 +126,7 @@ class VllmProcessor:
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
renderer=self.input_processor.renderer, renderer=self.input_processor.renderer,
tool_parser_class=self.tool_parser_class, tool_parser_class=self.tool_parser_class,
exclude_tools_when_tool_choice_none=self.exclude_tools_when_tool_choice_none,
) )
request_for_sampling = pre.request_for_sampling request_for_sampling = pre.request_for_sampling
...@@ -472,5 +474,8 @@ class EngineFactory: ...@@ -472,5 +474,8 @@ class EngineFactory:
tool_parser_class, tool_parser_class,
reasoning_parser_class, reasoning_parser_class,
) )
gen.exclude_tools_when_tool_choice_none = (
self.config.exclude_tools_when_tool_choice_none
)
return PythonAsyncEngine(gen.generator, loop) return PythonAsyncEngine(gen.generator, loop)
...@@ -162,6 +162,9 @@ async def _get_runtime_config( ...@@ -162,6 +162,9 @@ async def _get_runtime_config(
# set reasoning parser and tool call parser # set reasoning parser and tool call parser
runtime_config.reasoning_parser = dynamo_args.dyn_reasoning_parser runtime_config.reasoning_parser = dynamo_args.dyn_reasoning_parser
runtime_config.tool_call_parser = dynamo_args.dyn_tool_call_parser runtime_config.tool_call_parser = dynamo_args.dyn_tool_call_parser
runtime_config.exclude_tools_when_tool_choice_none = (
dynamo_args.exclude_tools_when_tool_choice_none
)
# Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer # Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer
is_decode_worker = server_args.disaggregation_mode == "decode" is_decode_worker = server_args.disaggregation_mode == "decode"
runtime_config.enable_local_indexer = ( runtime_config.enable_local_indexer = (
......
...@@ -372,6 +372,9 @@ async def init_llm_worker( ...@@ -372,6 +372,9 @@ async def init_llm_worker(
runtime_config.max_num_batched_tokens = config.max_num_tokens runtime_config.max_num_batched_tokens = config.max_num_tokens
runtime_config.reasoning_parser = config.dyn_reasoning_parser runtime_config.reasoning_parser = config.dyn_reasoning_parser
runtime_config.tool_call_parser = config.dyn_tool_call_parser runtime_config.tool_call_parser = config.dyn_tool_call_parser
runtime_config.exclude_tools_when_tool_choice_none = (
config.exclude_tools_when_tool_choice_none
)
# Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer # Decode workers don't create the WorkerKvQuery endpoint, so don't advertise local indexer
runtime_config.enable_local_indexer = ( runtime_config.enable_local_indexer = (
config.enable_local_indexer config.enable_local_indexer
......
...@@ -635,6 +635,9 @@ async def register_vllm_model( ...@@ -635,6 +635,9 @@ async def register_vllm_model(
if model_type != ModelType.Prefill: if model_type != ModelType.Prefill:
runtime_config.tool_call_parser = config.dyn_tool_call_parser runtime_config.tool_call_parser = config.dyn_tool_call_parser
runtime_config.reasoning_parser = config.dyn_reasoning_parser runtime_config.reasoning_parser = config.dyn_reasoning_parser
runtime_config.exclude_tools_when_tool_choice_none = (
config.exclude_tools_when_tool_choice_none
)
# Get data_parallel_size from vllm_config (defaults to 1) # Get data_parallel_size from vllm_config (defaults to 1)
dp_range = get_dp_range_for_worker(vllm_config) dp_range = get_dp_range_for_worker(vllm_config)
......
...@@ -60,6 +60,14 @@ impl ModelRuntimeConfig { ...@@ -60,6 +60,14 @@ impl ModelRuntimeConfig {
self.inner.enable_local_indexer = enable_local_indexer; self.inner.enable_local_indexer = enable_local_indexer;
} }
#[setter]
fn set_exclude_tools_when_tool_choice_none(
&mut self,
exclude_tools_when_tool_choice_none: bool,
) {
self.inner.exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none;
}
#[setter] #[setter]
fn set_enable_eagle(&mut self, enable_eagle: bool) { fn set_enable_eagle(&mut self, enable_eagle: bool) {
self.inner.enable_eagle = enable_eagle; self.inner.enable_eagle = enable_eagle;
...@@ -124,6 +132,11 @@ impl ModelRuntimeConfig { ...@@ -124,6 +132,11 @@ impl ModelRuntimeConfig {
self.inner.enable_local_indexer self.inner.enable_local_indexer
} }
#[getter]
fn exclude_tools_when_tool_choice_none(&self) -> bool {
self.inner.exclude_tools_when_tool_choice_none
}
#[getter] #[getter]
fn runtime_data(&self, py: Python<'_>) -> PyResult<PyObject> { fn runtime_data(&self, py: Python<'_>) -> PyResult<PyObject> {
let dict = PyDict::new(py); let dict = PyDict::new(py);
......
...@@ -479,6 +479,7 @@ class ModelRuntimeConfig: ...@@ -479,6 +479,7 @@ class ModelRuntimeConfig:
max_num_batched_tokens: int | None max_num_batched_tokens: int | None
tool_call_parser: str | None tool_call_parser: str | None
reasoning_parser: str | None reasoning_parser: str | None
exclude_tools_when_tool_choice_none: bool
data_parallel_start_rank: int data_parallel_start_rank: int
data_parallel_size: int data_parallel_size: int
enable_local_indexer: bool enable_local_indexer: bool
......
...@@ -1404,7 +1404,8 @@ async fn main() -> Result<()> { ...@@ -1404,7 +1404,8 @@ async fn main() -> Result<()> {
if let Some(contents) = contents { if let Some(contents) = contents {
match serde_json::from_str::<ChatTemplate>(&contents) { match serde_json::from_str::<ChatTemplate>(&contents) {
Ok(chat_template) => { Ok(chat_template) => {
match PromptFormatter::from_parts(chat_template, ContextMixins::new(&[])) { match PromptFormatter::from_parts(chat_template, ContextMixins::new(&[]), true)
{
Ok(formatter) => { Ok(formatter) => {
println!( println!(
" Prompt formatter loaded from tokenizer_config.json (using frontend-compatible renderer)" " Prompt formatter loaded from tokenizer_config.json (using frontend-compatible renderer)"
......
...@@ -28,6 +28,10 @@ pub struct ModelRuntimeConfig { ...@@ -28,6 +28,10 @@ pub struct ModelRuntimeConfig {
pub reasoning_parser: Option<String>, pub reasoning_parser: Option<String>,
/// When true, strip tool definitions from the chat template when tool_choice is "none".
#[serde(default = "default_exclude_tools_when_tool_choice_none")]
pub exclude_tools_when_tool_choice_none: bool,
/// Starting rank of data parallel ranks for this worker (0 if DP not enabled) /// Starting rank of data parallel ranks for this worker (0 if DP not enabled)
#[serde(default = "default_data_parallel_start_rank")] #[serde(default = "default_data_parallel_start_rank")]
pub data_parallel_start_rank: u32, pub data_parallel_start_rank: u32,
...@@ -74,6 +78,10 @@ const fn default_local_indexer() -> bool { ...@@ -74,6 +78,10 @@ const fn default_local_indexer() -> bool {
true true
} }
const fn default_exclude_tools_when_tool_choice_none() -> bool {
true
}
const fn default_eagle() -> bool { const fn default_eagle() -> bool {
false false
} }
...@@ -86,6 +94,7 @@ impl Default for ModelRuntimeConfig { ...@@ -86,6 +94,7 @@ impl Default for ModelRuntimeConfig {
max_num_batched_tokens: None, max_num_batched_tokens: None,
tool_call_parser: None, tool_call_parser: None,
reasoning_parser: None, reasoning_parser: None,
exclude_tools_when_tool_choice_none: default_exclude_tools_when_tool_choice_none(),
data_parallel_start_rank: default_data_parallel_start_rank(), data_parallel_start_rank: default_data_parallel_start_rank(),
data_parallel_size: default_data_parallel_size(), data_parallel_size: default_data_parallel_size(),
enable_local_indexer: true, enable_local_indexer: true,
......
...@@ -79,6 +79,7 @@ impl PromptFormatter { ...@@ -79,6 +79,7 @@ impl PromptFormatter {
mdc.prompt_context mdc.prompt_context
.clone() .clone()
.map_or(ContextMixins::default(), |x| ContextMixins::new(&x)), .map_or(ContextMixins::default(), |x| ContextMixins::new(&x)),
mdc.runtime_config.exclude_tools_when_tool_choice_none,
) )
} }
PromptFormatterArtifact::HfChatTemplate { .. } => Err(anyhow::anyhow!( PromptFormatterArtifact::HfChatTemplate { .. } => Err(anyhow::anyhow!(
...@@ -87,8 +88,16 @@ impl PromptFormatter { ...@@ -87,8 +88,16 @@ impl PromptFormatter {
} }
} }
pub fn from_parts(config: ChatTemplate, context: ContextMixins) -> Result<PromptFormatter> { pub fn from_parts(
let formatter = HfTokenizerConfigJsonFormatter::new(config, context)?; config: ChatTemplate,
context: ContextMixins,
exclude_tools_when_tool_choice_none: bool,
) -> Result<PromptFormatter> {
let formatter = HfTokenizerConfigJsonFormatter::with_options(
config,
context,
exclude_tools_when_tool_choice_none,
)?;
Ok(Self::OAI(Arc::new(formatter))) Ok(Self::OAI(Arc::new(formatter)))
} }
} }
...@@ -123,6 +132,9 @@ struct HfTokenizerConfigJsonFormatter { ...@@ -123,6 +132,9 @@ struct HfTokenizerConfigJsonFormatter {
mixins: Arc<ContextMixins>, mixins: Arc<ContextMixins>,
supports_add_generation_prompt: bool, supports_add_generation_prompt: bool,
requires_content_arrays: bool, requires_content_arrays: bool,
/// When true, strip tool definitions from the chat template when tool_choice is "none".
/// This prevents models from generating raw XML tool calls in the content field.
exclude_tools_when_tool_choice_none: bool,
} }
// /// OpenAI Standard Prompt Formatter // /// OpenAI Standard Prompt Formatter
......
...@@ -74,7 +74,16 @@ impl Default for JinjaEnvironment { ...@@ -74,7 +74,16 @@ impl Default for JinjaEnvironment {
} }
impl HfTokenizerConfigJsonFormatter { impl HfTokenizerConfigJsonFormatter {
#[cfg(test)]
pub fn new(config: ChatTemplate, mixins: ContextMixins) -> anyhow::Result<Self> { pub fn new(config: ChatTemplate, mixins: ContextMixins) -> anyhow::Result<Self> {
Self::with_options(config, mixins, true)
}
pub fn with_options(
config: ChatTemplate,
mixins: ContextMixins,
exclude_tools_when_tool_choice_none: bool,
) -> anyhow::Result<Self> {
let mut env = JinjaEnvironment::default().env(); let mut env = JinjaEnvironment::default().env();
let chat_template = config.chat_template.as_ref().ok_or(anyhow::anyhow!( let chat_template = config.chat_template.as_ref().ok_or(anyhow::anyhow!(
...@@ -158,6 +167,7 @@ impl HfTokenizerConfigJsonFormatter { ...@@ -158,6 +167,7 @@ impl HfTokenizerConfigJsonFormatter {
mixins: Arc::new(mixins), mixins: Arc::new(mixins),
supports_add_generation_prompt: supports_add_generation_prompt.unwrap_or(false), supports_add_generation_prompt: supports_add_generation_prompt.unwrap_or(false),
requires_content_arrays, requires_content_arrays,
exclude_tools_when_tool_choice_none,
}) })
} }
} }
......
...@@ -346,6 +346,16 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter { ...@@ -346,6 +346,16 @@ impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
let mixins = Value::from_dyn_object(self.mixins.clone()); let mixins = Value::from_dyn_object(self.mixins.clone());
let tools = req.tools(); let tools = req.tools();
// Strip tools when tool_choice is "none" and the flag is enabled, so the model
// doesn't see tool definitions and generate raw XML tool calls in its response.
let tools = if self.exclude_tools_when_tool_choice_none {
match req.tool_choice() {
Some(ref tc) if tc.as_str() == Some("none") => None,
_ => tools,
}
} else {
tools
};
// has_tools should be true if tools is a non-empty array // has_tools should be true if tools is a non-empty array
let has_tools = tools.as_ref().and_then(|v| v.len()).is_some_and(|l| l > 0); let has_tools = tools.as_ref().and_then(|v| v.len()).is_some_and(|l| l > 0);
let add_generation_prompt = req.should_add_generation_prompt(); let add_generation_prompt = req.should_add_generation_prompt();
...@@ -1225,4 +1235,81 @@ NORMAL MODE ...@@ -1225,4 +1235,81 @@ NORMAL MODE
let s = dummy_state(vec![]); let s = dummy_state(vec![]);
assert!(s.should_add_generation_prompt()); assert!(s.should_add_generation_prompt());
} }
/// Helper to build a formatter with a simple tool-aware template.
fn tool_aware_formatter(
exclude_tools_when_tool_choice_none: bool,
) -> HfTokenizerConfigJsonFormatter {
let template = r#"
{%- if tools is iterable and tools | length > 0 %}
TOOL_MODE tools={{ tools | length }}
{%- else %}
NORMAL_MODE
{%- endif %}
{{ messages[0].content }}"#;
let chat_template: super::tokcfg::ChatTemplate =
serde_json::from_value(serde_json::json!({ "chat_template": template })).unwrap();
HfTokenizerConfigJsonFormatter::with_options(
chat_template,
ContextMixins::new(&[]),
exclude_tools_when_tool_choice_none,
)
.unwrap()
}
/// Helper to build a request with tools and optional tool_choice.
fn request_with_tool_choice(tool_choice: &str) -> NvCreateChatCompletionRequest {
serde_json::from_value(serde_json::json!({
"model": "test",
"messages": [{"role": "user", "content": "hello"}],
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"location": {"type": "string"}}}
}
}],
"tool_choice": tool_choice
}))
.unwrap()
}
#[test]
fn test_exclude_tools_strips_when_tool_choice_none() {
let formatter = tool_aware_formatter(true);
let request = request_with_tool_choice("none");
let result = formatter.render(&request).unwrap();
assert!(
result.contains("NORMAL_MODE"),
"With exclude_tools=true and tool_choice=none, tools should be stripped. Got: {}",
result
);
}
#[test]
fn test_exclude_tools_keeps_when_tool_choice_auto() {
let formatter = tool_aware_formatter(true);
let request = request_with_tool_choice("auto");
let result = formatter.render(&request).unwrap();
assert!(
result.contains("TOOL_MODE"),
"With tool_choice=auto, tools should be included. Got: {}",
result
);
}
#[test]
fn test_no_exclude_tools_keeps_when_tool_choice_none() {
let formatter = tool_aware_formatter(false);
let request = request_with_tool_choice("none");
let result = formatter.render(&request).unwrap();
assert!(
result.contains("TOOL_MODE"),
"With exclude_tools=false and tool_choice=none, tools should NOT be stripped. Got: {}",
result
);
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment