Unverified Commit 03ce1c6e authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[Bugfix] Kimi-K2 tool parser streaming - fix token leakage, argument...


[Bugfix] Kimi-K2 tool parser streaming - fix token leakage, argument truncation, and content dropping (#38579)
Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
parent 4353c9cb
......@@ -3,14 +3,20 @@
# ruff: noqa: E501
import json
from unittest.mock import MagicMock
import pytest
from vllm.entrypoints.openai.engine.protocol import FunctionCall, ToolCall
from tests.tool_parsers.utils import (
run_tool_extraction,
run_tool_extraction_streaming,
)
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
from vllm.tokenizers import get_tokenizer
from vllm.tool_parsers.kimi_k2_tool_parser import KimiK2ToolParser
# Use a common model that is likely to be available
MODEL = "moonshotai/Kimi-K2-Instruct"
......@@ -20,959 +26,557 @@ def kimi_k2_tokenizer():
@pytest.fixture
def kimi_k2_tool_parser(kimi_k2_tokenizer):
def parser(kimi_k2_tokenizer):
return KimiK2ToolParser(kimi_k2_tokenizer)
def assert_tool_calls(
actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]
):
assert len(actual_tool_calls) == len(expected_tool_calls)
SECTION_BEGIN = "<|tool_calls_section_begin|>"
SECTION_END = "<|tool_calls_section_end|>"
TOOL_BEGIN = "<|tool_call_begin|>"
TOOL_END = "<|tool_call_end|>"
ARG_BEGIN = "<|tool_call_argument_begin|>"
def _tool(tool_id: str, args: str) -> str:
return f"{TOOL_BEGIN}{tool_id} {ARG_BEGIN}{args}{TOOL_END}"
for actual_tool_call, expected_tool_call in zip(
actual_tool_calls, expected_tool_calls
):
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
# assert tool call id format: should contain function name and numeric index
# Format can be either "functions.func_name:0" or "func_name:0"
assert actual_tool_call.id.split(":")[-1].isdigit()
assert (
actual_tool_call.id.split(":")[0].split(".")[-1]
== expected_tool_call.function.name
)
def _wrap(*tool_strs: str) -> str:
return SECTION_BEGIN + "".join(tool_strs) + SECTION_END
def run_streaming_sequence(parser, deltas):
"""Helper to simulate a streaming sequence and return results."""
previous_text = ""
previous_token_ids: list[int] = []
results = []
for delta_text, delta_token_ids in deltas:
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + delta_token_ids
result = parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=None,
class TestExtractToolCalls:
def test_no_tools(self, parser):
content, tool_calls = run_tool_extraction(
parser, "This is a test", streaming=False
)
results.append(result)
previous_text = current_text
previous_token_ids = current_token_ids
return results
def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
model_output = "This is a test"
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"tool_call_with_content_before",
"multi_tool_call_with_content_before",
"concatenated_tool_calls_bug_fix",
"three_concatenated_tool_calls",
"mixed_spacing_tool_calls",
"angle_brackets_in_json",
"newlines_in_json",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>""",
[
ToolCall(
id="functions.get_weather:0",
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"city": "Beijing",
},
),
),
type="function",
)
],
"I'll help you check the weather. ",
),
(
"""I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
functions.get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""",
[
ToolCall(
id="functions.get_weather:0",
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"city": "Beijing",
},
),
),
type="function",
),
ToolCall(
id="functions.get_weather:1",
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"city": "Shanghai",
},
),
),
type="function",
),
],
"I'll help you check the weather. ",
),
(
"""I'll get the weather and news for LA today. First, let me get the weather using Los Angeles coordinates, and then get the latest news. <|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"latitude": 34.0522, "longitude": -118.2437}<|tool_call_end|><|tool_call_begin|>functions.get_news:1<|tool_call_argument_begin|>{"content": "Los Angeles today"}<|tool_call_end|><|tool_calls_section_end|>""",
[
ToolCall(
id="functions.get_weather:0",
function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{"latitude": 34.0522, "longitude": -118.2437}
),
),
type="function",
),
ToolCall(
id="functions.get_news:1",
function=FunctionCall(
name="get_news",
arguments=json.dumps({"content": "Los Angeles today"}),
),
type="function",
),
],
"I'll get the weather and news for LA today. First, let me get the weather using Los Angeles coordinates, and then get the latest news. ",
),
(
"""I'll help you with multiple tasks. <|tool_calls_section_begin|><|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"city": "New York"}<|tool_call_end|><|tool_call_begin|>functions.get_news:1<|tool_call_argument_begin|>{"topic": "technology"}<|tool_call_end|><|tool_call_begin|>functions.send_email:2<|tool_call_argument_begin|>{"to": "user@example.com", "subject": "Daily Update"}<|tool_call_end|><|tool_calls_section_end|>""",
[
ToolCall(
id="functions.get_weather:0",
function=FunctionCall(
name="get_weather",
arguments=json.dumps({"city": "New York"}),
),
type="function",
assert content == "This is a test"
assert tool_calls == []
@pytest.mark.parametrize(
"model_output, expected_names, expected_args_list, expected_content",
[
pytest.param(
"I'll check. "
+ _wrap(_tool("functions.get_weather:0", '{"city": "Beijing"}')),
["get_weather"],
[{"city": "Beijing"}],
"I'll check. ",
id="single_tool_call",
),
pytest.param(
"Compare weather. "
+ _wrap(
_tool("functions.get_weather:0", '{"city": "Beijing"}'),
_tool("functions.get_weather:1", '{"city": "Shanghai"}'),
),
ToolCall(
id="functions.get_news:1",
function=FunctionCall(
name="get_news",
arguments=json.dumps({"topic": "technology"}),
["get_weather", "get_weather"],
[{"city": "Beijing"}, {"city": "Shanghai"}],
"Compare weather. ",
id="parallel_tool_calls",
),
pytest.param(
"Multiple tasks. "
+ _wrap(
_tool("functions.get_weather:0", '{"city": "New York"}'),
_tool("functions.get_news:1", '{"topic": "technology"}'),
_tool(
"functions.send_email:2",
'{"to": "user@example.com", "subject": "Daily Update"}',
),
type="function",
),
ToolCall(
id="functions.send_email:2",
function=FunctionCall(
name="send_email",
arguments=json.dumps(
{"to": "user@example.com", "subject": "Daily Update"}
),
),
type="function",
["get_weather", "get_news", "send_email"],
[
{"city": "New York"},
{"topic": "technology"},
{"to": "user@example.com", "subject": "Daily Update"},
],
"Multiple tasks. ",
id="three_tool_calls",
),
pytest.param(
"Process HTML. "
+ _wrap(
_tool("functions.process_html:0", '{"html": "<div>content</div>"}')
),
],
"I'll help you with multiple tasks. ",
),
(
"""Mixed spacing test. <|tool_calls_section_begin|> <|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {} <|tool_call_end|><|tool_call_begin|>functions.test2:1<|tool_call_argument_begin|>{}<|tool_call_end|> <|tool_calls_section_end|>""",
[
ToolCall(
id="functions.test:0",
function=FunctionCall(
name="test",
arguments=json.dumps({}),
),
type="function",
["process_html"],
[{"html": "<div>content</div>"}],
"Process HTML. ",
id="angle_brackets_in_json",
),
pytest.param(
"Formatted. "
+ _wrap(
_tool(
"functions.process_data:0",
'{\n "name": "test",\n "value": 123\n}',
)
),
ToolCall(
id="functions.test2:1",
function=FunctionCall(
name="test2",
arguments=json.dumps({}),
),
type="function",
),
],
"Mixed spacing test. ",
),
(
"""I need to process HTML content. <|tool_calls_section_begin|><|tool_call_begin|>functions.process_html:0<|tool_call_argument_begin|>{"html": "<div>content</div>", "text": "normal text"}<|tool_call_end|><|tool_calls_section_end|>""",
[
ToolCall(
id="functions.process_html:0",
function=FunctionCall(
name="process_html",
arguments=json.dumps(
{"html": "<div>content</div>", "text": "normal text"}
),
),
type="function",
)
],
"I need to process HTML content. ",
),
(
"""I need to process formatted JSON. <|tool_calls_section_begin|><|tool_call_begin|>functions.process_data:0<|tool_call_argument_begin|>{
"name": "test",
"value": 123,
"nested": {
"key": "value"
}
}<|tool_call_end|><|tool_calls_section_end|>""",
[
ToolCall(
id="functions.process_data:0",
function=FunctionCall(
name="process_data",
arguments=json.dumps(
{"name": "test", "value": 123, "nested": {"key": "value"}},
indent=2,
),
),
type="function",
)
],
"I need to process formatted JSON. ",
),
],
)
def test_extract_tool_calls(
kimi_k2_tool_parser, model_output, expected_tool_calls, expected_content
):
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def test_extract_tool_calls_invalid_json(kimi_k2_tool_parser):
"""we'll return every funcall result"""
model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.invalid_get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing" <|tool_call_end|> <|tool_call_begin|>
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid JSON tool calls
assert len(extracted_tool_calls.tool_calls) == 2
assert extracted_tool_calls.tool_calls[0].function.name == "invalid_get_weather"
assert extracted_tool_calls.tool_calls[1].function.name == "valid_get_weather"
def test_extract_tool_calls_invalid_funcall(kimi_k2_tool_parser):
"""we'll return every funcall result"""
model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.invalid_get_weather.0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|>
functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>"""
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
model_output, request=None
) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid JSON tool calls
assert len(extracted_tool_calls.tool_calls) == 1
assert extracted_tool_calls.tool_calls[0].function.name == "valid_get_weather"
def test_streaming_basic_functionality(kimi_k2_tool_parser):
"""Test basic streaming functionality."""
# Reset streaming state
kimi_k2_tool_parser.current_tool_name_sent = False
kimi_k2_tool_parser.prev_tool_call_arr = []
kimi_k2_tool_parser.current_tool_id = -1
kimi_k2_tool_parser.streamed_args_for_tool = []
# Test with a simple tool call
current_text = """ check the weather. <|tool_calls_section_begin|> <|tool_call_begin|>
functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>"""
# First call should handle the initial setup
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you",
current_text=current_text,
delta_text="<|tool_calls_section_end|>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# The result might be None or contain tool call information
# This depends on the internal state management
if result is not None and hasattr(result, "tool_calls") and result.tool_calls:
assert len(result.tool_calls) >= 0
def test_streaming_no_tool_calls(kimi_k2_tool_parser):
"""Test streaming when there are no tool calls."""
current_text = "This is just regular text without any tool calls."
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="This is just regular text",
current_text=current_text,
delta_text=" without any tool calls.",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Should return the delta text as content
assert result is not None
assert hasattr(result, "content")
assert result.content == " without any tool calls."
def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser):
"""
Test that text between <|tool_calls_section_begin|> and <|tool_call_begin|>
is suppressed and does not leak into reasoning_delta.
This is the main vulnerability being fixed.
"""
kimi_k2_tool_parser.reset_streaming_state()
# Get token IDs for the markers
section_begin_token_id = kimi_k2_tool_parser.vocab.get(
"<|tool_calls_section_begin|>"
)
tool_call_begin_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
# Simulate streaming sequence:
deltas = [
("I'll help you with that. ", [1, 2, 3]),
("<|tool_calls_section_begin|>", [section_begin_token_id]),
(" spurious text ", [4, 5]),
("<|tool_call_begin|>", [tool_call_begin_token_id]),
]
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
# Delta 1: "I'll help you with that. "
assert results[0] is not None
assert results[0].content == "I'll help you with that. "
# Delta 2: "<|tool_calls_section_begin|>"
# Section marker should be stripped and suppressed
assert results[1] is None or (
results[1].content is None or results[1].content == ""
)
# Delta 3: " spurious text or tokens " (THE LEAK SCENARIO)
# CRITICAL: This text should be suppressed, NOT returned as reasoning_delta
assert results[2] is None or (
results[2].content is None or results[2].content == ""
)
# Delta 4: "<|tool_call_begin|>..."
# Now we're in tool call mode, result depends on internal state
# The key is that the spurious text from Delta 3 was not leaked
def test_split_markers_across_deltas(kimi_k2_tool_parser):
"""
Test that markers split across delta chunks are correctly detected
via the rolling buffer mechanism.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_token_id = kimi_k2_tool_parser.vocab.get(
"<|tool_calls_section_begin|>"
["process_data"],
[{"name": "test", "value": 123}],
"Formatted. ",
id="multiline_json",
),
pytest.param(
"No prefix. " + _wrap(_tool("get_weather:0", '{"city": "Tokyo"}')),
["get_weather"],
[{"city": "Tokyo"}],
"No prefix. ",
id="no_functions_prefix",
),
pytest.param(
"Empty args. " + _wrap(_tool("functions.test:0", "{}")),
["test"],
[{}],
"Empty args. ",
id="empty_arguments",
),
],
)
# Delta 1: partial token, Delta 2: complete marker
deltas = [
("<|tool_calls_sec", [3]),
("tion_begin|> ", [section_begin_token_id, 4]),
]
_results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
# Now the complete marker should be detected via buffer
assert kimi_k2_tool_parser.in_tool_section is True
def test_marker_variants(kimi_k2_tool_parser):
"""Test that both singular and plural marker variants are recognized."""
kimi_k2_tool_parser.reset_streaming_state()
# Test singular variant: <|tool_call_section_begin|> (note: singular "call")
singular_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_section_begin|>")
if singular_token_id is not None: # Only test if tokenizer supports it
_result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning ",
current_text="Reasoning <|tool_call_section_begin|>",
delta_text="<|tool_call_section_begin|>",
previous_token_ids=[1, 2],
current_token_ids=[1, 2, singular_token_id],
delta_token_ids=[singular_token_id],
request=None,
def test_extract_tool_calls(
self, parser, model_output, expected_names, expected_args_list, expected_content
):
content, tool_calls = run_tool_extraction(parser, model_output, streaming=False)
assert content == expected_content
assert len(tool_calls) == len(expected_names)
for tc, name, expected_args in zip(
tool_calls, expected_names, expected_args_list
):
assert tc.type == "function"
assert tc.function.name == name
assert json.loads(tc.function.arguments) == expected_args
# id format: "something:digit"
assert tc.id.split(":")[-1].isdigit()
def test_invalid_json_still_extracted(self, parser):
"""Tool calls with invalid JSON are still returned (arguments as-is)."""
model_output = (
"Help. "
+ SECTION_BEGIN
+ _tool("functions.bad:0", '{"city": "Beijing"')
+ _tool("functions.good:1", '{"city": "Shanghai"}')
+ SECTION_END
)
# Should enter tool section mode with singular variant too
assert kimi_k2_tool_parser.in_tool_section is True
def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser):
"""
Test that after exiting a tool section with <|tool_calls_section_end|>,
subsequent text is correctly returned as reasoning content.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
deltas = [
("<|tool_calls_section_begin|>", [section_begin_id]),
("<|tool_calls_section_end|>", [section_end_id]),
(" More reasoning", [10, 11]),
]
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
assert kimi_k2_tool_parser.in_tool_section is False
assert results[2] is not None
assert results[2].content == " More reasoning"
def test_empty_tool_section(kimi_k2_tool_parser):
"""Test an empty tool section (begin immediately followed by end)."""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
# Section begin
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning ",
current_text="Reasoning <|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[1],
current_token_ids=[1, section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
# Immediate section end
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning <|tool_calls_section_begin|>",
current_text="Reasoning <|tool_calls_section_begin|><|tool_calls_section_end|>",
delta_text="<|tool_calls_section_end|>",
previous_token_ids=[1, section_begin_id],
current_token_ids=[1, section_begin_id, section_end_id],
delta_token_ids=[section_end_id],
request=None,
)
# Should exit cleanly without errors
assert kimi_k2_tool_parser.in_tool_section is False
def test_malformed_tool_section_recovery(kimi_k2_tool_parser):
"""
Test that the parser recovers from a malformed tool section
that never closes properly.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
# Enter tool section
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="<|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[],
current_token_ids=[section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
# Simulate a lot of text without proper tool calls or section end
# This should trigger the error recovery mechanism
large_text = "x" * 10000 # Exceeds max_section_chars
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|>",
current_text="<|tool_calls_section_begin|>" + large_text,
delta_text=large_text,
previous_token_ids=[section_begin_id],
current_token_ids=[section_begin_id] + list(range(100, 100 + len(large_text))),
delta_token_ids=list(range(100, 100 + len(large_text))),
request=None,
)
# Parser should have force-exited the tool section
assert kimi_k2_tool_parser.in_tool_section is False
# And returned the content as reasoning
assert result2 is not None
assert result2.content == large_text
def test_state_reset(kimi_k2_tool_parser):
"""Test that reset_streaming_state() properly clears all state."""
# Put parser in a complex state
kimi_k2_tool_parser.in_tool_section = True
kimi_k2_tool_parser.token_buffer = "some buffer"
kimi_k2_tool_parser.current_tool_id = 5
kimi_k2_tool_parser.prev_tool_call_arr = [{"id": "test"}]
kimi_k2_tool_parser.section_char_count = 1000
# Reset
kimi_k2_tool_parser.reset_streaming_state()
# Verify all state is cleared
assert kimi_k2_tool_parser.in_tool_section is False
assert kimi_k2_tool_parser.token_buffer == ""
assert kimi_k2_tool_parser.current_tool_id == -1
assert kimi_k2_tool_parser.prev_tool_call_arr == []
assert kimi_k2_tool_parser.section_char_count == 0
assert kimi_k2_tool_parser.current_tool_name_sent is False
assert kimi_k2_tool_parser.streamed_args_for_tool == []
def test_section_begin_noise_tool_begin_same_chunk(kimi_k2_tool_parser):
"""
Test that begin→noise→tool_begin within the SAME chunk suppresses
the noise text correctly (not just across chunks).
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
tool_call_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
# Single delta containing: section_begin + spurious text + tool_call_begin
combined_text = "<|tool_calls_section_begin|> noise text <|tool_call_begin|>"
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning ",
current_text="Reasoning " + combined_text,
delta_text=combined_text,
previous_token_ids=[1, 2],
current_token_ids=[1, 2, section_begin_id, 3, 4, tool_call_begin_id],
delta_token_ids=[section_begin_id, 3, 4, tool_call_begin_id],
request=None,
)
# The noise text should NOT leak into content
# Result should either be None/empty or start tool call parsing
if result is not None and result.content is not None:
# If content is returned, it should not contain the noise
assert "noise text" not in result.content
assert result.content == "" or result.content.strip() == ""
def test_stream_ends_without_section_end_marker(kimi_k2_tool_parser):
"""
Test that if the stream ends (EOF) without a proper section end marker,
the parser doesn't leak text, doesn't crash, and resets state cleanly.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
# Enter tool section
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="<|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[],
current_token_ids=[section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
# Some content in tool section
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|>",
current_text="<|tool_calls_section_begin|> partial content",
delta_text=" partial content",
previous_token_ids=[section_begin_id],
current_token_ids=[section_begin_id, 10, 11],
delta_token_ids=[10, 11],
request=None,
)
# Content should be suppressed
assert result2.content == "" or result2.content is None
# Stream ends (EOF) - no more deltas, no section_end marker
# Simulate this by manually checking state and resetting
# (In real usage, the request handler would call reset_streaming_state)
assert kimi_k2_tool_parser.in_tool_section is True # Still in section
# Reset state (as would happen between requests)
kimi_k2_tool_parser.reset_streaming_state()
# Verify clean slate
assert kimi_k2_tool_parser.in_tool_section is False
assert kimi_k2_tool_parser.token_buffer == ""
# Next request should work normally
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="New reasoning",
delta_text="New reasoning",
previous_token_ids=[],
current_token_ids=[20, 21],
delta_token_ids=[20, 21],
request=None,
)
assert result3 is not None
assert result3.content == "New reasoning"
def test_same_chunk_begin_and_end_markers(kimi_k2_tool_parser):
"""
CRITICAL TEST: Verify that when both section_begin and section_end
markers appear in the SAME chunk, the parser correctly:
1. Enters the tool section
2. Immediately exits the tool section
3. Does NOT get stuck in in_tool_section=True state
This tests the bug fix where elif was changed to if to handle
both state transitions in a single delta.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
# Single chunk with both markers (e.g., empty tool section)
combined_delta = "<|tool_calls_section_begin|><|tool_calls_section_end|>"
result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Some reasoning ",
current_text="Some reasoning " + combined_delta,
delta_text=combined_delta,
previous_token_ids=[1, 2],
current_token_ids=[1, 2, section_begin_id, section_end_id],
delta_token_ids=[section_begin_id, section_end_id],
request=None,
)
# CRITICAL: Parser should NOT be stuck in tool section
assert kimi_k2_tool_parser.in_tool_section is False, (
"Parser stuck in tool section after processing both begin/end in same chunk. "
"This indicates the elif bug was not fixed."
)
# Result should be empty or contain only stripped content
assert result is not None
assert result.content == "" or result.content is None
# Verify subsequent content streams correctly (not suppressed)
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Some reasoning " + combined_delta,
current_text="Some reasoning " + combined_delta + " More reasoning",
delta_text=" More reasoning",
previous_token_ids=[1, 2, section_begin_id, section_end_id],
current_token_ids=[1, 2, section_begin_id, section_end_id, 10, 11],
delta_token_ids=[10, 11],
request=None,
)
# This content should NOT be suppressed (we're out of tool section)
assert result2 is not None
assert result2.content == " More reasoning"
def test_same_chunk_begin_content_end_markers(kimi_k2_tool_parser):
"""
Test the same-chunk scenario with actual content between markers.
Example: <|tool_calls_section_begin|> text <|tool_calls_section_end|>
all arriving in one delta. The key is that the state machine correctly
transitions in and out within the same chunk.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
# Chunk with begin, some whitespace/noise, and end all together
# This simulates a tool section that opens and closes in the same chunk
combined_delta = "<|tool_calls_section_begin|> <|tool_calls_section_end|>"
_result = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning ",
current_text="Reasoning " + combined_delta,
delta_text=combined_delta,
previous_token_ids=[1],
current_token_ids=[1, section_begin_id, 100, section_end_id],
delta_token_ids=[section_begin_id, 100, section_end_id],
request=None,
)
# Parser should exit cleanly (not stuck in tool section)
assert kimi_k2_tool_parser.in_tool_section is False
# Verify the fix: next content should stream normally, not be suppressed
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Reasoning " + combined_delta,
current_text="Reasoning " + combined_delta + " Done",
delta_text=" Done",
previous_token_ids=[1, section_begin_id, 100, section_end_id],
current_token_ids=[1, section_begin_id, 100, section_end_id, 200],
delta_token_ids=[200],
request=None,
)
# Content after section should be returned (not suppressed)
assert result2 is not None
assert result2.content == " Done"
def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser):
"""
CRITICAL TEST (P1): Verify that when both <|tool_call_end|> and
<|tool_calls_section_end|> appear in the SAME chunk, the parser:
1. Processes the tool_call_end first (emits final arguments)
2. THEN exits the section
3. Does NOT drop the final tool call update
4. Does NOT leak special tokens into reasoning
This tests the deferred section exit fix.
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
# Simulate a streaming sequence for a SHORT tool call (all in one chunk):
combined = (
'<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} '
"<|tool_call_end|><|tool_calls_section_end|>"
)
deltas = [
("Let me help. ", [1, 2]),
("<|tool_calls_section_begin|>", [section_begin_id]),
(combined, [tool_begin_id, 10, 11, 12, tool_end_id, section_end_id]),
(" Done", [20]),
]
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
# CRITICAL: Parser should have exited section AFTER processing tool
assert kimi_k2_tool_parser.in_tool_section is False
# Tool call should have been emitted (not dropped)
if results[2] is not None and results[2].content is not None:
# Verify no special tokens leaked into content
assert "<|tool_call_end|>" not in results[2].content
assert "<|tool_calls_section_end|>" not in results[2].content
content, tool_calls = run_tool_extraction(parser, model_output, streaming=False)
assert len(tool_calls) == 2
assert tool_calls[0].function.name == "bad"
assert tool_calls[1].function.name == "good"
def test_invalid_funcall_id_skipped(self, parser):
"""Tool calls with malformed id (no colon+digit) are skipped."""
model_output = (
"Help. "
+ SECTION_BEGIN
+ _tool("functions.invalid.0", '{"city": "Beijing"}')
+ _tool("functions.valid:1", '{"city": "Shanghai"}')
+ SECTION_END
)
content, tool_calls = run_tool_extraction(parser, model_output, streaming=False)
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "valid"
def test_native_id_extracted(self, parser):
"""Regression: parser extracts native ID onto ToolCall (PR #32768)."""
model_output = "Checking weather. " + _wrap(
_tool("functions.get_weather:0", '{"city": "Tokyo"}')
)
content, tool_calls = run_tool_extraction(parser, model_output, streaming=False)
assert len(tool_calls) == 1
assert tool_calls[0].id == "functions.get_weather:0"
assert tool_calls[0].function.name == "get_weather"
assert json.loads(tool_calls[0].function.arguments) == {"city": "Tokyo"}
def test_multi_turn_native_id_continuity(self, kimi_k2_tokenizer):
"""Regression: native IDs from turn 1 preserved across turns (PR #32768)."""
turn1_parser = KimiK2ToolParser(kimi_k2_tokenizer)
turn1_output = "Let me check. " + _wrap(
_tool("functions.get_weather:0", '{"city": "Beijing"}')
)
_, turn1_tools = run_tool_extraction(
turn1_parser, turn1_output, streaming=False
)
assert len(turn1_tools) == 1
assert turn1_tools[0].id == "functions.get_weather:0"
# Content after tool section should stream normally
assert results[3] is not None
assert results[3].content == " Done"
# Fresh parser for turn 2
turn2_parser = KimiK2ToolParser(kimi_k2_tokenizer)
turn2_output = "Now let me get news. " + _wrap(
_tool("functions.get_news:0", '{"topic": "weather in Beijing"}')
)
_, turn2_tools = run_tool_extraction(
turn2_parser, turn2_output, streaming=False
)
assert len(turn2_tools) == 1
assert turn2_tools[0].id == "functions.get_news:0"
def test_streaming_tool_call_markers_not_leaked(kimi_k2_tool_parser):
"""
CRITICAL TEST: Verify that tool call markers (<|tool_call_begin|>,
<|tool_call_end|>, <|tool_call_argument_begin|>) are NOT leaked
into the content field during streaming.
def _split_tool_output_to_deltas(
content: str, tool_strs: list[tuple[str, str]]
) -> list[str]:
"""Build a list of string deltas with special tokens as separate chunks.
This reproduces the AWS Bedrock bug where tool call markers appeared
in the 'text' field of responses.
Args:
content: text before tool section
tool_strs: list of (tool_id, args_json)
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
# List of markers that should NEVER appear in content
forbidden_markers = [
"<|tool_call_begin|>",
"<|tool_call_end|>",
"<|tool_call_argument_begin|>",
"<|tool_calls_section_begin|>",
"<|tool_calls_section_end|>",
]
all_content = []
# Steps: reasoning, section begin, tool call, section end, more reasoning
tool_chunk = (
"<|tool_call_begin|> functions.get_weather:0 "
'<|tool_call_argument_begin|> {"city": "Tokyo"} <|tool_call_end|>'
)
deltas = [
("I'll check the weather. ", [1, 2, 3]),
("<|tool_calls_section_begin|>", [section_begin_id]),
(tool_chunk, [tool_begin_id, 10, 11, tool_end_id]),
("<|tool_calls_section_end|>", [section_end_id]),
(" Here's the result.", [20, 21]),
]
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
for res in results:
if res and res.content:
all_content.append(res.content)
# CRITICAL ASSERTIONS: No forbidden markers in any content
full_content = "".join(all_content)
for marker in forbidden_markers:
assert marker not in full_content, (
f"MARKER LEAK DETECTED: '{marker}' found in content. "
f"Full content: {repr(full_content)}"
deltas = [content, SECTION_BEGIN]
for tool_id, args_json in tool_strs:
deltas.extend(
[
TOOL_BEGIN,
f"{tool_id} ",
ARG_BEGIN,
f"{args_json} ",
TOOL_END,
]
)
# Also check that tool call content (function name, arguments) is not leaked
assert "get_weather" not in full_content, (
f"TOOL CALL CONTENT LEAKED: 'get_weather' found in content. "
f"Full content: {repr(full_content)}"
)
assert "Tokyo" not in full_content, (
f"TOOL CALL CONTENT LEAKED: 'Tokyo' found in content. "
f"Full content: {repr(full_content)}"
)
# Verify that legitimate content was preserved
assert "I'll check the weather." in full_content or len(all_content) > 0
deltas.append(SECTION_END)
return deltas
def test_native_id_extracted_and_placed_on_tool_call(kimi_k2_tool_parser):
"""Regression: parser extracts native ID onto ToolCall (PR #32768)."""
model_output = (
"Checking weather. "
"<|tool_calls_section_begin|>"
"<|tool_call_begin|>functions.get_weather:0"
'<|tool_call_argument_begin|>{"city": "Tokyo"}'
"<|tool_call_end|>"
"<|tool_calls_section_end|>"
)
result = kimi_k2_tool_parser.extract_tool_calls(model_output, request=None)
assert result.tools_called
assert len(result.tool_calls) == 1
tc = result.tool_calls[0]
# Native ID from model output must be used as the tool call ID
assert tc.id == "functions.get_weather:0"
assert tc.function.name == "get_weather"
assert json.loads(tc.function.arguments) == {"city": "Tokyo"}
def test_multi_turn_native_id_continuity(kimi_k2_tool_parser, kimi_k2_tokenizer):
"""Regression: native IDs from turn 1 preserved across turns (PR #32768)."""
turn1_output = (
"Let me check. "
"<|tool_calls_section_begin|>"
"<|tool_call_begin|>functions.get_weather:0"
'<|tool_call_argument_begin|>{"city": "Beijing"}'
"<|tool_call_end|>"
"<|tool_calls_section_end|>"
class TestStreamingHappyPath:
def test_single_tool_call(self, parser):
"""Verify DeltaToolCall output: name, id, arguments for one tool."""
deltas = _split_tool_output_to_deltas(
"I'll help. ",
[("functions.get_weather:0", '{"city": "Beijing"}')],
)
rec = run_tool_extraction_streaming(parser, deltas)
assert len(rec.tool_calls) == 1
tc = rec.tool_calls[0]
assert tc.function.name == "get_weather"
assert tc.id == "functions.get_weather:0"
assert json.loads(tc.function.arguments) == {"city": "Beijing"}
def test_multiple_tool_calls(self, parser):
"""Two tool calls emitted with correct indices, names, arguments."""
deltas = _split_tool_output_to_deltas(
"Compare weather. ",
[
("functions.get_weather:0", '{"city": "Tokyo"}'),
("functions.get_weather:1", '{"city": "NYC"}'),
],
)
rec = run_tool_extraction_streaming(parser, deltas)
assert len(rec.tool_calls) == 2
assert rec.tool_calls[0].function.name == "get_weather"
assert rec.tool_calls[0].id == "functions.get_weather:0"
assert json.loads(rec.tool_calls[0].function.arguments) == {"city": "Tokyo"}
assert rec.tool_calls[1].function.name == "get_weather"
assert rec.tool_calls[1].id == "functions.get_weather:1"
assert json.loads(rec.tool_calls[1].function.arguments) == {"city": "NYC"}
def test_content_before_tools(self, parser):
"""Content before section is streamed; markers/args don't leak."""
deltas = _split_tool_output_to_deltas(
"I'll check the weather. ",
[("functions.get_weather:0", '{"city": "Tokyo"}')],
)
rec = run_tool_extraction_streaming(parser, deltas)
assert "check the weather" in rec.other_content
# No markers or tool content leaked
for marker in [SECTION_BEGIN, SECTION_END, TOOL_BEGIN, TOOL_END, ARG_BEGIN]:
assert marker not in rec.other_content
assert "get_weather" not in rec.other_content
assert "Tokyo" not in rec.other_content
def test_no_tool_calls(self, parser):
"""Plain text streaming returns content only."""
deltas = ["This is just ", "regular text ", "without tools."]
rec = run_tool_extraction_streaming(parser, deltas)
assert rec.other_content == "This is just regular text without tools."
assert rec.tool_calls == []
def test_incremental_arguments(self, parser):
"""Arguments split across small chunks accumulate correctly."""
deltas = [
"Help. ",
SECTION_BEGIN,
TOOL_BEGIN,
"functions.get_weather:0 ",
ARG_BEGIN,
'{"ci',
'ty": "Be',
'ijing"}',
" ",
TOOL_END,
SECTION_END,
]
rec = run_tool_extraction_streaming(parser, deltas)
assert len(rec.tool_calls) == 1
assert rec.tool_calls[0].function.name == "get_weather"
assert json.loads(rec.tool_calls[0].function.arguments) == {"city": "Beijing"}
@pytest.mark.parametrize(
"model_output",
[
pytest.param(
"Single. "
+ _wrap(_tool("functions.get_weather:0", '{"city": "Beijing"}')),
id="single_tool",
),
pytest.param(
"Multi. "
+ _wrap(
_tool("functions.get_weather:0", '{"city": "Tokyo"}'),
_tool("functions.get_news:1", '{"topic": "tech"}'),
),
id="parallel_tools",
),
pytest.param(
"No prefix id. " + _wrap(_tool("get_weather:0", '{"city": "NYC"}')),
id="no_functions_prefix",
),
],
)
def test_streaming_matches_nonstreaming(self, parser, model_output):
"""Streaming reconstruction matches non-streaming extraction."""
content_non, tools_non = run_tool_extraction(
parser, model_output, streaming=False
)
content_stream, tools_stream = run_tool_extraction(
parser, model_output, streaming=True
)
turn1_result = kimi_k2_tool_parser.extract_tool_calls(turn1_output, request=None)
assert turn1_result.tools_called
assert turn1_result.tool_calls[0].id == "functions.get_weather:0"
# Fresh parser for turn 2
turn2_parser = KimiK2ToolParser(kimi_k2_tokenizer)
turn2_output = (
"Now let me get news. "
"<|tool_calls_section_begin|>"
"<|tool_call_begin|>functions.get_news:0"
'<|tool_call_argument_begin|>{"topic": "weather in Beijing"}'
"<|tool_call_end|>"
"<|tool_calls_section_end|>"
)
assert len(tools_non) == len(tools_stream)
for tc_non, tc_stream in zip(tools_non, tools_stream):
assert tc_non.function.name == tc_stream.function.name
assert json.loads(tc_non.function.arguments) == json.loads(
tc_stream.function.arguments
)
turn2_result = turn2_parser.extract_tool_calls(turn2_output, request=None)
assert turn2_result.tools_called
assert turn2_result.tool_calls[0].id == "functions.get_news:0"
class TestStreamingEdgeCases:
def test_marker_suppression(self, parser):
"""No special-token markers appear in reconstructed content."""
deltas = _split_tool_output_to_deltas(
"I'll check. ",
[("functions.get_weather:0", '{"city": "Tokyo"}')],
)
rec = run_tool_extraction_streaming(parser, deltas)
forbidden = [SECTION_BEGIN, SECTION_END, TOOL_BEGIN, TOOL_END, ARG_BEGIN]
for marker in forbidden:
assert marker not in rec.other_content, (
f"Marker leaked: {marker!r} in {rec.other_content!r}"
)
def test_noise_between_markers_suppressed(self, parser):
"""Text between section_begin and tool_call_begin doesn't leak."""
deltas = [
"Reasoning. ",
SECTION_BEGIN,
" spurious noise ",
TOOL_BEGIN,
"functions.test:0 ",
ARG_BEGIN,
'{"k": "v"} ',
TOOL_END,
SECTION_END,
]
rec = run_tool_extraction_streaming(parser, deltas)
assert "spurious" not in rec.other_content
assert "noise" not in rec.other_content
def test_empty_tool_section(self, parser):
"""Empty section (begin immediately followed by end) doesn't crash."""
deltas = ["Reasoning. ", SECTION_BEGIN, SECTION_END]
rec = run_tool_extraction_streaming(parser, deltas)
assert rec.tool_calls == []
def test_three_different_tools(self, parser):
"""Three tool calls with different functions stream correctly."""
deltas = _split_tool_output_to_deltas(
"Multiple tasks. ",
[
("functions.get_weather:0", '{"city": "NYC"}'),
("functions.get_news:1", '{"topic": "tech"}'),
("functions.send_email:2", '{"to": "a@b.com"}'),
],
)
rec = run_tool_extraction_streaming(parser, deltas)
assert len(rec.tool_calls) == 3
names = [tc.function.name for tc in rec.tool_calls]
assert names == ["get_weather", "get_news", "send_email"]
ids = [tc.id for tc in rec.tool_calls]
assert len(set(ids)) == 3 # unique ids
def test_truncated_tool_call_no_end_marker(self, parser):
"""Stream ending mid-tool-call (max_tokens) doesn't crash."""
deltas = [
"I'll check. ",
SECTION_BEGIN,
TOOL_BEGIN,
"functions.get_weather:0 ",
ARG_BEGIN,
'{"city": "Bei',
# Stream ends here — no TOOL_END, no SECTION_END
]
rec = run_tool_extraction_streaming(parser, deltas)
# Should not crash; tool name and partial args extracted
assert len(rec.tool_calls) == 1
assert rec.tool_calls[0].function.name == "get_weather"
assert rec.tool_calls[0].id == "functions.get_weather:0"
assert rec.tool_calls[0].function.arguments == '{"city": "Bei'
# No markers leaked into content
for marker in [SECTION_BEGIN, SECTION_END, TOOL_BEGIN, TOOL_END, ARG_BEGIN]:
assert marker not in rec.other_content
def test_content_after_tool_section(self, parser):
"""Trailing text after section_end doesn't crash or leak markers."""
deltas = [
"Before. ",
SECTION_BEGIN,
TOOL_BEGIN,
"functions.get_weather:0 ",
ARG_BEGIN,
'{"city": "Tokyo"} ',
TOOL_END,
SECTION_END,
" After tools.",
]
rec = run_tool_extraction_streaming(parser, deltas)
# Tool call extracted correctly
assert len(rec.tool_calls) == 1
assert rec.tool_calls[0].function.name == "get_weather"
assert json.loads(rec.tool_calls[0].function.arguments) == {"city": "Tokyo"}
# Trailing content after tool section is dropped
assert "After tools." not in rec.other_content
# No markers leaked into content
for marker in [SECTION_BEGIN, SECTION_END, TOOL_BEGIN, TOOL_END, ARG_BEGIN]:
assert marker not in rec.other_content
class TestAdjustRequest:
def test_sets_skip_special_tokens_false(self, parser):
request = MagicMock(spec=ChatCompletionRequest)
request.tools = [{"type": "function", "function": {"name": "test"}}]
request.tool_choice = "auto"
request.skip_special_tokens = True
result = parser.adjust_request(request)
assert result.skip_special_tokens is False
def test_no_change_when_tool_choice_none(self, parser):
request = MagicMock(spec=ChatCompletionRequest)
request.tools = [{"type": "function", "function": {"name": "test"}}]
request.tool_choice = "none"
request.skip_special_tokens = True
result = parser.adjust_request(request)
assert result.skip_special_tokens is True
def test_no_change_when_no_tools(self, parser):
request = MagicMock(spec=ChatCompletionRequest)
request.tools = None
request.tool_choice = "auto"
request.skip_special_tokens = True
result = parser.adjust_request(request)
assert result.skip_special_tokens is True
def _chunk_tokenized_deltas(tokenizer, text: str, stream_interval: int) -> list[str]:
"""Encode text, group tokens into chunks of stream_interval, decode each."""
token_ids = tokenizer.encode(text, add_special_tokens=False)
deltas = []
prev = ""
for i in range(0, len(token_ids), stream_interval):
decoded = tokenizer.decode(
token_ids[: i + stream_interval], skip_special_tokens=False
)
deltas.append(decoded[len(prev) :])
prev = decoded
return deltas
def test_streaming_multiple_tool_calls_not_leaked(kimi_k2_tool_parser):
"""
Test that MULTIPLE tool calls in streaming mode do not leak into content.
This reproduces the AWS Bedrock scenario: "Compare weather in Tokyo and NYC".
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
class TestStreamingIntervals:
"""Test streaming at various token-chunk sizes to catch boundary bugs."""
all_content = []
@pytest.mark.parametrize("stream_interval", [1, 2, 3, 5, 8])
def test_single_tool_call_at_interval(self, kimi_k2_tokenizer, stream_interval):
text = "Help. " + _wrap(_tool("functions.get_weather:0", '{"city": "Beijing"}'))
deltas = _chunk_tokenized_deltas(kimi_k2_tokenizer, text, stream_interval)
parser = KimiK2ToolParser(kimi_k2_tokenizer)
rec = run_tool_extraction_streaming(
parser, deltas, assert_one_tool_per_delta=False
)
tool1 = '<|tool_call_begin|> get_weather:0 <|tool_call_argument_begin|> {"city": "Tokyo"} <|tool_call_end|>'
tool2 = ' <|tool_call_begin|> get_weather:1 <|tool_call_argument_begin|> {"city": "New York"} <|tool_call_end|>'
assert len(rec.tool_calls) == 1
assert rec.tool_calls[0].function.name == "get_weather"
assert json.loads(rec.tool_calls[0].function.arguments) == {"city": "Beijing"}
deltas = [
("I'll compare the weather. ", [1, 2, 3]),
("<|tool_calls_section_begin|>", [section_begin_id]),
(tool1, [tool_begin_id, 10, tool_end_id]),
(tool2, [tool_begin_id, 20, tool_end_id]),
("<|tool_calls_section_end|>", [section_end_id]),
(" Here's the comparison.", [30]),
]
@pytest.mark.parametrize("stream_interval", [1, 2, 3, 5, 8])
def test_content_then_tool_call_at_interval(
self, kimi_k2_tokenizer, stream_interval
):
text = "Sure, let me check. " + _wrap(
_tool("functions.get_weather:0", '{"city": "Tokyo"}')
)
deltas = _chunk_tokenized_deltas(kimi_k2_tokenizer, text, stream_interval)
parser = KimiK2ToolParser(kimi_k2_tokenizer)
rec = run_tool_extraction_streaming(
parser, deltas, assert_one_tool_per_delta=False
)
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
assert "let me check" in rec.other_content
assert "get_weather" not in rec.other_content
assert len(rec.tool_calls) == 1
assert rec.tool_calls[0].function.name == "get_weather"
assert json.loads(rec.tool_calls[0].function.arguments) == {"city": "Tokyo"}
@pytest.mark.parametrize("stream_interval", [1, 2, 3, 5, 8])
def test_multiple_tool_calls_at_interval(self, kimi_k2_tokenizer, stream_interval):
text = "Compare. " + _wrap(
_tool("functions.search:0", '{"q": "cats"}'),
_tool("functions.search:1", '{"q": "dogs"}'),
)
deltas = _chunk_tokenized_deltas(kimi_k2_tokenizer, text, stream_interval)
parser = KimiK2ToolParser(kimi_k2_tokenizer)
rec = run_tool_extraction_streaming(
parser, deltas, assert_one_tool_per_delta=False
)
for res in results:
if res and res.content:
all_content.append(res.content)
assert len(rec.tool_calls) == 2
assert rec.tool_calls[0].function.name == "search"
assert json.loads(rec.tool_calls[0].function.arguments) == {"q": "cats"}
assert rec.tool_calls[1].function.name == "search"
assert json.loads(rec.tool_calls[1].function.arguments) == {"q": "dogs"}
@pytest.mark.parametrize("stream_interval", [1, 2, 3, 5, 8])
def test_plain_text_at_interval(self, kimi_k2_tokenizer, stream_interval):
text = "This is plain text with no tool calling involved."
deltas = _chunk_tokenized_deltas(kimi_k2_tokenizer, text, stream_interval)
parser = KimiK2ToolParser(kimi_k2_tokenizer)
rec = run_tool_extraction_streaming(
parser, deltas, assert_one_tool_per_delta=False
)
# Assertions
full_content = "".join(all_content)
assert rec.other_content == text
assert rec.tool_calls == []
# Check no markers leaked
forbidden = ["<|tool_call", "<|tool_calls_section"]
for marker in forbidden:
assert marker not in full_content, (
f"MARKER LEAKED: {marker} in {repr(full_content)}"
def test_content_and_tool_call_in_single_chunk(self, kimi_k2_tokenizer):
"""Content + complete tool call in one chunk must both be emitted."""
text = "Hi! " + _wrap(_tool("functions.get_weather:0", '{"city": "Beijing"}'))
deltas = _chunk_tokenized_deltas(kimi_k2_tokenizer, text, stream_interval=9999)
parser = KimiK2ToolParser(kimi_k2_tokenizer)
rec = run_tool_extraction_streaming(
parser, deltas, assert_one_tool_per_delta=False
)
# Check no tool call content leaked (both tools)
assert "get_weather" not in full_content, f"TOOL NAME LEAKED: {repr(full_content)}"
assert "Tokyo" not in full_content, f"TOOL ARG LEAKED (Tokyo): {repr(full_content)}"
assert "New York" not in full_content, (
f"TOOL ARG LEAKED (NYC): {repr(full_content)}"
)
# Legitimate content preserved
assert "compare" in full_content.lower() or len(all_content) > 0
assert "Hi!" in rec.other_content
assert "get_weather" not in rec.other_content
assert len(rec.tool_calls) == 1
assert rec.tool_calls[0].function.name == "get_weather"
assert json.loads(rec.tool_calls[0].function.arguments) == {"city": "Beijing"}
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# code modified from deepseekv3_tool_parser.py
from collections.abc import Sequence
......@@ -17,12 +16,14 @@ from vllm.entrypoints.openai.engine.protocol import (
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
Tool,
ToolParser,
)
from vllm.tool_parsers.utils import partial_tag_overlap
logger = init_logger(__name__)
......@@ -30,124 +31,44 @@ logger = init_logger(__name__)
class KimiK2ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.current_tool_name_sent: bool = False
# Streaming state
self._sent_content_idx: int = 0
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[
str
] = [] # map what has been streamed for each tool so far to a list
# Section-level state management to prevent token leakage
self.in_tool_section: bool = False
self.token_buffer: str = ""
# Buffer size: empirical worst-case for longest marker (~30 chars) * 2
# + safety margin for unicode + partial overlap. Prevents unbounded growth.
self.buffer_max_size: int = 1024
self.section_char_count: int = 0 # Track characters processed in tool section
self.max_section_chars: int = 8192 # Force exit if section exceeds this
self._buffer_overflow_logged: bool = False # Log overflow once per session
# Support both singular and plural variants
self.streamed_args_for_tool: list[str] = []
# Section marker
self.tool_calls_start_token: str = "<|tool_calls_section_begin|>"
self.tool_calls_end_token: str = "<|tool_calls_section_end|>"
self.tool_calls_start_token_variants: list[str] = [
"<|tool_calls_section_begin|>",
"<|tool_call_section_begin|>", # singular variant
]
self.tool_calls_end_token_variants: list[str] = [
"<|tool_calls_section_end|>",
"<|tool_call_section_end|>", # singular variant
]
# Individual tool call markers
self.tool_call_start_token: str = "<|tool_call_begin|>"
self.tool_call_end_token: str = "<|tool_call_end|>"
self.tool_call_arg_token: str = "<|tool_call_argument_begin|>"
# Regex for non-streaming extraction
self.tool_call_regex = re.compile(
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[^<]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>(?:(?!<\|tool_call_begin\|>).)*?)\s*<\|tool_call_end\|>",
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[^<]+:\d+)\s*"
r"<\|tool_call_argument_begin\|>\s*"
r"(?P<function_arguments>(?:(?!<\|tool_call_begin\|>).)*?)\s*"
r"<\|tool_call_end\|>",
re.DOTALL,
)
self.stream_tool_call_portion_regex = re.compile(
r"(?P<tool_call_id>.+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>.*)"
)
self.stream_tool_call_name_regex = re.compile(r"(?P<tool_call_id>.+:\d+)\s*")
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.tool_calls_start_token_id = self.vocab.get(self.tool_calls_start_token)
self.tool_calls_end_token_id = self.vocab.get(self.tool_calls_end_token)
# Get token IDs for all variants
self.tool_calls_start_token_ids: list[int] = [
tid
for variant in self.tool_calls_start_token_variants
if (tid := self.vocab.get(variant)) is not None
]
self.tool_calls_end_token_ids: list[int] = [
tid
for variant in self.tool_calls_end_token_variants
if (tid := self.vocab.get(variant)) is not None
]
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if (
self.tool_calls_start_token_id is None
or self.tool_calls_end_token_id is None
):
raise RuntimeError(
"Kimi-K2 Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
def _check_and_strip_markers(self, text: str) -> tuple[str, bool, bool]:
"""
Check for section begin/end markers in text and strip them.
Returns: (cleaned_text, found_section_begin, found_section_end)
"""
found_begin = False
found_end = False
cleaned = text
# Check for section begin markers (any variant)
for variant in self.tool_calls_start_token_variants:
if variant in cleaned:
cleaned = cleaned.replace(variant, "")
found_begin = True
# Check for section end markers (any variant)
for variant in self.tool_calls_end_token_variants:
if variant in cleaned:
cleaned = cleaned.replace(variant, "")
found_end = True
return cleaned, found_begin, found_end
def _reset_section_state(self) -> None:
"""Reset state when exiting tool section."""
self.in_tool_section = False
self.token_buffer = ""
self.section_char_count = 0
def reset_streaming_state(self) -> None:
"""
Reset all streaming state. Call this between requests to prevent
state leakage when parser instance is reused.
"""
# Reset section state
self._reset_section_state()
# Reset parent class state
self.current_tool_name_sent = False
self.prev_tool_call_arr = []
self.current_tool_id = -1
self.streamed_args_for_tool = []
logger.debug("Streaming state reset")
def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
) -> ChatCompletionRequest | ResponsesRequest:
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# Ensure special-token markers appear as literal text in
# current_text so we can do pure text-based parsing.
request.skip_special_tokens = False
return request
def extract_tool_calls(
self,
......@@ -198,6 +119,95 @@ class KimiK2ToolParser(ToolParser):
tools_called=False, tool_calls=[], content=model_output
)
def _extract_content(self, current_text: str) -> str | None:
"""Return unsent content before the tool-calls section, or None.
Holds back any trailing suffix that partially matches
``<|tool_calls_section_begin|>`` to avoid leaking marker bytes.
"""
if self.tool_calls_start_token not in current_text:
overlap = partial_tag_overlap(current_text, self.tool_calls_start_token)
sendable_idx = len(current_text) - overlap
else:
sendable_idx = current_text.index(self.tool_calls_start_token)
if sendable_idx > self._sent_content_idx:
content = current_text[self._sent_content_idx : sendable_idx]
self._sent_content_idx = sendable_idx
return content
return None
def _extract_tool_calls(self, current_text: str) -> list[str]:
"""Extract raw bodies from ``<|tool_call_begin|>…<|tool_call_end|>`` blocks."""
if self.tool_calls_start_token not in current_text:
return []
results: list[str] = []
pos = current_text.index(self.tool_calls_start_token)
while True:
start = current_text.find(self.tool_call_start_token, pos)
if start == -1:
break
tc_start = start + len(self.tool_call_start_token)
end = current_text.find(self.tool_call_end_token, tc_start)
if end != -1:
tool_call = current_text[tc_start:end]
pos = end + len(self.tool_call_end_token)
else:
tool_call = current_text[tc_start:]
overlap = partial_tag_overlap(tool_call, self.tool_call_end_token)
if overlap:
tool_call = tool_call[:-overlap]
results.append(tool_call)
if end == -1:
break
return results
@staticmethod
def _extract_tool_id_and_name(
header: str | None,
) -> tuple[str | None, str | None]:
"""Parse ``(tool_id, tool_name)`` from a header
like ``"functions.get_weather:0"``."""
if header is None:
return None, None
match = re.match(r"(.+:\d+)", header)
if not match:
return None, None
tool_id = match.group(1).strip()
tool_name = tool_id.split(":")[0].split(".")[-1]
return tool_id, tool_name
def _split_tool_call(self, tool_call: str) -> tuple[str | None, str | None]:
"""Split a tool-call body into ``(header, arguments)`` at the argument marker.
Example::
'get_weather:0 <|tool_call_argument_begin|>{"c'
-> ("get_weather:0", '{"c')
"""
arg_pos = tool_call.find(self.tool_call_arg_token)
if arg_pos == -1:
return None, None
header = tool_call[:arg_pos].strip()
tool_args = tool_call[arg_pos + len(self.tool_call_arg_token) :]
return header, tool_args
def _compute_args_diff(self, index: int, tool_args: str | None) -> str | None:
"""Return new argument text not yet sent for tool `index`, or None."""
if tool_args is None:
return None
prev = self.streamed_args_for_tool[index]
if len(tool_args) <= len(prev):
return None
diff = tool_args[len(prev) :]
self.streamed_args_for_tool[index] = tool_args
self.prev_tool_call_arr[index]["arguments"] = tool_args
return diff
def extract_tool_calls_streaming(
self,
previous_text: str,
......@@ -208,394 +218,59 @@ class KimiK2ToolParser(ToolParser):
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# Flag to defer section exit until after tool parsing completes
deferred_section_exit = False
# Add delta to buffer for split marker detection
self.token_buffer += delta_text
# Enforce buffer size limit to prevent memory issues
if len(self.token_buffer) > self.buffer_max_size:
if not self._buffer_overflow_logged:
logger.warning(
"Token buffer exceeded max size (%d bytes), flushing excess. "
"This may indicate very long markers or unusual tokenization.",
self.buffer_max_size,
)
self._buffer_overflow_logged = True
# Keep only the most recent content that might contain partial markers
self.token_buffer = self.token_buffer[-self.buffer_max_size // 2 :]
# Check buffer for section markers (handles split tokens)
buffered_text, found_section_begin, found_section_end = (
self._check_and_strip_markers(self.token_buffer)
)
# Track section state transitions
if found_section_begin and not self.in_tool_section:
logger.debug("Entering tool section")
self.in_tool_section = True
self.token_buffer = buffered_text # Use cleaned buffer
self.section_char_count = 0 # Reset counter for new section
if found_section_end and self.in_tool_section:
logger.debug("Detected section end marker")
# CRITICAL: Don't exit early if tool_call_end is in this chunk.
# Tool parser must emit final arguments/close first to avoid dropping
# the final tool update and leaking tokens into reasoning channel.
has_tool_end = self.tool_call_end_token_id in delta_token_ids
if has_tool_end:
# Defer exit until after tool parsing completes
deferred_section_exit = True
logger.debug("Deferring section exit: tool_call_end in same chunk")
self.token_buffer = buffered_text
else:
# No tool call ending, safe to exit immediately
logger.debug("Exiting tool section")
self._reset_section_state()
# Extract any content AFTER the section end marker in delta_text
# (don't use buffered_text as it contains tool call data)
post_section_content = ""
for variant in self.tool_calls_end_token_variants:
if variant in delta_text:
parts = delta_text.split(variant, 1)
if len(parts) > 1:
post_section_content = parts[1]
break
if post_section_content.strip():
return DeltaMessage(content=post_section_content)
return DeltaMessage(content="")
else:
self.token_buffer = buffered_text
# Check if any variant of section start token is in current_token_ids
has_section_token = any(
tid in current_token_ids for tid in self.tool_calls_start_token_ids
)
# Early return: if no section token detected yet, return as reasoning content
if not has_section_token and not self.in_tool_section:
logger.debug("No tool call tokens found!")
# Don't clear buffer - it needs to accumulate partial markers across deltas
# Buffer overflow is already protected by lines 215-224
return DeltaMessage(content=delta_text)
# Strip section markers from delta_text for subsequent processing
# NOTE: This preprocessing happens BEFORE the regex-based tool call
# parsing (from PR #24847) to ensure markers are removed cleanly
# before pattern matching. No double-stripping occurs because
# section markers and tool call markers are distinct.
delta_text, _, _ = self._check_and_strip_markers(delta_text)
# Error recovery: If in tool section for too long, force exit
if self.in_tool_section:
self.section_char_count += len(delta_text)
if self.section_char_count > self.max_section_chars:
logger.warning(
"Tool section exceeded max length (%d chars), forcing exit. "
"This may indicate malformed model output.",
self.max_section_chars,
)
self._reset_section_state()
# Deferred exit already handled by forced exit above
# Return remaining content as reasoning (or empty delta if no content)
return DeltaMessage(content=delta_text if delta_text.strip() else "")
try:
# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count = previous_token_ids.count(
self.tool_call_start_token_id
)
prev_tool_end_count = previous_token_ids.count(self.tool_call_end_token_id)
cur_tool_start_count = current_token_ids.count(
self.tool_call_start_token_id
)
cur_tool_end_count = current_token_ids.count(self.tool_call_end_token_id)
tool_call_portion = None
text_portion = None
# case: if we're generating text, OR rounding out a tool call
if (
cur_tool_start_count == cur_tool_end_count
and prev_tool_end_count == cur_tool_end_count
and self.tool_call_end_token not in delta_text
):
# Suppress content between section begin and first tool begin
# (header noise). Don't suppress content between tools to avoid
# breaking potential delimiter characters.
if self.in_tool_section and cur_tool_start_count == 0:
logger.debug(
"In tool section before first tool, suppressing: %s",
delta_text,
)
# Return empty delta to maintain iterator contract
return DeltaMessage(content="")
logger.debug("Generating text content! skipping tool parsing.")
return DeltaMessage(content=delta_text)
if self.tool_call_end_token in delta_text:
logger.debug("tool_call_end_token in delta_text")
full_text = current_text + delta_text
tool_call_portion = (
full_text.split(self.tool_call_start_token)[-1]
.split(self.tool_call_end_token)[0]
.rstrip()
)
delta_text = delta_text.split(self.tool_call_end_token)[0].rstrip()
text_portion = delta_text.split(self.tool_call_end_token)[-1].lstrip()
# case -- we're starting a new tool call
if (
cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count > prev_tool_start_count
):
if len(delta_token_ids) > 1:
tool_call_portion = current_text.split(self.tool_call_start_token)[
-1
]
else:
tool_call_portion = None
delta = None
text_portion = None
# set cursors and state appropriately
self.current_tool_id += 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("Starting on a new tool %s", self.current_tool_id)
# case -- we're updating an existing tool call
elif (
cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count == prev_tool_start_count
):
# get the portion of the text that's the tool call
tool_call_portion = current_text.split(self.tool_call_start_token)[-1]
text_portion = None
# case -- the current tool call is being closed.
elif (
cur_tool_start_count == cur_tool_end_count
and cur_tool_end_count >= prev_tool_end_count
):
if self.prev_tool_call_arr is None or len(self.prev_tool_call_arr) == 0:
logger.debug("attempting to close tool call, but no tool call")
# Handle deferred section exit before returning
if deferred_section_exit and self.in_tool_section:
self._reset_section_state()
return None
diff = self.prev_tool_call_arr[self.current_tool_id].get("arguments")
if diff:
diff = (
diff.encode("utf-8").decode("unicode_escape")
if diff is str
else diff
)
if '"}' not in delta_text:
# Handle deferred section exit before returning
if deferred_section_exit and self.in_tool_section:
self._reset_section_state()
return None
end_loc = delta_text.rindex('"}')
diff = delta_text[:end_loc] + '"}'
logger.debug(
"Finishing tool and found diff that had not "
"been streamed yet: %s",
diff,
)
self.streamed_args_for_tool[self.current_tool_id] += diff
# Handle deferred section exit before returning
if deferred_section_exit and self.in_tool_section:
logger.debug("Completing deferred section exit")
self._reset_section_state()
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=diff).model_dump(
exclude_none=True
),
)
]
)
# case -- otherwise we're just generating text
else:
# Check if we're in tool section - if so, suppress
if self.in_tool_section:
logger.debug("In tool section, suppressing text generation")
# Handle deferred section exit before returning
if deferred_section_exit:
self._reset_section_state()
return DeltaMessage(content="")
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
delta = DeltaMessage(tool_calls=[], content=text)
# Handle deferred section exit before returning
if deferred_section_exit and self.in_tool_section:
self._reset_section_state()
return delta
current_tool_call = dict()
if tool_call_portion:
current_tool_call_matches = self.stream_tool_call_portion_regex.match(
tool_call_portion
)
if current_tool_call_matches:
tool_id, tool_args = current_tool_call_matches.groups()
tool_name = tool_id.split(":")[0].split(".")[-1]
current_tool_call["id"] = tool_id.strip()
current_tool_call["name"] = tool_name
current_tool_call["arguments"] = tool_args
else:
current_tool_call_name_matches = (
self.stream_tool_call_name_regex.match(tool_call_portion)
)
if current_tool_call_name_matches:
(tool_id_str,) = current_tool_call_name_matches.groups()
tool_name = tool_id_str.split(":")[0].split(".")[-1]
current_tool_call["id"] = tool_id_str.strip()
current_tool_call["name"] = tool_name
current_tool_call["arguments"] = ""
else:
logger.debug("Not enough token")
return None
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
if not self.current_tool_name_sent:
if current_tool_call is None:
return None
function_name: str | None = current_tool_call.get("name")
tool_id = current_tool_call.get("id")
if function_name:
self.current_tool_name_sent = True
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=tool_id,
function=DeltaFunctionCall(
name=function_name
).model_dump(exclude_none=True),
)
]
# Extract any content before tool calls.
content = self._extract_content(current_text)
tool_calls = self._extract_tool_calls(current_text)
tool_call_deltas: list[DeltaToolCall] = []
for i, tool_call in enumerate(tool_calls):
# First time seeing tool call at index i.
if i >= len(self.prev_tool_call_arr):
# Initialize streaming state.
self.prev_tool_call_arr.append({})
self.streamed_args_for_tool.append("")
header, tool_args = self._split_tool_call(tool_call)
# Stream back tool name.
if "name" not in self.prev_tool_call_arr[i]:
tool_id, tool_name = self._extract_tool_id_and_name(header)
if not tool_name:
# Can't skip to tool i+1 if i isn't ready
break
self.prev_tool_call_arr[i]["name"] = tool_name
self.prev_tool_call_arr[i]["id"] = tool_id
tool_call_deltas.append(
DeltaToolCall(
index=i,
type="function",
id=tool_id,
function=DeltaFunctionCall(name=tool_name).model_dump(
exclude_none=True
),
)
)
else:
return None
# case -- otherwise, send the tool call delta
# if the tool call portion is None, send the delta as text
if tool_call_portion is None:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
# CRITICAL: Never return content if we're in a tool section
if self.in_tool_section:
return None
delta = (
DeltaMessage(content=delta_text)
if text_portion is not None
else None
)
return delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
logger.debug(
"Trying to parse current tool call with ID %s", self.current_tool_id
)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
if len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
cur_arguments = current_tool_call.get("arguments")
logger.debug("diffing old arguments: %s", prev_arguments)
logger.debug("against new ones: %s", cur_arguments)
# case -- no arguments have been created yet. skip sending a delta.
if not cur_arguments and not prev_arguments:
logger.debug("Skipping text %s - no arguments", delta_text)
delta = None
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif not cur_arguments and prev_arguments:
logger.error(
"should be impossible to have arguments reset "
"mid-call. skipping streaming anything."
)
delta = None
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif cur_arguments and not prev_arguments:
delta = DeltaMessage(
tool_calls=[
# Stream back new tool args by diffing against what was sent.
args_diff = self._compute_args_diff(i, tool_args)
if args_diff:
tool_call_deltas.append(
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=cur_arguments
).model_dump(exclude_none=True),
index=i,
function=DeltaFunctionCall(arguments=args_diff).model_dump(
exclude_none=True
),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] = cur_arguments
# last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments:
if (
isinstance(delta_text, str)
and cur_arguments != prev_arguments
and len(cur_arguments) > len(prev_arguments)
and cur_arguments.startswith(prev_arguments)
):
delta_arguments = cur_arguments[len(prev_arguments) :]
logger.debug("got diff %s", delta_text)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=delta_arguments
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] = cur_arguments
else:
delta = None
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
self.prev_tool_call_arr[self.current_tool_id] = current_tool_call
else:
self.prev_tool_call_arr.append(current_tool_call)
# Handle deferred section exit after tool parsing completes
if deferred_section_exit and self.in_tool_section:
logger.debug("Completing deferred section exit")
self._reset_section_state()
return delta
if content or tool_call_deltas:
return DeltaMessage(
content=content,
tool_calls=tool_call_deltas,
)
return None
except Exception:
logger.exception("Error trying to handle streaming tool call.")
return None # do not stream a delta. skip this token ID.
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment