Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a3d087ad
Unverified
Commit
a3d087ad
authored
Sep 19, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Sep 19, 2025
Browse files
[P/D][Nixl] Introduce `KVTransferMetrics` and aggregation strategy (#22188)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
058525b9
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
525 additions
and
25 deletions
+525
-25
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+210
-1
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+18
-3
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+21
-1
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
+100
-0
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+65
-3
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+65
-3
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+17
-10
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+7
-1
vllm/v1/metrics/stats.py
vllm/v1/metrics/stats.py
+2
-1
vllm/v1/outputs.py
vllm/v1/outputs.py
+10
-1
vllm/v1/worker/kv_connector_model_runner_mixin.py
vllm/v1/worker/kv_connector_model_runner_mixin.py
+10
-1
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
a3d087ad
...
...
@@ -18,12 +18,18 @@ import torch
from
vllm
import
LLM
from
vllm.config
import
KVTransferConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
KVOutputAggregator
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorStats
)
from
vllm.distributed.kv_transfer.kv_connector.v1.multi_connector
import
(
MultiKVConnectorStats
)
from
vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector
import
(
KVConnectorRole
,
NixlAgentMetadata
,
NixlConnector
,
NixlConnectorMetadata
,
NixlConnectorWorker
)
NixlConnectorWorker
,
NixlKVConnectorStats
)
from
vllm.forward_context
import
ForwardContext
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.v1.outputs
import
KVConnectorOutput
,
ModelRunnerOutput
from
.utils
import
create_request
,
create_scheduler
,
create_vllm_config
...
...
@@ -475,6 +481,209 @@ class TestNixlHandshake:
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
# we put here is important. First run ray, it will clean up the resources, then
# the rest of the tests.
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
FakeNixlWrapper
)
def
test_kv_connector_stats
(
dist_init
):
"""Test that KV transfer stats are properly recorded and retrieved."""
vllm_config
=
create_vllm_config
()
# Test worker role in decode server.
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
.
connector_worker
=
FakeNixlConnectorWorker
(
vllm_config
,
connector
.
engine_id
,
hand_shake_latency
=
0
)
# Verify that xfer_stats starts empty
initial_stats
=
connector
.
get_kv_connector_stats
()
assert
initial_stats
is
None
# Create transfer metadata
request_id
=
"test_req_for_stats"
metadata
=
NixlConnectorMetadata
()
metadata
.
add_new_req
(
request_id
=
request_id
,
local_block_ids
=
[
1
,
2
,
3
],
kv_transfer_params
=
{
"remote_block_ids"
:
[
4
,
5
,
6
],
"remote_engine_id"
:
FakeNixlConnectorWorker
.
REMOTE_ENGINE_ID
,
"remote_host"
:
"localhost"
,
"remote_port"
:
1234
,
"remote_tp_size"
:
1
,
})
connector
.
bind_connector_metadata
(
metadata
)
# Start the transfer
dummy_ctx
=
ForwardContext
(
no_compile_layers
=
{},
attn_metadata
=
{},
virtual_engine
=
0
,
)
connector
.
start_load_kv
(
dummy_ctx
)
# Verify stats are recorded after transfer is complete
max_iterations
=
2
# Clear metadata before start_load_kv to prevent reprocessing same request
connector
.
bind_connector_metadata
(
NixlConnectorMetadata
())
for
_
in
range
(
max_iterations
):
# Need to call start_load_kv to process completed handshakes
connector
.
start_load_kv
(
dummy_ctx
)
_
,
done_recving
=
connector
.
get_finished
(
finished_req_ids
=
set
())
if
len
(
done_recving
)
>
0
and
request_id
in
done_recving
:
break
time
.
sleep
(
0.1
)
# Small delay to allow background handshake to complete
else
:
assert
"Transfer did not complete within expected iterations"
# Now check that stats were recorded
stats_after_transfer
=
connector
.
get_kv_connector_stats
()
assert
isinstance
(
stats_after_transfer
,
NixlKVConnectorStats
)
# Verify stats values are recorded
assert
not
stats_after_transfer
.
is_empty
()
assert
stats_after_transfer
.
data
[
"num_successful_transfers"
]
==
1
# Verify stats are reset after retrieval
stats_after_reset
=
connector
.
get_kv_connector_stats
()
assert
stats_after_reset
is
None
def
test_kv_connector_stats_aggregation
():
"""
Test KV transfer stats aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""
# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
# done in MultiprocExecutor.execute_model
aggregator
=
KVOutputAggregator
(
world_size
=
3
)
# Create stats for multiple workers with different transfer patterns
worker1_stats
=
NixlKVConnectorStats
()
worker2_stats
=
NixlKVConnectorStats
()
worker3_stats
=
NixlKVConnectorStats
()
# Record different transfers on each worker
# Worker 1: 2 transfers
worker1_stats
.
record_transfer
()
worker1_stats
.
record_transfer
()
# Worker 2: 1 transfer
worker2_stats
.
record_transfer
()
# Worker 3: 3 transfers
worker3_stats
.
record_transfer
()
worker3_stats
.
record_transfer
()
worker3_stats
.
record_transfer
()
# Create ModelRunnerOutput instances for each worker
worker_outputs
=
[]
for
i
,
worker_stats
in
enumerate
(
[
worker1_stats
,
worker2_stats
,
worker3_stats
]):
output
=
ModelRunnerOutput
(
req_ids
=
[
f
"req_
{
i
}
"
],
req_id_to_index
=
{
f
"req_
{
i
}
"
:
0
},
sampled_token_ids
=
[[
123
]],
# dummy token
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
[
None
],
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
set
([
f
"req_
{
i
}
_send"
])
if
i
<
2
else
None
,
# Workers 0,1 finished sending
finished_recving
=
set
([
f
"req_
{
i
}
_recv"
])
if
i
>
0
else
None
,
# Workers 1,2 finished receiving
kv_connector_stats
=
worker_stats
,
))
worker_outputs
.
append
(
output
)
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
aggregated_output
=
aggregator
.
aggregate
(
worker_outputs
,
output_rank
=
0
)
kv_connector_stats
=
\
aggregated_output
.
kv_connector_output
.
kv_connector_stats
assert
isinstance
(
kv_connector_stats
,
NixlKVConnectorStats
)
# Number of total transfers across all workers.
assert
kv_connector_stats
.
data
[
"num_successful_transfers"
]
==
6
def
test_multi_kv_connector_stats_aggregation
():
"""
Test MultiKVConnectorStats aggregation across TP ranks using
KVOutputAggregator (used by MultiprocExecutor).
"""
aggregator
=
KVOutputAggregator
(
world_size
=
3
)
from
dataclasses
import
dataclass
@
dataclass
class
FooKVConnectorStats
(
KVConnectorStats
):
def
reset
(
self
):
self
.
data
=
{
"num_foo_transfers"
:
0
}
def
record_transfer
(
self
):
if
"num_foo_transfers"
not
in
self
.
data
:
self
.
data
[
"num_foo_transfers"
]
=
0
self
.
data
[
"num_foo_transfers"
]
+=
1
def
is_empty
(
self
)
->
bool
:
return
self
.
data
[
"num_foo_transfers"
]
==
0
def
aggregate
(
self
,
other
:
"FooKVConnectorStats"
)
->
"FooKVConnectorStats"
:
if
not
other
.
is_empty
():
self
.
data
[
"num_foo_transfers"
]
+=
other
.
data
[
"num_foo_transfers"
]
return
self
def
make_multi_stats
(
nixl_count
:
int
,
foo_count
:
int
)
->
MultiKVConnectorStats
:
data
:
dict
[
str
,
KVConnectorStats
]
=
{}
if
nixl_count
>
0
:
nixl_stats
=
NixlKVConnectorStats
()
for
_
in
range
(
nixl_count
):
nixl_stats
.
record_transfer
()
data
[
"NixlConnector"
]
=
nixl_stats
if
foo_count
>
0
:
foo_stats
=
FooKVConnectorStats
()
for
_
in
range
(
foo_count
):
foo_stats
.
record_transfer
()
data
[
"FooConnector"
]
=
foo_stats
return
MultiKVConnectorStats
(
data
=
data
)
# Create heterogeneous stats across 3 workers
worker_patterns
=
[(
2
,
1
),
(
3
,
0
),
(
0
,
5
)]
# (Nixl, Foo)
worker_outputs
:
list
[
ModelRunnerOutput
]
=
[]
for
i
,
(
nixl
,
foo
)
in
enumerate
(
worker_patterns
):
stats
=
make_multi_stats
(
nixl
,
foo
)
output
=
ModelRunnerOutput
(
req_ids
=
[
f
"req_
{
i
}
"
],
req_id_to_index
=
{
f
"req_
{
i
}
"
:
0
},
sampled_token_ids
=
[[
123
]],
logprobs
=
None
,
prompt_logprobs_dict
=
{},
pooler_output
=
[
None
],
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
set
([
f
"req_
{
i
}
_send"
])
if
i
<
2
else
None
,
finished_recving
=
set
([
f
"req_
{
i
}
_recv"
])
if
i
>
0
else
None
,
kv_connector_stats
=
stats
,
),
)
worker_outputs
.
append
(
output
)
aggregated_output
=
aggregator
.
aggregate
(
worker_outputs
,
output_rank
=
0
)
kv_connector_stats
=
\
aggregated_output
.
kv_connector_output
.
kv_connector_stats
assert
isinstance
(
kv_connector_stats
,
MultiKVConnectorStats
)
# Validate per-connector totals across workers
assert
kv_connector_stats
[
"NixlConnector"
].
data
[
"num_successful_transfers"
]
==
5
assert
kv_connector_stats
[
"FooConnector"
].
data
[
"num_foo_transfers"
]
==
6
@
pytest
.
mark
.
parametrize
(
"distributed_executor_backend"
,
[
"ray"
,
None
])
@
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper"
,
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
a3d087ad
...
...
@@ -129,7 +129,7 @@ class KVOutputAggregator:
def
aggregate
(
self
,
outputs
:
list
[
ModelRunnerOutput
],
output_rank
:
int
=
0
)
->
ModelRunnerOutput
:
#
a
ggregate kv_connector_output from all workers
#
A
ggregate kv_connector_output from all workers
def
update_finished_set
(
req_ids
:
Optional
[
set
[
str
]],
remaining_count_dict
:
dict
[
str
,
int
],
...
...
@@ -142,8 +142,9 @@ class KVOutputAggregator:
finished_sending
=
set
[
str
]()
finished_recving
=
set
[
str
]()
for
output
in
outputs
:
output
=
output
.
kv_connector_output
aggregated_kv_connector_stats
=
None
for
model_runner_output
in
outputs
:
output
=
model_runner_output
.
kv_connector_output
if
not
output
:
continue
update_finished_set
(
output
.
finished_sending
,
...
...
@@ -151,12 +152,26 @@ class KVOutputAggregator:
update_finished_set
(
output
.
finished_recving
,
self
.
_recv_remaining_count
,
finished_recving
)
# Aggregate kv_connector_stats from all workers.
if
aggregated_kv_connector_stats
is
None
:
# Use the first worker's kv_connector_stats as accumulator.
aggregated_kv_connector_stats
=
output
.
kv_connector_stats
elif
kv_connector_stats
:
=
output
.
kv_connector_stats
:
if
aggregated_kv_connector_stats
is
None
:
aggregated_kv_connector_stats
=
kv_connector_stats
else
:
assert
isinstance
(
aggregated_kv_connector_stats
,
type
(
kv_connector_stats
))
aggregated_kv_connector_stats
=
\
aggregated_kv_connector_stats
.
aggregate
(
kv_connector_stats
)
# select output of the worker specified by output_rank
output
=
outputs
[
output_rank
]
output
.
kv_connector_output
=
KVConnectorOutput
(
finished_sending
=
finished_sending
or
None
,
finished_recving
=
finished_recving
or
None
,
kv_connector_stats
=
aggregated_kv_connector_stats
or
None
,
)
return
output
...
...
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
a3d087ad
...
...
@@ -49,6 +49,8 @@ if TYPE_CHECKING:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorStats
)
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.request
import
Request
...
...
@@ -235,6 +237,12 @@ class KVConnectorBase_V1(ABC):
"""
return
None
def
get_kv_connector_stats
(
self
)
->
Optional
[
"KVConnectorStats"
]:
"""
Get the KV connector stats collected during the last interval.
"""
return
None
# ==============================
# Scheduler-side methods
# ==============================
...
...
@@ -365,4 +373,16 @@ class KVConnectorBase_V1(ABC):
int: expected sending or receiving completion count.
"""
return
None
\ No newline at end of file
return
None
@
classmethod
def
build_kv_connector_stats
(
cls
,
data
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
Optional
[
"KVConnectorStats"
]:
"""
KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object,
which can implement custom aggregation logic on the data dict.
"""
return
None
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
0 → 100644
View file @
a3d087ad
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Optional
,
Union
from
vllm.config.kv_transfer
import
KVTransferConfig
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
KVConnectorFactory
)
from
vllm.distributed.kv_transfer.kv_transfer_state
import
(
has_kv_transfer_group
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
@
dataclass
class
KVConnectorStats
:
"""
Base class for KV Connector Stats, a container for transfer performance
metrics or otherwise important telemetry from the connector.
All sub-classes need to be serializable as stats are sent from worker to
logger process.
"""
data
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
def
reset
(
self
):
"""Reset the stats, clear the state."""
raise
NotImplementedError
def
aggregate
(
self
,
other
:
"KVConnectorStats"
)
->
"KVConnectorStats"
:
"""
Aggregate stats with another `KVConnectorStats` object.
"""
raise
NotImplementedError
def
reduce
(
self
)
->
dict
[
str
,
Union
[
int
,
float
]]:
"""
Reduce the observations collected during a time interval to one or
more representative values (eg avg/median/sum of the series).
This is meant to be called by the logger to produce a summary of the
stats for the last time interval.
"""
raise
NotImplementedError
def
is_empty
(
self
)
->
bool
:
"""Return True if the stats are empty."""
raise
NotImplementedError
class
KVConnectorLogging
:
def
__init__
(
self
,
kv_tranfer_config
:
KVTransferConfig
):
# This should be called on frontend process.
assert
not
has_kv_transfer_group
()
# Instantiate the connector's stats class.
if
kv_tranfer_config
and
kv_tranfer_config
.
kv_connector
:
self
.
connector_cls
=
KVConnectorFactory
.
get_connector_class
(
kv_tranfer_config
)
self
.
reset
()
def
reset
(
self
):
self
.
transfer_stats_accumulator
:
Optional
[
KVConnectorStats
]
=
None
def
observe
(
self
,
transfer_stats_data
:
dict
[
str
,
Any
]):
# Should not be called when a KVConnector is not configured.
assert
self
.
connector_cls
is
not
None
# Called periodically when connector syncs with the scheduler.
# Note that this is not the same as the logging interval.
# We expect transfer_stats_data to be aggregated across all workers and
# consist of observations from a single connector or a MultiConnector.
transfer_stats
=
self
.
connector_cls
.
build_kv_connector_stats
(
transfer_stats_data
)
if
transfer_stats
is
None
:
logger
.
warning_once
(
"The connector %s is collecting stats but "
"does not implement the "
"`build_kv_connector_stats` method. "
"Stats will not be logged."
,
self
.
connector_cls
)
return
if
self
.
transfer_stats_accumulator
is
None
:
self
.
transfer_stats_accumulator
=
transfer_stats
else
:
# Accumulate last interval stats.
self
.
transfer_stats_accumulator
=
\
self
.
transfer_stats_accumulator
.
aggregate
(
transfer_stats
)
def
log
(
self
,
log_fn
=
logger
.
info
):
"""Log transfer metrics periodically, similar to throughput logging"""
if
(
self
.
transfer_stats_accumulator
and
not
self
.
transfer_stats_accumulator
.
is_empty
()):
# Produce a single cumulative stats object for the last time
# interval from the recorded observations.
xfer_metrics
=
self
.
transfer_stats_accumulator
.
reduce
()
xfer_metrics_str
=
", "
.
join
(
f
"
{
k
}
=
{
v
}
"
for
k
,
v
in
xfer_metrics
.
items
())
log_fn
(
"KV Transfer metrics: %s"
,
xfer_metrics_str
)
# Reset metrics for next interval
self
.
reset
()
\ No newline at end of file
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
a3d087ad
...
...
@@ -9,19 +9,21 @@ import torch
from
vllm.config
import
VllmConfig
from
vllm.config.kv_transfer
import
KVTransferConfig
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
KVConnectorFactory
)
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorStats
)
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
...
...
@@ -33,6 +35,43 @@ class MultiKVConnectorMetadata(KVConnectorMetadata):
extra_async_saves
:
Optional
[
dict
[
str
,
int
]]
=
None
@
dataclass
class
MultiKVConnectorStats
(
KVConnectorStats
):
"""
Maintain a dict of KVConnectorStats objects, one for each connector.
This is used to aggregate the stats from all connectors separately.
"""
def
aggregate
(
self
,
other
:
KVConnectorStats
)
->
KVConnectorStats
:
for
connector_id
,
stats
in
other
.
data
.
items
():
if
connector_id
not
in
self
.
data
:
self
[
connector_id
]
=
stats
else
:
assert
isinstance
(
stats
,
type
(
self
.
data
[
connector_id
]))
self
[
connector_id
]
=
self
[
connector_id
].
aggregate
(
stats
)
return
self
def
reset
(
self
):
for
stats
in
self
.
data
.
values
():
stats
.
reset
()
def
reduce
(
self
)
->
dict
[
str
,
Any
]:
# TODO (NickLucche) Adjust for logging on separate lines
return
{
connector_id
:
stats
.
reduce
()
for
connector_id
,
stats
in
self
.
data
.
items
()
}
def
is_empty
(
self
)
->
bool
:
return
all
(
stats
.
is_empty
()
for
stats
in
self
.
data
.
values
())
def
__getitem__
(
self
,
connector_id
:
str
)
->
KVConnectorStats
:
return
self
.
data
[
connector_id
]
def
__setitem__
(
self
,
connector_id
:
str
,
stats
:
KVConnectorStats
):
self
.
data
[
connector_id
]
=
stats
class
MultiConnector
(
KVConnectorBase_V1
):
"""
A wrapper for using multiple KVConnectors at the same time.
...
...
@@ -46,6 +85,7 @@ class MultiConnector(KVConnectorBase_V1):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
self
.
_connectors
:
list
[
KVConnectorBase_V1
]
=
[]
self
.
_ktc_kv_transfer_config
=
[]
ktcs
=
vllm_config
.
kv_transfer_config
.
kv_connector_extra_config
.
get
(
"connectors"
)
assert
ktcs
is
not
None
...
...
@@ -57,6 +97,7 @@ class MultiConnector(KVConnectorBase_V1):
**
ktc
,
engine_id
=
engine_id
)
self
.
_connectors
.
append
(
KVConnectorFactory
.
create_connector
(
temp_config
,
role
))
self
.
_ktc_kv_transfer_config
.
append
(
temp_config
.
kv_transfer_config
)
# A mapping from request id to the index of the connector chosen to
# load the request from (if any).
...
...
@@ -227,7 +268,7 @@ class MultiConnector(KVConnectorBase_V1):
return
async_saves
>
0
,
kv_txfer_params
def
take_events
(
self
)
->
Iterable
[
KVCacheEvent
]:
def
take_events
(
self
)
->
Iterable
[
"
KVCacheEvent
"
]:
for
c
in
self
.
_connectors
:
yield
from
c
.
take_events
()
...
...
@@ -264,3 +305,24 @@ class MultiConnector(KVConnectorBase_V1):
f
"(
{
', '
.
join
(
layouts
)
}
)."
f
"All connectors must use the same layout."
)
return
next
(
iter
(
layouts
),
None
)
@
classmethod
def
build_kv_connector_stats
(
cls
,
data
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
Optional
[
KVConnectorStats
]:
return
MultiKVConnectorStats
(
data
=
data
)
if
data
is
not
None
\
else
MultiKVConnectorStats
()
def
get_kv_connector_stats
(
self
)
->
Optional
[
MultiKVConnectorStats
]:
# Group connector stats by connector type.
stats_by_connector
:
Optional
[
MultiKVConnectorStats
]
=
None
for
c
in
self
.
_connectors
:
stats
=
c
.
get_kv_connector_stats
()
if
stats
is
None
:
continue
if
stats_by_connector
is
None
:
# Lazy init to allow optional return value.
stats_by_connector
=
MultiKVConnectorStats
()
stats_by_connector
[
c
.
__class__
.
__name__
]
=
stats
return
stats_by_connector
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
a3d087ad
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
contextlib
import
copy
import
logging
import
math
import
queue
...
...
@@ -11,7 +12,7 @@ from collections import defaultdict
from
collections.abc
import
Iterator
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
import
msgspec
import
numpy
as
np
...
...
@@ -23,6 +24,8 @@ from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
CopyBlocksOp
,
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorStats
)
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tp_group
)
...
...
@@ -33,7 +36,6 @@ from vllm.platforms import _Backend, current_platform
from
vllm.utils
import
make_zmq_path
,
make_zmq_socket
from
vllm.v1.attention.backends.utils
import
get_kv_cache_layout
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.request
import
RequestStatus
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
...
...
@@ -206,6 +208,18 @@ class NixlConnector(KVConnectorBase_V1):
assert
self
.
connector_worker
is
not
None
return
self
.
connector_worker
.
get_finished
()
def
get_kv_connector_stats
(
self
)
->
Optional
[
KVConnectorStats
]:
assert
self
.
connector_worker
is
not
None
return
self
.
connector_worker
.
get_kv_connector_stats
()
@
classmethod
def
build_kv_connector_stats
(
cls
,
data
:
Optional
[
dict
[
str
,
Any
]]
=
None
)
->
Optional
[
KVConnectorStats
]:
return
NixlKVConnectorStats
(
data
=
data
)
if
data
is
not
None
\
else
NixlKVConnectorStats
()
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
**
kwargs
)
->
None
:
assert
self
.
connector_worker
is
not
None
...
...
@@ -377,6 +391,7 @@ class NixlConnectorScheduler:
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
from
vllm.v1.request
import
RequestStatus
params
=
request
.
kv_transfer_params
logger
.
debug
(
...
...
@@ -550,6 +565,7 @@ class NixlConnectorWorker:
# With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks.
self
.
consumer_notification_counts_by_req
=
defaultdict
[
ReqId
,
int
](
int
)
self
.
xfer_stats
=
NixlKVConnectorStats
()
def
__del__
(
self
):
"""Cleanup background threads on destruction."""
...
...
@@ -1097,6 +1113,8 @@ class NixlConnectorWorker:
xfer_state
=
self
.
nixl_wrapper
.
check_xfer_state
(
handle
)
if
xfer_state
==
"DONE"
:
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
# TODO (NickLucche) Get from NIXL telemetry once integrated
self
.
xfer_stats
.
record_transfer
()
elif
xfer_state
==
"PROC"
:
in_progress
=
True
continue
...
...
@@ -1248,7 +1266,6 @@ class NixlConnectorWorker:
self
.
nixl_wrapper
.
transfer
(
handle
)
# Use handle to check completion in future step().
# TODO (NickLucche) surface xfer elapsed time
self
.
_recving_transfers
[
request_id
].
append
(
(
handle
,
time
.
perf_counter
()))
...
...
@@ -1300,6 +1317,15 @@ class NixlConnectorWorker:
block_len
=
self
.
block_len
return
block_len
def
get_kv_connector_stats
(
self
)
->
Optional
[
KVConnectorStats
]:
"""
Get the KV transfer stats for the connector.
"""
# Clear stats for next iteration
if
not
self
.
xfer_stats
.
is_empty
():
return
self
.
xfer_stats
.
clone_and_reset
()
return
None
@
contextlib
.
contextmanager
def
zmq_ctx
(
socket_type
:
Any
,
addr
:
str
)
->
Iterator
[
zmq
.
Socket
]:
...
...
@@ -1318,3 +1344,39 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
finally
:
if
ctx
is
not
None
:
ctx
.
destroy
(
linger
=
0
)
@
dataclass
class
NixlKVConnectorStats
(
KVConnectorStats
):
"""Container for transfer performance metrics"""
def
__post_init__
(
self
):
if
"num_successful_transfers"
not
in
self
.
data
:
self
.
data
[
"num_successful_transfers"
]
=
0
def
reset
(
self
):
self
.
data
=
{
"num_successful_transfers"
:
0
}
def
record_transfer
(
self
):
# TODO: record actual transfer stats when available
self
.
data
[
"num_successful_transfers"
]
+=
1
def
clone_and_reset
(
self
)
->
"NixlKVConnectorStats"
:
old
=
copy
.
copy
(
self
)
self
.
reset
()
return
old
def
is_empty
(
self
)
->
bool
:
return
self
.
data
[
"num_successful_transfers"
]
==
0
def
aggregate
(
self
,
other
:
KVConnectorStats
)
->
KVConnectorStats
:
if
not
other
.
is_empty
():
self
.
data
[
"num_successful_transfers"
]
+=
other
.
data
[
"num_successful_transfers"
]
return
self
def
reduce
(
self
)
->
dict
[
str
,
Union
[
int
,
float
]]:
# TODO: reduce stats to a single value, calculate latency/throughput
return
{
"num_successful_transfers"
:
self
.
data
[
"num_successful_transfers"
]
}
\ No newline at end of file
vllm/v1/core/sched/scheduler.py
View file @
a3d087ad
...
...
@@ -15,6 +15,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory
)
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
KVConnectorRole
)
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorStats
)
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalRegistry
from
vllm.v1.core.encoder_cache_manager
import
(
EncoderCacheManager
,
...
...
@@ -869,9 +871,12 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
pooler_outputs
=
model_runner_output
.
pooler_output
num_nans_in_logits
=
model_runner_output
.
num_nans_in_logits
kv_connector_output
=
model_runner_output
.
kv_connector_output
outputs
:
dict
[
int
,
list
[
EngineCoreOutput
]]
=
defaultdict
(
list
)
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
kv_connector_stats
=
(
kv_connector_output
.
kv_connector_stats
if
kv_connector_output
else
None
)
# NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more,
# the below loop can be a performance bottleneck. We should do our best
...
...
@@ -1007,7 +1012,8 @@ class Scheduler(SchedulerInterface):
finished_requests
=
finished_set
)
finished_req_ids
.
clear
()
if
(
stats
:
=
self
.
make_stats
(
spec_decoding_stats
))
is
not
None
:
if
(
stats
:
=
self
.
make_stats
(
spec_decoding_stats
,
kv_connector_stats
))
is
not
None
:
# Return stats to only one of the front-ends.
if
(
eco
:
=
next
(
iter
(
engine_core_outputs
.
values
()),
None
))
is
None
:
# We must return the stats even if there are no request
...
...
@@ -1172,20 +1178,21 @@ class Scheduler(SchedulerInterface):
def
make_stats
(
self
,
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
,
kv_connector_stats
:
Optional
[
KVConnectorStats
]
=
None
,
)
->
Optional
[
SchedulerStats
]:
if
not
self
.
log_stats
:
return
None
prefix_cache_stats
=
self
.
kv_cache_manager
.
make_prefix_cache_stats
()
assert
prefix_cache_stats
is
not
None
return
SchedulerStats
(
num_
runn
ing_reqs
=
len
(
self
.
runn
ing
),
num_waiting_reqs
=
len
(
self
.
waiting
)
,
kv_cache_usage
=
self
.
kv_cache_manager
.
usage
,
prefix_cache_stats
=
prefix_cache
_stats
,
spec_decoding_stats
=
spec_decoding_stats
,
num_corrupted_reqs
=
sum
(
req
.
is_output_corrupted
for
req
in
self
.
running
),
)
return
SchedulerStats
(
num_running_reqs
=
len
(
self
.
running
),
num_
wait
ing_reqs
=
len
(
self
.
wait
ing
),
kv_cache_usage
=
self
.
kv_cache_manager
.
usage
,
prefix_cache_stats
=
prefix_cache_stats
,
spec_decoding_stats
=
spec_decoding
_stats
,
num_corrupted_reqs
=
sum
(
req
.
is_output_corrupted
for
req
in
self
.
running
),
kv_connector_stats
=
kv_connector_stats
.
data
if
kv_connector_stats
else
None
)
def
make_spec_decoding_stats
(
self
,
...
...
vllm/v1/metrics/loggers.py
View file @
a3d087ad
...
...
@@ -9,6 +9,8 @@ from typing import Callable, Optional, Union
import
prometheus_client
from
vllm.config
import
SupportsMetricsInfo
,
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorLogging
)
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
PrefixCachingMetrics
from
vllm.v1.engine
import
FinishReason
...
...
@@ -59,6 +61,8 @@ class LoggingStatLogger(StatLoggerBase):
# TODO: Make the interval configurable.
self
.
prefix_caching_metrics
=
PrefixCachingMetrics
()
self
.
spec_decoding_logging
=
SpecDecodingLogging
()
kv_tranfer_config
=
self
.
vllm_config
.
kv_transfer_config
self
.
kv_transfer_logging
=
KVConnectorLogging
(
kv_tranfer_config
)
self
.
last_prompt_throughput
:
float
=
0.0
self
.
last_generation_throughput
:
float
=
0.0
...
...
@@ -97,7 +101,8 @@ class LoggingStatLogger(StatLoggerBase):
if
scheduler_stats
.
spec_decoding_stats
is
not
None
:
self
.
spec_decoding_logging
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
if
kv_connector_stats
:
=
scheduler_stats
.
kv_connector_stats
:
self
.
kv_transfer_logging
.
observe
(
kv_connector_stats
)
self
.
last_scheduler_stats
=
scheduler_stats
def
log
(
self
):
...
...
@@ -136,6 +141,7 @@ class LoggingStatLogger(StatLoggerBase):
self
.
prefix_caching_metrics
.
hit_rate
*
100
,
)
self
.
spec_decoding_logging
.
log
(
log_fn
=
log_fn
)
self
.
kv_transfer_logging
.
log
(
log_fn
=
log_fn
)
def
log_engine_initialized
(
self
):
if
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
:
...
...
vllm/v1/metrics/stats.py
View file @
a3d087ad
...
...
@@ -3,7 +3,7 @@
import
time
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
...
...
@@ -43,6 +43,7 @@ class SchedulerStats:
default_factory
=
PrefixCacheStats
)
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
kv_connector_stats
:
Optional
[
dict
[
str
,
Any
]]
=
None
num_corrupted_reqs
:
int
=
0
...
...
vllm/v1/outputs.py
View file @
a3d087ad
...
...
@@ -3,10 +3,14 @@
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
NamedTuple
,
Optional
from
typing
import
TYPE_CHECKING
,
NamedTuple
,
Optional
import
torch
if
TYPE_CHECKING
:
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorStats
)
class
LogprobsLists
(
NamedTuple
):
...
...
@@ -77,6 +81,11 @@ class KVConnectorOutput:
# [req_ids]
finished_sending
:
Optional
[
set
[
str
]]
=
None
finished_recving
:
Optional
[
set
[
str
]]
=
None
kv_connector_stats
:
Optional
[
"KVConnectorStats"
]
=
None
def
is_empty
(
self
):
return
(
not
self
.
finished_sending
and
not
self
.
finished_recving
and
not
self
.
kv_connector_stats
)
# ModelRunnerOutput is serialized and sent to the scheduler process.
...
...
vllm/v1/worker/kv_connector_model_runner_mixin.py
View file @
a3d087ad
...
...
@@ -13,6 +13,8 @@ from vllm.distributed.kv_transfer import (ensure_kv_transfer_shutdown,
get_kv_transfer_group
,
has_kv_transfer_group
)
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBase
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorStats
)
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
KVConnectorOutput
,
...
...
@@ -119,4 +121,11 @@ class KVConnectorModelRunnerMixin:
output
.
finished_sending
,
output
.
finished_recving
=
(
kv_connector
.
get_finished
(
scheduler_output
.
finished_req_ids
))
kv_connector
.
clear_connector_metadata
()
output
.
kv_connector_stats
=
KVConnectorModelRunnerMixin
.
\
get_kv_connector_stats
()
@
staticmethod
def
get_kv_connector_stats
()
->
Optional
[
KVConnectorStats
]:
if
has_kv_transfer_group
():
return
get_kv_transfer_group
().
get_kv_connector_stats
()
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment