Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d2b52805
Commit
d2b52805
authored
Sep 07, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2rc1' into v0.10.2rc1-ori
parents
9a521c23
5438967f
Changes
511
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1611 additions
and
408 deletions
+1611
-408
vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py
vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py
+2
-2
vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py
...entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py
+286
-243
vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py
vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py
+679
-0
vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py
vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py
+83
-25
vllm/entrypoints/utils.py
vllm/entrypoints/utils.py
+3
-1
vllm/envs.py
vllm/envs.py
+56
-11
vllm/executor/mp_distributed_executor.py
vllm/executor/mp_distributed_executor.py
+1
-1
vllm/executor/msgspec_utils.py
vllm/executor/msgspec_utils.py
+7
-2
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+6
-0
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+2
-1
vllm/inputs/data.py
vllm/inputs/data.py
+29
-7
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+127
-65
vllm/inputs/registry.py
vllm/inputs/registry.py
+9
-3
vllm/lora/layers.py
vllm/lora/layers.py
+1
-4
vllm/lora/models.py
vllm/lora/models.py
+9
-4
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+122
-0
vllm/model_executor/layers/attention_layer_base.py
vllm/model_executor/layers/attention_layer_base.py
+23
-0
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+30
-32
vllm/model_executor/layers/fused_moe/config.py
vllm/model_executor/layers/fused_moe/config.py
+14
-7
vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
...352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
+122
-0
No files found.
Too many changes to show.
To preserve performance only
511 of 511+
files are displayed.
Plain diff
Email patch
vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py
View file @
d2b52805
...
...
@@ -8,7 +8,7 @@ from typing import Any, Optional
import
regex
as
re
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.chat_utils
import
random
_tool_call_id
from
vllm.entrypoints.chat_utils
import
make
_tool_call_id
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaMessage
,
ExtractedToolCallInformation
,
...
...
@@ -74,7 +74,7 @@ class Phi4MiniJsonToolParser(ToolParser):
tool_calls
:
list
[
ToolCall
]
=
[
ToolCall
(
id
=
random
_tool_call_id
(),
id
=
make
_tool_call_id
(),
type
=
"function"
,
function
=
FunctionCall
(
name
=
raw_function_call
[
"name"
],
...
...
vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
ast
import
json
import
uuid
from
collections.abc
import
Sequence
...
...
@@ -22,7 +22,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
logger
=
init_logger
(
__name__
)
@
ToolParserManager
.
register_module
(
[
"qwen3_coder"
]
)
@
ToolParserManager
.
register_module
(
"qwen3_coder"
)
class
Qwen3CoderToolParser
(
ToolParser
):
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
...
...
@@ -30,6 +30,8 @@ class Qwen3CoderToolParser(ToolParser):
self
.
current_tool_name_sent
:
bool
=
False
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
# Override base class type - we use string IDs for tool calls
self
.
current_tool_id
:
Optional
[
str
]
=
None
# type: ignore
self
.
streamed_args_for_tool
:
list
[
str
]
=
[]
# Sentinel tokens for streaming mode
...
...
@@ -42,20 +44,6 @@ class Qwen3CoderToolParser(ToolParser):
self
.
is_tool_call_started
:
bool
=
False
self
.
failed_count
:
int
=
0
# Streaming state variables
self
.
current_tool_index
:
int
=
0
self
.
header_sent
:
bool
=
False
self
.
current_tool_string_id
:
Optional
[
str
]
=
None
self
.
current_function_name
:
Optional
[
str
]
=
None
self
.
current_param_name
:
Optional
[
str
]
=
None
self
.
current_param_value
:
str
=
""
self
.
param_count
:
int
=
0
self
.
in_param
:
bool
=
False
self
.
in_function
:
bool
=
False
self
.
accumulated_text
:
str
=
""
self
.
json_started
:
bool
=
False
self
.
json_closed
:
bool
=
False
# Enhanced streaming state - reset for each new message
self
.
_reset_streaming_state
()
...
...
@@ -67,7 +55,8 @@ class Qwen3CoderToolParser(ToolParser):
self
.
tool_call_function_regex
=
re
.
compile
(
r
"<function=(.*?)</function>|<function=(.*)$"
,
re
.
DOTALL
)
self
.
tool_call_parameter_regex
=
re
.
compile
(
r
"<parameter=(.*?)</parameter>|<parameter=(.*?)$"
,
re
.
DOTALL
)
r
"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)"
,
re
.
DOTALL
)
if
not
self
.
model_tokenizer
:
raise
ValueError
(
...
...
@@ -84,7 +73,7 @@ class Qwen3CoderToolParser(ToolParser):
"Qwen3 XML Tool parser could not locate tool call start/end "
"tokens in the tokenizer!"
)
logger
.
debug
(
"vLLM Successfully import tool parser %s !"
,
logger
.
info
(
"vLLM Successfully import tool parser %s !"
,
self
.
__class__
.
__name__
)
def
_generate_tool_call_id
(
self
)
->
str
:
...
...
@@ -96,7 +85,7 @@ class Qwen3CoderToolParser(ToolParser):
self
.
current_tool_index
=
0
self
.
is_tool_call_started
=
False
self
.
header_sent
=
False
self
.
current_tool_
string_
id
=
None
self
.
current_tool_id
=
None
self
.
current_function_name
=
None
self
.
current_param_name
=
None
self
.
current_param_value
=
""
...
...
@@ -106,22 +95,21 @@ class Qwen3CoderToolParser(ToolParser):
self
.
accumulated_text
=
""
self
.
json_started
=
False
self
.
json_closed
=
False
def
_parse_xml_function_call
(
self
,
function_call_str
:
str
,
tools
:
Optional
[
list
[
ChatCompletionToolsParam
]]
)
->
Optional
[
ToolCall
]:
def
get_arguments_config
(
func_name
:
str
)
->
dict
:
# Store accumulated parameters for type conversion
self
.
accumulated_params
=
{}
self
.
streaming_request
=
None
def
_get_arguments_config
(
self
,
func_name
:
str
,
tools
:
Optional
[
list
[
ChatCompletionToolsParam
]])
->
dict
:
"""Extract argument configuration for a function."""
if
tools
is
None
:
return
{}
for
config
in
tools
:
if
not
hasattr
(
config
,
"type"
)
or
not
(
hasattr
(
config
,
"function"
)
and
hasattr
(
config
.
function
,
"name"
)):
if
not
hasattr
(
config
,
"type"
)
or
not
(
hasattr
(
config
,
"function"
)
and
hasattr
(
config
.
function
,
"name"
)):
continue
if
(
config
.
type
==
"function"
and
config
.
function
.
name
==
func_name
):
if
config
.
type
==
"function"
and
config
.
function
.
name
==
func_name
:
if
not
hasattr
(
config
.
function
,
"parameters"
):
return
{}
params
=
config
.
function
.
parameters
...
...
@@ -135,14 +123,13 @@ class Qwen3CoderToolParser(ToolParser):
func_name
)
return
{}
def
convert_param_value
(
param_value
:
str
,
param_name
:
str
,
def
_
convert_param_value
(
self
,
param_value
:
str
,
param_name
:
str
,
param_config
:
dict
,
func_name
:
str
)
->
Any
:
"""Convert parameter value based on its type in the schema."""
# Handle null value for any type
if
param_value
.
lower
()
==
"null"
:
return
None
converted_value
:
Any
if
param_name
not
in
param_config
:
if
param_config
!=
{}:
logger
.
warning
(
...
...
@@ -151,38 +138,31 @@ class Qwen3CoderToolParser(ToolParser):
"string value."
,
param_name
,
func_name
)
return
param_value
if
(
isinstance
(
param_config
[
param_name
],
dict
)
and
"type"
in
param_config
[
param_name
]):
param_type
=
str
(
param_config
[
param_name
][
"type"
]).
strip
().
lower
()
if
isinstance
(
param_config
[
param_name
],
dict
)
and
"type"
in
param_config
[
param_name
]:
param_type
=
str
(
param_config
[
param_name
][
"type"
]).
strip
().
lower
()
else
:
param_type
=
"string"
if
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]:
if
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]:
return
param_value
elif
(
param_type
.
startswith
(
"int"
)
or
param_type
.
startswith
(
"uint"
)
or
param_type
.
startswith
(
"long"
)
or
param_type
.
startswith
(
"short"
)
or
param_type
.
startswith
(
"unsigned"
)
)
:
elif
param_type
.
startswith
(
"int"
)
or
param_type
.
startswith
(
"uint"
)
or
param_type
.
startswith
(
"long"
)
or
param_type
.
startswith
(
"short"
)
or
param_type
.
startswith
(
"unsigned"
):
try
:
converted_value
=
int
(
param_value
)
return
converted_value
except
ValueError
:
return
int
(
param_value
)
except
(
ValueError
,
TypeError
):
logger
.
warning
(
"Parsed value '%s' of parameter '%s' is not an "
"integer in tool '%s', degenerating to string."
,
param_value
,
param_name
,
func_name
)
return
param_value
elif
(
param_type
.
startswith
(
"num"
)
or
param_type
.
startswith
(
"float"
)):
elif
param_type
.
startswith
(
"num"
)
or
param_type
.
startswith
(
"float"
):
try
:
float_param_value
=
float
(
param_value
)
converted_value
=
(
float_param_value
if
float_param_value
-
int
(
float_param_value
)
!=
0
else
int
(
float_param_value
))
return
converted_value
except
ValueError
:
return
float_param_value
if
float_param_value
-
int
(
float_param_value
)
!=
0
else
int
(
float_param_value
)
except
(
ValueError
,
TypeError
):
logger
.
warning
(
"Parsed value '%s' of parameter '%s' is not a float "
"in tool '%s', degenerating to string."
,
param_value
,
...
...
@@ -192,36 +172,45 @@ class Qwen3CoderToolParser(ToolParser):
param_value
=
param_value
.
lower
()
if
param_value
not
in
[
"true"
,
"false"
]:
logger
.
warning
(
"Parsed value '%s' of parameter '%s' is not a "
"boolean (`true` of `false`) in tool '%s', "
"degenerating to false."
,
param_value
,
param_name
,
func_name
)
"Parsed value '%s' of parameter '%s' is not a boolean "
"(`true` or `false`) in tool '%s', degenerating to "
"false."
,
param_value
,
param_name
,
func_name
)
return
param_value
==
"true"
else
:
if
param_type
==
"object"
or
param_type
.
startswith
(
"dict"
):
if
param_type
in
[
"object"
,
"array"
,
"arr"
]
or
param_type
.
startswith
(
"dict"
)
or
param_type
.
startswith
(
"list"
):
try
:
converted
_value
=
json
.
loads
(
param_value
)
return
converted
_value
except
json
.
JSONDecodeError
:
param
_value
=
json
.
loads
(
param_value
)
return
param
_value
except
(
json
.
JSONDecodeError
,
TypeError
,
ValueError
)
:
logger
.
warning
(
"Parsed value '%s' of parameter '%s'
is
not
a
"
"valid JSON object
in tool '%s', will try
other
"
"
methods to parse it."
,
param_value
,
param_name
,
"Parsed value '%s' of parameter '%s'
can
not
be
"
"parsed with json.loads
in tool '%s', will try "
"other
methods to parse it."
,
param_value
,
param_name
,
func_name
)
try
:
param_value
=
ast
.
literal_eval
(
param_value
)
# safer
except
(
ValueError
,
SyntaxError
,
TypeError
):
logger
.
warning
(
"Parameter '%s' has unknown type '%s'. "
"The value will be treated as a string."
,
param_name
,
param_type
)
"Parsed value '%s' of parameter '%s' cannot be "
"converted via Python `ast.literal_eval()` in tool "
"'%s', degenerating to string."
,
param_value
,
param_name
,
func_name
)
return
param_value
def
_parse_xml_function_call
(
self
,
function_call_str
:
str
,
tools
:
Optional
[
list
[
ChatCompletionToolsParam
]]
)
->
Optional
[
ToolCall
]:
# Extract function name
end_index
=
function_call_str
.
index
(
">"
)
function_name
=
function_call_str
[:
end_index
]
param_config
=
get_arguments_config
(
function_name
)
param_config
=
self
.
_
get_arguments_config
(
function_name
,
tools
)
parameters
=
function_call_str
[
end_index
+
1
:]
param_dict
=
{}
for
match
in
self
.
tool_call_parameter_regex
.
findall
(
parameters
):
match_text
=
match
[
0
]
if
match
[
0
]
else
match
[
1
]
for
match_text
in
self
.
tool_call_parameter_regex
.
findall
(
parameters
):
idx
=
match_text
.
index
(
">"
)
param_name
=
match_text
[:
idx
]
param_value
=
str
(
match_text
[
idx
+
1
:])
...
...
@@ -231,7 +220,7 @@ class Qwen3CoderToolParser(ToolParser):
if
param_value
.
endswith
(
"
\n
"
):
param_value
=
param_value
[:
-
1
]
param_dict
[
param_name
]
=
convert_param_value
(
param_dict
[
param_name
]
=
self
.
_
convert_param_value
(
param_value
,
param_name
,
param_config
,
function_name
)
return
ToolCall
(
type
=
"function"
,
...
...
@@ -284,8 +273,7 @@ class Qwen3CoderToolParser(ToolParser):
for
function_call_str
in
function_calls
]
# Populate prev_tool_call_arr for serving layer to set
# finish_reason
# Populate prev_tool_call_arr for serving layer to set finish_reason
self
.
prev_tool_call_arr
.
clear
()
# Clear previous calls
for
tool_call
in
tool_calls
:
if
tool_call
:
...
...
@@ -298,8 +286,8 @@ class Qwen3CoderToolParser(ToolParser):
# Extract content before tool calls
content_index
=
model_output
.
find
(
self
.
tool_call_start_token
)
content_index
=
(
content_index
if
content_index
>=
0
else
model_output
.
find
(
self
.
tool_call_prefix
))
idx
=
model_output
.
find
(
self
.
tool_call_prefix
)
content_index
=
content_index
if
content_index
>=
0
else
idx
content
=
model_output
[:
content_index
]
# .rstrip()
return
ExtractedToolCallInformation
(
...
...
@@ -324,13 +312,16 @@ class Qwen3CoderToolParser(ToolParser):
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
Union
[
DeltaMessage
,
None
]:
# If no delta text, return None unless it's an EOS token after tool
# calls
# Store request for type conversion
if
not
previous_text
:
self
.
_reset_streaming_state
()
self
.
streaming_request
=
request
# If no delta text, return None unless it's an EOS token after tools
if
not
delta_text
:
# Check if this is an EOS token after all tool calls are complete
# We check for tool calls in the text even if is_tool_call_started
# is False because it might have been reset after processing all
# tools
# Check for tool calls in text even if is_tool_call_started
# is False (might have been reset after processing all tools)
if
(
delta_token_ids
and
self
.
tool_call_end_token_id
not
in
delta_token_ids
):
# Count complete tool calls
...
...
@@ -339,24 +330,19 @@ class Qwen3CoderToolParser(ToolParser):
# If we have completed tool calls and populated
# prev_tool_call_arr
if
(
complete_calls
>
0
and
len
(
self
.
prev_tool_call_arr
)
>
0
)
:
if
complete_calls
>
0
and
len
(
self
.
prev_tool_call_arr
)
>
0
:
# Check if all tool calls are closed
open_calls
=
(
current_text
.
count
(
self
.
tool_call_start_token
)
-
current_text
.
count
(
self
.
tool_call_end_token
)
)
open_calls
=
current_text
.
count
(
self
.
tool_call_start_token
)
-
current_text
.
count
(
self
.
tool_call_end_token
)
if
open_calls
==
0
:
# Return empty delta message to allow finish_reason
# processing
# Return empty delta for finish_reason processing
return
DeltaMessage
(
content
=
""
)
elif
not
self
.
is_tool_call_started
and
current_text
:
# This is a regular content response that's now complete
return
DeltaMessage
(
content
=
""
)
return
None
# Check if this is the first call (reset state if needed)
if
not
previous_text
:
self
.
_reset_streaming_state
()
# Update accumulated text
self
.
accumulated_text
=
current_text
...
...
@@ -371,11 +357,11 @@ class Qwen3CoderToolParser(ToolParser):
self
.
param_count
=
0
self
.
json_started
=
False
self
.
json_closed
=
False
self
.
accumulated_params
=
{}
# Check if there are more tool calls
tool_starts_count
=
current_text
.
count
(
self
.
tool_call_start_token
)
if
self
.
current_tool_index
>=
tool_starts_count
:
tool_starts
=
current_text
.
count
(
self
.
tool_call_start_token
)
if
self
.
current_tool_index
>=
tool_starts
:
# No more tool calls
self
.
is_tool_call_started
=
False
# Continue processing next tool
...
...
@@ -412,20 +398,20 @@ class Qwen3CoderToolParser(ToolParser):
# We're in a tool call, find the current tool call portion
# Need to find the correct tool call based on current_tool_index
tool_starts
:
list
[
int
]
=
[]
tool_start
_position
s
:
list
[
int
]
=
[]
idx
=
0
while
True
:
idx
=
current_text
.
find
(
self
.
tool_call_start_token
,
idx
)
if
idx
==
-
1
:
break
tool_starts
.
append
(
idx
)
tool_start
_position
s
.
append
(
idx
)
idx
+=
len
(
self
.
tool_call_start_token
)
if
self
.
current_tool_index
>=
len
(
tool_starts
):
if
self
.
current_tool_index
>=
len
(
tool_start
_position
s
):
# No more tool calls to process yet
return
None
tool_start_idx
=
tool_starts
[
self
.
current_tool_index
]
tool_start_idx
=
tool_start
_position
s
[
self
.
current_tool_index
]
# Find where this tool call ends (or current position if not ended yet)
tool_end_idx
=
current_text
.
find
(
self
.
tool_call_end_token
,
tool_start_idx
)
...
...
@@ -438,19 +424,19 @@ class Qwen3CoderToolParser(ToolParser):
# Looking for function header
if
not
self
.
header_sent
:
if
self
.
tool_call_prefix
in
tool_text
:
func_start
=
(
tool_text
.
find
(
self
.
tool_call_prefix
)
+
len
(
self
.
tool_call_prefix
)
)
func_start
=
tool_text
.
find
(
self
.
tool_call_prefix
)
+
len
(
self
.
tool_call_prefix
)
func_end
=
tool_text
.
find
(
">"
,
func_start
)
if
func_end
!=
-
1
:
# Found complete function name
self
.
current_function_name
=
tool_text
[
func_start
:
func_end
]
self
.
current_tool_
string_
id
=
self
.
_generate_tool_call_id
()
self
.
current_tool_id
=
self
.
_generate_tool_call_id
()
self
.
header_sent
=
True
self
.
in_function
=
True
# IMPORTANT: Add to prev_tool_call_arr immediately when
we
# detect a tool call. This ensures
# IMPORTANT: Add to prev_tool_call_arr immediately when
#
we
detect a tool call. This ensures
# finish_reason="tool_calls" even if parsing isn't complete
already_added
=
any
(
tool
.
get
(
"name"
)
==
self
.
current_function_name
...
...
@@ -466,7 +452,7 @@ class Qwen3CoderToolParser(ToolParser):
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
id
=
self
.
current_tool_
string_
id
,
id
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
name
=
self
.
current_function_name
,
arguments
=
""
),
type
=
"function"
,
...
...
@@ -496,10 +482,11 @@ class Qwen3CoderToolParser(ToolParser):
# Close JSON
self
.
json_closed
=
True
# Extract the complete tool call to update prev_tool_call_arr
# with final arguments. Find the function content
func_start
=
(
tool_text
.
find
(
self
.
tool_call_prefix
)
+
len
(
self
.
tool_call_prefix
))
# Extract complete tool call to update
# prev_tool_call_arr with final arguments
# Find the function content
func_start
=
tool_text
.
find
(
self
.
tool_call_prefix
)
+
len
(
self
.
tool_call_prefix
)
func_content_end
=
tool_text
.
find
(
self
.
function_end_token
,
func_start
)
if
func_content_end
!=
-
1
:
...
...
@@ -507,15 +494,17 @@ class Qwen3CoderToolParser(ToolParser):
# Parse to get the complete arguments
try
:
parsed_tool
=
self
.
_parse_xml_function_call
(
func_content
,
request
.
tools
if
request
else
None
)
func_content
,
self
.
streaming_request
.
tools
if
self
.
streaming_request
else
None
)
if
parsed_tool
:
# Update existing entry in
prev_tool_call_arr with
# complete arg
ument
s
# Update existing entry in
#
prev_tool_call_arr with
complete args
for
i
,
tool
in
enumerate
(
self
.
prev_tool_call_arr
):
if
(
tool
.
get
(
"name"
)
==
parsed_tool
.
function
.
name
):
self
.
prev_tool_call_arr
[
i
][
"arguments"
]
=
(
parsed_tool
.
function
.
arguments
)
if
tool
.
get
(
"name"
)
==
parsed_tool
.
function
.
name
:
args
=
parsed_tool
.
function
.
arguments
self
.
prev_tool_call_arr
[
i
][
"arguments"
]
=
args
break
except
Exception
:
pass
# Ignore parsing errors during streaming
...
...
@@ -530,17 +519,12 @@ class Qwen3CoderToolParser(ToolParser):
# Reset state for next tool
self
.
in_function
=
False
self
.
json_closed
=
True
self
.
accumulated_params
=
{}
return
result
# Look for parameters
# Count how many complete parameters we have processed
complete_params
=
tool_text
.
count
(
self
.
parameter_end_token
)
# Check if we should start a new parameter
if
not
self
.
in_param
and
self
.
param_count
<
complete_params
:
# Find the unprocessed parameter
# Count parameter starts
# Find all parameter starts
param_starts
=
[]
idx
=
0
while
True
:
...
...
@@ -550,7 +534,9 @@ class Qwen3CoderToolParser(ToolParser):
param_starts
.
append
(
idx
)
idx
+=
len
(
self
.
parameter_prefix
)
if
len
(
param_starts
)
>
self
.
param_count
:
# Check if we should start a new parameter
if
(
not
self
.
in_param
and
self
.
param_count
<
len
(
param_starts
)
and
len
(
param_starts
)
>
self
.
param_count
):
# Process the next parameter
param_idx
=
param_starts
[
self
.
param_count
]
param_start
=
param_idx
+
len
(
self
.
parameter_prefix
)
...
...
@@ -568,23 +554,62 @@ class Qwen3CoderToolParser(ToolParser):
value_text
=
value_text
[
1
:]
# Find where this parameter ends
param_end_idx
=
value_text
.
find
(
self
.
parameter_end_token
)
param_end_idx
=
value_text
.
find
(
self
.
parameter_end_token
)
if
param_end_idx
==
-
1
:
# No closing tag, look for next parameter or
# function end
next_param_idx
=
value_text
.
find
(
self
.
parameter_prefix
)
func_end_idx
=
value_text
.
find
(
self
.
function_end_token
)
if
next_param_idx
!=
-
1
and
(
func_end_idx
==
-
1
or
next_param_idx
<
func_end_idx
):
param_end_idx
=
next_param_idx
elif
func_end_idx
!=
-
1
:
param_end_idx
=
func_end_idx
else
:
# Neither found, check if tool call is complete
if
self
.
tool_call_end_token
in
tool_text
:
# Tool call is complete, so parameter
# must be complete too. Use all
# remaining text before function end
param_end_idx
=
len
(
value_text
)
else
:
# Still streaming, wait for more content
return
None
if
param_end_idx
!=
-
1
:
# Complete parameter found
param_value
=
value_text
[:
param_end_idx
]
if
param_value
.
endswith
(
"
\n
"
):
param_value
=
param_value
[:
-
1
]
# Build complete JSON fragment for this parameter
# Store raw value for later processing
self
.
accumulated_params
[
self
.
current_param_name
]
=
param_value
# Get parameter configuration for type conversion
param_config
=
self
.
_get_arguments_config
(
self
.
current_function_name
or
""
,
self
.
streaming_request
.
tools
if
self
.
streaming_request
else
None
)
# Convert param value to appropriate type
converted_value
=
self
.
_convert_param_value
(
param_value
,
self
.
current_param_name
,
param_config
,
self
.
current_function_name
or
""
)
# Build JSON fragment based on the converted type
# Use json.dumps to properly serialize the value
serialized_value
=
json
.
dumps
(
converted_value
,
ensure_ascii
=
False
)
if
self
.
param_count
==
0
:
json_fragment
=
(
'"'
+
self
.
current_param_name
+
'": "'
+
json
.
dumps
(
param_value
)[
1
:
-
1
]
+
'"'
)
json_fragment
=
(
f
'"
{
self
.
current_param_name
}
": '
f
'
{
serialized_value
}
'
)
else
:
json_fragment
=
(
', "'
+
self
.
current_param_name
+
'": "'
+
json
.
dumps
(
param_value
)[
1
:
-
1
]
+
'"'
)
json_fragment
=
(
f
', "
{
self
.
current_param_name
}
": '
f
'
{
serialized_value
}
'
)
self
.
param_count
+=
1
...
...
@@ -596,7 +621,8 @@ class Qwen3CoderToolParser(ToolParser):
)
])
# Continue parameter value
# Continue parameter value - Not used in the current implementation
# since we process complete parameters above
if
self
.
in_param
:
if
self
.
parameter_end_token
in
delta_text
:
# End of parameter
...
...
@@ -608,25 +634,42 @@ class Qwen3CoderToolParser(ToolParser):
gt_idx
=
value_chunk
.
find
(
">"
)
value_chunk
=
value_chunk
[
gt_idx
+
1
:]
if
(
not
self
.
current_param_value
and
value_chunk
.
startswith
(
"
\n
"
)
)
:
if
not
self
.
current_param_value
and
value_chunk
.
startswith
(
"
\n
"
):
value_chunk
=
value_chunk
[
1
:]
#
Calculate incremental JSON
#
Store complete value
full_value
=
self
.
current_param_value
+
value_chunk
prev_escaped
=
(
json
.
dumps
(
self
.
current_param_value
)[
1
:
-
1
]
if
self
.
current_param_value
else
""
)
full_escaped
=
json
.
dumps
(
full_value
)[
1
:
-
1
]
delta_escaped
=
full_escaped
[
len
(
prev_escaped
):]
self
.
accumulated_params
[
self
.
current_param_name
]
=
full_value
# Get parameter configuration for type conversion
param_config
=
self
.
_get_arguments_config
(
self
.
current_function_name
or
""
,
self
.
streaming_request
.
tools
if
self
.
streaming_request
else
None
)
# Convert the parameter value to the appropriate type
converted_value
=
self
.
_convert_param_value
(
full_value
,
self
.
current_param_name
or
""
,
param_config
,
self
.
current_function_name
or
""
)
# Serialize the converted value
serialized_value
=
json
.
dumps
(
converted_value
,
ensure_ascii
=
False
)
# Since we've been streaming the quoted version,
# we need to close it properly
# This is complex - for now just complete the value
self
.
in_param
=
False
self
.
current_param_value
=
""
# Just close the current parameter string
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
function
=
DeltaFunctionCall
(
arguments
=
delta_escaped
+
'"'
),
arguments
=
'"'
),
# Close the string quote
)
])
else
:
...
...
@@ -638,18 +681,18 @@ class Qwen3CoderToolParser(ToolParser):
gt_idx
=
value_chunk
.
find
(
">"
)
value_chunk
=
value_chunk
[
gt_idx
+
1
:]
if
(
not
self
.
current_param_value
and
value_chunk
.
startswith
(
"
\n
"
)
)
:
if
not
self
.
current_param_value
and
value_chunk
.
startswith
(
"
\n
"
):
value_chunk
=
value_chunk
[
1
:]
if
value_chunk
:
# Stream the escaped delta
prev_escaped
=
(
json
.
dumps
(
self
.
current_param_value
)[
1
:
-
1
]
if
self
.
current_param_value
else
""
)
prev_escaped
=
json
.
dumps
(
self
.
current_param_value
,
ensure_ascii
=
False
)[
1
:
-
1
]
if
self
.
current_param_value
else
""
self
.
current_param_value
+=
value_chunk
full_escaped
=
json
.
dumps
(
self
.
current_param_v
al
u
e
)[
1
:
-
1
]
full_escaped
=
json
.
dumps
(
self
.
current_param_value
,
ensure_ascii
=
F
al
s
e
)[
1
:
-
1
]
delta_escaped
=
full_escaped
[
len
(
prev_escaped
):]
if
delta_escaped
:
...
...
vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py
0 → 100644
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from qwen3coder xml parser, All rights reserved.
# ruff: noqa: E501
import
ast
import
json
import
uuid
from
collections.abc
import
Sequence
from
typing
import
Any
,
Optional
,
Union
import
regex
as
re
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
ChatCompletionToolsParam
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
ToolCall
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
ToolParser
,
ToolParserManager
)
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
logger
=
init_logger
(
__name__
)
@
ToolParserManager
.
register_module
(
"seed_oss"
)
class
SeedOssToolParser
(
ToolParser
):
TOOL_CALL_START
=
"<seed:tool_call>"
TOOL_CALL_END
=
"</seed:tool_call>"
def
__init__
(
self
,
tokenizer
:
AnyTokenizer
):
super
().
__init__
(
tokenizer
)
# --- streaming state ---
self
.
_reset_streaming_state
()
self
.
prev_tool_call_arr
:
list
[
dict
]
=
[]
self
.
tool_call_start_token
:
str
=
self
.
TOOL_CALL_START
self
.
tool_call_end_token
:
str
=
self
.
TOOL_CALL_END
# Sentinel tokens for streaming mode
self
.
tool_call_prefix
:
str
=
"<function="
self
.
function_end_token
:
str
=
"</function>"
self
.
parameter_prefix
:
str
=
"<parameter="
self
.
parameter_end_token
:
str
=
"</parameter>"
self
.
think_start_token
:
str
=
"<seed:think>"
self
.
think_end_token
:
str
=
"</seed:think>"
self
.
is_tool_call_started
:
bool
=
False
self
.
is_thinking_end
:
bool
=
False
self
.
failed_count
:
int
=
0
self
.
_reset_streaming_state
()
self
.
tool_call_start_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_start_token
)
self
.
tool_call_end_token_id
=
self
.
vocab
.
get
(
self
.
tool_call_end_token
)
self
.
think_end_token_id
=
self
.
vocab
.
get
(
self
.
think_end_token
)
if
(
self
.
tool_call_start_token_id
is
None
or
self
.
tool_call_end_token_id
is
None
):
raise
RuntimeError
(
"Seed_Oss XML parser: tokenizer did not include "
"<seed:tool_call> or its closing tag."
)
tool_start_re
=
re
.
escape
(
self
.
tool_call_start_token
)
tool_end_re
=
re
.
escape
(
self
.
tool_call_end_token
)
self
.
tool_call_complete_regex
=
re
.
compile
(
rf
"
{
tool_start_re
}
(.*?)
{
tool_end_re
}
"
,
re
.
DOTALL
)
self
.
tool_call_regex
=
re
.
compile
(
rf
"
{
tool_start_re
}
(.*?)
{
tool_end_re
}
|
{
tool_start_re
}
(.*?)$"
,
re
.
DOTALL
)
self
.
tool_call_function_regex
=
re
.
compile
(
r
"<function=(.*?)</function>|<function=(.*)$"
,
re
.
DOTALL
)
self
.
tool_call_parameter_regex
=
re
.
compile
(
r
"<parameter=(.*?)</parameter>|<parameter=(.*?)$"
,
re
.
DOTALL
)
logger
.
info
(
"vLLM Seed-Oss XML tool parser loaded (%s)."
,
self
.
__class__
.
__name__
)
def
_generate_tool_call_id
(
self
)
->
str
:
"""Generate a unique tool call ID."""
return
f
"call_
{
uuid
.
uuid4
().
hex
[:
24
]
}
"
def
_reset_streaming_state
(
self
):
"""Reset all streaming state."""
self
.
current_tool_index
=
0
self
.
is_tool_call_started
=
False
self
.
header_sent
=
False
self
.
current_tool_id
=
-
1
self
.
current_function_name
=
None
self
.
current_param_name
=
None
self
.
current_param_value
=
""
self
.
param_count
=
0
self
.
in_param
=
False
self
.
in_function
=
False
self
.
accumulated_text
=
""
self
.
json_started
=
False
self
.
json_closed
=
False
def
_parse_xml_function_call
(
self
,
function_call_str
:
str
,
tools
:
Optional
[
list
[
ChatCompletionToolsParam
]]
)
->
Optional
[
ToolCall
]:
def
get_arguments_config
(
func_name
:
str
)
->
dict
:
if
tools
is
None
:
return
{}
for
config
in
tools
:
if
not
hasattr
(
config
,
"type"
)
or
not
(
hasattr
(
config
,
"function"
)
and
hasattr
(
config
.
function
,
"name"
)):
continue
if
(
config
.
type
==
"function"
and
config
.
function
.
name
==
func_name
):
if
not
hasattr
(
config
.
function
,
"parameters"
):
return
{}
params
=
config
.
function
.
parameters
if
isinstance
(
params
,
dict
)
and
"properties"
in
params
:
return
params
[
"properties"
]
elif
isinstance
(
params
,
dict
):
return
params
else
:
return
{}
logger
.
warning
(
"Tool '%s' is not defined in the tools list."
,
func_name
)
return
{}
def
convert_param_value
(
param_value
:
str
,
param_name
:
str
,
param_config
:
dict
,
func_name
:
str
)
->
Any
:
# Handle null value for any type
if
param_value
.
lower
()
==
"null"
:
return
None
if
param_name
not
in
param_config
:
if
param_config
!=
{}:
logger
.
warning
(
"Parsed parameter '%s' is not defined in "
"the tool parameters for tool '%s', "
"directly returning the string value."
,
param_name
,
func_name
)
return
param_value
if
(
isinstance
(
param_config
[
param_name
],
dict
)
and
"type"
in
param_config
[
param_name
]):
param_type
=
str
(
param_config
[
param_name
][
"type"
]).
strip
().
lower
()
else
:
param_type
=
"string"
if
param_type
in
[
"string"
,
"str"
,
"text"
,
"varchar"
,
"char"
,
"enum"
]:
return
param_value
elif
(
param_type
.
startswith
(
"int"
)
or
param_type
.
startswith
(
"uint"
)
or
param_type
.
startswith
(
"long"
)
or
param_type
.
startswith
(
"short"
)
or
param_type
.
startswith
(
"unsigned"
)):
try
:
param_value
=
int
(
param_value
)
# type: ignore
except
(
ValueError
,
TypeError
):
logger
.
warning
(
"Parsed value '%s' of parameter '%s' is not an integer in tool "
"'%s', degenerating to string."
,
param_value
,
param_name
,
func_name
)
return
param_value
elif
param_type
.
startswith
(
"num"
)
or
param_type
.
startswith
(
"float"
):
try
:
float_param_value
=
float
(
param_value
)
param_value
=
float_param_value
if
float_param_value
-
int
(
float_param_value
)
!=
0
else
int
(
float_param_value
)
# type: ignore
except
(
ValueError
,
TypeError
):
logger
.
warning
(
"Parsed value '%s' of parameter '%s' is not a float in tool "
"'%s', degenerating to string."
,
param_value
,
param_name
,
func_name
)
return
param_value
elif
param_type
in
[
"boolean"
,
"bool"
,
"binary"
]:
param_value
=
param_value
.
lower
()
if
param_value
not
in
[
"true"
,
"false"
]:
logger
.
warning
(
"Parsed value '%s' of parameter '%s' is not a boolean "
"(`true` of `false`) in tool '%s', degenerating to false."
,
param_value
,
param_name
,
func_name
)
return
param_value
==
"true"
else
:
if
param_type
==
"object"
or
param_type
.
startswith
(
"dict"
):
try
:
param_value
=
json
.
loads
(
param_value
)
return
param_value
except
(
ValueError
,
TypeError
,
json
.
JSONDecodeError
):
logger
.
warning
(
"Parsed value '%s' of parameter '%s' is not a valid JSON "
"object in tool '%s', will try other methods to parse it."
,
param_value
,
param_name
,
func_name
)
try
:
param_value
=
ast
.
literal_eval
(
param_value
)
except
(
ValueError
,
SyntaxError
):
logger
.
warning
(
"Parsed value '%s' of parameter '%s' cannot be converted via "
"Python `ast.literal_eval()` in tool '%s', degenerating to string."
,
param_value
,
param_name
,
func_name
)
return
param_value
# Extract function name
end_index
=
function_call_str
.
index
(
">"
)
function_name
=
function_call_str
[:
end_index
]
param_config
=
get_arguments_config
(
function_name
)
parameters
=
function_call_str
[
end_index
+
1
:]
param_dict
=
{}
for
match
in
self
.
tool_call_parameter_regex
.
findall
(
parameters
):
match_text
=
match
[
0
]
if
match
[
0
]
else
match
[
1
]
idx
=
match_text
.
index
(
">"
)
param_name
=
match_text
[:
idx
]
param_value
=
str
(
match_text
[
idx
+
1
:])
# Remove prefix and trailing \n
if
param_value
.
startswith
(
"
\n
"
):
param_value
=
param_value
[
1
:]
if
param_value
.
endswith
(
"
\n
"
):
param_value
=
param_value
[:
-
1
]
param_dict
[
param_name
]
=
convert_param_value
(
param_value
,
param_name
,
param_config
,
function_name
)
return
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
function_name
,
arguments
=
json
.
dumps
(
param_dict
,
ensure_ascii
=
False
)),
)
def
_get_function_calls
(
self
,
model_output
:
str
)
->
list
[
str
]:
# Find all tool calls
matched_ranges
=
self
.
tool_call_regex
.
findall
(
model_output
)
raw_tool_calls
=
[
match
[
0
]
if
match
[
0
]
else
match
[
1
]
for
match
in
matched_ranges
]
# Back-off strategy if no tool_call tags found
if
len
(
raw_tool_calls
)
==
0
:
raw_tool_calls
=
[
model_output
]
raw_function_calls
=
[]
for
tool_call
in
raw_tool_calls
:
raw_function_calls
.
extend
(
self
.
tool_call_function_regex
.
findall
(
tool_call
))
function_calls
=
[
match
[
0
]
if
match
[
0
]
else
match
[
1
]
for
match
in
raw_function_calls
]
return
function_calls
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
,
)
->
ExtractedToolCallInformation
:
# Quick check to avoid unnecessary processing
if
self
.
tool_call_prefix
not
in
model_output
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
# Check if both think start and end tokens are present
if
(
self
.
think_start_token
in
model_output
and
self
.
think_end_token
in
model_output
):
# Find the position of think end token
think_end_index
=
model_output
.
find
(
self
.
think_end_token
)
+
len
(
self
.
think_end_token
)
# Extract content after think end token
result_content
=
model_output
[
think_end_index
:]
thinking_content
=
model_output
[:
think_end_index
]
else
:
thinking_content
=
""
result_content
=
model_output
try
:
function_calls
=
self
.
_get_function_calls
(
result_content
)
if
len
(
function_calls
)
==
0
:
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
tool_calls
=
[
self
.
_parse_xml_function_call
(
function_call_str
,
request
.
tools
)
for
function_call_str
in
function_calls
]
# Populate prev_tool_call_arr for serving layer to set finish_reason
self
.
prev_tool_call_arr
.
clear
()
# Clear previous calls
for
tool_call
in
tool_calls
:
if
tool_call
:
self
.
prev_tool_call_arr
.
append
({
"name"
:
tool_call
.
function
.
name
,
"arguments"
:
tool_call
.
function
.
arguments
,
})
# Extract content before tool calls
tool_call_start_index
=
result_content
.
find
(
self
.
tool_call_start_token
)
tool_call_start_index
=
(
tool_call_start_index
if
tool_call_start_index
>=
0
else
result_content
.
find
(
self
.
tool_call_prefix
))
content
=
thinking_content
+
result_content
[:
tool_call_start_index
]
return
ExtractedToolCallInformation
(
tools_called
=
(
len
(
tool_calls
)
>
0
),
tool_calls
=
tool_calls
,
content
=
content
if
content
else
None
,
)
except
Exception
:
logger
.
exception
(
"Error in extracting tool call from response."
)
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
Union
[
DeltaMessage
,
None
]:
# If no delta text, return None unless
# it's an EOS token after tool calls
if
not
delta_text
:
# Check if this is an EOS token after all tool calls are complete
# We check for tool calls in the text even if is_tool_call_started
# is False because it might have been reset after processing all tools
if
(
delta_token_ids
and
self
.
tool_call_end_token_id
not
in
delta_token_ids
):
# Count complete tool calls
complete_calls
=
len
(
self
.
tool_call_complete_regex
.
findall
(
current_text
))
# If we have completed tool calls and populated prev_tool_call_arr
if
complete_calls
>
0
and
len
(
self
.
prev_tool_call_arr
)
>
0
:
# Check if all tool calls are closed
open_calls
=
current_text
.
count
(
self
.
tool_call_start_token
)
-
current_text
.
count
(
self
.
tool_call_end_token
)
if
open_calls
==
0
:
# Return empty delta message to allow finish_reason processing
return
DeltaMessage
(
content
=
""
)
elif
not
self
.
is_tool_call_started
and
current_text
:
# This is a regular content response that's now complete
return
DeltaMessage
(
content
=
""
)
return
None
# Check if this is the first call (reset state if needed)
if
not
previous_text
:
self
.
_reset_streaming_state
()
# Update accumulated text
self
.
accumulated_text
=
current_text
# Check if we need to advance to next tool
if
self
.
json_closed
and
not
self
.
in_function
:
# Check if this tool call has ended
tool_ends
=
current_text
.
count
(
self
.
tool_call_end_token
)
if
tool_ends
>
self
.
current_tool_index
:
# This tool has ended, advance to next
self
.
current_tool_index
+=
1
self
.
header_sent
=
False
self
.
param_count
=
0
self
.
json_started
=
False
self
.
json_closed
=
False
# Check if there are more tool calls
if
self
.
current_tool_index
>=
current_text
.
count
(
self
.
tool_call_start_token
):
# No more tool calls
self
.
is_tool_call_started
=
False
# Continue processing next tool
return
None
# Check if end thinking
if
(
not
self
.
is_thinking_end
and
(
self
.
think_end_token_id
in
delta_token_ids
or
self
.
think_end_token
in
delta_text
)):
self
.
is_thinking_end
=
True
# If thinking hasn't ended yet, don't process any tool calls
if
not
self
.
is_thinking_end
:
return
DeltaMessage
(
content
=
delta_text
)
# Handle normal content before tool calls
if
not
self
.
is_tool_call_started
:
# Check if tool call is starting
if
(
self
.
tool_call_start_token_id
in
delta_token_ids
or
self
.
tool_call_start_token
in
delta_text
):
self
.
is_tool_call_started
=
True
# Return any content before the tool call
if
self
.
tool_call_start_token
in
delta_text
:
content_before
=
delta_text
[:
delta_text
.
index
(
self
.
tool_call_start_token
)]
if
content_before
:
return
DeltaMessage
(
content
=
content_before
)
return
None
else
:
# Check if we're between tool calls - skip whitespace
if
(
current_text
.
rstrip
().
endswith
(
self
.
tool_call_end_token
)
and
delta_text
.
strip
()
==
""
):
# We just ended a tool call, skip whitespace
return
None
# Normal content, no tool call
return
DeltaMessage
(
content
=
delta_text
)
# Check if we're between tool calls (waiting for next one)
# Count tool calls we've seen vs processed
tool_starts_count
=
current_text
.
count
(
self
.
tool_call_start_token
)
if
self
.
current_tool_index
>=
tool_starts_count
:
# We're past all tool calls, shouldn't be here
return
None
# We're in a tool call, find the current tool call portion
# Need to find the correct tool call based on current_tool_index
# Only process tool calls after think_end_token
think_end_index
=
current_text
.
find
(
self
.
think_end_token
)
+
len
(
self
.
think_end_token
)
if
self
.
think_end_token
in
current_text
else
0
tool_starts
:
list
[
int
]
=
[]
idx
=
think_end_index
while
True
:
idx
=
current_text
.
find
(
self
.
tool_call_start_token
,
idx
)
if
idx
==
-
1
:
break
tool_starts
.
append
(
idx
)
idx
+=
len
(
self
.
tool_call_start_token
)
if
self
.
current_tool_index
>=
len
(
tool_starts
):
# No more tool calls to process yet
return
None
tool_start_idx
=
tool_starts
[
self
.
current_tool_index
]
# Find where this tool call ends (or current position if not ended yet)
tool_end_idx
=
current_text
.
find
(
self
.
tool_call_end_token
,
tool_start_idx
)
if
tool_end_idx
==
-
1
:
tool_text
=
current_text
[
tool_start_idx
:]
else
:
tool_text
=
current_text
[
tool_start_idx
:
tool_end_idx
+
len
(
self
.
tool_call_end_token
)]
# Looking for function header
if
not
self
.
header_sent
:
if
self
.
tool_call_prefix
in
tool_text
:
func_start
=
tool_text
.
find
(
self
.
tool_call_prefix
)
+
len
(
self
.
tool_call_prefix
)
func_end
=
tool_text
.
find
(
">"
,
func_start
)
if
func_end
!=
-
1
:
# Found complete function name
self
.
current_function_name
=
tool_text
[
func_start
:
func_end
]
self
.
current_tool_id
=
self
.
_generate_tool_call_id
(
)
# type: ignore
self
.
header_sent
=
True
self
.
in_function
=
True
# IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call
# This ensures finish_reason="tool_calls" even if parsing isn't complete
already_added
=
any
(
tool
.
get
(
"name"
)
==
self
.
current_function_name
for
tool
in
self
.
prev_tool_call_arr
)
if
not
already_added
:
self
.
prev_tool_call_arr
.
append
({
"name"
:
self
.
current_function_name
,
"arguments"
:
"{}"
,
# Placeholder, will be updated later
})
# Send header with function info
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
id
=
self
.
current_tool_id
,
function
=
DeltaFunctionCall
(
name
=
self
.
current_function_name
,
arguments
=
""
),
type
=
"function"
,
)
])
return
None
# We've sent header, now handle function body
if
self
.
in_function
:
# Send opening brace if not sent yet
if
(
not
self
.
json_started
and
self
.
parameter_prefix
not
in
delta_text
):
self
.
json_started
=
True
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
function
=
DeltaFunctionCall
(
arguments
=
"{"
),
)
])
# Make sure json_started is set if we're processing parameters
if
not
self
.
json_started
:
self
.
json_started
=
True
# Check for function end in accumulated text
if
not
self
.
json_closed
and
self
.
function_end_token
in
tool_text
:
# Close JSON
self
.
json_closed
=
True
# Extract the complete tool call to update prev_tool_call_arr with final arguments
# Find the function content
func_start
=
tool_text
.
find
(
self
.
tool_call_prefix
)
+
len
(
self
.
tool_call_prefix
)
func_content_end
=
tool_text
.
find
(
self
.
function_end_token
,
func_start
)
if
func_content_end
!=
-
1
:
func_content
=
tool_text
[
func_start
:
func_content_end
]
# Parse to get the complete arguments
try
:
parsed_tool
=
self
.
_parse_xml_function_call
(
func_content
,
request
.
tools
if
request
else
None
)
if
parsed_tool
:
# Update existing entry in prev_tool_call_arr with complete arguments
for
i
,
tool
in
enumerate
(
self
.
prev_tool_call_arr
):
if
tool
.
get
(
"name"
)
==
parsed_tool
.
function
.
name
:
self
.
prev_tool_call_arr
[
i
][
"arguments"
]
=
(
parsed_tool
.
function
.
arguments
)
break
except
Exception
:
logger
.
warning
(
"Failed to parse tool arguments during streaming."
,
exc_info
=
True
)
result
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
function
=
DeltaFunctionCall
(
arguments
=
"}"
),
)
])
# Reset state for next tool
self
.
in_function
=
False
self
.
json_closed
=
True
return
result
# Look for parameters
# Count how many complete parameters we have processed
complete_params
=
tool_text
.
count
(
self
.
parameter_end_token
)
# Check if we should start a new parameter
if
not
self
.
in_param
and
self
.
param_count
<
complete_params
:
# Find the unprocessed parameter
# Count parameter starts
param_starts
=
[]
idx
=
0
while
True
:
idx
=
tool_text
.
find
(
self
.
parameter_prefix
,
idx
)
if
idx
==
-
1
:
break
param_starts
.
append
(
idx
)
idx
+=
len
(
self
.
parameter_prefix
)
if
len
(
param_starts
)
>
self
.
param_count
:
# Process the next parameter
param_idx
=
param_starts
[
self
.
param_count
]
param_start
=
param_idx
+
len
(
self
.
parameter_prefix
)
remaining
=
tool_text
[
param_start
:]
if
">"
in
remaining
:
# We have the complete parameter name
name_end
=
remaining
.
find
(
">"
)
self
.
current_param_name
=
remaining
[:
name_end
]
# Find the parameter value
value_start
=
param_start
+
name_end
+
1
value_text
=
tool_text
[
value_start
:]
if
value_text
.
startswith
(
"
\n
"
):
value_text
=
value_text
[
1
:]
# Find where this parameter ends
param_end_idx
=
value_text
.
find
(
self
.
parameter_end_token
)
if
param_end_idx
!=
-
1
:
# Complete parameter found
param_value
=
value_text
[:
param_end_idx
]
if
param_value
.
endswith
(
"
\n
"
):
param_value
=
param_value
[:
-
1
]
# Build complete JSON fragment for this parameter
if
self
.
param_count
==
0
:
json_fragment
=
(
'"'
+
self
.
current_param_name
+
'": "'
+
json
.
dumps
(
param_value
)[
1
:
-
1
]
+
'"'
)
else
:
json_fragment
=
(
', "'
+
self
.
current_param_name
+
'": "'
+
json
.
dumps
(
param_value
)[
1
:
-
1
]
+
'"'
)
self
.
param_count
+=
1
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
function
=
DeltaFunctionCall
(
arguments
=
json_fragment
),
)
])
# Continue parameter value
if
self
.
in_param
:
if
self
.
parameter_end_token
in
delta_text
:
# End of parameter
end_idx
=
delta_text
.
find
(
self
.
parameter_end_token
)
value_chunk
=
delta_text
[:
end_idx
]
# Skip past > if at start
if
not
self
.
current_param_value
and
">"
in
value_chunk
:
gt_idx
=
value_chunk
.
find
(
">"
)
value_chunk
=
value_chunk
[
gt_idx
+
1
:]
if
not
self
.
current_param_value
and
value_chunk
.
startswith
(
"
\n
"
):
value_chunk
=
value_chunk
[
1
:]
# Calculate incremental JSON
full_value
=
self
.
current_param_value
+
value_chunk
prev_escaped
=
(
json
.
dumps
(
self
.
current_param_value
)[
1
:
-
1
]
if
self
.
current_param_value
else
""
)
full_escaped
=
json
.
dumps
(
full_value
)[
1
:
-
1
]
delta_escaped
=
full_escaped
[
len
(
prev_escaped
):]
self
.
in_param
=
False
self
.
current_param_value
=
""
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
function
=
DeltaFunctionCall
(
arguments
=
delta_escaped
+
'"'
),
)
])
else
:
# Continue accumulating value
value_chunk
=
delta_text
# Handle first chunk after param name
if
not
self
.
current_param_value
and
">"
in
value_chunk
:
gt_idx
=
value_chunk
.
find
(
">"
)
value_chunk
=
value_chunk
[
gt_idx
+
1
:]
if
not
self
.
current_param_value
and
value_chunk
.
startswith
(
"
\n
"
):
value_chunk
=
value_chunk
[
1
:]
if
value_chunk
:
# Stream the escaped delta
prev_escaped
=
(
json
.
dumps
(
self
.
current_param_value
)[
1
:
-
1
]
if
self
.
current_param_value
else
""
)
self
.
current_param_value
+=
value_chunk
full_escaped
=
json
.
dumps
(
self
.
current_param_value
)[
1
:
-
1
]
delta_escaped
=
full_escaped
[
len
(
prev_escaped
):]
if
delta_escaped
:
return
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
self
.
current_tool_index
,
function
=
DeltaFunctionCall
(
arguments
=
delta_escaped
),
)
])
return
None
vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py
View file @
d2b52805
...
...
@@ -7,7 +7,7 @@ from typing import Any, Optional, Union
import
regex
as
re
from
vllm.entrypoints.chat_utils
import
random
_tool_call_id
from
vllm.entrypoints.chat_utils
import
make
_tool_call_id
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
...
...
@@ -186,11 +186,31 @@ class xLAMToolParser(ToolParser):
"""
Extract tool calls for streaming mode.
"""
# Simplify detection: if it begins with "[" treat it as a function call
is_function_call
=
(
current_text
.
strip
().
startswith
(
"["
))
# If not a function call, return normal content
if
not
is_function_call
:
# First, check for a definitive start of a tool call block.
# This prevents premature parsing of incomplete output.
stripped_text
=
current_text
.
strip
()
preprocessed_content
,
preprocessed_tool_calls
=
(
self
.
preprocess_model_output
(
current_text
))
# For JSON code blocks, we need to detect them earlier, even if incomplete
has_potential_json_block
=
(
"```json"
in
current_text
or
"```
\n
["
in
current_text
or
"[TOOL_CALLS]"
in
current_text
or
"<tool_call>"
in
current_text
)
is_tool_call_block
=
(
stripped_text
.
startswith
(
"["
)
or
stripped_text
.
startswith
(
"<tool_call>"
)
or
stripped_text
.
startswith
(
"[TOOL_CALLS]"
)
or
# Check if we have thinking tags with JSON-like content following
(
"</think>["
in
current_text
)
or
# Check if the text contains a JSON array after preprocessing
preprocessed_tool_calls
is
not
None
or
# For JSON code blocks, detect early if we see enough structure
(
has_potential_json_block
and
'"name"'
in
current_text
and
'"arguments"'
in
current_text
))
if
not
is_tool_call_block
:
return
DeltaMessage
(
content
=
delta_text
)
try
:
...
...
@@ -204,7 +224,10 @@ class xLAMToolParser(ToolParser):
# Try parsing as JSON to check for complete tool calls
try
:
parsed_tools
=
json
.
loads
(
current_text
)
# Use preprocessed tool calls if available
tool_calls_text
=
(
preprocessed_tool_calls
if
preprocessed_tool_calls
else
current_text
)
parsed_tools
=
json
.
loads
(
tool_calls_text
)
if
isinstance
(
parsed_tools
,
list
):
# Update our tool array for next time
self
.
prev_tool_call_arr
=
parsed_tools
...
...
@@ -226,7 +249,7 @@ class xLAMToolParser(ToolParser):
function_name
=
name_match
.
group
(
1
)
# The test expects us to send just the name first
tool_id
=
random
_tool_call_id
()
tool_id
=
make
_tool_call_id
()
delta
=
DeltaMessage
(
tool_calls
=
[
DeltaToolCall
(
index
=
0
,
...
...
@@ -257,12 +280,39 @@ class xLAMToolParser(ToolParser):
return
delta
# Use regex to identify tool calls in the output
# Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks
search_text
=
(
preprocessed_tool_calls
if
preprocessed_tool_calls
else
current_text
)
# For JSON code blocks that aren't complete yet, try to extract the JSON content
if
not
preprocessed_tool_calls
and
has_potential_json_block
:
# Try to extract the JSON array from within the code block
json_match
=
re
.
search
(
r
"```(?:json)?\s*([\s\S]*?)(?:```|$)"
,
current_text
)
if
json_match
:
potential_json
=
json_match
.
group
(
1
).
strip
()
# Use this as search text even if it's incomplete
if
potential_json
.
startswith
(
"["
)
and
(
'"name"'
in
potential_json
and
'"arguments"'
in
potential_json
):
search_text
=
potential_json
# Try to find complete tool names first
name_pattern
=
r
'"name"\s*:\s*"([^"]+)"'
name_matches
=
list
(
re
.
finditer
(
name_pattern
,
current
_text
))
name_matches
=
list
(
re
.
finditer
(
name_pattern
,
search
_text
))
tool_count
=
len
(
name_matches
)
# If no
tools found yet, return
# If no
complete tool names found, check for partial tool names
if
tool_count
==
0
:
# Check if we're in the middle of parsing a tool name
partial_name_pattern
=
r
'"name"\s*:\s*"([^"]*)'
partial_matches
=
list
(
re
.
finditer
(
partial_name_pattern
,
search_text
))
if
partial_matches
:
# We have a partial tool name - not ready to emit yet
return
None
else
:
# No tools found at all
return
None
# Ensure our state arrays are large enough
...
...
@@ -332,7 +382,7 @@ class xLAMToolParser(ToolParser):
# First, check for the empty arguments case: "arguments": {}
empty_args_pattern
=
(
r
'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}'
)
empty_args_match
=
re
.
search
(
empty_args_pattern
,
current
_text
)
empty_args_match
=
re
.
search
(
empty_args_pattern
,
search
_text
)
# Check if this tool has empty arguments
if
empty_args_match
and
empty_args_match
.
start
()
>
0
:
...
...
@@ -376,7 +426,7 @@ class xLAMToolParser(ToolParser):
# Extract arguments for current tool using regex for non-empty arguments
args_pattern
=
r
'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
args_matches
=
list
(
re
.
finditer
(
args_pattern
,
current
_text
))
args_matches
=
list
(
re
.
finditer
(
args_pattern
,
search
_text
))
if
current_idx
<
len
(
args_matches
):
args_text
=
args_matches
[
current_idx
].
group
(
1
)
...
...
@@ -384,17 +434,25 @@ class xLAMToolParser(ToolParser):
# Handle transition between tools
is_last_tool
=
current_idx
==
tool_count
-
1
# Find where the arguments for our current tool end
if
not
is_last_tool
:
# If we have more tools after this one, try to find the complete argument block
next_tool_pos
=
current_text
.
find
(
"},{"
,
args_matches
[
current_idx
].
start
())
if
next_tool_pos
!=
-
1
:
args_end_pos
=
(
next_tool_pos
+
1
)
# +1 to include the '}'
args_text
=
(
current_text
[
args_matches
[
current_idx
]
.
start
():
args_end_pos
].
split
(
'"arguments":'
)[
1
].
strip
())
# For multiple tools, extract only the arguments for the current tool
if
tool_count
>
1
:
# Parse the entire JSON structure to properly extract arguments for each tool
try
:
parsed_tools
=
json
.
loads
(
search_text
)
if
isinstance
(
parsed_tools
,
list
)
and
current_idx
<
len
(
parsed_tools
):
current_tool
=
parsed_tools
[
current_idx
]
if
isinstance
(
current_tool
.
get
(
"arguments"
),
dict
):
args_text
=
json
.
dumps
(
current_tool
[
"arguments"
])
else
:
args_text
=
str
(
current_tool
.
get
(
"arguments"
,
"{}"
))
except
(
json
.
JSONDecodeError
,
KeyError
,
IndexError
):
# Fallback to regex-based extraction
pass
# If arguments haven't been sent yet
sent_args
=
self
.
streaming_state
[
"sent_tools"
][
...
...
vllm/entrypoints/utils.py
View file @
d2b52805
...
...
@@ -313,12 +313,14 @@ def log_non_default_args(args: Union[argparse.Namespace, EngineArgs]):
# Handle EngineArgs instance
elif
isinstance
(
args
,
EngineArgs
):
default_args
=
EngineArgs
()
# Create default instance
default_args
=
EngineArgs
(
model
=
args
.
model
)
# Create default instance
for
field
in
dataclasses
.
fields
(
args
):
current_val
=
getattr
(
args
,
field
.
name
)
default_val
=
getattr
(
default_args
,
field
.
name
)
if
current_val
!=
default_val
:
non_default_args
[
field
.
name
]
=
current_val
if
default_args
.
model
!=
EngineArgs
.
model
:
non_default_args
[
"model"
]
=
default_args
.
model
else
:
raise
TypeError
(
"Unsupported argument type. "
\
"Must be argparse.Namespace or EngineArgs instance."
)
...
...
vllm/envs.py
View file @
d2b52805
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
hashlib
import
json
import
os
import
sys
import
tempfile
...
...
@@ -42,7 +43,6 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_USE_FLASHINFER_SAMPLER
:
Optional
[
bool
]
=
None
VLLM_FLASHINFER_FORCE_TENSOR_CORES
:
bool
=
False
VLLM_PP_LAYER_PARTITION
:
Optional
[
str
]
=
None
VLLM_CPU_KVCACHE_SPACE
:
Optional
[
int
]
=
0
VLLM_CPU_OMP_THREADS_BIND
:
str
=
""
...
...
@@ -99,6 +99,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_RMSNORM
:
bool
=
True
VLLM_ROCM_USE_AITER_MLA
:
bool
=
True
VLLM_ROCM_USE_AITER_MHA
:
bool
=
True
VLLM_ROCM_USE_AITER_FP8BMM
:
bool
=
True
VLLM_ROCM_USE_SKINNY_GEMM
:
bool
=
True
VLLM_ROCM_FP8_PADDING
:
bool
=
True
VLLM_ROCM_MOE_PADDING
:
bool
=
True
...
...
@@ -131,7 +132,9 @@ if TYPE_CHECKING:
VLLM_TPU_USING_PATHWAYS
:
bool
=
False
VLLM_USE_DEEP_GEMM
:
bool
=
False
VLLM_USE_DEEP_GEMM_E8M0
:
bool
=
True
VLLM_USE_DEEP_GEMM_E8M0_HOPPER
:
bool
=
False
VLLM_SKIP_DEEP_GEMM_WARMUP
:
bool
=
False
VLLM_USE_FUSED_MOE_GROUPED_TOPK
:
bool
=
True
VLLM_USE_FLASHINFER_MOE_FP8
:
bool
=
False
VLLM_USE_FLASHINFER_MOE_FP4
:
bool
=
False
VLLM_FLASHINFER_MOE_BACKEND
:
str
=
"throughput"
...
...
@@ -159,9 +162,12 @@ if TYPE_CHECKING:
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE
:
bool
=
False
VLLM_ENABLE_RESPONSES_API_STORE
:
bool
=
False
VLLM_USE_TRTLLM_ATTENTION
:
Optional
[
str
]
=
None
VLLM_HAS_FLASHINFER_CUBIN
:
bool
=
False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
:
bool
=
False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
:
bool
=
False
VLLM_ALLREDUCE_USE_SYMM_MEM
:
bool
=
False
VLLM_TUNED_CONFIG_FOLDER
:
Optional
[
str
]
=
None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -465,11 +471,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
bool
(
int
(
os
.
environ
[
"VLLM_USE_FLASHINFER_SAMPLER"
]))
if
"VLLM_USE_FLASHINFER_SAMPLER"
in
os
.
environ
else
None
,
# If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture.
"VLLM_FLASHINFER_FORCE_TENSOR_CORES"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_FLASHINFER_FORCE_TENSOR_CORES"
,
"0"
))),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION"
:
lambda
:
os
.
getenv
(
"VLLM_PP_LAYER_PARTITION"
,
None
),
...
...
@@ -667,11 +668,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_LORA_RESOLVER_CACHE_DIR"
:
lambda
:
os
.
getenv
(
"VLLM_LORA_RESOLVER_CACHE_DIR"
,
None
),
# Enables torch profiler if set. Path to the directory where torch profiler
# traces are saved. Note that it must be an absolute path.
# Enables torch profiler if set.
# Both AsyncLLM's CPU traces as well as workers'
# traces (CPU & GPU) will be saved under this directory.
# Note that it must be an absolute path.
"VLLM_TORCH_PROFILER_DIR"
:
lambda
:
(
None
if
os
.
getenv
(
"VLLM_TORCH_PROFILER_DIR"
,
None
)
is
None
else
os
.
path
.
expanduser
(
os
.
getenv
(
"VLLM_TORCH_PROFILER_DIR"
,
"."
))),
.
path
.
abspath
(
os
.
path
.
expanduser
(
os
.
getenv
(
"VLLM_TORCH_PROFILER_DIR"
,
"."
)))),
# Enable torch profiler to record shapes if set
# VLLM_TORCH_PROFILER_RECORD_SHAPES=1. If not set, torch profiler will
...
...
@@ -771,6 +775,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER_MHA"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
"VLLM_ROCM_USE_AITER_FP8BMM"
:
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER_FP8BMM"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM"
:
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_SKINNY_GEMM"
,
"True"
).
lower
()
in
...
...
@@ -953,9 +963,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_DEEP_GEMM"
,
"0"
))),
# Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs.
# E8M0 is faster on B200 but may reduce accuracy.
"VLLM_USE_DEEP_GEMM_E8M0"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_DEEP_GEMM_E8M0"
,
"1"
))),
# TODO(wentao): unify the two E8M0 flags after verifying the correctness.
# Whether to use E8M0 scaling when DeepGEMM is used on Hopper GPUs.
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER"
,
"0"
))),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine
...
...
@@ -964,6 +977,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_SKIP_DEEP_GEMM_WARMUP"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_SKIP_DEEP_GEMM_WARMUP"
,
"0"
))),
# Whether to use fused grouped_topk used for MoE expert selection.
"VLLM_USE_FUSED_MOE_GROUPED_TOPK"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FUSED_MOE_GROUPED_TOPK"
,
"1"
))),
# Allow use of FlashInfer MoE kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE_FP8"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_USE_FLASHINFER_MOE_FP8"
,
"0"
))),
...
...
@@ -1042,6 +1059,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE"
,
"163840"
)),
# Specifies the thresholds of the communicated tensor sizes under which
# vllm should use flashinfer fused allreduce. The variable should be a
# JSON with the following format:
# { <world size>: <max size in mb> }
# Unspecified world sizes will fallback to
# { 2: 64, 4: 1, <everything else>: 0.5 }
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB"
:
lambda
:
json
.
loads
(
os
.
getenv
(
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB"
,
"{}"
)),
# MoE routing strategy selector.
# See `RoutingSimulator.get_available_strategies()` # for available
# strategies.
...
...
@@ -1108,6 +1135,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRTLLM_ATTENTION"
:
lambda
:
os
.
getenv
(
"VLLM_USE_TRTLLM_ATTENTION"
,
None
),
# If set, it means we pre-downloaded cubin files and flashinfer will
# read the cubin files directly.
"VLLM_HAS_FLASHINFER_CUBIN"
:
lambda
:
os
.
getenv
(
"VLLM_HAS_FLASHINFER_CUBIN"
,
False
),
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
# Otherwise, uses the first available of: flashinfer cutlass GEMM,
# vllm cutlass GEMM, marlin GEMM.
...
...
@@ -1120,6 +1152,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_CUDAGRAPH_GC"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_CUDAGRAPH_GC"
,
"0"
))),
# Disable padding to CUDA graph capture batch sizes.
# TODO(wentao): https://github.com/vllm-project/vllm/issues/23378
# After the issue is fixed, we can remove this flag.
"VLLM_DISABLE_PAD_FOR_CUDAGRAPH"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_DISABLE_PAD_FOR_CUDAGRAPH"
,
"0"
))),
# Used to force set up loopback IP
"VLLM_LOOPBACK_IP"
:
lambda
:
os
.
getenv
(
"VLLM_LOOPBACK_IP"
,
""
),
...
...
@@ -1153,6 +1191,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_RESPONSES_API_STORE"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_RESPONSES_API_STORE"
,
"0"
))),
# Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ALLREDUCE_USE_SYMM_MEM"
,
"0"
))),
# Allows vllm to find tuned config under customized folder
"VLLM_TUNED_CONFIG_FOLDER"
:
lambda
:
os
.
getenv
(
"VLLM_TUNED_CONFIG_FOLDER"
,
None
),
...
...
@@ -1218,10 +1260,12 @@ def compute_hash() -> str:
"VLLM_USE_AITER_UNIFIED_ATTENTION"
,
"VLLM_ATTENTION_BACKEND"
,
"VLLM_USE_FLASHINFER_SAMPLER"
,
"VLLM_FLASHINFER_FORCE_TENSOR_CORES"
,
"VLLM_DISABLED_KERNELS"
,
"VLLM_USE_DEEP_GEMM"
,
"VLLM_USE_DEEP_GEMM_E8M0"
,
"VLLM_USE_DEEP_GEMM_E8M0_HOPPER"
,
"VLLM_USE_TRTLLM_FP4_GEMM"
,
"VLLM_USE_FUSED_MOE_GROUPED_TOPK"
,
"VLLM_USE_FLASHINFER_MOE_FP8"
,
"VLLM_USE_FLASHINFER_MOE_FP4"
,
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8"
,
...
...
@@ -1235,6 +1279,7 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_AITER_RMSNORM"
,
"VLLM_ROCM_USE_AITER_MLA"
,
"VLLM_ROCM_USE_AITER_MHA"
,
"VLLM_ROCM_USE_AITER_FP8BMM"
,
"VLLM_ROCM_USE_SKINNY_GEMM"
,
"VLLM_ROCM_FP8_PADDING"
,
"VLLM_ROCM_MOE_PADDING"
,
...
...
vllm/executor/mp_distributed_executor.py
View file @
d2b52805
...
...
@@ -101,7 +101,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
result_handler
.
start
()
self
.
worker_monitor
.
start
()
# Set up signal handlers to shutdown the executor cleanly
# Set up signal handlers to shut
down the executor cleanly
# sometimes gc does not work well
self
.
driver_worker
=
WorkerWrapperBase
(
self
.
vllm_config
,
0
)
...
...
vllm/executor/msgspec_utils.py
View file @
d2b52805
...
...
@@ -4,11 +4,12 @@
from
array
import
array
from
typing
import
Any
,
Type
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
def
encode_hook
(
obj
:
Any
)
->
Any
:
"""Custom msgspec enc hook that supports array types.
"""Custom msgspec enc hook that supports array types
and MultiModalKwargs
.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
...
...
@@ -17,10 +18,12 @@ def encode_hook(obj: Any) -> Any:
f
"vLLM array type should use '
{
VLLM_TOKEN_ID_ARRAY_TYPE
}
' type. "
f
"Given array has a type code of
{
obj
.
typecode
}
."
)
return
obj
.
tobytes
()
if
isinstance
(
obj
,
MultiModalKwargs
):
return
dict
(
obj
)
def
decode_hook
(
type
:
Type
,
obj
:
Any
)
->
Any
:
"""Custom msgspec dec hook that supports array types.
"""Custom msgspec dec hook that supports array types
and MultiModalKwargs
.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
"""
...
...
@@ -28,3 +31,5 @@ def decode_hook(type: Type, obj: Any) -> Any:
deserialized
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
)
deserialized
.
frombytes
(
obj
)
return
deserialized
if
type
is
MultiModalKwargs
:
return
MultiModalKwargs
(
obj
)
vllm/executor/ray_utils.py
View file @
d2b52805
...
...
@@ -10,6 +10,7 @@ import msgspec
import
vllm.platforms
from
vllm.config
import
ParallelConfig
from
vllm.distributed
import
get_pp_group
from
vllm.executor.msgspec_utils
import
decode_hook
,
encode_hook
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
...
...
@@ -136,6 +137,11 @@ try:
scheduler_output
,
intermediate_tensors
)
if
isinstance
(
output
,
IntermediateTensors
):
output
=
scheduler_output
,
output
elif
not
get_pp_group
().
is_last_rank
:
# Case where there are no scheduled requests
# but may still be finished requests.
assert
not
output
or
not
output
.
req_ids
output
=
scheduler_output
,
None
return
output
def
override_env_vars
(
self
,
vars
:
Dict
[
str
,
str
]):
...
...
vllm/inputs/__init__.py
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
.data
import
(
DecoderOnlyInputs
,
EmbedsInputs
,
EmbedsPrompt
,
from
.data
import
(
DataPrompt
,
DecoderOnlyInputs
,
EmbedsInputs
,
EmbedsPrompt
,
EncoderDecoderInputs
,
ExplicitEncoderDecoderPrompt
,
ProcessorInputs
,
PromptType
,
SingletonInputs
,
SingletonPrompt
,
TextPrompt
,
TokenInputs
,
TokensPrompt
,
...
...
@@ -18,6 +18,7 @@ target model.
"""
__all__
=
[
"DataPrompt"
,
"TextPrompt"
,
"TokensPrompt"
,
"PromptType"
,
...
...
vllm/inputs/data.py
View file @
d2b52805
...
...
@@ -7,7 +7,8 @@ import torch
from
typing_extensions
import
NotRequired
,
TypedDict
,
TypeIs
,
TypeVar
if
TYPE_CHECKING
:
from
vllm.multimodal.inputs
import
MultiModalDataDict
,
MultiModalInputs
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalInputs
,
MultiModalUUIDDict
)
class
TextPrompt
(
TypedDict
):
...
...
@@ -30,6 +31,15 @@ class TextPrompt(TypedDict):
to pass the mm_processor_kwargs to each of them.
"""
multi_modal_uuids
:
NotRequired
[
"MultiModalUUIDDict"
]
"""
Optional user-specified UUIDs for multimodal items, mapped by modality.
Lists must match the number of items per modality and may contain `None`.
For `None` entries, the hasher will compute IDs automatically; non-None
entries override the default hashes for caching, and MUST be unique per
multimodal item.
"""
cache_salt
:
NotRequired
[
str
]
"""
Optional cache salt to be used for prefix caching.
...
...
@@ -59,6 +69,14 @@ class TokensPrompt(TypedDict):
to pass the mm_processor_kwargs to each of them.
"""
multi_modal_uuids
:
NotRequired
[
"MultiModalUUIDDict"
]
"""
Optional user-specified UUIDs for multimodal items, mapped by modality.
Lists must match the number of items per modality and may contain `None`.
For `None` entries, the hasher will compute IDs automatically; non-None
entries override the default hashes for caching.
"""
cache_salt
:
NotRequired
[
str
]
"""
Optional cache salt to be used for prefix caching.
...
...
@@ -77,6 +95,16 @@ class EmbedsPrompt(TypedDict):
"""
class
DataPrompt
(
TypedDict
):
"""Represents generic inputs handled by IO processor plugins."""
data
:
Any
"""The input data"""
data_format
:
str
"""The input data format"""
SingletonPrompt
=
Union
[
str
,
TextPrompt
,
TokensPrompt
,
EmbedsPrompt
]
"""
Set of possible schemas for a single prompt:
...
...
@@ -174,9 +202,6 @@ class TokenInputs(TypedDict):
prompt_token_ids
:
list
[
int
]
"""The token IDs of the prompt."""
token_type_ids
:
NotRequired
[
list
[
int
]]
"""The token type IDs of the prompt."""
prompt
:
NotRequired
[
str
]
"""
The original prompt text corresponding to the token IDs, if available.
...
...
@@ -190,7 +215,6 @@ class TokenInputs(TypedDict):
def
token_inputs
(
prompt_token_ids
:
list
[
int
],
token_type_ids
:
Optional
[
list
[
int
]]
=
None
,
prompt
:
Optional
[
str
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
)
->
TokenInputs
:
...
...
@@ -200,8 +224,6 @@ def token_inputs(
if
prompt
is
not
None
:
inputs
[
"prompt"
]
=
prompt
if
token_type_ids
is
not
None
:
inputs
[
"token_type_ids"
]
=
token_type_ids
if
cache_salt
is
not
None
:
inputs
[
"cache_salt"
]
=
cache_salt
...
...
vllm/inputs/preprocess.py
View file @
d2b52805
...
...
@@ -11,8 +11,9 @@ from vllm.config import ModelConfig
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal.cache
import
BaseMultiModalProcessorCache
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalInputs
)
MultiModalInputs
,
MultiModalUUIDDict
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
...
...
@@ -32,12 +33,14 @@ class InputPreprocessor:
model_config
:
ModelConfig
,
tokenizer
:
Optional
[
TokenizerGroup
],
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
mm_processor_cache
:
Optional
[
BaseMultiModalProcessorCache
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
model_config
=
model_config
self
.
tokenizer
=
tokenizer
self
.
mm_registry
=
mm_registry
self
.
mm_processor_cache
=
mm_processor_cache
def
get_tokenizer_group
(
self
)
->
TokenizerGroup
:
if
self
.
tokenizer
is
None
:
...
...
@@ -254,7 +257,9 @@ class InputPreprocessor:
mm_processor_kwargs
:
Optional
[
Mapping
[
str
,
object
]],
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
*
,
mm_hash_overrides
:
Optional
[
Union
[
dict
[
str
,
list
[
str
]],
MultiModalUUIDDict
]]
=
None
,
)
->
MultiModalInputs
:
"""
Apply the model's multi-modal processor to a multi-modal prompt,
...
...
@@ -262,17 +267,22 @@ class InputPreprocessor:
"""
tokenizer
=
self
.
_get_mm_tokenizer
(
lora_request
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
tokenizer
=
tokenizer
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
tokenizer
=
tokenizer
,
cache
=
self
.
mm_processor_cache
,
)
if
mm_processor_kwargs
is
None
:
mm_processor_kwargs
=
{}
return
mm_processor
.
apply
(
prompt
,
return
mm_processor
.
apply
(
prompt
,
mm_data
,
hf_processor_mm_kwargs
=
mm_processor_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
return_mm_hashes
=
return_mm_hashes
)
mm_hash_overrides
=
mm_hash_overrides
,
)
async
def
_process_multimodal_async
(
self
,
...
...
@@ -281,7 +291,9 @@ class InputPreprocessor:
mm_processor_kwargs
:
Optional
[
Mapping
[
str
,
object
]],
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
*
,
mm_hash_overrides
:
Optional
[
Union
[
dict
[
str
,
list
[
str
]],
MultiModalUUIDDict
]]
=
None
,
)
->
MultiModalInputs
:
"""
Async version of
...
...
@@ -289,16 +301,22 @@ class InputPreprocessor:
"""
tokenizer
=
await
self
.
_get_mm_tokenizer_async
(
lora_request
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
tokenizer
=
tokenizer
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
tokenizer
=
tokenizer
,
cache
=
self
.
mm_processor_cache
,
)
if
mm_processor_kwargs
is
None
:
mm_processor_kwargs
=
{}
return
mm_processor
.
apply
(
prompt
,
return
mm_processor
.
apply
(
prompt
,
mm_data
,
hf_processor_mm_kwargs
=
mm_processor_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
return_mm_hashes
=
return_mm_hashes
)
mm_hash_overrides
=
mm_hash_overrides
,
)
def
_process_embeds
(
self
,
...
...
@@ -330,15 +348,33 @@ class InputPreprocessor:
)
->
EmbedsInputs
:
return
self
.
_process_embeds
(
parsed_content
)
def
_truncate_inputs
(
self
,
inputs
:
list
[
int
],
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
list
[
int
]:
if
not
tokenization_kwargs
or
"truncation"
not
in
\
tokenization_kwargs
or
self
.
tokenizer
is
None
:
return
inputs
max_length
=
tokenization_kwargs
[
"max_length"
]
if
self
.
tokenizer
.
truncation_side
==
"left"
:
return
inputs
[
-
max_length
:]
else
:
return
inputs
[:
max_length
]
def
_process_tokens
(
self
,
parsed_content
:
TokensPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
*
,
mm_hash_overrides
:
Optional
[
Union
[
dict
[
str
,
list
[
str
]],
MultiModalUUIDDict
]]
=
None
,
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
prompt_token_ids
=
parsed_content
[
"prompt_token_ids"
]
token_type_ids
=
parsed_content
.
get
(
"token_type_ids"
)
prompt_token_ids
=
self
.
_truncate_inputs
(
parsed_content
[
"prompt_token_ids"
],
tokenization_kwargs
)
inputs
:
Union
[
TokenInputs
,
MultiModalInputs
]
if
multi_modal_data
:
=
parsed_content
.
get
(
"multi_modal_data"
):
...
...
@@ -348,13 +384,10 @@ class InputPreprocessor:
parsed_content
.
get
(
"mm_processor_kwargs"
),
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hash
es
,
mm_hash_overrides
=
mm_hash_overrid
es
,
)
else
:
inputs
=
token_inputs
(
prompt_token_ids
=
prompt_token_ids
,
token_type_ids
=
token_type_ids
,
)
inputs
=
token_inputs
(
prompt_token_ids
=
prompt_token_ids
)
if
cache_salt
:
=
parsed_content
.
get
(
"cache_salt"
):
inputs
[
"cache_salt"
]
=
cache_salt
...
...
@@ -366,10 +399,12 @@ class InputPreprocessor:
parsed_content
:
TokensPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
*
,
mm_hash_overrides
:
Optional
[
Union
[
dict
[
str
,
list
[
str
]],
MultiModalUUIDDict
]]
=
None
,
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
prompt_token_ids
=
parsed_content
[
"prompt_token_ids"
]
token_type_ids
=
parsed_content
.
get
(
"token_type_ids"
)
prompt_token_ids
=
self
.
_truncate_inputs
(
parsed_content
[
"prompt_token_ids"
],
tokenization_kwargs
)
inputs
:
Union
[
TokenInputs
,
MultiModalInputs
]
if
multi_modal_data
:
=
parsed_content
.
get
(
"multi_modal_data"
):
...
...
@@ -379,13 +414,10 @@ class InputPreprocessor:
parsed_content
.
get
(
"mm_processor_kwargs"
),
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hash
es
,
mm_hash_overrides
=
mm_hash_overrid
es
,
)
else
:
inputs
=
token_inputs
(
prompt_token_ids
=
prompt_token_ids
,
token_type_ids
=
token_type_ids
,
)
inputs
=
token_inputs
(
prompt_token_ids
=
prompt_token_ids
,
)
if
cache_salt
:
=
parsed_content
.
get
(
"cache_salt"
):
inputs
[
"cache_salt"
]
=
cache_salt
...
...
@@ -397,7 +429,9 @@ class InputPreprocessor:
parsed_content
:
TextPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
*
,
mm_hash_overrides
:
Optional
[
Union
[
dict
[
str
,
list
[
str
]],
MultiModalUUIDDict
]]
=
None
,
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
prompt_text
=
parsed_content
[
"prompt"
]
...
...
@@ -409,7 +443,7 @@ class InputPreprocessor:
parsed_content
.
get
(
"mm_processor_kwargs"
),
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hash
es
,
mm_hash_overrides
=
mm_hash_overrid
es
,
)
else
:
prompt_token_ids
=
self
.
_tokenize_prompt
(
...
...
@@ -432,7 +466,9 @@ class InputPreprocessor:
parsed_content
:
TextPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
*
,
mm_hash_overrides
:
Optional
[
Union
[
dict
[
str
,
list
[
str
]],
MultiModalUUIDDict
]]
=
None
,
)
->
Union
[
TokenInputs
,
MultiModalInputs
]:
prompt_text
=
parsed_content
[
"prompt"
]
...
...
@@ -444,7 +480,7 @@ class InputPreprocessor:
parsed_content
.
get
(
"mm_processor_kwargs"
),
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hash
es
,
mm_hash_overrides
=
mm_hash_overrid
es
,
)
else
:
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
...
...
@@ -467,7 +503,9 @@ class InputPreprocessor:
prompt
:
SingletonPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
*
,
mm_hash_overrides
:
Optional
[
Union
[
dict
[
str
,
list
[
str
]],
MultiModalUUIDDict
]]
=
None
,
)
->
SingletonInputs
:
"""
Extract the singleton inputs from a prompt.
...
...
@@ -476,7 +514,6 @@ class InputPreprocessor:
* prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
* return_mm_hashes: whether to return multimodal hashes
Returns:
...
...
@@ -490,21 +527,21 @@ class InputPreprocessor:
return
self
.
_process_tokens
(
parsed
[
"content"
],
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hash
es
,
mm_hash_overrides
=
mm_hash_overrid
es
,
)
if
parsed
[
"type"
]
==
"text"
:
return
self
.
_process_text
(
parsed
[
"content"
],
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hash
es
,
mm_hash_overrides
=
mm_hash_overrid
es
,
)
if
parsed
[
"type"
]
==
"str"
:
return
self
.
_process_text
(
TextPrompt
(
prompt
=
parsed
[
"content"
]),
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hash
es
,
mm_hash_overrides
=
mm_hash_overrid
es
,
)
assert_never
(
parsed
)
...
...
@@ -514,7 +551,9 @@ class InputPreprocessor:
prompt
:
SingletonPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
*
,
mm_hash_overrides
:
Optional
[
Union
[
dict
[
str
,
list
[
str
]],
MultiModalUUIDDict
]]
=
None
,
)
->
SingletonInputs
:
"""
Async version of
...
...
@@ -528,21 +567,21 @@ class InputPreprocessor:
return
await
self
.
_process_tokens_async
(
parsed
[
"content"
],
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hash
es
,
mm_hash_overrides
=
mm_hash_overrid
es
,
)
if
parsed
[
"type"
]
==
"text"
:
return
await
self
.
_process_text_async
(
parsed
[
"content"
],
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hash
es
,
mm_hash_overrides
=
mm_hash_overrid
es
,
)
if
parsed
[
"type"
]
==
"str"
:
return
await
self
.
_process_text_async
(
TextPrompt
(
prompt
=
parsed
[
"content"
]),
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hash
es
,
mm_hash_overrides
=
mm_hash_overrid
es
,
)
assert_never
(
parsed
)
...
...
@@ -652,6 +691,9 @@ class InputPreprocessor:
self
,
prompt
:
PromptType
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
*
,
mm_hash_overrides
:
Optional
[
Union
[
dict
[
str
,
list
[
str
]],
MultiModalUUIDDict
]]
=
None
,
)
->
EncoderDecoderInputs
:
"""
For encoder/decoder models only:
...
...
@@ -693,6 +735,7 @@ class InputPreprocessor:
encoder_inputs
=
self
.
_prompt_to_llm_inputs
(
prompt
[
"encoder_prompt"
],
tokenization_kwargs
=
tokenization_kwargs
,
mm_hash_overrides
=
mm_hash_overrides
,
)
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
decoder_inputs
=
None
...
...
@@ -708,6 +751,7 @@ class InputPreprocessor:
inputs
=
self
.
_prompt_to_llm_inputs
(
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
mm_hash_overrides
=
mm_hash_overrides
,
)
if
self
.
model_config
.
is_multimodal_model
:
# Encoder-Decoder Multimodal model
...
...
@@ -723,6 +767,9 @@ class InputPreprocessor:
self
,
prompt
:
PromptType
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
*
,
mm_hash_overrides
:
Optional
[
Union
[
dict
[
str
,
list
[
str
]],
MultiModalUUIDDict
]]
=
None
,
)
->
EncoderDecoderInputs
:
"""
Async version of
...
...
@@ -735,6 +782,7 @@ class InputPreprocessor:
encoder_task
=
self
.
_prompt_to_llm_inputs_async
(
prompt
[
"encoder_prompt"
],
tokenization_kwargs
=
tokenization_kwargs
,
mm_hash_overrides
=
mm_hash_overrides
,
)
if
(
decoder_input
:
=
prompt
[
"decoder_prompt"
])
is
None
:
...
...
@@ -744,6 +792,7 @@ class InputPreprocessor:
decoder_task
=
self
.
_prompt_to_llm_inputs_async
(
decoder_input
,
tokenization_kwargs
=
tokenization_kwargs
,
mm_hash_overrides
=
mm_hash_overrides
,
)
encoder_inputs
,
decoder_inputs
=
await
asyncio
.
gather
(
...
...
@@ -759,6 +808,7 @@ class InputPreprocessor:
inputs
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
mm_hash_overrides
=
mm_hash_overrides
,
)
if
self
.
model_config
.
is_multimodal_model
:
# Encoder-Decoder Multimodal model
...
...
@@ -785,7 +835,9 @@ class InputPreprocessor:
prompt
:
SingletonPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
*
,
mm_hash_overrides
:
Optional
[
Union
[
dict
[
str
,
list
[
str
]],
MultiModalUUIDDict
]]
=
None
,
)
->
DecoderOnlyInputs
:
"""
For decoder-only models:
...
...
@@ -796,7 +848,6 @@ class InputPreprocessor:
* prompt: input prompt
* lora_request
* return_mm_hashes
Returns:
...
...
@@ -807,7 +858,7 @@ class InputPreprocessor:
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hash
es
,
mm_hash_overrides
=
mm_hash_overrid
es
,
)
return
self
.
_build_decoder_only_llm_inputs
(
prompt_comps
)
...
...
@@ -817,7 +868,9 @@ class InputPreprocessor:
prompt
:
SingletonPrompt
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
*
,
mm_hash_overrides
:
Optional
[
Union
[
dict
[
str
,
list
[
str
]],
MultiModalUUIDDict
]]
=
None
,
)
->
DecoderOnlyInputs
:
"""
Async version of
...
...
@@ -827,7 +880,7 @@ class InputPreprocessor:
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hash
es
,
mm_hash_overrides
=
mm_hash_overrid
es
,
)
return
self
.
_build_decoder_only_llm_inputs
(
prompt_comps
)
...
...
@@ -837,17 +890,19 @@ class InputPreprocessor:
prompt
:
PromptType
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
*
,
mm_hash_overrides
:
Optional
[
Union
[
dict
[
str
,
list
[
str
]],
MultiModalUUIDDict
]]
=
None
,
)
->
ProcessorInputs
:
"""Preprocess the input prompt."""
if
self
.
model_config
.
is_encoder_decoder
:
assert
not
return_mm_hashes
,
(
"Multimodal hashes for encoder-decoder models should not be "
,
"returned until they are supported on vLLM V1."
)
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
# input prompts to encoder & decoder
.
return
self
.
_process_encoder_decoder_prompt
(
prompt
,
tokenization_kwargs
)
prompt
,
tokenization_kwargs
,
mm_hash_overrides
=
mm_hash_overrides
,
)
if
is_explicit_encoder_decoder_prompt
(
prompt
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
...
...
@@ -858,7 +913,7 @@ class InputPreprocessor:
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hash
es
,
mm_hash_overrides
=
mm_hash_overrid
es
,
)
async
def
preprocess_async
(
...
...
@@ -866,19 +921,22 @@ class InputPreprocessor:
prompt
:
PromptType
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
*
,
mm_hash_overrides
:
Optional
[
Union
[
dict
[
str
,
list
[
str
]],
MultiModalUUIDDict
]]
=
None
,
)
->
ProcessorInputs
:
"""
Async version of
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
"""
if
self
.
model_config
.
is_encoder_decoder
:
assert
not
return_mm_hashes
,
(
"Multimodal hashes for encoder-decoder models should not be "
,
"returned until they are supported on vLLM V1."
)
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return
await
self
.
_process_encoder_decoder_prompt_async
(
prompt
)
# input prompts to encoder & decoder.
return
await
self
.
_process_encoder_decoder_prompt_async
(
prompt
,
tokenization_kwargs
,
mm_hash_overrides
=
mm_hash_overrides
,
)
if
is_explicit_encoder_decoder_prompt
(
prompt
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
...
...
@@ -889,5 +947,9 @@ class InputPreprocessor:
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
lora_request
=
lora_request
,
return_mm_hashes
=
return_mm_hash
es
,
mm_hash_overrides
=
mm_hash_overrid
es
,
)
def
clear_cache
(
self
)
->
None
:
if
self
.
mm_processor_cache
is
not
None
:
self
.
mm_processor_cache
.
clear_cache
()
vllm/inputs/registry.py
View file @
d2b52805
...
...
@@ -223,20 +223,26 @@ class InputRegistry:
The model is identified by ``model_config``.
"""
# Avoid circular import
from
vllm.multimodal.cache
import
processor_only_cache_from_config
from
vllm.sequence
import
SequenceData
if
not
model_config
.
is_multimodal_model
:
seq_data
=
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
))
return
DummyData
(
seq_data
=
seq_data
)
cache
=
processor_only_cache_from_config
(
model_config
,
mm_registry
)
# Encoder dummy data does not contain multi-modal data
if
is_encoder_data
:
enc_data
=
mm_registry
.
get_encoder_dummy_data
(
model_config
,
seq_len
)
enc_data
=
mm_registry
.
get_encoder_dummy_data
(
model_config
,
seq_len
,
cache
=
cache
)
seq_data
=
SequenceData
.
from_seqs
(
enc_data
.
prompt_token_ids
)
return
DummyData
(
seq_data
=
seq_data
)
dec_data
=
mm_registry
.
get_decoder_dummy_data
(
model_config
,
seq_len
)
dec_data
=
mm_registry
.
get_decoder_dummy_data
(
model_config
,
seq_len
,
cache
=
cache
)
return
DummyData
(
seq_data
=
SequenceData
.
from_seqs
(
dec_data
.
prompt_token_ids
),
...
...
vllm/lora/layers.py
View file @
d2b52805
...
...
@@ -48,9 +48,6 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
# GPTQ/AWQ
elif
hasattr
(
base_layer
,
"qweight"
):
return
base_layer
.
qweight
.
device
# marlin
elif
hasattr
(
base_layer
,
"B"
):
return
base_layer
.
B
.
device
# HQQ marlin
elif
hasattr
(
base_layer
,
"W_q"
):
return
base_layer
.
W_q
.
device
...
...
@@ -608,7 +605,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
class
MergedColumnParallelLinearWithLoRA
(
ColumnParallelLinearWithLoRA
):
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
packed together (eg. gate_proj + up_proj -> gate_up_proj).
packed together (e
.
g. gate_proj + up_proj -> gate_up_proj).
This means we have 2 LoRAs, each applied to one half of the layer.
...
...
vllm/lora/models.py
View file @
d2b52805
...
...
@@ -207,6 +207,7 @@ class LoRAModel(AdapterModel):
"""
lora_tensor_path
=
os
.
path
.
join
(
lora_dir
,
"adapter_model.safetensors"
)
lora_bin_file_path
=
os
.
path
.
join
(
lora_dir
,
"adapter_model.bin"
)
lora_pt_file_path
=
os
.
path
.
join
(
lora_dir
,
"adapter_model.pt"
)
new_embeddings_tensor_path
=
os
.
path
.
join
(
lora_dir
,
"new_embeddings.safetensors"
)
new_embeddings_bin_file_path
=
os
.
path
.
join
(
lora_dir
,
...
...
@@ -255,9 +256,10 @@ class LoRAModel(AdapterModel):
check_unexpected_modules
(
f
)
for
module
in
f
.
keys
():
# noqa
tensors
[
module
]
=
f
.
get_tensor
(
module
)
elif
os
.
path
.
isfile
(
lora_bin_file_path
):
# When a bin file is provided, we rely on config to find unexpected
# modules.
elif
os
.
path
.
isfile
(
lora_bin_file_path
)
or
os
.
path
.
isfile
(
lora_pt_file_path
):
# When a bin/pt file is provided, we rely on config to find
# unexpected modules.
unexpected_modules
=
[]
target_modules
=
peft_helper
.
target_modules
if
not
isinstance
(
target_modules
,
list
):
...
...
@@ -279,7 +281,10 @@ class LoRAModel(AdapterModel):
f
" target modules in
{
expected_lora_modules
}
"
f
" but received
{
unexpected_modules
}
."
f
" Please verify that the loaded LoRA module is correct"
)
tensors
=
torch
.
load
(
lora_bin_file_path
,
lora_file_path
=
(
lora_bin_file_path
if
os
.
path
.
isfile
(
lora_bin_file_path
)
else
lora_pt_file_path
)
tensors
=
torch
.
load
(
lora_file_path
,
map_location
=
device
,
weights_only
=
True
)
else
:
...
...
vllm/model_executor/layers/activation.py
View file @
d2b52805
...
...
@@ -10,11 +10,14 @@ import torch.nn.functional as F
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
LazyDict
logger
=
init_logger
(
__name__
)
@
CustomOp
.
register
(
"fatrelu_and_mul"
)
class
FatreluAndMul
(
CustomOp
):
...
...
@@ -363,6 +366,112 @@ class ReLUSquaredActivation(CustomOp):
return
self
.
forward_native
(
x
)
@
CustomOp
.
register
(
"xielu"
)
class
XIELU
(
CustomOp
):
"""
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
Otherwise, we emit a single warning and use xIELU Python
"""
def
__init__
(
self
,
alpha_p_init
:
float
=
0.8
,
alpha_n_init
:
float
=
0.8
,
beta
:
float
=
0.5
,
eps
:
float
=
-
1e-6
,
dtype
:
torch
.
dtype
=
torch
.
bfloat16
,
with_vector_loads
:
bool
=
False
,
):
super
().
__init__
()
self
.
alpha_p
=
nn
.
Parameter
(
torch
.
log
(
torch
.
exp
(
torch
.
tensor
(
alpha_p_init
,
dtype
=
dtype
))
-
1
).
unsqueeze
(
0
))
self
.
alpha_n
=
nn
.
Parameter
(
torch
.
log
(
torch
.
exp
(
torch
.
tensor
(
alpha_n_init
-
beta
,
dtype
=
dtype
))
-
1
).
unsqueeze
(
0
))
self
.
register_buffer
(
"beta"
,
torch
.
tensor
(
beta
,
dtype
=
dtype
))
self
.
register_buffer
(
"eps"
,
torch
.
tensor
(
eps
,
dtype
=
dtype
))
self
.
with_vector_loads
=
with_vector_loads
# Temporary until xIELU CUDA fully implemented
self
.
_beta_scalar
=
float
(
self
.
beta
.
detach
().
cpu
().
float
().
item
())
self
.
_eps_scalar
=
float
(
self
.
eps
.
detach
().
cpu
().
float
().
item
())
self
.
_xielu_cuda_obj
=
None
try
:
import
xielu.ops
# noqa: F401
self
.
_xielu_cuda_obj
=
torch
.
classes
.
xielu
.
XIELU
()
msg
=
"Using experimental xIELU CUDA."
try
:
from
torch._dynamo
import
allow_in_graph
self
.
_xielu_cuda_fn
=
allow_in_graph
(
self
.
_xielu_cuda
)
msg
+=
" Enabled torch._dynamo for xIELU CUDA."
except
Exception
as
err
:
msg
+=
(
f
" Could not enable torch._dynamo for xIELU (
{
err
}
) - "
"this may result in slower performance."
)
self
.
_xielu_cuda_fn
=
self
.
_xielu_cuda
logger
.
warning_once
(
msg
)
except
Exception
as
err
:
logger
.
warning_once
(
"CUDA-fused xIELU not available (%s) –"
" falling back to a Python version.
\n
"
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`"
,
str
(
err
),
)
def
_xielu_python
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
alpha_p
=
nn
.
functional
.
softplus
(
self
.
alpha_p
)
alpha_n
=
self
.
beta
+
nn
.
functional
.
softplus
(
self
.
alpha_n
)
return
torch
.
where
(
x
>
0
,
alpha_p
*
x
*
x
+
self
.
beta
*
x
,
(
torch
.
expm1
(
torch
.
min
(
x
,
self
.
eps
))
-
x
)
*
alpha_n
+
self
.
beta
*
x
,
)
def
_xielu_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Firewall function to prevent torch.compile from seeing .item()"""
assert
self
.
_xielu_cuda_obj
is
not
None
,
(
"XIELU CUDA object must not be None"
)
original_shape
=
x
.
shape
# CUDA kernel expects 3D tensors, reshape if needed
while
x
.
dim
()
<
3
:
x
=
x
.
unsqueeze
(
0
)
if
x
.
dim
()
>
3
:
x
=
x
.
view
(
-
1
,
1
,
x
.
size
(
-
1
))
if
original_shape
!=
x
.
shape
:
logger
.
warning_once
(
"Warning: xIELU input tensor expects 3 dimensions"
" but got (shape: %s). Reshaping to (shape: %s)."
,
original_shape
,
x
.
shape
,
)
result
=
self
.
_xielu_cuda_obj
.
forward
(
x
,
self
.
alpha_p
,
self
.
alpha_n
,
# Temporary until xIELU CUDA fully implemented ->
# self.{beta,eps}.item()
self
.
_beta_scalar
,
self
.
_eps_scalar
,
self
.
with_vector_loads
,
)
return
result
.
view
(
original_shape
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
_xielu_cuda_obj
is
not
None
and
input
.
is_cuda
:
if
not
torch
.
_dynamo
.
is_compiling
():
return
self
.
_xielu_cuda_fn
(
input
)
else
:
logger
.
warning_once
(
"torch._dynamo is compiling, using Python version of xIELU."
)
return
self
.
_xielu_python
(
input
)
class
ScaledActivation
(
nn
.
Module
):
"""An activation function with post-scale parameters.
...
...
@@ -422,12 +531,25 @@ _ACTIVATION_REGISTRY = LazyDict({
lambda
:
nn
.
SiLU
(),
"quick_gelu"
:
lambda
:
QuickGELU
(),
"tanh"
:
lambda
:
nn
.
Tanh
(),
"sigmoid"
:
lambda
:
nn
.
Sigmoid
(),
"xielu"
:
lambda
:
XIELU
(),
})
def
get_act_fn
(
act_fn_name
:
str
)
->
nn
.
Module
:
"""Get an activation function by name."""
act_fn_name
=
act_fn_name
.
lower
()
if
act_fn_name
.
startswith
(
"torch.nn.modules."
):
activation_name
=
act_fn_name
.
split
(
"."
)[
-
1
]
if
activation_name
==
"identity"
:
return
nn
.
Identity
()
act_fn_name
=
activation_name
if
act_fn_name
not
in
_ACTIVATION_REGISTRY
:
raise
ValueError
(
f
"Activation function
{
act_fn_name
!
r
}
is not supported."
)
...
...
vllm/model_executor/layers/attention_layer_base.py
0 → 100644
View file @
d2b52805
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Base class for attention-like layers."""
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
class
AttentionLayerBase
(
ABC
):
"""
Base class for attention-like layers (Attention, Mamba, etc.)
that support the v1 engine.
This provides a common interface for getting attention backends
from different layer types.
"""
@
abstractmethod
def
get_attn_backend
(
self
)
->
type
[
"AttentionBackend"
]:
"""Get the attention backend class for this layer."""
pass
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
d2b52805
...
...
@@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.deep_gemm
import
(
fp8_m_grouped_gemm_nt_masked
,
is_
blackwell_
deep_gemm_e8m0_used
)
is_deep_gemm_e8m0_used
)
logger
=
init_logger
(
__name__
)
...
...
@@ -70,53 +70,51 @@ def _silu_mul_fp8_quant_deep_gemm(
# number of valid tokens for this expert
n_tokens
=
tl
.
load
(
counts_ptr
+
e
*
stride_counts_e
).
to
(
tl
.
int64
)
cols
=
tl
.
arange
(
0
,
BLOCK
)
cols
=
cols
.
to
(
tl
.
int64
)
mask_h
=
cols
<
BLOCK
cols
=
tl
.
arange
(
0
,
BLOCK
).
to
(
tl
.
int64
)
mask
=
cols
<
BLOCK
base_input_offset
=
e
*
stride_i_e
+
g
*
GROUP_SIZE
*
stride_i_h
base_gate_offset
=
base_input_offset
+
cols
*
stride_i_h
base_up_offset
=
base_input_offset
+
H
*
stride_i_h
+
cols
*
stride_i_h
base_yq_offset
=
(
e
*
stride_yq_e
+
g
*
GROUP_SIZE
*
stride_yq_h
+
cols
*
stride_yq_h
)
base_ys_offset
=
e
*
stride_ys_e
+
g
*
stride_ys_g
for
t
in
tl
.
range
(
0
,
n_tokens
,
num_stages
=
NUM_STAGES
):
base_i_offset
=
(
e
*
stride_i_e
+
t
*
stride_i_t
+
g
*
GROUP_SIZE
*
stride_i_h
)
base_yq_offset
=
(
e
*
stride_yq_e
+
t
*
stride_yq_t
+
g
*
GROUP_SIZE
*
stride_yq_h
)
base_ys_offset
=
e
*
stride_ys_e
+
t
*
stride_ys_t
+
g
*
stride_ys_g
mask
=
mask_h
x
=
tl
.
load
(
input_ptr
+
base_i_offset
+
cols
*
stride_i_h
,
gate
=
tl
.
load
(
input_ptr
+
base_gate_offset
+
t
*
stride_i_t
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
y2
=
tl
.
load
(
input_ptr
+
base_i_offset
+
H
*
stride_i_h
+
cols
*
stride_i_h
,
up
=
tl
.
load
(
input_ptr
+
base_up_offset
+
t
*
stride_i_t
,
mask
=
mask
,
other
=
0.0
)
.
to
(
tl
.
float32
)
other
=
0.0
)
x
=
x
*
(
1.0
/
(
1.0
+
tl
.
exp
(
-
x
)))
y
=
x
*
y2
gate
=
gate
*
(
1.0
/
(
1.0
+
tl
.
exp
(
-
gate
)))
y
=
gate
*
up
y_s
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
/
fp8_max
if
use_ue8m0
:
y_s
=
tl
.
exp2
(
tl
.
ceil
(
tl
.
log2
(
y_s
)))
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
scale_raw
=
_absmax
/
fp8_max
y_s
=
tl
.
math
.
exp2
(
tl
.
ceil
(
tl
.
log2
(
scale_raw
)))
if
use_ue8m0
else
scale_raw
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
base_yq_offset
+
cols
*
stride_yq_
h
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
+
base_ys_offset
,
y_s
)
tl
.
store
(
y_q_ptr
+
base_yq_offset
+
t
*
stride_yq_
t
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
+
base_ys_offset
+
t
*
stride_ys_t
,
y_s
)
def
silu_mul_fp8_quant_deep_gemm
(
y
:
torch
.
Tensor
,
# (E, T, 2*H)
float32
y
:
torch
.
Tensor
,
# (E, T, 2*H)
tokens_per_expert
:
torch
.
Tensor
,
# (E,) number of valid tokens per expert
group_size
:
int
=
128
,
eps
:
float
=
1e-10
,
):
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
silu-activated, multiplied by the second half, then quantized into FP8.
Returns `(y_q, y_s)` where
* `y_q`
is the
FP8 tensor
of
shape
`
(E, T, H)
`
, same layout as
`
y[..., :H]
`.
* `y_s`
has
shape
`
(E, T, H // group_size)
` and
strides
`
(T*G, 1, T)
`
* `y_q`
:
FP8 tensor
,
shape (E, T, H), same layout as y[..., :H]
* `y_s`
: FP32 tensor,
shape (E, T, H // group_size)
,
strides (T*G, 1, T)
"""
assert
y
.
ndim
==
3
,
"y must be (E, T, 2*H)"
E
,
T
,
H2
=
y
.
shape
...
...
@@ -148,7 +146,7 @@ def silu_mul_fp8_quant_deep_gemm(
stride_cnt_e
=
tokens_per_expert
.
stride
()[
0
]
#
s
tatic grid over experts and H-groups.
#
S
tatic grid over experts and H-groups.
# A loop inside the kernel handles the token dim
grid
=
(
E
*
G
,
)
...
...
@@ -176,9 +174,9 @@ def silu_mul_fp8_quant_deep_gemm(
eps
,
fp8_min
,
fp8_max
,
is_
blackwell_
deep_gemm_e8m0_used
(),
is_deep_gemm_e8m0_used
(),
BLOCK
=
group_size
,
NUM_STAGES
=
8
,
NUM_STAGES
=
4
,
num_warps
=
1
,
)
...
...
vllm/model_executor/layers/fused_moe/config.py
View file @
d2b52805
...
...
@@ -190,12 +190,6 @@ class FusedMoEParallelConfig:
return
(
self
.
use_all2all_kernels
and
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
)
@
property
def
use_flashinfer_cutlass_kernels
(
self
):
return
(
envs
.
VLLM_USE_FLASHINFER_MOE_FP4
and
has_flashinfer_cutlass_fused_moe
()
and
envs
.
VLLM_FLASHINFER_MOE_BACKEND
==
"throughput"
)
@
staticmethod
def
make
(
tp_size_
:
int
,
dp_size_
:
int
,
vllm_parallel_config
:
ParallelConfig
)
->
"FusedMoEParallelConfig"
:
...
...
@@ -404,7 +398,14 @@ class FusedMoEConfig:
@
property
def
use_flashinfer_cutlass_kernels
(
self
):
return
self
.
moe_parallel_config
.
use_flashinfer_cutlass_kernels
"""
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
"""
return
(
self
.
quant_config
is
not
None
and
self
.
quant_config
.
quant_dtype
==
"nvfp4"
and
envs
.
VLLM_USE_FLASHINFER_MOE_FP4
and
has_flashinfer_cutlass_fused_moe
()
and
envs
.
VLLM_FLASHINFER_MOE_BACKEND
==
"throughput"
)
@
staticmethod
def
make
(
...
...
@@ -450,6 +451,12 @@ class FusedMoEConfig:
if
quant_dtype
is
None
and
isinstance
(
quant_config
,
Fp8Config
):
quant_dtype
=
torch
.
float8_e4m3fn
from
vllm.model_executor.layers.quantization.mxfp4
import
(
Mxfp4Config
)
if
(
quant_dtype
is
None
and
isinstance
(
quant_config
,
Mxfp4Config
)
and
envs
.
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
):
quant_dtype
=
"mxfp8"
from
vllm.model_executor.layers.quantization.modelopt
import
(
ModelOptNvFp4Config
)
if
quant_dtype
is
None
and
isinstance
(
quant_config
,
...
...
vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json
0 → 100644
View file @
d2b52805
{
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"8192"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"16384"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
2
}
}
Prev
1
…
20
21
22
23
24
25
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment