Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dcb5624a
Commit
dcb5624a
authored
Apr 29, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.5' into v0.8.5-dev
parents
55880ca2
ba41cc90
Changes
554
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
567 additions
and
657 deletions
+567
-657
vllm/entrypoints/cli/main.py
vllm/entrypoints/cli/main.py
+2
-0
vllm/entrypoints/launcher.py
vllm/entrypoints/launcher.py
+62
-36
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+48
-34
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+31
-22
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+12
-12
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+104
-16
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+2
-2
vllm/entrypoints/openai/serving_engine.py
vllm/entrypoints/openai/serving_engine.py
+13
-1
vllm/entrypoints/openai/serving_models.py
vllm/entrypoints/openai/serving_models.py
+70
-0
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
+1
-0
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+17
-0
vllm/env_override.py
vllm/env_override.py
+15
-2
vllm/envs.py
vllm/envs.py
+40
-16
vllm/executor/uniproc_executor.py
vllm/executor/uniproc_executor.py
+2
-2
vllm/forward_context.py
vllm/forward_context.py
+23
-0
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+3
-5
vllm/inputs/data.py
vllm/inputs/data.py
+2
-159
vllm/inputs/preprocess.py
vllm/inputs/preprocess.py
+14
-52
vllm/inputs/registry.py
vllm/inputs/registry.py
+23
-298
vllm/lora/resolver.py
vllm/lora/resolver.py
+83
-0
No files found.
Too many changes to show.
To preserve performance only
554 of 554+
files are displayed.
Plain diff
Email patch
vllm/entrypoints/cli/main.py
View file @
dcb5624a
...
...
@@ -5,6 +5,7 @@ import signal
import
sys
import
vllm.entrypoints.cli.benchmark.main
import
vllm.entrypoints.cli.collect_env
import
vllm.entrypoints.cli.openai
import
vllm.entrypoints.cli.serve
import
vllm.version
...
...
@@ -15,6 +16,7 @@ CMD_MODULES = [
vllm
.
entrypoints
.
cli
.
openai
,
vllm
.
entrypoints
.
cli
.
serve
,
vllm
.
entrypoints
.
cli
.
benchmark
.
main
,
vllm
.
entrypoints
.
cli
.
collect_env
,
]
...
...
vllm/entrypoints/launcher.py
View file @
dcb5624a
...
...
@@ -12,9 +12,11 @@ from fastapi import FastAPI, Request, Response
from
vllm
import
envs
from
vllm.engine.async_llm_engine
import
AsyncEngineDeadError
from
vllm.engine.multiprocessing
import
MQEngineDeadError
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.ssl
import
SSLCertRefresher
from
vllm.logger
import
init_logger
from
vllm.utils
import
find_process_using_port
from
vllm.v1.engine.exceptions
import
EngineDeadError
,
EngineGenerateError
logger
=
init_logger
(
__name__
)
...
...
@@ -40,6 +42,8 @@ async def serve_http(app: FastAPI,
loop
=
asyncio
.
get_running_loop
()
watchdog_task
=
loop
.
create_task
(
watchdog_loop
(
server
,
app
.
state
.
engine_client
))
server_task
=
loop
.
create_task
(
server
.
serve
(
sockets
=
[
sock
]
if
sock
else
None
))
...
...
@@ -52,6 +56,7 @@ async def serve_http(app: FastAPI,
def
signal_handler
()
->
None
:
# prevents the uvicorn signal handler to exit early
server_task
.
cancel
()
watchdog_task
.
cancel
()
if
ssl_cert_refresher
:
ssl_cert_refresher
.
stop
()
...
...
@@ -73,48 +78,69 @@ async def serve_http(app: FastAPI,
port
,
process
,
" "
.
join
(
process
.
cmdline
()))
logger
.
info
(
"Shutting down FastAPI HTTP server."
)
return
server
.
shutdown
()
finally
:
watchdog_task
.
cancel
()
async
def
watchdog_loop
(
server
:
uvicorn
.
Server
,
engine
:
EngineClient
):
"""
# Watchdog task that runs in the background, checking
# for error state in the engine. Needed to trigger shutdown
# if an exception arises is StreamingResponse() generator.
"""
VLLM_WATCHDOG_TIME_S
=
5.0
while
True
:
await
asyncio
.
sleep
(
VLLM_WATCHDOG_TIME_S
)
terminate_if_errored
(
server
,
engine
)
def
terminate_if_errored
(
server
:
uvicorn
.
Server
,
engine
:
EngineClient
):
"""
See discussions here on shutting down a uvicorn server
https://github.com/encode/uvicorn/discussions/1103
In this case we cannot await the server shutdown here
because handler must first return to close the connection
for this request.
"""
engine_errored
=
engine
.
errored
and
not
engine
.
is_running
if
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
and
engine_errored
:
server
.
should_exit
=
True
def
_add_shutdown_handlers
(
app
:
FastAPI
,
server
:
uvicorn
.
Server
)
->
None
:
"""Adds handlers for fatal errors that should crash the server"""
"""
VLLM V1 AsyncLLM catches exceptions and returns
only two types: EngineGenerateError and EngineDeadError.
EngineGenerateError is raised by the per request generate()
method. This error could be request specific (and therefore
recoverable - e.g. if there is an error in input processing).
EngineDeadError is raised by the background output_handler
method. This error is global and therefore not recoverable.
We register these @app.exception_handlers to return nice
responses to the end user if they occur and shut down if needed.
See https://fastapi.tiangolo.com/tutorial/handling-errors/
for more details on how exception handlers work.
If an exception is encountered in a StreamingResponse
generator, the exception is not raised, since we already sent
a 200 status. Rather, we send an error message as the next chunk.
Since the exception is not raised, this means that the server
will not automatically shut down. Instead, we use the watchdog
background task for check for errored state.
"""
@
app
.
exception_handler
(
RuntimeError
)
async
def
runtime_error_handler
(
request
:
Request
,
__
):
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
engine
=
request
.
app
.
state
.
engine_client
if
(
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
and
engine
.
errored
and
not
engine
.
is_running
):
logger
.
fatal
(
"AsyncLLMEngine has failed, terminating server "
"process"
)
# See discussions here on shutting down a uvicorn server
# https://github.com/encode/uvicorn/discussions/1103
# In this case we cannot await the server shutdown here because
# this handler must first return to close the connection for
# this request.
server
.
should_exit
=
True
return
Response
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
@
app
.
exception_handler
(
AsyncEngineDeadError
)
async
def
async_engine_dead_handler
(
_
,
__
):
"""Kill the server if the async engine is already dead. It will
not handle any further requests."""
if
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
:
logger
.
fatal
(
"AsyncLLMEngine is already dead, terminating server "
"process"
)
server
.
should_exit
=
True
return
Response
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
@
app
.
exception_handler
(
MQEngineDeadError
)
async
def
mq_e
ngine
_d
ead
_handler
(
_
,
__
):
"""Kill the server if the mq engine is already dead. It will
not handle any further requests."""
if
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
:
logger
.
fatal
(
"MQLLMEngine is already dead, terminating
server
"
"process"
)
server
.
should_exit
=
True
@
app
.
exception_handler
(
E
ngine
D
ead
Error
)
@
app
.
exception_handler
(
EngineGenerateError
)
async
def
runtime_exception_handler
(
request
:
Request
,
__
):
terminate_if_errored
(
server
=
server
,
engine
=
request
.
app
.
state
.
engine_client
,
)
return
Response
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
vllm/entrypoints/llm.py
View file @
dcb5624a
...
...
@@ -40,7 +40,6 @@ from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind
,
SamplingParams
)
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
MistralTokenizer
,
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
(
Counter
,
Device
,
deprecate_args
,
deprecate_kwargs
,
is_list_of
)
...
...
@@ -118,7 +117,7 @@ class LLM:
disable_async_output_proc: Disable async output processing.
This may result in lower performance.
hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running
. If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
...
...
@@ -252,11 +251,15 @@ class LLM:
self
.
request_counter
=
Counter
()
self
.
default_sampling_params
:
Union
[
dict
[
str
,
Any
],
None
]
=
None
def
get_tokenizer
(
self
)
->
AnyTokenizer
:
return
self
.
llm_engine
.
get_tokenizer_group
(
TokenizerGroup
).
tokenizer
def
get_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
AnyTokenizer
:
return
self
.
llm_engine
.
get_tokenizer_group
().
get_lora_tokenizer
(
lora_request
)
def
set_tokenizer
(
self
,
tokenizer
:
AnyTokenizer
)
->
None
:
tokenizer_group
=
self
.
llm_engine
.
get_tokenizer_group
(
TokenizerGroup
)
tokenizer_group
=
self
.
llm_engine
.
get_tokenizer_group
()
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
...
...
@@ -520,11 +523,9 @@ class LLM:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
params: The beam search parameters.
TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.?
"""
# TODO: how does beam search work together with length penalty,
# frequency, penalty, and stopping criteria, etc.?
beam_width
=
params
.
beam_width
max_tokens
=
params
.
max_tokens
temperature
=
params
.
temperature
...
...
@@ -536,15 +537,18 @@ class LLM:
tokenizer
.
eos_token_id
,
length_penalty
)
# TODO - fix handling of multimodal data for beam search; we pass it
# through in the async version on the abstract EngineClient, but not
# here.
if
any
(
"multi_modal_data"
in
prompt
and
prompt
[
"multi_modal_data"
]
is
not
None
for
prompt
in
prompts
):
logger
.
warning
(
"Multimodal data appears to have been provided, but is not"
" currently being passed through in LLM.beam_search()!"
)
def
create_tokens_prompt_from_beam
(
beam
:
BeamSearchSequence
)
->
TokensPrompt
:
token_prompt_kwargs
:
TokensPrompt
=
{
"prompt_token_ids"
:
beam
.
tokens
}
if
beam
.
multi_modal_data
is
not
None
:
token_prompt_kwargs
[
"multi_modal_data"
]
=
beam
.
multi_modal_data
if
beam
.
mm_processor_kwargs
is
not
None
:
token_prompt_kwargs
[
"mm_processor_kwargs"
]
=
beam
.
mm_processor_kwargs
return
TokensPrompt
(
**
token_prompt_kwargs
)
tokenizer
=
self
.
get_tokenizer
()
# generate 2 * beam_width candidates at each step
...
...
@@ -556,11 +560,20 @@ class LLM:
instances
:
list
[
BeamSearchInstance
]
=
[]
for
prompt
in
prompts
:
# Add multimodal processor kwargs & data
mm_kwargs
=
{}
if
"multi_modal_data"
in
prompt
:
mm_kwargs
[
"multi_modal_data"
]
=
prompt
[
"multi_modal_data"
]
if
"mm_processor_kwargs"
in
prompt
:
mm_kwargs
[
"mm_processor_kwargs"
]
=
prompt
[
"mm_processor_kwargs"
]
if
is_token_prompt
(
prompt
):
prompt_tokens
=
prompt
[
"prompt_token_ids"
]
else
:
prompt_tokens
=
tokenizer
.
encode
(
prompt
[
"prompt"
])
instances
.
append
(
BeamSearchInstance
(
prompt_tokens
))
instances
.
append
(
BeamSearchInstance
(
prompt_tokens
,
logprobs
=
None
,
**
mm_kwargs
))
for
_
in
range
(
max_tokens
):
all_beams
:
list
[
BeamSearchSequence
]
=
list
(
...
...
@@ -575,8 +588,7 @@ class LLM:
break
prompts_batch
=
[
TokensPrompt
(
prompt_token_ids
=
beam
.
tokens
)
for
beam
in
all_beams
create_tokens_prompt_from_beam
(
beam
)
for
beam
in
all_beams
]
# only runs for one step
...
...
@@ -602,7 +614,10 @@ class LLM:
tokens
=
current_beam
.
tokens
+
[
token_id
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
logprob_obj
.
logprob
,
multi_modal_data
=
current_beam
.
multi_modal_data
,
mm_processor_kwargs
=
current_beam
.
mm_processor_kwargs
)
if
token_id
==
tokenizer
.
eos_token_id
and
\
not
ignore_eos
:
...
...
@@ -701,7 +716,7 @@ class LLM:
cast
(
list
[
ChatCompletionMessageParam
],
messages
)
]
tokenizer
=
self
.
get_tokenizer
()
tokenizer
=
self
.
get_tokenizer
(
lora_request
)
model_config
=
self
.
llm_engine
.
get_model_config
()
resolved_content_format
=
resolve_chat_template_content_format
(
chat_template
,
...
...
@@ -724,9 +739,8 @@ class LLM:
content_format
=
resolved_content_format
,
)
prompt_data
:
Union
[
str
,
list
[
int
]]
if
isinstance
(
tokenizer
,
MistralTokenizer
):
prompt_
data
=
apply_mistral_chat_template
(
prompt_
token_ids
=
apply_mistral_chat_template
(
tokenizer
,
messages
=
msgs
,
chat_template
=
chat_template
,
...
...
@@ -735,7 +749,7 @@ class LLM:
continue_final_message
=
continue_final_message
,
)
else
:
prompt_
data
=
apply_hf_chat_template
(
prompt_
str
=
apply_hf_chat_template
(
tokenizer
,
trust_remote_code
=
model_config
.
trust_remote_code
,
conversation
=
conversation
,
...
...
@@ -744,12 +758,12 @@ class LLM:
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
continue_final_message
,
)
# Special tokens are already included in chat templates so
# should not be added by the tokenizer in this case.
prompt_token_ids
=
tokenizer
.
encode
(
prompt_str
,
add_special_tokens
=
False
)
prompt
:
Union
[
TokensPrompt
,
TextPrompt
]
if
is_list_of
(
prompt_data
,
int
):
prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_data
)
else
:
prompt
=
TextPrompt
(
prompt
=
prompt_data
)
prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_token_ids
)
if
mm_data
is
not
None
:
prompt
[
"multi_modal_data"
]
=
mm_data
...
...
@@ -1048,8 +1062,6 @@ class LLM:
if
len
(
encoded_output_1
)
==
1
:
encoded_output_1
=
encoded_output_1
*
len
(
encoded_output_2
)
scores
:
list
[
PoolingRequestOutput
]
=
[]
scores
=
_cosine_similarity
(
tokenizer
=
tokenizer
,
embed_1
=
encoded_output_1
,
embed_2
=
encoded_output_2
)
...
...
@@ -1384,7 +1396,9 @@ class LLM:
grammar
=
guided_options
.
guided_grammar
,
json_object
=
guided_options
.
guided_json_object
,
backend
=
guided_options
.
guided_decoding_backend
,
whitespace_pattern
=
guided_options
.
guided_whitespace_pattern
)
whitespace_pattern
=
guided_options
.
guided_whitespace_pattern
,
structural_tag
=
guided_options
.
structural_tag
,
)
return
params
def
_run_engine
(
...
...
vllm/entrypoints/openai/api_server.py
View file @
dcb5624a
...
...
@@ -30,7 +30,7 @@ from starlette.routing import Mount
from
typing_extensions
import
assert_never
import
vllm.envs
as
envs
from
vllm.config
import
Model
Config
from
vllm.config
import
Vllm
Config
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
# type: ignore
from
vllm.engine.multiprocessing.client
import
MQLLMEngineClient
...
...
@@ -310,32 +310,33 @@ def mount_metrics(app: FastAPI):
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from
prometheus_client
import
(
CollectorRegistry
,
make_asgi_app
,
from
prometheus_client
import
(
REGISTRY
,
CollectorRegistry
,
make_asgi_app
,
multiprocess
)
from
prometheus_fastapi_instrumentator
import
Instrumentator
registry
=
REGISTRY
prometheus_multiproc_dir_path
=
os
.
getenv
(
"PROMETHEUS_MULTIPROC_DIR"
,
None
)
if
prometheus_multiproc_dir_path
is
not
None
:
logger
.
debug
(
"vLLM to use %s as PROMETHEUS_MULTIPROC_DIR"
,
prometheus_multiproc_dir_path
)
registry
=
CollectorRegistry
()
multiprocess
.
MultiProcessCollector
(
registry
)
Instrumentator
(
excluded_handlers
=
[
"/metrics"
,
"/health"
,
"/load"
,
"/ping"
,
"/version"
,
],
registry
=
registry
,
).
add
().
instrument
(
app
).
expose
(
app
)
# Add prometheus asgi middleware to route /metrics requests
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
(
registry
=
registry
))
else
:
# Add prometheus asgi middleware to route /metrics requests
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
())
Instrumentator
(
excluded_handlers
=
[
"/metrics"
,
"/health"
,
"/load"
,
"/ping"
,
"/version"
,
"/server_info"
,
],
registry
=
registry
,
).
add
().
instrument
(
app
).
expose
(
app
)
# Add prometheus asgi middleware to route /metrics requests
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
(
registry
=
registry
))
# Workaround for 307 Redirect for /metrics
metrics_route
.
path_regex
=
re
.
compile
(
"^/metrics(?P<path>.*)$"
)
...
...
@@ -687,6 +688,11 @@ TASK_HANDLERS: dict[str, dict[str, tuple]] = {
if
envs
.
VLLM_SERVER_DEV_MODE
:
@
router
.
get
(
"/server_info"
)
async
def
show_server_info
(
raw_request
:
Request
):
server_info
=
{
"vllm_config"
:
str
(
raw_request
.
app
.
state
.
vllm_config
)}
return
JSONResponse
(
content
=
server_info
)
@
router
.
post
(
"/reset_prefix_cache"
)
async
def
reset_prefix_cache
(
raw_request
:
Request
):
"""
...
...
@@ -875,7 +881,8 @@ def build_app(args: Namespace) -> FastAPI:
section
async
for
section
in
response
.
body_iterator
]
response
.
body_iterator
=
iterate_in_threadpool
(
iter
(
response_body
))
logger
.
info
(
"response_body={%s}"
,
response_body
[
0
].
decode
())
logger
.
info
(
"response_body={%s}"
,
response_body
[
0
].
decode
()
if
response_body
else
None
)
return
response
for
middleware
in
args
.
middleware
:
...
...
@@ -894,7 +901,7 @@ def build_app(args: Namespace) -> FastAPI:
async
def
init_app_state
(
engine_client
:
EngineClient
,
model
_config
:
Model
Config
,
vllm
_config
:
Vllm
Config
,
state
:
State
,
args
:
Namespace
,
)
->
None
:
...
...
@@ -915,6 +922,8 @@ async def init_app_state(
state
.
engine_client
=
engine_client
state
.
log_stats
=
not
args
.
disable_log_stats
state
.
vllm_config
=
vllm_config
model_config
=
vllm_config
.
model_config
resolved_chat_template
=
load_chat_template
(
args
.
chat_template
)
if
resolved_chat_template
is
not
None
:
...
...
@@ -1069,8 +1078,8 @@ async def run_server(args, **uvicorn_kwargs) -> None:
async
with
build_async_engine_client
(
args
)
as
engine_client
:
app
=
build_app
(
args
)
model
_config
=
await
engine_client
.
get_
model
_config
()
await
init_app_state
(
engine_client
,
model
_config
,
app
.
state
,
args
)
vllm
_config
=
await
engine_client
.
get_
vllm
_config
()
await
init_app_state
(
engine_client
,
vllm
_config
,
app
.
state
,
args
)
def
_listen_addr
(
a
:
str
)
->
str
:
if
is_valid_ipv6_address
(
a
):
...
...
vllm/entrypoints/openai/cli_args.py
View file @
dcb5624a
...
...
@@ -11,7 +11,7 @@ import ssl
from
collections.abc
import
Sequence
from
typing
import
Optional
,
Union
,
get_args
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
nullable_str
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
optional_type
from
vllm.entrypoints.chat_utils
import
(
ChatTemplateContentFormatOption
,
validate_chat_template
)
from
vllm.entrypoints.openai.serving_models
import
(
LoRAModulePath
,
...
...
@@ -79,7 +79,7 @@ class PromptAdapterParserAction(argparse.Action):
def
make_arg_parser
(
parser
:
FlexibleArgumentParser
)
->
FlexibleArgumentParser
:
parser
.
add_argument
(
"--host"
,
type
=
nullable_
str
,
type
=
optional_type
(
str
)
,
default
=
None
,
help
=
"Host name."
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Port number."
)
...
...
@@ -108,13 +108,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default
=
[
"*"
],
help
=
"Allowed headers."
)
parser
.
add_argument
(
"--api-key"
,
type
=
nullable_
str
,
type
=
optional_type
(
str
)
,
default
=
None
,
help
=
"If provided, the server will require this key "
"to be presented in the header."
)
parser
.
add_argument
(
"--lora-modules"
,
type
=
nullable_
str
,
type
=
optional_type
(
str
)
,
default
=
None
,
nargs
=
'+'
,
action
=
LoRAParserAction
,
...
...
@@ -126,14 +126,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"
\"
base_model_name
\"
:
\"
id
\"
}``"
)
parser
.
add_argument
(
"--prompt-adapters"
,
type
=
nullable_
str
,
type
=
optional_type
(
str
)
,
default
=
None
,
nargs
=
'+'
,
action
=
PromptAdapterParserAction
,
help
=
"Prompt adapter configurations in the format name=path. "
"Multiple adapters can be specified."
)
parser
.
add_argument
(
"--chat-template"
,
type
=
nullable_
str
,
type
=
optional_type
(
str
)
,
default
=
None
,
help
=
"The file path to the chat template, "
"or the template in single-line form "
...
...
@@ -151,20 +151,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'similar to OpenAI schema. '
'Example: ``[{"type": "text", "text": "Hello world!"}]``'
)
parser
.
add_argument
(
"--response-role"
,
type
=
nullable_
str
,
type
=
optional_type
(
str
)
,
default
=
"assistant"
,
help
=
"The role name to return if "
"``request.add_generation_prompt=true``."
)
parser
.
add_argument
(
"--ssl-keyfile"
,
type
=
nullable_
str
,
type
=
optional_type
(
str
)
,
default
=
None
,
help
=
"The file path to the SSL key file."
)
parser
.
add_argument
(
"--ssl-certfile"
,
type
=
nullable_
str
,
type
=
optional_type
(
str
)
,
default
=
None
,
help
=
"The file path to the SSL cert file."
)
parser
.
add_argument
(
"--ssl-ca-certs"
,
type
=
nullable_
str
,
type
=
optional_type
(
str
)
,
default
=
None
,
help
=
"The CA certificates file."
)
parser
.
add_argument
(
...
...
@@ -180,13 +180,13 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
)
parser
.
add_argument
(
"--root-path"
,
type
=
nullable_
str
,
type
=
optional_type
(
str
)
,
default
=
None
,
help
=
"FastAPI root_path when app is behind a path based routing proxy."
)
parser
.
add_argument
(
"--middleware"
,
type
=
nullable_
str
,
type
=
optional_type
(
str
)
,
action
=
"append"
,
default
=
[],
help
=
"Additional ASGI middleware to apply to the app. "
...
...
vllm/entrypoints/openai/protocol.py
View file @
dcb5624a
...
...
@@ -2,6 +2,7 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import
json
import
re
import
time
from
argparse
import
Namespace
...
...
@@ -139,12 +140,30 @@ class JsonSchemaResponseFormat(OpenAIBaseModel):
strict
:
Optional
[
bool
]
=
None
class
StructuralTag
(
OpenAIBaseModel
):
begin
:
str
# schema is the field, but that causes conflicts with pydantic so
# instead use structural_tag_schema with an alias
structural_tag_schema
:
Optional
[
dict
[
str
,
Any
]]
=
Field
(
default
=
None
,
alias
=
"schema"
)
end
:
str
class
StructuralTagResponseFormat
(
OpenAIBaseModel
):
type
:
Literal
[
"structural_tag"
]
structures
:
list
[
StructuralTag
]
triggers
:
list
[
str
]
class
ResponseFormat
(
OpenAIBaseModel
):
# type must be "json_schema", "json_object" or "text"
# type must be "json_schema", "json_object"
,
or "text"
type
:
Literal
[
"text"
,
"json_object"
,
"json_schema"
]
json_schema
:
Optional
[
JsonSchemaResponseFormat
]
=
None
AnyResponseFormat
=
Union
[
ResponseFormat
,
StructuralTagResponseFormat
]
class
StreamOptions
(
OpenAIBaseModel
):
include_usage
:
Optional
[
bool
]
=
True
continuous_usage_stats
:
Optional
[
bool
]
=
False
...
...
@@ -227,7 +246,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
max_completion_tokens
:
Optional
[
int
]
=
None
n
:
Optional
[
int
]
=
1
presence_penalty
:
Optional
[
float
]
=
0.0
response_format
:
Optional
[
ResponseFormat
]
=
None
response_format
:
Optional
[
Any
ResponseFormat
]
=
None
seed
:
Optional
[
int
]
=
Field
(
None
,
ge
=
_LONG_INFO
.
min
,
le
=
_LONG_INFO
.
max
)
stop
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
Field
(
default_factory
=
list
)
stream
:
Optional
[
bool
]
=
False
...
...
@@ -340,6 +359,11 @@ class ChatCompletionRequest(OpenAIBaseModel):
description
=
(
"If specified, the output will follow the context free grammar."
),
)
structural_tag
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
(
"If specified, the output will follow the structural tag schema."
),
)
guided_decoding_backend
:
Optional
[
str
]
=
Field
(
default
=
None
,
description
=
(
...
...
@@ -476,6 +500,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_schema
=
self
.
response_format
.
json_schema
assert
json_schema
is
not
None
self
.
guided_json
=
json_schema
.
json_schema
elif
self
.
response_format
.
type
==
"structural_tag"
:
structural_tag
=
self
.
response_format
assert
structural_tag
is
not
None
and
isinstance
(
structural_tag
,
StructuralTagResponseFormat
)
s_tag_obj
=
structural_tag
.
model_dump
(
by_alias
=
True
)
self
.
structural_tag
=
json
.
dumps
(
s_tag_obj
)
guided_decoding
=
GuidedDecodingParams
.
from_optional
(
json
=
self
.
_get_guided_json_from_tool
()
or
self
.
guided_json
,
...
...
@@ -485,6 +515,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
json_object
=
guided_json_object
,
backend
=
self
.
guided_decoding_backend
,
whitespace_pattern
=
self
.
guided_whitespace_pattern
,
structural_tag
=
self
.
structural_tag
,
)
return
SamplingParams
.
from_optional
(
...
...
@@ -742,12 +773,13 @@ class CompletionRequest(OpenAIBaseModel):
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."
),
)
response_format
:
Optional
[
ResponseFormat
]
=
Field
(
response_format
:
Optional
[
Any
ResponseFormat
]
=
Field
(
default
=
None
,
description
=
(
"Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
"{'type': 'text' } is supported."
),
description
=
(
"Similar to chat completion, this parameter specifies the format "
"of output. Only {'type': 'json_object'}, {'type': 'json_schema'}"
", {'type': 'structural_tag'}, or {'type': 'text' } is supported."
),
)
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
Field
(
default
=
None
,
...
...
@@ -1577,14 +1609,6 @@ class TranscriptionRequest(OpenAIBaseModel):
"""
## TODO (varun) : Support if set to 0, certain thresholds are met !!
temperature
:
float
=
Field
(
default
=
0.0
)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
timestamp_granularities
:
list
[
Literal
[
"word"
,
"segment"
]]
=
Field
(
alias
=
"timestamp_granularities[]"
,
default
=
[])
...
...
@@ -1596,6 +1620,7 @@ class TranscriptionRequest(OpenAIBaseModel):
timestamps incurs additional latency.
"""
# doc: begin-transcription-extra-params
stream
:
Optional
[
bool
]
=
False
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
...
...
@@ -1604,10 +1629,51 @@ class TranscriptionRequest(OpenAIBaseModel):
# Flattened stream option to simplify form data.
stream_include_usage
:
Optional
[
bool
]
=
False
stream_continuous_usage_stats
:
Optional
[
bool
]
=
False
# doc: end-transcription-extra-params
# doc: begin-transcription-sampling-params
temperature
:
float
=
Field
(
default
=
0.0
)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
top_p
:
Optional
[
float
]
=
None
"""Enables nucleus (top-p) sampling, where tokens are selected from the
smallest possible set whose cumulative probability exceeds `p`.
"""
top_k
:
Optional
[
int
]
=
None
"""Limits sampling to the `k` most probable tokens at each step."""
min_p
:
Optional
[
float
]
=
None
"""Filters out tokens with a probability lower than `min_p`, ensuring a
minimum likelihood threshold during sampling.
"""
seed
:
Optional
[
int
]
=
Field
(
None
,
ge
=
_LONG_INFO
.
min
,
le
=
_LONG_INFO
.
max
)
"""The seed to use for sampling."""
frequency_penalty
:
Optional
[
float
]
=
0.0
"""The frequency penalty to use for sampling."""
repetition_penalty
:
Optional
[
float
]
=
None
"""The repetition penalty to use for sampling."""
presence_penalty
:
Optional
[
float
]
=
0.0
"""The presence penalty to use for sampling."""
# doc: end-transcription-sampling-params
# Default sampling parameters for transcription requests.
_DEFAULT_SAMPLING_PARAMS
:
dict
=
{
"temperature"
:
0
,
"repetition_penalty"
:
1.0
,
"temperature"
:
1.0
,
"top_p"
:
1.0
,
"top_k"
:
-
1
,
"min_p"
:
0.0
,
}
def
to_sampling_params
(
...
...
@@ -1619,13 +1685,35 @@ class TranscriptionRequest(OpenAIBaseModel):
if
default_sampling_params
is
None
:
default_sampling_params
=
{}
# Default parameters
if
(
temperature
:
=
self
.
temperature
)
is
None
:
temperature
=
default_sampling_params
.
get
(
"temperature"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"temperature"
])
if
(
top_p
:
=
self
.
top_p
)
is
None
:
top_p
=
default_sampling_params
.
get
(
"top_p"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"top_p"
])
if
(
top_k
:
=
self
.
top_k
)
is
None
:
top_k
=
default_sampling_params
.
get
(
"top_k"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"top_k"
])
if
(
min_p
:
=
self
.
min_p
)
is
None
:
min_p
=
default_sampling_params
.
get
(
"min_p"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"min_p"
])
if
(
repetition_penalty
:
=
self
.
repetition_penalty
)
is
None
:
repetition_penalty
=
default_sampling_params
.
get
(
"repetition_penalty"
,
self
.
_DEFAULT_SAMPLING_PARAMS
[
"repetition_penalty"
])
return
SamplingParams
.
from_optional
(
temperature
=
temperature
,
max_tokens
=
max_tokens
,
seed
=
self
.
seed
,
top_p
=
top_p
,
top_k
=
top_k
,
min_p
=
min_p
,
frequency_penalty
=
self
.
frequency_penalty
,
repetition_penalty
=
repetition_penalty
,
presence_penalty
=
self
.
presence_penalty
,
output_kind
=
RequestOutputKind
.
DELTA
if
self
.
stream
\
else
RequestOutputKind
.
FINAL_ONLY
)
...
...
vllm/entrypoints/openai/run_batch.py
View file @
dcb5624a
...
...
@@ -12,7 +12,7 @@ import torch
from
prometheus_client
import
start_http_server
from
tqdm
import
tqdm
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
nullable_str
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
optional_type
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.logger
import
RequestLogger
,
logger
# yapf: disable
...
...
@@ -61,7 +61,7 @@ def parse_args():
"to the output URL."
,
)
parser
.
add_argument
(
"--response-role"
,
type
=
nullable_
str
,
type
=
optional_type
(
str
)
,
default
=
"assistant"
,
help
=
"The role name to return if "
"`request.add_generation_prompt=True`."
)
...
...
vllm/entrypoints/openai/serving_engine.py
View file @
dcb5624a
...
...
@@ -10,6 +10,7 @@ from fastapi import Request
from
pydantic
import
Field
from
starlette.datastructures
import
Headers
import
vllm.envs
as
envs
from
vllm.config
import
ModelConfig
from
vllm.engine.protocol
import
EngineClient
# yapf conflicts with isort for this block
...
...
@@ -125,18 +126,29 @@ class OpenAIServing:
self
,
request
:
AnyRequest
,
)
->
Optional
[
ErrorResponse
]:
error_response
=
None
if
self
.
_is_model_supported
(
request
.
model
):
return
None
if
request
.
model
in
[
lora
.
lora_name
for
lora
in
self
.
models
.
lora_requests
]:
return
None
if
envs
.
VLLM_ALLOW_RUNTIME_LORA_UPDATING
and
request
.
model
and
(
load_result
:
=
await
self
.
models
.
resolve_lora
(
request
.
model
)):
if
isinstance
(
load_result
,
LoRARequest
):
return
None
if
isinstance
(
load_result
,
ErrorResponse
)
and
\
load_result
.
code
==
HTTPStatus
.
BAD_REQUEST
.
value
:
error_response
=
load_result
if
request
.
model
in
[
prompt_adapter
.
prompt_adapter_name
for
prompt_adapter
in
self
.
models
.
prompt_adapter_requests
]:
return
None
return
self
.
create_error_response
(
return
error_response
or
self
.
create_error_response
(
message
=
f
"The model `
{
request
.
model
}
` does not exist."
,
err_type
=
"NotFoundError"
,
status_code
=
HTTPStatus
.
NOT_FOUND
)
...
...
vllm/entrypoints/openai/serving_models.py
View file @
dcb5624a
...
...
@@ -2,6 +2,8 @@
import
json
import
pathlib
from
asyncio
import
Lock
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
http
import
HTTPStatus
from
typing
import
Optional
,
Union
...
...
@@ -15,6 +17,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
UnloadLoRAAdapterRequest
)
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.resolver
import
LoRAResolver
,
LoRAResolverRegistry
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.utils
import
AtomicCounter
...
...
@@ -63,11 +66,19 @@ class OpenAIServingModels:
self
.
base_model_paths
=
base_model_paths
self
.
max_model_len
=
model_config
.
max_model_len
self
.
engine_client
=
engine_client
self
.
model_config
=
model_config
self
.
static_lora_modules
=
lora_modules
self
.
lora_requests
:
list
[
LoRARequest
]
=
[]
self
.
lora_id_counter
=
AtomicCounter
(
0
)
self
.
lora_resolvers
:
list
[
LoRAResolver
]
=
[]
for
lora_resolver_name
in
LoRAResolverRegistry
.
get_supported_resolvers
(
):
self
.
lora_resolvers
.
append
(
LoRAResolverRegistry
.
get_resolver
(
lora_resolver_name
))
self
.
lora_resolver_lock
:
dict
[
str
,
Lock
]
=
defaultdict
(
Lock
)
self
.
prompt_adapter_requests
=
[]
if
prompt_adapters
is
not
None
:
for
i
,
prompt_adapter
in
enumerate
(
prompt_adapters
,
start
=
1
):
...
...
@@ -234,6 +245,65 @@ class OpenAIServingModels:
return
None
async
def
resolve_lora
(
self
,
lora_name
:
str
)
->
Union
[
LoRARequest
,
ErrorResponse
]:
"""Attempt to resolve a LoRA adapter using available resolvers.
Args:
lora_name: Name/identifier of the LoRA adapter
Returns:
LoRARequest if found and loaded successfully.
ErrorResponse (404) if no resolver finds the adapter.
ErrorResponse (400) if adapter(s) are found but none load.
"""
async
with
self
.
lora_resolver_lock
[
lora_name
]:
# First check if this LoRA is already loaded
for
existing
in
self
.
lora_requests
:
if
existing
.
lora_name
==
lora_name
:
return
existing
base_model_name
=
self
.
model_config
.
model
unique_id
=
self
.
lora_id_counter
.
inc
(
1
)
found_adapter
=
False
# Try to resolve using available resolvers
for
resolver
in
self
.
lora_resolvers
:
lora_request
=
await
resolver
.
resolve_lora
(
base_model_name
,
lora_name
)
if
lora_request
is
not
None
:
found_adapter
=
True
lora_request
.
lora_int_id
=
unique_id
try
:
await
self
.
engine_client
.
add_lora
(
lora_request
)
self
.
lora_requests
.
append
(
lora_request
)
logger
.
info
(
"Resolved and loaded LoRA adapter '%s' using %s"
,
lora_name
,
resolver
.
__class__
.
__name__
)
return
lora_request
except
BaseException
as
e
:
logger
.
warning
(
"Failed to load LoRA '%s' resolved by %s: %s. "
"Trying next resolver."
,
lora_name
,
resolver
.
__class__
.
__name__
,
e
)
continue
if
found_adapter
:
# An adapter was found, but all attempts to load it failed.
return
create_error_response
(
message
=
(
f
"LoRA adapter '
{
lora_name
}
' was found "
"but could not be loaded."
),
err_type
=
"BadRequestError"
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
else
:
# No adapter was found
return
create_error_response
(
message
=
f
"LoRA adapter
{
lora_name
}
does not exist"
,
err_type
=
"NotFoundError"
,
status_code
=
HTTPStatus
.
NOT_FOUND
)
def
create_error_response
(
message
:
str
,
...
...
vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
View file @
dcb5624a
...
...
@@ -27,6 +27,7 @@ logger = init_logger(__name__)
@
ToolParserManager
.
register_module
(
"llama3_json"
)
@
ToolParserManager
.
register_module
(
"llama4_json"
)
class
Llama3JsonToolParser
(
ToolParser
):
"""
Tool call parser for Llama 3.1 models intended for use with the
...
...
vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
View file @
dcb5624a
...
...
@@ -38,6 +38,10 @@ class MistralToolCall(ToolCall):
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return
""
.
join
(
choices
(
ALPHANUMERIC
,
k
=
9
))
@
staticmethod
def
is_valid_id
(
id
:
str
)
->
bool
:
return
id
.
isalnum
()
and
len
(
id
)
==
9
@
ToolParserManager
.
register_module
(
"mistral"
)
class
MistralToolParser
(
ToolParser
):
...
...
@@ -70,6 +74,19 @@ class MistralToolParser(ToolParser):
"Mistral Tool Parser could not locate the tool call token in "
"the tokenizer!"
)
def
adjust_request
(
self
,
request
:
ChatCompletionRequest
)
->
ChatCompletionRequest
:
if
not
isinstance
(
self
.
model_tokenizer
,
MistralTokenizer
)
and
request
.
tools
and
request
.
tool_choice
!=
'none'
:
# Do not skip special tokens when using chat template
# with Mistral parser as TOOL_CALL token is needed
# for tool detection.
# Note: we don't want skip_special_tokens=False
# with MistralTokenizer as it is incompatible
request
.
skip_special_tokens
=
False
return
request
def
extract_tool_calls
(
self
,
model_output
:
str
,
...
...
vllm/env_override.py
View file @
dcb5624a
...
...
@@ -8,8 +8,21 @@ import torch
# that interact with vllm workers.
# they are executed whenever `import vllm` is called.
# see https://github.com/NVIDIA/nccl/issues/1234
os
.
environ
[
'NCCL_CUMEM_ENABLE'
]
=
'0'
if
not
os
.
path
.
exists
(
'/dev/nvidia-caps-imex-channels'
):
# normally, we disable NCCL_CUMEM_ENABLE because it
# will cost 1~2 GiB GPU memory with cudagraph+allreduce,
# see https://github.com/NVIDIA/nccl/issues/1234
# for more details.
# However, NCCL requires NCCL_CUMEM_ENABLE to work with
# multi-node NVLink, typically on GB200-NVL72 systems.
# The ultimate way to detect multi-node NVLink is to use
# NVML APIs, which are too expensive to call here.
# As an approximation, we check the existence of
# /dev/nvidia-caps-imex-channels, used by
# multi-node NVLink to communicate across nodes.
# This will still cost some GPU memory, but it is worthwhile
# because we can get very fast cross-node bandwidth with NVLink.
os
.
environ
[
'NCCL_CUMEM_ENABLE'
]
=
'0'
# see https://github.com/vllm-project/vllm/pull/15951
# it avoids unintentional cuda initialization from torch.cuda.is_available()
...
...
vllm/envs.py
View file @
dcb5624a
...
...
@@ -86,10 +86,12 @@ if TYPE_CHECKING:
VLLM_DISABLED_KERNELS
:
list
[
str
]
=
[]
VLLM_USE_V1
:
bool
=
True
VLLM_ROCM_USE_AITER
:
bool
=
False
VLLM_ROCM_USE_AITER_PAGED_ATTN
:
bool
=
False
VLLM_ROCM_USE_AITER_LINEAR
:
bool
=
True
VLLM_ROCM_USE_AITER_MOE
:
bool
=
True
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE
:
bool
=
False
VLLM_ROCM_USE_AITER_RMSNORM
:
bool
=
True
VLLM_ROCM_USE_AITER_MLA
:
bool
=
True
VLLM_ROCM_USE_SKINNY_GEMM
:
bool
=
True
VLLM_ROCM_FP8_PADDING
:
bool
=
True
VLLM_ROCM_MOE_PADDING
:
bool
=
True
VLLM_ROCM_CUSTOM_PAGED_ATTN
:
bool
=
True
...
...
@@ -107,6 +109,7 @@ if TYPE_CHECKING:
VLLM_RAY_BUNDLE_INDICES
:
str
=
""
VLLM_CUDART_SO_PATH
:
Optional
[
str
]
=
None
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
:
bool
=
True
VLLM_HPU_USE_DELAYED_SAMPLING
:
bool
=
False
VLLM_DP_RANK
:
int
=
0
VLLM_DP_RANK_LOCAL
:
int
=
-
1
VLLM_DP_SIZE
:
int
=
1
...
...
@@ -114,10 +117,10 @@ if TYPE_CHECKING:
VLLM_DP_MASTER_PORT
:
int
=
0
VLLM_MARLIN_USE_ATOMIC_ADD
:
bool
=
False
VLLM_V0_USE_OUTLINES_CACHE
:
bool
=
False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION
:
bool
=
False
VLLM_TPU_BUCKET_PADDING_GAP
:
int
=
0
VLLM_USE_DEEP_GEMM
:
bool
=
False
VLLM_XGRAMMAR_CACHE_MB
:
int
=
0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD
:
int
=
256
def
get_default_cache_root
():
...
...
@@ -586,6 +589,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# Whether to use aiter paged attention.
# By default is disabled.
"VLLM_ROCM_USE_AITER_PAGED_ATTN"
:
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER_PAGED_ATTN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# use aiter linear op if aiter ops are enabled
# The following list of related ops
# - scaled_mm (per-tensor / rowwise)
...
...
@@ -599,18 +608,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER_MOE"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# Whether to use aiter block scaled moe kernel.
# By default this is disabled.
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE"
:
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE"
,
"false"
).
lower
()
in
(
"true"
,
"1"
)),
# use aiter rms norm op if aiter ops are enabled.
"VLLM_ROCM_USE_AITER_RMSNORM"
:
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER_RMSNORM"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# Whether to use aiter mla ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MLA"
:
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_AITER_MLA"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# use rocm skinny gemms
"VLLM_ROCM_USE_SKINNY_GEMM"
:
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_USE_SKINNY_GEMM"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# Pad the fp8 weights to 256 bytes for ROCm
"VLLM_ROCM_FP8_PADDING"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ROCM_FP8_PADDING"
,
"1"
))),
...
...
@@ -700,6 +712,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
os
.
environ
.
get
(
"VLLM_CONTIGUOUS_PA"
,
"true"
).
lower
()
in
(
"1"
,
"true"
),
# Use delayed sampling for HPU to reduce host cpu overhead
# between each step.
"VLLM_HPU_USE_DELAYED_SAMPLING"
:
lambda
:
os
.
environ
.
get
(
"VLLM_DELAYED_SAMPLING"
,
"false"
).
lower
()
in
(
"1"
,
"true"
),
# Rank of the process in the data parallel setting
"VLLM_DP_RANK"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_DP_RANK"
,
"0"
)),
...
...
@@ -745,11 +763,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V0_USE_OUTLINES_CACHE"
:
lambda
:
os
.
environ
.
get
(
"VLLM_V0_USE_OUTLINES_CACHE"
,
"0"
)
==
"1"
,
# If set, disables TPU-specific optimization for top-k & top-p sampling
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"
:
lambda
:
bool
(
int
(
os
.
environ
[
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"
]))
if
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"
in
os
.
environ
else
None
,
# Gap between padding buckets for the forward pass. So we have
# 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP"
:
...
...
@@ -765,6 +778,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
# It can be changed with this variable if needed for some reason.
"VLLM_XGRAMMAR_CACHE_MB"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_XGRAMMAR_CACHE_MB"
,
"512"
)),
# Control the threshold for msgspec to use 'zero copy' for
# serialization/deserialization of tensors. Tensors below
# this limit will be encoded into the msgpack buffer, and
# tensors above will instead be sent via a separate message.
# While the sending side still actually copies the tensor
# in all cases, on the receiving side, tensors above this
# limit will actually be zero-copy decoded.
"VLLM_MSGPACK_ZERO_COPY_THRESHOLD"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_MSGPACK_ZERO_COPY_THRESHOLD"
,
"256"
)),
}
# end-env-vars-definition
...
...
@@ -803,7 +826,7 @@ def compute_hash() -> str:
variables, ensure that it is included in the factors list if
it affects the computation graph. For example, different values
of VLLM_PP_LAYER_PARTITION will generate different computation
graphs, so it is included in the factors list. The env vars that
graphs, so it is included in the factors list. The env vars that
affect the choice of different kernels or attention backends should
also be included in the factors list.
"""
...
...
@@ -832,6 +855,7 @@ def compute_hash() -> str:
if
key
in
environment_variables
:
factorize
(
key
)
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
()).
hexdigest
()
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()
return
hash_str
\ No newline at end of file
vllm/executor/uniproc_executor.py
View file @
dcb5624a
...
...
@@ -34,13 +34,13 @@ class UniProcExecutor(ExecutorBase):
if
len
(
device_info
)
>
1
:
local_rank
=
int
(
device_info
[
1
])
rank
=
0
is_driver_worker
=
True
kwargs
=
dict
(
vllm_config
=
self
.
vllm_config
,
local_rank
=
local_rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
is_driver_worker
=
(
not
self
.
parallel_config
)
or
(
rank
%
self
.
parallel_config
.
tensor_parallel_size
==
0
),
is_driver_worker
=
is_driver_worker
,
)
self
.
collective_rpc
(
"init_worker"
,
args
=
([
kwargs
],
))
self
.
collective_rpc
(
"init_device"
)
...
...
vllm/forward_context.py
View file @
dcb5624a
...
...
@@ -11,6 +11,10 @@ import torch.distributed as dist
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
has_kv_transfer_group
,
is_v1_kv_transfer_group
)
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorBase_V1
from
vllm.logger
import
init_logger
if
TYPE_CHECKING
:
...
...
@@ -98,6 +102,17 @@ def set_forward_context(attn_metadata: Any,
virtual_engine
=
virtual_engine
,
attn_metadata
=
attn_metadata
,
dp_metadata
=
dp_metadata
)
# KVConnector: trigger (possibly async) load before forward.
# Each attn layer will block until the reading is complete.
trigger_kv_transfer
=
(
attn_metadata
is
not
None
and
has_kv_transfer_group
()
and
is_v1_kv_transfer_group
())
if
trigger_kv_transfer
:
kv_connector
=
get_kv_transfer_group
()
assert
isinstance
(
kv_connector
,
KVConnectorBase_V1
)
kv_connector
.
start_load_kv
(
_forward_context
)
try
:
yield
finally
:
...
...
@@ -133,4 +148,12 @@ def set_forward_context(attn_metadata: Any,
logger
.
info
((
"Batchsize forward time stats "
"(batchsize, count, median_time(ms)): %s"
),
forward_stats
)
# KVConnector: each attn layer triggers (possibly async) save.
# Ensure all those operations complete before forward() is done.
if
trigger_kv_transfer
:
kv_connector
=
get_kv_transfer_group
()
assert
isinstance
(
kv_connector
,
KVConnectorBase_V1
)
kv_connector
.
wait_for_save
()
_forward_context
=
prev_context
vllm/inputs/__init__.py
View file @
dcb5624a
...
...
@@ -2,10 +2,9 @@
from
.data
import
(
DecoderOnlyInputs
,
EncoderDecoderInputs
,
ExplicitEncoderDecoderPrompt
,
ProcessorInputs
,
PromptType
,
SingletonInputs
,
SingletonInputsAdapter
,
SingletonPrompt
,
TextPrompt
,
TokenInputs
,
TokensPrompt
,
build_explicit_enc_dec_prompt
,
to_enc_dec_tuple_list
,
token_inputs
,
zip_enc_dec_prompts
)
SingletonInputs
,
SingletonPrompt
,
TextPrompt
,
TokenInputs
,
TokensPrompt
,
build_explicit_enc_dec_prompt
,
to_enc_dec_tuple_list
,
token_inputs
,
zip_enc_dec_prompts
)
from
.registry
import
(
DummyData
,
InputContext
,
InputProcessingContext
,
InputRegistry
)
...
...
@@ -27,7 +26,6 @@ __all__ = [
"EncoderDecoderInputs"
,
"ProcessorInputs"
,
"SingletonInputs"
,
"SingletonInputsAdapter"
,
"build_explicit_enc_dec_prompt"
,
"to_enc_dec_tuple_list"
,
"zip_enc_dec_prompts"
,
...
...
vllm/inputs/data.py
View file @
dcb5624a
# SPDX-License-Identifier: Apache-2.0
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
functools
import
cached_property
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
Literal
,
Optional
,
Union
,
cast
import
torch
from
typing_extensions
import
NotRequired
,
TypedDict
,
TypeVar
,
assert_never
from
typing_extensions
import
NotRequired
,
TypedDict
,
TypeVar
if
TYPE_CHECKING
:
from
vllm.multimodal
import
(
MultiModalDataDict
,
MultiModalKwargs
,
MultiModalPlaceholderDict
)
from
vllm.multimodal.inputs
import
MultiModalInputs
from
vllm.multimodal.inputs
import
MultiModalDataDict
,
MultiModalInputs
class
TextPrompt
(
TypedDict
):
...
...
@@ -147,46 +141,11 @@ class TokenInputs(TypedDict):
The original prompt text corresponding to the token IDs, if available.
"""
multi_modal_data
:
NotRequired
[
"MultiModalDataDict"
]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
multi_modal_inputs
:
NotRequired
[
"MultiModalKwargs"
]
"""
Optional multi-modal inputs to pass to the model,
if the model supports it.
"""
multi_modal_placeholders
:
NotRequired
[
"MultiModalPlaceholderDict"
]
"""
Placeholder ranges for the multi-modal data.
"""
multi_modal_hashes
:
NotRequired
[
list
[
str
]]
"""
The hashes of the multi-modal data.
"""
mm_processor_kwargs
:
NotRequired
[
dict
[
str
,
Any
]]
"""
Optional multi-modal processor kwargs to be forwarded to the
multimodal input mapper & processor. Note that if multiple modalities
have registered mappers etc for the model being considered, we attempt
to pass the mm_processor_kwargs to each of them.
"""
def
token_inputs
(
prompt_token_ids
:
list
[
int
],
token_type_ids
:
Optional
[
list
[
int
]]
=
None
,
prompt
:
Optional
[
str
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
multi_modal_inputs
:
Optional
[
"MultiModalKwargs"
]
=
None
,
multi_modal_hashes
:
Optional
[
list
[
str
]]
=
None
,
multi_modal_placeholders
:
Optional
[
"MultiModalPlaceholderDict"
]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
)
->
TokenInputs
:
"""Construct :class:`TokenInputs` from optional values."""
inputs
=
TokenInputs
(
type
=
"token"
,
prompt_token_ids
=
prompt_token_ids
)
...
...
@@ -195,16 +154,6 @@ def token_inputs(
inputs
[
"prompt"
]
=
prompt
if
token_type_ids
is
not
None
:
inputs
[
"token_type_ids"
]
=
token_type_ids
if
multi_modal_data
is
not
None
:
inputs
[
"multi_modal_data"
]
=
multi_modal_data
if
multi_modal_inputs
is
not
None
:
inputs
[
"multi_modal_inputs"
]
=
multi_modal_inputs
if
multi_modal_hashes
is
not
None
:
inputs
[
"multi_modal_hashes"
]
=
multi_modal_hashes
if
multi_modal_placeholders
is
not
None
:
inputs
[
"multi_modal_placeholders"
]
=
multi_modal_placeholders
if
mm_processor_kwargs
is
not
None
:
inputs
[
"mm_processor_kwargs"
]
=
mm_processor_kwargs
return
inputs
...
...
@@ -237,112 +186,6 @@ A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.
"""
@
dataclass
class
SingletonInputsAdapter
:
"""
Unified interface to access the components of :class:`SingletonInputs`.
"""
inputs
:
SingletonInputs
@
cached_property
def
prompt
(
self
)
->
Optional
[
str
]:
inputs
=
self
.
inputs
if
inputs
[
"type"
]
==
"token"
or
inputs
[
"type"
]
==
"multimodal"
:
return
inputs
.
get
(
"prompt"
)
assert_never
(
inputs
)
# type: ignore[arg-type]
@
cached_property
def
prompt_token_ids
(
self
)
->
list
[
int
]:
inputs
=
self
.
inputs
if
inputs
[
"type"
]
==
"token"
or
inputs
[
"type"
]
==
"multimodal"
:
return
inputs
.
get
(
"prompt_token_ids"
,
[])
assert_never
(
inputs
)
# type: ignore[arg-type]
@
cached_property
def
token_type_ids
(
self
)
->
list
[
int
]:
inputs
=
self
.
inputs
if
inputs
[
"type"
]
==
"token"
or
inputs
[
"type"
]
==
"multimodal"
:
return
inputs
.
get
(
"token_type_ids"
,
[])
assert_never
(
inputs
)
# type: ignore[arg-type]
@
cached_property
def
prompt_embeds
(
self
)
->
Optional
[
torch
.
Tensor
]:
inputs
=
self
.
inputs
if
inputs
[
"type"
]
==
"token"
or
inputs
[
"type"
]
==
"multimodal"
:
return
None
assert_never
(
inputs
)
# type: ignore[arg-type]
@
cached_property
def
multi_modal_data
(
self
)
->
"MultiModalDataDict"
:
inputs
=
self
.
inputs
if
inputs
[
"type"
]
==
"token"
:
return
inputs
.
get
(
"multi_modal_data"
,
{})
if
inputs
[
"type"
]
==
"multimodal"
:
return
inputs
.
get
(
"mm_kwargs"
,
{})
assert_never
(
inputs
)
# type: ignore[arg-type]
@
cached_property
def
multi_modal_inputs
(
self
)
->
Union
[
dict
,
"MultiModalKwargs"
]:
inputs
=
self
.
inputs
if
inputs
[
"type"
]
==
"token"
:
return
inputs
.
get
(
"multi_modal_inputs"
,
{})
if
inputs
[
"type"
]
==
"multimodal"
:
return
inputs
.
get
(
"mm_kwargs"
,
{})
assert_never
(
inputs
)
# type: ignore[arg-type]
@
cached_property
def
multi_modal_hashes
(
self
)
->
list
[
str
]:
inputs
=
self
.
inputs
if
inputs
[
"type"
]
==
"token"
:
return
inputs
.
get
(
"multi_modal_hashes"
,
[])
if
inputs
[
"type"
]
==
"multimodal"
:
# only the case when we use MultiModalInputs
return
inputs
.
get
(
"mm_hashes"
,
[])
# type: ignore[return-value]
assert_never
(
inputs
)
# type: ignore[arg-type]
@
cached_property
def
multi_modal_placeholders
(
self
)
->
"MultiModalPlaceholderDict"
:
inputs
=
self
.
inputs
if
inputs
[
"type"
]
==
"token"
:
return
inputs
.
get
(
"multi_modal_placeholders"
,
{})
if
inputs
[
"type"
]
==
"multimodal"
:
return
inputs
.
get
(
"mm_placeholders"
,
{})
assert_never
(
inputs
)
# type: ignore[arg-type]
@
cached_property
def
mm_processor_kwargs
(
self
)
->
dict
[
str
,
Any
]:
inputs
=
self
.
inputs
if
inputs
[
"type"
]
==
"token"
:
return
inputs
.
get
(
"mm_processor_kwargs"
,
{})
if
inputs
[
"type"
]
==
"multimodal"
:
return
{}
assert_never
(
inputs
)
# type: ignore[arg-type]
ProcessorInputs
=
Union
[
DecoderOnlyInputs
,
EncoderDecoderInputs
]
"""
The inputs to :data:`vllm.inputs.InputProcessor`.
...
...
vllm/inputs/preprocess.py
View file @
dcb5624a
...
...
@@ -13,7 +13,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalInputs
)
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.transformers_utils.tokenizer_group
import
Base
TokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
.data
import
(
DecoderOnlyInputs
,
EncoderDecoderInputs
,
ProcessorInputs
,
PromptType
,
SingletonInputs
,
SingletonPrompt
,
token_inputs
)
...
...
@@ -27,7 +27,7 @@ class InputPreprocessor:
def
__init__
(
self
,
model_config
:
ModelConfig
,
tokenizer
:
Optional
[
Base
TokenizerGroup
],
tokenizer
:
Optional
[
TokenizerGroup
],
mm_registry
:
MultiModalRegistry
=
MULTIMODAL_REGISTRY
,
)
->
None
:
super
().
__init__
()
...
...
@@ -36,7 +36,7 @@ class InputPreprocessor:
self
.
tokenizer
=
tokenizer
self
.
mm_registry
=
mm_registry
def
get_tokenizer_group
(
self
)
->
Base
TokenizerGroup
:
def
get_tokenizer_group
(
self
)
->
TokenizerGroup
:
if
self
.
tokenizer
is
None
:
raise
ValueError
(
"You cannot pass text prompts when "
"`skip_tokenizer_init` is True"
)
...
...
@@ -223,28 +223,6 @@ class InputPreprocessor:
lora_request
=
lora_request
,
add_special_tokens
=
add_special_tokens
)
def
_can_process_multimodal
(
self
)
->
bool
:
model_config
=
self
.
model_config
if
not
model_config
.
is_multimodal_model
:
raise
ValueError
(
"Your model does not support multi-modal inputs"
)
# Interim measure so we can handle models that have yet to be
# updated to use the new multi-modal processor
can_process_multimodal
=
self
.
mm_registry
.
has_processor
(
model_config
)
if
not
can_process_multimodal
:
from
vllm.model_executor.models.registry
import
_VLLM_MODELS
if
not
any
(
arch
in
_VLLM_MODELS
for
arch
in
model_config
.
architectures
):
logger
.
warning_once
(
"Your model uses the legacy input pipeline, which will be "
"removed in an upcoming release. "
"Please upgrade to the new multi-modal processing pipeline "
"(https://docs.vllm.ai/en/latest/design/mm_processing.html)"
)
return
can_process_multimodal
def
_process_multimodal
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
...
...
@@ -258,8 +236,7 @@ class InputPreprocessor:
returning the corresponding token IDs and metadata.
"""
# At the moment on model (PrithviGeoSpatialMAE) requires to be
# initialized without a tokenizer while using also multi-modal
# input.
# initialized without a tokenizer while using also multi-modal input
if
not
self
.
tokenizer
:
tokenizer
=
object
()
# Dummy
else
:
...
...
@@ -285,8 +262,7 @@ class InputPreprocessor:
)
->
MultiModalInputs
:
"""Async version of :meth:`_process_multimodal`."""
# At the moment on model (PrithviGeoSpatialMAE) requires to be
# initialized without a tokenizer while using also multi-modal
# input.
# initialized without a tokenizer while using also multi-modal input
if
not
self
.
tokenizer
:
tokenizer
=
object
()
# Dummy
else
:
...
...
@@ -343,7 +319,7 @@ class InputPreprocessor:
multi_modal_data
=
tokens_content
.
get
(
"multi_modal_data"
)
mm_processor_kwargs
=
tokens_content
.
get
(
"mm_processor_kwargs"
)
if
multi_modal_data
is
not
None
and
self
.
_can_process_multimodal
()
:
if
multi_modal_data
is
not
None
:
return
self
.
_process_multimodal
(
prompt_token_ids
,
multi_modal_data
,
...
...
@@ -355,8 +331,6 @@ class InputPreprocessor:
return
token_inputs
(
prompt_token_ids
=
prompt_token_ids
,
token_type_ids
=
token_type_ids
,
multi_modal_data
=
multi_modal_data
,
mm_processor_kwargs
=
mm_processor_kwargs
,
)
if
parsed
[
"type"
]
==
"text"
:
...
...
@@ -366,7 +340,7 @@ class InputPreprocessor:
multi_modal_data
=
text_content
.
get
(
"multi_modal_data"
)
mm_processor_kwargs
=
text_content
.
get
(
"mm_processor_kwargs"
)
if
multi_modal_data
is
not
None
and
self
.
_can_process_multimodal
()
:
if
multi_modal_data
is
not
None
:
return
self
.
_process_multimodal
(
prompt_text
,
multi_modal_data
,
...
...
@@ -383,8 +357,6 @@ class InputPreprocessor:
return
token_inputs
(
prompt
=
prompt_text
,
prompt_token_ids
=
prompt_token_ids
,
multi_modal_data
=
multi_modal_data
,
mm_processor_kwargs
=
mm_processor_kwargs
,
)
assert_never
(
parsed
)
...
...
@@ -417,7 +389,7 @@ class InputPreprocessor:
multi_modal_data
=
tokens_content
.
get
(
"multi_modal_data"
)
mm_processor_kwargs
=
tokens_content
.
get
(
"mm_processor_kwargs"
)
if
multi_modal_data
is
not
None
and
self
.
_can_process_multimodal
()
:
if
multi_modal_data
is
not
None
:
return
await
self
.
_process_multimodal_async
(
prompt_token_ids
,
multi_modal_data
,
...
...
@@ -426,11 +398,7 @@ class InputPreprocessor:
return_mm_hashes
=
return_mm_hashes
,
)
return
token_inputs
(
prompt_token_ids
=
prompt_token_ids
,
multi_modal_data
=
multi_modal_data
,
mm_processor_kwargs
=
mm_processor_kwargs
,
)
return
token_inputs
(
prompt_token_ids
=
prompt_token_ids
)
if
parsed
[
"type"
]
==
"text"
:
text_content
=
parsed
[
"content"
]
...
...
@@ -439,7 +407,7 @@ class InputPreprocessor:
multi_modal_data
=
text_content
.
get
(
"multi_modal_data"
)
mm_processor_kwargs
=
text_content
.
get
(
"mm_processor_kwargs"
)
if
multi_modal_data
is
not
None
and
self
.
_can_process_multimodal
()
:
if
multi_modal_data
is
not
None
:
return
await
self
.
_process_multimodal_async
(
prompt_text
,
multi_modal_data
,
...
...
@@ -456,8 +424,6 @@ class InputPreprocessor:
return
token_inputs
(
prompt
=
prompt_text
,
prompt_token_ids
=
prompt_token_ids
,
multi_modal_data
=
multi_modal_data
,
mm_processor_kwargs
=
mm_processor_kwargs
,
)
assert_never
(
parsed
)
...
...
@@ -594,15 +560,13 @@ class InputPreprocessor:
decoder_inputs
=
self
.
_prompt_to_llm_inputs
(
decoder_input
)
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
if
self
.
model_config
.
is_multimodal_model
:
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
encoder_inputs
,
decoder_inputs
))
else
:
inputs
=
self
.
_prompt_to_llm_inputs
(
prompt
)
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
if
self
.
model_config
.
is_multimodal_model
:
# Encoder-Decoder Multimodal model
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
...
...
@@ -637,15 +601,13 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
if
self
.
model_config
.
is_multimodal_model
:
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
encoder_inputs
,
decoder_inputs
))
else
:
inputs
=
await
self
.
_prompt_to_llm_inputs_async
(
prompt
)
if
self
.
model_config
.
is_multimodal_model
and
(
self
.
_can_process_multimodal
()):
if
self
.
model_config
.
is_multimodal_model
:
# Encoder-Decoder Multimodal model
encoder_inputs
,
decoder_inputs
=
(
self
.
_separate_enc_dec_inputs_from_mm_processor_outputs
(
...
...
vllm/inputs/registry.py
View file @
dcb5624a
# SPDX-License-Identifier: Apache-2.0
import
functools
from
collections
import
UserDict
from
collections.abc
import
Mapping
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
NamedTuple
,
Optional
,
Protocol
,
Union
)
from
typing
import
TYPE_CHECKING
,
Any
,
NamedTuple
,
Optional
,
Union
from
torch
import
nn
from
transformers
import
BatchFeature
,
PretrainedConfig
,
ProcessorMixin
from
typing_extensions
import
TypeVar
,
assert_never
from
typing_extensions
import
TypeVar
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.processor
import
cached_processor_from_config
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
(
ClassRegistry
,
get_allowed_kwarg_only_overrides
,
resolve_mm_processor_kwargs
)
from
.data
import
ProcessorInputs
,
SingletonInputs
from
.parse
import
split_enc_dec_inputs
from
vllm.utils
import
resolve_mm_processor_kwargs
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
...
...
@@ -26,8 +16,6 @@ if TYPE_CHECKING:
MultiModalRegistry
)
from
vllm.sequence
import
SequenceData
logger
=
init_logger
(
__name__
)
_T
=
TypeVar
(
"_T"
)
_C
=
TypeVar
(
"_C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
_P
=
TypeVar
(
"_P"
,
bound
=
ProcessorMixin
,
default
=
ProcessorMixin
)
...
...
@@ -172,142 +160,23 @@ class InputProcessingContext(InputContext):
raise
RuntimeError
(
msg
)
from
exc
N
=
TypeVar
(
"N"
,
bound
=
type
[
nn
.
Module
])
class
DummyData
(
NamedTuple
):
"""Dummy data used for profiling."""
"""
Dummy data used for profiling.
Note: This is only used in V0.
"""
seq_data
:
"SequenceData"
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
multi_modal_placeholders
:
Optional
[
"MultiModalPlaceholderDict"
]
=
None
class
DummyDataFactory
(
Protocol
):
def
__call__
(
self
,
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
**
mm_processor_kwargs
:
Any
,
)
->
DummyData
:
"""
Create dummy data to be inputted into the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
The :code:`mm_processor_kwargs` are overrides provided at
initialization time to values in the config whose values
may affect the number of tokens per instance.
"""
...
class
_MultiModalCounts
(
UserDict
[
str
,
int
]):
"""
Wraps `mm_counts` for a more informative error message
when attempting to access a plugin that does not exist.
"""
def
__getitem__
(
self
,
key
:
str
)
->
int
:
try
:
return
super
().
__getitem__
(
key
)
except
KeyError
as
exc
:
msg
=
(
f
"There is no multi-modal plugin with the key:
{
key
}
. "
f
"Available keys:
{
set
(
self
.
keys
())
}
"
)
raise
KeyError
(
msg
)
from
exc
InputProcessor
=
Callable
[[
InputContext
,
ProcessorInputs
],
ProcessorInputs
]
"""Preprocess the inputs to the model."""
class
InputRegistry
:
"""
A registry to dispatch data processing
according to the target model.
Note: This is only used in V0.
"""
def
__init__
(
self
)
->
None
:
self
.
_dummy_factories_by_model_type
=
\
ClassRegistry
[
nn
.
Module
,
DummyDataFactory
]()
self
.
_dummy_encoder_factories_by_model_type
=
\
ClassRegistry
[
nn
.
Module
,
DummyDataFactory
]()
self
.
_input_processors_by_model_type
=
\
ClassRegistry
[
nn
.
Module
,
InputProcessor
]()
def
_default_dummy_data_factory
(
self
,
ctx
:
InputContext
,
seq_len
:
int
,
mm_counts
:
Mapping
[
str
,
int
],
)
->
DummyData
:
"""
The default dummy data factory represents the longest possible text
that can be inputted to the model.
Note:
:data:`InputProcessor` is not applied to the dummy data.
"""
# Avoid circular import
from
vllm.sequence
import
SequenceData
return
DummyData
(
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
)))
def
register_dummy_data
(
self
,
factory
:
DummyDataFactory
):
"""
Register a dummy data factory to a model class.
During memory profiling, the provided function is invoked to create
dummy data to be inputted into the model. The resulting memory usage
should be an upper bound of what the model would use at inference time.
"""
def
wrapper
(
model_cls
:
N
)
->
N
:
if
self
.
_dummy_factories_by_model_type
.
contains
(
model_cls
,
strict
=
True
):
logger
.
warning
(
"Model class %s already has dummy data "
"registered to %s. It is overwritten by the new one."
,
model_cls
,
self
)
self
.
_dummy_factories_by_model_type
[
model_cls
]
=
factory
return
model_cls
return
wrapper
def
_get_dummy_data_factory
(
self
,
model_cls
:
type
[
nn
.
Module
]):
return
self
.
_dummy_factories_by_model_type
\
.
get
(
model_cls
,
self
.
_default_dummy_data_factory
)
def
register_dummy_encoder_data
(
self
,
factory
:
DummyDataFactory
):
"""
Register a dummy encoder data factory to a model class
This is similar to :meth:`~register_dummy_data`, but for encoder input.
"""
def
wrapper
(
model_cls
:
N
)
->
N
:
if
self
.
_dummy_encoder_factories_by_model_type
.
contains
(
model_cls
,
strict
=
True
):
logger
.
warning
(
"Model class %s already has dummy encoder data "
"registered to %s. It is overwritten by the new one."
,
model_cls
,
self
)
self
.
_dummy_encoder_factories_by_model_type
[
model_cls
]
=
factory
return
model_cls
return
wrapper
def
_get_dummy_encoder_data_factory
(
self
,
model_cls
:
type
[
nn
.
Module
]):
return
self
.
_dummy_encoder_factories_by_model_type
\
.
get
(
model_cls
,
self
.
_default_dummy_data_factory
)
def
dummy_data_for_profiling
(
self
,
model_config
:
"ModelConfig"
,
...
...
@@ -319,169 +188,25 @@ class InputRegistry:
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
Note:
This should be called after
:meth:`~MultiModalRegistry.init_mm_limits_per_prompt`.
"""
# Avoid circular import
from
vllm.model_executor.model_loader
import
get_model_architecture
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal.profiling
import
MultiModalProfiler
from
vllm.sequence
import
SequenceData
if
mm_registry
.
has_processor
(
model_config
):
processor
=
mm_registry
.
create_processor
(
model_config
,
disable_cache
=
True
)
profiler
=
MultiModalProfiler
(
processor
)
dummy_data_v1
=
(
profiler
.
get_encoder_dummy_data
(
seq_len
)
if
is_encoder_data
else
profiler
.
get_decoder_dummy_data
(
seq_len
))
_seq_data
=
SequenceData
.
from_seqs
(
dummy_data_v1
.
prompt_token_ids
)
# type: ignore[attr-defined]
dummy_data
=
DummyData
(
seq_data
=
_seq_data
,
multi_modal_data
=
getattr
(
dummy_data_v1
,
"multi_modal_data"
,
None
),
multi_modal_placeholders
=
getattr
(
dummy_data_v1
,
"multi_modal_placeholders"
,
None
),
)
else
:
model_cls
,
_
=
get_model_architecture
(
model_config
)
if
is_encoder_data
:
dummy_factory
=
self
.
_get_dummy_encoder_data_factory
(
model_cls
)
else
:
dummy_factory
=
self
.
_get_dummy_data_factory
(
model_cls
)
mm_counts
=
mm_registry
.
get_mm_limits_per_prompt
(
model_config
)
mm_processor_kwargs
=
get_allowed_kwarg_only_overrides
(
dummy_factory
,
overrides
=
model_config
.
mm_processor_kwargs
,
requires_kw_only
=
False
,
allow_var_kwargs
=
True
,
)
dummy_data
=
dummy_factory
(
InputContext
(
model_config
),
seq_len
,
_MultiModalCounts
(
mm_counts
),
**
mm_processor_kwargs
)
# Having more tokens is over-conservative but otherwise fine
num_tokens
=
dummy_data
.
seq_data
.
prompt_token_ids
if
len
(
num_tokens
)
<
seq_len
:
if
is_encoder_data
:
logger
.
warning_once
(
f
"Expected at least
{
seq_len
}
dummy encoder tokens for "
f
"profiling, but found
{
len
(
num_tokens
)
}
tokens instead."
)
else
:
raise
AssertionError
(
f
"Expected at least
{
seq_len
}
dummy tokens for profiling, "
f
"but found
{
len
(
num_tokens
)
}
tokens instead."
)
if
(
dummy_data
.
multi_modal_data
is
not
None
and
not
isinstance
(
dummy_data
.
multi_modal_data
,
MultiModalKwargs
)):
for
k
,
v
in
dummy_data
.
multi_modal_data
.
items
():
num_items
=
len
(
v
)
if
isinstance
(
v
,
list
)
else
1
num_expected
=
mm_counts
[
k
]
assert
num_items
>=
num_expected
,
(
f
"Expected at least
{
num_expected
}
dummy '
{
k
}
' instances "
f
"for profiling, but found
{
num_items
}
instances instead."
)
return
dummy_data
def
_default_input_processor
(
self
,
ctx
:
InputContext
,
inputs
:
ProcessorInputs
,
**
kwargs
:
object
,
)
->
ProcessorInputs
:
"""The default input processor is a no-op."""
return
inputs
def
register_input_processor
(
self
,
processor
:
InputProcessor
):
"""
Register an input processor to a model class.
The provided function is invoked on each input to the model. This
happens before
:meth:`~vllm.multimodal.registry.MultiModalRegistry.map_input`.
"""
def
wrapper
(
model_cls
:
N
)
->
N
:
if
self
.
_input_processors_by_model_type
.
contains
(
model_cls
,
strict
=
True
):
logger
.
warning
(
"Model class %s already has input processor "
"registered to %s. It is overwritten by the new one."
,
model_cls
,
self
)
self
.
_input_processors_by_model_type
[
model_cls
]
=
processor
return
model_cls
if
not
model_config
.
is_multimodal_model
:
seq_data
=
SequenceData
.
from_prompt_token_counts
((
0
,
seq_len
))
return
DummyData
(
seq_data
=
seq_data
)
return
wrapper
# Encoder dummy data does not contain multi-modal data
if
is_encoder_data
:
enc_data
=
mm_registry
.
get_encoder_dummy_data
(
model_config
,
seq_len
)
seq_data
=
SequenceData
.
from_seqs
(
enc_data
.
prompt_token_ids
)
return
DummyData
(
seq_data
=
seq_data
)
def
_get_model_input_processor
(
self
,
model_cls
:
type
[
nn
.
Module
]):
return
self
.
_input_processors_by_model_type
\
.
get
(
model_cls
,
self
.
_default_input_processor
)
def
_ensure_mm_kwargs
(
self
,
inputs
:
SingletonInputs
,
mm_processor_kwargs
:
dict
[
str
,
Any
],
):
if
inputs
[
"type"
]
==
"token"
:
# In case the input processor for that model fails to set it
if
"mm_processor_kwargs"
not
in
inputs
:
inputs
[
"mm_processor_kwargs"
]
=
mm_processor_kwargs
elif
inputs
[
"type"
]
==
"multimodal"
:
# Be more strict in V2
assert
"mm_kwargs"
in
inputs
else
:
assert_never
(
inputs
[
"type"
])
# type: ignore[arg-type]
def
process_input
(
self
,
model_config
:
"ModelConfig"
,
inputs
:
ProcessorInputs
)
->
ProcessorInputs
:
"""
Apply an input processor to an instance of model inputs.
The model is identified by ``model_config``.
"""
# Avoid circular import
from
vllm.model_executor.model_loader
import
get_model_architecture
model_cls
,
_
=
get_model_architecture
(
model_config
)
processor
=
self
.
_get_model_input_processor
(
model_cls
)
# Handle multimodal processor kwargs with priority:
# Inference kwargs -> Init kwargs -> {}
# If it's empty, it'll fall back to the default kwarg values
mm_processor_kwargs
=
resolve_mm_processor_kwargs
(
model_config
.
mm_processor_kwargs
,
inputs
.
get
(
"mm_processor_kwargs"
,
{}),
# type: ignore
processor
,
requires_kw_only
=
False
,
allow_var_kwargs
=
True
,
)
dec_data
=
mm_registry
.
get_decoder_dummy_data
(
model_config
,
seq_len
)
processed_inputs
=
processor
(
InputContext
(
model_config
),
inputs
,
**
mm_processor_kwarg
s
,
return
DummyData
(
seq_data
=
SequenceData
.
from_seqs
(
dec_data
.
prompt_token_ids
),
multi_modal_data
=
dec_data
.
multi_modal_data
,
multi_modal_placeholders
=
dec_data
.
multi_modal_placeholder
s
,
)
encoder_inputs
,
decoder_inputs
=
split_enc_dec_inputs
(
processed_inputs
)
if
encoder_inputs
is
not
None
:
self
.
_ensure_mm_kwargs
(
encoder_inputs
,
mm_processor_kwargs
)
if
decoder_inputs
is
not
None
:
self
.
_ensure_mm_kwargs
(
decoder_inputs
,
mm_processor_kwargs
)
return
processed_inputs
def
create_input_processor
(
self
,
model_config
:
"ModelConfig"
):
"""
Create an input processor (see :meth:`_process_input`) for a
specific model.
"""
return
functools
.
partial
(
self
.
process_input
,
model_config
)
vllm/lora/resolver.py
0 → 100644
View file @
dcb5624a
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
typing
import
AbstractSet
,
Dict
,
Optional
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
logger
=
init_logger
(
__name__
)
class
LoRAResolver
(
ABC
):
"""Base class for LoRA adapter resolvers.
This class defines the interface for resolving and fetching LoRA adapters.
Implementations of this class should handle the logic for locating and
downloading LoRA adapters from various sources (e.g. S3, cloud storage,
etc.).
"""
@
abstractmethod
async
def
resolve_lora
(
self
,
base_model_name
:
str
,
lora_name
:
str
)
->
Optional
[
LoRARequest
]:
"""Abstract method to resolve and fetch a LoRA model adapter.
Implements logic to locate and download LoRA adapter based on the name.
Implementations might fetch from a blob storage or other sources.
Args:
base_model_name: The name/identifier of the base model to resolve.
lora_name: The name/identifier of the LoRA model to resolve.
Returns:
Optional[LoRARequest]: The resolved LoRA model information, or None
if the LoRA model cannot be found.
"""
pass
@
dataclass
class
_LoRAResolverRegistry
:
resolvers
:
Dict
[
str
,
LoRAResolver
]
=
field
(
default_factory
=
dict
)
def
get_supported_resolvers
(
self
)
->
AbstractSet
[
str
]:
"""Get all registered resolver names."""
return
self
.
resolvers
.
keys
()
def
register_resolver
(
self
,
resolver_name
:
str
,
resolver
:
LoRAResolver
,
)
->
None
:
"""Register a LoRA resolver.
Args:
resolver_name: Name to register the resolver under.
resolver: The LoRA resolver instance to register.
"""
if
resolver_name
in
self
.
resolvers
:
logger
.
warning
(
"LoRA resolver %s is already registered, and will be "
"overwritten by the new resolver instance %s."
,
resolver_name
,
resolver
)
self
.
resolvers
[
resolver_name
]
=
resolver
def
get_resolver
(
self
,
resolver_name
:
str
)
->
LoRAResolver
:
"""Get a registered resolver instance by name.
Args:
resolver_name: Name of the resolver to get.
Returns:
The resolver instance.
Raises:
KeyError: If the resolver is not found in the registry.
"""
if
resolver_name
not
in
self
.
resolvers
:
raise
KeyError
(
f
"LoRA resolver '
{
resolver_name
}
' not found. "
f
"Available resolvers:
{
list
(
self
.
resolvers
.
keys
())
}
"
)
return
self
.
resolvers
[
resolver_name
]
LoRAResolverRegistry
=
_LoRAResolverRegistry
()
Prev
1
…
18
19
20
21
22
23
24
25
26
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment