# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
import uuid
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
)
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
)
logger = init_logger(__name__)
class Step3p5ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
# Override base class type - we use string IDs for tool calls
self.current_tool_id: str | None = None # type: ignore
self.streamed_args_for_tool: list[str] = []
# Sentinel tokens for streaming mode
self.tool_call_start_token: str = ""
self.tool_call_end_token: str = ""
self.tool_call_prefix: str = "(.*?)", re.DOTALL
)
self.tool_call_function_regex = re.compile(
r"", re.DOTALL
)
self.tool_call_parameter_regex = re.compile(
r"", re.DOTALL
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if self.tool_call_start_token_id is None or self.tool_call_end_token_id is None:
raise RuntimeError(
"Step3p5 RL Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
# Get EOS token ID for EOS detection
self.eos_token_id = getattr(self.model_tokenizer, "eos_token_id", None)
logger.info(
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def parser_should_check_for_unstreamed_tool_arg_tokens(self) -> bool:
"""
Skip the remaining_call calculation in serving
"""
return False
def _reset_streaming_state(self):
"""Reset all streaming state for a new request."""
self._processed_length: int = 0 # Position of last processed character
self._tool_call_index: int = 0 # Number of tool calls processed so far
self.streaming_request = None # Current request being processed
def _get_arguments_config(
self, func_name: str, tools: list[ChatCompletionToolsParam] | None
) -> dict:
"""Extract argument configuration for a function."""
if tools is None:
return {}
for config in tools:
if not hasattr(config, "type") or not (
hasattr(config, "function") and hasattr(config.function, "name")
):
continue
if config.type == "function" and config.function.name == func_name:
if not hasattr(config.function, "parameters"):
return {}
params = config.function.parameters
if isinstance(params, dict) and "properties" in params:
return params["properties"]
elif isinstance(params, dict):
return params
else:
return {}
logger.warning("Tool '%s' is not defined in the tools list.", func_name)
return {}
def _convert_param_value(
self, param_value: str, param_name: str, param_config: dict, func_name: str
) -> Any:
"""Convert parameter value based on its type in the schema."""
# Handle null value for any type
if param_value.lower() == "null":
return None
if param_name not in param_config:
if param_config != {}:
logger.warning(
"Parsed parameter '%s' is not defined in the tool "
"parameters for tool '%s', directly returning the "
"string value.",
param_name,
func_name,
)
return param_value
if (
isinstance(param_config[param_name], dict)
and "type" in param_config[param_name]
):
param_type = str(param_config[param_name]["type"]).strip().lower()
else:
param_type = "string"
if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
return param_value
elif (
param_type.startswith("int")
or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")
):
try:
return int(param_value)
except (ValueError, TypeError):
try:
float_value = float(param_value)
if float_value.is_integer():
return int(float_value)
except (ValueError, TypeError):
pass
try:
literal_value = ast.literal_eval(param_value)
if isinstance(literal_value, bool):
return int(literal_value)
if isinstance(literal_value, (int, float)):
return (
int(literal_value)
if float(literal_value).is_integer()
else literal_value
)
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' is not an integer "
"in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
)
return param_value
elif param_type.startswith("num") or param_type.startswith("float"):
try:
float_param_value = float(param_value)
return (
float_param_value
if float_param_value - int(float_param_value) != 0
else int(float_param_value)
)
except (ValueError, TypeError):
try:
literal_value = ast.literal_eval(param_value)
if isinstance(literal_value, (int, float)):
return (
float(literal_value)
if float(literal_value) - int(float(literal_value)) != 0
else int(float(literal_value))
)
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' is not a float "
"in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
)
return param_value
elif param_type in ["boolean", "bool", "binary"]:
normalized_value = param_value.strip().lower()
if normalized_value in ["true", "false"]:
return normalized_value == "true"
if normalized_value in ["1", "0"]:
return normalized_value == "1"
try:
literal_value = ast.literal_eval(param_value)
if isinstance(literal_value, bool):
return literal_value
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' is not a boolean "
"in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
)
return param_value
else:
if (
param_type in ["object", "array", "arr"]
or param_type.startswith("dict")
or param_type.startswith("list")
):
try:
param_value = json.loads(param_value)
return param_value
except (json.JSONDecodeError, TypeError, ValueError):
try:
literal_value = ast.literal_eval(param_value)
if isinstance(literal_value, (list, dict)):
return literal_value
if isinstance(literal_value, (tuple, set)):
return list(literal_value)
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' cannot be parsed "
"as JSON in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
)
return param_value
try:
literal_value = ast.literal_eval(param_value) # safer
if isinstance(literal_value, (tuple, set)):
return list(literal_value)
if (
isinstance(literal_value, (list, dict, str, int, float, bool))
or literal_value is None
):
return literal_value
except (ValueError, SyntaxError, TypeError):
pass
logger.warning(
"Parsed value '%s' of parameter '%s' cannot be converted via "
"Python `ast.literal_eval()` in tool '%s', returning raw string.",
param_value,
param_name,
func_name,
)
return param_value
def _parse_parameters_fallback(
self,
parameters: str,
allowed_param_names: set[str] | None = None,
) -> list[tuple[str, str]]:
"""Fallback parser for malformed parameter tags."""
param_pairs: list[tuple[str, str]] = []
pos = 0
while True:
start = parameters.find(self.parameter_prefix, pos)
if start == -1:
break
name_start = start + len(self.parameter_prefix)
name_end = parameters.find(">", name_start)
if name_end == -1:
newline_idx = parameters.find("\n", name_start)
end_tag = parameters.find(self.parameter_end_token, name_start)
next_param = parameters.find(self.parameter_prefix, name_start)
candidates = [
idx for idx in [newline_idx, end_tag, next_param] if idx != -1
]
if not candidates:
break
name_end = min(candidates)
value_start = name_end
else:
value_start = name_end + 1
param_name = parameters[name_start:name_end].strip()
next_param = parameters.find(self.parameter_prefix, value_start)
end_tag = parameters.find(self.parameter_end_token, value_start)
if end_tag == -1 or (next_param != -1 and next_param < end_tag):
end = next_param if next_param != -1 else len(parameters)
pos = end
else:
end = end_tag
pos = end + len(self.parameter_end_token)
param_value = parameters[value_start:end]
if allowed_param_names is None or param_name in allowed_param_names:
param_pairs.append((param_name, param_value))
return param_pairs
def _is_valid_json_arguments(self, arguments: str) -> bool:
"""Check if arguments can be loaded as JSON."""
try:
json.loads(arguments)
except Exception:
return False
return True
def _parse_xml_function_call(
self, function_call_str: str, tools: list[ChatCompletionToolsParam] | None
) -> ToolCall | None:
# Extract function name
end_index = function_call_str.index(">")
# check empty function name
function_name = function_call_str[:end_index].strip()
if function_name.startswith("="):
function_name = function_name.lstrip("=").strip()
if not function_name or function_name.strip("'\"") == "":
logger.warning("Empty function name in tool call.")
return None
if function_name[0] in "\"'" and function_name[-1] == function_name[0]:
function_name = function_name[1:-1].strip()
if not function_name:
logger.warning("Empty function name in tool call.")
return None
param_config = self._get_arguments_config(function_name, tools)
parameters = function_call_str[end_index + 1 :]
param_dict = {}
match_texts = self.tool_call_parameter_regex.findall(parameters)
use_fallback = False
if match_texts:
for match_text in match_texts:
if self.parameter_prefix in match_text or ">" not in match_text:
use_fallback = True
break
else:
use_fallback = self.parameter_prefix in parameters
if use_fallback:
allowed_param_names = (
set(param_config.keys())
if isinstance(param_config, dict) and param_config
else None
)
param_pairs = self._parse_parameters_fallback(
parameters, allowed_param_names
)
else:
param_pairs = []
for match_text in match_texts:
idx = match_text.index(">")
param_name = match_text[:idx]
param_value = str(match_text[idx + 1 :])
param_pairs.append((param_name, param_value))
for param_name, param_value in param_pairs:
# Remove prefix and trailing \n
if param_value.startswith("\n"):
param_value = param_value[1:]
if param_value.endswith("\n"):
param_value = param_value[:-1]
param_dict[param_name] = self._convert_param_value(
param_value, param_name, param_config, function_name
)
try:
arguments = json.dumps(param_dict, ensure_ascii=False)
except Exception as e:
logger.warning("Error in converting parameter value: %s", e)
return None
return ToolCall(
type="function",
function=FunctionCall(name=function_name, arguments=arguments),
)
def _get_function_calls(self, model_output: str) -> list[str]:
# Find all tool calls
raw_tool_calls = self.tool_call_complete_regex.findall(model_output)
# if no closed tool_call tags found, return empty list
if len(raw_tool_calls) == 0:
return []
raw_function_calls = []
for tool_call in raw_tool_calls:
function_matches = self.tool_call_function_regex.findall(tool_call)
raw_function_calls.extend(function_matches)
return raw_function_calls
def _check_format(self, model_output: str) -> bool:
"""Check if model output contains properly formatted tool call.
Requirements:
1. Must have closed tool_call tags (...)
2. Must have closed function tags ()
3. If parameter tags exist, they must be closed and correct
Returns True if the format is valid, False otherwise.
"""
# Check 1: Must have closed tool_call tags
tool_call_matches = self.tool_call_complete_regex.findall(model_output)
if len(tool_call_matches) == 0:
return False
# Check 2: Must have closed function tags within tool_call
has_valid_function = False
for tool_call_content in tool_call_matches:
function_matches = self.tool_call_function_regex.findall(tool_call_content)
if len(function_matches) > 0:
has_valid_function = True
# Check if there's an unclosed function tag
if (
self.tool_call_prefix in tool_call_content
and self.function_end_token not in tool_call_content
):
return False
if not has_valid_function:
return False
# Check 3: If parameter tags exist, they must be closed and correct
for tool_call_content in tool_call_matches:
# Count opening and closing parameter tags
param_open_count = tool_call_content.count(self.parameter_prefix)
param_close_count = tool_call_content.count(self.parameter_end_token)
# If there are parameter tags, they must be balanced
if param_open_count > 0:
if param_open_count != param_close_count:
return False
# Check if all parameter tags are properly closed using regex
param_matches = self.tool_call_parameter_regex.findall(
tool_call_content
)
if len(param_matches) != param_open_count:
return False
return True
def _wrap_missing_tool_call_tags(self, model_output: str) -> str:
"""Wrap bare blocks with tags."""
if (
self.tool_call_prefix not in model_output
or self.function_end_token not in model_output
):
return model_output
def _wrap_bare_functions(text: str) -> str:
pos = 0
wrapped_parts: list[str] = []
while True:
func_idx = text.find(self.tool_call_prefix, pos)
if func_idx == -1:
wrapped_parts.append(text[pos:])
break
end_idx = text.find(self.function_end_token, func_idx)
if end_idx == -1:
wrapped_parts.append(text[pos:])
break
end_idx += len(self.function_end_token)
wrapped_parts.append(text[pos:func_idx])
wrapped_parts.append(self.tool_call_start_token)
wrapped_parts.append(text[func_idx:end_idx])
wrapped_parts.append(self.tool_call_end_token)
ws_idx = end_idx
while ws_idx < len(text) and text[ws_idx].isspace():
ws_idx += 1
if text.startswith(self.tool_call_end_token, ws_idx):
if ws_idx > end_idx:
wrapped_parts.append(text[end_idx:ws_idx])
pos = ws_idx + len(self.tool_call_end_token)
else:
pos = end_idx
return "".join(wrapped_parts)
tool_call_ranges = [
match.span()
for match in self.tool_call_complete_regex.finditer(model_output)
]
if not tool_call_ranges:
return _wrap_bare_functions(model_output)
wrapped_parts: list[str] = []
pos = 0
for start, end in tool_call_ranges:
if start < pos:
continue
wrapped_parts.append(_wrap_bare_functions(model_output[pos:start]))
wrapped_parts.append(model_output[start:end])
pos = end
wrapped_parts.append(_wrap_bare_functions(model_output[pos:]))
return "".join(wrapped_parts)
def _normalize_prev_arguments(self, args_value: Any) -> Any:
if isinstance(args_value, str):
try:
return json.loads(args_value)
except (TypeError, ValueError, json.JSONDecodeError):
return args_value
return args_value
def _update_prev_tool_call_state(self, tool_calls: list[ToolCall]) -> None:
self.prev_tool_call_arr.clear()
self.streamed_args_for_tool.clear()
for tool_call in tool_calls:
if not tool_call or not tool_call.function:
continue
args_value = tool_call.function.arguments
if isinstance(args_value, str):
args_json = args_value
elif args_value is None:
args_json = ""
else:
try:
args_json = json.dumps(args_value, ensure_ascii=False)
except (TypeError, ValueError):
args_json = str(args_value)
prev_args = self._normalize_prev_arguments(args_json)
self.prev_tool_call_arr.append(
{
"name": tool_call.function.name,
"arguments": prev_args,
}
)
try:
expected_args_json = json.dumps(prev_args, ensure_ascii=False)
except (TypeError, ValueError):
expected_args_json = args_json
# Serving may subtract the latest delta length from
# streamed_args_for_tool to detect unstreamed suffixes. Since this
# parser emits full arguments at once, store expected+actual so
# the subtraction yields expected_args_json and no resend occurs.
self.streamed_args_for_tool.append(expected_args_json + args_json)
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
try:
origin_model_output = model_output
try:
# Fallback: handle outputs without wrapper.
origin_model_output = self._wrap_missing_tool_call_tags(
origin_model_output
)
model_output = origin_model_output
except Exception:
pass
# Use streaming-like approach: process position by position
valid_tool_calls = []
content_parts = []
processed_length = 0
while processed_length < len(model_output):
# Find next tool call start
tool_start_idx = self._find_tool_call_start(
model_output, processed_length
)
# Case 1: No more tool calls - add remaining as content
if tool_start_idx == -1:
remaining = model_output[processed_length:]
if remaining:
content_parts.append(remaining)
break
# Case 2: Content before tool call
if tool_start_idx > processed_length:
content_before = model_output[processed_length:tool_start_idx]
# Skip whitespace-only content between tool calls
# Check if we just ended a tool call and this is pure whitespace
if processed_length > 0:
text_before = model_output[:processed_length]
if (
text_before.rstrip().endswith(self.tool_call_end_token)
and content_before.strip() == ""
):
# Skip whitespace between tool calls
pass
else:
content_parts.append(content_before)
else:
content_parts.append(content_before)
# Case 3: Try to find complete tool call
tool_end_idx = self._find_first_complete_tool_call_end(
model_output, tool_start_idx
)
# If tool call is incomplete - add remaining as content and stop
if tool_end_idx == -1:
remaining = model_output[tool_start_idx:]
if remaining:
content_parts.append(remaining)
break
# Extract and try to parse the complete tool call
tool_call_text = model_output[tool_start_idx:tool_end_idx]
parsed_result = self.extract_tool_calls_basic(tool_call_text, request)
# If parsing succeeded, record the tool call(s)
if parsed_result.tools_called and parsed_result.tool_calls:
valid_tool_calls.extend(parsed_result.tool_calls)
processed_length = tool_end_idx
else:
# Parsing failed - treat this tool call as content
content_parts.append(tool_call_text)
processed_length = tool_end_idx
# Populate prev_tool_call_arr for serving layer to set finish_reason
self._update_prev_tool_call_state(valid_tool_calls)
# Combine content parts
content = "".join(content_parts) if content_parts else None
return ExtractedToolCallInformation(
tools_called=(len(valid_tool_calls) > 0),
tool_calls=valid_tool_calls,
content=content if content else None,
)
except Exception:
logger.warning("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def extract_tool_calls_basic(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
model_output = self._wrap_missing_tool_call_tags(model_output)
# Quick check to avoid unnecessary processing
if not self._check_format(model_output):
tool_call_matches = self.tool_call_complete_regex.findall(model_output)
if len(tool_call_matches) == 0:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
function_calls = self._get_function_calls(model_output)
if len(function_calls) == 0:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
tool_calls: list[ToolCall] = []
for function_call_str in function_calls:
tool_call = self._parse_xml_function_call(
function_call_str, request.tools
)
if tool_call:
tool_calls.append(tool_call)
if not tool_calls:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
for tool_call in tool_calls:
if (
not tool_call.function
or tool_call.function.arguments is None
or not self._is_valid_json_arguments(tool_call.function.arguments)
):
logger.warning(
"Invalid JSON arguments in tool call, falling back to content."
)
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# Populate prev_tool_call_arr for serving layer to set finish_reason
self._update_prev_tool_call_state(tool_calls)
# Extract content before tool calls
content_index = model_output.find(self.tool_call_start_token)
content = model_output[:content_index] # .rstrip()
return ExtractedToolCallInformation(
tools_called=(len(tool_calls) > 0),
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.warning("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def _find_first_complete_tool_call_end(self, text: str, start_pos: int = 0) -> int:
"""Find the end position of the first complete tool call.
Args:
text: Text to search in
start_pos: Position to start searching from
Returns:
Position after the first tag, or -1 if incomplete
Example:
"......" returns position after
"""
# Find tool call start
start_idx = text.find(self.tool_call_start_token, start_pos)
if start_idx == -1:
return -1
# Find matching end token
end_idx = text.find(
self.tool_call_end_token, start_idx + len(self.tool_call_start_token)
)
if end_idx == -1:
return -1 # Incomplete tool call
# Return position after end token
return end_idx + len(self.tool_call_end_token)
def _find_tool_call_start(self, text: str, start_pos: int = 0) -> int:
"""Find the start position of next tool call.
Args:
text: Text to search in
start_pos: Position to start searching from
Returns:
Position of token, or -1 if not found
"""
return text.find(self.tool_call_start_token, start_pos)
def _extract_content_between_tool_calls_list(self, text: str) -> list[str]:
"""Extract content segments after each tool call.
For n tool calls, returns n segments where segment[i] is the content
after tool_call[i] (before tool_call[i+1] or at the end).
Empty or whitespace-only segments are represented as empty string "".
Args:
text: Text containing tool calls
Returns:
List of content segments (one per tool call)
"""
content_segments = []
pos = 0
while True:
# Find end of current tool call
end_pos = text.find(self.tool_call_end_token, pos)
if end_pos == -1:
break
# Move past the end token
end_pos += len(self.tool_call_end_token)
# Find start of next tool call
next_start = self._find_tool_call_start(text, end_pos)
# Extract content between current end and next start (or text end)
content = text[end_pos:next_start] if next_start != -1 else text[end_pos:]
# Store content (empty string if whitespace-only)
content_segments.append(content if content.strip() else "")
if next_start == -1:
break
pos = next_start
return content_segments
def _convert_tool_calls_to_deltas(
self, tool_calls: list[ToolCall], starting_index: int = 0
) -> list[DeltaToolCall]:
"""Convert complete ToolCall list to DeltaToolCall list.
Returns complete tool calls without splitting into fragments.
Args:
tool_calls: List of tool calls to convert
starting_index: Starting index for tool calls (default 0)
Returns:
List of DeltaToolCall with complete arguments
"""
delta_tool_calls = []
for i, tool_call in enumerate[ToolCall](tool_calls):
index = starting_index + i
tool_id = self._generate_tool_call_id()
# Create complete DeltaToolCall with full arguments
delta_tool_calls.append(
DeltaToolCall(
index=index,
id=tool_id,
function=DeltaFunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
),
type="function",
)
)
return delta_tool_calls
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming text using complete parsing.
Strategy:
1. Accumulate text in buffer and track processed position
2. In each iteration, try to extract content or complete tool calls
3. Parse complete tool calls using non-streaming method
4. Convert parsed results to delta sequence
5. Handle EOS token to flush incomplete tool calls as content
"""
# Initialize state for new request
if not previous_text:
self._reset_streaming_state()
self.streaming_request = request
# Check for EOS token
has_eos = (
self.eos_token_id is not None
and delta_token_ids
and self.eos_token_id in delta_token_ids
)
# If no delta text, check if we need to return empty delta for finish_reason
if not delta_text and not has_eos:
# Check if this is an EOS token after all tool calls are complete
if delta_token_ids and self.tool_call_end_token_id not in delta_token_ids:
# Count complete tool calls
complete_calls = len(
self.tool_call_complete_regex.findall(current_text)
)
# If we have completed tool calls and populated prev_tool_call_arr
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token
) - current_text.count(self.tool_call_end_token)
if open_calls == 0:
# Return empty delta for finish_reason processing
return DeltaMessage(content="")
return None
# Process all available content
accumulated_deltas: list[DeltaMessage] = []
while self._has_unprocessed_content(current_text):
# Try to process next chunk (content or tool call)
delta = self._process_next_chunk(current_text)
if delta is None:
# Cannot proceed further, need more tokens
break
# Accumulate deltas
if isinstance(delta, list):
accumulated_deltas.extend(delta)
else:
accumulated_deltas.append(delta)
# Handle EOS: flush any remaining incomplete tool calls as content
if has_eos:
remaining_delta = self._flush_remaining_content(current_text)
if remaining_delta:
accumulated_deltas.append(remaining_delta)
# If no remaining content but we have tool calls, return empty delta
elif len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token
) - current_text.count(self.tool_call_end_token)
if open_calls == 0:
accumulated_deltas.append(DeltaMessage(content=""))
# Return results
return self._format_delta_result(accumulated_deltas)
def _has_unprocessed_content(self, current_text: str) -> bool:
"""Check if there's unprocessed content in the buffer."""
return self._processed_length < len(current_text)
def _process_next_chunk(
self, current_text: str
) -> DeltaMessage | list[DeltaMessage] | None:
"""Process next chunk: either regular content or a complete tool call.
Args:
current_text: Current accumulated text
Returns:
- DeltaMessage or list of DeltaMessage if processed successfully
- None if cannot proceed (need more tokens)
"""
# Find next tool call start
tool_start_idx = self._find_tool_call_start(
current_text, self._processed_length
)
# Case 1: No tool call found - return remaining content
if tool_start_idx == -1:
return self._process_content(
current_text, self._processed_length, len(current_text)
)
# Case 2: Content before tool call
if tool_start_idx > self._processed_length:
return self._process_content(
current_text, self._processed_length, tool_start_idx
)
# Case 3: Tool call at current position
# Find end of the first complete tool call
tool_end_idx = self._find_first_complete_tool_call_end(
current_text, tool_start_idx
)
if tool_end_idx == -1:
# Tool call incomplete, wait for more tokens
return None
# Process complete tool call
return self._process_complete_tool_calls(
current_text, tool_start_idx, tool_end_idx
)
def _process_content(
self, current_text: str, start_pos: int, end_pos: int
) -> DeltaMessage | None:
"""Process regular content (non-tool-call text).
Args:
current_text: Current accumulated text
start_pos: Start position in buffer
end_pos: End position in buffer
Returns:
DeltaMessage with content if non-empty
"""
if start_pos >= end_pos:
return None
content = current_text[start_pos:end_pos]
# Check if we're between tool calls - skip whitespace
if start_pos > 0:
# Check if text before start_pos ends with
text_before = current_text[:start_pos]
if (
text_before.rstrip().endswith(self.tool_call_end_token)
and content.strip() == ""
):
# We just ended a tool call, skip whitespace between tool calls
self._processed_length = end_pos
return None
# Return content if non-empty
if content:
self._processed_length = end_pos
return DeltaMessage(content=content)
# Mark as processed even if empty
self._processed_length = end_pos
return None
def _flush_remaining_content(self, current_text: str) -> DeltaMessage | None:
"""Flush any remaining unprocessed content as regular content.
Args:
current_text: Current accumulated text
Used when EOS token is encountered to handle incomplete tool calls.
"""
if not self._has_unprocessed_content(current_text):
return None
remaining = current_text[self._processed_length :]
if remaining:
self._processed_length = len(current_text)
return DeltaMessage(content=remaining)
self._processed_length = len(current_text)
return None
def _format_delta_result(self, deltas: list[DeltaMessage]) -> DeltaMessage | None:
"""Format delta result for return.
Merges all deltas into a single DeltaMessage.
Args:
deltas: List of delta messages
Returns:
- None if empty
- Single merged DeltaMessage with all content and tool_calls
"""
if not deltas:
return None
if len(deltas) == 1:
return deltas[0]
# Merge multiple deltas into one
merged_content_parts = []
merged_tool_calls = []
for delta in deltas:
if delta.content:
merged_content_parts.append(delta.content)
if delta.tool_calls:
merged_tool_calls.extend(delta.tool_calls)
# Create merged DeltaMessage
merged_content = "".join(merged_content_parts) if merged_content_parts else None
# Build kwargs - only include tool_calls if non-empty
kwargs: dict[str, Any] = {"content": merged_content}
if merged_tool_calls:
kwargs["tool_calls"] = merged_tool_calls
return DeltaMessage(**kwargs)
def _process_complete_tool_calls(
self, current_text: str, start_pos: int, end_pos: int
) -> list[DeltaMessage] | None:
"""Process complete tool calls and convert to delta sequence.
Args:
current_text: Current accumulated text
start_pos: Start position (should be at )
end_pos: End position (after )
Returns:
List of DeltaMessage if successful, None otherwise
"""
try:
# Extract text segment containing complete tool call(s)
text_to_parse = current_text[start_pos:end_pos]
# Parse using non-streaming method
result = self.extract_tool_calls_basic(
text_to_parse, self.streaming_request
)
# Case 1: Successfully parsed tool calls
if result.tools_called and result.tool_calls:
# Note: Due to _find_first_complete_tool_call_end, we typically
# process only one tool call at a time
# but we can also process multiple tool calls below
deltas = self._build_tool_call_deltas(result.tool_calls, text_to_parse)
self._update_state_after_tool_calls(result.tool_calls, end_pos)
return deltas if deltas else None
# Case 2: Parsing failed - treat as regular content
self._processed_length = end_pos
return [DeltaMessage(content=text_to_parse)]
except Exception as e:
# Exception during parsing - treat as content
logger.debug("Failed to parse tool calls: %s, treating as content", e)
self._processed_length = end_pos
failed_text = current_text[start_pos:end_pos]
return [DeltaMessage(content=failed_text)] if failed_text else None
def _build_tool_call_deltas(
self, tool_calls: list[ToolCall], parsed_text: str
) -> list[DeltaMessage]:
"""Build delta messages from parsed tool calls with interleaved content.
Args:
tool_calls: List of parsed tool calls
parsed_text: Original text that was parsed
Returns:
List of DeltaMessage with tool calls and content interleaved
"""
# Extract content segments between tool calls
content_segments = self._extract_content_between_tool_calls_list(parsed_text)
# Convert all tool calls to DeltaToolCall list
delta_tool_calls = self._convert_tool_calls_to_deltas(
tool_calls, self._tool_call_index
)
# Merge all content segments into a single string
merged_content = "".join(content_segments)
# Return a single DeltaMessage with all tool calls and content
# Build kwargs - only include non-empty fields
kwargs: dict[str, Any] = {}
if merged_content:
kwargs["content"] = merged_content
if delta_tool_calls:
kwargs["tool_calls"] = delta_tool_calls
# Only return DeltaMessage if we have content or tool_calls
if kwargs:
return [DeltaMessage(**kwargs)]
else:
return []
def _update_state_after_tool_calls(
self, tool_calls: list[ToolCall], end_pos: int
) -> None:
"""Update internal state after processing tool calls.
Args:
tool_calls: List of processed tool calls
end_pos: End position in buffer
"""
# Update processed position
self._processed_length = end_pos
# Update tool call index
self._tool_call_index += len(tool_calls)
# Update prev_tool_call_arr for finish_reason
self._update_prev_tool_call_state(tool_calls)