Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a3f8d5dd
"vllm/vscode:/vscode.git/clone" did not exist on "185d5e7cca7dc4cfb0f6c2e595cb62d23585efa1"
Commit
a3f8d5dd
authored
Dec 17, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori
parents
8d75f22e
f34eca5f
Changes
499
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
700 additions
and
238 deletions
+700
-238
vllm/distributed/ec_transfer/ec_connector/example_connector.py
...distributed/ec_transfer/ec_connector/example_connector.py
+1
-1
vllm/distributed/kv_events.py
vllm/distributed/kv_events.py
+129
-1
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+15
-0
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+9
-1
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
...tributed/kv_transfer/kv_connector/v1/lmcache_connector.py
+114
-3
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
...er/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
+9
-2
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
+0
-3
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+6
-0
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+56
-39
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+2
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+19
-99
vllm/entrypoints/anthropic/serving_messages.py
vllm/entrypoints/anthropic/serving_messages.py
+6
-6
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+92
-37
vllm/entrypoints/cli/__init__.py
vllm/entrypoints/cli/__init__.py
+2
-0
vllm/entrypoints/cli/benchmark/startup.py
vllm/entrypoints/cli/benchmark/startup.py
+21
-0
vllm/entrypoints/context.py
vllm/entrypoints/context.py
+8
-8
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+4
-14
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+1
-1
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+2
-2
vllm/entrypoints/openai/parser/harmony_utils.py
vllm/entrypoints/openai/parser/harmony_utils.py
+204
-21
No files found.
vllm/distributed/ec_transfer/ec_connector/example_connector.py
View file @
a3f8d5dd
...
...
@@ -144,7 +144,7 @@ class ECExampleConnector(ECConnectorBase):
Update ECConnector state after encoder cache allocation.
"""
mm_hash
=
request
.
mm_features
[
index
].
identifier
num_encoder_token
=
request
.
get_num_encoder_
token
s
(
index
)
num_encoder_token
=
request
.
get_num_encoder_
embed
s
(
index
)
# Insert mm_hash only if this block has not been recorded yet.
self
.
_mm_datas_need_loads
[
mm_hash
]
=
num_encoder_token
...
...
vllm/distributed/kv_events.py
View file @
a3f8d5dd
...
...
@@ -5,7 +5,7 @@ import queue
import
threading
import
time
from
abc
import
ABC
,
abstractmethod
from
collections
import
deque
from
collections
import
Counter
,
deque
from
collections.abc
import
Callable
from
dataclasses
import
asdict
from
itertools
import
count
...
...
@@ -54,11 +54,26 @@ class BlockStored(KVCacheEvent):
lora_id
:
int
|
None
medium
:
str
|
None
def
__hash__
(
self
)
->
int
:
return
hash
(
(
tuple
(
self
.
block_hashes
),
self
.
parent_block_hash
,
tuple
(
self
.
token_ids
),
self
.
block_size
,
self
.
lora_id
,
self
.
medium
,
)
)
class
BlockRemoved
(
KVCacheEvent
):
block_hashes
:
list
[
ExternalBlockHash
]
medium
:
str
|
None
def
__hash__
(
self
)
->
int
:
return
hash
((
tuple
(
self
.
block_hashes
),
self
.
medium
))
class
AllBlocksCleared
(
KVCacheEvent
):
pass
...
...
@@ -68,6 +83,119 @@ class KVEventBatch(EventBatch):
events
:
list
[
BlockStored
|
BlockRemoved
|
AllBlocksCleared
]
class
KVEventAggregator
:
"""
Aggregates KV events across multiple workers.
Tracks how many times each event appears and returns only those
that were emitted by all workers.
"""
__slots__
=
(
"_event_counter"
,
"_num_workers"
)
def
__init__
(
self
,
num_workers
:
int
)
->
None
:
if
num_workers
<=
0
:
raise
ValueError
(
"num_workers must be greater than zero."
)
self
.
_event_counter
:
Counter
[
KVCacheEvent
]
=
Counter
()
self
.
_num_workers
:
int
=
num_workers
def
add_events
(
self
,
events
:
list
[
KVCacheEvent
])
->
None
:
"""
Add events from a worker batch.
:param events: List of KVCacheEvent objects.
"""
if
not
isinstance
(
events
,
list
):
raise
TypeError
(
"events must be a list of KVCacheEvent."
)
self
.
_event_counter
.
update
(
events
)
def
get_common_events
(
self
)
->
list
[
KVCacheEvent
]:
"""
Return events that appeared in all workers.
:return: List of events present in all workers.
"""
return
[
event
for
event
,
count
in
self
.
_event_counter
.
items
()
if
count
==
self
.
_num_workers
]
def
get_all_events
(
self
)
->
list
[
KVCacheEvent
]:
"""
Return all events for all workers.
:return: List of events for all workers.
"""
return
list
(
self
.
_event_counter
.
elements
())
def
clear_events
(
self
)
->
None
:
"""
Clear all tracked events.
"""
self
.
_event_counter
.
clear
()
def
increment_workers
(
self
,
count
:
int
=
1
)
->
None
:
"""
Increment the number of workers contributing events.
:param count: Number to increment the workers by.
"""
if
count
<=
0
:
raise
ValueError
(
"count must be positive."
)
self
.
_num_workers
+=
count
def
reset_workers
(
self
)
->
None
:
"""
Reset the number of workers to 1.
"""
self
.
_num_workers
=
1
def
get_number_of_workers
(
self
)
->
int
:
"""
Return the number of workers.
:return: int number of workers.
"""
return
self
.
_num_workers
def
__repr__
(
self
)
->
str
:
return
(
f
"<KVEventAggregator workers=
{
self
.
_num_workers
}
, "
f
"events=
{
len
(
self
.
_event_counter
)
}
>"
)
class
KVConnectorKVEvents
(
ABC
):
"""
Abstract base class for KV events.
Acts as a container for KV events from the connector.
"""
@
abstractmethod
def
add_events
(
self
,
events
:
list
[
KVCacheEvent
])
->
None
:
raise
NotImplementedError
@
abstractmethod
def
aggregate
(
self
)
->
"KVConnectorKVEvents"
:
raise
NotImplementedError
@
abstractmethod
def
increment_workers
(
self
,
count
:
int
=
1
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
get_all_events
(
self
)
->
list
[
KVCacheEvent
]:
raise
NotImplementedError
@
abstractmethod
def
get_number_of_workers
(
self
)
->
int
:
raise
NotImplementedError
@
abstractmethod
def
clear_events
(
self
)
->
None
:
raise
NotImplementedError
class
EventPublisher
(
ABC
):
"""Lightweight publisher for EventBatch batches with data parallelism
support.
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
a3f8d5dd
...
...
@@ -78,6 +78,7 @@ class KVOutputAggregator:
finished_sending
=
set
[
str
]()
finished_recving
=
set
[
str
]()
aggregated_kv_connector_stats
=
None
combined_kv_cache_events
=
None
invalid_block_ids
=
set
[
int
]()
for
model_runner_output
in
outputs
:
assert
model_runner_output
is
not
None
...
...
@@ -119,6 +120,19 @@ class KVOutputAggregator:
aggregated_kv_connector_stats
.
aggregate
(
kv_connector_stats
)
)
# Combine kv_cache_events from all workers.
if
combined_kv_cache_events
is
None
:
# Use the first worker's kv_cache events as start event list.
combined_kv_cache_events
=
kv_output
.
kv_cache_events
elif
kv_cache_events
:
=
kv_output
.
kv_cache_events
:
assert
isinstance
(
combined_kv_cache_events
,
type
(
kv_cache_events
),
)
worker_kv_cache_events
=
kv_cache_events
.
get_all_events
()
combined_kv_cache_events
.
add_events
(
worker_kv_cache_events
)
combined_kv_cache_events
.
increment_workers
(
1
)
invalid_block_ids
|=
kv_output
.
invalid_block_ids
# select output of the worker specified by output_rank
...
...
@@ -129,6 +143,7 @@ class KVOutputAggregator:
finished_sending
=
finished_sending
or
None
,
finished_recving
=
finished_recving
or
None
,
kv_connector_stats
=
aggregated_kv_connector_stats
or
None
,
kv_cache_events
=
combined_kv_cache_events
or
None
,
invalid_block_ids
=
invalid_block_ids
,
expected_finished_count
=
self
.
_expected_finished_count
,
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
a3f8d5dd
...
...
@@ -49,7 +49,7 @@ from vllm.v1.outputs import KVConnectorOutput
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_events
import
KVCacheEvent
,
KVConnectorKVEvents
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorPromMetrics
,
KVConnectorStats
,
...
...
@@ -379,6 +379,14 @@ class KVConnectorBase_V1(ABC):
"""
return
None
def
get_kv_connector_kv_cache_events
(
self
)
->
Optional
[
"KVConnectorKVEvents"
]:
"""
Get the KV connector kv cache events collected during the last interval.
This function should be called by the model runner every time after the
model execution and before cleanup.
"""
return
None
def
get_handshake_metadata
(
self
)
->
KVConnectorHandshakeMetadata
|
None
:
"""
Get the KVConnector handshake metadata for this connector.
...
...
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
View file @
a3f8d5dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
,
Any
import
torch
from
lmcache.integration.vllm.vllm_v1_adapter
import
(
LMCacheConnectorV1Impl
as
LMCacheConnectorLatestImpl
,
)
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_events
import
(
BlockStored
,
KVCacheEvent
,
KVConnectorKVEvents
,
KVEventAggregator
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
...
...
@@ -16,6 +20,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
)
from
vllm.logger
import
init_logger
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
if
TYPE_CHECKING
:
from
vllm.forward_context
import
ForwardContext
...
...
@@ -26,6 +31,44 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
class
LMCacheKVEvents
(
KVConnectorKVEvents
):
"""
Concrete implementation of KVConnectorKVEvents using KVEventAggregator.
"""
def
__init__
(
self
,
num_workers
:
int
)
->
None
:
self
.
_aggregator
=
KVEventAggregator
(
num_workers
)
def
add_events
(
self
,
events
:
list
[
KVCacheEvent
])
->
None
:
self
.
_aggregator
.
add_events
(
events
)
def
aggregate
(
self
)
->
"LMCacheKVEvents"
:
"""
Aggregate KV events and retain only common events.
"""
common_events
=
self
.
_aggregator
.
get_common_events
()
self
.
_aggregator
.
clear_events
()
self
.
_aggregator
.
add_events
(
common_events
)
self
.
_aggregator
.
reset_workers
()
return
self
def
increment_workers
(
self
,
count
:
int
=
1
)
->
None
:
self
.
_aggregator
.
increment_workers
(
count
)
def
get_all_events
(
self
)
->
list
[
KVCacheEvent
]:
return
self
.
_aggregator
.
get_all_events
()
def
get_number_of_workers
(
self
)
->
int
:
return
self
.
_aggregator
.
get_number_of_workers
()
def
clear_events
(
self
)
->
None
:
self
.
_aggregator
.
clear_events
()
self
.
_aggregator
.
reset_workers
()
def
__repr__
(
self
)
->
str
:
return
f
"<LMCacheKVEvents events=
{
self
.
get_all_events
()
}
>"
class
LMCacheConnectorV1
(
KVConnectorBase_V1
):
def
__init__
(
self
,
...
...
@@ -50,10 +93,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
cls
=
_adapter
.
LMCacheConnectorV1Impl
else
:
logger
.
info
(
"Initializing latest dev LMCache connector"
)
# lazy import
from
lmcache.integration.vllm.vllm_v1_adapter
import
(
LMCacheConnectorV1Impl
as
LMCacheConnectorLatestImpl
,
)
cls
=
LMCacheConnectorLatestImpl
self
.
_lmcache_engine
=
cls
(
vllm_config
,
role
,
self
)
self
.
_kv_cache_events
:
LMCacheKVEvents
|
None
=
None
# ==============================
# Worker-side methods
# ==============================
...
...
@@ -151,6 +201,31 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
# Fallback for older versions that don't support this method
return
set
()
def
get_kv_connector_kv_cache_events
(
self
)
->
LMCacheKVEvents
|
None
:
"""
Get the KV connector kv cache events collected during the last interval.
"""
events
=
self
.
_lmcache_engine
.
get_kv_events
()
# type: ignore [attr-defined]
if
not
events
:
return
None
blocks
:
list
[
BlockStored
]
=
[
BlockStored
(
block_hashes
=
e
.
block_hashes
,
parent_block_hash
=
e
.
parent_block_hash
,
token_ids
=
e
.
token_ids
,
lora_id
=
e
.
lora_id
,
block_size
=
e
.
block_size
,
medium
=
e
.
medium
,
)
for
e
in
events
]
lmcache_kv_events
=
LMCacheKVEvents
(
num_workers
=
1
)
lmcache_kv_events
.
add_events
(
blocks
)
return
lmcache_kv_events
# ==============================
# Scheduler-side methods
# ==============================
...
...
@@ -198,6 +273,28 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
"""
return
self
.
_lmcache_engine
.
build_connector_meta
(
scheduler_output
)
def
update_connector_output
(
self
,
connector_output
:
KVConnectorOutput
):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
# Get the KV events
kv_cache_events
=
connector_output
.
kv_cache_events
if
not
kv_cache_events
or
not
isinstance
(
kv_cache_events
,
LMCacheKVEvents
):
return
if
self
.
_kv_cache_events
is
None
:
self
.
_kv_cache_events
=
kv_cache_events
else
:
self
.
_kv_cache_events
.
add_events
(
kv_cache_events
.
get_all_events
())
self
.
_kv_cache_events
.
increment_workers
(
kv_cache_events
.
get_number_of_workers
()
)
return
def
request_finished
(
self
,
request
:
"Request"
,
...
...
@@ -214,3 +311,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
returned by the engine.
"""
return
self
.
_lmcache_engine
.
request_finished
(
request
,
block_ids
)
def
take_events
(
self
)
->
Iterable
[
"KVCacheEvent"
]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
if
self
.
_kv_cache_events
is
not
None
:
self
.
_kv_cache_events
.
aggregate
()
kv_cache_events
=
self
.
_kv_cache_events
.
get_all_events
()
yield
from
kv_cache_events
self
.
_kv_cache_events
.
clear_events
()
self
.
_kv_cache_events
=
None
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
View file @
a3f8d5dd
...
...
@@ -27,7 +27,14 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
LMCacheAsyncLookupServer
,
)
from
lmcache.v1.offload_server.zmq_server
import
ZMQOffloadServer
from
lmcache.v1.plugin.plugin_launcher
import
PluginLauncher
try
:
from
lmcache.v1.plugin.runtime_plugin_launcher
import
RuntimePluginLauncher
except
ImportError
:
# Backwards compatibility for lmcache <= 0.3.10-post1
from
lmcache.v1.plugin.plugin_launcher
import
(
PluginLauncher
as
RuntimePluginLauncher
,
)
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
...
...
@@ -683,7 +690,7 @@ class LMCacheConnectorV1Impl:
self
.
api_server
=
InternalAPIServer
(
self
)
self
.
api_server
.
start
()
# Launch plugins
self
.
plugin_launcher
=
PluginLauncher
(
self
.
plugin_launcher
=
Runtime
PluginLauncher
(
self
.
config
,
role
,
self
.
worker_count
,
...
...
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
View file @
a3f8d5dd
...
...
@@ -7,7 +7,6 @@ from prometheus_client import Counter, Gauge, Histogram
from
vllm.config
import
KVTransferConfig
,
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.factory
import
KVConnectorFactory
from
vllm.distributed.kv_transfer.kv_transfer_state
import
has_kv_transfer_group
from
vllm.logger
import
init_logger
PromMetric
:
TypeAlias
=
Gauge
|
Counter
|
Histogram
...
...
@@ -53,8 +52,6 @@ class KVConnectorStats:
class
KVConnectorLogging
:
def
__init__
(
self
,
kv_transfer_config
:
KVTransferConfig
|
None
):
# This should be called on frontend process.
assert
not
has_kv_transfer_group
()
# Instantiate the connector's stats class.
if
kv_transfer_config
and
kv_transfer_config
.
kv_connector
:
self
.
connector_cls
=
KVConnectorFactory
.
get_connector_class
(
...
...
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
a3f8d5dd
...
...
@@ -259,6 +259,12 @@ class MultiConnector(KVConnectorBase_V1):
agg_block_ids
|=
c
.
get_block_ids_with_load_errors
()
return
agg_block_ids
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events' method
# for the MultiConnector. It should be able to get events from multiple
# connectors, handling the case where only a subset of the requested connectors
# implements the 'get_kv_connector_kv_cache_events'
# Follow on PR from https://github.com/vllm-project/vllm/pull/28309#pullrequestreview-3566351082
# ==============================
# Scheduler-side methods
# ==============================
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
a3f8d5dd
...
...
@@ -202,17 +202,22 @@ def compute_nixl_compatibility_hash(
return
compat_hash
@
dataclass
class
RemoteMeta
:
block_ids
:
list
[
int
]
host
:
str
port
:
int
engine_id
:
str
request_id
:
str
@
dataclass
class
ReqMeta
:
local_block_ids
:
list
[
int
]
# To be used when logical block size does not match the kernel block size
local_physical_block_ids
:
list
[
int
]
remote_block_ids
:
list
[
int
]
remote_host
:
str
remote_port
:
int
remote_engine_id
:
str
remote_request_id
:
str
tp_size
:
int
remote
:
RemoteMeta
|
None
=
None
class
NixlConnectorMetadata
(
KVConnectorMetadata
):
...
...
@@ -223,31 +228,43 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self
.
reqs_in_batch
:
set
[
ReqId
]
=
set
()
self
.
reqs_not_processed
:
set
[
ReqId
]
=
set
()
def
add_new_req
(
def
_
add_new_req
(
self
,
request_id
:
ReqId
,
local_block_ids
:
list
[
int
],
kv_transfer_params
:
dict
[
str
,
Any
],
load_remote_cache
:
bool
=
True
,
save_to_host
:
bool
=
False
,
):
# save and load are mutually exclusive
assert
load_remote_cache
^
save_to_host
_req
=
ReqMeta
(
)
->
ReqMeta
:
return
ReqMeta
(
local_block_ids
=
local_block_ids
,
local_physical_block_ids
=
local_block_ids
,
remote_block_ids
=
kv_transfer_params
[
"remote_block_ids"
],
remote_engine_id
=
kv_transfer_params
[
"remote_engine_id"
],
remote_request_id
=
kv_transfer_params
[
"remote_request_id"
],
remote_host
=
kv_transfer_params
[
"remote_host"
],
remote_port
=
kv_transfer_params
[
"remote_port"
],
# P workers don't need to receive tp_size from proxy here.
tp_size
=
kv_transfer_params
.
get
(
"tp_size"
,
1
),
)
if
save_to_host
:
self
.
reqs_to_save
[
request_id
]
=
_req
if
load_remote_cache
:
self
.
reqs_to_recv
[
request_id
]
=
_req
def
add_new_req_to_save
(
self
,
request_id
:
ReqId
,
local_block_ids
:
list
[
int
],
kv_transfer_params
:
dict
[
str
,
Any
],
):
self
.
reqs_to_save
[
request_id
]
=
self
.
_add_new_req
(
local_block_ids
,
kv_transfer_params
)
def
add_new_req_to_recv
(
self
,
request_id
:
ReqId
,
local_block_ids
:
list
[
int
],
kv_transfer_params
:
dict
[
str
,
Any
],
):
req
=
self
.
_add_new_req
(
local_block_ids
,
kv_transfer_params
)
req
.
remote
=
RemoteMeta
(
block_ids
=
kv_transfer_params
[
"remote_block_ids"
],
engine_id
=
kv_transfer_params
[
"remote_engine_id"
],
request_id
=
kv_transfer_params
[
"remote_request_id"
],
host
=
kv_transfer_params
[
"remote_host"
],
port
=
kv_transfer_params
[
"remote_port"
],
)
self
.
reqs_to_recv
[
request_id
]
=
req
class
NixlConnector
(
KVConnectorBase_V1
):
...
...
@@ -666,22 +683,18 @@ class NixlConnectorScheduler:
# Loop through scheduled reqs and convert to ReqMeta.
for
req_id
,
(
req
,
block_ids
)
in
self
.
_reqs_need_recv
.
items
():
assert
req
.
kv_transfer_params
is
not
None
meta
.
add_new_req
(
meta
.
add_new_req
_to_recv
(
request_id
=
req_id
,
local_block_ids
=
block_ids
,
kv_transfer_params
=
req
.
kv_transfer_params
,
load_remote_cache
=
True
,
save_to_host
=
False
,
)
for
req_id
,
(
req
,
block_ids
)
in
self
.
_reqs_need_save
.
items
():
assert
req
.
kv_transfer_params
is
not
None
meta
.
add_new_req
(
meta
.
add_new_req
_to_save
(
request_id
=
req_id
,
local_block_ids
=
block_ids
,
kv_transfer_params
=
req
.
kv_transfer_params
,
load_remote_cache
=
False
,
save_to_host
=
True
,
)
meta
.
reqs_to_send
=
self
.
_reqs_need_send
...
...
@@ -1124,10 +1137,11 @@ class NixlConnectorWorker:
# Do NIXL handshake in background and add to _ready_requests when done.
fut
=
self
.
_handshake_futures
.
get
(
remote_engine_id
)
if
fut
is
None
:
assert
meta
.
remote
is
not
None
fut
=
self
.
_handshake_initiation_executor
.
submit
(
self
.
_nixl_handshake
,
meta
.
remote
_
host
,
meta
.
remote
_
port
,
meta
.
remote
.
host
,
meta
.
remote
.
port
,
meta
.
tp_size
,
remote_engine_id
,
)
...
...
@@ -1774,6 +1788,7 @@ class NixlConnectorWorker:
# clean up metadata for completed requests
meta
=
self
.
_recving_metadata
.
pop
(
req_id
,
None
)
assert
meta
is
not
None
,
f
"
{
req_id
}
not found in recving_metadata list"
assert
meta
.
remote
is
not
None
if
self
.
use_host_buffer
:
self
.
sync_recved_kv_to_device
(
req_id
,
meta
)
if
self
.
enable_permute_local_kv
:
...
...
@@ -1781,7 +1796,7 @@ class NixlConnectorWorker:
# post processing for heteroblocksize
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
meta
.
remote
_
engine_id
meta
.
remote
.
engine_id
)
if
(
not
self
.
use_mla
...
...
@@ -1916,17 +1931,18 @@ class NixlConnectorWorker:
meta
.
local_physical_block_ids
=
self
.
_logical_to_kernel_block_ids
(
meta
.
local_block_ids
)
meta
.
remote_block_ids
=
self
.
_logical_to_kernel_block_ids
(
meta
.
remote_block_ids
assert
meta
.
remote
is
not
None
meta
.
remote
.
block_ids
=
self
.
_logical_to_kernel_block_ids
(
meta
.
remote
.
block_ids
)
remote_engine_id
=
meta
.
remote
_
engine_id
remote_engine_id
=
meta
.
remote
.
engine_id
logger
.
debug
(
"start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. "
,
req_id
,
remote_engine_id
,
len
(
meta
.
local_physical_block_ids
),
len
(
meta
.
remote
_
block_ids
),
len
(
meta
.
remote
.
block_ids
),
)
# always store metadata for failure recovery
self
.
_recving_metadata
[
req_id
]
=
meta
...
...
@@ -1965,17 +1981,18 @@ class NixlConnectorWorker:
self
.
_reqs_to_send
[
req_id
]
=
expiration_time
def
_read_blocks_for_req
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
assert
meta
.
remote
is
not
None
logger
.
debug
(
"Remote agent %s available, calling _read_blocks for req %s"
,
meta
.
remote
_
engine_id
,
meta
.
remote
.
engine_id
,
req_id
,
)
self
.
_read_blocks
(
request_id
=
req_id
,
dst_engine_id
=
meta
.
remote
_
engine_id
,
remote_request_id
=
meta
.
remote
_
request_id
,
dst_engine_id
=
meta
.
remote
.
engine_id
,
remote_request_id
=
meta
.
remote
.
request_id
,
local_block_ids
=
meta
.
local_physical_block_ids
,
remote_block_ids
=
meta
.
remote
_
block_ids
,
remote_block_ids
=
meta
.
remote
.
block_ids
,
)
def
_read_blocks
(
...
...
vllm/distributed/parallel_state.py
View file @
a3f8d5dd
...
...
@@ -1586,6 +1586,8 @@ def destroy_distributed_environment():
def
cleanup_dist_env_and_memory
(
shutdown_ray
:
bool
=
False
):
# Reset environment variable cache
envs
.
disable_envs_cache
()
# Ensure all objects are not frozen before cleanup
gc
.
unfreeze
()
...
...
vllm/engine/arg_utils.py
View file @
a3f8d5dd
...
...
@@ -71,7 +71,6 @@ from vllm.config.model import (
LogprobsMode
,
ModelDType
,
RunnerOption
,
TaskOption
,
TokenizerMode
,
)
from
vllm.config.multimodal
import
MMCacheType
,
MMEncoderTPMode
...
...
@@ -360,7 +359,6 @@ class EngineArgs:
hf_config_path
:
str
|
None
=
ModelConfig
.
hf_config_path
runner
:
RunnerOption
=
ModelConfig
.
runner
convert
:
ConvertOption
=
ModelConfig
.
convert
task
:
TaskOption
|
None
=
ModelConfig
.
task
skip_tokenizer_init
:
bool
=
ModelConfig
.
skip_tokenizer_init
enable_prompt_embeds
:
bool
=
ModelConfig
.
enable_prompt_embeds
tokenizer_mode
:
TokenizerMode
|
str
=
ModelConfig
.
tokenizer_mode
...
...
@@ -373,9 +371,8 @@ class EngineArgs:
config_format
:
str
=
ModelConfig
.
config_format
dtype
:
ModelDType
=
ModelConfig
.
dtype
kv_cache_dtype
:
CacheDType
=
CacheConfig
.
cache_dtype
seed
:
int
|
None
=
0
seed
:
int
=
ModelConfig
.
seed
max_model_len
:
int
|
None
=
ModelConfig
.
max_model_len
cuda_graph_sizes
:
list
[
int
]
|
None
=
CompilationConfig
.
cudagraph_capture_sizes
cudagraph_capture_sizes
:
list
[
int
]
|
None
=
(
CompilationConfig
.
cudagraph_capture_sizes
)
...
...
@@ -463,7 +460,6 @@ class EngineArgs:
MultiModalConfig
,
"media_io_kwargs"
)
mm_processor_kwargs
:
dict
[
str
,
Any
]
|
None
=
MultiModalConfig
.
mm_processor_kwargs
disable_mm_preprocessor_cache
:
bool
=
False
# DEPRECATED
mm_processor_cache_gb
:
float
=
MultiModalConfig
.
mm_processor_cache_gb
mm_processor_cache_type
:
MMCacheType
|
None
=
(
MultiModalConfig
.
mm_processor_cache_type
...
...
@@ -495,7 +491,7 @@ class EngineArgs:
enable_chunked_prefill
:
bool
|
None
=
None
disable_chunked_mm_input
:
bool
=
SchedulerConfig
.
disable_chunked_mm_input
disable_hybrid_kv_cache_manager
:
bool
=
(
disable_hybrid_kv_cache_manager
:
bool
|
None
=
(
SchedulerConfig
.
disable_hybrid_kv_cache_manager
)
...
...
@@ -559,9 +555,6 @@ class EngineArgs:
use_tqdm_on_load
:
bool
=
LoadConfig
.
use_tqdm_on_load
pt_load_map_location
:
str
=
LoadConfig
.
pt_load_map_location
# DEPRECATED
enable_multimodal_encoder_data_parallel
:
bool
=
False
logits_processors
:
list
[
str
|
type
[
LogitsProcessor
]]
|
None
=
(
ModelConfig
.
logits_processors
)
...
...
@@ -629,7 +622,6 @@ class EngineArgs:
model_group
.
add_argument
(
"--model"
,
**
model_kwargs
[
"model"
])
model_group
.
add_argument
(
"--runner"
,
**
model_kwargs
[
"runner"
])
model_group
.
add_argument
(
"--convert"
,
**
model_kwargs
[
"convert"
])
model_group
.
add_argument
(
"--task"
,
**
model_kwargs
[
"task"
],
deprecated
=
True
)
model_group
.
add_argument
(
"--tokenizer"
,
**
model_kwargs
[
"tokenizer"
])
model_group
.
add_argument
(
"--tokenizer-mode"
,
**
model_kwargs
[
"tokenizer_mode"
])
model_group
.
add_argument
(
...
...
@@ -883,11 +875,6 @@ class EngineArgs:
parallel_group
.
add_argument
(
"--worker-extension-cls"
,
**
parallel_kwargs
[
"worker_extension_cls"
]
)
parallel_group
.
add_argument
(
"--enable-multimodal-encoder-data-parallel"
,
action
=
"store_true"
,
deprecated
=
True
,
)
# KV cache arguments
cache_kwargs
=
get_kwargs
(
CacheConfig
)
...
...
@@ -961,9 +948,6 @@ class EngineArgs:
multimodal_group
.
add_argument
(
"--mm-processor-cache-gb"
,
**
multimodal_kwargs
[
"mm_processor_cache_gb"
]
)
multimodal_group
.
add_argument
(
"--disable-mm-preprocessor-cache"
,
action
=
"store_true"
,
deprecated
=
True
)
multimodal_group
.
add_argument
(
"--mm-processor-cache-type"
,
**
multimodal_kwargs
[
"mm_processor_cache_type"
]
)
...
...
@@ -1121,15 +1105,6 @@ class EngineArgs:
compilation_group
.
add_argument
(
"--cudagraph-capture-sizes"
,
**
compilation_kwargs
[
"cudagraph_capture_sizes"
]
)
compilation_kwargs
[
"cudagraph_capture_sizes"
][
"help"
]
=
(
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or v1.0.0,"
" whichever is soonest. Please use --cudagraph-capture-sizes instead."
)
compilation_group
.
add_argument
(
"--cuda-graph-sizes"
,
**
compilation_kwargs
[
"cudagraph_capture_sizes"
],
deprecated
=
True
,
)
compilation_group
.
add_argument
(
"--max-cudagraph-capture-size"
,
**
compilation_kwargs
[
"max_cudagraph_capture_size"
],
...
...
@@ -1202,62 +1177,20 @@ class EngineArgs:
if
is_gguf
(
self
.
model
):
self
.
quantization
=
self
.
load_format
=
"gguf"
# NOTE(woosuk): In V1, we use separate processes for workers (unless
# VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here
# doesn't affect the user process.
if
self
.
seed
is
None
:
logger
.
warning_once
(
"`seed=None` is equivalent to `seed=0` in V1 Engine. "
"You will no longer be allowed to pass `None` in v0.13."
,
scope
=
"local"
,
)
self
.
seed
=
0
if
not
envs
.
VLLM_ENABLE_V1_MULTIPROCESSING
:
logger
.
warning
(
"The global random seed is set to %d. Since "
"VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
"affect the random state of the Python process that "
"launched vLLM."
,
self
.
seed
,
)
if
self
.
disable_mm_preprocessor_cache
:
logger
.
warning_once
(
"`--disable-mm-preprocessor-cache` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-processor-cache-gb 0` instead."
,
scope
=
"local"
,
)
self
.
mm_processor_cache_gb
=
0
elif
envs
.
VLLM_MM_INPUT_CACHE_GIB
!=
4
:
logger
.
warning_once
(
"VLLM_MM_INPUT_CACHE_GIB` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-processor-cache-gb %d` instead."
,
envs
.
VLLM_MM_INPUT_CACHE_GIB
,
scope
=
"local"
,
)
self
.
mm_processor_cache_gb
=
envs
.
VLLM_MM_INPUT_CACHE_GIB
if
self
.
enable_multimodal_encoder_data_parallel
:
logger
.
warning_once
(
"--enable-multimodal-encoder-data-parallel` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-encoder-tp-mode data` instead."
,
scope
=
"local"
,
if
not
envs
.
VLLM_ENABLE_V1_MULTIPROCESSING
:
logger
.
warning
(
"The global random seed is set to %d. Since "
"VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
"affect the random state of the Python process that "
"launched vLLM."
,
self
.
seed
,
)
self
.
mm_encoder_tp_mode
=
"data"
return
ModelConfig
(
model
=
self
.
model
,
hf_config_path
=
self
.
hf_config_path
,
runner
=
self
.
runner
,
convert
=
self
.
convert
,
task
=
self
.
task
,
tokenizer
=
self
.
tokenizer
,
tokenizer_mode
=
self
.
tokenizer_mode
,
trust_remote_code
=
self
.
trust_remote_code
,
...
...
@@ -1716,7 +1649,13 @@ class EngineArgs:
"attention_backend and attention_config.backend "
"are mutually exclusive"
)
attention_config
.
backend
=
self
.
attention_backend
# Convert string to enum if needed (CLI parsing returns a string)
if
isinstance
(
self
.
attention_backend
,
str
):
attention_config
.
backend
=
AttentionBackendEnum
[
self
.
attention_backend
.
upper
()
]
else
:
attention_config
.
backend
=
self
.
attention_backend
load_config
=
self
.
create_load_config
()
...
...
@@ -1741,18 +1680,6 @@ class EngineArgs:
# Compilation config overrides
compilation_config
=
copy
.
deepcopy
(
self
.
compilation_config
)
if
self
.
cuda_graph_sizes
is
not
None
:
logger
.
warning
(
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or "
"v1.0.0, whichever is soonest. Please use --cudagraph-capture-sizes "
"instead."
)
if
compilation_config
.
cudagraph_capture_sizes
is
not
None
:
raise
ValueError
(
"cuda_graph_sizes and compilation_config."
"cudagraph_capture_sizes are mutually exclusive"
)
compilation_config
.
cudagraph_capture_sizes
=
self
.
cuda_graph_sizes
if
self
.
cudagraph_capture_sizes
is
not
None
:
if
compilation_config
.
cudagraph_capture_sizes
is
not
None
:
raise
ValueError
(
...
...
@@ -1862,6 +1789,7 @@ class EngineArgs:
except
Exception
:
# This is only used to set default_max_num_batched_tokens
device_memory
=
0
device_name
=
""
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
# throughput, see PR #17885 for more details.
...
...
@@ -1926,16 +1854,6 @@ class EngineArgs:
default_chunked_prefill
=
model_config
.
is_chunked_prefill_supported
default_prefix_caching
=
model_config
.
is_prefix_caching_supported
if
self
.
prefill_context_parallel_size
>
1
:
default_chunked_prefill
=
False
default_prefix_caching
=
False
logger
.
warning_once
(
"--prefill-context-parallel-size > 1 is not compatible with "
"chunked prefill and prefix caching now. Chunked prefill "
"and prefix caching have been disabled by default."
,
scope
=
"local"
,
)
if
self
.
enable_chunked_prefill
is
None
:
self
.
enable_chunked_prefill
=
default_chunked_prefill
...
...
@@ -2121,11 +2039,13 @@ def human_readable_int(value):
"k"
:
10
**
3
,
"m"
:
10
**
6
,
"g"
:
10
**
9
,
"t"
:
10
**
12
,
}
binary_multiplier
=
{
"K"
:
2
**
10
,
"M"
:
2
**
20
,
"G"
:
2
**
30
,
"T"
:
2
**
40
,
}
number
,
suffix
=
match
.
groups
()
...
...
vllm/entrypoints/anthropic/serving_messages.py
View file @
a3f8d5dd
...
...
@@ -324,12 +324,12 @@ class AnthropicServingMessages(OpenAIServingChat):
id
=
origin_chunk
.
id
,
content
=
[],
model
=
origin_chunk
.
model
,
),
usage
=
AnthropicUsage
(
input_tokens
=
origin_chunk
.
usage
.
prompt_tokens
if
origin_chunk
.
usage
else
0
,
output_tokens
=
0
,
usage
=
AnthropicUsage
(
input_tokens
=
origin_chunk
.
usage
.
prompt_tokens
if
origin_chunk
.
usage
else
0
,
output_tokens
=
0
,
)
,
),
)
first_item
=
False
...
...
vllm/entrypoints/chat_utils.py
View file @
a3f8d5dd
...
...
@@ -9,7 +9,7 @@ from collections import Counter, defaultdict, deque
from
collections.abc
import
Awaitable
,
Callable
,
Iterable
from
functools
import
cached_property
,
lru_cache
,
partial
from
pathlib
import
Path
from
typing
import
Any
,
Generic
,
Literal
,
TypeAlias
,
TypeVar
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
Literal
,
TypeAlias
,
TypeVar
,
cast
import
jinja2
import
jinja2.ext
...
...
@@ -24,6 +24,7 @@ from openai.types.chat import (
ChatCompletionContentPartInputAudioParam
,
ChatCompletionContentPartRefusalParam
,
ChatCompletionContentPartTextParam
,
ChatCompletionFunctionToolParam
,
ChatCompletionMessageToolCallParam
,
ChatCompletionToolMessageParam
,
)
...
...
@@ -49,11 +50,20 @@ from vllm.logger import init_logger
from
vllm.model_executor.models
import
SupportsMultiModal
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalDataDict
,
MultiModalUUIDDict
from
vllm.multimodal.utils
import
MEDIA_CONNECTOR_REGISTRY
,
MediaConnector
from
vllm.tokenizers
import
MistralTokenizer
,
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.transformers_utils.chat_templates
import
get_chat_template_fallback_path
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.utils
import
random_uuid
from
vllm.utils.collection_utils
import
is_list_of
from
vllm.utils.func_utils
import
supports_kw
from
vllm.utils.import_utils
import
LazyLoader
if
TYPE_CHECKING
:
import
torch
from
vllm.tokenizers.mistral
import
MistralTokenizer
else
:
torch
=
LazyLoader
(
"torch"
,
globals
(),
"torch"
)
logger
=
init_logger
(
__name__
)
...
...
@@ -260,6 +270,9 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
reasoning
:
str
|
None
"""The reasoning content for interleaved thinking."""
tools
:
list
[
ChatCompletionFunctionToolParam
]
|
None
"""The tools for developer role."""
ChatCompletionMessageParam
:
TypeAlias
=
(
OpenAIChatCompletionMessageParam
...
...
@@ -291,6 +304,9 @@ class ConversationMessage(TypedDict, total=False):
reasoning_content
:
str
|
None
"""Deprecated: The reasoning content for interleaved thinking."""
tools
:
list
[
ChatCompletionFunctionToolParam
]
|
None
"""The tools for developer role."""
# Passed in by user
ChatTemplateContentFormatOption
=
Literal
[
"auto"
,
"string"
,
"openai"
]
...
...
@@ -620,6 +636,44 @@ ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
_T
=
TypeVar
(
"_T"
)
def
_extract_embeds
(
tensors
:
list
[
torch
.
Tensor
]):
if
len
(
tensors
)
==
0
:
return
tensors
if
len
(
tensors
)
==
1
:
tensors
[
0
].
_is_single_item
=
True
# type: ignore
return
tensors
[
0
]
# To keep backwards compatibility for single item input
first_shape
=
tensors
[
0
].
shape
if
all
(
t
.
shape
==
first_shape
for
t
in
tensors
):
return
torch
.
stack
(
tensors
)
return
tensors
def
_get_embeds_data
(
items_by_modality
:
dict
[
str
,
list
[
Any
]],
modality
:
str
):
embeds_key
=
f
"
{
modality
}
_embeds"
embeds
=
items_by_modality
[
embeds_key
]
if
len
(
embeds
)
==
0
:
return
embeds
if
is_list_of
(
embeds
,
torch
.
Tensor
):
return
_extract_embeds
(
embeds
)
if
is_list_of
(
embeds
,
dict
):
if
not
embeds
:
return
{}
first_keys
=
set
(
embeds
[
0
].
keys
())
if
any
(
set
(
item
.
keys
())
!=
first_keys
for
item
in
embeds
[
1
:]):
raise
ValueError
(
"All dictionaries in the list of embeddings must have the same keys."
)
return
{
k
:
_extract_embeds
([
item
[
k
]
for
item
in
embeds
])
for
k
in
first_keys
}
return
embeds
class
BaseMultiModalItemTracker
(
ABC
,
Generic
[
_T
]):
"""
Tracks multi-modal items in a given request and ensures that the number
...
...
@@ -688,11 +742,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def
all_mm_uuids
(
self
)
->
MultiModalUUIDDict
|
None
:
if
not
self
.
_items_by_modality
:
return
None
mm_uuids
=
{}
uuids_by_modality
=
dict
(
self
.
_uuids_by_modality
)
if
"image"
in
uuids_by_modality
and
"image_embeds"
in
uuids_by_modality
:
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
if
"audio"
in
uuids_by_modality
and
"audio_embeds"
in
uuids_by_modality
:
raise
ValueError
(
"Mixing raw audio and embedding inputs is not allowed"
)
mm_uuids
=
{}
if
"image_embeds"
in
uuids_by_modality
:
mm_uuids
[
"image"
]
=
uuids_by_modality
[
"image_embeds"
]
if
"image"
in
uuids_by_modality
:
...
...
@@ -703,6 +760,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
mm_uuids
[
"audio"
]
=
uuids_by_modality
[
"audio"
]
# UUIDs of audios
if
"video"
in
uuids_by_modality
:
mm_uuids
[
"video"
]
=
uuids_by_modality
[
"video"
]
# UUIDs of videos
return
mm_uuids
@
abstractmethod
...
...
@@ -714,29 +772,25 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def
all_mm_data
(
self
)
->
MultiModalDataDict
|
None
:
if
not
self
.
_items_by_modality
:
return
None
mm_inputs
=
{}
items_by_modality
=
dict
(
self
.
_items_by_modality
)
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
if
"audio"
in
items_by_modality
and
"audio_embeds"
in
items_by_modality
:
raise
ValueError
(
"Mixing raw audio and embedding inputs is not allowed"
)
mm_inputs
=
{}
if
"image_embeds"
in
items_by_modality
:
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
mm_inputs
[
"image"
]
=
(
image_embeds_lst
if
len
(
image_embeds_lst
)
!=
1
else
image_embeds_lst
[
0
]
)
mm_inputs
[
"image"
]
=
_get_embeds_data
(
items_by_modality
,
"image"
)
if
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
if
"audio_embeds"
in
items_by_modality
:
audio_embeds_lst
=
items_by_modality
[
"audio_embeds"
]
mm_inputs
[
"audio"
]
=
(
audio_embeds_lst
if
len
(
audio_embeds_lst
)
!=
1
else
audio_embeds_lst
[
0
]
)
mm_inputs
[
"audio"
]
=
_get_embeds_data
(
items_by_modality
,
"audio"
)
if
"audio"
in
items_by_modality
:
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
if
"video"
in
items_by_modality
:
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
return
mm_inputs
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
...
...
@@ -747,38 +801,32 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async
def
all_mm_data
(
self
)
->
MultiModalDataDict
|
None
:
if
not
self
.
_items_by_modality
:
return
None
mm_inputs
=
{}
items_by_modality
=
{}
for
modality
,
items
in
self
.
_items_by_modality
.
items
():
coros
=
[]
for
item
in
items
:
if
item
is
not
None
:
coros
.
append
(
item
)
else
:
coros
.
append
(
asyncio
.
sleep
(
0
))
items_by_modality
[
modality
]
=
await
asyncio
.
gather
(
*
coros
)
coros_by_modality
=
{
modality
:
[
item
or
asyncio
.
sleep
(
0
)
for
item
in
items
]
for
modality
,
items
in
self
.
_items_by_modality
.
items
()
}
items_by_modality
:
dict
[
str
,
list
[
object
|
None
]]
=
{
modality
:
await
asyncio
.
gather
(
*
coros
)
for
modality
,
coros
in
coros_by_modality
.
items
()
}
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
if
"audio"
in
items_by_modality
and
"audio_embeds"
in
items_by_modality
:
raise
ValueError
(
"Mixing raw audio and embedding inputs is not allowed"
)
mm_inputs
=
{}
if
"image_embeds"
in
items_by_modality
:
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
mm_inputs
[
"image"
]
=
(
image_embeds_lst
if
len
(
image_embeds_lst
)
!=
1
else
image_embeds_lst
[
0
]
)
mm_inputs
[
"image"
]
=
_get_embeds_data
(
items_by_modality
,
"image"
)
if
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
if
"audio_embeds"
in
items_by_modality
:
audio_embeds_lst
=
items_by_modality
[
"audio_embeds"
]
mm_inputs
[
"audio"
]
=
(
audio_embeds_lst
if
len
(
audio_embeds_lst
)
!=
1
else
audio_embeds_lst
[
0
]
)
mm_inputs
[
"audio"
]
=
_get_embeds_data
(
items_by_modality
,
"audio"
)
if
"audio"
in
items_by_modality
:
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
if
"video"
in
items_by_modality
:
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
return
mm_inputs
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
...
...
@@ -1578,6 +1626,8 @@ def _parse_chat_message_content(
if
"name"
in
message
and
isinstance
(
message
[
"name"
],
str
):
result_msg
[
"name"
]
=
message
[
"name"
]
if
role
==
"developer"
:
result_msg
[
"tools"
]
=
message
.
get
(
"tools"
,
None
)
return
result
...
...
@@ -1588,12 +1638,17 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for
message
in
messages
:
if
(
message
[
"role"
]
==
"assistant"
and
"tool_calls"
in
message
and
isinstance
(
message
[
"tool_calls"
],
list
)
):
for
item
in
message
[
"tool_calls"
]:
if
message
[
"role"
]
==
"assistant"
and
"tool_calls"
in
message
:
tool_calls
=
message
.
get
(
"tool_calls"
)
if
not
isinstance
(
tool_calls
,
list
):
continue
if
len
(
tool_calls
)
==
0
:
# Drop empty tool_calls to keep templates on the normal assistant path.
message
.
pop
(
"tool_calls"
,
None
)
continue
for
item
in
tool_calls
:
# if arguments is None or empty string, set to {}
if
content
:
=
item
[
"function"
].
get
(
"arguments"
):
if
not
isinstance
(
content
,
(
dict
,
list
)):
...
...
@@ -1797,7 +1852,7 @@ def apply_hf_chat_template(
def
apply_mistral_chat_template
(
tokenizer
:
MistralTokenizer
,
tokenizer
:
"
MistralTokenizer
"
,
messages
:
list
[
ChatCompletionMessageParam
],
chat_template
:
str
|
None
,
tools
:
list
[
dict
[
str
,
Any
]]
|
None
,
...
...
vllm/entrypoints/cli/__init__.py
View file @
a3f8d5dd
...
...
@@ -2,12 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.entrypoints.cli.benchmark.latency
import
BenchmarkLatencySubcommand
from
vllm.entrypoints.cli.benchmark.serve
import
BenchmarkServingSubcommand
from
vllm.entrypoints.cli.benchmark.startup
import
BenchmarkStartupSubcommand
from
vllm.entrypoints.cli.benchmark.sweep
import
BenchmarkSweepSubcommand
from
vllm.entrypoints.cli.benchmark.throughput
import
BenchmarkThroughputSubcommand
__all__
:
list
[
str
]
=
[
"BenchmarkLatencySubcommand"
,
"BenchmarkServingSubcommand"
,
"BenchmarkStartupSubcommand"
,
"BenchmarkSweepSubcommand"
,
"BenchmarkThroughputSubcommand"
,
]
vllm/entrypoints/cli/benchmark/startup.py
0 → 100644
View file @
a3f8d5dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
from
vllm.benchmarks.startup
import
add_cli_args
,
main
from
vllm.entrypoints.cli.benchmark.base
import
BenchmarkSubcommandBase
class
BenchmarkStartupSubcommand
(
BenchmarkSubcommandBase
):
"""The `startup` subcommand for `vllm bench`."""
name
=
"startup"
help
=
"Benchmark the startup time of vLLM models."
@
classmethod
def
add_cli_args
(
cls
,
parser
:
argparse
.
ArgumentParser
)
->
None
:
add_cli_args
(
parser
)
@
staticmethod
def
cmd
(
args
:
argparse
.
Namespace
)
->
None
:
main
(
args
)
vllm/entrypoints/context.py
View file @
a3f8d5dd
...
...
@@ -34,13 +34,13 @@ from vllm.entrypoints.openai.protocol import (
ResponseRawMessageAndToken
,
ResponsesRequest
,
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
ToolParser
from
vllm.entrypoints.responses_utils
import
construct_tool_dicts
from
vllm.entrypoints.tool
import
Tool
from
vllm.entrypoints.tool_server
import
ToolServer
from
vllm.outputs
import
RequestOutput
from
vllm.reasoning.abs_reasoning_parsers
import
ReasoningParser
from
vllm.tokenizers.protocol
import
TokenizerLike
from
vllm.tool_parsers.abstract_tool_parser
import
ToolParser
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
random_uuid
...
...
@@ -74,24 +74,24 @@ class TurnMetrics:
def
__init__
(
self
,
input_tokens
=
0
,
output_tokens
=
0
,
cached_input_tokens
=
0
,
tool_output_tokens
=
0
,
):
input_tokens
:
int
=
0
,
output_tokens
:
int
=
0
,
cached_input_tokens
:
int
=
0
,
tool_output_tokens
:
int
=
0
,
)
->
None
:
self
.
input_tokens
=
input_tokens
self
.
output_tokens
=
output_tokens
self
.
cached_input_tokens
=
cached_input_tokens
self
.
tool_output_tokens
=
tool_output_tokens
def
reset
(
self
):
def
reset
(
self
)
->
None
:
"""Reset counters for a new turn."""
self
.
input_tokens
=
0
self
.
output_tokens
=
0
self
.
cached_input_tokens
=
0
self
.
tool_output_tokens
=
0
def
copy
(
self
):
def
copy
(
self
)
->
"TurnMetrics"
:
"""Create a copy of this turn's token counts."""
return
TurnMetrics
(
self
.
input_tokens
,
...
...
vllm/entrypoints/llm.py
View file @
a3f8d5dd
...
...
@@ -9,7 +9,7 @@ import cloudpickle
import
torch.nn
as
nn
from
pydantic
import
ValidationError
from
tqdm.auto
import
tqdm
from
typing_extensions
import
TypeVar
,
deprecated
from
typing_extensions
import
TypeVar
from
vllm.beam_search
import
(
BeamSearchInstance
,
...
...
@@ -72,8 +72,8 @@ from vllm.platforms import current_platform
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
BeamSearchParams
,
RequestOutputKind
,
SamplingParams
from
vllm.tasks
import
PoolingTask
from
vllm.tokenizers
import
MistralTokenizer
,
TokenizerLike
from
vllm.tokenizers.
hf
import
get_cached_t
okenizer
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers.
mistral
import
MistralT
okenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils.collection_utils
import
as_iter
,
is_list_of
from
vllm.utils.counter
import
Counter
...
...
@@ -199,7 +199,7 @@ class LLM:
quantization
:
QuantizationMethods
|
None
=
None
,
revision
:
str
|
None
=
None
,
tokenizer_revision
:
str
|
None
=
None
,
seed
:
int
|
None
=
None
,
seed
:
int
=
0
,
gpu_memory_utilization
:
float
=
0.9
,
swap_space
:
float
=
4
,
cpu_offload_gb
:
float
=
0
,
...
...
@@ -367,16 +367,6 @@ class LLM:
def
get_tokenizer
(
self
)
->
TokenizerLike
:
return
self
.
llm_engine
.
get_tokenizer
()
@
deprecated
(
"`set_tokenizer` is deprecated and will be removed in v0.13."
)
def
set_tokenizer
(
self
,
tokenizer
:
TokenizerLike
)
->
None
:
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
if
tokenizer
.
__class__
.
__name__
.
startswith
(
"Cached"
):
self
.
llm_engine
.
tokenizer
=
tokenizer
else
:
self
.
llm_engine
.
tokenizer
=
get_cached_tokenizer
(
tokenizer
)
def
reset_mm_cache
(
self
)
->
None
:
self
.
input_processor
.
clear_mm_cache
()
self
.
llm_engine
.
reset_mm_cache
()
...
...
vllm/entrypoints/openai/api_server.py
View file @
a3f8d5dd
...
...
@@ -72,7 +72,6 @@ from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription
,
OpenAIServingTranslation
,
)
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.entrypoints.openai.utils
import
validate_json_request
from
vllm.entrypoints.pooling.classify.serving
import
ServingClassification
from
vllm.entrypoints.pooling.embed.serving
import
OpenAIServingEmbedding
...
...
@@ -95,6 +94,7 @@ from vllm.entrypoints.utils import (
from
vllm.logger
import
init_logger
from
vllm.reasoning
import
ReasoningParserManager
from
vllm.tasks
import
POOLING_TASKS
from
vllm.tool_parsers
import
ToolParserManager
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.gc_utils
import
freeze_gc_heap
...
...
vllm/entrypoints/openai/cli_args.py
View file @
a3f8d5dd
...
...
@@ -27,8 +27,8 @@ from vllm.entrypoints.constants import (
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
,
)
from
vllm.entrypoints.openai.serving_models
import
LoRAModulePath
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.logger
import
init_logger
from
vllm.tool_parsers
import
ToolParserManager
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
logger
=
init_logger
(
__name__
)
...
...
@@ -176,7 +176,7 @@ class FrontendArgs:
enable_force_include_usage
:
bool
=
False
"""If set to True, including usage on every request."""
enable_tokenizer_info_endpoint
:
bool
=
False
"""Enable the /
get_
tokenizer_info endpoint. May expose chat
"""Enable the
`
/tokenizer_info
`
endpoint. May expose chat
templates and other tokenizer configuration."""
enable_log_outputs
:
bool
=
False
"""If True, log model outputs (generations).
...
...
vllm/entrypoints/openai/parser/harmony_utils.py
View file @
a3f8d5dd
...
...
@@ -232,7 +232,177 @@ def parse_response_input(
return
msg
def
parse_chat_inputs_to_harmony_messages
(
chat_msgs
:
list
)
->
list
[
Message
]:
"""
Parse a list of messages from request.messages in the Chat Completion API to
Harmony messages.
"""
msgs
:
list
[
Message
]
=
[]
tool_id_names
:
dict
[
str
,
str
]
=
{}
# Collect tool id to name mappings for tool response recipient values
for
chat_msg
in
chat_msgs
:
for
tool_call
in
chat_msg
.
get
(
"tool_calls"
,
[]):
tool_id_names
[
tool_call
.
get
(
"id"
)]
=
tool_call
.
get
(
"function"
,
{}).
get
(
"name"
)
for
chat_msg
in
chat_msgs
:
msgs
.
extend
(
parse_chat_input_to_harmony_message
(
chat_msg
,
tool_id_names
))
msgs
=
auto_drop_analysis_messages
(
msgs
)
return
msgs
def
auto_drop_analysis_messages
(
msgs
:
list
[
Message
])
->
list
[
Message
]:
"""
Harmony models expect the analysis messages (representing raw chain of thought) to
be dropped after an assistant message to the final channel is produced from the
reasoning of those messages.
The openai-harmony library does this if the very last assistant message is to the
final channel, but it does not handle the case where we're in longer multi-turn
conversations and the client gave us reasoning content from previous turns of
the conversation with multiple assistant messages to the final channel in the
conversation.
So, we find the index of the last assistant message to the final channel and drop
all analysis messages that precede it, leaving only the analysis messages that
are relevant to the current part of the conversation.
"""
last_assistant_final_index
=
-
1
for
i
in
range
(
len
(
msgs
)
-
1
,
-
1
,
-
1
):
msg
=
msgs
[
i
]
if
msg
.
author
.
role
==
"assistant"
and
msg
.
channel
==
"final"
:
last_assistant_final_index
=
i
break
cleaned_msgs
:
list
[
Message
]
=
[]
for
i
,
msg
in
enumerate
(
msgs
):
if
i
<
last_assistant_final_index
and
msg
.
channel
==
"analysis"
:
continue
cleaned_msgs
.
append
(
msg
)
return
cleaned_msgs
def
flatten_chat_text_content
(
content
:
str
|
list
|
None
)
->
str
|
None
:
"""
Extract the text parts from a chat message content field and flatten them
into a single string.
"""
if
isinstance
(
content
,
list
):
return
""
.
join
(
item
.
get
(
"text"
,
""
)
for
item
in
content
if
isinstance
(
item
,
dict
)
and
item
.
get
(
"type"
)
==
"text"
)
return
content
def
parse_chat_input_to_harmony_message
(
chat_msg
,
tool_id_names
:
dict
[
str
,
str
]
|
None
=
None
)
->
list
[
Message
]:
"""
Parse a message from request.messages in the Chat Completion API to
Harmony messages.
"""
tool_id_names
=
tool_id_names
or
{}
if
not
isinstance
(
chat_msg
,
dict
):
# Handle Pydantic models
chat_msg
=
chat_msg
.
model_dump
(
exclude_none
=
True
)
role
=
chat_msg
.
get
(
"role"
)
msgs
:
list
[
Message
]
=
[]
# Assistant message with tool calls
tool_calls
=
chat_msg
.
get
(
"tool_calls"
,
[])
if
role
==
"assistant"
and
tool_calls
:
content
=
flatten_chat_text_content
(
chat_msg
.
get
(
"content"
))
if
content
:
commentary_msg
=
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
content
)
commentary_msg
=
commentary_msg
.
with_channel
(
"commentary"
)
msgs
.
append
(
commentary_msg
)
reasoning_content
=
chat_msg
.
get
(
"reasoning"
)
or
chat_msg
.
get
(
"reasoning_content"
)
if
reasoning_content
:
analysis_msg
=
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
reasoning_content
)
analysis_msg
=
analysis_msg
.
with_channel
(
"analysis"
)
msgs
.
append
(
analysis_msg
)
for
call
in
tool_calls
:
func
=
call
.
get
(
"function"
,
{})
name
=
func
.
get
(
"name"
,
""
)
arguments
=
func
.
get
(
"arguments"
,
""
)
or
""
msg
=
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
arguments
)
msg
=
msg
.
with_channel
(
"commentary"
)
msg
=
msg
.
with_recipient
(
f
"functions.
{
name
}
"
)
# Officially, this should be `<|constrain|>json` but there is not clear
# evidence that improves accuracy over `json` and some anecdotes to the
# contrary. Further testing of the different content_types is needed.
msg
=
msg
.
with_content_type
(
"json"
)
msgs
.
append
(
msg
)
return
msgs
# Tool role message (tool output)
if
role
==
"tool"
:
tool_call_id
=
chat_msg
.
get
(
"tool_call_id"
,
""
)
name
=
tool_id_names
.
get
(
tool_call_id
,
""
)
content
=
chat_msg
.
get
(
"content"
,
""
)
or
""
content
=
flatten_chat_text_content
(
content
)
msg
=
(
Message
.
from_author_and_content
(
Author
.
new
(
Role
.
TOOL
,
f
"functions.
{
name
}
"
),
content
)
.
with_channel
(
"commentary"
)
.
with_recipient
(
"assistant"
)
)
return
[
msg
]
# Non-tool reasoning content
reasoning_content
=
chat_msg
.
get
(
"reasoning"
)
or
chat_msg
.
get
(
"reasoning_content"
)
if
role
==
"assistant"
and
reasoning_content
:
analysis_msg
=
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
reasoning_content
)
analysis_msg
=
analysis_msg
.
with_channel
(
"analysis"
)
msgs
.
append
(
analysis_msg
)
# Default: user/assistant/system messages with content
content
=
chat_msg
.
get
(
"content"
)
or
""
if
content
is
None
:
content
=
""
if
isinstance
(
content
,
str
):
contents
=
[
TextContent
(
text
=
content
)]
else
:
# TODO: Support refusal.
contents
=
[
TextContent
(
text
=
c
.
get
(
"text"
,
""
))
for
c
in
content
]
# Only add assistant messages if they have content, as reasoning or tool calling
# assistant messages were already added above.
if
role
==
"assistant"
and
contents
and
contents
[
0
].
text
:
msg
=
Message
.
from_role_and_contents
(
role
,
contents
)
# Send non-tool assistant messages to the final channel
msg
=
msg
.
with_channel
(
"final"
)
msgs
.
append
(
msg
)
# For user/system/developer messages, add them directly even if no content.
elif
role
!=
"assistant"
:
msg
=
Message
.
from_role_and_contents
(
role
,
contents
)
msgs
.
append
(
msg
)
return
msgs
def
parse_input_to_harmony_message
(
chat_msg
)
->
list
[
Message
]:
"""
Parse a message from request.previous_input_messages in the Responsees API to
Harmony messages.
"""
if
not
isinstance
(
chat_msg
,
dict
):
# Handle Pydantic models
chat_msg
=
chat_msg
.
model_dump
(
exclude_none
=
True
)
...
...
@@ -258,14 +428,7 @@ def parse_input_to_harmony_message(chat_msg) -> list[Message]:
if
role
==
"tool"
:
name
=
chat_msg
.
get
(
"name"
,
""
)
content
=
chat_msg
.
get
(
"content"
,
""
)
or
""
if
isinstance
(
content
,
list
):
# Handle array format for tool message content
# by concatenating all text parts.
content
=
""
.
join
(
item
.
get
(
"text"
,
""
)
for
item
in
content
if
isinstance
(
item
,
dict
)
and
item
.
get
(
"type"
)
==
"text"
)
content
=
flatten_chat_text_content
(
content
)
msg
=
Message
.
from_author_and_content
(
Author
.
new
(
Role
.
TOOL
,
f
"functions.
{
name
}
"
),
content
...
...
@@ -623,20 +786,40 @@ def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
def
parse_chat_output
(
token_ids
:
Sequence
[
int
],
)
->
tuple
[
str
|
None
,
str
|
None
,
bool
]:
"""
Parse the output of a Harmony chat completion into reasoning and final content.
Note that when the `openai` tool parser is used, serving_chat only uses this
for the reasoning content and gets the final content from the tool call parser.
When the `openai` tool parser is not enabled, or when `GptOssReasoningParser` is
in use,this needs to return the final content without any tool calls parsed.
Empty reasoning or final content is returned as None instead of an empty string.
"""
parser
=
parse_output_into_messages
(
token_ids
)
output_msgs
=
parser
.
messages
is_tool_call
=
False
# TODO: update this when tool call is supported
if
len
(
output_msgs
)
==
0
:
# The generation has stopped during reasoning.
reasoning
=
parser
.
current_content
final_content
=
None
elif
len
(
output_msgs
)
==
1
:
# The generation has stopped during final message.
reasoning
=
output_msgs
[
0
].
content
[
0
].
text
final_content
=
parser
.
current_content
else
:
reasoning_msg
=
output_msgs
[:
-
1
]
final_msg
=
output_msgs
[
-
1
]
reasoning
=
"
\n
"
.
join
([
msg
.
content
[
0
].
text
for
msg
in
reasoning_msg
])
final_content
=
final_msg
.
content
[
0
].
text
# Get completed messages from the parser
reasoning_texts
=
[
msg
.
content
[
0
].
text
for
msg
in
output_msgs
if
msg
.
channel
==
"analysis"
]
final_texts
=
[
msg
.
content
[
0
].
text
for
msg
in
output_msgs
if
msg
.
channel
!=
"analysis"
]
# Extract partial messages from the parser
if
parser
.
current_channel
==
"analysis"
and
parser
.
current_content
:
reasoning_texts
.
append
(
parser
.
current_content
)
elif
parser
.
current_channel
!=
"analysis"
and
parser
.
current_content
:
final_texts
.
append
(
parser
.
current_content
)
# Flatten multiple messages into a single string
reasoning
:
str
|
None
=
"
\n
"
.
join
(
reasoning_texts
)
final_content
:
str
|
None
=
"
\n
"
.
join
(
final_texts
)
# Return None instead of empty string since existing callers check for None
reasoning
=
reasoning
or
None
final_content
=
final_content
or
None
return
reasoning
,
final_content
,
is_tool_call
Prev
1
…
8
9
10
11
12
13
14
15
16
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment