Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ebb9cc5f
Unverified
Commit
ebb9cc5f
authored
Mar 07, 2026
by
Matthew Bonanni
Committed by
GitHub
Mar 07, 2026
Browse files
[UX][Startup] Account for CUDA graphs during memory profiling (#30515)
parent
85f50eb4
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
360 additions
and
61 deletions
+360
-61
vllm/compilation/cuda_graph.py
vllm/compilation/cuda_graph.py
+19
-1
vllm/envs.py
vllm/envs.py
+7
-0
vllm/v1/cudagraph_dispatcher.py
vllm/v1/cudagraph_dispatcher.py
+5
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+227
-52
vllm/v1/worker/gpu_ubatch_wrapper.py
vllm/v1/worker/gpu_ubatch_wrapper.py
+11
-2
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+91
-4
No files found.
vllm/compilation/cuda_graph.py
View file @
ebb9cc5f
...
...
@@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
import
weakref
from
collections
import
Counter
from
collections.abc
import
Callable
from
contextlib
import
ExitStack
from
typing
import
Any
from
typing
import
Any
,
ClassVar
from
unittest.mock
import
patch
import
torch
...
...
@@ -162,6 +163,14 @@ class CUDAGraphWrapper:
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
"""
_all_instances
:
ClassVar
[
weakref
.
WeakSet
[
"CUDAGraphWrapper"
]]
=
weakref
.
WeakSet
()
@
classmethod
def
clear_all_graphs
(
cls
)
->
None
:
"""Clear captured graphs from all CUDAGraphWrapper instances."""
for
instance
in
list
(
cls
.
_all_instances
):
instance
.
clear_graphs
()
def
__init__
(
self
,
runnable
:
Callable
[...,
Any
],
...
...
@@ -192,6 +201,8 @@ class CUDAGraphWrapper:
# cudagraphs for.
self
.
concrete_cudagraph_entries
:
dict
[
BatchDescriptor
,
CUDAGraphEntry
]
=
{}
CUDAGraphWrapper
.
_all_instances
.
add
(
self
)
def
__getattr__
(
self
,
key
:
str
)
->
Any
:
# allow accessing the attributes of the runnable.
if
hasattr
(
self
.
runnable
,
key
):
...
...
@@ -205,6 +216,13 @@ class CUDAGraphWrapper:
# in case we need to access the original runnable.
return
self
.
runnable
@
property
def
cudagraph_wrapper
(
self
)
->
"CUDAGraphWrapper"
:
return
self
def
clear_graphs
(
self
)
->
None
:
self
.
concrete_cudagraph_entries
.
clear
()
def
__call__
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
|
None
:
forward_context
=
get_forward_context
()
batch_descriptor
=
forward_context
.
batch_descriptor
...
...
vllm/envs.py
View file @
ebb9cc5f
...
...
@@ -244,6 +244,7 @@ if TYPE_CHECKING:
VLLM_CUDA_COMPATIBILITY_PATH
:
str
|
None
=
None
VLLM_ELASTIC_EP_SCALE_UP_LAUNCH
:
bool
=
False
VLLM_ELASTIC_EP_DRAIN_REQUESTS
:
bool
=
False
VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS
:
bool
=
False
def
get_default_cache_root
():
...
...
@@ -1628,6 +1629,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ELASTIC_EP_DRAIN_REQUESTS"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ELASTIC_EP_DRAIN_REQUESTS"
,
"0"
))
),
# If set to 1, enable CUDA graph memory estimation during memory profiling.
# This profiles CUDA graph memory usage to provide more accurate KV cache
# memory allocation. Disabled by default to preserve existing behavior.
"VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS"
,
"0"
))
),
}
...
...
vllm/v1/cudagraph_dispatcher.py
View file @
ebb9cc5f
...
...
@@ -334,8 +334,11 @@ class CudagraphDispatcher:
for
mode
in
[
CUDAGraphMode
.
PIECEWISE
,
CUDAGraphMode
.
FULL
]:
descs
=
list
(
self
.
cudagraph_keys
[
mode
])
if
descs
:
# Sort by num_tokens descending (largest first)
descs
.
sort
(
key
=
lambda
d
:
d
.
num_tokens
,
reverse
=
True
)
# Sort by (num_tokens, num_active_loras) descending
descs
.
sort
(
key
=
lambda
d
:
(
d
.
num_tokens
,
d
.
num_active_loras
),
reverse
=
True
,
)
result
.
append
((
mode
,
descs
))
return
result
vllm/v1/worker/gpu_model_runner.py
View file @
ebb9cc5f
...
...
@@ -29,6 +29,7 @@ from vllm.config import (
CUDAGraphMode
,
VllmConfig
,
get_layers_from_vllm_config
,
set_current_vllm_config
,
update_config
,
)
from
vllm.distributed.ec_transfer
import
get_ec_transfer
,
has_ec_transfer
...
...
@@ -94,6 +95,7 @@ from vllm.multimodal.inputs import (
PlaceholderRange
,
)
from
vllm.multimodal.utils
import
group_and_batch_mm_kwargs
from
vllm.platforms
import
current_platform
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingType
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -596,6 +598,17 @@ class GPUModelRunner(
self
.
async_output_copy_stream
=
torch
.
cuda
.
Stream
()
self
.
prepare_inputs_event
=
torch
.
Event
()
# self.cudagraph_batch_sizes sorts in ascending order.
if
(
self
.
compilation_config
.
cudagraph_capture_sizes
and
self
.
compilation_config
.
cudagraph_mode
!=
CUDAGraphMode
.
NONE
):
self
.
cudagraph_batch_sizes
=
sorted
(
self
.
compilation_config
.
cudagraph_capture_sizes
)
else
:
self
.
cudagraph_batch_sizes
=
[]
# Cache the device properties.
self
.
_init_device_properties
()
...
...
@@ -4727,6 +4740,7 @@ class GPUModelRunner(
remove_lora
:
bool
=
True
,
is_graph_capturing
:
bool
=
False
,
num_active_loras
:
int
=
0
,
profile_seq_lens
:
int
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Run a dummy forward pass to warm up/profile run or capture the
...
...
@@ -4751,6 +4765,9 @@ class GPUModelRunner(
remove_lora: If False, dummy LoRAs are not destroyed after the run
num_active_loras: Number of distinct active LoRAs to capture for.
LoRA is activated when num_active_loras > 0.
profile_seq_lens: If provided, use this value for seq_lens instead
of max_query_len. Used to profile attention workspace that
scales with context length.
"""
mm_config
=
self
.
vllm_config
.
model_config
.
multimodal_config
if
mm_config
and
mm_config
.
mm_encoder_only
:
...
...
@@ -4881,11 +4898,13 @@ class GPUModelRunner(
# If force_attention is True, we always capture attention.
# Otherwise, it only happens for cudagraph_runtime_mode=FULL.
if
force_attention
or
cudagraph_runtime_mode
==
CUDAGraphMode
.
FULL
:
if
create_mixed_batch
:
if
profile_seq_lens
is
not
None
:
seq_lens
=
profile_seq_lens
# type: ignore[assignment]
elif
create_mixed_batch
:
# In the mixed batch mode (used for FI warmup), we use
# shorter sequence lengths to run faster.
# TODO(luka) better system for describing dummy batches
seq_lens
=
[
1
]
*
num_decode_tokens
+
[
num_prefill_tokens
+
1
]
seq_lens
=
[
1
]
*
num_decode_tokens
+
[
num_prefill_tokens
+
1
]
# type: ignore[assignment]
else
:
seq_lens
=
max_query_len
# type: ignore[assignment]
self
.
seq_lens
.
np
[:
num_reqs
]
=
seq_lens
...
...
@@ -5298,24 +5317,34 @@ class GPUModelRunner(
self
.
encoder_cache
.
clear
()
gc
.
collect
()
@
instrument
(
span_name
=
"Capture model"
)
def
capture_model
(
self
)
->
int
:
if
self
.
compilation_config
.
cudagraph_mode
==
CUDAGraphMode
.
NONE
:
logger
.
warning
(
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
"ensure `cudagraph_mode` was not manually set to `NONE`"
def
_init_minimal_kv_cache_for_profiling
(
self
)
->
None
:
from
vllm.v1.core.kv_cache_utils
import
(
get_kv_cache_config_from_groups
,
get_kv_cache_groups
,
)
return
0
compilation_counter
.
num_gpu_runner_capture_triggers
+=
1
kv_cache_spec
=
self
.
get_kv_cache_spec
()
kv_cache_groups
=
get_kv_cache_groups
(
self
.
vllm_config
,
kv_cache_spec
)
min_blocks
=
self
.
compilation_config
.
max_cudagraph_capture_size
or
1
if
kv_cache_groups
:
page_size
=
kv_cache_groups
[
0
].
kv_cache_spec
.
page_size_bytes
group_size
=
max
(
len
(
g
.
layer_names
)
for
g
in
kv_cache_groups
)
available_memory
=
min_blocks
*
page_size
*
group_size
else
:
available_memory
=
1
# Attention-free model
start_time
=
time
.
perf_counter
()
minimal_config
=
get_kv_cache_config_from_groups
(
self
.
vllm_config
,
kv_cache_groups
,
available_memory
=
available_memory
)
self
.
initialize_kv_cache
(
minimal_config
)
self
.
cache_config
.
num_gpu_blocks
=
minimal_config
.
num_blocks
logger
.
debug
(
"Initialized minimal KV cache for CUDA graph profiling"
)
@
staticmethod
@
contextmanager
def
freeze_gc
():
# Optimize garbage collection during CUDA graph capture.
# Clean up, then freeze all remaining objects from being included
# in future collections.
def
_freeze_gc
():
gc
.
collect
()
should_freeze
=
not
envs
.
VLLM_ENABLE_CUDAGRAPH_GC
if
should_freeze
:
...
...
@@ -5327,11 +5356,148 @@ class GPUModelRunner(
gc
.
unfreeze
()
gc
.
collect
()
def
_cleanup_profiling_kv_cache
(
self
)
->
None
:
torch
.
accelerator
.
synchronize
()
if
hasattr
(
self
,
"kv_caches"
)
and
self
.
kv_caches
:
for
i
in
range
(
len
(
self
.
kv_caches
)):
self
.
kv_caches
[
i
]
=
None
# type: ignore
self
.
kv_caches
.
clear
()
if
hasattr
(
self
,
"cross_layers_kv_cache"
):
self
.
cross_layers_kv_cache
=
None
self
.
cross_layers_attn_backend
=
None
if
hasattr
(
self
,
"attn_groups"
):
self
.
attn_groups
.
clear
()
if
hasattr
(
self
,
"kv_cache_config"
):
delattr
(
self
,
"kv_cache_config"
)
self
.
cache_config
.
num_gpu_blocks
=
None
for
layer
in
self
.
compilation_config
.
static_forward_context
.
values
():
if
hasattr
(
layer
,
"kv_cache"
):
layer
.
kv_cache
=
[]
gc
.
collect
()
torch
.
accelerator
.
empty_cache
()
logger
.
debug
(
"Cleaned up profiling KV cache and CUDA graphs"
)
@
torch
.
inference_mode
()
def
profile_cudagraph_memory
(
self
)
->
int
:
with
set_current_vllm_config
(
self
.
vllm_config
):
self
.
_init_minimal_kv_cache_for_profiling
()
saved_num_cudagraph_captured
=
compilation_counter
.
num_cudagraph_captured
capture_descs
=
self
.
cudagraph_dispatcher
.
get_capture_descs
()
total_graphs
=
sum
(
len
(
descs
)
for
_
,
descs
in
capture_descs
)
if
total_graphs
==
0
:
logger
.
debug
(
"No CUDA graphs will be captured, skipping profiling"
)
self
.
_cleanup_profiling_kv_cache
()
return
0
logger
.
info
(
"Profiling CUDA graph memory: %s"
,
", "
.
join
(
f
"
{
mode
.
name
}
=
{
len
(
descs
)
}
(largest=
{
descs
[
0
].
num_tokens
}
)"
for
mode
,
descs
in
capture_descs
if
descs
),
)
# Use a temporary pool for profiling to avoid fragmentation in the main pool.
profiling_pool
=
current_platform
.
graph_pool_handle
()
original_pools
:
dict
[
int
,
Any
]
=
{}
for
instance
in
list
(
CUDAGraphWrapper
.
_all_instances
):
original_pools
[
id
(
instance
)]
=
instance
.
graph_pool
instance
.
graph_pool
=
profiling_pool
set_cudagraph_capturing_enabled
(
True
)
with
self
.
_freeze_gc
(),
graph_capture
(
device
=
self
.
device
):
shared_memory_estimate
=
{}
per_graph_estimate
=
{}
torch
.
accelerator
.
synchronize
()
torch
.
accelerator
.
empty_cache
()
for
mode
,
descs
in
capture_descs
:
profile_descs
=
descs
[:
2
]
mem_samples
:
list
[
int
]
=
[]
for
i
,
desc
in
enumerate
(
profile_descs
):
mem_before
=
torch
.
cuda
.
mem_get_info
()[
0
]
self
.
_warmup_and_capture
(
desc
,
cudagraph_runtime_mode
=
mode
,
profile_seq_lens
=
(
min
(
self
.
max_model_len
,
self
.
max_num_tokens
//
desc
.
num_tokens
,
)
if
mode
==
CUDAGraphMode
.
FULL
and
i
==
0
else
None
),
)
torch
.
accelerator
.
synchronize
()
free_after
=
torch
.
cuda
.
mem_get_info
()[
0
]
mem_samples
.
append
(
mem_before
-
free_after
)
first_capture
=
mem_samples
[
0
]
# Use at least 1 MiB per graph for driver overhead
per_graph
=
max
(
mem_samples
[
1
]
if
len
(
mem_samples
)
>
1
else
0
,
1
<<
20
)
shared_memory_estimate
[
mode
]
=
first_capture
per_graph_estimate
[
mode
]
=
per_graph
*
(
len
(
descs
)
-
1
)
logger
.
debug
(
"Estimated %s CUDA graph memory: "
"%.2f MiB first-capture + (%d-1) × %.2f MiB per-graph"
,
mode
.
name
,
first_capture
/
(
1
<<
20
),
len
(
descs
),
per_graph
/
(
1
<<
20
),
)
set_cudagraph_capturing_enabled
(
False
)
CUDAGraphWrapper
.
clear_all_graphs
()
for
instance
in
list
(
CUDAGraphWrapper
.
_all_instances
):
if
id
(
instance
)
in
original_pools
:
instance
.
graph_pool
=
original_pools
[
id
(
instance
)]
self
.
maybe_remove_all_loras
(
self
.
lora_config
)
self
.
_cleanup_profiling_kv_cache
()
compilation_counter
.
num_cudagraph_captured
=
saved_num_cudagraph_captured
# FULL and PIECEWISE graphs share the global pool at runtime and are
# never replayed concurrently, so the pool overlays their memory.
# Take the max to avoid double-counting the overlap.
total_estimate
=
max
(
shared_memory_estimate
.
values
())
+
sum
(
per_graph_estimate
.
values
()
)
logger
.
info
(
"Estimated CUDA graph memory: %.2f GiB total"
,
total_estimate
/
(
1
<<
30
),
)
return
int
(
total_estimate
)
@
instrument
(
span_name
=
"Capture model"
)
def
capture_model
(
self
)
->
int
:
if
self
.
compilation_config
.
cudagraph_mode
==
CUDAGraphMode
.
NONE
:
logger
.
warning
(
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
"ensure `cudagraph_mode` was not manually set to `NONE`"
)
return
0
compilation_counter
.
num_gpu_runner_capture_triggers
+=
1
start_time
=
time
.
perf_counter
()
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
set_cudagraph_capturing_enabled
(
True
)
with
freeze_gc
(),
graph_capture
(
device
=
self
.
device
):
with
self
.
_freeze_gc
(),
graph_capture
(
device
=
self
.
device
):
torch
.
accelerator
.
synchronize
()
torch
.
accelerator
.
empty_cache
()
start_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
for
(
...
...
@@ -5342,6 +5508,7 @@ class GPUModelRunner(
batch_descriptors
=
batch_descs
,
cudagraph_runtime_mode
=
runtime_mode
,
)
torch
.
accelerator
.
synchronize
()
torch
.
accelerator
.
synchronize
()
end_free_gpu_memory
=
torch
.
cuda
.
mem_get_info
()[
0
]
...
...
@@ -5353,6 +5520,9 @@ class GPUModelRunner(
# after here.
set_cudagraph_capturing_enabled
(
False
)
torch
.
accelerator
.
synchronize
()
torch
.
accelerator
.
empty_cache
()
# Lock workspace to prevent resizing during execution.
# Max workspace sizes should have been captured during warmup/profiling.
lock_workspace
()
...
...
@@ -5369,6 +5539,40 @@ class GPUModelRunner(
)
return
cuda_graph_size
def
_warmup_and_capture
(
self
,
desc
:
BatchDescriptor
,
cudagraph_runtime_mode
:
CUDAGraphMode
,
profile_seq_lens
:
int
|
None
=
None
,
allow_microbatching
:
bool
=
False
,
num_warmups
:
int
|
None
=
None
,
):
if
num_warmups
is
None
:
num_warmups
=
self
.
compilation_config
.
cudagraph_num_of_warmups
force_attention
=
cudagraph_runtime_mode
==
CUDAGraphMode
.
FULL
for
_
in
range
(
num_warmups
):
self
.
_dummy_run
(
desc
.
num_tokens
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
force_attention
=
force_attention
,
uniform_decode
=
desc
.
uniform
,
allow_microbatching
=
allow_microbatching
,
skip_eplb
=
True
,
remove_lora
=
False
,
num_active_loras
=
desc
.
num_active_loras
,
)
self
.
_dummy_run
(
desc
.
num_tokens
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
uniform_decode
=
desc
.
uniform
,
allow_microbatching
=
allow_microbatching
,
skip_eplb
=
True
,
remove_lora
=
False
,
num_active_loras
=
desc
.
num_active_loras
,
is_graph_capturing
=
True
,
profile_seq_lens
=
profile_seq_lens
,
)
def
_capture_cudagraphs
(
self
,
batch_descriptors
:
list
[
BatchDescriptor
],
...
...
@@ -5383,15 +5587,6 @@ class GPUModelRunner(
return
uniform_decode
=
batch_descriptors
[
0
].
uniform
force_attention
=
cudagraph_runtime_mode
==
CUDAGraphMode
.
FULL
dummy_run
=
functools
.
partial
(
self
.
_dummy_run
,
uniform_decode
=
uniform_decode
,
skip_eplb
=
True
,
remove_lora
=
False
,
force_attention
=
force_attention
,
)
# Only rank 0 should print progress bar during capture
if
is_global_first_rank
():
...
...
@@ -5406,9 +5601,6 @@ class GPUModelRunner(
# We skip EPLB here since we don't want to record dummy metrics
for
batch_desc
in
batch_descriptors
:
num_tokens
=
batch_desc
.
num_tokens
num_active_loras
=
batch_desc
.
num_active_loras
# We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched
...
...
@@ -5419,33 +5611,16 @@ class GPUModelRunner(
and
uniform_decode
and
check_ubatch_thresholds
(
config
=
self
.
vllm_config
.
parallel_config
,
num_tokens
=
num_tokens
,
num_tokens
=
batch_desc
.
num_tokens
,
uniform_decode
=
uniform_decode
,
)
)
for
_
in
range
(
self
.
compilation_config
.
cudagraph_num_of_warmups
):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE` is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
dummy_run
(
num_tokens
,
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
allow_microbatching
=
allow_microbatching
,
num_active_loras
=
num_active_loras
,
)
# Capture run
dummy_run
(
num_tokens
,
self
.
_warmup_and_capture
(
batch_desc
,
cudagraph_runtime_mode
=
cudagraph_runtime_mode
,
allow_microbatching
=
allow_microbatching
,
num_active_loras
=
num_active_loras
,
is_graph_capturing
=
True
,
)
torch
.
accelerator
.
synchronize
()
self
.
maybe_remove_all_loras
(
self
.
lora_config
)
def
initialize_attn_backend
(
self
,
kv_cache_config
:
KVCacheConfig
)
->
None
:
...
...
vllm/v1/worker/gpu_ubatch_wrapper.py
View file @
ebb9cc5f
...
...
@@ -112,16 +112,25 @@ class UBatchWrapper:
self
.
cudagraphs
:
dict
[
int
,
CUDAGraphMetaData
]
=
{}
self
.
cudagraph_wrapper
=
None
self
.
graph_pool
=
None
if
runtime_mode
is
not
CUDAGraphMode
.
NONE
:
self
.
cudagraph_wrapper
=
CUDAGraphWrapper
(
runnable
,
vllm_config
,
runtime_mode
=
runtime_mode
)
self
.
graph_pool
=
current_platform
.
get_global_graph_pool
()
self
.
sm_control
=
self
.
_create_sm_control_context
(
vllm_config
)
self
.
device
=
device
@
property
def
graph_pool
(
self
):
if
self
.
cudagraph_wrapper
is
not
None
:
return
self
.
cudagraph_wrapper
.
graph_pool
return
None
def
clear_graphs
(
self
)
->
None
:
self
.
cudagraphs
.
clear
()
if
self
.
cudagraph_wrapper
is
not
None
:
self
.
cudagraph_wrapper
.
clear_graphs
()
@
staticmethod
def
_create_sm_control_context
(
vllm_config
:
VllmConfig
):
comm_sms
:
int
=
envs
.
VLLM_DBO_COMM_SMS
...
...
vllm/v1/worker/gpu_worker.py
View file @
ebb9cc5f
...
...
@@ -44,6 +44,7 @@ from vllm.profiler.wrapper import CudaProfilerWrapper, TorchProfilerWrapper
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
SupportedTask
from
vllm.tracing
import
instrument
from
vllm.utils.mem_constants
import
GiB_bytes
from
vllm.utils.mem_utils
import
MemorySnapshot
,
format_gib
,
memory_profiling
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
...
...
@@ -390,8 +391,36 @@ class Worker(WorkerBase):
)
as
profile_result
:
self
.
model_runner
.
profile_run
()
profile_torch_peak
=
current_platform
.
memory_stats
(
self
.
device
).
get
(
"allocated_bytes.all.peak"
,
0
)
# Profile CUDA graph memory if graphs will be captured.
cudagraph_memory_estimate
=
0
if
not
self
.
model_config
.
enforce_eager
:
cudagraph_memory_estimate
=
self
.
model_runner
.
profile_cudagraph_memory
()
# Use the pre-cudagraph torch peak to avoid double-counting.
profile_result
.
torch_peak_increase
=
(
profile_torch_peak
-
profile_result
.
before_profile
.
torch_peak
)
profile_result
.
non_kv_cache_memory
=
(
profile_result
.
non_torch_increase
+
profile_result
.
torch_peak_increase
+
profile_result
.
weights_memory
)
cudagraph_memory_estimate_applied
=
(
cudagraph_memory_estimate
if
envs
.
VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS
else
0
)
self
.
non_torch_memory
=
profile_result
.
non_torch_increase
self
.
peak_activation_memory
=
profile_result
.
torch_peak_increase
self
.
peak_activation_memory
=
(
profile_result
.
torch_peak_increase
+
cudagraph_memory_estimate_applied
)
self
.
cudagraph_memory_estimate
=
cudagraph_memory_estimate
free_gpu_memory
=
profile_result
.
after_profile
.
free_memory
# NOTE(woosuk): Here we assume that the other processes using the same
...
...
@@ -406,7 +435,9 @@ class Worker(WorkerBase):
"isolate vLLM in its own container."
)
self
.
available_kv_cache_memory_bytes
=
(
self
.
requested_memory
-
profile_result
.
non_kv_cache_memory
self
.
requested_memory
-
profile_result
.
non_kv_cache_memory
-
cudagraph_memory_estimate_applied
)
unrequested_memory
=
self
.
init_snapshot
.
free_memory
-
self
.
requested_memory
...
...
@@ -428,6 +459,46 @@ class Worker(WorkerBase):
scope
=
"local"
,
)
if
cudagraph_memory_estimate
>
0
:
total_mem
=
self
.
init_snapshot
.
total_memory
current_util
=
self
.
cache_config
.
gpu_memory_utilization
cg_util_delta
=
cudagraph_memory_estimate
/
total_mem
if
envs
.
VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS
:
equiv_util
=
round
(
current_util
-
cg_util_delta
,
4
)
suggested_util
=
min
(
round
(
current_util
+
cg_util_delta
,
4
),
1.0
,
)
logger
.
info
(
"CUDA graph memory profiling is enabled "
"(VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1). "
"This will become the default in v0.19. "
"The current --gpu-memory-utilization=%.4f is equivalent "
"to --gpu-memory-utilization=%.4f without CUDA graph "
"memory profiling. To maintain the same effective KV "
"cache size as before, increase "
"--gpu-memory-utilization to %.4f."
,
current_util
,
equiv_util
,
suggested_util
,
)
else
:
suggested_util
=
min
(
round
(
current_util
+
cg_util_delta
,
4
),
1.0
,
)
logger
.
info
(
"In v0.19, CUDA graph memory profiling will be enabled "
"by default (VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1), "
"which more accurately accounts for CUDA graph memory "
"during KV cache allocation. To try it now, set "
"VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1 and increase "
"--gpu-memory-utilization from %.4f to %.4f to maintain "
"the same effective KV cache size."
,
current_util
,
suggested_util
,
)
return
int
(
self
.
available_kv_cache_memory_bytes
)
def
get_kv_connector_handshake_metadata
(
self
)
->
dict
|
None
:
...
...
@@ -487,14 +558,14 @@ class Worker(WorkerBase):
@
instrument
(
span_name
=
"Warmup (GPU)"
)
def
compile_or_warm_up_model
(
self
)
->
float
:
warmup_sizes
=
[]
warmup_sizes
:
list
[
int
]
=
[]
if
self
.
vllm_config
.
compilation_config
.
mode
==
CompilationMode
.
VLLM_COMPILE
:
# warm up sizes that are not in cudagraph capture sizes,
# but users still want to compile for better performance,
# e.g. for the max-num-batched token size in chunked prefill.
compile_sizes
=
self
.
vllm_config
.
compilation_config
.
compile_sizes
warmup_sizes
=
compile_sizes
.
copy
()
if
compile_sizes
is
not
None
else
[]
warmup_sizes
=
compile_sizes
.
copy
()
if
compile_sizes
is
not
None
else
[]
# type: ignore[assignment]
cg_capture_sizes
:
list
[
int
]
=
[]
if
self
.
vllm_config
.
compilation_config
.
cudagraph_mode
!=
CUDAGraphMode
.
NONE
:
...
...
@@ -526,6 +597,22 @@ class Worker(WorkerBase):
if
not
self
.
model_config
.
enforce_eager
:
cuda_graph_memory_bytes
=
self
.
model_runner
.
capture_model
()
# Compare actual vs estimated CUDA graph memory (if we did profiling)
if
(
hasattr
(
self
,
"cudagraph_memory_estimate"
)
and
self
.
cudagraph_memory_estimate
>
0
):
GiB
=
lambda
b
:
round
(
b
/
GiB_bytes
,
2
)
diff
=
abs
(
cuda_graph_memory_bytes
-
self
.
cudagraph_memory_estimate
)
logger
.
info
(
"CUDA graph pool memory: %s GiB (actual), %s GiB (estimated), "
"difference: %s GiB (%.1f%%)."
,
GiB
(
cuda_graph_memory_bytes
),
GiB
(
self
.
cudagraph_memory_estimate
),
GiB
(
diff
),
100
*
diff
/
max
(
cuda_graph_memory_bytes
,
1
),
)
if
self
.
cache_config
.
kv_cache_memory_bytes
is
None
and
hasattr
(
self
,
"peak_activation_memory"
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment