Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a3f8d5dd
"vllm/vscode:/vscode.git/clone" did not exist on "ea228b4491342f6b7a283e1a414e1a75171a0241"
Commit
a3f8d5dd
authored
Dec 17, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.13.0rc2' into v0.13.0rc2-ori
parents
8d75f22e
f34eca5f
Changes
499
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
700 additions
and
238 deletions
+700
-238
vllm/distributed/ec_transfer/ec_connector/example_connector.py
...distributed/ec_transfer/ec_connector/example_connector.py
+1
-1
vllm/distributed/kv_events.py
vllm/distributed/kv_events.py
+129
-1
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+15
-0
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+9
-1
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
...tributed/kv_transfer/kv_connector/v1/lmcache_connector.py
+114
-3
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
...er/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
+9
-2
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
+0
-3
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+6
-0
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+56
-39
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+2
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+19
-99
vllm/entrypoints/anthropic/serving_messages.py
vllm/entrypoints/anthropic/serving_messages.py
+6
-6
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+92
-37
vllm/entrypoints/cli/__init__.py
vllm/entrypoints/cli/__init__.py
+2
-0
vllm/entrypoints/cli/benchmark/startup.py
vllm/entrypoints/cli/benchmark/startup.py
+21
-0
vllm/entrypoints/context.py
vllm/entrypoints/context.py
+8
-8
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+4
-14
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+1
-1
vllm/entrypoints/openai/cli_args.py
vllm/entrypoints/openai/cli_args.py
+2
-2
vllm/entrypoints/openai/parser/harmony_utils.py
vllm/entrypoints/openai/parser/harmony_utils.py
+204
-21
No files found.
vllm/distributed/ec_transfer/ec_connector/example_connector.py
View file @
a3f8d5dd
...
@@ -144,7 +144,7 @@ class ECExampleConnector(ECConnectorBase):
...
@@ -144,7 +144,7 @@ class ECExampleConnector(ECConnectorBase):
Update ECConnector state after encoder cache allocation.
Update ECConnector state after encoder cache allocation.
"""
"""
mm_hash
=
request
.
mm_features
[
index
].
identifier
mm_hash
=
request
.
mm_features
[
index
].
identifier
num_encoder_token
=
request
.
get_num_encoder_
token
s
(
index
)
num_encoder_token
=
request
.
get_num_encoder_
embed
s
(
index
)
# Insert mm_hash only if this block has not been recorded yet.
# Insert mm_hash only if this block has not been recorded yet.
self
.
_mm_datas_need_loads
[
mm_hash
]
=
num_encoder_token
self
.
_mm_datas_need_loads
[
mm_hash
]
=
num_encoder_token
...
...
vllm/distributed/kv_events.py
View file @
a3f8d5dd
...
@@ -5,7 +5,7 @@ import queue
...
@@ -5,7 +5,7 @@ import queue
import
threading
import
threading
import
time
import
time
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
deque
from
collections
import
Counter
,
deque
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
dataclasses
import
asdict
from
dataclasses
import
asdict
from
itertools
import
count
from
itertools
import
count
...
@@ -54,11 +54,26 @@ class BlockStored(KVCacheEvent):
...
@@ -54,11 +54,26 @@ class BlockStored(KVCacheEvent):
lora_id
:
int
|
None
lora_id
:
int
|
None
medium
:
str
|
None
medium
:
str
|
None
def
__hash__
(
self
)
->
int
:
return
hash
(
(
tuple
(
self
.
block_hashes
),
self
.
parent_block_hash
,
tuple
(
self
.
token_ids
),
self
.
block_size
,
self
.
lora_id
,
self
.
medium
,
)
)
class
BlockRemoved
(
KVCacheEvent
):
class
BlockRemoved
(
KVCacheEvent
):
block_hashes
:
list
[
ExternalBlockHash
]
block_hashes
:
list
[
ExternalBlockHash
]
medium
:
str
|
None
medium
:
str
|
None
def
__hash__
(
self
)
->
int
:
return
hash
((
tuple
(
self
.
block_hashes
),
self
.
medium
))
class
AllBlocksCleared
(
KVCacheEvent
):
class
AllBlocksCleared
(
KVCacheEvent
):
pass
pass
...
@@ -68,6 +83,119 @@ class KVEventBatch(EventBatch):
...
@@ -68,6 +83,119 @@ class KVEventBatch(EventBatch):
events
:
list
[
BlockStored
|
BlockRemoved
|
AllBlocksCleared
]
events
:
list
[
BlockStored
|
BlockRemoved
|
AllBlocksCleared
]
class
KVEventAggregator
:
"""
Aggregates KV events across multiple workers.
Tracks how many times each event appears and returns only those
that were emitted by all workers.
"""
__slots__
=
(
"_event_counter"
,
"_num_workers"
)
def
__init__
(
self
,
num_workers
:
int
)
->
None
:
if
num_workers
<=
0
:
raise
ValueError
(
"num_workers must be greater than zero."
)
self
.
_event_counter
:
Counter
[
KVCacheEvent
]
=
Counter
()
self
.
_num_workers
:
int
=
num_workers
def
add_events
(
self
,
events
:
list
[
KVCacheEvent
])
->
None
:
"""
Add events from a worker batch.
:param events: List of KVCacheEvent objects.
"""
if
not
isinstance
(
events
,
list
):
raise
TypeError
(
"events must be a list of KVCacheEvent."
)
self
.
_event_counter
.
update
(
events
)
def
get_common_events
(
self
)
->
list
[
KVCacheEvent
]:
"""
Return events that appeared in all workers.
:return: List of events present in all workers.
"""
return
[
event
for
event
,
count
in
self
.
_event_counter
.
items
()
if
count
==
self
.
_num_workers
]
def
get_all_events
(
self
)
->
list
[
KVCacheEvent
]:
"""
Return all events for all workers.
:return: List of events for all workers.
"""
return
list
(
self
.
_event_counter
.
elements
())
def
clear_events
(
self
)
->
None
:
"""
Clear all tracked events.
"""
self
.
_event_counter
.
clear
()
def
increment_workers
(
self
,
count
:
int
=
1
)
->
None
:
"""
Increment the number of workers contributing events.
:param count: Number to increment the workers by.
"""
if
count
<=
0
:
raise
ValueError
(
"count must be positive."
)
self
.
_num_workers
+=
count
def
reset_workers
(
self
)
->
None
:
"""
Reset the number of workers to 1.
"""
self
.
_num_workers
=
1
def
get_number_of_workers
(
self
)
->
int
:
"""
Return the number of workers.
:return: int number of workers.
"""
return
self
.
_num_workers
def
__repr__
(
self
)
->
str
:
return
(
f
"<KVEventAggregator workers=
{
self
.
_num_workers
}
, "
f
"events=
{
len
(
self
.
_event_counter
)
}
>"
)
class
KVConnectorKVEvents
(
ABC
):
"""
Abstract base class for KV events.
Acts as a container for KV events from the connector.
"""
@
abstractmethod
def
add_events
(
self
,
events
:
list
[
KVCacheEvent
])
->
None
:
raise
NotImplementedError
@
abstractmethod
def
aggregate
(
self
)
->
"KVConnectorKVEvents"
:
raise
NotImplementedError
@
abstractmethod
def
increment_workers
(
self
,
count
:
int
=
1
)
->
None
:
raise
NotImplementedError
@
abstractmethod
def
get_all_events
(
self
)
->
list
[
KVCacheEvent
]:
raise
NotImplementedError
@
abstractmethod
def
get_number_of_workers
(
self
)
->
int
:
raise
NotImplementedError
@
abstractmethod
def
clear_events
(
self
)
->
None
:
raise
NotImplementedError
class
EventPublisher
(
ABC
):
class
EventPublisher
(
ABC
):
"""Lightweight publisher for EventBatch batches with data parallelism
"""Lightweight publisher for EventBatch batches with data parallelism
support.
support.
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
a3f8d5dd
...
@@ -78,6 +78,7 @@ class KVOutputAggregator:
...
@@ -78,6 +78,7 @@ class KVOutputAggregator:
finished_sending
=
set
[
str
]()
finished_sending
=
set
[
str
]()
finished_recving
=
set
[
str
]()
finished_recving
=
set
[
str
]()
aggregated_kv_connector_stats
=
None
aggregated_kv_connector_stats
=
None
combined_kv_cache_events
=
None
invalid_block_ids
=
set
[
int
]()
invalid_block_ids
=
set
[
int
]()
for
model_runner_output
in
outputs
:
for
model_runner_output
in
outputs
:
assert
model_runner_output
is
not
None
assert
model_runner_output
is
not
None
...
@@ -119,6 +120,19 @@ class KVOutputAggregator:
...
@@ -119,6 +120,19 @@ class KVOutputAggregator:
aggregated_kv_connector_stats
.
aggregate
(
kv_connector_stats
)
aggregated_kv_connector_stats
.
aggregate
(
kv_connector_stats
)
)
)
# Combine kv_cache_events from all workers.
if
combined_kv_cache_events
is
None
:
# Use the first worker's kv_cache events as start event list.
combined_kv_cache_events
=
kv_output
.
kv_cache_events
elif
kv_cache_events
:
=
kv_output
.
kv_cache_events
:
assert
isinstance
(
combined_kv_cache_events
,
type
(
kv_cache_events
),
)
worker_kv_cache_events
=
kv_cache_events
.
get_all_events
()
combined_kv_cache_events
.
add_events
(
worker_kv_cache_events
)
combined_kv_cache_events
.
increment_workers
(
1
)
invalid_block_ids
|=
kv_output
.
invalid_block_ids
invalid_block_ids
|=
kv_output
.
invalid_block_ids
# select output of the worker specified by output_rank
# select output of the worker specified by output_rank
...
@@ -129,6 +143,7 @@ class KVOutputAggregator:
...
@@ -129,6 +143,7 @@ class KVOutputAggregator:
finished_sending
=
finished_sending
or
None
,
finished_sending
=
finished_sending
or
None
,
finished_recving
=
finished_recving
or
None
,
finished_recving
=
finished_recving
or
None
,
kv_connector_stats
=
aggregated_kv_connector_stats
or
None
,
kv_connector_stats
=
aggregated_kv_connector_stats
or
None
,
kv_cache_events
=
combined_kv_cache_events
or
None
,
invalid_block_ids
=
invalid_block_ids
,
invalid_block_ids
=
invalid_block_ids
,
expected_finished_count
=
self
.
_expected_finished_count
,
expected_finished_count
=
self
.
_expected_finished_count
,
)
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
a3f8d5dd
...
@@ -49,7 +49,7 @@ from vllm.v1.outputs import KVConnectorOutput
...
@@ -49,7 +49,7 @@ from vllm.v1.outputs import KVConnectorOutput
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_events
import
KVCacheEvent
,
KVConnectorKVEvents
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorPromMetrics
,
KVConnectorPromMetrics
,
KVConnectorStats
,
KVConnectorStats
,
...
@@ -379,6 +379,14 @@ class KVConnectorBase_V1(ABC):
...
@@ -379,6 +379,14 @@ class KVConnectorBase_V1(ABC):
"""
"""
return
None
return
None
def
get_kv_connector_kv_cache_events
(
self
)
->
Optional
[
"KVConnectorKVEvents"
]:
"""
Get the KV connector kv cache events collected during the last interval.
This function should be called by the model runner every time after the
model execution and before cleanup.
"""
return
None
def
get_handshake_metadata
(
self
)
->
KVConnectorHandshakeMetadata
|
None
:
def
get_handshake_metadata
(
self
)
->
KVConnectorHandshakeMetadata
|
None
:
"""
"""
Get the KVConnector handshake metadata for this connector.
Get the KVConnector handshake metadata for this connector.
...
...
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
View file @
a3f8d5dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
import
torch
import
torch
from
lmcache.integration.vllm.vllm_v1_adapter
import
(
LMCacheConnectorV1Impl
as
LMCacheConnectorLatestImpl
,
)
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_events
import
(
BlockStored
,
KVCacheEvent
,
KVConnectorKVEvents
,
KVEventAggregator
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorMetadata
,
...
@@ -16,6 +20,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
...
@@ -16,6 +20,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
)
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
...
@@ -26,6 +31,44 @@ if TYPE_CHECKING:
...
@@ -26,6 +31,44 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
LMCacheKVEvents
(
KVConnectorKVEvents
):
"""
Concrete implementation of KVConnectorKVEvents using KVEventAggregator.
"""
def
__init__
(
self
,
num_workers
:
int
)
->
None
:
self
.
_aggregator
=
KVEventAggregator
(
num_workers
)
def
add_events
(
self
,
events
:
list
[
KVCacheEvent
])
->
None
:
self
.
_aggregator
.
add_events
(
events
)
def
aggregate
(
self
)
->
"LMCacheKVEvents"
:
"""
Aggregate KV events and retain only common events.
"""
common_events
=
self
.
_aggregator
.
get_common_events
()
self
.
_aggregator
.
clear_events
()
self
.
_aggregator
.
add_events
(
common_events
)
self
.
_aggregator
.
reset_workers
()
return
self
def
increment_workers
(
self
,
count
:
int
=
1
)
->
None
:
self
.
_aggregator
.
increment_workers
(
count
)
def
get_all_events
(
self
)
->
list
[
KVCacheEvent
]:
return
self
.
_aggregator
.
get_all_events
()
def
get_number_of_workers
(
self
)
->
int
:
return
self
.
_aggregator
.
get_number_of_workers
()
def
clear_events
(
self
)
->
None
:
self
.
_aggregator
.
clear_events
()
self
.
_aggregator
.
reset_workers
()
def
__repr__
(
self
)
->
str
:
return
f
"<LMCacheKVEvents events=
{
self
.
get_all_events
()
}
>"
class
LMCacheConnectorV1
(
KVConnectorBase_V1
):
class
LMCacheConnectorV1
(
KVConnectorBase_V1
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -50,10 +93,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
...
@@ -50,10 +93,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
cls
=
_adapter
.
LMCacheConnectorV1Impl
cls
=
_adapter
.
LMCacheConnectorV1Impl
else
:
else
:
logger
.
info
(
"Initializing latest dev LMCache connector"
)
logger
.
info
(
"Initializing latest dev LMCache connector"
)
# lazy import
from
lmcache.integration.vllm.vllm_v1_adapter
import
(
LMCacheConnectorV1Impl
as
LMCacheConnectorLatestImpl
,
)
cls
=
LMCacheConnectorLatestImpl
cls
=
LMCacheConnectorLatestImpl
self
.
_lmcache_engine
=
cls
(
vllm_config
,
role
,
self
)
self
.
_lmcache_engine
=
cls
(
vllm_config
,
role
,
self
)
self
.
_kv_cache_events
:
LMCacheKVEvents
|
None
=
None
# ==============================
# ==============================
# Worker-side methods
# Worker-side methods
# ==============================
# ==============================
...
@@ -151,6 +201,31 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
...
@@ -151,6 +201,31 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
# Fallback for older versions that don't support this method
# Fallback for older versions that don't support this method
return
set
()
return
set
()
def
get_kv_connector_kv_cache_events
(
self
)
->
LMCacheKVEvents
|
None
:
"""
Get the KV connector kv cache events collected during the last interval.
"""
events
=
self
.
_lmcache_engine
.
get_kv_events
()
# type: ignore [attr-defined]
if
not
events
:
return
None
blocks
:
list
[
BlockStored
]
=
[
BlockStored
(
block_hashes
=
e
.
block_hashes
,
parent_block_hash
=
e
.
parent_block_hash
,
token_ids
=
e
.
token_ids
,
lora_id
=
e
.
lora_id
,
block_size
=
e
.
block_size
,
medium
=
e
.
medium
,
)
for
e
in
events
]
lmcache_kv_events
=
LMCacheKVEvents
(
num_workers
=
1
)
lmcache_kv_events
.
add_events
(
blocks
)
return
lmcache_kv_events
# ==============================
# ==============================
# Scheduler-side methods
# Scheduler-side methods
# ==============================
# ==============================
...
@@ -198,6 +273,28 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
...
@@ -198,6 +273,28 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
"""
"""
return
self
.
_lmcache_engine
.
build_connector_meta
(
scheduler_output
)
return
self
.
_lmcache_engine
.
build_connector_meta
(
scheduler_output
)
def
update_connector_output
(
self
,
connector_output
:
KVConnectorOutput
):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
# Get the KV events
kv_cache_events
=
connector_output
.
kv_cache_events
if
not
kv_cache_events
or
not
isinstance
(
kv_cache_events
,
LMCacheKVEvents
):
return
if
self
.
_kv_cache_events
is
None
:
self
.
_kv_cache_events
=
kv_cache_events
else
:
self
.
_kv_cache_events
.
add_events
(
kv_cache_events
.
get_all_events
())
self
.
_kv_cache_events
.
increment_workers
(
kv_cache_events
.
get_number_of_workers
()
)
return
def
request_finished
(
def
request_finished
(
self
,
self
,
request
:
"Request"
,
request
:
"Request"
,
...
@@ -214,3 +311,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
...
@@ -214,3 +311,17 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
returned by the engine.
returned by the engine.
"""
"""
return
self
.
_lmcache_engine
.
request_finished
(
request
,
block_ids
)
return
self
.
_lmcache_engine
.
request_finished
(
request
,
block_ids
)
def
take_events
(
self
)
->
Iterable
[
"KVCacheEvent"
]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
if
self
.
_kv_cache_events
is
not
None
:
self
.
_kv_cache_events
.
aggregate
()
kv_cache_events
=
self
.
_kv_cache_events
.
get_all_events
()
yield
from
kv_cache_events
self
.
_kv_cache_events
.
clear_events
()
self
.
_kv_cache_events
=
None
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
View file @
a3f8d5dd
...
@@ -27,7 +27,14 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
...
@@ -27,7 +27,14 @@ from lmcache.v1.lookup_client.lmcache_async_lookup_client import (
LMCacheAsyncLookupServer
,
LMCacheAsyncLookupServer
,
)
)
from
lmcache.v1.offload_server.zmq_server
import
ZMQOffloadServer
from
lmcache.v1.offload_server.zmq_server
import
ZMQOffloadServer
from
lmcache.v1.plugin.plugin_launcher
import
PluginLauncher
try
:
from
lmcache.v1.plugin.runtime_plugin_launcher
import
RuntimePluginLauncher
except
ImportError
:
# Backwards compatibility for lmcache <= 0.3.10-post1
from
lmcache.v1.plugin.plugin_launcher
import
(
PluginLauncher
as
RuntimePluginLauncher
,
)
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
...
@@ -683,7 +690,7 @@ class LMCacheConnectorV1Impl:
...
@@ -683,7 +690,7 @@ class LMCacheConnectorV1Impl:
self
.
api_server
=
InternalAPIServer
(
self
)
self
.
api_server
=
InternalAPIServer
(
self
)
self
.
api_server
.
start
()
self
.
api_server
.
start
()
# Launch plugins
# Launch plugins
self
.
plugin_launcher
=
PluginLauncher
(
self
.
plugin_launcher
=
Runtime
PluginLauncher
(
self
.
config
,
self
.
config
,
role
,
role
,
self
.
worker_count
,
self
.
worker_count
,
...
...
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
View file @
a3f8d5dd
...
@@ -7,7 +7,6 @@ from prometheus_client import Counter, Gauge, Histogram
...
@@ -7,7 +7,6 @@ from prometheus_client import Counter, Gauge, Histogram
from
vllm.config
import
KVTransferConfig
,
VllmConfig
from
vllm.config
import
KVTransferConfig
,
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.factory
import
KVConnectorFactory
from
vllm.distributed.kv_transfer.kv_connector.factory
import
KVConnectorFactory
from
vllm.distributed.kv_transfer.kv_transfer_state
import
has_kv_transfer_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
PromMetric
:
TypeAlias
=
Gauge
|
Counter
|
Histogram
PromMetric
:
TypeAlias
=
Gauge
|
Counter
|
Histogram
...
@@ -53,8 +52,6 @@ class KVConnectorStats:
...
@@ -53,8 +52,6 @@ class KVConnectorStats:
class
KVConnectorLogging
:
class
KVConnectorLogging
:
def
__init__
(
self
,
kv_transfer_config
:
KVTransferConfig
|
None
):
def
__init__
(
self
,
kv_transfer_config
:
KVTransferConfig
|
None
):
# This should be called on frontend process.
assert
not
has_kv_transfer_group
()
# Instantiate the connector's stats class.
# Instantiate the connector's stats class.
if
kv_transfer_config
and
kv_transfer_config
.
kv_connector
:
if
kv_transfer_config
and
kv_transfer_config
.
kv_connector
:
self
.
connector_cls
=
KVConnectorFactory
.
get_connector_class
(
self
.
connector_cls
=
KVConnectorFactory
.
get_connector_class
(
...
...
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
a3f8d5dd
...
@@ -259,6 +259,12 @@ class MultiConnector(KVConnectorBase_V1):
...
@@ -259,6 +259,12 @@ class MultiConnector(KVConnectorBase_V1):
agg_block_ids
|=
c
.
get_block_ids_with_load_errors
()
agg_block_ids
|=
c
.
get_block_ids_with_load_errors
()
return
agg_block_ids
return
agg_block_ids
# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events' method
# for the MultiConnector. It should be able to get events from multiple
# connectors, handling the case where only a subset of the requested connectors
# implements the 'get_kv_connector_kv_cache_events'
# Follow on PR from https://github.com/vllm-project/vllm/pull/28309#pullrequestreview-3566351082
# ==============================
# ==============================
# Scheduler-side methods
# Scheduler-side methods
# ==============================
# ==============================
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
a3f8d5dd
...
@@ -202,17 +202,22 @@ def compute_nixl_compatibility_hash(
...
@@ -202,17 +202,22 @@ def compute_nixl_compatibility_hash(
return
compat_hash
return
compat_hash
@
dataclass
class
RemoteMeta
:
block_ids
:
list
[
int
]
host
:
str
port
:
int
engine_id
:
str
request_id
:
str
@
dataclass
@
dataclass
class
ReqMeta
:
class
ReqMeta
:
local_block_ids
:
list
[
int
]
local_block_ids
:
list
[
int
]
# To be used when logical block size does not match the kernel block size
# To be used when logical block size does not match the kernel block size
local_physical_block_ids
:
list
[
int
]
local_physical_block_ids
:
list
[
int
]
remote_block_ids
:
list
[
int
]
remote_host
:
str
remote_port
:
int
remote_engine_id
:
str
remote_request_id
:
str
tp_size
:
int
tp_size
:
int
remote
:
RemoteMeta
|
None
=
None
class
NixlConnectorMetadata
(
KVConnectorMetadata
):
class
NixlConnectorMetadata
(
KVConnectorMetadata
):
...
@@ -223,31 +228,43 @@ class NixlConnectorMetadata(KVConnectorMetadata):
...
@@ -223,31 +228,43 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self
.
reqs_in_batch
:
set
[
ReqId
]
=
set
()
self
.
reqs_in_batch
:
set
[
ReqId
]
=
set
()
self
.
reqs_not_processed
:
set
[
ReqId
]
=
set
()
self
.
reqs_not_processed
:
set
[
ReqId
]
=
set
()
def
add_new_req
(
def
_
add_new_req
(
self
,
self
,
request_id
:
ReqId
,
local_block_ids
:
list
[
int
],
local_block_ids
:
list
[
int
],
kv_transfer_params
:
dict
[
str
,
Any
],
kv_transfer_params
:
dict
[
str
,
Any
],
load_remote_cache
:
bool
=
True
,
)
->
ReqMeta
:
save_to_host
:
bool
=
False
,
return
ReqMeta
(
):
# save and load are mutually exclusive
assert
load_remote_cache
^
save_to_host
_req
=
ReqMeta
(
local_block_ids
=
local_block_ids
,
local_block_ids
=
local_block_ids
,
local_physical_block_ids
=
local_block_ids
,
local_physical_block_ids
=
local_block_ids
,
remote_block_ids
=
kv_transfer_params
[
"remote_block_ids"
],
remote_engine_id
=
kv_transfer_params
[
"remote_engine_id"
],
remote_request_id
=
kv_transfer_params
[
"remote_request_id"
],
remote_host
=
kv_transfer_params
[
"remote_host"
],
remote_port
=
kv_transfer_params
[
"remote_port"
],
# P workers don't need to receive tp_size from proxy here.
# P workers don't need to receive tp_size from proxy here.
tp_size
=
kv_transfer_params
.
get
(
"tp_size"
,
1
),
tp_size
=
kv_transfer_params
.
get
(
"tp_size"
,
1
),
)
)
if
save_to_host
:
self
.
reqs_to_save
[
request_id
]
=
_req
def
add_new_req_to_save
(
if
load_remote_cache
:
self
,
self
.
reqs_to_recv
[
request_id
]
=
_req
request_id
:
ReqId
,
local_block_ids
:
list
[
int
],
kv_transfer_params
:
dict
[
str
,
Any
],
):
self
.
reqs_to_save
[
request_id
]
=
self
.
_add_new_req
(
local_block_ids
,
kv_transfer_params
)
def
add_new_req_to_recv
(
self
,
request_id
:
ReqId
,
local_block_ids
:
list
[
int
],
kv_transfer_params
:
dict
[
str
,
Any
],
):
req
=
self
.
_add_new_req
(
local_block_ids
,
kv_transfer_params
)
req
.
remote
=
RemoteMeta
(
block_ids
=
kv_transfer_params
[
"remote_block_ids"
],
engine_id
=
kv_transfer_params
[
"remote_engine_id"
],
request_id
=
kv_transfer_params
[
"remote_request_id"
],
host
=
kv_transfer_params
[
"remote_host"
],
port
=
kv_transfer_params
[
"remote_port"
],
)
self
.
reqs_to_recv
[
request_id
]
=
req
class
NixlConnector
(
KVConnectorBase_V1
):
class
NixlConnector
(
KVConnectorBase_V1
):
...
@@ -666,22 +683,18 @@ class NixlConnectorScheduler:
...
@@ -666,22 +683,18 @@ class NixlConnectorScheduler:
# Loop through scheduled reqs and convert to ReqMeta.
# Loop through scheduled reqs and convert to ReqMeta.
for
req_id
,
(
req
,
block_ids
)
in
self
.
_reqs_need_recv
.
items
():
for
req_id
,
(
req
,
block_ids
)
in
self
.
_reqs_need_recv
.
items
():
assert
req
.
kv_transfer_params
is
not
None
assert
req
.
kv_transfer_params
is
not
None
meta
.
add_new_req
(
meta
.
add_new_req
_to_recv
(
request_id
=
req_id
,
request_id
=
req_id
,
local_block_ids
=
block_ids
,
local_block_ids
=
block_ids
,
kv_transfer_params
=
req
.
kv_transfer_params
,
kv_transfer_params
=
req
.
kv_transfer_params
,
load_remote_cache
=
True
,
save_to_host
=
False
,
)
)
for
req_id
,
(
req
,
block_ids
)
in
self
.
_reqs_need_save
.
items
():
for
req_id
,
(
req
,
block_ids
)
in
self
.
_reqs_need_save
.
items
():
assert
req
.
kv_transfer_params
is
not
None
assert
req
.
kv_transfer_params
is
not
None
meta
.
add_new_req
(
meta
.
add_new_req
_to_save
(
request_id
=
req_id
,
request_id
=
req_id
,
local_block_ids
=
block_ids
,
local_block_ids
=
block_ids
,
kv_transfer_params
=
req
.
kv_transfer_params
,
kv_transfer_params
=
req
.
kv_transfer_params
,
load_remote_cache
=
False
,
save_to_host
=
True
,
)
)
meta
.
reqs_to_send
=
self
.
_reqs_need_send
meta
.
reqs_to_send
=
self
.
_reqs_need_send
...
@@ -1124,10 +1137,11 @@ class NixlConnectorWorker:
...
@@ -1124,10 +1137,11 @@ class NixlConnectorWorker:
# Do NIXL handshake in background and add to _ready_requests when done.
# Do NIXL handshake in background and add to _ready_requests when done.
fut
=
self
.
_handshake_futures
.
get
(
remote_engine_id
)
fut
=
self
.
_handshake_futures
.
get
(
remote_engine_id
)
if
fut
is
None
:
if
fut
is
None
:
assert
meta
.
remote
is
not
None
fut
=
self
.
_handshake_initiation_executor
.
submit
(
fut
=
self
.
_handshake_initiation_executor
.
submit
(
self
.
_nixl_handshake
,
self
.
_nixl_handshake
,
meta
.
remote
_
host
,
meta
.
remote
.
host
,
meta
.
remote
_
port
,
meta
.
remote
.
port
,
meta
.
tp_size
,
meta
.
tp_size
,
remote_engine_id
,
remote_engine_id
,
)
)
...
@@ -1774,6 +1788,7 @@ class NixlConnectorWorker:
...
@@ -1774,6 +1788,7 @@ class NixlConnectorWorker:
# clean up metadata for completed requests
# clean up metadata for completed requests
meta
=
self
.
_recving_metadata
.
pop
(
req_id
,
None
)
meta
=
self
.
_recving_metadata
.
pop
(
req_id
,
None
)
assert
meta
is
not
None
,
f
"
{
req_id
}
not found in recving_metadata list"
assert
meta
is
not
None
,
f
"
{
req_id
}
not found in recving_metadata list"
assert
meta
.
remote
is
not
None
if
self
.
use_host_buffer
:
if
self
.
use_host_buffer
:
self
.
sync_recved_kv_to_device
(
req_id
,
meta
)
self
.
sync_recved_kv_to_device
(
req_id
,
meta
)
if
self
.
enable_permute_local_kv
:
if
self
.
enable_permute_local_kv
:
...
@@ -1781,7 +1796,7 @@ class NixlConnectorWorker:
...
@@ -1781,7 +1796,7 @@ class NixlConnectorWorker:
# post processing for heteroblocksize
# post processing for heteroblocksize
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
meta
.
remote
_
engine_id
meta
.
remote
.
engine_id
)
)
if
(
if
(
not
self
.
use_mla
not
self
.
use_mla
...
@@ -1916,17 +1931,18 @@ class NixlConnectorWorker:
...
@@ -1916,17 +1931,18 @@ class NixlConnectorWorker:
meta
.
local_physical_block_ids
=
self
.
_logical_to_kernel_block_ids
(
meta
.
local_physical_block_ids
=
self
.
_logical_to_kernel_block_ids
(
meta
.
local_block_ids
meta
.
local_block_ids
)
)
meta
.
remote_block_ids
=
self
.
_logical_to_kernel_block_ids
(
assert
meta
.
remote
is
not
None
meta
.
remote_block_ids
meta
.
remote
.
block_ids
=
self
.
_logical_to_kernel_block_ids
(
meta
.
remote
.
block_ids
)
)
remote_engine_id
=
meta
.
remote
_
engine_id
remote_engine_id
=
meta
.
remote
.
engine_id
logger
.
debug
(
logger
.
debug
(
"start_load_kv for request %s from remote engine %s. "
"start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. "
,
"Num local_block_ids: %s. Num remote_block_ids: %s. "
,
req_id
,
req_id
,
remote_engine_id
,
remote_engine_id
,
len
(
meta
.
local_physical_block_ids
),
len
(
meta
.
local_physical_block_ids
),
len
(
meta
.
remote
_
block_ids
),
len
(
meta
.
remote
.
block_ids
),
)
)
# always store metadata for failure recovery
# always store metadata for failure recovery
self
.
_recving_metadata
[
req_id
]
=
meta
self
.
_recving_metadata
[
req_id
]
=
meta
...
@@ -1965,17 +1981,18 @@ class NixlConnectorWorker:
...
@@ -1965,17 +1981,18 @@ class NixlConnectorWorker:
self
.
_reqs_to_send
[
req_id
]
=
expiration_time
self
.
_reqs_to_send
[
req_id
]
=
expiration_time
def
_read_blocks_for_req
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
def
_read_blocks_for_req
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
assert
meta
.
remote
is
not
None
logger
.
debug
(
logger
.
debug
(
"Remote agent %s available, calling _read_blocks for req %s"
,
"Remote agent %s available, calling _read_blocks for req %s"
,
meta
.
remote
_
engine_id
,
meta
.
remote
.
engine_id
,
req_id
,
req_id
,
)
)
self
.
_read_blocks
(
self
.
_read_blocks
(
request_id
=
req_id
,
request_id
=
req_id
,
dst_engine_id
=
meta
.
remote
_
engine_id
,
dst_engine_id
=
meta
.
remote
.
engine_id
,
remote_request_id
=
meta
.
remote
_
request_id
,
remote_request_id
=
meta
.
remote
.
request_id
,
local_block_ids
=
meta
.
local_physical_block_ids
,
local_block_ids
=
meta
.
local_physical_block_ids
,
remote_block_ids
=
meta
.
remote
_
block_ids
,
remote_block_ids
=
meta
.
remote
.
block_ids
,
)
)
def
_read_blocks
(
def
_read_blocks
(
...
...
vllm/distributed/parallel_state.py
View file @
a3f8d5dd
...
@@ -1586,6 +1586,8 @@ def destroy_distributed_environment():
...
@@ -1586,6 +1586,8 @@ def destroy_distributed_environment():
def
cleanup_dist_env_and_memory
(
shutdown_ray
:
bool
=
False
):
def
cleanup_dist_env_and_memory
(
shutdown_ray
:
bool
=
False
):
# Reset environment variable cache
envs
.
disable_envs_cache
()
# Ensure all objects are not frozen before cleanup
# Ensure all objects are not frozen before cleanup
gc
.
unfreeze
()
gc
.
unfreeze
()
...
...
vllm/engine/arg_utils.py
View file @
a3f8d5dd
...
@@ -71,7 +71,6 @@ from vllm.config.model import (
...
@@ -71,7 +71,6 @@ from vllm.config.model import (
LogprobsMode
,
LogprobsMode
,
ModelDType
,
ModelDType
,
RunnerOption
,
RunnerOption
,
TaskOption
,
TokenizerMode
,
TokenizerMode
,
)
)
from
vllm.config.multimodal
import
MMCacheType
,
MMEncoderTPMode
from
vllm.config.multimodal
import
MMCacheType
,
MMEncoderTPMode
...
@@ -360,7 +359,6 @@ class EngineArgs:
...
@@ -360,7 +359,6 @@ class EngineArgs:
hf_config_path
:
str
|
None
=
ModelConfig
.
hf_config_path
hf_config_path
:
str
|
None
=
ModelConfig
.
hf_config_path
runner
:
RunnerOption
=
ModelConfig
.
runner
runner
:
RunnerOption
=
ModelConfig
.
runner
convert
:
ConvertOption
=
ModelConfig
.
convert
convert
:
ConvertOption
=
ModelConfig
.
convert
task
:
TaskOption
|
None
=
ModelConfig
.
task
skip_tokenizer_init
:
bool
=
ModelConfig
.
skip_tokenizer_init
skip_tokenizer_init
:
bool
=
ModelConfig
.
skip_tokenizer_init
enable_prompt_embeds
:
bool
=
ModelConfig
.
enable_prompt_embeds
enable_prompt_embeds
:
bool
=
ModelConfig
.
enable_prompt_embeds
tokenizer_mode
:
TokenizerMode
|
str
=
ModelConfig
.
tokenizer_mode
tokenizer_mode
:
TokenizerMode
|
str
=
ModelConfig
.
tokenizer_mode
...
@@ -373,9 +371,8 @@ class EngineArgs:
...
@@ -373,9 +371,8 @@ class EngineArgs:
config_format
:
str
=
ModelConfig
.
config_format
config_format
:
str
=
ModelConfig
.
config_format
dtype
:
ModelDType
=
ModelConfig
.
dtype
dtype
:
ModelDType
=
ModelConfig
.
dtype
kv_cache_dtype
:
CacheDType
=
CacheConfig
.
cache_dtype
kv_cache_dtype
:
CacheDType
=
CacheConfig
.
cache_dtype
seed
:
int
|
None
=
0
seed
:
int
=
ModelConfig
.
seed
max_model_len
:
int
|
None
=
ModelConfig
.
max_model_len
max_model_len
:
int
|
None
=
ModelConfig
.
max_model_len
cuda_graph_sizes
:
list
[
int
]
|
None
=
CompilationConfig
.
cudagraph_capture_sizes
cudagraph_capture_sizes
:
list
[
int
]
|
None
=
(
cudagraph_capture_sizes
:
list
[
int
]
|
None
=
(
CompilationConfig
.
cudagraph_capture_sizes
CompilationConfig
.
cudagraph_capture_sizes
)
)
...
@@ -463,7 +460,6 @@ class EngineArgs:
...
@@ -463,7 +460,6 @@ class EngineArgs:
MultiModalConfig
,
"media_io_kwargs"
MultiModalConfig
,
"media_io_kwargs"
)
)
mm_processor_kwargs
:
dict
[
str
,
Any
]
|
None
=
MultiModalConfig
.
mm_processor_kwargs
mm_processor_kwargs
:
dict
[
str
,
Any
]
|
None
=
MultiModalConfig
.
mm_processor_kwargs
disable_mm_preprocessor_cache
:
bool
=
False
# DEPRECATED
mm_processor_cache_gb
:
float
=
MultiModalConfig
.
mm_processor_cache_gb
mm_processor_cache_gb
:
float
=
MultiModalConfig
.
mm_processor_cache_gb
mm_processor_cache_type
:
MMCacheType
|
None
=
(
mm_processor_cache_type
:
MMCacheType
|
None
=
(
MultiModalConfig
.
mm_processor_cache_type
MultiModalConfig
.
mm_processor_cache_type
...
@@ -495,7 +491,7 @@ class EngineArgs:
...
@@ -495,7 +491,7 @@ class EngineArgs:
enable_chunked_prefill
:
bool
|
None
=
None
enable_chunked_prefill
:
bool
|
None
=
None
disable_chunked_mm_input
:
bool
=
SchedulerConfig
.
disable_chunked_mm_input
disable_chunked_mm_input
:
bool
=
SchedulerConfig
.
disable_chunked_mm_input
disable_hybrid_kv_cache_manager
:
bool
=
(
disable_hybrid_kv_cache_manager
:
bool
|
None
=
(
SchedulerConfig
.
disable_hybrid_kv_cache_manager
SchedulerConfig
.
disable_hybrid_kv_cache_manager
)
)
...
@@ -559,9 +555,6 @@ class EngineArgs:
...
@@ -559,9 +555,6 @@ class EngineArgs:
use_tqdm_on_load
:
bool
=
LoadConfig
.
use_tqdm_on_load
use_tqdm_on_load
:
bool
=
LoadConfig
.
use_tqdm_on_load
pt_load_map_location
:
str
=
LoadConfig
.
pt_load_map_location
pt_load_map_location
:
str
=
LoadConfig
.
pt_load_map_location
# DEPRECATED
enable_multimodal_encoder_data_parallel
:
bool
=
False
logits_processors
:
list
[
str
|
type
[
LogitsProcessor
]]
|
None
=
(
logits_processors
:
list
[
str
|
type
[
LogitsProcessor
]]
|
None
=
(
ModelConfig
.
logits_processors
ModelConfig
.
logits_processors
)
)
...
@@ -629,7 +622,6 @@ class EngineArgs:
...
@@ -629,7 +622,6 @@ class EngineArgs:
model_group
.
add_argument
(
"--model"
,
**
model_kwargs
[
"model"
])
model_group
.
add_argument
(
"--model"
,
**
model_kwargs
[
"model"
])
model_group
.
add_argument
(
"--runner"
,
**
model_kwargs
[
"runner"
])
model_group
.
add_argument
(
"--runner"
,
**
model_kwargs
[
"runner"
])
model_group
.
add_argument
(
"--convert"
,
**
model_kwargs
[
"convert"
])
model_group
.
add_argument
(
"--convert"
,
**
model_kwargs
[
"convert"
])
model_group
.
add_argument
(
"--task"
,
**
model_kwargs
[
"task"
],
deprecated
=
True
)
model_group
.
add_argument
(
"--tokenizer"
,
**
model_kwargs
[
"tokenizer"
])
model_group
.
add_argument
(
"--tokenizer"
,
**
model_kwargs
[
"tokenizer"
])
model_group
.
add_argument
(
"--tokenizer-mode"
,
**
model_kwargs
[
"tokenizer_mode"
])
model_group
.
add_argument
(
"--tokenizer-mode"
,
**
model_kwargs
[
"tokenizer_mode"
])
model_group
.
add_argument
(
model_group
.
add_argument
(
...
@@ -883,11 +875,6 @@ class EngineArgs:
...
@@ -883,11 +875,6 @@ class EngineArgs:
parallel_group
.
add_argument
(
parallel_group
.
add_argument
(
"--worker-extension-cls"
,
**
parallel_kwargs
[
"worker_extension_cls"
]
"--worker-extension-cls"
,
**
parallel_kwargs
[
"worker_extension_cls"
]
)
)
parallel_group
.
add_argument
(
"--enable-multimodal-encoder-data-parallel"
,
action
=
"store_true"
,
deprecated
=
True
,
)
# KV cache arguments
# KV cache arguments
cache_kwargs
=
get_kwargs
(
CacheConfig
)
cache_kwargs
=
get_kwargs
(
CacheConfig
)
...
@@ -961,9 +948,6 @@ class EngineArgs:
...
@@ -961,9 +948,6 @@ class EngineArgs:
multimodal_group
.
add_argument
(
multimodal_group
.
add_argument
(
"--mm-processor-cache-gb"
,
**
multimodal_kwargs
[
"mm_processor_cache_gb"
]
"--mm-processor-cache-gb"
,
**
multimodal_kwargs
[
"mm_processor_cache_gb"
]
)
)
multimodal_group
.
add_argument
(
"--disable-mm-preprocessor-cache"
,
action
=
"store_true"
,
deprecated
=
True
)
multimodal_group
.
add_argument
(
multimodal_group
.
add_argument
(
"--mm-processor-cache-type"
,
**
multimodal_kwargs
[
"mm_processor_cache_type"
]
"--mm-processor-cache-type"
,
**
multimodal_kwargs
[
"mm_processor_cache_type"
]
)
)
...
@@ -1121,15 +1105,6 @@ class EngineArgs:
...
@@ -1121,15 +1105,6 @@ class EngineArgs:
compilation_group
.
add_argument
(
compilation_group
.
add_argument
(
"--cudagraph-capture-sizes"
,
**
compilation_kwargs
[
"cudagraph_capture_sizes"
]
"--cudagraph-capture-sizes"
,
**
compilation_kwargs
[
"cudagraph_capture_sizes"
]
)
)
compilation_kwargs
[
"cudagraph_capture_sizes"
][
"help"
]
=
(
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or v1.0.0,"
" whichever is soonest. Please use --cudagraph-capture-sizes instead."
)
compilation_group
.
add_argument
(
"--cuda-graph-sizes"
,
**
compilation_kwargs
[
"cudagraph_capture_sizes"
],
deprecated
=
True
,
)
compilation_group
.
add_argument
(
compilation_group
.
add_argument
(
"--max-cudagraph-capture-size"
,
"--max-cudagraph-capture-size"
,
**
compilation_kwargs
[
"max_cudagraph_capture_size"
],
**
compilation_kwargs
[
"max_cudagraph_capture_size"
],
...
@@ -1202,62 +1177,20 @@ class EngineArgs:
...
@@ -1202,62 +1177,20 @@ class EngineArgs:
if
is_gguf
(
self
.
model
):
if
is_gguf
(
self
.
model
):
self
.
quantization
=
self
.
load_format
=
"gguf"
self
.
quantization
=
self
.
load_format
=
"gguf"
# NOTE(woosuk): In V1, we use separate processes for workers (unless
if
not
envs
.
VLLM_ENABLE_V1_MULTIPROCESSING
:
# VLLM_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here
logger
.
warning
(
# doesn't affect the user process.
"The global random seed is set to %d. Since "
if
self
.
seed
is
None
:
"VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
logger
.
warning_once
(
"affect the random state of the Python process that "
"`seed=None` is equivalent to `seed=0` in V1 Engine. "
"launched vLLM."
,
"You will no longer be allowed to pass `None` in v0.13."
,
self
.
seed
,
scope
=
"local"
,
)
self
.
seed
=
0
if
not
envs
.
VLLM_ENABLE_V1_MULTIPROCESSING
:
logger
.
warning
(
"The global random seed is set to %d. Since "
"VLLM_ENABLE_V1_MULTIPROCESSING is set to False, this may "
"affect the random state of the Python process that "
"launched vLLM."
,
self
.
seed
,
)
if
self
.
disable_mm_preprocessor_cache
:
logger
.
warning_once
(
"`--disable-mm-preprocessor-cache` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-processor-cache-gb 0` instead."
,
scope
=
"local"
,
)
self
.
mm_processor_cache_gb
=
0
elif
envs
.
VLLM_MM_INPUT_CACHE_GIB
!=
4
:
logger
.
warning_once
(
"VLLM_MM_INPUT_CACHE_GIB` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-processor-cache-gb %d` instead."
,
envs
.
VLLM_MM_INPUT_CACHE_GIB
,
scope
=
"local"
,
)
self
.
mm_processor_cache_gb
=
envs
.
VLLM_MM_INPUT_CACHE_GIB
if
self
.
enable_multimodal_encoder_data_parallel
:
logger
.
warning_once
(
"--enable-multimodal-encoder-data-parallel` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-encoder-tp-mode data` instead."
,
scope
=
"local"
,
)
)
self
.
mm_encoder_tp_mode
=
"data"
return
ModelConfig
(
return
ModelConfig
(
model
=
self
.
model
,
model
=
self
.
model
,
hf_config_path
=
self
.
hf_config_path
,
hf_config_path
=
self
.
hf_config_path
,
runner
=
self
.
runner
,
runner
=
self
.
runner
,
convert
=
self
.
convert
,
convert
=
self
.
convert
,
task
=
self
.
task
,
tokenizer
=
self
.
tokenizer
,
tokenizer
=
self
.
tokenizer
,
tokenizer_mode
=
self
.
tokenizer_mode
,
tokenizer_mode
=
self
.
tokenizer_mode
,
trust_remote_code
=
self
.
trust_remote_code
,
trust_remote_code
=
self
.
trust_remote_code
,
...
@@ -1716,7 +1649,13 @@ class EngineArgs:
...
@@ -1716,7 +1649,13 @@ class EngineArgs:
"attention_backend and attention_config.backend "
"attention_backend and attention_config.backend "
"are mutually exclusive"
"are mutually exclusive"
)
)
attention_config
.
backend
=
self
.
attention_backend
# Convert string to enum if needed (CLI parsing returns a string)
if
isinstance
(
self
.
attention_backend
,
str
):
attention_config
.
backend
=
AttentionBackendEnum
[
self
.
attention_backend
.
upper
()
]
else
:
attention_config
.
backend
=
self
.
attention_backend
load_config
=
self
.
create_load_config
()
load_config
=
self
.
create_load_config
()
...
@@ -1741,18 +1680,6 @@ class EngineArgs:
...
@@ -1741,18 +1680,6 @@ class EngineArgs:
# Compilation config overrides
# Compilation config overrides
compilation_config
=
copy
.
deepcopy
(
self
.
compilation_config
)
compilation_config
=
copy
.
deepcopy
(
self
.
compilation_config
)
if
self
.
cuda_graph_sizes
is
not
None
:
logger
.
warning
(
"--cuda-graph-sizes is deprecated and will be removed in v0.13.0 or "
"v1.0.0, whichever is soonest. Please use --cudagraph-capture-sizes "
"instead."
)
if
compilation_config
.
cudagraph_capture_sizes
is
not
None
:
raise
ValueError
(
"cuda_graph_sizes and compilation_config."
"cudagraph_capture_sizes are mutually exclusive"
)
compilation_config
.
cudagraph_capture_sizes
=
self
.
cuda_graph_sizes
if
self
.
cudagraph_capture_sizes
is
not
None
:
if
self
.
cudagraph_capture_sizes
is
not
None
:
if
compilation_config
.
cudagraph_capture_sizes
is
not
None
:
if
compilation_config
.
cudagraph_capture_sizes
is
not
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -1862,6 +1789,7 @@ class EngineArgs:
...
@@ -1862,6 +1789,7 @@ class EngineArgs:
except
Exception
:
except
Exception
:
# This is only used to set default_max_num_batched_tokens
# This is only used to set default_max_num_batched_tokens
device_memory
=
0
device_memory
=
0
device_name
=
""
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
# NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces
# throughput, see PR #17885 for more details.
# throughput, see PR #17885 for more details.
...
@@ -1926,16 +1854,6 @@ class EngineArgs:
...
@@ -1926,16 +1854,6 @@ class EngineArgs:
default_chunked_prefill
=
model_config
.
is_chunked_prefill_supported
default_chunked_prefill
=
model_config
.
is_chunked_prefill_supported
default_prefix_caching
=
model_config
.
is_prefix_caching_supported
default_prefix_caching
=
model_config
.
is_prefix_caching_supported
if
self
.
prefill_context_parallel_size
>
1
:
default_chunked_prefill
=
False
default_prefix_caching
=
False
logger
.
warning_once
(
"--prefill-context-parallel-size > 1 is not compatible with "
"chunked prefill and prefix caching now. Chunked prefill "
"and prefix caching have been disabled by default."
,
scope
=
"local"
,
)
if
self
.
enable_chunked_prefill
is
None
:
if
self
.
enable_chunked_prefill
is
None
:
self
.
enable_chunked_prefill
=
default_chunked_prefill
self
.
enable_chunked_prefill
=
default_chunked_prefill
...
@@ -2121,11 +2039,13 @@ def human_readable_int(value):
...
@@ -2121,11 +2039,13 @@ def human_readable_int(value):
"k"
:
10
**
3
,
"k"
:
10
**
3
,
"m"
:
10
**
6
,
"m"
:
10
**
6
,
"g"
:
10
**
9
,
"g"
:
10
**
9
,
"t"
:
10
**
12
,
}
}
binary_multiplier
=
{
binary_multiplier
=
{
"K"
:
2
**
10
,
"K"
:
2
**
10
,
"M"
:
2
**
20
,
"M"
:
2
**
20
,
"G"
:
2
**
30
,
"G"
:
2
**
30
,
"T"
:
2
**
40
,
}
}
number
,
suffix
=
match
.
groups
()
number
,
suffix
=
match
.
groups
()
...
...
vllm/entrypoints/anthropic/serving_messages.py
View file @
a3f8d5dd
...
@@ -324,12 +324,12 @@ class AnthropicServingMessages(OpenAIServingChat):
...
@@ -324,12 +324,12 @@ class AnthropicServingMessages(OpenAIServingChat):
id
=
origin_chunk
.
id
,
id
=
origin_chunk
.
id
,
content
=
[],
content
=
[],
model
=
origin_chunk
.
model
,
model
=
origin_chunk
.
model
,
),
usage
=
AnthropicUsage
(
usage
=
AnthropicUsage
(
input_tokens
=
origin_chunk
.
usage
.
prompt_tokens
input_tokens
=
origin_chunk
.
usage
.
prompt_tokens
if
origin_chunk
.
usage
if
origin_chunk
.
usage
else
0
,
else
0
,
output_tokens
=
0
,
output_tokens
=
0
,
)
,
),
),
)
)
first_item
=
False
first_item
=
False
...
...
vllm/entrypoints/chat_utils.py
View file @
a3f8d5dd
...
@@ -9,7 +9,7 @@ from collections import Counter, defaultdict, deque
...
@@ -9,7 +9,7 @@ from collections import Counter, defaultdict, deque
from
collections.abc
import
Awaitable
,
Callable
,
Iterable
from
collections.abc
import
Awaitable
,
Callable
,
Iterable
from
functools
import
cached_property
,
lru_cache
,
partial
from
functools
import
cached_property
,
lru_cache
,
partial
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Generic
,
Literal
,
TypeAlias
,
TypeVar
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Generic
,
Literal
,
TypeAlias
,
TypeVar
,
cast
import
jinja2
import
jinja2
import
jinja2.ext
import
jinja2.ext
...
@@ -24,6 +24,7 @@ from openai.types.chat import (
...
@@ -24,6 +24,7 @@ from openai.types.chat import (
ChatCompletionContentPartInputAudioParam
,
ChatCompletionContentPartInputAudioParam
,
ChatCompletionContentPartRefusalParam
,
ChatCompletionContentPartRefusalParam
,
ChatCompletionContentPartTextParam
,
ChatCompletionContentPartTextParam
,
ChatCompletionFunctionToolParam
,
ChatCompletionMessageToolCallParam
,
ChatCompletionMessageToolCallParam
,
ChatCompletionToolMessageParam
,
ChatCompletionToolMessageParam
,
)
)
...
@@ -49,11 +50,20 @@ from vllm.logger import init_logger
...
@@ -49,11 +50,20 @@ from vllm.logger import init_logger
from
vllm.model_executor.models
import
SupportsMultiModal
from
vllm.model_executor.models
import
SupportsMultiModal
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalDataDict
,
MultiModalUUIDDict
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalDataDict
,
MultiModalUUIDDict
from
vllm.multimodal.utils
import
MEDIA_CONNECTOR_REGISTRY
,
MediaConnector
from
vllm.multimodal.utils
import
MEDIA_CONNECTOR_REGISTRY
,
MediaConnector
from
vllm.tokenizers
import
MistralTokenizer
,
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.transformers_utils.chat_templates
import
get_chat_template_fallback_path
from
vllm.transformers_utils.chat_templates
import
get_chat_template_fallback_path
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.transformers_utils.processor
import
cached_get_processor
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
from
vllm.utils.collection_utils
import
is_list_of
from
vllm.utils.func_utils
import
supports_kw
from
vllm.utils.func_utils
import
supports_kw
from
vllm.utils.import_utils
import
LazyLoader
if
TYPE_CHECKING
:
import
torch
from
vllm.tokenizers.mistral
import
MistralTokenizer
else
:
torch
=
LazyLoader
(
"torch"
,
globals
(),
"torch"
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -260,6 +270,9 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
...
@@ -260,6 +270,9 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):
reasoning
:
str
|
None
reasoning
:
str
|
None
"""The reasoning content for interleaved thinking."""
"""The reasoning content for interleaved thinking."""
tools
:
list
[
ChatCompletionFunctionToolParam
]
|
None
"""The tools for developer role."""
ChatCompletionMessageParam
:
TypeAlias
=
(
ChatCompletionMessageParam
:
TypeAlias
=
(
OpenAIChatCompletionMessageParam
OpenAIChatCompletionMessageParam
...
@@ -291,6 +304,9 @@ class ConversationMessage(TypedDict, total=False):
...
@@ -291,6 +304,9 @@ class ConversationMessage(TypedDict, total=False):
reasoning_content
:
str
|
None
reasoning_content
:
str
|
None
"""Deprecated: The reasoning content for interleaved thinking."""
"""Deprecated: The reasoning content for interleaved thinking."""
tools
:
list
[
ChatCompletionFunctionToolParam
]
|
None
"""The tools for developer role."""
# Passed in by user
# Passed in by user
ChatTemplateContentFormatOption
=
Literal
[
"auto"
,
"string"
,
"openai"
]
ChatTemplateContentFormatOption
=
Literal
[
"auto"
,
"string"
,
"openai"
]
...
@@ -620,6 +636,44 @@ ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
...
@@ -620,6 +636,44 @@ ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
_T
=
TypeVar
(
"_T"
)
_T
=
TypeVar
(
"_T"
)
def
_extract_embeds
(
tensors
:
list
[
torch
.
Tensor
]):
if
len
(
tensors
)
==
0
:
return
tensors
if
len
(
tensors
)
==
1
:
tensors
[
0
].
_is_single_item
=
True
# type: ignore
return
tensors
[
0
]
# To keep backwards compatibility for single item input
first_shape
=
tensors
[
0
].
shape
if
all
(
t
.
shape
==
first_shape
for
t
in
tensors
):
return
torch
.
stack
(
tensors
)
return
tensors
def
_get_embeds_data
(
items_by_modality
:
dict
[
str
,
list
[
Any
]],
modality
:
str
):
embeds_key
=
f
"
{
modality
}
_embeds"
embeds
=
items_by_modality
[
embeds_key
]
if
len
(
embeds
)
==
0
:
return
embeds
if
is_list_of
(
embeds
,
torch
.
Tensor
):
return
_extract_embeds
(
embeds
)
if
is_list_of
(
embeds
,
dict
):
if
not
embeds
:
return
{}
first_keys
=
set
(
embeds
[
0
].
keys
())
if
any
(
set
(
item
.
keys
())
!=
first_keys
for
item
in
embeds
[
1
:]):
raise
ValueError
(
"All dictionaries in the list of embeddings must have the same keys."
)
return
{
k
:
_extract_embeds
([
item
[
k
]
for
item
in
embeds
])
for
k
in
first_keys
}
return
embeds
class
BaseMultiModalItemTracker
(
ABC
,
Generic
[
_T
]):
class
BaseMultiModalItemTracker
(
ABC
,
Generic
[
_T
]):
"""
"""
Tracks multi-modal items in a given request and ensures that the number
Tracks multi-modal items in a given request and ensures that the number
...
@@ -688,11 +742,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -688,11 +742,14 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def
all_mm_uuids
(
self
)
->
MultiModalUUIDDict
|
None
:
def
all_mm_uuids
(
self
)
->
MultiModalUUIDDict
|
None
:
if
not
self
.
_items_by_modality
:
if
not
self
.
_items_by_modality
:
return
None
return
None
mm_uuids
=
{}
uuids_by_modality
=
dict
(
self
.
_uuids_by_modality
)
uuids_by_modality
=
dict
(
self
.
_uuids_by_modality
)
if
"image"
in
uuids_by_modality
and
"image_embeds"
in
uuids_by_modality
:
if
"image"
in
uuids_by_modality
and
"image_embeds"
in
uuids_by_modality
:
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
if
"audio"
in
uuids_by_modality
and
"audio_embeds"
in
uuids_by_modality
:
raise
ValueError
(
"Mixing raw audio and embedding inputs is not allowed"
)
mm_uuids
=
{}
if
"image_embeds"
in
uuids_by_modality
:
if
"image_embeds"
in
uuids_by_modality
:
mm_uuids
[
"image"
]
=
uuids_by_modality
[
"image_embeds"
]
mm_uuids
[
"image"
]
=
uuids_by_modality
[
"image_embeds"
]
if
"image"
in
uuids_by_modality
:
if
"image"
in
uuids_by_modality
:
...
@@ -703,6 +760,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
...
@@ -703,6 +760,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
mm_uuids
[
"audio"
]
=
uuids_by_modality
[
"audio"
]
# UUIDs of audios
mm_uuids
[
"audio"
]
=
uuids_by_modality
[
"audio"
]
# UUIDs of audios
if
"video"
in
uuids_by_modality
:
if
"video"
in
uuids_by_modality
:
mm_uuids
[
"video"
]
=
uuids_by_modality
[
"video"
]
# UUIDs of videos
mm_uuids
[
"video"
]
=
uuids_by_modality
[
"video"
]
# UUIDs of videos
return
mm_uuids
return
mm_uuids
@
abstractmethod
@
abstractmethod
...
@@ -714,29 +772,25 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
...
@@ -714,29 +772,25 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[object]):
def
all_mm_data
(
self
)
->
MultiModalDataDict
|
None
:
def
all_mm_data
(
self
)
->
MultiModalDataDict
|
None
:
if
not
self
.
_items_by_modality
:
if
not
self
.
_items_by_modality
:
return
None
return
None
mm_inputs
=
{}
items_by_modality
=
dict
(
self
.
_items_by_modality
)
items_by_modality
=
dict
(
self
.
_items_by_modality
)
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
if
"audio"
in
items_by_modality
and
"audio_embeds"
in
items_by_modality
:
if
"audio"
in
items_by_modality
and
"audio_embeds"
in
items_by_modality
:
raise
ValueError
(
"Mixing raw audio and embedding inputs is not allowed"
)
raise
ValueError
(
"Mixing raw audio and embedding inputs is not allowed"
)
mm_inputs
=
{}
if
"image_embeds"
in
items_by_modality
:
if
"image_embeds"
in
items_by_modality
:
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
mm_inputs
[
"image"
]
=
_get_embeds_data
(
items_by_modality
,
"image"
)
mm_inputs
[
"image"
]
=
(
image_embeds_lst
if
len
(
image_embeds_lst
)
!=
1
else
image_embeds_lst
[
0
]
)
if
"image"
in
items_by_modality
:
if
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
if
"audio_embeds"
in
items_by_modality
:
if
"audio_embeds"
in
items_by_modality
:
audio_embeds_lst
=
items_by_modality
[
"audio_embeds"
]
mm_inputs
[
"audio"
]
=
_get_embeds_data
(
items_by_modality
,
"audio"
)
mm_inputs
[
"audio"
]
=
(
audio_embeds_lst
if
len
(
audio_embeds_lst
)
!=
1
else
audio_embeds_lst
[
0
]
)
if
"audio"
in
items_by_modality
:
if
"audio"
in
items_by_modality
:
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
if
"video"
in
items_by_modality
:
if
"video"
in
items_by_modality
:
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
return
mm_inputs
return
mm_inputs
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
...
@@ -747,38 +801,32 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
...
@@ -747,38 +801,32 @@ class AsyncMultiModalItemTracker(BaseMultiModalItemTracker[Awaitable[object]]):
async
def
all_mm_data
(
self
)
->
MultiModalDataDict
|
None
:
async
def
all_mm_data
(
self
)
->
MultiModalDataDict
|
None
:
if
not
self
.
_items_by_modality
:
if
not
self
.
_items_by_modality
:
return
None
return
None
mm_inputs
=
{}
items_by_modality
=
{}
for
modality
,
items
in
self
.
_items_by_modality
.
items
():
coros
=
[]
for
item
in
items
:
if
item
is
not
None
:
coros
.
append
(
item
)
else
:
coros
.
append
(
asyncio
.
sleep
(
0
))
items_by_modality
[
modality
]
=
await
asyncio
.
gather
(
*
coros
)
coros_by_modality
=
{
modality
:
[
item
or
asyncio
.
sleep
(
0
)
for
item
in
items
]
for
modality
,
items
in
self
.
_items_by_modality
.
items
()
}
items_by_modality
:
dict
[
str
,
list
[
object
|
None
]]
=
{
modality
:
await
asyncio
.
gather
(
*
coros
)
for
modality
,
coros
in
coros_by_modality
.
items
()
}
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
if
"image"
in
items_by_modality
and
"image_embeds"
in
items_by_modality
:
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
raise
ValueError
(
"Mixing raw image and embedding inputs is not allowed"
)
if
"audio"
in
items_by_modality
and
"audio_embeds"
in
items_by_modality
:
if
"audio"
in
items_by_modality
and
"audio_embeds"
in
items_by_modality
:
raise
ValueError
(
"Mixing raw audio and embedding inputs is not allowed"
)
raise
ValueError
(
"Mixing raw audio and embedding inputs is not allowed"
)
mm_inputs
=
{}
if
"image_embeds"
in
items_by_modality
:
if
"image_embeds"
in
items_by_modality
:
image_embeds_lst
=
items_by_modality
[
"image_embeds"
]
mm_inputs
[
"image"
]
=
_get_embeds_data
(
items_by_modality
,
"image"
)
mm_inputs
[
"image"
]
=
(
image_embeds_lst
if
len
(
image_embeds_lst
)
!=
1
else
image_embeds_lst
[
0
]
)
if
"image"
in
items_by_modality
:
if
"image"
in
items_by_modality
:
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
mm_inputs
[
"image"
]
=
items_by_modality
[
"image"
]
# A list of images
if
"audio_embeds"
in
items_by_modality
:
if
"audio_embeds"
in
items_by_modality
:
audio_embeds_lst
=
items_by_modality
[
"audio_embeds"
]
mm_inputs
[
"audio"
]
=
_get_embeds_data
(
items_by_modality
,
"audio"
)
mm_inputs
[
"audio"
]
=
(
audio_embeds_lst
if
len
(
audio_embeds_lst
)
!=
1
else
audio_embeds_lst
[
0
]
)
if
"audio"
in
items_by_modality
:
if
"audio"
in
items_by_modality
:
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
mm_inputs
[
"audio"
]
=
items_by_modality
[
"audio"
]
# A list of audios
if
"video"
in
items_by_modality
:
if
"video"
in
items_by_modality
:
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
mm_inputs
[
"video"
]
=
items_by_modality
[
"video"
]
# A list of videos
return
mm_inputs
return
mm_inputs
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
def
create_parser
(
self
)
->
"BaseMultiModalContentParser"
:
...
@@ -1578,6 +1626,8 @@ def _parse_chat_message_content(
...
@@ -1578,6 +1626,8 @@ def _parse_chat_message_content(
if
"name"
in
message
and
isinstance
(
message
[
"name"
],
str
):
if
"name"
in
message
and
isinstance
(
message
[
"name"
],
str
):
result_msg
[
"name"
]
=
message
[
"name"
]
result_msg
[
"name"
]
=
message
[
"name"
]
if
role
==
"developer"
:
result_msg
[
"tools"
]
=
message
.
get
(
"tools"
,
None
)
return
result
return
result
...
@@ -1588,12 +1638,17 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
...
@@ -1588,12 +1638,17 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None:
# so, for messages that have tool_calls, parse the string (which we get
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
# from openAI format) to dict
for
message
in
messages
:
for
message
in
messages
:
if
(
if
message
[
"role"
]
==
"assistant"
and
"tool_calls"
in
message
:
message
[
"role"
]
==
"assistant"
tool_calls
=
message
.
get
(
"tool_calls"
)
and
"tool_calls"
in
message
if
not
isinstance
(
tool_calls
,
list
):
and
isinstance
(
message
[
"tool_calls"
],
list
)
continue
):
for
item
in
message
[
"tool_calls"
]:
if
len
(
tool_calls
)
==
0
:
# Drop empty tool_calls to keep templates on the normal assistant path.
message
.
pop
(
"tool_calls"
,
None
)
continue
for
item
in
tool_calls
:
# if arguments is None or empty string, set to {}
# if arguments is None or empty string, set to {}
if
content
:
=
item
[
"function"
].
get
(
"arguments"
):
if
content
:
=
item
[
"function"
].
get
(
"arguments"
):
if
not
isinstance
(
content
,
(
dict
,
list
)):
if
not
isinstance
(
content
,
(
dict
,
list
)):
...
@@ -1797,7 +1852,7 @@ def apply_hf_chat_template(
...
@@ -1797,7 +1852,7 @@ def apply_hf_chat_template(
def
apply_mistral_chat_template
(
def
apply_mistral_chat_template
(
tokenizer
:
MistralTokenizer
,
tokenizer
:
"
MistralTokenizer
"
,
messages
:
list
[
ChatCompletionMessageParam
],
messages
:
list
[
ChatCompletionMessageParam
],
chat_template
:
str
|
None
,
chat_template
:
str
|
None
,
tools
:
list
[
dict
[
str
,
Any
]]
|
None
,
tools
:
list
[
dict
[
str
,
Any
]]
|
None
,
...
...
vllm/entrypoints/cli/__init__.py
View file @
a3f8d5dd
...
@@ -2,12 +2,14 @@
...
@@ -2,12 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.entrypoints.cli.benchmark.latency
import
BenchmarkLatencySubcommand
from
vllm.entrypoints.cli.benchmark.latency
import
BenchmarkLatencySubcommand
from
vllm.entrypoints.cli.benchmark.serve
import
BenchmarkServingSubcommand
from
vllm.entrypoints.cli.benchmark.serve
import
BenchmarkServingSubcommand
from
vllm.entrypoints.cli.benchmark.startup
import
BenchmarkStartupSubcommand
from
vllm.entrypoints.cli.benchmark.sweep
import
BenchmarkSweepSubcommand
from
vllm.entrypoints.cli.benchmark.sweep
import
BenchmarkSweepSubcommand
from
vllm.entrypoints.cli.benchmark.throughput
import
BenchmarkThroughputSubcommand
from
vllm.entrypoints.cli.benchmark.throughput
import
BenchmarkThroughputSubcommand
__all__
:
list
[
str
]
=
[
__all__
:
list
[
str
]
=
[
"BenchmarkLatencySubcommand"
,
"BenchmarkLatencySubcommand"
,
"BenchmarkServingSubcommand"
,
"BenchmarkServingSubcommand"
,
"BenchmarkStartupSubcommand"
,
"BenchmarkSweepSubcommand"
,
"BenchmarkSweepSubcommand"
,
"BenchmarkThroughputSubcommand"
,
"BenchmarkThroughputSubcommand"
,
]
]
vllm/entrypoints/cli/benchmark/startup.py
0 → 100644
View file @
a3f8d5dd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
from
vllm.benchmarks.startup
import
add_cli_args
,
main
from
vllm.entrypoints.cli.benchmark.base
import
BenchmarkSubcommandBase
class
BenchmarkStartupSubcommand
(
BenchmarkSubcommandBase
):
"""The `startup` subcommand for `vllm bench`."""
name
=
"startup"
help
=
"Benchmark the startup time of vLLM models."
@
classmethod
def
add_cli_args
(
cls
,
parser
:
argparse
.
ArgumentParser
)
->
None
:
add_cli_args
(
parser
)
@
staticmethod
def
cmd
(
args
:
argparse
.
Namespace
)
->
None
:
main
(
args
)
vllm/entrypoints/context.py
View file @
a3f8d5dd
...
@@ -34,13 +34,13 @@ from vllm.entrypoints.openai.protocol import (
...
@@ -34,13 +34,13 @@ from vllm.entrypoints.openai.protocol import (
ResponseRawMessageAndToken
,
ResponseRawMessageAndToken
,
ResponsesRequest
,
ResponsesRequest
,
)
)
from
vllm.entrypoints.openai.tool_parsers.abstract_tool_parser
import
ToolParser
from
vllm.entrypoints.responses_utils
import
construct_tool_dicts
from
vllm.entrypoints.responses_utils
import
construct_tool_dicts
from
vllm.entrypoints.tool
import
Tool
from
vllm.entrypoints.tool
import
Tool
from
vllm.entrypoints.tool_server
import
ToolServer
from
vllm.entrypoints.tool_server
import
ToolServer
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.reasoning.abs_reasoning_parsers
import
ReasoningParser
from
vllm.reasoning.abs_reasoning_parsers
import
ReasoningParser
from
vllm.tokenizers.protocol
import
TokenizerLike
from
vllm.tokenizers.protocol
import
TokenizerLike
from
vllm.tool_parsers.abstract_tool_parser
import
ToolParser
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.utils
import
random_uuid
from
vllm.utils
import
random_uuid
...
@@ -74,24 +74,24 @@ class TurnMetrics:
...
@@ -74,24 +74,24 @@ class TurnMetrics:
def
__init__
(
def
__init__
(
self
,
self
,
input_tokens
=
0
,
input_tokens
:
int
=
0
,
output_tokens
=
0
,
output_tokens
:
int
=
0
,
cached_input_tokens
=
0
,
cached_input_tokens
:
int
=
0
,
tool_output_tokens
=
0
,
tool_output_tokens
:
int
=
0
,
):
)
->
None
:
self
.
input_tokens
=
input_tokens
self
.
input_tokens
=
input_tokens
self
.
output_tokens
=
output_tokens
self
.
output_tokens
=
output_tokens
self
.
cached_input_tokens
=
cached_input_tokens
self
.
cached_input_tokens
=
cached_input_tokens
self
.
tool_output_tokens
=
tool_output_tokens
self
.
tool_output_tokens
=
tool_output_tokens
def
reset
(
self
):
def
reset
(
self
)
->
None
:
"""Reset counters for a new turn."""
"""Reset counters for a new turn."""
self
.
input_tokens
=
0
self
.
input_tokens
=
0
self
.
output_tokens
=
0
self
.
output_tokens
=
0
self
.
cached_input_tokens
=
0
self
.
cached_input_tokens
=
0
self
.
tool_output_tokens
=
0
self
.
tool_output_tokens
=
0
def
copy
(
self
):
def
copy
(
self
)
->
"TurnMetrics"
:
"""Create a copy of this turn's token counts."""
"""Create a copy of this turn's token counts."""
return
TurnMetrics
(
return
TurnMetrics
(
self
.
input_tokens
,
self
.
input_tokens
,
...
...
vllm/entrypoints/llm.py
View file @
a3f8d5dd
...
@@ -9,7 +9,7 @@ import cloudpickle
...
@@ -9,7 +9,7 @@ import cloudpickle
import
torch.nn
as
nn
import
torch.nn
as
nn
from
pydantic
import
ValidationError
from
pydantic
import
ValidationError
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
from
typing_extensions
import
TypeVar
,
deprecated
from
typing_extensions
import
TypeVar
from
vllm.beam_search
import
(
from
vllm.beam_search
import
(
BeamSearchInstance
,
BeamSearchInstance
,
...
@@ -72,8 +72,8 @@ from vllm.platforms import current_platform
...
@@ -72,8 +72,8 @@ from vllm.platforms import current_platform
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
BeamSearchParams
,
RequestOutputKind
,
SamplingParams
from
vllm.sampling_params
import
BeamSearchParams
,
RequestOutputKind
,
SamplingParams
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
from
vllm.tokenizers
import
MistralTokenizer
,
TokenizerLike
from
vllm.tokenizers
import
TokenizerLike
from
vllm.tokenizers.
hf
import
get_cached_t
okenizer
from
vllm.tokenizers.
mistral
import
MistralT
okenizer
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils.collection_utils
import
as_iter
,
is_list_of
from
vllm.utils.collection_utils
import
as_iter
,
is_list_of
from
vllm.utils.counter
import
Counter
from
vllm.utils.counter
import
Counter
...
@@ -199,7 +199,7 @@ class LLM:
...
@@ -199,7 +199,7 @@ class LLM:
quantization
:
QuantizationMethods
|
None
=
None
,
quantization
:
QuantizationMethods
|
None
=
None
,
revision
:
str
|
None
=
None
,
revision
:
str
|
None
=
None
,
tokenizer_revision
:
str
|
None
=
None
,
tokenizer_revision
:
str
|
None
=
None
,
seed
:
int
|
None
=
None
,
seed
:
int
=
0
,
gpu_memory_utilization
:
float
=
0.9
,
gpu_memory_utilization
:
float
=
0.9
,
swap_space
:
float
=
4
,
swap_space
:
float
=
4
,
cpu_offload_gb
:
float
=
0
,
cpu_offload_gb
:
float
=
0
,
...
@@ -367,16 +367,6 @@ class LLM:
...
@@ -367,16 +367,6 @@ class LLM:
def
get_tokenizer
(
self
)
->
TokenizerLike
:
def
get_tokenizer
(
self
)
->
TokenizerLike
:
return
self
.
llm_engine
.
get_tokenizer
()
return
self
.
llm_engine
.
get_tokenizer
()
@
deprecated
(
"`set_tokenizer` is deprecated and will be removed in v0.13."
)
def
set_tokenizer
(
self
,
tokenizer
:
TokenizerLike
)
->
None
:
# While CachedTokenizer is dynamic, have no choice but
# compare class name. Misjudgment will arise from
# user-defined tokenizer started with 'Cached'
if
tokenizer
.
__class__
.
__name__
.
startswith
(
"Cached"
):
self
.
llm_engine
.
tokenizer
=
tokenizer
else
:
self
.
llm_engine
.
tokenizer
=
get_cached_tokenizer
(
tokenizer
)
def
reset_mm_cache
(
self
)
->
None
:
def
reset_mm_cache
(
self
)
->
None
:
self
.
input_processor
.
clear_mm_cache
()
self
.
input_processor
.
clear_mm_cache
()
self
.
llm_engine
.
reset_mm_cache
()
self
.
llm_engine
.
reset_mm_cache
()
...
...
vllm/entrypoints/openai/api_server.py
View file @
a3f8d5dd
...
@@ -72,7 +72,6 @@ from vllm.entrypoints.openai.serving_transcription import (
...
@@ -72,7 +72,6 @@ from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription
,
OpenAIServingTranscription
,
OpenAIServingTranslation
,
OpenAIServingTranslation
,
)
)
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.entrypoints.openai.utils
import
validate_json_request
from
vllm.entrypoints.openai.utils
import
validate_json_request
from
vllm.entrypoints.pooling.classify.serving
import
ServingClassification
from
vllm.entrypoints.pooling.classify.serving
import
ServingClassification
from
vllm.entrypoints.pooling.embed.serving
import
OpenAIServingEmbedding
from
vllm.entrypoints.pooling.embed.serving
import
OpenAIServingEmbedding
...
@@ -95,6 +94,7 @@ from vllm.entrypoints.utils import (
...
@@ -95,6 +94,7 @@ from vllm.entrypoints.utils import (
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.reasoning
import
ReasoningParserManager
from
vllm.reasoning
import
ReasoningParserManager
from
vllm.tasks
import
POOLING_TASKS
from
vllm.tasks
import
POOLING_TASKS
from
vllm.tool_parsers
import
ToolParserManager
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.gc_utils
import
freeze_gc_heap
from
vllm.utils.gc_utils
import
freeze_gc_heap
...
...
vllm/entrypoints/openai/cli_args.py
View file @
a3f8d5dd
...
@@ -27,8 +27,8 @@ from vllm.entrypoints.constants import (
...
@@ -27,8 +27,8 @@ from vllm.entrypoints.constants import (
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
,
H11_MAX_INCOMPLETE_EVENT_SIZE_DEFAULT
,
)
)
from
vllm.entrypoints.openai.serving_models
import
LoRAModulePath
from
vllm.entrypoints.openai.serving_models
import
LoRAModulePath
from
vllm.entrypoints.openai.tool_parsers
import
ToolParserManager
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.tool_parsers
import
ToolParserManager
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -176,7 +176,7 @@ class FrontendArgs:
...
@@ -176,7 +176,7 @@ class FrontendArgs:
enable_force_include_usage
:
bool
=
False
enable_force_include_usage
:
bool
=
False
"""If set to True, including usage on every request."""
"""If set to True, including usage on every request."""
enable_tokenizer_info_endpoint
:
bool
=
False
enable_tokenizer_info_endpoint
:
bool
=
False
"""Enable the /
get_
tokenizer_info endpoint. May expose chat
"""Enable the
`
/tokenizer_info
`
endpoint. May expose chat
templates and other tokenizer configuration."""
templates and other tokenizer configuration."""
enable_log_outputs
:
bool
=
False
enable_log_outputs
:
bool
=
False
"""If True, log model outputs (generations).
"""If True, log model outputs (generations).
...
...
vllm/entrypoints/openai/parser/harmony_utils.py
View file @
a3f8d5dd
...
@@ -232,7 +232,177 @@ def parse_response_input(
...
@@ -232,7 +232,177 @@ def parse_response_input(
return
msg
return
msg
def
parse_chat_inputs_to_harmony_messages
(
chat_msgs
:
list
)
->
list
[
Message
]:
"""
Parse a list of messages from request.messages in the Chat Completion API to
Harmony messages.
"""
msgs
:
list
[
Message
]
=
[]
tool_id_names
:
dict
[
str
,
str
]
=
{}
# Collect tool id to name mappings for tool response recipient values
for
chat_msg
in
chat_msgs
:
for
tool_call
in
chat_msg
.
get
(
"tool_calls"
,
[]):
tool_id_names
[
tool_call
.
get
(
"id"
)]
=
tool_call
.
get
(
"function"
,
{}).
get
(
"name"
)
for
chat_msg
in
chat_msgs
:
msgs
.
extend
(
parse_chat_input_to_harmony_message
(
chat_msg
,
tool_id_names
))
msgs
=
auto_drop_analysis_messages
(
msgs
)
return
msgs
def
auto_drop_analysis_messages
(
msgs
:
list
[
Message
])
->
list
[
Message
]:
"""
Harmony models expect the analysis messages (representing raw chain of thought) to
be dropped after an assistant message to the final channel is produced from the
reasoning of those messages.
The openai-harmony library does this if the very last assistant message is to the
final channel, but it does not handle the case where we're in longer multi-turn
conversations and the client gave us reasoning content from previous turns of
the conversation with multiple assistant messages to the final channel in the
conversation.
So, we find the index of the last assistant message to the final channel and drop
all analysis messages that precede it, leaving only the analysis messages that
are relevant to the current part of the conversation.
"""
last_assistant_final_index
=
-
1
for
i
in
range
(
len
(
msgs
)
-
1
,
-
1
,
-
1
):
msg
=
msgs
[
i
]
if
msg
.
author
.
role
==
"assistant"
and
msg
.
channel
==
"final"
:
last_assistant_final_index
=
i
break
cleaned_msgs
:
list
[
Message
]
=
[]
for
i
,
msg
in
enumerate
(
msgs
):
if
i
<
last_assistant_final_index
and
msg
.
channel
==
"analysis"
:
continue
cleaned_msgs
.
append
(
msg
)
return
cleaned_msgs
def
flatten_chat_text_content
(
content
:
str
|
list
|
None
)
->
str
|
None
:
"""
Extract the text parts from a chat message content field and flatten them
into a single string.
"""
if
isinstance
(
content
,
list
):
return
""
.
join
(
item
.
get
(
"text"
,
""
)
for
item
in
content
if
isinstance
(
item
,
dict
)
and
item
.
get
(
"type"
)
==
"text"
)
return
content
def
parse_chat_input_to_harmony_message
(
chat_msg
,
tool_id_names
:
dict
[
str
,
str
]
|
None
=
None
)
->
list
[
Message
]:
"""
Parse a message from request.messages in the Chat Completion API to
Harmony messages.
"""
tool_id_names
=
tool_id_names
or
{}
if
not
isinstance
(
chat_msg
,
dict
):
# Handle Pydantic models
chat_msg
=
chat_msg
.
model_dump
(
exclude_none
=
True
)
role
=
chat_msg
.
get
(
"role"
)
msgs
:
list
[
Message
]
=
[]
# Assistant message with tool calls
tool_calls
=
chat_msg
.
get
(
"tool_calls"
,
[])
if
role
==
"assistant"
and
tool_calls
:
content
=
flatten_chat_text_content
(
chat_msg
.
get
(
"content"
))
if
content
:
commentary_msg
=
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
content
)
commentary_msg
=
commentary_msg
.
with_channel
(
"commentary"
)
msgs
.
append
(
commentary_msg
)
reasoning_content
=
chat_msg
.
get
(
"reasoning"
)
or
chat_msg
.
get
(
"reasoning_content"
)
if
reasoning_content
:
analysis_msg
=
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
reasoning_content
)
analysis_msg
=
analysis_msg
.
with_channel
(
"analysis"
)
msgs
.
append
(
analysis_msg
)
for
call
in
tool_calls
:
func
=
call
.
get
(
"function"
,
{})
name
=
func
.
get
(
"name"
,
""
)
arguments
=
func
.
get
(
"arguments"
,
""
)
or
""
msg
=
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
arguments
)
msg
=
msg
.
with_channel
(
"commentary"
)
msg
=
msg
.
with_recipient
(
f
"functions.
{
name
}
"
)
# Officially, this should be `<|constrain|>json` but there is not clear
# evidence that improves accuracy over `json` and some anecdotes to the
# contrary. Further testing of the different content_types is needed.
msg
=
msg
.
with_content_type
(
"json"
)
msgs
.
append
(
msg
)
return
msgs
# Tool role message (tool output)
if
role
==
"tool"
:
tool_call_id
=
chat_msg
.
get
(
"tool_call_id"
,
""
)
name
=
tool_id_names
.
get
(
tool_call_id
,
""
)
content
=
chat_msg
.
get
(
"content"
,
""
)
or
""
content
=
flatten_chat_text_content
(
content
)
msg
=
(
Message
.
from_author_and_content
(
Author
.
new
(
Role
.
TOOL
,
f
"functions.
{
name
}
"
),
content
)
.
with_channel
(
"commentary"
)
.
with_recipient
(
"assistant"
)
)
return
[
msg
]
# Non-tool reasoning content
reasoning_content
=
chat_msg
.
get
(
"reasoning"
)
or
chat_msg
.
get
(
"reasoning_content"
)
if
role
==
"assistant"
and
reasoning_content
:
analysis_msg
=
Message
.
from_role_and_content
(
Role
.
ASSISTANT
,
reasoning_content
)
analysis_msg
=
analysis_msg
.
with_channel
(
"analysis"
)
msgs
.
append
(
analysis_msg
)
# Default: user/assistant/system messages with content
content
=
chat_msg
.
get
(
"content"
)
or
""
if
content
is
None
:
content
=
""
if
isinstance
(
content
,
str
):
contents
=
[
TextContent
(
text
=
content
)]
else
:
# TODO: Support refusal.
contents
=
[
TextContent
(
text
=
c
.
get
(
"text"
,
""
))
for
c
in
content
]
# Only add assistant messages if they have content, as reasoning or tool calling
# assistant messages were already added above.
if
role
==
"assistant"
and
contents
and
contents
[
0
].
text
:
msg
=
Message
.
from_role_and_contents
(
role
,
contents
)
# Send non-tool assistant messages to the final channel
msg
=
msg
.
with_channel
(
"final"
)
msgs
.
append
(
msg
)
# For user/system/developer messages, add them directly even if no content.
elif
role
!=
"assistant"
:
msg
=
Message
.
from_role_and_contents
(
role
,
contents
)
msgs
.
append
(
msg
)
return
msgs
def
parse_input_to_harmony_message
(
chat_msg
)
->
list
[
Message
]:
def
parse_input_to_harmony_message
(
chat_msg
)
->
list
[
Message
]:
"""
Parse a message from request.previous_input_messages in the Responsees API to
Harmony messages.
"""
if
not
isinstance
(
chat_msg
,
dict
):
if
not
isinstance
(
chat_msg
,
dict
):
# Handle Pydantic models
# Handle Pydantic models
chat_msg
=
chat_msg
.
model_dump
(
exclude_none
=
True
)
chat_msg
=
chat_msg
.
model_dump
(
exclude_none
=
True
)
...
@@ -258,14 +428,7 @@ def parse_input_to_harmony_message(chat_msg) -> list[Message]:
...
@@ -258,14 +428,7 @@ def parse_input_to_harmony_message(chat_msg) -> list[Message]:
if
role
==
"tool"
:
if
role
==
"tool"
:
name
=
chat_msg
.
get
(
"name"
,
""
)
name
=
chat_msg
.
get
(
"name"
,
""
)
content
=
chat_msg
.
get
(
"content"
,
""
)
or
""
content
=
chat_msg
.
get
(
"content"
,
""
)
or
""
if
isinstance
(
content
,
list
):
content
=
flatten_chat_text_content
(
content
)
# Handle array format for tool message content
# by concatenating all text parts.
content
=
""
.
join
(
item
.
get
(
"text"
,
""
)
for
item
in
content
if
isinstance
(
item
,
dict
)
and
item
.
get
(
"type"
)
==
"text"
)
msg
=
Message
.
from_author_and_content
(
msg
=
Message
.
from_author_and_content
(
Author
.
new
(
Role
.
TOOL
,
f
"functions.
{
name
}
"
),
content
Author
.
new
(
Role
.
TOOL
,
f
"functions.
{
name
}
"
),
content
...
@@ -623,20 +786,40 @@ def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
...
@@ -623,20 +786,40 @@ def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
def
parse_chat_output
(
def
parse_chat_output
(
token_ids
:
Sequence
[
int
],
token_ids
:
Sequence
[
int
],
)
->
tuple
[
str
|
None
,
str
|
None
,
bool
]:
)
->
tuple
[
str
|
None
,
str
|
None
,
bool
]:
"""
Parse the output of a Harmony chat completion into reasoning and final content.
Note that when the `openai` tool parser is used, serving_chat only uses this
for the reasoning content and gets the final content from the tool call parser.
When the `openai` tool parser is not enabled, or when `GptOssReasoningParser` is
in use,this needs to return the final content without any tool calls parsed.
Empty reasoning or final content is returned as None instead of an empty string.
"""
parser
=
parse_output_into_messages
(
token_ids
)
parser
=
parse_output_into_messages
(
token_ids
)
output_msgs
=
parser
.
messages
output_msgs
=
parser
.
messages
is_tool_call
=
False
# TODO: update this when tool call is supported
is_tool_call
=
False
# TODO: update this when tool call is supported
if
len
(
output_msgs
)
==
0
:
# The generation has stopped during reasoning.
# Get completed messages from the parser
reasoning
=
parser
.
current_content
reasoning_texts
=
[
final_content
=
None
msg
.
content
[
0
].
text
for
msg
in
output_msgs
if
msg
.
channel
==
"analysis"
elif
len
(
output_msgs
)
==
1
:
]
# The generation has stopped during final message.
final_texts
=
[
reasoning
=
output_msgs
[
0
].
content
[
0
].
text
msg
.
content
[
0
].
text
for
msg
in
output_msgs
if
msg
.
channel
!=
"analysis"
final_content
=
parser
.
current_content
]
else
:
reasoning_msg
=
output_msgs
[:
-
1
]
# Extract partial messages from the parser
final_msg
=
output_msgs
[
-
1
]
if
parser
.
current_channel
==
"analysis"
and
parser
.
current_content
:
reasoning
=
"
\n
"
.
join
([
msg
.
content
[
0
].
text
for
msg
in
reasoning_msg
])
reasoning_texts
.
append
(
parser
.
current_content
)
final_content
=
final_msg
.
content
[
0
].
text
elif
parser
.
current_channel
!=
"analysis"
and
parser
.
current_content
:
final_texts
.
append
(
parser
.
current_content
)
# Flatten multiple messages into a single string
reasoning
:
str
|
None
=
"
\n
"
.
join
(
reasoning_texts
)
final_content
:
str
|
None
=
"
\n
"
.
join
(
final_texts
)
# Return None instead of empty string since existing callers check for None
reasoning
=
reasoning
or
None
final_content
=
final_content
or
None
return
reasoning
,
final_content
,
is_tool_call
return
reasoning
,
final_content
,
is_tool_call
Prev
1
…
8
9
10
11
12
13
14
15
16
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment