Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6d2051cc
Commit
6d2051cc
authored
Oct 21, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev
parents
2c7f740a
a2c71c54
Changes
457
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1287 additions
and
593 deletions
+1287
-593
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+173
-22
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+223
-49
vllm/engine/metrics.py
vllm/engine/metrics.py
+9
-13
vllm/engine/metrics_types.py
vllm/engine/metrics_types.py
+0
-1
vllm/engine/multiprocessing/__init__.py
vllm/engine/multiprocessing/__init__.py
+63
-3
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+134
-22
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+8
-6
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+30
-6
vllm/engine/output_processor/single_step.py
vllm/engine/output_processor/single_step.py
+12
-160
vllm/engine/output_processor/stop_checker.py
vllm/engine/output_processor/stop_checker.py
+2
-2
vllm/engine/output_processor/util.py
vllm/engine/output_processor/util.py
+8
-5
vllm/engine/protocol.py
vllm/engine/protocol.py
+128
-11
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+36
-4
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+105
-103
vllm/entrypoints/logger.py
vllm/entrypoints/logger.py
+3
-2
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+19
-5
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+28
-1
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+139
-40
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+133
-121
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+34
-17
No files found.
Too many changes to show.
To preserve performance only
457 of 457+
files are displayed.
Plain diff
Email patch
vllm/engine/async_llm_engine.py
View file @
6d2051cc
...
...
@@ -2,8 +2,8 @@ import asyncio
import
time
import
weakref
from
functools
import
partial
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
from
typing
import
(
Any
,
AsyncGenerator
,
Callable
,
Coroutine
,
Dict
,
Iterable
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
,
overload
)
from
weakref
import
ReferenceType
import
vllm.envs
as
envs
...
...
@@ -14,12 +14,15 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.llm_engine
import
LLMEngine
,
SchedulerOutputState
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.engine.protocol
import
EngineClient
from
vllm.executor.executor_base
import
ExecutorAsyncBase
from
vllm.executor.gpu_executor
import
GPUExecutorAsync
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Prompt
Type
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
get_guided_decoding_logits_processor
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
...
...
@@ -28,7 +31,7 @@ from vllm.sampling_params import SamplingParams
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
weak_bind
from
vllm.utils
import
deprecate_kwargs
,
weak_bind
logger
=
init_logger
(
__name__
)
ENGINE_ITERATION_TIMEOUT_S
=
envs
.
VLLM_ENGINE_ITERATION_TIMEOUT_S
...
...
@@ -363,11 +366,18 @@ class _AsyncLLMEngine(LLMEngine):
self
.
cached_scheduler_outputs
[
virtual_engine
]
=
SchedulerOutputState
()
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output
:
bool
=
False
if
not
seq_group_metadata_list
\
else
seq_group_metadata_list
[
0
].
state
.
num_steps
==
1
ctx
.
append_output
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
allow_async_output_proc
,
is_last_step
=
True
)
is_last_step
=
True
,
is_first_step_output
=
is_first_step_output
)
if
outputs
and
allow_async_output_proc
:
assert
len
(
...
...
@@ -402,31 +412,86 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await
self
.
model_executor
.
stop_remote_worker_execution_loop_async
()
@
overload
# DEPRECATED
async
def
add_request_async
(
self
,
request_id
:
str
,
*
,
inputs
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
...
@
overload
async
def
add_request_async
(
self
,
request_id
:
str
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
async
def
add_request_async
(
self
,
request_id
:
str
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
"""Async version of :meth:`add_request`."""
if
inputs
is
not
None
:
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
if
priority
!=
0
and
not
self
.
scheduler_config
.
policy
==
"priority"
:
raise
ValueError
(
f
"Got priority
{
priority
}
but "
"Priority scheduling is not enabled."
)
if
arrival_time
is
None
:
arrival_time
=
time
.
time
()
preprocessed_inputs
=
await
self
.
input_preprocessor
.
preprocess_async
(
inputs
,
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
processed_inputs
=
self
.
input_processor
(
preprocessed_inputs
)
if
isinstance
(
params
,
SamplingParams
)
and
\
params
.
guided_decoding
is
not
None
:
# Guided decoding has an async implementation for building logits
# processors in a separate threadpool.
# We want to invoke that here instead of using the blocking
# implementation in the LLMEngine
params
=
await
build_guided_decoding_logits_processor_async
(
sampling_params
=
params
,
tokenizer
=
self
.
get_tokenizer
(
lora_request
),
default_guided_backend
=
self
.
decoding_config
.
guided_decoding_backend
)
self
.
_add_processed_request
(
request_id
=
request_id
,
processed_inputs
=
processed_inputs
,
...
...
@@ -435,6 +500,7 @@ class _AsyncLLMEngine(LLMEngine):
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
,
priority
=
priority
,
)
async
def
check_health_async
(
self
)
->
None
:
...
...
@@ -443,7 +509,37 @@ class _AsyncLLMEngine(LLMEngine):
self
.
model_executor
.
check_health
()
class
AsyncLLMEngine
:
async
def
build_guided_decoding_logits_processor_async
(
sampling_params
:
SamplingParams
,
tokenizer
:
AnyTokenizer
,
default_guided_backend
:
str
)
->
SamplingParams
:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Modifies sampling params in-place and returns
the modified sampling params."""
if
(
guided_decoding
:
=
sampling_params
.
guided_decoding
)
is
None
:
return
sampling_params
logger
.
debug
(
"Building guided decoding logits processor. "
"Params: %s"
,
guided_decoding
)
guided_decoding
.
backend
=
guided_decoding
.
backend
or
default_guided_backend
processor
=
await
get_guided_decoding_logits_processor
(
guided_params
=
guided_decoding
,
tokenizer
=
tokenizer
)
if
processor
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
[]
sampling_params
.
logits_processors
.
append
(
processor
)
# Unset guided decoding params after constructing the lp from them
sampling_params
.
guided_decoding
=
None
return
sampling_params
class
AsyncLLMEngine
(
EngineClient
):
"""An asynchronous wrapper for :class:`LLMEngine`.
This class is used to wrap the :class:`LLMEngine` class to make it
...
...
@@ -774,16 +870,58 @@ class AsyncLLMEngine:
# This method does not need to be async, but kept that way
# for backwards compatibility.
async
def
add_request
(
@
overload
# DEPRECATED
def
add_request
(
self
,
request_id
:
str
,
*
,
inputs
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]]:
...
@
overload
def
add_request
(
self
,
request_id
:
str
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
Coroutine
[
None
,
None
,
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]]:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
async
def
add_request
(
self
,
request_id
:
str
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
AsyncGenerator
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
],
None
]:
if
inputs
is
not
None
:
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
if
not
self
.
is_running
:
if
self
.
start_engine_loop
:
self
.
start_background_loop
()
...
...
@@ -794,26 +932,34 @@ class AsyncLLMEngine:
"error that caused the background loop to stop "
"(AsyncEngineDeadError)."
)
if
(
priority
!=
0
and
not
self
.
engine
.
scheduler_config
.
policy
==
"priority"
):
raise
ValueError
(
f
"Got priority
{
priority
}
but "
"Priority scheduling is not enabled."
)
stream
=
self
.
_request_tracker
.
add_request
(
request_id
,
verbose
=
self
.
log_requests
,
inputs
=
inputs
,
prompt
=
prompt
,
params
=
params
,
arrival_time
=
arrival_time
or
time
.
time
(),
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
)
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
)
return
stream
.
generator
()
async
def
generate
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate outputs for a request.
...
...
@@ -822,8 +968,7 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
...
...
@@ -831,6 +976,8 @@ class AsyncLLMEngine:
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: The priority of the request.
Only applicable with priority scheduling.
Yields:
The output `RequestOutput` objects from the LLMEngine
...
...
@@ -881,21 +1028,23 @@ class AsyncLLMEngine:
"""
async
for
output
in
await
self
.
add_request
(
request_id
,
inputs
,
prompt
,
sampling_params
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
):
yield
LLMEngine
.
validate_output
(
output
,
RequestOutput
)
async
def
encode
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model.
...
...
@@ -904,13 +1053,14 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
lora_request: LoRA request to use for generation, if any.
trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
...
...
@@ -959,10 +1109,11 @@ class AsyncLLMEngine:
"""
async
for
output
in
await
self
.
add_request
(
request_id
,
inputs
,
prompt
,
pooling_params
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
priority
=
priority
,
):
yield
LLMEngine
.
validate_output
(
output
,
EmbeddingRequestOutput
)
...
...
vllm/engine/llm_engine.py
View file @
6d2051cc
...
...
@@ -6,7 +6,7 @@ from functools import partial
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Deque
,
Dict
,
Iterable
,
List
,
Mapping
,
NamedTuple
,
Optional
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Type
,
Union
from
typing
import
Set
,
Type
,
Union
,
cast
,
overload
import
torch
from
typing_extensions
import
TypeVar
...
...
@@ -25,14 +25,17 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor
)
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.engine.output_processor.util
import
create_output_by_sequence_group
from
vllm.entrypoints.openai.logits_processors
import
get_logits_processors
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.ray_utils
import
initialize_ray_cluster
from
vllm.inputs
import
(
INPUT_REGISTRY
,
Encoder
Decoder
LLM
Inputs
,
InputRegistry
,
LLMInputs
,
Prompt
Inputs
)
from
vllm.inputs
import
(
INPUT_REGISTRY
,
Decoder
Only
Inputs
,
EncoderDecoderInputs
,
InputRegistry
,
Prompt
Type
)
from
vllm.inputs.preprocess
import
InputPreprocessor
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
get_local_guided_decoding_logits_processor
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
(
EmbeddingRequestOutput
,
RequestOutput
,
RequestOutputFactory
)
...
...
@@ -41,7 +44,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceGroupOutput
,
SequenceStatus
)
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
init_tracer
)
from
vllm.transformers_utils.config
import
try_get_generation_config
...
...
@@ -51,7 +54,7 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup
,
init_tokenizer_from_configs
)
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.utils
import
Counter
,
Device
,
weak_bind
from
vllm.utils
import
Counter
,
Device
,
deprecate_kwargs
,
weak_bind
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
...
...
@@ -90,6 +93,12 @@ class OutputData(NamedTuple):
scheduler_outputs
:
SchedulerOutputs
is_async
:
bool
is_last_step
:
bool
# Indicates if this output is from the first step of the
# multi-step. When multi-step is disabled, this is always
# set to True.
# is_first_step_output is invalid when `outputs` has
# outputs from multiple steps.
is_first_step_output
:
Optional
[
bool
]
skip
:
List
[
int
]
...
...
@@ -108,13 +117,15 @@ class SchedulerContext:
def
append_output
(
self
,
outputs
:
List
[
SamplerOutput
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduler_outputs
:
SchedulerOutputs
,
is_async
:
bool
,
is_last_step
:
bool
):
is_last_step
:
bool
,
is_first_step_output
:
Optional
[
bool
]):
self
.
output_queue
.
append
(
OutputData
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
is_async
,
is_last_step
=
is_last_step
,
is_first_step_output
=
is_first_step_output
,
skip
=
[]))
...
...
@@ -177,7 +188,7 @@ class LLMEngine:
raise
TypeError
(
f
"Expected output of type
{
output_type
}
, "
f
"but found type
{
type
(
output
)
}
"
)
return
output
return
cast
(
_O
,
output
)
@
classmethod
def
validate_outputs
(
...
...
@@ -236,10 +247,11 @@ class LLMEngine:
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
"enable_prefix_caching=%s, use_async_output_proc=%s, "
"use_cached_outputs=%s, mm_processor_kwargs=%s)"
,
"seed=%d, served_model_name=%s, "
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"mm_processor_kwargs=%s)"
,
VLLM_VERSION
,
model_config
.
model
,
speculative_config
,
...
...
@@ -268,8 +280,8 @@ class LLMEngine:
observability_config
,
model_config
.
seed
,
model_config
.
served_model_name
,
scheduler_config
.
use_v2_block_manager
,
scheduler_config
.
num_scheduler_steps
,
scheduler_config
.
chunked_prefill_enabled
,
scheduler_config
.
multi_step_stream_outputs
,
cache_config
.
enable_prefix_caching
,
model_config
.
use_async_output_proc
,
...
...
@@ -277,9 +289,6 @@ class LLMEngine:
model_config
.
mm_processor_kwargs
,
)
# TODO(woosuk): Print more configs in debug mode.
from
vllm.plugins
import
load_general_plugins
load_general_plugins
()
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
lora_config
=
lora_config
...
...
@@ -625,7 +634,7 @@ class LLMEngine:
def
_add_processed_request
(
self
,
request_id
:
str
,
processed_inputs
:
Union
[
LLM
Inputs
,
EncoderDecoder
LLM
Inputs
],
processed_inputs
:
Union
[
DecoderOnly
Inputs
,
EncoderDecoderInputs
],
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
float
,
lora_request
:
Optional
[
LoRARequest
],
...
...
@@ -689,16 +698,51 @@ class LLMEngine:
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
self
.
model_executor
.
stop_remote_worker_execution_loop
()
@
overload
# DEPRECATED
def
add_request
(
self
,
request_id
:
str
,
*
,
inputs
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
...
@
overload
def
add_request
(
self
,
request_id
:
str
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
add_request
(
self
,
request_id
:
str
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
"""Add a request to the engine's request pool.
...
...
@@ -708,8 +752,7 @@ class LLMEngine:
Args:
request_id: The unique ID of the request.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation.
...
...
@@ -723,7 +766,7 @@ class LLMEngine:
Details:
- Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None.
- Create `
best_of
` number of :class:`~vllm.Sequence` objects.
- Create `
n
` number of :class:`~vllm.Sequence` objects.
- Create a :class:`~vllm.SequenceGroup` object
from the list of :class:`~vllm.Sequence`.
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
...
...
@@ -744,11 +787,15 @@ class LLMEngine:
>>> # continue the request processing
>>> ...
"""
if
inputs
is
not
None
:
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
if
lora_request
is
not
None
and
not
self
.
lora_config
:
raise
ValueError
(
f
"Got lora_request
{
lora_request
}
but LoRA is "
"not enabled!"
)
if
priority
>
0
and
not
self
.
scheduler_config
.
policy
==
"priority"
:
if
priority
!=
0
and
not
self
.
scheduler_config
.
policy
==
"priority"
:
raise
ValueError
(
f
"Got priority
{
priority
}
but "
"Priority scheduling is not enabled."
)
...
...
@@ -756,13 +803,20 @@ class LLMEngine:
arrival_time
=
time
.
time
()
preprocessed_inputs
=
self
.
input_preprocessor
.
preprocess
(
inputs
,
prompt
,
request_id
=
request_id
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
)
processed_inputs
=
self
.
input_processor
(
preprocessed_inputs
)
# This is a bit of a hack - copy the mm_processor_kwargs that were
# used in the input processor to the processed output, since these
# kwargs are presumed to be immutable and the values should be aligned
# between the input processor (here) and the input mapper.
processed_inputs
[
"mm_processor_kwargs"
]
=
preprocessed_inputs
.
get
(
"mm_processor_kwargs"
)
self
.
_add_processed_request
(
request_id
=
request_id
,
processed_inputs
=
processed_inputs
,
...
...
@@ -795,6 +849,9 @@ class LLMEngine:
raise
ValueError
(
f
"Cannot request more than "
f
"
{
max_logprobs
}
logprobs."
)
sampling_params
=
self
.
_build_logits_processors
(
sampling_params
,
lora_request
)
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params
=
sampling_params
.
clone
()
...
...
@@ -911,6 +968,45 @@ class LLMEngine:
return
def
_update_num_computed_tokens_for_multi_step_prefill
(
self
,
seq_group
:
SequenceGroup
,
seq_group_meta
:
SequenceGroupMetadata
,
is_first_step_output
:
Optional
[
bool
]):
"""
This function updates num_computed_tokens for prompt sequences
when Multi-Step is enabled.
seq_group: SequenceGroup to update the num_computed_tokens for.
seq_group_meta: Metadata of the given SequenceGroup.
is_first_step_output: Optional[bool] -
When available, is_first_step_output indicates if the appended
output token is the output of the first-step in multi-step.
A value of None indicates that outputs from all steps in
in multi-step are submitted in a single burst.
"""
assert
self
.
scheduler_config
.
is_multi_step
if
not
seq_group_meta
.
is_prompt
:
# num_computed_token updates for multi-step decodes happen after
# the tokens are appended to the sequence.
return
do_update
:
bool
=
False
if
self
.
scheduler_config
.
chunked_prefill_enabled
:
# In multi-step + chunked-prefill case, the prompt sequences
# that are scheduled are fully processed in the first step.
do_update
=
is_first_step_output
is
None
or
is_first_step_output
else
:
# Normal multi-step decoding case. In this case prompt-sequences
# are actually single-stepped. Always update in this case.
assert
seq_group
.
state
.
num_steps
==
1
do_update
=
True
if
do_update
:
seq_group
.
update_num_computed_tokens
(
seq_group_meta
.
token_chunk_size
)
def
_process_model_outputs
(
self
,
ctx
:
SchedulerContext
,
request_id
:
Optional
[
str
]
=
None
)
->
None
:
...
...
@@ -919,8 +1015,8 @@ class LLMEngine:
ctx: The virtual engine context to work on
request_id: If provided, then only this request is going to be processed
"""
now
=
time
.
time
()
if
len
(
ctx
.
output_queue
)
==
0
:
...
...
@@ -931,20 +1027,28 @@ class LLMEngine:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
,
skip
)
=
ctx
.
output_queue
[
0
]
is_last_step
,
is_first_step_output
,
skip
)
=
ctx
.
output_queue
[
0
]
else
:
(
outputs
,
seq_group_metadata_list
,
scheduler_outputs
,
is_async
,
is_last_step
,
skip
)
=
ctx
.
output_queue
.
popleft
()
is_last_step
,
is_first_step_output
,
skip
)
=
ctx
.
output_queue
.
popleft
()
# Sanity check
assert
len
(
seq_group_metadata_list
)
==
len
(
scheduler_outputs
.
scheduled_seq_groups
)
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
if
len
(
outputs
)
>
1
:
has_multiple_outputs
:
bool
=
len
(
outputs
)
>
1
outputs_by_sequence_group
:
List
[
List
[
SequenceGroupOutput
]]
if
has_multiple_outputs
:
assert
self
.
scheduler_config
.
is_multi_step
or
\
self
.
speculative_config
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
outputs_by_sequence_group
=
create_output_by_sequence_group
(
outputs
,
num_seq_groups
=
len
(
seq_group_metadata_list
))
# We have outputs for multiple steps submitted in a single burst,
# so invalidate is_first_step_output.
is_first_step_output
=
None
else
:
outputs_by_sequence_group
=
outputs
...
...
@@ -974,20 +1078,26 @@ class LLMEngine:
seq_group_meta
=
seq_group_metadata_list
[
i
]
scheduled_seq_group
=
scheduler_outputs
.
scheduled_seq_groups
[
i
]
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
:
SequenceGroup
=
scheduled_seq_group
.
seq_group
if
seq_group
.
is_finished
():
finished_before
.
append
(
i
)
continue
if
len
(
outputs
)
>
1
:
output
:
List
[
SequenceGroupOutput
]
if
has_multiple_outputs
:
output
=
outputs_by_sequence_group
[
i
]
else
:
output
=
[
outputs_by_sequence_group
[
0
][
i
]]
if
not
is_async
:
seq_group
.
update_num_computed_tokens
(
scheduled_seq_group
.
token_chunk_size
)
if
self
.
scheduler_config
.
is_multi_step
:
# Updates happen only if the sequence is prefill
self
.
_update_num_computed_tokens_for_multi_step_prefill
(
seq_group
,
seq_group_meta
,
is_first_step_output
)
else
:
seq_group
.
update_num_computed_tokens
(
seq_group_meta
.
token_chunk_size
or
0
)
if
outputs
:
for
o
in
outputs
:
...
...
@@ -995,13 +1105,13 @@ class LLMEngine:
and
seq_group
.
metrics
is
not
None
):
if
seq_group
.
metrics
.
model_forward_time
is
not
None
:
seq_group
.
metrics
.
model_forward_time
+=
(
o
.
model_forward_time
)
o
.
model_forward_time
or
0
)
else
:
seq_group
.
metrics
.
model_forward_time
=
(
o
.
model_forward_time
)
if
seq_group
.
metrics
.
model_execute_time
is
not
None
:
seq_group
.
metrics
.
model_execute_time
+=
(
o
.
model_execute_time
)
o
.
model_execute_time
or
0
)
else
:
seq_group
.
metrics
.
model_execute_time
=
(
o
.
model_execute_time
)
...
...
@@ -1121,19 +1231,34 @@ class LLMEngine:
if
seq_group
.
is_finished
():
continue
seq_group
.
update_num_computed_tokens
(
seq_group_metadata
.
token_chunk_size
)
if
self
.
scheduler_config
.
is_multi_step
:
# Updates happen only if the sequence is prefill
self
.
_update_num_computed_tokens_for_multi_step_prefill
(
seq_group
,
seq_group_metadata
,
seq_group
.
state
.
num_steps
==
1
)
else
:
token_chunk_size
=
(
seq_group_metadata
.
token_chunk_size
if
seq_group_metadata
.
token_chunk_size
is
not
None
else
0
)
seq_group
.
update_num_computed_tokens
(
token_chunk_size
)
if
seq_group_metadata
.
do_sample
:
assert
len
(
sequence_group_outputs
.
samples
)
==
1
,
(
"Async output processor expects a single sample"
" (i.e sampling_params.n == 1 and no "
"sampling_params.best_of > 1)"
)
" (i.e sampling_params.n == 1)"
)
sample
=
sequence_group_outputs
.
samples
[
0
]
assert
len
(
seq_group
.
seqs
)
==
1
seq
=
seq_group
.
seqs
[
0
]
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
if
self
.
scheduler_config
.
is_multi_step
:
is_prefill_append
=
seq
.
data
.
get_num_uncomputed_tokens
(
)
==
0
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
if
not
is_prefill_append
:
seq_group
.
update_num_computed_tokens
(
1
)
else
:
seq
.
append_token_id
(
sample
.
output_token
,
sample
.
logprobs
)
def
step
(
self
)
->
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]:
"""Performs one decoding iteration and returns newly generated results.
...
...
@@ -1286,12 +1411,19 @@ class LLMEngine:
if
self
.
scheduler_config
.
is_multi_step
:
self
.
cached_scheduler_outputs
[
0
]
=
SchedulerOutputState
()
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output
:
bool
=
False
if
not
seq_group_metadata_list
\
else
seq_group_metadata_list
[
0
].
state
.
num_steps
==
1
# Add results to the output_queue
ctx
.
append_output
(
outputs
=
outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
is_async
=
allow_async_output_proc
,
is_last_step
=
True
)
is_last_step
=
True
,
is_first_step_output
=
is_first_step_output
)
if
outputs
and
allow_async_output_proc
:
assert
len
(
outputs
)
==
1
,
(
...
...
@@ -1482,7 +1614,6 @@ class LLMEngine:
# Metadata
num_prompt_tokens_requests
:
List
[
int
]
=
[]
num_generation_tokens_requests
:
List
[
int
]
=
[]
best_of_requests
:
List
[
int
]
=
[]
n_requests
:
List
[
int
]
=
[]
finished_reason_requests
:
List
[
str
]
=
[]
...
...
@@ -1553,8 +1684,6 @@ class LLMEngine:
for
seq
in
seq_group
.
get_finished_seqs
()
])
if
seq_group
.
sampling_params
is
not
None
:
best_of_requests
.
append
(
seq_group
.
sampling_params
.
best_of
)
n_requests
.
append
(
seq_group
.
sampling_params
.
n
)
finished_reason_requests
.
extend
([
SequenceStatus
.
get_finished_reason
(
seq
.
status
)
...
...
@@ -1607,7 +1736,6 @@ class LLMEngine:
# Metadata
num_prompt_tokens_requests
=
num_prompt_tokens_requests
,
num_generation_tokens_requests
=
num_generation_tokens_requests
,
best_of_requests
=
best_of_requests
,
n_requests
=
n_requests
,
finished_reason_requests
=
finished_reason_requests
,
)
...
...
@@ -1694,8 +1822,6 @@ class LLMEngine:
seq_group
.
sampling_params
.
top_p
)
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_REQUEST_MAX_TOKENS
,
seq_group
.
sampling_params
.
max_tokens
)
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_REQUEST_BEST_OF
,
seq_group
.
sampling_params
.
best_of
)
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_REQUEST_N
,
seq_group
.
sampling_params
.
n
)
seq_span
.
set_attribute
(
SpanAttributes
.
LLM_USAGE_NUM_SEQUENCES
,
...
...
@@ -1732,8 +1858,8 @@ class LLMEngine:
def
is_embedding_model
(
self
):
return
self
.
model_config
.
is_embedding_model
def
_validate_model_inputs
(
self
,
inputs
:
Union
[
LLM
Inputs
,
EncoderDecoder
LLM
Inputs
]):
def
_validate_model_inputs
(
self
,
inputs
:
Union
[
DecoderOnly
Inputs
,
EncoderDecoderInputs
]):
if
self
.
model_config
.
is_multimodal_model
:
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
...
...
@@ -1760,4 +1886,52 @@ class LLMEngine:
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
\ No newline at end of file
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def
_build_logits_processors
(
self
,
sampling_params
:
SamplingParams
,
lora_request
:
Optional
[
LoRARequest
])
->
SamplingParams
:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Returns the modified sampling params."""
logits_processors
=
[]
if
(
guided_decoding
:
=
sampling_params
.
guided_decoding
)
is
not
None
:
logger
.
debug
(
"Building guided decoding logits processor in "
"LLMEngine. Params: %s"
,
guided_decoding
)
tokenizer
=
self
.
get_tokenizer
(
lora_request
=
lora_request
)
guided_decoding
.
backend
=
guided_decoding
.
backend
or
\
self
.
decoding_config
.
guided_decoding_backend
processor
=
get_local_guided_decoding_logits_processor
(
guided_params
=
guided_decoding
,
tokenizer
=
tokenizer
)
if
processor
:
logits_processors
.
append
(
processor
)
# Unset so this doesn't get passed down to the model
sampling_params
.
guided_decoding
=
None
if
(
sampling_params
.
logit_bias
or
sampling_params
.
allowed_token_ids
):
tokenizer
=
self
.
get_tokenizer
(
lora_request
=
lora_request
)
processors
=
get_logits_processors
(
logit_bias
=
sampling_params
.
logit_bias
,
allowed_token_ids
=
sampling_params
.
allowed_token_ids
,
tokenizer
=
tokenizer
)
logits_processors
.
extend
(
processors
)
# Unset so these don't get passed down to the model
sampling_params
.
logit_bias
=
None
sampling_params
.
allowed_token_ids
=
None
if
logits_processors
:
if
sampling_params
.
logits_processors
is
None
:
sampling_params
.
logits_processors
=
logits_processors
else
:
sampling_params
.
logits_processors
.
extend
(
logits_processors
)
return
sampling_params
vllm/engine/metrics.py
View file @
6d2051cc
from
typing
import
TYPE_CHECKING
from
typing
import
Counter
as
CollectionsCounter
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Type
,
Union
,
cast
import
numpy
as
np
import
prometheus_client
...
...
@@ -134,12 +134,6 @@ class Metrics:
labelnames
=
labelnames
,
buckets
=
build_1_2_5_buckets
(
max_model_len
),
)
self
.
histogram_best_of_request
=
self
.
_histogram_cls
(
name
=
"vllm:request_params_best_of"
,
documentation
=
"Histogram of the best_of request parameter."
,
labelnames
=
labelnames
,
buckets
=
[
1
,
2
,
5
,
10
,
20
],
)
self
.
histogram_n_request
=
self
.
_histogram_cls
(
name
=
"vllm:request_params_n"
,
documentation
=
"Histogram of the n request parameter."
,
...
...
@@ -255,10 +249,11 @@ class _RayHistogramWrapper:
labelnames
:
Optional
[
List
[
str
]]
=
None
,
buckets
:
Optional
[
List
[
float
]]
=
None
):
labelnames_tuple
=
tuple
(
labelnames
)
if
labelnames
else
None
boundaries
=
buckets
if
buckets
else
[]
self
.
_histogram
=
ray_metrics
.
Histogram
(
name
=
name
,
description
=
documentation
,
tag_keys
=
labelnames_tuple
,
boundaries
=
b
ucket
s
)
boundaries
=
b
oundarie
s
)
def
labels
(
self
,
**
labels
):
self
.
_histogram
.
set_default_tags
(
labels
)
...
...
@@ -273,9 +268,12 @@ class RayMetrics(Metrics):
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_gauge_cls
=
_RayGaugeWrapper
_counter_cls
=
_RayCounterWrapper
_histogram_cls
=
_RayHistogramWrapper
_gauge_cls
:
Type
[
prometheus_client
.
Gauge
]
=
cast
(
Type
[
prometheus_client
.
Gauge
],
_RayGaugeWrapper
)
_counter_cls
:
Type
[
prometheus_client
.
Counter
]
=
cast
(
Type
[
prometheus_client
.
Counter
],
_RayCounterWrapper
)
_histogram_cls
:
Type
[
prometheus_client
.
Histogram
]
=
cast
(
Type
[
prometheus_client
.
Histogram
],
_RayHistogramWrapper
)
def
__init__
(
self
,
labelnames
:
List
[
str
],
max_model_len
:
int
):
if
ray_metrics
is
None
:
...
...
@@ -473,8 +471,6 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
metrics
.
histogram_num_generation_tokens_request
,
stats
.
num_generation_tokens_requests
)
self
.
_log_histogram
(
self
.
metrics
.
histogram_n_request
,
stats
.
n_requests
)
self
.
_log_histogram
(
self
.
metrics
.
histogram_best_of_request
,
stats
.
best_of_requests
)
def
_log_prometheus_interval
(
self
,
prompt_throughput
:
float
,
generation_throughput
:
float
)
->
None
:
...
...
vllm/engine/metrics_types.py
View file @
6d2051cc
...
...
@@ -49,7 +49,6 @@ class Stats:
# Metadata
num_prompt_tokens_requests
:
List
[
int
]
num_generation_tokens_requests
:
List
[
int
]
best_of_requests
:
List
[
int
]
n_requests
:
List
[
int
]
finished_reason_requests
:
List
[
str
]
...
...
vllm/engine/multiprocessing/__init__.py
View file @
6d2051cc
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
List
,
Mapping
,
Optional
,
Union
from
typing
import
List
,
Mapping
,
Optional
,
Union
,
overload
from
vllm
import
PoolingParams
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Prompt
Type
from
vllm.lora.request
import
LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
deprecate_kwargs
VLLM_RPC_SUCCESS_STR
=
"SUCCESS"
...
...
@@ -23,12 +24,71 @@ class MQEngineDeadError(RuntimeError):
@
dataclass
class
RPCProcessRequest
:
inputs
:
Prompt
Inputs
prompt
:
Prompt
Type
params
:
Union
[
SamplingParams
,
PoolingParams
]
request_id
:
str
lora_request
:
Optional
[
LoRARequest
]
=
None
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
priority
:
int
=
0
@
overload
# DEPRECATED
def
__init__
(
self
,
*
,
inputs
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
...
@
overload
def
__init__
(
self
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
None
:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
__init__
(
self
,
prompt
:
Optional
[
PromptType
]
=
None
,
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
if
inputs
is
not
None
:
prompt
=
inputs
assert
(
prompt
is
not
None
and
params
is
not
None
and
request_id
is
not
None
)
super
().
__init__
()
self
.
prompt
=
prompt
self
.
params
=
params
self
.
request_id
=
request_id
self
.
lora_request
=
lora_request
self
.
trace_headers
=
trace_headers
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
priority
=
priority
@
dataclass
...
...
vllm/engine/multiprocessing/client.py
View file @
6d2051cc
...
...
@@ -2,8 +2,8 @@ import asyncio
import
copy
import
pickle
from
contextlib
import
contextmanager
,
suppress
from
typing
import
(
Any
,
AsyncGenerator
,
Dict
,
Iterator
,
Mapping
,
Optional
,
Union
)
from
typing
import
(
Any
,
AsyncGenerator
,
Dict
,
Iterator
,
List
,
Mapping
,
Optional
,
Union
,
cast
,
overload
)
import
cloudpickle
import
zmq
...
...
@@ -13,9 +13,12 @@ from zmq.asyncio import Socket
from
vllm
import
PoolingParams
from
vllm.config
import
DecodingConfig
,
EngineConfig
,
ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.engine.async_llm_engine
import
(
build_guided_decoding_logits_processor_async
)
from
vllm.engine.multiprocessing
import
(
ENGINE_DEAD_ERROR
,
IPC_DATA_EXT
,
IPC_HEALTH_EXT
,
IPC_INPUT_EXT
,
IPC_OUTPUT_EXT
,
RPC_REQUEST_T
,
...
...
@@ -23,15 +26,18 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCError
,
RPCProcessRequest
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCUProfileRequest
)
from
vllm.engine.protocol
import
EngineClient
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.inputs
import
Prompt
Inputs
from
vllm.inputs
import
Prompt
Type
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.utils
import
deprecate_kwargs
logger
=
init_logger
(
__name__
)
...
...
@@ -47,7 +53,7 @@ class MQClientClosedError(Exception):
"""
class
MQLLMEngineClient
:
class
MQLLMEngineClient
(
EngineClient
)
:
"""A client wrapper for MQLLMEngine that conforms to the
EngineClient protocol.
...
...
@@ -310,7 +316,7 @@ class MQLLMEngineClient:
or
response
!=
VLLM_RPC_SUCCESS_STR
):
raise
ValueError
(
error_message
)
async
def
get_tokenizer
(
self
,
lora_request
:
LoRARequest
):
async
def
get_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
):
return
await
self
.
tokenizer
.
get_lora_tokenizer_async
(
lora_request
)
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
...
...
@@ -338,8 +344,14 @@ class MQLLMEngineClient:
await
self
.
_send_one_way_rpc_request
(
request
=
RPCAbortRequest
(
request_id
),
socket
=
self
.
input_socket
)
async
def
do_log_stats
(
self
):
"""Ignore do_log_stats (handled on MQLLMEngine polling)"""
async
def
do_log_stats
(
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
model_output
:
Optional
[
List
[
SamplerOutput
]]
=
None
,
)
->
None
:
"""
Ignore do_log_stats (handled on MQLLMEngine polling)
"""
pass
async
def
check_health
(
self
):
...
...
@@ -367,14 +379,48 @@ class MQLLMEngineClient:
def
dead_error
(
self
)
->
BaseException
:
return
ENGINE_DEAD_ERROR
(
self
.
_errored_with
)
@
overload
# DEPRECATED
def
generate
(
self
,
inputs
:
PromptInputs
,
*
,
inputs
:
PromptType
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
@
overload
def
generate
(
self
,
prompt
:
PromptType
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
generate
(
self
,
prompt
:
Optional
[
PromptType
]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
# DEPRECATED
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate outputs for a request.
...
...
@@ -383,8 +429,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
...
...
@@ -392,18 +437,58 @@ class MQLLMEngineClient:
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
priority: Priority of the request (lower means earlier handling).
Any priority other than 0 will lead to an error if the
scheduling policy is not "priority".
"""
return
self
.
_process_request
(
inputs
,
sampling_params
,
request_id
,
if
inputs
is
not
None
:
prompt
=
inputs
assert
(
prompt
is
not
None
and
sampling_params
is
not
None
and
request_id
is
not
None
)
return
self
.
_process_request
(
prompt
,
sampling_params
,
request_id
,
lora_request
,
trace_headers
,
prompt_adapter_request
)
prompt_adapter_request
,
priority
)
@
overload
# DEPRECATED
def
encode
(
self
,
*
,
inputs
:
PromptType
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
...
@
overload
def
encode
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
...
@
deprecate_kwargs
(
"inputs"
,
additional_message
=
"Please use the 'prompt' parameter instead."
,
)
def
encode
(
self
,
prompt
:
Optional
[
PromptType
]
=
None
,
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
# DEPRECATED
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model.
...
...
@@ -412,8 +497,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
...
...
@@ -424,17 +508,29 @@ class MQLLMEngineClient:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
return
self
.
_process_request
(
inputs
,
pooling_params
,
request_id
,
lora_request
,
trace_headers
)
if
inputs
is
not
None
:
prompt
=
inputs
assert
(
prompt
is
not
None
and
pooling_params
is
not
None
and
request_id
is
not
None
)
return
cast
(
AsyncGenerator
[
EmbeddingRequestOutput
,
None
],
self
.
_process_request
(
prompt
,
pooling_params
,
request_id
,
lora_request
,
trace_headers
,
priority
=
priority
))
async
def
_process_request
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
Union
[
AsyncGenerator
[
RequestOutput
,
None
],
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
...
...
@@ -443,6 +539,20 @@ class MQLLMEngineClient:
if
self
.
_errored_with
is
not
None
:
raise
ENGINE_DEAD_ERROR
(
self
.
_errored_with
)
# Constructing guided decoding logits processors is expensive, so we do
# it here to avoid contending with cpu resources and the GIL on the
# backend process.
if
isinstance
(
params
,
SamplingParams
)
and
\
params
.
guided_decoding
is
not
None
:
params
=
await
\
build_guided_decoding_logits_processor_async
(
sampling_params
=
params
,
tokenizer
=
await
self
.
get_tokenizer
(
lora_request
),
default_guided_backend
=
(
self
.
decoding_config
.
guided_decoding_backend
if
self
.
decoding_config
else
DecodingConfig
.
guided_decoding_backend
),
)
# 1) Create output queue for this requests.
queue
:
asyncio
.
Queue
[
Union
[
RequestOutput
,
BaseException
]]
=
asyncio
.
Queue
()
...
...
@@ -462,12 +572,14 @@ class MQLLMEngineClient:
request_bytes
=
pickle
.
dumps
(
RPCProcessRequest
(
inputs
=
inputs
,
prompt
=
prompt
,
params
=
params
,
request_id
=
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
))
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
))
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts
=
(
request_bytes
,
...
...
vllm/engine/multiprocessing/engine.py
View file @
6d2051cc
...
...
@@ -73,11 +73,9 @@ class MQLLMEngine:
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
use_cached_outputs
=
True
kwargs
[
'
use_cached_outputs
'
]
=
True
self
.
engine
=
LLMEngine
(
*
args
,
**
kwargs
,
use_cached_outputs
=
use_cached_outputs
)
self
.
engine
=
LLMEngine
(
*
args
,
**
kwargs
)
self
.
log_requests
=
log_requests
self
.
use_async_sockets
=
use_async_sockets
...
...
@@ -130,6 +128,9 @@ class MQLLMEngine:
def
from_engine_args
(
cls
,
engine_args
:
AsyncEngineArgs
,
usage_context
:
UsageContext
,
ipc_path
:
str
):
"""Creates an MQLLMEngine from the engine arguments."""
# Setup plugins for each process
from
vllm.plugins
import
load_general_plugins
load_general_plugins
()
engine_config
=
engine_args
.
create_engine_config
()
...
...
@@ -278,11 +279,12 @@ class MQLLMEngine:
try
:
self
.
engine
.
add_request
(
request_id
=
request_id
,
inputs
=
request
.
inputs
,
prompt
=
request
.
prompt
,
params
=
request
.
params
,
lora_request
=
request
.
lora_request
,
trace_headers
=
request
.
trace_headers
,
prompt_adapter_request
=
request
.
prompt_adapter_request
)
prompt_adapter_request
=
request
.
prompt_adapter_request
,
priority
=
request
.
priority
)
if
self
.
log_requests
:
logger
.
info
(
"Added request %s."
,
request
.
request_id
)
...
...
vllm/engine/output_processor/multi_step.py
View file @
6d2051cc
import
functools
from
typing
import
Callable
,
List
from
typing
import
Callable
,
List
,
cast
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.output_processor.interfaces
import
(
...
...
@@ -9,8 +9,10 @@ from vllm.engine.output_processor.single_step import (
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.sequence
import
(
VLLM_INVALID_TOKEN_ID
,
CompletionSequenceGroupOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
Counter
...
...
@@ -57,11 +59,14 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
"""
for
output
in
outputs
:
# Concatenate single-step prompt logprob processing results.
assert
isinstance
(
output
,
CompletionSequenceGroupOutput
)
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
@
staticmethod
@
functools
.
lru_cache
()
def
_log_prompt_logprob_unsupported_warning_once
():
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
logger
.
warning
(
"Prompt logprob is not supported by multi step workers. "
"(e.g., speculative decode uses multi step workers)."
)
...
...
@@ -97,6 +102,19 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
assert
len
(
seqs
)
==
1
,
(
"Beam search not supported in multi-step decoding."
)
seq
=
seqs
[
0
]
seq_id
=
seq
.
seq_id
# This method is defined in the more generic
# SequenceGroupOutputProcessor, but here we assume that the outputs are
# of a more specific type.
assert
all
([
isinstance
(
output
,
CompletionSequenceGroupOutput
)
for
output
in
outputs
])
compl_outputs
=
cast
(
List
[
CompletionSequenceGroupOutput
],
outputs
)
assert
all
([
seq_id
==
output
.
samples
[
0
].
parent_seq_id
for
output
in
compl_outputs
])
if
is_async
:
# Async case: We process tokens one by one. Here, we know the token
...
...
@@ -108,7 +126,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
# Since there's only one sequence per sequence group,
# we can take the first sample.
samples
=
[
output
.
samples
[
0
]
for
output
in
outputs
]
samples
=
[
output
.
samples
[
0
]
for
output
in
compl_
outputs
]
# entries in sample tokens may be invalid (eg. due to spec decode
# rejecting tokens).
...
...
@@ -145,7 +163,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
remaining_tokens
=
sampling_params
.
max_tokens
-
(
seq
.
get_output_len
()
+
len
(
output_token_ids
))
if
remaining_tokens
<
0
:
valid_samples
=
valid_samples
[:
remaining_tokens
]
output_token_ids
=
output_token_ids
[:
remaining_tokens
]
# Truncate any tokens after EOS. This is required as spec decode
...
...
@@ -159,9 +176,9 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
for
i
in
range
(
len
(
output_token_ids
)):
if
output_token_ids
[
i
]
==
eos_token_id
:
output_token_ids
=
output_token_ids
[:
i
+
1
]
valid_samples
=
valid_samples
[:
i
+
1
]
break
is_prefill_sampled_token
=
seq
.
data
.
get_num_uncomputed_tokens
()
==
0
# Incrementally append tokens to the sequence, as if we had only one new
# token.
for
output_token_id
,
output_logprob
in
zip
(
output_token_ids
,
...
...
@@ -171,6 +188,13 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
logprobs
=
output_logprob
,
)
if
is_prefill_sampled_token
:
is_prefill_sampled_token
=
False
else
:
# Update num_computed_tokens iff the sampled token is not from
# a prefill step.
seq
.
data
.
update_num_computed_tokens
(
1
)
self
.
_process_decode_and_stop
(
seq
,
sampling_params
)
if
seq
.
is_finished
():
...
...
vllm/engine/output_processor/single_step.py
View file @
6d2051cc
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Tuple
from
vllm.config
import
SchedulerConfig
from
vllm.core.scheduler
import
Scheduler
...
...
@@ -6,9 +6,9 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor
)
from
vllm.engine.output_processor.stop_checker
import
StopChecker
from
vllm.logger
import
init_logger
from
vllm.s
ampling_params
import
SamplingParams
from
vllm.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.s
equence
import
(
CompletionSequenceGroupOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.detokenizer
import
Detokenizer
from
vllm.utils
import
Counter
...
...
@@ -17,7 +17,7 @@ logger = init_logger(__name__)
def
single_step_process_prompt_logprob
(
sg_output_proc
:
SequenceGroupOutputProcessor
,
seq_group
:
SequenceGroup
,
output
:
SequenceGroupOutput
)
->
None
:
output
:
Completion
SequenceGroupOutput
)
->
None
:
"""Process prompt logprobs associated with the :class:`SequenceGroupOutput`
for a given step.
...
...
@@ -107,13 +107,14 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
"""
assert
len
(
outputs
)
==
1
,
(
"Single step should only has 1 output."
)
output
=
outputs
[
0
]
assert
isinstance
(
output
,
CompletionSequenceGroupOutput
)
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
SequenceGroupOutput
,
is_async
:
bool
)
->
None
:
sampling_params
=
seq_group
.
sampling_params
if
sampling_params
.
best_of
==
1
and
not
sampling_params
.
use_beam_search
:
if
sampling_params
.
n
==
1
:
# only have one output sample
sample
=
outputs
.
samples
[
0
]
# only have one sequence
...
...
@@ -142,7 +143,6 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Process samples
samples
=
outputs
.
samples
parent_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
existing_finished_seqs
=
seq_group
.
get_finished_seqs
()
parent_child_dict
:
Dict
[
int
,
List
[
SequenceOutput
]]
=
{
parent_seq
.
seq_id
:
[]
for
parent_seq
in
parent_seqs
...
...
@@ -197,106 +197,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
lora_req
=
seq_group
.
lora_request
,
)
# Non-beam search case
if
not
sampling_params
.
use_beam_search
:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for
seq
,
parent
in
child_seqs
:
if
seq
is
not
parent
:
seq_group
.
add
(
seq
)
if
not
seq
.
is_finished
():
for
scheduler
in
self
.
scheduler
:
scheduler
.
fork_seq
(
parent
,
seq
)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for
seq
,
parent
in
child_seqs
:
if
seq
is
parent
and
seq
.
is_finished
():
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_seq
(
seq
)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs
:
List
[
Tuple
[
Sequence
,
Optional
[
Sequence
]]]
=
[]
unselected_child_seqs
:
List
[
Tuple
[
Sequence
,
Optional
[
Sequence
]]]
=
[]
beam_width
=
sampling_params
.
best_of
length_penalty
=
sampling_params
.
length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs
=
[(
seq
,
None
,
False
)
for
seq
in
existing_finished_seqs
]
new_finished_seqs
=
[(
seq
,
parent
,
True
)
for
seq
,
parent
in
child_seqs
if
seq
.
is_finished
()]
all_finished_seqs
=
existing_finished_seqs
+
new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
x
[
0
].
eos_token_id
),
reverse
=
True
)
for
seq
,
parent
,
is_new
in
all_finished_seqs
[:
beam_width
]:
if
is_new
:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs
.
append
((
seq
,
parent
))
for
seq
,
parent
,
is_new
in
all_finished_seqs
[
beam_width
:]:
if
is_new
:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs
.
append
((
seq
,
parent
))
else
:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group
.
remove
(
seq
.
seq_id
)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs
=
[(
seq
,
parent
)
for
seq
,
parent
in
child_seqs
if
not
seq
.
is_finished
()]
# Sort the running sequences by their scores.
running_child_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
x
[
0
].
eos_token_id
),
reverse
=
True
)
# Check if we can stop the beam search.
if
len
(
running_child_seqs
)
==
0
:
# No running sequences, stop the beam search.
stop_beam_search
=
True
elif
len
(
all_finished_seqs
)
<
beam_width
:
# Not enough finished sequences, continue the beam search.
stop_beam_search
=
False
else
:
# Check the early stopping criteria
best_running_seq
=
running_child_seqs
[
0
][
0
]
current_worst_seq
=
all_finished_seqs
[
beam_width
-
1
][
0
]
stop_beam_search
=
self
.
_check_beam_search_early_stopping
(
sampling_params
.
early_stopping
,
sampling_params
,
best_running_seq
,
current_worst_seq
)
if
stop_beam_search
:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs
.
extend
(
running_child_seqs
)
else
:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs
.
extend
(
running_child_seqs
[:
beam_width
])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs
.
extend
(
running_child_seqs
[
beam_width
:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for
seq
,
parent
in
selected_
child_seqs
:
for
seq
,
parent
in
child_seqs
:
if
seq
is
not
parent
:
seq_group
.
add
(
seq
)
if
not
seq
.
is_finished
():
...
...
@@ -305,61 +208,10 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for
seq
,
parent
in
selected_child_seqs
:
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for
seq
,
parent
in
child_seqs
:
if
seq
is
parent
and
seq
.
is_finished
():
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_seq
(
seq
)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for
seq
,
parent
in
unselected_child_seqs
:
if
seq
is
parent
:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group
.
remove
(
seq
.
seq_id
)
for
scheduler
in
self
.
scheduler
:
scheduler
.
free_seq
(
seq
)
def
_check_beam_search_early_stopping
(
self
,
early_stopping
:
Union
[
bool
,
str
],
sampling_params
:
SamplingParams
,
best_running_seq
:
Sequence
,
current_worst_seq
:
Sequence
,
)
->
bool
:
assert
sampling_params
.
use_beam_search
length_penalty
=
sampling_params
.
length_penalty
if
early_stopping
is
True
:
return
True
current_worst_score
=
current_worst_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
current_worst_seq
.
eos_token_id
)
if
early_stopping
is
False
:
highest_attainable_score
=
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
best_running_seq
.
eos_token_id
)
else
:
assert
early_stopping
==
"never"
if
length_penalty
>
0.0
:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length
=
max
(
best_running_seq
.
get_prompt_len
()
+
sampling_params
.
max_tokens
,
self
.
scheduler_config
.
max_model_len
)
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
best_running_seq
.
eos_token_id
,
seq_len
=
max_possible_length
))
else
:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
best_running_seq
.
eos_token_id
))
return
current_worst_score
>=
highest_attainable_score
return
vllm/engine/output_processor/stop_checker.py
View file @
6d2051cc
...
...
@@ -57,7 +57,7 @@ class StopChecker:
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id
=
seq
.
get_last_token_id
()
if
last_token_id
in
sampling_params
.
stop_token_ids
:
if
last_token_id
in
(
sampling_params
.
stop_token_ids
or
())
:
if
new_char_count
and
(
not
sampling_params
.
include_stop_str_in_output
):
# Remove last token
...
...
@@ -92,7 +92,7 @@ class StopChecker:
Returns the stop string if matched or else None.
"""
if
not
new_char_count
:
if
not
new_char_count
or
not
sampling_params
.
stop
:
return
None
for
stop_str
in
sampling_params
.
stop
:
...
...
vllm/engine/output_processor/util.py
View file @
6d2051cc
from
typing
import
List
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Union
from
typing
import
cast
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
Pooler
Output
,
SequenceGroupOutput
from
vllm.sequence
import
CompletionSequenceGroup
Output
,
SequenceGroupOutput
def
create_output_by_sequence_group
(
outputs
:
GenericSequence
[
Union
[
SamplerOutput
,
PoolerOutput
]
],
outputs
:
GenericSequence
[
SamplerOutput
],
num_seq_groups
:
int
)
->
List
[
List
[
SequenceGroupOutput
]]:
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
output_by_sequence_group
:
List
[
List
[
SequenceGroupOutput
]]
=
[
output_by_sequence_group
:
List
[
List
[
Completion
SequenceGroupOutput
]]
=
[
[]
for
_
in
range
(
num_seq_groups
)
]
for
step
in
outputs
:
sequence_group_output
:
CompletionSequenceGroupOutput
for
i
,
sequence_group_output
in
enumerate
(
step
):
output_by_sequence_group
[
i
].
append
(
sequence_group_output
)
return
output_by_sequence_group
# Cast to the more generic type that CompletionSequenceGroupOutput
# inherits from.
return
cast
(
List
[
List
[
SequenceGroupOutput
]],
output_by_sequence_group
)
vllm/engine/protocol.py
View file @
6d2051cc
from
typing
import
(
AsyncGenerator
,
List
,
Mapping
,
Optional
,
Protocol
,
runtime_checkable
)
import
asyncio
from
abc
import
ABC
,
abstractmethod
from
typing
import
AsyncGenerator
,
List
,
Mapping
,
Optional
,
Union
from
vllm.beam_search
import
BeamSearchSequence
,
create_sort_beams_key_function
from
vllm.config
import
DecodingConfig
,
ModelConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.inputs.data
import
PromptInputs
from
vllm.inputs.data
import
PromptType
,
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.outputs
import
(
CompletionOutput
,
EmbeddingRequestOutput
,
RequestOutput
)
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
collect_from_async_generator
,
random_uuid
logger
=
init_logger
(
__name__
)
@
runtime_checkable
class
EngineClient
(
Protocol
):
class
EngineClient
(
ABC
):
"""Protocol class for Clients to Engine"""
@
property
@
abstractmethod
def
is_running
(
self
)
->
bool
:
...
@
property
@
abstractmethod
def
is_stopped
(
self
)
->
bool
:
...
@
property
@
abstractmethod
def
errored
(
self
)
->
bool
:
...
@
property
@
abstractmethod
def
dead_error
(
self
)
->
BaseException
:
...
@
abstractmethod
def
generate
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
sampling_params
:
SamplingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
"""Generate
s
outputs for a request"""
"""Generate outputs for a request
.
"""
...
async
def
beam_search
(
self
,
prompt
:
Union
[
PromptType
,
List
[
int
]],
request_id
:
str
,
params
:
BeamSearchParams
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
beam_width
=
params
.
beam_width
max_tokens
=
params
.
max_tokens
ignore_eos
=
params
.
ignore_eos
temperature
=
params
.
temperature
length_penalty
=
params
.
length_penalty
tokenizer
=
await
self
.
get_tokenizer
(
lora_request
=
None
)
tokenizedPrompt
=
prompt
if
isinstance
(
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
tokenizedLength
=
len
(
tokenizedPrompt
)
sort_beams_key
=
create_sort_beams_key_function
(
tokenizer
.
eos_token_id
,
length_penalty
)
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
temperature
=
temperature
)
all_beams
=
[
BeamSearchSequence
(
tokens
=
tokenizedPrompt
,
cum_logprob
=
0
)]
completed
=
[]
for
_
in
range
(
max_tokens
):
prompts_batch
=
[
TokensPrompt
(
prompt_token_ids
=
beam
.
tokens
)
for
beam
in
all_beams
]
tasks
=
[]
request_id
=
f
"beam_search-
{
random_uuid
()
}
"
for
i
,
individual_prompt
in
enumerate
(
prompts_batch
):
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
task
=
asyncio
.
create_task
(
collect_from_async_generator
(
self
.
generate
(
individual_prompt
,
beam_search_params
,
request_id_item
)))
tasks
.
append
(
task
)
output
=
await
asyncio
.
gather
(
*
tasks
)
output
=
[
x
[
0
]
for
x
in
output
]
new_beams
=
[]
for
i
,
current_beam
in
enumerate
(
all_beams
):
result
=
output
[
i
]
if
result
.
outputs
[
0
].
logprobs
is
not
None
:
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
if
token_id
==
tokenizer
.
eos_token_id
and
\
not
ignore_eos
:
completed
.
append
(
new_beam
)
else
:
new_beams
.
append
(
new_beam
)
sorted_beams
=
sorted
(
new_beams
,
key
=
sort_beams_key
,
reverse
=
True
)
all_beams
=
sorted_beams
[:
beam_width
]
completed
.
extend
(
all_beams
)
sorted_completed
=
sorted
(
completed
,
key
=
sort_beams_key
,
reverse
=
True
)
best_beams
=
sorted_completed
[:
beam_width
]
for
beam
in
best_beams
:
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
[
tokenizedLength
:])
beam_search_output
=
RequestOutput
(
request_id
=
request_id
,
prompt
=
prompt
,
outputs
=
[
CompletionOutput
(
text
=
beam
.
text
,
cumulative_logprob
=
beam
.
cum_logprob
,
token_ids
=
beam
.
tokens
,
index
=
i
,
logprobs
=
beam
.
cum_logprob
,
)
for
(
i
,
beam
)
in
enumerate
(
best_beams
)
],
finished
=
True
,
prompt_token_ids
=
tokenizedPrompt
,
prompt_logprobs
=
None
)
yield
beam_search_output
@
abstractmethod
def
encode
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
pooling_params
:
PoolingParams
,
request_id
:
str
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
)
->
AsyncGenerator
[
EmbeddingRequestOutput
,
None
]:
"""Generate outputs for a request from an embedding model."""
...
@
abstractmethod
async
def
abort
(
self
,
request_id
:
str
)
->
None
:
"""Abort a request.
...
...
@@ -63,14 +172,17 @@ class EngineClient(Protocol):
request_id: The unique id of the request.
"""
@
abstractmethod
async
def
get_model_config
(
self
)
->
ModelConfig
:
"""Get the model configuration of the vLLM engine."""
...
@
abstractmethod
async
def
get_decoding_config
(
self
)
->
DecodingConfig
:
...
"""Get the decoding configuration of the vLLM engine."""
@
abstractmethod
async
def
get_tokenizer
(
self
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
...
...
@@ -78,9 +190,11 @@ class EngineClient(Protocol):
"""Get the appropriate tokenizer for the request"""
...
@
abstractmethod
async
def
is_tracing_enabled
(
self
)
->
bool
:
...
@
abstractmethod
async
def
do_log_stats
(
self
,
scheduler_outputs
:
Optional
[
SchedulerOutputs
]
=
None
,
...
...
@@ -88,14 +202,17 @@ class EngineClient(Protocol):
)
->
None
:
...
@
abstractmethod
async
def
check_health
(
self
)
->
None
:
"""Raise if unhealthy"""
...
@
abstractmethod
async
def
start_profile
(
self
)
->
None
:
"""Start profiling the engine"""
...
@
abstractmethod
async
def
stop_profile
(
self
)
->
None
:
"""Start profiling the engine"""
...
vllm/entrypoints/chat_utils.py
View file @
6d2051cc
...
...
@@ -157,22 +157,24 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
if
model_type
.
startswith
(
"llava"
):
return
self
.
_cached_token_str
(
self
.
_tokenizer
,
hf_config
.
image_token_index
)
if
model_type
in
(
"chameleon"
,
"internvl_chat"
):
if
model_type
in
(
"chameleon"
,
"internvl_chat"
,
"NVLM_D"
):
return
"<image>"
if
model_type
==
"mllama"
:
return
"<|image|>"
if
model_type
==
"qwen2_vl"
:
return
"<|vision_start|><|image_pad|><|vision_end|>"
if
model_type
==
"molmo"
:
return
""
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
raise
TypeError
(
f
"Unknown
{
modality
}
model type:
{
model_type
}
"
)
elif
modality
==
"audio"
:
if
model_type
==
"ultravox"
:
return
"<|reserved_special_token_0|>"
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
raise
TypeError
(
f
"Unknown
{
modality
}
model type:
{
model_type
}
"
)
elif
modality
==
"video"
:
if
model_type
==
"qwen2_vl"
:
return
"<|vision_start|><|video_pad|><|vision_end|>"
raise
TypeError
(
f
"Unknown model type:
{
model_type
}
"
)
raise
TypeError
(
f
"Unknown
{
modality
}
model type:
{
model_type
}
"
)
else
:
raise
TypeError
(
f
"Unknown modality:
{
modality
}
"
)
...
...
@@ -303,6 +305,28 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self
.
_add_placeholder
(
placeholder
)
def
validate_chat_template
(
chat_template
:
Optional
[
Union
[
Path
,
str
]]):
"""Raises if the provided chat template appears invalid."""
if
chat_template
is
None
:
return
elif
isinstance
(
chat_template
,
Path
)
and
not
chat_template
.
exists
():
raise
FileNotFoundError
(
"the supplied chat template path doesn't exist"
)
elif
isinstance
(
chat_template
,
str
):
JINJA_CHARS
=
"{}
\n
"
if
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
)
and
not
Path
(
chat_template
).
exists
():
raise
ValueError
(
f
"The supplied chat template string (
{
chat_template
}
) "
f
"appears path-like, but doesn't exist!"
)
else
:
raise
TypeError
(
f
"
{
type
(
chat_template
)
}
is not a valid chat template type"
)
def
load_chat_template
(
chat_template
:
Optional
[
Union
[
Path
,
str
]])
->
Optional
[
str
]:
if
chat_template
is
None
:
...
...
@@ -542,6 +566,14 @@ def apply_mistral_chat_template(
if
chat_template
is
not
None
:
logger
.
warning
(
"'chat_template' cannot be overridden for mistral tokenizer."
)
if
"add_generation_prompt"
in
kwargs
:
logger
.
warning
(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored."
)
if
"continue_final_message"
in
kwargs
:
logger
.
warning
(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored."
)
return
tokenizer
.
apply_chat_template
(
messages
=
messages
,
...
...
vllm/entrypoints/llm.py
View file @
6d2051cc
import
itertools
import
warnings
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
typing
import
(
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
,
cast
,
overload
)
from
tqdm
import
tqdm
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
BeamSearchSequence
,
get_beam_search_score
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
apply_hf_chat_template
,
apply_mistral_chat_template
,
parse_chat_messages
)
from
vllm.inputs
import
Prompt
Inputs
,
TextPrompt
,
TokensPrompt
from
vllm.inputs
import
Prompt
Type
,
TextPrompt
,
TokensPrompt
from
vllm.inputs.parse
import
parse_and_batch_prompt
from
vllm.logger
import
init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor.guided_decoding
import
(
GuidedDecodingRequest
,
get_local_guided_decoding_logits_processor
)
from
vllm.model_executor.guided_decoding.guided_fields
import
LLMGuidedOptions
from
vllm.model_executor.guided_decoding.guided_fields
import
(
GuidedDecodingRequest
,
LLMGuidedOptions
)
from
vllm.outputs
import
EmbeddingRequestOutput
,
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sampling_params
import
(
BeamSearchParams
,
GuidedDecodingParams
,
RequestOutputKind
,
SamplingParams
)
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
MistralTokenizer
,
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
...
...
@@ -32,37 +34,6 @@ from vllm.utils import Counter, deprecate_kwargs, is_list_of
logger
=
init_logger
(
__name__
)
@
dataclass
class
BeamSearchSequence
:
"""A sequence for beam search.
It keeps track of the tokens and the log probability of the sequence.
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
tokens
:
List
[
int
]
cum_logprob
:
float
=
0.0
text
:
Optional
[
str
]
=
None
@
dataclass
class
BeamSearchOutput
:
"""The output of beam search.
It contains the list of the best beam search sequences.
The length of the list is equal to the beam width.
"""
sequences
:
List
[
BeamSearchSequence
]
class
BeamSearchInstance
:
def
__init__
(
self
,
prompt_tokens
:
List
[
int
]):
self
.
beams
:
List
[
BeamSearchSequence
]
=
[
BeamSearchSequence
(
tokens
=
prompt_tokens
)
]
self
.
completed
:
List
[
BeamSearchSequence
]
=
[]
class
LLM
:
"""An LLM for generating texts from given prompts and sampling parameters.
...
...
@@ -179,15 +150,7 @@ class LLM:
if
"disable_log_stats"
not
in
kwargs
:
kwargs
[
"disable_log_stats"
]
=
True
removed_vision_keys
=
(
"image_token_id"
,
"image_feature_size"
,
"image_input_shape"
,
"image_input_type"
,
)
if
any
(
k
in
kwargs
for
k
in
removed_vision_keys
):
raise
TypeError
(
"There is no need to pass vision-related arguments anymore."
)
engine_args
=
EngineArgs
(
model
=
model
,
tokenizer
=
tokenizer
,
...
...
@@ -293,8 +256,8 @@ class LLM:
@
overload
def
generate
(
self
,
inpu
ts
:
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
/
,
# We may enable `inputs` keyword after removing the old API
promp
ts
:
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
/
,
*
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
Sequence
[
SamplingParams
]]]
=
None
,
...
...
@@ -304,14 +267,13 @@ class LLM:
...
@
deprecate_kwargs
(
"prompts"
,
"prompt_token_ids"
,
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_LEGACY
,
additional_message
=
"Please use the '
inpu
ts' parameter instead."
,
additional_message
=
"Please use the '
promp
ts' parameter instead."
,
)
def
generate
(
self
,
prompts
:
Union
[
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
prompts
:
Union
[
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
sampling_params
:
Optional
[
Union
[
SamplingParams
,
Sequence
[
SamplingParams
]]]
=
None
,
...
...
@@ -330,7 +292,9 @@ class LLM:
into a single list and pass it to this method.
Args:
inputs: A list of inputs to generate completions for.
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
...
...
@@ -358,12 +322,13 @@ class LLM:
"models (XForCausalLM, XForConditionalGeneration)."
)
if
prompt_token_ids
is
not
None
:
inpu
ts
=
self
.
_convert_v1_inputs
(
parsed_promp
ts
=
self
.
_convert_v1_inputs
(
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompt_token_ids
=
prompt_token_ids
,
)
else
:
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
parsed_prompts
=
cast
(
Union
[
PromptType
,
Sequence
[
PromptType
]],
prompts
)
if
isinstance
(
guided_options_request
,
dict
):
if
len
(
guided_options_request
)
>
1
:
...
...
@@ -378,7 +343,7 @@ class LLM:
sampling_params
=
SamplingParams
()
self
.
_validate_and_add_requests
(
inputs
=
inpu
ts
,
prompts
=
parsed_promp
ts
,
params
=
sampling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
@@ -391,9 +356,7 @@ class LLM:
def
beam_search
(
self
,
prompts
:
List
[
Union
[
str
,
List
[
int
]]],
beam_width
:
int
,
max_tokens
:
int
,
ignore_eos
:
bool
=
False
,
params
:
BeamSearchParams
,
)
->
List
[
BeamSearchOutput
]:
"""
Generate sequences using beam search.
...
...
@@ -401,20 +364,30 @@ class LLM:
Args:
prompts: A list of prompts. Each prompt can be a string or a list
of token IDs.
beam_width: The number of beams to keep at each step.
max_tokens: The max number of tokens to generate for each prompt.
params: The beam search parameters.
TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.?
"""
beam_width
=
params
.
beam_width
max_tokens
=
params
.
max_tokens
temperature
=
params
.
temperature
ignore_eos
=
params
.
ignore_eos
length_penalty
=
params
.
length_penalty
def
sort_beams_key
(
x
:
BeamSearchSequence
)
->
float
:
return
get_beam_search_score
(
x
.
tokens
,
x
.
cum_logprob
,
tokenizer
.
eos_token_id
,
length_penalty
)
tokenizer
=
self
.
get_tokenizer
()
# generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation
# at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
temperature
=
0.0
)
temperature
=
temperature
)
instances
:
List
[
BeamSearchInstance
]
=
[]
for
prompt
in
prompts
:
...
...
@@ -469,7 +442,7 @@ class LLM:
else
:
instance_new_beams
.
append
(
new_beam
)
sorted_beams
=
sorted
(
instance_new_beams
,
key
=
lambda
x
:
x
.
cum_logprob
,
key
=
sort_beams_key
,
reverse
=
True
)
instance
.
beams
=
sorted_beams
[:
beam_width
]
...
...
@@ -477,7 +450,7 @@ class LLM:
for
instance
in
instances
:
instance
.
completed
.
extend
(
instance
.
beams
)
sorted_completed
=
sorted
(
instance
.
completed
,
key
=
lambda
x
:
x
.
cum_logprob
,
key
=
sort_beams_key
,
reverse
=
True
)
best_beams
=
sorted_completed
[:
beam_width
]
...
...
@@ -497,7 +470,9 @@ class LLM:
lora_request
:
Optional
[
LoRARequest
]
=
None
,
chat_template
:
Optional
[
str
]
=
None
,
add_generation_prompt
:
bool
=
True
,
continue_final_message
:
bool
=
False
,
tools
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
List
[
RequestOutput
]:
"""
Generate responses for a chat conversation.
...
...
@@ -524,6 +499,11 @@ class LLM:
If not provided, the model's default chat template will be used.
add_generation_prompt: If True, adds a generation template
to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be `True`
if `add_generation_prompt` is also `True`.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
Returns:
A list of ``RequestOutput`` objects containing the generated
...
...
@@ -534,10 +514,13 @@ class LLM:
# Handle multi and single conversations
if
is_list_of
(
messages
,
list
):
# messages is List[List[...]]
list_of_messages
=
messages
list_of_messages
=
cast
(
List
[
List
[
ChatCompletionMessageParam
]],
messages
)
else
:
# messages is List[...]
list_of_messages
=
[
messages
]
list_of_messages
=
[
cast
(
List
[
ChatCompletionMessageParam
],
messages
)
]
prompts
:
List
[
Union
[
TokensPrompt
,
TextPrompt
]]
=
[]
...
...
@@ -545,6 +528,9 @@ class LLM:
tokenizer
=
self
.
get_tokenizer
()
model_config
=
self
.
llm_engine
.
get_model_config
()
# NOTE: _parse_chat_message_content_parts() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it.
conversation
,
mm_data
=
parse_chat_messages
(
msgs
,
model_config
,
tokenizer
)
...
...
@@ -555,6 +541,7 @@ class LLM:
messages
=
msgs
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
continue_final_message
,
tools
=
tools
,
)
else
:
...
...
@@ -563,6 +550,7 @@ class LLM:
conversation
=
conversation
,
chat_template
=
chat_template
,
add_generation_prompt
=
add_generation_prompt
,
continue_final_message
=
continue_final_message
,
tools
=
tools
,
)
...
...
@@ -575,6 +563,9 @@ class LLM:
if
mm_data
is
not
None
:
prompt
[
"multi_modal_data"
]
=
mm_data
if
mm_processor_kwargs
is
not
None
:
prompt
[
"mm_processor_kwargs"
]
=
mm_processor_kwargs
prompts
.
append
(
prompt
)
return
self
.
generate
(
...
...
@@ -648,8 +639,8 @@ class LLM:
@
overload
def
encode
(
self
,
inpu
ts
:
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
/
,
# We may enable `inputs` keyword after removing the old API
promp
ts
:
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
/
,
*
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
...
...
@@ -659,14 +650,13 @@ class LLM:
...
@
deprecate_kwargs
(
"prompts"
,
"prompt_token_ids"
,
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_LEGACY
,
additional_message
=
"Please use the '
inpu
ts' parameter instead."
,
additional_message
=
"Please use the '
promp
ts' parameter instead."
,
)
def
encode
(
self
,
prompts
:
Union
[
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
prompts
:
Union
[
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
Optional
[
Union
[
str
,
List
[
str
]]]]
=
None
,
pooling_params
:
Optional
[
Union
[
PoolingParams
,
Sequence
[
PoolingParams
]]]
=
None
,
...
...
@@ -682,9 +672,9 @@ class LLM:
into a single list and pass it to this method.
Args:
inpu
ts: The
inpu
ts to the LLM. You may pass a sequence of
inputs for
batch inference. See :class:`~vllm.inputs.Prompt
Inputs
`
for more details about the format of each
input
.
promp
ts: The
promp
ts to the LLM. You may pass a sequence of
prompts
for
batch inference. See :class:`~vllm.inputs.Prompt
Type
`
for more details about the format of each
prompts
.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
...
...
@@ -707,19 +697,20 @@ class LLM:
)
if
prompt_token_ids
is
not
None
:
inpu
ts
=
self
.
_convert_v1_inputs
(
parsed_promp
ts
=
self
.
_convert_v1_inputs
(
prompts
=
cast
(
Optional
[
Union
[
str
,
List
[
str
]]],
prompts
),
prompt_token_ids
=
prompt_token_ids
,
)
else
:
inputs
=
cast
(
Union
[
PromptInputs
,
Sequence
[
PromptInputs
]],
prompts
)
parsed_prompts
=
cast
(
Union
[
PromptType
,
Sequence
[
PromptType
]],
prompts
)
if
pooling_params
is
None
:
# Use default pooling params.
pooling_params
=
PoolingParams
()
self
.
_validate_and_add_requests
(
inputs
=
inpu
ts
,
prompts
=
parsed_promp
ts
,
params
=
pooling_params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
...
...
@@ -763,9 +754,9 @@ class LLM:
raise
ValueError
(
"Either prompts or prompt_token_ids must be "
"provided."
)
inpu
ts
:
List
[
Prompt
Inputs
]
=
[]
parsed_promp
ts
:
List
[
Prompt
Type
]
=
[]
for
i
in
range
(
num_requests
):
item
:
Prompt
Inputs
item
:
Prompt
Type
if
prompts
is
not
None
:
item
=
TextPrompt
(
prompt
=
prompts
[
i
])
...
...
@@ -774,13 +765,13 @@ class LLM:
else
:
raise
AssertionError
inpu
ts
.
append
(
item
)
parsed_promp
ts
.
append
(
item
)
return
inpu
ts
return
parsed_promp
ts
def
_validate_and_add_requests
(
self
,
inpu
ts
:
Union
[
Prompt
Inputs
,
Sequence
[
Prompt
Inputs
]],
promp
ts
:
Union
[
Prompt
Type
,
Sequence
[
Prompt
Type
]],
params
:
Union
[
SamplingParams
,
Sequence
[
SamplingParams
],
PoolingParams
,
Sequence
[
PoolingParams
]],
lora_request
:
Optional
[
Union
[
Sequence
[
LoRARequest
],
LoRARequest
]],
...
...
@@ -788,11 +779,19 @@ class LLM:
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
,
priority
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
if
isinstance
(
inputs
,
(
str
,
dict
)):
if
guided_options
is
not
None
:
warnings
.
warn
(
"guided_options_request is deprecated, use "
"SamplingParams.guided_decoding instead"
,
DeprecationWarning
,
stacklevel
=
2
,
)
if
isinstance
(
prompts
,
(
str
,
dict
)):
# Convert a single prompt to a list.
inpu
ts
=
[
inpu
ts
]
promp
ts
=
[
promp
ts
]
num_requests
=
len
(
inpu
ts
)
num_requests
=
len
(
promp
ts
)
if
isinstance
(
params
,
list
)
and
len
(
params
)
!=
num_requests
:
raise
ValueError
(
"The lengths of prompts and params "
"must be the same."
)
...
...
@@ -803,15 +802,15 @@ class LLM:
for
sp
in
params
if
isinstance
(
params
,
list
)
else
(
params
,
):
if
isinstance
(
sp
,
SamplingParams
):
self
.
_add_guided_p
rocessor
(
sp
,
guided_options
)
self
.
_add_guided_p
arams
(
sp
,
guided_options
)
# We only care about the final output
sp
.
output_kind
=
RequestOutputKind
.
FINAL_ONLY
# Add requests to the engine.
for
i
,
request_inputs
in
enumerate
(
inpu
ts
):
for
i
,
prompt
in
enumerate
(
promp
ts
):
self
.
_add_request
(
request_inputs
,
prompt
,
params
[
i
]
if
isinstance
(
params
,
Sequence
)
else
params
,
lora_request
=
lora_request
[
i
]
if
isinstance
(
lora_request
,
Sequence
)
else
lora_request
,
...
...
@@ -821,7 +820,7 @@ class LLM:
def
_add_request
(
self
,
inputs
:
Prompt
Inputs
,
prompt
:
Prompt
Type
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
lora_request
:
Optional
[
LoRARequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
...
...
@@ -830,29 +829,32 @@ class LLM:
request_id
=
str
(
next
(
self
.
request_counter
))
self
.
llm_engine
.
add_request
(
request_id
,
inputs
,
prompt
,
params
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
)
def
_add_guided_p
rocessor
(
def
_add_guided_p
arams
(
self
,
params
:
SamplingParams
,
guided_options
:
Optional
[
GuidedDecodingRequest
]
=
None
):
if
guided_options
:
if
guided_options
.
guided_decoding_backend
is
None
:
decoding_config
=
self
.
llm_engine
.
get_decoding_config
()
guided_options
.
guided_decoding_backend
=
(
decoding_config
.
guided_decoding_backend
)
guided_logits_processor
=
get_local_guided_decoding_logits_processor
(
#noqa
guided_options
.
guided_decoding_backend
,
guided_options
,
self
.
get_tokenizer
())
if
guided_logits_processor
:
if
params
.
logits_processors
is
None
:
params
.
logits_processors
=
[]
params
.
logits_processors
.
append
(
guided_logits_processor
)
if
guided_options
is
None
:
return
params
if
params
.
guided_decoding
is
not
None
:
raise
ValueError
(
"Cannot set both guided_options_request and"
"params.guided_decoding."
)
params
.
guided_decoding
=
GuidedDecodingParams
(
json
=
guided_options
.
guided_json
,
regex
=
guided_options
.
guided_regex
,
choice
=
guided_options
.
guided_choice
,
grammar
=
guided_options
.
guided_grammar
,
json_object
=
guided_options
.
guided_json_object
,
backend
=
guided_options
.
guided_decoding_backend
,
whitespace_pattern
=
guided_options
.
guided_whitespace_pattern
)
return
params
def
_run_engine
(
...
...
vllm/entrypoints/logger.py
View file @
6d2051cc
...
...
@@ -4,7 +4,7 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
logger
=
init_logger
(
__name__
)
...
...
@@ -21,7 +21,8 @@ class RequestLogger:
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt_token_ids
:
Optional
[
List
[
int
]],
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
]],
params
:
Optional
[
Union
[
SamplingParams
,
PoolingParams
,
BeamSearchParams
]],
lora_request
:
Optional
[
LoRARequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
)
->
None
:
...
...
vllm/entrypoints/openai/api_server.py
View file @
6d2051cc
...
...
@@ -31,7 +31,8 @@ from vllm.engine.multiprocessing.engine import run_mp_engine
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.launcher
import
serve_http
from
vllm.entrypoints.logger
import
RequestLogger
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.cli_args
import
(
make_arg_parser
,
validate_parsed_serve_args
)
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
...
...
@@ -53,6 +54,7 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from
vllm.entrypoints.openai.serving_engine
import
BaseModelPath
from
vllm.entrypoints.openai.serving_tokenization
import
(
OpenAIServingTokenization
)
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
,
get_open_zmq_ipc_path
...
...
@@ -526,8 +528,20 @@ async def run_server(args, **uvicorn_kwargs) -> None:
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
temp_socket
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
temp_socket
.
bind
((
""
,
args
.
port
))
if
args
.
tool_parser_plugin
and
len
(
args
.
tool_parser_plugin
)
>
3
:
ToolParserManager
.
import_tool_parser
(
args
.
tool_parser_plugin
)
valide_tool_parses
=
ToolParserManager
.
tool_parsers
.
keys
()
if
args
.
enable_auto_tool_choice
\
and
args
.
tool_call_parser
not
in
valide_tool_parses
:
raise
KeyError
(
f
"invalid tool call parser:
{
args
.
tool_call_parser
}
"
f
"(chose from {{
{
','
.
join
(
valide_tool_parses
)
}
}})"
)
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
bind
((
""
,
args
.
port
))
def
signal_handler
(
*
_
)
->
None
:
# Interrupt server on sigterm while initializing
...
...
@@ -541,8 +555,6 @@ async def run_server(args, **uvicorn_kwargs) -> None:
model_config
=
await
engine_client
.
get_model_config
()
init_app_state
(
engine_client
,
model_config
,
app
.
state
,
args
)
temp_socket
.
close
()
shutdown_task
=
await
serve_http
(
app
,
host
=
args
.
host
,
...
...
@@ -553,6 +565,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
ssl_certfile
=
args
.
ssl_certfile
,
ssl_ca_certs
=
args
.
ssl_ca_certs
,
ssl_cert_reqs
=
args
.
ssl_cert_reqs
,
fd
=
sock
.
fileno
(),
**
uvicorn_kwargs
,
)
...
...
@@ -567,5 +580,6 @@ if __name__ == "__main__":
description
=
"vLLM OpenAI-Compatible RESTful API server."
)
parser
=
make_arg_parser
(
parser
)
args
=
parser
.
parse_args
()
validate_parsed_serve_args
(
args
)
uvloop
.
run
(
run_server
(
args
))
vllm/entrypoints/openai/cli_args.py
View file @
6d2051cc
...
...
@@ -10,8 +10,10 @@ import ssl
from
typing
import
List
,
Optional
,
Sequence
,
Union
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
nullable_str
from
vllm.entrypoints.chat_utils
import
validate_chat_template
from
vllm.entrypoints.openai.serving_engine
import
(
LoRAModulePath
,
PromptAdapterPath
)
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.utils
import
FlexibleArgumentParser
...
...
@@ -190,16 +192,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"Enable auto tool choice for supported models. Use --tool-call-parser"
"to specify which parser to use"
)
valid_tool_parsers
=
ToolParserManager
.
tool_parsers
.
keys
()
parser
.
add_argument
(
"--tool-call-parser"
,
type
=
str
,
choices
=
[
"mistral"
,
"hermes"
],
metavar
=
"{"
+
","
.
join
(
valid_tool_parsers
)
+
"} or name registered in "
"--tool-parser-plugin"
,
default
=
None
,
help
=
"Select the tool call parser depending on the model that you're using."
" This is used to parse the model-generated tool call into OpenAI API "
"format. Required for --enable-auto-tool-choice."
)
parser
.
add_argument
(
"--tool-parser-plugin"
,
type
=
str
,
default
=
""
,
help
=
"Special the tool parser plugin write to parse the model-generated tool"
" into OpenAI API format, the name register in this plugin can be used "
"in --tool-call-parser."
)
parser
=
AsyncEngineArgs
.
add_cli_args
(
parser
)
parser
.
add_argument
(
'--max-log-len'
,
...
...
@@ -219,6 +232,20 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
return
parser
def
validate_parsed_serve_args
(
args
:
argparse
.
Namespace
):
"""Quick checks for model serve args that raise prior to loading."""
if
hasattr
(
args
,
"subparser"
)
and
args
.
subparser
!=
"serve"
:
return
# Ensure that the chat template is valid; raises if it likely isn't
validate_chat_template
(
args
.
chat_template
)
# Enable auto tool needs a tool call parser to be valid
if
args
.
enable_auto_tool_choice
and
not
args
.
tool_call_parser
:
raise
TypeError
(
"Error: --enable-auto-tool-choice requires "
"--tool-call-parser"
)
def
create_parser_for_docs
()
->
FlexibleArgumentParser
:
parser_for_docs
=
FlexibleArgumentParser
(
prog
=
"-m vllm.entrypoints.openai.api_server"
)
...
...
vllm/entrypoints/openai/protocol.py
View file @
6d2051cc
...
...
@@ -10,12 +10,10 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
from
typing_extensions
import
Annotated
,
Required
,
TypedDict
from
vllm.entrypoints.chat_utils
import
ChatCompletionMessageParam
from
vllm.entrypoints.openai.logits_processors
import
get_logits_processors
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
(
LogitsProcessor
,
RequestOutputKind
,
SamplingParams
)
from
vllm.sampling_params
import
(
BeamSearchParams
,
GuidedDecodingParams
,
RequestOutputKind
,
SamplingParams
)
from
vllm.sequence
import
Logprob
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
random_uuid
# torch is mocked during docs generation,
...
...
@@ -186,7 +184,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
min_p
:
float
=
0.0
repetition_penalty
:
float
=
1.0
length_penalty
:
float
=
1.0
early_stopping
:
bool
=
False
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
include_stop_str_in_output
:
bool
=
False
ignore_eos
:
bool
=
False
...
...
@@ -211,6 +208,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
continue_final_message
:
bool
=
Field
(
default
=
False
,
description
=
(
"If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
"This allows you to
\"
prefill
\"
part of the model's response for it. "
"Cannot be used at the same time as `add_generation_prompt`."
),
)
add_special_tokens
:
bool
=
Field
(
default
=
False
,
description
=
(
...
...
@@ -272,13 +278,33 @@ class ChatCompletionRequest(OpenAIBaseModel):
description
=
(
"If specified, will override the default whitespace pattern "
"for guided json decoding."
))
priority
:
int
=
Field
(
default
=
0
,
description
=
(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
))
# doc: end-chat-completion-extra-params
def
to_sampling_params
(
self
,
tokenizer
:
AnyTokenizer
,
guided_decode_logits_processor
:
Optional
[
LogitsProcessor
],
default_max_tokens
:
int
)
->
SamplingParams
:
def
to_beam_search_params
(
self
,
default_max_tokens
:
int
)
->
BeamSearchParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
n
=
self
.
n
if
self
.
n
is
not
None
else
1
temperature
=
self
.
temperature
if
self
.
temperature
is
not
None
else
0.0
return
BeamSearchParams
(
beam_width
=
n
,
max_tokens
=
max_tokens
,
ignore_eos
=
self
.
ignore_eos
,
temperature
=
temperature
,
length_penalty
=
self
.
length_penalty
,
)
def
to_sampling_params
(
self
,
default_max_tokens
:
int
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
...
...
@@ -287,14 +313,19 @@ class ChatCompletionRequest(OpenAIBaseModel):
if
prompt_logprobs
is
None
and
self
.
echo
:
prompt_logprobs
=
self
.
top_logprobs
# We now allow logprobs being true without top_logrobs.
logits_processors
=
get_logits_processors
(
logit_bias
=
self
.
logit_bias
,
allowed_token_ids
=
None
,
tokenizer
=
tokenizer
,
)
if
guided_decode_logits_processor
:
logits_processors
.
append
(
guided_decode_logits_processor
)
guided_json_object
=
None
if
(
self
.
response_format
is
not
None
and
self
.
response_format
.
type
==
"json_object"
):
guided_json_object
=
True
guided_decoding
=
GuidedDecodingParams
.
from_optional
(
json
=
self
.
_get_guided_json_from_tool
()
or
self
.
guided_json
,
regex
=
self
.
guided_regex
,
choice
=
self
.
guided_choice
,
grammar
=
self
.
guided_grammar
,
json_object
=
guided_json_object
,
backend
=
self
.
guided_decoding_backend
,
whitespace_pattern
=
self
.
guided_whitespace_pattern
)
return
SamplingParams
.
from_optional
(
n
=
self
.
n
,
...
...
@@ -314,17 +345,32 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos
=
self
.
ignore_eos
,
max_tokens
=
max_tokens
,
min_tokens
=
self
.
min_tokens
,
use_beam_search
=
self
.
use_beam_search
,
early_stopping
=
self
.
early_stopping
,
skip_special_tokens
=
self
.
skip_special_tokens
,
spaces_between_special_tokens
=
self
.
spaces_between_special_tokens
,
include_stop_str_in_output
=
self
.
include_stop_str_in_output
,
length_penalty
=
self
.
length_penalty
,
logits_processors
=
logits_processors
,
truncate_prompt_tokens
=
self
.
truncate_prompt_tokens
,
output_kind
=
RequestOutputKind
.
DELTA
if
self
.
stream
\
else
RequestOutputKind
.
FINAL_ONLY
,
)
guided_decoding
=
guided_decoding
,
logit_bias
=
self
.
logit_bias
)
def
_get_guided_json_from_tool
(
self
)
->
Optional
[
Union
[
str
,
dict
,
BaseModel
]]:
# user has chosen to not use any tool
if
self
.
tool_choice
==
"none"
or
self
.
tools
is
None
:
return
None
# user has chosen to use a named tool
if
type
(
self
.
tool_choice
)
is
ChatCompletionNamedToolChoiceParam
:
tool_name
=
self
.
tool_choice
.
function
.
name
tools
=
{
tool
.
function
.
name
:
tool
.
function
for
tool
in
self
.
tools
}
if
tool_name
not
in
tools
:
raise
ValueError
(
f
"Tool '
{
tool_name
}
' has not been passed in `tools`."
)
tool
=
tools
[
tool_name
]
return
tool
.
parameters
return
None
@
model_validator
(
mode
=
"before"
)
@
classmethod
...
...
@@ -386,7 +432,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
# if "tool_choice" is not specified but tools are provided,
# default to "auto" tool_choice
if
"tool_choice"
not
in
data
and
"tools"
in
data
:
if
"tool_choice"
not
in
data
and
data
.
get
(
"tools"
)
:
data
[
"tool_choice"
]
=
"auto"
# if "tool_choice" is specified -- validation
...
...
@@ -431,6 +477,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
" of the specified `tools`"
)
return
data
@
model_validator
(
mode
=
"before"
)
@
classmethod
def
check_generation_prompt
(
cls
,
data
):
if
data
.
get
(
"continue_final_message"
)
and
data
.
get
(
"add_generation_prompt"
):
raise
ValueError
(
"Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True."
)
return
data
class
CompletionRequest
(
OpenAIBaseModel
):
# Ordered by official OpenAI API documentation
...
...
@@ -460,7 +515,6 @@ class CompletionRequest(OpenAIBaseModel):
min_p
:
float
=
0.0
repetition_penalty
:
float
=
1.0
length_penalty
:
float
=
1.0
early_stopping
:
bool
=
False
stop_token_ids
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
list
)
include_stop_str_in_output
:
bool
=
False
ignore_eos
:
bool
=
False
...
...
@@ -516,13 +570,33 @@ class CompletionRequest(OpenAIBaseModel):
description
=
(
"If specified, will override the default whitespace pattern "
"for guided json decoding."
))
priority
:
int
=
Field
(
default
=
0
,
description
=
(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
))
# doc: end-completion-extra-params
def
to_sampling_params
(
self
,
tokenizer
:
AnyTokenizer
,
guided_decode_logits_processor
:
Optional
[
LogitsProcessor
],
default_max_tokens
:
int
)
->
SamplingParams
:
def
to_beam_search_params
(
self
,
default_max_tokens
:
int
)
->
BeamSearchParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
n
=
self
.
n
if
self
.
n
is
not
None
else
1
temperature
=
self
.
temperature
if
self
.
temperature
is
not
None
else
0.0
return
BeamSearchParams
(
beam_width
=
n
,
max_tokens
=
max_tokens
,
ignore_eos
=
self
.
ignore_eos
,
temperature
=
temperature
,
length_penalty
=
self
.
length_penalty
,
)
def
to_sampling_params
(
self
,
default_max_tokens
:
int
)
->
SamplingParams
:
max_tokens
=
self
.
max_tokens
if
max_tokens
is
None
:
max_tokens
=
default_max_tokens
...
...
@@ -533,13 +607,19 @@ class CompletionRequest(OpenAIBaseModel):
echo_without_generation
=
self
.
echo
and
self
.
max_tokens
==
0
logits_processors
=
get_logits_processors
(
logit_bias
=
self
.
logit_bias
,
allowed_token_ids
=
self
.
allowed_token_ids
,
tokenizer
=
tokenizer
,
)
if
guided_decode_logits_processor
:
logits_processors
.
append
(
guided_decode_logits_processor
)
guided_json_object
=
None
if
(
self
.
response_format
is
not
None
and
self
.
response_format
.
type
==
"json_object"
):
guided_json_object
=
True
guided_decoding
=
GuidedDecodingParams
.
from_optional
(
json
=
self
.
guided_json
,
regex
=
self
.
guided_regex
,
choice
=
self
.
guided_choice
,
grammar
=
self
.
guided_grammar
,
json_object
=
guided_json_object
,
backend
=
self
.
guided_decoding_backend
,
whitespace_pattern
=
self
.
guided_whitespace_pattern
)
return
SamplingParams
.
from_optional
(
n
=
self
.
n
,
...
...
@@ -558,18 +638,16 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos
=
self
.
ignore_eos
,
max_tokens
=
max_tokens
if
not
echo_without_generation
else
1
,
min_tokens
=
self
.
min_tokens
,
use_beam_search
=
self
.
use_beam_search
,
early_stopping
=
self
.
early_stopping
,
prompt_logprobs
=
prompt_logprobs
,
skip_special_tokens
=
self
.
skip_special_tokens
,
spaces_between_special_tokens
=
self
.
spaces_between_special_tokens
,
include_stop_str_in_output
=
self
.
include_stop_str_in_output
,
length_penalty
=
self
.
length_penalty
,
logits_processors
=
logits_processors
,
truncate_prompt_tokens
=
self
.
truncate_prompt_tokens
,
output_kind
=
RequestOutputKind
.
DELTA
if
self
.
stream
\
else
RequestOutputKind
.
FINAL_ONLY
,
)
guided_decoding
=
guided_decoding
,
logit_bias
=
self
.
logit_bias
,
allowed_token_ids
=
self
.
allowed_token_ids
)
@
model_validator
(
mode
=
"before"
)
@
classmethod
...
...
@@ -619,12 +697,23 @@ class EmbeddingRequest(OpenAIBaseModel):
encoding_format
:
Literal
[
"float"
,
"base64"
]
=
"float"
dimensions
:
Optional
[
int
]
=
None
user
:
Optional
[
str
]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
# doc: begin-embedding-pooling-params
additional_data
:
Optional
[
Any
]
=
None
# doc: end-embedding-pooling-params
# doc: begin-embedding-extra-params
priority
:
int
=
Field
(
default
=
0
,
description
=
(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
))
# doc: end-embedding-extra-params
def
to_pooling_params
(
self
):
return
PoolingParams
(
additional_data
=
self
.
additional_data
)
...
...
@@ -862,8 +951,18 @@ class TokenizeChatRequest(OpenAIBaseModel):
messages
:
List
[
ChatCompletionMessageParam
]
add_generation_prompt
:
bool
=
Field
(
default
=
True
)
continue_final_message
:
bool
=
Field
(
default
=
False
)
add_special_tokens
:
bool
=
Field
(
default
=
False
)
@
model_validator
(
mode
=
"before"
)
@
classmethod
def
check_generation_prompt
(
cls
,
data
):
if
data
.
get
(
"continue_final_message"
)
and
data
.
get
(
"add_generation_prompt"
):
raise
ValueError
(
"Cannot set both `continue_final_message` and "
"`add_generation_prompt` to True."
)
return
data
TokenizeRequest
=
Union
[
TokenizeCompletionRequest
,
TokenizeChatRequest
]
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
6d2051cc
...
...
@@ -29,12 +29,11 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
OpenAIServing
,
PromptAdapterPath
,
TextTokensPrompt
)
from
vllm.entrypoints.openai.tool_parsers
import
(
Hermes2ProToolParser
,
MistralToolParser
,
ToolParser
)
from
vllm.entrypoints.openai.tool_parsers
import
ToolParser
,
ToolParserManager
from
vllm.inputs
import
TokensPrompt
from
vllm.logger
import
init_logger
from
vllm.outputs
import
CompletionOutput
,
RequestOutput
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
log_tracing_disabled_warning
)
...
...
@@ -81,13 +80,13 @@ class OpenAIServingChat(OpenAIServing):
self
.
tool_parser
:
Optional
[
Callable
[[
AnyTokenizer
],
ToolParser
]]
=
None
if
self
.
enable_auto_tools
:
if
tool_parser
==
"mistral"
:
self
.
tool_parser
=
MistralToolParser
elif
tool_parser
==
"hermes"
:
self
.
tool_parser
=
Hermes2ProToolParser
else
:
try
:
self
.
tool_parser
=
ToolParserManager
.
get_tool_parser
(
tool_parser
)
except
Exception
as
e
:
raise
TypeError
(
"Error: --enable-auto-tool-choice requires "
"--tool-call-parser"
)
f
"tool_parser:'
{
tool_parser
}
' which has not "
"been registered"
)
from
e
async
def
create_chat_completion
(
self
,
...
...
@@ -137,6 +136,7 @@ class OpenAIServingChat(OpenAIServing):
messages
=
request
.
messages
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
continue_final_message
=
request
.
continue_final_message
,
tools
=
tool_dicts
,
documents
=
request
.
documents
,
**
(
request
.
chat_template_kwargs
or
{}),
...
...
@@ -147,18 +147,19 @@ class OpenAIServingChat(OpenAIServing):
conversation
=
conversation
,
chat_template
=
request
.
chat_template
or
self
.
chat_template
,
add_generation_prompt
=
request
.
add_generation_prompt
,
continue_final_message
=
request
.
continue_final_message
,
tools
=
tool_dicts
,
documents
=
request
.
documents
,
**
(
request
.
chat_template_kwargs
or
{}),
)
except
Exception
as
e
:
logger
.
e
rror
(
"Error in applying chat template from request
: %s"
,
e
)
logger
.
e
xception
(
"Error in applying chat template from request
"
)
return
self
.
create_error_response
(
str
(
e
))
try
:
mm_data
=
await
mm_data_future
except
Exception
as
e
:
logger
.
e
rror
(
"Error in loading multi-modal data
: %s"
,
e
)
logger
.
e
xception
(
"Error in loading multi-modal data
"
)
return
self
.
create_error_response
(
str
(
e
))
# validation for OpenAI tools
...
...
@@ -182,8 +183,9 @@ class OpenAIServingChat(OpenAIServing):
raw_request
.
state
.
request_metadata
=
request_metadata
try
:
guided_decode_logits_processor
=
(
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
if
self
.
enable_auto_tools
and
self
.
tool_parser
:
request
=
self
.
tool_parser
(
tokenizer
).
adjust_request
(
request
=
request
)
if
isinstance
(
prompt
,
str
):
prompt_inputs
=
self
.
_tokenize_prompt_input
(
...
...
@@ -202,11 +204,15 @@ class OpenAIServingChat(OpenAIServing):
assert
prompt_inputs
is
not
None
sampling_params
=
request
.
to_sampling_params
(
tokenizer
,
guided_decode_logits_processor
,
default_max_tokens
=
self
.
max_model_len
-
len
(
prompt_inputs
[
"prompt_token_ids"
]))
sampling_params
:
Union
[
SamplingParams
,
BeamSearchParams
]
default_max_tokens
=
self
.
max_model_len
-
len
(
prompt_inputs
[
"prompt_token_ids"
])
if
request
.
use_beam_search
:
sampling_params
=
request
.
to_beam_search_params
(
default_max_tokens
)
else
:
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
)
self
.
_log_inputs
(
request_id
,
prompt_inputs
,
...
...
@@ -228,14 +234,22 @@ class OpenAIServingChat(OpenAIServing):
and
contains_trace_headers
(
raw_request
.
headers
)):
log_tracing_disabled_warning
()
result_generator
=
self
.
engine_client
.
generate
(
engine_inputs
,
sampling_params
,
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
)
if
isinstance
(
sampling_params
,
BeamSearchParams
):
result_generator
=
self
.
engine_client
.
beam_search
(
engine_inputs
[
'prompt_token_ids'
],
request_id
,
sampling_params
,
)
else
:
result_generator
=
self
.
engine_client
.
generate
(
engine_inputs
,
sampling_params
,
request_id
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
request
.
priority
,
)
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
return
self
.
create_error_response
(
str
(
e
))
...
...
@@ -281,12 +295,8 @@ class OpenAIServingChat(OpenAIServing):
num_choices
=
1
if
request
.
n
is
None
else
request
.
n
previous_num_tokens
=
[
0
]
*
num_choices
finish_reason_sent
=
[
False
]
*
num_choices
num_prompt_tokens
=
0
tool_parser
:
Optional
[
ToolParser
]
=
self
.
tool_parser
(
tokenizer
)
if
self
.
tool_parser
else
None
if
isinstance
(
request
.
tool_choice
,
ChatCompletionNamedToolChoiceParam
):
tool_choice_function_name
=
request
.
tool_choice
.
function
.
name
else
:
...
...
@@ -305,6 +315,29 @@ class OpenAIServingChat(OpenAIServing):
else
:
previous_texts
,
all_previous_token_ids
=
None
,
None
# Prepare the tool parser if it's needed
try
:
if
tool_choice_auto
and
self
.
tool_parser
:
tool_parsers
:
List
[
Optional
[
ToolParser
]]
=
[
self
.
tool_parser
(
tokenizer
)
]
*
num_choices
else
:
tool_parsers
=
[
None
]
*
num_choices
except
RuntimeError
as
e
:
logger
.
exception
(
"Error in tool parser creation."
)
data
=
self
.
create_streaming_error_response
(
str
(
e
))
yield
f
"data:
{
data
}
\n\n
"
yield
"data: [DONE]
\n\n
"
return
stream_options
=
request
.
stream_options
if
stream_options
:
include_usage
=
stream_options
.
include_usage
include_continuous_usage
=
include_usage
and
\
stream_options
.
continuous_usage_stats
else
:
include_usage
,
include_continuous_usage
=
False
,
False
try
:
async
for
res
in
result_generator
:
if
res
.
prompt_token_ids
is
not
None
:
...
...
@@ -323,7 +356,6 @@ class OpenAIServingChat(OpenAIServing):
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for
i
in
range
(
num_choices
):
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
DeltaMessage
(
...
...
@@ -339,26 +371,19 @@ class OpenAIServingChat(OpenAIServing):
choices
=
[
choice_data
],
model
=
model_name
)
# if usage should be included
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
# if continuous usage stats are requested, add it
if
request
.
stream_options
.
continuous_usage_stats
:
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
0
,
total_tokens
=
num_prompt_tokens
)
chunk
.
usage
=
usage
# otherwise don't
else
:
chunk
.
usage
=
None
# if continuous usage stats are requested, add it
if
include_continuous_usage
:
chunk
.
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
0
,
total_tokens
=
num_prompt_tokens
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
# Send response to echo the input portion of the
# last message
if
request
.
echo
:
if
request
.
echo
or
request
.
continue_final_message
:
last_msg_content
:
str
=
""
if
conversation
and
"content"
in
conversation
[
-
1
]
and
conversation
[
-
1
].
get
(
"role"
)
==
role
:
...
...
@@ -379,17 +404,11 @@ class OpenAIServingChat(OpenAIServing):
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
if
(
request
.
stream_options
.
continuous_usage_stats
):
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
0
,
total_tokens
=
num_prompt_tokens
)
chunk
.
usage
=
usage
else
:
chunk
.
usage
=
None
if
include_continuous_usage
:
chunk
.
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
0
,
total_tokens
=
num_prompt_tokens
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
...
...
@@ -398,6 +417,7 @@ class OpenAIServingChat(OpenAIServing):
for
output
in
res
.
outputs
:
i
=
output
.
index
tool_parser
=
tool_parsers
[
i
]
if
finish_reason_sent
[
i
]:
continue
...
...
@@ -415,6 +435,12 @@ class OpenAIServingChat(OpenAIServing):
logprobs
=
None
delta_text
=
output
.
text
if
not
delta_text
and
not
output
.
token_ids
and
\
not
previous_num_tokens
[
i
]:
# Chunked prefill case, don't return empty chunks
continue
delta_message
:
Optional
[
DeltaMessage
]
# handle streaming deltas for tools with named tool_choice
...
...
@@ -445,7 +471,8 @@ class OpenAIServingChat(OpenAIServing):
delta_text
=
delta_text
,
previous_token_ids
=
previous_token_ids
,
current_token_ids
=
current_token_ids
,
delta_token_ids
=
output
.
token_ids
))
delta_token_ids
=
output
.
token_ids
,
request
=
request
))
# update the previous values for the next iteration
previous_texts
[
i
]
=
current_text
...
...
@@ -467,36 +494,11 @@ class OpenAIServingChat(OpenAIServing):
if
output
.
finish_reason
is
None
:
# Send token-by-token response for each request.n
choice_data
=
ChatCompletionResponseStreamChoice
(
index
=
i
,
delta
=
delta_message
,
logprobs
=
logprobs
,
finish_reason
=
None
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
# handle usage stats if requested & if continuous
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
if
request
.
stream_options
.
continuous_usage_stats
:
completion_tokens
=
len
(
output
.
token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
num_prompt_tokens
+
completion_tokens
,
)
chunk
.
usage
=
usage
else
:
chunk
.
usage
=
None
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
# if the model is finished generating
else
:
...
...
@@ -504,10 +506,12 @@ class OpenAIServingChat(OpenAIServing):
# any tokens that were generated but previously
# matched by partial json parsing
# only happens if we are NOT using guided decoding
auto_tools_called
=
False
if
tool_parser
:
index
=
len
(
tool_parser
.
prev_tool_call_arr
)
-
1
if
len
(
tool_parser
.
prev_tool_call_arr
)
>
0
else
0
auto_tools_called
=
len
(
tool_parser
.
prev_tool_call_arr
)
>
0
index
=
len
(
tool_parser
.
prev_tool_call_arr
)
-
1
if
auto_tools_called
else
0
else
:
index
=
0
...
...
@@ -542,38 +546,34 @@ class OpenAIServingChat(OpenAIServing):
delta
=
delta_message
,
logprobs
=
logprobs
,
finish_reason
=
output
.
finish_reason
if
not
(
tool_parser
and
len
(
tool_parser
.
prev_tool_call_arr
))
else
"tool_calls"
,
if
not
auto_tools_called
else
"tool_calls"
,
stop_reason
=
output
.
stop_reason
)
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
if
request
.
stream_options
.
continuous_usage_stats
:
completion_tokens
=
len
(
output
.
token_ids
)
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
num_prompt_tokens
+
completion_tokens
,
)
chunk
.
usage
=
usage
else
:
chunk
.
usage
=
None
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
finish_reason_sent
[
i
]
=
True
chunk
=
ChatCompletionStreamResponse
(
id
=
request_id
,
object
=
chunk_object_type
,
created
=
created_time
,
choices
=
[
choice_data
],
model
=
model_name
)
# handle usage stats if requested & if continuous
if
include_continuous_usage
:
completion_tokens
=
previous_num_tokens
[
i
]
chunk
.
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
completion_tokens
,
total_tokens
=
num_prompt_tokens
+
completion_tokens
,
)
data
=
chunk
.
model_dump_json
(
exclude_unset
=
True
)
yield
f
"data:
{
data
}
\n\n
"
# once the final token is handled, if stream_options.include_usage
# is sent, send the usage
if
(
request
.
stream_options
and
request
.
stream_options
.
include_usage
):
completion_tokens
=
previous_num_tokens
[
i
]
if
include_usage
:
completion_tokens
=
sum
(
previous_num_tokens
)
final_usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
,
completion_tokens
=
completion_tokens
,
...
...
@@ -600,7 +600,7 @@ class OpenAIServingChat(OpenAIServing):
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
logger
.
e
rror
(
"
e
rror in chat completion stream generator
: %s"
,
e
)
logger
.
e
xception
(
"
E
rror in chat completion stream generator
."
)
data
=
self
.
create_streaming_error_response
(
str
(
e
))
yield
f
"data:
{
data
}
\n\n
"
# Send the final done message after all response.n are finished
...
...
@@ -646,8 +646,10 @@ class OpenAIServingChat(OpenAIServing):
else
:
logprobs
=
None
# by default, tools are not used.
tools_called
=
False
# In the OpenAI API the finish_reason is "tools_called"
# if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls
auto_tools_called
=
False
# if auto tools are not enabled, and a named tool choice using
# outlines is not being used
...
...
@@ -669,7 +671,6 @@ class OpenAIServingChat(OpenAIServing):
name
=
request
.
tool_choice
.
function
.
name
,
arguments
=
output
.
text
))
])
tools_called
=
True
# if the request doesn't use tool choice
# OR specifies to not use a tool
...
...
@@ -683,9 +684,18 @@ class OpenAIServingChat(OpenAIServing):
or
request
.
tool_choice
is
None
)
and
self
.
enable_auto_tools
\
and
self
.
tool_parser
:
tool_parser
=
self
.
tool_parser
(
tokenizer
)
tool_call_info
=
tool_parser
.
extract_tool_calls
(
output
.
text
)
tools_called
=
tool_call_info
.
tools_called
try
:
tool_parser
=
self
.
tool_parser
(
tokenizer
)
except
RuntimeError
as
e
:
logger
.
exception
(
"Error in tool parser creation."
)
return
self
.
create_error_response
(
str
(
e
))
tool_call_info
=
tool_parser
.
extract_tool_calls
(
output
.
text
,
request
=
request
)
# In the OpenAI API the finish_reason is "tools_called"
# if the tool choice is auto and the model produced a tool
# call. The same is not true for named function calls
auto_tools_called
=
tool_call_info
.
tools_called
if
tool_call_info
.
tools_called
:
message
=
ChatMessage
(
role
=
role
,
content
=
tool_call_info
.
content
,
...
...
@@ -708,12 +718,12 @@ class OpenAIServingChat(OpenAIServing):
index
=
output
.
index
,
message
=
message
,
logprobs
=
logprobs
,
finish_reason
=
"tool_calls"
if
tools_called
else
finish_reason
=
"tool_calls"
if
auto_
tools_called
else
output
.
finish_reason
if
output
.
finish_reason
else
"stop"
,
stop_reason
=
output
.
stop_reason
)
choices
.
append
(
choice_data
)
if
request
.
echo
:
if
request
.
echo
or
request
.
continue_final_message
:
last_msg_content
=
""
if
conversation
and
"content"
in
conversation
[
-
1
]
and
conversation
[
-
1
].
get
(
"role"
)
==
role
:
...
...
@@ -726,6 +736,8 @@ class OpenAIServingChat(OpenAIServing):
assert
final_res
.
prompt_token_ids
is
not
None
num_prompt_tokens
=
len
(
final_res
.
prompt_token_ids
)
if
final_res
.
encoder_prompt_token_ids
is
not
None
:
num_prompt_tokens
+=
len
(
final_res
.
encoder_prompt_token_ids
)
num_generated_tokens
=
sum
(
len
(
output
.
token_ids
)
for
output
in
final_res
.
outputs
)
usage
=
UsageInfo
(
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
6d2051cc
...
...
@@ -28,6 +28,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
PromptAdapterPath
)
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
BeamSearchParams
,
SamplingParams
from
vllm.sequence
import
Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
log_tracing_disabled_warning
)
...
...
@@ -110,8 +111,6 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer
=
await
self
.
engine_client
.
get_tokenizer
(
lora_request
)
guided_decode_logits_processor
=
(
await
self
.
_guided_decode_logits_processor
(
request
,
tokenizer
))
prompts
=
list
(
self
.
_tokenize_prompt_input_or_inputs
(
request
,
...
...
@@ -122,11 +121,15 @@ class OpenAIServingCompletion(OpenAIServing):
))
for
i
,
prompt_inputs
in
enumerate
(
prompts
):
sampling_params
=
request
.
to_sampling_params
(
tokenizer
,
guided_decode_logits_processor
,
default_max_tokens
=
self
.
max_model_len
-
len
(
prompt_inputs
[
"prompt_token_ids"
]))
sampling_params
:
Union
[
SamplingParams
,
BeamSearchParams
]
default_max_tokens
=
self
.
max_model_len
-
len
(
prompt_inputs
[
"prompt_token_ids"
])
if
request
.
use_beam_search
:
sampling_params
=
request
.
to_beam_search_params
(
default_max_tokens
)
else
:
sampling_params
=
request
.
to_sampling_params
(
default_max_tokens
)
request_id_item
=
f
"
{
request_id
}
-
{
i
}
"
...
...
@@ -145,14 +148,25 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request
.
headers
):
log_tracing_disabled_warning
()
generator
=
self
.
engine_client
.
generate
(
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]},
sampling_params
,
request_id_item
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
,
)
if
isinstance
(
sampling_params
,
BeamSearchParams
):
generator
=
self
.
engine_client
.
beam_search
(
prompt_inputs
[
"prompt_token_ids"
],
request_id_item
,
sampling_params
,
)
else
:
generator
=
self
.
engine_client
.
generate
(
{
"prompt_token_ids"
:
prompt_inputs
[
"prompt_token_ids"
]
},
sampling_params
,
request_id_item
,
lora_request
=
lora_request
,
prompt_adapter_request
=
prompt_adapter_request
,
trace_headers
=
trace_headers
,
priority
=
request
.
priority
,
)
generators
.
append
(
generator
)
except
ValueError
as
e
:
...
...
@@ -260,8 +274,6 @@ class OpenAIServingCompletion(OpenAIServing):
for
output
in
res
.
outputs
:
i
=
output
.
index
+
prompt_idx
*
num_choices
# TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
assert
request
.
max_tokens
is
not
None
if
request
.
echo
and
request
.
max_tokens
==
0
:
...
...
@@ -293,6 +305,11 @@ class OpenAIServingCompletion(OpenAIServing):
delta_token_ids
=
output
.
token_ids
out_logprobs
=
output
.
logprobs
if
not
delta_text
and
not
delta_token_ids
\
and
not
previous_num_tokens
[
i
]:
# Chunked prefill case, don't return empty chunks
continue
if
request
.
logprobs
is
not
None
:
assert
out_logprobs
is
not
None
,
(
"Did not output logprobs"
)
...
...
Prev
1
…
13
14
15
16
17
18
19
20
21
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment