Unverified Commit 358bfd31 authored by wangln19's avatar wangln19 Committed by GitHub
Browse files

fix: update kimi k2 tool parser logic (#31207)


Signed-off-by: default avatarwangln19 <wanglinian@dev.wanglinian.msh-dev.svc.cluster.local>
Signed-off-by: default avatarWang Linian <wanglinian@stu.pku.edu.cn>
Co-authored-by: default avatarwangln19 <wanglinian@dev.wanglinian.msh-dev.svc.cluster.local>
Co-authored-by: default avatarChauncey <chaunceyjiang@gmail.com>
parent 39512aba
...@@ -44,6 +44,33 @@ def assert_tool_calls( ...@@ -44,6 +44,33 @@ def assert_tool_calls(
) )
def run_streaming_sequence(parser, deltas):
"""Helper to simulate a streaming sequence and return results."""
previous_text = ""
previous_token_ids: list[int] = []
results = []
for delta_text, delta_token_ids in deltas:
current_text = previous_text + delta_text
current_token_ids = previous_token_ids + delta_token_ids
result = parser.extract_tool_calls_streaming(
previous_text=previous_text,
current_text=current_text,
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
request=None,
)
results.append(result)
previous_text = current_text
previous_token_ids = current_token_ids
return results
def test_extract_tool_calls_no_tools(kimi_k2_tool_parser): def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
model_output = "This is a test" model_output = "This is a test"
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
...@@ -346,61 +373,32 @@ def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser): ...@@ -346,61 +373,32 @@ def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser):
tool_call_begin_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>") tool_call_begin_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
# Simulate streaming sequence: # Simulate streaming sequence:
deltas = [
("I'll help you with that. ", [1, 2, 3]),
("<|tool_calls_section_begin|>", [section_begin_token_id]),
(" spurious text ", [4, 5]),
("<|tool_call_begin|>", [tool_call_begin_token_id]),
]
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
# Delta 1: "I'll help you with that. " # Delta 1: "I'll help you with that. "
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( assert results[0] is not None
previous_text="", assert results[0].content == "I'll help you with that. "
current_text="I'll help you with that. ",
delta_text="I'll help you with that. ",
previous_token_ids=[],
current_token_ids=[1, 2, 3], # Regular tokens
delta_token_ids=[1, 2, 3],
request=None,
)
assert result1 is not None
assert result1.content == "I'll help you with that. "
# Delta 2: "<|tool_calls_section_begin|>" # Delta 2: "<|tool_calls_section_begin|>"
prev_ids = [1, 2, 3]
curr_ids = prev_ids + [section_begin_token_id]
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you with that. ",
current_text="I'll help you with that. <|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=prev_ids,
current_token_ids=curr_ids,
delta_token_ids=[section_begin_token_id],
request=None,
)
# Section marker should be stripped and suppressed # Section marker should be stripped and suppressed
assert result2 is None or (result2.content is None or result2.content == "") assert results[1] is None or (
results[1].content is None or results[1].content == ""
)
# Delta 3: " spurious text or tokens " (THE LEAK SCENARIO) # Delta 3: " spurious text or tokens " (THE LEAK SCENARIO)
prev_ids = curr_ids
curr_ids = curr_ids + [4, 5]
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you with that. <|tool_calls_section_begin|>",
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
delta_text=" spurious text ",
previous_token_ids=prev_ids,
current_token_ids=curr_ids,
delta_token_ids=[4, 5],
request=None,
)
# CRITICAL: This text should be suppressed, NOT returned as reasoning_delta # CRITICAL: This text should be suppressed, NOT returned as reasoning_delta
assert result3 is None or (result3.content is None or result3.content == "") assert results[2] is None or (
results[2].content is None or results[2].content == ""
)
# Delta 4: "<|tool_call_begin|>..." # Delta 4: "<|tool_call_begin|>..."
prev_ids = curr_ids
curr_ids = curr_ids + [tool_call_begin_token_id]
_result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text <|tool_call_begin|>",
delta_text="<|tool_call_begin|>",
previous_token_ids=prev_ids,
current_token_ids=curr_ids,
delta_token_ids=[tool_call_begin_token_id],
request=None,
)
# Now we're in tool call mode, result depends on internal state # Now we're in tool call mode, result depends on internal state
# The key is that the spurious text from Delta 3 was not leaked # The key is that the spurious text from Delta 3 was not leaked
...@@ -416,31 +414,15 @@ def test_split_markers_across_deltas(kimi_k2_tool_parser): ...@@ -416,31 +414,15 @@ def test_split_markers_across_deltas(kimi_k2_tool_parser):
"<|tool_calls_section_begin|>" "<|tool_calls_section_begin|>"
) )
# Delta 1: "...reasoning<|tool_calls_sec" # Delta 1: partial token, Delta 2: complete marker
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( deltas = [
previous_text="Some reasoning", ("<|tool_calls_sec", [3]),
current_text="Some reasoning<|tool_calls_sec", ("tion_begin|> ", [section_begin_token_id, 4]),
delta_text="<|tool_calls_sec", ]
previous_token_ids=[1, 2],
current_token_ids=[1, 2, 3], # Partial token _results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
delta_token_ids=[3],
request=None,
)
# Partial token not recognized yet, might be buffered
# Should return as content or None (depends on implementation)
# Delta 2: "tion_begin|> " (completes the marker)
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Some reasoning<|tool_calls_sec",
current_text="Some reasoning<|tool_calls_section_begin|> ",
delta_text="tion_begin|> ",
previous_token_ids=[1, 2, 3],
current_token_ids=[1, 2, section_begin_token_id, 4],
delta_token_ids=[section_begin_token_id, 4],
request=None,
)
# Now the complete marker should be detected via buffer # Now the complete marker should be detected via buffer
# The parser should enter tool section mode
assert kimi_k2_tool_parser.in_tool_section is True assert kimi_k2_tool_parser.in_tool_section is True
...@@ -475,42 +457,17 @@ def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser): ...@@ -475,42 +457,17 @@ def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser):
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>") section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>") section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
# Enter tool section deltas = [
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming( ("<|tool_calls_section_begin|>", [section_begin_id]),
previous_text="", ("<|tool_calls_section_end|>", [section_end_id]),
current_text="<|tool_calls_section_begin|>", (" More reasoning", [10, 11]),
delta_text="<|tool_calls_section_begin|>", ]
previous_token_ids=[],
current_token_ids=[section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
# Exit tool section results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="<|tool_calls_section_begin|>",
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
delta_text="<|tool_calls_section_end|>",
previous_token_ids=[section_begin_id],
current_token_ids=[section_begin_id, section_end_id],
delta_token_ids=[section_end_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is False
# Subsequent reasoning text should be returned normally assert kimi_k2_tool_parser.in_tool_section is False
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming( assert results[2] is not None
previous_text="<|tool_calls_section_begin|><|tool_calls_section_end|>", assert results[2].content == " More reasoning"
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|> More reasoning",
delta_text=" More reasoning",
previous_token_ids=[section_begin_id, section_end_id],
current_token_ids=[section_begin_id, section_end_id, 10, 11],
delta_token_ids=[10, 11],
request=None,
)
assert result3 is not None
assert result3.content == " More reasoning"
def test_empty_tool_section(kimi_k2_tool_parser): def test_empty_tool_section(kimi_k2_tool_parser):
...@@ -819,106 +776,150 @@ def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser): ...@@ -819,106 +776,150 @@ def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser):
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>") tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
# Simulate a streaming sequence for a SHORT tool call (all in one chunk): # Simulate a streaming sequence for a SHORT tool call (all in one chunk):
# 1. Reasoning text
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text="Let me help. ",
delta_text="Let me help. ",
previous_token_ids=[],
current_token_ids=[1, 2],
delta_token_ids=[1, 2],
request=None,
)
assert result1 is not None
assert result1.content == "Let me help. "
# 2. Section begin
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
previous_text="Let me help. ",
current_text="Let me help. <|tool_calls_section_begin|>",
delta_text="<|tool_calls_section_begin|>",
previous_token_ids=[1, 2],
current_token_ids=[1, 2, section_begin_id],
delta_token_ids=[section_begin_id],
request=None,
)
assert kimi_k2_tool_parser.in_tool_section is True
# 3. Tool call begin + full content + tool_end + section_end ALL IN ONE CHUNK
# This is the critical scenario for short tool calls
combined = ( combined = (
'<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} ' '<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} '
"<|tool_call_end|><|tool_calls_section_end|>" "<|tool_call_end|><|tool_calls_section_end|>"
) )
# Build up the previous text gradually to simulate realistic streaming deltas = [
prev_text = "Let me help. <|tool_calls_section_begin|>" ("Let me help. ", [1, 2]),
curr_text = prev_text + combined ("<|tool_calls_section_begin|>", [section_begin_id]),
(combined, [tool_begin_id, 10, 11, 12, tool_end_id, section_end_id]),
(" Done", [20]),
]
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming( results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
previous_text=prev_text,
current_text=curr_text,
delta_text=combined,
previous_token_ids=[1, 2, section_begin_id],
current_token_ids=[
1,
2,
section_begin_id,
tool_begin_id,
10,
11,
12,
tool_end_id,
section_end_id,
],
delta_token_ids=[tool_begin_id, 10, 11, 12, tool_end_id, section_end_id],
request=None,
)
# CRITICAL: Parser should have exited section AFTER processing tool # CRITICAL: Parser should have exited section AFTER processing tool
assert kimi_k2_tool_parser.in_tool_section is False assert kimi_k2_tool_parser.in_tool_section is False
# Tool call should have been emitted (not dropped) # Tool call should have been emitted (not dropped)
# The result might be the tool name or None depending on state, but if results[2] is not None and results[2].content is not None:
# importantly, it shouldn't be returning the literal tokens as content
if result3 is not None and result3.content is not None:
# Verify no special tokens leaked into content # Verify no special tokens leaked into content
assert "<|tool_call_end|>" not in result3.content assert "<|tool_call_end|>" not in results[2].content
assert "<|tool_calls_section_end|>" not in result3.content assert "<|tool_calls_section_end|>" not in results[2].content
# 4. Verify subsequent content streams normally # Content after tool section should stream normally
result4 = kimi_k2_tool_parser.extract_tool_calls_streaming( assert results[3] is not None
previous_text=curr_text, assert results[3].content == " Done"
current_text=curr_text + " Done",
delta_text=" Done",
previous_token_ids=[ def test_streaming_tool_call_markers_not_leaked(kimi_k2_tool_parser):
1, """
2, CRITICAL TEST: Verify that tool call markers (<|tool_call_begin|>,
section_begin_id, <|tool_call_end|>, <|tool_call_argument_begin|>) are NOT leaked
tool_begin_id, into the content field during streaming.
10,
11, This reproduces the AWS Bedrock bug where tool call markers appeared
12, in the 'text' field of responses.
tool_end_id, """
section_end_id, kimi_k2_tool_parser.reset_streaming_state()
],
current_token_ids=[ section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
1, section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
2, tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
section_begin_id, tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
tool_begin_id,
10, # List of markers that should NEVER appear in content
11, forbidden_markers = [
12, "<|tool_call_begin|>",
tool_end_id, "<|tool_call_end|>",
section_end_id, "<|tool_call_argument_begin|>",
20, "<|tool_calls_section_begin|>",
], "<|tool_calls_section_end|>",
delta_token_ids=[20], ]
request=None,
all_content = []
# Steps: reasoning, section begin, tool call, section end, more reasoning
tool_chunk = (
"<|tool_call_begin|> functions.get_weather:0 "
'<|tool_call_argument_begin|> {"city": "Tokyo"} <|tool_call_end|>'
) )
deltas = [
("I'll check the weather. ", [1, 2, 3]),
("<|tool_calls_section_begin|>", [section_begin_id]),
(tool_chunk, [tool_begin_id, 10, 11, tool_end_id]),
("<|tool_calls_section_end|>", [section_end_id]),
(" Here's the result.", [20, 21]),
]
# Content after tool section should stream normally results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
assert result4 is not None
assert result4.content == " Done" for res in results:
if res and res.content:
all_content.append(res.content)
# CRITICAL ASSERTIONS: No forbidden markers in any content
full_content = "".join(all_content)
for marker in forbidden_markers:
assert marker not in full_content, (
f"MARKER LEAK DETECTED: '{marker}' found in content. "
f"Full content: {repr(full_content)}"
)
# Also check that tool call content (function name, arguments) is not leaked
assert "get_weather" not in full_content, (
f"TOOL CALL CONTENT LEAKED: 'get_weather' found in content. "
f"Full content: {repr(full_content)}"
)
assert "Tokyo" not in full_content, (
f"TOOL CALL CONTENT LEAKED: 'Tokyo' found in content. "
f"Full content: {repr(full_content)}"
)
# Verify that legitimate content was preserved
assert "I'll check the weather." in full_content or len(all_content) > 0
def test_streaming_multiple_tool_calls_not_leaked(kimi_k2_tool_parser):
"""
Test that MULTIPLE tool calls in streaming mode do not leak into content.
This reproduces the AWS Bedrock scenario: "Compare weather in Tokyo and NYC".
"""
kimi_k2_tool_parser.reset_streaming_state()
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
all_content = []
tool1 = '<|tool_call_begin|> get_weather:0 <|tool_call_argument_begin|> {"city": "Tokyo"} <|tool_call_end|>'
tool2 = ' <|tool_call_begin|> get_weather:1 <|tool_call_argument_begin|> {"city": "New York"} <|tool_call_end|>'
deltas = [
("I'll compare the weather. ", [1, 2, 3]),
("<|tool_calls_section_begin|>", [section_begin_id]),
(tool1, [tool_begin_id, 10, tool_end_id]),
(tool2, [tool_begin_id, 20, tool_end_id]),
("<|tool_calls_section_end|>", [section_end_id]),
(" Here's the comparison.", [30]),
]
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
for res in results:
if res and res.content:
all_content.append(res.content)
# Assertions
full_content = "".join(all_content)
# Check no markers leaked
forbidden = ["<|tool_call", "<|tool_calls_section"]
for marker in forbidden:
assert marker not in full_content, (
f"MARKER LEAKED: {marker} in {repr(full_content)}"
)
# Check no tool call content leaked (both tools)
assert "get_weather" not in full_content, f"TOOL NAME LEAKED: {repr(full_content)}"
assert "Tokyo" not in full_content, f"TOOL ARG LEAKED (Tokyo): {repr(full_content)}"
assert "New York" not in full_content, (
f"TOOL ARG LEAKED (NYC): {repr(full_content)}"
)
# Legitimate content preserved
assert "compare" in full_content.lower() or len(all_content) > 0
...@@ -122,7 +122,6 @@ class KimiK2ToolParser(ToolParser): ...@@ -122,7 +122,6 @@ class KimiK2ToolParser(ToolParser):
if variant in cleaned: if variant in cleaned:
cleaned = cleaned.replace(variant, "") cleaned = cleaned.replace(variant, "")
found_end = True found_end = True
return cleaned, found_begin, found_end return cleaned, found_begin, found_end
def _reset_section_state(self) -> None: def _reset_section_state(self) -> None:
...@@ -238,6 +237,7 @@ class KimiK2ToolParser(ToolParser): ...@@ -238,6 +237,7 @@ class KimiK2ToolParser(ToolParser):
self.in_tool_section = True self.in_tool_section = True
self.token_buffer = buffered_text # Use cleaned buffer self.token_buffer = buffered_text # Use cleaned buffer
self.section_char_count = 0 # Reset counter for new section self.section_char_count = 0 # Reset counter for new section
if found_section_end and self.in_tool_section: if found_section_end and self.in_tool_section:
logger.debug("Detected section end marker") logger.debug("Detected section end marker")
# CRITICAL: Don't exit early if tool_call_end is in this chunk. # CRITICAL: Don't exit early if tool_call_end is in this chunk.
...@@ -252,13 +252,18 @@ class KimiK2ToolParser(ToolParser): ...@@ -252,13 +252,18 @@ class KimiK2ToolParser(ToolParser):
else: else:
# No tool call ending, safe to exit immediately # No tool call ending, safe to exit immediately
logger.debug("Exiting tool section") logger.debug("Exiting tool section")
remaining = buffered_text
self._reset_section_state() self._reset_section_state()
# Return remaining text as reasoning content if non-empty # Extract any content AFTER the section end marker in delta_text
if remaining.strip(): # (don't use buffered_text as it contains tool call data)
return DeltaMessage(content=remaining) post_section_content = ""
# Return empty delta to maintain function contract for variant in self.tool_calls_end_token_variants:
# (always returns DeltaMessage) if variant in delta_text:
parts = delta_text.split(variant, 1)
if len(parts) > 1:
post_section_content = parts[1]
break
if post_section_content.strip():
return DeltaMessage(content=post_section_content)
return DeltaMessage(content="") return DeltaMessage(content="")
else: else:
self.token_buffer = buffered_text self.token_buffer = buffered_text
...@@ -316,12 +321,12 @@ class KimiK2ToolParser(ToolParser): ...@@ -316,12 +321,12 @@ class KimiK2ToolParser(ToolParser):
and prev_tool_end_count == cur_tool_end_count and prev_tool_end_count == cur_tool_end_count
and self.tool_call_end_token not in delta_text and self.tool_call_end_token not in delta_text
): ):
# CRITICAL FIX: Suppress content if in tool section but # Suppress content between section begin and first tool begin
# no tool calls started # (header noise). Don't suppress content between tools to avoid
# breaking potential delimiter characters.
if self.in_tool_section and cur_tool_start_count == 0: if self.in_tool_section and cur_tool_start_count == 0:
logger.debug( logger.debug(
"In tool section but no tool calls started yet. " "In tool section before first tool, suppressing: %s",
"Suppressing: %s",
delta_text, delta_text,
) )
# Return empty delta to maintain iterator contract # Return empty delta to maintain iterator contract
...@@ -488,6 +493,9 @@ class KimiK2ToolParser(ToolParser): ...@@ -488,6 +493,9 @@ class KimiK2ToolParser(ToolParser):
if tool_call_portion is None: if tool_call_portion is None:
# if there's text but not tool calls, send that - # if there's text but not tool calls, send that -
# otherwise None to skip chunk # otherwise None to skip chunk
# CRITICAL: Never return content if we're in a tool section
if self.in_tool_section:
return None
delta = ( delta = (
DeltaMessage(content=delta_text) DeltaMessage(content=delta_text)
if text_portion is not None if text_portion is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment