Commit 8d75f22e authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori

parents ce888aa4 7d80c73d
......@@ -30,6 +30,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"deepseekv31_tool_parser",
"DeepSeekV31ToolParser",
),
"deepseek_v32": (
"deepseekv32_tool_parser",
"DeepSeekV32ToolParser",
),
"ernie45": (
"ernie45_tool_parser",
"Ernie45ToolParser",
......@@ -130,6 +134,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"xlam_tool_parser",
"xLAMToolParser",
),
"gigachat3": (
"gigachat3_tool_parser",
"GigaChat3ToolParser",
),
}
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import uuid
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
class DeepSeekV32ToolParser(ToolParser):
"""
example tool call content:
<|DSML|function_calls>
<|DSML|invoke name="get_weather">
<|DSML|parameter name="location" string="true">杭州</|DSML|parameter>
<|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter>
</|DSML|invoke>
<|DSML|invoke name="get_weather">
<|DSML|parameter name="location" string="true">北京</|DSML|parameter>
<|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter>
</|DSML|invoke>
</|DSML|function_calls>
"""
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.prev_tool_call_arr: list[dict] = []
# Sentinel tokens
self.dsml_token: str = "|DSML|"
self.dsml_start_check: str = "<" + self.dsml_token
self.tool_call_start_token: str = "<|DSML|function_calls>"
self.tool_call_end_token: str = "</|DSML|function_calls>"
self.invoke_start_prefix: str = "<|DSML|invoke name="
self.invoke_end_token: str = "</|DSML|invoke>"
self.parameter_prefix: str = "<|DSML|parameter name="
self.parameter_end_token: str = "</|DSML|parameter>"
# Streaming state variables
self.current_tool_name_sent: bool = False
# Override base class type - we use string IDs for tool calls
self.current_tool_id: str | None = None # type: ignore
self.streamed_args_for_tool: list[str] = []
self.is_tool_call_started: bool = False
self.failed_count: int = 0
# Initialize streaming state variables
self.current_tool_index: int = 0
self.invoke_index: int = 0
self.header_sent: bool = False
self.current_function_name: str | None = None
self.current_param_name: str | None = None
self.current_param_value: str = ""
self.param_count: int = 0
self.in_param: bool = False
self.in_function: bool = False
self.json_started: bool = False
self.json_closed: bool = False
self.accumulated_params: dict = {}
self.streaming_request: ChatCompletionRequest | None = None
# Enhanced streaming state - reset for each new message
self._reset_streaming_state()
# Regex patterns for complete parsing
self.tool_call_complete_regex = re.compile(
r"<|DSML|function_calls>(.*?)</|DSML|function_calls>", re.DOTALL
)
self.invoke_complete_regex = re.compile(
r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)</|DSML|invoke>', re.DOTALL
)
self.parameter_complete_regex = re.compile(
r'<|DSML|parameter\s+name="([^"]+)"\s+string="(?:true|false)"\s*>(.*?)</|DSML|parameter>',
re.DOTALL,
)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
logger.debug(
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def _reset_streaming_state(self):
"""Reset all streaming state."""
self.current_tool_index = 0
self.invoke_index = 0
self.is_tool_call_started = False
self.header_sent = False
self.current_tool_id = None
self.current_function_name = None
self.current_param_name = None
self.current_param_value = ""
self.param_count = 0
self.in_param = False
self.in_function = False
self.json_started = False
self.json_closed = False
# Store accumulated parameters for type conversion
self.accumulated_params = {}
self.streaming_request = None
# Clear previous tool call history to avoid state pollution
self.prev_tool_call_arr.clear()
def _parse_invoke_params(self, invoke_str: str) -> dict | None:
param_dict = dict()
for param_name, param_val in self.parameter_complete_regex.findall(invoke_str):
param_dict[param_name] = param_val
return param_dict
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""Extract tool calls from complete model output (non-streaming)."""
# Quick check
if self.tool_call_start_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
tool_calls = []
# Find all complete tool_call blocks
for tool_call_match in self.tool_call_complete_regex.findall(model_output):
# Find all invokes within this tool_call
for invoke_name, invoke_content in self.invoke_complete_regex.findall(
tool_call_match
):
param_dict = self._parse_invoke_params(invoke_content)
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(
name=invoke_name,
arguments=json.dumps(param_dict, ensure_ascii=False),
),
)
)
if not tool_calls:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
# Extract content before first tool call
first_tool_idx = model_output.find(self.tool_call_start_token)
content = model_output[:first_tool_idx] if first_tool_idx > 0 else None
return ExtractedToolCallInformation(
tools_called=True, tool_calls=tool_calls, content=content
)
except Exception:
logger.exception("Error extracting tool calls")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
def _extract_name(self, name_str: str) -> str:
"""Extract name from quoted string."""
name_str = name_str.strip()
if (
name_str.startswith('"')
and name_str.endswith('"')
or name_str.startswith("'")
and name_str.endswith("'")
):
return name_str[1:-1]
return name_str
def _extract_param_name(self, input_str: str) -> str:
"""Extract param name"""
start = input_str.find('"') + 1
end = input_str.find('"', start)
return input_str[start:end] if start > 0 and end > start else input_str
def _convert_param_value(self, value: str, param_type: str) -> Any:
"""Convert parameter value to the correct type."""
if value.lower() == "null":
return None
param_type = param_type.lower()
if param_type in ["string", "str", "text"]:
return value
elif param_type in ["integer", "int"]:
try:
return int(value)
except (ValueError, TypeError):
return value
elif param_type in ["number", "float"]:
try:
val = float(value)
return val if val != int(val) else int(val)
except (ValueError, TypeError):
return value
elif param_type in ["boolean", "bool"]:
return value.lower() in ["true", "1"]
elif param_type in ["object", "array"]:
try:
return json.loads(value)
except json.JSONDecodeError:
return value
else:
# Try JSON parse first, fallback to string
try:
return json.loads(value)
except json.JSONDecodeError:
return value
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int], # pylint: disable=unused-argument
current_token_ids: Sequence[int], # pylint: disable=unused-argument
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming model output."""
# Store request for type conversion
if not previous_text:
self._reset_streaming_state()
self.streaming_request = request
# If no delta text, return None unless it's an EOS token after tools
if not delta_text:
# Check if this is an EOS token after all tool calls are complete
if delta_token_ids:
# Count complete tool calls
complete_calls = len(
self.tool_call_complete_regex.findall(current_text)
)
# If we have completed tool calls and populated prev_tool_call_arr
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token
) - current_text.count(self.tool_call_end_token)
if open_calls == 0:
# Return empty delta for finish_reason processing
return DeltaMessage(content="")
elif not self.is_tool_call_started and current_text:
# This is a regular content response that's now complete
return DeltaMessage(content="")
return None
# Check if we need to advance to next tool
if self.json_closed and not self.in_function:
# Check if this tool call has ended
invoke_ends = current_text.count(self.invoke_end_token)
if invoke_ends > self.current_tool_index:
# This tool has ended, advance to next
self.current_tool_index += 1
self.header_sent = False
self.param_count = 0
self.json_started = False
self.json_closed = False
self.in_function = False # Now we can safely set this to False
self.accumulated_params = {}
# Continue processing next tool
return None
# Handle normal content before tool calls
if not self.is_tool_call_started:
# Check if tool call is starting
if self.dsml_token in current_text:
self.is_tool_call_started = True
# Return any content before the tool call
if self.dsml_start_check in delta_text:
content_before = delta_text[
: delta_text.index(self.dsml_start_check)
]
if content_before:
return DeltaMessage(content=content_before)
return None
else:
# Check if we're between tool calls - skip whitespace
if (
current_text.rstrip().endswith(self.tool_call_end_token)
and delta_text.strip() == ""
):
# We just ended a tool call, skip whitespace
return None
# Normal content, no tool call
if delta_text.endswith("<"):
return DeltaMessage(content=delta_text[:-1])
if previous_text and previous_text.endswith("<"):
return DeltaMessage(content="<" + delta_text)
return DeltaMessage(content=delta_text)
# Check if we're between tool calls (waiting for next one)
invoke_starts_count = current_text.count(self.invoke_start_prefix)
if self.current_tool_index >= invoke_starts_count:
# We're past all tool calls, shouldn't be here
return None
# Find the current tool call portion
invoke_start_positions: list[int] = []
idx = 0
while True:
idx = current_text.find(self.invoke_start_prefix, idx)
if idx == -1:
break
invoke_start_positions.append(idx)
idx += len(self.invoke_start_prefix)
if self.current_tool_index >= len(invoke_start_positions):
# No more tool calls to process yet
return None
invoke_start_idx = invoke_start_positions[self.current_tool_index]
# Find where this tool call ends (or current position if not ended yet)
invoke_end_idx = current_text.find(self.invoke_end_token, invoke_start_idx)
if invoke_end_idx == -1:
tool_text = current_text[invoke_start_idx:]
else:
tool_text = current_text[
invoke_start_idx : invoke_end_idx + len(self.invoke_end_token)
]
# Looking for function header
if not self.header_sent:
if self.invoke_start_prefix in tool_text:
func_start = tool_text.find(self.invoke_start_prefix) + len(
self.invoke_start_prefix
)
# Find the end quote for the function name
func_end = tool_text.find(">", func_start)
if func_end != -1:
# Found complete function name
function_name_raw = tool_text[func_start:func_end]
self.current_function_name = self._extract_name(function_name_raw)
self.current_tool_id = self._generate_tool_call_id()
self.header_sent = True
self.in_function = True
# Add to prev_tool_call_arr immediately when we detect a tool call
# Each tool call should be recorded regardless of function name
# Ensure we don't add the same tool call index multiple times
if len(self.prev_tool_call_arr) <= self.current_tool_index:
self.prev_tool_call_arr.append(
{
"name": self.current_function_name,
"arguments": "{}", # Placeholder, will be updated later
}
)
# Send header with function info
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
id=self.current_tool_id,
function=DeltaFunctionCall(
name=self.current_function_name, arguments=""
),
type="function",
)
]
)
return None
# We've sent header, now handle function body
if self.in_function:
# Send opening brace if not sent yet
if self.in_function and not self.json_started:
self.json_started = True
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="{"),
)
]
)
# Make sure json_started is set if we're processing parameters
if not self.json_started:
self.json_started = True
# Check for function end in accumulated text
if not self.json_closed and self.invoke_end_token in tool_text:
# Count total parameters in the tool text
total_param_count = tool_text.count(self.parameter_prefix)
# Only close JSON if all parameters have been processed
if self.param_count >= total_param_count:
# Close JSON
self.json_closed = True
# Extract complete tool call
# Find the invoke content
invoke_start = tool_text.find(self.invoke_start_prefix) + len(
self.invoke_start_prefix
)
invoke_content_end = tool_text.find(
self.invoke_end_token, invoke_start
)
if invoke_content_end != -1:
invoke_content = tool_text[invoke_start:invoke_content_end]
# Parse to get the complete arguments
try:
invoke_params = self._parse_invoke_params(invoke_content)
if invoke_params and self.current_tool_index < len(
self.prev_tool_call_arr
):
# Update existing entry in prev_tool_call_arr
self.prev_tool_call_arr[self.current_tool_index][
"arguments"
] = json.dumps(invoke_params, ensure_ascii=False)
except Exception:
pass # Ignore parsing errors during streaming
result = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="}"),
)
]
)
# Reset state for next tool
self.json_closed = True
self.in_function = False
self.accumulated_params = {}
logger.debug("[M2_STREAMING] Tool call completed")
return result
else:
# Don't close JSON yet, continue processing parameters
return None
# Look for parameters
# Find all parameter starts
param_starts = []
idx = 0
while True:
idx = tool_text.find(self.parameter_prefix, idx)
if idx == -1:
break
param_starts.append(idx)
idx += len(self.parameter_prefix)
# Check if we should start a new parameter
if (
not self.in_param
and self.param_count < len(param_starts)
and len(param_starts) > self.param_count
):
# Process the next parameter
param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix)
remaining = tool_text[param_start:]
if ">" in remaining:
# We have the complete parameter name
name_end = remaining.find(">")
param_name_raw = remaining[:name_end]
self.current_param_name = self._extract_param_name(param_name_raw)
# Find the parameter value
value_start = param_start + name_end + 1
value_text = tool_text[value_start:]
if value_text.startswith("\n"):
value_text = value_text[1:]
# Find where this parameter ends
param_end_idx = value_text.find(self.parameter_end_token)
if param_end_idx == -1:
# No closing tag, look for next parameter or function end
next_param_idx = value_text.find(self.parameter_prefix)
func_end_idx = value_text.find(self.invoke_end_token)
if next_param_idx != -1 and (
func_end_idx == -1 or next_param_idx < func_end_idx
):
param_end_idx = next_param_idx
elif func_end_idx != -1:
param_end_idx = func_end_idx
else:
# Neither found, check if tool call is complete
if self.invoke_end_token in tool_text:
# Tool call and parameter is complete
param_end_idx = len(value_text)
else:
# Still streaming, wait for more content
return None
if param_end_idx != -1:
# Complete parameter found
param_value = value_text[:param_end_idx]
if param_value.endswith("\n"):
param_value = param_value[:-1]
# Store raw value for later processing
self.accumulated_params[self.current_param_name] = param_value
# Get parameter configuration for type conversion
param_config = {}
if self.streaming_request and self.streaming_request.tools:
for tool in self.streaming_request.tools:
if (
hasattr(tool, "function")
and tool.function.name == self.current_function_name
and hasattr(tool.function, "parameters")
):
params = tool.function.parameters
if (
isinstance(params, dict)
and "properties" in params
):
param_config = params["properties"]
break
# Get parameter type
param_type = "string"
if (
self.current_param_name in param_config
and isinstance(param_config[self.current_param_name], dict)
and "type" in param_config[self.current_param_name]
):
param_type = param_config[self.current_param_name]["type"]
# Convert param value to appropriate type
converted_value = self._convert_param_value(
param_value, param_type
)
# Build JSON fragment based on the converted type
# Use json.dumps to properly serialize the value
serialized_value = json.dumps(
converted_value, ensure_ascii=False
)
if self.param_count == 0:
json_fragment = (
f'"{self.current_param_name}": {serialized_value}'
)
else:
json_fragment = (
f', "{self.current_param_name}": {serialized_value}'
)
self.param_count += 1
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments=json_fragment),
)
]
)
return None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
logger = init_logger(__name__)
REGEX_FUNCTION_CALL = re.compile(
r"function call(?:<\|role_sep\|>\n)?(\{.*)",
re.DOTALL,
)
NAME_REGEX = re.compile(
r'"name"\s*:\s*"([^"]*)"',
re.DOTALL,
)
ARGS_REGEX = re.compile(
r'"arguments"\s*:\s*(.*)',
re.DOTALL,
)
class GigaChat3ToolParser(ToolParser):
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
self.tool_started: bool = False
self.tool_name_sent: bool = False
self.tool_id: str | None = None
self.prev_tool_call_arr: list[dict] = []
self.content_buffer: str = ""
self.trigger_start = "function call{"
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
match = REGEX_FUNCTION_CALL.search(model_output)
if not match:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
json_candidate = match.group(1).strip()
try:
data = json.loads(json_candidate)
except json.JSONDecodeError:
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
if not (isinstance(data, dict) and "name" in data and "arguments" in data):
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output,
)
name = data["name"]
args = data["arguments"]
if not isinstance(args, str):
args = json.dumps(args, ensure_ascii=False)
tool_calls = [
ToolCall(
type="function",
function=FunctionCall(
name=name,
arguments=args,
),
)
]
prefix = model_output[: match.start()]
content = prefix.rstrip() if prefix and prefix.strip() else None
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content,
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
func_name = None
cur_args = None
if not self.tool_started:
match = REGEX_FUNCTION_CALL.search(current_text)
if match:
self.tool_started = True
self.content_buffer = ""
else:
self.content_buffer += delta_text
clean_buffer = self.content_buffer.lstrip()
is_prefix = self.trigger_start.startswith(clean_buffer)
starts_with_trigger = clean_buffer.startswith(self.trigger_start)
if is_prefix or starts_with_trigger:
return None
else:
flush_text = self.content_buffer
self.content_buffer = ""
return DeltaMessage(content=flush_text)
match = REGEX_FUNCTION_CALL.search(current_text)
if not match:
return None
json_tail = match.group(1).strip()
name_match = NAME_REGEX.search(json_tail)
if name_match:
func_name = name_match.group(1)
args_match = ARGS_REGEX.search(json_tail)
if args_match:
cur_args = args_match.group(1).strip()
if cur_args.endswith("}"): # last '}' end of json
try:
candidate = cur_args[:-1].strip()
json.loads(candidate)
cur_args = candidate
except json.JSONDecodeError:
pass
if not self.prev_tool_call_arr:
self.prev_tool_call_arr.append({})
if not self.tool_name_sent:
if not func_name:
return None
self.tool_name_sent = True
self.tool_id = make_tool_call_id()
self.prev_tool_call_arr[0]["name"] = func_name
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
id=self.tool_id,
type="function",
function=DeltaFunctionCall(
name=func_name,
).model_dump(exclude_none=True),
)
],
content=None,
)
if cur_args is None:
return None
prev_args = self.prev_tool_call_arr[0].get("arguments", "")
if not prev_args:
delta_args = cur_args
elif cur_args.startswith(prev_args):
delta_args = cur_args[len(prev_args) :]
else:
return None
if not delta_args:
return None
self.prev_tool_call_arr[0]["arguments"] = cur_args
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=0,
function=DeltaFunctionCall(
arguments=delta_args,
).model_dump(exclude_none=True),
)
],
content=None,
)
......@@ -3,12 +3,12 @@
import json
from collections.abc import Sequence
from enum import Enum, auto
from random import choices
from string import ascii_letters, digits
import partial_json_parser
import ijson
import regex as re
from partial_json_parser.core.options import Allow
from pydantic import Field
from vllm.entrypoints.openai.protocol import (
......@@ -23,7 +23,6 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.entrypoints.openai.tool_parsers.utils import extract_intermediate_diff
from vllm.logger import init_logger
from vllm.tokenizers import MistralTokenizer, TokenizerLike
......@@ -32,6 +31,22 @@ logger = init_logger(__name__)
ALPHANUMERIC = ascii_letters + digits
class StreamingState(Enum):
"""Enum for tracking the current streaming parsing state."""
WAITING_FOR_TOOL_START = auto()
WAITING_FOR_TOOL_KEY = (
auto()
) # waiting for the "name" or "arguments" key to be complete
PARSING_NAME = auto()
PARSING_NAME_COMPLETED = auto()
WAITING_FOR_ARGUMENTS_START = auto()
PARSING_ARGUMENTS = auto()
PARSING_ARGUMENTS_COMPLETED = auto()
TOOL_COMPLETE = auto()
ALL_TOOLS_COMPLETE = auto()
class MistralToolCall(ToolCall):
id: str = Field(default_factory=lambda: MistralToolCall.generate_random_id())
......@@ -46,8 +61,8 @@ class MistralToolCall(ToolCall):
return id.isalnum() and len(id) == 9
def _is_fn_name_regex_support(model_tokenizer: TokenizerLike) -> bool:
return (
def _is_pre_v11_tokeniser(model_tokenizer: TokenizerLike) -> bool:
return not (
isinstance(model_tokenizer, MistralTokenizer) and model_tokenizer.version >= 11
)
......@@ -69,21 +84,22 @@ class MistralToolParser(ToolParser):
# initialize properties used for state when parsing tool calls in
# streaming mode
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.current_tool_name_sent: bool = False
self.streamed_args_for_tool: list[
str
] = [] # map what has been streamed for each tool so far to a list
self.streaming_state: StreamingState = StreamingState.WAITING_FOR_TOOL_START
# For streaming pre v11 tokenizer tool calls
self.current_tool_name: str | None = None
self.current_tool_mistral_id: str | None = None
self.starting_new_tool = False
if _is_pre_v11_tokeniser(self.model_tokenizer):
self.parse_coro = ijson.parse_coro(
self.update_stream_state_pre_v11_tokenizer()
)
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if _is_fn_name_regex_support(self.model_tokenizer):
self.fn_name_regex = re.compile(
r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL
)
else:
self.fn_name_regex = None
self._is_pre_v11 = _is_pre_v11_tokeniser(self.model_tokenizer)
if self.bot_token_id is None:
raise RuntimeError(
......@@ -127,16 +143,18 @@ class MistralToolParser(ToolParser):
tool_content = model_output.replace(self.bot_token, "").strip()
try:
# we first try to directly load the json as parsing very nested
# jsons is difficult
try:
if self.fn_name_regex:
matches = self.fn_name_regex.findall(tool_content)
if not self._is_pre_v11:
function_call_arr = []
for match in matches:
fn_name = match[0]
args = match[1]
for single_tool_content in model_output.split(self.bot_token):
if "{" not in single_tool_content:
continue
end_name = single_tool_content.find("{")
fn_name, args = (
single_tool_content[:end_name],
single_tool_content[end_name:],
)
# fn_name is encoded outside serialized json dump
# only arguments are serialized
......@@ -193,198 +211,372 @@ class MistralToolParser(ToolParser):
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
# if the tool call token is not in the tokens generated so far, append
# output to contents since it's not a tool
if self.bot_token not in current_text:
if self.bot_token_id not in current_token_ids:
# if the tool call token is not in the tokens generated so far,
# append output to contents since it's not a tool
return DeltaMessage(content=delta_text)
# if the tool call token ID IS in the tokens generated so far, that
# if the tool call token IS in the tokens generated so far, that
# means we're parsing as tool calls now
# handle if we detected the BOT token which means the start of tool
# calling
if self.bot_token_id in delta_token_ids and len(delta_token_ids) == 1:
# if it's the only token, return None, so we don't send a chat
# completion any don't send a control token
return None
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
try:
# replace BOT token with empty string, and convert single quotes
# to double to allow parsing as JSON since mistral uses single
# quotes instead of double for tool calls
parsable_arr = current_text.split(self.bot_token)[-1]
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try:
tool_call_arr: list[dict] = partial_json_parser.loads(
parsable_arr, flags
if _is_pre_v11_tokeniser(self.model_tokenizer):
return self._extract_tool_calls_streaming_pre_v11_tokenizer(
delta_text=delta_text,
delta_token_ids=delta_token_ids,
)
except partial_json_parser.core.exceptions.MalformedJSON:
logger.debug("not enough tokens to parse into JSON yet")
else:
return self._extract_tool_calls_streaming(
delta_text=delta_text, delta_token_ids=delta_token_ids
)
except Exception:
logger.exception("Error trying to handle streaming tool call.")
return None
# select as the current tool call the one we're on the state at
current_tool_call: dict = (
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
def _extract_tool_calls_streaming(
self,
delta_text: str,
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""
Extracts tool calls for Mistral models
doing tool calls of the following format:
`[TOOL_CALLS]add{"a": 3.5, "b": 4}`
"""
additional_content: str = ""
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
# this is the first tool call
assert self.bot_token_id in delta_token_ids
if not delta_text.startswith(self.bot_token):
additional_content += delta_text.split(self.bot_token)[0]
delta_text = self.bot_token + "".join(
delta_text.split(self.bot_token)[1:]
)
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if len(tool_call_arr) == 0:
delta_tool_calls = self._generate_delta_tool_call(delta_text)
if not additional_content and len(delta_tool_calls) == 0:
if self.streaming_state in [
StreamingState.PARSING_ARGUMENTS,
StreamingState.PARSING_ARGUMENTS_COMPLETED,
StreamingState.TOOL_COMPLETE,
StreamingState.ALL_TOOLS_COMPLETE,
]:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return DeltaMessage()
else:
# return None when the tool is not likely to be finished
# This can occur when the name is being parsed for example
# and we wait for the name to be complete
# before sending the function name
return None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif (
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if self.current_tool_id >= 0:
diff: str | None = current_tool_call.get("arguments")
if diff:
diff = json.dumps(diff, ensure_ascii=False).replace(
self.streamed_args_for_tool[self.current_tool_id], ""
)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff
).model_dump(exclude_none=True),
)
]
)
self.streamed_args_for_tool[self.current_tool_id] += diff
else:
delta = None
else:
delta = None
# re-set stuff pertaining to progress in the current tool
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
delta = DeltaMessage()
if additional_content:
delta.content = additional_content
if len(delta_tool_calls) > 0:
delta.tool_calls = delta_tool_calls
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining its final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if delta_tool_calls and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
return delta
# case: update an existing tool - this is handled below
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(
tool_calls=[
def _generate_delta_tool_call(self, delta_text: str) -> list[DeltaToolCall]:
if delta_text == "" or delta_text is None:
return []
delta_function_name = None
tool_id = None
if self.streaming_state not in [
StreamingState.PARSING_NAME,
StreamingState.PARSING_ARGUMENTS,
] and delta_text.startswith(self.bot_token):
self.current_tool_id += 1
self.streaming_state = StreamingState.PARSING_NAME
delta_text = delta_text.replace(self.bot_token, "", 1)
if self.streaming_state == StreamingState.PARSING_NAME:
if self.current_tool_name is None:
self.current_tool_name = ""
# The name stops where the arguments start
# And the arguments start with the `{` char
if "{" in delta_text:
tool_id = MistralToolCall.generate_random_id()
delta_function_name = delta_text.split("{")[0]
self.current_tool_name += delta_function_name
delta_text = delta_text[len(delta_function_name) :]
self.streaming_state = StreamingState.PARSING_ARGUMENTS
else:
# we want to send the tool name once it's complete
self.current_tool_name += delta_text
return []
if self.streaming_state == StreamingState.PARSING_ARGUMENTS:
next_function_text = None
if self.bot_token in delta_text:
# current tool call is over
delta_arguments = ""
delta_arguments += delta_text.split(self.bot_token)[0]
next_function_text = delta_text[len(delta_arguments) :]
self.streaming_state = StreamingState.TOOL_COMPLETE
else:
delta_arguments = delta_text
ret = []
if self.current_tool_name or delta_arguments:
ret += [
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=MistralToolCall.generate_random_id(),
id=tool_id,
function=DeltaFunctionCall(
name=function_name
name=self.current_tool_name, arguments=delta_arguments
).model_dump(exclude_none=True),
)
]
self.current_tool_name = None
if next_function_text:
ret += self._generate_delta_tool_call(next_function_text)
return ret
# Should not happen
return []
@ijson.coroutine
def update_stream_state_pre_v11_tokenizer(self):
while True:
(prefix, event, value) = yield
if prefix == "item" and event == "start_map":
self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY
if prefix == "item" and event == "map_key" and value == "name":
self.streaming_state = StreamingState.PARSING_NAME
if prefix == "item.name" and event == "string":
self.current_tool_name = value
self.streaming_state = StreamingState.PARSING_NAME_COMPLETED
if prefix == "item" and event == "map_key" and value == "arguments":
self.streaming_state = StreamingState.WAITING_FOR_ARGUMENTS_START
if prefix == "item.arguments" and event == "start_map":
self.streaming_state = StreamingState.PARSING_ARGUMENTS
if prefix == "item.arguments" and event == "end_map":
self.streaming_state = StreamingState.PARSING_ARGUMENTS_COMPLETED
if prefix == "item" and event == "end_map":
self.streaming_state = StreamingState.TOOL_COMPLETE
if prefix == "" and event == "end_array":
self.streaming_state = StreamingState.ALL_TOOLS_COMPLETE
def _extract_tool_calls_streaming_pre_v11_tokenizer(
self,
delta_text: str,
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
"""
Extracts tool calls for Mistral models
doing tool calls of the following format:
`[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}`
"""
assert self.parse_coro is not None
content = None
delta_tool_calls: list[DeltaToolCall] = []
current_tool_call: DeltaToolCall = DeltaToolCall(
index=self.current_tool_id, type="function"
)
self.current_tool_name_sent = True
else:
delta = None
# now we know we're on the same tool call and we're streaming
# arguments
else:
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments"
)
cur_arguments = current_tool_call.get("arguments")
new_text = delta_text.replace("'", '"')
if '"}' in new_text:
new_text = new_text[: new_text.rindex('"}')]
if not cur_arguments and not prev_arguments:
delta = None
elif not cur_arguments and prev_arguments:
logger.error(
"INVARIANT - impossible to have arguments reset mid-arguments"
current_tool_call_modified = False
if self.bot_token_id in delta_token_ids:
# this is the first tool call
if not delta_text.startswith(self.bot_token):
content = delta_text.split(self.bot_token)[0]
delta_text = "".join(delta_text.split(self.bot_token)[1:])
# Cut smartly the delta text to catch the ijson events
# as ijson does not give us the index in the text at each event.
# We need to cut so that we know
# where in the text the events are emitted from.
while len(delta_text) > 0:
streaming_state_before_parse = self.streaming_state
if self.streaming_state == StreamingState.WAITING_FOR_TOOL_START:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_opening_curly_braces=1,
)
delta = None
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments, ensure_ascii=False)[
:-2
]
logger.debug("finding %s in %s", new_text, cur_arguments_json)
if new_text not in cur_arguments_json:
return None
arguments_delta = cur_arguments_json[
: cur_arguments_json.rindex(new_text) + len(new_text)
]
logger.debug(
"First tokens in arguments received: %s", arguments_delta
elif self.streaming_state == StreamingState.WAITING_FOR_TOOL_KEY:
# Wait until another key is sent
# or the current tool is completed
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_colon=1,
stop_after_opening_curly_braces=1,
# if the tool ends, we want to separate
# at the start of the next tool
)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=arguments_delta
).model_dump(exclude_none=True),
elif self.streaming_state == StreamingState.PARSING_NAME:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_comma=1,
stop_after_closing_brackets=1,
)
]
elif self.streaming_state == StreamingState.WAITING_FOR_ARGUMENTS_START:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_opening_curly_braces=1,
)
self.streamed_args_for_tool[self.current_tool_id] += arguments_delta
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
logger.debug(
"Searching for diff between \n%s\n%s",
cur_args_json,
prev_args_json,
elif self.streaming_state == StreamingState.PARSING_ARGUMENTS:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_closing_curly_braces=1,
# we could be more clever
# by listening to item.arguments.* start_map events
# and know how many curly braces we can allow
)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json
elif self.streaming_state in [
StreamingState.PARSING_ARGUMENTS_COMPLETED,
StreamingState.PARSING_NAME_COMPLETED,
]:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_closing_curly_braces=1,
stop_after_closing_brackets=1,
)
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff
).model_dump(exclude_none=True),
elif self.streaming_state == StreamingState.TOOL_COMPLETE:
delta_to_be_parsed, delta_text = self._split_delta(
delta_text=delta_text,
stop_after_opening_curly_braces=1,
stop_after_closing_brackets=1,
)
elif self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE:
content = delta_text
delta_text = ""
else:
delta_to_be_parsed = delta_text
delta_text = ""
if self.streaming_state != StreamingState.ALL_TOOLS_COMPLETE:
self.parse_coro.send(delta_to_be_parsed.encode("utf-8"))
# Given the parsed text and the possible streaming state change,
# let's add to the tool delta
if (
(streaming_state_before_parse != self.streaming_state)
and streaming_state_before_parse
in [StreamingState.WAITING_FOR_TOOL_START, StreamingState.TOOL_COMPLETE]
and self.streaming_state
not in [
StreamingState.ALL_TOOLS_COMPLETE,
StreamingState.TOOL_COMPLETE,
StreamingState.WAITING_FOR_TOOL_START,
]
):
# starting a new tool call
if current_tool_call_modified:
if self.current_tool_mistral_id is not None:
current_tool_call.id = self.current_tool_mistral_id
self.current_tool_mistral_id = None
delta_tool_calls.append(current_tool_call)
current_tool_call_modified = False
self.current_tool_id += 1
self.current_tool_mistral_id = MistralToolCall.generate_random_id()
current_tool_call = DeltaToolCall(
index=self.current_tool_id,
type="function",
)
self.streamed_args_for_tool[self.current_tool_id] += argument_diff
if current_tool_call.function is None:
current_tool_call.function = DeltaFunctionCall()
if self.current_tool_name is not None:
# we have the complete tool name
current_tool_call_modified = True
current_tool_call.function.name = self.current_tool_name
self.current_tool_name = None
if self.streaming_state == StreamingState.PARSING_NAME_COMPLETED:
self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY
if self.streaming_state in [
StreamingState.PARSING_ARGUMENTS,
StreamingState.PARSING_ARGUMENTS_COMPLETED,
]:
if self.streaming_state == StreamingState.PARSING_ARGUMENTS_COMPLETED:
self.streaming_state = StreamingState.WAITING_FOR_TOOL_KEY
# the delta_to_be_parsed is part of arguments.
current_tool_call_modified = True
if current_tool_call.function.arguments is None:
current_tool_call.function.arguments = delta_to_be_parsed
else:
# try parsing it with regular JSON - if it works we're
# at the end, and we need to send the difference between
# tokens streamed so far and the valid JSON
delta = None
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
logger.debug(
"Skipping chunk as a result of tool streaming extraction error"
current_tool_call.function.arguments += delta_to_be_parsed
if streaming_state_before_parse != StreamingState.PARSING_ARGUMENTS:
# It's the first chunk of arg. let's lstrip it
current_tool_call.function.arguments = (
current_tool_call.function.arguments.lstrip()
)
if current_tool_call_modified:
if self.current_tool_mistral_id is not None:
current_tool_call.id = self.current_tool_mistral_id
self.current_tool_mistral_id = None
delta_tool_calls.append(current_tool_call)
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining it's final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if delta_tool_calls and not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
if content or len(delta_tool_calls) > 0:
delta_message = DeltaMessage()
if content:
delta_message.content = content
if len(delta_tool_calls) > 0:
delta_message.tool_calls = delta_tool_calls
return delta_message
else:
if self.streaming_state == StreamingState.ALL_TOOLS_COMPLETE:
return DeltaMessage()
else:
return None
def _split_delta(
self,
delta_text: str,
stop_after_quotes: int = -1,
stop_after_opening_curly_braces: int = -1,
stop_after_closing_curly_braces: int = -1,
stop_after_closing_brackets: int = -1,
stop_after_colon: int = -1,
stop_after_comma=-1,
) -> tuple[str, str]:
delta_to_be_parsed = ""
for i, c in enumerate(delta_text):
if c in ['"', "'"]:
delta_to_be_parsed += c
stop_after_quotes -= 1
if stop_after_quotes == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
elif c == "{":
delta_to_be_parsed += c
stop_after_opening_curly_braces -= 1
if stop_after_opening_curly_braces == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
elif c == "}":
delta_to_be_parsed += c
stop_after_closing_curly_braces -= 1
if stop_after_closing_curly_braces == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
elif c == "]":
delta_to_be_parsed += c
stop_after_closing_brackets -= 1
if stop_after_closing_brackets == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
elif c == ":":
delta_to_be_parsed += c
stop_after_colon -= 1
if stop_after_colon == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
elif c == ",":
delta_to_be_parsed += c
stop_after_comma -= 1
if stop_after_comma == 0:
return (delta_to_be_parsed, delta_text[i + 1 :])
else:
delta_to_be_parsed += c
return (delta_to_be_parsed, "")
......@@ -4,7 +4,7 @@ import json
from collections.abc import Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.harmony_utils import parse_output_into_messages
from vllm.entrypoints.openai.parser.harmony_utils import parse_output_into_messages
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaMessage,
......
......@@ -59,8 +59,8 @@ async def create_embedding(
return JSONResponse(content=generator.model_dump())
elif isinstance(generator, EmbeddingBytesResponse):
return StreamingResponse(
content=generator.body,
headers={"metadata": generator.metadata},
content=generator.content,
headers=generator.headers,
media_type=generator.media_type,
)
......
......@@ -203,6 +203,6 @@ class EmbeddingResponse(OpenAIBaseModel):
class EmbeddingBytesResponse(OpenAIBaseModel):
body: list[bytes]
metadata: str
content: list[bytes]
headers: dict[str, str] | None = None
media_type: str = "application/octet-stream"
......@@ -163,29 +163,35 @@ class EmbeddingMixin(OpenAIServing):
usage=usage,
)
def encode_bytes():
body, items, usage = encode_pooling_bytes(
def encode_bytes(bytes_only: bool) -> EmbeddingBytesResponse:
content, items, usage = encode_pooling_bytes(
pooling_outputs=final_res_batch_checked,
embed_dtype=embed_dtype,
endianness=endianness,
)
metadata = {
headers = (
None
if bytes_only
else {
"metadata": json.dumps(
{
"id": ctx.request_id,
"created": ctx.created_time,
"model": ctx.model_name,
"data": items,
"usage": usage,
}
return EmbeddingBytesResponse(
body=body,
metadata=json.dumps(metadata),
)
}
)
return EmbeddingBytesResponse(content=content, headers=headers)
if encoding_format == "float" or encoding_format == "base64":
return encode_float_base64()
elif encoding_format == "bytes":
return encode_bytes()
elif encoding_format == "bytes" or encoding_format == "bytes_only":
return encode_bytes(bytes_only=encoding_format == "bytes_only")
else:
assert_never(encoding_format)
......
......@@ -55,8 +55,8 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
return JSONResponse(content=generator.model_dump())
elif isinstance(generator, PoolingBytesResponse):
return StreamingResponse(
content=generator.body,
headers={"metadata": generator.metadata},
content=generator.content,
headers=generator.headers,
media_type=generator.media_type,
)
......
......@@ -143,6 +143,6 @@ class PoolingResponse(OpenAIBaseModel):
class PoolingBytesResponse(OpenAIBaseModel):
body: list[bytes]
metadata: str
content: list[bytes]
headers: dict[str, str] | None = None
media_type: str = "application/octet-stream"
......@@ -314,29 +314,38 @@ class OpenAIServingPooling(OpenAIServing):
usage=usage,
)
def encode_bytes():
body, items, usage = encode_pooling_bytes(
def encode_bytes(bytes_only: bool) -> PoolingBytesResponse:
content, items, usage = encode_pooling_bytes(
pooling_outputs=final_res_batch,
embed_dtype=embed_dtype,
endianness=endianness,
)
metadata = {
headers = (
None
if bytes_only
else {
"metadata": json.dumps(
{
"id": request_id,
"created": created_time,
"model": model_name,
"data": items,
"usage": usage,
}
)
}
)
return PoolingBytesResponse(
body=body,
metadata=json.dumps(metadata),
content=content,
headers=headers,
)
if encoding_format == "float" or encoding_format == "base64":
return encode_float_base64()
elif encoding_format == "bytes":
return encode_bytes()
elif encoding_format == "bytes" or encoding_format == "bytes_only":
return encode_bytes(bytes_only=encoding_format == "bytes_only")
else:
assert_never(encoding_format)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessageToolCallParam,
......@@ -10,18 +12,53 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
Function as FunctionCallTool,
)
from openai.types.responses import ResponseFunctionToolCall, ResponseOutputItem
from openai.types.responses.response import ToolChoice
from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.response_output_item import McpCall
from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from openai.types.responses.tool import Tool
from vllm import envs
from vllm.entrypoints.constants import MCP_PREFIX
from vllm.entrypoints.openai.protocol import (
ChatCompletionMessageParam,
ResponseInputOutputItem,
)
from vllm.utils import random_uuid
def make_response_output_items_from_parsable_context(
response_messages: list[ResponseInputOutputItem],
) -> list[ResponseOutputItem]:
"""Given a list of sentences, construct ResponseOutput Items."""
output_messages: list[ResponseOutputItem] = []
for message in response_messages:
if not isinstance(message, ResponseFunctionToolCallOutputItem):
output_messages.append(message)
else:
if len(output_messages) == 0:
raise ValueError(
"Cannot have a FunctionToolCallOutput before FunctionToolCall."
)
if isinstance(output_messages[-1], ResponseFunctionToolCall):
mcp_message = McpCall(
id=f"{MCP_PREFIX}{random_uuid()}",
arguments=output_messages[-1].arguments,
name=output_messages[-1].name,
server_label=output_messages[
-1
].name, # TODO: store the server label
type=f"{MCP_PREFIX}call",
status="completed",
output=message.output,
# TODO: support error output
)
output_messages[-1] = mcp_message
return output_messages
def construct_input_messages(
......@@ -62,12 +99,63 @@ def construct_input_messages(
if isinstance(request_input, str):
messages.append({"role": "user", "content": request_input})
else:
for item in request_input:
messages.append(construct_chat_message_with_tool_call(item))
input_messages = construct_chat_messages_with_tool_call(request_input)
messages.extend(input_messages)
return messages
def _maybe_combine_reasoning_and_tool_call(
item: ResponseInputOutputItem, messages: list[ChatCompletionMessageParam]
) -> ChatCompletionMessageParam | None:
"""Many models treat MCP calls and reasoning as a single message.
This function checks if the last message is a reasoning message and
the current message is a tool call"""
if not (
isinstance(item, ResponseFunctionToolCall) and item.id.startswith(MCP_PREFIX)
):
return None
if len(messages) == 0:
return None
last_message = messages[-1]
if not (
last_message.get("role") == "assistant"
and last_message.get("reasoning") is not None
):
return None
last_message["tool_calls"] = [
ChatCompletionMessageToolCallParam(
id=item.call_id,
function=FunctionCallTool(
name=item.name,
arguments=item.arguments,
),
type="function",
)
]
return last_message
def construct_chat_messages_with_tool_call(
input_messages: list[ResponseInputOutputItem],
) -> list[ChatCompletionMessageParam]:
"""This function wraps _construct_single_message_from_response_item
Because some chatMessages come from multiple response items
for example a reasoning item and a MCP tool call are two response items
but are one chat message
"""
messages: list[ChatCompletionMessageParam] = []
for item in input_messages:
maybe_combined_message = _maybe_combine_reasoning_and_tool_call(item, messages)
if maybe_combined_message is not None:
messages[-1] = maybe_combined_message
else:
messages.append(_construct_single_message_from_response_item(item))
return messages
def construct_chat_message_with_tool_call(
def _construct_single_message_from_response_item(
item: ResponseInputOutputItem,
) -> ChatCompletionMessageParam:
if isinstance(item, ResponseFunctionToolCall):
......@@ -146,3 +234,16 @@ def convert_tool_responses_to_completions_format(tool: dict) -> dict:
"type": "function",
"function": tool,
}
def construct_tool_dicts(
tools: list[Tool], tool_choice: ToolChoice
) -> list[dict[str, Any]] | None:
if tools is None or (tool_choice == "none"):
tool_dicts = None
else:
tool_dicts = [
convert_tool_responses_to_completions_format(tool.model_dump())
for tool in tools
]
return tool_dicts
......@@ -16,7 +16,6 @@ from vllm.entrypoints.openai.api_server import (
completion,
create_chat_completion,
create_completion,
health,
validate_json_request,
)
from vllm.entrypoints.openai.protocol import (
......@@ -38,6 +37,7 @@ from vllm.entrypoints.pooling.score.api_router import (
score,
)
from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest
from vllm.entrypoints.serve.instrumentator.health import health
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
......
......@@ -89,12 +89,10 @@ def parse_score_data(
data_1: str | ScoreContentPartParam,
data_2: str | ScoreContentPartParam,
model_config: ModelConfig,
tokenizer: TokenizerLike,
) -> tuple[str, str, MultiModalDataDict | None]:
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
mm_tracker = MultiModalItemTracker(model_config)
content_1 = _parse_score_content(data_1, mm_tracker)
content_2 = _parse_score_content(data_2, mm_tracker)
def ensure_str(content: _ContentPart | None) -> str:
......@@ -188,7 +186,6 @@ def get_score_prompt(
data_1,
data_2,
model_config,
tokenizer,
)
from vllm.model_executor.model_loader import get_model_cls
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import FastAPI
def register_vllm_serve_api_routers(app: FastAPI):
from vllm.entrypoints.serve.lora.api_router import (
attach_router as attach_lora_router,
)
attach_lora_router(app)
from vllm.entrypoints.serve.elastic_ep.api_router import (
attach_router as attach_elastic_ep_router,
)
attach_elastic_ep_router(app)
from vllm.entrypoints.serve.profile.api_router import (
attach_router as attach_profile_router,
)
attach_profile_router(app)
from vllm.entrypoints.serve.sleep.api_router import (
attach_router as attach_sleep_router,
)
attach_sleep_router(app)
from vllm.entrypoints.serve.tokenize.api_router import (
attach_router as attach_tokenize_router,
)
attach_tokenize_router(app)
from vllm.entrypoints.serve.disagg.api_router import (
attach_router as attach_disagg_router,
)
attach_disagg_router(app)
from vllm.entrypoints.serve.rlhf.api_router import (
attach_router as attach_rlhf_router,
)
attach_rlhf_router(app)
from vllm.entrypoints.serve.instrumentator.metrics import (
attach_router as attach_metrics_router,
)
attach_metrics_router(app)
from vllm.entrypoints.serve.instrumentator.health import (
attach_router as attach_health_router,
)
attach_health_router(app)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import json
from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.api_server import validate_json_request
from vllm.entrypoints.openai.protocol import (
ErrorResponse,
)
from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
GenerateResponse,
)
from vllm.entrypoints.serve.disagg.serving import (
ServingTokens,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.entrypoints.utils import (
load_aware_call,
with_cancellation,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def generate_tokens(request: Request) -> ServingTokens | None:
return request.app.state.serving_tokens
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
router = APIRouter()
@router.post(
"/inference/v1/generate",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def generate(request: GenerateRequest, raw_request: Request):
handler = generate_tokens(raw_request)
if handler is None:
return tokenization(raw_request).create_error_response(
message="The model does not support generate tokens API"
)
try:
generator = await handler.serve_tokens(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, GenerateResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
def attach_router(app: FastAPI):
if getattr(app.state.args, "tokens_only", False):
@router.post("/abort_requests")
async def abort_requests(raw_request: Request):
"""
Abort one or more requests. To be used in a
Disaggregated Everything setup.
"""
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}",
) from e
request_ids = body.get("request_ids")
if request_ids is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="Missing 'request_ids' in request body",
)
# Abort requests in background
asyncio.create_task(engine_client(raw_request).abort(request_ids))
return Response(status_code=200)
app.include_router(router)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from pydantic import BaseModel, Field
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProbs,
Logprob,
SamplingParams,
StreamOptions,
)
from vllm.utils import random_uuid
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
token_ids: list[int]
"""The token ids to generate text from."""
# features: MultiModalFeatureSpec
# TODO (NickLucche): implement once Renderer work is completed
features: str | None = None
"""The processed MM inputs for the model."""
sampling_params: SamplingParams
"""The sampling parameters for the model."""
model: str | None = None
stream: bool | None = False
stream_options: StreamOptions | None = None
cache_salt: str | None = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
class GenerateResponseChoice(BaseModel):
index: int
logprobs: ChatCompletionLogProbs | None = None
# per OpenAI spec this is the default
finish_reason: str | None = "stop"
token_ids: list[int] | None = None
class GenerateResponse(BaseModel):
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
choices: list[GenerateResponseChoice]
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator
......@@ -14,15 +16,17 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ErrorResponse,
GenerateRequest,
GenerateResponse,
GenerateResponseChoice,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
GenerateResponse,
GenerateResponseChoice,
)
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment