Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af7f4372
Commit
af7f4372
authored
Sep 03, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1
parents
5e19cdef
09c77926
Changes
465
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2160 additions
and
784 deletions
+2160
-784
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+394
-131
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+583
-174
vllm/engine/metrics.py
vllm/engine/metrics.py
+70
-102
vllm/engine/metrics_types.py
vllm/engine/metrics_types.py
+88
-0
vllm/engine/output_processor/interfaces.py
vllm/engine/output_processor/interfaces.py
+2
-3
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+2
-3
vllm/engine/output_processor/stop_checker.py
vllm/engine/output_processor/stop_checker.py
+2
-4
vllm/engine/protocol.py
vllm/engine/protocol.py
+28
-11
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+13
-7
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+122
-47
vllm/entrypoints/launcher.py
vllm/entrypoints/launcher.py
+58
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+128
-47
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+142
-36
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+27
-4
vllm/entrypoints/openai/logits_processors.py
vllm/entrypoints/openai/logits_processors.py
+12
-9
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+68
-35
vllm/entrypoints/openai/rpc/__init__.py
vllm/entrypoints/openai/rpc/__init__.py
+10
-2
vllm/entrypoints/openai/rpc/client.py
vllm/entrypoints/openai/rpc/client.py
+287
-82
vllm/entrypoints/openai/rpc/server.py
vllm/entrypoints/openai/rpc/server.py
+87
-74
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+37
-11
No files found.
Too many changes to show.
To preserve performance only
465 of 465+
files are displayed.
Plain diff
Email patch
vllm/engine/async_llm_engine.py
View file @
af7f4372
import
asyncio
import
time
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
(
Async
It
erator
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
from
typing
import
(
Any
,
Async
Gen
erator
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
from
transformers
import
PreTrainedTokenizer
import
torch
from
typing_extensions
import
assert_never
import
vllm.envs
as
envs
from
vllm.config
import
(
DecodingConfig
,
EngineConfig
,
LoRAConfig
,
ModelConfig
,
...
...
@@ -12,19 +14,25 @@ from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.metrics
import
StatLoggerBase
from
vllm.engine.llm_engine
import
(
DecoderPromptComponents
,
LLMEngine
,
PromptComponents
)
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.ray_utils
import
initialize_ray_cluster
,
ray
from
vllm.inputs
import
LLMInputs
,
PromptInputs
from
vllm.inputs
import
(
EncoderDecoderLLMInputs
,
LLMInputs
,
PromptInputs
,
SingletonPromptInputs
)
from
vllm.inputs.parse
import
is_explicit_encoder_decoder_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceGroupMetadata
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
print_warning_once
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
...
...
@@ -58,41 +66,61 @@ def _log_task_completion(task: asyncio.Task,
error_callback
(
exception
)
raise
AsyncEngineDeadError
(
"Task finished unexpectedly. This should never happen! "
"Please open an issue on Github. See stack trace above for the"
"Please open an issue on Github. See stack trace above for the
"
"actual cause."
)
from
e
STOP_ITERATION
=
Exception
()
# Sentinel
class
AsyncStream
:
"""A stream of RequestOutputs or EmbeddingRequestOutputs for a request
that can be iterated over asynchronously."""
that can be iterated over asynchronously
via an async generator
."""
def
__init__
(
self
,
request_id
:
str
)
->
None
:
def
__init__
(
self
,
request_id
:
str
,
cancel
:
Callable
[[
str
],
None
]
)
->
None
:
self
.
request_id
=
request_id
self
.
_cancel
=
cancel
self
.
_queue
:
asyncio
.
Queue
=
asyncio
.
Queue
()
self
.
_finished
=
False
def
put
(
self
,
item
:
Union
[
RequestOutput
,
EmbeddingRequestOutput
,
Exception
])
->
None
:
if
self
.
_finished
:
return
if
not
self
.
_finished
:
self
.
_queue
.
put_nowait
(
item
)
def
finish
(
self
)
->
None
:
self
.
_queue
.
put_nowait
(
StopAsyncIteration
())
def
finish
(
self
,
exception
:
Optional
[
Union
[
BaseException
,
Type
[
BaseException
]]]
=
None
,
)
->
None
:
if
not
self
.
_finished
:
self
.
_finished
=
True
self
.
_queue
.
put_nowait
(
exception
if
self
.
_is_raisable
(
exception
)
else
STOP_ITERATION
)
@
property
def
finished
(
self
)
->
bool
:
return
self
.
_finished
def
__aiter__
(
self
):
return
self
async
def
__anext__
(
self
)
->
Union
[
RequestOutput
,
EmbeddingRequestOutput
]:
async
def
generator
(
self
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]:
try
:
while
True
:
result
=
await
self
.
_queue
.
get
()
if
isinstance
(
result
,
Exception
):
if
self
.
_is_raisable
(
result
):
if
result
==
STOP_ITERATION
:
return
raise
result
return
result
yield
result
except
GeneratorExit
:
self
.
_cancel
(
self
.
request_id
)
raise
asyncio
.
CancelledError
from
None
@
staticmethod
def
_is_raisable
(
value
:
Any
):
return
isinstance
(
value
,
BaseException
)
or
\
(
isinstance
(
value
,
type
)
and
\
issubclass
(
value
,
BaseException
))
class
RequestTracker
:
...
...
@@ -100,7 +128,7 @@ class RequestTracker:
def
__init__
(
self
)
->
None
:
self
.
_request_streams
:
Dict
[
str
,
AsyncStream
]
=
{}
self
.
_
finish
ed_requests
:
asyncio
.
Queue
[
str
]
=
asyncio
.
Queue
()
self
.
_
abort
ed_requests
:
asyncio
.
Queue
[
str
]
=
asyncio
.
Queue
()
self
.
_new_requests
:
asyncio
.
Queue
[
Tuple
[
AsyncStream
,
dict
]]
=
asyncio
.
Queue
()
self
.
new_requests_event
=
asyncio
.
Event
()
...
...
@@ -117,12 +145,12 @@ class RequestTracker:
"""Propagate an exception to request streams
(all if request_id is None)."""
if
request_id
is
not
None
:
self
.
_request_streams
[
request_id
].
put
(
exc
)
self
.
abort_request
(
request_id
)
self
.
abort_request
(
request_id
,
exception
=
exc
)
else
:
for
rid
,
stream
in
self
.
_request_streams
.
items
():
stream
.
put
(
exc
)
self
.
abort_request
(
rid
)
# NB: tuple() used here because self.abort_request pops the stream
# out of self._request_streams, so we can't iterate on it directly
for
rid
in
tuple
(
self
.
_request_streams
.
keys
()):
self
.
abort_request
(
rid
,
exception
=
exc
)
def
process_request_output
(
self
,
request_output
:
Union
[
RequestOutput
,
...
...
@@ -131,26 +159,31 @@ class RequestTracker:
verbose
:
bool
=
False
)
->
None
:
"""Process a request output from the engine."""
request_id
=
request_output
.
request_id
finished
=
request_output
.
finished
if
finished
:
stream
=
self
.
_request_streams
.
pop
(
request_id
,
None
)
else
:
stream
=
self
.
_request_streams
.
get
(
request_id
)
# Guard against a KeyError which can occur if the request was aborted
# while the output was generated
if
(
stream
:
=
self
.
_request_streams
.
get
(
request_id
))
is
not
None
:
if
stream
is
not
None
:
stream
.
put
(
request_output
)
if
request_output
.
finished
:
if
verbose
:
if
finished
:
stream
.
finish
()
if
verbose
and
finished
:
logger
.
info
(
"Finished request %s."
,
request_id
)
self
.
abort_request
(
request_id
)
def
process_exception
(
self
,
request_id
:
str
,
exception
:
Exception
,
exception
:
Base
Exception
,
*
,
verbose
:
bool
=
False
)
->
None
:
"""Propagate an exception from the engine."""
self
.
_request_streams
[
request_id
].
put
(
exception
)
if
verbose
:
logger
.
info
(
"Finished request %s."
,
request_id
)
self
.
abort_request
(
request_id
)
self
.
abort_request
(
request_id
,
exception
=
exception
)
def
add_request
(
self
,
request_id
:
str
,
...
...
@@ -162,7 +195,8 @@ class RequestTracker:
if
request_id
in
self
.
_request_streams
:
raise
KeyError
(
f
"Request
{
request_id
}
already exists."
)
stream
=
AsyncStream
(
request_id
)
abort_request
=
partial
(
self
.
abort_request
,
verbose
=
verbose
)
stream
=
AsyncStream
(
request_id
,
abort_request
)
self
.
_new_requests
.
put_nowait
((
stream
,
{
"request_id"
:
request_id
,
**
engine_add_request_kwargs
...
...
@@ -175,38 +209,41 @@ class RequestTracker:
return
stream
def
abort_request
(
self
,
request_id
:
str
,
*
,
verbose
:
bool
=
False
)
->
None
:
def
abort_request
(
self
,
request_id
:
str
,
*
,
exception
:
Optional
[
Union
[
BaseException
,
Type
[
BaseException
]]]
=
None
,
verbose
:
bool
=
False
)
->
None
:
"""Abort a request during next background loop iteration."""
if
verbose
:
logger
.
info
(
"Aborted request %s."
,
request_id
)
self
.
_
finish
ed_requests
.
put_nowait
(
request_id
)
self
.
_
abort
ed_requests
.
put_nowait
(
request_id
)
if
request_id
not
in
self
.
_request_streams
or
self
.
_request_streams
[
request_id
].
finished
:
# The request has already finished or been aborted.
return
self
.
_request_streams
[
request_id
].
finish
()
stream
=
self
.
_request_streams
.
pop
(
request_id
,
None
)
if
stream
is
not
None
:
stream
.
finish
(
exception
=
exception
)
def
get_new_and_
finish
ed_requests
(
self
)
->
Tuple
[
List
[
Dict
],
Set
[
str
]]:
def
get_new_and_
abort
ed_requests
(
self
)
->
Tuple
[
List
[
Dict
],
Set
[
str
]]:
"""Get the new requests and finished requests to be
sent to the engine."""
new_requests
:
List
[
Dict
]
=
[]
finished_requests
:
Set
[
str
]
=
set
()
while
not
self
.
_
finish
ed_requests
.
empty
():
request_id
=
self
.
_
finish
ed_requests
.
get_nowait
()
while
not
self
.
_
abort
ed_requests
.
empty
():
request_id
=
self
.
_
abort
ed_requests
.
get_nowait
()
finished_requests
.
add
(
request_id
)
self
.
_request_streams
.
pop
(
request_id
,
None
)
while
not
self
.
_new_requests
.
empty
():
stream
,
new_request
=
self
.
_new_requests
.
get_nowait
()
if
stream
.
request_id
in
finished_requests
:
request_id
=
stream
.
request_id
if
request_id
in
finished_requests
:
# The request has already been aborted.
stream
.
finish
()
continue
self
.
_request_streams
[
stream
.
request_id
]
=
stream
stream
.
finish
(
asyncio
.
CancelledError
)
finished_requests
.
discard
(
request_id
)
else
:
self
.
_request_streams
[
request_id
]
=
stream
new_requests
.
append
(
new_request
)
return
new_requests
,
finished_requests
...
...
@@ -220,9 +257,25 @@ class RequestTracker:
return
not
self
.
_new_requests
.
empty
()
@
dataclass
class
SchedulerOutputState
:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
last_output
:
Optional
[
SamplerOutput
]
=
None
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
class
_AsyncLLMEngine
(
LLMEngine
):
"""Extension of LLMEngine to add async methods."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
pipeline_parallel_size
=
\
self
.
parallel_config
.
pipeline_parallel_size
self
.
cached_scheduler_outputs
=
[
SchedulerOutputState
()
for
_
in
range
(
pipeline_parallel_size
)
]
async
def
step_async
(
self
,
virtual_engine
:
int
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
...
...
@@ -235,13 +288,39 @@ class _AsyncLLMEngine(LLMEngine):
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
# these are cached outputs from previous iterations. None if on first
# iteration
cached_outputs
=
self
.
cached_scheduler_outputs
[
virtual_engine
]
seq_group_metadata_list
=
cached_outputs
.
seq_group_metadata_list
scheduler_outputs
=
cached_outputs
.
scheduler_outputs
# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
[
virtual_engine
].
schedule
()
if
(
self
.
scheduler_config
.
is_multi_step
and
scheduler_outputs
.
num_lookahead_slots
>
0
):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self
.
_cache_scheduler_outputs_for_multi_step
(
virtual_engine
,
seq_group_metadata_list
,
scheduler_outputs
)
assert
seq_group_metadata_list
is
not
None
assert
scheduler_outputs
is
not
None
if
not
scheduler_outputs
.
is_empty
():
# Execute the model.
finished_requests_ids
=
self
.
scheduler
[
virtual_engine
].
get_and_reset_finished_requests_ids
()
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids
=
\
self
.
_get_last_sampled_token_ids
(
virtual_engine
)
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
...
...
@@ -250,15 +329,35 @@ class _AsyncLLMEngine(LLMEngine):
virtual_engine
=
virtual_engine
,
num_lookahead_slots
=
scheduler_outputs
.
num_lookahead_slots
,
running_queue_size
=
scheduler_outputs
.
running_queue_size
,
finished_requests_ids
=
finished_requests_ids
)
finished_requests_ids
=
finished_requests_ids
,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids
=
last_sampled_token_ids
)
# Execute the model.
output
=
await
self
.
model_executor
.
execute_model_async
(
execute_model_req
)
# we need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if
self
.
scheduler_config
.
is_multi_step
:
self
.
_update_cached_scheduler_output
(
virtual_engine
,
output
)
else
:
output
=
[]
# Finish the current step for all the sequence groups.
if
self
.
scheduler_config
.
is_multi_step
:
for
seq_group
in
seq_group_metadata_list
:
seq_group
.
finish_step
()
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# clear the cache if we have finished all the steps
if
self
.
scheduler_config
.
is_multi_step
:
self
.
cached_scheduler_outputs
[
virtual_engine
]
=
SchedulerOutputState
()
request_outputs
=
self
.
_process_model_outputs
(
output
,
scheduler_outputs
.
scheduled_seq_groups
,
scheduler_outputs
.
ignored_seq_groups
,
seq_group_metadata_list
)
else
:
request_outputs
=
[]
# Log stats.
self
.
do_log_stats
(
scheduler_outputs
,
output
)
...
...
@@ -268,42 +367,196 @@ class _AsyncLLMEngine(LLMEngine):
return
request_outputs
def
_has_remaining_steps
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
)
->
bool
:
if
(
not
self
.
scheduler_config
.
is_multi_step
or
not
seq_group_metadata_list
):
return
False
# TODO(will) this is a sanity check for nowto make sure that all the
# seqs are on the same steps. Eventually we will want to do some sort of
# dynamic scheduling when doing multi-step decoding.
ref_remaining_steps
=
seq_group_metadata_list
[
0
].
state
.
remaining_steps
if
any
([
seq_group
.
state
.
remaining_steps
!=
ref_remaining_steps
for
seq_group
in
seq_group_metadata_list
[
1
:]
]):
raise
AssertionError
((
"All running sequence groups should "
"have the same remaining steps."
))
return
ref_remaining_steps
>
0
def
_cache_scheduler_outputs_for_multi_step
(
self
,
virtual_engine
:
int
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
scheduler_outputs
:
SchedulerOutputs
)
->
None
:
self
.
cached_scheduler_outputs
[
virtual_engine
].
seq_group_metadata_list
=
seq_group_metadata_list
self
.
cached_scheduler_outputs
[
virtual_engine
].
scheduler_outputs
=
\
scheduler_outputs
self
.
cached_scheduler_outputs
[
virtual_engine
].
last_output
=
None
def
_get_last_sampled_token_ids
(
self
,
virtual_engine
:
int
)
->
Optional
[
torch
.
Tensor
]:
cached_last_output
=
self
.
cached_scheduler_outputs
[
virtual_engine
].
last_output
if
(
self
.
scheduler_config
.
is_multi_step
and
self
.
parallel_config
.
pipeline_parallel_size
>
1
and
cached_last_output
is
not
None
and
cached_last_output
.
sampled_token_ids_cpu
is
not
None
):
return
cached_last_output
.
sampled_token_ids_cpu
return
None
def
_update_cached_scheduler_output
(
self
,
virtual_engine
:
int
,
output
:
List
[
Optional
[
SamplerOutput
]])
->
None
:
if
(
self
.
parallel_config
.
pipeline_parallel_size
>
1
and
len
(
output
)
>
0
and
output
[
0
]
is
not
None
):
last_output
=
output
[
-
1
]
assert
last_output
is
not
None
assert
last_output
.
sampled_token_ids_cpu
is
not
None
assert
last_output
.
sampled_token_ids
is
None
assert
last_output
.
sampled_token_probs
is
None
self
.
cached_scheduler_outputs
[
virtual_engine
].
last_output
=
last_output
async
def
stop_remote_worker_execution_loop_async
(
self
)
->
None
:
"""Stop the remote worker execution loop."""
await
self
.
model_executor
.
stop_remote_worker_execution_loop_async
()
async
def
process_model_inputs
_async
(
async
def
_tokenize_prompt
_async
(
self
,
prompt
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
],
)
->
List
[
int
]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer
=
self
.
get_tokenizer_group
(
missing_msg
=
"prompts must be None if skip_tokenizer_init is True"
)
return
await
tokenizer
.
encode_async
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
)
async
def
_extract_prompt_components_async
(
self
,
inputs
:
SingletonPromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
PromptComponents
:
"""Async version of :meth:`_extract_prompt_components`."""
if
isinstance
(
inputs
,
str
):
prompt
=
inputs
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
multi_modal_data
=
None
elif
isinstance
(
inputs
,
dict
):
if
"prompt_token_ids"
in
inputs
:
prompt
=
None
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
else
:
# NOTE: This extra assignment is required to pass mypy
prompt
=
parsed_prompt
=
inputs
[
"prompt"
]
prompt_token_ids
=
await
self
.
_tokenize_prompt_async
(
parsed_prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
else
:
assert_never
(
inputs
)
return
prompt
,
prompt_token_ids
,
multi_modal_data
async
def
_process_encoder_decoder_prompt_async
(
self
,
inputs
:
PromptInputs
,
request_id
:
str
,
)
->
EncoderDecoderLLMInputs
:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps
:
PromptComponents
decoder_comps
:
DecoderPromptComponents
if
is_explicit_encoder_decoder_prompt
(
inputs
):
encoder_task
=
self
.
_extract_prompt_components_async
(
inputs
[
"encoder_prompt"
],
request_id
=
request_id
,
)
if
(
decoder_input
:
=
inputs
[
"decoder_prompt"
])
is
None
:
encoder_comps
=
await
encoder_task
decoder_comps
=
None
,
None
,
None
else
:
decoder_task
=
self
.
_extract_prompt_components_async
(
decoder_input
,
request_id
=
request_id
,
)
encoder_comps
,
decoder_comps
=
await
asyncio
.
gather
(
encoder_task
,
decoder_task
)
else
:
encoder_comps
=
await
self
.
_extract_prompt_components_async
(
inputs
,
request_id
=
request_id
,
)
decoder_comps
=
None
,
None
,
None
return
self
.
_build_enc_dec_llm_inputs
(
encoder_comps
,
decoder_comps
)
async
def
_process_decoder_only_prompt_async
(
self
,
inputs
:
SingletonPromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
LLMInputs
:
if
isinstance
(
inputs
,
str
):
inputs
=
{
"prompt"
:
inputs
}
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps
=
await
self
.
_extract_prompt_components_async
(
inputs
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
if
"prompt_token_ids"
not
in
inputs
:
tokenizer
=
self
.
get_tokenizer_group
(
"prompts must be None if "
"skip_tokenizer_init is True"
)
return
self
.
_build_decoder_only_llm_inputs
(
prompt_comps
,
prompt_adapter_request
=
prompt_adapter_request
,
)
prompt_token_ids
=
await
tokenizer
.
encode_async
(
async
def
process_model_inputs_async
(
self
,
inputs
:
PromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
Union
[
LLMInputs
,
EncoderDecoderLLMInputs
]:
"""Async version of :meth:`process_model_inputs`."""
if
self
.
is_encoder_decoder_model
():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs
=
await
self
.
_process_encoder_decoder_prompt_async
(
inputs
,
request_id
=
request_id
,
prompt
=
inputs
[
"prompt"
],
lora_request
=
lora_request
)
)
else
:
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
if
prompt_adapter_request
:
prompt_token_ids
=
[
0
]
*
prompt_adapter_request
.
prompt_adapter_num_virtual_tokens
+
\
prompt_token_ids
if
is_explicit_encoder_decoder_prompt
(
inputs
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
"to decoder-only models"
)
llm_inputs
=
LLMInputs
(
prompt_token_ids
=
prompt_token_ids
,
prompt
=
inputs
.
get
(
"prompt"
),
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
))
# Decoder-only operation
model_inputs
=
await
self
.
_process_decoder_only_prompt_async
(
inputs
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
return
self
.
input_processor
(
llm
_inputs
)
return
self
.
input_processor
(
model
_inputs
)
async
def
add_request_async
(
self
,
...
...
@@ -315,6 +568,7 @@ class _AsyncLLMEngine(LLMEngine):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
"""Async version of :meth:`add_request`."""
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
...
...
@@ -322,10 +576,11 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time
=
time
.
time
()
processed_inputs
=
await
self
.
process_model_inputs_async
(
inputs
,
request_id
=
request_id
,
inputs
=
inputs
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
)
self
.
_add_processed_request
(
request_id
=
request_id
,
...
...
@@ -380,6 +635,20 @@ class AsyncLLMEngine:
self
.
log_requests
=
log_requests
self
.
engine
=
self
.
_init_engine
(
*
args
,
**
kwargs
)
if
self
.
engine_use_ray
:
print_warning_once
(
"DEPRECATED. `--engine-use-ray` is deprecated and will "
"be removed in a future update. "
"See https://github.com/vllm-project/vllm/issues/7045."
)
if
envs
.
VLLM_ALLOW_ENGINE_USE_RAY
:
print_warning_once
(
"VLLM_ALLOW_ENGINE_USE_RAY is set, force engine use Ray"
)
else
:
raise
ValueError
(
"`--engine-use-ray` is deprecated. "
"Set `VLLM_ALLOW_ENGINE_USE_RAY=1` to "
"force use it"
)
self
.
background_loop
:
Optional
[
asyncio
.
Future
]
=
None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
...
...
@@ -497,6 +766,11 @@ class AsyncLLMEngine:
def
errored
(
self
)
->
bool
:
return
self
.
_errored_with
is
not
None
@
property
def
limit_concurrency
(
self
)
->
Optional
[
int
]:
"""Maximum number of concurrently running requests."""
return
None
def
set_errored
(
self
,
exc
:
Exception
)
->
None
:
self
.
_errored_with
=
exc
...
...
@@ -507,7 +781,7 @@ class AsyncLLMEngine:
async
def
get_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
"PreTrained
Tokenizer
"
:
)
->
Any
Tokenizer
:
if
self
.
engine_use_ray
:
return
await
self
.
engine
.
get_tokenizer
.
remote
(
# type: ignore
lora_request
)
...
...
@@ -531,6 +805,20 @@ class AsyncLLMEngine:
partial
(
_log_task_completion
,
error_callback
=
self
.
_error_callback
))
self
.
background_loop
=
asyncio
.
shield
(
self
.
_background_loop_unshielded
)
def
shutdown_background_loop
(
self
)
->
None
:
"""
Shut down the background loop.
This method needs to be called during cleanup to remove
references to `self` and properly GC the resources held
by the async LLM engine (e.g., the executors as well as
their resources).
"""
if
self
.
_background_loop_unshielded
is
not
None
:
self
.
_background_loop_unshielded
.
cancel
()
self
.
_background_loop_unshielded
=
None
self
.
background_loop
=
None
def
_init_engine
(
self
,
*
args
,
**
kwargs
)
->
Union
[
_AsyncLLMEngine
,
"ray.ObjectRef"
]:
if
not
self
.
engine_use_ray
:
...
...
@@ -556,8 +844,8 @@ class AsyncLLMEngine:
Returns True if there are in-progress requests."""
new_requests
,
finish
ed_requests
=
(
self
.
_request_tracker
.
get_new_and_
finish
ed_requests
())
new_requests
,
abort
ed_requests
=
(
self
.
_request_tracker
.
get_new_and_
abort
ed_requests
())
for
new_request
in
new_requests
:
# Add the request into the vLLM engine's waiting queue.
...
...
@@ -576,8 +864,8 @@ class AsyncLLMEngine:
verbose
=
self
.
log_requests
,
)
if
finish
ed_requests
:
await
self
.
_engine_abort
(
finish
ed_requests
)
if
abort
ed_requests
:
await
self
.
_engine_abort
(
abort
ed_requests
)
if
self
.
engine_use_ray
:
request_outputs
=
await
self
.
engine
.
step
.
remote
()
# type: ignore
...
...
@@ -666,6 +954,8 @@ class AsyncLLMEngine:
raise
await
asyncio
.
sleep
(
0
)
# This method does not need to be async, but kept that way
# for backwards compatibility.
async
def
add_request
(
self
,
request_id
:
str
,
...
...
@@ -675,7 +965,7 @@ class AsyncLLMEngine:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
Async
Stream
:
)
->
Async
Generator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]
:
if
not
self
.
is_running
:
if
self
.
start_engine_loop
:
self
.
start_background_loop
()
...
...
@@ -686,20 +976,17 @@ class AsyncLLMEngine:
"error that caused the background loop to stop "
"(AsyncEngineDeadError)."
)
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
stream
=
self
.
_request_tracker
.
add_request
(
request_id
,
verbose
=
self
.
log_requests
,
inputs
=
inputs
,
params
=
params
,
arrival_time
=
arrival_time
,
arrival_time
=
arrival_time
or
time
.
time
()
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
)
return
stream
return
stream
.
generator
()
async
def
generate
(
self
,
...
...
@@ -709,7 +996,7 @@ class AsyncLLMEngine:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
Async
It
erator
[
RequestOutput
]:
)
->
Async
Gen
erator
[
RequestOutput
,
None
]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
...
...
@@ -774,7 +1061,7 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
async
for
output
in
self
.
_process
_request
(
async
for
output
in
await
self
.
add
_request
(
request_id
,
inputs
,
sampling_params
,
...
...
@@ -791,7 +1078,7 @@ class AsyncLLMEngine:
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
Async
It
erator
[
EmbeddingRequestOutput
]:
)
->
Async
Gen
erator
[
EmbeddingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model.
Generate outputs for a request. This method is a coroutine. It adds the
...
...
@@ -852,7 +1139,7 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
async
for
output
in
self
.
_process
_request
(
async
for
output
in
await
self
.
add
_request
(
request_id
,
inputs
,
pooling_params
,
...
...
@@ -861,37 +1148,6 @@ class AsyncLLMEngine:
):
yield
LLMEngine
.
validate_output
(
output
,
EmbeddingRequestOutput
)
async
def
_process_request
(
self
,
request_id
:
str
,
inputs
:
PromptInputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
*
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
AsyncIterator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
arrival_time
=
time
.
time
()
stream
=
await
self
.
add_request
(
request_id
,
inputs
,
params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
)
try
:
async
for
request_output
in
stream
:
yield
request_output
except
(
Exception
,
asyncio
.
CancelledError
)
as
e
:
self
.
_abort
(
request_id
)
raise
e
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort a request.
...
...
@@ -920,6 +1176,7 @@ class AsyncLLMEngine:
request_id: The unique id of the request.
"""
self
.
_request_tracker
.
abort_request
(
request_id
,
exception
=
asyncio
.
CancelledError
,
verbose
=
self
.
log_requests
)
async
def
get_model_config
(
self
)
->
ModelConfig
:
...
...
@@ -1009,3 +1266,9 @@ class AsyncLLMEngine:
logger_name
=
logger_name
))
else
:
self
.
engine
.
remove_logger
(
logger_name
=
logger_name
)
async
def
start_profile
(
self
)
->
None
:
self
.
engine
.
model_executor
.
_run_workers
(
"start_profile"
)
async
def
stop_profile
(
self
)
->
None
:
self
.
engine
.
model_executor
.
_run_workers
(
"stop_profile"
)
vllm/engine/llm_engine.py
View file @
af7f4372
...
...
@@ -3,28 +3,33 @@ from contextlib import contextmanager
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Type
,
TypeVar
,
Union
from
typing
import
Set
,
Tuple
,
Type
,
Union
from
typing_extensions
import
TypeVar
,
assert_never
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ObservabilityConfig
,
ParallelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
)
from
vllm.core.scheduler
import
(
ScheduledSequenceGroup
,
Scheduler
,
SchedulerOutputs
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics
import
(
LoggingStatLogger
,
PrometheusStatLogger
,
StatLoggerBase
,
Stats
)
from
vllm.engine.metrics_types
import
StatLoggerBase
,
Stats
from
vllm.engine.output_processor.interfaces
import
(
SequenceGroupOutputProcessor
)
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.engine.output_processor.util
import
create_output_by_sequence_group
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
INPUT_REGISTRY
,
LLMInputs
,
PromptInputs
from
vllm.inputs
import
(
INPUT_REGISTRY
,
EncoderDecoderLLMInputs
,
InputRegistry
,
LLMInputs
,
PromptInputs
,
SingletonPromptInputs
)
from
vllm.inputs.parse
import
is_explicit_encoder_decoder_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.outputs
import
(
EmbeddingRequestOutput
,
RequestOutput
,
RequestOutputFactory
)
from
vllm.pooling_params
import
PoolingParams
...
...
@@ -38,11 +43,12 @@ from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer
)
from
vllm.transformers_utils.config
import
try_get_generation_config
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
(
AnyTokenizer
,
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
,
Device
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
...
...
@@ -62,8 +68,14 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
return
config
.
to_diff_dict
()
_G
=
TypeVar
(
"_G"
,
bound
=
BaseTokenizerGroup
,
default
=
BaseTokenizerGroup
)
_O
=
TypeVar
(
"_O"
,
RequestOutput
,
EmbeddingRequestOutput
)
PromptComponents
=
Tuple
[
Optional
[
str
],
List
[
int
],
Optional
[
MultiModalDataDict
]]
DecoderPromptComponents
=
Tuple
[
Optional
[
str
],
Optional
[
List
[
int
]],
Optional
[
MultiModalDataDict
]]
class
LLMEngine
:
"""An LLM engine that receives requests and generates texts.
...
...
@@ -89,8 +101,6 @@ class LLMEngine:
scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device.
lora_config (Optional): The configuration related to serving multi-LoRA.
multimodal_config (Optional): The configuration related to multimodal
models.
speculative_config (Optional): The configuration related to speculative
decoding.
executor_class: The model executor class for managing distributed
...
...
@@ -130,24 +140,6 @@ class LLMEngine:
@
classmethod
def
validate_outputs
(
cls
,
outputs
:
GenericSequence
[
object
],
output_type
:
Type
[
_O
],
)
->
List
[
_O
]:
do_validate
=
cls
.
DO_VALIDATE_OUTPUT
outputs_
:
List
[
_O
]
if
TYPE_CHECKING
or
do_validate
:
outputs_
=
[]
for
output
in
outputs
:
if
not
isinstance
(
output
,
output_type
):
raise
TypeError
(
f
"Expected output of type
{
output_type
}
, "
f
"but found type
{
type
(
output
)
}
"
)
outputs_
.
append
(
output
)
else
:
outputs_
=
outputs
return
outputs_
tokenizer
:
Optional
[
BaseTokenizerGroup
]
...
...
@@ -161,7 +153,6 @@ class LLMEngine:
device_config
:
DeviceConfig
,
load_config
:
LoadConfig
,
lora_config
:
Optional
[
LoRAConfig
],
multimodal_config
:
Optional
[
MultiModalConfig
],
speculative_config
:
Optional
[
SpeculativeConfig
],
decoding_config
:
Optional
[
DecodingConfig
],
observability_config
:
Optional
[
ObservabilityConfig
],
...
...
@@ -170,6 +161,7 @@ class LLMEngine:
log_stats
:
bool
,
usage_context
:
UsageContext
=
UsageContext
.
ENGINE_CONTEXT
,
stat_loggers
:
Optional
[
Dict
[
str
,
StatLoggerBase
]]
=
None
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
)
->
None
:
logger
.
info
(
"Initializing an LLM engine (v%s) with config: "
...
...
@@ -216,11 +208,12 @@ class LLMEngine:
cache_config
.
enable_prefix_caching
,
)
# TODO(woosuk): Print more configs in debug mode.
from
vllm.plugins
import
load_general_plugins
load_general_plugins
()
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
self
.
multimodal_config
=
multimodal_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
device_config
=
device_config
...
...
@@ -235,16 +228,26 @@ class LLMEngine:
if
not
self
.
model_config
.
skip_tokenizer_init
:
self
.
tokenizer
=
self
.
_init_tokenizer
()
self
.
detokenizer
=
Detokenizer
(
self
.
tokenizer
)
tokenizer_group
=
self
.
get_tokenizer_group
()
else
:
self
.
tokenizer
=
None
self
.
detokenizer
=
None
tokenizer_group
=
None
# Ensure that the function doesn't contain a reference to self,
# to avoid engine GC issues
def
get_tokenizer_for_seq
(
sequence
:
Sequence
)
->
AnyTokenizer
:
assert
tokenizer_group
,
(
"tokenizer_group cannot be None, "
"make sure skip_tokenizer_init is False"
)
return
tokenizer_group
.
get_lora_tokenizer
(
sequence
.
lora_request
)
self
.
seq_counter
=
Counter
()
self
.
generation_config_fields
=
_load_generation_config_dict
(
model_config
)
self
.
input_processor
=
INPUT_REGISTRY
.
create_input_processor
(
self
.
model_config
)
self
.
input_registry
=
input_registry
self
.
input_processor
=
input_registry
.
create_input_processor
(
model_config
)
self
.
model_executor
=
executor_class
(
model_config
=
model_config
,
...
...
@@ -253,14 +256,12 @@ class LLMEngine:
scheduler_config
=
scheduler_config
,
device_config
=
device_config
,
lora_config
=
lora_config
,
multimodal_config
=
multimodal_config
,
speculative_config
=
speculative_config
,
load_config
=
load_config
,
prompt_adapter_config
=
prompt_adapter_config
,
observability_config
=
self
.
observability_config
,
)
init_success
=
False
try
:
if
not
self
.
model_config
.
embedding_mode
:
self
.
_initialize_kv_caches
()
...
...
@@ -320,6 +321,13 @@ class LLMEngine:
if
stat_loggers
is
not
None
:
self
.
stat_loggers
=
stat_loggers
else
:
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from
vllm.engine.metrics
import
(
LoggingStatLogger
,
PrometheusStatLogger
)
self
.
stat_loggers
=
{
"logging"
:
LoggingStatLogger
(
...
...
@@ -339,11 +347,6 @@ class LLMEngine:
"vllm.llm_engine"
,
self
.
observability_config
.
otlp_traces_endpoint
)
def
get_tokenizer_for_seq
(
self
,
sequence
:
Sequence
)
->
"PreTrainedTokenizer"
:
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
sequence
.
lora_request
)
# Create sequence output processor, e.g. for beam search or
# speculative decoding.
self
.
output_processor
=
(
...
...
@@ -358,13 +361,6 @@ class LLMEngine:
get_tokenizer_for_seq
,
),
))
init_success
=
True
finally
:
if
not
init_success
:
# Ensure that model_executor is shut down if LLMEngine init
# failed
self
.
model_executor
.
shutdown
()
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
...
...
@@ -482,11 +478,20 @@ class LLMEngine:
def
get_tokenizer_group
(
self
,
fail_msg
:
str
=
MISSING_TOKENIZER_GROUP_MSG
)
->
BaseTokenizerGroup
:
if
self
.
tokenizer
is
None
:
raise
ValueError
(
fail_msg
)
group_type
:
Type
[
_G
]
=
BaseTokenizerGroup
,
*
,
missing_msg
:
str
=
MISSING_TOKENIZER_GROUP_MSG
,
)
->
_G
:
tokenizer_group
=
self
.
tokenizer
return
self
.
tokenizer
if
tokenizer_group
is
None
:
raise
ValueError
(
missing_msg
)
if
not
isinstance
(
tokenizer_group
,
group_type
):
raise
TypeError
(
"Invalid type of tokenizer group. "
f
"Expected type:
{
group_type
}
, but "
f
"found type:
{
type
(
tokenizer_group
)
}
"
)
return
tokenizer_group
def
get_tokenizer
(
self
,
...
...
@@ -494,10 +499,6 @@ class LLMEngine:
)
->
AnyTokenizer
:
return
self
.
get_tokenizer_group
().
get_lora_tokenizer
(
lora_request
)
# def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
# return self.get_tokenizer_group().get_lora_tokenizer(
# sequence.lora_request)
def
_init_tokenizer
(
self
)
->
BaseTokenizerGroup
:
return
init_tokenizer_from_configs
(
model_config
=
self
.
model_config
,
...
...
@@ -516,8 +517,19 @@ class LLMEngine:
self
.
prompt_adapter_config
.
verify_with_model_config
(
self
.
model_config
)
def
_get_eos_token_id
(
self
,
lora_request
:
Optional
[
LoRARequest
])
->
Optional
[
int
]:
def
_get_bos_token_id
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
Optional
[
int
]:
if
self
.
tokenizer
is
None
:
logger
.
warning
(
"Using None for BOS token id because tokenizer "
"is not initialized"
)
return
None
return
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
bos_token_id
def
_get_eos_token_id
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
)
->
Optional
[
int
]:
if
self
.
tokenizer
is
None
:
logger
.
warning
(
"Using None for EOS token id because tokenizer "
"is not initialized"
)
...
...
@@ -525,16 +537,43 @@ class LLMEngine:
return
self
.
tokenizer
.
get_lora_tokenizer
(
lora_request
).
eos_token_id
def
_get_decoder_start_token_id
(
self
)
->
Optional
[
int
]:
'''
Obtain the decoder start token id employed by an encoder/decoder
model. Returns None for non-encoder/decoder models or if the
model config is unavailable.
'''
if
not
self
.
is_encoder_decoder_model
():
logger
.
warning
(
"Using None for decoder start token id because "
"this is not an encoder/decoder model."
)
return
None
if
(
self
.
model_config
is
None
or
self
.
model_config
.
hf_config
is
None
):
logger
.
warning
(
"Using None for decoder start token id because "
"model config is not available."
)
return
None
dec_start_token_id
=
getattr
(
self
.
model_config
.
hf_config
,
'decoder_start_token_id'
,
None
)
if
dec_start_token_id
is
None
:
logger
.
warning
(
"Falling back on <BOS> for decoder start token id "
"because decoder start token id is not available."
)
dec_start_token_id
=
self
.
_get_bos_token_id
()
return
dec_start_token_id
def
_add_processed_request
(
self
,
request_id
:
str
,
processed_inputs
:
LLMInputs
,
processed_inputs
:
Union
[
LLMInputs
,
EncoderDecoderLLMInputs
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
None
:
self
.
_validate_model_inputs
(
processed_inputs
)
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
seq_id
=
next
(
self
.
seq_counter
)
...
...
@@ -543,6 +582,16 @@ class LLMEngine:
seq
=
Sequence
(
seq_id
,
processed_inputs
,
block_size
,
eos_token_id
,
lora_request
,
prompt_adapter_request
)
encoder_seq
=
None
if
'encoder_prompt_token_ids'
in
processed_inputs
:
encoder_seq
=
Sequence
(
seq_id
,
processed_inputs
,
block_size
,
eos_token_id
,
lora_request
,
prompt_adapter_request
,
from_decoder_prompt
=
False
)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if
isinstance
(
params
,
SamplingParams
):
seq_group
=
self
.
_create_sequence_group_with_sampling
(
...
...
@@ -552,7 +601,8 @@ class LLMEngine:
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
elif
isinstance
(
params
,
PoolingParams
):
seq_group
=
self
.
_create_sequence_group_with_pooling
(
request_id
,
...
...
@@ -560,7 +610,8 @@ class LLMEngine:
params
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
else
:
raise
ValueError
(
"Either SamplingParams or PoolingParams must be provided."
)
...
...
@@ -576,36 +627,333 @@ class LLMEngine:
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
self
.
model_executor
.
stop_remote_worker_execution_loop
()
def
process_model_inputs
(
_LLMInputComponentsType
=
Tuple
[
str
,
List
[
int
]]
def
_prepare_decoder_input_ids_for_generation
(
self
,
decoder_input_ids
:
Optional
[
List
[
int
]],
)
->
List
[
int
]:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
Based on
https://github.com/huggingface/transformers/blob/
4037a2b5b1278736e566aec12e169100275545ea/
src/transformers/generation/utils.py
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
Arguments:
* decoder_input_ids: input token ids to preprocess
Returns:
* Processed token list
"""
decoder_start_token_id
=
self
.
_get_decoder_start_token_id
()
assert
decoder_start_token_id
is
not
None
if
decoder_input_ids
is
None
:
# no decoder prompt input ->
# use decoder_start_token_id as decoder_input_ids
decoder_input_ids
=
self
.
_get_default_enc_dec_decoder_prompt
()
if
(
len
(
decoder_input_ids
)
==
0
or
decoder_input_ids
[
0
]
!=
decoder_start_token_id
):
decoder_input_ids
=
[
decoder_start_token_id
]
+
decoder_input_ids
return
decoder_input_ids
def
_tokenize_prompt
(
self
,
prompt
:
str
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
],
)
->
List
[
int
]:
'''
Wrapper around application of the model's tokenizer.
Arguments:
* prompt
* request_id
* lora_request
Returns:
* prompt token ids
'''
tokenizer
=
self
.
get_tokenizer_group
(
missing_msg
=
"prompts must be None if skip_tokenizer_init is True"
)
return
tokenizer
.
encode
(
request_id
=
request_id
,
prompt
=
prompt
,
lora_request
=
lora_request
)
def
_extract_prompt_components
(
self
,
inputs
:
SingletonPromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
PromptComponents
:
'''
Extract the components of any single encoder or decoder input prompt.
Arguments:
* request_id
* inputs: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns:
* prompt
* prompt_token_ids
* multi_modal_data
'''
if
isinstance
(
inputs
,
str
):
prompt
=
inputs
prompt_token_ids
=
self
.
_tokenize_prompt
(
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
multi_modal_data
=
None
elif
isinstance
(
inputs
,
dict
):
if
"prompt_token_ids"
in
inputs
:
prompt
=
None
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
else
:
# NOTE: This extra assignment is required to pass mypy
prompt
=
parsed_prompt
=
inputs
[
"prompt"
]
prompt_token_ids
=
self
.
_tokenize_prompt
(
parsed_prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
)
else
:
assert_never
(
inputs
)
return
prompt
,
prompt_token_ids
,
multi_modal_data
def
_apply_prompt_adapter
(
self
,
prompt_token_ids
:
List
[
int
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
List
[
int
]:
if
prompt_adapter_request
:
prompt_token_ids
=
(
[
0
]
*
prompt_adapter_request
.
prompt_adapter_num_virtual_tokens
+
prompt_token_ids
)
return
prompt_token_ids
def
_get_default_enc_dec_decoder_prompt
(
self
)
->
List
[
int
]:
'''
Specifically for encoder/decoder models:
generate a default decoder prompt for when
the user specifies only the encoder prompt.
Encoder/decoder models utilize the decoder
prompt in different ways; as new models are
added, it is intended that this function
will be extended to produce differing
default decoder prompts, depending on the
model variety.
Absent a special case, the default behavior
of this method is to mirror the behavior of
the HuggingFace (HF) GenerationMixin for a None
decoder prompt, which is to employ a logit processor
setting to force the first decoded token to be <BOS>.
Here, this behavior is approximated by having the
"default" decoder prompt be <BOS>.
However, it is possible that in the future
other models may have different or more
complex logic for the default decoder prompt.
This motivates having a special helper method
for default decoder prompts.
Returns:
* prompt_token_ids
'''
bos_token_id
=
self
.
_get_bos_token_id
()
assert
bos_token_id
is
not
None
return
[
bos_token_id
]
def
_build_enc_dec_llm_inputs
(
self
,
encoder_comps
:
PromptComponents
,
decoder_comps
:
DecoderPromptComponents
,
)
->
EncoderDecoderLLMInputs
:
encoder_prompt
,
encoder_prompt_ids
,
encoder_mm_data
=
encoder_comps
decoder_prompt
,
decoder_prompt_ids
,
decoder_mm_data
=
decoder_comps
if
encoder_mm_data
is
not
None
or
decoder_mm_data
is
not
None
:
raise
ValueError
(
"Multi-modal encoder-decoder models are "
"not supported yet"
)
decoder_prompt_ids
=
(
self
.
_prepare_decoder_input_ids_for_generation
(
decoder_prompt_ids
))
return
EncoderDecoderLLMInputs
(
prompt_token_ids
=
decoder_prompt_ids
,
prompt
=
decoder_prompt
,
encoder_prompt_token_ids
=
encoder_prompt_ids
,
encoder_prompt
=
encoder_prompt
,
)
def
_process_encoder_decoder_prompt
(
self
,
inputs
:
PromptInputs
,
request_id
:
str
,
)
->
EncoderDecoderLLMInputs
:
'''
For encoder/decoder models only:
Process an input prompt into an
:class:`EncoderDecoderLLMInputs` instance.
There are two types of input prompts:
singleton prompts which carry only the
encoder prompt, and explicit encoder/decoder
prompts which carry both the encoder and the
decoder prompts as member variables.
This function handles the following scenarios:
* Singleton encoder prompt: extract encoder prompt
token ids & infer default decoder prompt token ids
* Explicit encoder/decoder prompt: extract encoder
and decoder prompt token ids
Note that for Explicit encoder/decoder prompts,
each sub-prompt (encoder or decoder prompt) can
have any possible singleton type; thus this
method relies on helper functions to obtain
token ids for the sub-prompts.
Arguments:
* inputs: an input prompt
* request_id
Returns:
* :class:`EncoderDecoderLLMInputs` instance
'''
encoder_comps
:
PromptComponents
decoder_comps
:
DecoderPromptComponents
if
is_explicit_encoder_decoder_prompt
(
inputs
):
encoder_comps
=
self
.
_extract_prompt_components
(
inputs
[
"encoder_prompt"
],
request_id
=
request_id
,
)
if
(
decoder_input
:
=
inputs
[
"decoder_prompt"
])
is
None
:
decoder_comps
=
None
,
None
,
None
else
:
decoder_comps
=
self
.
_extract_prompt_components
(
decoder_input
,
request_id
=
request_id
,
)
else
:
encoder_comps
=
self
.
_extract_prompt_components
(
inputs
,
request_id
=
request_id
,
)
decoder_comps
=
None
,
None
,
None
return
self
.
_build_enc_dec_llm_inputs
(
encoder_comps
,
decoder_comps
)
def
_build_decoder_only_llm_inputs
(
self
,
prompt_comps
:
PromptComponents
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
LLMInputs
:
prompt
,
prompt_token_ids
,
multi_modal_data
=
prompt_comps
prompt_token_ids
=
self
.
_apply_prompt_adapter
(
prompt_token_ids
,
prompt_adapter_request
=
prompt_adapter_request
)
return
LLMInputs
(
prompt_token_ids
=
prompt_token_ids
,
prompt
=
prompt
,
multi_modal_data
=
multi_modal_data
)
def
_process_decoder_only_prompt
(
self
,
inputs
:
SingletonPromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
LLMInputs
:
if
isinstance
(
inputs
,
str
):
inputs
=
{
"prompt"
:
inputs
}
'''
For decoder-only models:
Process an input prompt into an :class:`LLMInputs` instance.
if
"prompt_token_ids"
not
in
inputs
:
tokenizer
=
self
.
get_tokenizer_group
(
"prompts must be None if "
"skip_tokenizer_init is True"
)
Arguments:
prompt_token_ids
=
tokenizer
.
encode
(
request_id
=
request_id
,
prompt
=
inputs
[
"prompt"
],
lora_request
=
lora_request
)
else
:
prompt_token_ids
=
inputs
[
"prompt_token_ids"
]
* inputs: input prompt
* request_id
* lora_request
* prompt_adapter_request
if
prompt_adapter_request
:
prompt_token_ids
=
\
[
0
]
*
prompt_adapter_request
.
prompt_adapter_num_virtual_tokens
\
+
prompt_token_ids
Returns:
llm_inputs
=
LLMInputs
(
prompt_token_ids
=
prompt_token_ids
,
prompt
=
inputs
.
get
(
"prompt"
),
multi_modal_data
=
inputs
.
get
(
"multi_modal_data"
))
* :class:`LLMInputs` instance
'''
return
self
.
input_processor
(
llm_inputs
)
prompt_comps
=
self
.
_extract_prompt_components
(
inputs
,
request_id
=
request_id
,
lora_request
=
lora_request
,
)
return
self
.
_build_decoder_only_llm_inputs
(
prompt_comps
,
prompt_adapter_request
=
prompt_adapter_request
,
)
def
process_model_inputs
(
self
,
inputs
:
PromptInputs
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
Union
[
LLMInputs
,
EncoderDecoderLLMInputs
]:
if
self
.
is_encoder_decoder_model
():
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
model_inputs
=
self
.
_process_encoder_decoder_prompt
(
inputs
,
request_id
=
request_id
,
)
else
:
if
is_explicit_encoder_decoder_prompt
(
inputs
):
raise
ValueError
(
"Cannot pass encoder-decoder prompt "
"to decoder-only models"
)
# Decoder-only operation
model_inputs
=
self
.
_process_decoder_only_prompt
(
inputs
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
return
self
.
input_processor
(
model_inputs
)
def
add_request
(
self
,
...
...
@@ -666,10 +1014,11 @@ class LLMEngine:
arrival_time
=
time
.
time
()
processed_inputs
=
self
.
process_model_inputs
(
inputs
,
request_id
=
request_id
,
inputs
=
inputs
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
)
self
.
_add_processed_request
(
request_id
=
request_id
,
...
...
@@ -690,6 +1039,7 @@ class LLMEngine:
lora_request
:
Optional
[
LoRARequest
],
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
encoder_seq
:
Optional
[
Sequence
]
=
None
,
)
->
SequenceGroup
:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs
=
self
.
get_model_config
().
max_logprobs
...
...
@@ -715,7 +1065,8 @@ class LLMEngine:
sampling_params
=
sampling_params
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
return
seq_group
...
...
@@ -727,6 +1078,7 @@ class LLMEngine:
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
encoder_seq
:
Optional
[
Sequence
]
=
None
,
)
->
SequenceGroup
:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
...
...
@@ -738,7 +1090,8 @@ class LLMEngine:
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
pooling_params
=
pooling_params
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
encoder_seq
=
encoder_seq
)
return
seq_group
def
abort_request
(
self
,
request_id
:
Union
[
str
,
Iterable
[
str
]])
->
None
:
...
...
@@ -836,6 +1189,22 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
update_num_computed_tokens
(
scheduled_seq_group
.
token_chunk_size
)
if
output
is
not
None
and
len
(
output
)
>
0
:
for
o
in
output
:
if
(
isinstance
(
o
,
SamplerOutput
)
and
seq_group
.
metrics
is
not
None
):
if
seq_group
.
metrics
.
model_forward_time
is
not
None
:
seq_group
.
metrics
.
model_forward_time
+=
(
o
.
model_forward_time
)
else
:
seq_group
.
metrics
.
model_forward_time
=
(
o
.
model_forward_time
)
if
seq_group
.
metrics
.
model_execute_time
is
not
None
:
seq_group
.
metrics
.
model_execute_time
+=
(
o
.
model_execute_time
)
else
:
seq_group
.
metrics
.
model_execute_time
=
(
o
.
model_execute_time
)
if
self
.
model_config
.
embedding_mode
:
self
.
_process_sequence_group_outputs
(
seq_group
,
outputs
)
continue
...
...
@@ -916,6 +1285,11 @@ class LLMEngine:
raise
NotImplementedError
(
"Pipeline parallelism is only supported through AsyncLLMEngine "
"as performance will be severely degraded otherwise."
)
if
self
.
scheduler_config
.
num_scheduler_steps
>
1
:
raise
NotImplementedError
(
"Multiple scheduler steps (multi-step) are only supported "
"through AsyncLLMEngine. "
)
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
[
0
].
schedule
()
...
...
@@ -1015,6 +1389,13 @@ class LLMEngine:
for
scheduler
in
self
.
scheduler
)
cpu_cache_usage_sys
=
1.0
-
(
num_free_cpu
/
num_total_cpu
)
# Prefix Cache Hit Rate. Note that we always use
# the cache hit rate of the first virtual engine.
cpu_prefix_cache_hit_rate
=
self
.
scheduler
[
0
].
get_prefix_cache_hit_rate
(
Device
.
CPU
)
gpu_prefix_cache_hit_rate
=
self
.
scheduler
[
0
].
get_prefix_cache_hit_rate
(
Device
.
GPU
)
# Iteration stats
num_prompt_tokens_iter
=
0
num_generation_tokens_iter
=
0
...
...
@@ -1123,6 +1504,9 @@ class LLMEngine:
# KV Cache Usage in %
gpu_cache_usage_sys
=
gpu_cache_usage_sys
,
cpu_cache_usage_sys
=
cpu_cache_usage_sys
,
# Prefix Cache Hit Rate
cpu_prefix_cache_hit_rate
=
cpu_prefix_cache_hit_rate
,
gpu_prefix_cache_hit_rate
=
gpu_prefix_cache_hit_rate
,
# Iteration stats
num_prompt_tokens_iter
=
num_prompt_tokens_iter
,
...
...
@@ -1228,3 +1612,28 @@ class LLMEngine:
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_LATENCY_TIME_TO_FIRST_TOKEN
,
ttft
)
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_LATENCY_E2E
,
e2e_time
)
if
metrics
.
scheduler_time
is
not
None
:
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_LATENCY_TIME_IN_SCHEDULER
,
metrics
.
scheduler_time
)
if
metrics
.
model_forward_time
is
not
None
:
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_LATENCY_TIME_IN_MODEL_FORWARD
,
metrics
.
model_forward_time
/
1000.0
)
if
metrics
.
model_execute_time
is
not
None
:
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_LATENCY_TIME_IN_MODEL_EXECUTE
,
metrics
.
model_execute_time
)
def
is_encoder_decoder_model
(
self
):
return
self
.
model_config
.
is_encoder_decoder_model
def
is_embedding_model
(
self
):
return
self
.
model_config
.
is_embedding_model
def
_validate_model_inputs
(
self
,
inputs
:
Union
[
LLMInputs
,
EncoderDecoderLLMInputs
]):
prompt_key
=
"encoder_prompt_token_ids"
\
if
self
.
is_encoder_decoder_model
()
else
"prompt_token_ids"
if
not
inputs
.
get
(
prompt_key
):
raise
ValueError
(
"Prompt cannot be empty"
)
\ No newline at end of file
vllm/engine/metrics.py
View file @
af7f4372
import
time
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
from
typing
import
Counter
as
CollectionsCounter
from
typing
import
Dict
,
List
,
Optional
,
Protocol
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
import
numpy
as
np
import
prometheus_client
from
vllm.engine.metrics_types
import
(
StatLoggerBase
,
Stats
,
SupportsMetricsInfo
)
from
vllm.executor.ray_utils
import
ray
from
vllm.logger
import
init_logger
...
...
@@ -29,41 +28,60 @@ prometheus_client.disable_created_metrics()
# begin-metrics-definitions
class
Metrics
:
"""
vLLM uses a multiprocessing-based frontend for the OpenAI server.
This means that we need to run prometheus_client in multiprocessing mode
See https://prometheus.github.io/client_python/multiprocess/ for more
details on limitations.
"""
labelname_finish_reason
=
"finished_reason"
_gauge_cls
=
prometheus_client
.
Gauge
_counter_cls
=
prometheus_client
.
Counter
_histogram_cls
=
prometheus_client
.
Histogram
def
__init__
(
self
,
labelnames
:
List
[
str
],
max_model_len
:
int
):
# Unregister any existing vLLM collectors
# Unregister any existing vLLM collectors
(for CI/CD)
self
.
_unregister_vllm_metrics
()
# Config Information
self
.
_create_info_cache_config
()
# System stats
# Scheduler State
self
.
gauge_scheduler_running
=
self
.
_gauge_cls
(
name
=
"vllm:num_requests_running"
,
documentation
=
"Number of requests currently running on GPU."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
)
self
.
gauge_scheduler_waiting
=
self
.
_gauge_cls
(
name
=
"vllm:num_requests_waiting"
,
documentation
=
"Number of requests waiting to be processed."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
)
self
.
gauge_scheduler_swapped
=
self
.
_gauge_cls
(
name
=
"vllm:num_requests_swapped"
,
documentation
=
"Number of requests swapped to CPU."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
)
# KV Cache Usage in %
self
.
gauge_gpu_cache_usage
=
self
.
_gauge_cls
(
name
=
"vllm:gpu_cache_usage_perc"
,
documentation
=
"GPU KV-cache usage. 1 means 100 percent usage."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
)
self
.
gauge_cpu_cache_usage
=
self
.
_gauge_cls
(
name
=
"vllm:cpu_cache_usage_perc"
,
documentation
=
"CPU KV-cache usage. 1 means 100 percent usage."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
)
# Prefix caching block hit rate
self
.
gauge_cpu_prefix_cache_hit_rate
=
self
.
_gauge_cls
(
name
=
"vllm:cpu_prefix_cache_hit_rate"
,
documentation
=
"CPU prefix cache block hit rate."
,
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
)
self
.
gauge_gpu_prefix_cache_hit_rate
=
self
.
_gauge_cls
(
name
=
"vllm:gpu_prefix_cache_hit_rate"
,
documentation
=
"GPU prefix cache block hit rate."
,
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
)
# Iteration stats
self
.
counter_num_preemption
=
self
.
_counter_cls
(
...
...
@@ -137,11 +155,13 @@ class Metrics:
self
.
gauge_spec_decode_draft_acceptance_rate
=
self
.
_gauge_cls
(
name
=
"vllm:spec_decode_draft_acceptance_rate"
,
documentation
=
"Speulative token acceptance rate."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
)
self
.
gauge_spec_decode_efficiency
=
self
.
_gauge_cls
(
name
=
"vllm:spec_decode_efficiency"
,
documentation
=
"Speculative decoding system efficiency."
,
labelnames
=
labelnames
)
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
)
self
.
counter_spec_decode_num_accepted_tokens
=
(
self
.
_counter_cls
(
name
=
"vllm:spec_decode_num_accepted_tokens_total"
,
documentation
=
"Number of accepted tokens."
,
...
...
@@ -160,19 +180,18 @@ class Metrics:
name
=
"vllm:avg_prompt_throughput_toks_per_s"
,
documentation
=
"Average prefill throughput in tokens/s."
,
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
,
)
# Deprecated in favor of vllm:generation_tokens_total
self
.
gauge_avg_generation_throughput
=
self
.
_gauge_cls
(
name
=
"vllm:avg_generation_throughput_toks_per_s"
,
documentation
=
"Average generation throughput in tokens/s."
,
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
,
)
def
_create_info_cache_config
(
self
)
->
None
:
# Config Information
self
.
info_cache_config
=
prometheus_client
.
Info
(
name
=
'vllm:cache_config'
,
documentation
=
'information of cache_config'
)
# end-metrics-definitions
def
_unregister_vllm_metrics
(
self
)
->
None
:
for
collector
in
list
(
prometheus_client
.
REGISTRY
.
_collector_to_names
):
...
...
@@ -180,9 +199,6 @@ class Metrics:
prometheus_client
.
REGISTRY
.
unregister
(
collector
)
# end-metrics-definitions
class
_RayGaugeWrapper
:
"""Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge"""
...
...
@@ -190,7 +206,9 @@ class _RayGaugeWrapper:
def
__init__
(
self
,
name
:
str
,
documentation
:
str
=
""
,
labelnames
:
Optional
[
List
[
str
]]
=
None
):
labelnames
:
Optional
[
List
[
str
]]
=
None
,
multiprocess_mode
:
str
=
""
):
del
multiprocess_mode
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
self
.
_gauge
=
ray_metrics
.
Gauge
(
name
=
name
,
description
=
documentation
,
...
...
@@ -268,10 +286,6 @@ class RayMetrics(Metrics):
# No-op on purpose
pass
def
_create_info_cache_config
(
self
)
->
None
:
# No-op on purpose
pass
def
build_1_2_5_buckets
(
max_value
:
int
)
->
List
[
int
]:
"""
...
...
@@ -295,46 +309,6 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
exponent
+=
1
@
dataclass
class
Stats
:
"""Created by LLMEngine for use by StatLogger."""
now
:
float
# System stats (should have _sys suffix)
# Scheduler State
num_running_sys
:
int
num_waiting_sys
:
int
num_swapped_sys
:
int
# KV Cache Usage in %
gpu_cache_usage_sys
:
float
cpu_cache_usage_sys
:
float
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter
:
int
num_generation_tokens_iter
:
int
time_to_first_tokens_iter
:
List
[
float
]
time_per_output_tokens_iter
:
List
[
float
]
num_preemption_iter
:
int
# Request stats (should have _requests suffix)
# Latency
time_e2e_requests
:
List
[
float
]
# Metadata
num_prompt_tokens_requests
:
List
[
int
]
num_generation_tokens_requests
:
List
[
int
]
best_of_requests
:
List
[
int
]
n_requests
:
List
[
int
]
finished_reason_requests
:
List
[
str
]
spec_decode_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
class
SupportsMetricsInfo
(
Protocol
):
def
metrics_info
(
self
)
->
Dict
[
str
,
str
]:
...
def
local_interval_elapsed
(
now
:
float
,
last_log
:
float
,
local_interval
:
float
)
->
bool
:
elapsed_time
=
now
-
last_log
...
...
@@ -346,38 +320,9 @@ def get_throughput(tracked_stats: List[int], now: float,
return
float
(
np
.
sum
(
tracked_stats
)
/
(
now
-
last_log
))
class
StatLoggerBase
(
ABC
):
"""Base class for StatLogger."""
def
__init__
(
self
,
local_interval
:
float
)
->
None
:
# Tracked stats over current local logging interval.
self
.
num_prompt_tokens
:
List
[
int
]
=
[]
self
.
num_generation_tokens
:
List
[
int
]
=
[]
self
.
last_local_log
=
time
.
time
()
self
.
local_interval
=
local_interval
self
.
spec_decode_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
@
abstractmethod
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
log
(
self
,
stats
:
Stats
)
->
None
:
raise
NotImplementedError
def
maybe_update_spec_decode_metrics
(
self
,
stats
:
Stats
):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if
stats
.
spec_decode_metrics
is
not
None
:
self
.
spec_decode_metrics
=
stats
.
spec_decode_metrics
class
LoggingStatLogger
(
StatLoggerBase
):
"""LoggingStatLogger is used in LLMEngine to log to Stdout."""
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
raise
NotImplementedError
def
log
(
self
,
stats
:
Stats
)
->
None
:
"""Called by LLMEngine.
Logs to Stdout every self.local_interval seconds."""
...
...
@@ -417,7 +362,13 @@ class LoggingStatLogger(StatLoggerBase):
stats
.
gpu_cache_usage_sys
*
100
,
stats
.
cpu_cache_usage_sys
*
100
,
)
if
(
stats
.
cpu_prefix_cache_hit_rate
>=
0
or
stats
.
gpu_prefix_cache_hit_rate
>=
0
):
logger
.
info
(
"Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%"
,
stats
.
gpu_prefix_cache_hit_rate
*
100
,
stats
.
cpu_prefix_cache_hit_rate
*
100
,
)
if
self
.
spec_decode_metrics
is
not
None
:
logger
.
info
(
self
.
_format_spec_decode_metrics_str
(
...
...
@@ -440,10 +391,14 @@ class LoggingStatLogger(StatLoggerBase):
f
"Number of draft tokens:
{
metrics
.
draft_tokens
}
, "
f
"Number of emitted tokens:
{
metrics
.
emitted_tokens
}
."
)
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
raise
NotImplementedError
class
PrometheusStatLogger
(
StatLoggerBase
):
"""PrometheusStatLogger is used LLMEngine to log to Promethus."""
_metrics_cls
=
Metrics
_gauge_cls
=
prometheus_client
.
Gauge
def
__init__
(
self
,
local_interval
:
float
,
labels
:
Dict
[
str
,
str
],
max_model_len
:
int
)
->
None
:
...
...
@@ -453,10 +408,6 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
metrics
=
self
.
_metrics_cls
(
labelnames
=
list
(
labels
.
keys
()),
max_model_len
=
max_model_len
)
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
if
type
==
"cache_config"
:
self
.
metrics
.
info_cache_config
.
info
(
obj
.
metrics_info
())
def
_log_gauge
(
self
,
gauge
,
data
:
Union
[
int
,
float
])
->
None
:
# Convenience function for logging to gauge.
gauge
.
labels
(
**
self
.
labels
).
set
(
data
)
...
...
@@ -489,6 +440,10 @@ class PrometheusStatLogger(StatLoggerBase):
stats
.
gpu_cache_usage_sys
)
self
.
_log_gauge
(
self
.
metrics
.
gauge_cpu_cache_usage
,
stats
.
cpu_cache_usage_sys
)
self
.
_log_gauge
(
self
.
metrics
.
gauge_cpu_prefix_cache_hit_rate
,
stats
.
cpu_prefix_cache_hit_rate
)
self
.
_log_gauge
(
self
.
metrics
.
gauge_gpu_prefix_cache_hit_rate
,
stats
.
gpu_prefix_cache_hit_rate
)
# Iteration level data
self
.
_log_counter
(
self
.
metrics
.
counter_num_preemption
,
...
...
@@ -586,6 +541,19 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
last_local_log
=
stats
.
now
self
.
spec_decode_metrics
=
None
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
# Info type metrics are syntactic sugar for a gauge permanently set to 1
# Since prometheus multiprocessing mode does not support Info, emulate
# info here with a gauge.
if
type
==
"cache_config"
:
metrics_info
=
obj
.
metrics_info
()
info_gauge
=
self
.
_gauge_cls
(
name
=
"vllm:cache_config_info"
,
documentation
=
"Information of the LLMEngine CacheConfig"
,
labelnames
=
metrics_info
.
keys
(),
multiprocess_mode
=
"mostrecent"
)
info_gauge
.
labels
(
**
metrics_info
).
set
(
1
)
class
RayPrometheusStatLogger
(
PrometheusStatLogger
):
"""RayPrometheusStatLogger uses Ray metrics instead."""
...
...
vllm/engine/metrics_types.py
0 → 100644
View file @
af7f4372
"""
These types are defined in this file to avoid importing vllm.engine.metrics
and therefore importing prometheus_client.
This is required due to usage of Prometheus multiprocess mode to enable
metrics after splitting out the uvicorn process from the engine process.
Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR
before prometheus_client is imported. Typically, this is done by setting
the env variable before launch, but since we are a library, we need to
do this in Python code and lazily import prometheus_client.
"""
import
time
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Protocol
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
@
dataclass
class
Stats
:
"""Created by LLMEngine for use by StatLogger."""
now
:
float
# System stats (should have _sys suffix)
# Scheduler State
num_running_sys
:
int
num_waiting_sys
:
int
num_swapped_sys
:
int
# KV Cache Usage in %
gpu_cache_usage_sys
:
float
cpu_cache_usage_sys
:
float
# Prefix caching block hit rate
cpu_prefix_cache_hit_rate
:
float
gpu_prefix_cache_hit_rate
:
float
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter
:
int
num_generation_tokens_iter
:
int
time_to_first_tokens_iter
:
List
[
float
]
time_per_output_tokens_iter
:
List
[
float
]
num_preemption_iter
:
int
# Request stats (should have _requests suffix)
# Latency
time_e2e_requests
:
List
[
float
]
# Metadata
num_prompt_tokens_requests
:
List
[
int
]
num_generation_tokens_requests
:
List
[
int
]
best_of_requests
:
List
[
int
]
n_requests
:
List
[
int
]
finished_reason_requests
:
List
[
str
]
spec_decode_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
class
SupportsMetricsInfo
(
Protocol
):
def
metrics_info
(
self
)
->
Dict
[
str
,
str
]:
...
class
StatLoggerBase
(
ABC
):
"""Base class for StatLogger."""
def
__init__
(
self
,
local_interval
:
float
)
->
None
:
# Tracked stats over current local logging interval.
self
.
num_prompt_tokens
:
List
[
int
]
=
[]
self
.
num_generation_tokens
:
List
[
int
]
=
[]
self
.
last_local_log
=
time
.
time
()
self
.
local_interval
=
local_interval
self
.
spec_decode_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
@
abstractmethod
def
log
(
self
,
stats
:
Stats
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
info
(
self
,
type
:
str
,
obj
:
SupportsMetricsInfo
)
->
None
:
raise
NotImplementedError
def
maybe_update_spec_decode_metrics
(
self
,
stats
:
Stats
):
"""Save spec decode metrics (since they are unlikely
to be emitted at same time as log interval)."""
if
stats
.
spec_decode_metrics
is
not
None
:
self
.
spec_decode_metrics
=
stats
.
spec_decode_metrics
vllm/engine/output_processor/interfaces.py
View file @
af7f4372
from
abc
import
ABC
,
abstractmethod
from
typing
import
Callable
,
List
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
SchedulerConfig
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceGroupOutput
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
Counter
...
...
@@ -29,7 +28,7 @@ class SequenceGroupOutputProcessor(ABC):
detokenizer
:
Detokenizer
,
scheduler
:
List
[
Scheduler
],
seq_counter
:
Counter
,
get_tokenizer_for_seq
:
Callable
[[
Sequence
],
PreTrained
Tokenizer
],
get_tokenizer_for_seq
:
Callable
[[
Sequence
],
Any
Tokenizer
],
stop_checker
:
"StopChecker"
,
):
"""Create an output processor.
...
...
vllm/engine/output_processor/multi_step.py
View file @
af7f4372
import
functools
from
typing
import
Callable
,
List
from
transformers
import
PreTrainedTokenizer
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.output_processor.interfaces
import
(
SequenceGroupOutputProcessor
)
...
...
@@ -12,6 +10,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
Counter
logger
=
init_logger
(
__name__
)
...
...
@@ -36,7 +35,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
detokenizer
:
Detokenizer
,
scheduler
:
List
[
Scheduler
],
seq_counter
:
Counter
,
get_tokenizer_for_seq
:
Callable
[[
Sequence
],
PreTrained
Tokenizer
],
get_tokenizer_for_seq
:
Callable
[[
Sequence
],
Any
Tokenizer
],
stop_checker
:
StopChecker
,
):
self
.
detokenizer
=
detokenizer
...
...
vllm/engine/output_processor/stop_checker.py
View file @
af7f4372
from
typing
import
Callable
,
Optional
from
transformers
import
PreTrainedTokenizer
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
Sequence
,
SequenceStatus
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
class
StopChecker
:
...
...
@@ -15,8 +14,7 @@ class StopChecker:
"""
def
__init__
(
self
,
max_model_len
:
int
,
get_tokenizer_for_seq
:
Callable
[[
Sequence
],
PreTrainedTokenizer
]):
get_tokenizer_for_seq
:
Callable
[[
Sequence
],
AnyTokenizer
]):
# Do not use it directly, but use `self._get_max_model_len`.
self
.
_max_model_len
=
max_model_len
self
.
get_tokenizer_for_seq
=
get_tokenizer_for_seq
...
...
vllm/engine/protocol.py
View file @
af7f4372
from
typing
import
(
Async
It
erator
,
List
,
Mapping
,
Optional
,
Protocol
,
from
typing
import
(
Async
Gen
erator
,
List
,
Mapping
,
Optional
,
Protocol
,
runtime_checkable
)
from
transformers
import
PreTrainedTokenizer
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.inputs.data
import
PromptInputs
...
...
@@ -12,6 +10,7 @@ from vllm.pooling_params import PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
@
runtime_checkable
...
...
@@ -30,7 +29,11 @@ class AsyncEngineClient(Protocol):
def
errored
(
self
)
->
bool
:
...
async
def
generate
(
@
property
def
limit_concurrency
(
self
)
->
Optional
[
int
]:
"""Maximum number of concurrently running requests."""
def
generate
(
self
,
inputs
:
PromptInputs
,
sampling_params
:
SamplingParams
,
...
...
@@ -38,18 +41,20 @@ class AsyncEngineClient(Protocol):
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
Async
It
erator
[
RequestOutput
]:
)
->
Async
Gen
erator
[
RequestOutput
,
None
]:
"""Generates outputs for a request"""
...
async
def
encode
(
def
encode
(
self
,
inputs
:
PromptInputs
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
)
->
Async
It
erator
[
EmbeddingRequestOutput
]:
)
->
Async
Gen
erator
[
EmbeddingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model."""
...
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort a request.
...
...
@@ -60,25 +65,37 @@ class AsyncEngineClient(Protocol):
async
def
get_model_config
(
self
)
->
ModelConfig
:
"""Get the model configuration of the vLLM engine."""
...
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
...
"""Get the decoding configuration of the vLLM engine."""
async
def
get_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
PreTrainedTokenizer
:
"""Get the appropriate Tokenizer for the request"""
)
->
AnyTokenizer
:
"""Get the appropriate tokenizer for the request"""
...
async
def
is_tracing_enabled
(
self
)
->
bool
:
pass
...
async
def
do_log_stats
(
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
,
)
->
None
:
pass
...
async
def
check_health
(
self
)
->
None
:
"""Raise if unhealthy"""
...
async
def
start_profile
(
self
)
->
None
:
"""Start profiling the engine"""
...
async
def
stop_profile
(
self
)
->
None
:
"""Start profiling the engine"""
...
vllm/entrypoints/api_server.py
View file @
af7f4372
...
...
@@ -20,7 +20,8 @@ from vllm.entrypoints.launcher import serve_http
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
from
vllm.utils
import
(
FlexibleArgumentParser
,
iterate_with_cancellation
,
random_uuid
)
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
"vllm.entrypoints.api_server"
)
...
...
@@ -53,11 +54,14 @@ async def generate(request: Request) -> Response:
assert
engine
is
not
None
results_generator
=
engine
.
generate
(
prompt
,
sampling_params
,
request_id
)
results_generator
=
iterate_with_cancellation
(
results_generator
,
is_cancelled
=
request
.
is_disconnected
)
# Streaming case
async
def
stream_results
()
->
AsyncGenerator
[
bytes
,
None
]:
async
for
request_output
in
results_generator
:
prompt
=
request_output
.
prompt
assert
prompt
is
not
None
text_outputs
=
[
prompt
+
output
.
text
for
output
in
request_output
.
outputs
]
...
...
@@ -69,15 +73,15 @@ async def generate(request: Request) -> Response:
# Non-streaming case
final_output
=
None
try
:
async
for
request_output
in
results_generator
:
if
await
request
.
is_disconnected
():
# Abort the request if the client disconnects.
await
engine
.
abort
(
request_id
)
return
Response
(
status_code
=
499
)
final_output
=
request_output
except
asyncio
.
CancelledError
:
return
Response
(
status_code
=
499
)
assert
final_output
is
not
None
prompt
=
final_output
.
prompt
assert
prompt
is
not
None
text_outputs
=
[
prompt
+
output
.
text
for
output
in
final_output
.
outputs
]
ret
=
{
"text"
:
text_outputs
}
return
JSONResponse
(
ret
)
...
...
@@ -113,9 +117,11 @@ async def run_server(args: Namespace,
logger
.
info
(
"args: %s"
,
args
)
app
=
await
init_app
(
args
,
llm_engine
)
assert
engine
is
not
None
shutdown_task
=
await
serve_http
(
app
,
engine
=
engine
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
args
.
log_level
,
...
...
vllm/entrypoints/chat_utils.py
View file @
af7f4372
import
codecs
from
dataclasses
import
dataclass
from
functools
import
lru_cache
from
typing
import
(
Awaitable
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
,
cast
,
final
)
from
pathlib
import
Path
from
typing
import
(
Any
,
Awaitable
,
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
Union
)
# yapf conflicts with isort for this block
# yapf: disable
...
...
@@ -14,18 +15,33 @@ from openai.types.chat import (
ChatCompletionMessageParam
as
OpenAIChatCompletionMessageParam
)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from
pydantic
import
ConfigDict
from
transformers
import
PreTrainedTokenizer
from
typing_extensions
import
Required
,
TypedDict
from
pydantic
import
ConfigDict
,
TypeAdapter
from
typing_extensions
import
Required
,
TypeAlias
,
TypedDict
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal.utils
import
async_get_and_parse_image
from
vllm.multimodal.utils
import
(
async_get_and_parse_audio
,
async_get_and_parse_image
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
logger
=
init_logger
(
__name__
)
class
AudioURL
(
TypedDict
,
total
=
False
):
url
:
Required
[
str
]
"""
Either a URL of the audio or a data URL with base64 encoded audio data.
"""
class
ChatCompletionContentPartAudioParam
(
TypedDict
,
total
=
False
):
audio_url
:
Required
[
AudioURL
]
type
:
Required
[
Literal
[
"audio_url"
]]
"""The type of the content part."""
class
CustomChatCompletionContentPartParam
(
TypedDict
,
total
=
False
):
__pydantic_config__
=
ConfigDict
(
extra
=
"allow"
)
# type: ignore
...
...
@@ -33,8 +49,9 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
"""The type of the content part."""
ChatCompletionContentPartParam
=
Union
[
OpenAIChatCompletionContentPartParam
,
CustomChatCompletionContentPartParam
]
ChatCompletionContentPartParam
:
TypeAlias
=
Union
[
OpenAIChatCompletionContentPartParam
,
ChatCompletionContentPartAudioParam
,
CustomChatCompletionContentPartParam
,
]
class
CustomChatCompletionMessageParam
(
TypedDict
,
total
=
False
):
...
...
@@ -57,7 +74,7 @@ ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
CustomChatCompletionMessageParam
]
@
final
# So that it should be compatible with Dict[str, str]
# TODO: Make fields ReadOnly once mypy supports it
class
ConversationMessage
(
TypedDict
):
role
:
str
content
:
str
...
...
@@ -69,13 +86,17 @@ class ChatMessageParseResult:
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
def
load_chat_template
(
chat_template
:
Optional
[
str
])
->
Optional
[
str
]:
def
load_chat_template
(
chat_template
:
Optional
[
Union
[
Path
,
str
]])
->
Optional
[
str
]:
if
chat_template
is
None
:
return
None
try
:
with
open
(
chat_template
,
"r"
)
as
f
:
resolved_chat_template
=
f
.
read
()
except
OSError
as
e
:
if
isinstance
(
chat_template
,
Path
):
raise
JINJA_CHARS
=
"{}
\n
"
if
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
):
msg
=
(
f
"The supplied chat template (
{
chat_template
}
) "
...
...
@@ -92,11 +113,12 @@ def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
@
lru_cache
(
maxsize
=
None
)
def
_
image
_token_str
(
model_config
:
ModelConfig
,
tokenizer
:
PreTrainedTokenizer
)
->
Optional
[
str
]:
def
_
mm
_token_str
(
model_config
:
ModelConfig
,
tokenizer
:
AnyTokenizer
,
modality
:
Literal
[
"image"
,
"audio"
]
)
->
Optional
[
str
]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type
=
model_config
.
hf_config
.
model_type
if
modality
==
"image"
:
if
model_type
==
"phi3_v"
:
# Workaround since this token is not defined in the tokenizer
return
"<|image_1|>"
...
...
@@ -109,40 +131,54 @@ def _image_token_str(model_config: ModelConfig,
return
tokenizer
.
decode
(
model_config
.
hf_config
.
image_token_index
)
if
model_type
in
(
"chameleon"
,
"internvl_chat"
):
return
"<image>"
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
elif
modality
==
"audio"
:
if
model_type
==
"ultravox"
:
return
"<|reserved_special_token_0|>"
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
else
:
raise
TypeError
(
f
"Unknown modality:
{
modality
}
"
)
# TODO: Let user specify how to insert
image
tokens into prompt
# TODO: Let user specify how to insert
multimodal
tokens into prompt
# (similar to chat template)
def
_get_full_image_text_prompt
(
image_token_str
:
str
,
text_prompt
:
str
)
->
str
:
"""Combine image and text prompts for vision language model"""
def
_get_full_multimodal_text_prompt
(
placeholder_token_str
:
str
,
text_prompt
:
str
)
->
str
:
"""Combine multimodal prompts for a multimodal language model"""
# NOTE: For now we assume all model architectures use the same
# image + text prompt format. This may change in the future.
return
f
"
{
image_token_str
}
\n
{
text_prompt
}
"
# placeholder + text prompt format. This may change in the future.
return
f
"
{
placeholder_token_str
}
\n
{
text_prompt
}
"
_TextParser
=
TypeAdapter
(
ChatCompletionContentPartTextParam
)
_ImageParser
=
TypeAdapter
(
ChatCompletionContentPartImageParam
)
_AudioParser
=
TypeAdapter
(
ChatCompletionContentPartAudioParam
)
def
_parse_chat_message_content_parts
(
role
:
str
,
parts
:
Iterable
[
ChatCompletionContentPartParam
],
model_config
:
ModelConfig
,
tokenizer
:
PreTrained
Tokenizer
,
tokenizer
:
Any
Tokenizer
,
)
->
ChatMessageParseResult
:
texts
:
List
[
str
]
=
[]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
modality
:
Literal
[
"image"
,
"audio"
]
=
"image"
for
part
in
parts
:
part_type
=
part
[
"type"
]
if
part_type
==
"text"
:
text
=
cast
(
ChatCompletionContentPartTextParam
,
part
)[
"text"
]
text
=
_TextParser
.
validate_python
(
part
)[
"text"
]
texts
.
append
(
text
)
elif
part_type
==
"image_url"
:
modality
=
"image"
if
len
(
mm_futures
)
>
0
:
raise
NotImplementedError
(
"Multiple
'image_url'
input is currently not supported."
)
"Multiple
multimodal
input
s
is currently not supported."
)
image_url
=
cast
(
ChatCompletionContentPartImageParam
,
part
)[
"image_url"
]
image_url
=
_ImageParser
.
validate_python
(
part
)[
"image_url"
]
if
image_url
.
get
(
"detail"
,
"auto"
)
!=
"auto"
:
logger
.
warning
(
...
...
@@ -151,21 +187,31 @@ def _parse_chat_message_content_parts(
image_future
=
async_get_and_parse_image
(
image_url
[
"url"
])
mm_futures
.
append
(
image_future
)
elif
part_type
==
"audio_url"
:
modality
=
"audio"
if
len
(
mm_futures
)
>
0
:
raise
NotImplementedError
(
"Multiple multimodal inputs is currently not supported."
)
audio_url
=
_AudioParser
.
validate_python
(
part
)[
"audio_url"
]
audio_future
=
async_get_and_parse_audio
(
audio_url
[
"url"
])
mm_futures
.
append
(
audio_future
)
else
:
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
text_prompt
=
"
\n
"
.
join
(
texts
)
if
mm_futures
:
image_token_str
=
_image_token_str
(
model_config
,
tokenizer
)
if
image_token_str
is
not
None
:
if
image_token_str
in
text_prompt
:
placeholder_token_str
=
_mm_token_str
(
model_config
,
tokenizer
,
modality
)
if
placeholder_token_str
is
not
None
:
if
placeholder_token_str
in
text_prompt
:
logger
.
warning
(
"Detected
image
token string in the text prompt. "
"Detected
multi-modal
token string in the text prompt. "
"Skipping prompt formatting."
)
else
:
text_prompt
=
_get_full_
image
_text_prompt
(
image_token_str
=
image
_token_str
,
text_prompt
=
_get_full_
multimodal
_text_prompt
(
placeholder_token_str
=
placeholder
_token_str
,
text_prompt
=
text_prompt
,
)
...
...
@@ -177,7 +223,7 @@ def _parse_chat_message_content_parts(
def
_parse_chat_message_content
(
message
:
ChatCompletionMessageParam
,
model_config
:
ModelConfig
,
tokenizer
:
PreTrained
Tokenizer
,
tokenizer
:
Any
Tokenizer
,
)
->
ChatMessageParseResult
:
role
=
message
[
"role"
]
content
=
message
.
get
(
"content"
)
...
...
@@ -188,14 +234,18 @@ def _parse_chat_message_content(
messages
=
[
ConversationMessage
(
role
=
role
,
content
=
content
)]
return
ChatMessageParseResult
(
messages
=
messages
,
mm_futures
=
[])
return
_parse_chat_message_content_parts
(
role
,
content
,
model_config
,
tokenizer
)
return
_parse_chat_message_content_parts
(
role
,
content
,
# type: ignore
model_config
,
tokenizer
,
)
def
parse_chat_messages
(
messages
:
List
[
ChatCompletionMessageParam
],
model_config
:
ModelConfig
,
tokenizer
:
PreTrained
Tokenizer
,
tokenizer
:
Any
Tokenizer
,
)
->
Tuple
[
List
[
ConversationMessage
],
List
[
Awaitable
[
MultiModalDataDict
]]]:
conversation
:
List
[
ConversationMessage
]
=
[]
mm_futures
:
List
[
Awaitable
[
MultiModalDataDict
]]
=
[]
...
...
@@ -208,3 +258,28 @@ def parse_chat_messages(
mm_futures
.
extend
(
parse_result
.
mm_futures
)
return
conversation
,
mm_futures
def
apply_chat_template
(
tokenizer
:
AnyTokenizer
,
conversation
:
List
[
ConversationMessage
],
chat_template
:
Optional
[
str
],
*
,
tokenize
:
bool
=
False
,
# Different from HF's default
**
kwargs
:
Any
,
)
->
str
:
if
chat_template
is
None
and
tokenizer
.
chat_template
is
None
:
raise
ValueError
(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
)
prompt
=
tokenizer
.
apply_chat_template
(
conversation
=
conversation
,
chat_template
=
chat_template
,
tokenize
=
tokenize
,
**
kwargs
,
)
assert
isinstance
(
prompt
,
str
)
return
prompt
vllm/entrypoints/launcher.py
View file @
af7f4372
import
asyncio
import
signal
from
http
import
HTTPStatus
from
typing
import
Any
import
uvicorn
from
fastapi
import
FastAPI
from
fastapi
import
FastAPI
,
Response
from
vllm
import
envs
from
vllm.engine.async_llm_engine
import
AsyncEngineDeadError
from
vllm.engine.protocol
import
AsyncEngineClient
from
vllm.logger
import
init_logger
from
vllm.utils
import
find_process_using_port
logger
=
init_logger
(
__name__
)
async
def
serve_http
(
app
:
FastAPI
,
**
uvicorn_kwargs
:
Any
):
async
def
serve_http
(
app
:
FastAPI
,
engine
:
AsyncEngineClient
,
**
uvicorn_kwargs
:
Any
):
logger
.
info
(
"Available routes are:"
)
for
route
in
app
.
routes
:
methods
=
getattr
(
route
,
"methods"
,
None
)
...
...
@@ -21,8 +27,18 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
logger
.
info
(
"Route: %s, Methods: %s"
,
path
,
', '
.
join
(
methods
))
# Set concurrency limits in uvicorn if running in multiprocessing mode
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
if
engine
.
limit_concurrency
is
not
None
:
logger
.
info
(
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
"limit at the expense of performance run with "
"--disable-frontend-multiprocessing"
,
engine
.
limit_concurrency
)
uvicorn_kwargs
[
"limit_concurrency"
]
=
engine
.
limit_concurrency
config
=
uvicorn
.
Config
(
app
,
**
uvicorn_kwargs
)
server
=
uvicorn
.
Server
(
config
)
_add_shutdown_handlers
(
app
,
server
,
engine
)
loop
=
asyncio
.
get_running_loop
()
...
...
@@ -42,5 +58,45 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
await
server_task
return
dummy_shutdown
()
except
asyncio
.
CancelledError
:
port
=
uvicorn_kwargs
[
"port"
]
process
=
find_process_using_port
(
port
)
if
process
is
not
None
:
logger
.
debug
(
"port %s is used by process %s launched with command:
\n
%s"
,
port
,
process
,
" "
.
join
(
process
.
cmdline
()))
logger
.
info
(
"Gracefully stopping http server"
)
return
server
.
shutdown
()
def
_add_shutdown_handlers
(
app
:
FastAPI
,
server
:
uvicorn
.
Server
,
engine
:
AsyncEngineClient
)
->
None
:
"""Adds handlers for fatal errors that should crash the server"""
@
app
.
exception_handler
(
RuntimeError
)
async
def
runtime_error_handler
(
_
,
__
):
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
if
(
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
and
engine
.
errored
and
not
engine
.
is_running
):
logger
.
fatal
(
"AsyncLLMEngine has failed, terminating server "
"process"
)
# See discussions here on shutting down a uvicorn server
# https://github.com/encode/uvicorn/discussions/1103
# In this case we cannot await the server shutdown here because
# this handler must first return to close the connection for
# this request.
server
.
should_exit
=
True
return
Response
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
@
app
.
exception_handler
(
AsyncEngineDeadError
)
async
def
engine_dead_handler
(
_
,
__
):
"""Kill the server if the async engine is already dead. It will
not handle any further requests."""
if
not
envs
.
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH
:
logger
.
fatal
(
"AsyncLLMEngine is already dead, terminating server "
"process"
)
server
.
should_exit
=
True
return
Response
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
)
vllm/entrypoints/llm.py
View file @
af7f4372
...
...
@@ -2,12 +2,14 @@ from contextlib import contextmanager
from
typing
import
ClassVar
,
List
,
Optional
,
Sequence
,
Union
,
cast
,
overload
from
tqdm
import
tqdm
from
transformers
import
PreTrainedTokenizer
,
PreTrainedTokenizerFast
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.inputs
import
(
PromptInputs
,
TextPrompt
,
TokensPrompt
,
parse_and_batch_prompt
)
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
apply_chat_template
,
parse_chat_messages
)
from
vllm.inputs
import
PromptInputs
,
TextPrompt
,
TokensPrompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
...
...
@@ -17,7 +19,9 @@ from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer
import
get_cached_tokenizer
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
,
deprecate_kwargs
...
...
@@ -119,18 +123,31 @@ class LLM:
tokenizer_revision
:
Optional
[
str
]
=
None
,
seed
:
int
=
0
,
gpu_memory_utilization
:
float
=
0.9
,
swap_space
:
in
t
=
4
,
swap_space
:
floa
t
=
4
,
cpu_offload_gb
:
float
=
0
,
enforce_eager
:
bool
=
Fals
e
,
enforce_eager
:
Optional
[
bool
]
=
Non
e
,
max_context_len_to_capture
:
Optional
[
int
]
=
None
,
max_seq_len_to_capture
:
int
=
8192
,
disable_custom_all_reduce
:
bool
=
False
,
**
kwargs
,
)
->
None
:
'''
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False for decoder-only models and True
for encoder/decoder models, since encoder/decoder models
do not currently support CUDAGraph.
'''
if
"disable_log_stats"
not
in
kwargs
:
kwargs
[
"disable_log_stats"
]
=
True
removed_vision_keys
=
(
"image_token_id"
,
"image_feature_size"
,
"image_input_shape"
,
"image_input_type"
)
removed_vision_keys
=
(
"image_token_id"
,
"image_feature_size"
,
"image_input_shape"
,
"image_input_type"
,
)
if
any
(
k
in
kwargs
for
k
in
removed_vision_keys
):
raise
TypeError
(
"There is no need to pass vision-related arguments anymore."
)
...
...
@@ -159,22 +176,19 @@ class LLM:
engine_args
,
usage_context
=
UsageContext
.
LLM_CLASS
)
self
.
request_counter
=
Counter
()
def
get_tokenizer
(
self
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
return
self
.
llm_engine
.
tokenizer
.
tokenizer
def
get_tokenizer
(
self
)
->
AnyTokenizer
:
return
self
.
llm_engine
.
get_tokenizer_group
(
TokenizerGroup
).
tokenizer
def
set_tokenizer
(
self
,
tokenizer
:
AnyTokenizer
)
->
None
:
tokenizer_group
=
self
.
llm_engine
.
get_tokenizer_group
(
TokenizerGroup
)
def
set_tokenizer
(
self
,
tokenizer
:
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
],
)
->
None
:
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
if
tokenizer
.
__class__
.
__name__
.
startswith
(
"Cached"
):
self
.
llm_engine
.
tokenizer
.
tokenizer
=
tokenizer
tokenizer
_group
.
tokenizer
=
tokenizer
else
:
self
.
llm_engine
.
tokenizer
.
tokenizer
=
get_cached_tokenizer
(
tokenizer
)
tokenizer_group
.
tokenizer
=
get_cached_tokenizer
(
tokenizer
)
@
overload
# LEGACY: single (prompt + optional token ids)
def
generate
(
...
...
@@ -250,11 +264,12 @@ class LLM:
)
->
List
[
RequestOutput
]:
...
@
deprecate_kwargs
(
"prompts"
,
@
deprecate_kwargs
(
"prompts"
,
"prompt_token_ids"
,
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_LEGACY
,
additional_message
=
"Please use the 'inputs' parameter
"
"instead."
)
additional_message
=
"Please use the 'inputs' parameter
instead."
,
)
def
generate
(
self
,
prompts
:
Union
[
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
...
...
@@ -287,7 +302,7 @@ class LLM:
generation, if any.
Returns:
A list of `RequestOutput` objects containing the
A list of
`
`RequestOutput`
`
objects containing the
generated completions in the same order as the input prompts.
Note:
...
...
@@ -297,8 +312,8 @@ class LLM:
"""
if
self
.
llm_engine
.
model_config
.
embedding_mode
:
raise
ValueError
(
"LLM.generate() is only supported for generation
models
"
"(XForCausalLM)."
)
"LLM.generate() is only supported for
(conditional)
generation "
"
models
(XForCausalLM
, XForConditionalGeneration
)."
)
if
prompt_token_ids
is
not
None
:
inputs
=
self
.
_convert_v1_inputs
(
...
...
@@ -330,6 +345,62 @@ class LLM:
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
LLMEngine
.
validate_outputs
(
outputs
,
RequestOutput
)
def
chat
(
self
,
messages
:
List
[
ChatCompletionMessageParam
],
sampling_params
:
Optional
[
Union
[
SamplingParams
,
List
[
SamplingParams
]]]
=
None
,
use_tqdm
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
chat_template
:
Optional
[
str
]
=
None
,
add_generation_prompt
:
bool
=
True
,
)
->
List
[
RequestOutput
]:
"""
Generates responses for chat messages.
Converts the messages to prompts using the tokenizer and calls
the :meth:`generate` method to generate the responses.
Args:
messages: A list of messages to generate responses for. Each
message is a list of dictionaries with 'role' and 'content'
keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
add_generation_prompt: If True, adds a generation template
to each message.
Returns:
A list of ``RequestOutput`` objects containing the generated
responses in the same order as the input messages.
"""
tokenizer
=
self
.
get_tokenizer
()
model_config
=
self
.
llm_engine
.
get_model_config
()
conversations
,
_
=
parse_chat_messages
(
messages
,
model_config
,
tokenizer
)
prompts
=
apply_chat_template
(
tokenizer
,
conversations
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
)
return
self
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
use_tqdm
,
lora_request
=
lora_request
,
)
@
overload
# LEGACY: single (prompt + optional token ids)
def
encode
(
self
,
...
...
@@ -404,11 +475,12 @@ class LLM:
)
->
List
[
EmbeddingRequestOutput
]:
...
@
deprecate_kwargs
(
"prompts"
,
@
deprecate_kwargs
(
"prompts"
,
"prompt_token_ids"
,
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_LEGACY
,
additional_message
=
"Please use the 'inputs' parameter
"
"instead."
)
additional_message
=
"Please use the 'inputs' parameter
instead."
,
)
def
encode
(
self
,
prompts
:
Union
[
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
...
...
@@ -504,6 +576,8 @@ class LLM:
inputs
:
List
[
PromptInputs
]
=
[]
for
i
in
range
(
num_requests
):
item
:
PromptInputs
if
prompts
is
not
None
:
item
=
TextPrompt
(
prompt
=
prompts
[
i
])
elif
prompt_token_ids
is
not
None
:
...
...
@@ -554,15 +628,15 @@ class LLM:
params
[
i
]
if
isinstance
(
params
,
Sequence
)
else
params
,
lora_request
=
lora_request
[
i
]
if
isinstance
(
lora_request
,
Sequence
)
else
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
)
def
_add_request
(
self
,
inputs
:
PromptInputs
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
lora_request
:
Optional
[
Union
[
List
[
LoRARequest
],
LoRARequest
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_engine
.
add_request
(
...
...
@@ -570,7 +644,8 @@ class LLM:
inputs
,
params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
)
def
_add_guided_processor
(
self
,
...
...
@@ -619,8 +694,8 @@ class LLM:
in_spd
=
total_in_toks
/
pbar
.
format_dict
[
"elapsed"
]
total_out_toks
+=
sum
(
len
(
stp
.
token_ids
)
for
stp
in
output
.
outputs
)
out_spd
=
total_out_toks
/
pbar
.
format_dict
[
"elapsed"
]
out_spd
=
(
total_out_toks
/
pbar
.
format_dict
[
"elapsed"
]
)
pbar
.
postfix
=
(
f
"est. speed input:
{
in_spd
:.
2
f
}
toks/s, "
f
"output:
{
out_spd
:.
2
f
}
toks/s"
)
...
...
@@ -631,3 +706,9 @@ class LLM:
# This is necessary because some requests may be finished earlier than
# its previous requests.
return
sorted
(
outputs
,
key
=
lambda
x
:
int
(
x
.
request_id
))
def
_is_encoder_decoder_model
(
self
):
return
self
.
llm_engine
.
is_encoder_decoder_model
()
def
_is_embedding_model
(
self
):
return
self
.
llm_engine
.
is_embedding_model
()
vllm/entrypoints/openai/api_server.py
View file @
af7f4372
import
asyncio
import
importlib
import
inspect
import
multiprocessing
import
os
import
re
import
tempfile
from
argparse
import
Namespace
from
contextlib
import
asynccontextmanager
from
http
import
HTTPStatus
from
multiprocessing
import
Process
from
typing
import
AsyncIterator
,
Set
from
typing
import
AsyncIterator
,
Optional
,
Set
from
fastapi
import
APIRouter
,
FastAPI
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
prometheus_client
import
make_asgi_app
from
starlette.routing
import
Mount
from
typing_extensions
import
assert_never
import
vllm.envs
as
envs
from
vllm.config
import
ModelConfig
...
...
@@ -28,14 +30,16 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
ChatCompletionResponse
,
CompletionRequest
,
CompletionResponse
,
DetokenizeRequest
,
DetokenizeResponse
,
EmbeddingRequest
,
ErrorResponse
,
EmbeddingRequest
,
EmbeddingResponse
,
ErrorResponse
,
TokenizeRequest
,
TokenizeResponse
)
# yapf: enable
from
vllm.entrypoints.openai.rpc.client
import
AsyncEngineRPCClient
from
vllm.entrypoints.openai.rpc.server
import
run_rpc_server
# yapf: enable
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_completion
import
OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
...
...
@@ -43,7 +47,7 @@ from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization
)
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
get_open_
port
from
vllm.utils
import
FlexibleArgumentParser
,
get_open_
zmq_ipc_path
from
vllm.version
import
__version__
as
VLLM_VERSION
TIMEOUT_KEEP_ALIVE
=
5
# seconds
...
...
@@ -54,19 +58,23 @@ openai_serving_chat: OpenAIServingChat
openai_serving_completion
:
OpenAIServingCompletion
openai_serving_embedding
:
OpenAIServingEmbedding
openai_serving_tokenization
:
OpenAIServingTokenization
prometheus_multiproc_dir
:
tempfile
.
TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger
=
init_logger
(
'vllm.entrypoints.openai.api_server'
)
_running_tasks
:
Set
[
asyncio
.
Task
]
=
set
()
def
model_is_embedding
(
model_name
:
str
,
trust_remote_code
:
bool
)
->
bool
:
def
model_is_embedding
(
model_name
:
str
,
trust_remote_code
:
bool
,
quantization
:
str
)
->
bool
:
return
ModelConfig
(
model
=
model_name
,
tokenizer
=
model_name
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
trust_remote_code
,
quantization
=
quantization
,
seed
=
0
,
dtype
=
"
float16
"
).
embedding_mode
dtype
=
"
auto
"
).
embedding_mode
@
asynccontextmanager
...
...
@@ -86,7 +94,16 @@ async def lifespan(app: FastAPI):
@
asynccontextmanager
async
def
build_async_engine_client
(
args
)
->
AsyncIterator
[
AsyncEngineClient
]:
async
def
build_async_engine_client
(
args
:
Namespace
)
->
AsyncIterator
[
Optional
[
AsyncEngineClient
]]:
"""
Create AsyncEngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global
engine_args
...
...
@@ -97,7 +114,8 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if
(
model_is_embedding
(
args
.
model
,
args
.
trust_remote_code
)
if
(
model_is_embedding
(
args
.
model
,
args
.
trust_remote_code
,
args
.
quantization
)
or
args
.
disable_frontend_multiprocessing
):
async_engine_client
=
AsyncLLMEngine
.
from_engine_args
(
engine_args
,
usage_context
=
UsageContext
.
OPENAI_API_SERVER
)
...
...
@@ -106,37 +124,99 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
# Otherwise, use the multiprocessing AsyncLLMEngine.
else
:
# Start RPCServer in separate process (holds the AsyncLLMEngine).
port
=
get_open_port
(
envs
.
VLLM_RPC_PORT
)
rpc_server_process
=
Process
(
target
=
run_rpc_server
,
args
=
(
engine_args
,
UsageContext
.
OPENAI_API_SERVER
,
port
))
rpc_server_process
.
start
()
if
"PROMETHEUS_MULTIPROC_DIR"
not
in
os
.
environ
:
# Make TemporaryDirectory for prometheus multiprocessing
# Note: global TemporaryDirectory will be automatically
# cleaned up upon exit.
global
prometheus_multiproc_dir
prometheus_multiproc_dir
=
tempfile
.
TemporaryDirectory
()
os
.
environ
[
"PROMETHEUS_MULTIPROC_DIR"
]
=
prometheus_multiproc_dir
.
name
else
:
logger
.
warning
(
"Found PROMETHEUS_MULTIPROC_DIR was set by user. "
"This directory must be wiped between vLLM runs or "
"you will find inaccurate metrics. Unset the variable "
"and vLLM will properly handle cleanup."
)
# Select random path for IPC.
rpc_path
=
get_open_zmq_ipc_path
()
logger
.
info
(
"Multiprocessing frontend to use %s for RPC Path."
,
rpc_path
)
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
async_engine_client
=
AsyncEngineRPCClient
(
port
)
await
async_engine_client
.
setup
()
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client
=
AsyncEngineRPCClient
(
rpc_path
)
async_engine_client
=
rpc_client
# type: ignore
# Start RPCServer in separate process (holds the AsyncLLMEngine).
context
=
multiprocessing
.
get_context
(
"spawn"
)
# the current process might have CUDA context,
# so we need to spawn a new process
rpc_server_process
=
context
.
Process
(
target
=
run_rpc_server
,
args
=
(
engine_args
,
UsageContext
.
OPENAI_API_SERVER
,
rpc_path
))
rpc_server_process
.
start
()
logger
.
info
(
"Started engine process with PID %d"
,
rpc_server_process
.
pid
)
try
:
while
True
:
try
:
await
rpc_client
.
setup
()
break
except
TimeoutError
:
if
not
rpc_server_process
.
is_alive
():
logger
.
error
(
"RPCServer process died before responding "
"to readiness probe"
)
yield
None
return
yield
async_engine_client
finally
:
# Ensure rpc server process was terminated
rpc_server_process
.
terminate
()
# Close all open connections to the backend
async_engine
_client
.
close
()
rpc
_client
.
close
()
# Wait for server process to join
rpc_server_process
.
join
()
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from
prometheus_client
import
multiprocess
multiprocess
.
mark_process_dead
(
rpc_server_process
.
pid
)
router
=
APIRouter
()
def
mount_metrics
(
app
:
FastAPI
):
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from
prometheus_client
import
(
CollectorRegistry
,
make_asgi_app
,
multiprocess
)
prometheus_multiproc_dir_path
=
os
.
getenv
(
"PROMETHEUS_MULTIPROC_DIR"
,
None
)
if
prometheus_multiproc_dir_path
is
not
None
:
logger
.
info
(
"vLLM to use %s as PROMETHEUS_MULTIPROC_DIR"
,
prometheus_multiproc_dir_path
)
registry
=
CollectorRegistry
()
multiprocess
.
MultiProcessCollector
(
registry
)
# Add prometheus asgi middleware to route /metrics requests
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
(
registry
=
registry
))
else
:
# Add prometheus asgi middleware to route /metrics requests
metrics_route
=
Mount
(
"/metrics"
,
make_asgi_app
())
# Workaround for 307 Redirect for /metrics
metrics_route
.
path_regex
=
re
.
compile
(
'^/metrics(?P<path>.*)$'
)
app
.
routes
.
append
(
metrics_route
)
...
...
@@ -155,10 +235,11 @@ async def tokenize(request: TokenizeRequest):
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
else
:
assert
isinstance
(
generator
,
TokenizeResponse
)
elif
isinstance
(
generator
,
TokenizeResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
assert_never
(
generator
)
@
router
.
post
(
"/detokenize"
)
async
def
detokenize
(
request
:
DetokenizeRequest
):
...
...
@@ -166,10 +247,11 @@ async def detokenize(request: DetokenizeRequest):
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
else
:
assert
isinstance
(
generator
,
DetokenizeResponse
)
elif
isinstance
(
generator
,
DetokenizeResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
assert_never
(
generator
)
@
router
.
get
(
"/v1/models"
)
async
def
show_available_models
():
...
...
@@ -191,13 +273,11 @@ async def create_chat_completion(request: ChatCompletionRequest,
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
if
request
.
stream
:
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
else
:
assert
isinstance
(
generator
,
ChatCompletionResponse
)
elif
isinstance
(
generator
,
ChatCompletionResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
@
router
.
post
(
"/v1/completions"
)
async
def
create_completion
(
request
:
CompletionRequest
,
raw_request
:
Request
):
...
...
@@ -206,12 +286,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
if
request
.
stream
:
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
else
:
elif
isinstance
(
generator
,
CompletionResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
())
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
@
router
.
post
(
"/v1/embeddings"
)
async
def
create_embedding
(
request
:
EmbeddingRequest
,
raw_request
:
Request
):
...
...
@@ -220,9 +299,31 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
if
isinstance
(
generator
,
ErrorResponse
):
return
JSONResponse
(
content
=
generator
.
model_dump
(),
status_code
=
generator
.
code
)
else
:
el
if
isinstance
(
generator
,
EmbeddingRespon
se
)
:
return
JSONResponse
(
content
=
generator
.
model_dump
())
assert_never
(
generator
)
if
envs
.
VLLM_TORCH_PROFILER_DIR
:
logger
.
warning
(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
@
router
.
post
(
"/start_profile"
)
async
def
start_profile
():
logger
.
info
(
"Starting profiler..."
)
await
async_engine_client
.
start_profile
()
logger
.
info
(
"Profiler started."
)
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/stop_profile"
)
async
def
stop_profile
():
logger
.
info
(
"Stopping profiler..."
)
await
async_engine_client
.
stop_profile
()
logger
.
info
(
"Profiler stopped."
)
return
Response
(
status_code
=
200
)
def
build_app
(
args
:
Namespace
)
->
FastAPI
:
app
=
FastAPI
(
lifespan
=
lifespan
)
...
...
@@ -340,10 +441,15 @@ async def run_server(args, **uvicorn_kwargs) -> None:
logger
.
info
(
"args: %s"
,
args
)
async
with
build_async_engine_client
(
args
)
as
async_engine_client
:
# If None, creation of the client failed and we exit.
if
async_engine_client
is
None
:
return
app
=
await
init_app
(
async_engine_client
,
args
)
shutdown_task
=
await
serve_http
(
app
,
engine
=
async_engine_client
,
host
=
args
.
host
,
port
=
args
.
port
,
log_level
=
args
.
uvicorn_log_level
,
...
...
vllm/entrypoints/openai/cli_args.py
View file @
af7f4372
...
...
@@ -7,6 +7,7 @@ purposes.
import
argparse
import
json
import
ssl
from
typing
import
List
,
Optional
,
Sequence
,
Union
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
nullable_str
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
...
...
@@ -16,8 +17,19 @@ from vllm.utils import FlexibleArgumentParser
class
LoRAParserAction
(
argparse
.
Action
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
lora_list
=
[]
def
__call__
(
self
,
parser
:
argparse
.
ArgumentParser
,
namespace
:
argparse
.
Namespace
,
values
:
Optional
[
Union
[
str
,
Sequence
[
str
]]],
option_string
:
Optional
[
str
]
=
None
,
):
if
values
is
None
:
values
=
[]
if
isinstance
(
values
,
str
):
raise
TypeError
(
"Expected values to be a list"
)
lora_list
:
List
[
LoRAModulePath
]
=
[]
for
item
in
values
:
name
,
path
=
item
.
split
(
'='
)
lora_list
.
append
(
LoRAModulePath
(
name
,
path
))
...
...
@@ -26,8 +38,19 @@ class LoRAParserAction(argparse.Action):
class
PromptAdapterParserAction
(
argparse
.
Action
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
adapter_list
=
[]
def
__call__
(
self
,
parser
:
argparse
.
ArgumentParser
,
namespace
:
argparse
.
Namespace
,
values
:
Optional
[
Union
[
str
,
Sequence
[
str
]]],
option_string
:
Optional
[
str
]
=
None
,
):
if
values
is
None
:
values
=
[]
if
isinstance
(
values
,
str
):
raise
TypeError
(
"Expected values to be a list"
)
adapter_list
:
List
[
PromptAdapterPath
]
=
[]
for
item
in
values
:
name
,
path
=
item
.
split
(
'='
)
adapter_list
.
append
(
PromptAdapterPath
(
name
,
path
))
...
...
vllm/entrypoints/openai/logits_processors.py
View file @
af7f4372
...
...
@@ -2,9 +2,9 @@ from functools import lru_cache, partial
from
typing
import
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Union
import
torch
from
transformers
import
PreTrainedTokenizer
from
vllm.sampling_params
import
LogitsProcessor
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
class
AllowedTokenIdsLogitsProcessor
:
...
...
@@ -40,9 +40,11 @@ def _get_allowed_token_ids_logits_processor(
return
AllowedTokenIdsLogitsProcessor
(
allowed_token_ids
)
def
logit_bias_logits_processor
(
logit_bias
:
Dict
[
str
,
float
],
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
logit_bias_logits_processor
(
logit_bias
:
Dict
[
int
,
float
],
token_ids
:
List
[
int
],
logits
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
for
token_id
,
bias
in
logit_bias
.
items
():
logits
[
token_id
]
+=
bias
return
logits
...
...
@@ -51,8 +53,9 @@ def logit_bias_logits_processor(logit_bias: Dict[str,
def
get_logits_processors
(
logit_bias
:
Optional
[
Union
[
Dict
[
int
,
float
],
Dict
[
str
,
float
]]],
allowed_token_ids
:
Optional
[
List
[
int
]],
tokenizer
:
PreTrainedTokenizer
)
->
List
[
LogitsProcessor
]:
logits_processors
=
[]
tokenizer
:
AnyTokenizer
,
)
->
List
[
LogitsProcessor
]:
logits_processors
:
List
[
LogitsProcessor
]
=
[]
if
logit_bias
:
try
:
# Convert token_id to integer
...
...
@@ -69,7 +72,7 @@ def get_logits_processors(
# Check if token_id is within the vocab size
for
token_id
,
bias
in
clamped_logit_bias
.
items
():
if
token_id
<
0
or
token_id
>=
tokenizer
.
vocab_size
:
raise
ValueError
(
"token_id in logit_bias contains "
raise
ValueError
(
f
"token_id
{
token_id
}
in logit_bias contains "
"out-of-vocab token id"
)
logits_processors
.
append
(
...
...
vllm/entrypoints/openai/protocol.py
View file @
af7f4372
...
...
@@ -6,18 +6,20 @@ from typing import Any, Dict, List, Literal, Optional, Union
import
torch
from
pydantic
import
BaseModel
,
ConfigDict
,
Field
,
model_validator
from
transformers
import
PreTrainedTokenizer
from
typing_extensions
import
Annotated
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.entrypoints.openai.logits_processors
import
get_logits_processors
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
LogitsProcessor
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
random_uuid
# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO
=
Namespace
(
min
=-
9223372036854775808
,
max
=
9223372036854775807
)
_LONG_INFO
:
Union
[
"torch.iinfo"
,
Namespace
]
try
:
from
sphinx.ext.autodoc.mock
import
_MockModule
...
...
@@ -152,6 +154,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
skip_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
prompt_logprobs
:
Optional
[
int
]
=
None
# doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params
...
...
@@ -190,8 +193,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
default
=
None
,
description
=
(
"A Jinja template to use for this conversion. "
"If this is not passed, the model's default chat template will be "
"used instead."
),
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
),
)
chat_template_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
Field
(
default
=
None
,
...
...
@@ -232,13 +236,17 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params
def
to_sampling_params
(
self
,
tokenizer
:
PreTrained
Tokenizer
,
self
,
tokenizer
:
Any
Tokenizer
,
guided_decode_logits_processor
:
Optional
[
LogitsProcessor
],
default_max_tokens
:
int
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
prompt_logprobs
=
self
.
prompt_logprobs
if
prompt_logprobs
is
None
and
self
.
echo
:
prompt_logprobs
=
self
.
top_logprobs
# We now allow logprobs being true without top_logrobs.
logits_processors
=
get_logits_processors
(
logit_bias
=
self
.
logit_bias
,
...
...
@@ -248,7 +256,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
if
guided_decode_logits_processor
:
logits_processors
.
append
(
guided_decode_logits_processor
)
return
SamplingParams
(
return
SamplingParams
.
from_optional
(
n
=
self
.
n
,
best_of
=
self
.
best_of
,
presence_penalty
=
self
.
presence_penalty
,
...
...
@@ -262,7 +270,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
stop
=
self
.
stop
,
stop_token_ids
=
self
.
stop_token_ids
,
logprobs
=
self
.
top_logprobs
if
self
.
logprobs
else
None
,
prompt_logprobs
=
self
.
top
_logprobs
if
self
.
echo
else
None
,
prompt_logprobs
=
prompt
_logprobs
,
ignore_eos
=
self
.
ignore_eos
,
max_tokens
=
max_tokens
,
min_tokens
=
self
.
min_tokens
,
...
...
@@ -276,14 +284,36 @@ class ChatCompletionRequest(OpenAIBaseModel):
truncate_prompt_tokens
=
self
.
truncate_prompt_tokens
,
)
@
model_validator
(
mode
=
'
before
'
)
@
model_validator
(
mode
=
"
before
"
)
@
classmethod
def
validate_stream_options
(
cls
,
values
):
if
(
values
.
get
(
'stream_options'
)
is
not
None
and
not
values
.
get
(
'stream'
)):
def
validate_stream_options
(
cls
,
data
):
if
data
.
get
(
"stream_options"
)
and
not
data
.
get
(
"stream"
):
raise
ValueError
(
"stream_options can only be set if stream is true"
)
return
values
"Stream options can only be defined when `stream=True`."
)
return
data
@
model_validator
(
mode
=
"before"
)
@
classmethod
def
check_logprobs
(
cls
,
data
):
if
(
prompt_logprobs
:
=
data
.
get
(
"prompt_logprobs"
))
is
not
None
:
if
data
.
get
(
"stream"
)
and
prompt_logprobs
>
0
:
raise
ValueError
(
"`prompt_logprobs` are not available when `stream=True`."
)
if
prompt_logprobs
<
0
:
raise
ValueError
(
"`prompt_logprobs` must be a positive value."
)
if
(
top_logprobs
:
=
data
.
get
(
"top_logprobs"
))
is
not
None
:
if
top_logprobs
<
0
:
raise
ValueError
(
"`top_logprobs` must be a positive value."
)
if
not
data
.
get
(
"logprobs"
):
raise
ValueError
(
"when using `top_logprobs`, `logprobs` must be set to true."
)
return
data
@
model_validator
(
mode
=
"before"
)
@
classmethod
...
...
@@ -316,19 +346,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
"When using `tool_choice`, `tools` must be set."
)
return
data
@
model_validator
(
mode
=
"before"
)
@
classmethod
def
check_logprobs
(
cls
,
data
):
if
"top_logprobs"
in
data
and
data
[
"top_logprobs"
]
is
not
None
:
if
"logprobs"
not
in
data
or
data
[
"logprobs"
]
is
False
:
raise
ValueError
(
"when using `top_logprobs`, `logprobs` must be set to true."
)
elif
data
[
"top_logprobs"
]
<
0
:
raise
ValueError
(
"`top_logprobs` must be a value a positive value."
)
return
data
class
CompletionRequest
(
OpenAIBaseModel
):
# Ordered by official OpenAI API documentation
...
...
@@ -367,6 +384,7 @@ class CompletionRequest(OpenAIBaseModel):
spaces_between_special_tokens
:
bool
=
True
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
allowed_token_ids
:
Optional
[
List
[
int
]]
=
None
prompt_logprobs
:
Optional
[
int
]
=
None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
...
...
@@ -417,13 +435,17 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params
def
to_sampling_params
(
self
,
tokenizer
:
PreTrained
Tokenizer
,
self
,
tokenizer
:
Any
Tokenizer
,
guided_decode_logits_processor
:
Optional
[
LogitsProcessor
],
default_max_tokens
:
int
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
prompt_logprobs
=
self
.
prompt_logprobs
if
prompt_logprobs
is
None
and
self
.
echo
:
prompt_logprobs
=
self
.
logprobs
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
logits_processors
=
get_logits_processors
(
...
...
@@ -434,7 +456,7 @@ class CompletionRequest(OpenAIBaseModel):
if
guided_decode_logits_processor
:
logits_processors
.
append
(
guided_decode_logits_processor
)
return
SamplingParams
(
return
SamplingParams
.
from_optional
(
n
=
self
.
n
,
best_of
=
self
.
best_of
,
presence_penalty
=
self
.
presence_penalty
,
...
...
@@ -453,7 +475,7 @@ class CompletionRequest(OpenAIBaseModel):
min_tokens
=
self
.
min_tokens
,
use_beam_search
=
self
.
use_beam_search
,
early_stopping
=
self
.
early_stopping
,
prompt_logprobs
=
self
.
logprobs
if
self
.
echo
else
None
,
prompt_logprobs
=
prompt_
logprobs
,
skip_special_tokens
=
self
.
skip_special_tokens
,
spaces_between_special_tokens
=
self
.
spaces_between_special_tokens
,
include_stop_str_in_output
=
self
.
include_stop_str_in_output
,
...
...
@@ -479,9 +501,17 @@ class CompletionRequest(OpenAIBaseModel):
@
model_validator
(
mode
=
"before"
)
@
classmethod
def
check_logprobs
(
cls
,
data
):
if
"logprobs"
in
data
and
data
[
"logprobs"
]
is
not
None
and
not
data
[
"logprobs"
]
>=
0
:
raise
ValueError
(
"if passed, `logprobs` must be a positive value."
)
if
(
prompt_logprobs
:
=
data
.
get
(
"prompt_logprobs"
))
is
not
None
:
if
data
.
get
(
"stream"
)
and
prompt_logprobs
>
0
:
raise
ValueError
(
"`prompt_logprobs` are not available when `stream=True`."
)
if
prompt_logprobs
<
0
:
raise
ValueError
(
"`prompt_logprobs` must be a positive value."
)
if
(
logprobs
:
=
data
.
get
(
"logprobs"
))
is
not
None
and
logprobs
<
0
:
raise
ValueError
(
"`logprobs` must be a positive value."
)
return
data
@
model_validator
(
mode
=
"before"
)
...
...
@@ -489,7 +519,8 @@ class CompletionRequest(OpenAIBaseModel):
def
validate_stream_options
(
cls
,
data
):
if
data
.
get
(
"stream_options"
)
and
not
data
.
get
(
"stream"
):
raise
ValueError
(
"Stream options can only be defined when stream is true."
)
"Stream options can only be defined when `stream=True`."
)
return
data
...
...
@@ -498,7 +529,7 @@ class EmbeddingRequest(OpenAIBaseModel):
# https://platform.openai.com/docs/api-reference/embeddings
model
:
str
input
:
Union
[
List
[
int
],
List
[
List
[
int
]],
str
,
List
[
str
]]
encoding_format
:
Optional
[
str
]
=
Field
(
'
float
'
,
pattern
=
'^(float|base64)$'
)
encoding_format
:
Literal
[
"
float
"
,
"base64"
]
=
"float"
dimensions
:
Optional
[
int
]
=
None
user
:
Optional
[
str
]
=
None
...
...
@@ -531,6 +562,7 @@ class CompletionResponseChoice(OpenAIBaseModel):
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"
),
)
prompt_logprobs
:
Optional
[
List
[
Optional
[
Dict
[
int
,
Logprob
]]]]
=
None
class
CompletionResponse
(
OpenAIBaseModel
):
...
...
@@ -626,6 +658,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
model
:
str
choices
:
List
[
ChatCompletionResponseChoice
]
usage
:
UsageInfo
prompt_logprobs
:
Optional
[
List
[
Optional
[
Dict
[
int
,
Logprob
]]]]
=
None
class
DeltaMessage
(
OpenAIBaseModel
):
...
...
@@ -671,7 +704,7 @@ class BatchRequestInput(OpenAIBaseModel):
url
:
str
# The parameters of the request.
body
:
ChatCompletionRequest
body
:
Union
[
ChatCompletionRequest
,
EmbeddingRequest
]
class
BatchResponseData
(
OpenAIBaseModel
):
...
...
@@ -682,7 +715,7 @@ class BatchResponseData(OpenAIBaseModel):
request_id
:
str
# The body of the response.
body
:
Optional
[
ChatCompletionResponse
]
=
None
body
:
Optional
[
Union
[
ChatCompletionResponse
,
EmbeddingResponse
]
]
=
None
class
BatchRequestOutput
(
OpenAIBaseModel
):
...
...
vllm/entrypoints/openai/rpc/__init__.py
View file @
af7f4372
...
...
@@ -7,8 +7,14 @@ from vllm.lora.request import LoRARequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
# Success string used for RPC instructions.
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
VLLM_RPC_HEALTHY_STR
=
"HEALTHY"
# Minimum value of ZMQ.SOCKET_LIMIT to run mp.
VLLM_RPC_SOCKET_LIMIT_CUTOFF
=
2000
# HWM is set to Infinity.
VLLM_RPC_ZMQ_HWM
=
0
@
dataclass
...
...
@@ -34,8 +40,10 @@ class RPCUtilityRequest(Enum):
GET_SCHEDULER_CONFIG
=
5
GET_LORA_CONFIG
=
6
DO_LOG_STATS
=
7
CHECK
_HEALTH
=
8
IS_SERVER
_HEALTH
Y
=
8
IS_TRACING_ENABLED
=
9
START_PROFILE
=
10
STOP_PROFILE
=
11
RPC_REQUEST_TYPE
=
Union
[
RPCGenerateRequest
,
RPCAbortRequest
,
...
...
vllm/entrypoints/openai/rpc/client.py
View file @
af7f4372
from
contextlib
import
contextmanager
from
typing
import
Any
,
AsyncIterator
,
Mapping
,
Optional
import
asyncio
from
contextlib
import
contextmanager
,
suppress
from
typing
import
Any
,
AsyncGenerator
,
Mapping
,
Optional
from
uuid
import
uuid4
import
cloudpickle
import
zmq
...
...
@@ -7,29 +9,152 @@ import zmq.asyncio
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
# yapf: disable
from
vllm.entrypoints.openai.rpc
import
(
RPC_REQUEST_TYPE
,
VLLM_RPC_HEALTHY_STR
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
VLLM_RPC_SOCKET_LIMIT_CUTOFF
,
VLLM_RPC_SUCCESS_STR
,
VLLM_RPC_ZMQ_HWM
,
RPCAbortRequest
,
RPCGenerateRequest
,
RPCUtilityRequest
)
# yapf: enable
from
vllm.envs
import
VLLM_RPC_GET_DATA_TIMEOUT_MS
from
vllm.inputs
import
PromptInputs
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
logger
=
init_logger
(
__name__
)
# Path used for inprocess proxy.
INPROC_PROXY_PATH
=
f
"inproc://
{
uuid4
()
}
"
class
RPCClientClosedError
(
Exception
):
"""Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it.
"""
class
AsyncEngineRPCClient
:
def
__init__
(
self
,
port
:
int
):
class
AsyncEngineRPCClient
:
"""
RPCClient that connects to the RPCServer wrapping AsyncLLMEngine.
The overall design mirrors the Asynchronous Client Server Pattern
https://zguide.zeromq.org/docs/chapter3/#The-Asynchronous-Client-Server-Pattern
On startup, the RPCClient:
- makes DEALER socket (to_rpc_server) that connects to the RPCServer
via ipc, which uses unix sockets under the hood
(https://libzmq.readthedocs.io/en/zeromq4-1/zmq_ipc.html)
- makes ROUTER socket (from_api_server) that binds to a random
inproc address, which uses memory under the hood
(https://libzmq.readthedocs.io/en/zeromq3-x/zmq_inproc.html)
- runs a proxy in a background asyncio task between
from_api_server (ROUTER, inproc) and to_rpc_server (DEALER ipc, )
Each request handled by the asyncio api_server calls generate():
- make a DEALER socket that connects to from_api_server via inproc
- send a RCPGenerateRequest to the inproc socket
- background proxy forwards the request from inproc -> ipc
- RPCServer responds to the request one token at a time over ipc
- background proxy forwards the response from ipc -> inproc
The connection looks like this:
DEALER <- inproc -> [ ROUTER | DEALER ] <- ipc -> DEALER
Message routing is performed via identities that are managed by the
ROUTER socket. ROUTER sockets track every connection it has and
tells the caller about these. The way it tells the caller is to stick
the connection identity in front of each message received. When we
send the message via a ROUTER, we first send an identity frame.
See https://zguide.zeromq.org/docs/chapter3/#The-Extended-Reply-Envelope
for more details on connection identities.
This proxy design enables us to use a single unix socket, which
improves performance by avoiding syscalls (~5%) and avoids resource limits
such as ulimit, which defaults to 1024 on ubuntu.
Note: we run set_hwm(0) on each socket, which sets the HWM to inf,
which is required to avoid dropping messages under high load.
This is generally not advisable. However, since we are in control
of both sides of the connection + failure on either side is
catastrophic to the overall system health and memory profiling
suggests limited memory overhead relative to asyncio, we will
proceed for now.
See https://zguide.zeromq.org/docs/chapter2/#High-Water-Marks
for more details on high water marks.
"""
def
__init__
(
self
,
rpc_path
:
str
):
self
.
context
=
zmq
.
asyncio
.
Context
()
self
.
path
=
f
"tcp://localhost:
{
port
}
"
self
.
_data_timeout
=
VLLM_RPC_GET_DATA_TIMEOUT_MS
self
.
_errored
=
False
# Maximum number of sockets that can be opened (typically 65536).
# ZMQ_SOCKET_LIMIT (http://api.zeromq.org/4-2:zmq-ctx-get)
socket_limit
=
self
.
context
.
get
(
zmq
.
constants
.
SOCKET_LIMIT
)
if
socket_limit
<
VLLM_RPC_SOCKET_LIMIT_CUTOFF
:
raise
ValueError
(
f
"Found zmq.constants.SOCKET_LIMIT=
{
socket_limit
}
, which caps "
"the number of concurrent requests vLLM can process. Launch "
"vLLM with --disable-frontend-multiprocessing and open a "
"GitHub issue so we can investigate."
)
# We only have 1 ipc connection that uses unix sockets, so
# safe to set MAX_SOCKETS to the zmq SOCKET_LIMIT (i.e. will
# not run into ulimit issues)
self
.
context
.
set
(
zmq
.
constants
.
MAX_SOCKETS
,
socket_limit
)
# IPC connection to RPC Server (uses unix sockets).
self
.
to_rpc_server
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
to_rpc_server
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
to_rpc_server
.
bind
(
rpc_path
)
# In process proxy to RPC Server (uses memory-based messaging).
self
.
from_api_server
=
self
.
context
.
socket
(
zmq
.
constants
.
ROUTER
)
self
.
from_api_server
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
from_api_server
.
bind
(
INPROC_PROXY_PATH
)
# Asyncio background task for the proxy.
self
.
proxy_task
=
asyncio
.
create_task
(
self
.
run_proxy
(
self
.
from_api_server
,
self
.
to_rpc_server
))
# Since we open 1 inproc socket per request, we have a hard cap on
# the number of requests that can run in vLLM w. frontend
# mulitprocessing. This value is used uvicorn to launch
# with --limit-concurrency to return 503 when server is overloaded.
# We need 2 sockets per request - 2:
# 1 for generate(), 1 for abort(), do_log_stats(), check_health()
self
.
limit_concurrency
=
socket_limit
//
2
-
2
async
def
run_proxy
(
self
,
socket_from
,
socket_to
):
"""Background task that runs a proxy"""
poller
=
zmq
.
asyncio
.
Poller
()
poller
.
register
(
socket_from
,
zmq
.
constants
.
POLLIN
)
poller
.
register
(
socket_to
,
zmq
.
constants
.
POLLIN
)
while
True
:
events
=
await
poller
.
poll
()
events
=
dict
(
events
)
if
socket_from
in
events
:
identity
,
msg
=
await
socket_from
.
recv_multipart
()
await
socket_to
.
send_multipart
([
identity
,
msg
])
if
socket_to
in
events
:
identity
,
msg
=
await
socket_to
.
recv_multipart
()
await
socket_from
.
send_multipart
([
identity
,
msg
])
async
def
setup
(
self
):
"""Setup the client before it starts sending server requests."""
# Wait until server is ready.
await
self
.
wait_for_server
()
await
self
.
_
wait_for_server
_rpc
()
# Get the configs.
self
.
model_config
=
await
self
.
_get_model_config_rpc
()
...
...
@@ -47,59 +172,100 @@ class AsyncEngineRPCClient:
def
close
(
self
):
"""Destroy the ZeroMQ Context."""
# Close all sockets associated with this context and
# then terminate the context.
self
.
from_api_server
.
close
()
self
.
to_rpc_server
.
close
()
self
.
context
.
destroy
()
@
contextmanager
def
socket
(
self
):
# Ensure client sockets are always closed after use
def
to_proxy_socket
(
self
):
# Connect to the RPCServer via the proxy.
# Raise a sensible error if the client was already closed.
# This can happen if a server shutdown is triggered but some coroutines
# are still running requests.
# There should not be a race condition with this check because we don't
# yield to the event loop between here and opening the socket.
if
self
.
context
.
closed
:
raise
RPCClientClosedError
(
"The ZMQ client has already shut down"
)
# Connect to RPC socket for Request-Reply pattern,
# Note that we use DEALER to enable asynchronous communication
# to enable streaming.
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
socket
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
try
:
socket
.
connect
(
self
.
path
)
socket
.
connect
(
INPROC_PROXY_PATH
)
yield
socket
finally
:
socket
.
close
()
socket
.
close
(
linger
=
0
)
async
def
_send_get_data_rpc_request
(
self
,
request
:
RPCUtilityRequest
,
expected_type
:
Any
,
error_message
:
str
)
->
Any
:
"""Send an RPC request that is expecting data back."""
with
self
.
socket
()
as
socket
:
with
self
.
to_proxy_socket
()
as
socket
:
# Ping RPCServer with a request.
await
socket
.
send
(
cloudpickle
.
dumps
(
request
))
await
socket
.
send_multipart
([
cloudpickle
.
dumps
(
request
)])
# Make sure the server responds
if
await
socket
.
poll
(
timeout
=
self
.
_data_timeout
)
==
0
:
raise
TimeoutError
(
"Server didn't reply within "
f
"
{
self
.
_data_timeout
}
ms"
)
# Await the data from the Server.
data
=
cloudpickle
.
loads
(
await
socket
.
recv
())
if
isinstance
(
data
,
Exception
):
# Re-raise exceptions returned by the server
raise
data
if
not
isinstance
(
data
,
expected_type
):
# LoRAConfig can be None.
if
expected_type
==
LoRAConfig
and
data
is
None
:
pass
elif
isinstance
(
data
,
Exception
):
logger
.
error
(
error_message
)
raise
data
else
:
raise
ValueError
(
error_message
)
return
data
async
def
_send_one_way_rpc_request
(
self
,
request
:
RPC_REQUEST_TYPE
,
error_message
:
str
):
async
def
_send_one_way_rpc_request
(
self
,
request
:
RPC_REQUEST_TYPE
,
error_message
:
str
,
socket
:
Optional
[
zmq
.
asyncio
.
Socket
]
=
None
):
"""Send one-way RPC request to trigger an action."""
with
self
.
socket
()
as
socket
:
# Ping RPC Server with request.
await
socket
.
send
(
cloudpickle
.
dumps
(
request
))
# Await acknowledgement from RPCServer.
response
=
cloudpickle
.
loads
(
await
socket
.
recv
())
async
def
do_rpc_call
(
socket
:
zmq
.
asyncio
.
Socket
,
request
:
RPC_REQUEST_TYPE
):
await
socket
.
send_multipart
([
cloudpickle
.
dumps
(
request
)])
if
await
socket
.
poll
(
timeout
=
self
.
_data_timeout
)
==
0
:
raise
TimeoutError
(
"Server didn't reply within "
f
"
{
self
.
_data_timeout
}
ms"
)
return
cloudpickle
.
loads
(
await
socket
.
recv
())
# Make a new socket connection.
if
socket
is
None
:
with
self
.
to_proxy_socket
()
as
socket
:
response
=
await
do_rpc_call
(
socket
,
request
)
# Use existing socket connection.
else
:
response
=
await
do_rpc_call
(
socket
,
request
)
if
not
isinstance
(
response
,
str
)
or
response
!=
VLLM_RPC_SUCCESS_STR
:
if
isinstance
(
response
,
Exception
):
logger
.
error
(
error_message
)
raise
response
raise
ValueError
(
error_message
)
return
response
async
def
get_tokenizer
(
self
,
lora_request
:
LoRARequest
):
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
...
...
@@ -112,12 +278,12 @@ class AsyncEngineRPCClient:
async
def
is_tracing_enabled
(
self
)
->
bool
:
return
self
.
tracing_flag
async
def
wait_for_server
(
self
):
async
def
_
wait_for_server
_rpc
(
self
):
"""Wait for the RPCServer to start up."""
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
IS_SERVER_READY
,
error_message
=
"Unable to start RPC Server
.
"
)
error_message
=
"Unable to start RPC Server"
)
async
def
_get_model_config_rpc
(
self
)
->
ModelConfig
:
"""Get the ModelConfig object from the RPC Server"""
...
...
@@ -151,7 +317,7 @@ class AsyncEngineRPCClient:
expected_type
=
SchedulerConfig
,
error_message
=
"Could not get SchedulerConfig from RPC Server"
)
async
def
_get_lora_config_rpc
(
self
):
async
def
_get_lora_config_rpc
(
self
)
->
LoRAConfig
:
"""Get LoRAConfig from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
...
...
@@ -159,29 +325,51 @@ class AsyncEngineRPCClient:
expected_type
=
LoRAConfig
,
error_message
=
"Could not get LoRAConfig from RPC Server"
)
async
def
_is_tracing_enabled_rpc
(
self
)
->
ParallelConfig
:
async
def
_is_tracing_enabled_rpc
(
self
)
->
bool
:
"""Get is_tracing_enabled flag from the RPCServer"""
return
await
self
.
_send_get_data_rpc_request
(
RPCUtilityRequest
.
IS_TRACING_ENABLED
,
expected_type
=
bool
,
error_message
=
"Could not get is_tracing_enabled flag from RPC "
"Server"
)
error_message
=
"Could not get is_tracing_enabled from RPC Server"
)
async
def
abort
(
self
,
request_id
:
str
):
"""Send an ABORT_REQUEST signal to the RPC Server"""
# Suppress timeouts as well.
# In cases where the server is busy processing requests and a very
# large volume of abort requests arrive, it is likely that the server
# will not be able to ack all of them in time. We have seen this when
# we abort 20k requests at once while another 2k are processing- many
# of them time out, but we see the server successfully abort all of the
# requests.
# In this case we assume that the server has received or will receive
# these abort requests, and ignore the timeout. This prevents a massive
# wall of `TimeoutError` stack traces.
with
suppress
(
RPCClientClosedError
,
TimeoutError
):
await
self
.
_send_one_way_rpc_request
(
request
=
RPCAbortRequest
(
request_id
),
error_message
=
f
"RPCAbortRequest
{
request_id
}
failed"
)
async
def
do_log_stats
(
self
):
"""Send a DO_LOG_STATS signal to the RPC Server"""
with
suppress
(
RPCClientClosedError
):
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
DO_LOG_STATS
,
error_message
=
"RPCRequest DO_LOG_STATS failed."
)
@
property
def
is_running
(
self
)
->
bool
:
return
not
self
.
_errored
@
property
def
is_stopped
(
self
)
->
bool
:
return
self
.
_errored
@
property
def
errored
(
self
)
->
bool
:
return
self
.
_errored
async
def
generate
(
self
,
inputs
:
PromptInputs
,
...
...
@@ -190,11 +378,12 @@ class AsyncEngineRPCClient:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
Async
It
erator
[
RequestOutput
]:
)
->
Async
Gen
erator
[
RequestOutput
,
None
]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
with
self
.
socket
()
as
socket
:
finished
=
False
try
:
with
self
.
to_proxy_socket
()
as
socket
:
# Send RPCGenerateRequest to the RPCServer.
await
socket
.
send_multipart
([
cloudpickle
.
dumps
(
...
...
@@ -208,41 +397,57 @@ class AsyncEngineRPCClient:
])
# Stream back the results from the RPC Server.
while
True
:
while
not
finished
:
message
=
await
socket
.
recv
()
request_output
=
cloudpickle
.
loads
(
message
)
if
isinstance
(
request_output
,
Exception
):
# On exception, check if the server is still healthy
# possibly setting the `errored` property.
if
not
self
.
_errored
:
try
:
await
self
.
check_health
(
socket
=
socket
)
except
Exception
as
e
:
self
.
_errored
=
True
logger
.
exception
(
repr
(
e
))
# NB: do before raising here so that the flag is set
# by the time the caller receives this exception
raise
request_output
if
request_output
.
finished
:
break
finished
=
request_output
.
finished
yield
request_output
yield
request_output
finally
:
# Request was canceled by the client.
if
not
finished
and
not
self
.
_errored
:
await
self
.
abort
(
request_id
)
async
def
check_health
(
self
)
->
None
:
async
def
check_health
(
self
,
socket
:
Optional
[
zmq
.
asyncio
.
Socket
]
=
None
)
->
None
:
"""Raise if unhealthy"""
with
self
.
socket
()
as
socket
:
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
IS_SERVER_HEALTHY
,
error_message
=
"Got Unhealthy response from RPC Server"
,
socket
=
socket
)
# Ping RPCServer with CHECK_HEALTH request.
await
socket
.
send
(
cloudpickle
.
dumps
(
RPCUtilityRequest
.
CHECK_HEALTH
)
)
async
def
encode
(
self
,
*
args
,
**
kwargs
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
raise
NotImplementedError
(
"Embeddings not supported with multiprocessing backend"
)
# Await the reply from the server.
# TODO: do we need an internal timeout here?
# Or do we expect the external probe to timeout and let this chill?
health_message
=
cloudpickle
.
loads
(
await
socket
.
recv
())
async
def
start_profile
(
self
)
->
None
:
"""Start profiling the engine"""
if
isinstance
(
health_message
,
Exception
):
raise
health_message
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
START_PROFILE
,
error_message
=
"RPCRequest START_PROFILE failed."
)
if
health_message
!=
VLLM_RPC_HEALTHY_STR
:
raise
ValueError
(
"Expected healthy response from backend but got "
"f{health_message}"
)
async
def
stop_profile
(
self
)
->
None
:
"""Stop profiling the engine"""
async
def
encode
(
self
,
*
args
,
**
kwargs
)
->
AsyncIterator
[
EmbeddingRequestOutput
]:
raise
NotImplementedError
(
"Embeddings not supported with multiprocessing backend"
)
await
self
.
_send_one_way_rpc_request
(
request
=
RPCUtilityRequest
.
STOP_PROFILE
,
error_message
=
"RPCRequest STOP_PROFILE failed."
)
\ No newline at end of file
vllm/entrypoints/openai/rpc/server.py
View file @
af7f4372
import
asyncio
import
signal
from
typing
import
Any
,
Coroutine
from
typing
import
Any
,
Coroutine
,
Union
import
cloudpickle
import
uvloop
import
zmq
import
zmq.asyncio
from
typing_extensions
import
Never
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
from
vllm.entrypoints.openai.rpc
import
(
VLLM_RPC_HEALTHY_STR
,
VLLM_RPC_SUCCESS_STR
,
RPCAbortRequest
,
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.entrypoints.openai.rpc
import
(
VLLM_RPC_SUCCESS_STR
,
VLLM_RPC_ZMQ_HWM
,
RPCAbortRequest
,
RPCGenerateRequest
,
RPCUtilityRequest
)
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
logger
=
init_logger
(
__name__
)
CONFIG_TYPE
=
Union
[
ModelConfig
,
DecodingConfig
,
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
]
class
AsyncEngineRPCServer
:
def
__init__
(
self
,
async_engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
port
:
int
):
usage_context
:
UsageContext
,
rpc_path
:
str
):
# Initialize engine first.
self
.
engine
=
AsyncLLMEngine
.
from_engine_args
(
async_engine_args
,
usage_context
)
self
.
engine
=
AsyncLLMEngine
.
from_engine_args
(
async_engine_args
,
usage_context
=
usage_context
)
# Initialize context.
self
.
context
=
zmq
.
asyncio
.
Context
()
# Init socket for readiness state.
self
.
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
ROUTER
)
# Note numeric form of localhost should be used for zmq bind(),
# see https://stackoverflow.com/a/8958414
self
.
socket
.
bind
(
f
"tcp://127.0.0.1:
{
port
}
"
)
# Init socket.
self
.
socket
=
self
.
context
.
socket
(
zmq
.
constants
.
DEALER
)
self
.
socket
.
set_hwm
(
VLLM_RPC_ZMQ_HWM
)
self
.
socket
.
connect
(
rpc_path
)
def
cleanup
(
self
):
"""Cleanup all resources."""
self
.
socket
.
close
()
self
.
context
.
destroy
()
self
.
engine
.
shutdown_background_loop
()
# Clear the engine reference so that it can be GC'ed.
del
self
.
engine
async
def
get_model_config
(
self
,
identity
):
"""Send the ModelConfig"""
model_config
=
await
self
.
engine
.
get_model_config
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
model_config
)])
async
def
get_decoding_config
(
self
,
identity
):
"""Send the DecodingConfig"""
decoding_config
=
await
self
.
engine
.
get_decoding_config
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
decoding_config
)])
async
def
get_lora_config
(
self
,
identity
):
lora_config
=
await
self
.
engine
.
get_lora_config
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
lora_config
)])
async
def
get_scheduler_config
(
self
,
identity
):
"""Send the SchedulerConfig"""
parallel_config
=
await
self
.
engine
.
get_scheduler_config
()
async
def
get_config
(
self
,
identity
,
request
):
try
:
config
:
CONFIG_TYPE
if
request
==
RPCUtilityRequest
.
GET_MODEL_CONFIG
:
config
=
await
self
.
engine
.
get_model_config
()
elif
request
==
RPCUtilityRequest
.
GET_DECODING_CONFIG
:
config
=
await
self
.
engine
.
get_decoding_config
()
elif
request
==
RPCUtilityRequest
.
GET_LORA_CONFIG
:
config
=
await
self
.
engine
.
get_lora_config
()
elif
request
==
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
:
config
=
await
self
.
engine
.
get_scheduler_config
()
elif
request
==
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
:
config
=
await
self
.
engine
.
get_parallel_config
()
else
:
raise
ValueError
(
"Unknown Config Request: %s"
,
request
)
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
parallel_config
)])
async
def
get_parallel_config
(
self
,
identity
):
"""Send the ParallelConfig"""
parallel_config
=
await
self
.
engine
.
get_parallel_config
()
[
identity
,
cloudpickle
.
dumps
(
config
)])
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
parallel_config
)])
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
e
)])
async
def
is_tracing_enabled
(
self
,
identity
):
"""Send the is_tracing_enabled flag"""
...
...
@@ -84,28 +80,23 @@ class AsyncEngineRPCServer:
"""Log stats and confirm success."""
await
self
.
engine
.
do_log_stats
()
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
])
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)])
async
def
is_server_ready
(
self
,
identity
):
"""Notify the client that we are ready."""
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
])
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)])
async
def
abort
(
self
,
identity
,
request
:
RPCAbortRequest
):
"""Abort request and notify the client of success."""
try
:
# Abort the request in the llm engine.
await
self
.
engine
.
abort
(
request
.
request_id
)
# Send confirmation to the client.
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
])
result
:
Union
[
str
,
Exception
]
=
VLLM_RPC_SUCCESS_STR
except
Exception
as
e
:
result
=
e
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
result
)])
async
def
generate
(
self
,
identity
,
generate_request
:
RPCGenerateRequest
):
try
:
...
...
@@ -122,17 +113,37 @@ class AsyncEngineRPCServer:
[
identity
,
cloudpickle
.
dumps
(
request_output
)])
except
Exception
as
e
:
### Notify client of all failures
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
async
def
check_health
(
self
,
identity
):
try
:
await
self
.
engine
.
check_health
()
await
self
.
socket
.
send_multipart
(
[
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_HEALTHY_STR
)])
[
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
)])
except
Exception
as
e
:
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
e
)])
async
def
start_profile
(
self
,
identity
):
logger
.
info
(
"Starting profiler..."
)
await
self
.
engine
.
start_profile
()
logger
.
info
(
"Profiler started."
)
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
])
async
def
stop_profile
(
self
,
identity
):
logger
.
info
(
"Stopping profiler..."
)
await
self
.
engine
.
stop_profile
()
logger
.
info
(
"Profiler stopped."
)
await
self
.
socket
.
send_multipart
([
identity
,
cloudpickle
.
dumps
(
VLLM_RPC_SUCCESS_STR
),
])
def
_make_handler_coro
(
self
,
identity
,
message
)
->
Coroutine
[
Any
,
Any
,
Never
]:
"""Route the zmq message to the handler coroutine."""
...
...
@@ -146,24 +157,26 @@ class AsyncEngineRPCServer:
return
self
.
abort
(
identity
,
request
)
elif
isinstance
(
request
,
RPCUtilityRequest
):
if
request
==
RPCUtilityRequest
.
GET_MODEL_CONFIG
:
return
self
.
get_model_config
(
identity
)
elif
request
==
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
:
return
self
.
get_parallel_config
(
identity
)
elif
request
==
RPCUtilityRequest
.
GET_DECODING_CONFIG
:
return
self
.
get_decoding_config
(
identity
)
elif
request
==
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
:
return
self
.
get_scheduler_config
(
identity
)
elif
request
==
RPCUtilityRequest
.
GET_LORA_CONFIG
:
return
self
.
get_lora_config
(
identity
)
if
request
in
[
RPCUtilityRequest
.
GET_MODEL_CONFIG
,
RPCUtilityRequest
.
GET_PARALLEL_CONFIG
,
RPCUtilityRequest
.
GET_DECODING_CONFIG
,
RPCUtilityRequest
.
GET_SCHEDULER_CONFIG
,
RPCUtilityRequest
.
GET_LORA_CONFIG
]:
return
self
.
get_config
(
identity
,
request
)
elif
request
==
RPCUtilityRequest
.
DO_LOG_STATS
:
return
self
.
do_log_stats
(
identity
)
elif
request
==
RPCUtilityRequest
.
IS_SERVER_READY
:
return
self
.
is_server_ready
(
identity
)
elif
request
==
RPCUtilityRequest
.
CHECK
_HEALTH
:
elif
request
==
RPCUtilityRequest
.
IS_SERVER
_HEALTH
Y
:
return
self
.
check_health
(
identity
)
elif
request
==
RPCUtilityRequest
.
IS_TRACING_ENABLED
:
return
self
.
is_tracing_enabled
(
identity
)
elif
request
==
RPCUtilityRequest
.
START_PROFILE
:
return
self
.
start_profile
(
identity
)
elif
request
==
RPCUtilityRequest
.
STOP_PROFILE
:
return
self
.
stop_profile
(
identity
)
else
:
raise
ValueError
(
f
"Unknown RPCUtilityRequest type:
{
request
}
"
)
...
...
@@ -213,6 +226,6 @@ async def run_server(server: AsyncEngineRPCServer):
def
run_rpc_server
(
async_engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
port
:
int
):
server
=
AsyncEngineRPCServer
(
async_engine_args
,
usage_context
,
port
)
asyncio
.
run
(
run_server
(
server
))
usage_context
:
UsageContext
,
rpc_path
:
str
):
server
=
AsyncEngineRPCServer
(
async_engine_args
,
usage_context
,
rpc_path
)
uvloop
.
run
(
run_server
(
server
))
vllm/entrypoints/openai/run_batch.py
View file @
af7f4372
import
asyncio
from
io
import
StringIO
from
typing
import
Awaitable
,
List
from
typing
import
Awaitable
,
Callable
,
List
import
aiohttp
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
nullable_str
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.logger
import
RequestLogger
# yapf: disable
from
vllm.entrypoints.openai.protocol
import
(
BatchRequestInput
,
BatchRequestOutput
,
BatchResponseData
,
ChatCompletionResponse
,
ErrorResponse
)
EmbeddingResponse
,
ErrorResponse
)
# yapf: enable
from
vllm.entrypoints.openai.serving_chat
import
OpenAIServingChat
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
random_uuid
...
...
@@ -82,27 +85,26 @@ async def write_file(path_or_url: str, data: str) -> None:
f
.
write
(
data
)
async
def
run_request
(
chat_
serving
:
OpenAIServingChat
,
async
def
run_request
(
serving
_engine_func
:
Callable
,
request
:
BatchRequestInput
)
->
BatchRequestOutput
:
chat_request
=
request
.
body
chat_response
=
await
chat_serving
.
create_chat_completion
(
chat_request
)
response
=
await
serving_engine_func
(
request
.
body
)
if
isinstance
(
chat_
response
,
ChatCompletionResponse
):
if
isinstance
(
response
,
(
ChatCompletionResponse
,
EmbeddingResponse
)
):
batch_output
=
BatchRequestOutput
(
id
=
f
"vllm-
{
random_uuid
()
}
"
,
custom_id
=
request
.
custom_id
,
response
=
BatchResponseData
(
body
=
chat_
response
,
request_id
=
f
"vllm-batch-
{
random_uuid
()
}
"
),
body
=
response
,
request_id
=
f
"vllm-batch-
{
random_uuid
()
}
"
),
error
=
None
,
)
elif
isinstance
(
chat_
response
,
ErrorResponse
):
elif
isinstance
(
response
,
ErrorResponse
):
batch_output
=
BatchRequestOutput
(
id
=
f
"vllm-
{
random_uuid
()
}
"
,
custom_id
=
request
.
custom_id
,
response
=
BatchResponseData
(
status_code
=
chat_
response
.
code
,
status_code
=
response
.
code
,
request_id
=
f
"vllm-batch-
{
random_uuid
()
}
"
),
error
=
chat_
response
,
error
=
response
,
)
else
:
raise
ValueError
(
"Request must not be sent in stream mode"
)
...
...
@@ -128,6 +130,7 @@ async def main(args):
else
:
request_logger
=
RequestLogger
(
max_log_len
=
args
.
max_log_len
)
# Create the openai serving objects.
openai_serving_chat
=
OpenAIServingChat
(
engine
,
model_config
,
...
...
@@ -138,12 +141,35 @@ async def main(args):
request_logger
=
request_logger
,
chat_template
=
None
,
)
openai_serving_embedding
=
OpenAIServingEmbedding
(
engine
,
model_config
,
served_model_names
,
request_logger
=
request_logger
,
)
# Submit all requests in the file to the engine "concurrently".
response_futures
:
List
[
Awaitable
[
BatchRequestOutput
]]
=
[]
for
request_json
in
(
await
read_file
(
args
.
input_file
)).
strip
().
split
(
"
\n
"
):
# Skip empty lines.
request_json
=
request_json
.
strip
()
if
not
request_json
:
continue
request
=
BatchRequestInput
.
model_validate_json
(
request_json
)
response_futures
.
append
(
run_request
(
openai_serving_chat
,
request
))
# Determine the type of request and run it.
if
request
.
url
==
"/v1/chat/completions"
:
response_futures
.
append
(
run_request
(
openai_serving_chat
.
create_chat_completion
,
request
))
elif
request
.
url
==
"/v1/embeddings"
:
response_futures
.
append
(
run_request
(
openai_serving_embedding
.
create_embedding
,
request
))
else
:
raise
ValueError
(
"Only /v1/chat/completions and /v1/embeddings are"
"supported in the batch endpoint."
)
responses
=
await
asyncio
.
gather
(
*
response_futures
)
...
...
Prev
1
…
12
13
14
15
16
17
18
19
20
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment