Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
69520bc6
Unverified
Commit
69520bc6
authored
Dec 02, 2025
by
Yong Hoon Shin
Committed by
GitHub
Dec 03, 2025
Browse files
Add logging for cudagraph related info (#29825)
Signed-off-by:
Yong Hoon Shin
<
yhshin@meta.com
>
parent
3a775148
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
161 additions
and
6 deletions
+161
-6
vllm/compilation/cuda_graph.py
vllm/compilation/cuda_graph.py
+94
-0
vllm/config/observability.py
vllm/config/observability.py
+4
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+6
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+7
-1
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+14
-0
vllm/v1/metrics/stats.py
vllm/v1/metrics/stats.py
+3
-0
vllm/v1/outputs.py
vllm/v1/outputs.py
+4
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+28
-4
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+1
-1
No files found.
vllm/compilation/cuda_graph.py
View file @
69520bc6
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
from
collections
import
Counter
from
collections.abc
import
Callable
from
contextlib
import
ExitStack
from
typing
import
Any
...
...
@@ -22,6 +23,99 @@ from vllm.utils.torch_utils import weak_ref_tensors
logger
=
init_logger
(
__name__
)
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
CUDAGraphStat
:
num_unpadded_tokens
:
int
num_padded_tokens
:
int
num_paddings
:
int
runtime_mode
:
str
class
CUDAGraphLogging
:
"""Aggregate and log cudagraph metrics"""
COLUMN_HEADERS
=
[
"Unpadded Tokens"
,
"Padded Tokens"
,
"Num Paddings"
,
"Runtime Mode"
,
"Count"
,
]
def
__init__
(
self
,
cg_mode
:
CUDAGraphMode
,
cg_capture_sizes
:
list
[
int
]
|
None
):
self
.
reset
()
self
.
cg_mode
=
str
(
cg_mode
)
self
.
cg_capture_sizes
=
str
(
cg_capture_sizes
or
[])
self
.
settings_header
=
(
"**CUDAGraph Config Settings:**
\n\n
"
f
"- Mode:
{
self
.
cg_mode
}
\n
"
f
"- Capture sizes:
{
self
.
cg_capture_sizes
}
\n\n
"
"**CUDAGraph Stats:**
\n\n
"
)
def
reset
(
self
):
self
.
stats
=
[]
def
observe
(
self
,
cudagraph_stat
:
CUDAGraphStat
):
self
.
stats
.
append
(
cudagraph_stat
)
def
generate_metric_table
(
self
)
->
str
:
stats_counts
=
Counter
(
self
.
stats
)
# Convert stats to rows of strings, in descending order of observed frequencies
rows
=
[]
for
stat
,
count
in
sorted
(
stats_counts
.
items
(),
key
=
lambda
item
:
item
[
1
],
reverse
=
True
):
rows
.
append
(
[
str
(
stat
.
num_unpadded_tokens
),
str
(
stat
.
num_padded_tokens
),
str
(
stat
.
num_paddings
),
stat
.
runtime_mode
,
str
(
count
),
]
)
# Calculate column widths (max of header and data)
col_widths
=
[]
for
i
,
header_text
in
enumerate
(
self
.
COLUMN_HEADERS
):
max_width
=
len
(
header_text
)
for
row
in
rows
:
max_width
=
max
(
max_width
,
len
(
row
[
i
]))
col_widths
.
append
(
max_width
)
table_header_list
=
[
h
.
ljust
(
w
)
for
h
,
w
in
zip
(
self
.
COLUMN_HEADERS
,
col_widths
)
]
table_header
=
"| "
+
" | "
.
join
(
table_header_list
)
+
" |
\n
"
table_separator
=
"|"
+
"|"
.
join
(
"-"
*
(
w
+
2
)
for
w
in
col_widths
)
+
"|
\n
"
# Create data rows with proper alignment
data_rows
=
[]
for
row
in
rows
:
formatted_row
=
[
str
(
val
).
ljust
(
width
)
for
val
,
width
in
zip
(
row
,
col_widths
)
]
data_rows
.
append
(
"| "
+
" | "
.
join
(
formatted_row
)
+
" |"
)
return
(
self
.
settings_header
+
table_header
+
table_separator
+
"
\n
"
.
join
(
data_rows
)
+
"
\n
"
)
def
log
(
self
,
log_fn
=
logger
.
info
):
if
not
self
.
stats
:
return
log_fn
(
self
.
generate_metric_table
())
self
.
reset
()
@
dataclasses
.
dataclass
class
CUDAGraphEntry
:
batch_descriptor
:
BatchDescriptor
...
...
vllm/config/observability.py
View file @
69520bc6
...
...
@@ -55,6 +55,10 @@ class ObservabilityConfig:
kv_cache_metrics_sample
:
float
=
Field
(
default
=
0.01
,
gt
=
0
,
le
=
1
)
"""Sampling rate for KV cache metrics (0.0, 1.0]. Default 0.01 = 1% of blocks."""
cudagraph_metrics
:
bool
=
False
"""Enable CUDA graph metrics (number of padded/unpadded tokens, runtime cudagraph
dispatch modes, and their observed frequencies at every logging interval)."""
@
cached_property
def
collect_model_forward_time
(
self
)
->
bool
:
"""Whether to collect model forward time for the request."""
...
...
vllm/engine/arg_utils.py
View file @
69520bc6
...
...
@@ -518,6 +518,7 @@ class EngineArgs:
kv_cache_metrics_sample
:
float
=
get_field
(
ObservabilityConfig
,
"kv_cache_metrics_sample"
)
cudagraph_metrics
:
bool
=
ObservabilityConfig
.
cudagraph_metrics
scheduling_policy
:
SchedulerPolicy
=
SchedulerConfig
.
policy
scheduler_cls
:
str
|
type
[
object
]
|
None
=
SchedulerConfig
.
scheduler_cls
...
...
@@ -1021,6 +1022,10 @@ class EngineArgs:
"--kv-cache-metrics-sample"
,
**
observability_kwargs
[
"kv_cache_metrics_sample"
],
)
observability_group
.
add_argument
(
"--cudagraph-metrics"
,
**
observability_kwargs
[
"cudagraph_metrics"
],
)
# Scheduler arguments
scheduler_kwargs
=
get_kwargs
(
SchedulerConfig
)
...
...
@@ -1698,6 +1703,7 @@ class EngineArgs:
collect_detailed_traces
=
self
.
collect_detailed_traces
,
kv_cache_metrics
=
self
.
kv_cache_metrics
,
kv_cache_metrics_sample
=
self
.
kv_cache_metrics_sample
,
cudagraph_metrics
=
self
.
cudagraph_metrics
,
)
# Compilation config overrides
...
...
vllm/v1/core/sched/scheduler.py
View file @
69520bc6
...
...
@@ -7,6 +7,7 @@ from collections.abc import Iterable
from
typing
import
Any
from
vllm
import
envs
from
vllm.compilation.cuda_graph
import
CUDAGraphStat
from
vllm.config
import
VllmConfig
from
vllm.distributed.ec_transfer.ec_connector.base
import
(
ECConnectorMetadata
,
...
...
@@ -1037,6 +1038,7 @@ class Scheduler(SchedulerInterface):
pooler_outputs
=
model_runner_output
.
pooler_output
num_nans_in_logits
=
model_runner_output
.
num_nans_in_logits
kv_connector_output
=
model_runner_output
.
kv_connector_output
cudagraph_stats
=
model_runner_output
.
cudagraph_stats
outputs
:
dict
[
int
,
list
[
EngineCoreOutput
]]
=
defaultdict
(
list
)
spec_decoding_stats
:
SpecDecodingStats
|
None
=
None
...
...
@@ -1219,7 +1221,9 @@ class Scheduler(SchedulerInterface):
finished_req_ids
.
clear
()
if
(
stats
:
=
self
.
make_stats
(
spec_decoding_stats
,
kv_connector_stats
)
stats
:
=
self
.
make_stats
(
spec_decoding_stats
,
kv_connector_stats
,
cudagraph_stats
)
)
is
not
None
:
# Return stats to only one of the front-ends.
if
(
eco
:
=
next
(
iter
(
engine_core_outputs
.
values
()),
None
))
is
None
:
...
...
@@ -1420,6 +1424,7 @@ class Scheduler(SchedulerInterface):
self
,
spec_decoding_stats
:
SpecDecodingStats
|
None
=
None
,
kv_connector_stats
:
KVConnectorStats
|
None
=
None
,
cudagraph_stats
:
CUDAGraphStat
|
None
=
None
,
)
->
SchedulerStats
|
None
:
if
not
self
.
log_stats
:
return
None
...
...
@@ -1444,6 +1449,7 @@ class Scheduler(SchedulerInterface):
kv_cache_eviction_events
=
eviction_events
,
spec_decoding_stats
=
spec_stats
,
kv_connector_stats
=
connector_stats_payload
,
cudagraph_stats
=
cudagraph_stats
,
)
def
make_spec_decoding_stats
(
...
...
vllm/v1/metrics/loggers.py
View file @
69520bc6
...
...
@@ -10,6 +10,7 @@ from typing import TypeAlias
from
prometheus_client
import
Counter
,
Gauge
,
Histogram
import
vllm.envs
as
envs
from
vllm.compilation.cuda_graph
import
CUDAGraphLogging
from
vllm.config
import
SupportsMetricsInfo
,
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorLogging
,
...
...
@@ -106,6 +107,12 @@ class LoggingStatLogger(StatLoggerBase):
self
.
spec_decoding_logging
=
SpecDecodingLogging
()
kv_transfer_config
=
self
.
vllm_config
.
kv_transfer_config
self
.
kv_connector_logging
=
KVConnectorLogging
(
kv_transfer_config
)
self
.
cudagraph_logging
=
None
if
self
.
vllm_config
.
observability_config
.
cudagraph_metrics
:
self
.
cudagraph_logging
=
CUDAGraphLogging
(
self
.
vllm_config
.
compilation_config
.
cudagraph_mode
,
self
.
vllm_config
.
compilation_config
.
cudagraph_capture_sizes
,
)
self
.
last_prompt_throughput
:
float
=
0.0
self
.
last_generation_throughput
:
float
=
0.0
self
.
engine_is_idle
=
False
...
...
@@ -161,6 +168,11 @@ class LoggingStatLogger(StatLoggerBase):
self
.
spec_decoding_logging
.
observe
(
scheduler_stats
.
spec_decoding_stats
)
if
kv_connector_stats
:
=
scheduler_stats
.
kv_connector_stats
:
self
.
kv_connector_logging
.
observe
(
kv_connector_stats
)
if
(
self
.
cudagraph_logging
is
not
None
and
scheduler_stats
.
cudagraph_stats
is
not
None
):
self
.
cudagraph_logging
.
observe
(
scheduler_stats
.
cudagraph_stats
)
if
not
self
.
aggregated
:
self
.
last_scheduler_stats
=
scheduler_stats
if
mm_cache_stats
:
...
...
@@ -240,6 +252,8 @@ class LoggingStatLogger(StatLoggerBase):
self
.
spec_decoding_logging
.
log
(
log_fn
=
log_fn
)
self
.
kv_connector_logging
.
log
(
log_fn
=
log_fn
)
if
self
.
cudagraph_logging
is
not
None
:
self
.
cudagraph_logging
.
log
(
log_fn
=
log_fn
)
def
log_engine_initialized
(
self
):
if
self
.
vllm_config
.
cache_config
.
num_gpu_blocks
:
...
...
vllm/v1/metrics/stats.py
View file @
69520bc6
...
...
@@ -7,6 +7,7 @@ from dataclasses import dataclass, field
from
typing
import
TYPE_CHECKING
,
Any
import
vllm.envs
as
envs
from
vllm.compilation.cuda_graph
import
CUDAGraphStat
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
if
TYPE_CHECKING
:
...
...
@@ -183,6 +184,8 @@ class SchedulerStats:
waiting_lora_adapters
:
dict
[
str
,
int
]
=
field
(
default_factory
=
dict
)
running_lora_adapters
:
dict
[
str
,
int
]
=
field
(
default_factory
=
dict
)
cudagraph_stats
:
CUDAGraphStat
|
None
=
None
@
dataclass
class
RequestStateStats
:
...
...
vllm/v1/outputs.py
View file @
69520bc6
...
...
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, NamedTuple
import
numpy
as
np
import
torch
from
vllm.compilation.cuda_graph
import
CUDAGraphStat
from
vllm.v1.core.sched.output
import
SchedulerOutput
if
TYPE_CHECKING
:
...
...
@@ -169,6 +170,9 @@ class ModelRunnerOutput:
# req_id -> num_nans_in_logits
num_nans_in_logits
:
dict
[
str
,
int
]
|
None
=
None
# information related to cudagraph execution
cudagraph_stats
:
CUDAGraphStat
|
None
=
None
# ModelRunnerOutput wrapper for async scheduling.
class
AsyncModelRunnerOutput
(
ABC
):
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
69520bc6
...
...
@@ -27,7 +27,7 @@ from vllm.attention.backends.abstract import (
)
from
vllm.attention.layer
import
Attention
,
MLAAttention
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.cuda_graph
import
CUDAGraphWrapper
from
vllm.compilation.cuda_graph
import
CUDAGraphStat
,
CUDAGraphWrapper
from
vllm.compilation.monitor
import
set_cudagraph_capturing_enabled
from
vllm.config
import
(
CompilationMode
,
...
...
@@ -257,6 +257,7 @@ class ExecuteModelState(NamedTuple):
sample_hidden_states
:
torch
.
Tensor
aux_hidden_states
:
list
[
torch
.
Tensor
]
|
None
ec_connector_output
:
ECConnectorOutput
|
None
cudagraph_stats
:
CUDAGraphStat
|
None
class
GPUModelRunner
(
...
...
@@ -2755,7 +2756,11 @@ class GPUModelRunner(
force_uniform_decode
:
bool
|
None
=
None
,
force_has_lora
:
bool
|
None
=
None
,
)
->
tuple
[
CUDAGraphMode
,
BatchDescriptor
,
UBatchSlices
|
None
,
torch
.
Tensor
|
None
CUDAGraphMode
,
BatchDescriptor
,
UBatchSlices
|
None
,
torch
.
Tensor
|
None
,
CUDAGraphStat
|
None
,
]:
num_tokens_padded
=
self
.
_pad_for_sequence_parallelism
(
num_tokens
)
uniform_decode
=
(
...
...
@@ -2820,7 +2825,22 @@ class GPUModelRunner(
# num_tokens_across_dp will no-longer be valid
assert
batch_descriptor
.
num_tokens
==
num_tokens_padded
return
cudagraph_mode
,
batch_descriptor
,
ubatch_slices
,
num_tokens_across_dp
cudagraph_stats
=
None
if
self
.
vllm_config
.
observability_config
.
cudagraph_metrics
:
cudagraph_stats
=
CUDAGraphStat
(
num_unpadded_tokens
=
num_tokens
,
num_padded_tokens
=
batch_descriptor
.
num_tokens
,
num_paddings
=
batch_descriptor
.
num_tokens
-
num_tokens
,
runtime_mode
=
str
(
cudagraph_mode
),
)
return
(
cudagraph_mode
,
batch_descriptor
,
ubatch_slices
,
num_tokens_across_dp
,
cudagraph_stats
,
)
@
torch
.
inference_mode
()
def
execute_model
(
...
...
@@ -2918,6 +2938,7 @@ class GPUModelRunner(
batch_desc
,
ubatch_slices
,
num_tokens_across_dp
,
cudagraph_stats
,
)
=
self
.
_determine_batch_execution_and_padding
(
num_tokens
=
num_tokens_unpadded
,
num_reqs
=
num_reqs
,
...
...
@@ -3067,6 +3088,7 @@ class GPUModelRunner(
sample_hidden_states
,
aux_hidden_states
,
ec_connector_output
,
cudagraph_stats
,
)
self
.
kv_connector_output
=
kv_connector_output
return
None
...
...
@@ -3102,6 +3124,7 @@ class GPUModelRunner(
sample_hidden_states
,
aux_hidden_states
,
ec_connector_output
,
cudagraph_stats
,
)
=
self
.
execute_model_state
# Clear ephemeral state.
self
.
execute_model_state
=
None
...
...
@@ -3217,6 +3240,7 @@ class GPUModelRunner(
if
self
.
supports_mm_inputs
else
None
,
num_nans_in_logits
=
num_nans_in_logits
,
cudagraph_stats
=
cudagraph_stats
,
)
if
not
self
.
use_async_scheduling
:
...
...
@@ -3937,7 +3961,7 @@ class GPUModelRunner(
num_sampled_tokens
=
np
.
ones
(
num_reqs
,
dtype
=
np
.
int32
)
_cudagraph_mode
,
batch_desc
,
ubatch_slices
,
num_tokens_across_dp
=
(
_cudagraph_mode
,
batch_desc
,
ubatch_slices
,
num_tokens_across_dp
,
_
=
(
self
.
_determine_batch_execution_and_padding
(
num_tokens
=
num_tokens_unpadded
,
num_reqs
=
num_reqs
,
...
...
vllm/v1/worker/gpu_worker.py
View file @
69520bc6
...
...
@@ -564,7 +564,7 @@ class Worker(WorkerBase):
# TODO(lucas): This is pretty gross; ideally we should only ever call
# `_determine_batch_execution_and_padding` once (will get called again
# in `execute_model`) but this requires a larger refactor of PP.
_
,
batch_desc
,
_
,
_
=
(
_
,
batch_desc
,
_
,
_
,
_
=
(
self
.
model_runner
.
_determine_batch_execution_and_padding
(
num_tokens
=
num_scheduled_tokens
,
num_reqs
=
len
(
num_scheduled_tokens_np
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment