Unverified Commit c798593f authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Bugfix] Fix the DSML token leakage in DSV4/3.2 (#40806)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
Co-authored-by: default avatarsfeng33 <4florafeng@gmail.com>
Co-authored-by: Windswithyou 1694599440@qq.com
parent 12a3f645
...@@ -484,6 +484,58 @@ class TestExtractToolCallsStreaming: ...@@ -484,6 +484,58 @@ class TestExtractToolCallsStreaming:
# Should have no tool call deltas yet # Should have no tool call deltas yet
assert all(not d.tool_calls for d in deltas) assert all(not d.tool_calls for d in deltas)
def test_no_marker_leak_chunked(self, parser):
"""Chunked streaming must NOT leak DSML start-marker fragments
as content (GitHub #40801)."""
full_text = build_tool_call("fn", {"k": "v"})
deltas = self._stream_chunked(parser, full_text, chunk_size=5)
content = "".join(d.content for d in deltas if d.content is not None)
assert content == ""
args_str = self._reconstruct_args(deltas)
assert json.loads(args_str) == {"k": "v"}
def test_no_marker_leak_with_prefix_chunked(self, parser):
"""Content before a tool call must not include start-marker
fragments when chunked (GitHub #40801)."""
full_text = "Hello!" + build_tool_call("fn", {"a": "b"})
deltas = self._stream_chunked(parser, full_text, chunk_size=5)
content = "".join(d.content for d in deltas if d.content is not None)
assert content == "Hello!"
assert "DSML" not in content
assert "<|" not in content
args_str = self._reconstruct_args(deltas)
assert json.loads(args_str) == {"a": "b"}
def test_no_marker_leak_char_by_char(self, parser):
"""Character-by-character streaming must not leak marker
fragments (GitHub #40801)."""
full_text = build_tool_call("fn", {"k": "v"})
deltas = self._stream_chunked(parser, full_text, chunk_size=1)
content = "".join(d.content for d in deltas if d.content is not None)
assert content == ""
args_str = self._reconstruct_args(deltas)
assert json.loads(args_str) == {"k": "v"}
def test_no_marker_leak_all_split_points(self, parser):
"""Start token split at every possible boundary must not
leak (GitHub #40801)."""
for chunk_size in range(1, len(FC_START) + 2):
p = make_parser()
full_text = build_tool_call("fn", {"k": "v"})
deltas = self._stream_chunked(p, full_text, chunk_size=chunk_size)
content = "".join(d.content for d in deltas if d.content is not None)
assert content == "", (
f"Leaked content {content!r} at chunk_size={chunk_size}"
)
def test_false_partial_marker_emitted(self, parser):
"""Text ending with a prefix of the start token that turns out
NOT to be a marker must still be emitted as content."""
full_text = "<|DSM some regular text"
deltas = self._stream_chunked(parser, full_text, chunk_size=3)
content = "".join(d.content for d in deltas if d.content is not None)
assert content == full_text
class TestDelimiterPreservation: class TestDelimiterPreservation:
"""Regression: fast detokenization skipping DSML delimiters (PR #33964).""" """Regression: fast detokenization skipping DSML delimiters (PR #33964)."""
......
...@@ -26,6 +26,7 @@ from vllm.tool_parsers.abstract_tool_parser import ( ...@@ -26,6 +26,7 @@ from vllm.tool_parsers.abstract_tool_parser import (
Tool, Tool,
ToolParser, ToolParser,
) )
from vllm.tool_parsers.utils import partial_tag_overlap
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -54,8 +55,8 @@ class DeepSeekV32ToolParser(ToolParser): ...@@ -54,8 +55,8 @@ class DeepSeekV32ToolParser(ToolParser):
self.tool_call_start_token: str = "<|DSML|function_calls>" self.tool_call_start_token: str = "<|DSML|function_calls>"
# Streaming state # Streaming state
self.is_tool_call_started: bool = False
self.current_tool_index: int = 0 self.current_tool_index: int = 0
self._sent_content_idx: int = 0
# Regex patterns for complete parsing # Regex patterns for complete parsing
self.tool_call_complete_regex = re.compile( self.tool_call_complete_regex = re.compile(
...@@ -219,7 +220,7 @@ class DeepSeekV32ToolParser(ToolParser): ...@@ -219,7 +220,7 @@ class DeepSeekV32ToolParser(ToolParser):
def _reset_streaming_state(self): def _reset_streaming_state(self):
"""Reset all streaming state.""" """Reset all streaming state."""
self.current_tool_index = 0 self.current_tool_index = 0
self.is_tool_call_started = False self._sent_content_idx = 0
self.prev_tool_call_arr.clear() self.prev_tool_call_arr.clear()
self.streamed_args_for_tool.clear() self.streamed_args_for_tool.clear()
...@@ -264,6 +265,24 @@ class DeepSeekV32ToolParser(ToolParser): ...@@ -264,6 +265,24 @@ class DeepSeekV32ToolParser(ToolParser):
return delta_tool_calls return delta_tool_calls
def _extract_content(self, current_text: str) -> str | None:
"""Return unsent non-tool-call text, or None.
Holds back any suffix that could be a partial start marker
so that split markers are never leaked as content.
"""
if self.tool_call_start_token not in current_text:
overlap = partial_tag_overlap(current_text, self.tool_call_start_token)
sendable_idx = len(current_text) - overlap
else:
sendable_idx = current_text.index(self.tool_call_start_token)
if sendable_idx > self._sent_content_idx:
content = current_text[self._sent_content_idx : sendable_idx]
self._sent_content_idx = sendable_idx
return content
return None
def extract_tool_calls_streaming( def extract_tool_calls_streaming(
self, self,
previous_text: str, previous_text: str,
...@@ -285,29 +304,11 @@ class DeepSeekV32ToolParser(ToolParser): ...@@ -285,29 +304,11 @@ class DeepSeekV32ToolParser(ToolParser):
if not previous_text: if not previous_text:
self._reset_streaming_state() self._reset_streaming_state()
# Detect whether we've entered the tool-call region. content = self._extract_content(current_text)
# Use current_text (not delta_text) since the start token may
# be split across chunks.
content_before = None
if self.is_tool_call_started:
pass
elif self.tool_call_start_token in current_text:
# Tool-call region found, capture any plain text before it.
self.is_tool_call_started = True
start_idx = current_text.index(self.tool_call_start_token)
content_before = current_text[len(previous_text) : start_idx] or None
else:
# Still in plain-text region, forward as content.
return DeltaMessage(content=delta_text) if delta_text else None
# Inside tool-call region: emit any newly completed invokes.
delta_tool_calls = self._extract_delta_tool_calls(current_text, request) delta_tool_calls = self._extract_delta_tool_calls(current_text, request)
if delta_tool_calls or content_before: if delta_tool_calls or content:
return DeltaMessage( return DeltaMessage(content=content, tool_calls=delta_tool_calls)
content=content_before,
tool_calls=delta_tool_calls,
)
# Empty delta with token ids means EOS or closing tag; return # Empty delta with token ids means EOS or closing tag; return
# non-None so the serving framework can finalize finish_reason. # non-None so the serving framework can finalize finish_reason.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment