Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a3d087ad
Unverified
Commit
a3d087ad
authored
Sep 19, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Sep 19, 2025
Browse files
[P/D][Nixl] Introduce `KVTransferMetrics` and aggregation strategy (#22188)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
058525b9
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
525 additions
and
25 deletions
+525
-25
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+210
-1
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+18
-3
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+21
-1
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
+100
-0
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+65
-3
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+65
-3
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+17
-10
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+7
-1
vllm/v1/metrics/stats.py
vllm/v1/metrics/stats.py
+2
-1
vllm/v1/outputs.py
vllm/v1/outputs.py
+10
-1
vllm/v1/worker/kv_connector_model_runner_mixin.py
vllm/v1/worker/kv_connector_model_runner_mixin.py
+10
-1
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
a3d087ad
...
@@ -18,12 +18,18 @@ import torch
...
@@ -18,12 +18,18 @@ import torch
from
vllm
import
LLM
from
vllm
import
LLM
from
vllm.config
import
KVTransferConfig
from
vllm.config
import
KVTransferConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorStats
)
from
vllm.distributed.kv_transfer.kv_connector.v1.multi_connector
import
(
MultiKVConnectorStats
)
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
KVConnectorRole
,
NixlAgentMetadata
,
NixlConnector
,
NixlConnectorMetadata
,
KVConnectorRole
,
NixlAgentMetadata
,
NixlConnector
,
NixlConnectorMetadata
,
NixlConnectorWorker
)
NixlConnectorWorker
,
NixlKVConnectorStats
)
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
.utils
import
create_request
,
create_scheduler
,
create_vllm_config
from
.utils
import
create_request
,
create_scheduler
,
create_vllm_config
...
@@ -475,6 +481,209 @@ class TestNixlHandshake:
...
@@ -475,6 +481,209 @@ class TestNixlHandshake:
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
# we put here is important. First run ray, it will clean up the resources, then
# we put here is important. First run ray, it will clean up the resources, then
# the rest of the tests.
# the rest of the tests.
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
)
def
test_kv_connector_stats
(
dist_init
):
"""Test that KV transfer stats are properly recorded and retrieved."""
vllm_config
=
create_vllm_config
()
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
# Verify that xfer_stats starts empty
initial_stats
=
connector
.
get_kv_connector_stats
()
assert
initial_stats
is
None
# Create transfer metadata
request_id
=
"test_req_for_stats"
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req
(
request_id
=
request_id
,
local_block_ids
=
[
1
,
2
,
3
],
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
})
connector
.
bind_connector_metadata
(
metadata
)
# Start the transfer
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
)
connector
.
start_load_kv
(
dummy_ctx
)
# Verify stats are recorded after transfer is complete
max_iterations
=
2
# Clear metadata before start_load_kv to prevent reprocessing same request
connector
.
bind_connector_metadata
(
NixlConnectorMetadata
())
for
_
in
range
(
max_iterations
):
# Need to call start_load_kv to process completed handshakes
connector
.
start_load_kv
(
dummy_ctx
)
_
,
done_recving
=
connector
.
get_finished
(
finished_req_ids
=
set
())
if
len
(
done_recving
)
>
0
and
request_id
in
done_recving
:
break
time
.
sleep
(
0.1
)
# Small delay to allow background handshake to complete
else
:
assert
"Transfer did not complete within expected iterations"
# Now check that stats were recorded
stats_after_transfer
=
connector
.
get_kv_connector_stats
()
assert
isinstance
(
stats_after_transfer
,
NixlKVConnectorStats
)
# Verify stats values are recorded
assert
not
stats_after_transfer
.
is_empty
()
assert
stats_after_transfer
.
data
[
"num_successful_transfers"
]
==
1
# Verify stats are reset after retrieval
stats_after_reset
=
connector
.
get_kv_connector_stats
()
assert
stats_after_reset
is
None
def
test_kv_connector_stats_aggregation
():
"""
Test KV transfer stats aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""
# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
# done in MultiprocExecutor.execute_model
aggregator
=
KVOutputAggregator
(
world_size
=
3
)
# Create stats for multiple workers with different transfer patterns
worker1_stats
=
NixlKVConnectorStats
()
worker2_stats
=
NixlKVConnectorStats
()
worker3_stats
=
NixlKVConnectorStats
()
# Record different transfers on each worker
# Worker 1: 2 transfers
worker1_stats
.
record_transfer
()
worker1_stats
.
record_transfer
()
# Worker 2: 1 transfer
worker2_stats
.
record_transfer
()
# Worker 3: 3 transfers
worker3_stats
.
record_transfer
()
worker3_stats
.
record_transfer
()
worker3_stats
.
record_transfer
()
# Create ModelRunnerOutput instances for each worker
worker_outputs
=
[]
for
i
,
worker_stats
in
enumerate
(
[
worker1_stats
,
worker2_stats
,
worker3_stats
]):
output
=
ModelRunnerOutput
(
req_ids
=
[
f
"req_
{
i
}
"
],
req_id_to_index
=
{
f
"req_
{
i
}
"
:
0
},
sampled_token_ids
=
[[
123
]],
# dummy token
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
[
None
],
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
set
([
f
"req_
{
i
}
_send"
])
if
i
<
2
else
None
,
# Workers 0,1 finished sending
finished_recving
=
set
([
f
"req_
{
i
}
_recv"
])
if
i
>
0
else
None
,
# Workers 1,2 finished receiving
kv_connector_stats
=
worker_stats
,
))
worker_outputs
.
append
(
output
)
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
aggregated_output
=
aggregator
.
aggregate
(
worker_outputs
,
output_rank
=
0
)
kv_connector_stats
=
\
aggregated_output
.
kv_connector_output
.
kv_connector_stats
assert
isinstance
(
kv_connector_stats
,
NixlKVConnectorStats
)
# Number of total transfers across all workers.
assert
kv_connector_stats
.
data
[
"num_successful_transfers"
]
==
6
def
test_multi_kv_connector_stats_aggregation
():
"""
Test MultiKVConnectorStats aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""
aggregator
=
KVOutputAggregator
(
world_size
=
3
)
from
dataclasses
import
dataclass
@
dataclass
class
FooKVConnectorStats
(
KVConnectorStats
):
def
reset
(
self
):
self
.
data
=
{
"num_foo_transfers"
:
0
}
def
record_transfer
(
self
):
if
"num_foo_transfers"
not
in
self
.
data
:
self
.
data
[
"num_foo_transfers"
]
=
0
self
.
data
[
"num_foo_transfers"
]
+=
1
def
is_empty
(
self
)
->
bool
:
return
self
.
data
[
"num_foo_transfers"
]
==
0
def
aggregate
(
self
,
other
:
"FooKVConnectorStats"
)
->
"FooKVConnectorStats"
:
if
not
other
.
is_empty
():
self
.
data
[
"num_foo_transfers"
]
+=
other
.
data
[
"num_foo_transfers"
]
return
self
def
make_multi_stats
(
nixl_count
:
int
,
foo_count
:
int
)
->
MultiKVConnectorStats
:
data
:
dict
[
str
,
KVConnectorStats
]
=
{}
if
nixl_count
>
0
:
nixl_stats
=
NixlKVConnectorStats
()
for
_
in
range
(
nixl_count
):
nixl_stats
.
record_transfer
()
data
[
"NixlConnector"
]
=
nixl_stats
if
foo_count
>
0
:
foo_stats
=
FooKVConnectorStats
()
for
_
in
range
(
foo_count
):
foo_stats
.
record_transfer
()
data
[
"FooConnector"
]
=
foo_stats
return
MultiKVConnectorStats
(
data
=
data
)
# Create heterogeneous stats across 3 workers
worker_patterns
=
[(
2
,
1
),
(
3
,
0
),
(
0
,
5
)]
# (Nixl, Foo)
worker_outputs
:
list
[
ModelRunnerOutput
]
=
[]
for
i
,
(
nixl
,
foo
)
in
enumerate
(
worker_patterns
):
stats
=
make_multi_stats
(
nixl
,
foo
)
output
=
ModelRunnerOutput
(
req_ids
=
[
f
"req_
{
i
}
"
],
req_id_to_index
=
{
f
"req_
{
i
}
"
:
0
},
sampled_token_ids
=
[[
123
]],
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
[
None
],
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
set
([
f
"req_
{
i
}
_send"
])
if
i
<
2
else
None
,
finished_recving
=
set
([
f
"req_
{
i
}
_recv"
])
if
i
>
0
else
None
,
kv_connector_stats
=
stats
,
),
)
worker_outputs
.
append
(
output
)
aggregated_output
=
aggregator
.
aggregate
(
worker_outputs
,
output_rank
=
0
)
kv_connector_stats
=
\
aggregated_output
.
kv_connector_output
.
kv_connector_stats
assert
isinstance
(
kv_connector_stats
,
MultiKVConnectorStats
)
# Validate per-connector totals across workers
assert
kv_connector_stats
[
"NixlConnector"
].
data
[
"num_successful_transfers"
]
==
5
assert
kv_connector_stats
[
"FooConnector"
].
data
[
"num_foo_transfers"
]
==
6
@
pytest
.
mark
.
parametrize
(
"distributed_executor_backend"
,
[
"ray"
,
None
])
@
pytest
.
mark
.
parametrize
(
"distributed_executor_backend"
,
[
"ray"
,
None
])
@
patch
(
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
a3d087ad
...
@@ -129,7 +129,7 @@ class KVOutputAggregator:
...
@@ -129,7 +129,7 @@ class KVOutputAggregator:
def
aggregate
(
self
,
def
aggregate
(
self
,
outputs
:
list
[
ModelRunnerOutput
],
outputs
:
list
[
ModelRunnerOutput
],
output_rank
:
int
=
0
)
->
ModelRunnerOutput
:
output_rank
:
int
=
0
)
->
ModelRunnerOutput
:
#
a
ggregate kv_connector_output from all workers
#
A
ggregate kv_connector_output from all workers
def
update_finished_set
(
req_ids
:
Optional
[
set
[
str
]],
def
update_finished_set
(
req_ids
:
Optional
[
set
[
str
]],
remaining_count_dict
:
dict
[
str
,
int
],
remaining_count_dict
:
dict
[
str
,
int
],
...
@@ -142,8 +142,9 @@ class KVOutputAggregator:
...
@@ -142,8 +142,9 @@ class KVOutputAggregator:
finished_sending
=
set
[
str
]()
finished_sending
=
set
[
str
]()
finished_recving
=
set
[
str
]()
finished_recving
=
set
[
str
]()
for
output
in
outputs
:
aggregated_kv_connector_stats
=
None
output
=
output
.
kv_connector_output
for
model_runner_output
in
outputs
:
output
=
model_runner_output
.
kv_connector_output
if
not
output
:
if
not
output
:
continue
continue
update_finished_set
(
output
.
finished_sending
,
update_finished_set
(
output
.
finished_sending
,
...
@@ -151,12 +152,26 @@ class KVOutputAggregator:
...
@@ -151,12 +152,26 @@ class KVOutputAggregator:
update_finished_set
(
output
.
finished_recving
,
update_finished_set
(
output
.
finished_recving
,
self
.
_recv_remaining_count
,
finished_recving
)
self
.
_recv_remaining_count
,
finished_recving
)
# Aggregate kv_connector_stats from all workers.
if
aggregated_kv_connector_stats
is
None
:
# Use the first worker's kv_connector_stats as accumulator.
aggregated_kv_connector_stats
=
output
.
kv_connector_stats
elif
kv_connector_stats
:
=
output
.
kv_connector_stats
:
if
aggregated_kv_connector_stats
is
None
:
aggregated_kv_connector_stats
=
kv_connector_stats
else
:
assert
isinstance
(
aggregated_kv_connector_stats
,
type
(
kv_connector_stats
))
aggregated_kv_connector_stats
=
\
aggregated_kv_connector_stats
.
aggregate
(
kv_connector_stats
)
# select output of the worker specified by output_rank
# select output of the worker specified by output_rank
output
=
outputs
[
output_rank
]
output
=
outputs
[
output_rank
]
output
.
kv_connector_output
=
KVConnectorOutput
(
output
.
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
finished_sending
or
None
,
finished_sending
=
finished_sending
or
None
,
finished_recving
=
finished_recving
or
None
,
finished_recving
=
finished_recving
or
None
,
kv_connector_stats
=
aggregated_kv_connector_stats
or
None
,
)
)
return
output
return
output
...
...
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
a3d087ad
...
@@ -49,6 +49,8 @@ if TYPE_CHECKING:
...
@@ -49,6 +49,8 @@ if TYPE_CHECKING:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorStats
)
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
@@ -235,6 +237,12 @@ class KVConnectorBase_V1(ABC):
...
@@ -235,6 +237,12 @@ class KVConnectorBase_V1(ABC):
"""
"""
return
None
return
None
def
get_kv_connector_stats
(
self
)
->
Optional
[
"KVConnectorStats"
]:
"""
Get the KV connector stats collected during the last interval.
"""
return
None
# ==============================
# ==============================
# Scheduler-side methods
# Scheduler-side methods
# ==============================
# ==============================
...
@@ -366,3 +374,15 @@ class KVConnectorBase_V1(ABC):
...
@@ -366,3 +374,15 @@ class KVConnectorBase_V1(ABC):
"""
"""
return
None
return
None
@
classmethod
def
build_kv_connector_stats
(
cls
,
data
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
Optional
[
"KVConnectorStats"
]:
"""
KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object,
which can implement custom aggregation logic on the data dict.
"""
return
None
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
0 → 100644
View file @
a3d087ad
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Optional
,
Union
from
vllm.config.kv_transfer
import
KVTransferConfig
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
KVConnectorFactory
)
from
vllm.distributed.kv_transfer.kv_transfer_state
import
(
has_kv_transfer_group
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
@
dataclass
class
KVConnectorStats
:
"""
Base class for KV Connector Stats, a container for transfer performance
metrics or otherwise important telemetry from the connector.
All sub-classes need to be serializable as stats are sent from worker to
logger process.
"""
data
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
def
reset
(
self
):
"""Reset the stats, clear the state."""
raise
NotImplementedError
def
aggregate
(
self
,
other
:
"KVConnectorStats"
)
->
"KVConnectorStats"
:
"""
Aggregate stats with another `KVConnectorStats` object.
"""
raise
NotImplementedError
def
reduce
(
self
)
->
dict
[
str
,
Union
[
int
,
float
]]:
"""
Reduce the observations collected during a time interval to one or
more representative values (eg avg/median/sum of the series).
This is meant to be called by the logger to produce a summary of the
stats for the last time interval.
"""
raise
NotImplementedError
def
is_empty
(
self
)
->
bool
:
"""Return True if the stats are empty."""
raise
NotImplementedError
class
KVConnectorLogging
:
def
__init__
(
self
,
kv_tranfer_config
:
KVTransferConfig
):
# This should be called on frontend process.
assert
not
has_kv_transfer_group
()
# Instantiate the connector's stats class.
if
kv_tranfer_config
and
kv_tranfer_config
.
kv_connector
:
self
.
connector_cls
=
KVConnectorFactory
.
get_connector_class
(
kv_tranfer_config
)
self
.
reset
()
def
reset
(
self
):
self
.
transfer_stats_accumulator
:
Optional
[
KVConnectorStats
]
=
None
def
observe
(
self
,
transfer_stats_data
:
dict
[
str
,
Any
]):
# Should not be called when a KVConnector is not configured.
assert
self
.
connector_cls
is
not
None
# Called periodically when connector syncs with the scheduler.
# Note that this is not the same as the logging interval.
# We expect transfer_stats_data to be aggregated across all workers and
# consist of observations from a single connector or a MultiConnector.
transfer_stats
=
self
.
connector_cls
.
build_kv_connector_stats
(
transfer_stats_data
)
if
transfer_stats
is
None
:
logger
.
warning_once
(
"The connector %s is collecting stats but "
"does not implement the "
"`build_kv_connector_stats` method. "
"Stats will not be logged."
,
self
.
connector_cls
)
return
if
self
.
transfer_stats_accumulator
is
None
:
self
.
transfer_stats_accumulator
=
transfer_stats
else
:
# Accumulate last interval stats.
self
.
transfer_stats_accumulator
=
\
self
.
transfer_stats_accumulator
.
aggregate
(
transfer_stats
)
def
log
(
self
,
log_fn
=
logger
.
info
):
"""Log transfer metrics periodically, similar to throughput logging"""
if
(
self
.
transfer_stats_accumulator
and
not
self
.
transfer_stats_accumulator
.
is_empty
()):
# Produce a single cumulative stats object for the last time
# interval from the recorded observations.
xfer_metrics
=
self
.
transfer_stats_accumulator
.
reduce
()
xfer_metrics_str
=
", "
.
join
(
f
"
{
k
}
=
{
v
}
"
for
k
,
v
in
xfer_metrics
.
items
())
log_fn
(
"KV Transfer metrics: %s"
,
xfer_metrics_str
)
# Reset metrics for next interval
self
.
reset
()
\ No newline at end of file
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
a3d087ad
...
@@ -9,19 +9,21 @@ import torch
...
@@ -9,19 +9,21 @@ import torch
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.kv_transfer
import
KVTransferConfig
from
vllm.config.kv_transfer
import
KVTransferConfig
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
KVConnectorFactory
)
KVConnectorFactory
)
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorStats
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
from
vllm.v1.outputs
import
KVConnectorOutput
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -33,6 +35,43 @@ class MultiKVConnectorMetadata(KVConnectorMetadata):
...
@@ -33,6 +35,43 @@ class MultiKVConnectorMetadata(KVConnectorMetadata):
extra_async_saves
:
Optional
[
dict
[
str
,
int
]]
=
None
extra_async_saves
:
Optional
[
dict
[
str
,
int
]]
=
None
@
dataclass
class
MultiKVConnectorStats
(
KVConnectorStats
):
"""
Maintain a dict of KVConnectorStats objects, one for each connector.
This is used to aggregate the stats from all connectors separately.
"""
def
aggregate
(
self
,
other
:
KVConnectorStats
)
->
KVConnectorStats
:
for
connector_id
,
stats
in
other
.
data
.
items
():
if
connector_id
not
in
self
.
data
:
self
[
connector_id
]
=
stats
else
:
assert
isinstance
(
stats
,
type
(
self
.
data
[
connector_id
]))
self
[
connector_id
]
=
self
[
connector_id
].
aggregate
(
stats
)
return
self
def
reset
(
self
):
for
stats
in
self
.
data
.
values
():
stats
.
reset
()
def
reduce
(
self
)
->
dict
[
str
,
Any
]:
# TODO (NickLucche) Adjust for logging on separate lines
return
{
connector_id
:
stats
.
reduce
()
for
connector_id
,
stats
in
self
.
data
.
items
()
}
def
is_empty
(
self
)
->
bool
:
return
all
(
stats
.
is_empty
()
for
stats
in
self
.
data
.
values
())
def
__getitem__
(
self
,
connector_id
:
str
)
->
KVConnectorStats
:
return
self
.
data
[
connector_id
]
def
__setitem__
(
self
,
connector_id
:
str
,
stats
:
KVConnectorStats
):
self
.
data
[
connector_id
]
=
stats
class
MultiConnector
(
KVConnectorBase_V1
):
class
MultiConnector
(
KVConnectorBase_V1
):
"""
"""
A wrapper for using multiple KVConnectors at the same time.
A wrapper for using multiple KVConnectors at the same time.
...
@@ -46,6 +85,7 @@ class MultiConnector(KVConnectorBase_V1):
...
@@ -46,6 +85,7 @@ class MultiConnector(KVConnectorBase_V1):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
self
.
_connectors
:
list
[
KVConnectorBase_V1
]
=
[]
self
.
_connectors
:
list
[
KVConnectorBase_V1
]
=
[]
self
.
_ktc_kv_transfer_config
=
[]
ktcs
=
vllm_config
.
kv_transfer_config
.
kv_connector_extra_config
.
get
(
ktcs
=
vllm_config
.
kv_transfer_config
.
kv_connector_extra_config
.
get
(
"connectors"
)
"connectors"
)
assert
ktcs
is
not
None
assert
ktcs
is
not
None
...
@@ -57,6 +97,7 @@ class MultiConnector(KVConnectorBase_V1):
...
@@ -57,6 +97,7 @@ class MultiConnector(KVConnectorBase_V1):
**
ktc
,
engine_id
=
engine_id
)
**
ktc
,
engine_id
=
engine_id
)
self
.
_connectors
.
append
(
self
.
_connectors
.
append
(
KVConnectorFactory
.
create_connector
(
temp_config
,
role
))
KVConnectorFactory
.
create_connector
(
temp_config
,
role
))
self
.
_ktc_kv_transfer_config
.
append
(
temp_config
.
kv_transfer_config
)
# A mapping from request id to the index of the connector chosen to
# A mapping from request id to the index of the connector chosen to
# load the request from (if any).
# load the request from (if any).
...
@@ -227,7 +268,7 @@ class MultiConnector(KVConnectorBase_V1):
...
@@ -227,7 +268,7 @@ class MultiConnector(KVConnectorBase_V1):
return
async_saves
>
0
,
kv_txfer_params
return
async_saves
>
0
,
kv_txfer_params
def
take_events
(
self
)
->
Iterable
[
KVCacheEvent
]:
def
take_events
(
self
)
->
Iterable
[
"
KVCacheEvent
"
]:
for
c
in
self
.
_connectors
:
for
c
in
self
.
_connectors
:
yield
from
c
.
take_events
()
yield
from
c
.
take_events
()
...
@@ -264,3 +305,24 @@ class MultiConnector(KVConnectorBase_V1):
...
@@ -264,3 +305,24 @@ class MultiConnector(KVConnectorBase_V1):
f
"(
{
', '
.
join
(
layouts
)
}
)."
f
"(
{
', '
.
join
(
layouts
)
}
)."
f
"All connectors must use the same layout."
)
f
"All connectors must use the same layout."
)
return
next
(
iter
(
layouts
),
None
)
return
next
(
iter
(
layouts
),
None
)
@
classmethod
def
build_kv_connector_stats
(
cls
,
data
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
Optional
[
KVConnectorStats
]:
return
MultiKVConnectorStats
(
data
=
data
)
if
data
is
not
None
\
else
MultiKVConnectorStats
()
def
get_kv_connector_stats
(
self
)
->
Optional
[
MultiKVConnectorStats
]:
# Group connector stats by connector type.
stats_by_connector
:
Optional
[
MultiKVConnectorStats
]
=
None
for
c
in
self
.
_connectors
:
stats
=
c
.
get_kv_connector_stats
()
if
stats
is
None
:
continue
if
stats_by_connector
is
None
:
# Lazy init to allow optional return value.
stats_by_connector
=
MultiKVConnectorStats
()
stats_by_connector
[
c
.
__class__
.
__name__
]
=
stats
return
stats_by_connector
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
a3d087ad
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
contextlib
import
copy
import
logging
import
logging
import
math
import
math
import
queue
import
queue
...
@@ -11,7 +12,7 @@ from collections import defaultdict
...
@@ -11,7 +12,7 @@ from collections import defaultdict
from
collections.abc
import
Iterator
from
collections.abc
import
Iterator
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
import
msgspec
import
msgspec
import
numpy
as
np
import
numpy
as
np
...
@@ -23,6 +24,8 @@ from vllm.attention.selector import backend_name_to_enum, get_attn_backend
...
@@ -23,6 +24,8 @@ from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
CopyBlocksOp
,
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
CopyBlocksOp
,
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorStats
)
from
vllm.distributed.parallel_state
import
(
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tp_group
)
get_tp_group
)
...
@@ -33,7 +36,6 @@ from vllm.platforms import _Backend, current_platform
...
@@ -33,7 +36,6 @@ from vllm.platforms import _Backend, current_platform
from
vllm.utils
import
make_zmq_path
,
make_zmq_socket
from
vllm.utils
import
make_zmq_path
,
make_zmq_socket
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.request
import
RequestStatus
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
...
@@ -206,6 +208,18 @@ class NixlConnector(KVConnectorBase_V1):
...
@@ -206,6 +208,18 @@ class NixlConnector(KVConnectorBase_V1):
assert
self
.
connector_worker
is
not
None
assert
self
.
connector_worker
is
not
None
return
self
.
connector_worker
.
get_finished
()
return
self
.
connector_worker
.
get_finished
()
def
get_kv_connector_stats
(
self
)
->
Optional
[
KVConnectorStats
]:
assert
self
.
connector_worker
is
not
None
return
self
.
connector_worker
.
get_kv_connector_stats
()
@
classmethod
def
build_kv_connector_stats
(
cls
,
data
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
Optional
[
KVConnectorStats
]:
return
NixlKVConnectorStats
(
data
=
data
)
if
data
is
not
None
\
else
NixlKVConnectorStats
()
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
**
kwargs
)
->
None
:
**
kwargs
)
->
None
:
assert
self
.
connector_worker
is
not
None
assert
self
.
connector_worker
is
not
None
...
@@ -377,6 +391,7 @@ class NixlConnectorScheduler:
...
@@ -377,6 +391,7 @@ class NixlConnectorScheduler:
Once a request is finished, determine whether request blocks
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
should be freed now or will be sent asynchronously and freed later.
"""
"""
from
vllm.v1.request
import
RequestStatus
params
=
request
.
kv_transfer_params
params
=
request
.
kv_transfer_params
logger
.
debug
(
logger
.
debug
(
...
@@ -550,6 +565,7 @@ class NixlConnectorWorker:
...
@@ -550,6 +565,7 @@ class NixlConnectorWorker:
# With heterogeneous TP, P must wait for all assigned D TP workers to
# With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks.
# finish reading before safely freeing the blocks.
self
.
consumer_notification_counts_by_req
=
defaultdict
[
ReqId
,
int
](
int
)
self
.
consumer_notification_counts_by_req
=
defaultdict
[
ReqId
,
int
](
int
)
self
.
xfer_stats
=
NixlKVConnectorStats
()
def
__del__
(
self
):
def
__del__
(
self
):
"""Cleanup background threads on destruction."""
"""Cleanup background threads on destruction."""
...
@@ -1097,6 +1113,8 @@ class NixlConnectorWorker:
...
@@ -1097,6 +1113,8 @@ class NixlConnectorWorker:
xfer_state
=
self
.
nixl_wrapper
.
check_xfer_state
(
handle
)
xfer_state
=
self
.
nixl_wrapper
.
check_xfer_state
(
handle
)
if
xfer_state
==
"DONE"
:
if
xfer_state
==
"DONE"
:
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
# TODO (NickLucche) Get from NIXL telemetry once integrated
self
.
xfer_stats
.
record_transfer
()
elif
xfer_state
==
"PROC"
:
elif
xfer_state
==
"PROC"
:
in_progress
=
True
in_progress
=
True
continue
continue
...
@@ -1248,7 +1266,6 @@ class NixlConnectorWorker:
...
@@ -1248,7 +1266,6 @@ class NixlConnectorWorker:
self
.
nixl_wrapper
.
transfer
(
handle
)
self
.
nixl_wrapper
.
transfer
(
handle
)
# Use handle to check completion in future step().
# Use handle to check completion in future step().
# TODO (NickLucche) surface xfer elapsed time
self
.
_recving_transfers
[
request_id
].
append
(
self
.
_recving_transfers
[
request_id
].
append
(
(
handle
,
time
.
perf_counter
()))
(
handle
,
time
.
perf_counter
()))
...
@@ -1300,6 +1317,15 @@ class NixlConnectorWorker:
...
@@ -1300,6 +1317,15 @@ class NixlConnectorWorker:
block_len
=
self
.
block_len
block_len
=
self
.
block_len
return
block_len
return
block_len
def
get_kv_connector_stats
(
self
)
->
Optional
[
KVConnectorStats
]:
"""
Get the KV transfer stats for the connector.
"""
# Clear stats for next iteration
if
not
self
.
xfer_stats
.
is_empty
():
return
self
.
xfer_stats
.
clone_and_reset
()
return
None
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
zmq_ctx
(
socket_type
:
Any
,
addr
:
str
)
->
Iterator
[
zmq
.
Socket
]:
def
zmq_ctx
(
socket_type
:
Any
,
addr
:
str
)
->
Iterator
[
zmq
.
Socket
]:
...
@@ -1318,3 +1344,39 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
...
@@ -1318,3 +1344,39 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
finally
:
finally
:
if
ctx
is
not
None
:
if
ctx
is
not
None
:
ctx
.
destroy
(
linger
=
0
)
ctx
.
destroy
(
linger
=
0
)
@
dataclass
class
NixlKVConnectorStats
(
KVConnectorStats
):
"""Container for transfer performance metrics"""
def
__post_init__
(
self
):
if
"num_successful_transfers"
not
in
self
.
data
:
self
.
data
[
"num_successful_transfers"
]
=
0
def
reset
(
self
):
self
.
data
=
{
"num_successful_transfers"
:
0
}
def
record_transfer
(
self
):
# TODO: record actual transfer stats when available
self
.
data
[
"num_successful_transfers"
]
+=
1
def
clone_and_reset
(
self
)
->
"NixlKVConnectorStats"
:
old
=
copy
.
copy
(
self
)
self
.
reset
()
return
old
def
is_empty
(
self
)
->
bool
:
return
self
.
data
[
"num_successful_transfers"
]
==
0
def
aggregate
(
self
,
other
:
KVConnectorStats
)
->
KVConnectorStats
:
if
not
other
.
is_empty
():
self
.
data
[
"num_successful_transfers"
]
+=
other
.
data
[
"num_successful_transfers"
]
return
self
def
reduce
(
self
)
->
dict
[
str
,
Union
[
int
,
float
]]:
# TODO: reduce stats to a single value, calculate latency/throughput
return
{
"num_successful_transfers"
:
self
.
data
[
"num_successful_transfers"
]
}
\ No newline at end of file
vllm/v1/core/sched/scheduler.py
View file @
a3d087ad
...
@@ -15,6 +15,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
...
@@ -15,6 +15,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory
)
KVConnectorFactory
)
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
KVConnectorRole
)
KVConnectorRole
)
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorStats
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
...
@@ -869,9 +871,12 @@ class Scheduler(SchedulerInterface):
...
@@ -869,9 +871,12 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
pooler_outputs
=
model_runner_output
.
pooler_output
pooler_outputs
=
model_runner_output
.
pooler_output
num_nans_in_logits
=
model_runner_output
.
num_nans_in_logits
num_nans_in_logits
=
model_runner_output
.
num_nans_in_logits
kv_connector_output
=
model_runner_output
.
kv_connector_output
outputs
:
dict
[
int
,
list
[
EngineCoreOutput
]]
=
defaultdict
(
list
)
outputs
:
dict
[
int
,
list
[
EngineCoreOutput
]]
=
defaultdict
(
list
)
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
kv_connector_stats
=
(
kv_connector_output
.
kv_connector_stats
if
kv_connector_output
else
None
)
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# the below loop can be a performance bottleneck. We should do our best
# the below loop can be a performance bottleneck. We should do our best
...
@@ -1007,7 +1012,8 @@ class Scheduler(SchedulerInterface):
...
@@ -1007,7 +1012,8 @@ class Scheduler(SchedulerInterface):
finished_requests
=
finished_set
)
finished_requests
=
finished_set
)
finished_req_ids
.
clear
()
finished_req_ids
.
clear
()
if
(
stats
:
=
self
.
make_stats
(
spec_decoding_stats
))
is
not
None
:
if
(
stats
:
=
self
.
make_stats
(
spec_decoding_stats
,
kv_connector_stats
))
is
not
None
:
# Return stats to only one of the front-ends.
# Return stats to only one of the front-ends.
if
(
eco
:
=
next
(
iter
(
engine_core_outputs
.
values
()),
None
))
is
None
:
if
(
eco
:
=
next
(
iter
(
engine_core_outputs
.
values
()),
None
))
is
None
:
# We must return the stats even if there are no request
# We must return the stats even if there are no request
...
@@ -1172,20 +1178,21 @@ class Scheduler(SchedulerInterface):
...
@@ -1172,20 +1178,21 @@ class Scheduler(SchedulerInterface):
def
make_stats
(
def
make_stats
(
self
,
self
,
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
,
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
,
kv_connector_stats
:
Optional
[
KVConnectorStats
]
=
None
,
)
->
Optional
[
SchedulerStats
]:
)
->
Optional
[
SchedulerStats
]:
if
not
self
.
log_stats
:
if
not
self
.
log_stats
:
return
None
return
None
prefix_cache_stats
=
self
.
kv_cache_manager
.
make_prefix_cache_stats
()
prefix_cache_stats
=
self
.
kv_cache_manager
.
make_prefix_cache_stats
()
assert
prefix_cache_stats
is
not
None
assert
prefix_cache_stats
is
not
None
return
SchedulerStats
(
return
SchedulerStats
(
num_running_reqs
=
len
(
self
.
running
),
num_running_reqs
=
len
(
self
.
running
),
num_waiting_reqs
=
len
(
self
.
waiting
),
num_waiting_reqs
=
len
(
self
.
waiting
),
kv_cache_usage
=
self
.
kv_cache_manager
.
usage
,
kv_cache_usage
=
self
.
kv_cache_manager
.
usage
,
prefix_cache_stats
=
prefix_cache_stats
,
prefix_cache_stats
=
prefix_cache_stats
,
spec_decoding_stats
=
spec_decoding_stats
,
spec_decoding_stats
=
spec_decoding_stats
,
num_corrupted_reqs
=
sum
(
req
.
is_output_corrupted
num_corrupted_reqs
=
sum
(
req
.
is_output_corrupted
for
req
in
self
.
running
),
for
req
in
self
.
running
),
)
kv_connector_stats
=
kv_connector_stats
.
data
if
kv_connector_stats
else
None
)
def
make_spec_decoding_stats
(
def
make_spec_decoding_stats
(
self
,
self
,
...
...
vllm/v1/metrics/loggers.py
View file @
a3d087ad
...
@@ -9,6 +9,8 @@ from typing import Callable, Optional, Union
...
@@ -9,6 +9,8 @@ from typing import Callable, Optional, Union
import
prometheus_client
import
prometheus_client
from
vllm.config
import
SupportsMetricsInfo
,
VllmConfig
from
vllm.config
import
SupportsMetricsInfo
,
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorLogging
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.engine
import
FinishReason
from
vllm.v1.engine
import
FinishReason
...
@@ -59,6 +61,8 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -59,6 +61,8 @@ class LoggingStatLogger(StatLoggerBase):
# TODO: Make the interval configurable.
# TODO: Make the interval configurable.
self
.
prefix_caching_metrics
=
PrefixCachingMetrics
()
self
.
prefix_caching_metrics
=
PrefixCachingMetrics
()
self
.
spec_decoding_logging
=
SpecDecodingLogging
()
self
.
spec_decoding_logging
=
SpecDecodingLogging
()
kv_tranfer_config
=
self
.
vllm_config
.
kv_transfer_config
self
.
kv_transfer_logging
=
KVConnectorLogging
(
kv_tranfer_config
)
self
.
last_prompt_throughput
:
float
=
0.0
self
.
last_prompt_throughput
:
float
=
0.0
self
.
last_generation_throughput
:
float
=
0.0
self
.
last_generation_throughput
:
float
=
0.0
...
@@ -97,7 +101,8 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -97,7 +101,8 @@ class LoggingStatLogger(StatLoggerBase):
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_logging
.
observe
(
self
.
spec_decoding_logging
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
scheduler_stats
.
spec_decoding_stats
)
if
kv_connector_stats
:
=
scheduler_stats
.
kv_connector_stats
:
self
.
kv_transfer_logging
.
observe
(
kv_connector_stats
)
self
.
last_scheduler_stats
=
scheduler_stats
self
.
last_scheduler_stats
=
scheduler_stats
def
log
(
self
):
def
log
(
self
):
...
@@ -136,6 +141,7 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -136,6 +141,7 @@ class LoggingStatLogger(StatLoggerBase):
self
.
prefix_caching_metrics
.
hit_rate
*
100
,
self
.
prefix_caching_metrics
.
hit_rate
*
100
,
)
)
self
.
spec_decoding_logging
.
log
(
log_fn
=
log_fn
)
self
.
spec_decoding_logging
.
log
(
log_fn
=
log_fn
)
self
.
kv_transfer_logging
.
log
(
log_fn
=
log_fn
)
def
log_engine_initialized
(
self
):
def
log_engine_initialized
(
self
):
if
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
:
if
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
:
...
...
vllm/v1/metrics/stats.py
View file @
a3d087ad
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
import
time
import
time
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
...
@@ -43,6 +43,7 @@ class SchedulerStats:
...
@@ -43,6 +43,7 @@ class SchedulerStats:
default_factory
=
PrefixCacheStats
)
default_factory
=
PrefixCacheStats
)
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
kv_connector_stats
:
Optional
[
dict
[
str
,
Any
]]
=
None
num_corrupted_reqs
:
int
=
0
num_corrupted_reqs
:
int
=
0
...
...
vllm/v1/outputs.py
View file @
a3d087ad
...
@@ -3,10 +3,14 @@
...
@@ -3,10 +3,14 @@
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
NamedTuple
,
Optional
from
typing
import
TYPE_CHECKING
,
NamedTuple
,
Optional
import
torch
import
torch
if
TYPE_CHECKING
:
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorStats
)
class
LogprobsLists
(
NamedTuple
):
class
LogprobsLists
(
NamedTuple
):
...
@@ -77,6 +81,11 @@ class KVConnectorOutput:
...
@@ -77,6 +81,11 @@ class KVConnectorOutput:
# [req_ids]
# [req_ids]
finished_sending
:
Optional
[
set
[
str
]]
=
None
finished_sending
:
Optional
[
set
[
str
]]
=
None
finished_recving
:
Optional
[
set
[
str
]]
=
None
finished_recving
:
Optional
[
set
[
str
]]
=
None
kv_connector_stats
:
Optional
[
"KVConnectorStats"
]
=
None
def
is_empty
(
self
):
return
(
not
self
.
finished_sending
and
not
self
.
finished_recving
and
not
self
.
kv_connector_stats
)
# ModelRunnerOutput is serialized and sent to the scheduler process.
# ModelRunnerOutput is serialized and sent to the scheduler process.
...
...
vllm/v1/worker/kv_connector_model_runner_mixin.py
View file @
a3d087ad
...
@@ -13,6 +13,8 @@ from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown,
...
@@ -13,6 +13,8 @@ from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown,
get_kv_transfer_group
,
get_kv_transfer_group
,
has_kv_transfer_group
)
has_kv_transfer_group
)
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorStats
)
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
KVConnectorOutput
,
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
KVConnectorOutput
,
...
@@ -119,4 +121,11 @@ class KVConnectorModelRunnerMixin:
...
@@ -119,4 +121,11 @@ class KVConnectorModelRunnerMixin:
output
.
finished_sending
,
output
.
finished_recving
=
(
output
.
finished_sending
,
output
.
finished_recving
=
(
kv_connector
.
get_finished
(
scheduler_output
.
finished_req_ids
))
kv_connector
.
get_finished
(
scheduler_output
.
finished_req_ids
))
kv_connector
.
clear_connector_metadata
()
output
.
kv_connector_stats
=
KVConnectorModelRunnerMixin
.
\
get_kv_connector_stats
()
@
staticmethod
def
get_kv_connector_stats
()
->
Optional
[
KVConnectorStats
]:
if
has_kv_transfer_group
():
return
get_kv_transfer_group
().
get_kv_connector_stats
()
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment