Unverified Commit 13151a4d authored by Flora Feng's avatar Flora Feng Committed by GitHub
Browse files

[Bugfix] Fix Gemma4 streaming tool call corruption for split boolean/number values (#39114)


Signed-off-by: default avatarsfeng33 <4florafeng@gmail.com>
parent 56c976c1
...@@ -491,6 +491,51 @@ class TestStreamingExtraction: ...@@ -491,6 +491,51 @@ class TestStreamingExtraction:
assert parsed_args["count"] == 42 assert parsed_args["count"] == 42
assert parsed_args["active"] is True assert parsed_args["active"] is True
def test_streaming_boolean_split_across_chunks(self, parser, mock_request):
"""Boolean value split across token boundaries must not corrupt JSON."""
chunks = [
"<|tool_call>",
"call:search{input:{all:" + "true"[:3],
"e}}",
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
args_text = self._collect_arguments(results)
assert args_text, "No arguments were streamed"
parsed_args = json.loads(args_text)
assert parsed_args["input"]["all"] is True
def test_streaming_false_split_across_chunks(self, parser, mock_request):
"""Boolean false split across chunks."""
chunks = [
"<|tool_call>",
"call:set{flag:" + "false"[:4],
"e}",
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
args_text = self._collect_arguments(results)
assert args_text, "No arguments were streamed"
parsed_args = json.loads(args_text)
assert parsed_args["flag"] is False
def test_streaming_number_split_across_chunks(self, parser, mock_request):
"""Number split across chunks must not change type."""
chunks = [
"<|tool_call>",
"call:set{count:4",
"2}",
"<tool_call|>",
]
results = self._simulate_streaming(parser, mock_request, chunks)
args_text = self._collect_arguments(results)
assert args_text, "No arguments were streamed"
parsed_args = json.loads(args_text)
assert parsed_args["count"] == 42
def test_streaming_empty_args(self, parser, mock_request): def test_streaming_empty_args(self, parser, mock_request):
"""Tool call with no arguments.""" """Tool call with no arguments."""
chunks = [ chunks = [
......
...@@ -78,7 +78,7 @@ def _parse_gemma4_value(value_str: str) -> object: ...@@ -78,7 +78,7 @@ def _parse_gemma4_value(value_str: str) -> object:
return value_str return value_str
def _parse_gemma4_args(args_str: str) -> dict: def _parse_gemma4_args(args_str: str, *, partial: bool = False) -> dict:
"""Parse Gemma4's custom key:value format into a Python dict. """Parse Gemma4's custom key:value format into a Python dict.
Format examples:: Format examples::
...@@ -89,6 +89,12 @@ def _parse_gemma4_args(args_str: str) -> dict: ...@@ -89,6 +89,12 @@ def _parse_gemma4_args(args_str: str) -> dict:
nested:{inner_key:<|"|>val<|"|>} nested:{inner_key:<|"|>val<|"|>}
items:[<|"|>a<|"|>,<|"|>b<|"|>] items:[<|"|>a<|"|>,<|"|>b<|"|>]
Args:
args_str: The raw Gemma4 argument string.
partial: When True (streaming), bare values at end of string are
omitted because they may be incomplete and type-unstable
(e.g. partial boolean parsed as bare string).
Returns a dict ready for ``json.dumps()``. Returns a dict ready for ``json.dumps()``.
""" """
if not args_str or not args_str.strip(): if not args_str or not args_str.strip():
...@@ -155,7 +161,12 @@ def _parse_gemma4_args(args_str: str) -> dict: ...@@ -155,7 +161,12 @@ def _parse_gemma4_args(args_str: str) -> dict:
elif args_str[i] == "}": elif args_str[i] == "}":
depth -= 1 depth -= 1
i += 1 i += 1
result[key] = _parse_gemma4_args(args_str[obj_start : i - 1]) if depth > 0:
# Incomplete nested object — use i (not i-1) to avoid
# dropping the last char, and recurse as partial.
result[key] = _parse_gemma4_args(args_str[obj_start:i], partial=True)
else:
result[key] = _parse_gemma4_args(args_str[obj_start : i - 1])
# Array: [...] # Array: [...]
elif args_str[i] == "[": elif args_str[i] == "[":
...@@ -173,20 +184,26 @@ def _parse_gemma4_args(args_str: str) -> dict: ...@@ -173,20 +184,26 @@ def _parse_gemma4_args(args_str: str) -> dict:
elif args_str[i] == "]": elif args_str[i] == "]":
depth -= 1 depth -= 1
i += 1 i += 1
arr_content = args_str[arr_start : i - 1] if depth > 0:
result[key] = _parse_gemma4_array(arr_content) result[key] = _parse_gemma4_array(args_str[arr_start:i], partial=True)
else:
result[key] = _parse_gemma4_array(args_str[arr_start : i - 1])
# Bare value (number, boolean, etc.) # Bare value (number, boolean, etc.)
else: else:
val_start = i val_start = i
while i < n and args_str[i] not in (",", "}", "]"): while i < n and args_str[i] not in (",", "}", "]"):
i += 1 i += 1
if partial and i >= n:
# Value may be incomplete (e.g. partial boolean) —
# withhold to avoid type instability during streaming.
break
result[key] = _parse_gemma4_value(args_str[val_start:i]) result[key] = _parse_gemma4_value(args_str[val_start:i])
return result return result
def _parse_gemma4_array(arr_str: str) -> list: def _parse_gemma4_array(arr_str: str, *, partial: bool = False) -> list:
"""Parse a Gemma4 array content string into a Python list.""" """Parse a Gemma4 array content string into a Python list."""
items: list = [] items: list = []
i = 0 i = 0
...@@ -224,7 +241,10 @@ def _parse_gemma4_array(arr_str: str) -> list: ...@@ -224,7 +241,10 @@ def _parse_gemma4_array(arr_str: str) -> list:
elif arr_str[i] == "}": elif arr_str[i] == "}":
depth -= 1 depth -= 1
i += 1 i += 1
items.append(_parse_gemma4_args(arr_str[obj_start : i - 1])) if depth > 0:
items.append(_parse_gemma4_args(arr_str[obj_start:i], partial=True))
else:
items.append(_parse_gemma4_args(arr_str[obj_start : i - 1]))
# Nested array # Nested array
elif arr_str[i] == "[": elif arr_str[i] == "[":
...@@ -237,13 +257,18 @@ def _parse_gemma4_array(arr_str: str) -> list: ...@@ -237,13 +257,18 @@ def _parse_gemma4_array(arr_str: str) -> list:
elif arr_str[i] == "]": elif arr_str[i] == "]":
depth -= 1 depth -= 1
i += 1 i += 1
items.append(_parse_gemma4_array(arr_str[sub_start : i - 1])) if depth > 0:
items.append(_parse_gemma4_array(arr_str[sub_start:i], partial=True))
else:
items.append(_parse_gemma4_array(arr_str[sub_start : i - 1]))
# Bare value # Bare value
else: else:
val_start = i val_start = i
while i < n and arr_str[i] not in (",", "]"): while i < n and arr_str[i] not in (",", "]"):
i += 1 i += 1
if partial and i >= n:
break
items.append(_parse_gemma4_value(arr_str[val_start:i])) items.append(_parse_gemma4_value(arr_str[val_start:i]))
return items return items
...@@ -663,7 +688,7 @@ class Gemma4ToolParser(ToolParser): ...@@ -663,7 +688,7 @@ class Gemma4ToolParser(ToolParser):
DeltaMessage with the argument diff, or None if no new content. DeltaMessage with the argument diff, or None if no new content.
""" """
try: try:
current_args = _parse_gemma4_args(raw_args_str) current_args = _parse_gemma4_args(raw_args_str, partial=True)
except Exception: except Exception:
logger.debug( logger.debug(
"Could not parse partial Gemma4 args yet: %s", "Could not parse partial Gemma4 args yet: %s",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment