Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4eabe123
Commit
4eabe123
authored
May 28, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori
parents
45840cd2
58738772
Changes
670
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
559 additions
and
126 deletions
+559
-126
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
+11
-7
vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py
...points/openai/tool_parsers/llama4_pythonic_tool_parser.py
+302
-0
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
+9
-5
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+1
-1
vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py
vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py
+6
-5
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
+7
-4
vllm/entrypoints/utils.py
vllm/entrypoints/utils.py
+59
-0
vllm/envs.py
vllm/envs.py
+14
-3
vllm/executor/executor_base.py
vllm/executor/executor_base.py
+1
-1
vllm/executor/ray_distributed_executor.py
vllm/executor/ray_distributed_executor.py
+3
-3
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+2
-4
vllm/forward_context.py
vllm/forward_context.py
+4
-1
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+3
-2
vllm/inputs/data.py
vllm/inputs/data.py
+39
-29
vllm/inputs/parse.py
vllm/inputs/parse.py
+4
-4
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+39
-17
vllm/inputs/registry.py
vllm/inputs/registry.py
+2
-2
vllm/logger.py
vllm/logger.py
+5
-5
vllm/logging_utils/dump_input.py
vllm/logging_utils/dump_input.py
+3
-3
vllm/lora/models.py
vllm/lora/models.py
+45
-30
No files found.
vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
json
import
json
import
re
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
typing
import
Union
from
typing
import
Union
import
partial_json_parser
import
partial_json_parser
import
regex
as
re
from
partial_json_parser.core.options
import
Allow
from
partial_json_parser.core.options
import
Allow
from
vllm.entrypoints.chat_utils
import
random_tool_call_id
from
vllm.entrypoints.chat_utils
import
random_tool_call_id
...
@@ -96,8 +96,9 @@ class JambaToolParser(ToolParser):
...
@@ -96,8 +96,9 @@ class JambaToolParser(ToolParser):
function
=
FunctionCall
(
function
=
FunctionCall
(
name
=
function_call
[
"name"
],
name
=
function_call
[
"name"
],
# function call args are JSON but as a string
# function call args are JSON but as a string
arguments
=
json
.
dumps
(
function_call
[
"arguments"
])))
arguments
=
json
.
dumps
(
function_call
[
"arguments"
],
for
function_call
in
raw_function_calls
ensure_ascii
=
False
),
))
for
function_call
in
raw_function_calls
]
]
content
=
model_output
[:
model_output
.
content
=
model_output
[:
model_output
.
...
@@ -187,7 +188,7 @@ class JambaToolParser(ToolParser):
...
@@ -187,7 +188,7 @@ class JambaToolParser(ToolParser):
diff
:
Union
[
str
,
None
]
=
current_tool_call
.
get
(
"arguments"
)
diff
:
Union
[
str
,
None
]
=
current_tool_call
.
get
(
"arguments"
)
if
diff
:
if
diff
:
diff
=
json
.
dumps
(
diff
).
replace
(
diff
=
json
.
dumps
(
diff
,
ensure_ascii
=
False
).
replace
(
self
.
streamed_args_for_tool
[
self
.
current_tool_id
],
self
.
streamed_args_for_tool
[
self
.
current_tool_id
],
""
)
""
)
delta
=
DeltaMessage
(
tool_calls
=
[
delta
=
DeltaMessage
(
tool_calls
=
[
...
@@ -248,7 +249,8 @@ class JambaToolParser(ToolParser):
...
@@ -248,7 +249,8 @@ class JambaToolParser(ToolParser):
"mid-arguments"
)
"mid-arguments"
)
delta
=
None
delta
=
None
elif
cur_arguments
and
not
prev_arguments
:
elif
cur_arguments
and
not
prev_arguments
:
cur_arguments_json
=
json
.
dumps
(
cur_arguments
)
cur_arguments_json
=
json
.
dumps
(
cur_arguments
,
ensure_ascii
=
False
)
logger
.
debug
(
"finding %s in %s"
,
new_text
,
logger
.
debug
(
"finding %s in %s"
,
new_text
,
cur_arguments_json
)
cur_arguments_json
)
...
@@ -267,8 +269,10 @@ class JambaToolParser(ToolParser):
...
@@ -267,8 +269,10 @@ class JambaToolParser(ToolParser):
self
.
current_tool_id
]
+=
arguments_delta
self
.
current_tool_id
]
+=
arguments_delta
elif
cur_arguments
and
prev_arguments
:
elif
cur_arguments
and
prev_arguments
:
cur_args_json
=
json
.
dumps
(
cur_arguments
)
cur_args_json
=
json
.
dumps
(
cur_arguments
,
prev_args_json
=
json
.
dumps
(
prev_arguments
)
ensure_ascii
=
False
)
prev_args_json
=
json
.
dumps
(
prev_arguments
,
ensure_ascii
=
False
)
logger
.
debug
(
"Searching for diff between
\n
%s
\n
%s"
,
logger
.
debug
(
"Searching for diff between
\n
%s
\n
%s"
,
cur_args_json
,
prev_args_json
)
cur_args_json
,
prev_args_json
)
...
...
vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py
0 → 100644
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
import
ast
import
json
from
collections.abc
import
Sequence
from
typing
import
Any
,
Union
import
regex
as
re
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaFunctionCall
,
DeltaMessage
,
DeltaToolCall
,
ExtractedToolCallInformation
,
FunctionCall
,
ToolCall
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
(
ToolParser
,
ToolParserManager
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
class
_UnexpectedAstError
(
Exception
):
pass
@
ToolParserManager
.
register_module
(
"llama4_pythonic"
)
class
Llama4PythonicToolParser
(
ToolParser
):
"""
Toolcall parser for Llama4 that produce tool calls in a pythonic style
Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic
"""
# TODO(mdepinet): Possible future improvements:
# 1. Support text + tools separated by either <|python_tag|> or \n\n
# 2. Support tools outside of a list (or separated by a semicolon).
# This depends on item 1 for consistent streaming.
# Neither of these are necessary for e.g. ToolACE, but both would help make
# Llama3.2 models more reliable.
TOOL_CALL_REGEX
=
re
.
compile
(
r
"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]"
,
re
.
DOTALL
)
def
__init__
(
self
,
tokenizer
:
PreTrainedTokenizerBase
):
super
().
__init__
(
tokenizer
)
# Rename for readability. This is NOT a tool id.
@
property
def
current_tool_index
(
self
)
->
int
:
return
self
.
current_tool_id
@
current_tool_index
.
setter
def
current_tool_index
(
self
,
value
:
int
)
->
None
:
self
.
current_tool_id
=
value
def
extract_tool_calls
(
self
,
model_output
:
str
,
request
:
ChatCompletionRequest
)
->
ExtractedToolCallInformation
:
"""
Extract the tool calls from a complete model response.
"""
# remove <|python_start|> and <|python_end|>
# as Llama 4 model sometime will output those tokens
if
model_output
.
startswith
(
"<|python_start|>"
):
model_output
=
model_output
[
len
(
"<|python_start|>"
):]
model_output
=
model_output
.
replace
(
"<|python_end|>"
,
""
)
if
not
(
self
.
TOOL_CALL_REGEX
.
match
(
model_output
)):
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
try
:
module
=
ast
.
parse
(
model_output
)
parsed
=
getattr
(
module
.
body
[
0
],
"value"
,
None
)
if
isinstance
(
parsed
,
ast
.
List
)
and
all
(
isinstance
(
e
,
ast
.
Call
)
for
e
in
parsed
.
elts
):
return
ExtractedToolCallInformation
(
tools_called
=
True
,
tool_calls
=
[
_handle_single_tool
(
e
)
# type: ignore
for
e
in
parsed
.
elts
],
content
=
None
)
else
:
raise
_UnexpectedAstError
(
"Tool output must be a list of function calls"
)
except
Exception
:
logger
.
exception
(
"Error in extracting tool call from response."
)
# Treat as regular text
return
ExtractedToolCallInformation
(
tools_called
=
False
,
tool_calls
=
[],
content
=
model_output
)
def
extract_tool_calls_streaming
(
self
,
previous_text
:
str
,
current_text
:
str
,
delta_text
:
str
,
previous_token_ids
:
Sequence
[
int
],
current_token_ids
:
Sequence
[
int
],
delta_token_ids
:
Sequence
[
int
],
request
:
ChatCompletionRequest
,
)
->
Union
[
DeltaMessage
,
None
]:
if
not
current_text
.
startswith
(
"["
)
and
not
current_text
.
startswith
(
"<|python_start|>"
):
return
DeltaMessage
(
content
=
delta_text
)
try
:
# remove <|python_start|> and <|python_end|>
if
current_text
.
startswith
(
"<|python_start|>"
):
current_text
=
current_text
[
len
(
"<|python_start|>"
):]
if
current_text
.
endswith
(
"<|python_end|>"
):
current_text
=
current_text
[:
current_text
.
rfind
(
"<|python_end|>"
)]
valid_and_added_text
=
_make_valid_python
(
current_text
)
if
valid_and_added_text
is
None
:
return
None
valid_text
,
added_text
=
valid_and_added_text
module
=
ast
.
parse
(
valid_text
)
parsed
=
getattr
(
module
.
body
[
0
],
"value"
,
None
)
if
not
isinstance
(
parsed
,
ast
.
List
)
or
not
all
(
isinstance
(
e
,
ast
.
Call
)
for
e
in
parsed
.
elts
):
raise
_UnexpectedAstError
(
"Tool output must be a list of function calls"
)
tool_calls
=
[
_handle_single_tool
(
e
)
# type: ignore
for
e
in
parsed
.
elts
]
tool_deltas
=
[]
for
index
,
new_call
in
enumerate
(
tool_calls
):
if
index
<
self
.
current_tool_index
:
continue
self
.
current_tool_index
=
index
if
len
(
self
.
streamed_args_for_tool
)
==
index
:
self
.
streamed_args_for_tool
.
append
(
""
)
new_call_complete
=
index
<
len
(
tool_calls
)
-
1
or
")]"
not
in
added_text
if
new_call_complete
:
self
.
current_tool_index
+=
1
withheld_suffix
=
(
added_text
[:
-
2
]
if
not
new_call_complete
else
""
)
if
not
new_call_complete
and
added_text
[
-
2
]
==
")"
:
# Function call is incomplete. Withhold the closing bracket.
withheld_suffix
=
withheld_suffix
+
"}"
# Strings get single quotes in the model-produced string.
# JSON requires double quotes.
withheld_suffix
=
withheld_suffix
.
replace
(
"'"
,
'"'
)
delta
=
_compute_tool_delta
(
self
.
streamed_args_for_tool
[
index
],
new_call
,
index
,
withheld_suffix
)
if
delta
is
not
None
:
tool_deltas
.
append
(
delta
)
if
(
delta
.
function
is
not
None
and
delta
.
function
.
arguments
is
not
None
):
self
.
streamed_args_for_tool
[
index
]
+=
delta
.
function
.
arguments
# HACK: serving_chat.py inspects the internal state of tool parsers
# when determining it's final streaming delta, automatically
# adding autocompleted JSON.
# These two lines avoid that nonsense while ensuring finish_reason
# is set to tool_calls when at least one tool is called.
if
tool_deltas
and
not
self
.
prev_tool_call_arr
:
self
.
prev_tool_call_arr
=
[{
"arguments"
:
{}}]
if
tool_deltas
:
return
DeltaMessage
(
tool_calls
=
tool_deltas
)
elif
not
added_text
and
self
.
current_tool_id
>
0
:
# Return an empty DeltaMessage once the tool calls are all done
# so that finish_reason gets set.
return
DeltaMessage
(
content
=
''
)
else
:
return
None
except
Exception
:
logger
.
exception
(
"Error trying to handle streaming tool call."
)
logger
.
debug
(
"Skipping chunk as a result of tool streaming extraction "
"error"
)
return
None
def
_get_parameter_value
(
val
:
ast
.
expr
)
->
Any
:
if
isinstance
(
val
,
ast
.
Constant
):
return
val
.
value
elif
isinstance
(
val
,
ast
.
Dict
):
if
not
all
(
isinstance
(
k
,
ast
.
Constant
)
for
k
in
val
.
keys
):
raise
_UnexpectedAstError
(
"Dict tool call arguments must have literal keys"
)
return
{
k
.
value
:
_get_parameter_value
(
v
)
# type: ignore
for
k
,
v
in
zip
(
val
.
keys
,
val
.
values
)
}
elif
isinstance
(
val
,
ast
.
List
):
return
[
_get_parameter_value
(
v
)
for
v
in
val
.
elts
]
else
:
raise
_UnexpectedAstError
(
"Tool call arguments must be literals"
)
def
_handle_single_tool
(
call
:
ast
.
Call
)
->
ToolCall
:
if
not
isinstance
(
call
.
func
,
ast
.
Name
):
raise
_UnexpectedAstError
(
"Invalid tool call name"
)
function_name
=
call
.
func
.
id
arguments
=
{}
for
keyword
in
call
.
keywords
:
arguments
[
keyword
.
arg
]
=
_get_parameter_value
(
keyword
.
value
)
return
ToolCall
(
type
=
"function"
,
function
=
FunctionCall
(
name
=
function_name
,
arguments
=
json
.
dumps
(
arguments
)))
def
_make_valid_python
(
text
:
str
)
->
Union
[
tuple
[
str
,
str
],
None
]:
bracket_stack
=
[]
for
index
,
char
in
enumerate
(
text
):
if
char
in
{
"["
,
"("
,
"{"
}:
bracket_stack
.
append
(
char
)
elif
char
==
"]"
:
if
not
bracket_stack
or
bracket_stack
.
pop
()
!=
"["
:
raise
_UnexpectedAstError
(
"Mismatched square brackets"
)
elif
char
==
")"
:
if
not
bracket_stack
or
bracket_stack
.
pop
()
!=
"("
:
raise
_UnexpectedAstError
(
"Mismatched parentheses"
)
elif
char
==
"}"
:
if
not
bracket_stack
or
bracket_stack
.
pop
()
!=
"{"
:
raise
_UnexpectedAstError
(
"Mismatched curly braces"
)
elif
char
in
{
"'"
,
'"'
}:
if
bracket_stack
and
bracket_stack
[
-
1
]
==
char
:
if
index
>
0
and
text
[
index
-
1
]
==
"
\\
"
:
# Treat an escaped quote as a regular character
pass
else
:
bracket_stack
.
pop
()
elif
bracket_stack
and
bracket_stack
[
-
1
]
in
{
"'"
,
'"'
}:
# Double quote within a single quote string or vice versa.
pass
else
:
bracket_stack
.
append
(
char
)
text
=
text
.
rstrip
()
if
text
.
endswith
(
"="
)
or
text
.
endswith
(
":"
):
# Since we have no type information for this property/parameter value,
# we can't fill in a valid value.
return
None
if
bracket_stack
and
bracket_stack
[
-
1
]
==
"{"
:
trailing_dict_text
=
text
[:
text
.
rfind
(
"{"
)]
num_keys
=
trailing_dict_text
.
count
(
":"
)
num_values
=
trailing_dict_text
.
count
(
","
)
if
num_keys
<=
num_values
:
return
None
# Incomplete property name within parameter value
if
bracket_stack
and
bracket_stack
[
-
1
]
==
"("
:
trailing_params_text
=
text
[:
text
.
rfind
(
"("
)]
num_full_param_names
=
trailing_params_text
.
count
(
"="
)
num_full_param_values
=
trailing_params_text
.
count
(
","
)
if
num_full_param_names
<=
num_full_param_values
:
return
None
# Incomplete parameter name
if
text
.
endswith
(
","
):
text
=
text
[:
-
1
]
if
bracket_stack
and
bracket_stack
[
-
1
]
==
"["
and
not
text
.
endswith
(
"["
)
and
not
text
.
endswith
(
")"
):
return
None
# Incomplete function name
added_text
=
""
for
char
in
reversed
(
bracket_stack
):
if
char
==
"["
:
added_text
+=
"]"
elif
char
==
"("
:
added_text
+=
")"
elif
char
==
"{"
:
added_text
+=
"}"
elif
char
==
"'"
:
added_text
+=
"'"
elif
char
==
'"'
:
added_text
+=
'"'
return
text
+
added_text
,
added_text
def
_compute_tool_delta
(
previously_sent_args
:
str
,
new_call
:
ToolCall
,
index
:
int
,
withheld_suffix
:
str
)
->
Union
[
DeltaToolCall
,
None
]:
new_call_args
=
new_call
.
function
.
arguments
if
withheld_suffix
:
assert
new_call_args
.
endswith
(
withheld_suffix
)
new_call_args
=
new_call_args
[:
-
len
(
withheld_suffix
)]
if
not
previously_sent_args
:
return
DeltaToolCall
(
id
=
new_call
.
id
,
type
=
"function"
,
index
=
index
,
function
=
DeltaFunctionCall
(
name
=
new_call
.
function
.
name
,
arguments
=
new_call_args
,
))
arg_diff
=
new_call_args
[
len
(
previously_sent_args
):]
return
DeltaToolCall
(
id
=
None
,
index
=
index
,
function
=
DeltaFunctionCall
(
arguments
=
arg_diff
))
if
arg_diff
else
None
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
json
import
json
import
re
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
json
import
JSONDecoder
from
json
import
JSONDecoder
from
typing
import
Union
from
typing
import
Union
import
partial_json_parser
import
partial_json_parser
import
regex
as
re
from
partial_json_parser.core.options
import
Allow
from
partial_json_parser.core.options
import
Allow
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
...
@@ -88,7 +88,8 @@ class Llama3JsonToolParser(ToolParser):
...
@@ -88,7 +88,8 @@ class Llama3JsonToolParser(ToolParser):
# function call args are JSON but as a string
# function call args are JSON but as a string
arguments
=
json
.
dumps
(
raw_function_call
[
"arguments"
]
\
arguments
=
json
.
dumps
(
raw_function_call
[
"arguments"
]
\
if
"arguments"
in
raw_function_call
\
if
"arguments"
in
raw_function_call
\
else
raw_function_call
[
"parameters"
])))
else
raw_function_call
[
"parameters"
],
ensure_ascii
=
False
)))
for
raw_function_call
in
function_call_arr
for
raw_function_call
in
function_call_arr
]
]
...
@@ -174,7 +175,8 @@ class Llama3JsonToolParser(ToolParser):
...
@@ -174,7 +175,8 @@ class Llama3JsonToolParser(ToolParser):
if
self
.
current_tool_id
>=
0
:
if
self
.
current_tool_id
>=
0
:
cur_arguments
=
current_tool_call
.
get
(
"arguments"
)
cur_arguments
=
current_tool_call
.
get
(
"arguments"
)
if
cur_arguments
:
if
cur_arguments
:
cur_args_json
=
json
.
dumps
(
cur_arguments
)
cur_args_json
=
json
.
dumps
(
cur_arguments
,
ensure_ascii
=
False
)
sent
=
len
(
sent
=
len
(
self
.
streamed_args_for_tool
[
self
.
current_tool_id
])
self
.
streamed_args_for_tool
[
self
.
current_tool_id
])
argument_diff
=
cur_args_json
[
sent
:]
argument_diff
=
cur_args_json
[
sent
:]
...
@@ -226,7 +228,8 @@ class Llama3JsonToolParser(ToolParser):
...
@@ -226,7 +228,8 @@ class Llama3JsonToolParser(ToolParser):
if
cur_arguments
:
if
cur_arguments
:
sent
=
len
(
sent
=
len
(
self
.
streamed_args_for_tool
[
self
.
current_tool_id
])
self
.
streamed_args_for_tool
[
self
.
current_tool_id
])
cur_args_json
=
json
.
dumps
(
cur_arguments
)
cur_args_json
=
json
.
dumps
(
cur_arguments
,
ensure_ascii
=
False
)
prev_arguments
=
self
.
prev_tool_call_arr
[
prev_arguments
=
self
.
prev_tool_call_arr
[
self
.
current_tool_id
].
get
(
"arguments"
)
self
.
current_tool_id
].
get
(
"arguments"
)
...
@@ -234,7 +237,8 @@ class Llama3JsonToolParser(ToolParser):
...
@@ -234,7 +237,8 @@ class Llama3JsonToolParser(ToolParser):
if
is_complete
[
self
.
current_tool_id
]:
if
is_complete
[
self
.
current_tool_id
]:
argument_diff
=
cur_args_json
[
sent
:]
argument_diff
=
cur_args_json
[
sent
:]
elif
prev_arguments
:
elif
prev_arguments
:
prev_args_json
=
json
.
dumps
(
prev_arguments
)
prev_args_json
=
json
.
dumps
(
prev_arguments
,
ensure_ascii
=
False
)
if
cur_args_json
!=
prev_args_json
:
if
cur_args_json
!=
prev_args_json
:
prefix
=
find_common_prefix
(
prefix
=
find_common_prefix
(
...
...
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
json
import
json
import
re
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
random
import
choices
from
random
import
choices
from
string
import
ascii_letters
,
digits
from
string
import
ascii_letters
,
digits
from
typing
import
Union
from
typing
import
Union
import
partial_json_parser
import
partial_json_parser
import
regex
as
re
from
partial_json_parser.core.options
import
Allow
from
partial_json_parser.core.options
import
Allow
from
pydantic
import
Field
from
pydantic
import
Field
...
...
vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py
View file @
4eabe123
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
json
import
json
import
re
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
import
regex
as
re
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.chat_utils
import
random_tool_call_id
from
vllm.entrypoints.chat_utils
import
random_tool_call_id
...
@@ -79,10 +79,11 @@ class Phi4MiniJsonToolParser(ToolParser):
...
@@ -79,10 +79,11 @@ class Phi4MiniJsonToolParser(ToolParser):
name
=
raw_function_call
[
"name"
],
name
=
raw_function_call
[
"name"
],
# function call args are JSON but as a string
# function call args are JSON but as a string
arguments
=
json
.
dumps
(
arguments
=
json
.
dumps
(
raw_function_call
[
"arguments"
]
if
"arguments"
in
raw_function_call
[
"arguments"
]
raw_function_call
else
if
"arguments"
in
raw_function_call
else
raw_function_call
[
"parameters"
])))
raw_function_call
[
"parameters"
],
for
raw_function_call
in
function_call_arr
ensure_ascii
=
False
),
))
for
raw_function_call
in
function_call_arr
]
]
# get any content before the tool call
# get any content before the tool call
...
...
vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py
View file @
4eabe123
...
@@ -2,10 +2,10 @@
...
@@ -2,10 +2,10 @@
import
ast
import
ast
import
json
import
json
import
re
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
typing
import
Any
,
Union
from
typing
import
Any
,
Union
import
regex
as
re
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
...
@@ -200,9 +200,12 @@ def _handle_single_tool(call: ast.Call) -> ToolCall:
...
@@ -200,9 +200,12 @@ def _handle_single_tool(call: ast.Call) -> ToolCall:
arguments
=
{}
arguments
=
{}
for
keyword
in
call
.
keywords
:
for
keyword
in
call
.
keywords
:
arguments
[
keyword
.
arg
]
=
_get_parameter_value
(
keyword
.
value
)
arguments
[
keyword
.
arg
]
=
_get_parameter_value
(
keyword
.
value
)
return
ToolCall
(
type
=
"function"
,
return
ToolCall
(
function
=
FunctionCall
(
name
=
function_name
,
type
=
"function"
,
arguments
=
json
.
dumps
(
arguments
)))
function
=
FunctionCall
(
name
=
function_name
,
arguments
=
json
.
dumps
(
arguments
,
ensure_ascii
=
False
)),
)
def
_make_valid_python
(
text
:
str
)
->
Union
[
tuple
[
str
,
str
],
None
]:
def
_make_valid_python
(
text
:
str
)
->
Union
[
tuple
[
str
,
str
],
None
]:
...
...
vllm/entrypoints/utils.py
View file @
4eabe123
...
@@ -13,6 +13,13 @@ from vllm.logger import init_logger
...
@@ -13,6 +13,13 @@ from vllm.logger import init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
VLLM_SERVE_PARSER_EPILOG
=
(
"Tip: Use `vllm serve --help=<keyword>` to explore arguments from help.
\n
"
" - To view a argument group: --help=ModelConfig
\n
"
" - To view a single argument: --help=max-num-seqs
\n
"
" - To search by keyword: --help=max
\n
"
" - To list all groups: --help=listgroup"
)
async
def
listen_for_disconnect
(
request
:
Request
)
->
None
:
async
def
listen_for_disconnect
(
request
:
Request
)
->
None
:
"""Returns if a disconnect message is received"""
"""Returns if a disconnect message is received"""
...
@@ -158,3 +165,55 @@ def _validate_truncation_size(
...
@@ -158,3 +165,55 @@ def _validate_truncation_size(
tokenization_kwargs
[
"max_length"
]
=
truncate_prompt_tokens
tokenization_kwargs
[
"max_length"
]
=
truncate_prompt_tokens
return
truncate_prompt_tokens
return
truncate_prompt_tokens
def
show_filtered_argument_or_group_from_help
(
parser
):
import
sys
for
arg
in
sys
.
argv
:
if
arg
.
startswith
(
'--help='
):
search_keyword
=
arg
.
split
(
'='
,
1
)[
1
]
# List available groups
if
search_keyword
==
'listgroup'
:
print
(
"
\n
Available argument groups:"
)
for
group
in
parser
.
_action_groups
:
if
group
.
title
and
not
group
.
title
.
startswith
(
"positional arguments"
):
print
(
f
" -
{
group
.
title
}
"
)
if
group
.
description
:
print
(
" "
+
group
.
description
.
strip
())
print
()
sys
.
exit
(
0
)
# For group search
formatter
=
parser
.
_get_formatter
()
for
group
in
parser
.
_action_groups
:
if
group
.
title
and
group
.
title
.
lower
()
==
search_keyword
.
lower
(
):
formatter
.
start_section
(
group
.
title
)
formatter
.
add_text
(
group
.
description
)
formatter
.
add_arguments
(
group
.
_group_actions
)
formatter
.
end_section
()
print
(
formatter
.
format_help
())
sys
.
exit
(
0
)
# For single arg
matched_actions
=
[]
for
group
in
parser
.
_action_groups
:
for
action
in
group
.
_group_actions
:
# search option name
if
any
(
search_keyword
.
lower
()
in
opt
.
lower
()
for
opt
in
action
.
option_strings
):
matched_actions
.
append
(
action
)
if
matched_actions
:
print
(
f
"
\n
Parameters matching '
{
search_keyword
}
':
\n
"
)
formatter
=
parser
.
_get_formatter
()
formatter
.
add_arguments
(
matched_actions
)
print
(
formatter
.
format_help
())
sys
.
exit
(
0
)
print
(
f
"
\n
No group or parameter matching '
{
search_keyword
}
'"
)
print
(
"Tip: use `--help=listgroup` to view all groups."
)
sys
.
exit
(
1
)
vllm/envs.py
View file @
4eabe123
...
@@ -117,6 +117,7 @@ if TYPE_CHECKING:
...
@@ -117,6 +117,7 @@ if TYPE_CHECKING:
VLLM_NIXL_SIDE_CHANNEL_HOST
:
str
=
"localhost"
VLLM_NIXL_SIDE_CHANNEL_HOST
:
str
=
"localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT
:
int
=
5557
VLLM_NIXL_SIDE_CHANNEL_PORT
:
int
=
5557
VLLM_ALL2ALL_BACKEND
:
str
=
"naive"
VLLM_ALL2ALL_BACKEND
:
str
=
"naive"
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
:
int
=
163840
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -163,7 +164,7 @@ def get_vllm_port() -> Optional[int]:
...
@@ -163,7 +164,7 @@ def get_vllm_port() -> Optional[int]:
raise
ValueError
(
raise
ValueError
(
f
"VLLM_PORT '
{
port
}
' appears to be a URI. "
f
"VLLM_PORT '
{
port
}
' appears to be a URI. "
"This may be caused by a Kubernetes service discovery issue"
"This may be caused by a Kubernetes service discovery issue"
"check the warning in: https://docs.vllm.ai/en/stable/
serving
/env_vars.html"
"check the warning in: https://docs.vllm.ai/en/stable/
usage
/env_vars.html"
)
)
except
Exception
:
except
Exception
:
pass
pass
...
@@ -175,7 +176,7 @@ def get_vllm_port() -> Optional[int]:
...
@@ -175,7 +176,7 @@ def get_vllm_port() -> Optional[int]:
# The begin-* and end* here are used by the documentation generator
# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
# to extract the used env vars.
#
begin-
env-vars-definition
#
--8<-- [start:
env-vars-definition
]
environment_variables
:
dict
[
str
,
Callable
[[],
Any
]]
=
{
environment_variables
:
dict
[
str
,
Callable
[[],
Any
]]
=
{
...
@@ -809,11 +810,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -809,11 +810,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
int
(
os
.
getenv
(
"VLLM_NIXL_SIDE_CHANNEL_PORT"
,
"5557"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_NIXL_SIDE_CHANNEL_PORT"
,
"5557"
)),
# all2all backend for vllm's expert parallel communication
# all2all backend for vllm's expert parallel communication
# Available options:
# - "naive": naive all2all implementation using all-reduce
# - "pplx": use pplx kernels
"VLLM_ALL2ALL_BACKEND"
:
"VLLM_ALL2ALL_BACKEND"
:
lambda
:
os
.
getenv
(
"VLLM_ALL2ALL_BACKEND"
,
"naive"
),
lambda
:
os
.
getenv
(
"VLLM_ALL2ALL_BACKEND"
,
"naive"
),
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
# the blockscale tensor of activations NVFP4 Quantization.
# This is used to prevent the kernel from running out of memory.
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE"
,
"163840"
)),
}
}
# end
-
env-vars-definition
#
--8<-- [
end
:
env-vars-definition
]
def
__getattr__
(
name
:
str
):
def
__getattr__
(
name
:
str
):
...
...
vllm/executor/executor_base.py
View file @
4eabe123
...
@@ -74,7 +74,7 @@ class ExecutorBase(ABC):
...
@@ -74,7 +74,7 @@ class ExecutorBase(ABC):
`self` argument, in addition to the arguments passed in `args`
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
timeout: Maximum time in seconds to wait for execution. Raises a
{exc}
`TimeoutError` on timeout. `None` means wait indefinitely.
[
`TimeoutError`
][]
on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
...
...
vllm/executor/ray_distributed_executor.py
View file @
4eabe123
...
@@ -528,12 +528,12 @@ class RayDistributedExecutor(DistributedExecutorBase):
...
@@ -528,12 +528,12 @@ class RayDistributedExecutor(DistributedExecutorBase):
ray
.
get
(
parallel_worker_tasks
)
ray
.
get
(
parallel_worker_tasks
)
def
_check_ray_cgraph_installation
(
self
):
def
_check_ray_cgraph_installation
(
self
):
import
pkg_resources
import
importlib.metadata
from
packaging
import
version
from
packaging
import
version
required_version
=
version
.
parse
(
"2.43.0"
)
required_version
=
version
.
parse
(
"2.43.0"
)
current_version
=
version
.
parse
(
current_version
=
version
.
parse
(
importlib
.
metadata
.
version
(
"ray"
))
pkg_resources
.
get_distribution
(
"ray"
).
version
)
if
current_version
<
required_version
:
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
is "
raise
ValueError
(
f
"Ray version
{
required_version
}
is "
f
"required, but found
{
current_version
}
"
)
f
"required, but found
{
current_version
}
"
)
...
...
vllm/executor/ray_utils.py
View file @
4eabe123
...
@@ -87,9 +87,8 @@ try:
...
@@ -87,9 +87,8 @@ try:
# TODO(swang): This is needed right now because Ray Compiled Graph
# TODO(swang): This is needed right now because Ray Compiled Graph
# executes on a background thread, so we need to reset torch's
# executes on a background thread, so we need to reset torch's
# current device.
# current device.
import
torch
if
not
self
.
compiled_dag_cuda_device_set
:
if
not
self
.
compiled_dag_cuda_device_set
:
torch
.
cuda
.
set_device
(
self
.
worker
.
device
)
current_platform
.
set_device
(
self
.
worker
.
device
)
self
.
compiled_dag_cuda_device_set
=
True
self
.
compiled_dag_cuda_device_set
=
True
output
=
self
.
worker
.
_execute_model_spmd
(
execute_model_req
,
output
=
self
.
worker
.
_execute_model_spmd
(
execute_model_req
,
...
@@ -113,8 +112,7 @@ try:
...
@@ -113,8 +112,7 @@ try:
# Not needed
# Not needed
pass
pass
else
:
else
:
import
torch
current_platform
.
set_device
(
self
.
worker
.
device
)
torch
.
cuda
.
set_device
(
self
.
worker
.
device
)
self
.
compiled_dag_cuda_device_set
=
True
self
.
compiled_dag_cuda_device_set
=
True
...
...
vllm/forward_context.py
View file @
4eabe123
...
@@ -120,7 +120,10 @@ def set_forward_context(attn_metadata: Any,
...
@@ -120,7 +120,10 @@ def set_forward_context(attn_metadata: Any,
# we use synchronous scheduling right now,
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# adding a sync point here should not affect
# scheduling of the next batch
# scheduling of the next batch
torch
.
cuda
.
synchronize
()
from
vllm.platforms
import
current_platform
synchronize
=
current_platform
.
synchronize
if
synchronize
is
not
None
:
synchronize
()
now
=
time
.
perf_counter
()
now
=
time
.
perf_counter
()
# time measurement is in milliseconds
# time measurement is in milliseconds
batchsize_forward_time
[
batchsize
].
append
(
batchsize_forward_time
[
batchsize
].
append
(
...
...
vllm/inputs/__init__.py
View file @
4eabe123
...
@@ -10,8 +10,9 @@ from .registry import (DummyData, InputContext, InputProcessingContext,
...
@@ -10,8 +10,9 @@ from .registry import (DummyData, InputContext, InputProcessingContext,
INPUT_REGISTRY
=
InputRegistry
()
INPUT_REGISTRY
=
InputRegistry
()
"""
"""
The global {class}`~InputRegistry` which is used by {class}`~vllm.LLMEngine`
The global [`InputRegistry`][vllm.inputs.registry.InputRegistry] which is used
to dispatch data processing according to the target model.
by [`LLMEngine`][vllm.LLMEngine] to dispatch data processing according to the
target model.
"""
"""
__all__
=
[
__all__
=
[
...
...
vllm/inputs/data.py
View file @
4eabe123
...
@@ -80,22 +80,24 @@ SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
...
@@ -80,22 +80,24 @@ SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
"""
"""
Set of possible schemas for a single prompt:
Set of possible schemas for a single prompt:
- A text prompt (
{class}
`str` or
{class}`
TextPrompt
`
)
- A text prompt (
[
`str`
][]
or
[`TextPrompt`][vllm.inputs.data.
TextPrompt
]
)
- A tokenized prompt (
{class}
`TokensPrompt`)
- A tokenized prompt (
[
`TokensPrompt`
][vllm.inputs.data.TokensPrompt]
)
- An embeddings prompt (
{class}
`EmbedsPrompt`)
- An embeddings prompt (
[
`EmbedsPrompt`
][vllm.inputs.data.EmbedsPrompt]
)
Note that "singleton" is as opposed to a data structure
Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
which encapsulates multiple prompts, i.e. of the sort
which may be utilized for encoder/decoder models when
which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
the user desires to express both the encoder & decoder
prompts explicitly, i.e. {class}`ExplicitEncoderDecoderPrompt`
prompts explicitly, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
A prompt of type
{class}`
SingletonPrompt
`
may be
employed
A prompt of type
[`SingletonPrompt`][vllm.inputs.data.
SingletonPrompt
]
may be
as (1) input to a decoder-only model, (2) input to
employed
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
where the decoder-prompt is not specified explicitly, or
(3) as a member of a larger data structure encapsulating
(3) as a member of a larger data structure encapsulating
more than one prompt, i.e. {class}`ExplicitEncoderDecoderPrompt`
more than one prompt, i.e.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
"""
"""
...
@@ -126,18 +128,20 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
...
@@ -126,18 +128,20 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
comprising an explicit encoder prompt and a decoder prompt.
comprising an explicit encoder prompt and a decoder prompt.
The encoder and decoder prompts, respectively, may be formatted
The encoder and decoder prompts, respectively, may be formatted
according to any of the {class}`SingletonPrompt` schemas,
according to any of the
[`SingletonPrompt`][vllm.inputs.data.SingletonPrompt] schemas,
and are not required to have the same schema.
and are not required to have the same schema.
Only the encoder prompt may have multi-modal data. mm_processor_kwargs
Only the encoder prompt may have multi-modal data. mm_processor_kwargs
should be at the top-level, and should not be set in the encoder/decoder
should be at the top-level, and should not be set in the encoder/decoder
prompts, since they are agnostic to the encoder/decoder.
prompts, since they are agnostic to the encoder/decoder.
Note that an {class}`ExplicitEncoderDecoderPrompt` may not
Note that an
be used as an input to a decoder-only model,
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
may not be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure themselves must be
fields of this data structure themselves must be
{class}`
SingletonPrompt
`
instances.
[`SingletonPrompt`][vllm.inputs.data.
SingletonPrompt
]
instances.
"""
"""
encoder_prompt
:
_T1_co
encoder_prompt
:
_T1_co
...
@@ -152,11 +156,11 @@ PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
...
@@ -152,11 +156,11 @@ PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
Set of possible schemas for an LLM input, including
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
both decoder-only and encoder/decoder input types:
- A text prompt (
{class}
`str` or
{class}`
TextPrompt
`
)
- A text prompt (
[
`str`
][]
or
[`TextPrompt`][vllm.inputs.data.
TextPrompt
]
)
- A tokenized prompt (
{class}
`TokensPrompt`)
- A tokenized prompt (
[
`TokensPrompt`
][vllm.inputs.data.TokensPrompt]
)
- An embeddings prompt (
{class}
`EmbedsPrompt`)
- An embeddings prompt (
[
`EmbedsPrompt`
][vllm.inputs.data.EmbedsPrompt]
)
- A single data structure containing both an encoder and a decoder prompt
- A single data structure containing both an encoder and a decoder prompt
(
{class}`
ExplicitEncoderDecoderPrompt
`
)
(
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.
ExplicitEncoderDecoderPrompt
]
)
"""
"""
...
@@ -189,7 +193,8 @@ def token_inputs(
...
@@ -189,7 +193,8 @@ def token_inputs(
prompt
:
Optional
[
str
]
=
None
,
prompt
:
Optional
[
str
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
)
->
TokenInputs
:
)
->
TokenInputs
:
"""Construct {class}`TokenInputs` from optional values."""
"""Construct [`TokenInputs`][vllm.inputs.data.TokenInputs] from optional
values."""
inputs
=
TokenInputs
(
type
=
"token"
,
prompt_token_ids
=
prompt_token_ids
)
inputs
=
TokenInputs
(
type
=
"token"
,
prompt_token_ids
=
prompt_token_ids
)
if
prompt
is
not
None
:
if
prompt
is
not
None
:
...
@@ -221,7 +226,8 @@ def embeds_inputs(
...
@@ -221,7 +226,8 @@ def embeds_inputs(
prompt_embeds
:
torch
.
Tensor
,
prompt_embeds
:
torch
.
Tensor
,
cache_salt
:
Optional
[
str
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
)
->
EmbedsInputs
:
)
->
EmbedsInputs
:
"""Construct :class:`EmbedsInputs` from optional values."""
"""Construct [`EmbedsInputs`][vllm.inputs.data.EmbedsInputs] from optional
values."""
inputs
=
EmbedsInputs
(
type
=
"embeds"
,
prompt_embeds
=
prompt_embeds
)
inputs
=
EmbedsInputs
(
type
=
"embeds"
,
prompt_embeds
=
prompt_embeds
)
if
cache_salt
is
not
None
:
if
cache_salt
is
not
None
:
...
@@ -232,7 +238,7 @@ def embeds_inputs(
...
@@ -232,7 +238,7 @@ def embeds_inputs(
DecoderOnlyInputs
=
Union
[
TokenInputs
,
EmbedsInputs
,
"MultiModalInputs"
]
DecoderOnlyInputs
=
Union
[
TokenInputs
,
EmbedsInputs
,
"MultiModalInputs"
]
"""
"""
The inputs in
{class}`~vllm
.LLMEngine
`
before they are
The inputs in
[`LLMEngine`][vllm.engine.llm_engine
.LLMEngine
]
before they are
passed to the model executor.
passed to the model executor.
This specifies the data required for decoder-only models.
This specifies the data required for decoder-only models.
"""
"""
...
@@ -240,11 +246,12 @@ This specifies the data required for decoder-only models.
...
@@ -240,11 +246,12 @@ This specifies the data required for decoder-only models.
class
EncoderDecoderInputs
(
TypedDict
):
class
EncoderDecoderInputs
(
TypedDict
):
"""
"""
The inputs in
{class}`~vllm
.LLMEngine
`
before they
are
The inputs in
[`LLMEngine`][vllm.engine.llm_engine
.LLMEngine
]
before they
passed to the model executor.
are
passed to the model executor.
This specifies the required data for encoder-decoder models.
This specifies the required data for encoder-decoder models.
"""
"""
encoder
:
Union
[
TokenInputs
,
"MultiModalInputs"
]
encoder
:
Union
[
TokenInputs
,
"MultiModalInputs"
]
"""The inputs for the encoder portion."""
"""The inputs for the encoder portion."""
...
@@ -254,13 +261,13 @@ class EncoderDecoderInputs(TypedDict):
...
@@ -254,13 +261,13 @@ class EncoderDecoderInputs(TypedDict):
SingletonInputs
=
Union
[
TokenInputs
,
EmbedsInputs
,
"MultiModalInputs"
]
SingletonInputs
=
Union
[
TokenInputs
,
EmbedsInputs
,
"MultiModalInputs"
]
"""
"""
A processed
{class}
`SingletonPrompt` which can be
passed to
A processed
[
`SingletonPrompt`
][vllm.inputs.data.SingletonPrompt]
which can be
{class}
`vllm.sequence.Sequence`.
passed to [
`vllm.sequence.Sequence`
][]
.
"""
"""
ProcessorInputs
=
Union
[
DecoderOnlyInputs
,
EncoderDecoderInputs
]
ProcessorInputs
=
Union
[
DecoderOnlyInputs
,
EncoderDecoderInputs
]
"""
"""
The
in
puts
to {data}
`vllm.inputs.InputProcessor`.
The
out
puts
from [
`vllm.inputs.
preprocess.
InputPr
epr
ocessor`
][]
.
"""
"""
_T1
=
TypeVar
(
"_T1"
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
)
_T1
=
TypeVar
(
"_T1"
,
bound
=
SingletonPrompt
,
default
=
SingletonPrompt
)
...
@@ -277,7 +284,8 @@ def build_explicit_enc_dec_prompt(
...
@@ -277,7 +284,8 @@ def build_explicit_enc_dec_prompt(
return
ExplicitEncoderDecoderPrompt
(
return
ExplicitEncoderDecoderPrompt
(
encoder_prompt
=
encoder_prompt
,
encoder_prompt
=
encoder_prompt
,
decoder_prompt
=
decoder_prompt
,
decoder_prompt
=
decoder_prompt
,
mm_processor_kwargs
=
mm_processor_kwargs
)
mm_processor_kwargs
=
mm_processor_kwargs
,
)
def
zip_enc_dec_prompts
(
def
zip_enc_dec_prompts
(
...
@@ -288,7 +296,8 @@ def zip_enc_dec_prompts(
...
@@ -288,7 +296,8 @@ def zip_enc_dec_prompts(
)
->
list
[
ExplicitEncoderDecoderPrompt
[
_T1
,
_T2
]]:
)
->
list
[
ExplicitEncoderDecoderPrompt
[
_T1
,
_T2
]]:
"""
"""
Zip encoder and decoder prompts together into a list of
Zip encoder and decoder prompts together into a list of
{class}`ExplicitEncoderDecoderPrompt` instances.
[`ExplicitEncoderDecoderPrompt`][vllm.inputs.data.ExplicitEncoderDecoderPrompt]
instances.
``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
``mm_processor_kwargs`` may also be provided; if a dict is passed, the same
dictionary will be used for every encoder/decoder prompt. If an iterable is
dictionary will be used for every encoder/decoder prompt. If an iterable is
...
@@ -299,10 +308,11 @@ def zip_enc_dec_prompts(
...
@@ -299,10 +308,11 @@ def zip_enc_dec_prompts(
if
isinstance
(
mm_processor_kwargs
,
dict
):
if
isinstance
(
mm_processor_kwargs
,
dict
):
return
[
return
[
build_explicit_enc_dec_prompt
(
build_explicit_enc_dec_prompt
(
encoder_prompt
,
decoder_prompt
,
encoder_prompt
,
cast
(
dict
[
str
,
Any
],
mm_processor_kwargs
))
decoder_prompt
,
for
(
encoder_prompt
,
cast
(
dict
[
str
,
Any
],
mm_processor_kwargs
),
decoder_prompt
)
in
zip
(
enc_prompts
,
dec_prompts
)
)
for
(
encoder_prompt
,
decoder_prompt
)
in
zip
(
enc_prompts
,
dec_prompts
)
]
]
return
[
return
[
build_explicit_enc_dec_prompt
(
encoder_prompt
,
decoder_prompt
,
build_explicit_enc_dec_prompt
(
encoder_prompt
,
decoder_prompt
,
...
...
vllm/inputs/parse.py
View file @
4eabe123
...
@@ -23,13 +23,13 @@ class ParsedTokens(TypedDict):
...
@@ -23,13 +23,13 @@ class ParsedTokens(TypedDict):
@
overload
@
overload
def
parse_and_batch_prompt
(
def
parse_and_batch_prompt
(
prompt
:
Union
[
str
,
list
[
str
]])
->
Sequence
[
ParsedText
]:
prompt
:
Union
[
str
,
list
[
str
]]
,
)
->
Sequence
[
ParsedText
]:
...
...
@
overload
@
overload
def
parse_and_batch_prompt
(
def
parse_and_batch_prompt
(
prompt
:
Union
[
list
[
int
],
list
[
list
[
int
]]])
->
Sequence
[
ParsedTokens
]:
prompt
:
Union
[
list
[
int
],
list
[
list
[
int
]]]
,
)
->
Sequence
[
ParsedTokens
]:
...
...
...
@@ -86,7 +86,7 @@ class ParsedTokensPrompt(TypedDict):
...
@@ -86,7 +86,7 @@ class ParsedTokensPrompt(TypedDict):
class
ParsedEmbedsPrompt
(
TypedDict
):
class
ParsedEmbedsPrompt
(
TypedDict
):
type
:
Literal
[
'
embeds
'
]
type
:
Literal
[
"
embeds
"
]
content
:
EmbedsPrompt
content
:
EmbedsPrompt
...
@@ -133,7 +133,7 @@ def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
...
@@ -133,7 +133,7 @@ def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
def
is_explicit_encoder_decoder_prompt
(
def
is_explicit_encoder_decoder_prompt
(
prompt
:
PromptType
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
prompt
:
PromptType
,
)
->
TypeIs
[
ExplicitEncoderDecoderPrompt
]:
return
isinstance
(
prompt
,
dict
)
and
"encoder_prompt"
in
prompt
return
isinstance
(
prompt
,
dict
)
and
"encoder_prompt"
in
prompt
...
...
vllm/inputs/preprocess.py
View file @
4eabe123
...
@@ -67,11 +67,11 @@ class InputPreprocessor:
...
@@ -67,11 +67,11 @@ class InputPreprocessor:
return
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
eos_token_id
return
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
eos_token_id
def
get_decoder_start_token_id
(
self
)
->
Optional
[
int
]:
def
get_decoder_start_token_id
(
self
)
->
Optional
[
int
]:
'''
"""
Obtain the decoder start token id employed by an encoder/decoder
Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the
model. Returns None for non-encoder/decoder models or if the
model config is unavailable.
model config is unavailable.
'''
"""
if
not
self
.
model_config
.
is_encoder_decoder
:
if
not
self
.
model_config
.
is_encoder_decoder
:
logger
.
warning_once
(
logger
.
warning_once
(
...
@@ -79,14 +79,14 @@ class InputPreprocessor:
...
@@ -79,14 +79,14 @@ class InputPreprocessor:
"this is not an encoder/decoder model."
)
"this is not an encoder/decoder model."
)
return
None
return
None
if
(
self
.
model_config
is
None
or
self
.
model_config
.
hf_config
is
None
)
:
if
self
.
model_config
is
None
or
self
.
model_config
.
hf_config
is
None
:
logger
.
warning_once
(
logger
.
warning_once
(
"Using None for decoder start token id because "
"Using None for decoder start token id because "
"model config is not available."
)
"model config is not available."
)
return
None
return
None
dec_start_token_id
=
getattr
(
self
.
model_config
.
hf_config
,
dec_start_token_id
=
getattr
(
self
.
model_config
.
hf_config
,
'
decoder_start_token_id
'
,
None
)
"
decoder_start_token_id
"
,
None
)
if
dec_start_token_id
is
None
:
if
dec_start_token_id
is
None
:
logger
.
warning_once
(
logger
.
warning_once
(
"Falling back on <BOS> for decoder start token "
"Falling back on <BOS> for decoder start token "
...
@@ -97,7 +97,7 @@ class InputPreprocessor:
...
@@ -97,7 +97,7 @@ class InputPreprocessor:
return
dec_start_token_id
return
dec_start_token_id
def
_get_default_enc_dec_decoder_prompt
(
self
)
->
list
[
int
]:
def
_get_default_enc_dec_decoder_prompt
(
self
)
->
list
[
int
]:
'''
"""
Specifically for encoder/decoder models:
Specifically for encoder/decoder models:
generate a default decoder prompt for when
generate a default decoder prompt for when
the user specifies only the encoder prompt.
the user specifies only the encoder prompt.
...
@@ -126,7 +126,7 @@ class InputPreprocessor:
...
@@ -126,7 +126,7 @@ class InputPreprocessor:
Returns:
Returns:
* prompt_token_ids
* prompt_token_ids
'''
"""
bos_token_id
=
self
.
get_bos_token_id
()
bos_token_id
=
self
.
get_bos_token_id
()
assert
bos_token_id
is
not
None
assert
bos_token_id
is
not
None
...
@@ -224,7 +224,10 @@ class InputPreprocessor:
...
@@ -224,7 +224,10 @@ class InputPreprocessor:
lora_request
:
Optional
[
LoRARequest
],
lora_request
:
Optional
[
LoRARequest
],
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
list
[
int
]:
)
->
list
[
int
]:
"""Async version of {meth}`_tokenize_prompt`."""
"""
Async version of
[`_tokenize_prompt`][vllm.inputs.preprocess.InputPreprocessor._tokenize_prompt].
"""
tokenizer
=
self
.
get_tokenizer_group
()
tokenizer
=
self
.
get_tokenizer_group
()
tokenization_kwargs
=
self
.
_get_tokenization_kw
(
tokenization_kwargs
)
tokenization_kwargs
=
self
.
_get_tokenization_kw
(
tokenization_kwargs
)
...
@@ -287,7 +290,10 @@ class InputPreprocessor:
...
@@ -287,7 +290,10 @@ class InputPreprocessor:
lora_request
:
Optional
[
LoRARequest
],
lora_request
:
Optional
[
LoRARequest
],
return_mm_hashes
:
bool
=
False
,
return_mm_hashes
:
bool
=
False
,
)
->
MultiModalInputs
:
)
->
MultiModalInputs
:
"""Async version of {meth}`_process_multimodal`."""
"""
Async version of
[`_process_multimodal`][vllm.inputs.preprocess.InputPreprocessor._process_multimodal].
"""
tokenizer
=
await
self
.
_get_mm_tokenizer_async
(
lora_request
)
tokenizer
=
await
self
.
_get_mm_tokenizer_async
(
lora_request
)
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
mm_processor
=
self
.
mm_registry
.
create_processor
(
self
.
model_config
,
...
@@ -472,7 +478,7 @@ class InputPreprocessor:
...
@@ -472,7 +478,7 @@ class InputPreprocessor:
Returns:
Returns:
*
{class}`
SingletonInputs
`
instance
*
[`SingletonInputs`][vllm.inputs.data.
SingletonInputs
]
instance
"""
"""
parsed
=
parse_singleton_prompt
(
prompt
)
parsed
=
parse_singleton_prompt
(
prompt
)
...
@@ -508,7 +514,10 @@ class InputPreprocessor:
...
@@ -508,7 +514,10 @@ class InputPreprocessor:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
return_mm_hashes
:
bool
=
False
,
)
->
SingletonInputs
:
)
->
SingletonInputs
:
"""Async version of {meth}`_prompt_to_llm_inputs`."""
"""
Async version of
[`_prompt_to_llm_inputs`][vllm.inputs.preprocess.InputPreprocessor._prompt_to_llm_inputs].
"""
parsed
=
parse_singleton_prompt
(
prompt
)
parsed
=
parse_singleton_prompt
(
prompt
)
if
parsed
[
"type"
]
==
"embeds"
:
if
parsed
[
"type"
]
==
"embeds"
:
...
@@ -644,7 +653,9 @@ class InputPreprocessor:
...
@@ -644,7 +653,9 @@ class InputPreprocessor:
)
->
EncoderDecoderInputs
:
)
->
EncoderDecoderInputs
:
"""
"""
For encoder/decoder models only:
For encoder/decoder models only:
Process an input prompt into an {class}`EncoderDecoderInputs` instance.
Process an input prompt into an
[`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance.
There are two types of input prompts:
There are two types of input prompts:
singleton prompts which carry only the
singleton prompts which carry only the
...
@@ -670,7 +681,8 @@ class InputPreprocessor:
...
@@ -670,7 +681,8 @@ class InputPreprocessor:
Returns:
Returns:
* {class}`EncoderDecoderInputs` instance
* [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
instance
"""
"""
encoder_inputs
:
SingletonInputs
encoder_inputs
:
SingletonInputs
decoder_inputs
:
Optional
[
SingletonInputs
]
decoder_inputs
:
Optional
[
SingletonInputs
]
...
@@ -710,7 +722,10 @@ class InputPreprocessor:
...
@@ -710,7 +722,10 @@ class InputPreprocessor:
prompt
:
PromptType
,
prompt
:
PromptType
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
tokenization_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
EncoderDecoderInputs
:
)
->
EncoderDecoderInputs
:
"""Async version of {meth}`_process_encoder_decoder_prompt`."""
"""
Async version of
[`_process_encoder_decoder_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_encoder_decoder_prompt].
"""
encoder_inputs
:
SingletonInputs
encoder_inputs
:
SingletonInputs
decoder_inputs
:
Optional
[
SingletonInputs
]
decoder_inputs
:
Optional
[
SingletonInputs
]
...
@@ -778,7 +793,8 @@ class InputPreprocessor:
...
@@ -778,7 +793,8 @@ class InputPreprocessor:
)
->
DecoderOnlyInputs
:
)
->
DecoderOnlyInputs
:
"""
"""
For decoder-only models:
For decoder-only models:
Process an input prompt into an {class}`DecoderOnlyInputs` instance.
Process an input prompt into a
[`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.
Arguments:
Arguments:
...
@@ -789,7 +805,7 @@ class InputPreprocessor:
...
@@ -789,7 +805,7 @@ class InputPreprocessor:
Returns:
Returns:
*
{class}`
DecoderOnlyInputs
`
instance
*
[`DecoderOnlyInputs`][vllm.inputs.data.
DecoderOnlyInputs
]
instance
"""
"""
prompt_comps
=
self
.
_prompt_to_llm_inputs
(
prompt_comps
=
self
.
_prompt_to_llm_inputs
(
...
@@ -812,7 +828,10 @@ class InputPreprocessor:
...
@@ -812,7 +828,10 @@ class InputPreprocessor:
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
return_mm_hashes
:
bool
=
False
,
)
->
DecoderOnlyInputs
:
)
->
DecoderOnlyInputs
:
"""Async version of {meth}`_process_decoder_only_prompt`."""
"""
Async version of
[`_process_decoder_only_prompt`][vllm.inputs.preprocess.InputPreprocessor._process_decoder_only_prompt].
"""
prompt_comps
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt_comps
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt
,
prompt
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
...
@@ -863,7 +882,10 @@ class InputPreprocessor:
...
@@ -863,7 +882,10 @@ class InputPreprocessor:
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
return_mm_hashes
:
bool
=
False
,
return_mm_hashes
:
bool
=
False
,
)
->
ProcessorInputs
:
)
->
ProcessorInputs
:
"""Async version of {meth}`preprocess`."""
"""
Async version of
[`preprocess`][vllm.inputs.preprocess.InputPreprocessor.preprocess].
"""
if
self
.
model_config
.
is_encoder_decoder
:
if
self
.
model_config
.
is_encoder_decoder
:
assert
not
return_mm_hashes
,
(
assert
not
return_mm_hashes
,
(
"Multimodal hashes for encoder-decoder models should not be "
,
"Multimodal hashes for encoder-decoder models should not be "
,
...
...
vllm/inputs/registry.py
View file @
4eabe123
...
@@ -38,7 +38,7 @@ class InputContext:
...
@@ -38,7 +38,7 @@ class InputContext:
)
->
_C
:
)
->
_C
:
"""
"""
Get the HuggingFace configuration
Get the HuggingFace configuration
(
{class}
`transformers.PretrainedConfig`) of the model,
(`transformers.PretrainedConfig`) of the model,
additionally checking its type.
additionally checking its type.
Raises:
Raises:
...
@@ -79,7 +79,7 @@ class InputContext:
...
@@ -79,7 +79,7 @@ class InputContext:
)
->
_P
:
)
->
_P
:
"""
"""
Get the HuggingFace processor
Get the HuggingFace processor
(
{class}
`transformers.ProcessorMixin`) of the model,
(`transformers.ProcessorMixin`) of the model,
additionally checking its type.
additionally checking its type.
Raises:
Raises:
...
...
vllm/logger.py
View file @
4eabe123
...
@@ -68,22 +68,22 @@ class _VllmLogger(Logger):
...
@@ -68,22 +68,22 @@ class _VllmLogger(Logger):
"""
"""
Note:
Note:
This class is just to provide type information.
This class is just to provide type information.
We actually patch the methods directly on the
{class}
`logging.Logger`
We actually patch the methods directly on the
[
`logging.Logger`
][]
instance to avoid conflicting with other libraries such as
instance to avoid conflicting with other libraries such as
`intel_extension_for_pytorch.utils._logger`.
`intel_extension_for_pytorch.utils._logger`.
"""
"""
def
info_once
(
self
,
msg
:
str
,
*
args
:
Hashable
)
->
None
:
def
info_once
(
self
,
msg
:
str
,
*
args
:
Hashable
)
->
None
:
"""
"""
As
{meth}`info`
, but subsequent calls with
the same message
As
[`info`][logging.Logger.info]
, but subsequent calls with
are silently dropped.
the same message
are silently dropped.
"""
"""
_print_info_once
(
self
,
msg
,
*
args
)
_print_info_once
(
self
,
msg
,
*
args
)
def
warning_once
(
self
,
msg
:
str
,
*
args
:
Hashable
)
->
None
:
def
warning_once
(
self
,
msg
:
str
,
*
args
:
Hashable
)
->
None
:
"""
"""
As
{meth}
`warning`, but subsequent calls with
the same message
As
[
`warning`
][logging.Logger.warning]
, but subsequent calls with
are silently dropped.
the same message
are silently dropped.
"""
"""
_print_warning_once
(
self
,
msg
,
*
args
)
_print_warning_once
(
self
,
msg
,
*
args
)
...
...
vllm/logging_utils/dump_input.py
View file @
4eabe123
...
@@ -18,7 +18,7 @@ logger = init_logger(__name__)
...
@@ -18,7 +18,7 @@ logger = init_logger(__name__)
def
prepare_object_to_dump
(
obj
)
->
str
:
def
prepare_object_to_dump
(
obj
)
->
str
:
if
isinstance
(
obj
,
str
):
if
isinstance
(
obj
,
str
):
return
"'{obj}'"
# Double quotes
return
f
"'
{
obj
}
'"
# Double quotes
elif
isinstance
(
obj
,
dict
):
elif
isinstance
(
obj
,
dict
):
dict_str
=
', '
.
join
({
f
'
{
str
(
k
)
}
:
{
prepare_object_to_dump
(
v
)
}
'
\
dict_str
=
', '
.
join
({
f
'
{
str
(
k
)
}
:
{
prepare_object_to_dump
(
v
)
}
'
\
for
k
,
v
in
obj
.
items
()})
for
k
,
v
in
obj
.
items
()})
...
@@ -42,9 +42,9 @@ def prepare_object_to_dump(obj) -> str:
...
@@ -42,9 +42,9 @@ def prepare_object_to_dump(obj) -> str:
return
obj
.
anon_repr
()
return
obj
.
anon_repr
()
elif
hasattr
(
obj
,
'__dict__'
):
elif
hasattr
(
obj
,
'__dict__'
):
items
=
obj
.
__dict__
.
items
()
items
=
obj
.
__dict__
.
items
()
dict_str
=
','
.
join
([
f
'
{
str
(
k
)
}
=
{
prepare_object_to_dump
(
v
)
}
'
\
dict_str
=
',
'
.
join
([
f
'
{
str
(
k
)
}
=
{
prepare_object_to_dump
(
v
)
}
'
\
for
k
,
v
in
items
])
for
k
,
v
in
items
])
return
(
f
"
{
type
(
obj
).
__name__
}
(
{
dict_str
}
)"
)
return
f
"
{
type
(
obj
).
__name__
}
(
{
dict_str
}
)"
else
:
else
:
# Hacky way to make sure we can serialize the object in JSON format
# Hacky way to make sure we can serialize the object in JSON format
try
:
try
:
...
...
vllm/lora/models.py
View file @
4eabe123
...
@@ -3,11 +3,11 @@
...
@@ -3,11 +3,11 @@
import
copy
import
copy
import
math
import
math
import
os
import
os
import
re
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
import
regex
as
re
import
safetensors.torch
import
safetensors.torch
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -29,6 +29,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
...
@@ -29,6 +29,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
get_supported_lora_modules
,
get_supported_lora_modules
,
is_regex_target_modules
,
is_regex_target_modules
,
parse_fine_tuned_lora_name
,
replace_submodule
)
parse_fine_tuned_lora_name
,
replace_submodule
)
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.models
import
SupportsLoRA
,
supports_multimodal
from
vllm.model_executor.models
import
SupportsLoRA
,
supports_multimodal
from
vllm.model_executor.models.interfaces
import
is_pooling_model
from
vllm.model_executor.models.interfaces
import
is_pooling_model
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
...
@@ -185,19 +186,19 @@ class LoRAModel(AdapterModel):
...
@@ -185,19 +186,19 @@ class LoRAModel(AdapterModel):
@
classmethod
@
classmethod
def
from_local_checkpoint
(
def
from_local_checkpoint
(
cls
,
cls
,
lora_dir
:
str
,
lora_dir
:
str
,
expected_lora_modules
:
list
[
str
],
expected_lora_modules
:
list
[
str
],
peft_helper
:
PEFTHelper
,
peft_helper
:
PEFTHelper
,
*
,
*
,
lora_model_id
:
Optional
[
int
]
=
None
,
lora_model_id
:
Optional
[
int
]
=
None
,
device
:
str
=
"cuda"
,
device
:
str
=
"cuda"
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
target_embedding_padding
:
Optional
[
int
]
=
None
,
target_embedding_padding
:
Optional
[
int
]
=
None
,
embedding_modules
:
Optional
[
dict
[
str
,
str
]]
=
None
,
embedding_modules
:
Optional
[
dict
[
str
,
str
]]
=
None
,
embedding_padding_modules
:
Optional
[
list
[
str
]]
=
None
,
embedding_padding_modules
:
Optional
[
list
[
str
]]
=
None
,
weights_mapper
:
Optional
[
WeightsMapper
]
=
None
,
weights_mapper
:
Optional
[
WeightsMapper
]
=
None
,
)
->
"LoRAModel"
:
tensorizer_config_dict
:
Optional
[
dict
]
=
None
)
->
"LoRAModel"
:
"""Create a LoRAModel from a local checkpoint.
"""Create a LoRAModel from a local checkpoint.
Args:
Args:
...
@@ -219,10 +220,36 @@ class LoRAModel(AdapterModel):
...
@@ -219,10 +220,36 @@ class LoRAModel(AdapterModel):
lora_dir
,
"new_embeddings.safetensors"
)
lora_dir
,
"new_embeddings.safetensors"
)
new_embeddings_bin_file_path
=
os
.
path
.
join
(
lora_dir
,
new_embeddings_bin_file_path
=
os
.
path
.
join
(
lora_dir
,
"new_embeddings.bin"
)
"new_embeddings.bin"
)
tensors
:
dict
[
str
,
torch
.
Tensor
]
=
{}
unexpected_modules
:
list
[
Union
[
list
[
str
],
str
]]
=
[]
def
check_unexpected_modules
(
modules
:
dict
):
for
lora_module
in
modules
.
keys
():
# noqa
module_name
,
_
,
_
=
parse_fine_tuned_lora_name
(
lora_module
,
weights_mapper
)
part_name
=
module_name
.
split
(
"."
)[
-
1
]
if
part_name
not
in
expected_lora_modules
:
unexpected_modules
.
append
(
module_name
)
if
unexpected_modules
:
raise
ValueError
(
f
"While loading
{
lora_dir
}
, expected"
f
" target modules in
{
expected_lora_modules
}
"
f
" but received
{
unexpected_modules
}
."
f
" Please verify that the loaded LoRA module is correct"
)
unexpected_modules
:
list
[
Union
[
list
[
str
],
str
]]
if
tensorizer_config_dict
:
if
os
.
path
.
isfile
(
lora_tensor_path
):
from
tensorizer
import
TensorDeserializer
tensors
:
dict
[
str
,
torch
.
Tensor
]
=
{}
tensorizer_config
=
TensorizerConfig
(
**
tensorizer_config_dict
)
lora_tensor_path
=
os
.
path
.
join
(
tensorizer_config
.
tensorizer_dir
,
"adapter_model.tensors"
)
tensorizer_args
=
tensorizer_config
.
_construct_tensorizer_args
()
tensors
=
TensorDeserializer
(
lora_tensor_path
,
dtype
=
tensorizer_config
.
dtype
,
**
tensorizer_args
.
deserializer_params
)
check_unexpected_modules
(
tensors
)
elif
os
.
path
.
isfile
(
lora_tensor_path
):
# Find unexpected modules.
# Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules.
# Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist
# in peft if you have target_modules A, B, C and C does not exist
...
@@ -232,20 +259,8 @@ class LoRAModel(AdapterModel):
...
@@ -232,20 +259,8 @@ class LoRAModel(AdapterModel):
unexpected_modules
=
[]
unexpected_modules
=
[]
with
safetensors
.
safe_open
(
lora_tensor_path
,
with
safetensors
.
safe_open
(
lora_tensor_path
,
framework
=
"pt"
)
as
f
:
# type: ignore
framework
=
"pt"
)
as
f
:
# type: ignore
for
lora_module
in
f
.
keys
():
# noqa
module_name
,
_
,
_
=
parse_fine_tuned_lora_name
(
lora_module
,
weights_mapper
)
part_name
=
module_name
.
split
(
"."
)[
-
1
]
if
part_name
not
in
expected_lora_modules
:
unexpected_modules
.
append
(
module_name
)
if
unexpected_modules
:
raise
ValueError
(
f
"While loading
{
lora_dir
}
, expected"
f
" target modules in
{
expected_lora_modules
}
"
f
" but received
{
unexpected_modules
}
."
f
" Please verify that the loaded LoRA module is correct"
)
# Load tensors if there are only expected modules.
# Load tensors if there are only expected modules.
check_unexpected_modules
(
f
)
for
module
in
f
.
keys
():
# noqa
for
module
in
f
.
keys
():
# noqa
tensors
[
module
]
=
f
.
get_tensor
(
module
)
tensors
[
module
]
=
f
.
get_tensor
(
module
)
elif
os
.
path
.
isfile
(
lora_bin_file_path
):
elif
os
.
path
.
isfile
(
lora_bin_file_path
):
...
...
Prev
1
…
21
22
23
24
25
26
27
28
29
…
34
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment