Unverified Commit 8fd7de9a authored by Richard Huo's avatar Richard Huo Committed by GitHub
Browse files

fix: update the tool calling functionalities for sglang frontend processor to...

fix: update the tool calling functionalities for sglang frontend processor to match with the latest sglang implementation (#8269)
parent 01688850
...@@ -86,6 +86,7 @@ class FrontendConfig(KvRouterConfigBase, AicPerfConfigBase): ...@@ -86,6 +86,7 @@ class FrontendConfig(KvRouterConfigBase, AicPerfConfigBase):
exclude_tools_when_tool_choice_none: bool exclude_tools_when_tool_choice_none: bool
preprocess_workers: int preprocess_workers: int
tokenizer_backend: str tokenizer_backend: str
trust_remote_code: bool
_VALID_TOKENIZER_BACKENDS = {"default", "fastokens"} _VALID_TOKENIZER_BACKENDS = {"default", "fastokens"}
...@@ -562,3 +563,14 @@ class FrontendArgGroup(ArgGroup): ...@@ -562,3 +563,14 @@ class FrontendArgGroup(ArgGroup):
), ),
choices=["default", "fastokens"], choices=["default", "fastokens"],
) )
add_negatable_bool_argument(
g,
flag_name="--trust-remote-code",
env_var="DYN_TRUST_REMOTE_CODE",
default=False,
help=(
"Trust remote code when loading the tokenizer. Required for models "
"that ship custom tokenizer code (e.g. Qwen, Falcon)."
),
)
...@@ -114,17 +114,6 @@ def parse_args() -> tuple[FrontendConfig, Optional[Namespace], Optional[Namespac ...@@ -114,17 +114,6 @@ def parse_args() -> tuple[FrontendConfig, Optional[Namespace], Optional[Namespac
vllm_flags = None vllm_flags = None
sglang_flags = None sglang_flags = None
# --trust-remote-code is only meaningful with --dyn-chat-processor vllm.
# Warn and strip it when a different (or no) chat processor is active so
# it does not propagate as an unknown-argument error below.
if "--trust-remote-code" in unknown and config.chat_processor != "vllm":
logger.warning(
"--trust-remote-code has no effect without '--dyn-chat-processor vllm'. "
"It is only supported by the vLLM chat processor. "
"Pass '--dyn-chat-processor vllm' to enable trust_remote_code."
)
unknown = [arg for arg in unknown if arg != "--trust-remote-code"]
# parse extra vllm flags using vllm native parser. # parse extra vllm flags using vllm native parser.
if config.chat_processor == "vllm": if config.chat_processor == "vllm":
try: try:
......
...@@ -3,24 +3,42 @@ ...@@ -3,24 +3,42 @@
from __future__ import annotations from __future__ import annotations
import json
import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any, TypeAlias
from sglang.srt.entrypoints.openai.protocol import Function as SglangFunction from sglang.srt.entrypoints.openai.protocol import Function as SglangFunction
from sglang.srt.entrypoints.openai.protocol import Tool as SglangTool from sglang.srt.entrypoints.openai.protocol import Tool as SglangTool
from sglang.srt.entrypoints.openai.protocol import ToolChoice as SglangToolChoice
from sglang.srt.entrypoints.openai.protocol import (
ToolChoiceFuncName as SglangToolChoiceFuncName,
)
from sglang.srt.function_call.core_types import ToolCallItem
from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.function_call.json_array_parser import JsonArrayParser
from sglang.srt.function_call.utils import get_json_schema_constraint
from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.parser.reasoning_parser import ReasoningParser
from .utils import random_call_id from .utils import random_call_id
logger = logging.getLogger(__name__)
# Union of parser types used for tool call detection.
# - FunctionCallParser: model-specific format detection (tool_choice="auto")
# - JsonArrayParser: direct JSON array parsing under constrained decoding
# (tool_choice="required" or named function)
ToolCallParserType: TypeAlias = FunctionCallParser | JsonArrayParser
@dataclass @dataclass
class SglangPreprocessResult: class SglangPreprocessResult:
"""Result of SGLang preprocessing.""" """Result of SGLang preprocessing."""
prompt_token_ids: list[int] prompt_token_ids: list[int]
tool_call_parser: FunctionCallParser | None tool_call_parser: ToolCallParserType | None
reasoning_parser: ReasoningParser | None reasoning_parser: ReasoningParser | None
guided_decoding: dict[str, Any] | None
request: dict[str, Any] request: dict[str, Any]
...@@ -64,7 +82,7 @@ def create_parsers( ...@@ -64,7 +82,7 @@ def create_parsers(
tool_call_parser_name: str | None, tool_call_parser_name: str | None,
reasoning_parser_name: str | None, reasoning_parser_name: str | None,
sglang_tools: list[SglangTool] | None = None, sglang_tools: list[SglangTool] | None = None,
) -> tuple[FunctionCallParser | None, ReasoningParser | None]: ) -> tuple[ToolCallParserType | None, ReasoningParser | None]:
"""Create tool call and reasoning parsers for a request. """Create tool call and reasoning parsers for a request.
Shared by both the single-process preprocessing path and the pool path Shared by both the single-process preprocessing path and the pool path
...@@ -72,17 +90,25 @@ def create_parsers( ...@@ -72,17 +90,25 @@ def create_parsers(
If ``sglang_tools`` is provided, reuses them; otherwise converts from If ``sglang_tools`` is provided, reuses them; otherwise converts from
the request's ``tools`` field. the request's ``tools`` field.
For ``tool_choice="required"`` or a named function, uses
:class:`JsonArrayParser` (matching native SGLang) since guided decoding
constrains the output to a JSON array. Otherwise uses the model-specific
:class:`FunctionCallParser`.
""" """
if sglang_tools is None: if sglang_tools is None:
sglang_tools = convert_tools(request.get("tools")) sglang_tools = convert_tools(request.get("tools"))
tool_choice = request.get("tool_choice", "auto") tool_choice = request.get("tool_choice", "auto")
tool_call_parser = None tool_call_parser: ToolCallParserType | None = None
if tool_call_parser_name and sglang_tools and tool_choice != "none": if sglang_tools and tool_choice != "none":
tool_call_parser = FunctionCallParser( if tool_choice == "required" or _is_named_tool_choice(tool_choice):
tools=sglang_tools, tool_call_parser = JsonArrayParser()
tool_call_parser=tool_call_parser_name, elif tool_call_parser_name:
) tool_call_parser = FunctionCallParser(
tools=sglang_tools,
tool_call_parser=tool_call_parser_name,
)
reasoning_parser = None reasoning_parser = None
if reasoning_parser_name: if reasoning_parser_name:
...@@ -94,6 +120,78 @@ def create_parsers( ...@@ -94,6 +120,78 @@ def create_parsers(
return tool_call_parser, reasoning_parser return tool_call_parser, reasoning_parser
def _is_named_tool_choice(tool_choice: Any) -> bool:
return (
isinstance(tool_choice, dict)
and tool_choice.get("type") == "function"
and isinstance(tool_choice.get("function"), dict)
and bool(tool_choice["function"].get("name"))
)
def build_tool_call_guided_decoding(
request: dict[str, Any],
*,
tool_call_parser_name: str | None,
sglang_tools: list[SglangTool] | None,
) -> dict[str, Any] | None:
"""Build native-SGLang-like tool call constraints for guided decoding."""
if not sglang_tools:
return None
tool_choice = request.get("tool_choice", "auto")
if tool_choice == "none":
return None
parallel_tool_calls = request.get("parallel_tool_calls")
constraint: Any = None
if tool_choice == "required" or _is_named_tool_choice(tool_choice):
# get_json_schema_constraint branches on isinstance(tool_choice,
# ToolChoice) for the named-function case — passing our raw dict
# would silently fall through and return None, disabling guided
# decoding and letting the model omit required fields.
sglang_tool_choice: Any = tool_choice
if _is_named_tool_choice(tool_choice):
sglang_tool_choice = SglangToolChoice(
type="function",
function=SglangToolChoiceFuncName(
name=tool_choice["function"]["name"],
),
)
constraint = (
"json_schema",
get_json_schema_constraint(
sglang_tools,
sglang_tool_choice,
parallel_tool_calls=parallel_tool_calls,
),
)
elif tool_call_parser_name:
parser = FunctionCallParser(
tools=sglang_tools,
tool_call_parser=tool_call_parser_name,
)
constraint = parser.get_structure_constraint(
tool_choice,
parallel_tool_calls=parallel_tool_calls,
)
if isinstance(constraint, tuple) and len(constraint) == 2:
if constraint[0] == "json_schema":
return {"json": constraint[1]}
if constraint[0] == "structural_tag":
tag_value = constraint[1]
# SGLang returns a Pydantic model (LegacyStructuralTagResponseFormat)
# here. Convert to a plain dict before it hits the RPC layer —
# msgpack/serde_json cannot serialize BaseModel instances.
if hasattr(tag_value, "model_dump"):
tag_value = tag_value.model_dump()
return {"structural_tag": tag_value}
return None
def _normalize_prompt_token_ids(prompt_token_ids: Any) -> list[int]: def _normalize_prompt_token_ids(prompt_token_ids: Any) -> list[int]:
if isinstance(prompt_token_ids, list): if isinstance(prompt_token_ids, list):
return prompt_token_ids return prompt_token_ids
...@@ -127,6 +225,20 @@ def preprocess_chat_request( ...@@ -127,6 +225,20 @@ def preprocess_chat_request(
# Convert tools to SGLang format (done once, shared with parser creation) # Convert tools to SGLang format (done once, shared with parser creation)
sglang_tools = convert_tools(request.get("tools")) sglang_tools = convert_tools(request.get("tools"))
# Reject a named tool_choice whose function is missing from tools —
# otherwise the chat template would render with zero tools while
# guided decoding still constrains the output to that function's
# schema, producing confusing model behavior.
tool_choice = request.get("tool_choice", "auto")
if _is_named_tool_choice(tool_choice):
chosen_name = tool_choice["function"]["name"]
available_names = {t.function.name for t in (sglang_tools or [])}
if chosen_name not in available_names:
raise ValueError(
f"tool_choice names function {chosen_name!r}, but it is not "
f"present in tools (available: {sorted(available_names) or 'none'})"
)
# Build template kwargs -- single call for rendering + tokenization # Build template kwargs -- single call for rendering + tokenization
template_kwargs: dict[str, Any] = { template_kwargs: dict[str, Any] = {
"add_generation_prompt": True, "add_generation_prompt": True,
...@@ -134,11 +246,18 @@ def preprocess_chat_request( ...@@ -134,11 +246,18 @@ def preprocess_chat_request(
} }
# Strip tools from template when tool_choice=none so the model doesn't # Strip tools from template when tool_choice=none so the model doesn't
# see them and generate raw XML tool calls in its response. # see them and generate raw XML tool calls in its response.
tool_choice = request.get("tool_choice", "auto") # When tool_choice names a specific function, only include that tool
# in the template so the model doesn't see irrelevant definitions.
if sglang_tools and not ( if sglang_tools and not (
exclude_tools_when_tool_choice_none and tool_choice == "none" exclude_tools_when_tool_choice_none and tool_choice == "none"
): ):
template_kwargs["tools"] = [t.model_dump() for t in sglang_tools] if _is_named_tool_choice(tool_choice):
chosen_name = tool_choice["function"]["name"]
template_kwargs["tools"] = [
t.model_dump() for t in sglang_tools if t.function.name == chosen_name
]
else:
template_kwargs["tools"] = [t.model_dump() for t in sglang_tools]
prompt_token_ids = _normalize_prompt_token_ids( prompt_token_ids = _normalize_prompt_token_ids(
tokenizer.apply_chat_template(messages, **template_kwargs) tokenizer.apply_chat_template(messages, **template_kwargs)
...@@ -150,11 +269,17 @@ def preprocess_chat_request( ...@@ -150,11 +269,17 @@ def preprocess_chat_request(
reasoning_parser_name=reasoning_parser_name, reasoning_parser_name=reasoning_parser_name,
sglang_tools=sglang_tools, sglang_tools=sglang_tools,
) )
guided_decoding = build_tool_call_guided_decoding(
request,
tool_call_parser_name=tool_call_parser_name,
sglang_tools=sglang_tools,
)
return SglangPreprocessResult( return SglangPreprocessResult(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
tool_call_parser=tool_call_parser, tool_call_parser=tool_call_parser,
reasoning_parser=reasoning_parser, reasoning_parser=reasoning_parser,
guided_decoding=guided_decoding,
request=request, request=request,
) )
...@@ -163,13 +288,102 @@ def _random_call_id() -> str: ...@@ -163,13 +288,102 @@ def _random_call_id() -> str:
return random_call_id() return random_call_id()
def _get_history_tool_calls_count(messages: list[dict[str, Any]]) -> int:
"""Count prior assistant tool calls for parser-specific ID generation."""
count = 0
for msg in messages:
if msg.get("role") != "assistant":
continue
tool_calls = msg.get("tool_calls")
if isinstance(tool_calls, list):
count += len(tool_calls)
return count
def _tool_call_id_for_parser(
parser_name: str | None,
name: str,
index: int,
history_tool_calls_count: int,
) -> str:
"""Match native SGLang tool-call ID behavior for parser-specific formats.
``index`` is the sequential position of this call within the current
response — callers must pass the same index they use as the dict key
for the call, so the ID stays consistent with the emitted ``index``
field. For ``parse_non_stream`` output, ``ToolCallItem.tool_index``
can instead reflect the tool-definition position, so it is not safe
to read here directly.
"""
if parser_name != "kimi_k2":
return _random_call_id()
return f"functions.{name or ''}:{history_tool_calls_count + index}"
def _parse_json_array_buffer(buffer: str) -> list[ToolCallItem]:
"""Parse a JSON array buffer from constrained decoding into ToolCallItems.
Used as the fallback when JsonArrayParser's streaming parsing missed
arguments (same chunking-sensitivity issue as FunctionCallParser).
Mirrors SGLang native's ``orjson.loads`` path in ``_process_tool_calls``.
The buffer may contain trailing special tokens (e.g. ``<|endoftext|>``)
from incremental detokenization with ``skip_special_tokens=False``.
If the full buffer is not valid JSON, we extract the substring between
the first ``[`` and last ``]`` and retry.
"""
data = _try_parse_json_array(buffer)
if data is None:
return []
calls: list[ToolCallItem] = []
for i, tool in enumerate(data):
if not isinstance(tool, dict):
continue
name = tool.get("name", "")
params = tool.get("parameters")
if params is None:
params = tool.get("arguments")
if params is not None and not isinstance(params, str):
params = json.dumps(params, ensure_ascii=False)
calls.append(
ToolCallItem(
tool_index=i,
name=name,
parameters=params if params is not None else "",
)
)
return calls
def _try_parse_json_array(text: str) -> list | None:
"""Try to parse a JSON array from *text*, tolerating surrounding noise."""
try:
data = json.loads(text)
if isinstance(data, list):
return data
except (json.JSONDecodeError, TypeError):
pass
# Retry: extract the outermost [...] substring (handles trailing
# special tokens or leading content text).
start = text.find("[")
end = text.rfind("]")
if start != -1 and end > start:
try:
data = json.loads(text[start : end + 1])
if isinstance(data, list):
return data
except (json.JSONDecodeError, TypeError):
pass
return None
class SglangStreamingPostProcessor: class SglangStreamingPostProcessor:
"""Streaming post-processor using SGLang parsers and HF tokenizer detokenization. """Streaming post-processor using SGLang parsers and HF tokenizer detokenization.
Handles: Handles:
- Incremental detokenization via sliding-window decode (6-token lookback) - Incremental detokenization via sliding-window decode (6-token lookback)
- Reasoning content extraction via SGLang ReasoningParser - Reasoning content extraction via SGLang ReasoningParser
- Tool call parsing via SGLang FunctionCallParser (parameter deltas) - Tool call parsing via SGLang FunctionCallParser or JsonArrayParser
""" """
# Lookback window size for incremental detokenization. UTF-8 characters # Lookback window size for incremental detokenization. UTF-8 characters
...@@ -182,26 +396,46 @@ class SglangStreamingPostProcessor: ...@@ -182,26 +396,46 @@ class SglangStreamingPostProcessor:
self, self,
*, *,
tokenizer, tokenizer,
tool_call_parser: FunctionCallParser | None, tool_call_parser: ToolCallParserType | None,
reasoning_parser: ReasoningParser | None, reasoning_parser: ReasoningParser | None,
history_tool_calls_count: int = 0,
sglang_tools: list[SglangTool] | None = None,
tool_call_parser_name: str | None = None,
) -> None: ) -> None:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tool_call_parser = tool_call_parser self.tool_call_parser = tool_call_parser
self.reasoning_parser = reasoning_parser self.reasoning_parser = reasoning_parser
self.history_tool_calls_count = history_tool_calls_count
self._sglang_tools = sglang_tools or []
self._tool_call_parser_name = tool_call_parser_name
self._fast_plain_text = tool_call_parser is None and reasoning_parser is None self._fast_plain_text = tool_call_parser is None and reasoning_parser is None
# Preserve special tokens when a tool call parser is active so
# delimiter tokens (e.g. <|tool_call|>) remain visible to the parser.
self._skip_special_tokens = tool_call_parser is None
self._is_json_array_parser = isinstance(tool_call_parser, JsonArrayParser)
self._all_token_ids: list[int] = [] self._all_token_ids: list[int] = []
# Tool call accumulation. SGLang's streaming parser returns # Tool call accumulation. SGLang's streaming parser returns
# deltas (name in one chunk, argument fragments across subsequent # deltas (name in one chunk, argument fragments across subsequent
# chunks). However, when the complete tool-call JSON arrives in a # chunks). However, the base detector processes at most one event
# single chunk the parser emits the name but never streams # per call (a name OR an argument diff), and the post-processor
# arguments (a chunking-sensitivity issue in the base detector). # calls it only once per token batch. When multiple tool calls
# We accumulate names + arg fragments from streaming deltas and, # arrive together, later calls may not be detected during streaming.
# on finish, fall back to parse_non_stream on the detector buffer # We accumulate all text fed to the parser and, on finish, re-parse
# for any tool call whose arguments are still missing. # the full text to recover any missed tool calls or arguments.
self._tool_call_ids: dict[int, str] = {} # tool_index -> call_id self._tool_call_ids: dict[int, str] = {} # tool_index -> call_id
self._tool_call_names: dict[int, str] = {} # tool_index -> name self._tool_call_names: dict[int, str] = {} # tool_index -> name
self._tool_call_args: dict[int, list[str]] = {} # tool_index -> arg chunks self._tool_call_args: dict[int, list[str]] = {} # tool_index -> arg chunks
# Full text accumulator for robust finish-time re-parse.
self._tool_text_parts: list[str] = []
def _tool_call_id(self, name: str, index: int) -> str:
return _tool_call_id_for_parser(
self._tool_call_parser_name,
name,
index,
self.history_tool_calls_count,
)
def _incremental_decode(self, new_token_ids: list[int]) -> str: def _incremental_decode(self, new_token_ids: list[int]) -> str:
"""Decode new tokens with lookback window for multi-byte char boundaries. """Decode new tokens with lookback window for multi-byte char boundaries.
...@@ -226,14 +460,18 @@ class SglangStreamingPostProcessor: ...@@ -226,14 +460,18 @@ class SglangStreamingPostProcessor:
# Decode lookback-only prefix (before new tokens) # Decode lookback-only prefix (before new tokens)
prefix_tokens = self._all_token_ids[start:prev_count] prefix_tokens = self._all_token_ids[start:prev_count]
prefix_text = ( prefix_text = (
self.tokenizer.decode(prefix_tokens, skip_special_tokens=True) self.tokenizer.decode(
prefix_tokens, skip_special_tokens=self._skip_special_tokens
)
if prefix_tokens if prefix_tokens
else "" else ""
) )
# Decode lookback + new tokens together # Decode lookback + new tokens together
window_tokens = self._all_token_ids[start:] window_tokens = self._all_token_ids[start:]
window_text = self.tokenizer.decode(window_tokens, skip_special_tokens=True) window_text = self.tokenizer.decode(
window_tokens, skip_special_tokens=self._skip_special_tokens
)
return window_text[len(prefix_text) :] return window_text[len(prefix_text) :]
...@@ -282,15 +520,24 @@ class SglangStreamingPostProcessor: ...@@ -282,15 +520,24 @@ class SglangStreamingPostProcessor:
content_text = normal_text content_text = normal_text
if self.tool_call_parser and normal_text: if self.tool_call_parser and normal_text:
parsed_text, tool_calls = self.tool_call_parser.parse_stream_chunk( # Accumulate raw text for finish-time re-parse.
normal_text self._tool_text_parts.append(normal_text)
)
if self._is_json_array_parser:
result = self.tool_call_parser.parse_streaming_increment(
normal_text, self._sglang_tools
)
parsed_text, tool_calls = result.normal_text, result.calls
else:
parsed_text, tool_calls = self.tool_call_parser.parse_stream_chunk(
normal_text
)
content_text = parsed_text content_text = parsed_text
for tc in tool_calls: for tc in tool_calls:
idx = tc.tool_index idx = tc.tool_index
if idx not in self._tool_call_ids: if idx not in self._tool_call_ids:
self._tool_call_ids[idx] = _random_call_id() self._tool_call_ids[idx] = self._tool_call_id(tc.name or "", idx)
if tc.name: if tc.name:
self._tool_call_names[idx] = tc.name self._tool_call_names[idx] = tc.name
if tc.parameters: if tc.parameters:
...@@ -307,26 +554,128 @@ class SglangStreamingPostProcessor: ...@@ -307,26 +554,128 @@ class SglangStreamingPostProcessor:
delta["reasoning_content"] = reasoning_text delta["reasoning_content"] = reasoning_text
has_content = True has_content = True
# Emit complete tool calls on finish. For any tool call whose # On finish, re-parse the full accumulated text to recover tool
# arguments are still empty (chunking-sensitivity issue), fall # calls or arguments that the streaming parser missed.
# back to parse_non_stream on the detector's buffer. #
if finish_reason and self._tool_call_names: # The streaming parser (BaseFormatDetector.parse_streaming_increment)
# processes at most one event per invocation — a tool name OR an
# argument diff — and the post-processor calls it once per token
# batch. When multiple tool calls arrive together or the complete
# JSON lands in a single chunk, later calls (or arguments) may
# never be detected during streaming.
#
# The re-parse uses the accumulated text (not the parser's internal
# _buffer, which is consumed during streaming) and assigns
# sequential indices to match the OpenAI API convention.
if (
finish_reason
and self.tool_call_parser is not None
and self._tool_text_parts
):
# Purge streaming results that don't match any known tool.
# When guided decoding is not enforced the streaming parser
# can misidentify words in the prompt (e.g. a person's name)
# as function names.
known_names = (
{t.function.name for t in self._sglang_tools}
if self._sglang_tools
else set()
)
if known_names:
for idx in list(self._tool_call_names):
if self._tool_call_names[idx] not in known_names:
del self._tool_call_names[idx]
self._tool_call_ids.pop(idx, None)
self._tool_call_args.pop(idx, None)
# Discard malformed (non-JSON) argument fragments that the
# streaming parser accumulated from mixed content.
for idx in list(self._tool_call_args):
combined = "".join(self._tool_call_args[idx])
if combined:
try:
json.loads(combined)
except (json.JSONDecodeError, ValueError):
del self._tool_call_args[idx]
missing_names = not self._tool_call_names
missing_args = any( missing_args = any(
idx not in self._tool_call_args for idx in self._tool_call_names idx not in self._tool_call_args for idx in self._tool_call_names
) )
if missing_args and self.tool_call_parser is not None: should_reparse = False
buffer = getattr(self.tool_call_parser.detector, "_buffer", "") full_text = ""
if buffer: if missing_names or missing_args:
_, final_calls = self.tool_call_parser.parse_non_stream(buffer) full_text = "".join(self._tool_text_parts)
for tc in final_calls: # Skip the re-parse when the accumulated text has no
idx = tc.tool_index # tool-call markers. Avoids wasted `parse_non_stream`
if idx not in self._tool_call_ids: # work on plain-text responses (common when tools are
self._tool_call_ids[idx] = _random_call_id() # offered but the model replies without calling any) and
# guards against detectors that raise on arbitrary input.
should_reparse = bool(
full_text
) and self.tool_call_parser.has_tool_call(full_text)
if should_reparse:
if self._is_json_array_parser:
final_calls = _parse_json_array_buffer(full_text)
# Secondary fallback: when guided decoding did not
# constrain the output (e.g. the backend doesn't
# support it), the model may have produced tool calls
# in its native format. Try the model-specific
# parser so we don't silently drop them.
if (
not final_calls
and self._tool_call_parser_name
and self._sglang_tools
):
try:
fcp = FunctionCallParser(
tools=self._sglang_tools,
tool_call_parser=self._tool_call_parser_name,
)
_, final_calls = fcp.parse_non_stream(full_text)
except (
ValueError,
KeyError,
json.JSONDecodeError,
IndexError,
) as e:
# Fallback path: model-native tool-call text is
# malformed. Log and return no tool calls rather
# than crashing the whole response — the primary
# JSON-array path has already failed, and the
# normal text is still usable.
logger.warning(
"Native tool-call fallback parse failed (parser=%r): %s",
self._tool_call_parser_name,
e,
)
final_calls = []
else:
_, final_calls = self.tool_call_parser.parse_non_stream(full_text)
# Filter to known tool names (reuse set from above).
if known_names:
final_calls = [tc for tc in final_calls if tc.name in known_names]
# Re-index sequentially so repeated calls to the same
# tool get distinct indices (parse_non_stream may assign
# indices based on the tool-definition position instead).
# When the re-parse returns results, it is authoritative:
# clear streaming state first so we don't mix a name from
# the re-parse with args from streaming at the same index.
if final_calls:
self._tool_call_ids.clear()
self._tool_call_names.clear()
self._tool_call_args.clear()
for seq_idx, tc in enumerate(final_calls):
self._tool_call_ids[seq_idx] = self._tool_call_id(
tc.name or "", seq_idx
)
if tc.name: if tc.name:
self._tool_call_names[idx] = tc.name self._tool_call_names[seq_idx] = tc.name
if tc.parameters: if tc.parameters:
self._tool_call_args[idx] = [tc.parameters] self._tool_call_args[seq_idx] = [tc.parameters]
if finish_reason and self._tool_call_names:
tool_calls_out: list[dict[str, Any]] = [] tool_calls_out: list[dict[str, Any]] = []
for idx in sorted(self._tool_call_names): for idx in sorted(self._tool_call_names):
tool_calls_out.append( tool_calls_out.append(
...@@ -343,11 +692,17 @@ class SglangStreamingPostProcessor: ...@@ -343,11 +692,17 @@ class SglangStreamingPostProcessor:
delta["tool_calls"] = tool_calls_out delta["tool_calls"] = tool_calls_out
has_content = True has_content = True
if has_content or finish_reason: # Rewrite finish_reason "stop" → "tool_calls" when tool calls were
# detected, matching the OpenAI API spec and official SGLang behaviour.
effective_finish = finish_reason
if finish_reason == "stop" and self._tool_call_names:
effective_finish = "tool_calls"
if has_content or effective_finish:
return { return {
"index": 0, "index": 0,
"delta": delta if has_content else {}, "delta": delta if has_content else {},
"finish_reason": finish_reason, "finish_reason": effective_finish,
"logprobs": None, "logprobs": None,
} }
......
...@@ -32,6 +32,9 @@ from dynamo.runtime import DistributedRuntime ...@@ -32,6 +32,9 @@ from dynamo.runtime import DistributedRuntime
from .sglang_prepost import ( from .sglang_prepost import (
SglangStreamingPostProcessor, SglangStreamingPostProcessor,
ToolCallParserType,
_get_history_tool_calls_count,
convert_tools,
create_parsers, create_parsers,
preprocess_chat_request, preprocess_chat_request,
) )
...@@ -117,11 +120,12 @@ def _init_worker( ...@@ -117,11 +120,12 @@ def _init_worker(
tool_call_parser_name: str | None, tool_call_parser_name: str | None,
reasoning_parser_name: str | None, reasoning_parser_name: str | None,
exclude_tools_when_tool_choice_none: bool = True, exclude_tools_when_tool_choice_none: bool = True,
trust_remote_code: bool = False,
) -> None: ) -> None:
"""Initialize a worker process with its own tokenizer.""" """Initialize a worker process with its own tokenizer."""
global _w_tokenizer, _w_tool_call_parser_name, _w_reasoning_parser_name global _w_tokenizer, _w_tool_call_parser_name, _w_reasoning_parser_name
global _w_exclude_tools_when_tool_choice_none global _w_exclude_tools_when_tool_choice_none
_w_tokenizer = get_tokenizer(model_path) _w_tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code)
_w_tool_call_parser_name = tool_call_parser_name _w_tool_call_parser_name = tool_call_parser_name
_w_reasoning_parser_name = reasoning_parser_name _w_reasoning_parser_name = reasoning_parser_name
_w_exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none _w_exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
...@@ -146,7 +150,12 @@ def _preprocess_worker( ...@@ -146,7 +150,12 @@ def _preprocess_worker(
raise PreprocessError(_unsupported_n_error(n)) raise PreprocessError(_unsupported_n_error(n))
dynamo_preproc = _build_dynamo_preproc( dynamo_preproc = _build_dynamo_preproc(
request, pre.prompt_token_ids, model_name, eos_token_id request,
pre.prompt_token_ids,
model_name,
eos_token_id,
pre.guided_decoding,
pre.tool_call_parser,
) )
return SglangPreprocessWorkerResult( return SglangPreprocessWorkerResult(
...@@ -161,6 +170,8 @@ def _build_dynamo_preproc( ...@@ -161,6 +170,8 @@ def _build_dynamo_preproc(
prompt_token_ids: list[int], prompt_token_ids: list[int],
model_name: str, model_name: str,
eos_token_id: int | None, eos_token_id: int | None,
guided_decoding: dict[str, Any] | None = None,
tool_call_parser: ToolCallParserType | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Build the Dynamo preprocessed request dict from request fields.""" """Build the Dynamo preprocessed request dict from request fields."""
max_tokens = request.get("max_completion_tokens") or request.get("max_tokens") max_tokens = request.get("max_completion_tokens") or request.get("max_tokens")
...@@ -205,11 +216,16 @@ def _build_dynamo_preproc( ...@@ -205,11 +216,16 @@ def _build_dynamo_preproc(
"top_k": request.get("top_k", 0) or -1, "top_k": request.get("top_k", 0) or -1,
"min_p": request.get("min_p", 0.0), "min_p": request.get("min_p", 0.0),
"seed": request.get("seed"), "seed": request.get("seed"),
"guided_decoding": guided_decoding,
}, },
"output_options": { "output_options": {
"logprobs": logprobs_val, "logprobs": logprobs_val,
"prompt_logprobs": None, "prompt_logprobs": None,
"skip_special_tokens": True, # Preserve special tokens only when a tool-call parser is
# actually active — the parser needs delimiter tokens
# (e.g. <|tool_call|>) to detect calls. Mirrors the
# post-processor's _skip_special_tokens logic.
"skip_special_tokens": tool_call_parser is None,
}, },
"eos_token_ids": [eos_token_id] if eos_token_id is not None else [], "eos_token_ids": [eos_token_id] if eos_token_id is not None else [],
"annotations": [], "annotations": [],
...@@ -320,7 +336,12 @@ class SglangProcessor: ...@@ -320,7 +336,12 @@ class SglangProcessor:
return return
dynamo_preproc = _build_dynamo_preproc( dynamo_preproc = _build_dynamo_preproc(
request, tokens, request["model"], self.eos_token_id request,
tokens,
request["model"],
self.eos_token_id,
pre.guided_decoding,
pre.tool_call_parser,
) )
except Exception as exc: except Exception as exc:
logger.exception("SGLang preprocessing failed for request %s", request_id) logger.exception("SGLang preprocessing failed for request %s", request_id)
...@@ -336,6 +357,11 @@ class SglangProcessor: ...@@ -336,6 +357,11 @@ class SglangProcessor:
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
tool_call_parser=pre.tool_call_parser, tool_call_parser=pre.tool_call_parser,
reasoning_parser=pre.reasoning_parser, reasoning_parser=pre.reasoning_parser,
history_tool_calls_count=_get_history_tool_calls_count(
request.get("messages", [])
),
sglang_tools=convert_tools(request.get("tools")),
tool_call_parser_name=self.tool_call_parser_name,
) )
async for item in self._generate_and_stream( async for item in self._generate_and_stream(
...@@ -389,6 +415,11 @@ class SglangProcessor: ...@@ -389,6 +415,11 @@ class SglangProcessor:
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
tool_call_parser=tool_call_parser, tool_call_parser=tool_call_parser,
reasoning_parser=reasoning_parser, reasoning_parser=reasoning_parser,
history_tool_calls_count=_get_history_tool_calls_count(
request.get("messages", [])
),
sglang_tools=convert_tools(request.get("tools")),
tool_call_parser_name=self.tool_call_parser_name,
) )
async for item in self._generate_and_stream( async for item in self._generate_and_stream(
...@@ -530,6 +561,7 @@ class SglangEngineFactory: ...@@ -530,6 +561,7 @@ class SglangEngineFactory:
self.tool_call_parser_name = tool_call_parser_name self.tool_call_parser_name = tool_call_parser_name
self.reasoning_parser_name = reasoning_parser_name self.reasoning_parser_name = reasoning_parser_name
self.trust_remote_code = config.trust_remote_code
self.stream_interval = 20 self.stream_interval = 20
raw_stream_interval = os.getenv("DYN_SGLANG_STREAM_INTERVAL") raw_stream_interval = os.getenv("DYN_SGLANG_STREAM_INTERVAL")
if raw_stream_interval: if raw_stream_interval:
...@@ -560,7 +592,7 @@ class SglangEngineFactory: ...@@ -560,7 +592,7 @@ class SglangEngineFactory:
await fetch_model(source_path, ignore_weights=True) await fetch_model(source_path, ignore_weights=True)
logger.info("Loading SGLang tokenizer from %s", source_path) logger.info("Loading SGLang tokenizer from %s", source_path)
tokenizer = get_tokenizer(source_path) tokenizer = get_tokenizer(source_path, trust_remote_code=self.trust_remote_code)
eos_token_id = getattr(tokenizer, "eos_token_id", None) eos_token_id = getattr(tokenizer, "eos_token_id", None)
...@@ -610,6 +642,7 @@ class SglangEngineFactory: ...@@ -610,6 +642,7 @@ class SglangEngineFactory:
tool_call_parser_name, tool_call_parser_name,
reasoning_parser_name, reasoning_parser_name,
self.config.exclude_tools_when_tool_choice_none, self.config.exclude_tools_when_tool_choice_none,
self.trust_remote_code,
), ),
) )
futures = [ futures = [
......
...@@ -10,7 +10,11 @@ Parallels test_vllm_unit.py for the vLLM backend. ...@@ -10,7 +10,11 @@ Parallels test_vllm_unit.py for the vLLM backend.
""" """
import json
import pytest import pytest
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.function_call.json_array_parser import JsonArrayParser
from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.srt.utils.hf_transformers_utils import get_tokenizer
import dynamo.frontend.sglang_processor as sglang_processor_module import dynamo.frontend.sglang_processor as sglang_processor_module
...@@ -18,6 +22,8 @@ from dynamo.frontend.sglang_prepost import ( ...@@ -18,6 +22,8 @@ from dynamo.frontend.sglang_prepost import (
SglangPreprocessResult, SglangPreprocessResult,
SglangStreamingPostProcessor, SglangStreamingPostProcessor,
_normalize_prompt_token_ids, _normalize_prompt_token_ids,
_parse_json_array_buffer,
build_tool_call_guided_decoding,
convert_tools, convert_tools,
create_parsers, create_parsers,
preprocess_chat_request, preprocess_chat_request,
...@@ -119,6 +125,18 @@ class TestBuildDynamoPreproc: ...@@ -119,6 +125,18 @@ class TestBuildDynamoPreproc:
assert sampling["repetition_penalty"] == 1.1 assert sampling["repetition_penalty"] == 1.1
assert sampling["seed"] == 42 assert sampling["seed"] == 42
def test_guided_decoding_passthrough(self):
result = _build_dynamo_preproc(
{"model": "test"},
prompt_token_ids=[1, 2, 3],
model_name="test",
eos_token_id=None,
guided_decoding={"json": {"type": "object"}},
)
assert result["sampling_options"]["guided_decoding"] == {
"json": {"type": "object"}
}
def test_stop_conditions_string(self): def test_stop_conditions_string(self):
"""Single stop string is wrapped in a list.""" """Single stop string is wrapped in a list."""
result = _build_dynamo_preproc( result = _build_dynamo_preproc(
...@@ -368,6 +386,92 @@ class TestCreateParsers: ...@@ -368,6 +386,92 @@ class TestCreateParsers:
assert tcp is None assert tcp is None
assert rp is not None assert rp is not None
class TestBuildToolCallGuidedDecoding:
def test_none_when_no_tools(self):
assert (
build_tool_call_guided_decoding(
{"tool_choice": "auto"},
tool_call_parser_name="hermes",
sglang_tools=None,
)
is None
)
def test_none_when_tool_choice_none(self):
tools = convert_tools(
[
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {"type": "object", "properties": {}},
},
}
]
)
assert (
build_tool_call_guided_decoding(
{"tool_choice": "none"},
tool_call_parser_name="hermes",
sglang_tools=tools,
)
is None
)
def test_required_tool_choice_builds_json_schema_guidance(self):
tools = convert_tools(
[
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
}
]
)
guided = build_tool_call_guided_decoding(
{"tool_choice": "required"},
tool_call_parser_name="hermes",
sglang_tools=tools,
)
assert isinstance(guided, dict)
assert "json" in guided
def test_auto_strict_tools_can_build_structural_tag_guidance(self):
tools = convert_tools(
[
{
"type": "function",
"function": {
"name": "get_weather",
"strict": True,
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
}
]
)
guided = build_tool_call_guided_decoding(
{"tool_choice": "auto"},
tool_call_parser_name="kimi_k2",
sglang_tools=tools,
)
assert isinstance(guided, dict)
assert "structural_tag" in guided
def test_tool_parser_requires_tools(self): def test_tool_parser_requires_tools(self):
"""Tool parser is not created if no tools in request.""" """Tool parser is not created if no tools in request."""
tcp, rp = create_parsers( tcp, rp = create_parsers(
...@@ -437,6 +541,170 @@ class TestCreateParsers: ...@@ -437,6 +541,170 @@ class TestCreateParsers:
assert tcp is not None assert tcp is not None
assert rp is not None assert rp is not None
def test_required_creates_json_array_parser(self):
"""tool_choice='required' creates JsonArrayParser, not FunctionCallParser."""
request = {
"tools": [
{
"type": "function",
"function": {
"name": "f",
"description": "d",
"parameters": {},
},
}
],
"tool_choice": "required",
}
tcp, _ = create_parsers(
request, tool_call_parser_name="hermes", reasoning_parser_name=None
)
assert isinstance(tcp, JsonArrayParser)
def test_named_tool_choice_creates_json_array_parser(self):
"""Named tool_choice creates JsonArrayParser."""
request = {
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {},
},
}
],
"tool_choice": {
"type": "function",
"function": {"name": "get_weather"},
},
}
tcp, _ = create_parsers(
request, tool_call_parser_name="hermes", reasoning_parser_name=None
)
assert isinstance(tcp, JsonArrayParser)
def test_auto_creates_function_call_parser(self):
"""tool_choice='auto' creates FunctionCallParser."""
request = {
"tools": [
{
"type": "function",
"function": {
"name": "f",
"description": "d",
"parameters": {},
},
}
],
"tool_choice": "auto",
}
tcp, _ = create_parsers(
request, tool_call_parser_name="hermes", reasoning_parser_name=None
)
assert isinstance(tcp, FunctionCallParser)
def test_required_without_parser_name_still_creates_json_array_parser(self):
"""tool_choice='required' doesn't need tool_call_parser_name."""
request = {
"tools": [
{
"type": "function",
"function": {
"name": "f",
"description": "d",
"parameters": {},
},
}
],
"tool_choice": "required",
}
tcp, _ = create_parsers(
request, tool_call_parser_name=None, reasoning_parser_name=None
)
assert isinstance(tcp, JsonArrayParser)
# ---------------------------------------------------------------------------
# _parse_json_array_buffer
# ---------------------------------------------------------------------------
class TestParseJsonArrayBuffer:
"""Test JSON array fallback parser for constrained decoding output."""
def test_single_tool_call(self):
buffer = json.dumps([{"name": "get_weather", "parameters": {"city": "NYC"}}])
calls = _parse_json_array_buffer(buffer)
assert len(calls) == 1
assert calls[0].name == "get_weather"
assert calls[0].tool_index == 0
assert json.loads(calls[0].parameters) == {"city": "NYC"}
def test_multiple_tool_calls(self):
buffer = json.dumps(
[
{"name": "get_weather", "parameters": {"city": "NYC"}},
{"name": "search", "parameters": {"q": "hello"}},
]
)
calls = _parse_json_array_buffer(buffer)
assert len(calls) == 2
assert calls[0].name == "get_weather"
assert calls[0].tool_index == 0
assert calls[1].name == "search"
assert calls[1].tool_index == 1
def test_arguments_key_also_accepted(self):
"""Some formats use 'arguments' instead of 'parameters'."""
buffer = json.dumps([{"name": "f", "arguments": {"x": 1}}])
calls = _parse_json_array_buffer(buffer)
assert len(calls) == 1
assert json.loads(calls[0].parameters) == {"x": 1}
def test_string_parameters_preserved(self):
buffer = json.dumps([{"name": "f", "parameters": "already_a_string"}])
calls = _parse_json_array_buffer(buffer)
assert calls[0].parameters == "already_a_string"
def test_invalid_json_returns_empty(self):
assert _parse_json_array_buffer("not json") == []
def test_non_array_returns_empty(self):
assert _parse_json_array_buffer('{"name": "f"}') == []
def test_empty_buffer_returns_empty(self):
assert _parse_json_array_buffer("") == []
def test_non_dict_items_skipped(self):
buffer = json.dumps(["not_a_dict", {"name": "f", "parameters": {}}])
calls = _parse_json_array_buffer(buffer)
assert len(calls) == 1
assert calls[0].name == "f"
assert calls[0].tool_index == 1
def test_trailing_special_token(self):
"""Trailing EOS/special tokens should not break parsing."""
buffer = '[{"name": "f", "parameters": {"x": 1}}]<|endoftext|>'
calls = _parse_json_array_buffer(buffer)
assert len(calls) == 1
assert calls[0].name == "f"
assert json.loads(calls[0].parameters) == {"x": 1}
def test_leading_text_with_array(self):
"""Leading non-JSON text before the array should be tolerated."""
buffer = 'some preamble [{"name": "f", "parameters": {"x": 1}}]'
calls = _parse_json_array_buffer(buffer)
assert len(calls) == 1
assert calls[0].name == "f"
def test_trailing_and_leading_noise(self):
"""Both leading and trailing noise."""
buffer = 'text [{"name": "g", "parameters": {"y": 2}}] <|end|>'
calls = _parse_json_array_buffer(buffer)
assert len(calls) == 1
assert calls[0].name == "g"
class TestNormalizePromptTokenIds: class TestNormalizePromptTokenIds:
def test_batch_encoding_like_object_uses_input_ids(self): def test_batch_encoding_like_object_uses_input_ids(self):
...@@ -641,6 +909,37 @@ class TestPreprocessChatRequest: ...@@ -641,6 +909,37 @@ class TestPreprocessChatRequest:
with_auto.prompt_token_ids with_auto.prompt_token_ids
), "tool_choice=none with flag off should keep tools in template" ), "tool_choice=none with flag off should keep tools in template"
def test_named_tool_choice_missing_function_raises(self, tokenizer):
"""Named tool_choice referencing a function absent from tools raises ValueError."""
request = {
"model": MODEL,
"messages": [{"role": "user", "content": "Hello"}],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}
],
"tool_choice": {
"type": "function",
"function": {"name": "does_not_exist"},
},
}
with pytest.raises(ValueError, match="does_not_exist"):
preprocess_chat_request(
request,
tokenizer=tokenizer,
tool_call_parser_name="hermes",
reasoning_parser_name=None,
)
def test_init_worker_propagates_exclude_flag_true(self): def test_init_worker_propagates_exclude_flag_true(self):
"""_init_worker sets the worker-global exclude_tools flag to True.""" """_init_worker sets the worker-global exclude_tools flag to True."""
_init_worker(MODEL, None, None, exclude_tools_when_tool_choice_none=True) _init_worker(MODEL, None, None, exclude_tools_when_tool_choice_none=True)
......
...@@ -15,6 +15,7 @@ import pytest ...@@ -15,6 +15,7 @@ import pytest
from sglang.srt.entrypoints.openai.protocol import Function as SglangFunction from sglang.srt.entrypoints.openai.protocol import Function as SglangFunction
from sglang.srt.entrypoints.openai.protocol import Tool as SglangTool from sglang.srt.entrypoints.openai.protocol import Tool as SglangTool
from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.function_call.json_array_parser import JsonArrayParser
from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.srt.utils.hf_transformers_utils import get_tokenizer
...@@ -153,6 +154,114 @@ class TestSingleToolCall: ...@@ -153,6 +154,114 @@ class TestSingleToolCall:
assert tc[0]["index"] == 0 assert tc[0]["index"] == 0
class TestKimiToolCallIds:
def test_kimi_uses_history_adjusted_ids(self):
class DummyTokenizer:
def decode(self, token_ids, skip_special_tokens=True):
return "".join(chr(x) for x in token_ids)
class DummyToolCall:
def __init__(self, tool_index, name, parameters):
self.tool_index = tool_index
self.name = name
self.parameters = parameters
class DummyParser:
tool_call_parser = "kimi_k2"
detector = type("Detector", (), {"_buffer": ""})()
def parse_stream_chunk(self, text):
return "", [
DummyToolCall(0, "get_weather", '{"city":"Paris"}'),
DummyToolCall(
1, "search_gutenberg_books", '{"search_terms":["Joyce"]}'
),
]
post = SglangStreamingPostProcessor(
tokenizer=DummyTokenizer(),
tool_call_parser=DummyParser(),
reasoning_parser=None,
history_tool_calls_count=3,
tool_call_parser_name="kimi_k2",
)
choice = post.process_output(
{
"token_ids": [ord("x")],
"finish_reason": "stop",
}
)
tc = choice["delta"]["tool_calls"]
assert [item["id"] for item in tc] == [
"functions.get_weather:3",
"functions.search_gutenberg_books:4",
]
def test_kimi_reparse_uses_sequential_index_not_tool_index(self):
"""kimi_k2 IDs after re-parse use the output position, not tool_index.
``FunctionCallParser.parse_non_stream`` can return
``ToolCallItem.tool_index`` values that reflect the tool-definition
position rather than the call's sequential position. IDs must
align with the emitted ``index`` field, so they are built from
the post-processor's ``seq_idx``.
"""
class DummyTokenizer:
def decode(self, token_ids, skip_special_tokens=True):
return "".join(chr(x) for x in token_ids)
class DummyToolCall:
def __init__(self, tool_index, name, parameters):
self.tool_index = tool_index
self.name = name
self.parameters = parameters
class DummyParser:
tool_call_parser = "kimi_k2"
detector = type("Detector", (), {"_buffer": ""})()
def parse_stream_chunk(self, text):
# Streaming misses both calls — forces the re-parse path.
return "", []
def has_tool_call(self, text):
return True
def parse_non_stream(self, text):
# Non-sequential tool_index values, as parse_non_stream
# sometimes returns tool-definition positions.
return "", [
DummyToolCall(5, "get_weather", '{"city":"Paris"}'),
DummyToolCall(2, "search_gutenberg_books", '{"q":"Joyce"}'),
]
post = SglangStreamingPostProcessor(
tokenizer=DummyTokenizer(),
tool_call_parser=DummyParser(),
reasoning_parser=None,
history_tool_calls_count=3,
tool_call_parser_name="kimi_k2",
)
choice = post.process_output(
{
"token_ids": [ord("x")],
"finish_reason": "stop",
}
)
tc = choice["delta"]["tool_calls"]
# IDs must use seq_idx (0, 1) + history (3), not tool_index (5, 2).
assert [item["id"] for item in tc] == [
"functions.get_weather:3",
"functions.search_gutenberg_books:4",
]
assert [item["index"] for item in tc] == [0, 1]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# No reasoning parser # No reasoning parser
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
...@@ -276,3 +385,159 @@ class TestNoToolCalls: ...@@ -276,3 +385,159 @@ class TestNoToolCalls:
c = r.get("delta", {}).get("content", "") c = r.get("delta", {}).get("content", "")
content += c content += c
assert "Hello, world!" in content assert "Hello, world!" in content
# ---------------------------------------------------------------------------
# Single-chunk tool calls (finish-time re-parse fallback)
# ---------------------------------------------------------------------------
class TestSingleChunkFallback:
"""When all tool call tokens + finish arrive in one batch, the streaming
parser only processes one event. The finish-time re-parse must recover
arguments and any additional tool calls."""
TEXT = (
"<think>\nLet me search for books.\n</think>\n\n"
'<tool_call>\n{"name": "search_gutenberg_books", '
'"arguments": {"search_terms": ["James Joyce"]}}\n</tool_call>'
)
def test_all_tokens_plus_finish_in_one_batch(self, tokenizer):
"""Entire response + finish in a single process_output call."""
tcp = FunctionCallParser(tools=TOOLS, tool_call_parser="hermes")
rp = ReasoningParser(model_type="qwen3", stream_reasoning=True)
post = SglangStreamingPostProcessor(
tokenizer=tokenizer,
tool_call_parser=tcp,
reasoning_parser=rp,
)
token_ids = tokenizer.encode(self.TEXT)
# Feed ALL tokens at once with finish_reason
choice = post.process_output({"token_ids": token_ids, "finish_reason": "stop"})
assert choice is not None
tc = choice.get("delta", {}).get("tool_calls", [])
assert len(tc) == 1, f"Expected 1 tool call, got {len(tc)}"
assert tc[0]["function"]["name"] == "search_gutenberg_books"
args = json.loads(tc[0]["function"]["arguments"])
assert args == {"search_terms": ["James Joyce"]}
def test_multiple_tools_single_chunk(self, tokenizer):
"""Multiple tool calls in one chunk -- re-parse must find all."""
text = (
"<think>\nI'll search and check weather.\n</think>\n\n"
'<tool_call>\n{"name": "search_gutenberg_books", '
'"arguments": {"search_terms": ["Joyce"]}}\n</tool_call>\n'
'<tool_call>\n{"name": "get_weather", '
'"arguments": {"city": "London"}}\n</tool_call>'
)
tcp = FunctionCallParser(tools=TOOLS, tool_call_parser="hermes")
rp = ReasoningParser(model_type="qwen3", stream_reasoning=True)
post = SglangStreamingPostProcessor(
tokenizer=tokenizer,
tool_call_parser=tcp,
reasoning_parser=rp,
)
token_ids = tokenizer.encode(text)
choice = post.process_output({"token_ids": token_ids, "finish_reason": "stop"})
assert choice is not None
tc = choice.get("delta", {}).get("tool_calls", [])
assert len(tc) == 2, f"Expected 2 tool calls, got {len(tc)}"
names = {t["function"]["name"] for t in tc}
assert names == {"search_gutenberg_books", "get_weather"}
for t in tc:
args = json.loads(t["function"]["arguments"])
assert args, f"Arguments should not be empty for {t['function']['name']}"
def test_finish_reason_rewritten_to_tool_calls(self, tokenizer):
"""finish_reason should be 'tool_calls' when re-parse finds calls."""
tcp = FunctionCallParser(tools=TOOLS, tool_call_parser="hermes")
post = SglangStreamingPostProcessor(
tokenizer=tokenizer,
tool_call_parser=tcp,
reasoning_parser=None,
)
text = (
'<tool_call>\n{"name": "get_weather", '
'"arguments": {"city": "NYC"}}\n</tool_call>'
)
token_ids = tokenizer.encode(text)
choice = post.process_output({"token_ids": token_ids, "finish_reason": "stop"})
assert choice is not None
assert choice["finish_reason"] == "tool_calls"
# ---------------------------------------------------------------------------
# JsonArrayParser path (tool_choice="required" / named function)
# ---------------------------------------------------------------------------
class TestJsonArrayParserReparse:
"""Exercise the JsonArrayParser branch of the finish-time re-parse.
Under ``tool_choice="required"`` or a named function, guided decoding
constrains the model to emit a raw JSON array and
SglangStreamingPostProcessor is constructed with a JsonArrayParser
instead of a FunctionCallParser. The re-parse path uses
``has_tool_call`` on the parser as a cheap gate and
``_parse_json_array_buffer`` for recovery — this class locks in that
API surface so a SGLang upgrade can't silently break it.
"""
def test_single_call_reparse(self, tokenizer):
"""Full JSON array arriving in one chunk triggers the re-parse."""
text = '[{"name": "get_weather", "parameters": {"city": "NYC"}}]'
post = SglangStreamingPostProcessor(
tokenizer=tokenizer,
tool_call_parser=JsonArrayParser(),
reasoning_parser=None,
sglang_tools=TOOLS,
)
token_ids = tokenizer.encode(text)
choice = post.process_output({"token_ids": token_ids, "finish_reason": "stop"})
assert choice is not None
tc = choice.get("delta", {}).get("tool_calls", [])
assert len(tc) == 1
assert tc[0]["function"]["name"] == "get_weather"
assert json.loads(tc[0]["function"]["arguments"]) == {"city": "NYC"}
assert choice["finish_reason"] == "tool_calls"
def test_multiple_calls_reparse(self, tokenizer):
"""Multiple calls in one chunk; re-parse must recover all."""
text = (
'[{"name": "search_gutenberg_books", '
'"parameters": {"search_terms": ["Joyce"]}}, '
'{"name": "get_weather", "parameters": {"city": "London"}}]'
)
post = SglangStreamingPostProcessor(
tokenizer=tokenizer,
tool_call_parser=JsonArrayParser(),
reasoning_parser=None,
sglang_tools=TOOLS,
)
token_ids = tokenizer.encode(text)
choice = post.process_output({"token_ids": token_ids, "finish_reason": "stop"})
assert choice is not None
tc = choice.get("delta", {}).get("tool_calls", [])
assert len(tc) == 2
names = {t["function"]["name"] for t in tc}
assert names == {"search_gutenberg_books", "get_weather"}
def test_plain_text_skips_reparse(self, tokenizer):
"""Plain text with no JSON markers must not crash the re-parse path.
Locks in that the ``has_tool_call`` gate on JsonArrayParser returns
False for text without '[' or '{', so ``_parse_json_array_buffer``
and the secondary FunctionCallParser fallback are never reached.
"""
post = SglangStreamingPostProcessor(
tokenizer=tokenizer,
tool_call_parser=JsonArrayParser(),
reasoning_parser=None,
sglang_tools=TOOLS,
)
token_ids = tokenizer.encode("Hello, world!")
choice = post.process_output({"token_ids": token_ids, "finish_reason": "stop"})
# No tool calls, plain content preserved, no crash.
tc = (choice or {}).get("delta", {}).get("tool_calls", [])
assert tc == []
...@@ -448,7 +448,7 @@ class EngineFactory: ...@@ -448,7 +448,7 @@ class EngineFactory:
tokenizer_mode = getattr(self.flags, "tokenizer_mode", None) or "auto" tokenizer_mode = getattr(self.flags, "tokenizer_mode", None) or "auto"
config_format = getattr(self.flags, "config_format", None) or "auto" config_format = getattr(self.flags, "config_format", None) or "auto"
load_format = getattr(self.flags, "load_format", None) or "dummy" load_format = getattr(self.flags, "load_format", None) or "dummy"
trust_remote_code = getattr(self.flags, "trust_remote_code", False) trust_remote_code = self.config.trust_remote_code
enable_auto_tool_choice = getattr(self.flags, "enable_auto_tool_choice", False) enable_auto_tool_choice = getattr(self.flags, "enable_auto_tool_choice", False)
model_config = ModelConfig( model_config = ModelConfig(
......
...@@ -1062,6 +1062,11 @@ class BaseWorkerHandler(LoraMixin, RLMixin, BaseGenerativeHandler[RequestT, Resp ...@@ -1062,6 +1062,11 @@ class BaseWorkerHandler(LoraMixin, RLMixin, BaseGenerativeHandler[RequestT, Resp
json_schema = guided_decoding.get("json") json_schema = guided_decoding.get("json")
if json_schema is not None: if json_schema is not None:
return {"json_schema": json.dumps(json_schema)} return {"json_schema": json.dumps(json_schema)}
structural_tag = guided_decoding.get("structural_tag")
if structural_tag is not None:
if hasattr(structural_tag, "model_dump"):
structural_tag = structural_tag.model_dump()
return {"structural_tag": json.dumps(structural_tag)}
return {} return {}
@staticmethod @staticmethod
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""End-to-end tool calling tests against the Dynamo frontend.
Validates:
- Streaming protocol shape (chat.completion.chunk SSE)
- Tool-call reconstruction from streamed deltas
- Tool-call argument JSON validated against the declared JSON Schema
- tool_choice variants: auto / required / none / named function
- Multi-turn conversations carrying tool results
- Multi-tool parallel calls
"""
from __future__ import annotations
import json
import logging
import os
import shutil
import time
from dataclasses import dataclass
from typing import Any, Generator
import pytest
from tests.conftest import EtcdServer, NatsServer
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_models_api
from tests.utils.port_utils import allocate_ports
openai = pytest.importorskip("openai")
OpenAI = openai.OpenAI
jsonschema = pytest.importorskip("jsonschema")
Draft7Validator = jsonschema.Draft7Validator
logger = logging.getLogger(__name__)
MODEL_NAME = "Qwen/Qwen3-0.6B"
pytestmark = [
pytest.mark.sglang,
pytest.mark.e2e,
pytest.mark.gpu_1,
pytest.mark.integration,
pytest.mark.pre_merge,
pytest.mark.model(MODEL_NAME),
pytest.mark.timeout(300),
]
# ---------------------------------------------------------------------------
# Process management
# ---------------------------------------------------------------------------
def _check_ready(response) -> bool:
try:
return (response.json() or {}).get("status") == "ready"
except ValueError:
return False
def _prepare_log_dir(request, suffix: str) -> str:
log_dir = f"{request.node.name}_{suffix}"
shutil.rmtree(log_dir, ignore_errors=True)
return log_dir
class WorkerProcess(ManagedProcess):
"""backend worker for the tool-calling tests."""
def __init__(self, request, *, system_port: int):
env = os.environ.copy()
env["DYN_LOG"] = "info"
env["DYN_SYSTEM_PORT"] = str(system_port)
env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
super().__init__(
command=[
"python3",
"-m",
"dynamo.sglang",
"--model-path",
MODEL_NAME,
"--served-model-name",
MODEL_NAME,
"--trust-remote-code",
],
env=env,
health_check_urls=[
(f"http://localhost:{system_port}/health", _check_ready),
],
timeout=600,
display_output=True,
terminate_all_matching_process_names=False,
stragglers=["SGLANG:EngineCore"],
straggler_commands=["-m dynamo.sglang"],
log_dir=_prepare_log_dir(request, "sglang-worker"),
)
class ToolCallingFrontendProcess(ManagedProcess):
"""Frontend HTTP ingress.
SGLang-specific chat processor, tool-call parser, and reasoning parser
flags are only attached when ``sglang`` is importable in the current
environment (otherwise the frontend would fail to load them).
"""
def __init__(self, request, *, frontend_port: int):
env = os.environ.copy()
env["DYN_LOG"] = "info"
env.pop("DYN_SYSTEM_PORT", None)
command = [
"python3",
"-m",
"dynamo.frontend",
"--http-port",
str(frontend_port),
"--router-mode",
"round-robin",
"--dyn-chat-processor",
"sglang",
"--tool-call-parser",
"qwen25",
"--reasoning-parser",
"qwen3",
"--trust-remote-code",
]
super().__init__(
command=command,
env=env,
health_check_urls=[
(f"http://localhost:{frontend_port}/v1/models", check_models_api),
],
timeout=240,
display_output=True,
terminate_all_matching_process_names=False,
straggler_commands=["-m dynamo.frontend"],
log_dir=_prepare_log_dir(request, "frontend"),
)
@pytest.fixture(scope="module")
def runtime_services(request) -> Generator[None, None, None]:
"""Module-scoped NATS + Etcd for the tool calling stack.
Inlined (rather than depending on the function-scoped
``runtime_services_dynamic_ports``) so the worker + frontend processes
can be reused across all tests in this module.
"""
with NatsServer(request, port=0) as nats, EtcdServer(request, port=0) as etcd:
orig_nats = os.environ.get("NATS_SERVER")
orig_etcd = os.environ.get("ETCD_ENDPOINTS")
os.environ["NATS_SERVER"] = f"nats://localhost:{nats.port}"
os.environ["ETCD_ENDPOINTS"] = f"http://localhost:{etcd.port}"
try:
yield
finally:
if orig_nats is not None:
os.environ["NATS_SERVER"] = orig_nats
else:
os.environ.pop("NATS_SERVER", None)
if orig_etcd is not None:
os.environ["ETCD_ENDPOINTS"] = orig_etcd
else:
os.environ.pop("ETCD_ENDPOINTS", None)
@pytest.fixture(scope="module")
def tool_calling_services(
request, runtime_services, predownload_models
) -> Generator[int, None, None]:
"""Start the SGLang worker + tool-calling-aware frontend.
Yields the frontend HTTP port.
"""
frontend_port, system_port = allocate_ports(count=2, start_port=10000)
with WorkerProcess(request, system_port=system_port):
# Allow worker to register with discovery.
time.sleep(2)
with ToolCallingFrontendProcess(request, frontend_port=frontend_port):
logger.info(
"Tool calling stack ready (frontend=%d worker_system=%d)",
frontend_port,
system_port,
)
yield frontend_port
@pytest.fixture(scope="module")
def client(tool_calling_services: int) -> OpenAI:
return OpenAI(
api_key="EMPTY", base_url=f"http://localhost:{tool_calling_services}/v1"
)
@pytest.fixture(scope="module")
def model() -> str:
return MODEL_NAME
# ---------------------------------------------------------------------------
# Tool definitions
# ---------------------------------------------------------------------------
TOOLS_WEATHER = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather for a city",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string"},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["city"],
"additionalProperties": True,
},
},
}
]
TOOLS_SEARCH = [
{
"type": "function",
"function": {
"name": "search_web",
"description": "Search the web for information",
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string"},
"num_results": {"type": "integer"},
},
"required": ["query"],
"additionalProperties": True,
},
},
}
]
TOOLS_CALCULATOR = [
{
"type": "function",
"function": {
"name": "calculate",
"description": "Evaluate a mathematical expression",
"parameters": {
"type": "object",
"properties": {
"expression": {"type": "string"},
},
"required": ["expression"],
"additionalProperties": True,
},
},
}
]
TOOLS_COMPLEX_ARGS = [
{
"type": "function",
"function": {
"name": "create_event",
"description": "Create a calendar event with attendees and recurrence",
"parameters": {
"type": "object",
"properties": {
"title": {"type": "string"},
"start_time": {"type": "string"},
"end_time": {"type": "string"},
"attendees": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {"type": "string"},
"email": {"type": "string"},
"role": {
"type": "string",
"enum": ["required", "optional", "organizer"],
},
},
"required": ["email"],
"additionalProperties": True,
},
},
"recurrence": {
"type": "object",
"properties": {
"frequency": {
"type": "string",
"enum": ["daily", "weekly", "monthly"],
},
"interval": {"type": "integer"},
"count": {"type": "integer"},
},
"additionalProperties": True,
},
"location": {"type": "string"},
"description": {"type": "string"},
},
"required": ["title", "start_time", "end_time"],
"additionalProperties": True,
},
},
}
]
TOOLS_DATABASE = [
{
"type": "function",
"function": {
"name": "query_database",
"description": "Execute a SQL query against the database",
"parameters": {
"type": "object",
"properties": {
"sql": {"type": "string"},
"params": {
"type": "array",
"items": {"type": "string"},
},
"database": {
"type": "string",
"enum": ["users", "orders", "products"],
},
},
"required": ["sql", "database"],
"additionalProperties": True,
},
},
}
]
TOOLS_GET_TIME = [
{
"type": "function",
"function": {
"name": "get_time",
"description": "Get the current time in a timezone",
"parameters": {
"type": "object",
"properties": {
"timezone": {"type": "string"},
},
"required": ["timezone"],
"additionalProperties": True,
},
},
}
]
ALL_TOOLS = (
TOOLS_WEATHER
+ TOOLS_SEARCH
+ TOOLS_CALCULATOR
+ TOOLS_COMPLEX_ARGS
+ TOOLS_DATABASE
)
# ---------------------------------------------------------------------------
# Streaming helpers
# ---------------------------------------------------------------------------
def tool_schema_map(tools: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
out: dict[str, dict[str, Any]] = {}
for tool in tools:
fn = tool["function"]
out[fn["name"]] = fn["parameters"]
return out
@dataclass
class StreamResult:
content: str
reasoning_content: str
tool_calls: list[dict[str, Any]]
finish_reason: str | None
model: str
chunks: int
ttft_ms: float
raw_chunks: list[Any]
def collect_stream(stream) -> StreamResult:
content_parts: list[str] = []
reasoning_parts: list[str] = []
tool_calls_by_index: dict[int, dict[str, Any]] = {}
finish_reason = None
model = ""
chunk_count = 0
raw_chunks: list[Any] = []
t0 = time.monotonic()
ttft_ms = 0.0
for chunk in stream:
raw_chunks.append(chunk)
chunk_count += 1
if chunk_count == 1:
ttft_ms = (time.monotonic() - t0) * 1000.0
model = chunk.model
for choice in chunk.choices:
delta = choice.delta
if getattr(delta, "content", None):
content_parts.append(delta.content)
if getattr(delta, "reasoning_content", None):
reasoning_parts.append(delta.reasoning_content)
if getattr(delta, "tool_calls", None):
for tc in delta.tool_calls:
idx = tc.index
entry = tool_calls_by_index.setdefault(
idx,
{
"id": "",
"type": "function",
"function": {"name": "", "arguments": ""},
},
)
if tc.id:
if entry["id"] and entry["id"] != tc.id:
raise AssertionError(
f"Tool call id changed within same index {idx}: "
f"{entry['id']} -> {tc.id}"
)
entry["id"] = tc.id
if tc.type:
entry["type"] = tc.type
if tc.function:
if tc.function.name:
if (
entry["function"]["name"]
and entry["function"]["name"] != tc.function.name
):
raise AssertionError(
f"Tool name changed within same index {idx}: "
f"{entry['function']['name']} -> {tc.function.name}"
)
entry["function"]["name"] = tc.function.name
if tc.function.arguments:
entry["function"]["arguments"] += tc.function.arguments
if choice.finish_reason:
finish_reason = choice.finish_reason
ordered_tool_calls = [tool_calls_by_index[i] for i in sorted(tool_calls_by_index)]
return StreamResult(
content="".join(content_parts),
reasoning_content="".join(reasoning_parts),
tool_calls=ordered_tool_calls,
finish_reason=finish_reason,
model=model,
chunks=chunk_count,
ttft_ms=ttft_ms,
raw_chunks=raw_chunks,
)
def stream_chat(
client: OpenAI,
model: str,
*,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
max_tokens: int = 4096,
**kwargs,
) -> StreamResult:
req: dict[str, Any] = {
"model": model,
"messages": messages,
"stream": True,
"max_tokens": max_tokens,
}
if tools is not None:
req["tools"] = tools
req.update(kwargs)
stream = client.chat.completions.create(**req)
return collect_stream(stream)
def parse_and_validate_tool_call(
tc: dict[str, Any],
schema_by_name: dict[str, dict[str, Any]],
*,
expected_name: str | None = None,
) -> dict[str, Any]:
assert tc["type"] == "function", f"unexpected tool type: {tc['type']!r}"
assert tc["id"], "tool call id must be non-empty"
fn_name = tc["function"]["name"]
assert fn_name, "tool call function name must be non-empty"
if expected_name is not None:
assert fn_name == expected_name, f"expected {expected_name!r}, got {fn_name!r}"
assert fn_name in schema_by_name, f"unknown tool name {fn_name!r}"
args_str = tc["function"]["arguments"]
assert args_str, "tool call arguments must be non-empty"
try:
args = json.loads(args_str)
except json.JSONDecodeError as e:
raise AssertionError(f"arguments are not valid JSON: {args_str!r}") from e
assert isinstance(args, dict), f"arguments must decode to object, got {type(args)}"
validator = Draft7Validator(schema_by_name[fn_name])
errors = sorted(validator.iter_errors(args), key=lambda e: list(e.path))
if errors:
rendered = "; ".join(
f"path={list(err.path)} message={err.message}" for err in errors
)
raise AssertionError(f"arguments failed schema validation: {rendered}")
return args
def assert_finish_reason(result: StreamResult, allowed: set[str]) -> None:
assert result.finish_reason in allowed, (
f"unexpected finish_reason={result.finish_reason!r}, "
f"allowed={sorted(allowed)}"
)
def assistant_tool_message_from_result(result: StreamResult) -> dict[str, Any]:
return {
"role": "assistant",
"content": result.content or None,
"tool_calls": result.tool_calls,
}
# ---------------------------------------------------------------------------
# Protocol / contract tests
# ---------------------------------------------------------------------------
class TestToolCallingProtocol:
def test_stream_has_required_chunk_shape(self, client: OpenAI, model: str):
stream = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "What's the weather in Berlin?"}],
tools=TOOLS_WEATHER,
stream=True,
max_tokens=256,
)
chunk_count = 0
saw_finish = False
for chunk in stream:
chunk_count += 1
assert chunk.id
assert chunk.model == model or isinstance(chunk.model, str)
assert chunk.object == "chat.completion.chunk"
assert chunk.created > 0
assert len(chunk.choices) >= 1
for choice in chunk.choices:
assert choice.index == 0
if choice.finish_reason is not None:
saw_finish = True
assert choice.finish_reason in {"stop", "tool_calls", "length"}
assert chunk_count > 0
assert saw_finish, "stream never emitted a finish_reason"
def test_single_tool_call_schema_valid(self, client: OpenAI, model: str):
result = stream_chat(
client,
model,
messages=[{"role": "user", "content": "What's the weather in Tokyo?"}],
tools=TOOLS_WEATHER,
)
assert_finish_reason(result, {"tool_calls"})
assert len(result.tool_calls) >= 1
schema = tool_schema_map(TOOLS_WEATHER)
args = parse_and_validate_tool_call(
result.tool_calls[0], schema, expected_name="get_weather"
)
assert "city" in args
assert isinstance(args["city"], str)
assert args["city"]
def test_tool_choice_required_forces_a_tool_call(self, client: OpenAI, model: str):
result = stream_chat(
client,
model,
messages=[{"role": "user", "content": "Hello there."}],
tools=TOOLS_WEATHER,
tool_choice="required",
)
assert_finish_reason(result, {"tool_calls"})
assert len(result.tool_calls) >= 1
# Intent of this test: verify tool_choice=required forces a call.
# The prompt doesn't warrant a tool call, so a small model may
# hallucinate values for optional fields. Validate only that the
# call is well-formed and the required fields are present; don't
# enforce the full schema (e.g. enum values on optional fields).
schema = tool_schema_map(TOOLS_WEATHER)
for tc in result.tool_calls:
assert tc["type"] == "function"
assert tc["id"]
fn_name = tc["function"]["name"]
assert fn_name in schema, f"unknown tool name {fn_name!r}"
args = json.loads(tc["function"]["arguments"])
assert isinstance(args, dict)
for required_field in schema[fn_name].get("required", []):
assert (
required_field in args
), f"{fn_name} missing required field {required_field!r}"
def test_tool_choice_none_suppresses_tool_calls(self, client: OpenAI, model: str):
result = stream_chat(
client,
model,
messages=[{"role": "user", "content": "What's the weather in Paris?"}],
tools=TOOLS_WEATHER,
tool_choice="none",
)
assert_finish_reason(result, {"stop"})
assert result.tool_calls == []
assert result.content.strip()
def test_named_tool_choice_forces_specific_function(
self, client: OpenAI, model: str
):
result = stream_chat(
client,
model,
messages=[{"role": "user", "content": "What's the weather in Paris?"}],
tools=TOOLS_WEATHER,
tool_choice={"type": "function", "function": {"name": "get_weather"}},
)
assert_finish_reason(result, {"tool_calls"})
assert len(result.tool_calls) >= 1
schema = tool_schema_map(TOOLS_WEATHER)
for tc in result.tool_calls:
parse_and_validate_tool_call(tc, schema, expected_name="get_weather")
def test_parallel_multi_tool_request_includes_all_expected_tools(
self, client: OpenAI, model: str
):
result = stream_chat(
client,
model,
messages=[
{
"role": "user",
"content": (
"Do all three of these with tools: "
"1) weather in Paris, "
"2) search the web for latest Python release, "
"3) calculate 15 * 23 + 7."
),
}
],
tools=TOOLS_WEATHER + TOOLS_SEARCH + TOOLS_CALCULATOR,
parallel_tool_calls=True,
)
assert_finish_reason(result, {"tool_calls"})
# Models sometimes batch only a subset and emit follow-up calls in
# later turns; require at least 2 distinct tools rather than all 3.
schemas = tool_schema_map(TOOLS_WEATHER + TOOLS_SEARCH + TOOLS_CALCULATOR)
names: set[str] = set()
for tc in result.tool_calls:
parse_and_validate_tool_call(tc, schemas)
names.add(tc["function"]["name"])
assert len(names) >= 2, f"expected at least 2 distinct tools, got {names}"
def test_tool_call_ids_unique_in_single_response(self, client: OpenAI, model: str):
result = stream_chat(
client,
model,
messages=[
{
"role": "user",
"content": "Get weather for New York, London, and Tokyo.",
}
],
tools=TOOLS_WEATHER,
tool_choice="required",
parallel_tool_calls=True,
)
assert_finish_reason(result, {"tool_calls"})
ids = [tc["id"] for tc in result.tool_calls]
assert len(ids) == len(set(ids)), f"duplicate tool ids: {ids}"
def test_complex_nested_arguments_schema_valid(self, client: OpenAI, model: str):
result = stream_chat(
client,
model,
messages=[
{
"role": "user",
"content": (
"Create a weekly team standup meeting titled 'Engineering Standup' "
"from 2025-01-15T09:00:00Z to 2025-01-15T09:30:00Z. "
"Add attendees: Alice (alice@example.com, required) and "
"Bob (bob@example.com, optional). "
"Set recurrence weekly every 1 week for 10 occurrences. "
"Location: Conference Room B. "
"Use the create_event tool."
),
}
],
tools=TOOLS_COMPLEX_ARGS,
)
if result.finish_reason != "tool_calls":
pytest.skip(
"Model declined to call create_event for complex schema "
f"(finish_reason={result.finish_reason!r})"
)
schema = tool_schema_map(TOOLS_COMPLEX_ARGS)
args = parse_and_validate_tool_call(
result.tool_calls[0], schema, expected_name="create_event"
)
assert args["title"]
assert args["start_time"]
assert args["end_time"]
assert isinstance(args.get("attendees", []), list)
def test_sql_tool_arguments_schema_valid(self, client: OpenAI, model: str):
result = stream_chat(
client,
model,
messages=[
{
"role": "user",
"content": (
"Call the query_database tool. "
"Set the 'database' parameter to \"users\" and "
"set the 'sql' parameter to: "
"SELECT * FROM users WHERE name LIKE '%O''Brien%' "
"AND created_at > '2024-01-01'"
),
}
],
tools=TOOLS_DATABASE,
tool_choice={"type": "function", "function": {"name": "query_database"}},
)
assert_finish_reason(result, {"tool_calls"})
schema = tool_schema_map(TOOLS_DATABASE)
args = parse_and_validate_tool_call(
result.tool_calls[0], schema, expected_name="query_database"
)
assert args["database"] == "users"
assert "SELECT" in args["sql"].upper()
def test_array_argument_schema_valid(self, client: OpenAI, model: str):
tools = [
{
"type": "function",
"function": {
"name": "send_emails",
"description": "Send emails",
"parameters": {
"type": "object",
"properties": {
"recipients": {
"type": "array",
"items": {"type": "string"},
},
"subject": {"type": "string"},
"body": {"type": "string"},
},
"required": ["recipients", "subject", "body"],
},
},
}
]
result = stream_chat(
client,
model,
messages=[
{
"role": "user",
"content": (
"Send an email with subject 'Team Update' and body "
"'Meeting at 3pm' to alice@example.com, bob@example.com, "
"and carol@example.com."
),
}
],
tools=tools,
tool_choice={"type": "function", "function": {"name": "send_emails"}},
)
assert_finish_reason(result, {"tool_calls"})
schema = tool_schema_map(tools)
args = parse_and_validate_tool_call(
result.tool_calls[0], schema, expected_name="send_emails"
)
assert isinstance(args["recipients"], list)
assert len(args["recipients"]) >= 3
def test_no_tools_is_plain_text(self, client: OpenAI, model: str):
result = stream_chat(
client,
model,
messages=[{"role": "user", "content": "What is the capital of France?"}],
)
assert_finish_reason(result, {"stop"})
assert result.tool_calls == []
assert result.content.strip()
# ---------------------------------------------------------------------------
# Multi-turn contract tests
# ---------------------------------------------------------------------------
class TestToolCallingMultiTurn:
def test_tool_result_is_consumed_and_final_answer_is_text(
self, client: OpenAI, model: str
):
schemas = tool_schema_map(TOOLS_WEATHER)
first = stream_chat(
client,
model,
messages=[{"role": "user", "content": "What is the weather in London?"}],
tools=TOOLS_WEATHER,
)
assert_finish_reason(first, {"tool_calls"})
assert len(first.tool_calls) >= 1
parse_and_validate_tool_call(
first.tool_calls[0], schemas, expected_name="get_weather"
)
second = stream_chat(
client,
model,
messages=[
{"role": "user", "content": "What is the weather in London?"},
assistant_tool_message_from_result(first),
{
"role": "tool",
"tool_call_id": first.tool_calls[0]["id"],
"content": json.dumps(
{"temperature": 15, "unit": "celsius", "condition": "cloudy"}
),
},
],
tools=TOOLS_WEATHER,
)
assert_finish_reason(second, {"stop"})
assert second.tool_calls == []
assert second.content.strip()
assert "15" in second.content or "cloud" in second.content.lower()
def test_chained_tool_use_search_then_calculate(self, client: OpenAI, model: str):
schemas = tool_schema_map(TOOLS_SEARCH + TOOLS_CALCULATOR)
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": (
"Search the web for the population of Tokyo, "
"then calculate what 10% of that number is."
),
}
]
step1 = stream_chat(
client, model, messages=messages, tools=TOOLS_SEARCH + TOOLS_CALCULATOR
)
assert_finish_reason(step1, {"tool_calls"})
assert len(step1.tool_calls) >= 1
parse_and_validate_tool_call(step1.tool_calls[0], schemas)
messages.append(assistant_tool_message_from_result(step1))
messages.append(
{
"role": "tool",
"tool_call_id": step1.tool_calls[0]["id"],
"content": json.dumps(
{
"results": [
{"title": "Tokyo population", "snippet": "13,960,000"}
]
}
),
}
)
step2 = stream_chat(
client, model, messages=messages, tools=TOOLS_SEARCH + TOOLS_CALCULATOR
)
# Small models sometimes short-circuit and compute the answer in
# their reasoning instead of chaining a second tool call. Accept
# either path: (a) another tool call to `calculate`, or (b) a
# direct text answer containing the correct result.
assert_finish_reason(step2, {"tool_calls", "stop"})
if step2.finish_reason == "tool_calls":
assert len(step2.tool_calls) >= 1
args2 = parse_and_validate_tool_call(step2.tool_calls[0], schemas)
assert step2.tool_calls[0]["function"]["name"] == "calculate"
assert "13960000" in args2["expression"].replace(
",", ""
) or "1396000" in args2["expression"].replace(",", "")
messages.append(assistant_tool_message_from_result(step2))
messages.append(
{
"role": "tool",
"tool_call_id": step2.tool_calls[0]["id"],
"content": "1396000",
}
)
step3 = stream_chat(
client, model, messages=messages, tools=TOOLS_SEARCH + TOOLS_CALCULATOR
)
assert_finish_reason(step3, {"stop"})
assert step3.tool_calls == []
assert "1396000" in step3.content.replace(",", "")
else:
# Short-circuit path: model did the math itself. Just verify
# the final answer is present in the text.
assert step2.tool_calls == []
assert "1396000" in step2.content.replace(",", "")
def test_multiple_prior_tool_results_synthesize_to_text(
self, client: OpenAI, model: str
):
result = stream_chat(
client,
model,
messages=[
{"role": "user", "content": "Get the weather in Tokyo and Paris."},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_001",
"type": "function",
"function": {
"name": "get_weather",
"arguments": json.dumps({"city": "Tokyo"}),
},
},
{
"id": "call_002",
"type": "function",
"function": {
"name": "get_weather",
"arguments": json.dumps({"city": "Paris"}),
},
},
],
},
{
"role": "tool",
"tool_call_id": "call_001",
"content": json.dumps(
{"temperature": 22, "unit": "celsius", "condition": "sunny"}
),
},
{
"role": "tool",
"tool_call_id": "call_002",
"content": json.dumps(
{"temperature": 18, "unit": "celsius", "condition": "rainy"}
),
},
],
tools=TOOLS_WEATHER,
)
assert_finish_reason(result, {"stop"})
assert result.tool_calls == []
assert result.content.strip()
lower = result.content.lower()
assert "tokyo" in lower or "paris" in lower
# ---------------------------------------------------------------------------
# Model-behavior smoke tests
# These are intentionally looser because the model may vary.
# ---------------------------------------------------------------------------
class TestToolCallingModelBehavior:
def test_many_tools_prefers_calculator_for_math_question(
self, client: OpenAI, model: str
):
result = stream_chat(
client,
model,
messages=[
{"role": "user", "content": "What is 2^10? Use a tool if helpful."}
],
tools=ALL_TOOLS,
)
assert result.finish_reason in {"stop", "tool_calls"}
if result.finish_reason == "tool_calls":
assert len(result.tool_calls) >= 1
assert result.tool_calls[0]["function"]["name"] == "calculate"
def test_unicode_arguments_are_preserved(self, client: OpenAI, model: str):
result = stream_chat(
client,
model,
messages=[
{
"role": "user",
"content": "What's the weather in Zürich, Switzerland?",
}
],
tools=TOOLS_WEATHER,
)
assert result.finish_reason in {"stop", "tool_calls"}
if result.finish_reason == "tool_calls":
schema = tool_schema_map(TOOLS_WEATHER)
args = parse_and_validate_tool_call(
result.tool_calls[0], schema, expected_name="get_weather"
)
assert args["city"]
def test_system_instruction_encourages_tool_use(self, client: OpenAI, model: str):
result = stream_chat(
client,
model,
messages=[
{
"role": "system",
"content": (
"You are a careful weather assistant. "
"Always use the get_weather tool for weather questions."
),
},
{"role": "user", "content": "How's the weather in Sydney?"},
],
tools=TOOLS_WEATHER,
)
assert result.finish_reason in {"stop", "tool_calls"}
if result.finish_reason == "tool_calls":
schema = tool_schema_map(TOOLS_WEATHER)
parse_and_validate_tool_call(
result.tool_calls[0], schema, expected_name="get_weather"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment