Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8d75f22e
Commit
8d75f22e
authored
Dec 13, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc1' into v0.13.0rc1-ori
parents
ce888aa4
7d80c73d
Changes
679
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1608 additions
and
250 deletions
+1608
-250
vllm/entrypoints/openai/tool_parsers/__init__.py
vllm/entrypoints/openai/tool_parsers/__init__.py
+8
-0
vllm/entrypoints/openai/tool_parsers/deepseekv32_tool_parser.py
...ntrypoints/openai/tool_parsers/deepseekv32_tool_parser.py
+591
-0
vllm/entrypoints/openai/tool_parsers/gigachat3_tool_parser.py
.../entrypoints/openai/tool_parsers/gigachat3_tool_parser.py
+190
-0
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+395
-203
vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py
vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py
+1
-1
vllm/entrypoints/pooling/embed/api_router.py
vllm/entrypoints/pooling/embed/api_router.py
+2
-2
vllm/entrypoints/pooling/embed/protocol.py
vllm/entrypoints/pooling/embed/protocol.py
+2
-2
vllm/entrypoints/pooling/embed/serving.py
vllm/entrypoints/pooling/embed/serving.py
+20
-14
vllm/entrypoints/pooling/pooling/api_router.py
vllm/entrypoints/pooling/pooling/api_router.py
+2
-2
vllm/entrypoints/pooling/pooling/protocol.py
vllm/entrypoints/pooling/pooling/protocol.py
+2
-2
vllm/entrypoints/pooling/pooling/serving.py
vllm/entrypoints/pooling/pooling/serving.py
+22
-13
vllm/entrypoints/responses_utils.py
vllm/entrypoints/responses_utils.py
+104
-3
vllm/entrypoints/sagemaker/routes.py
vllm/entrypoints/sagemaker/routes.py
+1
-1
vllm/entrypoints/score_utils.py
vllm/entrypoints/score_utils.py
+1
-4
vllm/entrypoints/serve/__init__.py
vllm/entrypoints/serve/__init__.py
+60
-0
vllm/entrypoints/serve/disagg/__init__.py
vllm/entrypoints/serve/disagg/__init__.py
+0
-0
vllm/entrypoints/serve/disagg/api_router.py
vllm/entrypoints/serve/disagg/api_router.py
+110
-0
vllm/entrypoints/serve/disagg/protocol.py
vllm/entrypoints/serve/disagg/protocol.py
+90
-0
vllm/entrypoints/serve/disagg/serving.py
vllm/entrypoints/serve/disagg/serving.py
+7
-3
vllm/entrypoints/serve/elastic_ep/__init__.py
vllm/entrypoints/serve/elastic_ep/__init__.py
+0
-0
No files found.
Too many changes to show.
To preserve performance only
679 of 679+
files are displayed.
Plain diff
Email patch
vllm/entrypoints/openai/tool_parsers/__init__.py
View file @
8d75f22e
...
...
@@ -30,6 +30,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"deepseekv31_tool_parser"
,
"DeepSeekV31ToolParser"
,
),
"deepseek_v32"
:
(
"deepseekv32_tool_parser"
,
"DeepSeekV32ToolParser"
,
),
"ernie45"
:
(
"ernie45_tool_parser"
,
"Ernie45ToolParser"
,
...
...
@@ -130,6 +134,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"xlam_tool_parser"
,
"xLAMToolParser"
,
),
"gigachat3"
:
(
"gigachat3_tool_parser"
,
"GigaChat3ToolParser"
,
),
}
...
...
vllm/entrypoints/openai/tool_parsers/deepseekv32_tool_parser.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
import
uuid
from
collections.abc
import
Sequence
from
typing
import
Any
import
regex
as
re
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
ToolCall
,
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
ToolParser
,
)
from
vllm.logger
import
init_logger
from
vllm.tokenizers
import
TokenizerLike
logger
=
init_logger
(
__name__
)
class
DeepSeekV32ToolParser
(
ToolParser
):
"""
example tool call content:
<|DSML|function_calls>
<|DSML|invoke name="get_weather">
<|DSML|parameter name="location" string="true">杭州</|DSML|parameter>
<|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter>
</|DSML|invoke>
<|DSML|invoke name="get_weather">
<|DSML|parameter name="location" string="true">北京</|DSML|parameter>
<|DSML|parameter name="date" string="true">2024-01-16</|DSML|parameter>
</|DSML|invoke>
</|DSML|function_calls>
"""
def
__init__
(
self
,
tokenizer
:
TokenizerLike
):
super
().
__init__
(
tokenizer
)
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
# Sentinel tokens
self
.
dsml_token
:
str
=
"|DSML|"
self
.
dsml_start_check
:
str
=
"<"
+
self
.
dsml_token
self
.
tool_call_start_token
:
str
=
"<|DSML|function_calls>"
self
.
tool_call_end_token
:
str
=
"</|DSML|function_calls>"
self
.
invoke_start_prefix
:
str
=
"<|DSML|invoke name="
self
.
invoke_end_token
:
str
=
"</|DSML|invoke>"
self
.
parameter_prefix
:
str
=
"<|DSML|parameter name="
self
.
parameter_end_token
:
str
=
"</|DSML|parameter>"
# Streaming state variables
self
.
current_tool_name_sent
:
bool
=
False
# Override base class type - we use string IDs for tool calls
self
.
current_tool_id
:
str
|
None
=
None
# type: ignore
self
.
streamed_args_for_tool
:
list
[
str
]
=
[]
self
.
is_tool_call_started
:
bool
=
False
self
.
failed_count
:
int
=
0
# Initialize streaming state variables
self
.
current_tool_index
:
int
=
0
self
.
invoke_index
:
int
=
0
self
.
header_sent
:
bool
=
False
self
.
current_function_name
:
str
|
None
=
None
self
.
current_param_name
:
str
|
None
=
None
self
.
current_param_value
:
str
=
""
self
.
param_count
:
int
=
0
self
.
in_param
:
bool
=
False
self
.
in_function
:
bool
=
False
self
.
json_started
:
bool
=
False
self
.
json_closed
:
bool
=
False
self
.
accumulated_params
:
dict
=
{}
self
.
streaming_request
:
ChatCompletionRequest
|
None
=
None
# Enhanced streaming state - reset for each new message
self
.
_reset_streaming_state
()
# Regex patterns for complete parsing
self
.
tool_call_complete_regex
=
re
.
compile
(
r
"<|DSML|function_calls>(.*?)</|DSML|function_calls>"
,
re
.
DOTALL
)
self
.
invoke_complete_regex
=
re
.
compile
(
r
'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)</|DSML|invoke>'
,
re
.
DOTALL
)
self
.
parameter_complete_regex
=
re
.
compile
(
r
'<|DSML|parameter\s+name="([^"]+)"\s+string="(?:true|false)"\s*>(.*?)</|DSML|parameter>'
,
re
.
DOTALL
,
)
if
not
self
.
model_tokenizer
:
raise
ValueError
(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction."
)
logger
.
debug
(
"vLLM Successfully import tool parser %s !"
,
self
.
__class__
.
__name__
)
def
_generate_tool_call_id
(
self
)
->
str
:
"""Generate a unique tool call ID."""
return
f
"call_
{
uuid
.
uuid4
().
hex
[:
24
]
}
"
def
_reset_streaming_state
(
self
):
"""Reset all streaming state."""
self
.
current_tool_index
=
0
self
.
invoke_index
=
0
self
.
is_tool_call_started
=
False
self
.
header_sent
=
False
self
.
current_tool_id
=
None
self
.
current_function_name
=
None
self
.
current_param_name
=
None
self
.
current_param_value
=
""
self
.
param_count
=
0
self
.
in_param
=
False
self
.
in_function
=
False
self
.
json_started
=
False
self
.
json_closed
=
False
# Store accumulated parameters for type conversion
self
.
accumulated_params
=
{}
self
.
streaming_request
=
None
# Clear previous tool call history to avoid state pollution
self
.
prev_tool_call_arr
.
clear
()
def
_parse_invoke_params
(
self
,
invoke_str
:
str
)
->
dict
|
None
:
param_dict
=
dict
()
for
param_name
,
param_val
in
self
.
parameter_complete_regex
.
findall
(
invoke_str
):
param_dict
[
param_name
]
=
param_val
return
param_dict
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
"""Extract tool calls from complete model output (non-streaming)."""
# Quick check
if
self
.
tool_call_start_token
not
in
model_output
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
try
:
tool_calls
=
[]
# Find all complete tool_call blocks
for
tool_call_match
in
self
.
tool_call_complete_regex
.
findall
(
model_output
):
# Find all invokes within this tool_call
for
invoke_name
,
invoke_content
in
self
.
invoke_complete_regex
.
findall
(
tool_call_match
):
param_dict
=
self
.
_parse_invoke_params
(
invoke_content
)
tool_calls
.
append
(
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
invoke_name
,
arguments
=
json
.
dumps
(
param_dict
,
ensure_ascii
=
False
),
),
)
)
if
not
tool_calls
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
# Extract content before first tool call
first_tool_idx
=
model_output
.
find
(
self
.
tool_call_start_token
)
content
=
model_output
[:
first_tool_idx
]
if
first_tool_idx
>
0
else
None
return
ExtractedToolCallInformation
(
tools_called
=
True
,
tool_calls
=
tool_calls
,
content
=
content
)
except
Exception
:
logger
.
exception
(
"Error extracting tool calls"
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
_extract_name
(
self
,
name_str
:
str
)
->
str
:
"""Extract name from quoted string."""
name_str
=
name_str
.
strip
()
if
(
name_str
.
startswith
(
'"'
)
and
name_str
.
endswith
(
'"'
)
or
name_str
.
startswith
(
"'"
)
and
name_str
.
endswith
(
"'"
)
):
return
name_str
[
1
:
-
1
]
return
name_str
def
_extract_param_name
(
self
,
input_str
:
str
)
->
str
:
"""Extract param name"""
start
=
input_str
.
find
(
'"'
)
+
1
end
=
input_str
.
find
(
'"'
,
start
)
return
input_str
[
start
:
end
]
if
start
>
0
and
end
>
start
else
input_str
def
_convert_param_value
(
self
,
value
:
str
,
param_type
:
str
)
->
Any
:
"""Convert parameter value to the correct type."""
if
value
.
lower
()
==
"null"
:
return
None
param_type
=
param_type
.
lower
()
if
param_type
in
[
"string"
,
"str"
,
"text"
]:
return
value
elif
param_type
in
[
"integer"
,
"int"
]:
try
:
return
int
(
value
)
except
(
ValueError
,
TypeError
):
return
value
elif
param_type
in
[
"number"
,
"float"
]:
try
:
val
=
float
(
value
)
return
val
if
val
!=
int
(
val
)
else
int
(
val
)
except
(
ValueError
,
TypeError
):
return
value
elif
param_type
in
[
"boolean"
,
"bool"
]:
return
value
.
lower
()
in
[
"true"
,
"1"
]
elif
param_type
in
[
"object"
,
"array"
]:
try
:
return
json
.
loads
(
value
)
except
json
.
JSONDecodeError
:
return
value
else
:
# Try JSON parse first, fallback to string
try
:
return
json
.
loads
(
value
)
except
json
.
JSONDecodeError
:
return
value
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
# pylint: disable=unused-argument
current_token_ids
:
Sequence
[
int
],
# pylint: disable=unused-argument
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
DeltaMessage
|
None
:
"""Extract tool calls from streaming model output."""
# Store request for type conversion
if
not
previous_text
:
self
.
_reset_streaming_state
()
self
.
streaming_request
=
request
# If no delta text, return None unless it's an EOS token after tools
if
not
delta_text
:
# Check if this is an EOS token after all tool calls are complete
if
delta_token_ids
:
# Count complete tool calls
complete_calls
=
len
(
self
.
tool_call_complete_regex
.
findall
(
current_text
)
)
# If we have completed tool calls and populated prev_tool_call_arr
if
complete_calls
>
0
and
len
(
self
.
prev_tool_call_arr
)
>
0
:
# Check if all tool calls are closed
open_calls
=
current_text
.
count
(
self
.
tool_call_start_token
)
-
current_text
.
count
(
self
.
tool_call_end_token
)
if
open_calls
==
0
:
# Return empty delta for finish_reason processing
return
DeltaMessage
(
content
=
""
)
elif
not
self
.
is_tool_call_started
and
current_text
:
# This is a regular content response that's now complete
return
DeltaMessage
(
content
=
""
)
return
None
# Check if we need to advance to next tool
if
self
.
json_closed
and
not
self
.
in_function
:
# Check if this tool call has ended
invoke_ends
=
current_text
.
count
(
self
.
invoke_end_token
)
if
invoke_ends
>
self
.
current_tool_index
:
# This tool has ended, advance to next
self
.
current_tool_index
+=
1
self
.
header_sent
=
False
self
.
param_count
=
0
self
.
json_started
=
False
self
.
json_closed
=
False
self
.
in_function
=
False
# Now we can safely set this to False
self
.
accumulated_params
=
{}
# Continue processing next tool
return
None
# Handle normal content before tool calls
if
not
self
.
is_tool_call_started
:
# Check if tool call is starting
if
self
.
dsml_token
in
current_text
:
self
.
is_tool_call_started
=
True
# Return any content before the tool call
if
self
.
dsml_start_check
in
delta_text
:
content_before
=
delta_text
[
:
delta_text
.
index
(
self
.
dsml_start_check
)
]
if
content_before
:
return
DeltaMessage
(
content
=
content_before
)
return
None
else
:
# Check if we're between tool calls - skip whitespace
if
(
current_text
.
rstrip
().
endswith
(
self
.
tool_call_end_token
)
and
delta_text
.
strip
()
==
""
):
# We just ended a tool call, skip whitespace
return
None
# Normal content, no tool call
if
delta_text
.
endswith
(
"<"
):
return
DeltaMessage
(
content
=
delta_text
[:
-
1
])
if
previous_text
and
previous_text
.
endswith
(
"<"
):
return
DeltaMessage
(
content
=
"<"
+
delta_text
)
return
DeltaMessage
(
content
=
delta_text
)
# Check if we're between tool calls (waiting for next one)
invoke_starts_count
=
current_text
.
count
(
self
.
invoke_start_prefix
)
if
self
.
current_tool_index
>=
invoke_starts_count
:
# We're past all tool calls, shouldn't be here
return
None
# Find the current tool call portion
invoke_start_positions
:
list
[
int
]
=
[]
idx
=
0
while
True
:
idx
=
current_text
.
find
(
self
.
invoke_start_prefix
,
idx
)
if
idx
==
-
1
:
break
invoke_start_positions
.
append
(
idx
)
idx
+=
len
(
self
.
invoke_start_prefix
)
if
self
.
current_tool_index
>=
len
(
invoke_start_positions
):
# No more tool calls to process yet
return
None
invoke_start_idx
=
invoke_start_positions
[
self
.
current_tool_index
]
# Find where this tool call ends (or current position if not ended yet)
invoke_end_idx
=
current_text
.
find
(
self
.
invoke_end_token
,
invoke_start_idx
)
if
invoke_end_idx
==
-
1
:
tool_text
=
current_text
[
invoke_start_idx
:]
else
:
tool_text
=
current_text
[
invoke_start_idx
:
invoke_end_idx
+
len
(
self
.
invoke_end_token
)
]
# Looking for function header
if
not
self
.
header_sent
:
if
self
.
invoke_start_prefix
in
tool_text
:
func_start
=
tool_text
.
find
(
self
.
invoke_start_prefix
)
+
len
(
self
.
invoke_start_prefix
)
# Find the end quote for the function name
func_end
=
tool_text
.
find
(
">"
,
func_start
)
if
func_end
!=
-
1
:
# Found complete function name
function_name_raw
=
tool_text
[
func_start
:
func_end
]
self
.
current_function_name
=
self
.
_extract_name
(
function_name_raw
)
self
.
current_tool_id
=
self
.
_generate_tool_call_id
()
self
.
header_sent
=
True
self
.
in_function
=
True
# Add to prev_tool_call_arr immediately when we detect a tool call
# Each tool call should be recorded regardless of function name
# Ensure we don't add the same tool call index multiple times
if
len
(
self
.
prev_tool_call_arr
)
<=
self
.
current_tool_index
:
self
.
prev_tool_call_arr
.
append
(
{
"name"
:
self
.
current_function_name
,
"arguments"
:
"{}"
,
# Placeholder, will be updated later
}
)
# Send header with function info
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
id
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
name
=
self
.
current_function_name
,
arguments
=
""
),
type
=
"function"
,
)
]
)
return
None
# We've sent header, now handle function body
if
self
.
in_function
:
# Send opening brace if not sent yet
if
self
.
in_function
and
not
self
.
json_started
:
self
.
json_started
=
True
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
function
=
DeltaFunctionCall
(
arguments
=
"{"
),
)
]
)
# Make sure json_started is set if we're processing parameters
if
not
self
.
json_started
:
self
.
json_started
=
True
# Check for function end in accumulated text
if
not
self
.
json_closed
and
self
.
invoke_end_token
in
tool_text
:
# Count total parameters in the tool text
total_param_count
=
tool_text
.
count
(
self
.
parameter_prefix
)
# Only close JSON if all parameters have been processed
if
self
.
param_count
>=
total_param_count
:
# Close JSON
self
.
json_closed
=
True
# Extract complete tool call
# Find the invoke content
invoke_start
=
tool_text
.
find
(
self
.
invoke_start_prefix
)
+
len
(
self
.
invoke_start_prefix
)
invoke_content_end
=
tool_text
.
find
(
self
.
invoke_end_token
,
invoke_start
)
if
invoke_content_end
!=
-
1
:
invoke_content
=
tool_text
[
invoke_start
:
invoke_content_end
]
# Parse to get the complete arguments
try
:
invoke_params
=
self
.
_parse_invoke_params
(
invoke_content
)
if
invoke_params
and
self
.
current_tool_index
<
len
(
self
.
prev_tool_call_arr
):
# Update existing entry in prev_tool_call_arr
self
.
prev_tool_call_arr
[
self
.
current_tool_index
][
"arguments"
]
=
json
.
dumps
(
invoke_params
,
ensure_ascii
=
False
)
except
Exception
:
pass
# Ignore parsing errors during streaming
result
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
function
=
DeltaFunctionCall
(
arguments
=
"}"
),
)
]
)
# Reset state for next tool
self
.
json_closed
=
True
self
.
in_function
=
False
self
.
accumulated_params
=
{}
logger
.
debug
(
"[M2_STREAMING] Tool call completed"
)
return
result
else
:
# Don't close JSON yet, continue processing parameters
return
None
# Look for parameters
# Find all parameter starts
param_starts
=
[]
idx
=
0
while
True
:
idx
=
tool_text
.
find
(
self
.
parameter_prefix
,
idx
)
if
idx
==
-
1
:
break
param_starts
.
append
(
idx
)
idx
+=
len
(
self
.
parameter_prefix
)
# Check if we should start a new parameter
if
(
not
self
.
in_param
and
self
.
param_count
<
len
(
param_starts
)
and
len
(
param_starts
)
>
self
.
param_count
):
# Process the next parameter
param_idx
=
param_starts
[
self
.
param_count
]
param_start
=
param_idx
+
len
(
self
.
parameter_prefix
)
remaining
=
tool_text
[
param_start
:]
if
">"
in
remaining
:
# We have the complete parameter name
name_end
=
remaining
.
find
(
">"
)
param_name_raw
=
remaining
[:
name_end
]
self
.
current_param_name
=
self
.
_extract_param_name
(
param_name_raw
)
# Find the parameter value
value_start
=
param_start
+
name_end
+
1
value_text
=
tool_text
[
value_start
:]
if
value_text
.
startswith
(
"
\n
"
):
value_text
=
value_text
[
1
:]
# Find where this parameter ends
param_end_idx
=
value_text
.
find
(
self
.
parameter_end_token
)
if
param_end_idx
==
-
1
:
# No closing tag, look for next parameter or function end
next_param_idx
=
value_text
.
find
(
self
.
parameter_prefix
)
func_end_idx
=
value_text
.
find
(
self
.
invoke_end_token
)
if
next_param_idx
!=
-
1
and
(
func_end_idx
==
-
1
or
next_param_idx
<
func_end_idx
):
param_end_idx
=
next_param_idx
elif
func_end_idx
!=
-
1
:
param_end_idx
=
func_end_idx
else
:
# Neither found, check if tool call is complete
if
self
.
invoke_end_token
in
tool_text
:
# Tool call and parameter is complete
param_end_idx
=
len
(
value_text
)
else
:
# Still streaming, wait for more content
return
None
if
param_end_idx
!=
-
1
:
# Complete parameter found
param_value
=
value_text
[:
param_end_idx
]
if
param_value
.
endswith
(
"
\n
"
):
param_value
=
param_value
[:
-
1
]
# Store raw value for later processing
self
.
accumulated_params
[
self
.
current_param_name
]
=
param_value
# Get parameter configuration for type conversion
param_config
=
{}
if
self
.
streaming_request
and
self
.
streaming_request
.
tools
:
for
tool
in
self
.
streaming_request
.
tools
:
if
(
hasattr
(
tool
,
"function"
)
and
tool
.
function
.
name
==
self
.
current_function_name
and
hasattr
(
tool
.
function
,
"parameters"
)
):
params
=
tool
.
function
.
parameters
if
(
isinstance
(
params
,
dict
)
and
"properties"
in
params
):
param_config
=
params
[
"properties"
]
break
# Get parameter type
param_type
=
"string"
if
(
self
.
current_param_name
in
param_config
and
isinstance
(
param_config
[
self
.
current_param_name
],
dict
)
and
"type"
in
param_config
[
self
.
current_param_name
]
):
param_type
=
param_config
[
self
.
current_param_name
][
"type"
]
# Convert param value to appropriate type
converted_value
=
self
.
_convert_param_value
(
param_value
,
param_type
)
# Build JSON fragment based on the converted type
# Use json.dumps to properly serialize the value
serialized_value
=
json
.
dumps
(
converted_value
,
ensure_ascii
=
False
)
if
self
.
param_count
==
0
:
json_fragment
=
(
f
'"
{
self
.
current_param_name
}
":
{
serialized_value
}
'
)
else
:
json_fragment
=
(
f
', "
{
self
.
current_param_name
}
":
{
serialized_value
}
'
)
self
.
param_count
+=
1
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
function
=
DeltaFunctionCall
(
arguments
=
json_fragment
),
)
]
)
return
None
vllm/entrypoints/openai/tool_parsers/gigachat3_tool_parser.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
from
collections.abc
import
Sequence
import
regex
as
re
from
vllm.entrypoints.chat_utils
import
make_tool_call_id
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
ToolCall
,
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
ToolParser
from
vllm.logger
import
init_logger
from
vllm.tokenizers
import
TokenizerLike
logger
=
init_logger
(
__name__
)
REGEX_FUNCTION_CALL
=
re
.
compile
(
r
"function call(?:<\|role_sep\|>\n)?(\{.*)"
,
re
.
DOTALL
,
)
NAME_REGEX
=
re
.
compile
(
r
'"name"\s*:\s*"([^"]*)"'
,
re
.
DOTALL
,
)
ARGS_REGEX
=
re
.
compile
(
r
'"arguments"\s*:\s*(.*)'
,
re
.
DOTALL
,
)
class
GigaChat3ToolParser
(
ToolParser
):
def
__init__
(
self
,
tokenizer
:
TokenizerLike
):
super
().
__init__
(
tokenizer
)
self
.
tool_started
:
bool
=
False
self
.
tool_name_sent
:
bool
=
False
self
.
tool_id
:
str
|
None
=
None
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
self
.
content_buffer
:
str
=
""
self
.
trigger_start
=
"function call{"
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
match
=
REGEX_FUNCTION_CALL
.
search
(
model_output
)
if
not
match
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
,
)
json_candidate
=
match
.
group
(
1
).
strip
()
try
:
data
=
json
.
loads
(
json_candidate
)
except
json
.
JSONDecodeError
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
,
)
if
not
(
isinstance
(
data
,
dict
)
and
"name"
in
data
and
"arguments"
in
data
):
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
,
)
name
=
data
[
"name"
]
args
=
data
[
"arguments"
]
if
not
isinstance
(
args
,
str
):
args
=
json
.
dumps
(
args
,
ensure_ascii
=
False
)
tool_calls
=
[
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
name
,
arguments
=
args
,
),
)
]
prefix
=
model_output
[:
match
.
start
()]
content
=
prefix
.
rstrip
()
if
prefix
and
prefix
.
strip
()
else
None
return
ExtractedToolCallInformation
(
tools_called
=
True
,
tool_calls
=
tool_calls
,
content
=
content
,
)
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
DeltaMessage
|
None
:
func_name
=
None
cur_args
=
None
if
not
self
.
tool_started
:
match
=
REGEX_FUNCTION_CALL
.
search
(
current_text
)
if
match
:
self
.
tool_started
=
True
self
.
content_buffer
=
""
else
:
self
.
content_buffer
+=
delta_text
clean_buffer
=
self
.
content_buffer
.
lstrip
()
is_prefix
=
self
.
trigger_start
.
startswith
(
clean_buffer
)
starts_with_trigger
=
clean_buffer
.
startswith
(
self
.
trigger_start
)
if
is_prefix
or
starts_with_trigger
:
return
None
else
:
flush_text
=
self
.
content_buffer
self
.
content_buffer
=
""
return
DeltaMessage
(
content
=
flush_text
)
match
=
REGEX_FUNCTION_CALL
.
search
(
current_text
)
if
not
match
:
return
None
json_tail
=
match
.
group
(
1
).
strip
()
name_match
=
NAME_REGEX
.
search
(
json_tail
)
if
name_match
:
func_name
=
name_match
.
group
(
1
)
args_match
=
ARGS_REGEX
.
search
(
json_tail
)
if
args_match
:
cur_args
=
args_match
.
group
(
1
).
strip
()
if
cur_args
.
endswith
(
"}"
):
# last '}' end of json
try
:
candidate
=
cur_args
[:
-
1
].
strip
()
json
.
loads
(
candidate
)
cur_args
=
candidate
except
json
.
JSONDecodeError
:
pass
if
not
self
.
prev_tool_call_arr
:
self
.
prev_tool_call_arr
.
append
({})
if
not
self
.
tool_name_sent
:
if
not
func_name
:
return
None
self
.
tool_name_sent
=
True
self
.
tool_id
=
make_tool_call_id
()
self
.
prev_tool_call_arr
[
0
][
"name"
]
=
func_name
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
0
,
id
=
self
.
tool_id
,
type
=
"function"
,
function
=
DeltaFunctionCall
(
name
=
func_name
,
).
model_dump
(
exclude_none
=
True
),
)
],
content
=
None
,
)
if
cur_args
is
None
:
return
None
prev_args
=
self
.
prev_tool_call_arr
[
0
].
get
(
"arguments"
,
""
)
if
not
prev_args
:
delta_args
=
cur_args
elif
cur_args
.
startswith
(
prev_args
):
delta_args
=
cur_args
[
len
(
prev_args
)
:]
else
:
return
None
if
not
delta_args
:
return
None
self
.
prev_tool_call_arr
[
0
][
"arguments"
]
=
cur_args
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
0
,
function
=
DeltaFunctionCall
(
arguments
=
delta_args
,
).
model_dump
(
exclude_none
=
True
),
)
],
content
=
None
,
)
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
8d75f22e
...
...
@@ -3,12 +3,12 @@
import
json
from
collections.abc
import
Sequence
from
enum
import
Enum
,
auto
from
random
import
choices
from
string
import
ascii_letters
,
digits
import
partial_json_parser
import
ijson
import
regex
as
re
from
partial_json_parser.core.options
import
Allow
from
pydantic
import
Field
from
vllm.entrypoints.openai.protocol
import
(
...
...
@@ -23,7 +23,6 @@ from vllm.entrypoints.openai.protocol import (
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
ToolParser
,
)
from
vllm.entrypoints.openai.tool_parsers.utils
import
extract_intermediate_diff
from
vllm.logger
import
init_logger
from
vllm.tokenizers
import
MistralTokenizer
,
TokenizerLike
...
...
@@ -32,6 +31,22 @@ logger = init_logger(__name__)
ALPHANUMERIC
=
ascii_letters
+
digits
class
StreamingState
(
Enum
):
"""Enum for tracking the current streaming parsing state."""
WAITING_FOR_TOOL_START
=
auto
()
WAITING_FOR_TOOL_KEY
=
(
auto
()
)
# waiting for the "name" or "arguments" key to be complete
PARSING_NAME
=
auto
()
PARSING_NAME_COMPLETED
=
auto
()
WAITING_FOR_ARGUMENTS_START
=
auto
()
PARSING_ARGUMENTS
=
auto
()
PARSING_ARGUMENTS_COMPLETED
=
auto
()
TOOL_COMPLETE
=
auto
()
ALL_TOOLS_COMPLETE
=
auto
()
class
MistralToolCall
(
ToolCall
):
id
:
str
=
Field
(
default_factory
=
lambda
:
MistralToolCall
.
generate_random_id
())
...
...
@@ -46,8 +61,8 @@ class MistralToolCall(ToolCall):
return
id
.
isalnum
()
and
len
(
id
)
==
9
def
_is_
fn_name_regex_support
(
model_tokenizer
:
TokenizerLike
)
->
bool
:
return
(
def
_is_
pre_v11_tokeniser
(
model_tokenizer
:
TokenizerLike
)
->
bool
:
return
not
(
isinstance
(
model_tokenizer
,
MistralTokenizer
)
and
model_tokenizer
.
version
>=
11
)
...
...
@@ -69,21 +84,22 @@ class MistralToolParser(ToolParser):
# initialize properties used for state when parsing tool calls in
# streaming mode
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
self
.
current_tool_id
:
int
=
-
1
self
.
current_tool_name_sent
:
bool
=
False
self
.
streamed_args_for_tool
:
list
[
str
]
=
[]
# map what has been streamed for each tool so far to a list
self
.
streaming_state
:
StreamingState
=
StreamingState
.
WAITING_FOR_TOOL_START
# For streaming pre v11 tokenizer tool calls
self
.
current_tool_name
:
str
|
None
=
None
self
.
current_tool_mistral_id
:
str
|
None
=
None
self
.
starting_new_tool
=
False
if
_is_pre_v11_tokeniser
(
self
.
model_tokenizer
):
self
.
parse_coro
=
ijson
.
parse_coro
(
self
.
update_stream_state_pre_v11_tokenizer
()
)
self
.
bot_token
=
"[TOOL_CALLS]"
self
.
bot_token_id
=
self
.
vocab
.
get
(
self
.
bot_token
)
self
.
tool_call_regex
=
re
.
compile
(
r
"\[{.*}\]"
,
re
.
DOTALL
)
if
_is_fn_name_regex_support
(
self
.
model_tokenizer
):
self
.
fn_name_regex
=
re
.
compile
(
r
"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)"
,
re
.
DOTALL
)
else
:
self
.
fn_name_regex
=
None
self
.
_is_pre_v11
=
_is_pre_v11_tokeniser
(
self
.
model_tokenizer
)
if
self
.
bot_token_id
is
None
:
raise
RuntimeError
(
...
...
@@ -127,16 +143,18 @@ class MistralToolParser(ToolParser):
tool_content
=
model_output
.
replace
(
self
.
bot_token
,
""
).
strip
()
try
:
# we first try to directly load the json as parsing very nested
# jsons is difficult
try
:
if
self
.
fn_name_regex
:
matches
=
self
.
fn_name_regex
.
findall
(
tool_content
)
if
not
self
.
_is_pre_v11
:
function_call_arr
=
[]
for
match
in
matches
:
fn_name
=
match
[
0
]
args
=
match
[
1
]
for
single_tool_content
in
model_output
.
split
(
self
.
bot_token
):
if
"{"
not
in
single_tool_content
:
continue
end_name
=
single_tool_content
.
find
(
"{"
)
fn_name
,
args
=
(
single_tool_content
[:
end_name
],
single_tool_content
[
end_name
:],
)
# fn_name is encoded outside serialized json dump
# only arguments are serialized
...
...
@@ -193,198 +211,372 @@ class MistralToolParser(ToolParser):
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
DeltaMessage
|
None
:
#
if
the tool call
token
is
not in
the tokens generated so far, append
# output to contents since it's not a tool
if
self
.
bot_token
not
in
current_text
:
if
self
.
bot_
token
_id
not
in
current_token_ids
:
# if the tool call token is not in the tokens generated so far,
# append output to contents since it's not a tool
return
DeltaMessage
(
content
=
delta_text
)
# if the tool call token
ID
IS in the tokens generated so far, that
# if the tool call token IS in the tokens generated so far, that
# means we're parsing as tool calls now
# handle if we detected the BOT token which means the start of tool
# calling
if
self
.
bot_token_id
in
delta_token_ids
and
len
(
delta_token_ids
)
==
1
:
# if it's the only token, return None, so we don't send a chat
# completion any don't send a control token
return
None
# bit mask flags for partial JSON parsing. If the name hasn't been
# sent yet, don't allow sending
# an incomplete string since OpenAI only ever (as far as I have
# seen) allows sending the entire tool/ function name at once.
flags
=
Allow
.
ALL
if
self
.
current_tool_name_sent
else
Allow
.
ALL
&
~
Allow
.
STR
try
:
# replace BOT token with empty string, and convert single quotes
# to double to allow parsing as JSON since mistral uses single
# quotes instead of double for tool calls
parsable_arr
=
current_text
.
split
(
self
.
bot_token
)[
-
1
]
# tool calls are generated in an array, so do partial JSON
# parsing on the entire array
try
:
tool_call_arr
:
list
[
dict
]
=
partial_json_parser
.
loads
(
parsable_arr
,
flags
if
_is_pre_v11_tokeniser
(
self
.
model_tokenizer
):
return
self
.
_extract_tool_calls_streaming_pre_v11_tokenizer
(
delta_text
=
delta_text
,
delta_token_ids
=
delta_token_ids
,
)
except
partial_json_parser
.
core
.
exceptions
.
MalformedJSON
:
logger
.
debug
(
"not enough tokens to parse into JSON yet"
)
else
:
return
self
.
_extract_tool_calls_streaming
(
delta_text
=
delta_text
,
delta_token_ids
=
delta_token_ids
)
except
Exception
:
logger
.
exception
(
"Error trying to handle streaming tool call."
)
return
None
# select as the current tool call the one we're on the state at
current_tool_call
:
dict
=
(
tool_call_arr
[
self
.
current_tool_id
]
if
len
(
tool_call_arr
)
>
0
else
{}
def
_extract_tool_calls_streaming
(
self
,
delta_text
:
str
,
delta_token_ids
:
Sequence
[
int
],
)
->
DeltaMessage
|
None
:
"""
Extracts tool calls for Mistral models
doing tool calls of the following format:
`[TOOL_CALLS]add{"a": 3.5, "b": 4}`
"""
additional_content
:
str
=
""
if
self
.
streaming_state
==
StreamingState
.
WAITING_FOR_TOOL_START
:
# this is the first tool call
assert
self
.
bot_token_id
in
delta_token_ids
if
not
delta_text
.
startswith
(
self
.
bot_token
):
additional_content
+=
delta_text
.
split
(
self
.
bot_token
)[
0
]
delta_text
=
self
.
bot_token
+
""
.
join
(
delta_text
.
split
(
self
.
bot_token
)[
1
:]
)
# case -- if no tokens have been streamed for the tool, e.g.
# only the array brackets, stream nothing
if
len
(
tool_call_arr
)
==
0
:
delta_tool_calls
=
self
.
_generate_delta_tool_call
(
delta_text
)
if
not
additional_content
and
len
(
delta_tool_calls
)
==
0
:
if
self
.
streaming_state
in
[
StreamingState
.
PARSING_ARGUMENTS
,
StreamingState
.
PARSING_ARGUMENTS_COMPLETED
,
StreamingState
.
TOOL_COMPLETE
,
StreamingState
.
ALL_TOOLS_COMPLETE
,
]:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return
DeltaMessage
()
else
:
# return None when the tool is not likely to be finished
# This can occur when the name is being parsed for example
# and we wait for the name to be complete
# before sending the function name
return
None
# case: we are starting a new tool in the array
# -> array has > 0 length AND length has moved past cursor
elif
(
len
(
tool_call_arr
)
>
0
and
len
(
tool_call_arr
)
>
self
.
current_tool_id
+
1
):
# if we're moving on to a new call, first make sure we
# haven't missed anything in the previous one that was
# auto-generated due to JSON completions, but wasn't
# streamed to the client yet.
if
self
.
current_tool_id
>=
0
:
diff
:
str
|
None
=
current_tool_call
.
get
(
"arguments"
)
if
diff
:
diff
=
json
.
dumps
(
diff
,
ensure_ascii
=
False
).
replace
(
self
.
streamed_args_for_tool
[
self
.
current_tool_id
],
""
)
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
diff
).
model_dump
(
exclude_none
=
True
),
)
]
)
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
diff
else
:
delta
=
None
else
:
delta
=
None
# re-set stuff pertaining to progress in the current tool
self
.
current_tool_id
=
len
(
tool_call_arr
)
-
1
self
.
current_tool_name_sent
=
False
self
.
streamed_args_for_tool
.
append
(
""
)
logger
.
debug
(
"starting on new tool %d"
,
self
.
current_tool_id
)
delta
=
DeltaMessage
()
if
additional_content
:
delta
.
content
=
additional_content
if
len
(
delta_tool_calls
)
>
0
:
delta
.
tool_calls
=
delta_tool_calls
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining its final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if
delta_tool_calls
and
not
self
.
prev_tool_call_arr
:
self
.
prev_tool_call_arr
=
[{
"arguments"
:
{}}]
return
delta
# case: update an existing tool - this is handled below
# if the current tool name hasn't been sent, send if available
# - otherwise send nothing
if
not
self
.
current_tool_name_sent
:
function_name
=
current_tool_call
.
get
(
"name"
)
if
function_name
:
delta
=
DeltaMessage
(
tool_calls
=
[
def
_generate_delta_tool_call
(
self
,
delta_text
:
str
)
->
list
[
DeltaToolCall
]:
if
delta_text
==
""
or
delta_text
is
None
:
return
[]
delta_function_name
=
None
tool_id
=
None
if
self
.
streaming_state
not
in
[
StreamingState
.
PARSING_NAME
,
StreamingState
.
PARSING_ARGUMENTS
,
]
and
delta_text
.
startswith
(
self
.
bot_token
):
self
.
current_tool_id
+=
1
self
.
streaming_state
=
StreamingState
.
PARSING_NAME
delta_text
=
delta_text
.
replace
(
self
.
bot_token
,
""
,
1
)
if
self
.
streaming_state
==
StreamingState
.
PARSING_NAME
:
if
self
.
current_tool_name
is
None
:
self
.
current_tool_name
=
""
# The name stops where the arguments start
# And the arguments start with the `{` char
if
"{"
in
delta_text
:
tool_id
=
MistralToolCall
.
generate_random_id
()
delta_function_name
=
delta_text
.
split
(
"{"
)[
0
]
self
.
current_tool_name
+=
delta_function_name
delta_text
=
delta_text
[
len
(
delta_function_name
)
:]
self
.
streaming_state
=
StreamingState
.
PARSING_ARGUMENTS
else
:
# we want to send the tool name once it's complete
self
.
current_tool_name
+=
delta_text
return
[]
if
self
.
streaming_state
==
StreamingState
.
PARSING_ARGUMENTS
:
next_function_text
=
None
if
self
.
bot_token
in
delta_text
:
# current tool call is over
delta_arguments
=
""
delta_arguments
+=
delta_text
.
split
(
self
.
bot_token
)[
0
]
next_function_text
=
delta_text
[
len
(
delta_arguments
)
:]
self
.
streaming_state
=
StreamingState
.
TOOL_COMPLETE
else
:
delta_arguments
=
delta_text
ret
=
[]
if
self
.
current_tool_name
or
delta_arguments
:
ret
+=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
type
=
"function"
,
id
=
MistralToolCall
.
generate_random
_id
()
,
id
=
tool
_id
,
function
=
DeltaFunctionCall
(
name
=
function_name
name
=
self
.
current_tool_name
,
arguments
=
delta_arguments
).
model_dump
(
exclude_none
=
True
),
)
]
self
.
current_tool_name
=
None
if
next_function_text
:
ret
+=
self
.
_generate_delta_tool_call
(
next_function_text
)
return
ret
# Should not happen
return
[]
@
ijson
.
coroutine
def
update_stream_state_pre_v11_tokenizer
(
self
):
while
True
:
(
prefix
,
event
,
value
)
=
yield
if
prefix
==
"item"
and
event
==
"start_map"
:
self
.
streaming_state
=
StreamingState
.
WAITING_FOR_TOOL_KEY
if
prefix
==
"item"
and
event
==
"map_key"
and
value
==
"name"
:
self
.
streaming_state
=
StreamingState
.
PARSING_NAME
if
prefix
==
"item.name"
and
event
==
"string"
:
self
.
current_tool_name
=
value
self
.
streaming_state
=
StreamingState
.
PARSING_NAME_COMPLETED
if
prefix
==
"item"
and
event
==
"map_key"
and
value
==
"arguments"
:
self
.
streaming_state
=
StreamingState
.
WAITING_FOR_ARGUMENTS_START
if
prefix
==
"item.arguments"
and
event
==
"start_map"
:
self
.
streaming_state
=
StreamingState
.
PARSING_ARGUMENTS
if
prefix
==
"item.arguments"
and
event
==
"end_map"
:
self
.
streaming_state
=
StreamingState
.
PARSING_ARGUMENTS_COMPLETED
if
prefix
==
"item"
and
event
==
"end_map"
:
self
.
streaming_state
=
StreamingState
.
TOOL_COMPLETE
if
prefix
==
""
and
event
==
"end_array"
:
self
.
streaming_state
=
StreamingState
.
ALL_TOOLS_COMPLETE
def
_extract_tool_calls_streaming_pre_v11_tokenizer
(
self
,
delta_text
:
str
,
delta_token_ids
:
Sequence
[
int
],
)
->
DeltaMessage
|
None
:
"""
Extracts tool calls for Mistral models
doing tool calls of the following format:
`[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}`
"""
assert
self
.
parse_coro
is
not
None
content
=
None
delta_tool_calls
:
list
[
DeltaToolCall
]
=
[]
current_tool_call
:
DeltaToolCall
=
DeltaToolCall
(
index
=
self
.
current_tool_id
,
type
=
"function"
)
self
.
current_tool_name_sent
=
True
else
:
delta
=
None
# now we know we're on the same tool call and we're streaming
# arguments
else
:
prev_arguments
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
)
cur_arguments
=
current_tool_call
.
get
(
"arguments"
)
new_text
=
delta_text
.
replace
(
"'"
,
'"'
)
if
'"}'
in
new_text
:
new_text
=
new_text
[:
new_text
.
rindex
(
'"}'
)]
if
not
cur_arguments
and
not
prev_arguments
:
delta
=
None
elif
not
cur_arguments
and
prev_arguments
:
logger
.
error
(
"INVARIANT - impossible to have arguments reset mid-arguments"
current_tool_call_modified
=
False
if
self
.
bot_token_id
in
delta_token_ids
:
# this is the first tool call
if
not
delta_text
.
startswith
(
self
.
bot_token
):
content
=
delta_text
.
split
(
self
.
bot_token
)[
0
]
delta_text
=
""
.
join
(
delta_text
.
split
(
self
.
bot_token
)[
1
:])
# Cut smartly the delta text to catch the ijson events
# as ijson does not give us the index in the text at each event.
# We need to cut so that we know
# where in the text the events are emitted from.
while
len
(
delta_text
)
>
0
:
streaming_state_before_parse
=
self
.
streaming_state
if
self
.
streaming_state
==
StreamingState
.
WAITING_FOR_TOOL_START
:
delta_to_be_parsed
,
delta_text
=
self
.
_split_delta
(
delta_text
=
delta_text
,
stop_after_opening_curly_braces
=
1
,
)
delta
=
None
elif
cur_arguments
and
not
prev_arguments
:
cur_arguments_json
=
json
.
dumps
(
cur_arguments
,
ensure_ascii
=
False
)[
:
-
2
]
logger
.
debug
(
"finding %s in %s"
,
new_text
,
cur_arguments_json
)
if
new_text
not
in
cur_arguments_json
:
return
None
arguments_delta
=
cur_arguments_json
[
:
cur_arguments_json
.
rindex
(
new_text
)
+
len
(
new_text
)
]
logger
.
debug
(
"First tokens in arguments received: %s"
,
arguments_delta
elif
self
.
streaming_state
==
StreamingState
.
WAITING_FOR_TOOL_KEY
:
# Wait until another key is sent
# or the current tool is completed
delta_to_be_parsed
,
delta_text
=
self
.
_split_delta
(
delta_text
=
delta_text
,
stop_after_colon
=
1
,
stop_after_opening_curly_braces
=
1
,
# if the tool ends, we want to separate
# at the start of the next tool
)
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
arguments_delta
).
model_dump
(
exclude_none
=
True
),
elif
self
.
streaming_state
==
StreamingState
.
PARSING_NAME
:
delta_to_be_parsed
,
delta_text
=
self
.
_split_delta
(
delta_text
=
delta_text
,
stop_after_comma
=
1
,
stop_after_closing_brackets
=
1
,
)
]
elif
self
.
streaming_state
==
StreamingState
.
WAITING_FOR_ARGUMENTS_START
:
delta_to_be_parsed
,
delta_text
=
self
.
_split_delta
(
delta_text
=
delta_text
,
stop_after_opening_curly_braces
=
1
,
)
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
arguments_delta
elif
cur_arguments
and
prev_arguments
:
cur_args_json
=
json
.
dumps
(
cur_arguments
,
ensure_ascii
=
False
)
prev_args_json
=
json
.
dumps
(
prev_arguments
,
ensure_ascii
=
False
)
logger
.
debug
(
"Searching for diff between
\n
%s
\n
%s"
,
cur_args_json
,
prev_args_json
,
elif
self
.
streaming_state
==
StreamingState
.
PARSING_ARGUMENTS
:
delta_to_be_parsed
,
delta_text
=
self
.
_split_delta
(
delta_text
=
delta_text
,
stop_after_closing_curly_braces
=
1
,
# we could be more clever
# by listening to item.arguments.* start_map events
# and know how many curly braces we can allow
)
argument_diff
=
extract_intermediate_diff
(
cur_args_json
,
prev_args_json
elif
self
.
streaming_state
in
[
StreamingState
.
PARSING_ARGUMENTS_COMPLETED
,
StreamingState
.
PARSING_NAME_COMPLETED
,
]:
delta_to_be_parsed
,
delta_text
=
self
.
_split_delta
(
delta_text
=
delta_text
,
stop_after_closing_curly_braces
=
1
,
stop_after_closing_brackets
=
1
,
)
logger
.
debug
(
"got arguments diff: %s"
,
argument_diff
)
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
arguments
=
argument_diff
).
model_dump
(
exclude_none
=
True
),
elif
self
.
streaming_state
==
StreamingState
.
TOOL_COMPLETE
:
delta_to_be_parsed
,
delta_text
=
self
.
_split_delta
(
delta_text
=
delta_text
,
stop_after_opening_curly_braces
=
1
,
stop_after_closing_brackets
=
1
,
)
elif
self
.
streaming_state
==
StreamingState
.
ALL_TOOLS_COMPLETE
:
content
=
delta_text
delta_text
=
""
else
:
delta_to_be_parsed
=
delta_text
delta_text
=
""
if
self
.
streaming_state
!=
StreamingState
.
ALL_TOOLS_COMPLETE
:
self
.
parse_coro
.
send
(
delta_to_be_parsed
.
encode
(
"utf-8"
))
# Given the parsed text and the possible streaming state change,
# let's add to the tool delta
if
(
(
streaming_state_before_parse
!=
self
.
streaming_state
)
and
streaming_state_before_parse
in
[
StreamingState
.
WAITING_FOR_TOOL_START
,
StreamingState
.
TOOL_COMPLETE
]
and
self
.
streaming_state
not
in
[
StreamingState
.
ALL_TOOLS_COMPLETE
,
StreamingState
.
TOOL_COMPLETE
,
StreamingState
.
WAITING_FOR_TOOL_START
,
]
):
# starting a new tool call
if
current_tool_call_modified
:
if
self
.
current_tool_mistral_id
is
not
None
:
current_tool_call
.
id
=
self
.
current_tool_mistral_id
self
.
current_tool_mistral_id
=
None
delta_tool_calls
.
append
(
current_tool_call
)
current_tool_call_modified
=
False
self
.
current_tool_id
+=
1
self
.
current_tool_mistral_id
=
MistralToolCall
.
generate_random_id
()
current_tool_call
=
DeltaToolCall
(
index
=
self
.
current_tool_id
,
type
=
"function"
,
)
self
.
streamed_args_for_tool
[
self
.
current_tool_id
]
+=
argument_diff
if
current_tool_call
.
function
is
None
:
current_tool_call
.
function
=
DeltaFunctionCall
()
if
self
.
current_tool_name
is
not
None
:
# we have the complete tool name
current_tool_call_modified
=
True
current_tool_call
.
function
.
name
=
self
.
current_tool_name
self
.
current_tool_name
=
None
if
self
.
streaming_state
==
StreamingState
.
PARSING_NAME_COMPLETED
:
self
.
streaming_state
=
StreamingState
.
WAITING_FOR_TOOL_KEY
if
self
.
streaming_state
in
[
StreamingState
.
PARSING_ARGUMENTS
,
StreamingState
.
PARSING_ARGUMENTS_COMPLETED
,
]:
if
self
.
streaming_state
==
StreamingState
.
PARSING_ARGUMENTS_COMPLETED
:
self
.
streaming_state
=
StreamingState
.
WAITING_FOR_TOOL_KEY
# the delta_to_be_parsed is part of arguments.
current_tool_call_modified
=
True
if
current_tool_call
.
function
.
arguments
is
None
:
current_tool_call
.
function
.
arguments
=
delta_to_be_parsed
else
:
# try parsing it with regular JSON - if it works we're
# at the end, and we need to send the difference between
# tokens streamed so far and the valid JSON
delta
=
None
# check to see if the name is defined and has been sent. if so,
# stream the name - otherwise keep waiting
# finish by setting old and returning None as base case
self
.
prev_tool_call_arr
=
tool_call_arr
return
delta
except
Exception
:
logger
.
exception
(
"Error trying to handle streaming tool call."
)
logger
.
debug
(
"Skipping chunk as a result of tool streaming extraction error"
current_tool_call
.
function
.
arguments
+=
delta_to_be_parsed
if
streaming_state_before_parse
!=
StreamingState
.
PARSING_ARGUMENTS
:
# It's the first chunk of arg. let's lstrip it
current_tool_call
.
function
.
arguments
=
(
current_tool_call
.
function
.
arguments
.
lstrip
()
)
if
current_tool_call_modified
:
if
self
.
current_tool_mistral_id
is
not
None
:
current_tool_call
.
id
=
self
.
current_tool_mistral_id
self
.
current_tool_mistral_id
=
None
delta_tool_calls
.
append
(
current_tool_call
)
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining it's final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if
delta_tool_calls
and
not
self
.
prev_tool_call_arr
:
self
.
prev_tool_call_arr
=
[{
"arguments"
:
{}}]
if
content
or
len
(
delta_tool_calls
)
>
0
:
delta_message
=
DeltaMessage
()
if
content
:
delta_message
.
content
=
content
if
len
(
delta_tool_calls
)
>
0
:
delta_message
.
tool_calls
=
delta_tool_calls
return
delta_message
else
:
if
self
.
streaming_state
==
StreamingState
.
ALL_TOOLS_COMPLETE
:
return
DeltaMessage
()
else
:
return
None
def
_split_delta
(
self
,
delta_text
:
str
,
stop_after_quotes
:
int
=
-
1
,
stop_after_opening_curly_braces
:
int
=
-
1
,
stop_after_closing_curly_braces
:
int
=
-
1
,
stop_after_closing_brackets
:
int
=
-
1
,
stop_after_colon
:
int
=
-
1
,
stop_after_comma
=-
1
,
)
->
tuple
[
str
,
str
]:
delta_to_be_parsed
=
""
for
i
,
c
in
enumerate
(
delta_text
):
if
c
in
[
'"'
,
"'"
]:
delta_to_be_parsed
+=
c
stop_after_quotes
-=
1
if
stop_after_quotes
==
0
:
return
(
delta_to_be_parsed
,
delta_text
[
i
+
1
:])
elif
c
==
"{"
:
delta_to_be_parsed
+=
c
stop_after_opening_curly_braces
-=
1
if
stop_after_opening_curly_braces
==
0
:
return
(
delta_to_be_parsed
,
delta_text
[
i
+
1
:])
elif
c
==
"}"
:
delta_to_be_parsed
+=
c
stop_after_closing_curly_braces
-=
1
if
stop_after_closing_curly_braces
==
0
:
return
(
delta_to_be_parsed
,
delta_text
[
i
+
1
:])
elif
c
==
"]"
:
delta_to_be_parsed
+=
c
stop_after_closing_brackets
-=
1
if
stop_after_closing_brackets
==
0
:
return
(
delta_to_be_parsed
,
delta_text
[
i
+
1
:])
elif
c
==
":"
:
delta_to_be_parsed
+=
c
stop_after_colon
-=
1
if
stop_after_colon
==
0
:
return
(
delta_to_be_parsed
,
delta_text
[
i
+
1
:])
elif
c
==
","
:
delta_to_be_parsed
+=
c
stop_after_comma
-=
1
if
stop_after_comma
==
0
:
return
(
delta_to_be_parsed
,
delta_text
[
i
+
1
:])
else
:
delta_to_be_parsed
+=
c
return
(
delta_to_be_parsed
,
""
)
vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py
View file @
8d75f22e
...
...
@@ -4,7 +4,7 @@ import json
from
collections.abc
import
Sequence
from
typing
import
TYPE_CHECKING
from
vllm.entrypoints.harmony_utils
import
parse_output_into_messages
from
vllm.entrypoints.
openai.parser.
harmony_utils
import
parse_output_into_messages
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaMessage
,
...
...
vllm/entrypoints/pooling/embed/api_router.py
View file @
8d75f22e
...
...
@@ -59,8 +59,8 @@ async def create_embedding(
return
JSONResponse
(
content
=
generator
.
model_dump
())
elif
isinstance
(
generator
,
EmbeddingBytesResponse
):
return
StreamingResponse
(
content
=
generator
.
body
,
headers
=
{
"metadata"
:
generator
.
metadata
}
,
content
=
generator
.
content
,
headers
=
generator
.
headers
,
media_type
=
generator
.
media_type
,
)
...
...
vllm/entrypoints/pooling/embed/protocol.py
View file @
8d75f22e
...
...
@@ -203,6 +203,6 @@ class EmbeddingResponse(OpenAIBaseModel):
class
EmbeddingBytesResponse
(
OpenAIBaseModel
):
body
:
list
[
bytes
]
metadata
:
str
content
:
list
[
bytes
]
headers
:
dict
[
str
,
str
]
|
None
=
None
media_type
:
str
=
"application/octet-stream"
vllm/entrypoints/pooling/embed/serving.py
View file @
8d75f22e
...
...
@@ -163,29 +163,35 @@ class EmbeddingMixin(OpenAIServing):
usage
=
usage
,
)
def
encode_bytes
(
)
:
body
,
items
,
usage
=
encode_pooling_bytes
(
def
encode_bytes
(
bytes_only
:
bool
)
->
EmbeddingBytesResponse
:
content
,
items
,
usage
=
encode_pooling_bytes
(
pooling_outputs
=
final_res_batch_checked
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
)
metadata
=
{
headers
=
(
None
if
bytes_only
else
{
"metadata"
:
json
.
dumps
(
{
"id"
:
ctx
.
request_id
,
"created"
:
ctx
.
created_time
,
"model"
:
ctx
.
model_name
,
"data"
:
items
,
"usage"
:
usage
,
}
return
EmbeddingBytesResponse
(
body
=
body
,
metadata
=
json
.
dumps
(
metadata
),
)
}
)
return
EmbeddingBytesResponse
(
content
=
content
,
headers
=
headers
)
if
encoding_format
==
"float"
or
encoding_format
==
"base64"
:
return
encode_float_base64
()
elif
encoding_format
==
"bytes"
:
return
encode_bytes
()
elif
encoding_format
==
"bytes"
or
encoding_format
==
"bytes_only"
:
return
encode_bytes
(
bytes_only
=
encoding_format
==
"bytes_only"
)
else
:
assert_never
(
encoding_format
)
...
...
vllm/entrypoints/pooling/pooling/api_router.py
View file @
8d75f22e
...
...
@@ -55,8 +55,8 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
return
JSONResponse
(
content
=
generator
.
model_dump
())
elif
isinstance
(
generator
,
PoolingBytesResponse
):
return
StreamingResponse
(
content
=
generator
.
body
,
headers
=
{
"metadata"
:
generator
.
metadata
}
,
content
=
generator
.
content
,
headers
=
generator
.
headers
,
media_type
=
generator
.
media_type
,
)
...
...
vllm/entrypoints/pooling/pooling/protocol.py
View file @
8d75f22e
...
...
@@ -143,6 +143,6 @@ class PoolingResponse(OpenAIBaseModel):
class
PoolingBytesResponse
(
OpenAIBaseModel
):
body
:
list
[
bytes
]
metadata
:
str
content
:
list
[
bytes
]
headers
:
dict
[
str
,
str
]
|
None
=
None
media_type
:
str
=
"application/octet-stream"
vllm/entrypoints/pooling/pooling/serving.py
View file @
8d75f22e
...
...
@@ -314,29 +314,38 @@ class OpenAIServingPooling(OpenAIServing):
usage
=
usage
,
)
def
encode_bytes
(
)
:
body
,
items
,
usage
=
encode_pooling_bytes
(
def
encode_bytes
(
bytes_only
:
bool
)
->
PoolingBytesResponse
:
content
,
items
,
usage
=
encode_pooling_bytes
(
pooling_outputs
=
final_res_batch
,
embed_dtype
=
embed_dtype
,
endianness
=
endianness
,
)
metadata
=
{
headers
=
(
None
if
bytes_only
else
{
"metadata"
:
json
.
dumps
(
{
"id"
:
request_id
,
"created"
:
created_time
,
"model"
:
model_name
,
"data"
:
items
,
"usage"
:
usage
,
}
)
}
)
return
PoolingBytesResponse
(
body
=
body
,
metadata
=
json
.
dumps
(
metadata
)
,
content
=
content
,
headers
=
headers
,
)
if
encoding_format
==
"float"
or
encoding_format
==
"base64"
:
return
encode_float_base64
()
elif
encoding_format
==
"bytes"
:
return
encode_bytes
()
elif
encoding_format
==
"bytes"
or
encoding_format
==
"bytes_only"
:
return
encode_bytes
(
bytes_only
=
encoding_format
==
"bytes_only"
)
else
:
assert_never
(
encoding_format
)
...
...
vllm/entrypoints/responses_utils.py
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
from
openai.types.chat
import
(
ChatCompletionAssistantMessageParam
,
ChatCompletionMessageToolCallParam
,
...
...
@@ -10,18 +12,53 @@ from openai.types.chat.chat_completion_message_tool_call_param import (
Function
as
FunctionCallTool
,
)
from
openai.types.responses
import
ResponseFunctionToolCall
,
ResponseOutputItem
from
openai.types.responses.response
import
ToolChoice
from
openai.types.responses.response_function_tool_call_output_item
import
(
ResponseFunctionToolCallOutputItem
,
)
from
openai.types.responses.response_output_item
import
McpCall
from
openai.types.responses.response_output_message
import
ResponseOutputMessage
from
openai.types.responses.response_reasoning_item
import
ResponseReasoningItem
from
openai.types.responses.tool
import
Tool
from
vllm
import
envs
from
vllm.entrypoints.constants
import
MCP_PREFIX
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionMessageParam
,
ResponseInputOutputItem
,
)
from
vllm.utils
import
random_uuid
def
make_response_output_items_from_parsable_context
(
response_messages
:
list
[
ResponseInputOutputItem
],
)
->
list
[
ResponseOutputItem
]:
"""Given a list of sentences, construct ResponseOutput Items."""
output_messages
:
list
[
ResponseOutputItem
]
=
[]
for
message
in
response_messages
:
if
not
isinstance
(
message
,
ResponseFunctionToolCallOutputItem
):
output_messages
.
append
(
message
)
else
:
if
len
(
output_messages
)
==
0
:
raise
ValueError
(
"Cannot have a FunctionToolCallOutput before FunctionToolCall."
)
if
isinstance
(
output_messages
[
-
1
],
ResponseFunctionToolCall
):
mcp_message
=
McpCall
(
id
=
f
"
{
MCP_PREFIX
}{
random_uuid
()
}
"
,
arguments
=
output_messages
[
-
1
].
arguments
,
name
=
output_messages
[
-
1
].
name
,
server_label
=
output_messages
[
-
1
].
name
,
# TODO: store the server label
type
=
f
"
{
MCP_PREFIX
}
call"
,
status
=
"completed"
,
output
=
message
.
output
,
# TODO: support error output
)
output_messages
[
-
1
]
=
mcp_message
return
output_messages
def
construct_input_messages
(
...
...
@@ -62,12 +99,63 @@ def construct_input_messages(
if
isinstance
(
request_input
,
str
):
messages
.
append
({
"role"
:
"user"
,
"content"
:
request_input
})
else
:
for
item
in
request_input
:
messages
.
append
(
construct_chat_message_with_tool_call
(
item
))
input_messages
=
construct_chat_messages_with_tool_call
(
request_input
)
messages
.
extend
(
input_messages
)
return
messages
def
_maybe_combine_reasoning_and_tool_call
(
item
:
ResponseInputOutputItem
,
messages
:
list
[
ChatCompletionMessageParam
]
)
->
ChatCompletionMessageParam
|
None
:
"""Many models treat MCP calls and reasoning as a single message.
This function checks if the last message is a reasoning message and
the current message is a tool call"""
if
not
(
isinstance
(
item
,
ResponseFunctionToolCall
)
and
item
.
id
.
startswith
(
MCP_PREFIX
)
):
return
None
if
len
(
messages
)
==
0
:
return
None
last_message
=
messages
[
-
1
]
if
not
(
last_message
.
get
(
"role"
)
==
"assistant"
and
last_message
.
get
(
"reasoning"
)
is
not
None
):
return
None
last_message
[
"tool_calls"
]
=
[
ChatCompletionMessageToolCallParam
(
id
=
item
.
call_id
,
function
=
FunctionCallTool
(
name
=
item
.
name
,
arguments
=
item
.
arguments
,
),
type
=
"function"
,
)
]
return
last_message
def
construct_chat_messages_with_tool_call
(
input_messages
:
list
[
ResponseInputOutputItem
],
)
->
list
[
ChatCompletionMessageParam
]:
"""This function wraps _construct_single_message_from_response_item
Because some chatMessages come from multiple response items
for example a reasoning item and a MCP tool call are two response items
but are one chat message
"""
messages
:
list
[
ChatCompletionMessageParam
]
=
[]
for
item
in
input_messages
:
maybe_combined_message
=
_maybe_combine_reasoning_and_tool_call
(
item
,
messages
)
if
maybe_combined_message
is
not
None
:
messages
[
-
1
]
=
maybe_combined_message
else
:
messages
.
append
(
_construct_single_message_from_response_item
(
item
))
return
messages
def
construct_
chat
_message_
with_tool_call
(
def
_
construct_
single
_message_
from_response_item
(
item
:
ResponseInputOutputItem
,
)
->
ChatCompletionMessageParam
:
if
isinstance
(
item
,
ResponseFunctionToolCall
):
...
...
@@ -146,3 +234,16 @@ def convert_tool_responses_to_completions_format(tool: dict) -> dict:
"type"
:
"function"
,
"function"
:
tool
,
}
def
construct_tool_dicts
(
tools
:
list
[
Tool
],
tool_choice
:
ToolChoice
)
->
list
[
dict
[
str
,
Any
]]
|
None
:
if
tools
is
None
or
(
tool_choice
==
"none"
):
tool_dicts
=
None
else
:
tool_dicts
=
[
convert_tool_responses_to_completions_format
(
tool
.
model_dump
())
for
tool
in
tools
]
return
tool_dicts
vllm/entrypoints/sagemaker/routes.py
View file @
8d75f22e
...
...
@@ -16,7 +16,6 @@ from vllm.entrypoints.openai.api_server import (
completion
,
create_chat_completion
,
create_completion
,
health
,
validate_json_request
,
)
from
vllm.entrypoints.openai.protocol
import
(
...
...
@@ -38,6 +37,7 @@ from vllm.entrypoints.pooling.score.api_router import (
score
,
)
from
vllm.entrypoints.pooling.score.protocol
import
RerankRequest
,
ScoreRequest
from
vllm.entrypoints.serve.instrumentator.health
import
health
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
...
...
vllm/entrypoints/score_utils.py
View file @
8d75f22e
...
...
@@ -89,12 +89,10 @@ def parse_score_data(
data_1
:
str
|
ScoreContentPartParam
,
data_2
:
str
|
ScoreContentPartParam
,
model_config
:
ModelConfig
,
tokenizer
:
TokenizerLike
,
)
->
tuple
[
str
,
str
,
MultiModalDataDict
|
None
]:
mm_tracker
=
MultiModalItemTracker
(
model_config
,
tokenizer
)
mm_tracker
=
MultiModalItemTracker
(
model_config
)
content_1
=
_parse_score_content
(
data_1
,
mm_tracker
)
content_2
=
_parse_score_content
(
data_2
,
mm_tracker
)
def
ensure_str
(
content
:
_ContentPart
|
None
)
->
str
:
...
...
@@ -188,7 +186,6 @@ def get_score_prompt(
data_1
,
data_2
,
model_config
,
tokenizer
,
)
from
vllm.model_executor.model_loader
import
get_model_cls
...
...
vllm/entrypoints/serve/__init__.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
fastapi
import
FastAPI
def
register_vllm_serve_api_routers
(
app
:
FastAPI
):
from
vllm.entrypoints.serve.lora.api_router
import
(
attach_router
as
attach_lora_router
,
)
attach_lora_router
(
app
)
from
vllm.entrypoints.serve.elastic_ep.api_router
import
(
attach_router
as
attach_elastic_ep_router
,
)
attach_elastic_ep_router
(
app
)
from
vllm.entrypoints.serve.profile.api_router
import
(
attach_router
as
attach_profile_router
,
)
attach_profile_router
(
app
)
from
vllm.entrypoints.serve.sleep.api_router
import
(
attach_router
as
attach_sleep_router
,
)
attach_sleep_router
(
app
)
from
vllm.entrypoints.serve.tokenize.api_router
import
(
attach_router
as
attach_tokenize_router
,
)
attach_tokenize_router
(
app
)
from
vllm.entrypoints.serve.disagg.api_router
import
(
attach_router
as
attach_disagg_router
,
)
attach_disagg_router
(
app
)
from
vllm.entrypoints.serve.rlhf.api_router
import
(
attach_router
as
attach_rlhf_router
,
)
attach_rlhf_router
(
app
)
from
vllm.entrypoints.serve.instrumentator.metrics
import
(
attach_router
as
attach_metrics_router
,
)
attach_metrics_router
(
app
)
from
vllm.entrypoints.serve.instrumentator.health
import
(
attach_router
as
attach_health_router
,
)
attach_health_router
(
app
)
vllm/entrypoints/serve/disagg/__init__.py
0 → 100644
View file @
8d75f22e
vllm/entrypoints/serve/disagg/api_router.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
json
from
http
import
HTTPStatus
from
fastapi
import
APIRouter
,
Depends
,
FastAPI
,
HTTPException
,
Request
,
Response
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.openai.api_server
import
validate_json_request
from
vllm.entrypoints.openai.protocol
import
(
ErrorResponse
,
)
from
vllm.entrypoints.serve.disagg.protocol
import
(
GenerateRequest
,
GenerateResponse
,
)
from
vllm.entrypoints.serve.disagg.serving
import
(
ServingTokens
,
)
from
vllm.entrypoints.serve.tokenize.serving
import
OpenAIServingTokenization
from
vllm.entrypoints.utils
import
(
load_aware_call
,
with_cancellation
,
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
tokenization
(
request
:
Request
)
->
OpenAIServingTokenization
:
return
request
.
app
.
state
.
openai_serving_tokenization
def
generate_tokens
(
request
:
Request
)
->
ServingTokens
|
None
:
return
request
.
app
.
state
.
serving_tokens
def
engine_client
(
request
:
Request
)
->
EngineClient
:
return
request
.
app
.
state
.
engine_client
router
=
APIRouter
()
@
router
.
post
(
"/inference/v1/generate"
,
dependencies
=
[
Depends
(
validate_json_request
)],
responses
=
{
HTTPStatus
.
OK
.
value
:
{
"content"
:
{
"text/event-stream"
:
{}}},
HTTPStatus
.
BAD_REQUEST
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
NOT_FOUND
.
value
:
{
"model"
:
ErrorResponse
},
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
:
{
"model"
:
ErrorResponse
},
},
)
@
with_cancellation
@
load_aware_call
async
def
generate
(
request
:
GenerateRequest
,
raw_request
:
Request
):
handler
=
generate_tokens
(
raw_request
)
if
handler
is
None
:
return
tokenization
(
raw_request
).
create_error_response
(
message
=
"The model does not support generate tokens API"
)
try
:
generator
=
await
handler
.
serve_tokens
(
request
,
raw_request
)
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
detail
=
str
(
e
)
)
from
e
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
error
.
code
)
elif
isinstance
(
generator
,
GenerateResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
def
attach_router
(
app
:
FastAPI
):
if
getattr
(
app
.
state
.
args
,
"tokens_only"
,
False
):
@
router
.
post
(
"/abort_requests"
)
async
def
abort_requests
(
raw_request
:
Request
):
"""
Abort one or more requests. To be used in a
Disaggregated Everything setup.
"""
try
:
body
=
await
raw_request
.
json
()
except
json
.
JSONDecodeError
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
BAD_REQUEST
.
value
,
detail
=
f
"JSON decode error:
{
e
}
"
,
)
from
e
request_ids
=
body
.
get
(
"request_ids"
)
if
request_ids
is
None
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
BAD_REQUEST
.
value
,
detail
=
"Missing 'request_ids' in request body"
,
)
# Abort requests in background
asyncio
.
create_task
(
engine_client
(
raw_request
).
abort
(
request_ids
))
return
Response
(
status_code
=
200
)
app
.
include_router
(
router
)
vllm/entrypoints/serve/disagg/protocol.py
0 → 100644
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
from
pydantic
import
BaseModel
,
Field
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionLogProbs
,
Logprob
,
SamplingParams
,
StreamOptions
,
)
from
vllm.utils
import
random_uuid
####### Tokens IN <> Tokens OUT #######
class
GenerateRequest
(
BaseModel
):
request_id
:
str
=
Field
(
default_factory
=
lambda
:
f
"
{
random_uuid
()
}
"
,
description
=
(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
token_ids
:
list
[
int
]
"""The token ids to generate text from."""
# features: MultiModalFeatureSpec
# TODO (NickLucche): implement once Renderer work is completed
features
:
str
|
None
=
None
"""The processed MM inputs for the model."""
sampling_params
:
SamplingParams
"""The sampling parameters for the model."""
model
:
str
|
None
=
None
stream
:
bool
|
None
=
False
stream_options
:
StreamOptions
|
None
=
None
cache_salt
:
str
|
None
=
Field
(
default
=
None
,
description
=
(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
priority
:
int
=
Field
(
default
=
0
,
description
=
(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
kv_transfer_params
:
dict
[
str
,
Any
]
|
None
=
Field
(
default
=
None
,
description
=
"KVTransfer parameters used for disaggregated serving."
,
)
class
GenerateResponseChoice
(
BaseModel
):
index
:
int
logprobs
:
ChatCompletionLogProbs
|
None
=
None
# per OpenAI spec this is the default
finish_reason
:
str
|
None
=
"stop"
token_ids
:
list
[
int
]
|
None
=
None
class
GenerateResponse
(
BaseModel
):
request_id
:
str
=
Field
(
default_factory
=
lambda
:
f
"
{
random_uuid
()
}
"
,
description
=
(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
choices
:
list
[
GenerateResponseChoice
]
prompt_logprobs
:
list
[
dict
[
int
,
Logprob
]
|
None
]
|
None
=
None
kv_transfer_params
:
dict
[
str
,
Any
]
|
None
=
Field
(
default
=
None
,
description
=
"KVTransfer parameters used for disaggregated serving."
,
)
vllm/entrypoints/
openai
/serving
_tokens
.py
→
vllm/entrypoints/
serve/disagg
/serving.py
View file @
8d75f22e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
time
from
collections.abc
import
AsyncGenerator
...
...
@@ -14,15 +16,17 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProbs
,
ChatCompletionLogProbsContent
,
ErrorResponse
,
GenerateRequest
,
GenerateResponse
,
GenerateResponseChoice
,
PromptTokenUsageInfo
,
RequestResponseMetadata
,
UsageInfo
,
)
from
vllm.entrypoints.openai.serving_engine
import
OpenAIServing
,
clamp_prompt_logprobs
from
vllm.entrypoints.openai.serving_models
import
OpenAIServingModels
from
vllm.entrypoints.serve.disagg.protocol
import
(
GenerateRequest
,
GenerateResponse
,
GenerateResponseChoice
,
)
from
vllm.inputs.data
import
TokensPrompt
as
EngineTokensPrompt
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
...
...
vllm/entrypoints/serve/elastic_ep/__init__.py
0 → 100644
View file @
8d75f22e
Prev
1
…
18
19
20
21
22
23
24
25
26
…
34
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment