Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
48f30902
Unverified
Commit
48f30902
authored
Oct 03, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Oct 03, 2025
Browse files
[NIXL][Misc] Expose metrics from NIXL for logging to CLI (#25388)
Signed-off-by:
NickLucche
<
nlucches@redhat.com
>
parent
0e93ac0b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
127 additions
and
28 deletions
+127
-28
requirements/kv_connectors.txt
requirements/kv_connectors.txt
+1
-1
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+54
-11
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+69
-13
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+3
-3
No files found.
requirements/kv_connectors.txt
View file @
48f30902
lmcache
lmcache
nixl >= 0.
5.1
# Required for disaggregated prefill
nixl >= 0.
6.0
# Required for disaggregated prefill
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
48f30902
...
@@ -57,6 +57,26 @@ def clear_kv_transfer():
...
@@ -57,6 +57,26 @@ def clear_kv_transfer():
ensure_kv_transfer_shutdown
()
ensure_kv_transfer_shutdown
()
def
get_default_xfer_telemetry
(
xferDurationS
:
float
=
1
,
postDurationS
:
float
=
1
,
totalBytes
:
int
=
1
,
descCount
:
int
=
1
)
->
dict
:
class
AttributeDict
(
dict
):
__slots__
=
()
__getattr__
=
dict
.
__getitem__
__setattr__
=
dict
.
__setitem__
# type: ignore[assignment]
# We can't instantiate nixlXferTelemetry because it's read only and
# ray env does not have NIXL, so we must fake it
return
AttributeDict
(
xferDuration
=
xferDurationS
*
1e6
,
# in us
postDuration
=
postDurationS
*
1e6
,
# in us
totalBytes
=
totalBytes
,
descCount
=
descCount
,
)
class
FakeNixlWrapper
:
class
FakeNixlWrapper
:
"""Mock implementation of NixlWrapper for testing.
"""Mock implementation of NixlWrapper for testing.
...
@@ -132,6 +152,9 @@ class FakeNixlWrapper:
...
@@ -132,6 +152,9 @@ class FakeNixlWrapper:
def
transfer
(
self
,
handle
:
int
)
->
str
:
def
transfer
(
self
,
handle
:
int
)
->
str
:
return
"PROC"
return
"PROC"
def
get_xfer_telemetry
(
self
,
handle
:
int
)
->
dict
:
return
get_default_xfer_telemetry
()
############################################################
############################################################
# Follow are for changing the behavior during testing.
# Follow are for changing the behavior during testing.
############################################################
############################################################
...
@@ -169,6 +192,11 @@ nixl_agent = FakeNixlWrapper
...
@@ -169,6 +192,11 @@ nixl_agent = FakeNixlWrapper
with
open
(
os
.
path
.
join
(
pkg_root
,
"__init__.py"
),
"w"
)
as
f
:
with
open
(
os
.
path
.
join
(
pkg_root
,
"__init__.py"
),
"w"
)
as
f
:
f
.
write
(
stub
)
f
.
write
(
stub
)
# Mock nixlXferTelemetry class
pkg_root2
=
os
.
path
.
join
(
td
,
"nixl"
,
"_bindings"
)
os
.
makedirs
(
pkg_root2
,
exist_ok
=
True
)
with
open
(
os
.
path
.
join
(
pkg_root2
,
"__init__.py"
),
"w"
)
as
f
:
f
.
write
(
"class nixlXferTelemetry: pass"
)
# touch parent package
# touch parent package
open
(
os
.
path
.
join
(
td
,
"nixl"
,
"__init__.py"
),
"w"
).
close
()
open
(
os
.
path
.
join
(
td
,
"nixl"
,
"__init__.py"
),
"w"
).
close
()
yield
td
yield
td
...
@@ -575,7 +603,7 @@ def test_kv_connector_stats(dist_init):
...
@@ -575,7 +603,7 @@ def test_kv_connector_stats(dist_init):
# Verify stats values are recorded
# Verify stats values are recorded
assert
not
stats_after_transfer
.
is_empty
()
assert
not
stats_after_transfer
.
is_empty
()
assert
stats_after_transfer
.
data
[
"
num_successful_transfers
"
]
==
1
assert
stats_after_transfer
.
num_successful_transfers
==
1
# Verify stats are reset after retrieval
# Verify stats are reset after retrieval
stats_after_reset
=
connector
.
get_kv_connector_stats
()
stats_after_reset
=
connector
.
get_kv_connector_stats
()
...
@@ -599,16 +627,21 @@ def test_kv_connector_stats_aggregation():
...
@@ -599,16 +627,21 @@ def test_kv_connector_stats_aggregation():
# Record different transfers on each worker
# Record different transfers on each worker
# Worker 1: 2 transfers
# Worker 1: 2 transfers
worker1_stats
.
record_transfer
()
stats
=
get_default_xfer_telemetry
()
worker1_stats
.
record_transfer
()
worker1_stats
.
record_transfer
(
stats
)
worker1_stats
.
record_transfer
(
stats
)
# Worker 2: 1 transfer
# Worker 2: 1 transfer
worker2_stats
.
record_transfer
()
worker2_stats
.
record_transfer
(
stats
)
# Worker 3: 3 transfers
# Worker 3: 3 transfers
worker3_stats
.
record_transfer
()
stats
=
get_default_xfer_telemetry
(
xferDurationS
=
2
,
worker3_stats
.
record_transfer
()
postDurationS
=
2
,
worker3_stats
.
record_transfer
()
totalBytes
=
2
,
descCount
=
2
)
worker3_stats
.
record_transfer
(
stats
)
worker3_stats
.
record_transfer
(
stats
)
worker3_stats
.
record_transfer
(
stats
)
# Create ModelRunnerOutput instances for each worker
# Create ModelRunnerOutput instances for each worker
worker_outputs
=
[]
worker_outputs
=
[]
...
@@ -636,7 +669,12 @@ def test_kv_connector_stats_aggregation():
...
@@ -636,7 +669,12 @@ def test_kv_connector_stats_aggregation():
aggregated_output
.
kv_connector_output
.
kv_connector_stats
aggregated_output
.
kv_connector_output
.
kv_connector_stats
assert
isinstance
(
kv_connector_stats
,
NixlKVConnectorStats
)
assert
isinstance
(
kv_connector_stats
,
NixlKVConnectorStats
)
# Number of total transfers across all workers.
# Number of total transfers across all workers.
assert
kv_connector_stats
.
data
[
"num_successful_transfers"
]
==
6
assert
kv_connector_stats
.
num_successful_transfers
==
6
# Logging proc, call reduce() to get CLI-friendly stats.
cli_stats
=
kv_connector_stats
.
reduce
()
assert
cli_stats
[
"Avg xfer time (ms)"
]
==
1500.0
assert
cli_stats
[
"Avg post time (ms)"
]
==
1500.0
assert
cli_stats
[
"Avg number of descriptors"
]
==
1.5
def
test_multi_kv_connector_stats_aggregation
():
def
test_multi_kv_connector_stats_aggregation
():
...
@@ -649,6 +687,7 @@ def test_multi_kv_connector_stats_aggregation():
...
@@ -649,6 +687,7 @@ def test_multi_kv_connector_stats_aggregation():
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
# Mock a KVConnectorStats class for testing aggregation over connectors.
@
dataclass
@
dataclass
class
FooKVConnectorStats
(
KVConnectorStats
):
class
FooKVConnectorStats
(
KVConnectorStats
):
...
@@ -676,7 +715,7 @@ def test_multi_kv_connector_stats_aggregation():
...
@@ -676,7 +715,7 @@ def test_multi_kv_connector_stats_aggregation():
if
nixl_count
>
0
:
if
nixl_count
>
0
:
nixl_stats
=
NixlKVConnectorStats
()
nixl_stats
=
NixlKVConnectorStats
()
for
_
in
range
(
nixl_count
):
for
_
in
range
(
nixl_count
):
nixl_stats
.
record_transfer
()
nixl_stats
.
record_transfer
(
get_default_xfer_telemetry
()
)
data
[
"NixlConnector"
]
=
nixl_stats
data
[
"NixlConnector"
]
=
nixl_stats
if
foo_count
>
0
:
if
foo_count
>
0
:
foo_stats
=
FooKVConnectorStats
()
foo_stats
=
FooKVConnectorStats
()
...
@@ -712,8 +751,10 @@ def test_multi_kv_connector_stats_aggregation():
...
@@ -712,8 +751,10 @@ def test_multi_kv_connector_stats_aggregation():
assert
isinstance
(
kv_connector_stats
,
MultiKVConnectorStats
)
assert
isinstance
(
kv_connector_stats
,
MultiKVConnectorStats
)
# Validate per-connector totals across workers
# Validate per-connector totals across workers
assert
kv_connector_stats
[
"NixlConnector"
].
data
[
assert
isinstance
(
kv_connector_stats
[
"NixlConnector"
],
"num_successful_transfers"
]
==
5
NixlKVConnectorStats
)
assert
kv_connector_stats
[
"NixlConnector"
].
num_successful_transfers
==
5
assert
isinstance
(
kv_connector_stats
[
"FooConnector"
],
FooKVConnectorStats
)
assert
kv_connector_stats
[
"FooConnector"
].
data
[
"num_foo_transfers"
]
==
6
assert
kv_connector_stats
[
"FooConnector"
].
data
[
"num_foo_transfers"
]
==
6
...
@@ -755,6 +796,8 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
...
@@ -755,6 +796,8 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
"working_dir"
:
working_dir
,
# ship fake nixl package
"working_dir"
:
working_dir
,
# ship fake nixl package
"env_vars"
:
{
"env_vars"
:
{
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT"
:
str
(
timeout
),
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT"
:
str
(
timeout
),
# TODO: for ray to carry over, remove once we set
"NIXL_TELEMETRY_ENABLE"
:
"1"
,
},
},
}
}
ray
.
init
(
runtime_env
=
runtime_env
)
ray
.
init
(
runtime_env
=
runtime_env
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
48f30902
...
@@ -4,6 +4,7 @@ import contextlib
...
@@ -4,6 +4,7 @@ import contextlib
import
copy
import
copy
import
logging
import
logging
import
math
import
math
import
os
import
queue
import
queue
import
threading
import
threading
import
time
import
time
...
@@ -54,10 +55,12 @@ logger = init_logger(__name__)
...
@@ -54,10 +55,12 @@ logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try
:
try
:
from
nixl._api
import
nixl_agent
as
NixlWrapper
from
nixl._api
import
nixl_agent
as
NixlWrapper
from
nixl._bindings
import
nixlXferTelemetry
logger
.
info
(
"NIXL is available"
)
logger
.
info
(
"NIXL is available"
)
except
ImportError
:
except
ImportError
:
logger
.
warning
(
"NIXL is not available"
)
logger
.
warning
(
"NIXL is not available"
)
NixlWrapper
=
None
NixlWrapper
=
None
nixlXferTelemetry
=
None
try
:
try
:
from
nixl._api
import
nixl_agent_config
from
nixl._api
import
nixl_agent_config
...
@@ -476,6 +479,9 @@ class NixlConnectorWorker:
...
@@ -476,6 +479,9 @@ class NixlConnectorWorker:
self
.
nixl_backends
=
\
self
.
nixl_backends
=
\
vllm_config
.
kv_transfer_config
.
get_from_extra_config
(
vllm_config
.
kv_transfer_config
.
get_from_extra_config
(
"backends"
,
[
"UCX"
])
"backends"
,
[
"UCX"
])
# TODO temporary, once nixl allows for telemetry flag in config
# (next release), we can remove this env var.
os
.
environ
[
"NIXL_TELEMETRY_ENABLE"
]
=
"1"
# Agent.
# Agent.
non_ucx_backends
=
[
b
for
b
in
self
.
nixl_backends
if
b
!=
"UCX"
]
non_ucx_backends
=
[
b
for
b
in
self
.
nixl_backends
if
b
!=
"UCX"
]
if
nixl_agent_config
is
None
:
if
nixl_agent_config
is
None
:
...
@@ -1175,9 +1181,10 @@ class NixlConnectorWorker:
...
@@ -1175,9 +1181,10 @@ class NixlConnectorWorker:
for
handle
,
_xfer_stime
in
handles
:
for
handle
,
_xfer_stime
in
handles
:
xfer_state
=
self
.
nixl_wrapper
.
check_xfer_state
(
handle
)
xfer_state
=
self
.
nixl_wrapper
.
check_xfer_state
(
handle
)
if
xfer_state
==
"DONE"
:
if
xfer_state
==
"DONE"
:
# Get telemetry from NIXL
res
=
self
.
nixl_wrapper
.
get_xfer_telemetry
(
handle
)
self
.
xfer_stats
.
record_transfer
(
res
)
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
# TODO (NickLucche) Get from NIXL telemetry once integrated
self
.
xfer_stats
.
record_transfer
()
elif
xfer_state
==
"PROC"
:
elif
xfer_state
==
"PROC"
:
in_progress
=
True
in_progress
=
True
continue
continue
...
@@ -1449,15 +1456,25 @@ class NixlKVConnectorStats(KVConnectorStats):
...
@@ -1449,15 +1456,25 @@ class NixlKVConnectorStats(KVConnectorStats):
"""Container for transfer performance metrics"""
"""Container for transfer performance metrics"""
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
"num_successful_transfers"
not
in
self
.
data
:
if
not
self
.
data
:
self
.
data
[
"num_successful_transfers"
]
=
0
# Empty container init, no data is passed in.
self
.
reset
()
def
reset
(
self
):
def
reset
(
self
):
self
.
data
=
{
"num_successful_transfers"
:
0
}
# Must be serializable
self
.
data
:
dict
[
str
,
list
[
float
]]
=
{
"transfer_duration"
:
[],
"post_duration"
:
[],
"bytes_transferred"
:
[],
"num_descriptors"
:
[],
}
def
record_transfer
(
self
):
def
record_transfer
(
self
,
res
:
nixlXferTelemetry
):
# TODO: record actual transfer stats when available
# Keep metrics units consistent with rest of the code: time us->s
self
.
data
[
"num_successful_transfers"
]
+=
1
self
.
data
[
"transfer_duration"
].
append
(
res
.
xferDuration
/
1e6
)
self
.
data
[
"post_duration"
].
append
(
res
.
postDuration
/
1e6
)
self
.
data
[
"bytes_transferred"
].
append
(
res
.
totalBytes
)
self
.
data
[
"num_descriptors"
].
append
(
res
.
descCount
)
def
clone_and_reset
(
self
)
->
"NixlKVConnectorStats"
:
def
clone_and_reset
(
self
)
->
"NixlKVConnectorStats"
:
old
=
copy
.
copy
(
self
)
old
=
copy
.
copy
(
self
)
...
@@ -1465,16 +1482,55 @@ class NixlKVConnectorStats(KVConnectorStats):
...
@@ -1465,16 +1482,55 @@ class NixlKVConnectorStats(KVConnectorStats):
return
old
return
old
def
is_empty
(
self
)
->
bool
:
def
is_empty
(
self
)
->
bool
:
return
self
.
data
[
"
num_successful_transfers
"
]
==
0
return
self
.
num_successful_transfers
==
0
def
aggregate
(
self
,
other
:
KVConnectorStats
)
->
KVConnectorStats
:
def
aggregate
(
self
,
other
:
KVConnectorStats
)
->
KVConnectorStats
:
if
not
other
.
is_empty
():
if
not
other
.
is_empty
():
self
.
data
[
"num_successful_transfers"
]
+=
other
.
data
[
for
k
,
v
in
other
.
data
.
items
():
"num_successful_transfers"
]
accumulator
=
self
.
data
[
k
]
assert
isinstance
(
accumulator
,
list
)
accumulator
.
extend
(
v
)
return
self
return
self
def
reduce
(
self
)
->
dict
[
str
,
Union
[
int
,
float
]]:
def
reduce
(
self
)
->
dict
[
str
,
Union
[
int
,
float
]]:
# TODO: reduce stats to a single value, calculate latency/throughput
# Compute compact representative stats suitable for CLI logging
if
self
.
is_empty
():
return
{
return
{
"num_successful_transfers"
:
self
.
data
[
"num_successful_transfers"
]
"Num successful transfers"
:
0
,
"Avg xfer time (ms)"
:
0
,
"P90 xfer time (ms)"
:
0
,
"Avg post time (ms)"
:
0
,
"P90 post time (ms)"
:
0
,
"Avg MB per transfer"
:
0
,
"Throughput (MB/s)"
:
0
,
"Avg number of descriptors"
:
0
,
}
}
xfer_time
=
np
.
asarray
(
self
.
data
[
"transfer_duration"
])
post_time
=
np
.
asarray
(
self
.
data
[
"post_duration"
])
# Convert to MB for CLI logging.
mb
=
np
.
asarray
(
self
.
data
[
"bytes_transferred"
])
/
2
**
20
descs
=
np
.
asarray
(
self
.
data
[
"num_descriptors"
],
dtype
=
np
.
uint32
)
n
=
len
(
descs
)
assert
n
==
self
.
num_successful_transfers
total_mb
=
mb
.
sum
()
avg_mb
=
total_mb
/
n
total_time_seconds
=
xfer_time
.
sum
()
throughput_mb_s
=
total_mb
/
total_time_seconds
return
{
"Num successful transfers"
:
n
,
"Avg xfer time (ms)"
:
round
(
xfer_time
.
mean
()
*
1e3
,
3
),
"P90 xfer time (ms)"
:
round
(
np
.
percentile
(
xfer_time
,
90
)
*
1e3
,
3
),
"Avg post time (ms)"
:
round
(
post_time
.
mean
()
*
1e3
,
3
),
"P90 post time (ms)"
:
round
(
np
.
percentile
(
post_time
,
90
)
*
1e3
,
3
),
"Avg MB per transfer"
:
round
(
avg_mb
,
3
),
"Throughput (MB/s)"
:
round
(
throughput_mb_s
,
3
),
"Avg number of descriptors"
:
round
(
descs
.
mean
(),
1
),
}
@
property
def
num_successful_transfers
(
self
)
->
int
:
return
len
(
self
.
data
[
"transfer_duration"
])
\ No newline at end of file
vllm/v1/metrics/loggers.py
View file @
48f30902
...
@@ -62,7 +62,7 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -62,7 +62,7 @@ class LoggingStatLogger(StatLoggerBase):
self
.
prefix_caching_metrics
=
PrefixCachingMetrics
()
self
.
prefix_caching_metrics
=
PrefixCachingMetrics
()
self
.
spec_decoding_logging
=
SpecDecodingLogging
()
self
.
spec_decoding_logging
=
SpecDecodingLogging
()
kv_tranfer_config
=
self
.
vllm_config
.
kv_transfer_config
kv_tranfer_config
=
self
.
vllm_config
.
kv_transfer_config
self
.
kv_
transfe
r_logging
=
KVConnectorLogging
(
kv_tranfer_config
)
self
.
kv_
connecto
r_logging
=
KVConnectorLogging
(
kv_tranfer_config
)
self
.
last_prompt_throughput
:
float
=
0.0
self
.
last_prompt_throughput
:
float
=
0.0
self
.
last_generation_throughput
:
float
=
0.0
self
.
last_generation_throughput
:
float
=
0.0
...
@@ -101,7 +101,7 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -101,7 +101,7 @@ class LoggingStatLogger(StatLoggerBase):
self
.
spec_decoding_logging
.
observe
(
self
.
spec_decoding_logging
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
scheduler_stats
.
spec_decoding_stats
)
if
kv_connector_stats
:
=
scheduler_stats
.
kv_connector_stats
:
if
kv_connector_stats
:
=
scheduler_stats
.
kv_connector_stats
:
self
.
kv_
transfe
r_logging
.
observe
(
kv_connector_stats
)
self
.
kv_
connecto
r_logging
.
observe
(
kv_connector_stats
)
self
.
last_scheduler_stats
=
scheduler_stats
self
.
last_scheduler_stats
=
scheduler_stats
def
log
(
self
):
def
log
(
self
):
...
@@ -140,7 +140,7 @@ class LoggingStatLogger(StatLoggerBase):
...
@@ -140,7 +140,7 @@ class LoggingStatLogger(StatLoggerBase):
self
.
prefix_caching_metrics
.
hit_rate
*
100
,
self
.
prefix_caching_metrics
.
hit_rate
*
100
,
)
)
self
.
spec_decoding_logging
.
log
(
log_fn
=
log_fn
)
self
.
spec_decoding_logging
.
log
(
log_fn
=
log_fn
)
self
.
kv_
transfe
r_logging
.
log
(
log_fn
=
log_fn
)
self
.
kv_
connecto
r_logging
.
log
(
log_fn
=
log_fn
)
def
log_engine_initialized
(
self
):
def
log_engine_initialized
(
self
):
if
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
:
if
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment