Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2216a4e5
Commit
2216a4e5
authored
Oct 23, 2024
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/main'
parents
ad385667
51c24c97
Changes
239
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
588 additions
and
301 deletions
+588
-301
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+63
-5
vllm/config.py
vllm/config.py
+69
-32
vllm/core/block/prefix_caching_block.py
vllm/core/block/prefix_caching_block.py
+1
-1
vllm/core/evictor.py
vllm/core/evictor.py
+0
-0
vllm/core/evictor_v1.py
vllm/core/evictor_v1.py
+0
-106
vllm/core/scheduler.py
vllm/core/scheduler.py
+1
-1
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+15
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+23
-4
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+86
-19
vllm/engine/metrics.py
vllm/engine/metrics.py
+28
-1
vllm/engine/metrics_types.py
vllm/engine/metrics_types.py
+3
-0
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+12
-0
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+17
-10
vllm/engine/protocol.py
vllm/engine/protocol.py
+19
-10
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+118
-28
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+51
-13
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+17
-5
vllm/entrypoints/openai/serving_chat.py
vllm/entrypoints/openai/serving_chat.py
+2
-2
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+61
-62
vllm/entrypoints/openai/serving_embedding.py
vllm/entrypoints/openai/serving_embedding.py
+2
-1
No files found.
vllm/compilation/decorators.py
View file @
2216a4e5
import
inspect
import
inspect
from
typing
import
Dict
,
List
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
supports_dynamo
from
vllm.utils
import
supports_dynamo
logger
=
init_logger
(
__name__
)
def
support_torch_compile
(
dynamic_arg_dims
:
Dict
[
str
,
Union
[
int
,
List
[
int
]]]):
def
support_torch_compile
(
cls
:
Optional
[
type
]
=
None
,
dynamic_arg_dims
:
Optional
[
Dict
[
str
,
Union
[
int
,
List
[
int
]]]]
=
None
):
"""
"""
A decorator to add support for compiling the forward method of a class.
A decorator to add support for compiling the forward method of a class.
Usage 1: use directly as a decorator without arguments:
```python
@support_torch_compile
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
...
```
Usage 2: use as a decorator with arguments:
```python
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
...
```
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
dimensions of the argument. The dynamic dimensions can be either a single
dimensions of the argument. The dynamic dimensions can be either a single
integer or a list of integers.
integer or a list of integers.
Depending on the value of arguments:
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
of the `forward` method, based on the following default rules:
- if the argument is annotated as `torch.Tensor` or
`Optional[torch.Tensor]`, the first dimension will be
marked as dynamic.
- if the argument is annotated as `IntermediateTensors`, the first
dimension of all the tensors in the intermediate tensors
will be marked as dynamic.
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:
- if it is a single integer, the corresponding dimension of the argument
- if it is a single integer, the corresponding dimension of the argument
will be marked as dynamic.
will be marked as dynamic.
...
@@ -38,11 +72,35 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
...
@@ -38,11 +72,35 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
if
not
hasattr
(
cls
,
'forward'
):
if
not
hasattr
(
cls
,
'forward'
):
raise
TypeError
(
"decorated class should have a forward method."
)
raise
TypeError
(
"decorated class should have a forward method."
)
sig
=
inspect
.
signature
(
cls
.
forward
)
sig
=
inspect
.
signature
(
cls
.
forward
)
for
k
in
dynamic_arg_dims
:
inferred_dynamic_arg_dims
=
dynamic_arg_dims
if
inferred_dynamic_arg_dims
is
None
:
inferred_dynamic_arg_dims
=
{}
for
k
,
v
in
sig
.
parameters
.
items
():
if
v
.
annotation
in
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
IntermediateTensors
,
Optional
[
IntermediateTensors
]
]:
inferred_dynamic_arg_dims
[
k
]
=
0
logger
.
debug
((
"Inferred dynamic dimensions for "
"forward method of %s: %s"
),
cls
,
list
(
inferred_dynamic_arg_dims
.
keys
()))
if
len
(
inferred_dynamic_arg_dims
)
==
0
:
raise
ValueError
(
"No dynamic dimensions found in the forward method of "
f
"
{
cls
}
. Please provide dynamic_arg_dims explicitly."
)
for
k
in
inferred_dynamic_arg_dims
:
if
k
not
in
sig
.
parameters
:
if
k
not
in
sig
.
parameters
:
raise
ValueError
(
raise
ValueError
(
f
"Argument
{
k
}
not found in the forward method of
{
cls
}
"
)
f
"Argument
{
k
}
not found in the forward method of
{
cls
}
"
)
return
_support_torch_compile
(
cls
,
dynamic_arg_dims
)
return
_support_torch_compile
(
cls
,
inferred_dynamic_arg_dims
)
if
cls
is
not
None
:
# use `support_torch_compile` as a decorator without arguments
assert
isinstance
(
cls
,
type
)
return
cls_decorator_helper
(
cls
)
return
cls_decorator_helper
return
cls_decorator_helper
...
...
vllm/config.py
View file @
2216a4e5
import
enum
import
enum
import
json
import
json
from
dataclasses
import
dataclass
,
field
,
fields
from
dataclasses
import
dataclass
,
field
,
fields
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Dict
,
List
,
Mapping
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
ClassVar
,
Dict
,
Final
,
List
,
Literal
,
Optional
,
Tuple
,
Type
,
Union
)
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
import
torch
import
torch
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -17,8 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
...
@@ -17,8 +17,7 @@ from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config
,
get_hf_image_processor_config
,
get_hf_text_config
)
get_hf_text_config
)
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
from
vllm.utils
import
(
GiB_bytes
,
cuda_device_count_stateless
,
get_cpu_memory
,
is_hip
,
is_neuron
,
is_openvino
,
is_xpu
,
is_hip
,
is_openvino
,
is_xpu
,
print_warning_once
)
print_warning_once
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
from
ray.util.placement_group
import
PlacementGroup
...
@@ -33,6 +32,11 @@ logger = init_logger(__name__)
...
@@ -33,6 +32,11 @@ logger = init_logger(__name__)
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS
=
32768
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS
=
32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS
=
5120
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS
=
5120
TaskOption
=
Literal
[
"auto"
,
"generate"
,
"embedding"
]
# "draft" is only used internally for speculative decoding
_Task
=
Literal
[
"generate"
,
"embedding"
,
"draft"
]
class
ModelConfig
:
class
ModelConfig
:
"""Configuration for the model.
"""Configuration for the model.
...
@@ -40,7 +44,11 @@ class ModelConfig:
...
@@ -40,7 +44,11 @@ class ModelConfig:
Args:
Args:
model: Name or path of the huggingface model to use.
model: Name or path of the huggingface model to use.
It is also used as the content for `model_name` tag in metrics
It is also used as the content for `model_name` tag in metrics
output when `served_model_name` is not specified.
output when `served_model_name` is not specified.
task: The task to use the model for. Each vLLM instance only supports
one task, even if the same model can be used for multiple tasks.
When the model only supports one task, "auto" can be used to select
it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer, and
available, "slow" will always use the slow tokenizer, and
...
@@ -108,6 +116,7 @@ class ModelConfig:
...
@@ -108,6 +116,7 @@ class ModelConfig:
def
__init__
(
self
,
def
__init__
(
self
,
model
:
str
,
model
:
str
,
task
:
Union
[
TaskOption
,
_Task
],
tokenizer
:
str
,
tokenizer
:
str
,
tokenizer_mode
:
str
,
tokenizer_mode
:
str
,
trust_remote_code
:
bool
,
trust_remote_code
:
bool
,
...
@@ -205,9 +214,15 @@ class ModelConfig:
...
@@ -205,9 +214,15 @@ class ModelConfig:
self
.
is_attention_free
=
self
.
_init_attention_free
()
self
.
is_attention_free
=
self
.
_init_attention_free
()
self
.
has_inner_state
=
self
.
_init_has_inner_state
()
self
.
has_inner_state
=
self
.
_init_has_inner_state
()
self
.
override_neuron_config
=
override_neuron_config
if
is_neuron
(
if
current_platform
.
is_neuron
():
)
else
None
self
.
override_neuron_config
=
override_neuron_config
self
.
_verify_embedding_mode
()
else
:
self
.
override_neuron_config
=
None
supported_tasks
,
task
=
self
.
_resolve_task
(
task
,
self
.
hf_config
)
self
.
supported_tasks
=
supported_tasks
self
.
task
:
Final
=
task
self
.
_verify_quantization
()
self
.
_verify_quantization
()
self
.
_verify_cuda_graph
()
self
.
_verify_cuda_graph
()
self
.
_verify_bnb_config
()
self
.
_verify_bnb_config
()
...
@@ -241,18 +256,44 @@ class ModelConfig:
...
@@ -241,18 +256,44 @@ class ModelConfig:
"either 'auto', 'slow' or 'mistral'."
)
"either 'auto', 'slow' or 'mistral'."
)
self
.
tokenizer_mode
=
tokenizer_mode
self
.
tokenizer_mode
=
tokenizer_mode
def
_verify_embedding_mode
(
self
)
->
None
:
def
_resolve_task
(
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
self
,
task_option
:
Union
[
TaskOption
,
_Task
],
hf_config
:
PretrainedConfig
,
)
->
Tuple
[
Set
[
_Task
],
_Task
]:
if
task_option
==
"draft"
:
return
{
"draft"
},
"draft"
architectures
=
getattr
(
hf_config
,
"architectures"
,
[])
task_support
:
Dict
[
_Task
,
bool
]
=
{
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate"
:
ModelRegistry
.
is_text_generation_model
(
architectures
),
"embedding"
:
ModelRegistry
.
is_embedding_model
(
architectures
),
}
supported_tasks_lst
:
List
[
_Task
]
=
[
task
for
task
,
is_supported
in
task_support
.
items
()
if
is_supported
]
supported_tasks
=
set
(
supported_tasks_lst
)
if
task_option
==
"auto"
:
selected_task
=
next
(
iter
(
supported_tasks_lst
))
# TODO: Allow the same model architecture to be specified as either
if
len
(
supported_tasks
)
>
1
:
# generation or embedding model
logger
.
info
(
if
"Phi3VForCausalLM"
in
architectures
:
"This model supports multiple tasks: %s. "
# Match both remote and local names
"Defaulting to '%s'."
,
supported_tasks
,
selected_task
)
embedding_mode
=
"/VLM2Vec"
in
self
.
model
else
:
else
:
embedding_mode
=
ModelRegistry
.
is_embedding_model
(
architectures
)
if
task_option
not
in
supported_tasks
:
msg
=
(
f
"This model does not support the '
{
task_option
}
' task. "
f
"Supported tasks:
{
supported_tasks
}
"
)
raise
ValueError
(
msg
)
self
.
embedding_mode
=
embedding_mode
selected_task
=
task_option
return
supported_tasks
,
selected_task
def
_parse_quant_hf_config
(
self
):
def
_parse_quant_hf_config
(
self
):
quant_cfg
=
getattr
(
self
.
hf_config
,
"quantization_config"
,
None
)
quant_cfg
=
getattr
(
self
.
hf_config
,
"quantization_config"
,
None
)
...
@@ -337,7 +378,7 @@ class ModelConfig:
...
@@ -337,7 +378,7 @@ class ModelConfig:
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ."
)
" is not set, enabling VLLM_USE_TRITON_AWQ."
)
envs
.
VLLM_USE_TRITON_AWQ
=
True
envs
.
VLLM_USE_TRITON_AWQ
=
True
if
is_neuron
(
if
current_platform
.
is_neuron
(
)
and
self
.
quantization
not
in
neuron_supported_quantization
:
)
and
self
.
quantization
not
in
neuron_supported_quantization
:
raise
ValueError
(
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"
{
self
.
quantization
}
quantization is currently not "
...
@@ -410,7 +451,7 @@ class ModelConfig:
...
@@ -410,7 +451,7 @@ class ModelConfig:
# Async postprocessor is not necessary with embedding mode
# Async postprocessor is not necessary with embedding mode
# since there is no token generation
# since there is no token generation
if
self
.
embedding
_mode
:
if
self
.
task
==
"
embedding
"
:
self
.
use_async_output_proc
=
False
self
.
use_async_output_proc
=
False
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
...
@@ -591,11 +632,6 @@ class ModelConfig:
...
@@ -591,11 +632,6 @@ class ModelConfig:
(
hasattr
(
self
.
hf_config
,
"text_config"
)
and
getattr
(
(
hasattr
(
self
.
hf_config
,
"text_config"
)
and
getattr
(
self
.
hf_config
.
text_config
,
"is_encoder_decoder"
,
False
)))
self
.
hf_config
.
text_config
,
"is_encoder_decoder"
,
False
)))
@
property
def
is_embedding_model
(
self
)
->
bool
:
"""Extract the embedding model flag."""
return
self
.
embedding_mode
@
property
@
property
def
is_multimodal_model
(
self
)
->
bool
:
def
is_multimodal_model
(
self
)
->
bool
:
return
self
.
multimodal_config
is
not
None
return
self
.
multimodal_config
is
not
None
...
@@ -952,6 +988,7 @@ class SchedulerConfig:
...
@@ -952,6 +988,7 @@ class SchedulerConfig:
"""Scheduler configuration.
"""Scheduler configuration.
Args:
Args:
task: The task to use the model for.
max_num_batched_tokens: Maximum number of tokens to be processed in
max_num_batched_tokens: Maximum number of tokens to be processed in
a single iteration.
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
max_num_seqs: Maximum number of sequences to be processed in a single
...
@@ -966,7 +1003,6 @@ class SchedulerConfig:
...
@@ -966,7 +1003,6 @@ class SchedulerConfig:
prompt latency) before scheduling next prompt.
prompt latency) before scheduling next prompt.
enable_chunked_prefill: If True, prefill requests can be chunked based
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
preemption_mode: Whether to perform preemption by swapping or
preemption_mode: Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
We use recomputation by default since it incurs lower overhead than
...
@@ -981,13 +1017,13 @@ class SchedulerConfig:
...
@@ -981,13 +1017,13 @@ class SchedulerConfig:
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
task
:
_Task
,
max_num_batched_tokens
:
Optional
[
int
],
max_num_batched_tokens
:
Optional
[
int
],
max_num_seqs
:
int
,
max_num_seqs
:
int
,
max_model_len
:
int
,
max_model_len
:
int
,
num_lookahead_slots
:
int
=
0
,
num_lookahead_slots
:
int
=
0
,
delay_factor
:
float
=
0.0
,
delay_factor
:
float
=
0.0
,
enable_chunked_prefill
:
bool
=
False
,
enable_chunked_prefill
:
bool
=
False
,
embedding_mode
:
bool
=
False
,
is_multimodal_model
:
bool
=
False
,
is_multimodal_model
:
bool
=
False
,
preemption_mode
:
Optional
[
str
]
=
None
,
preemption_mode
:
Optional
[
str
]
=
None
,
num_scheduler_steps
:
int
=
1
,
num_scheduler_steps
:
int
=
1
,
...
@@ -1011,7 +1047,7 @@ class SchedulerConfig:
...
@@ -1011,7 +1047,7 @@ class SchedulerConfig:
# for higher throughput.
# for higher throughput.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
if
embedding
_mode
:
if
task
==
"
embedding
"
:
# For embedding, choose specific value for higher throughput
# For embedding, choose specific value for higher throughput
max_num_batched_tokens
=
max
(
max_num_batched_tokens
=
max
(
max_num_batched_tokens
,
max_num_batched_tokens
,
...
@@ -1031,12 +1067,12 @@ class SchedulerConfig:
...
@@ -1031,12 +1067,12 @@ class SchedulerConfig:
"Chunked prefill is enabled with max_num_batched_tokens=%d."
,
"Chunked prefill is enabled with max_num_batched_tokens=%d."
,
self
.
max_num_batched_tokens
)
self
.
max_num_batched_tokens
)
self
.
task
:
Final
=
task
self
.
max_num_seqs
=
max_num_seqs
self
.
max_num_seqs
=
max_num_seqs
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
self
.
num_lookahead_slots
=
num_lookahead_slots
self
.
num_lookahead_slots
=
num_lookahead_slots
self
.
delay_factor
=
delay_factor
self
.
delay_factor
=
delay_factor
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
embedding_mode
=
embedding_mode
self
.
preemption_mode
=
preemption_mode
self
.
preemption_mode
=
preemption_mode
self
.
num_scheduler_steps
=
num_scheduler_steps
self
.
num_scheduler_steps
=
num_scheduler_steps
self
.
multi_step_stream_outputs
=
multi_step_stream_outputs
self
.
multi_step_stream_outputs
=
multi_step_stream_outputs
...
@@ -1086,7 +1122,7 @@ class DeviceConfig:
...
@@ -1086,7 +1122,7 @@ class DeviceConfig:
# Automated device type detection
# Automated device type detection
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
():
self
.
device_type
=
"cuda"
self
.
device_type
=
"cuda"
elif
is_neuron
():
elif
current_platform
.
is_neuron
():
self
.
device_type
=
"neuron"
self
.
device_type
=
"neuron"
elif
is_openvino
():
elif
is_openvino
():
self
.
device_type
=
"openvino"
self
.
device_type
=
"openvino"
...
@@ -1248,6 +1284,7 @@ class SpeculativeConfig:
...
@@ -1248,6 +1284,7 @@ class SpeculativeConfig:
ngram_prompt_lookup_min
=
0
ngram_prompt_lookup_min
=
0
draft_model_config
=
ModelConfig
(
draft_model_config
=
ModelConfig
(
model
=
speculative_model
,
model
=
speculative_model
,
task
=
"draft"
,
tokenizer
=
target_model_config
.
tokenizer
,
tokenizer
=
target_model_config
.
tokenizer
,
tokenizer_mode
=
target_model_config
.
tokenizer_mode
,
tokenizer_mode
=
target_model_config
.
tokenizer_mode
,
trust_remote_code
=
target_model_config
.
trust_remote_code
,
trust_remote_code
=
target_model_config
.
trust_remote_code
,
...
@@ -1381,11 +1418,11 @@ class SpeculativeConfig:
...
@@ -1381,11 +1418,11 @@ class SpeculativeConfig:
else
:
else
:
speculative_draft_tensor_parallel_size
=
\
speculative_draft_tensor_parallel_size
=
\
target_parallel_config
.
tensor_parallel_size
target_parallel_config
.
tensor_parallel_size
elif
speculative_draft_tensor_parallel_size
!=
1
:
elif
speculative_draft_tensor_parallel_size
not
in
(
# TODO(wooyeon): allow tp values larger than 1
1
,
target_parallel_config
.
tensor_parallel_size
):
raise
ValueError
(
raise
ValueError
(
f
"
{
speculative_draft_tensor_parallel_size
=
}
cannot be "
f
"
{
speculative_draft_tensor_parallel_size
=
}
cannot be "
f
"other value than 1"
)
f
"other value than 1
or target model tensor_parallel_size
"
)
draft_parallel_config
=
ParallelConfig
(
draft_parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
target_parallel_config
.
pipeline_parallel_size
=
target_parallel_config
.
...
...
vllm/core/block/prefix_caching_block.py
View file @
2216a4e5
...
@@ -7,7 +7,7 @@ from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
...
@@ -7,7 +7,7 @@ from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
,
BlockId
,
Device
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
,
BlockId
,
Device
from
vllm.core.block.naive_block
import
(
BlockPool
,
NaiveBlock
,
from
vllm.core.block.naive_block
import
(
BlockPool
,
NaiveBlock
,
NaiveBlockAllocator
)
NaiveBlockAllocator
)
from
vllm.core.evictor
_v2
import
EvictionPolicy
,
Evictor
,
make_evictor
from
vllm.core.evictor
import
EvictionPolicy
,
Evictor
,
make_evictor
PrefixHash
=
int
PrefixHash
=
int
...
...
vllm/core/evictor
_v2
.py
→
vllm/core/evictor.py
View file @
2216a4e5
File moved
vllm/core/evictor_v1.py
deleted
100644 → 0
View file @
ad385667
import
enum
from
abc
import
ABC
,
abstractmethod
from
typing
import
OrderedDict
from
vllm.block
import
PhysicalTokenBlock
class
EvictionPolicy
(
enum
.
Enum
):
"""Enum for eviction policy used by make_evictor to instantiate the correct
Evictor subclass.
"""
LRU
=
enum
.
auto
()
class
Evictor
(
ABC
):
"""The Evictor subclasses should be used by the BlockAllocator class to
handle eviction of freed PhysicalTokenBlocks.
"""
@
abstractmethod
def
__init__
(
self
):
pass
@
abstractmethod
def
__contains__
(
self
,
block_hash
:
int
)
->
bool
:
pass
@
abstractmethod
def
evict
(
self
)
->
PhysicalTokenBlock
:
"""Runs the eviction algorithm and returns the evicted block"""
pass
@
abstractmethod
def
add
(
self
,
block
:
PhysicalTokenBlock
):
"""Adds block to the evictor, making it a candidate for eviction"""
pass
@
abstractmethod
def
remove
(
self
,
block_hash
:
int
)
->
PhysicalTokenBlock
:
"""Simply removes the block with the hash value block_hash from the
evictor. Caller is responsible for making sure that block_hash is
contained in the evictor before calling remove. Should be used to
"bring back" blocks that have been freed but not evicted yet.
"""
pass
@
property
@
abstractmethod
def
num_blocks
(
self
)
->
int
:
pass
class
LRUEvictor
(
Evictor
):
"""Evicts in a least-recently-used order using the last_accessed timestamp
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chose arbitrarily
"""
def
__init__
(
self
):
self
.
free_table
:
OrderedDict
[
int
,
PhysicalTokenBlock
]
=
OrderedDict
()
def
__contains__
(
self
,
block_hash
:
int
)
->
bool
:
return
block_hash
in
self
.
free_table
def
evict
(
self
)
->
PhysicalTokenBlock
:
if
len
(
self
.
free_table
)
==
0
:
raise
ValueError
(
"No usable cache memory left"
)
evicted_block
=
next
(
iter
(
self
.
free_table
.
values
()))
# The blocks with the lowest timestamps should be placed consecutively
# at the start of OrderedDict. Loop through all these blocks to
# find the one with maximum number of hashed tokens.
for
_
,
block
in
self
.
free_table
.
items
():
if
evicted_block
.
last_accessed
<
block
.
last_accessed
:
break
if
evicted_block
.
num_hashed_tokens
<
block
.
num_hashed_tokens
:
evicted_block
=
block
self
.
free_table
.
pop
(
evicted_block
.
block_hash
)
evicted_block
.
computed
=
False
return
evicted_block
def
add
(
self
,
block
:
PhysicalTokenBlock
):
self
.
free_table
[
block
.
block_hash
]
=
block
def
remove
(
self
,
block_hash
:
int
)
->
PhysicalTokenBlock
:
if
block_hash
not
in
self
.
free_table
:
raise
ValueError
(
"Attempting to remove block that's not in the evictor"
)
block
:
PhysicalTokenBlock
=
self
.
free_table
[
block_hash
]
self
.
free_table
.
pop
(
block_hash
)
return
block
@
property
def
num_blocks
(
self
)
->
int
:
return
len
(
self
.
free_table
)
def
make_evictor
(
eviction_policy
:
EvictionPolicy
)
->
Evictor
:
if
eviction_policy
==
EvictionPolicy
.
LRU
:
return
LRUEvictor
()
else
:
raise
ValueError
(
f
"Unknown cache eviction policy:
{
eviction_policy
}
"
)
vllm/core/scheduler.py
View file @
2216a4e5
...
@@ -313,7 +313,7 @@ class Scheduler:
...
@@ -313,7 +313,7 @@ class Scheduler:
self
.
lora_config
=
lora_config
self
.
lora_config
=
lora_config
version
=
"selfattn"
version
=
"selfattn"
if
(
self
.
scheduler_config
.
embedding
_mode
if
(
self
.
scheduler_config
.
task
==
"
embedding
"
or
self
.
cache_config
.
is_attention_free
):
or
self
.
cache_config
.
is_attention_free
):
version
=
"placeholder"
version
=
"placeholder"
...
...
vllm/distributed/parallel_state.py
View file @
2216a4e5
...
@@ -7,7 +7,7 @@ It takes over the control of the distributed environment from PyTorch.
...
@@ -7,7 +7,7 @@ It takes over the control of the distributed environment from PyTorch.
The typical workflow is:
The typical workflow is:
- call `init_distributed_environment` to initialize the distributed environment.
- call `init_distributed_environment` to initialize the distributed environment.
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
initialize the model parallel groups.
initialize the model parallel groups.
- any code dealing with the distributed stuff
- any code dealing with the distributed stuff
...
@@ -20,6 +20,7 @@ If you only need to use the distributed environment without model/pipeline
...
@@ -20,6 +20,7 @@ If you only need to use the distributed environment without model/pipeline
steps.
steps.
"""
"""
import
contextlib
import
contextlib
import
gc
import
pickle
import
pickle
import
weakref
import
weakref
from
collections
import
namedtuple
from
collections
import
namedtuple
...
@@ -1129,6 +1130,19 @@ def destroy_distributed_environment():
...
@@ -1129,6 +1130,19 @@ def destroy_distributed_environment():
torch
.
distributed
.
destroy_process_group
()
torch
.
distributed
.
destroy_process_group
()
def
cleanup_dist_env_and_memory
(
shutdown_ray
:
bool
=
False
):
destroy_model_parallel
()
destroy_distributed_environment
()
with
contextlib
.
suppress
(
AssertionError
):
torch
.
distributed
.
destroy_process_group
()
if
shutdown_ray
:
import
ray
# Lazy import Ray
ray
.
shutdown
()
gc
.
collect
()
if
not
current_platform
.
is_cpu
():
torch
.
cuda
.
empty_cache
()
def
in_the_same_node_as
(
pg
:
ProcessGroup
,
source_rank
:
int
=
0
)
->
List
[
bool
]:
def
in_the_same_node_as
(
pg
:
ProcessGroup
,
source_rank
:
int
=
0
)
->
List
[
bool
]:
"""
"""
This is a collective operation that returns if each rank is in the same node
This is a collective operation that returns if each rank is in the same node
...
...
vllm/engine/arg_utils.py
View file @
2216a4e5
...
@@ -3,7 +3,7 @@ import dataclasses
...
@@ -3,7 +3,7 @@ import dataclasses
import
json
import
json
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Literal
,
Mapping
,
Optional
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
Type
,
Union
,
cast
)
Tuple
,
Type
,
Union
,
cast
,
get_args
)
import
torch
import
torch
...
@@ -12,10 +12,12 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
...
@@ -12,10 +12,12 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoadFormat
,
DeviceConfig
,
EngineConfig
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TokenizerPoolConfig
)
SpeculativeConfig
,
TaskOption
,
TokenizerPoolConfig
)
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.transformers_utils.config
import
(
maybe_register_config_serialize_by_value
)
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
...
@@ -84,6 +86,7 @@ class EngineArgs:
...
@@ -84,6 +86,7 @@ class EngineArgs:
model
:
str
=
'facebook/opt-125m'
model
:
str
=
'facebook/opt-125m'
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
tokenizer
:
Optional
[
str
]
=
None
tokenizer
:
Optional
[
str
]
=
None
task
:
TaskOption
=
"auto"
skip_tokenizer_init
:
bool
=
False
skip_tokenizer_init
:
bool
=
False
tokenizer_mode
:
str
=
'auto'
tokenizer_mode
:
str
=
'auto'
trust_remote_code
:
bool
=
False
trust_remote_code
:
bool
=
False
...
@@ -198,6 +201,15 @@ class EngineArgs:
...
@@ -198,6 +201,15 @@ class EngineArgs:
type
=
str
,
type
=
str
,
default
=
EngineArgs
.
model
,
default
=
EngineArgs
.
model
,
help
=
'Name or path of the huggingface model to use.'
)
help
=
'Name or path of the huggingface model to use.'
)
parser
.
add_argument
(
'--task'
,
default
=
EngineArgs
.
task
,
choices
=
get_args
(
TaskOption
),
help
=
'The task to use the model for. Each vLLM instance only '
'supports one task, even if the same model can be used for '
'multiple tasks. When the model only supports one task, "auto" '
'can be used to select it; otherwise, you must specify explicitly '
'which task to use.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--tokenizer'
,
'--tokenizer'
,
type
=
nullable_str
,
type
=
nullable_str
,
...
@@ -418,7 +430,11 @@ class EngineArgs:
...
@@ -418,7 +430,11 @@ class EngineArgs:
help
=
'The fraction of GPU memory to be used for the model '
help
=
'The fraction of GPU memory to be used for the model '
'executor, which can range from 0 to 1. For example, a value of '
'executor, which can range from 0 to 1. For example, a value of '
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
'will use the default value of 0.9.'
)
'will use the default value of 0.9. This is a global gpu memory '
'utilization limit, for example if 50%% of the gpu memory is '
'already used before vLLM starts and --gpu-memory-utilization is '
'set to 0.9, then only 40%% of the gpu memory will be allocated '
'to the model executor.'
)
parser
.
add_argument
(
parser
.
add_argument
(
'--num-gpu-blocks-override'
,
'--num-gpu-blocks-override'
,
type
=
int
,
type
=
int
,
...
@@ -838,6 +854,7 @@ class EngineArgs:
...
@@ -838,6 +854,7 @@ class EngineArgs:
def
create_model_config
(
self
)
->
ModelConfig
:
def
create_model_config
(
self
)
->
ModelConfig
:
return
ModelConfig
(
return
ModelConfig
(
model
=
self
.
model
,
model
=
self
.
model
,
task
=
self
.
task
,
# We know this is not None because we set it in __post_init__
# We know this is not None because we set it in __post_init__
tokenizer
=
cast
(
str
,
self
.
tokenizer
),
tokenizer
=
cast
(
str
,
self
.
tokenizer
),
tokenizer_mode
=
self
.
tokenizer_mode
,
tokenizer_mode
=
self
.
tokenizer_mode
,
...
@@ -909,6 +926,8 @@ class EngineArgs:
...
@@ -909,6 +926,8 @@ class EngineArgs:
"supported for multimodal models and has been disabled."
)
"supported for multimodal models and has been disabled."
)
self
.
enable_prefix_caching
=
False
self
.
enable_prefix_caching
=
False
maybe_register_config_serialize_by_value
(
self
.
trust_remote_code
)
cache_config
=
CacheConfig
(
cache_config
=
CacheConfig
(
# neuron needs block_size = max_model_len
# neuron needs block_size = max_model_len
block_size
=
self
.
block_size
if
self
.
device
!=
"neuron"
else
block_size
=
self
.
block_size
if
self
.
device
!=
"neuron"
else
...
@@ -1026,13 +1045,13 @@ class EngineArgs:
...
@@ -1026,13 +1045,13 @@ class EngineArgs:
" please file an issue with detailed information."
)
" please file an issue with detailed information."
)
scheduler_config
=
SchedulerConfig
(
scheduler_config
=
SchedulerConfig
(
task
=
model_config
.
task
,
max_num_batched_tokens
=
self
.
max_num_batched_tokens
,
max_num_batched_tokens
=
self
.
max_num_batched_tokens
,
max_num_seqs
=
self
.
max_num_seqs
,
max_num_seqs
=
self
.
max_num_seqs
,
max_model_len
=
model_config
.
max_model_len
,
max_model_len
=
model_config
.
max_model_len
,
num_lookahead_slots
=
num_lookahead_slots
,
num_lookahead_slots
=
num_lookahead_slots
,
delay_factor
=
self
.
scheduler_delay_factor
,
delay_factor
=
self
.
scheduler_delay_factor
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
embedding_mode
=
model_config
.
embedding_mode
,
is_multimodal_model
=
model_config
.
is_multimodal_model
,
is_multimodal_model
=
model_config
.
is_multimodal_model
,
preemption_mode
=
self
.
preemption_mode
,
preemption_mode
=
self
.
preemption_mode
,
num_scheduler_steps
=
self
.
num_scheduler_steps
,
num_scheduler_steps
=
self
.
num_scheduler_steps
,
...
...
vllm/engine/llm_engine.py
View file @
2216a4e5
import
time
import
time
from
collections
import
Counter
as
collectionsCounter
from
collections
import
deque
from
collections
import
deque
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
...
@@ -43,8 +44,10 @@ from vllm.pooling_params import PoolingParams
...
@@ -43,8 +44,10 @@ from vllm.pooling_params import PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
from
vllm.sequence
import
(
EmbeddingSequenceGroupOutput
,
ExecuteModelRequest
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
ParallelSampleSequenceGroup
,
Sequence
,
SequenceGroupOutput
,
SequenceStatus
)
SequenceGroup
,
SequenceGroupBase
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceStatus
)
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
from
vllm.tracing
import
(
SpanAttributes
,
SpanKind
,
extract_trace_context
,
init_tracer
)
init_tracer
)
from
vllm.transformers_utils.config
import
try_get_generation_config
from
vllm.transformers_utils.config
import
try_get_generation_config
...
@@ -344,7 +347,7 @@ class LLMEngine:
...
@@ -344,7 +347,7 @@ class LLMEngine:
observability_config
=
self
.
observability_config
,
observability_config
=
self
.
observability_config
,
)
)
if
not
self
.
model_config
.
embedding
_mode
:
if
self
.
model_config
.
task
!=
"
embedding
"
:
self
.
_initialize_kv_caches
()
self
.
_initialize_kv_caches
()
# If usage stat is enabled, collect relevant info.
# If usage stat is enabled, collect relevant info.
...
@@ -473,6 +476,8 @@ class LLMEngine:
...
@@ -473,6 +476,8 @@ class LLMEngine:
),
),
))
))
self
.
seq_id_to_seq_group
:
Dict
[
str
,
SequenceGroupBase
]
=
{}
def
_initialize_kv_caches
(
self
)
->
None
:
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
"""Initialize the KV cache in the worker(s).
...
@@ -641,7 +646,10 @@ class LLMEngine:
...
@@ -641,7 +646,10 @@ class LLMEngine:
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
],
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
None
:
)
->
SequenceGroup
:
"""Add a processed request to the engine's request pool.
return the created sequence group.
"""
self
.
_validate_model_inputs
(
processed_inputs
)
self
.
_validate_model_inputs
(
processed_inputs
)
# Create the sequences.
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
block_size
=
self
.
cache_config
.
block_size
...
@@ -695,6 +703,8 @@ class LLMEngine:
...
@@ -695,6 +703,8 @@ class LLMEngine:
min_cost_scheduler
=
self
.
scheduler
[
costs
.
index
(
min
(
costs
))]
min_cost_scheduler
=
self
.
scheduler
[
costs
.
index
(
min
(
costs
))]
min_cost_scheduler
.
add_seq_group
(
seq_group
)
min_cost_scheduler
.
add_seq_group
(
seq_group
)
return
seq_group
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
def
stop_remote_worker_execution_loop
(
self
)
->
None
:
self
.
model_executor
.
stop_remote_worker_execution_loop
()
self
.
model_executor
.
stop_remote_worker_execution_loop
()
...
@@ -710,7 +720,7 @@ class LLMEngine:
...
@@ -710,7 +720,7 @@ class LLMEngine:
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
None
:
)
->
Optional
[
SequenceGroup
]
:
...
...
@
overload
@
overload
...
@@ -724,7 +734,7 @@ class LLMEngine:
...
@@ -724,7 +734,7 @@ class LLMEngine:
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
)
->
None
:
)
->
Optional
[
SequenceGroup
]
:
...
...
@
deprecate_kwargs
(
@
deprecate_kwargs
(
...
@@ -743,7 +753,7 @@ class LLMEngine:
...
@@ -743,7 +753,7 @@ class LLMEngine:
priority
:
int
=
0
,
priority
:
int
=
0
,
*
,
*
,
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
inputs
:
Optional
[
PromptType
]
=
None
,
# DEPRECATED
)
->
None
:
)
->
Optional
[
SequenceGroup
]
:
"""Add a request to the engine's request pool.
"""Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the
The request is added to the request pool and will be processed by the
...
@@ -787,6 +797,22 @@ class LLMEngine:
...
@@ -787,6 +797,22 @@ class LLMEngine:
>>> # continue the request processing
>>> # continue the request processing
>>> ...
>>> ...
"""
"""
if
isinstance
(
params
,
SamplingParams
)
and
params
.
n
>
1
:
ParallelSampleSequenceGroup
.
add_request
(
request_id
,
self
,
params
,
prompt
=
prompt
,
arrival_time
=
arrival_time
,
lora_request
=
lora_request
,
trace_headers
=
trace_headers
,
prompt_adapter_request
=
prompt_adapter_request
,
priority
=
priority
,
inputs
=
inputs
,
)
return
None
if
inputs
is
not
None
:
if
inputs
is
not
None
:
prompt
=
inputs
prompt
=
inputs
assert
prompt
is
not
None
and
params
is
not
None
assert
prompt
is
not
None
and
params
is
not
None
...
@@ -817,7 +843,7 @@ class LLMEngine:
...
@@ -817,7 +843,7 @@ class LLMEngine:
processed_inputs
[
"mm_processor_kwargs"
]
=
preprocessed_inputs
.
get
(
processed_inputs
[
"mm_processor_kwargs"
]
=
preprocessed_inputs
.
get
(
"mm_processor_kwargs"
)
"mm_processor_kwargs"
)
self
.
_add_processed_request
(
return
self
.
_add_processed_request
(
request_id
=
request_id
,
request_id
=
request_id
,
processed_inputs
=
processed_inputs
,
processed_inputs
=
processed_inputs
,
params
=
params
,
params
=
params
,
...
@@ -1116,7 +1142,7 @@ class LLMEngine:
...
@@ -1116,7 +1142,7 @@ class LLMEngine:
seq_group
.
metrics
.
model_execute_time
=
(
seq_group
.
metrics
.
model_execute_time
=
(
o
.
model_execute_time
)
o
.
model_execute_time
)
if
self
.
model_config
.
embedding
_mode
:
if
self
.
model_config
.
task
==
"
embedding
"
:
self
.
_process_sequence_group_outputs
(
seq_group
,
output
)
self
.
_process_sequence_group_outputs
(
seq_group
,
output
)
else
:
else
:
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
output
)
self
.
output_processor
.
process_prompt_logprob
(
seq_group
,
output
)
...
@@ -1134,7 +1160,9 @@ class LLMEngine:
...
@@ -1134,7 +1160,9 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
seq_group
,
self
.
seq_id_to_seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
...
@@ -1174,7 +1202,9 @@ class LLMEngine:
...
@@ -1174,7 +1202,9 @@ class LLMEngine:
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
.
maybe_set_first_token_time
(
now
)
seq_group
.
maybe_set_first_token_time
(
now
)
request_output
=
RequestOutputFactory
.
create
(
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
seq_group
,
self
.
seq_id_to_seq_group
,
use_cache
=
self
.
use_cached_outputs
)
if
request_output
:
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
...
@@ -1193,7 +1223,10 @@ class LLMEngine:
...
@@ -1193,7 +1223,10 @@ class LLMEngine:
continue
continue
request_output
=
RequestOutputFactory
.
create
(
request_output
=
RequestOutputFactory
.
create
(
seq_group
,
use_cache
=
self
.
use_cached_outputs
)
seq_group
,
self
.
seq_id_to_seq_group
,
use_cache
=
self
.
use_cached_outputs
,
)
if
request_output
:
if
request_output
:
ctx
.
request_outputs
.
append
(
request_output
)
ctx
.
request_outputs
.
append
(
request_output
)
...
@@ -1212,7 +1245,7 @@ class LLMEngine:
...
@@ -1212,7 +1245,7 @@ class LLMEngine:
skip
)
skip
)
# Tracing
# Tracing
self
.
do_tracing
(
scheduler_outputs
)
self
.
do_tracing
(
scheduler_outputs
,
finished_before
)
return
None
return
None
...
@@ -1617,6 +1650,25 @@ class LLMEngine:
...
@@ -1617,6 +1650,25 @@ class LLMEngine:
n_requests
:
List
[
int
]
=
[]
n_requests
:
List
[
int
]
=
[]
finished_reason_requests
:
List
[
str
]
=
[]
finished_reason_requests
:
List
[
str
]
=
[]
# Lora requests
running_lora_adapters
=
dict
(
collectionsCounter
([
running_request
.
lora_request
.
lora_name
for
scheduler
in
self
.
scheduler
for
running_request
in
scheduler
.
running
if
running_request
.
lora_request
]))
waiting_lora_adapters
=
dict
(
collectionsCounter
([
waiting_request
.
lora_request
.
lora_name
for
scheduler
in
self
.
scheduler
for
waiting_request
in
scheduler
.
waiting
if
waiting_request
.
lora_request
]))
max_lora_stat
=
"0"
if
self
.
lora_config
:
max_lora_stat
=
str
(
self
.
lora_config
.
max_loras
)
# NOTE: This loop assumes prefill seq_groups are before
# NOTE: This loop assumes prefill seq_groups are before
# decode seq_groups in scheduled_seq_groups.
# decode seq_groups in scheduled_seq_groups.
if
scheduler_outputs
is
not
None
:
if
scheduler_outputs
is
not
None
:
...
@@ -1666,6 +1718,15 @@ class LLMEngine:
...
@@ -1666,6 +1718,15 @@ class LLMEngine:
# TPOTs.
# TPOTs.
latency
=
seq_group
.
get_last_latency
(
now
)
latency
=
seq_group
.
get_last_latency
(
now
)
time_per_output_tokens_iter
.
append
(
latency
)
time_per_output_tokens_iter
.
append
(
latency
)
if
seq_group
.
state
.
current_step
==
0
:
# For async_output_proc, the do_log_stats()
# is called following init_multi_step(), which
# sets the current_step to zero.
actual_num_batched_tokens
+=
\
seq_group
.
state
.
num_steps
-
1
else
:
actual_num_batched_tokens
+=
\
seq_group
.
state
.
current_step
-
1
# Because of chunked prefill, we can have a single sequence
# Because of chunked prefill, we can have a single sequence
# group that does multiple prompt_runs. To prevent logging
# group that does multiple prompt_runs. To prevent logging
...
@@ -1738,7 +1799,9 @@ class LLMEngine:
...
@@ -1738,7 +1799,9 @@ class LLMEngine:
num_generation_tokens_requests
=
num_generation_tokens_requests
,
num_generation_tokens_requests
=
num_generation_tokens_requests
,
n_requests
=
n_requests
,
n_requests
=
n_requests
,
finished_reason_requests
=
finished_reason_requests
,
finished_reason_requests
=
finished_reason_requests
,
)
max_lora
=
str
(
max_lora_stat
),
waiting_lora_adapters
=
list
(
waiting_lora_adapters
.
keys
()),
running_lora_adapters
=
list
(
running_lora_adapters
.
keys
()))
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
self
.
model_executor
.
add_lora
(
lora_request
)
return
self
.
model_executor
.
add_lora
(
lora_request
)
...
@@ -1786,11 +1849,18 @@ class LLMEngine:
...
@@ -1786,11 +1849,18 @@ class LLMEngine:
def
is_tracing_enabled
(
self
)
->
bool
:
def
is_tracing_enabled
(
self
)
->
bool
:
return
self
.
tracer
is
not
None
return
self
.
tracer
is
not
None
def
do_tracing
(
self
,
scheduler_outputs
:
SchedulerOutputs
)
->
None
:
def
do_tracing
(
self
,
scheduler_outputs
:
SchedulerOutputs
,
finished_before
:
Optional
[
List
[
int
]]
=
None
)
->
None
:
if
self
.
tracer
is
None
:
if
self
.
tracer
is
None
:
return
return
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
for
idx
,
scheduled_seq_group
in
enumerate
(
scheduler_outputs
.
scheduled_seq_groups
):
# Skip double tracing when using async output proc
if
finished_before
and
idx
in
finished_before
:
continue
seq_group
=
scheduled_seq_group
.
seq_group
seq_group
=
scheduled_seq_group
.
seq_group
if
seq_group
.
is_finished
():
if
seq_group
.
is_finished
():
self
.
create_trace_span
(
seq_group
)
self
.
create_trace_span
(
seq_group
)
...
@@ -1855,9 +1925,6 @@ class LLMEngine:
...
@@ -1855,9 +1925,6 @@ class LLMEngine:
def
is_encoder_decoder_model
(
self
):
def
is_encoder_decoder_model
(
self
):
return
self
.
input_preprocessor
.
is_encoder_decoder_model
()
return
self
.
input_preprocessor
.
is_encoder_decoder_model
()
def
is_embedding_model
(
self
):
return
self
.
model_config
.
is_embedding_model
def
_validate_model_inputs
(
self
,
inputs
:
Union
[
DecoderOnlyInputs
,
def
_validate_model_inputs
(
self
,
inputs
:
Union
[
DecoderOnlyInputs
,
EncoderDecoderInputs
]):
EncoderDecoderInputs
]):
if
self
.
model_config
.
is_multimodal_model
:
if
self
.
model_config
.
is_multimodal_model
:
...
...
vllm/engine/metrics.py
View file @
2216a4e5
...
@@ -34,7 +34,11 @@ class Metrics:
...
@@ -34,7 +34,11 @@ class Metrics:
See https://prometheus.github.io/client_python/multiprocess/ for more
See https://prometheus.github.io/client_python/multiprocess/ for more
details on limitations.
details on limitations.
"""
"""
labelname_finish_reason
=
"finished_reason"
labelname_finish_reason
=
"finished_reason"
labelname_waiting_lora_adapters
=
"waiting_lora_adapters"
labelname_running_lora_adapters
=
"running_lora_adapters"
labelname_max_lora
=
"max_lora"
_gauge_cls
=
prometheus_client
.
Gauge
_gauge_cls
=
prometheus_client
.
Gauge
_counter_cls
=
prometheus_client
.
Counter
_counter_cls
=
prometheus_client
.
Counter
_histogram_cls
=
prometheus_client
.
Histogram
_histogram_cls
=
prometheus_client
.
Histogram
...
@@ -55,6 +59,16 @@ class Metrics:
...
@@ -55,6 +59,16 @@ class Metrics:
documentation
=
"Number of requests waiting to be processed."
,
documentation
=
"Number of requests waiting to be processed."
,
labelnames
=
labelnames
,
labelnames
=
labelnames
,
multiprocess_mode
=
"sum"
)
multiprocess_mode
=
"sum"
)
self
.
gauge_lora_info
=
self
.
_gauge_cls
(
name
=
"vllm:lora_requests_info"
,
documentation
=
"Running stats on lora requests."
,
labelnames
=
[
self
.
labelname_running_lora_adapters
,
self
.
labelname_max_lora
,
self
.
labelname_waiting_lora_adapters
,
],
multiprocess_mode
=
"livemostrecent"
,
)
self
.
gauge_scheduler_swapped
=
self
.
_gauge_cls
(
self
.
gauge_scheduler_swapped
=
self
.
_gauge_cls
(
name
=
"vllm:num_requests_swapped"
,
name
=
"vllm:num_requests_swapped"
,
documentation
=
"Number of requests swapped to CPU."
,
documentation
=
"Number of requests swapped to CPU."
,
...
@@ -426,6 +440,9 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -426,6 +440,9 @@ class PrometheusStatLogger(StatLoggerBase):
for
datum
in
data
:
for
datum
in
data
:
histogram
.
labels
(
**
self
.
labels
).
observe
(
datum
)
histogram
.
labels
(
**
self
.
labels
).
observe
(
datum
)
def
_log_gauge_string
(
self
,
gauge
,
data
:
Dict
[
str
,
str
])
->
None
:
gauge
.
labels
(
**
data
).
set
(
1
)
def
_log_prometheus
(
self
,
stats
:
Stats
)
->
None
:
def
_log_prometheus
(
self
,
stats
:
Stats
)
->
None
:
# System state data
# System state data
self
.
_log_gauge
(
self
.
metrics
.
gauge_scheduler_running
,
self
.
_log_gauge
(
self
.
metrics
.
gauge_scheduler_running
,
...
@@ -442,7 +459,17 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -442,7 +459,17 @@ class PrometheusStatLogger(StatLoggerBase):
stats
.
cpu_prefix_cache_hit_rate
)
stats
.
cpu_prefix_cache_hit_rate
)
self
.
_log_gauge
(
self
.
metrics
.
gauge_gpu_prefix_cache_hit_rate
,
self
.
_log_gauge
(
self
.
metrics
.
gauge_gpu_prefix_cache_hit_rate
,
stats
.
gpu_prefix_cache_hit_rate
)
stats
.
gpu_prefix_cache_hit_rate
)
# Including max-lora in metric, in future this property of lora
# config maybe extended to be dynamic.
lora_info
=
{
self
.
metrics
.
labelname_running_lora_adapters
:
","
.
join
(
stats
.
running_lora_adapters
),
self
.
metrics
.
labelname_waiting_lora_adapters
:
","
.
join
(
stats
.
waiting_lora_adapters
),
self
.
metrics
.
labelname_max_lora
:
stats
.
max_lora
,
}
self
.
_log_gauge_string
(
self
.
metrics
.
gauge_lora_info
,
lora_info
)
# Iteration level data
# Iteration level data
self
.
_log_counter
(
self
.
metrics
.
counter_num_preemption
,
self
.
_log_counter
(
self
.
metrics
.
counter_num_preemption
,
stats
.
num_preemption_iter
)
stats
.
num_preemption_iter
)
...
...
vllm/engine/metrics_types.py
View file @
2216a4e5
...
@@ -51,6 +51,9 @@ class Stats:
...
@@ -51,6 +51,9 @@ class Stats:
num_generation_tokens_requests
:
List
[
int
]
num_generation_tokens_requests
:
List
[
int
]
n_requests
:
List
[
int
]
n_requests
:
List
[
int
]
finished_reason_requests
:
List
[
str
]
finished_reason_requests
:
List
[
str
]
waiting_lora_adapters
:
List
[
str
]
running_lora_adapters
:
List
[
str
]
max_lora
:
str
spec_decode_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
spec_decode_metrics
:
Optional
[
"SpecDecodeWorkerMetrics"
]
=
None
...
...
vllm/engine/multiprocessing/client.py
View file @
2216a4e5
...
@@ -204,8 +204,20 @@ class MQLLMEngineClient(EngineClient):
...
@@ -204,8 +204,20 @@ class MQLLMEngineClient(EngineClient):
# (and record only the first one)
# (and record only the first one)
if
is_engine_errored
and
not
self
.
_errored_with
:
if
is_engine_errored
and
not
self
.
_errored_with
:
self
.
_errored_with
=
exception
self
.
_errored_with
=
exception
# If engine is errored, no matter the type of exception
# it will no longer be able to receive new requests,
# therefore we have to inform that the current
# processed requests failed as well. Send back a dead
# engine error give this feedback and also give a
# 'hint' to the server to shutdown next.
exception
=
self
.
dead_error
if
request_id
is
None
:
if
request_id
is
None
:
# If request_id is None, then the engine raised an
# exception for a batch, and we may not know the
# request that caused it, neither if it was actually
# caused by any of them (e.g. CUDA OOM). Therefore we
# broadcast the same exception for all requests.
for
queue_i
in
tuple
(
self
.
output_queues
.
values
()):
for
queue_i
in
tuple
(
self
.
output_queues
.
values
()):
queue_i
.
put_nowait
(
exception
)
queue_i
.
put_nowait
(
exception
)
else
:
else
:
...
...
vllm/engine/multiprocessing/engine.py
View file @
2216a4e5
...
@@ -8,7 +8,7 @@ from typing import Iterator, List, Optional, Union
...
@@ -8,7 +8,7 @@ from typing import Iterator, List, Optional, Union
import
cloudpickle
import
cloudpickle
import
zmq
import
zmq
from
vllm
import
AsyncEngineArgs
,
LLMEngine
,
SamplingParams
from
vllm
import
AsyncEngineArgs
,
SamplingParams
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
ParallelConfig
,
SchedulerConfig
)
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
...
@@ -21,12 +21,17 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
...
@@ -21,12 +21,17 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCStartupRequest
,
RPCStartupResponse
,
RPCStartupRequest
,
RPCStartupResponse
,
RPCUProfileRequest
)
RPCUProfileRequest
)
# yapf: enable
# yapf: enable
from
vllm.envs
import
VLLM_RPC_TIMEOUT
from
vllm.envs
import
VLLM_RPC_TIMEOUT
,
VLLM_USE_V1
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.executor.gpu_executor
import
GPUExecutor
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
if
VLLM_USE_V1
:
from
vllm.v1.engine.llm_engine
import
LLMEngine
else
:
from
vllm.engine.llm_engine
import
LLMEngine
CONFIG_TYPE
=
Union
[
ModelConfig
,
DecodingConfig
,
ParallelConfig
,
CONFIG_TYPE
=
Union
[
ModelConfig
,
DecodingConfig
,
ParallelConfig
,
SchedulerConfig
,
LoRAConfig
]
SchedulerConfig
,
LoRAConfig
]
...
@@ -136,14 +141,16 @@ class MQLLMEngine:
...
@@ -136,14 +141,16 @@ class MQLLMEngine:
executor_class
=
LLMEngine
.
_get_executor_cls
(
engine_config
)
executor_class
=
LLMEngine
.
_get_executor_cls
(
engine_config
)
return
cls
(
use_async_sockets
=
(
engine_config
.
model_config
.
use_async_output_proc
ipc_path
=
ipc_path
,
and
not
VLLM_USE_V1
)
use_async_sockets
=
engine_config
.
model_config
.
use_async_output_proc
,
**
engine_config
.
to_dict
(),
return
cls
(
ipc_path
=
ipc_path
,
executor_class
=
executor_class
,
use_async_sockets
=
use_async_sockets
,
log_requests
=
not
engine_args
.
disable_log_requests
,
**
engine_config
.
to_dict
(),
log_stats
=
not
engine_args
.
disable_log_stats
,
executor_class
=
executor_class
,
usage_context
=
usage_context
)
log_requests
=
not
engine_args
.
disable_log_requests
,
log_stats
=
not
engine_args
.
disable_log_stats
,
usage_context
=
usage_context
)
def
start
(
self
):
def
start
(
self
):
try
:
try
:
...
...
vllm/engine/protocol.py
View file @
2216a4e5
...
@@ -59,7 +59,7 @@ class EngineClient(ABC):
...
@@ -59,7 +59,7 @@ class EngineClient(ABC):
async
def
beam_search
(
async
def
beam_search
(
self
,
self
,
prompt
:
Union
[
PromptType
,
List
[
int
]],
prompt
:
Union
[
str
,
List
[
int
]],
request_id
:
str
,
request_id
:
str
,
params
:
BeamSearchParams
,
params
:
BeamSearchParams
,
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
)
->
AsyncGenerator
[
RequestOutput
,
None
]:
...
@@ -71,9 +71,13 @@ class EngineClient(ABC):
...
@@ -71,9 +71,13 @@ class EngineClient(ABC):
length_penalty
=
params
.
length_penalty
length_penalty
=
params
.
length_penalty
tokenizer
=
await
self
.
get_tokenizer
(
lora_request
=
None
)
tokenizer
=
await
self
.
get_tokenizer
(
lora_request
=
None
)
tokenizedPrompt
=
prompt
if
isinstance
(
if
isinstance
(
prompt
,
str
):
prompt
,
list
)
else
tokenizer
.
encode
(
prompt
)
tokenized_prompt
=
tokenizer
.
encode
(
prompt
)
tokenizedLength
=
len
(
tokenizedPrompt
)
prompt_text
=
prompt
else
:
tokenized_prompt
=
prompt
prompt_text
=
None
tokenized_length
=
len
(
tokenized_prompt
)
sort_beams_key
=
create_sort_beams_key_function
(
sort_beams_key
=
create_sort_beams_key_function
(
tokenizer
.
eos_token_id
,
length_penalty
)
tokenizer
.
eos_token_id
,
length_penalty
)
...
@@ -81,7 +85,11 @@ class EngineClient(ABC):
...
@@ -81,7 +85,11 @@ class EngineClient(ABC):
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
max_tokens
=
1
,
max_tokens
=
1
,
temperature
=
temperature
)
temperature
=
temperature
)
all_beams
=
[
BeamSearchSequence
(
tokens
=
tokenizedPrompt
,
cum_logprob
=
0
)]
all_beams
=
[
BeamSearchSequence
(
tokens
=
tokenized_prompt
,
logprobs
=
[],
cum_logprob
=
0
)
]
completed
=
[]
completed
=
[]
for
_
in
range
(
max_tokens
):
for
_
in
range
(
max_tokens
):
...
@@ -114,6 +122,7 @@ class EngineClient(ABC):
...
@@ -114,6 +122,7 @@ class EngineClient(ABC):
for
token_id
,
logprob_obj
in
logprobs
.
items
():
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
tokens
=
current_beam
.
tokens
+
[
token_id
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
cum_logprob
=
current_beam
.
cum_logprob
+
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
logprob_obj
.
logprob
)
...
@@ -131,22 +140,22 @@ class EngineClient(ABC):
...
@@ -131,22 +140,22 @@ class EngineClient(ABC):
best_beams
=
sorted_completed
[:
beam_width
]
best_beams
=
sorted_completed
[:
beam_width
]
for
beam
in
best_beams
:
for
beam
in
best_beams
:
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
[
tokenized
L
ength
:])
beam
.
text
=
tokenizer
.
decode
(
beam
.
tokens
[
tokenized
_l
ength
:])
beam_search_output
=
RequestOutput
(
beam_search_output
=
RequestOutput
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt
=
prompt
,
prompt
=
prompt
_text
,
outputs
=
[
outputs
=
[
CompletionOutput
(
CompletionOutput
(
text
=
beam
.
text
,
text
=
beam
.
text
,
cumulative_logprob
=
beam
.
cum_logprob
,
cumulative_logprob
=
beam
.
cum_logprob
,
token_ids
=
beam
.
tokens
,
token_ids
=
beam
.
tokens
[
tokenized_length
:]
,
index
=
i
,
index
=
i
,
logprobs
=
beam
.
cum_
logprob
,
logprobs
=
beam
.
logprob
s
,
)
for
(
i
,
beam
)
in
enumerate
(
best_beams
)
)
for
(
i
,
beam
)
in
enumerate
(
best_beams
)
],
],
finished
=
True
,
finished
=
True
,
prompt_token_ids
=
tokenized
P
rompt
,
prompt_token_ids
=
tokenized
_p
rompt
,
prompt_logprobs
=
None
)
prompt_logprobs
=
None
)
yield
beam_search_output
yield
beam_search_output
...
...
vllm/entrypoints/chat_utils.py
View file @
2216a4e5
...
@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
...
@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
from
collections
import
defaultdict
from
collections
import
defaultdict
from
functools
import
lru_cache
,
partial
from
functools
import
lru_cache
,
partial
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
(
Any
,
Awaitable
,
Dict
,
Generic
,
Iterable
,
List
,
Literal
,
from
typing
import
(
Any
,
Awaitable
,
Callable
,
Dict
,
Generic
,
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
TypeVar
,
Union
,
cast
)
Literal
,
Mapping
,
Optional
,
Tuple
,
TypeVar
,
Union
,
cast
)
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
...
@@ -33,6 +33,7 @@ from vllm.multimodal.utils import (async_get_and_parse_audio,
...
@@ -33,6 +33,7 @@ from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image
,
async_get_and_parse_image
,
get_and_parse_audio
,
get_and_parse_image
)
get_and_parse_audio
,
get_and_parse_image
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
print_warning_once
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -58,10 +59,35 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
...
@@ -58,10 +59,35 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
"""The type of the content part."""
"""The type of the content part."""
class
CustomChatCompletionContentSimpleImageParam
(
TypedDict
,
total
=
False
):
"""A simpler version of the param that only accepts a plain image_url.
This is supported by OpenAI API, although it is not documented.
Example:
{
"image_url": "https://example.com/image.jpg"
}
"""
image_url
:
Required
[
str
]
class
CustomChatCompletionContentSimpleAudioParam
(
TypedDict
,
total
=
False
):
"""A simpler version of the param that only accepts a plain audio_url.
Example:
{
"audio_url": "https://example.com/audio.mp3"
}
"""
audio_url
:
Required
[
str
]
ChatCompletionContentPartParam
:
TypeAlias
=
Union
[
ChatCompletionContentPartParam
:
TypeAlias
=
Union
[
OpenAIChatCompletionContentPartParam
,
ChatCompletionContentPartAudioParam
,
OpenAIChatCompletionContentPartParam
,
ChatCompletionContentPartAudioParam
,
ChatCompletionContentPartRefusalParam
,
ChatCompletionContentPartRefusalParam
,
CustomChatCompletionContentPartParam
]
CustomChatCompletionContentPartParam
,
CustomChatCompletionContentSimpleImageParam
,
CustomChatCompletionContentSimpleAudioParam
,
str
]
class
CustomChatCompletionMessageParam
(
TypedDict
,
total
=
False
):
class
CustomChatCompletionMessageParam
(
TypedDict
,
total
=
False
):
...
@@ -386,6 +412,71 @@ _AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
...
@@ -386,6 +412,71 @@ _AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser
=
partial
(
cast
,
ChatCompletionContentPartRefusalParam
)
_RefusalParser
=
partial
(
cast
,
ChatCompletionContentPartRefusalParam
)
MODEL_KEEP_MULTI_MODAL_CONTENT
=
{
'mllama'
}
MODEL_KEEP_MULTI_MODAL_CONTENT
=
{
'mllama'
}
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP
:
Dict
[
str
,
Callable
[[
ChatCompletionContentPartParam
],
str
]]
=
{
"text"
:
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
""
),
"image_url"
:
lambda
part
:
_ImageParser
(
part
).
get
(
"image_url"
,
{}).
get
(
"url"
,
""
),
"audio_url"
:
lambda
part
:
_AudioParser
(
part
).
get
(
"audio_url"
,
{}).
get
(
"url"
,
""
),
"refusal"
:
lambda
part
:
_RefusalParser
(
part
).
get
(
"refusal"
,
""
),
}
def
_parse_chat_message_content_mm_part
(
part
:
ChatCompletionContentPartParam
)
->
Tuple
[
str
,
str
]:
"""
Parses a given multi modal content part based on its type.
Args:
part: A dict containing the content part, with a potential 'type' field.
Returns:
A tuple (part_type, content) where:
- part_type: Type of the part (e.g., 'text', 'image_url').
- content: Parsed content (e.g., text, image URL).
Raises:
ValueError: If the 'type' field is missing and no direct URL is found.
"""
assert
isinstance
(
part
,
dict
)
# This is needed to avoid mypy errors: part.get() from str
part_type
=
part
.
get
(
"type"
,
None
)
if
isinstance
(
part_type
,
str
)
and
part_type
in
MM_PARSER_MAP
:
content
=
MM_PARSER_MAP
[
part_type
](
part
)
# Special case for 'image_url.detail'
if
part_type
==
"image_url"
and
part
.
get
(
"detail"
)
!=
"auto"
:
logger
.
warning
(
"'image_url.detail' is currently not supported "
"and will be ignored."
)
return
part_type
,
content
# Handle missing 'type' but provided direct URL fields.
if
part_type
is
None
:
if
part
.
get
(
"image_url"
)
is
not
None
:
image_params
=
cast
(
CustomChatCompletionContentSimpleImageParam
,
part
)
return
"image_url"
,
image_params
.
get
(
"image_url"
,
""
)
if
part
.
get
(
"audio_url"
)
is
not
None
:
audio_params
=
cast
(
CustomChatCompletionContentSimpleAudioParam
,
part
)
return
"audio_url"
,
audio_params
.
get
(
"audio_url"
,
""
)
# Raise an error if no 'type' or direct URL is found.
raise
ValueError
(
"Missing 'type' field in multimodal part."
)
if
not
isinstance
(
part_type
,
str
):
raise
ValueError
(
"Invalid 'type' field in multimodal part."
)
return
part_type
,
"unknown part_type content"
VALID_MESSAGE_CONTENT_MM_PART_TYPES
=
(
"text"
,
"refusal"
,
"image_url"
,
"audio_url"
)
def
_parse_chat_message_content_parts
(
def
_parse_chat_message_content_parts
(
role
:
str
,
role
:
str
,
...
@@ -401,29 +492,28 @@ def _parse_chat_message_content_parts(
...
@@ -401,29 +492,28 @@ def _parse_chat_message_content_parts(
has_image
=
False
has_image
=
False
for
part
in
parts
:
for
part
in
parts
:
part_type
=
part
[
"type"
]
if
isinstance
(
part
,
str
):
# Handle plain text parts
if
part_type
==
"text"
:
text
=
_TextParser
(
part
)
text
=
_TextParser
(
part
)[
"text"
]
texts
.
append
(
text
)
texts
.
append
(
text
)
el
if
part_type
==
"image_url"
:
el
se
:
# Handle structured dictionary parts
image_url
=
_ImageParser
(
part
)[
"image_url"
]
part_type
,
content
=
_parse_chat_message_content_mm_part
(
part
)
if
image_url
.
get
(
"detail"
,
"auto"
)
!=
"auto"
:
#
if
part_type is text/refusal/image_url/audio_url but
logg
er
.
warning
(
# content is empty,
logg
a
warning
and skip
"'image_url.detail' is currently not supported and "
if
part_type
in
VALID_MESSAGE_CONTENT_MM_PART_TYPES
and
not
content
:
"will be ignored."
)
logger
.
warning
(
"Skipping multimodal part "
"with empty / unparsable content."
)
mm_parser
.
parse_image
(
image_url
[
"url"
])
continue
has_image
=
True
el
if
part_type
==
"audio_ur
l"
:
if
part_type
in
(
"text"
,
"refusa
l"
)
:
audio_url
=
_AudioParser
(
part
)[
"audio_url"
]
texts
.
append
(
content
)
elif
part_type
==
"image_url"
:
mm_parser
.
parse_
audio
(
audio_url
[
"url"
]
)
mm_parser
.
parse_
image
(
content
)
elif
part_typ
e
=
=
"refusal"
:
has_imag
e
=
True
text
=
_RefusalParser
(
part
)[
"refusa
l"
]
elif
part_type
==
"audio_ur
l"
:
texts
.
append
(
te
x
t
)
mm_parser
.
parse_audio
(
con
te
n
t
)
else
:
else
:
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
text_prompt
=
"
\n
"
.
join
(
texts
)
text_prompt
=
"
\n
"
.
join
(
texts
)
if
keep_multimodal_content
:
if
keep_multimodal_content
:
...
@@ -564,14 +654,14 @@ def apply_mistral_chat_template(
...
@@ -564,14 +654,14 @@ def apply_mistral_chat_template(
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
List
[
int
]:
)
->
List
[
int
]:
if
chat_template
is
not
None
:
if
chat_template
is
not
None
:
logger
.
warning
(
print_
warning
_once
(
"'chat_template' cannot be overridden for mistral tokenizer."
)
"'chat_template' cannot be overridden for mistral tokenizer."
)
if
"add_generation_prompt"
in
kwargs
:
if
"add_generation_prompt"
in
kwargs
:
logger
.
warning
(
print_
warning
_once
(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored."
)
"so it will be ignored."
)
if
"continue_final_message"
in
kwargs
:
if
"continue_final_message"
in
kwargs
:
logger
.
warning
(
print_
warning
_once
(
"'continue_final_message' is not supported for mistral tokenizer, "
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored."
)
"so it will be ignored."
)
...
...
vllm/entrypoints/llm.py
View file @
2216a4e5
...
@@ -6,10 +6,10 @@ from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
...
@@ -6,10 +6,10 @@ from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
vllm
import
envs
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchOutput
,
BeamSearchSequence
,
get_beam_search_score
)
BeamSearchSequence
,
get_beam_search_score
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
,
TaskOption
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
from
vllm.entrypoints.chat_utils
import
(
ChatCompletionMessageParam
,
apply_hf_chat_template
,
apply_hf_chat_template
,
apply_mistral_chat_template
,
apply_mistral_chat_template
,
...
@@ -29,7 +29,12 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
...
@@ -29,7 +29,12 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer
)
get_cached_tokenizer
)
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.transformers_utils.tokenizer_group
import
TokenizerGroup
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Counter
,
deprecate_kwargs
,
is_list_of
from
vllm.utils
import
Counter
,
deprecate_args
,
deprecate_kwargs
,
is_list_of
if
envs
.
VLLM_USE_V1
:
from
vllm.v1.engine.llm_engine
import
LLMEngine
# type: ignore
else
:
from
vllm.engine.llm_engine
import
LLMEngine
# type: ignore
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -108,6 +113,12 @@ class LLM:
...
@@ -108,6 +113,12 @@ class LLM:
DEPRECATE_LEGACY
:
ClassVar
[
bool
]
=
False
DEPRECATE_LEGACY
:
ClassVar
[
bool
]
=
False
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
DEPRECATE_INIT_POSARGS
:
ClassVar
[
bool
]
=
True
"""
A flag to toggle whether to deprecate positional arguments in
:meth:`LLM.__init__`.
"""
@
classmethod
@
classmethod
@
contextmanager
@
contextmanager
def
deprecate_legacy_api
(
cls
):
def
deprecate_legacy_api
(
cls
):
...
@@ -117,6 +128,13 @@ class LLM:
...
@@ -117,6 +128,13 @@ class LLM:
cls
.
DEPRECATE_LEGACY
=
False
cls
.
DEPRECATE_LEGACY
=
False
@
deprecate_args
(
start_index
=
2
,
# Ignore self and model
is_deprecated
=
lambda
:
LLM
.
DEPRECATE_INIT_POSARGS
,
additional_message
=
(
"All positional arguments other than `model` will be "
"replaced with keyword arguments in an upcoming version."
),
)
def
__init__
(
def
__init__
(
self
,
self
,
model
:
str
,
model
:
str
,
...
@@ -139,6 +157,8 @@ class LLM:
...
@@ -139,6 +157,8 @@ class LLM:
disable_custom_all_reduce
:
bool
=
False
,
disable_custom_all_reduce
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
disable_async_output_proc
:
bool
=
False
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
# After positional args are removed, move this right below `model`
task
:
TaskOption
=
"auto"
,
**
kwargs
,
**
kwargs
,
)
->
None
:
)
->
None
:
'''
'''
...
@@ -153,6 +173,7 @@ class LLM:
...
@@ -153,6 +173,7 @@ class LLM:
engine_args
=
EngineArgs
(
engine_args
=
EngineArgs
(
model
=
model
,
model
=
model
,
task
=
task
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
tokenizer_mode
=
tokenizer_mode
,
tokenizer_mode
=
tokenizer_mode
,
skip_tokenizer_init
=
skip_tokenizer_init
,
skip_tokenizer_init
=
skip_tokenizer_init
,
...
@@ -316,10 +337,21 @@ class LLM:
...
@@ -316,10 +337,21 @@ class LLM:
considered legacy and may be deprecated in the future. You should
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
instead pass them via the ``inputs`` parameter.
"""
"""
if
self
.
llm_engine
.
model_config
.
embedding_mode
:
task
=
self
.
llm_engine
.
model_config
.
task
raise
ValueError
(
if
task
!=
"generate"
:
messages
=
[
"LLM.generate() is only supported for (conditional) generation "
"LLM.generate() is only supported for (conditional) generation "
"models (XForCausalLM, XForConditionalGeneration)."
)
"models (XForCausalLM, XForConditionalGeneration)."
,
]
supported_tasks
=
self
.
llm_engine
.
model_config
.
supported_tasks
if
"generate"
in
supported_tasks
:
messages
.
append
(
"Your model supports the 'generate' task, but is "
f
"currently initialized for the '
{
task
}
' task. Please "
"initialize the model using `--task generate`."
)
raise
ValueError
(
" "
.
join
(
messages
))
if
prompt_token_ids
is
not
None
:
if
prompt_token_ids
is
not
None
:
parsed_prompts
=
self
.
_convert_v1_inputs
(
parsed_prompts
=
self
.
_convert_v1_inputs
(
...
@@ -433,6 +465,7 @@ class LLM:
...
@@ -433,6 +465,7 @@ class LLM:
for
token_id
,
logprob_obj
in
logprobs
.
items
():
for
token_id
,
logprob_obj
in
logprobs
.
items
():
new_beam
=
BeamSearchSequence
(
new_beam
=
BeamSearchSequence
(
tokens
=
current_beam
.
tokens
+
[
token_id
],
tokens
=
current_beam
.
tokens
+
[
token_id
],
logprobs
=
current_beam
.
logprobs
+
[
logprobs
],
cum_logprob
=
current_beam
.
cum_logprob
+
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
)
logprob_obj
.
logprob
)
...
@@ -691,10 +724,18 @@ class LLM:
...
@@ -691,10 +724,18 @@ class LLM:
considered legacy and may be deprecated in the future. You should
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
instead pass them via the ``inputs`` parameter.
"""
"""
if
not
self
.
llm_engine
.
model_config
.
embedding_mode
:
task
=
self
.
llm_engine
.
model_config
.
task
raise
ValueError
(
if
task
!=
"embedding"
:
"LLM.encode() is only supported for embedding models (XModel)."
messages
=
[
"LLM.encode() is only supported for embedding models."
]
)
supported_tasks
=
self
.
llm_engine
.
model_config
.
supported_tasks
if
"embedding"
in
supported_tasks
:
messages
.
append
(
"Your model supports the 'embedding' task, but is "
f
"currently initialized for the '
{
task
}
' task. Please "
"initialize the model using `--task embedding`."
)
raise
ValueError
(
" "
.
join
(
messages
))
if
prompt_token_ids
is
not
None
:
if
prompt_token_ids
is
not
None
:
parsed_prompts
=
self
.
_convert_v1_inputs
(
parsed_prompts
=
self
.
_convert_v1_inputs
(
...
@@ -904,6 +945,3 @@ class LLM:
...
@@ -904,6 +945,3 @@ class LLM:
def
_is_encoder_decoder_model
(
self
):
def
_is_encoder_decoder_model
(
self
):
return
self
.
llm_engine
.
is_encoder_decoder_model
()
return
self
.
llm_engine
.
is_encoder_decoder_model
()
def
_is_embedding_model
(
self
):
return
self
.
llm_engine
.
is_embedding_model
()
vllm/entrypoints/openai/protocol.py
View file @
2216a4e5
...
@@ -284,6 +284,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -284,6 +284,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
"The priority of the request (lower means earlier handling; "
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
))
"if the served model does not use priority scheduling."
))
request_id
:
str
=
Field
(
default_factory
=
lambda
:
f
"
{
random_uuid
()
}
"
,
description
=
(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
))
# doc: end-chat-completion-extra-params
# doc: end-chat-completion-extra-params
...
@@ -314,9 +320,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
...
@@ -314,9 +320,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
prompt_logprobs
=
self
.
top_logprobs
prompt_logprobs
=
self
.
top_logprobs
guided_json_object
=
None
guided_json_object
=
None
if
(
self
.
response_format
is
not
None
if
self
.
response_format
is
not
None
:
and
self
.
response_format
.
type
==
"json_object"
):
if
self
.
response_format
.
type
==
"json_object"
:
guided_json_object
=
True
guided_json_object
=
True
elif
self
.
response_format
.
type
==
"json_schema"
:
json_schema
=
self
.
response_format
.
json_schema
assert
json_schema
is
not
None
self
.
guided_json
=
json_schema
.
json_schema
if
self
.
guided_decoding_backend
is
None
:
self
.
guided_decoding_backend
=
"lm-format-enforcer"
guided_decoding
=
GuidedDecodingParams
.
from_optional
(
guided_decoding
=
GuidedDecodingParams
.
from_optional
(
json
=
self
.
_get_guided_json_from_tool
()
or
self
.
guided_json
,
json
=
self
.
_get_guided_json_from_tool
()
or
self
.
guided_json
,
...
@@ -537,8 +549,8 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -537,8 +549,8 @@ class CompletionRequest(OpenAIBaseModel):
default
=
None
,
default
=
None
,
description
=
description
=
(
"Similar to chat completion, this parameter specifies the format of "
(
"Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'}
or
{'type': '
text' } is
"
"output. Only {'type': 'json_object'}
,
{'type': '
json_schema'} or
"
"supported."
),
"
{'type': 'text' } is
supported."
),
)
)
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
Field
(
guided_json
:
Optional
[
Union
[
str
,
dict
,
BaseModel
]]
=
Field
(
default
=
None
,
default
=
None
,
...
...
vllm/entrypoints/openai/serving_chat.py
View file @
2216a4e5
...
@@ -38,7 +38,7 @@ from vllm.sequence import Logprob
...
@@ -38,7 +38,7 @@ from vllm.sequence import Logprob
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
from
vllm.tracing
import
(
contains_trace_headers
,
extract_trace_headers
,
log_tracing_disabled_warning
)
log_tracing_disabled_warning
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
,
MistralTokenizer
from
vllm.utils
import
iterate_with_cancellation
,
random_uuid
from
vllm.utils
import
iterate_with_cancellation
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -176,7 +176,7 @@ class OpenAIServingChat(OpenAIServing):
...
@@ -176,7 +176,7 @@ class OpenAIServingChat(OpenAIServing):
"
\"
auto
\"
tool choice requires "
"
\"
auto
\"
tool choice requires "
"--enable-auto-tool-choice and --tool-call-parser to be set"
)
"--enable-auto-tool-choice and --tool-call-parser to be set"
)
request_id
=
f
"chat-
{
r
andom_uuid
()
}
"
request_id
=
f
"chat-
{
r
equest
.
request_id
}
"
request_metadata
=
RequestResponseMetadata
(
request_id
=
request_id
)
request_metadata
=
RequestResponseMetadata
(
request_id
=
request_id
)
if
raw_request
:
if
raw_request
:
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
2216a4e5
...
@@ -258,6 +258,14 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -258,6 +258,14 @@ class OpenAIServingCompletion(OpenAIServing):
has_echoed
=
[
False
]
*
num_choices
*
num_prompts
has_echoed
=
[
False
]
*
num_choices
*
num_prompts
num_prompt_tokens
=
[
0
]
*
num_prompts
num_prompt_tokens
=
[
0
]
*
num_prompts
stream_options
=
request
.
stream_options
if
stream_options
:
include_usage
=
stream_options
.
include_usage
include_continuous_usage
=
include_usage
and
\
stream_options
.
continuous_usage_stats
else
:
include_usage
,
include_continuous_usage
=
False
,
False
try
:
try
:
async
for
prompt_idx
,
res
in
result_generator
:
async
for
prompt_idx
,
res
in
result_generator
:
prompt_token_ids
=
res
.
prompt_token_ids
prompt_token_ids
=
res
.
prompt_token_ids
...
@@ -276,28 +284,25 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -276,28 +284,25 @@ class OpenAIServingCompletion(OpenAIServing):
i
=
output
.
index
+
prompt_idx
*
num_choices
i
=
output
.
index
+
prompt_idx
*
num_choices
assert
request
.
max_tokens
is
not
None
assert
request
.
max_tokens
is
not
None
if
request
.
echo
and
request
.
max_tokens
==
0
:
if
request
.
echo
and
not
has_echoed
[
i
]
:
assert
prompt_token_ids
is
not
None
assert
prompt_token_ids
is
not
None
assert
prompt_text
is
not
None
assert
prompt_text
is
not
None
# only return the prompt
if
request
.
max_tokens
==
0
:
delta_text
=
prompt_text
# only return the prompt
delta_token_ids
=
prompt_token_ids
delta_text
=
prompt_text
out_logprobs
=
prompt_logprobs
delta_token_ids
=
prompt_token_ids
has_echoed
[
i
]
=
True
out_logprobs
=
prompt_logprobs
elif
(
request
.
echo
and
request
.
max_tokens
>
0
else
:
and
not
has_echoed
[
i
]):
assert
prompt_logprobs
is
not
None
assert
prompt_token_ids
is
not
None
# echo the prompt and first token
assert
prompt_text
is
not
None
delta_text
=
prompt_text
+
output
.
text
assert
prompt_logprobs
is
not
None
delta_token_ids
=
[
# echo the prompt and first token
*
prompt_token_ids
,
*
output
.
token_ids
delta_text
=
prompt_text
+
output
.
text
]
delta_token_ids
=
[
out_logprobs
=
[
*
prompt_token_ids
,
*
output
.
token_ids
*
prompt_logprobs
,
]
*
(
output
.
logprobs
or
[]),
out_logprobs
=
[
]
*
prompt_logprobs
,
*
(
output
.
logprobs
or
[]),
]
has_echoed
[
i
]
=
True
has_echoed
[
i
]
=
True
else
:
else
:
# return just the delta
# return just the delta
...
@@ -341,45 +346,39 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -341,45 +346,39 @@ class OpenAIServingCompletion(OpenAIServing):
stop_reason
=
stop_reason
,
stop_reason
=
stop_reason
,
)
)
])
])
if
(
request
.
stream_options
if
include_continuous_usage
:
and
request
.
stream_options
.
include_usage
):
prompt_tokens
=
num_prompt_tokens
[
prompt_idx
]
if
(
request
.
stream_options
.
continuous_usage_stats
completion_tokens
=
previous_num_tokens
[
i
]
or
output
.
finish_reason
is
not
None
):
chunk
.
usage
=
UsageInfo
(
prompt_tokens
=
num_prompt_tokens
[
prompt_idx
]
prompt_tokens
=
prompt_tokens
,
completion_tokens
=
previous_num_tokens
[
i
]
completion_tokens
=
completion_tokens
,
usage
=
UsageInfo
(
total_tokens
=
prompt_tokens
+
completion_tokens
,
prompt_tokens
=
prompt_tokens
,
)
completion_tokens
=
completion_tokens
,
total_tokens
=
prompt_tokens
+
completion_tokens
,
)
if
request
.
stream_options
.
continuous_usage_stats
:
chunk
.
usage
=
usage
else
:
chunk
.
usage
=
None
response_json
=
chunk
.
model_dump_json
(
exclude_unset
=
False
)
response_json
=
chunk
.
model_dump_json
(
exclude_unset
=
False
)
yield
f
"data:
{
response_json
}
\n\n
"
yield
f
"data:
{
response_json
}
\n\n
"
if
(
request
.
stream_options
total_prompt_tokens
=
sum
(
num_prompt_tokens
)
and
request
.
stream_options
.
include_usage
):
total_completion_tokens
=
sum
(
previous_num_tokens
)
final_usage_info
=
UsageInfo
(
prompt_tokens
=
total_prompt_tokens
,
completion_tokens
=
total_completion_tokens
,
total_tokens
=
total_prompt_tokens
+
total_completion_tokens
)
if
include_usage
:
final_usage_chunk
=
CompletionStreamResponse
(
final_usage_chunk
=
CompletionStreamResponse
(
id
=
request_id
,
id
=
request_id
,
created
=
created_time
,
created
=
created_time
,
model
=
model_name
,
model
=
model_name
,
choices
=
[],
choices
=
[],
usage
=
usage
,
usage
=
final_usage_info
,
)
)
final_usage_data
=
(
final_usage_chunk
.
model_dump_json
(
final_usage_data
=
(
final_usage_chunk
.
model_dump_json
(
exclude_unset
=
False
,
exclude_none
=
True
))
exclude_unset
=
False
,
exclude_none
=
True
))
yield
f
"data:
{
final_usage_data
}
\n\n
"
yield
f
"data:
{
final_usage_data
}
\n\n
"
# report to FastAPI middleware aggregate usage across all choices
# report to FastAPI middleware aggregate usage across all choices
total_prompt_tokens
=
sum
(
num_prompt_tokens
)
request_metadata
.
final_usage_info
=
final_usage_info
total_completion_tokens
=
sum
(
previous_num_tokens
)
request_metadata
.
final_usage_info
=
UsageInfo
(
prompt_tokens
=
total_prompt_tokens
,
completion_tokens
=
total_completion_tokens
,
total_tokens
=
total_prompt_tokens
+
total_completion_tokens
)
except
ValueError
as
e
:
except
ValueError
as
e
:
# TODO: Use a vllm-specific Validation Error
# TODO: Use a vllm-specific Validation Error
...
@@ -413,26 +412,26 @@ class OpenAIServingCompletion(OpenAIServing):
...
@@ -413,26 +412,26 @@ class OpenAIServingCompletion(OpenAIServing):
for
output
in
final_res
.
outputs
:
for
output
in
final_res
.
outputs
:
assert
request
.
max_tokens
is
not
None
assert
request
.
max_tokens
is
not
None
if
request
.
echo
and
request
.
max_tokens
==
0
:
if
request
.
echo
:
assert
prompt_text
is
not
None
token_ids
=
prompt_token_ids
out_logprobs
=
prompt_logprobs
output_text
=
prompt_text
elif
request
.
echo
and
request
.
max_tokens
>
0
:
assert
prompt_text
is
not
None
assert
prompt_text
is
not
None
token_ids
=
[
*
prompt_token_ids
,
*
output
.
token_ids
]
if
request
.
max_tokens
==
0
:
token_ids
=
prompt_token_ids
if
request
.
logprobs
is
None
:
out_logprobs
=
prompt_logprobs
out
_logprobs
=
None
out
put_text
=
prompt_text
else
:
else
:
assert
prompt_logprobs
is
not
None
token_ids
=
[
*
prompt_token_ids
,
*
output
.
token_ids
]
assert
output
.
logprobs
is
not
None
out_logprobs
=
[
if
request
.
logprobs
is
None
:
*
prompt_logprobs
,
out_logprobs
=
None
*
output
.
logprobs
,
else
:
]
assert
prompt_logprobs
is
not
None
assert
output
.
logprobs
is
not
None
output_text
=
prompt_text
+
output
.
text
out_logprobs
=
[
*
prompt_logprobs
,
*
output
.
logprobs
,
]
output_text
=
prompt_text
+
output
.
text
else
:
else
:
token_ids
=
output
.
token_ids
token_ids
=
output
.
token_ids
out_logprobs
=
output
.
logprobs
out_logprobs
=
output
.
logprobs
...
...
vllm/entrypoints/openai/serving_embedding.py
View file @
2216a4e5
...
@@ -83,7 +83,8 @@ class OpenAIServingEmbedding(OpenAIServing):
...
@@ -83,7 +83,8 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_modules
=
None
,
lora_modules
=
None
,
prompt_adapters
=
None
,
prompt_adapters
=
None
,
request_logger
=
request_logger
)
request_logger
=
request_logger
)
self
.
_enabled
=
self
.
_check_embedding_mode
(
model_config
.
embedding_mode
)
self
.
_enabled
=
self
.
_check_embedding_mode
(
model_config
.
task
==
"embedding"
)
async
def
create_embedding
(
async
def
create_embedding
(
self
,
self
,
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment