"deploy/vscode:/vscode.git/clone" did not exist on "d8628cc453a03bdf4a1ee29eaf6c65da260a3f5e"
Unverified Commit 8fd7de9a authored by Richard Huo's avatar Richard Huo Committed by GitHub
Browse files

fix: update the tool calling functionalities for sglang frontend processor to...

fix: update the tool calling functionalities for sglang frontend processor to match with the latest sglang implementation (#8269)
parent 01688850
...@@ -86,6 +86,7 @@ class FrontendConfig(KvRouterConfigBase, AicPerfConfigBase): ...@@ -86,6 +86,7 @@ class FrontendConfig(KvRouterConfigBase, AicPerfConfigBase):
exclude_tools_when_tool_choice_none: bool exclude_tools_when_tool_choice_none: bool
preprocess_workers: int preprocess_workers: int
tokenizer_backend: str tokenizer_backend: str
trust_remote_code: bool
_VALID_TOKENIZER_BACKENDS = {"default", "fastokens"} _VALID_TOKENIZER_BACKENDS = {"default", "fastokens"}
...@@ -562,3 +563,14 @@ class FrontendArgGroup(ArgGroup): ...@@ -562,3 +563,14 @@ class FrontendArgGroup(ArgGroup):
), ),
choices=["default", "fastokens"], choices=["default", "fastokens"],
) )
add_negatable_bool_argument(
g,
flag_name="--trust-remote-code",
env_var="DYN_TRUST_REMOTE_CODE",
default=False,
help=(
"Trust remote code when loading the tokenizer. Required for models "
"that ship custom tokenizer code (e.g. Qwen, Falcon)."
),
)
...@@ -114,17 +114,6 @@ def parse_args() -> tuple[FrontendConfig, Optional[Namespace], Optional[Namespac ...@@ -114,17 +114,6 @@ def parse_args() -> tuple[FrontendConfig, Optional[Namespace], Optional[Namespac
vllm_flags = None vllm_flags = None
sglang_flags = None sglang_flags = None
# --trust-remote-code is only meaningful with --dyn-chat-processor vllm.
# Warn and strip it when a different (or no) chat processor is active so
# it does not propagate as an unknown-argument error below.
if "--trust-remote-code" in unknown and config.chat_processor != "vllm":
logger.warning(
"--trust-remote-code has no effect without '--dyn-chat-processor vllm'. "
"It is only supported by the vLLM chat processor. "
"Pass '--dyn-chat-processor vllm' to enable trust_remote_code."
)
unknown = [arg for arg in unknown if arg != "--trust-remote-code"]
# parse extra vllm flags using vllm native parser. # parse extra vllm flags using vllm native parser.
if config.chat_processor == "vllm": if config.chat_processor == "vllm":
try: try:
......
...@@ -32,6 +32,9 @@ from dynamo.runtime import DistributedRuntime ...@@ -32,6 +32,9 @@ from dynamo.runtime import DistributedRuntime
from .sglang_prepost import ( from .sglang_prepost import (
SglangStreamingPostProcessor, SglangStreamingPostProcessor,
ToolCallParserType,
_get_history_tool_calls_count,
convert_tools,
create_parsers, create_parsers,
preprocess_chat_request, preprocess_chat_request,
) )
...@@ -117,11 +120,12 @@ def _init_worker( ...@@ -117,11 +120,12 @@ def _init_worker(
tool_call_parser_name: str | None, tool_call_parser_name: str | None,
reasoning_parser_name: str | None, reasoning_parser_name: str | None,
exclude_tools_when_tool_choice_none: bool = True, exclude_tools_when_tool_choice_none: bool = True,
trust_remote_code: bool = False,
) -> None: ) -> None:
"""Initialize a worker process with its own tokenizer.""" """Initialize a worker process with its own tokenizer."""
global _w_tokenizer, _w_tool_call_parser_name, _w_reasoning_parser_name global _w_tokenizer, _w_tool_call_parser_name, _w_reasoning_parser_name
global _w_exclude_tools_when_tool_choice_none global _w_exclude_tools_when_tool_choice_none
_w_tokenizer = get_tokenizer(model_path) _w_tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code)
_w_tool_call_parser_name = tool_call_parser_name _w_tool_call_parser_name = tool_call_parser_name
_w_reasoning_parser_name = reasoning_parser_name _w_reasoning_parser_name = reasoning_parser_name
_w_exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none _w_exclude_tools_when_tool_choice_none = exclude_tools_when_tool_choice_none
...@@ -146,7 +150,12 @@ def _preprocess_worker( ...@@ -146,7 +150,12 @@ def _preprocess_worker(
raise PreprocessError(_unsupported_n_error(n)) raise PreprocessError(_unsupported_n_error(n))
dynamo_preproc = _build_dynamo_preproc( dynamo_preproc = _build_dynamo_preproc(
request, pre.prompt_token_ids, model_name, eos_token_id request,
pre.prompt_token_ids,
model_name,
eos_token_id,
pre.guided_decoding,
pre.tool_call_parser,
) )
return SglangPreprocessWorkerResult( return SglangPreprocessWorkerResult(
...@@ -161,6 +170,8 @@ def _build_dynamo_preproc( ...@@ -161,6 +170,8 @@ def _build_dynamo_preproc(
prompt_token_ids: list[int], prompt_token_ids: list[int],
model_name: str, model_name: str,
eos_token_id: int | None, eos_token_id: int | None,
guided_decoding: dict[str, Any] | None = None,
tool_call_parser: ToolCallParserType | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Build the Dynamo preprocessed request dict from request fields.""" """Build the Dynamo preprocessed request dict from request fields."""
max_tokens = request.get("max_completion_tokens") or request.get("max_tokens") max_tokens = request.get("max_completion_tokens") or request.get("max_tokens")
...@@ -205,11 +216,16 @@ def _build_dynamo_preproc( ...@@ -205,11 +216,16 @@ def _build_dynamo_preproc(
"top_k": request.get("top_k", 0) or -1, "top_k": request.get("top_k", 0) or -1,
"min_p": request.get("min_p", 0.0), "min_p": request.get("min_p", 0.0),
"seed": request.get("seed"), "seed": request.get("seed"),
"guided_decoding": guided_decoding,
}, },
"output_options": { "output_options": {
"logprobs": logprobs_val, "logprobs": logprobs_val,
"prompt_logprobs": None, "prompt_logprobs": None,
"skip_special_tokens": True, # Preserve special tokens only when a tool-call parser is
# actually active — the parser needs delimiter tokens
# (e.g. <|tool_call|>) to detect calls. Mirrors the
# post-processor's _skip_special_tokens logic.
"skip_special_tokens": tool_call_parser is None,
}, },
"eos_token_ids": [eos_token_id] if eos_token_id is not None else [], "eos_token_ids": [eos_token_id] if eos_token_id is not None else [],
"annotations": [], "annotations": [],
...@@ -320,7 +336,12 @@ class SglangProcessor: ...@@ -320,7 +336,12 @@ class SglangProcessor:
return return
dynamo_preproc = _build_dynamo_preproc( dynamo_preproc = _build_dynamo_preproc(
request, tokens, request["model"], self.eos_token_id request,
tokens,
request["model"],
self.eos_token_id,
pre.guided_decoding,
pre.tool_call_parser,
) )
except Exception as exc: except Exception as exc:
logger.exception("SGLang preprocessing failed for request %s", request_id) logger.exception("SGLang preprocessing failed for request %s", request_id)
...@@ -336,6 +357,11 @@ class SglangProcessor: ...@@ -336,6 +357,11 @@ class SglangProcessor:
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
tool_call_parser=pre.tool_call_parser, tool_call_parser=pre.tool_call_parser,
reasoning_parser=pre.reasoning_parser, reasoning_parser=pre.reasoning_parser,
history_tool_calls_count=_get_history_tool_calls_count(
request.get("messages", [])
),
sglang_tools=convert_tools(request.get("tools")),
tool_call_parser_name=self.tool_call_parser_name,
) )
async for item in self._generate_and_stream( async for item in self._generate_and_stream(
...@@ -389,6 +415,11 @@ class SglangProcessor: ...@@ -389,6 +415,11 @@ class SglangProcessor:
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
tool_call_parser=tool_call_parser, tool_call_parser=tool_call_parser,
reasoning_parser=reasoning_parser, reasoning_parser=reasoning_parser,
history_tool_calls_count=_get_history_tool_calls_count(
request.get("messages", [])
),
sglang_tools=convert_tools(request.get("tools")),
tool_call_parser_name=self.tool_call_parser_name,
) )
async for item in self._generate_and_stream( async for item in self._generate_and_stream(
...@@ -530,6 +561,7 @@ class SglangEngineFactory: ...@@ -530,6 +561,7 @@ class SglangEngineFactory:
self.tool_call_parser_name = tool_call_parser_name self.tool_call_parser_name = tool_call_parser_name
self.reasoning_parser_name = reasoning_parser_name self.reasoning_parser_name = reasoning_parser_name
self.trust_remote_code = config.trust_remote_code
self.stream_interval = 20 self.stream_interval = 20
raw_stream_interval = os.getenv("DYN_SGLANG_STREAM_INTERVAL") raw_stream_interval = os.getenv("DYN_SGLANG_STREAM_INTERVAL")
if raw_stream_interval: if raw_stream_interval:
...@@ -560,7 +592,7 @@ class SglangEngineFactory: ...@@ -560,7 +592,7 @@ class SglangEngineFactory:
await fetch_model(source_path, ignore_weights=True) await fetch_model(source_path, ignore_weights=True)
logger.info("Loading SGLang tokenizer from %s", source_path) logger.info("Loading SGLang tokenizer from %s", source_path)
tokenizer = get_tokenizer(source_path) tokenizer = get_tokenizer(source_path, trust_remote_code=self.trust_remote_code)
eos_token_id = getattr(tokenizer, "eos_token_id", None) eos_token_id = getattr(tokenizer, "eos_token_id", None)
...@@ -610,6 +642,7 @@ class SglangEngineFactory: ...@@ -610,6 +642,7 @@ class SglangEngineFactory:
tool_call_parser_name, tool_call_parser_name,
reasoning_parser_name, reasoning_parser_name,
self.config.exclude_tools_when_tool_choice_none, self.config.exclude_tools_when_tool_choice_none,
self.trust_remote_code,
), ),
) )
futures = [ futures = [
......
...@@ -10,7 +10,11 @@ Parallels test_vllm_unit.py for the vLLM backend. ...@@ -10,7 +10,11 @@ Parallels test_vllm_unit.py for the vLLM backend.
""" """
import json
import pytest import pytest
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.function_call.json_array_parser import JsonArrayParser
from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.srt.utils.hf_transformers_utils import get_tokenizer
import dynamo.frontend.sglang_processor as sglang_processor_module import dynamo.frontend.sglang_processor as sglang_processor_module
...@@ -18,6 +22,8 @@ from dynamo.frontend.sglang_prepost import ( ...@@ -18,6 +22,8 @@ from dynamo.frontend.sglang_prepost import (
SglangPreprocessResult, SglangPreprocessResult,
SglangStreamingPostProcessor, SglangStreamingPostProcessor,
_normalize_prompt_token_ids, _normalize_prompt_token_ids,
_parse_json_array_buffer,
build_tool_call_guided_decoding,
convert_tools, convert_tools,
create_parsers, create_parsers,
preprocess_chat_request, preprocess_chat_request,
...@@ -119,6 +125,18 @@ class TestBuildDynamoPreproc: ...@@ -119,6 +125,18 @@ class TestBuildDynamoPreproc:
assert sampling["repetition_penalty"] == 1.1 assert sampling["repetition_penalty"] == 1.1
assert sampling["seed"] == 42 assert sampling["seed"] == 42
def test_guided_decoding_passthrough(self):
result = _build_dynamo_preproc(
{"model": "test"},
prompt_token_ids=[1, 2, 3],
model_name="test",
eos_token_id=None,
guided_decoding={"json": {"type": "object"}},
)
assert result["sampling_options"]["guided_decoding"] == {
"json": {"type": "object"}
}
def test_stop_conditions_string(self): def test_stop_conditions_string(self):
"""Single stop string is wrapped in a list.""" """Single stop string is wrapped in a list."""
result = _build_dynamo_preproc( result = _build_dynamo_preproc(
...@@ -368,6 +386,92 @@ class TestCreateParsers: ...@@ -368,6 +386,92 @@ class TestCreateParsers:
assert tcp is None assert tcp is None
assert rp is not None assert rp is not None
class TestBuildToolCallGuidedDecoding:
def test_none_when_no_tools(self):
assert (
build_tool_call_guided_decoding(
{"tool_choice": "auto"},
tool_call_parser_name="hermes",
sglang_tools=None,
)
is None
)
def test_none_when_tool_choice_none(self):
tools = convert_tools(
[
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {"type": "object", "properties": {}},
},
}
]
)
assert (
build_tool_call_guided_decoding(
{"tool_choice": "none"},
tool_call_parser_name="hermes",
sglang_tools=tools,
)
is None
)
def test_required_tool_choice_builds_json_schema_guidance(self):
tools = convert_tools(
[
{
"type": "function",
"function": {
"name": "get_weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
}
]
)
guided = build_tool_call_guided_decoding(
{"tool_choice": "required"},
tool_call_parser_name="hermes",
sglang_tools=tools,
)
assert isinstance(guided, dict)
assert "json" in guided
def test_auto_strict_tools_can_build_structural_tag_guidance(self):
tools = convert_tools(
[
{
"type": "function",
"function": {
"name": "get_weather",
"strict": True,
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
"required": ["city"],
},
},
}
]
)
guided = build_tool_call_guided_decoding(
{"tool_choice": "auto"},
tool_call_parser_name="kimi_k2",
sglang_tools=tools,
)
assert isinstance(guided, dict)
assert "structural_tag" in guided
def test_tool_parser_requires_tools(self): def test_tool_parser_requires_tools(self):
"""Tool parser is not created if no tools in request.""" """Tool parser is not created if no tools in request."""
tcp, rp = create_parsers( tcp, rp = create_parsers(
...@@ -437,6 +541,170 @@ class TestCreateParsers: ...@@ -437,6 +541,170 @@ class TestCreateParsers:
assert tcp is not None assert tcp is not None
assert rp is not None assert rp is not None
def test_required_creates_json_array_parser(self):
"""tool_choice='required' creates JsonArrayParser, not FunctionCallParser."""
request = {
"tools": [
{
"type": "function",
"function": {
"name": "f",
"description": "d",
"parameters": {},
},
}
],
"tool_choice": "required",
}
tcp, _ = create_parsers(
request, tool_call_parser_name="hermes", reasoning_parser_name=None
)
assert isinstance(tcp, JsonArrayParser)
def test_named_tool_choice_creates_json_array_parser(self):
"""Named tool_choice creates JsonArrayParser."""
request = {
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {},
},
}
],
"tool_choice": {
"type": "function",
"function": {"name": "get_weather"},
},
}
tcp, _ = create_parsers(
request, tool_call_parser_name="hermes", reasoning_parser_name=None
)
assert isinstance(tcp, JsonArrayParser)
def test_auto_creates_function_call_parser(self):
"""tool_choice='auto' creates FunctionCallParser."""
request = {
"tools": [
{
"type": "function",
"function": {
"name": "f",
"description": "d",
"parameters": {},
},
}
],
"tool_choice": "auto",
}
tcp, _ = create_parsers(
request, tool_call_parser_name="hermes", reasoning_parser_name=None
)
assert isinstance(tcp, FunctionCallParser)
def test_required_without_parser_name_still_creates_json_array_parser(self):
"""tool_choice='required' doesn't need tool_call_parser_name."""
request = {
"tools": [
{
"type": "function",
"function": {
"name": "f",
"description": "d",
"parameters": {},
},
}
],
"tool_choice": "required",
}
tcp, _ = create_parsers(
request, tool_call_parser_name=None, reasoning_parser_name=None
)
assert isinstance(tcp, JsonArrayParser)
# ---------------------------------------------------------------------------
# _parse_json_array_buffer
# ---------------------------------------------------------------------------
class TestParseJsonArrayBuffer:
"""Test JSON array fallback parser for constrained decoding output."""
def test_single_tool_call(self):
buffer = json.dumps([{"name": "get_weather", "parameters": {"city": "NYC"}}])
calls = _parse_json_array_buffer(buffer)
assert len(calls) == 1
assert calls[0].name == "get_weather"
assert calls[0].tool_index == 0
assert json.loads(calls[0].parameters) == {"city": "NYC"}
def test_multiple_tool_calls(self):
buffer = json.dumps(
[
{"name": "get_weather", "parameters": {"city": "NYC"}},
{"name": "search", "parameters": {"q": "hello"}},
]
)
calls = _parse_json_array_buffer(buffer)
assert len(calls) == 2
assert calls[0].name == "get_weather"
assert calls[0].tool_index == 0
assert calls[1].name == "search"
assert calls[1].tool_index == 1
def test_arguments_key_also_accepted(self):
"""Some formats use 'arguments' instead of 'parameters'."""
buffer = json.dumps([{"name": "f", "arguments": {"x": 1}}])
calls = _parse_json_array_buffer(buffer)
assert len(calls) == 1
assert json.loads(calls[0].parameters) == {"x": 1}
def test_string_parameters_preserved(self):
buffer = json.dumps([{"name": "f", "parameters": "already_a_string"}])
calls = _parse_json_array_buffer(buffer)
assert calls[0].parameters == "already_a_string"
def test_invalid_json_returns_empty(self):
assert _parse_json_array_buffer("not json") == []
def test_non_array_returns_empty(self):
assert _parse_json_array_buffer('{"name": "f"}') == []
def test_empty_buffer_returns_empty(self):
assert _parse_json_array_buffer("") == []
def test_non_dict_items_skipped(self):
buffer = json.dumps(["not_a_dict", {"name": "f", "parameters": {}}])
calls = _parse_json_array_buffer(buffer)
assert len(calls) == 1
assert calls[0].name == "f"
assert calls[0].tool_index == 1
def test_trailing_special_token(self):
"""Trailing EOS/special tokens should not break parsing."""
buffer = '[{"name": "f", "parameters": {"x": 1}}]<|endoftext|>'
calls = _parse_json_array_buffer(buffer)
assert len(calls) == 1
assert calls[0].name == "f"
assert json.loads(calls[0].parameters) == {"x": 1}
def test_leading_text_with_array(self):
"""Leading non-JSON text before the array should be tolerated."""
buffer = 'some preamble [{"name": "f", "parameters": {"x": 1}}]'
calls = _parse_json_array_buffer(buffer)
assert len(calls) == 1
assert calls[0].name == "f"
def test_trailing_and_leading_noise(self):
"""Both leading and trailing noise."""
buffer = 'text [{"name": "g", "parameters": {"y": 2}}] <|end|>'
calls = _parse_json_array_buffer(buffer)
assert len(calls) == 1
assert calls[0].name == "g"
class TestNormalizePromptTokenIds: class TestNormalizePromptTokenIds:
def test_batch_encoding_like_object_uses_input_ids(self): def test_batch_encoding_like_object_uses_input_ids(self):
...@@ -641,6 +909,37 @@ class TestPreprocessChatRequest: ...@@ -641,6 +909,37 @@ class TestPreprocessChatRequest:
with_auto.prompt_token_ids with_auto.prompt_token_ids
), "tool_choice=none with flag off should keep tools in template" ), "tool_choice=none with flag off should keep tools in template"
def test_named_tool_choice_missing_function_raises(self, tokenizer):
"""Named tool_choice referencing a function absent from tools raises ValueError."""
request = {
"model": MODEL,
"messages": [{"role": "user", "content": "Hello"}],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {
"type": "object",
"properties": {"city": {"type": "string"}},
},
},
}
],
"tool_choice": {
"type": "function",
"function": {"name": "does_not_exist"},
},
}
with pytest.raises(ValueError, match="does_not_exist"):
preprocess_chat_request(
request,
tokenizer=tokenizer,
tool_call_parser_name="hermes",
reasoning_parser_name=None,
)
def test_init_worker_propagates_exclude_flag_true(self): def test_init_worker_propagates_exclude_flag_true(self):
"""_init_worker sets the worker-global exclude_tools flag to True.""" """_init_worker sets the worker-global exclude_tools flag to True."""
_init_worker(MODEL, None, None, exclude_tools_when_tool_choice_none=True) _init_worker(MODEL, None, None, exclude_tools_when_tool_choice_none=True)
......
...@@ -15,6 +15,7 @@ import pytest ...@@ -15,6 +15,7 @@ import pytest
from sglang.srt.entrypoints.openai.protocol import Function as SglangFunction from sglang.srt.entrypoints.openai.protocol import Function as SglangFunction
from sglang.srt.entrypoints.openai.protocol import Tool as SglangTool from sglang.srt.entrypoints.openai.protocol import Tool as SglangTool
from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.function_call.json_array_parser import JsonArrayParser
from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.parser.reasoning_parser import ReasoningParser
from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.srt.utils.hf_transformers_utils import get_tokenizer
...@@ -153,6 +154,114 @@ class TestSingleToolCall: ...@@ -153,6 +154,114 @@ class TestSingleToolCall:
assert tc[0]["index"] == 0 assert tc[0]["index"] == 0
class TestKimiToolCallIds:
def test_kimi_uses_history_adjusted_ids(self):
class DummyTokenizer:
def decode(self, token_ids, skip_special_tokens=True):
return "".join(chr(x) for x in token_ids)
class DummyToolCall:
def __init__(self, tool_index, name, parameters):
self.tool_index = tool_index
self.name = name
self.parameters = parameters
class DummyParser:
tool_call_parser = "kimi_k2"
detector = type("Detector", (), {"_buffer": ""})()
def parse_stream_chunk(self, text):
return "", [
DummyToolCall(0, "get_weather", '{"city":"Paris"}'),
DummyToolCall(
1, "search_gutenberg_books", '{"search_terms":["Joyce"]}'
),
]
post = SglangStreamingPostProcessor(
tokenizer=DummyTokenizer(),
tool_call_parser=DummyParser(),
reasoning_parser=None,
history_tool_calls_count=3,
tool_call_parser_name="kimi_k2",
)
choice = post.process_output(
{
"token_ids": [ord("x")],
"finish_reason": "stop",
}
)
tc = choice["delta"]["tool_calls"]
assert [item["id"] for item in tc] == [
"functions.get_weather:3",
"functions.search_gutenberg_books:4",
]
def test_kimi_reparse_uses_sequential_index_not_tool_index(self):
"""kimi_k2 IDs after re-parse use the output position, not tool_index.
``FunctionCallParser.parse_non_stream`` can return
``ToolCallItem.tool_index`` values that reflect the tool-definition
position rather than the call's sequential position. IDs must
align with the emitted ``index`` field, so they are built from
the post-processor's ``seq_idx``.
"""
class DummyTokenizer:
def decode(self, token_ids, skip_special_tokens=True):
return "".join(chr(x) for x in token_ids)
class DummyToolCall:
def __init__(self, tool_index, name, parameters):
self.tool_index = tool_index
self.name = name
self.parameters = parameters
class DummyParser:
tool_call_parser = "kimi_k2"
detector = type("Detector", (), {"_buffer": ""})()
def parse_stream_chunk(self, text):
# Streaming misses both calls — forces the re-parse path.
return "", []
def has_tool_call(self, text):
return True
def parse_non_stream(self, text):
# Non-sequential tool_index values, as parse_non_stream
# sometimes returns tool-definition positions.
return "", [
DummyToolCall(5, "get_weather", '{"city":"Paris"}'),
DummyToolCall(2, "search_gutenberg_books", '{"q":"Joyce"}'),
]
post = SglangStreamingPostProcessor(
tokenizer=DummyTokenizer(),
tool_call_parser=DummyParser(),
reasoning_parser=None,
history_tool_calls_count=3,
tool_call_parser_name="kimi_k2",
)
choice = post.process_output(
{
"token_ids": [ord("x")],
"finish_reason": "stop",
}
)
tc = choice["delta"]["tool_calls"]
# IDs must use seq_idx (0, 1) + history (3), not tool_index (5, 2).
assert [item["id"] for item in tc] == [
"functions.get_weather:3",
"functions.search_gutenberg_books:4",
]
assert [item["index"] for item in tc] == [0, 1]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# No reasoning parser # No reasoning parser
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
...@@ -276,3 +385,159 @@ class TestNoToolCalls: ...@@ -276,3 +385,159 @@ class TestNoToolCalls:
c = r.get("delta", {}).get("content", "") c = r.get("delta", {}).get("content", "")
content += c content += c
assert "Hello, world!" in content assert "Hello, world!" in content
# ---------------------------------------------------------------------------
# Single-chunk tool calls (finish-time re-parse fallback)
# ---------------------------------------------------------------------------
class TestSingleChunkFallback:
"""When all tool call tokens + finish arrive in one batch, the streaming
parser only processes one event. The finish-time re-parse must recover
arguments and any additional tool calls."""
TEXT = (
"<think>\nLet me search for books.\n</think>\n\n"
'<tool_call>\n{"name": "search_gutenberg_books", '
'"arguments": {"search_terms": ["James Joyce"]}}\n</tool_call>'
)
def test_all_tokens_plus_finish_in_one_batch(self, tokenizer):
"""Entire response + finish in a single process_output call."""
tcp = FunctionCallParser(tools=TOOLS, tool_call_parser="hermes")
rp = ReasoningParser(model_type="qwen3", stream_reasoning=True)
post = SglangStreamingPostProcessor(
tokenizer=tokenizer,
tool_call_parser=tcp,
reasoning_parser=rp,
)
token_ids = tokenizer.encode(self.TEXT)
# Feed ALL tokens at once with finish_reason
choice = post.process_output({"token_ids": token_ids, "finish_reason": "stop"})
assert choice is not None
tc = choice.get("delta", {}).get("tool_calls", [])
assert len(tc) == 1, f"Expected 1 tool call, got {len(tc)}"
assert tc[0]["function"]["name"] == "search_gutenberg_books"
args = json.loads(tc[0]["function"]["arguments"])
assert args == {"search_terms": ["James Joyce"]}
def test_multiple_tools_single_chunk(self, tokenizer):
"""Multiple tool calls in one chunk -- re-parse must find all."""
text = (
"<think>\nI'll search and check weather.\n</think>\n\n"
'<tool_call>\n{"name": "search_gutenberg_books", '
'"arguments": {"search_terms": ["Joyce"]}}\n</tool_call>\n'
'<tool_call>\n{"name": "get_weather", '
'"arguments": {"city": "London"}}\n</tool_call>'
)
tcp = FunctionCallParser(tools=TOOLS, tool_call_parser="hermes")
rp = ReasoningParser(model_type="qwen3", stream_reasoning=True)
post = SglangStreamingPostProcessor(
tokenizer=tokenizer,
tool_call_parser=tcp,
reasoning_parser=rp,
)
token_ids = tokenizer.encode(text)
choice = post.process_output({"token_ids": token_ids, "finish_reason": "stop"})
assert choice is not None
tc = choice.get("delta", {}).get("tool_calls", [])
assert len(tc) == 2, f"Expected 2 tool calls, got {len(tc)}"
names = {t["function"]["name"] for t in tc}
assert names == {"search_gutenberg_books", "get_weather"}
for t in tc:
args = json.loads(t["function"]["arguments"])
assert args, f"Arguments should not be empty for {t['function']['name']}"
def test_finish_reason_rewritten_to_tool_calls(self, tokenizer):
"""finish_reason should be 'tool_calls' when re-parse finds calls."""
tcp = FunctionCallParser(tools=TOOLS, tool_call_parser="hermes")
post = SglangStreamingPostProcessor(
tokenizer=tokenizer,
tool_call_parser=tcp,
reasoning_parser=None,
)
text = (
'<tool_call>\n{"name": "get_weather", '
'"arguments": {"city": "NYC"}}\n</tool_call>'
)
token_ids = tokenizer.encode(text)
choice = post.process_output({"token_ids": token_ids, "finish_reason": "stop"})
assert choice is not None
assert choice["finish_reason"] == "tool_calls"
# ---------------------------------------------------------------------------
# JsonArrayParser path (tool_choice="required" / named function)
# ---------------------------------------------------------------------------
class TestJsonArrayParserReparse:
"""Exercise the JsonArrayParser branch of the finish-time re-parse.
Under ``tool_choice="required"`` or a named function, guided decoding
constrains the model to emit a raw JSON array and
SglangStreamingPostProcessor is constructed with a JsonArrayParser
instead of a FunctionCallParser. The re-parse path uses
``has_tool_call`` on the parser as a cheap gate and
``_parse_json_array_buffer`` for recovery — this class locks in that
API surface so a SGLang upgrade can't silently break it.
"""
def test_single_call_reparse(self, tokenizer):
"""Full JSON array arriving in one chunk triggers the re-parse."""
text = '[{"name": "get_weather", "parameters": {"city": "NYC"}}]'
post = SglangStreamingPostProcessor(
tokenizer=tokenizer,
tool_call_parser=JsonArrayParser(),
reasoning_parser=None,
sglang_tools=TOOLS,
)
token_ids = tokenizer.encode(text)
choice = post.process_output({"token_ids": token_ids, "finish_reason": "stop"})
assert choice is not None
tc = choice.get("delta", {}).get("tool_calls", [])
assert len(tc) == 1
assert tc[0]["function"]["name"] == "get_weather"
assert json.loads(tc[0]["function"]["arguments"]) == {"city": "NYC"}
assert choice["finish_reason"] == "tool_calls"
def test_multiple_calls_reparse(self, tokenizer):
"""Multiple calls in one chunk; re-parse must recover all."""
text = (
'[{"name": "search_gutenberg_books", '
'"parameters": {"search_terms": ["Joyce"]}}, '
'{"name": "get_weather", "parameters": {"city": "London"}}]'
)
post = SglangStreamingPostProcessor(
tokenizer=tokenizer,
tool_call_parser=JsonArrayParser(),
reasoning_parser=None,
sglang_tools=TOOLS,
)
token_ids = tokenizer.encode(text)
choice = post.process_output({"token_ids": token_ids, "finish_reason": "stop"})
assert choice is not None
tc = choice.get("delta", {}).get("tool_calls", [])
assert len(tc) == 2
names = {t["function"]["name"] for t in tc}
assert names == {"search_gutenberg_books", "get_weather"}
def test_plain_text_skips_reparse(self, tokenizer):
"""Plain text with no JSON markers must not crash the re-parse path.
Locks in that the ``has_tool_call`` gate on JsonArrayParser returns
False for text without '[' or '{', so ``_parse_json_array_buffer``
and the secondary FunctionCallParser fallback are never reached.
"""
post = SglangStreamingPostProcessor(
tokenizer=tokenizer,
tool_call_parser=JsonArrayParser(),
reasoning_parser=None,
sglang_tools=TOOLS,
)
token_ids = tokenizer.encode("Hello, world!")
choice = post.process_output({"token_ids": token_ids, "finish_reason": "stop"})
# No tool calls, plain content preserved, no crash.
tc = (choice or {}).get("delta", {}).get("tool_calls", [])
assert tc == []
...@@ -448,7 +448,7 @@ class EngineFactory: ...@@ -448,7 +448,7 @@ class EngineFactory:
tokenizer_mode = getattr(self.flags, "tokenizer_mode", None) or "auto" tokenizer_mode = getattr(self.flags, "tokenizer_mode", None) or "auto"
config_format = getattr(self.flags, "config_format", None) or "auto" config_format = getattr(self.flags, "config_format", None) or "auto"
load_format = getattr(self.flags, "load_format", None) or "dummy" load_format = getattr(self.flags, "load_format", None) or "dummy"
trust_remote_code = getattr(self.flags, "trust_remote_code", False) trust_remote_code = self.config.trust_remote_code
enable_auto_tool_choice = getattr(self.flags, "enable_auto_tool_choice", False) enable_auto_tool_choice = getattr(self.flags, "enable_auto_tool_choice", False)
model_config = ModelConfig( model_config = ModelConfig(
......
...@@ -1062,6 +1062,11 @@ class BaseWorkerHandler(LoraMixin, RLMixin, BaseGenerativeHandler[RequestT, Resp ...@@ -1062,6 +1062,11 @@ class BaseWorkerHandler(LoraMixin, RLMixin, BaseGenerativeHandler[RequestT, Resp
json_schema = guided_decoding.get("json") json_schema = guided_decoding.get("json")
if json_schema is not None: if json_schema is not None:
return {"json_schema": json.dumps(json_schema)} return {"json_schema": json.dumps(json_schema)}
structural_tag = guided_decoding.get("structural_tag")
if structural_tag is not None:
if hasattr(structural_tag, "model_dump"):
structural_tag = structural_tag.model_dump()
return {"structural_tag": json.dumps(structural_tag)}
return {} return {}
@staticmethod @staticmethod
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment