Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4eabe123
Commit
4eabe123
authored
May 28, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori
parents
45840cd2
58738772
Changes
670
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
559 additions
and
126 deletions
+559
-126
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
+11
-7
vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py
...points/openai/tool_parsers/llama4_pythonic_tool_parser.py
+302
-0
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
+9
-5
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+1
-1
vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py
vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py
+6
-5
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
+7
-4
vllm/entrypoints/utils.py
vllm/entrypoints/utils.py
+59
-0
vllm/envs.py
vllm/envs.py
+14
-3
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+1
-1
vllm/executor/ray_distributed_executor.py
vllm/executor/ray_distributed_executor.py
+3
-3
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+2
-4
vllm/forward_context.py
vllm/forward_context.py
+4
-1
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+3
-2
vllm/inputs/data.py
vllm/inputs/data.py
+39
-29
vllm/inputs/parse.py
vllm/inputs/parse.py
+4
-4
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+39
-17
vllm/inputs/registry.py
vllm/inputs/registry.py
+2
-2
vllm/logger.py
vllm/logger.py
+5
-5
vllm/logging_utils/dump_input.py
vllm/logging_utils/dump_input.py
+3
-3
vllm/lora/models.py
vllm/lora/models.py
+45
-30
No files found.
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
import
json
import
re
from
collections.abc
import
Sequence
from
typing
import
Union
import
partial_json_parser
import
regex
as
re
from
partial_json_parser.core.options
import
Allow
from
vllm.entrypoints.chat_utils
import
random_tool_call_id
...
...
@@ -96,8 +96,9 @@ class JambaToolParser(ToolParser):
function
=
FunctionCall
(
name
=
function_call
[
"name"
],
# function call args are JSON but as a string
arguments
=
json
.
dumps
(
function_call
[
"arguments"
])))
for
function_call
in
raw_function_calls
arguments
=
json
.
dumps
(
function_call
[
"arguments"
],
ensure_ascii
=
False
),
))
for
function_call
in
raw_function_calls
]
content
=
model_output
[:
model_output
.
...
...
@@ -187,7 +188,7 @@ class JambaToolParser(ToolParser):
diff
:
Union
[
str
,
None
]
=
current_tool_call
.
get
(
"arguments"
)
if
diff
:
diff
=
json
.
dumps
(
diff
).
replace
(
diff
=
json
.
dumps
(
diff
,
ensure_ascii
=
False
).
replace
(
self
.
streamed_args_for_tool
[
self
.
current_tool_id
],
""
)
delta
=
DeltaMessage
(
tool_calls
=
[
...
...
@@ -248,7 +249,8 @@ class JambaToolParser(ToolParser):
"mid-arguments"
)
delta
=
None
elif
cur_arguments
and
not
prev_arguments
:
cur_arguments_json
=
json
.
dumps
(
cur_arguments
)
cur_arguments_json
=
json
.
dumps
(
cur_arguments
,
ensure_ascii
=
False
)
logger
.
debug
(
"finding %s in %s"
,
new_text
,
cur_arguments_json
)
...
...
@@ -267,8 +269,10 @@ class JambaToolParser(ToolParser):
self
.
current_tool_id
]
+=
arguments_delta
elif
cur_arguments
and
prev_arguments
:
cur_args_json
=
json
.
dumps
(
cur_arguments
)
prev_args_json
=
json
.
dumps
(
prev_arguments
)
cur_args_json
=
json
.
dumps
(
cur_arguments
,
ensure_ascii
=
False
)
prev_args_json
=
json
.
dumps
(
prev_arguments
,
ensure_ascii
=
False
)
logger
.
debug
(
"Searching for diff between
\n
%s
\n
%s"
,
cur_args_json
,
prev_args_json
)
...
...
vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py
0 → 100644
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
import
ast
import
json
from
collections.abc
import
Sequence
from
typing
import
Any
,
Union
import
regex
as
re
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
ToolCall
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
ToolParser
,
ToolParserManager
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
_UnexpectedAstError
(
Exception
):
pass
@
ToolParserManager
.
register_module
(
"llama4_pythonic"
)
class
Llama4PythonicToolParser
(
ToolParser
):
"""
Toolcall parser for Llama4 that produce tool calls in a pythonic style
Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic
"""
# TODO(mdepinet): Possible future improvements:
# 1. Support text + tools separated by either <|python_tag|> or \n\n
# 2. Support tools outside of a list (or separated by a semicolon).
# This depends on item 1 for consistent streaming.
# Neither of these are necessary for e.g. ToolACE, but both would help make
# Llama3.2 models more reliable.
TOOL_CALL_REGEX
=
re
.
compile
(
r
"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]"
,
re
.
DOTALL
)
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerBase
):
super
().
__init__
(
tokenizer
)
# Rename for readability. This is NOT a tool id.
@
property
def
current_tool_index
(
self
)
->
int
:
return
self
.
current_tool_id
@
current_tool_index
.
setter
def
current_tool_index
(
self
,
value
:
int
)
->
None
:
self
.
current_tool_id
=
value
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
ExtractedToolCallInformation
:
"""
Extract the tool calls from a complete model response.
"""
# remove <|python_start|> and <|python_end|>
# as Llama 4 model sometime will output those tokens
if
model_output
.
startswith
(
"<|python_start|>"
):
model_output
=
model_output
[
len
(
"<|python_start|>"
):]
model_output
=
model_output
.
replace
(
"<|python_end|>"
,
""
)
if
not
(
self
.
TOOL_CALL_REGEX
.
match
(
model_output
)):
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
try
:
module
=
ast
.
parse
(
model_output
)
parsed
=
getattr
(
module
.
body
[
0
],
"value"
,
None
)
if
isinstance
(
parsed
,
ast
.
List
)
and
all
(
isinstance
(
e
,
ast
.
Call
)
for
e
in
parsed
.
elts
):
return
ExtractedToolCallInformation
(
tools_called
=
True
,
tool_calls
=
[
_handle_single_tool
(
e
)
# type: ignore
for
e
in
parsed
.
elts
],
content
=
None
)
else
:
raise
_UnexpectedAstError
(
"Tool output must be a list of function calls"
)
except
Exception
:
logger
.
exception
(
"Error in extracting tool call from response."
)
# Treat as regular text
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
Union
[
DeltaMessage
,
None
]:
if
not
current_text
.
startswith
(
"["
)
and
not
current_text
.
startswith
(
"<|python_start|>"
):
return
DeltaMessage
(
content
=
delta_text
)
try
:
# remove <|python_start|> and <|python_end|>
if
current_text
.
startswith
(
"<|python_start|>"
):
current_text
=
current_text
[
len
(
"<|python_start|>"
):]
if
current_text
.
endswith
(
"<|python_end|>"
):
current_text
=
current_text
[:
current_text
.
rfind
(
"<|python_end|>"
)]
valid_and_added_text
=
_make_valid_python
(
current_text
)
if
valid_and_added_text
is
None
:
return
None
valid_text
,
added_text
=
valid_and_added_text
module
=
ast
.
parse
(
valid_text
)
parsed
=
getattr
(
module
.
body
[
0
],
"value"
,
None
)
if
not
isinstance
(
parsed
,
ast
.
List
)
or
not
all
(
isinstance
(
e
,
ast
.
Call
)
for
e
in
parsed
.
elts
):
raise
_UnexpectedAstError
(
"Tool output must be a list of function calls"
)
tool_calls
=
[
_handle_single_tool
(
e
)
# type: ignore
for
e
in
parsed
.
elts
]
tool_deltas
=
[]
for
index
,
new_call
in
enumerate
(
tool_calls
):
if
index
<
self
.
current_tool_index
:
continue
self
.
current_tool_index
=
index
if
len
(
self
.
streamed_args_for_tool
)
==
index
:
self
.
streamed_args_for_tool
.
append
(
""
)
new_call_complete
=
index
<
len
(
tool_calls
)
-
1
or
")]"
not
in
added_text
if
new_call_complete
:
self
.
current_tool_index
+=
1
withheld_suffix
=
(
added_text
[:
-
2
]
if
not
new_call_complete
else
""
)
if
not
new_call_complete
and
added_text
[
-
2
]
==
")"
:
# Function call is incomplete. Withhold the closing bracket.
withheld_suffix
=
withheld_suffix
+
"}"
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix
=
withheld_suffix
.
replace
(
"'"
,
'"'
)
delta
=
_compute_tool_delta
(
self
.
streamed_args_for_tool
[
index
],
new_call
,
index
,
withheld_suffix
)
if
delta
is
not
None
:
tool_deltas
.
append
(
delta
)
if
(
delta
.
function
is
not
None
and
delta
.
function
.
arguments
is
not
None
):
self
.
streamed_args_for_tool
[
index
]
+=
delta
.
function
.
arguments
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining it's final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if
tool_deltas
and
not
self
.
prev_tool_call_arr
:
self
.
prev_tool_call_arr
=
[{
"arguments"
:
{}}]
if
tool_deltas
:
return
DeltaMessage
(
tool_calls
=
tool_deltas
)
elif
not
added_text
and
self
.
current_tool_id
>
0
:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return
DeltaMessage
(
content
=
''
)
else
:
return
None
except
Exception
:
logger
.
exception
(
"Error trying to handle streaming tool call."
)
logger
.
debug
(
"Skipping chunk as a result of tool streaming extraction "
"error"
)
return
None
def
_get_parameter_value
(
val
:
ast
.
expr
)
->
Any
:
if
isinstance
(
val
,
ast
.
Constant
):
return
val
.
value
elif
isinstance
(
val
,
ast
.
Dict
):
if
not
all
(
isinstance
(
k
,
ast
.
Constant
)
for
k
in
val
.
keys
):
raise
_UnexpectedAstError
(
"Dict tool call arguments must have literal keys"
)
return
{
k
.
value
:
_get_parameter_value
(
v
)
# type: ignore
for
k
,
v
in
zip
(
val
.
keys
,
val
.
values
)
}
elif
isinstance
(
val
,
ast
.
List
):
return
[
_get_parameter_value
(
v
)
for
v
in
val
.
elts
]
else
:
raise
_UnexpectedAstError
(
"Tool call arguments must be literals"
)
def
_handle_single_tool
(
call
:
ast
.
Call
)
->
ToolCall
:
if
not
isinstance
(
call
.
func
,
ast
.
Name
):
raise
_UnexpectedAstError
(
"Invalid tool call name"
)
function_name
=
call
.
func
.
id
arguments
=
{}
for
keyword
in
call
.
keywords
:
arguments
[
keyword
.
arg
]
=
_get_parameter_value
(
keyword
.
value
)
return
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
function_name
,
arguments
=
json
.
dumps
(
arguments
)))
def
_make_valid_python
(
text
:
str
)
->
Union
[
tuple
[
str
,
str
],
None
]:
bracket_stack
=
[]
for
index
,
char
in
enumerate
(
text
):
if
char
in
{
"["
,
"("
,
"{"
}:
bracket_stack
.
append
(
char
)
elif
char
==
"]"
:
if
not
bracket_stack
or
bracket_stack
.
pop
()
!=
"["
:
raise
_UnexpectedAstError
(
"Mismatched square brackets"
)
elif
char
==
")"
:
if
not
bracket_stack
or
bracket_stack
.
pop
()
!=
"("
:
raise
_UnexpectedAstError
(
"Mismatched parentheses"
)
elif
char
==
"}"
:
if
not
bracket_stack
or
bracket_stack
.
pop
()
!=
"{"
:
raise
_UnexpectedAstError
(
"Mismatched curly braces"
)
elif
char
in
{
"'"
,
'"'
}:
if
bracket_stack
and
bracket_stack
[
-
1
]
==
char
:
if
index
>
0
and
text
[
index
-
1
]
==
"
\\
"
:
# Treat an escaped quote as a regular character
pass
else
:
bracket_stack
.
pop
()
elif
bracket_stack
and
bracket_stack
[
-
1
]
in
{
"'"
,
'"'
}:
# Double quote within a single quote string or vice versa.
pass
else
:
bracket_stack
.
append
(
char
)
text
=
text
.
rstrip
()
if
text
.
endswith
(
"="
)
or
text
.
endswith
(
":"
):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return
None
if
bracket_stack
and
bracket_stack
[
-
1
]
==
"{"
:
trailing_dict_text
=
text
[:
text
.
rfind
(
"{"
)]
num_keys
=
trailing_dict_text
.
count
(
":"
)
num_values
=
trailing_dict_text
.
count
(
","
)
if
num_keys
<=
num_values
:
return
None
# Incomplete property name within parameter value
if
bracket_stack
and
bracket_stack
[
-
1
]
==
"("
:
trailing_params_text
=
text
[:
text
.
rfind
(
"("
)]
num_full_param_names
=
trailing_params_text
.
count
(
"="
)
num_full_param_values
=
trailing_params_text
.
count
(
","
)
if
num_full_param_names
<=
num_full_param_values
:
return
None
# Incomplete parameter name
if
text
.
endswith
(
","
):
text
=
text
[:
-
1
]
if
bracket_stack
and
bracket_stack
[
-
1
]
==
"["
and
not
text
.
endswith
(
"["
)
and
not
text
.
endswith
(
")"
):
return
None
# Incomplete function name
added_text
=
""
for
char
in
reversed
(
bracket_stack
):
if
char
==
"["
:
added_text
+=
"]"
elif
char
==
"("
:
added_text
+=
")"
elif
char
==
"{"
:
added_text
+=
"}"
elif
char
==
"'"
:
added_text
+=
"'"
elif
char
==
'"'
:
added_text
+=
'"'
return
text
+
added_text
,
added_text
def
_compute_tool_delta
(
previously_sent_args
:
str
,
new_call
:
ToolCall
,
index
:
int
,
withheld_suffix
:
str
)
->
Union
[
DeltaToolCall
,
None
]:
new_call_args
=
new_call
.
function
.
arguments
if
withheld_suffix
:
assert
new_call_args
.
endswith
(
withheld_suffix
)
new_call_args
=
new_call_args
[:
-
len
(
withheld_suffix
)]
if
not
previously_sent_args
:
return
DeltaToolCall
(
id
=
new_call
.
id
,
type
=
"function"
,
index
=
index
,
function
=
DeltaFunctionCall
(
name
=
new_call
.
function
.
name
,
arguments
=
new_call_args
,
))
arg_diff
=
new_call_args
[
len
(
previously_sent_args
):]
return
DeltaToolCall
(
id
=
None
,
index
=
index
,
function
=
DeltaFunctionCall
(
arguments
=
arg_diff
))
if
arg_diff
else
None
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
import
json
import
re
from
collections.abc
import
Sequence
from
json
import
JSONDecoder
from
typing
import
Union
import
partial_json_parser
import
regex
as
re
from
partial_json_parser.core.options
import
Allow
from
transformers
import
PreTrainedTokenizerBase
...
...
@@ -88,7 +88,8 @@ class Llama3JsonToolParser(ToolParser):
# function call args are JSON but as a string
arguments
=
json
.
dumps
(
raw_function_call
[
"arguments"
]
\
if
"arguments"
in
raw_function_call
\
else
raw_function_call
[
"parameters"
])))
else
raw_function_call
[
"parameters"
],
ensure_ascii
=
False
)))
for
raw_function_call
in
function_call_arr
]
...
...
@@ -174,7 +175,8 @@ class Llama3JsonToolParser(ToolParser):
if
self
.
current_tool_id
>=
0
:
cur_arguments
=
current_tool_call
.
get
(
"arguments"
)
if
cur_arguments
:
cur_args_json
=
json
.
dumps
(
cur_arguments
)
cur_args_json
=
json
.
dumps
(
cur_arguments
,
ensure_ascii
=
False
)
sent
=
len
(
self
.
streamed_args_for_tool
[
self
.
current_tool_id
])
argument_diff
=
cur_args_json
[
sent
:]
...
...
@@ -226,7 +228,8 @@ class Llama3JsonToolParser(ToolParser):
if
cur_arguments
:
sent
=
len
(
self
.
streamed_args_for_tool
[
self
.
current_tool_id
])
cur_args_json
=
json
.
dumps
(
cur_arguments
)
cur_args_json
=
json
.
dumps
(
cur_arguments
,
ensure_ascii
=
False
)
prev_arguments
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
)
...
...
@@ -234,7 +237,8 @@ class Llama3JsonToolParser(ToolParser):
if
is_complete
[
self
.
current_tool_id
]:
argument_diff
=
cur_args_json
[
sent
:]
elif
prev_arguments
:
prev_args_json
=
json
.
dumps
(
prev_arguments
)
prev_args_json
=
json
.
dumps
(
prev_arguments
,
ensure_ascii
=
False
)
if
cur_args_json
!=
prev_args_json
:
prefix
=
find_common_prefix
(
...
...
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
import
json
import
re
from
collections.abc
import
Sequence
from
random
import
choices
from
string
import
ascii_letters
,
digits
from
typing
import
Union
import
partial_json_parser
import
regex
as
re
from
partial_json_parser.core.options
import
Allow
from
pydantic
import
Field
...
...
vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
import
json
import
re
from
collections.abc
import
Sequence
from
typing
import
Any
,
Optional
import
regex
as
re
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.chat_utils
import
random_tool_call_id
...
...
@@ -79,10 +79,11 @@ class Phi4MiniJsonToolParser(ToolParser):
name
=
raw_function_call
[
"name"
],
# function call args are JSON but as a string
arguments
=
json
.
dumps
(
raw_function_call
[
"arguments"
]
if
"arguments"
in
raw_function_call
else
raw_function_call
[
"parameters"
])))
for
raw_function_call
in
function_call_arr
raw_function_call
[
"arguments"
]
if
"arguments"
in
raw_function_call
else
raw_function_call
[
"parameters"
],
ensure_ascii
=
False
),
))
for
raw_function_call
in
function_call_arr
]
# get any content before the tool call
...
...
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
View file @
4eabe123
...
...
@@ -2,10 +2,10 @@
import
ast
import
json
import
re
from
collections.abc
import
Sequence
from
typing
import
Any
,
Union
import
regex
as
re
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
...
...
@@ -200,9 +200,12 @@ def _handle_single_tool(call: ast.Call) -> ToolCall:
arguments
=
{}
for
keyword
in
call
.
keywords
:
arguments
[
keyword
.
arg
]
=
_get_parameter_value
(
keyword
.
value
)
return
ToolCall
(
type
=
"function"
,
return
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
function_name
,
arguments
=
json
.
dumps
(
arguments
)))
arguments
=
json
.
dumps
(
arguments
,
ensure_ascii
=
False
)),
)
def
_make_valid_python
(
text
:
str
)
->
Union
[
tuple
[
str
,
str
],
None
]:
...
...
vllm/entrypoints/utils.py
View file @
4eabe123
...
...
@@ -13,6 +13,13 @@ from vllm.logger import init_logger
logger
=
init_logger
(
__name__
)
VLLM_SERVE_PARSER_EPILOG
=
(
"Tip: Use `vllm serve --help=<keyword>` to explore arguments from help.
\n
"
" - To view a argument group: --help=ModelConfig
\n
"
" - To view a single argument: --help=max-num-seqs
\n
"
" - To search by keyword: --help=max
\n
"
" - To list all groups: --help=listgroup"
)
async
def
listen_for_disconnect
(
request
:
Request
)
->
None
:
"""Returns if a disconnect message is received"""
...
...
@@ -158,3 +165,55 @@ def _validate_truncation_size(
tokenization_kwargs
[
"max_length"
]
=
truncate_prompt_tokens
return
truncate_prompt_tokens
def
show_filtered_argument_or_group_from_help
(
parser
):
import
sys
for
arg
in
sys
.
argv
:
if
arg
.
startswith
(
'--help='
):
search_keyword
=
arg
.
split
(
'='
,
1
)[
1
]
# List available groups
if
search_keyword
==
'listgroup'
:
print
(
"
\n
Available argument groups:"
)
for
group
in
parser
.
_action_groups
:
if
group
.
title
and
not
group
.
title
.
startswith
(
"positional arguments"
):
print
(
f
" -
{
group
.
title
}
"
)
if
group
.
description
:
print
(
" "
+
group
.
description
.
strip
())
print
()
sys
.
exit
(
0
)
# For group search
formatter
=
parser
.
_get_formatter
()
for
group
in
parser
.
_action_groups
:
if
group
.
title
and
group
.
title
.
lower
()
==
search_keyword
.
lower
(
):
formatter
.
start_section
(
group
.
title
)
formatter
.
add_text
(
group
.
description
)
formatter
.
add_arguments
(
group
.
_group_actions
)
formatter
.
end_section
()
print
(
formatter
.
format_help
())
sys
.
exit
(
0
)
# For single arg
matched_actions
=
[]
for
group
in
parser
.
_action_groups
:
for
action
in
group
.
_group_actions
:
# search option name
if
any
(
search_keyword
.
lower
()
in
opt
.
lower
()
for
opt
in
action
.
option_strings
):
matched_actions
.
append
(
action
)
if
matched_actions
:
print
(
f
"
\n
Parameters matching '
{
search_keyword
}
':
\n
"
)
formatter
=
parser
.
_get_formatter
()
formatter
.
add_arguments
(
matched_actions
)
print
(
formatter
.
format_help
())
sys
.
exit
(
0
)
print
(
f
"
\n
No group or parameter matching '
{
search_keyword
}
'"
)
print
(
"Tip: use `--help=listgroup` to view all groups."
)
sys
.
exit
(
1
)
vllm/envs.py
View file @
4eabe123
...
...
@@ -117,6 +117,7 @@ if TYPE_CHECKING:
VLLM_NIXL_SIDE_CHANNEL_HOST
:
str
=
"localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT
:
int
=
5557
VLLM_ALL2ALL_BACKEND
:
str
=
"naive"
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
:
int
=
163840
def
get_default_cache_root
():
...
...
@@ -163,7 +164,7 @@ def get_vllm_port() -> Optional[int]:
raise
ValueError
(
f
"VLLM_PORT '
{
port
}
' appears to be a URI. "
"This may be caused by a Kubernetes service discovery issue"
"check the warning in: https://docs.vllm.ai/en/stable/
serving
/env_vars.html"
"check the warning in: https://docs.vllm.ai/en/stable/
usage
/env_vars.html"
)
except
Exception
:
pass
...
...
@@ -175,7 +176,7 @@ def get_vllm_port() -> Optional[int]:
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
#
begin-
env-vars-definition
#
--8<-- [start:
env-vars-definition
]
environment_variables
:
dict
[
str
,
Callable
[[],
Any
]]
=
{
...
...
@@ -809,11 +810,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
int
(
os
.
getenv
(
"VLLM_NIXL_SIDE_CHANNEL_PORT"
,
"5557"
)),
# all2all backend for vllm's expert parallel communication
# Available options:
# - "naive": naive all2all implementation using all-reduce
# - "pplx": use pplx kernels
"VLLM_ALL2ALL_BACKEND"
:
lambda
:
os
.
getenv
(
"VLLM_ALL2ALL_BACKEND"
,
"naive"
),
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
# the blockscale tensor of activations NVFP4 Quantization.
# This is used to prevent the kernel from running out of memory.
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE"
,
"163840"
)),
}
# end
-
env-vars-definition
#
--8<-- [
end
:
env-vars-definition
]
def
__getattr__
(
name
:
str
):
...
...
vllm/executor/executor_base.py
View file @
4eabe123
...
...
@@ -74,7 +74,7 @@ class ExecutorBase(ABC):
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
{exc}
`TimeoutError` on timeout. `None` means wait indefinitely.
[
`TimeoutError`
][]
on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
...
...
vllm/executor/ray_distributed_executor.py
View file @
4eabe123
...
...
@@ -528,12 +528,12 @@ class RayDistributedExecutor(DistributedExecutorBase):
ray
.
get
(
parallel_worker_tasks
)
def
_check_ray_cgraph_installation
(
self
):
import
pkg_resources
import
importlib.metadata
from
packaging
import
version
required_version
=
version
.
parse
(
"2.43.0"
)
current_version
=
version
.
parse
(
pkg_resources
.
get_distribution
(
"ray"
).
version
)
current_version
=
version
.
parse
(
importlib
.
metadata
.
version
(
"ray"
))
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
is "
f
"required, but found
{
current_version
}
"
)
...
...
vllm/executor/ray_utils.py
View file @
4eabe123
...
...
@@ -87,9 +87,8 @@ try:
# TODO(swang): This is needed right now because Ray Compiled Graph
# executes on a background thread, so we need to reset torch's
# current device.
import
torch
if
not
self
.
compiled_dag_cuda_device_set
:
torch
.
cuda
.
set_device
(
self
.
worker
.
device
)
current_platform
.
set_device
(
self
.
worker
.
device
)
self
.
compiled_dag_cuda_device_set
=
True
output
=
self
.
worker
.
_execute_model_spmd
(
execute_model_req
,
...
...
@@ -113,8 +112,7 @@ try:
# Not needed
pass
else
:
import
torch
torch
.
cuda
.
set_device
(
self
.
worker
.
device
)
current_platform
.
set_device
(
self
.
worker
.
device
)
self
.
compiled_dag_cuda_device_set
=
True
...
...
vllm/forward_context.py
View file @
4eabe123
...
...
@@ -120,7 +120,10 @@ def set_forward_context(attn_metadata: Any,
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
torch
.
cuda
.
synchronize
()
from
vllm.platforms
import
current_platform
synchronize
=
current_platform
.
synchronize
if
synchronize
is
not
None
:
synchronize
()
now
=
time
.
perf_counter
()
# time measurement is in milliseconds
batchsize_forward_time
[
batchsize
].
append
(
...
...
vllm/inputs/__init__.py
View file @
4eabe123
...
...
@@ -10,8 +10,9 @@ from .registry import (DummyData, InputContext, InputProcessingContext,
INPUT_REGISTRY
=
InputRegistry
()
"""
The global {class}`~InputRegistry` which is used by {class}`~vllm.LLMEngine`
to dispatch data processing according to the target model.
The global [`InputRegistry`][vllm.inputs.registry.InputRegistry] which is used
by [`LLMEngine`][vllm.LLMEngine] to dispatch data processing according to the
target model.
"""
__all__
=
[
...
...
vllm/inputs/data.py
View file @
4eabe123
...
...
@@ -80,22 +80,24 @@ SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
"""
Set of possible schemas for a single prompt:
- A text prompt (
{class}
`str` or
{class}`
TextPrompt
`
)
- A tokenized prompt (
{class}
`TokensPrompt`)
- An embeddings prompt (
{class}
`EmbedsPrompt`)
- A text prompt (
[
`str`
][]
or
[`TextPrompt`][vllm.inputs.data.
TextPrompt
]
)
- A tokenized prompt (
[
`TokensPrompt`
][vllm.inputs.data.TokensPrompt]
)
- An embeddings prompt (
[
`EmbedsPrompt`
][vllm.inputs.data.EmbedsPrompt]
)
Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e. {class}`ExplicitEncoderDecoderPrompt`
prompts explicitly, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
A prompt of type
{class}`
SingletonPrompt
`
may be
employed
as (1) input to a decoder-only model, (2) input to
A prompt of type
[`SingletonPrompt`][vllm.inputs.data.
SingletonPrompt
]
may be
employed
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e. {class}`ExplicitEncoderDecoderPrompt`
more than one prompt, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
"""
...
...
@@ -126,18 +128,20 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
comprising an explicit encoder prompt and a decoder prompt.
The encoder and decoder prompts, respectively, may be formatted
according to any of the {class}`SingletonPrompt` schemas,
according to any of the
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas,
and are not required to have the same schema.
Only the encoder prompt may have multi-modal data. mm_processor_kwargs
should be at the top-level, and should not be set in the encoder/decoder
prompts, since they are agnostic to the encoder/decoder.
Note that an {class}`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model,
Note that an
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
may not be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure themselves must be
{class}`
SingletonPrompt
`
instances.
[`SingletonPrompt`][vllm.inputs.data.
SingletonPrompt
]
instances.
"""
encoder_prompt
:
_T1_co
...
...
@@ -152,11 +156,11 @@ PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
- A text prompt (
{class}
`str` or
{class}`
TextPrompt
`
)
- A tokenized prompt (
{class}
`TokensPrompt`)
- An embeddings prompt (
{class}
`EmbedsPrompt`)
- A text prompt (
[
`str`
][]
or
[`TextPrompt`][vllm.inputs.data.
TextPrompt
]
)
- A tokenized prompt (
[
`TokensPrompt`
][vllm.inputs.data.TokensPrompt]
)
- An embeddings prompt (
[
`EmbedsPrompt`
][vllm.inputs.data.EmbedsPrompt]
)
- A single data structure containing both an encoder and a decoder prompt
(
{class}`
ExplicitEncoderDecoderPrompt
`
)
(
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.
ExplicitEncoderDecoderPrompt
]
)
"""
...
...
@@ -189,7 +193,8 @@ def token_inputs(
prompt
:
Optional
[
str
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
)
->
TokenInputs
:
"""Construct {class}`TokenInputs` from optional values."""
"""Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
values."""
inputs
=
TokenInputs
(
type
=
"token"
,
prompt_token_ids
=
prompt_token_ids
)
if
prompt
is
not
None
:
...
...
@@ -221,7 +226,8 @@ def embeds_inputs(
prompt_embeds
:
torch
.
Tensor
,
cache_salt
:
Optional
[
str
]
=
None
,
)
->
EmbedsInputs
:
"""Construct :class:`EmbedsInputs` from optional values."""
"""Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
values."""
inputs
=
EmbedsInputs
(
type
=
"embeds"
,
prompt_embeds
=
prompt_embeds
)
if
cache_salt
is
not
None
:
...
...
@@ -232,7 +238,7 @@ def embeds_inputs(
DecoderOnlyInputs
=
Union
[
TokenInputs
,
EmbedsInputs
,
"MultiModalInputs"
]
"""
The inputs in
{class}`~vllm
.LLMEngine
`
before they are
The inputs in
[`LLMEngine`][vllm.engine.llm_engine
.LLMEngine
]
before they are
passed to the model executor.
This specifies the data required for decoder-only models.
"""
...
...
@@ -240,11 +246,12 @@ This specifies the data required for decoder-only models.
class
EncoderDecoderInputs
(
TypedDict
):
"""
The inputs in
{class}`~vllm
.LLMEngine
`
before they
are
passed to the model executor.
The inputs in
[`LLMEngine`][vllm.engine.llm_engine
.LLMEngine
]
before they
are
passed to the model executor.
This specifies the required data for encoder-decoder models.
"""
encoder
:
Union
[
TokenInputs
,
"MultiModalInputs"
]
"""The inputs for the encoder portion."""
...
...
@@ -254,13 +261,13 @@ class EncoderDecoderInputs(TypedDict):
SingletonInputs
=
Union
[
TokenInputs
,
EmbedsInputs
,
"MultiModalInputs"
]
"""
A processed
{class}
`SingletonPrompt` which can be
passed to
{class}
`vllm.sequence.Sequence`.
A processed
[
`SingletonPrompt`
][vllm.inputs.data.SingletonPrompt]
which can be
passed to [
`vllm.sequence.Sequence`
][]
.
"""
ProcessorInputs
=
Union
[
DecoderOnlyInputs
,
EncoderDecoderInputs
]
"""
The
in
puts
to {data}
`vllm.inputs.InputProcessor`.
The
out
puts
from [
`vllm.inputs.
preprocess.
InputPr
epr
ocessor`
][]
.
"""
_T1
=
TypeVar
(
"_T1"
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
)
...
...
@@ -277,7 +284,8 @@ def build_explicit_enc_dec_prompt(
return
ExplicitEncoderDecoderPrompt
(
encoder_prompt
=
encoder_prompt
,
decoder_prompt
=
decoder_prompt
,
mm_processor_kwargs
=
mm_processor_kwargs
)
mm_processor_kwargs
=
mm_processor_kwargs
,
)
def
zip_enc_dec_prompts
(
...
...
@@ -288,7 +296,8 @@ def zip_enc_dec_prompts(
)
->
list
[
ExplicitEncoderDecoderPrompt
[
_T1
,
_T2
]]:
"""
Zip encoder and decoder prompts together into a list of
{class}`ExplicitEncoderDecoderPrompt` instances.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
instances.
``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
dictionary will be used for every encoder/decoder prompt. If an iterable is
...
...
@@ -299,9 +308,10 @@ def zip_enc_dec_prompts(
if
isinstance
(
mm_processor_kwargs
,
dict
):
return
[
build_explicit_enc_dec_prompt
(
encoder_prompt
,
decoder_prompt
,
cast
(
dict
[
str
,
Any
],
mm_processor_kwargs
))
for
(
encoder_prompt
,
encoder_prompt
,
decoder_prompt
,
cast
(
dict
[
str
,
Any
],
mm_processor_kwargs
),
)
for
(
encoder_prompt
,
decoder_prompt
)
in
zip
(
enc_prompts
,
dec_prompts
)
]
return
[
...
...
vllm/inputs/parse.py
View file @
4eabe123
...
...
@@ -23,13 +23,13 @@ class ParsedTokens(TypedDict):
@
overload
def
parse_and_batch_prompt
(
prompt
:
Union
[
str
,
list
[
str
]])
->
Sequence
[
ParsedText
]:
prompt
:
Union
[
str
,
list
[
str
]]
,
)
->
Sequence
[
ParsedText
]:
...
@
overload
def
parse_and_batch_prompt
(
prompt
:
Union
[
list
[
int
],
list
[
list
[
int
]]])
->
Sequence
[
ParsedTokens
]:
prompt
:
Union
[
list
[
int
],
list
[
list
[
int
]]]
,
)
->
Sequence
[
ParsedTokens
]:
...
...
...
@@ -86,7 +86,7 @@ class ParsedTokensPrompt(TypedDict):
class
ParsedEmbedsPrompt
(
TypedDict
):
type
:
Literal
[
'
embeds
'
]
type
:
Literal
[
"
embeds
"
]
content
:
EmbedsPrompt
...
...
@@ -133,7 +133,7 @@ def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
def
is_explicit_encoder_decoder_prompt
(
prompt
:
PromptType
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
prompt
:
PromptType
,
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
return
isinstance
(
prompt
,
dict
)
and
"encoder_prompt"
in
prompt
...
...
vllm/inputs/preprocess.py
View file @
4eabe123
...
...
@@ -67,11 +67,11 @@ class InputPreprocessor:
return
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
eos_token_id
def
get_decoder_start_token_id
(
self
)
->
Optional
[
int
]:
'''
"""
Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the
model config is unavailable.
'''
"""
if
not
self
.
model_config
.
is_encoder_decoder
:
logger
.
warning_once
(
...
...
@@ -79,14 +79,14 @@ class InputPreprocessor:
"this is not an encoder/decoder model."
)
return
None
if
(
self
.
model_config
is
None
or
self
.
model_config
.
hf_config
is
None
)
:
if
self
.
model_config
is
None
or
self
.
model_config
.
hf_config
is
None
:
logger
.
warning_once
(
"Using None for decoder start token id because "
"model config is not available."
)
return
None
dec_start_token_id
=
getattr
(
self
.
model_config
.
hf_config
,
'
decoder_start_token_id
'
,
None
)
"
decoder_start_token_id
"
,
None
)
if
dec_start_token_id
is
None
:
logger
.
warning_once
(
"Falling back on <BOS> for decoder start token "
...
...
@@ -97,7 +97,7 @@ class InputPreprocessor:
return
dec_start_token_id
def
_get_default_enc_dec_decoder_prompt
(
self
)
->
list
[
int
]:
'''
"""
Specifically for encoder/decoder models:
generate a default decoder prompt for when
the user specifies only the encoder prompt.
...
...
@@ -126,7 +126,7 @@ class InputPreprocessor:
Returns:
* prompt_token_ids
'''
"""
bos_token_id
=
self
.
get_bos_token_id
()
assert
bos_token_id
is
not
None
...
...
@@ -224,7 +224,10 @@ class InputPreprocessor:
lora_request
:
Optional
[
LoRARequest
],
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
list
[
int
]:
"""Async version of {meth}`_tokenize_prompt`."""
"""
Async version of
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
"""
tokenizer
=
self
.
get_tokenizer_group
()
tokenization_kwargs
=
self
.
_get_tokenization_kw
(
tokenization_kwargs
)
...
...
@@ -287,7 +290,10 @@ class InputPreprocessor:
lora_request
:
Optional
[
LoRARequest
],
return_mm_hashes
:
bool
=
False
,
)
->
MultiModalInputs
:
"""Async version of {meth}`_process_multimodal`."""
"""
Async version of
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
"""
tokenizer
=
await
self
.
_get_mm_tokenizer_async
(
lora_request
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
...
...
@@ -472,7 +478,7 @@ class InputPreprocessor:
Returns:
*
{class}`
SingletonInputs
`
instance
*
[`SingletonInputs`][vllm.inputs.data.
SingletonInputs
]
instance
"""
parsed
=
parse_singleton_prompt
(
prompt
)
...
...
@@ -508,7 +514,10 @@ class InputPreprocessor:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
)
->
SingletonInputs
:
"""Async version of {meth}`_prompt_to_llm_inputs`."""
"""
Async version of
[`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs].
"""
parsed
=
parse_singleton_prompt
(
prompt
)
if
parsed
[
"type"
]
==
"embeds"
:
...
...
@@ -644,7 +653,9 @@ class InputPreprocessor:
)
->
EncoderDecoderInputs
:
"""
For encoder/decoder models only:
Process an input prompt into an {class}`EncoderDecoderInputs` instance.
Process an input prompt into an
[`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance.
There are two types of input prompts:
singleton prompts which carry only the
...
...
@@ -670,7 +681,8 @@ class InputPreprocessor:
Returns:
* {class}`EncoderDecoderInputs` instance
* [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance
"""
encoder_inputs
:
SingletonInputs
decoder_inputs
:
Optional
[
SingletonInputs
]
...
...
@@ -710,7 +722,10 @@ class InputPreprocessor:
prompt
:
PromptType
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
EncoderDecoderInputs
:
"""Async version of {meth}`_process_encoder_decoder_prompt`."""
"""
Async version of
[`_process_encoder_decoder_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_encoder_decoder_prompt].
"""
encoder_inputs
:
SingletonInputs
decoder_inputs
:
Optional
[
SingletonInputs
]
...
...
@@ -778,7 +793,8 @@ class InputPreprocessor:
)
->
DecoderOnlyInputs
:
"""
For decoder-only models:
Process an input prompt into an {class}`DecoderOnlyInputs` instance.
Process an input prompt into a
[`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
Arguments:
...
...
@@ -789,7 +805,7 @@ class InputPreprocessor:
Returns:
*
{class}`
DecoderOnlyInputs
`
instance
*
[`DecoderOnlyInputs`][vllm.inputs.data.
DecoderOnlyInputs
]
instance
"""
prompt_comps
=
self
.
_prompt_to_llm_inputs
(
...
...
@@ -812,7 +828,10 @@ class InputPreprocessor:
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
)
->
DecoderOnlyInputs
:
"""Async version of {meth}`_process_decoder_only_prompt`."""
"""
Async version of
[`_process_decoder_only_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_decoder_only_prompt].
"""
prompt_comps
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
...
...
@@ -863,7 +882,10 @@ class InputPreprocessor:
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
)
->
ProcessorInputs
:
"""Async version of {meth}`preprocess`."""
"""
Async version of
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
"""
if
self
.
model_config
.
is_encoder_decoder
:
assert
not
return_mm_hashes
,
(
"Multimodal hashes for encoder-decoder models should not be "
,
...
...
vllm/inputs/registry.py
View file @
4eabe123
...
...
@@ -38,7 +38,7 @@ class InputContext:
)
->
_C
:
"""
Get the HuggingFace configuration
(
{class}
`transformers.PretrainedConfig`) of the model,
(`transformers.PretrainedConfig`) of the model,
additionally checking its type.
Raises:
...
...
@@ -79,7 +79,7 @@ class InputContext:
)
->
_P
:
"""
Get the HuggingFace processor
(
{class}
`transformers.ProcessorMixin`) of the model,
(`transformers.ProcessorMixin`) of the model,
additionally checking its type.
Raises:
...
...
vllm/logger.py
View file @
4eabe123
...
...
@@ -68,22 +68,22 @@ class _VllmLogger(Logger):
"""
Note:
This class is just to provide type information.
We actually patch the methods directly on the
{class}
`logging.Logger`
We actually patch the methods directly on the
[
`logging.Logger`
][]
instance to avoid conflicting with other libraries such as
`intel_extension_for_pytorch.utils._logger`.
"""
def
info_once
(
self
,
msg
:
str
,
*
args
:
Hashable
)
->
None
:
"""
As
{meth}`info`
, but subsequent calls with
the same message
are silently dropped.
As
[`info`][logging.Logger.info]
, but subsequent calls with
the same message
are silently dropped.
"""
_print_info_once
(
self
,
msg
,
*
args
)
def
warning_once
(
self
,
msg
:
str
,
*
args
:
Hashable
)
->
None
:
"""
As
{meth}
`warning`, but subsequent calls with
the same message
are silently dropped.
As
[
`warning`
][logging.Logger.warning]
, but subsequent calls with
the same message
are silently dropped.
"""
_print_warning_once
(
self
,
msg
,
*
args
)
...
...
vllm/logging_utils/dump_input.py
View file @
4eabe123
...
...
@@ -18,7 +18,7 @@ logger = init_logger(__name__)
def
prepare_object_to_dump
(
obj
)
->
str
:
if
isinstance
(
obj
,
str
):
return
"'{obj}'"
# Double quotes
return
f
"'
{
obj
}
'"
# Double quotes
elif
isinstance
(
obj
,
dict
):
dict_str
=
', '
.
join
({
f
'
{
str
(
k
)
}
:
{
prepare_object_to_dump
(
v
)
}
'
\
for
k
,
v
in
obj
.
items
()})
...
...
@@ -42,9 +42,9 @@ def prepare_object_to_dump(obj) -> str:
return
obj
.
anon_repr
()
elif
hasattr
(
obj
,
'__dict__'
):
items
=
obj
.
__dict__
.
items
()
dict_str
=
','
.
join
([
f
'
{
str
(
k
)
}
=
{
prepare_object_to_dump
(
v
)
}
'
\
dict_str
=
',
'
.
join
([
f
'
{
str
(
k
)
}
=
{
prepare_object_to_dump
(
v
)
}
'
\
for
k
,
v
in
items
])
return
(
f
"
{
type
(
obj
).
__name__
}
(
{
dict_str
}
)"
)
return
f
"
{
type
(
obj
).
__name__
}
(
{
dict_str
}
)"
else
:
# Hacky way to make sure we can serialize the object in JSON format
try
:
...
...
vllm/lora/models.py
View file @
4eabe123
...
...
@@ -3,11 +3,11 @@
import
copy
import
math
import
os
import
re
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
regex
as
re
import
safetensors.torch
import
torch
from
torch
import
nn
...
...
@@ -29,6 +29,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
get_supported_lora_modules
,
is_regex_target_modules
,
parse_fine_tuned_lora_name
,
replace_submodule
)
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.models
import
SupportsLoRA
,
supports_multimodal
from
vllm.model_executor.models.interfaces
import
is_pooling_model
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
...
...
@@ -197,7 +198,7 @@ class LoRAModel(AdapterModel):
embedding_modules
:
Optional
[
dict
[
str
,
str
]]
=
None
,
embedding_padding_modules
:
Optional
[
list
[
str
]]
=
None
,
weights_mapper
:
Optional
[
WeightsMapper
]
=
None
,
)
->
"LoRAModel"
:
tensorizer_config_dict
:
Optional
[
dict
]
=
None
)
->
"LoRAModel"
:
"""Create a LoRAModel from a local checkpoint.
Args:
...
...
@@ -219,20 +220,11 @@ class LoRAModel(AdapterModel):
lora_dir
,
"new_embeddings.safetensors"
)
new_embeddings_bin_file_path
=
os
.
path
.
join
(
lora_dir
,
"new_embeddings.bin"
)
unexpected_modules
:
list
[
Union
[
list
[
str
],
str
]]
if
os
.
path
.
isfile
(
lora_tensor_path
):
tensors
:
dict
[
str
,
torch
.
Tensor
]
=
{}
# Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist
# in the model it won’t error and model will be trained with A, B
# loraified. C won’t exist in the safetensor but it will exist in
# the target_modules of the adapter_config.json.
unexpected_modules
=
[]
with
safetensors
.
safe_open
(
lora_tensor_path
,
framework
=
"pt"
)
as
f
:
# type: ignore
for
lora_module
in
f
.
keys
():
# noqa
unexpected_modules
:
list
[
Union
[
list
[
str
],
str
]]
=
[]
def
check_unexpected_modules
(
modules
:
dict
):
for
lora_module
in
modules
.
keys
():
# noqa
module_name
,
_
,
_
=
parse_fine_tuned_lora_name
(
lora_module
,
weights_mapper
)
part_name
=
module_name
.
split
(
"."
)[
-
1
]
...
...
@@ -243,9 +235,32 @@ class LoRAModel(AdapterModel):
f
"While loading
{
lora_dir
}
, expected"
f
" target modules in
{
expected_lora_modules
}
"
f
" but received
{
unexpected_modules
}
."
f
" Please verify that the loaded LoRA module is correct"
)
f
" Please verify that the loaded LoRA module is correct"
)
if
tensorizer_config_dict
:
from
tensorizer
import
TensorDeserializer
tensorizer_config
=
TensorizerConfig
(
**
tensorizer_config_dict
)
lora_tensor_path
=
os
.
path
.
join
(
tensorizer_config
.
tensorizer_dir
,
"adapter_model.tensors"
)
tensorizer_args
=
tensorizer_config
.
_construct_tensorizer_args
()
tensors
=
TensorDeserializer
(
lora_tensor_path
,
dtype
=
tensorizer_config
.
dtype
,
**
tensorizer_args
.
deserializer_params
)
check_unexpected_modules
(
tensors
)
elif
os
.
path
.
isfile
(
lora_tensor_path
):
# Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist
# in the model it won’t error and model will be trained with A, B
# loraified. C won’t exist in the safetensor but it will exist in
# the target_modules of the adapter_config.json.
unexpected_modules
=
[]
with
safetensors
.
safe_open
(
lora_tensor_path
,
framework
=
"pt"
)
as
f
:
# type: ignore
# Load tensors if there are only expected modules.
check_unexpected_modules
(
f
)
for
module
in
f
.
keys
():
# noqa
tensors
[
module
]
=
f
.
get_tensor
(
module
)
elif
os
.
path
.
isfile
(
lora_bin_file_path
):
...
...
Prev
1
…
21
22
23
24
25
26
27
28
29
…
34
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment