Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
38d80967
Commit
38d80967
authored
Sep 12, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori
parents
33650733
880c741b
Changes
544
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
924 additions
and
333 deletions
+924
-333
vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
+1
-1
vllm/distributed/kv_transfer/kv_transfer_state.py
vllm/distributed/kv_transfer/kv_transfer_state.py
+7
-0
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+59
-11
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+49
-45
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+4
-3
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+9
-8
vllm/engine/metrics.py
vllm/engine/metrics.py
+16
-2
vllm/engine/metrics_types.py
vllm/engine/metrics_types.py
+1
-1
vllm/engine/multiprocessing/client.py
vllm/engine/multiprocessing/client.py
+2
-2
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+1
-1
vllm/engine/protocol.py
vllm/engine/protocol.py
+6
-7
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+363
-190
vllm/entrypoints/context.py
vllm/entrypoints/context.py
+218
-24
vllm/entrypoints/harmony_utils.py
vllm/entrypoints/harmony_utils.py
+105
-21
vllm/entrypoints/launcher.py
vllm/entrypoints/launcher.py
+1
-1
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+15
-2
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+16
-2
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+5
-6
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+45
-5
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
544 of 544+
files are displayed.
Plain diff
Email patch
vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py
View file @
38d80967
...
@@ -20,7 +20,7 @@ from typing import Callable, Optional
...
@@ -20,7 +20,7 @@ from typing import Callable, Optional
import
torch
import
torch
from
vllm.config
import
KVTransferConfig
from
vllm.config
.kv_transfer
import
KVTransferConfig
from
vllm.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
vllm.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
vllm.distributed.kv_transfer.kv_pipe.base
import
KVPipeBase
from
vllm.distributed.kv_transfer.kv_pipe.base
import
KVPipeBase
from
vllm.distributed.utils
import
StatelessProcessGroup
from
vllm.distributed.utils
import
StatelessProcessGroup
...
...
vllm/distributed/kv_transfer/kv_transfer_state.py
View file @
38d80967
...
@@ -64,3 +64,10 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
...
@@ -64,3 +64,10 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
config
=
vllm_config
,
role
=
KVConnectorRole
.
WORKER
)
config
=
vllm_config
,
role
=
KVConnectorRole
.
WORKER
)
else
:
else
:
raise
ValueError
(
"V0 is no longer supported"
)
raise
ValueError
(
"V0 is no longer supported"
)
def
ensure_kv_transfer_shutdown
()
->
None
:
global
_KV_CONNECTOR_AGENT
if
_KV_CONNECTOR_AGENT
is
not
None
:
_KV_CONNECTOR_AGENT
.
shutdown
()
_KV_CONNECTOR_AGENT
=
None
vllm/distributed/parallel_state.py
View file @
38d80967
...
@@ -29,6 +29,7 @@ import weakref
...
@@ -29,6 +29,7 @@ import weakref
from
collections
import
namedtuple
from
collections
import
namedtuple
from
contextlib
import
contextmanager
,
nullcontext
from
contextlib
import
contextmanager
,
nullcontext
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
datetime
import
timedelta
from
multiprocessing
import
shared_memory
from
multiprocessing
import
shared_memory
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
Union
from
unittest.mock
import
patch
from
unittest.mock
import
patch
...
@@ -904,6 +905,18 @@ def get_tensor_model_parallel_group():
...
@@ -904,6 +905,18 @@ def get_tensor_model_parallel_group():
return
get_tp_group
()
return
get_tp_group
()
_DCP
:
Optional
[
GroupCoordinator
]
=
None
def
get_dcp_group
()
->
GroupCoordinator
:
assert
_DCP
is
not
None
,
(
"decode context model parallel group is not initialized"
)
return
_DCP
# kept for backward compatibility
get_context_model_parallel_group
=
get_dcp_group
_PP
:
Optional
[
GroupCoordinator
]
=
None
_PP
:
Optional
[
GroupCoordinator
]
=
None
_DP
:
Optional
[
GroupCoordinator
]
=
None
_DP
:
Optional
[
GroupCoordinator
]
=
None
...
@@ -939,8 +952,8 @@ def get_pipeline_model_parallel_group():
...
@@ -939,8 +952,8 @@ def get_pipeline_model_parallel_group():
def
graph_capture
(
device
:
torch
.
device
):
def
graph_capture
(
device
:
torch
.
device
):
"""
"""
`graph_capture` is a context manager which should surround the code that
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that
th
e
is capturing the CUDA graph. Its main purpose is to ensure that
som
e
some
operations will be run after the graph is captured, before the graph
operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
stream that the graph capture is running on. This stream is set to the
...
@@ -966,13 +979,12 @@ def set_custom_all_reduce(enable: bool):
...
@@ -966,13 +979,12 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE
=
enable
_ENABLE_CUSTOM_ALL_REDUCE
=
enable
def
init_distributed_environment
(
def
init_distributed_environment
(
world_size
:
int
=
-
1
,
world_size
:
int
=
-
1
,
rank
:
int
=
-
1
,
rank
:
int
=
-
1
,
distributed_init_method
:
str
=
"env://"
,
distributed_init_method
:
str
=
"env://"
,
local_rank
:
int
=
-
1
,
local_rank
:
int
=
-
1
,
backend
:
str
=
"nccl"
,
backend
:
str
=
"nccl"
,
timeout
:
Optional
[
timedelta
]
=
None
):
):
logger
.
debug
(
logger
.
debug
(
"world_size=%d rank=%d local_rank=%d "
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s"
,
world_size
,
rank
,
local_rank
,
"distributed_init_method=%s backend=%s"
,
world_size
,
rank
,
local_rank
,
...
@@ -1008,7 +1020,8 @@ def init_distributed_environment(
...
@@ -1008,7 +1020,8 @@ def init_distributed_environment(
backend
=
backend
,
backend
=
backend
,
init_method
=
distributed_init_method
,
init_method
=
distributed_init_method
,
world_size
=
world_size
,
world_size
=
world_size
,
rank
=
rank
)
rank
=
rank
,
timeout
=
timeout
)
# set the local rank
# set the local rank
# local_rank is not available in torch ProcessGroup,
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
# see https://github.com/pytorch/pytorch/issues/122816
...
@@ -1034,6 +1047,7 @@ def init_distributed_environment(
...
@@ -1034,6 +1047,7 @@ def init_distributed_environment(
def
initialize_model_parallel
(
def
initialize_model_parallel
(
tensor_model_parallel_size
:
int
=
1
,
tensor_model_parallel_size
:
int
=
1
,
pipeline_model_parallel_size
:
int
=
1
,
pipeline_model_parallel_size
:
int
=
1
,
decode_context_model_parallel_size
:
Optional
[
int
]
=
1
,
backend
:
Optional
[
str
]
=
None
,
backend
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
"""
"""
...
@@ -1098,6 +1112,23 @@ def initialize_model_parallel(
...
@@ -1098,6 +1112,23 @@ def initialize_model_parallel(
use_message_queue_broadcaster
=
True
,
use_message_queue_broadcaster
=
True
,
group_name
=
"tp"
)
group_name
=
"tp"
)
# Build the DCP model-parallel groups.
global
_DCP
assert
_DCP
is
None
,
(
"decode context model parallel group is already initialized"
)
# Note(hc): In the current implementation of decode context parallel,
# dcp_size must not exceed tp_size, because the world size does not
# change by DCP, it simply reuses the GPUs of TP group, and split one
# TP group into tp_size//dcp_size DCP groups.
group_ranks
=
all_ranks
.
reshape
(
-
1
,
decode_context_model_parallel_size
).
unbind
(
0
)
group_ranks
=
[
x
.
tolist
()
for
x
in
group_ranks
]
_DCP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
backend
,
use_message_queue_broadcaster
=
True
,
group_name
=
"dcp"
)
# Build the pipeline model-parallel groups.
# Build the pipeline model-parallel groups.
global
_PP
global
_PP
assert
_PP
is
None
,
(
assert
_PP
is
None
,
(
...
@@ -1141,6 +1172,7 @@ def initialize_model_parallel(
...
@@ -1141,6 +1172,7 @@ def initialize_model_parallel(
def
ensure_model_parallel_initialized
(
def
ensure_model_parallel_initialized
(
tensor_model_parallel_size
:
int
,
tensor_model_parallel_size
:
int
,
pipeline_model_parallel_size
:
int
,
pipeline_model_parallel_size
:
int
,
decode_context_model_parallel_size
:
Optional
[
int
]
=
1
,
backend
:
Optional
[
str
]
=
None
,
backend
:
Optional
[
str
]
=
None
,
)
->
None
:
)
->
None
:
"""Helper to initialize model parallel groups if they are not initialized,
"""Helper to initialize model parallel groups if they are not initialized,
...
@@ -1151,7 +1183,8 @@ def ensure_model_parallel_initialized(
...
@@ -1151,7 +1183,8 @@ def ensure_model_parallel_initialized(
get_world_group
().
device_group
)
get_world_group
().
device_group
)
if
not
model_parallel_is_initialized
():
if
not
model_parallel_is_initialized
():
initialize_model_parallel
(
tensor_model_parallel_size
,
initialize_model_parallel
(
tensor_model_parallel_size
,
pipeline_model_parallel_size
,
backend
)
pipeline_model_parallel_size
,
decode_context_model_parallel_size
,
backend
)
return
return
assert
(
assert
(
...
@@ -1226,6 +1259,16 @@ def get_tensor_model_parallel_rank():
...
@@ -1226,6 +1259,16 @@ def get_tensor_model_parallel_rank():
return
get_tp_group
().
rank_in_group
return
get_tp_group
().
rank_in_group
def
get_decode_context_model_parallel_world_size
():
"""Return world size for the decode context model parallel group."""
return
get_dcp_group
().
world_size
def
get_decode_context_model_parallel_rank
():
"""Return my rank for the decode context model parallel group."""
return
get_dcp_group
().
rank_in_group
def
get_node_count
()
->
int
:
def
get_node_count
()
->
int
:
"""Return the total number of nodes in the distributed environment. """
"""Return the total number of nodes in the distributed environment. """
assert
_NODE_COUNT
is
not
None
,
(
assert
_NODE_COUNT
is
not
None
,
(
...
@@ -1246,6 +1289,11 @@ def destroy_model_parallel():
...
@@ -1246,6 +1289,11 @@ def destroy_model_parallel():
_PP
.
destroy
()
_PP
.
destroy
()
_PP
=
None
_PP
=
None
global
_DCP
if
_DCP
:
_DCP
.
destroy
()
_DCP
=
None
global
_DP
global
_DP
if
_DP
:
if
_DP
:
_DP
.
destroy
()
_DP
.
destroy
()
...
...
vllm/engine/arg_utils.py
View file @
38d80967
...
@@ -22,9 +22,9 @@ from typing_extensions import TypeIs, deprecated
...
@@ -22,9 +22,9 @@ from typing_extensions import TypeIs, deprecated
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
(
BlockSize
,
CacheConfig
,
CacheDType
,
CompilationConfig
,
from
vllm.config
import
(
BlockSize
,
CacheConfig
,
CacheDType
,
CompilationConfig
,
ConfigFormat
,
ConfigType
,
ConvertOption
,
ConfigType
,
ConvertOption
,
DecodingConfig
,
DecodingConfig
,
DetailedTraceModules
,
Device
,
DetailedTraceModules
,
Device
,
DeviceConfig
,
DeviceConfig
,
DistributedExecutorBackend
,
EPLBConfig
,
DistributedExecutorBackend
,
EPLBConfig
,
GuidedDecodingBackend
,
HfOverrides
,
KVEventsConfig
,
GuidedDecodingBackend
,
HfOverrides
,
KVEventsConfig
,
KVTransferConfig
,
LoadConfig
,
LogprobsMode
,
KVTransferConfig
,
LoadConfig
,
LogprobsMode
,
LoRAConfig
,
MambaDType
,
MMEncoderTPMode
,
ModelConfig
,
LoRAConfig
,
MambaDType
,
MMEncoderTPMode
,
ModelConfig
,
...
@@ -227,8 +227,14 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
...
@@ -227,8 +227,14 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
elif
contains_type
(
type_hints
,
int
):
elif
contains_type
(
type_hints
,
int
):
kwargs
[
name
][
"type"
]
=
int
kwargs
[
name
][
"type"
]
=
int
# Special case for large integers
# Special case for large integers
if
name
in
{
"max_model_len"
,
"max_num_batched_tokens"
}:
human_readable_ints
=
{
"max_model_len"
,
"max_num_batched_tokens"
,
"kv_cache_memory_bytes"
,
}
if
name
in
human_readable_ints
:
kwargs
[
name
][
"type"
]
=
human_readable_int
kwargs
[
name
][
"type"
]
=
human_readable_int
kwargs
[
name
][
"help"
]
+=
f
"
\n\n
{
human_readable_int
.
__doc__
}
"
elif
contains_type
(
type_hints
,
float
):
elif
contains_type
(
type_hints
,
float
):
kwargs
[
name
][
"type"
]
=
float
kwargs
[
name
][
"type"
]
=
float
elif
(
contains_type
(
type_hints
,
dict
)
elif
(
contains_type
(
type_hints
,
dict
)
...
@@ -289,6 +295,7 @@ class EngineArgs:
...
@@ -289,6 +295,7 @@ class EngineArgs:
trust_remote_code
:
bool
=
ModelConfig
.
trust_remote_code
trust_remote_code
:
bool
=
ModelConfig
.
trust_remote_code
allowed_local_media_path
:
str
=
ModelConfig
.
allowed_local_media_path
allowed_local_media_path
:
str
=
ModelConfig
.
allowed_local_media_path
download_dir
:
Optional
[
str
]
=
LoadConfig
.
download_dir
download_dir
:
Optional
[
str
]
=
LoadConfig
.
download_dir
safetensors_load_strategy
:
str
=
LoadConfig
.
safetensors_load_strategy
load_format
:
Union
[
str
,
LoadFormats
]
=
LoadConfig
.
load_format
load_format
:
Union
[
str
,
LoadFormats
]
=
LoadConfig
.
load_format
config_format
:
str
=
ModelConfig
.
config_format
config_format
:
str
=
ModelConfig
.
config_format
dtype
:
ModelDType
=
ModelConfig
.
dtype
dtype
:
ModelDType
=
ModelConfig
.
dtype
...
@@ -306,6 +313,8 @@ class EngineArgs:
...
@@ -306,6 +313,8 @@ class EngineArgs:
# number of P/D disaggregation (or other disaggregation) workers
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size
:
int
=
ParallelConfig
.
pipeline_parallel_size
pipeline_parallel_size
:
int
=
ParallelConfig
.
pipeline_parallel_size
tensor_parallel_size
:
int
=
ParallelConfig
.
tensor_parallel_size
tensor_parallel_size
:
int
=
ParallelConfig
.
tensor_parallel_size
decode_context_parallel_size
:
int
=
\
ParallelConfig
.
decode_context_parallel_size
data_parallel_size
:
int
=
ParallelConfig
.
data_parallel_size
data_parallel_size
:
int
=
ParallelConfig
.
data_parallel_size
data_parallel_rank
:
Optional
[
int
]
=
None
data_parallel_rank
:
Optional
[
int
]
=
None
data_parallel_start_rank
:
Optional
[
int
]
=
None
data_parallel_start_rank
:
Optional
[
int
]
=
None
...
@@ -332,6 +341,7 @@ class EngineArgs:
...
@@ -332,6 +341,7 @@ class EngineArgs:
swap_space
:
float
=
CacheConfig
.
swap_space
swap_space
:
float
=
CacheConfig
.
swap_space
cpu_offload_gb
:
float
=
CacheConfig
.
cpu_offload_gb
cpu_offload_gb
:
float
=
CacheConfig
.
cpu_offload_gb
gpu_memory_utilization
:
float
=
CacheConfig
.
gpu_memory_utilization
gpu_memory_utilization
:
float
=
CacheConfig
.
gpu_memory_utilization
kv_cache_memory_bytes
:
Optional
[
int
]
=
CacheConfig
.
kv_cache_memory_bytes
max_num_batched_tokens
:
Optional
[
max_num_batched_tokens
:
Optional
[
int
]
=
SchedulerConfig
.
max_num_batched_tokens
int
]
=
SchedulerConfig
.
max_num_batched_tokens
max_num_partial_prefills
:
int
=
SchedulerConfig
.
max_num_partial_prefills
max_num_partial_prefills
:
int
=
SchedulerConfig
.
max_num_partial_prefills
...
@@ -417,8 +427,6 @@ class EngineArgs:
...
@@ -417,8 +427,6 @@ class EngineArgs:
scheduling_policy
:
SchedulerPolicy
=
SchedulerConfig
.
policy
scheduling_policy
:
SchedulerPolicy
=
SchedulerConfig
.
policy
scheduler_cls
:
Union
[
str
,
Type
[
object
]]
=
SchedulerConfig
.
scheduler_cls
scheduler_cls
:
Union
[
str
,
Type
[
object
]]
=
SchedulerConfig
.
scheduler_cls
override_neuron_config
:
dict
[
str
,
Any
]
=
\
get_field
(
ModelConfig
,
"override_neuron_config"
)
override_pooler_config
:
Optional
[
Union
[
dict
,
PoolerConfig
]]
=
\
override_pooler_config
:
Optional
[
Union
[
dict
,
PoolerConfig
]]
=
\
ModelConfig
.
override_pooler_config
ModelConfig
.
override_pooler_config
compilation_config
:
CompilationConfig
=
\
compilation_config
:
CompilationConfig
=
\
...
@@ -547,7 +555,6 @@ class EngineArgs:
...
@@ -547,7 +555,6 @@ class EngineArgs:
help
=
"Disable async output processing. This may result in "
help
=
"Disable async output processing. This may result in "
"lower performance."
)
"lower performance."
)
model_group
.
add_argument
(
"--config-format"
,
model_group
.
add_argument
(
"--config-format"
,
choices
=
[
f
.
value
for
f
in
ConfigFormat
],
**
model_kwargs
[
"config_format"
])
**
model_kwargs
[
"config_format"
])
# This one is a special case because it can bool
# This one is a special case because it can bool
# or str. TODO: Handle this in get_kwargs
# or str. TODO: Handle this in get_kwargs
...
@@ -559,8 +566,6 @@ class EngineArgs:
...
@@ -559,8 +566,6 @@ class EngineArgs:
help
=
model_kwargs
[
"hf_token"
][
"help"
])
help
=
model_kwargs
[
"hf_token"
][
"help"
])
model_group
.
add_argument
(
"--hf-overrides"
,
model_group
.
add_argument
(
"--hf-overrides"
,
**
model_kwargs
[
"hf_overrides"
])
**
model_kwargs
[
"hf_overrides"
])
model_group
.
add_argument
(
"--override-neuron-config"
,
**
model_kwargs
[
"override_neuron_config"
])
model_group
.
add_argument
(
"--override-pooler-config"
,
model_group
.
add_argument
(
"--override-pooler-config"
,
**
model_kwargs
[
"override_pooler_config"
])
**
model_kwargs
[
"override_pooler_config"
])
model_group
.
add_argument
(
"--logits-processor-pattern"
,
model_group
.
add_argument
(
"--logits-processor-pattern"
,
...
@@ -590,6 +595,8 @@ class EngineArgs:
...
@@ -590,6 +595,8 @@ class EngineArgs:
load_group
.
add_argument
(
"--load-format"
,
**
load_kwargs
[
"load_format"
])
load_group
.
add_argument
(
"--load-format"
,
**
load_kwargs
[
"load_format"
])
load_group
.
add_argument
(
"--download-dir"
,
load_group
.
add_argument
(
"--download-dir"
,
**
load_kwargs
[
"download_dir"
])
**
load_kwargs
[
"download_dir"
])
load_group
.
add_argument
(
"--safetensors-load-strategy"
,
**
load_kwargs
[
"safetensors_load_strategy"
])
load_group
.
add_argument
(
"--model-loader-extra-config"
,
load_group
.
add_argument
(
"--model-loader-extra-config"
,
**
load_kwargs
[
"model_loader_extra_config"
])
**
load_kwargs
[
"model_loader_extra_config"
])
load_group
.
add_argument
(
"--ignore-patterns"
,
load_group
.
add_argument
(
"--ignore-patterns"
,
...
@@ -636,6 +643,9 @@ class EngineArgs:
...
@@ -636,6 +643,9 @@ class EngineArgs:
**
parallel_kwargs
[
"pipeline_parallel_size"
])
**
parallel_kwargs
[
"pipeline_parallel_size"
])
parallel_group
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
parallel_group
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
**
parallel_kwargs
[
"tensor_parallel_size"
])
**
parallel_kwargs
[
"tensor_parallel_size"
])
parallel_group
.
add_argument
(
"--decode-context-parallel-size"
,
"-dcp"
,
**
parallel_kwargs
[
"decode_context_parallel_size"
])
parallel_group
.
add_argument
(
"--data-parallel-size"
,
"-dp"
,
parallel_group
.
add_argument
(
"--data-parallel-size"
,
"-dp"
,
**
parallel_kwargs
[
"data_parallel_size"
])
**
parallel_kwargs
[
"data_parallel_size"
])
parallel_group
.
add_argument
(
parallel_group
.
add_argument
(
...
@@ -731,6 +741,8 @@ class EngineArgs:
...
@@ -731,6 +741,8 @@ class EngineArgs:
cache_group
.
add_argument
(
"--block-size"
,
**
cache_kwargs
[
"block_size"
])
cache_group
.
add_argument
(
"--block-size"
,
**
cache_kwargs
[
"block_size"
])
cache_group
.
add_argument
(
"--gpu-memory-utilization"
,
cache_group
.
add_argument
(
"--gpu-memory-utilization"
,
**
cache_kwargs
[
"gpu_memory_utilization"
])
**
cache_kwargs
[
"gpu_memory_utilization"
])
cache_group
.
add_argument
(
"--kv-cache-memory-bytes"
,
**
cache_kwargs
[
"kv_cache_memory_bytes"
])
cache_group
.
add_argument
(
"--swap-space"
,
**
cache_kwargs
[
"swap_space"
])
cache_group
.
add_argument
(
"--swap-space"
,
**
cache_kwargs
[
"swap_space"
])
cache_group
.
add_argument
(
"--kv-cache-dtype"
,
cache_group
.
add_argument
(
"--kv-cache-dtype"
,
**
cache_kwargs
[
"cache_dtype"
])
**
cache_kwargs
[
"cache_dtype"
])
...
@@ -987,7 +999,6 @@ class EngineArgs:
...
@@ -987,7 +999,6 @@ class EngineArgs:
mm_processor_kwargs
=
self
.
mm_processor_kwargs
,
mm_processor_kwargs
=
self
.
mm_processor_kwargs
,
mm_processor_cache_gb
=
self
.
mm_processor_cache_gb
,
mm_processor_cache_gb
=
self
.
mm_processor_cache_gb
,
mm_encoder_tp_mode
=
self
.
mm_encoder_tp_mode
,
mm_encoder_tp_mode
=
self
.
mm_encoder_tp_mode
,
override_neuron_config
=
self
.
override_neuron_config
,
override_pooler_config
=
self
.
override_pooler_config
,
override_pooler_config
=
self
.
override_pooler_config
,
logits_processor_pattern
=
self
.
logits_processor_pattern
,
logits_processor_pattern
=
self
.
logits_processor_pattern
,
generation_config
=
self
.
generation_config
,
generation_config
=
self
.
generation_config
,
...
@@ -1024,6 +1035,7 @@ class EngineArgs:
...
@@ -1024,6 +1035,7 @@ class EngineArgs:
return
LoadConfig
(
return
LoadConfig
(
load_format
=
self
.
load_format
,
load_format
=
self
.
load_format
,
download_dir
=
self
.
download_dir
,
download_dir
=
self
.
download_dir
,
safetensors_load_strategy
=
self
.
safetensors_load_strategy
,
device
=
"cpu"
device
=
"cpu"
if
is_online_quantization
(
self
.
quantization
)
else
None
,
if
is_online_quantization
(
self
.
quantization
)
else
None
,
model_loader_extra_config
=
self
.
model_loader_extra_config
,
model_loader_extra_config
=
self
.
model_loader_extra_config
,
...
@@ -1053,9 +1065,10 @@ class EngineArgs:
...
@@ -1053,9 +1065,10 @@ class EngineArgs:
SpeculatorsConfig
)
SpeculatorsConfig
)
if
self
.
speculative_config
is
None
:
if
self
.
speculative_config
is
None
:
hf_config
=
get_config
(
self
.
hf_config_path
or
self
.
model
,
hf_config
=
get_config
(
self
.
trust_remote_code
,
self
.
revision
,
self
.
hf_config_path
or
target_model_config
.
model
,
self
.
code_revision
,
self
.
config_format
)
self
.
trust_remote_code
,
self
.
revision
,
self
.
code_revision
,
self
.
config_format
)
# if loading a SpeculatorsConfig, load the speculative_config
# if loading a SpeculatorsConfig, load the speculative_config
# details from the config directly
# details from the config directly
...
@@ -1065,7 +1078,7 @@ class EngineArgs:
...
@@ -1065,7 +1078,7 @@ class EngineArgs:
self
.
speculative_config
=
{}
self
.
speculative_config
=
{}
self
.
speculative_config
[
self
.
speculative_config
[
"num_speculative_tokens"
]
=
hf_config
.
num_lookahead_tokens
"num_speculative_tokens"
]
=
hf_config
.
num_lookahead_tokens
self
.
speculative_config
[
"model"
]
=
self
.
model
self
.
speculative_config
[
"model"
]
=
target_model_config
.
model
self
.
speculative_config
[
"method"
]
=
hf_config
.
method
self
.
speculative_config
[
"method"
]
=
hf_config
.
method
else
:
else
:
return
None
return
None
...
@@ -1156,9 +1169,21 @@ class EngineArgs:
...
@@ -1156,9 +1169,21 @@ class EngineArgs:
# global layers in interleaved sliding window models.
# global layers in interleaved sliding window models.
sliding_window
=
model_config
.
get_sliding_window
()
sliding_window
=
model_config
.
get_sliding_window
()
# Note(hc): In the current implementation of decode context
# parallel(DCP), tp_size needs to be divisible by dcp_size,
# because the world size does not change by dcp, it simply
# reuses the GPUs of TP group, and split one TP group into
# tp_size//dcp_size DCP groups.
assert
self
.
tensor_parallel_size
%
self
.
decode_context_parallel_size
\
==
0
,
(
f
"tp_size=
{
self
.
tensor_parallel_size
}
must be divisible by"
f
"dcp_size=
{
self
.
decode_context_parallel_size
}
."
)
cache_config
=
CacheConfig
(
cache_config
=
CacheConfig
(
block_size
=
self
.
block_size
,
block_size
=
self
.
block_size
,
gpu_memory_utilization
=
self
.
gpu_memory_utilization
,
gpu_memory_utilization
=
self
.
gpu_memory_utilization
,
kv_cache_memory_bytes
=
self
.
kv_cache_memory_bytes
,
swap_space
=
self
.
swap_space
,
swap_space
=
self
.
swap_space
,
cache_dtype
=
self
.
kv_cache_dtype
,
cache_dtype
=
self
.
kv_cache_dtype
,
is_attention_free
=
model_config
.
is_attention_free
,
is_attention_free
=
model_config
.
is_attention_free
,
...
@@ -1306,6 +1331,7 @@ class EngineArgs:
...
@@ -1306,6 +1331,7 @@ class EngineArgs:
distributed_executor_backend
=
self
.
distributed_executor_backend
,
distributed_executor_backend
=
self
.
distributed_executor_backend
,
worker_cls
=
self
.
worker_cls
,
worker_cls
=
self
.
worker_cls
,
worker_extension_cls
=
self
.
worker_extension_cls
,
worker_extension_cls
=
self
.
worker_extension_cls
,
decode_context_parallel_size
=
self
.
decode_context_parallel_size
,
)
)
speculative_config
=
self
.
create_speculative_config
(
speculative_config
=
self
.
create_speculative_config
(
...
@@ -1436,17 +1462,6 @@ class EngineArgs:
...
@@ -1436,17 +1462,6 @@ class EngineArgs:
recommend_to_remove
=
True
)
recommend_to_remove
=
True
)
return
False
return
False
# Triton v3.3 has f16 conversion regression issue on Turing and Volta,
# which broke fp16 inference
# see: https://github.com/triton-lang/triton/issues/6698
if
(
current_platform
.
is_cuda
()
and
not
current_platform
.
has_device_capability
(
80
)
and
model_config
.
dtype
==
torch
.
float16
):
_raise_or_fallback
(
feature_name
=
"Compute Capability < 8.0 with FP16"
,
recommend_to_remove
=
False
)
return
False
if
self
.
kv_cache_dtype
!=
"auto"
:
if
self
.
kv_cache_dtype
!=
"auto"
:
supported
=
current_platform
.
is_kv_cache_dtype_supported
(
supported
=
current_platform
.
is_kv_cache_dtype_supported
(
self
.
kv_cache_dtype
,
model_config
)
self
.
kv_cache_dtype
,
model_config
)
...
@@ -1476,12 +1491,6 @@ class EngineArgs:
...
@@ -1476,12 +1491,6 @@ class EngineArgs:
recommend_to_remove
=
False
)
recommend_to_remove
=
False
)
return
False
return
False
# No OTLP observability so far.
if
(
self
.
otlp_traces_endpoint
or
self
.
collect_detailed_traces
):
_raise_or_fallback
(
feature_name
=
"--otlp-traces-endpoint"
,
recommend_to_remove
=
False
)
return
False
# V1 supports N-gram, Medusa, and Eagle speculative decoding.
# V1 supports N-gram, Medusa, and Eagle speculative decoding.
if
(
self
.
speculative_config
is
not
None
if
(
self
.
speculative_config
is
not
None
and
self
.
speculative_config
.
get
(
"method"
)
==
"draft_model"
):
and
self
.
speculative_config
.
get
(
"method"
)
==
"draft_model"
):
...
@@ -1499,8 +1508,11 @@ class EngineArgs:
...
@@ -1499,8 +1508,11 @@ class EngineArgs:
"TRITON_MLA"
,
"TRITON_MLA"
,
"CUTLASS_MLA"
,
"CUTLASS_MLA"
,
"FLASHMLA"
,
"FLASHMLA"
,
"FLASHMLA_VLLM_V1"
,
"FLASH_ATTN_MLA"
,
"FLASHINFER"
,
"FLASHINFER"
,
"FLASHINFER_VLLM_V1"
,
"FLASHINFER_VLLM_V1"
,
"FLASHINFER_MLA"
,
"ROCM_AITER_MLA"
,
"ROCM_AITER_MLA"
,
"TORCH_SDPA_VLLM_V1"
,
"TORCH_SDPA_VLLM_V1"
,
"FLEX_ATTENTION"
,
"FLEX_ATTENTION"
,
...
@@ -1589,20 +1601,12 @@ class EngineArgs:
...
@@ -1589,20 +1601,12 @@ class EngineArgs:
"in low performance due to small KV cache size. Consider "
"in low performance due to small KV cache size. Consider "
"setting --max-model-len to a smaller value."
,
max_model_len
)
"setting --max-model-len to a smaller value."
,
max_model_len
)
# if using prefix caching, we must set a hash algo
# Disable prefix caching for multimodal models for VLLM_V0.
if
self
.
enable_prefix_caching
:
if
self
.
enable_prefix_caching
and
model_config
.
is_multimodal_model
:
# Disable prefix caching for multimodal models for VLLM_V0.
logger
.
warning
(
if
model_config
.
is_multimodal_model
:
"--enable-prefix-caching is not supported for multimodal "
logger
.
warning
(
"models in V0 and has been disabled."
)
"--enable-prefix-caching is not supported for multimodal "
self
.
enable_prefix_caching
=
False
"models in V0 and has been disabled."
)
self
.
enable_prefix_caching
=
False
# VLLM_V0 only supports builtin hash algo for prefix caching.
if
self
.
prefix_caching_hash_algo
==
"sha256"
:
raise
ValueError
(
"sha256 is not supported for prefix caching in V0 engine. "
"Please use 'builtin'."
)
# Set max_num_seqs to 256 for VLLM_V0.
# Set max_num_seqs to 256 for VLLM_V0.
if
self
.
max_num_seqs
is
None
:
if
self
.
max_num_seqs
is
None
:
...
...
vllm/engine/async_llm_engine.py
View file @
38d80967
...
@@ -10,8 +10,9 @@ from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
...
@@ -10,8 +10,9 @@ from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
from
weakref
import
ReferenceType
from
weakref
import
ReferenceType
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
(
DecodingConfig
,
LoRAConfig
,
ModelConfig
,
from
vllm.config
import
(
DecodingConfig
,
ModelConfig
,
ParallelConfig
,
ParallelConfig
,
SchedulerConfig
,
VllmConfig
)
SchedulerConfig
,
VllmConfig
)
from
vllm.config.lora
import
LoRAConfig
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.core.scheduler
import
SchedulerOutputs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_timeout
import
asyncio_timeout
from
vllm.engine.async_timeout
import
asyncio_timeout
...
@@ -717,7 +718,7 @@ class AsyncLLMEngine(EngineClient):
...
@@ -717,7 +718,7 @@ class AsyncLLMEngine(EngineClient):
# Stop the execute model loop in parallel workers until there
# Stop the execute model loop in parallel workers until there
# are more requests to process. This avoids waiting
# are more requests to process. This avoids waiting
# indefinitely in torch.distributed ops which may otherwise
# indefinitely in torch.distributed ops which may otherwise
# timeout, and unblocks the RPC thread in the workers so that
# time
out, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages,
# they can process any other queued control plane messages,
# such as add/remove lora adapters.
# such as add/remove lora adapters.
await
engine
.
engine
.
stop_remote_worker_execution_loop_async
()
await
engine
.
engine
.
stop_remote_worker_execution_loop_async
()
...
...
vllm/engine/llm_engine.py
View file @
38d80967
...
@@ -16,9 +16,9 @@ import torch
...
@@ -16,9 +16,9 @@ import torch
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
(
DecodingConfig
,
LoRA
Config
,
Model
Config
,
from
vllm.config
import
(
DecodingConfig
,
Model
Config
,
Observability
Config
,
ObservabilityConfig
,
ParallelConfig
,
SchedulerConfig
,
ParallelConfig
,
SchedulerConfig
,
VllmConfig
)
Vllm
Config
)
from
vllm.config.lora
import
LoRA
Config
from
vllm.core.scheduler
import
ScheduledSequenceGroup
,
SchedulerOutputs
from
vllm.core.scheduler
import
ScheduledSequenceGroup
,
SchedulerOutputs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics_types
import
StatLoggerBase
,
Stats
from
vllm.engine.metrics_types
import
StatLoggerBase
,
Stats
...
@@ -278,7 +278,8 @@ class LLMEngine:
...
@@ -278,7 +278,8 @@ class LLMEngine:
self
.
cache_config
.
block_size
,
self
.
cache_config
.
block_size
,
"gpu_memory_utilization"
:
"gpu_memory_utilization"
:
self
.
cache_config
.
gpu_memory_utilization
,
self
.
cache_config
.
gpu_memory_utilization
,
"kv_cache_memory_bytes"
:
self
.
cache_config
.
kv_cache_memory_bytes
,
# Quantization
# Quantization
"quantization"
:
"quantization"
:
self
.
model_config
.
quantization
,
self
.
model_config
.
quantization
,
...
@@ -1414,7 +1415,7 @@ class LLMEngine:
...
@@ -1414,7 +1415,7 @@ class LLMEngine:
num_generation_tokens_iter
=
0
num_generation_tokens_iter
=
0
num_tokens_iter
=
0
num_tokens_iter
=
0
time_to_first_tokens_iter
:
List
[
float
]
=
[]
time_to_first_tokens_iter
:
List
[
float
]
=
[]
time_per_output_token
s_iter
:
List
[
float
]
=
[]
inter_token_latencie
s_iter
:
List
[
float
]
=
[]
num_preemption_iter
=
(
0
if
scheduler_outputs
is
None
else
num_preemption_iter
=
(
0
if
scheduler_outputs
is
None
else
scheduler_outputs
.
preempted
)
scheduler_outputs
.
preempted
)
...
@@ -1498,9 +1499,9 @@ class LLMEngine:
...
@@ -1498,9 +1499,9 @@ class LLMEngine:
num_generation_tokens_from_prefill_groups
+=
(
num_generation_tokens_from_prefill_groups
+=
(
seq_group
.
num_seqs
())
seq_group
.
num_seqs
())
else
:
else
:
#
TPOTs.
#
ITLs
latency
=
seq_group
.
get_last_token_latency
()
latency
=
seq_group
.
get_last_token_latency
()
time_per_output_token
s_iter
.
append
(
latency
)
inter_token_latencie
s_iter
.
append
(
latency
)
if
seq_group
.
state
.
current_step
==
0
:
if
seq_group
.
state
.
current_step
==
0
:
# For async_output_proc, the do_log_stats()
# For async_output_proc, the do_log_stats()
# is called following init_multi_step(), which
# is called following init_multi_step(), which
...
@@ -1582,7 +1583,7 @@ class LLMEngine:
...
@@ -1582,7 +1583,7 @@ class LLMEngine:
num_generation_tokens_iter
=
num_generation_tokens_iter
,
num_generation_tokens_iter
=
num_generation_tokens_iter
,
num_tokens_iter
=
num_tokens_iter
,
num_tokens_iter
=
num_tokens_iter
,
time_to_first_tokens_iter
=
time_to_first_tokens_iter
,
time_to_first_tokens_iter
=
time_to_first_tokens_iter
,
time_per_output_tokens_iter
=
time_per_output_token
s_iter
,
inter_token_latencies_iter
=
inter_token_latencie
s_iter
,
num_preemption_iter
=
num_preemption_iter
,
num_preemption_iter
=
num_preemption_iter
,
# Request stats
# Request stats
...
...
vllm/engine/metrics.py
View file @
38d80967
...
@@ -113,9 +113,21 @@ class Metrics:
...
@@ -113,9 +113,21 @@ class Metrics:
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
,
20.0
,
40.0
,
80.0
,
160.0
,
640.0
,
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
,
20.0
,
40.0
,
80.0
,
160.0
,
640.0
,
2560.0
2560.0
])
])
# Deprecated in 0.11 - Renamed as vllm:inter_token_latency_seconds
# TODO: in 0.12, only enable if show_hidden_metrics=True
self
.
histogram_time_per_output_token
=
self
.
_histogram_cls
(
self
.
histogram_time_per_output_token
=
self
.
_histogram_cls
(
name
=
"vllm:time_per_output_token_seconds"
,
name
=
"vllm:time_per_output_token_seconds"
,
documentation
=
"Histogram of time per output token in seconds."
,
documentation
=
(
"Histogram of time per output token in seconds."
"DEPRECATED: Use vllm:inter_token_latency_seconds instead."
),
labelnames
=
labelnames
,
buckets
=
[
0.01
,
0.025
,
0.05
,
0.075
,
0.1
,
0.15
,
0.2
,
0.3
,
0.4
,
0.5
,
0.75
,
1.0
,
2.5
,
5.0
,
7.5
,
10.0
,
20.0
,
40.0
,
80.0
])
self
.
histogram_inter_token_latency
=
self
.
_histogram_cls
(
name
=
"vllm:inter_token_latency_seconds"
,
documentation
=
"Histogram of inter token latency in seconds."
,
labelnames
=
labelnames
,
labelnames
=
labelnames
,
buckets
=
[
buckets
=
[
0.01
,
0.025
,
0.05
,
0.075
,
0.1
,
0.15
,
0.2
,
0.3
,
0.4
,
0.5
,
0.75
,
0.01
,
0.025
,
0.05
,
0.075
,
0.1
,
0.15
,
0.2
,
0.3
,
0.4
,
0.5
,
0.75
,
...
@@ -491,7 +503,9 @@ class PrometheusStatLogger(StatLoggerBase):
...
@@ -491,7 +503,9 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
_log_histogram
(
self
.
metrics
.
histogram_time_to_first_token
,
self
.
_log_histogram
(
self
.
metrics
.
histogram_time_to_first_token
,
stats
.
time_to_first_tokens_iter
)
stats
.
time_to_first_tokens_iter
)
self
.
_log_histogram
(
self
.
metrics
.
histogram_time_per_output_token
,
self
.
_log_histogram
(
self
.
metrics
.
histogram_time_per_output_token
,
stats
.
time_per_output_tokens_iter
)
stats
.
inter_token_latencies_iter
)
self
.
_log_histogram
(
self
.
metrics
.
histogram_inter_token_latency
,
stats
.
inter_token_latencies_iter
)
# Request level data
# Request level data
# Latency
# Latency
...
...
vllm/engine/metrics_types.py
View file @
38d80967
...
@@ -43,7 +43,7 @@ class Stats:
...
@@ -43,7 +43,7 @@ class Stats:
num_generation_tokens_iter
:
int
num_generation_tokens_iter
:
int
num_tokens_iter
:
int
num_tokens_iter
:
int
time_to_first_tokens_iter
:
List
[
float
]
time_to_first_tokens_iter
:
List
[
float
]
time_per_output_token
s_iter
:
List
[
float
]
inter_token_latencie
s_iter
:
List
[
float
]
num_preemption_iter
:
int
num_preemption_iter
:
int
# Request stats (should have _requests suffix)
# Request stats (should have _requests suffix)
...
...
vllm/engine/multiprocessing/client.py
View file @
38d80967
...
@@ -235,7 +235,7 @@ class MQLLMEngineClient(EngineClient):
...
@@ -235,7 +235,7 @@ class MQLLMEngineClient(EngineClient):
# therefore we have to inform that the current
# therefore we have to inform that the current
# processed requests failed as well. Send back a dead
# processed requests failed as well. Send back a dead
# engine error give this feedback and also give a
# engine error give this feedback and also give a
# 'hint' to the server to shutdown next.
# 'hint' to the server to shut
down next.
exception
=
self
.
dead_error
exception
=
self
.
dead_error
if
request_id
is
None
:
if
request_id
is
None
:
...
@@ -270,7 +270,7 @@ class MQLLMEngineClient(EngineClient):
...
@@ -270,7 +270,7 @@ class MQLLMEngineClient(EngineClient):
queue
.
put_nowait
(
request_output
)
queue
.
put_nowait
(
request_output
)
async
def
setup
(
self
):
async
def
setup
(
self
):
"""Setup the client before it starts sending server requests."""
"""Set
up the client before it starts sending server requests."""
# Start output_loop
# Start output_loop
if
self
.
output_loop
is
None
:
if
self
.
output_loop
is
None
:
...
...
vllm/engine/multiprocessing/engine.py
View file @
38d80967
...
@@ -49,7 +49,7 @@ class MQLLMEngine:
...
@@ -49,7 +49,7 @@ class MQLLMEngine:
This class is used to wrap the
This class is used to wrap the
[`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use
[`LLMEngine`][vllm.engine.llm_engine.LLMEngine] class to enable use
in concurr
n
et manner. It runs a background loop and uses zeromq to
in concurre
n
t manner. It runs a background loop and uses zeromq to
receive new requests and stream outputs incrementally via ipc.
receive new requests and stream outputs incrementally via ipc.
The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode
The [`LLMEngine`][vllm.engine.llm_engine.LLMEngine] generate or encode
...
...
vllm/engine/protocol.py
View file @
38d80967
...
@@ -78,6 +78,7 @@ class EngineClient(ABC):
...
@@ -78,6 +78,7 @@ class EngineClient(ABC):
preprocessor
=
await
self
.
get_input_preprocessor
()
preprocessor
=
await
self
.
get_input_preprocessor
()
tokenizer_group
=
preprocessor
.
get_tokenizer_group
()
tokenizer_group
=
preprocessor
.
get_tokenizer_group
()
tokenizer
=
await
tokenizer_group
.
get_lora_tokenizer_async
()
tokenizer
=
await
tokenizer_group
.
get_lora_tokenizer_async
()
eos_token_id
=
tokenizer
.
eos_token_id
if
is_explicit_encoder_decoder_prompt
(
prompt
):
if
is_explicit_encoder_decoder_prompt
(
prompt
):
raise
NotImplementedError
raise
NotImplementedError
...
@@ -104,7 +105,7 @@ class EngineClient(ABC):
...
@@ -104,7 +105,7 @@ class EngineClient(ABC):
tokenized_length
=
len
(
prompt_token_ids
)
tokenized_length
=
len
(
prompt_token_ids
)
sort_beams_key
=
create_sort_beams_key_function
(
sort_beams_key
=
create_sort_beams_key_function
(
tokenizer
.
eos_token_id
,
length_penalty
)
eos_token_id
,
length_penalty
)
beam_search_params
=
SamplingParams
(
beam_search_params
=
SamplingParams
(
logprobs
=
2
*
beam_width
,
logprobs
=
2
*
beam_width
,
...
@@ -154,7 +155,7 @@ class EngineClient(ABC):
...
@@ -154,7 +155,7 @@ class EngineClient(ABC):
if
result
.
outputs
[
0
].
logprobs
is
not
None
:
if
result
.
outputs
[
0
].
logprobs
is
not
None
:
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
logprobs
=
result
.
outputs
[
0
].
logprobs
[
0
]
for
token_id
,
logprob_obj
in
logprobs
.
items
():
for
token_id
,
logprob_obj
in
logprobs
.
items
():
if
token_id
==
tokenizer
.
eos_token_id
and
\
if
token_id
==
eos_token_id
and
\
not
ignore_eos
:
not
ignore_eos
:
completed
.
append
(
completed
.
append
(
BeamSearchSequence
(
BeamSearchSequence
(
...
@@ -166,7 +167,7 @@ class EngineClient(ABC):
...
@@ -166,7 +167,7 @@ class EngineClient(ABC):
cum_logprob
=
current_beam
.
cum_logprob
+
cum_logprob
=
current_beam
.
cum_logprob
+
logprob_obj
.
logprob
,
logprob_obj
.
logprob
,
finish_reason
=
"stop"
,
finish_reason
=
"stop"
,
stop_reason
=
tokenizer
.
eos_token_id
))
stop_reason
=
eos_token_id
))
else
:
else
:
new_beams
.
append
(
new_beams
.
append
(
BeamSearchSequence
(
BeamSearchSequence
(
...
@@ -189,14 +190,14 @@ class EngineClient(ABC):
...
@@ -189,14 +190,14 @@ class EngineClient(ABC):
best_beams
=
sorted_completed
[:
beam_width
]
best_beams
=
sorted_completed
[:
beam_width
]
for
beam
in
best_beams
:
for
beam
in
best_beams
:
if
(
beam
.
tokens
[
-
1
]
==
tokenizer
.
eos_token_id
and
not
ignore_eos
):
if
(
beam
.
tokens
[
-
1
]
==
eos_token_id
and
not
ignore_eos
):
# Skip the eos token in the text.
# Skip the eos token in the text.
tokens
=
beam
.
tokens
[
tokenized_length
:
-
1
]
tokens
=
beam
.
tokens
[
tokenized_length
:
-
1
]
else
:
else
:
tokens
=
beam
.
tokens
[
tokenized_length
:]
tokens
=
beam
.
tokens
[
tokenized_length
:]
beam
.
text
=
tokenizer
.
decode
(
tokens
)
beam
.
text
=
tokenizer
.
decode
(
tokens
)
beam_search_output
=
RequestOutput
(
yield
RequestOutput
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt
=
prompt_text
,
prompt
=
prompt_text
,
outputs
=
[
outputs
=
[
...
@@ -214,8 +215,6 @@ class EngineClient(ABC):
...
@@ -214,8 +215,6 @@ class EngineClient(ABC):
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
prompt_logprobs
=
None
)
prompt_logprobs
=
None
)
yield
beam_search_output
@
abstractmethod
@
abstractmethod
def
encode
(
def
encode
(
self
,
self
,
...
...
vllm/entrypoints/chat_utils.py
View file @
38d80967
...
@@ -41,7 +41,8 @@ from typing_extensions import Required, TypeAlias, TypedDict
...
@@ -41,7 +41,8 @@ from typing_extensions import Required, TypeAlias, TypedDict
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.models
import
SupportsMultiModal
from
vllm.model_executor.models
import
SupportsMultiModal
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalDataDict
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalDataDict
,
MultiModalUUIDDict
)
from
vllm.multimodal.utils
import
MediaConnector
from
vllm.multimodal.utils
import
MediaConnector
# yapf: disable
# yapf: disable
from
vllm.transformers_utils.chat_templates
import
(
from
vllm.transformers_utils.chat_templates
import
(
...
@@ -72,6 +73,11 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
...
@@ -72,6 +73,11 @@ class ChatCompletionContentPartAudioParam(TypedDict, total=False):
type
:
Required
[
Literal
[
"audio_url"
]]
type
:
Required
[
Literal
[
"audio_url"
]]
"""The type of the content part."""
"""The type of the content part."""
uuid
:
Optional
[
str
]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class
ChatCompletionContentPartImageEmbedsParam
(
TypedDict
,
total
=
False
):
class
ChatCompletionContentPartImageEmbedsParam
(
TypedDict
,
total
=
False
):
...
@@ -83,6 +89,11 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
...
@@ -83,6 +89,11 @@ class ChatCompletionContentPartImageEmbedsParam(TypedDict, total=False):
"""
"""
type
:
Required
[
Literal
[
"image_embeds"
]]
type
:
Required
[
Literal
[
"image_embeds"
]]
"""The type of the content part."""
"""The type of the content part."""
uuid
:
Optional
[
str
]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class
VideoURL
(
TypedDict
,
total
=
False
):
class
VideoURL
(
TypedDict
,
total
=
False
):
...
@@ -97,12 +108,18 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
...
@@ -97,12 +108,18 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
type
:
Required
[
Literal
[
"video_url"
]]
type
:
Required
[
Literal
[
"video_url"
]]
"""The type of the content part."""
"""The type of the content part."""
uuid
:
Optional
[
str
]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class
PILImage
(
BaseModel
):
class
PILImage
(
BaseModel
):
"""
"""
A PIL.Image.Image object.
A PIL.Image.Image object.
"""
"""
image_pil
:
Image
.
Image
image_pil
:
Image
.
Image
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
...
@@ -115,7 +132,13 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
...
@@ -115,7 +132,13 @@ class CustomChatCompletionContentPILImageParam(TypedDict, total=False):
"image_pil": ImageAsset('cherry_blossom').pil_image
"image_pil": ImageAsset('cherry_blossom').pil_image
}
}
"""
"""
image_pil
:
Required
[
PILImage
]
image_pil
:
Required
[
PILImage
]
uuid
:
Optional
[
str
]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class
CustomChatCompletionContentSimpleImageParam
(
TypedDict
,
total
=
False
):
class
CustomChatCompletionContentSimpleImageParam
(
TypedDict
,
total
=
False
):
...
@@ -127,7 +150,13 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
...
@@ -127,7 +150,13 @@ class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"image_url": "https://example.com/image.jpg"
"image_url": "https://example.com/image.jpg"
}
}
"""
"""
image_url
:
Required
[
str
]
image_url
:
Required
[
str
]
uuid
:
Optional
[
str
]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class
CustomChatCompletionContentSimpleAudioParam
(
TypedDict
,
total
=
False
):
class
CustomChatCompletionContentSimpleAudioParam
(
TypedDict
,
total
=
False
):
...
@@ -138,6 +167,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
...
@@ -138,6 +167,7 @@ class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False):
"audio_url": "https://example.com/audio.mp3"
"audio_url": "https://example.com/audio.mp3"
}
}
"""
"""
audio_url
:
Required
[
str
]
audio_url
:
Required
[
str
]
...
@@ -149,7 +179,13 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
...
@@ -149,7 +179,13 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
"video_url": "https://example.com/video.mp4"
"video_url": "https://example.com/video.mp4"
}
}
"""
"""
video_url
:
Required
[
str
]
video_url
:
Required
[
str
]
uuid
:
Optional
[
str
]
"""
User-provided UUID of a media. User must guarantee that it is properly
generated and unique for different medias.
"""
class
CustomThinkCompletionContentParam
(
TypedDict
,
total
=
False
):
class
CustomThinkCompletionContentParam
(
TypedDict
,
total
=
False
):
...
@@ -174,19 +210,24 @@ class CustomThinkCompletionContentParam(TypedDict, total=False):
...
@@ -174,19 +210,24 @@ class CustomThinkCompletionContentParam(TypedDict, total=False):
ChatCompletionContentPartParam
:
TypeAlias
=
Union
[
ChatCompletionContentPartParam
:
TypeAlias
=
Union
[
OpenAIChatCompletionContentPartParam
,
ChatCompletionContentPartAudioParam
,
OpenAIChatCompletionContentPartParam
,
ChatCompletionContentPartAudioParam
,
ChatCompletionContentPartInputAudioParam
,
ChatCompletionContentPartInputAudioParam
,
ChatCompletionContentPartVideoParam
,
ChatCompletionContentPartRefusalParam
,
ChatCompletionContentPartVideoParam
,
ChatCompletionContentPartRefusalParam
,
CustomChatCompletionContentPILImageParam
,
CustomChatCompletionContentPILImageParam
,
CustomChatCompletionContentSimpleImageParam
,
CustomChatCompletionContentSimpleImageParam
,
ChatCompletionContentPartImageEmbedsParam
,
ChatCompletionContentPartImageEmbedsParam
,
CustomChatCompletionContentSimpleAudioParam
,
CustomChatCompletionContentSimpleAudioParam
,
CustomChatCompletionContentSimpleVideoParam
,
str
,
CustomChatCompletionContentSimpleVideoParam
,
CustomThinkCompletionContentParam
]
str
,
CustomThinkCompletionContentParam
,
]
class
CustomChatCompletionMessageParam
(
TypedDict
,
total
=
False
):
class
CustomChatCompletionMessageParam
(
TypedDict
,
total
=
False
):
"""Enables custom roles in the Chat Completion API."""
"""Enables custom roles in the Chat Completion API."""
role
:
Required
[
str
]
role
:
Required
[
str
]
"""The role of the message's author."""
"""The role of the message's author."""
...
@@ -207,9 +248,11 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
...
@@ -207,9 +248,11 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
"""The tool calls generated by the model, such as function calls."""
"""The tool calls generated by the model, such as function calls."""
ChatCompletionMessageParam
=
Union
[
OpenAIChatCompletionMessageParam
,
ChatCompletionMessageParam
=
Union
[
CustomChatCompletionMessageParam
,
OpenAIChatCompletionMessageParam
,
OpenAIHarmonyMessage
]
CustomChatCompletionMessageParam
,
OpenAIHarmonyMessage
,
]
# TODO: Make fields ReadOnly once mypy supports it
# TODO: Make fields ReadOnly once mypy supports it
...
@@ -262,13 +305,13 @@ def _is_var_or_elems_access(
...
@@ -262,13 +305,13 @@ def _is_var_or_elems_access(
key
:
Optional
[
str
]
=
None
,
key
:
Optional
[
str
]
=
None
,
)
->
bool
:
)
->
bool
:
if
isinstance
(
node
,
jinja2
.
nodes
.
Filter
):
if
isinstance
(
node
,
jinja2
.
nodes
.
Filter
):
return
(
node
.
node
is
not
None
return
node
.
node
is
not
None
and
_is_var_or_elems_access
(
and
_is_var_or_elems_access
(
node
.
node
,
varname
,
key
)
)
node
.
node
,
varname
,
key
)
if
isinstance
(
node
,
jinja2
.
nodes
.
Test
):
if
isinstance
(
node
,
jinja2
.
nodes
.
Test
):
return
_is_var_or_elems_access
(
node
.
node
,
varname
,
key
)
return
_is_var_or_elems_access
(
node
.
node
,
varname
,
key
)
if
(
isinstance
(
node
,
jinja2
.
nodes
.
Getitem
)
if
isinstance
(
node
,
jinja2
.
nodes
.
Getitem
)
and
isinstance
(
and
isinstance
(
node
.
arg
,
jinja2
.
nodes
.
Slice
)
)
:
node
.
arg
,
jinja2
.
nodes
.
Slice
):
return
_is_var_or_elems_access
(
node
.
node
,
varname
,
key
)
return
_is_var_or_elems_access
(
node
.
node
,
varname
,
key
)
# yapf: disable
# yapf: disable
...
@@ -373,15 +416,18 @@ def resolve_mistral_chat_template(
...
@@ -373,15 +416,18 @@ def resolve_mistral_chat_template(
)
->
Optional
[
str
]:
)
->
Optional
[
str
]:
if
chat_template
is
not
None
:
if
chat_template
is
not
None
:
logger
.
warning_once
(
logger
.
warning_once
(
"'chat_template' cannot be overridden for mistral tokenizer."
)
"'chat_template' cannot be overridden for mistral tokenizer."
)
if
"add_generation_prompt"
in
kwargs
:
if
"add_generation_prompt"
in
kwargs
:
logger
.
warning_once
(
logger
.
warning_once
(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored."
)
"so it will be ignored."
)
if
"continue_final_message"
in
kwargs
:
if
"continue_final_message"
in
kwargs
:
logger
.
warning_once
(
logger
.
warning_once
(
"'continue_final_message' is not supported for mistral tokenizer, "
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored."
)
"so it will be ignored."
)
return
None
return
None
...
@@ -401,23 +447,35 @@ def resolve_hf_chat_template(
...
@@ -401,23 +447,35 @@ def resolve_hf_chat_template(
try
:
try
:
processor
=
cached_get_processor
(
processor
=
cached_get_processor
(
tokenizer
.
name_or_path
,
tokenizer
.
name_or_path
,
processor_cls
=
(
PreTrainedTokenizer
,
PreTrainedTokenizerFast
,
processor_cls
=
(
ProcessorMixin
),
PreTrainedTokenizer
,
PreTrainedTokenizerFast
,
ProcessorMixin
,
),
trust_remote_code
=
model_config
.
trust_remote_code
,
trust_remote_code
=
model_config
.
trust_remote_code
,
)
)
if
isinstance
(
processor
,
ProcessorMixin
)
and
\
if
(
hasattr
(
processor
,
'chat_template'
)
and
\
isinstance
(
processor
,
ProcessorMixin
)
processor
.
chat_template
is
not
None
:
and
hasattr
(
processor
,
"chat_template"
)
and
processor
.
chat_template
is
not
None
):
return
processor
.
chat_template
return
processor
.
chat_template
except
Exception
:
except
Exception
:
logger
.
debug
(
"Failed to load AutoProcessor chat template for %s"
,
tokenizer
.
name_or_path
,
exc_info
=
True
)
# noqa: E501
logger
.
debug
(
"Failed to load AutoProcessor chat template for %s"
,
tokenizer
.
name_or_path
,
exc_info
=
True
,
)
# noqa: E501
# 3rd priority: AutoTokenizer chat template
# 3rd priority: AutoTokenizer chat template
try
:
try
:
return
tokenizer
.
get_chat_template
(
chat_template
,
tools
=
tools
)
return
tokenizer
.
get_chat_template
(
chat_template
,
tools
=
tools
)
except
Exception
:
except
Exception
:
logger
.
debug
(
"Failed to load AutoTokenizer chat template for %s"
,
logger
.
debug
(
tokenizer
.
name_or_path
,
exc_info
=
True
)
"Failed to load AutoTokenizer chat template for %s"
,
tokenizer
.
name_or_path
,
exc_info
=
True
,
)
# 4th priority: Predefined fallbacks
# 4th priority: Predefined fallbacks
path
=
get_chat_template_fallback_path
(
path
=
get_chat_template_fallback_path
(
...
@@ -425,12 +483,16 @@ def resolve_hf_chat_template(
...
@@ -425,12 +483,16 @@ def resolve_hf_chat_template(
tokenizer_name_or_path
=
model_config
.
tokenizer
,
tokenizer_name_or_path
=
model_config
.
tokenizer
,
)
)
if
path
is
not
None
:
if
path
is
not
None
:
logger
.
info
(
"Loading chat template fallback for %s as there isn't one "
logger
.
info
(
"defined on HF Hub."
,
tokenizer
.
name_or_path
)
"Loading chat template fallback for %s as there isn't one "
"defined on HF Hub."
,
tokenizer
.
name_or_path
,
)
chat_template
=
load_chat_template
(
path
)
chat_template
=
load_chat_template
(
path
)
else
:
else
:
logger
.
debug
(
"There is no chat template fallback for %s"
,
logger
.
debug
(
tokenizer
.
name_or_path
)
"There is no chat template fallback for %s"
,
tokenizer
.
name_or_path
)
return
chat_template
return
chat_template
...
@@ -452,11 +514,17 @@ def _resolve_chat_template_content_format(
...
@@ -452,11 +514,17 @@ def _resolve_chat_template_content_format(
else
:
else
:
hf_chat_template
=
None
hf_chat_template
=
None
jinja_text
=
(
hf_chat_template
if
isinstance
(
hf_chat_template
,
str
)
jinja_text
=
(
else
load_chat_template
(
chat_template
,
is_literal
=
True
))
hf_chat_template
if
isinstance
(
hf_chat_template
,
str
)
else
load_chat_template
(
chat_template
,
is_literal
=
True
)
)
detected_format
=
(
"string"
if
jinja_text
is
None
else
detected_format
=
(
_detect_content_format
(
jinja_text
,
default
=
"string"
))
"string"
if
jinja_text
is
None
else
_detect_content_format
(
jinja_text
,
default
=
"string"
)
)
return
detected_format
return
detected_format
...
@@ -512,7 +580,6 @@ def resolve_chat_template_content_format(
...
@@ -512,7 +580,6 @@ def resolve_chat_template_content_format(
return
detected_format
return
detected_format
ModalityStr
=
Literal
[
"image"
,
"audio"
,
"video"
,
"image_embeds"
]
ModalityStr
=
Literal
[
"image"
,
"audio"
,
"video"
,
"image_embeds"
]
_T
=
TypeVar
(
"_T"
)
_T
=
TypeVar
(
"_T"
)
...
@@ -531,6 +598,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -531,6 +598,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self
.
_tokenizer
=
tokenizer
self
.
_tokenizer
=
tokenizer
self
.
_items_by_modality
=
defaultdict
[
str
,
list
[
_T
]](
list
)
self
.
_items_by_modality
=
defaultdict
[
str
,
list
[
_T
]](
list
)
self
.
_uuids_by_modality
=
defaultdict
[
str
,
list
[
Optional
[
str
]]](
list
)
@
property
@
property
def
model_config
(
self
)
->
ModelConfig
:
def
model_config
(
self
)
->
ModelConfig
:
...
@@ -539,6 +607,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -539,6 +607,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
@
cached_property
@
cached_property
def
model_cls
(
self
)
->
type
[
SupportsMultiModal
]:
def
model_cls
(
self
)
->
type
[
SupportsMultiModal
]:
from
vllm.model_executor.model_loader
import
get_model_cls
from
vllm.model_executor.model_loader
import
get_model_cls
model_cls
=
get_model_cls
(
self
.
model_config
)
model_cls
=
get_model_cls
(
self
.
model_config
)
return
cast
(
type
[
SupportsMultiModal
],
model_cls
)
return
cast
(
type
[
SupportsMultiModal
],
model_cls
)
...
@@ -554,10 +623,15 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -554,10 +623,15 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def
mm_processor
(
self
):
def
mm_processor
(
self
):
return
self
.
mm_registry
.
create_processor
(
self
.
model_config
)
return
self
.
mm_registry
.
create_processor
(
self
.
model_config
)
def
add
(
self
,
modality
:
ModalityStr
,
item
:
_T
)
->
Optional
[
str
]:
def
add
(
self
,
modality
:
ModalityStr
,
item
:
_T
,
uuid
:
Optional
[
str
]
=
None
)
->
Optional
[
str
]:
"""
"""
Add a multi-modal item to the current prompt and returns the
Add a multi-modal item to the current prompt and returns the
placeholder string to use, if any.
placeholder string to use, if any.
An optional uuid can be added which serves as a unique identifier of the
media.
"""
"""
input_modality
=
modality
.
replace
(
"_embeds"
,
""
)
input_modality
=
modality
.
replace
(
"_embeds"
,
""
)
num_items
=
len
(
self
.
_items_by_modality
[
modality
])
+
1
num_items
=
len
(
self
.
_items_by_modality
[
modality
])
+
1
...
@@ -565,37 +639,64 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -565,37 +639,64 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self
.
mm_processor
.
validate_num_items
(
input_modality
,
num_items
)
self
.
mm_processor
.
validate_num_items
(
input_modality
,
num_items
)
self
.
_items_by_modality
[
modality
].
append
(
item
)
self
.
_items_by_modality
[
modality
].
append
(
item
)
self
.
_uuids_by_modality
[
modality
].
append
(
uuid
)
return
self
.
model_cls
.
get_placeholder_str
(
modality
,
num_items
)
return
self
.
model_cls
.
get_placeholder_str
(
modality
,
num_items
)
def
all_mm_uuids
(
self
)
->
Optional
[
MultiModalUUIDDict
]:
if
not
self
.
_items_by_modality
:
return
None
mm_uuids
=
{}
uuids_by_modality
=
dict
(
self
.
_uuids_by_modality
)
if
"image"
in
uuids_by_modality
and
"image_embeds"
in
uuids_by_modality
:
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
if
"image_embeds"
in
uuids_by_modality
:
image_embeds_uuids
=
uuids_by_modality
[
"image_embeds"
]
if
len
(
image_embeds_uuids
)
>
1
:
raise
ValueError
(
"Only one message can have {'type': 'image_embeds'}"
)
mm_uuids
[
"image"
]
=
uuids_by_modality
[
"image_embeds"
]
if
"image"
in
uuids_by_modality
:
mm_uuids
[
"image"
]
=
uuids_by_modality
[
"image"
]
# UUIDs of images
if
"audio"
in
uuids_by_modality
:
mm_uuids
[
"audio"
]
=
uuids_by_modality
[
"audio"
]
# UUIDs of audios
if
"video"
in
uuids_by_modality
:
mm_uuids
[
"video"
]
=
uuids_by_modality
[
"video"
]
# UUIDs of videos
return
mm_uuids
@
abstractmethod
@
abstractmethod
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
raise
NotImplementedError
raise
NotImplementedError
class
MultiModalItemTracker
(
BaseMultiModalItemTracker
[
object
]):
class
MultiModalItemTracker
(
BaseMultiModalItemTracker
[
object
]):
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
if
not
self
.
_items_by_modality
:
if
not
self
.
_items_by_modality
:
return
None
return
None
mm_inputs
=
{}
mm_inputs
=
{}
items_by_modality
=
dict
(
self
.
_items_by_modality
)
items_by_modality
=
dict
(
self
.
_items_by_modality
)
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
raise
ValueError
(
\
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
"Mixing raw image and embedding inputs is not allowed"
)
if
"image_embeds"
in
items_by_modality
:
if
"image_embeds"
in
items_by_modality
:
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
if
len
(
image_embeds_lst
)
>
1
:
if
len
(
image_embeds_lst
)
>
1
:
raise
ValueError
(
\
raise
ValueError
(
"Only one message can have {'type': 'image_embeds'}"
)
"Only one message can have {'type': 'image_embeds'}"
)
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
if
"image"
in
items_by_modality
:
if
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
if
"audio"
in
items_by_modality
:
if
"audio"
in
items_by_modality
:
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
if
"video"
in
items_by_modality
:
if
"video"
in
items_by_modality
:
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
return
mm_inputs
return
mm_inputs
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
...
@@ -603,32 +704,33 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
...
@@ -603,32 +704,33 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
class
AsyncMultiModalItemTracker
(
BaseMultiModalItemTracker
[
Awaitable
[
object
]]):
class
AsyncMultiModalItemTracker
(
BaseMultiModalItemTracker
[
Awaitable
[
object
]]):
async
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
async
def
all_mm_data
(
self
)
->
Optional
[
MultiModalDataDict
]:
if
not
self
.
_items_by_modality
:
if
not
self
.
_items_by_modality
:
return
None
return
None
mm_inputs
=
{}
mm_inputs
=
{}
items_by_modality
=
{
items_by_modality
=
{
modality
:
await
asyncio
.
gather
(
*
items
)
modality
:
await
asyncio
.
gather
(
*
items
)
for
modality
,
items
in
self
.
_items_by_modality
.
items
()
for
modality
,
items
in
self
.
_items_by_modality
.
items
()
}
}
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
raise
ValueError
(
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
"Mixing raw image and embedding inputs is not allowed"
)
if
"image_embeds"
in
items_by_modality
:
if
"image_embeds"
in
items_by_modality
:
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
if
len
(
image_embeds_lst
)
>
1
:
if
len
(
image_embeds_lst
)
>
1
:
raise
ValueError
(
raise
ValueError
(
"Only one message can have {'type': 'image_embeds'}"
)
"Only one message can have {'type': 'image_embeds'}"
)
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
mm_inputs
[
"image"
]
=
image_embeds_lst
[
0
]
if
"image"
in
items_by_modality
:
if
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
if
"audio"
in
items_by_modality
:
if
"audio"
in
items_by_modality
:
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
if
"video"
in
items_by_modality
:
if
"video"
in
items_by_modality
:
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
return
mm_inputs
return
mm_inputs
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
...
@@ -636,7 +738,6 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
...
@@ -636,7 +738,6 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
class
BaseMultiModalContentParser
(
ABC
):
class
BaseMultiModalContentParser
(
ABC
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -648,8 +749,9 @@ class BaseMultiModalContentParser(ABC):
...
@@ -648,8 +749,9 @@ class BaseMultiModalContentParser(ABC):
# }
# }
self
.
_placeholder_storage
:
dict
[
str
,
list
]
=
defaultdict
(
list
)
self
.
_placeholder_storage
:
dict
[
str
,
list
]
=
defaultdict
(
list
)
def
_add_placeholder
(
self
,
modality
:
ModalityStr
,
def
_add_placeholder
(
placeholder
:
Optional
[
str
]):
self
,
modality
:
ModalityStr
,
placeholder
:
Optional
[
str
]
):
mod_placeholder
=
MODALITY_PLACEHOLDERS_MAP
[
modality
]
mod_placeholder
=
MODALITY_PLACEHOLDERS_MAP
[
modality
]
if
placeholder
:
if
placeholder
:
self
.
_placeholder_storage
[
mod_placeholder
].
append
(
placeholder
)
self
.
_placeholder_storage
[
mod_placeholder
].
append
(
placeholder
)
...
@@ -658,33 +760,39 @@ class BaseMultiModalContentParser(ABC):
...
@@ -658,33 +760,39 @@ class BaseMultiModalContentParser(ABC):
return
dict
(
self
.
_placeholder_storage
)
return
dict
(
self
.
_placeholder_storage
)
@
abstractmethod
@
abstractmethod
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
def
parse_image
(
self
,
image_url
:
str
,
uuid
:
Optional
[
str
]
=
None
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
parse_image_embeds
(
self
,
def
parse_image_embeds
(
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]])
->
None
:
self
,
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]],
uuid
:
Optional
[
str
]
=
None
,
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
parse_image_pil
(
self
,
image_pil
:
Image
.
Image
)
->
None
:
def
parse_image_pil
(
self
,
image_pil
:
Image
.
Image
,
uuid
:
Optional
[
str
]
=
None
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
def
parse_audio
(
self
,
audio_url
:
str
,
uuid
:
Optional
[
str
]
=
None
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
)
->
None
:
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
,
uuid
:
Optional
[
str
]
=
None
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
parse_video
(
self
,
video_url
:
str
)
->
None
:
def
parse_video
(
self
,
video_url
:
str
,
uuid
:
Optional
[
str
]
=
None
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
class
MultiModalContentParser
(
BaseMultiModalContentParser
):
class
MultiModalContentParser
(
BaseMultiModalContentParser
):
def
__init__
(
self
,
tracker
:
MultiModalItemTracker
)
->
None
:
def
__init__
(
self
,
tracker
:
MultiModalItemTracker
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -695,70 +803,79 @@ class MultiModalContentParser(BaseMultiModalContentParser):
...
@@ -695,70 +803,79 @@ class MultiModalContentParser(BaseMultiModalContentParser):
allowed_local_media_path
=
tracker
.
allowed_local_media_path
,
allowed_local_media_path
=
tracker
.
allowed_local_media_path
,
)
)
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
def
parse_image
(
self
,
image_url
:
str
,
uuid
:
Optional
[
str
]
=
None
)
->
None
:
image
=
self
.
_connector
.
fetch_image
(
image_url
)
image
=
self
.
_connector
.
fetch_image
(
image_url
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image
,
uuid
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
def
parse_image_embeds
(
self
,
def
parse_image_embeds
(
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]])
->
None
:
self
,
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]],
uuid
:
Optional
[
str
]
=
None
,
)
->
None
:
if
isinstance
(
image_embeds
,
dict
):
if
isinstance
(
image_embeds
,
dict
):
embeds
=
{
embeds
=
{
k
:
self
.
_connector
.
fetch_image_embedding
(
v
)
k
:
self
.
_connector
.
fetch_image_embedding
(
v
)
for
k
,
v
in
image_embeds
.
items
()
for
k
,
v
in
image_embeds
.
items
()
}
}
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
embeds
)
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
embeds
,
uuid
)
if
isinstance
(
image_embeds
,
str
):
if
isinstance
(
image_embeds
,
str
):
embedding
=
self
.
_connector
.
fetch_image_embedding
(
image_embeds
)
embedding
=
self
.
_connector
.
fetch_image_embedding
(
image_embeds
)
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
embedding
)
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
embedding
,
uuid
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
def
parse_image_pil
(
self
,
image_pil
:
Image
.
Image
)
->
None
:
def
parse_image_pil
(
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_pil
)
self
,
image_pil
:
Image
.
Image
,
uuid
:
Optional
[
str
]
=
None
)
->
None
:
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_pil
,
uuid
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
def
parse_audio
(
self
,
audio_url
:
str
,
uuid
:
Optional
[
str
]
=
None
)
->
None
:
audio
=
self
.
_connector
.
fetch_audio
(
audio_url
)
audio
=
self
.
_connector
.
fetch_audio
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio
,
uuid
)
self
.
_add_placeholder
(
"audio"
,
placeholder
)
self
.
_add_placeholder
(
"audio"
,
placeholder
)
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
)
->
None
:
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
,
uuid
:
Optional
[
str
]
=
None
)
->
None
:
audio_data
=
input_audio
.
get
(
"data"
,
""
)
audio_data
=
input_audio
.
get
(
"data"
,
""
)
audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_url
=
f
"data:audio/
{
audio_format
}
;base64,
{
audio_data
}
"
audio_url
=
f
"data:audio/
{
audio_format
}
;base64,
{
audio_data
}
"
return
self
.
parse_audio
(
audio_url
)
return
self
.
parse_audio
(
audio_url
,
uuid
)
def
parse_video
(
self
,
video_url
:
str
)
->
None
:
def
parse_video
(
self
,
video_url
:
str
,
uuid
:
Optional
[
str
]
=
None
)
->
None
:
video
=
self
.
_connector
.
fetch_video
(
video_url
=
video_url
)
video
=
self
.
_connector
.
fetch_video
(
video_url
=
video_url
)
placeholder
=
self
.
_tracker
.
add
(
"video"
,
video
)
placeholder
=
self
.
_tracker
.
add
(
"video"
,
video
,
uuid
)
self
.
_add_placeholder
(
"video"
,
placeholder
)
self
.
_add_placeholder
(
"video"
,
placeholder
)
class
AsyncMultiModalContentParser
(
BaseMultiModalContentParser
):
class
AsyncMultiModalContentParser
(
BaseMultiModalContentParser
):
def
__init__
(
self
,
tracker
:
AsyncMultiModalItemTracker
)
->
None
:
def
__init__
(
self
,
tracker
:
AsyncMultiModalItemTracker
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
_tracker
=
tracker
self
.
_tracker
=
tracker
self
.
_connector
=
MediaConnector
(
self
.
_connector
=
MediaConnector
(
media_io_kwargs
=
self
.
_tracker
.
_model_config
.
media_io_kwargs
,
media_io_kwargs
=
self
.
_tracker
.
_model_config
.
media_io_kwargs
,
allowed_local_media_path
=
tracker
.
allowed_local_media_path
allowed_local_media_path
=
tracker
.
allowed_local_media_path
,
)
)
def
parse_image
(
self
,
image_url
:
str
)
->
None
:
def
parse_image
(
self
,
image_url
:
str
,
uuid
:
Optional
[
str
]
=
None
)
->
None
:
image_coro
=
self
.
_connector
.
fetch_image_async
(
image_url
)
image_coro
=
self
.
_connector
.
fetch_image_async
(
image_url
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_coro
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
image_coro
,
uuid
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
def
parse_image_embeds
(
self
,
def
parse_image_embeds
(
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]])
->
None
:
self
,
image_embeds
:
Union
[
str
,
dict
[
str
,
str
]],
uuid
:
Optional
[
str
]
=
None
,
)
->
None
:
future
:
asyncio
.
Future
[
Union
[
str
,
dict
[
str
,
str
]]]
=
asyncio
.
Future
()
future
:
asyncio
.
Future
[
Union
[
str
,
dict
[
str
,
str
]]]
=
asyncio
.
Future
()
if
isinstance
(
image_embeds
,
dict
):
if
isinstance
(
image_embeds
,
dict
):
...
@@ -769,37 +886,40 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
...
@@ -769,37 +886,40 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
future
.
set_result
(
embeds
)
future
.
set_result
(
embeds
)
if
isinstance
(
image_embeds
,
str
):
if
isinstance
(
image_embeds
,
str
):
embedding
=
self
.
_connector
.
\
embedding
=
self
.
_connector
.
fetch_image_embedding
(
image_embeds
)
fetch_image_embedding
(
image_embeds
)
future
.
set_result
(
embedding
)
future
.
set_result
(
embedding
)
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
future
)
placeholder
=
self
.
_tracker
.
add
(
"image_embeds"
,
future
,
uuid
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
def
parse_image_pil
(
self
,
image_pil
:
Image
.
Image
)
->
None
:
def
parse_image_pil
(
self
,
image_pil
:
Image
.
Image
,
uuid
:
Optional
[
str
]
=
None
)
->
None
:
future
:
asyncio
.
Future
[
Image
.
Image
]
=
asyncio
.
Future
()
future
:
asyncio
.
Future
[
Image
.
Image
]
=
asyncio
.
Future
()
future
.
set_result
(
image_pil
)
future
.
set_result
(
image_pil
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
future
)
placeholder
=
self
.
_tracker
.
add
(
"image"
,
future
,
uuid
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
self
.
_add_placeholder
(
"image"
,
placeholder
)
def
parse_audio
(
self
,
audio_url
:
str
)
->
None
:
def
parse_audio
(
self
,
audio_url
:
str
,
uuid
:
Optional
[
str
]
=
None
)
->
None
:
audio_coro
=
self
.
_connector
.
fetch_audio_async
(
audio_url
)
audio_coro
=
self
.
_connector
.
fetch_audio_async
(
audio_url
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio_coro
)
placeholder
=
self
.
_tracker
.
add
(
"audio"
,
audio_coro
,
uuid
)
self
.
_add_placeholder
(
"audio"
,
placeholder
)
self
.
_add_placeholder
(
"audio"
,
placeholder
)
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
)
->
None
:
def
parse_input_audio
(
self
,
input_audio
:
InputAudio
,
uuid
:
Optional
[
str
]
=
None
)
->
None
:
audio_data
=
input_audio
.
get
(
"data"
,
""
)
audio_data
=
input_audio
.
get
(
"data"
,
""
)
audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_format
=
input_audio
.
get
(
"format"
,
""
)
audio_url
=
f
"data:audio/
{
audio_format
}
;base64,
{
audio_data
}
"
audio_url
=
f
"data:audio/
{
audio_format
}
;base64,
{
audio_data
}
"
return
self
.
parse_audio
(
audio_url
)
return
self
.
parse_audio
(
audio_url
,
uuid
)
def
parse_video
(
self
,
video_url
:
str
)
->
None
:
def
parse_video
(
self
,
video_url
:
str
,
uuid
:
Optional
[
str
]
=
None
)
->
None
:
video
=
self
.
_connector
.
fetch_video_async
(
video_url
=
video_url
)
video
=
self
.
_connector
.
fetch_video_async
(
video_url
=
video_url
)
placeholder
=
self
.
_tracker
.
add
(
"video"
,
video
)
placeholder
=
self
.
_tracker
.
add
(
"video"
,
video
,
uuid
)
self
.
_add_placeholder
(
"video"
,
placeholder
)
self
.
_add_placeholder
(
"video"
,
placeholder
)
...
@@ -809,20 +929,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
...
@@ -809,20 +929,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
return
return
elif
isinstance
(
chat_template
,
Path
)
and
not
chat_template
.
exists
():
elif
isinstance
(
chat_template
,
Path
)
and
not
chat_template
.
exists
():
raise
FileNotFoundError
(
raise
FileNotFoundError
(
"the supplied chat template path doesn't exist"
)
"the supplied chat template path doesn't exist"
)
elif
isinstance
(
chat_template
,
str
):
elif
isinstance
(
chat_template
,
str
):
JINJA_CHARS
=
"{}
\n
"
JINJA_CHARS
=
"{}
\n
"
if
not
any
(
c
in
chat_template
if
(
for
c
in
JINJA_CHARS
)
and
not
Path
(
chat_template
).
exists
():
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
)
and
not
Path
(
chat_template
).
exists
()
):
raise
ValueError
(
raise
ValueError
(
f
"The supplied chat template string (
{
chat_template
}
) "
f
"The supplied chat template string (
{
chat_template
}
) "
f
"appears path-like, but doesn't exist!"
)
f
"appears path-like, but doesn't exist!"
)
else
:
else
:
raise
TypeError
(
raise
TypeError
(
f
"
{
type
(
chat_template
)
}
is not a valid chat template type"
)
f
"
{
type
(
chat_template
)
}
is not a valid chat template type"
)
def
_load_chat_template
(
def
_load_chat_template
(
...
@@ -835,8 +958,9 @@ def _load_chat_template(
...
@@ -835,8 +958,9 @@ def _load_chat_template(
if
is_literal
:
if
is_literal
:
if
isinstance
(
chat_template
,
Path
):
if
isinstance
(
chat_template
,
Path
):
raise
TypeError
(
"chat_template is expected to be read directly "
raise
TypeError
(
"from its value"
)
"chat_template is expected to be read directly from its value"
)
return
chat_template
return
chat_template
...
@@ -849,9 +973,11 @@ def _load_chat_template(
...
@@ -849,9 +973,11 @@ def _load_chat_template(
JINJA_CHARS
=
"{}
\n
"
JINJA_CHARS
=
"{}
\n
"
if
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
):
if
not
any
(
c
in
chat_template
for
c
in
JINJA_CHARS
):
msg
=
(
f
"The supplied chat template (
{
chat_template
}
) "
msg
=
(
f
"looks like a file path, but it failed to be "
f
"The supplied chat template (
{
chat_template
}
) "
f
"opened. Reason:
{
e
}
"
)
f
"looks like a file path, but it failed to be "
f
"opened. Reason:
{
e
}
"
)
raise
ValueError
(
msg
)
from
e
raise
ValueError
(
msg
)
from
e
# If opening a file fails, set chat template to be args to
# If opening a file fails, set chat template to be args to
...
@@ -870,8 +996,9 @@ def load_chat_template(
...
@@ -870,8 +996,9 @@ def load_chat_template(
return
_cached_load_chat_template
(
chat_template
,
is_literal
=
is_literal
)
return
_cached_load_chat_template
(
chat_template
,
is_literal
=
is_literal
)
def
_get_interleaved_text_prompt
(
placeholder_storage
:
dict
[
str
,
list
],
def
_get_interleaved_text_prompt
(
texts
:
list
[
str
])
->
str
:
placeholder_storage
:
dict
[
str
,
list
],
texts
:
list
[
str
]
)
->
str
:
for
idx
,
elem
in
enumerate
(
texts
):
for
idx
,
elem
in
enumerate
(
texts
):
if
elem
in
placeholder_storage
:
if
elem
in
placeholder_storage
:
texts
[
idx
]
=
placeholder_storage
[
elem
].
pop
(
0
)
texts
[
idx
]
=
placeholder_storage
[
elem
].
pop
(
0
)
...
@@ -881,10 +1008,11 @@ def _get_interleaved_text_prompt(placeholder_storage: dict[str, list],
...
@@ -881,10 +1008,11 @@ def _get_interleaved_text_prompt(placeholder_storage: dict[str, list],
# TODO: Let user specify how to insert multimodal tokens into prompt
# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
# (similar to chat template)
def
_get_full_multimodal_text_prompt
(
placeholder_storage
:
dict
[
str
,
list
],
def
_get_full_multimodal_text_prompt
(
texts
:
list
[
str
],
placeholder_storage
:
dict
[
str
,
list
],
interleave_strings
:
bool
texts
:
list
[
str
],
)
->
str
:
interleave_strings
:
bool
,
)
->
str
:
"""Combine multimodal prompts for a multimodal language model."""
"""Combine multimodal prompts for a multimodal language model."""
# flatten storage to make it looks like
# flatten storage to make it looks like
...
@@ -907,7 +1035,6 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
...
@@ -907,7 +1035,6 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
# Look through the text prompt to check for missing placeholders
# Look through the text prompt to check for missing placeholders
missing_placeholders
:
list
[
str
]
=
[]
missing_placeholders
:
list
[
str
]
=
[]
for
placeholder
in
placeholder_counts
:
for
placeholder
in
placeholder_counts
:
# For any existing placeholder in the text prompt, we leave it as is
# For any existing placeholder in the text prompt, we leave it as is
placeholder_counts
[
placeholder
]
-=
text_prompt
.
count
(
placeholder
)
placeholder_counts
[
placeholder
]
-=
text_prompt
.
count
(
placeholder
)
...
@@ -916,15 +1043,18 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
...
@@ -916,15 +1043,18 @@ def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list],
"Placeholder count is negative! "
"Placeholder count is negative! "
"Ensure that the 'interleave_strings' flag is disabled "
"Ensure that the 'interleave_strings' flag is disabled "
"(current value: %s) "
"(current value: %s) "
"when manually placing image placeholders."
,
interleave_strings
"when manually placing image placeholders."
,
interleave_strings
,
)
)
logger
.
debug
(
"Input prompt: %s"
,
text_prompt
)
logger
.
debug
(
"Input prompt: %s"
,
text_prompt
)
raise
ValueError
(
raise
ValueError
(
f
"Found more '
{
placeholder
}
' placeholders in input prompt than "
f
"Found more '
{
placeholder
}
' placeholders in input prompt than "
"actual multimodal data items."
)
"actual multimodal data items."
)
missing_placeholders
.
extend
([
placeholder
]
*
missing_placeholders
.
extend
(
placeholder_counts
[
placeholder
])
[
placeholder
]
*
placeholder_counts
[
placeholder
]
)
# NOTE: Default behaviour: we always add missing placeholders
# NOTE: Default behaviour: we always add missing placeholders
# at the front of the prompt, if interleave_strings=False
# at the front of the prompt, if interleave_strings=False
...
@@ -944,7 +1074,8 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
...
@@ -944,7 +1074,8 @@ _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser
=
TypeAdapter
(
ChatCompletionContentPartVideoParam
).
validate_python
_VideoParser
=
TypeAdapter
(
ChatCompletionContentPartVideoParam
).
validate_python
_ResponsesInputImageParser
=
TypeAdapter
(
_ResponsesInputImageParser
=
TypeAdapter
(
ResponseInputImageParam
).
validate_python
ResponseInputImageParam
).
validate_python
_ContentPart
:
TypeAlias
=
Union
[
str
,
dict
[
str
,
str
],
InputAudio
,
PILImage
]
_ContentPart
:
TypeAlias
=
Union
[
str
,
dict
[
str
,
str
],
InputAudio
,
PILImage
]
# Define a mapping from part types to their corresponding parsing functions.
# Define a mapping from part types to their corresponding parsing functions.
...
@@ -952,32 +1083,35 @@ MM_PARSER_MAP: dict[
...
@@ -952,32 +1083,35 @@ MM_PARSER_MAP: dict[
str
,
str
,
Callable
[[
ChatCompletionContentPartParam
],
_ContentPart
],
Callable
[[
ChatCompletionContentPartParam
],
_ContentPart
],
]
=
{
]
=
{
"text"
:
"text"
:
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
None
),
lambda
part
:
_T
ext
Parser
(
part
).
get
(
"t
ext
"
,
None
),
"thinking"
:
lambda
part
:
_T
hink
Parser
(
part
).
get
(
"t
hinking
"
,
None
),
"
thinking"
:
"
input_text"
:
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
None
),
lambda
part
:
_ThinkParser
(
part
).
get
(
"thinking"
,
None
),
"input_image"
:
lambda
part
:
_ResponsesInputImageParser
(
part
).
get
(
"i
nput_text"
:
"i
mage_url"
,
None
lambda
part
:
_TextParser
(
part
).
get
(
"text"
,
None
),
),
"i
nput_image"
:
"i
mage_url"
:
lambda
part
:
_ImageParser
(
part
)
lambda
part
:
_ResponsesInputImageParser
(
part
)
.
get
(
"image_url"
,
None
),
.
get
(
"image_url"
,
{})
"image_url"
:
.
get
(
"url"
,
None
),
lambda
part
:
_ImageParser
(
part
).
get
(
"image_url"
,
{}).
get
(
"url"
,
None
),
"image_embeds"
:
lambda
part
:
_Image
Embeds
Parser
(
part
).
get
(
"image_embeds"
:
"image_embeds"
,
None
lambda
part
:
_ImageEmbedsParser
(
part
).
get
(
"image_embeds"
,
None
),
),
"image_pil"
:
lambda
part
:
_PILImageParser
(
part
).
get
(
"image_pil"
,
None
),
"image_pil"
:
lambda
part
:
_PILImageParser
(
part
).
get
(
"image_pil"
,
None
),
"audio_url"
:
"audio_url"
:
lambda
part
:
_AudioParser
(
part
)
lambda
part
:
_AudioParser
(
part
).
get
(
"audio_url"
,
{}).
get
(
"url"
,
None
),
.
get
(
"audio_url"
,
{})
"input_audio"
:
.
get
(
"url"
,
None
),
lambda
part
:
_InputAudioParser
(
part
).
get
(
"input_audio"
,
None
),
"input_audio"
:
lambda
part
:
_InputAudioParser
(
part
).
get
(
"refusal"
:
"input_audio"
,
None
lambda
part
:
_RefusalParser
(
part
).
get
(
"refusal"
,
None
),
),
"video_url"
:
"refusal"
:
lambda
part
:
_RefusalParser
(
part
).
get
(
"refusal"
,
None
),
lambda
part
:
_VideoParser
(
part
).
get
(
"video_url"
,
{}).
get
(
"url"
,
None
),
"video_url"
:
lambda
part
:
_VideoParser
(
part
)
.
get
(
"video_url"
,
{})
.
get
(
"url"
,
None
),
}
}
def
_parse_chat_message_content_mm_part
(
def
_parse_chat_message_content_mm_part
(
part
:
ChatCompletionContentPartParam
)
->
tuple
[
str
,
_ContentPart
]:
part
:
ChatCompletionContentPartParam
,
)
->
tuple
[
str
,
_ContentPart
]:
"""
"""
Parses a given multi-modal content part based on its type.
Parses a given multi-modal content part based on its type.
...
@@ -993,7 +1127,8 @@ def _parse_chat_message_content_mm_part(
...
@@ -993,7 +1127,8 @@ def _parse_chat_message_content_mm_part(
ValueError: If the 'type' field is missing and no direct URL is found.
ValueError: If the 'type' field is missing and no direct URL is found.
"""
"""
assert
isinstance
(
assert
isinstance
(
part
,
dict
)
# This is needed to avoid mypy errors: part.get() from str
part
,
dict
)
# This is needed to avoid mypy errors: part.get() from str
part_type
=
part
.
get
(
"type"
,
None
)
part_type
=
part
.
get
(
"type"
,
None
)
if
isinstance
(
part_type
,
str
)
and
part_type
in
MM_PARSER_MAP
:
if
isinstance
(
part_type
,
str
)
and
part_type
in
MM_PARSER_MAP
:
...
@@ -1002,8 +1137,10 @@ def _parse_chat_message_content_mm_part(
...
@@ -1002,8 +1137,10 @@ def _parse_chat_message_content_mm_part(
# Special case for 'image_url.detail'
# Special case for 'image_url.detail'
# We only support 'auto', which is the default
# We only support 'auto', which is the default
if
part_type
==
"image_url"
and
part
.
get
(
"detail"
,
"auto"
)
!=
"auto"
:
if
part_type
==
"image_url"
and
part
.
get
(
"detail"
,
"auto"
)
!=
"auto"
:
logger
.
warning
(
"'image_url.detail' is currently not supported "
logger
.
warning
(
"and will be ignored."
)
"'image_url.detail' is currently not supported "
"and will be ignored."
)
return
part_type
,
content
return
part_type
,
content
...
@@ -1011,19 +1148,22 @@ def _parse_chat_message_content_mm_part(
...
@@ -1011,19 +1148,22 @@ def _parse_chat_message_content_mm_part(
# 'type' is required field by pydantic
# 'type' is required field by pydantic
if
part_type
is
None
:
if
part_type
is
None
:
if
part
.
get
(
"image_url"
)
is
not
None
:
if
part
.
get
(
"image_url"
)
is
not
None
:
image_params
=
cast
(
CustomChatCompletionContentSimpleImageParam
,
image_params
=
cast
(
part
)
CustomChatCompletionContentSimpleImageParam
,
part
)
return
"image_url"
,
image_params
.
get
(
"image_url"
,
""
)
return
"image_url"
,
image_params
.
get
(
"image_url"
,
""
)
if
part
.
get
(
"audio_url"
)
is
not
None
:
if
part
.
get
(
"audio_url"
)
is
not
None
:
audio_params
=
cast
(
CustomChatCompletionContentSimpleAudioParam
,
audio_params
=
cast
(
part
)
CustomChatCompletionContentSimpleAudioParam
,
part
)
return
"audio_url"
,
audio_params
.
get
(
"audio_url"
,
""
)
return
"audio_url"
,
audio_params
.
get
(
"audio_url"
,
""
)
if
part
.
get
(
"input_audio"
)
is
not
None
:
if
part
.
get
(
"input_audio"
)
is
not
None
:
input_audio_params
=
cast
(
dict
[
str
,
str
],
part
)
input_audio_params
=
cast
(
dict
[
str
,
str
],
part
)
return
"input_audio"
,
input_audio_params
return
"input_audio"
,
input_audio_params
if
part
.
get
(
"video_url"
)
is
not
None
:
if
part
.
get
(
"video_url"
)
is
not
None
:
video_params
=
cast
(
CustomChatCompletionContentSimpleVideoParam
,
video_params
=
cast
(
part
)
CustomChatCompletionContentSimpleVideoParam
,
part
)
return
"video_url"
,
video_params
.
get
(
"video_url"
,
""
)
return
"video_url"
,
video_params
.
get
(
"video_url"
,
""
)
# Raise an error if no 'type' or direct URL is found.
# Raise an error if no 'type' or direct URL is found.
raise
ValueError
(
"Missing 'type' field in multimodal part."
)
raise
ValueError
(
"Missing 'type' field in multimodal part."
)
...
@@ -1033,9 +1173,16 @@ def _parse_chat_message_content_mm_part(
...
@@ -1033,9 +1173,16 @@ def _parse_chat_message_content_mm_part(
return
part_type
,
"unknown part_type content"
return
part_type
,
"unknown part_type content"
VALID_MESSAGE_CONTENT_MM_PART_TYPES
=
(
"text"
,
"refusal"
,
"image_url"
,
VALID_MESSAGE_CONTENT_MM_PART_TYPES
=
(
"image_embeds"
,
"image_pil"
,
"text"
,
"audio_url"
,
"input_audio"
,
"video_url"
)
"refusal"
,
"image_url"
,
"image_embeds"
,
"image_pil"
,
"audio_url"
,
"input_audio"
,
"video_url"
,
)
def
_parse_chat_message_content_parts
(
def
_parse_chat_message_content_parts
(
...
@@ -1055,21 +1202,20 @@ def _parse_chat_message_content_parts(
...
@@ -1055,21 +1202,20 @@ def _parse_chat_message_content_parts(
part
,
part
,
mm_parser
,
mm_parser
,
wrap_dicts
=
wrap_dicts
,
wrap_dicts
=
wrap_dicts
,
interleave_strings
=
interleave_strings
interleave_strings
=
interleave_strings
,
)
)
if
parse_res
:
if
parse_res
:
content
.
append
(
parse_res
)
content
.
append
(
parse_res
)
if
wrap_dicts
:
if
wrap_dicts
:
# Parsing wraps images and texts as interleaved dictionaries
# Parsing wraps images and texts as interleaved dictionaries
return
[
ConversationMessage
(
role
=
role
,
return
[
ConversationMessage
(
role
=
role
,
content
=
content
)]
# type: ignore
content
=
content
)]
# type: ignore
texts
=
cast
(
list
[
str
],
content
)
texts
=
cast
(
list
[
str
],
content
)
mm_placeholder_storage
=
mm_parser
.
mm_placeholder_storage
()
mm_placeholder_storage
=
mm_parser
.
mm_placeholder_storage
()
if
mm_placeholder_storage
:
if
mm_placeholder_storage
:
text_prompt
=
_get_full_multimodal_text_prompt
(
mm_placeholder_storage
,
text_prompt
=
_get_full_multimodal_text_prompt
(
texts
,
mm_placeholder_storage
,
texts
,
interleave_strings
interleave_strings
)
)
else
:
else
:
text_prompt
=
"
\n
"
.
join
(
texts
)
text_prompt
=
"
\n
"
.
join
(
texts
)
...
@@ -1099,46 +1245,59 @@ def _parse_chat_message_content_part(
...
@@ -1099,46 +1245,59 @@ def _parse_chat_message_content_part(
if
part_type
in
VALID_MESSAGE_CONTENT_MM_PART_TYPES
and
content
is
None
:
if
part_type
in
VALID_MESSAGE_CONTENT_MM_PART_TYPES
and
content
is
None
:
logger
.
warning
(
logger
.
warning
(
"Skipping multimodal part '%s' (type: '%s') "
"Skipping multimodal part '%s' (type: '%s') "
"with empty / unparsable content."
,
part
,
part_type
)
"with empty / unparsable content."
,
part
,
part_type
,
)
return
None
return
None
if
part_type
in
(
"text"
,
"input_text"
,
"refusal"
,
"thinking"
):
if
part_type
in
(
"text"
,
"input_text"
,
"refusal"
,
"thinking"
):
str_content
=
cast
(
str
,
content
)
str_content
=
cast
(
str
,
content
)
if
wrap_dicts
:
if
wrap_dicts
:
return
{
'
type
'
:
'
text
'
,
'
text
'
:
str_content
}
return
{
"
type
"
:
"
text
"
,
"
text
"
:
str_content
}
else
:
else
:
return
str_content
return
str_content
# For media items, if a user has provided one, use it. Otherwise, insert
# a placeholder empty uuid.
uuid
=
part
.
get
(
"uuid"
,
None
)
if
uuid
is
not
None
:
uuid
=
str
(
uuid
)
modality
=
None
modality
=
None
if
part_type
==
"image_pil"
:
if
part_type
==
"image_pil"
:
image_content
=
cast
(
Image
.
Image
,
content
)
image_content
=
cast
(
Image
.
Image
,
content
)
mm_parser
.
parse_image_pil
(
image_content
)
mm_parser
.
parse_image_pil
(
image_content
,
uuid
)
modality
=
"image"
modality
=
"image"
elif
part_type
in
(
"image_url"
,
"input_image"
):
elif
part_type
in
(
"image_url"
,
"input_image"
):
str_content
=
cast
(
str
,
content
)
str_content
=
cast
(
str
,
content
)
mm_parser
.
parse_image
(
str_content
)
mm_parser
.
parse_image
(
str_content
,
uuid
)
modality
=
"image"
modality
=
"image"
elif
part_type
==
"image_embeds"
:
elif
part_type
==
"image_embeds"
:
content
=
cast
(
Union
[
str
,
dict
[
str
,
str
]],
content
)
content
=
cast
(
Union
[
str
,
dict
[
str
,
str
]],
content
)
mm_parser
.
parse_image_embeds
(
content
)
mm_parser
.
parse_image_embeds
(
content
,
uuid
)
modality
=
"image"
modality
=
"image"
elif
part_type
==
"audio_url"
:
elif
part_type
==
"audio_url"
:
str_content
=
cast
(
str
,
content
)
str_content
=
cast
(
str
,
content
)
mm_parser
.
parse_audio
(
str_content
)
mm_parser
.
parse_audio
(
str_content
,
uuid
)
modality
=
"audio"
modality
=
"audio"
elif
part_type
==
"input_audio"
:
elif
part_type
==
"input_audio"
:
dict_content
=
cast
(
InputAudio
,
content
)
dict_content
=
cast
(
InputAudio
,
content
)
mm_parser
.
parse_input_audio
(
dict_content
)
mm_parser
.
parse_input_audio
(
dict_content
,
uuid
)
modality
=
"audio"
modality
=
"audio"
elif
part_type
==
"video_url"
:
elif
part_type
==
"video_url"
:
str_content
=
cast
(
str
,
content
)
str_content
=
cast
(
str
,
content
)
mm_parser
.
parse_video
(
str_content
)
mm_parser
.
parse_video
(
str_content
,
uuid
)
modality
=
"video"
modality
=
"video"
else
:
else
:
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
raise
NotImplementedError
(
f
"Unknown part type:
{
part_type
}
"
)
return
{
'type'
:
modality
}
if
wrap_dicts
else
(
return
(
MODALITY_PLACEHOLDERS_MAP
[
modality
]
if
interleave_strings
else
None
{
"type"
:
modality
}
if
wrap_dicts
else
(
MODALITY_PLACEHOLDERS_MAP
[
modality
]
if
interleave_strings
else
None
)
)
)
...
@@ -1171,14 +1330,16 @@ def _parse_chat_message_content(
...
@@ -1171,14 +1330,16 @@ def _parse_chat_message_content(
)
)
for
result_msg
in
result
:
for
result_msg
in
result
:
if
role
==
'
assistant
'
:
if
role
==
"
assistant
"
:
parsed_msg
=
_AssistantParser
(
message
)
parsed_msg
=
_AssistantParser
(
message
)
# The 'tool_calls' is not None check ensures compatibility.
# The 'tool_calls' is not None check ensures compatibility.
# It's needed only if downstream code doesn't strictly
# It's needed only if downstream code doesn't strictly
# follow the OpenAI spec.
# follow the OpenAI spec.
if
(
"tool_calls"
in
parsed_msg
if
(
and
parsed_msg
[
"tool_calls"
]
is
not
None
):
"tool_calls"
in
parsed_msg
and
parsed_msg
[
"tool_calls"
]
is
not
None
):
result_msg
[
"tool_calls"
]
=
list
(
parsed_msg
[
"tool_calls"
])
result_msg
[
"tool_calls"
]
=
list
(
parsed_msg
[
"tool_calls"
])
elif
role
==
"tool"
:
elif
role
==
"tool"
:
parsed_msg
=
_ToolParser
(
message
)
parsed_msg
=
_ToolParser
(
message
)
...
@@ -1198,12 +1359,15 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
...
@@ -1198,12 +1359,15 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
# so, for messages that have tool_calls, parse the string (which we get
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
# from openAI format) to dict
for
message
in
messages
:
for
message
in
messages
:
if
(
message
[
"role"
]
==
"assistant"
and
"tool_calls"
in
message
if
(
and
isinstance
(
message
[
"tool_calls"
],
list
)):
message
[
"role"
]
==
"assistant"
and
"tool_calls"
in
message
and
isinstance
(
message
[
"tool_calls"
],
list
)
):
for
item
in
message
[
"tool_calls"
]:
for
item
in
message
[
"tool_calls"
]:
item
[
"function"
][
"arguments"
]
=
json
.
loads
(
item
[
"function"
][
"arguments"
]
=
json
.
loads
(
item
[
"function"
][
"arguments"
])
item
[
"function"
][
"arguments"
]
)
def
parse_chat_messages
(
def
parse_chat_messages
(
...
@@ -1211,7 +1375,11 @@ def parse_chat_messages(
...
@@ -1211,7 +1375,11 @@ def parse_chat_messages(
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
content_format
:
_ChatTemplateContentFormat
,
content_format
:
_ChatTemplateContentFormat
,
)
->
tuple
[
list
[
ConversationMessage
],
Optional
[
MultiModalDataDict
]]:
)
->
tuple
[
list
[
ConversationMessage
],
Optional
[
MultiModalDataDict
],
Optional
[
MultiModalUUIDDict
],
]:
conversation
:
list
[
ConversationMessage
]
=
[]
conversation
:
list
[
ConversationMessage
]
=
[]
mm_tracker
=
MultiModalItemTracker
(
model_config
,
tokenizer
)
mm_tracker
=
MultiModalItemTracker
(
model_config
,
tokenizer
)
...
@@ -1224,14 +1392,14 @@ def parse_chat_messages(
...
@@ -1224,14 +1392,14 @@ def parse_chat_messages(
content_format
==
"string"
content_format
==
"string"
and
model_config
.
multimodal_config
is
not
None
and
model_config
.
multimodal_config
is
not
None
and
model_config
.
multimodal_config
.
interleave_mm_strings
and
model_config
.
multimodal_config
.
interleave_mm_strings
)
)
,
)
)
conversation
.
extend
(
sub_messages
)
conversation
.
extend
(
sub_messages
)
_postprocess_messages
(
conversation
)
_postprocess_messages
(
conversation
)
return
conversation
,
mm_tracker
.
all_mm_data
()
return
conversation
,
mm_tracker
.
all_mm_data
()
,
mm_tracker
.
all_mm_uuids
()
def
parse_chat_messages_futures
(
def
parse_chat_messages_futures
(
...
@@ -1239,7 +1407,11 @@ def parse_chat_messages_futures(
...
@@ -1239,7 +1407,11 @@ def parse_chat_messages_futures(
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
tokenizer
:
AnyTokenizer
,
tokenizer
:
AnyTokenizer
,
content_format
:
_ChatTemplateContentFormat
,
content_format
:
_ChatTemplateContentFormat
,
)
->
tuple
[
list
[
ConversationMessage
],
Awaitable
[
Optional
[
MultiModalDataDict
]]]:
)
->
tuple
[
list
[
ConversationMessage
],
Awaitable
[
Optional
[
MultiModalDataDict
]],
Optional
[
MultiModalUUIDDict
],
]:
conversation
:
list
[
ConversationMessage
]
=
[]
conversation
:
list
[
ConversationMessage
]
=
[]
mm_tracker
=
AsyncMultiModalItemTracker
(
model_config
,
tokenizer
)
mm_tracker
=
AsyncMultiModalItemTracker
(
model_config
,
tokenizer
)
...
@@ -1252,14 +1424,14 @@ def parse_chat_messages_futures(
...
@@ -1252,14 +1424,14 @@ def parse_chat_messages_futures(
content_format
==
"string"
content_format
==
"string"
and
model_config
.
multimodal_config
is
not
None
and
model_config
.
multimodal_config
is
not
None
and
model_config
.
multimodal_config
.
interleave_mm_strings
and
model_config
.
multimodal_config
.
interleave_mm_strings
)
)
,
)
)
conversation
.
extend
(
sub_messages
)
conversation
.
extend
(
sub_messages
)
_postprocess_messages
(
conversation
)
_postprocess_messages
(
conversation
)
return
conversation
,
mm_tracker
.
all_mm_data
()
return
conversation
,
mm_tracker
.
all_mm_data
()
,
mm_tracker
.
all_mm_uuids
()
def
apply_hf_chat_template
(
def
apply_hf_chat_template
(
...
@@ -1283,10 +1455,10 @@ def apply_hf_chat_template(
...
@@ -1283,10 +1455,10 @@ def apply_hf_chat_template(
raise
ValueError
(
raise
ValueError
(
"As of transformers v4.44, default chat template is no longer "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."
)
"does not define one."
)
try
:
try
:
return
tokenizer
.
apply_chat_template
(
return
tokenizer
.
apply_chat_template
(
conversation
=
conversation
,
# type: ignore[arg-type]
conversation
=
conversation
,
# type: ignore[arg-type]
tools
=
tools
,
# type: ignore[arg-type]
tools
=
tools
,
# type: ignore[arg-type]
...
@@ -1298,13 +1470,14 @@ def apply_hf_chat_template(
...
@@ -1298,13 +1470,14 @@ def apply_hf_chat_template(
# External library exceptions can sometimes occur despite the framework's
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
# internal exception management capabilities.
except
Exception
as
e
:
except
Exception
as
e
:
# Log and report any library-related exceptions for further
# Log and report any library-related exceptions for further
# investigation.
# investigation.
logger
.
exception
(
logger
.
exception
(
"An error occurred in `transformers` while applying chat template"
)
"An error occurred in `transformers` while applying chat template"
)
raise
ValueError
(
str
(
e
))
from
e
raise
ValueError
(
str
(
e
))
from
e
def
apply_mistral_chat_template
(
def
apply_mistral_chat_template
(
tokenizer
:
MistralTokenizer
,
tokenizer
:
MistralTokenizer
,
messages
:
list
[
ChatCompletionMessageParam
],
messages
:
list
[
ChatCompletionMessageParam
],
...
@@ -1337,26 +1510,26 @@ def apply_mistral_chat_template(
...
@@ -1337,26 +1510,26 @@ def apply_mistral_chat_template(
# External library exceptions can sometimes occur despite the framework's
# External library exceptions can sometimes occur despite the framework's
# internal exception management capabilities.
# internal exception management capabilities.
except
Exception
as
e
:
except
Exception
as
e
:
# Log and report any library-related exceptions for further
# Log and report any library-related exceptions for further
# investigation.
# investigation.
logger
.
exception
(
logger
.
exception
(
"An error occurred in `mistral_common` while applying chat "
"An error occurred in `mistral_common` while applying chat
template
"
"template"
)
)
raise
ValueError
(
str
(
e
))
from
e
raise
ValueError
(
str
(
e
))
from
e
def
get_history_tool_calls_cnt
(
conversation
:
list
[
ConversationMessage
]):
def
get_history_tool_calls_cnt
(
conversation
:
list
[
ConversationMessage
]):
idx
=
0
idx
=
0
for
msg
in
conversation
:
for
msg
in
conversation
:
if
msg
[
'
role
'
]
==
'
assistant
'
:
if
msg
[
"
role
"
]
==
"
assistant
"
:
tool_calls
=
msg
.
get
(
'
tool_calls
'
)
tool_calls
=
msg
.
get
(
"
tool_calls
"
)
idx
+=
len
(
list
(
tool_calls
))
if
tool_calls
is
not
None
else
0
# noqa
idx
+=
len
(
list
(
tool_calls
))
if
tool_calls
is
not
None
else
0
# noqa
return
idx
return
idx
def
make_tool_call_id
(
id_type
:
str
=
'random'
,
func_name
=
None
,
idx
=
None
):
if
id_type
==
'kimi_k2'
:
def
make_tool_call_id
(
id_type
:
str
=
"random"
,
func_name
=
None
,
idx
=
None
):
return
f
'functions.
{
func_name
}
:
{
idx
}
'
if
id_type
==
"kimi_k2"
:
return
f
"functions.
{
func_name
}
:
{
idx
}
"
else
:
else
:
# by default return random
# by default return random
return
f
"chatcmpl-tool-
{
random_uuid
()
}
"
return
f
"chatcmpl-tool-
{
random_uuid
()
}
"
vllm/entrypoints/context.py
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
contextlib
import
json
import
json
import
logging
import
logging
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Sequence
from
contextlib
import
AsyncExitStack
from
contextlib
import
AsyncExitStack
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
...
@@ -21,6 +22,23 @@ if TYPE_CHECKING:
...
@@ -21,6 +22,23 @@ if TYPE_CHECKING:
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
TurnTokens
:
"""Tracks token counts for a single conversation turn."""
def
__init__
(
self
,
input_tokens
=
0
,
output_tokens
=
0
):
self
.
input_tokens
=
input_tokens
self
.
output_tokens
=
output_tokens
def
reset
(
self
):
"""Reset counters for a new turn."""
self
.
input_tokens
=
0
self
.
output_tokens
=
0
def
copy
(
self
):
"""Create a copy of this turn's token counts."""
return
TurnTokens
(
self
.
input_tokens
,
self
.
output_tokens
)
class
ConversationContext
(
ABC
):
class
ConversationContext
(
ABC
):
@
abstractmethod
@
abstractmethod
...
@@ -41,17 +59,32 @@ class ConversationContext(ABC):
...
@@ -41,17 +59,32 @@ class ConversationContext(ABC):
@
abstractmethod
@
abstractmethod
async
def
init_tool_sessions
(
self
,
tool_server
:
Optional
[
ToolServer
],
async
def
init_tool_sessions
(
self
,
tool_server
:
Optional
[
ToolServer
],
exit_stack
:
AsyncExitStack
)
->
None
:
exit_stack
:
AsyncExitStack
,
request_id
:
str
)
->
None
:
pass
pass
@
abstractmethod
async
def
cleanup_session
(
self
)
->
None
:
raise
NotImplementedError
(
"Should not be called."
)
class
SimpleContext
(
ConversationContext
):
class
SimpleContext
(
ConversationContext
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
last_output
=
None
self
.
last_output
=
None
self
.
num_prompt_tokens
=
0
self
.
num_output_tokens
=
0
self
.
num_cached_tokens
=
0
# todo num_reasoning_tokens is not implemented yet.
self
.
num_reasoning_tokens
=
0
def
append_output
(
self
,
output
)
->
None
:
def
append_output
(
self
,
output
)
->
None
:
self
.
last_output
=
output
self
.
last_output
=
output
if
not
isinstance
(
output
,
RequestOutput
):
raise
ValueError
(
"SimpleContext only supports RequestOutput."
)
self
.
num_prompt_tokens
=
len
(
output
.
prompt_token_ids
or
[])
self
.
num_cached_tokens
=
output
.
num_cached_tokens
or
0
self
.
num_output_tokens
+=
len
(
output
.
outputs
[
0
].
token_ids
or
[])
def
need_builtin_tool_call
(
self
)
->
bool
:
def
need_builtin_tool_call
(
self
)
->
bool
:
return
False
return
False
...
@@ -63,9 +96,13 @@ class SimpleContext(ConversationContext):
...
@@ -63,9 +96,13 @@ class SimpleContext(ConversationContext):
raise
NotImplementedError
(
"Should not be called."
)
raise
NotImplementedError
(
"Should not be called."
)
async
def
init_tool_sessions
(
self
,
tool_server
:
Optional
[
ToolServer
],
async
def
init_tool_sessions
(
self
,
tool_server
:
Optional
[
ToolServer
],
exit_stack
:
AsyncExitStack
)
->
None
:
exit_stack
:
AsyncExitStack
,
request_id
:
str
)
->
None
:
pass
pass
async
def
cleanup_session
(
self
)
->
None
:
raise
NotImplementedError
(
"Should not be called."
)
class
HarmonyContext
(
ConversationContext
):
class
HarmonyContext
(
ConversationContext
):
...
@@ -77,39 +114,130 @@ class HarmonyContext(ConversationContext):
...
@@ -77,39 +114,130 @@ class HarmonyContext(ConversationContext):
self
.
_messages
=
messages
self
.
_messages
=
messages
self
.
available_tools
=
available_tools
self
.
available_tools
=
available_tools
self
.
_tool_sessions
:
dict
[
str
,
Union
[
ClientSession
,
Tool
]]
=
{}
self
.
_tool_sessions
:
dict
[
str
,
Union
[
ClientSession
,
Tool
]]
=
{}
self
.
called_tools
:
set
[
str
]
=
set
()
self
.
parser
=
get_streamable_parser_for_assistant
()
self
.
parser
=
get_streamable_parser_for_assistant
()
self
.
num_init_messages
=
len
(
messages
)
self
.
num_init_messages
=
len
(
messages
)
self
.
num_prompt_tokens
=
0
self
.
num_prompt_tokens
=
0
self
.
num_output_tokens
=
0
self
.
num_output_tokens
=
0
# TODO(woosuk): Implement the following fields.
self
.
num_cached_tokens
=
0
self
.
num_cached_tokens
=
0
self
.
num_reasoning_tokens
=
0
self
.
num_reasoning_tokens
=
0
self
.
num_tool_output_tokens
=
0
def
_update_num_prompt_tokens
(
self
,
output
:
RequestOutput
):
# Turn tracking - replaces multiple individual tracking variables
if
output
.
prompt_token_ids
and
len
(
output
.
prompt_token_ids
)
>
0
:
self
.
current_turn
=
TurnTokens
()
# NOTE: with built-in tools, there might be multiple rounds in
self
.
previous_turn
=
TurnTokens
()
# the conversation, with the full conversation being resent
self
.
is_first_turn
=
True
# as new prompt each time. Hence the sum.
self
.
first_tok_of_message
=
True
# For streaming support
self
.
num_prompt_tokens
+=
len
(
output
.
prompt_token_ids
)
def
_update_num_output_tokens
(
self
,
token_ids
:
Sequence
[
int
]):
def
_update_num_reasoning_tokens
(
self
):
self
.
num_output_tokens
+=
len
(
token_ids
)
# Count all analysis and commentary channels as reasoning tokens
if
self
.
parser
.
current_channel
in
{
"analysis"
,
"commentary"
}:
self
.
num_reasoning_tokens
+=
1
def
append_output
(
self
,
output
)
->
None
:
def
append_output
(
self
,
output
)
->
None
:
if
isinstance
(
output
,
RequestOutput
):
if
isinstance
(
output
,
RequestOutput
):
self
.
_update_num_prompt_tokens
(
output
)
output_token_ids
=
output
.
outputs
[
0
].
token_ids
output_token_ids
=
output
.
outputs
[
0
].
token_ids
self
.
_update_num_output_tokens
(
output_token_ids
)
self
.
parser
=
get_streamable_parser_for_assistant
()
self
.
parser
=
get_streamable_parser_for_assistant
()
for
token_id
in
output_token_ids
:
for
token_id
in
output_token_ids
:
self
.
parser
.
process
(
token_id
)
self
.
parser
.
process
(
token_id
)
# Check if the current token is part of reasoning content
self
.
_update_num_reasoning_tokens
()
self
.
_update_prefill_token_usage
(
output
)
# Reset current turn output tokens for this turn
self
.
current_turn
.
output_tokens
=
0
self
.
_update_decode_token_usage
(
output
)
# Move current turn to previous turn for next turn's calculations
self
.
previous_turn
=
self
.
current_turn
.
copy
()
output_msgs
=
self
.
parser
.
messages
output_msgs
=
self
.
parser
.
messages
else
:
else
:
# Tool output.
# Tool output.
output_msgs
=
output
output_msgs
=
output
self
.
_messages
.
extend
(
output_msgs
)
self
.
_messages
.
extend
(
output_msgs
)
def
_update_prefill_token_usage
(
self
,
output
:
RequestOutput
)
->
None
:
"""Update token usage statistics for the prefill phase of generation.
The prefill phase processes the input prompt tokens. This method:
1. Counts the prompt tokens for this turn
2. Calculates tool output tokens for multi-turn conversations
3. Updates cached token counts
4. Tracks state for next turn calculations
Tool output tokens are calculated as:
current_prompt_tokens - last_turn_prompt_tokens -
last_turn_output_tokens
This represents tokens added between turns (typically tool responses).
Args:
output: The RequestOutput containing prompt token information
"""
if
output
.
prompt_token_ids
is
not
None
:
this_turn_input_tokens
=
len
(
output
.
prompt_token_ids
)
else
:
this_turn_input_tokens
=
0
logger
.
error
(
"RequestOutput appended contains no prompt_token_ids."
)
# Update current turn input tokens
self
.
current_turn
.
input_tokens
=
this_turn_input_tokens
self
.
num_prompt_tokens
+=
this_turn_input_tokens
# Calculate tool tokens (except on first turn)
if
self
.
is_first_turn
:
self
.
is_first_turn
=
False
else
:
# start counting tool after first turn
# tool tokens = this turn prefill - last turn prefill -
# last turn decode
this_turn_tool_tokens
=
(
self
.
current_turn
.
input_tokens
-
self
.
previous_turn
.
input_tokens
-
self
.
previous_turn
.
output_tokens
)
# Handle negative tool token counts (shouldn't happen in normal
# cases)
if
this_turn_tool_tokens
<
0
:
logger
.
error
(
"Negative tool output tokens calculated: %d "
"(current_input=%d, previous_input=%d, "
"previous_output=%d). Setting to 0."
,
this_turn_tool_tokens
,
self
.
current_turn
.
input_tokens
,
self
.
previous_turn
.
input_tokens
,
self
.
previous_turn
.
output_tokens
)
this_turn_tool_tokens
=
0
self
.
num_tool_output_tokens
+=
this_turn_tool_tokens
# Update cached tokens
if
output
.
num_cached_tokens
is
not
None
:
self
.
num_cached_tokens
+=
output
.
num_cached_tokens
def
_update_decode_token_usage
(
self
,
output
:
RequestOutput
)
->
int
:
"""Update token usage statistics for the decode phase of generation.
The decode phase processes the generated output tokens. This method:
1. Counts output tokens from all completion outputs
2. Updates the total output token count
3. Tracks tokens generated in the current turn
In streaming mode, this is called for each token generated.
In non-streaming mode, this is called once with all output tokens.
Args:
output: The RequestOutput containing generated token information
Returns:
int: Number of output tokens processed in this call
"""
updated_output_token_count
=
0
if
output
.
outputs
:
for
completion_output
in
output
.
outputs
:
# only keep last round
updated_output_token_count
+=
len
(
completion_output
.
token_ids
)
self
.
num_output_tokens
+=
updated_output_token_count
self
.
current_turn
.
output_tokens
+=
updated_output_token_count
return
updated_output_token_count
@
property
@
property
def
messages
(
self
)
->
list
:
def
messages
(
self
)
->
list
:
return
self
.
_messages
return
self
.
_messages
...
@@ -118,7 +246,8 @@ class HarmonyContext(ConversationContext):
...
@@ -118,7 +246,8 @@ class HarmonyContext(ConversationContext):
last_msg
=
self
.
messages
[
-
1
]
last_msg
=
self
.
messages
[
-
1
]
recipient
=
last_msg
.
recipient
recipient
=
last_msg
.
recipient
return
recipient
is
not
None
and
(
recipient
.
startswith
(
"browser."
)
return
recipient
is
not
None
and
(
recipient
.
startswith
(
"browser."
)
or
recipient
.
startswith
(
"python"
))
or
recipient
.
startswith
(
"python"
)
or
recipient
.
startswith
(
"container."
))
async
def
call_tool
(
self
)
->
list
[
Message
]:
async
def
call_tool
(
self
)
->
list
[
Message
]:
if
not
self
.
messages
:
if
not
self
.
messages
:
...
@@ -132,6 +261,9 @@ class HarmonyContext(ConversationContext):
...
@@ -132,6 +261,9 @@ class HarmonyContext(ConversationContext):
elif
recipient
.
startswith
(
"python"
):
elif
recipient
.
startswith
(
"python"
):
return
await
self
.
call_python_tool
(
return
await
self
.
call_python_tool
(
self
.
_tool_sessions
[
"python"
],
last_msg
)
self
.
_tool_sessions
[
"python"
],
last_msg
)
elif
recipient
.
startswith
(
"container."
):
return
await
self
.
call_container_tool
(
self
.
_tool_sessions
[
"container"
],
last_msg
)
raise
ValueError
(
"No tool call found"
)
raise
ValueError
(
"No tool call found"
)
def
render_for_completion
(
self
)
->
list
[
int
]:
def
render_for_completion
(
self
)
->
list
[
int
]:
...
@@ -140,6 +272,7 @@ class HarmonyContext(ConversationContext):
...
@@ -140,6 +272,7 @@ class HarmonyContext(ConversationContext):
async
def
call_search_tool
(
self
,
tool_session
:
Union
[
"ClientSession"
,
async
def
call_search_tool
(
self
,
tool_session
:
Union
[
"ClientSession"
,
Tool
],
Tool
],
last_msg
:
Message
)
->
list
[
Message
]:
last_msg
:
Message
)
->
list
[
Message
]:
self
.
called_tools
.
add
(
"browser"
)
if
isinstance
(
tool_session
,
Tool
):
if
isinstance
(
tool_session
,
Tool
):
return
await
tool_session
.
get_result
(
self
)
return
await
tool_session
.
get_result
(
self
)
tool_name
=
last_msg
.
recipient
.
split
(
"."
)[
1
]
tool_name
=
last_msg
.
recipient
.
split
(
"."
)[
1
]
...
@@ -149,12 +282,16 @@ class HarmonyContext(ConversationContext):
...
@@ -149,12 +282,16 @@ class HarmonyContext(ConversationContext):
content
=
TextContent
(
text
=
result_str
)
content
=
TextContent
(
text
=
result_str
)
author
=
Author
(
role
=
Role
.
TOOL
,
name
=
last_msg
.
recipient
)
author
=
Author
(
role
=
Role
.
TOOL
,
name
=
last_msg
.
recipient
)
return
[
return
[
Message
(
author
=
author
,
content
=
[
content
],
recipient
=
Role
.
ASSISTANT
)
Message
(
author
=
author
,
content
=
[
content
],
recipient
=
Role
.
ASSISTANT
,
channel
=
last_msg
.
channel
)
]
]
async
def
call_python_tool
(
self
,
tool_session
:
Union
[
"ClientSession"
,
async
def
call_python_tool
(
self
,
tool_session
:
Union
[
"ClientSession"
,
Tool
],
Tool
],
last_msg
:
Message
)
->
list
[
Message
]:
last_msg
:
Message
)
->
list
[
Message
]:
self
.
called_tools
.
add
(
"python"
)
if
isinstance
(
tool_session
,
Tool
):
if
isinstance
(
tool_session
,
Tool
):
return
await
tool_session
.
get_result
(
self
)
return
await
tool_session
.
get_result
(
self
)
param
=
{
param
=
{
...
@@ -174,13 +311,63 @@ class HarmonyContext(ConversationContext):
...
@@ -174,13 +311,63 @@ class HarmonyContext(ConversationContext):
]
]
async
def
init_tool_sessions
(
self
,
tool_server
:
Optional
[
ToolServer
],
async
def
init_tool_sessions
(
self
,
tool_server
:
Optional
[
ToolServer
],
exit_stack
:
AsyncExitStack
)
->
None
:
exit_stack
:
AsyncExitStack
,
request_id
:
str
)
->
None
:
if
tool_server
:
if
tool_server
:
for
tool_name
in
self
.
available_tools
:
for
tool_name
in
self
.
available_tools
:
if
tool_name
not
in
self
.
_tool_sessions
:
if
tool_name
not
in
self
.
_tool_sessions
:
self
.
_tool_sessions
[
tool_session
=
await
exit_stack
.
enter_async_context
(
tool_name
]
=
await
exit_stack
.
enter_async_context
(
tool_server
.
new_session
(
tool_name
,
request_id
))
tool_server
.
new_session
(
tool_name
))
self
.
_tool_sessions
[
tool_name
]
=
tool_session
exit_stack
.
push_async_exit
(
self
.
cleanup_session
)
async
def
call_container_tool
(
self
,
tool_session
:
Union
[
"ClientSession"
,
Tool
],
last_msg
:
Message
)
->
list
[
Message
]:
"""
Call container tool. Expect this to be run in a stateful docker
with command line terminal.
The official container tool would at least
expect the following format:
- for tool name: exec
- args:
{
"cmd":List[str] "command to execute",
"workdir":optional[str] "current working directory",
"env":optional[object/dict] "environment variables",
"session_name":optional[str] "session name",
"timeout":optional[int] "timeout in seconds",
"user":optional[str] "user name",
}
"""
self
.
called_tools
.
add
(
"container"
)
if
isinstance
(
tool_session
,
Tool
):
return
await
tool_session
.
get_result
(
self
)
tool_name
=
last_msg
.
recipient
.
split
(
"."
)[
1
].
split
(
" "
)[
0
]
args
=
json
.
loads
(
last_msg
.
content
[
0
].
text
)
result
=
await
tool_session
.
call_tool
(
tool_name
,
args
)
result_str
=
result
.
content
[
0
].
text
content
=
TextContent
(
text
=
result_str
)
author
=
Author
(
role
=
Role
.
TOOL
,
name
=
last_msg
.
recipient
)
return
[
Message
(
author
=
author
,
content
=
[
content
],
recipient
=
Role
.
ASSISTANT
,
channel
=
last_msg
.
channel
)
]
async
def
cleanup_session
(
self
,
*
args
,
**
kwargs
)
->
None
:
"""Can be used as coro to used in __aexit__"""
async
def
cleanup_tool_session
(
tool_session
):
if
not
isinstance
(
tool_session
,
Tool
):
logger
.
info
(
"Cleaning up tool session for %s"
,
tool_session
.
_client_info
)
with
contextlib
.
suppress
(
Exception
):
await
tool_session
.
call_tool
(
"cleanup_session"
,
{})
await
asyncio
.
gather
(
*
(
cleanup_tool_session
(
self
.
_tool_sessions
[
tool
])
for
tool
in
self
.
called_tools
))
class
StreamingHarmonyContext
(
HarmonyContext
):
class
StreamingHarmonyContext
(
HarmonyContext
):
...
@@ -203,15 +390,22 @@ class StreamingHarmonyContext(HarmonyContext):
...
@@ -203,15 +390,22 @@ class StreamingHarmonyContext(HarmonyContext):
# append_output is called for each output token in streaming case,
# append_output is called for each output token in streaming case,
# so we only want to add the prompt tokens once for each message.
# so we only want to add the prompt tokens once for each message.
if
self
.
first_tok_of_message
:
if
self
.
first_tok_of_message
:
self
.
_update_num_prompt_tokens
(
output
)
self
.
_update_prefill_token_usage
(
output
)
self
.
current_turn
.
output_tokens
=
0
# Reset self.first_tok_of_message if needed:
# Reset self.first_tok_of_message if needed:
# if the current token is the last one of the current message
# if the current token is the last one of the current message
# (finished=True), then the next token processed will mark the
# (finished=True), then the next token processed will mark the
# beginning of a new message
# beginning of a new message
self
.
first_tok_of_message
=
output
.
finished
self
.
first_tok_of_message
=
output
.
finished
tok
=
output
.
outputs
[
0
].
token_ids
[
0
]
for
tok
in
output
.
outputs
[
0
].
token_ids
:
self
.
parser
.
process
(
tok
)
self
.
parser
.
process
(
tok
)
self
.
_update_num_output_tokens
(
output
.
outputs
[
0
].
token_ids
)
self
.
_update_decode_token_usage
(
output
)
# For streaming, update previous turn when message is complete
if
output
.
finished
:
self
.
previous_turn
=
self
.
current_turn
.
copy
()
# Check if the current token is part of reasoning content
self
.
_update_num_reasoning_tokens
()
self
.
last_tok
=
tok
self
.
last_tok
=
tok
else
:
else
:
# Handle the case of tool output in direct message format
# Handle the case of tool output in direct message format
...
...
vllm/entrypoints/harmony_utils.py
View file @
38d80967
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
datetime
import
datetime
import
json
import
json
from
collections.abc
import
Iterable
,
Sequence
from
collections.abc
import
Iterable
,
Sequence
...
@@ -13,12 +16,15 @@ from openai.types.responses.response_function_web_search import (
...
@@ -13,12 +16,15 @@ from openai.types.responses.response_function_web_search import (
from
openai.types.responses.response_reasoning_item
import
(
from
openai.types.responses.response_reasoning_item
import
(
Content
as
ResponseReasoningTextContent
)
Content
as
ResponseReasoningTextContent
)
from
openai.types.responses.tool
import
Tool
from
openai.types.responses.tool
import
Tool
from
openai_harmony
import
(
Author
,
Conversation
,
DeveloperContent
,
from
openai_harmony
import
(
Author
,
ChannelConfig
,
Conversation
,
HarmonyEncodingName
,
Message
,
ReasoningEffort
,
DeveloperContent
,
HarmonyEncodingName
,
Message
,
Role
,
StreamableParser
,
SystemContent
,
TextContent
,
ReasoningEffort
,
Role
,
StreamableParser
,
ToolDescription
,
load_harmony_encoding
)
SystemContent
,
TextContent
,
ToolDescription
,
load_harmony_encoding
)
from
vllm.entrypoints.openai.protocol
import
ResponseInputOutputItem
from
vllm
import
envs
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionToolsParam
,
ResponseInputOutputItem
)
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
REASONING_EFFORT
=
{
REASONING_EFFORT
=
{
...
@@ -29,6 +35,20 @@ REASONING_EFFORT = {
...
@@ -29,6 +35,20 @@ REASONING_EFFORT = {
_harmony_encoding
=
None
_harmony_encoding
=
None
# Builtin tools that should be included in the system message when
# they are available and requested by the user.
# Tool args are provided by MCP tool descriptions. Output
# of the tools are stringified.
BUILTIN_TOOLS
=
{
"web_search_preview"
,
"code_interpreter"
,
"container"
,
}
def
has_custom_tools
(
tool_types
:
list
[
str
])
->
bool
:
return
not
set
(
tool_types
).
issubset
(
BUILTIN_TOOLS
)
def
get_encoding
():
def
get_encoding
():
global
_harmony_encoding
global
_harmony_encoding
...
@@ -44,10 +64,19 @@ def get_system_message(
...
@@ -44,10 +64,19 @@ def get_system_message(
start_date
:
Optional
[
str
]
=
None
,
start_date
:
Optional
[
str
]
=
None
,
browser_description
:
Optional
[
str
]
=
None
,
browser_description
:
Optional
[
str
]
=
None
,
python_description
:
Optional
[
str
]
=
None
,
python_description
:
Optional
[
str
]
=
None
,
container_description
:
Optional
[
str
]
=
None
,
instructions
:
Optional
[
str
]
=
None
,
with_custom_tools
:
bool
=
False
,
)
->
Message
:
)
->
Message
:
sys_msg_content
=
SystemContent
.
new
()
sys_msg_content
=
SystemContent
.
new
()
if
model_identity
is
not
None
:
if
model_identity
is
not
None
:
sys_msg_content
=
sys_msg_content
.
with_model_identity
(
model_identity
)
sys_msg_content
=
sys_msg_content
.
with_model_identity
(
model_identity
)
if
(
instructions
is
not
None
and
envs
.
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS
):
current_identity
=
sys_msg_content
.
model_identity
new_identity
=
(
f
'
{
current_identity
}
\n
{
instructions
}
'
if
current_identity
else
instructions
)
sys_msg_content
=
sys_msg_content
.
with_model_identity
(
new_identity
)
if
reasoning_effort
is
not
None
:
if
reasoning_effort
is
not
None
:
sys_msg_content
=
sys_msg_content
.
with_reasoning_effort
(
sys_msg_content
=
sys_msg_content
.
with_reasoning_effort
(
REASONING_EFFORT
[
reasoning_effort
])
REASONING_EFFORT
[
reasoning_effort
])
...
@@ -59,32 +88,55 @@ def get_system_message(
...
@@ -59,32 +88,55 @@ def get_system_message(
sys_msg_content
=
sys_msg_content
.
with_tools
(
browser_description
)
sys_msg_content
=
sys_msg_content
.
with_tools
(
browser_description
)
if
python_description
is
not
None
:
if
python_description
is
not
None
:
sys_msg_content
=
sys_msg_content
.
with_tools
(
python_description
)
sys_msg_content
=
sys_msg_content
.
with_tools
(
python_description
)
if
container_description
is
not
None
:
sys_msg_content
=
sys_msg_content
.
with_tools
(
container_description
)
if
not
with_custom_tools
:
channel_config
=
sys_msg_content
.
channel_config
invalid_channel
=
"commentary"
new_config
=
ChannelConfig
.
require_channels
(
[
c
for
c
in
channel_config
.
valid_channels
if
c
!=
invalid_channel
])
sys_msg_content
=
sys_msg_content
.
with_channel_config
(
new_config
)
sys_msg
=
Message
.
from_role_and_content
(
Role
.
SYSTEM
,
sys_msg_content
)
sys_msg
=
Message
.
from_role_and_content
(
Role
.
SYSTEM
,
sys_msg_content
)
return
sys_msg
return
sys_msg
def
get_developer_message
(
instructions
:
Optional
[
str
]
=
None
,
def
create_tool_definition
(
tool
:
Union
[
ChatCompletionToolsParam
,
Tool
]):
tools
:
Optional
[
list
[
Tool
]]
=
None
)
->
Message
:
if
isinstance
(
tool
,
ChatCompletionToolsParam
):
return
ToolDescription
.
new
(
name
=
tool
.
function
.
name
,
description
=
tool
.
function
.
description
,
parameters
=
tool
.
function
.
parameters
,
)
return
ToolDescription
.
new
(
name
=
tool
.
name
,
description
=
tool
.
description
,
parameters
=
tool
.
parameters
,
)
def
get_developer_message
(
instructions
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
list
[
Union
[
Tool
,
ChatCompletionToolsParam
]]]
=
None
,
)
->
Message
:
dev_msg_content
=
DeveloperContent
.
new
()
dev_msg_content
=
DeveloperContent
.
new
()
if
instructions
is
not
None
:
if
(
instructions
is
not
None
and
not
envs
.
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS
):
dev_msg_content
=
dev_msg_content
.
with_instructions
(
instructions
)
dev_msg_content
=
dev_msg_content
.
with_instructions
(
instructions
)
if
tools
is
not
None
:
if
tools
is
not
None
:
function_tools
=
[]
function_tools
:
list
[
Union
[
Tool
,
ChatCompletionToolsParam
]]
=
[]
for
tool
in
tools
:
for
tool
in
tools
:
if
tool
.
type
in
(
"web_search_preview"
,
"code_interpreter"
):
if
tool
.
type
in
(
"web_search_preview"
,
"code_interpreter"
,
"container"
):
# These are built-in tools that are added to the system message.
# These are built-in tools that are added to the system message.
pass
pass
elif
tool
.
type
==
"function"
:
elif
tool
.
type
==
"function"
:
function_tools
.
append
(
tool
)
function_tools
.
append
(
tool
)
else
:
else
:
raise
ValueError
(
f
"tool type
{
tool
.
type
}
not supported"
)
raise
ValueError
(
f
"tool type
{
tool
.
type
}
not supported"
)
if
function_tools
:
if
function_tools
:
function_tool_descriptions
=
[
function_tool_descriptions
=
[
ToolDescription
.
new
(
create_tool_definition
(
tool
)
for
tool
in
function_tools
name
=
tool
.
name
,
description
=
tool
.
description
,
parameters
=
tool
.
parameters
,
)
for
tool
in
function_tools
]
]
dev_msg_content
=
dev_msg_content
.
with_function_tools
(
dev_msg_content
=
dev_msg_content
.
with_function_tools
(
function_tool_descriptions
)
function_tool_descriptions
)
...
@@ -120,6 +172,8 @@ def parse_response_input(
...
@@ -120,6 +172,8 @@ def parse_response_input(
TextContent
(
text
=
text_prefix
+
c
[
"text"
])
for
c
in
content
TextContent
(
text
=
text_prefix
+
c
[
"text"
])
for
c
in
content
]
]
msg
=
Message
.
from_role_and_contents
(
role
,
contents
)
msg
=
Message
.
from_role_and_contents
(
role
,
contents
)
if
role
==
"assistant"
:
msg
=
msg
.
with_channel
(
"final"
)
elif
response_msg
[
"type"
]
==
"function_call_output"
:
elif
response_msg
[
"type"
]
==
"function_call_output"
:
call_id
=
response_msg
[
"call_id"
]
call_id
=
response_msg
[
"call_id"
]
call_response
:
Optional
[
ResponseFunctionToolCall
]
=
None
call_response
:
Optional
[
ResponseFunctionToolCall
]
=
None
...
@@ -148,16 +202,46 @@ def parse_response_input(
...
@@ -148,16 +202,46 @@ def parse_response_input(
return
msg
return
msg
def
parse_chat_input
(
chat_msg
)
->
Message
:
def
parse_chat_input
(
chat_msg
)
->
list
[
Message
]:
role
=
chat_msg
[
"role"
]
if
not
isinstance
(
chat_msg
,
dict
):
content
=
chat_msg
[
"content"
]
# Handle Pydantic models
chat_msg
=
chat_msg
.
model_dump
(
exclude_none
=
True
)
role
=
chat_msg
.
get
(
"role"
)
# Assistant message with tool calls
tool_calls
=
chat_msg
.
get
(
"tool_calls"
)
if
role
==
"assistant"
and
tool_calls
:
msgs
:
list
[
Message
]
=
[]
for
call
in
tool_calls
:
func
=
call
.
get
(
"function"
,
{})
name
=
func
.
get
(
"name"
,
""
)
arguments
=
func
.
get
(
"arguments"
,
""
)
or
""
msg
=
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
arguments
)
msg
=
msg
.
with_channel
(
"commentary"
)
msg
=
msg
.
with_recipient
(
f
"functions.
{
name
}
"
)
msg
=
msg
.
with_content_type
(
"json"
)
msgs
.
append
(
msg
)
return
msgs
# Tool role message (tool output)
if
role
==
"tool"
:
name
=
chat_msg
.
get
(
"name"
,
""
)
content
=
chat_msg
.
get
(
"content"
,
""
)
or
""
msg
=
Message
.
from_author_and_content
(
Author
.
new
(
Role
.
TOOL
,
f
"functions.
{
name
}
"
),
content
).
with_channel
(
"commentary"
)
return
[
msg
]
# Default: user/assistant/system messages with content
content
=
chat_msg
.
get
(
"content"
,
""
)
if
isinstance
(
content
,
str
):
if
isinstance
(
content
,
str
):
contents
=
[
TextContent
(
text
=
content
)]
contents
=
[
TextContent
(
text
=
content
)]
else
:
else
:
# TODO: Support refusal.
# TODO: Support refusal.
contents
=
[
TextContent
(
text
=
c
.
get
(
"text"
,
""
))
for
c
in
content
]
contents
=
[
TextContent
(
text
=
c
.
get
(
"text"
,
""
))
for
c
in
content
]
msg
=
Message
.
from_role_and_contents
(
role
,
contents
)
msg
=
Message
.
from_role_and_contents
(
role
,
contents
)
return
msg
return
[
msg
]
def
render_for_completion
(
messages
:
list
[
Message
])
->
list
[
int
]:
def
render_for_completion
(
messages
:
list
[
Message
])
->
list
[
int
]:
...
@@ -227,7 +311,7 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]:
...
@@ -227,7 +311,7 @@ def parse_output_message(message: Message) -> list[ResponseOutputItem]:
call_id
=
f
"call_
{
random_id
}
"
,
call_id
=
f
"call_
{
random_id
}
"
,
type
=
"function_call"
,
type
=
"function_call"
,
name
=
function_name
,
name
=
function_name
,
id
=
f
"f
t
_
{
random_id
}
"
,
id
=
f
"f
c
_
{
random_id
}
"
,
)
)
output_items
.
append
(
response_item
)
output_items
.
append
(
response_item
)
elif
recipient
is
not
None
and
(
recipient
.
startswith
(
"python"
)
elif
recipient
is
not
None
and
(
recipient
.
startswith
(
"python"
)
...
...
vllm/entrypoints/launcher.py
View file @
38d80967
...
@@ -95,7 +95,7 @@ async def serve_http(app: FastAPI,
...
@@ -95,7 +95,7 @@ async def serve_http(app: FastAPI,
port
=
uvicorn_kwargs
[
"port"
]
port
=
uvicorn_kwargs
[
"port"
]
process
=
find_process_using_port
(
port
)
process
=
find_process_using_port
(
port
)
if
process
is
not
None
:
if
process
is
not
None
:
logger
.
debu
g
(
logger
.
warnin
g
(
"port %s is used by process %s launched with command:
\n
%s"
,
"port %s is used by process %s launched with command:
\n
%s"
,
port
,
process
,
" "
.
join
(
process
.
cmdline
()))
port
,
process
,
" "
.
join
(
process
.
cmdline
()))
logger
.
info
(
"Shutting down FastAPI HTTP server."
)
logger
.
info
(
"Shutting down FastAPI HTTP server."
)
...
...
vllm/entrypoints/llm.py
View file @
38d80967
...
@@ -110,6 +110,14 @@ class LLM:
...
@@ -110,6 +110,14 @@ class LLM:
values will increase the KV cache size and thus improve the model's
values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of-
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
memory (OOM) errors.
kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
this is set to None and vllm can automatically infer the kv cache
size based on gpu_memory_utilization. However, users may want to
manually specify the kv cache memory size. kv_cache_memory_bytes
allows more fine-grain control of how much memory gets used when
compared with using gpu_memory_memory_utilization. Note that
kv_cache_memory_bytes (when not-None) ignores
gpu_memory_utilization
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
when their `best_of` sampling parameters are larger than 1. If all
...
@@ -184,6 +192,7 @@ class LLM:
...
@@ -184,6 +192,7 @@ class LLM:
hf_overrides
:
Optional
[
HfOverrides
]
=
None
,
hf_overrides
:
Optional
[
HfOverrides
]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
override_pooler_config
:
Optional
[
PoolerConfig
]
=
None
,
override_pooler_config
:
Optional
[
PoolerConfig
]
=
None
,
kv_cache_memory_bytes
:
Optional
[
int
]
=
None
,
compilation_config
:
Optional
[
Union
[
int
,
dict
[
str
,
Any
],
compilation_config
:
Optional
[
Union
[
int
,
dict
[
str
,
Any
],
CompilationConfig
]]
=
None
,
CompilationConfig
]]
=
None
,
logits_processors
:
Optional
[
list
[
Union
[
str
,
logits_processors
:
Optional
[
list
[
Union
[
str
,
...
@@ -204,7 +213,7 @@ class LLM:
...
@@ -204,7 +213,7 @@ class LLM:
if
"kv_transfer_config"
in
kwargs
and
isinstance
(
if
"kv_transfer_config"
in
kwargs
and
isinstance
(
kwargs
[
"kv_transfer_config"
],
dict
):
kwargs
[
"kv_transfer_config"
],
dict
):
from
vllm.config
import
KVTransferConfig
from
vllm.config
.kv_transfer
import
KVTransferConfig
raw_config_dict
=
kwargs
[
"kv_transfer_config"
]
raw_config_dict
=
kwargs
[
"kv_transfer_config"
]
try
:
try
:
kwargs
[
"kv_transfer_config"
]
=
KVTransferConfig
(
kwargs
[
"kv_transfer_config"
]
=
KVTransferConfig
(
...
@@ -251,6 +260,7 @@ class LLM:
...
@@ -251,6 +260,7 @@ class LLM:
tokenizer_revision
=
tokenizer_revision
,
tokenizer_revision
=
tokenizer_revision
,
seed
=
seed
,
seed
=
seed
,
gpu_memory_utilization
=
gpu_memory_utilization
,
gpu_memory_utilization
=
gpu_memory_utilization
,
kv_cache_memory_bytes
=
kv_cache_memory_bytes
,
swap_space
=
swap_space
,
swap_space
=
swap_space
,
cpu_offload_gb
=
cpu_offload_gb
,
cpu_offload_gb
=
cpu_offload_gb
,
enforce_eager
=
enforce_eager
,
enforce_eager
=
enforce_eager
,
...
@@ -796,7 +806,7 @@ class LLM:
...
@@ -796,7 +806,7 @@ class LLM:
# NOTE: _parse_chat_message_content_parts() currently doesn't
# NOTE: _parse_chat_message_content_parts() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in
# handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it.
# the chat message parsing for it.
conversation
,
mm_data
=
parse_chat_messages
(
conversation
,
mm_data
,
mm_uuids
=
parse_chat_messages
(
msgs
,
msgs
,
model_config
,
model_config
,
tokenizer
,
tokenizer
,
...
@@ -826,6 +836,9 @@ class LLM:
...
@@ -826,6 +836,9 @@ class LLM:
if
mm_data
is
not
None
:
if
mm_data
is
not
None
:
prompt
[
"multi_modal_data"
]
=
mm_data
prompt
[
"multi_modal_data"
]
=
mm_data
if
mm_uuids
is
not
None
:
prompt
[
"multi_modal_uuids"
]
=
mm_uuids
if
mm_processor_kwargs
is
not
None
:
if
mm_processor_kwargs
is
not
None
:
prompt
[
"mm_processor_kwargs"
]
=
mm_processor_kwargs
prompt
[
"mm_processor_kwargs"
]
=
mm_processor_kwargs
...
...
vllm/entrypoints/openai/api_server.py
View file @
38d80967
...
@@ -616,14 +616,23 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
...
@@ -616,14 +616,23 @@ async def create_responses(request: ResponsesRequest, raw_request: Request):
@
router
.
get
(
"/v1/responses/{response_id}"
)
@
router
.
get
(
"/v1/responses/{response_id}"
)
async
def
retrieve_responses
(
response_id
:
str
,
raw_request
:
Request
):
async
def
retrieve_responses
(
response_id
:
str
,
raw_request
:
Request
,
starting_after
:
Optional
[
int
]
=
None
,
stream
:
Optional
[
bool
]
=
False
,
):
handler
=
responses
(
raw_request
)
handler
=
responses
(
raw_request
)
if
handler
is
None
:
if
handler
is
None
:
return
base
(
raw_request
).
create_error_response
(
return
base
(
raw_request
).
create_error_response
(
message
=
"The model does not support Responses API"
)
message
=
"The model does not support Responses API"
)
try
:
try
:
response
=
await
handler
.
retrieve_responses
(
response_id
)
response
=
await
handler
.
retrieve_responses
(
response_id
,
starting_after
=
starting_after
,
stream
=
stream
,
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
raise
HTTPException
(
status_code
=
HTTPStatus
.
INTERNAL_SERVER_ERROR
.
value
,
detail
=
str
(
e
))
from
e
detail
=
str
(
e
))
from
e
...
@@ -631,6 +640,9 @@ async def retrieve_responses(response_id: str, raw_request: Request):
...
@@ -631,6 +640,9 @@ async def retrieve_responses(response_id: str, raw_request: Request):
if
isinstance
(
response
,
ErrorResponse
):
if
isinstance
(
response
,
ErrorResponse
):
return
JSONResponse
(
content
=
response
.
model_dump
(),
return
JSONResponse
(
content
=
response
.
model_dump
(),
status_code
=
response
.
error
.
code
)
status_code
=
response
.
error
.
code
)
elif
stream
:
return
StreamingResponse
(
content
=
response
,
media_type
=
"text/event-stream"
)
return
JSONResponse
(
content
=
response
.
model_dump
())
return
JSONResponse
(
content
=
response
.
model_dump
())
...
@@ -1705,6 +1717,8 @@ async def init_app_state(
...
@@ -1705,6 +1717,8 @@ async def init_app_state(
if
args
.
tool_server
==
"demo"
:
if
args
.
tool_server
==
"demo"
:
tool_server
:
Optional
[
ToolServer
]
=
DemoToolServer
()
tool_server
:
Optional
[
ToolServer
]
=
DemoToolServer
()
assert
isinstance
(
tool_server
,
DemoToolServer
)
await
tool_server
.
init_and_validate
()
elif
args
.
tool_server
:
elif
args
.
tool_server
:
tool_server
=
MCPToolServer
()
tool_server
=
MCPToolServer
()
await
tool_server
.
add_tool_server
(
args
.
tool_server
)
await
tool_server
.
add_tool_server
(
args
.
tool_server
)
...
...
vllm/entrypoints/openai/cli_args.py
View file @
38d80967
...
@@ -134,14 +134,13 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
...
@@ -134,14 +134,13 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
"""If specified, will run the OpenAI frontend server in the same process as
"""If specified, will run the OpenAI frontend server in the same process as
the model serving engine."""
the model serving engine."""
enable_request_id_headers
:
bool
=
False
enable_request_id_headers
:
bool
=
False
"""If specified, API server will add X-Request-Id header to responses.
"""If specified, API server will add X-Request-Id header to responses."""
Caution: this hurts performance at high QPS."""
enable_auto_tool_choice
:
bool
=
False
enable_auto_tool_choice
:
bool
=
False
"""If specified, exclude tool definitions in prompts when
tool_choice='none'."""
exclude_tools_when_tool_choice_none
:
bool
=
False
"""Enable auto tool choice for supported models. Use `--tool-call-parser`
"""Enable auto tool choice for supported models. Use `--tool-call-parser`
to specify which parser to use."""
to specify which parser to use."""
exclude_tools_when_tool_choice_none
:
bool
=
False
"""If specified, exclude tool definitions in prompts when
tool_choice='none'."""
tool_call_parser
:
Optional
[
str
]
=
None
tool_call_parser
:
Optional
[
str
]
=
None
"""Select the tool call parser depending on the model that you're using.
"""Select the tool call parser depending on the model that you're using.
This is used to parse the model-generated tool call into OpenAI API format.
This is used to parse the model-generated tool call into OpenAI API format.
...
@@ -204,7 +203,7 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
...
@@ -204,7 +203,7 @@ schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
frontend_kwargs
[
"lora_modules"
][
"type"
]
=
optional_type
(
str
)
frontend_kwargs
[
"lora_modules"
][
"type"
]
=
optional_type
(
str
)
frontend_kwargs
[
"lora_modules"
][
"action"
]
=
LoRAParserAction
frontend_kwargs
[
"lora_modules"
][
"action"
]
=
LoRAParserAction
# Special case: Middleware needs append action
# Special case: Middleware needs
to
append action
frontend_kwargs
[
"middleware"
][
"action"
]
=
"append"
frontend_kwargs
[
"middleware"
][
"action"
]
=
"append"
frontend_kwargs
[
"middleware"
][
"type"
]
=
str
frontend_kwargs
[
"middleware"
][
"type"
]
=
str
if
"nargs"
in
frontend_kwargs
[
"middleware"
]:
if
"nargs"
in
frontend_kwargs
[
"middleware"
]:
...
...
vllm/entrypoints/openai/protocol.py
View file @
38d80967
...
@@ -43,10 +43,10 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
...
@@ -43,10 +43,10 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
from
vllm.entrypoints.score_utils
import
(
ScoreContentPartParam
,
from
vllm.entrypoints.score_utils
import
(
ScoreContentPartParam
,
ScoreMultiModalParam
)
ScoreMultiModalParam
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.logprobs
import
Logprob
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
(
BeamSearchParams
,
GuidedDecodingParams
,
from
vllm.sampling_params
import
(
BeamSearchParams
,
GuidedDecodingParams
,
RequestOutputKind
,
SamplingParams
)
RequestOutputKind
,
SamplingParams
)
from
vllm.sequence
import
Logprob
from
vllm.utils
import
random_uuid
,
resolve_obj_by_qualname
from
vllm.utils
import
random_uuid
,
resolve_obj_by_qualname
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -1270,9 +1270,20 @@ class CompletionRequest(OpenAIBaseModel):
...
@@ -1270,9 +1270,20 @@ class CompletionRequest(OpenAIBaseModel):
@
model_validator
(
mode
=
"before"
)
@
model_validator
(
mode
=
"before"
)
@
classmethod
@
classmethod
def
validate_prompt_and_prompt_embeds
(
cls
,
data
):
def
validate_prompt_and_prompt_embeds
(
cls
,
data
):
if
data
.
get
(
"prompt"
)
is
None
and
data
.
get
(
"prompt_embeds"
)
is
None
:
prompt
=
data
.
get
(
"prompt"
)
prompt_embeds
=
data
.
get
(
"prompt_embeds"
)
prompt_is_empty
=
(
prompt
is
None
or
(
isinstance
(
prompt
,
str
)
and
prompt
==
""
))
embeds_is_empty
=
(
prompt_embeds
is
None
or
(
isinstance
(
prompt_embeds
,
list
)
and
len
(
prompt_embeds
)
==
0
))
if
prompt_is_empty
and
embeds_is_empty
:
raise
ValueError
(
raise
ValueError
(
"At least one of `prompt` or `prompt_embeds` must be set."
)
"Either prompt or prompt_embeds must be provided and non-empty."
)
return
data
return
data
@
model_validator
(
mode
=
"before"
)
@
model_validator
(
mode
=
"before"
)
...
@@ -1342,6 +1353,14 @@ class EmbeddingChatRequest(OpenAIBaseModel):
...
@@ -1342,6 +1353,14 @@ class EmbeddingChatRequest(OpenAIBaseModel):
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=-
1
)]]
=
None
# --8<-- [start:chat-embedding-extra-params]
# --8<-- [start:chat-embedding-extra-params]
add_generation_prompt
:
bool
=
Field
(
default
=
False
,
description
=
(
"If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."
),
)
add_special_tokens
:
bool
=
Field
(
add_special_tokens
:
bool
=
Field
(
default
=
False
,
default
=
False
,
description
=
(
description
=
(
...
@@ -1424,9 +1443,10 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
...
@@ -1424,9 +1443,10 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
When using plugins IOProcessor plugins, the actual input is processed
When using plugins IOProcessor plugins, the actual input is processed
by the plugin itself. Hence, we use a generic type for the request data
by the plugin itself. Hence, we use a generic type for the request data
"""
"""
softmax
:
bool
=
True
def
to_pooling_params
(
self
):
def
to_pooling_params
(
self
):
return
PoolingParams
(
task
=
"encode"
)
return
PoolingParams
(
task
=
"encode"
,
softmax
=
self
.
softmax
)
class
IOProcessorResponse
(
OpenAIBaseModel
,
Generic
[
T
]):
class
IOProcessorResponse
(
OpenAIBaseModel
,
Generic
[
T
]):
...
@@ -1832,7 +1852,8 @@ class InputTokensDetails(OpenAIBaseModel):
...
@@ -1832,7 +1852,8 @@ class InputTokensDetails(OpenAIBaseModel):
class
OutputTokensDetails
(
OpenAIBaseModel
):
class
OutputTokensDetails
(
OpenAIBaseModel
):
reasoning_tokens
:
int
reasoning_tokens
:
int
=
0
tool_output_tokens
:
int
=
0
class
ResponseUsage
(
OpenAIBaseModel
):
class
ResponseUsage
(
OpenAIBaseModel
):
...
@@ -2175,6 +2196,13 @@ class TranscriptionRequest(OpenAIBaseModel):
...
@@ -2175,6 +2196,13 @@ class TranscriptionRequest(OpenAIBaseModel):
)
)
# --8<-- [end:transcription-extra-params]
# --8<-- [end:transcription-extra-params]
to_language
:
Optional
[
str
]
=
None
"""The language of the output audio we transcribe to.
Please note that this is not currently used by supported models at this
time, but it is a placeholder for future use, matching translation api.
"""
# --8<-- [start:transcription-sampling-params]
# --8<-- [start:transcription-sampling-params]
temperature
:
float
=
Field
(
default
=
0.0
)
temperature
:
float
=
Field
(
default
=
0.0
)
"""The sampling temperature, between 0 and 1.
"""The sampling temperature, between 0 and 1.
...
@@ -2408,6 +2436,9 @@ class TranslationRequest(OpenAIBaseModel):
...
@@ -2408,6 +2436,9 @@ class TranslationRequest(OpenAIBaseModel):
# TODO support additional sampling parameters
# TODO support additional sampling parameters
# --8<-- [start:translation-sampling-params]
# --8<-- [start:translation-sampling-params]
seed
:
Optional
[
int
]
=
Field
(
None
,
ge
=
_LONG_INFO
.
min
,
le
=
_LONG_INFO
.
max
)
"""The seed to use for sampling."""
temperature
:
float
=
Field
(
default
=
0.0
)
temperature
:
float
=
Field
(
default
=
0.0
)
"""The sampling temperature, between 0 and 1.
"""The sampling temperature, between 0 and 1.
...
@@ -2427,6 +2458,14 @@ class TranslationRequest(OpenAIBaseModel):
...
@@ -2427,6 +2458,14 @@ class TranslationRequest(OpenAIBaseModel):
will improve accuracy.
will improve accuracy.
"""
"""
to_language
:
Optional
[
str
]
=
None
"""The language of the input audio we translate to.
Please note that this is not supported by all models, refer to the specific
model documentation for more details.
For instance, Whisper only supports `to_language=en`.
"""
stream
:
Optional
[
bool
]
=
False
stream
:
Optional
[
bool
]
=
False
"""Custom field not present in the original OpenAI definition. When set,
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
it will enable output to be streamed in a similar fashion as the Chat
...
@@ -2458,6 +2497,7 @@ class TranslationRequest(OpenAIBaseModel):
...
@@ -2458,6 +2497,7 @@ class TranslationRequest(OpenAIBaseModel):
return
SamplingParams
.
from_optional
(
temperature
=
temperature
,
return
SamplingParams
.
from_optional
(
temperature
=
temperature
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
seed
=
self
.
seed
,
output_kind
=
RequestOutputKind
.
DELTA
output_kind
=
RequestOutputKind
.
DELTA
if
self
.
stream
\
if
self
.
stream
\
else
RequestOutputKind
.
FINAL_ONLY
)
else
RequestOutputKind
.
FINAL_ONLY
)
...
...
vllm/entrypoints/openai/run_batch.py
View file @
38d80967
...
@@ -161,7 +161,7 @@ async def write_local_file(output_path: str,
...
@@ -161,7 +161,7 @@ async def write_local_file(output_path: str,
batch_outputs: The list of batch outputs to write.
batch_outputs: The list of batch outputs to write.
"""
"""
# We should make this async, but as long as run_batch runs as a
# We should make this async, but as long as run_batch runs as a
# standalone program, blocking the event loop won't
e
ffect performance.
# standalone program, blocking the event loop won't
a
ffect performance.
with
open
(
output_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
output_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
for
o
in
batch_outputs
:
for
o
in
batch_outputs
:
print
(
o
.
model_dump_json
(),
file
=
f
)
print
(
o
.
model_dump_json
(),
file
=
f
)
...
...
Prev
1
…
18
19
20
21
22
23
24
25
26
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment