Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cbf8f702
Unverified
Commit
cbf8f702
authored
Feb 25, 2026
by
Michael Goin
Committed by
GitHub
Feb 25, 2026
Browse files
[UX] Add `--performance-mode {balanced,interactivity,throughput}` (#34936)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
6831650c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
5 deletions
+37
-5
vllm/config/vllm.py
vllm/config/vllm.py
+26
-4
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+11
-1
No files found.
vllm/config/vllm.py
View file @
cbf8f702
...
...
@@ -14,7 +14,7 @@ from datetime import datetime
from
enum
import
IntEnum
from
functools
import
lru_cache
from
pathlib
import
Path
from
typing
import
TYPE_CHECKING
,
Any
,
TypeVar
,
get_args
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
TypeVar
,
get_args
import
torch
from
pydantic
import
ConfigDict
,
Field
,
model_validator
...
...
@@ -76,6 +76,8 @@ class OptimizationLevel(IntEnum):
"""O3: Currently the same as -O2s."""
PerformanceMode
=
Literal
[
"balanced"
,
"interactivity"
,
"throughput"
]
IS_QUANTIZED
=
False
IS_DENSE
=
False
# The optimizations that depend on these properties currently set to False
...
...
@@ -312,6 +314,13 @@ class VllmConfig:
performance. -O2 is used by default. See OptimizationLevel for full
description."""
performance_mode
:
PerformanceMode
=
"balanced"
"""Performance mode for runtime behavior, 'balanced' is the default.
'interactivity' favors low end-to-end per-request latency at small batch
sizes (fine-grained CUDA graphs, latency-oriented kernels).
'throughput' favors aggregate tokens/sec at high concurrency (larger CUDA
graphs, more aggressive batching, throughput-oriented kernels)."""
weight_transfer_config
:
WeightTransferConfig
|
None
=
None
"""The configurations for weight transfer during RL training."""
...
...
@@ -643,6 +652,11 @@ class VllmConfig:
# To give each torch profile run a unique instance name.
self
.
instance_id
=
f
"
{
time
.
time_ns
()
}
"
if
self
.
performance_mode
!=
"balanced"
:
logger
.
info_once
(
"Performance mode set to '%s'."
,
self
.
performance_mode
,
scope
=
"local"
)
self
.
try_verify_and_update_config
()
if
self
.
model_config
is
not
None
:
...
...
@@ -1331,6 +1345,12 @@ class VllmConfig:
]
# sort to make sure the sizes are in ascending order
cudagraph_capture_sizes
.
sort
()
else
:
if
self
.
performance_mode
==
"interactivity"
:
# Fine-grained CUDA graphs at small batch sizes
# for minimal padding overhead
interactivity_max
=
min
(
max_cudagraph_capture_size
,
32
)
cudagraph_capture_sizes
=
list
(
range
(
1
,
interactivity_max
+
1
))
else
:
cudagraph_capture_sizes
=
[
i
for
i
in
[
1
,
2
,
4
]
if
i
<=
max_cudagraph_capture_size
...
...
@@ -1345,6 +1365,8 @@ class VllmConfig:
cudagraph_capture_sizes
+=
list
(
range
(
256
,
max_cudagraph_capture_size
+
1
,
16
)
)
# de-duplicate and sort the sizes
cudagraph_capture_sizes
=
sorted
(
set
(
cudagraph_capture_sizes
))
if
(
self
.
parallel_config
.
tensor_parallel_size
>
1
...
...
vllm/engine/arg_utils.py
View file @
cbf8f702
...
...
@@ -89,7 +89,7 @@ from vllm.config.parallel import (
)
from
vllm.config.scheduler
import
SchedulerPolicy
from
vllm.config.utils
import
get_field
from
vllm.config.vllm
import
OptimizationLevel
from
vllm.config.vllm
import
OptimizationLevel
,
PerformanceMode
from
vllm.logger
import
init_logger
,
suppress_logging
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.plugins
import
load_general_plugins
...
...
@@ -596,6 +596,7 @@ class EngineArgs:
kv_sharing_fast_prefill
:
bool
=
CacheConfig
.
kv_sharing_fast_prefill
optimization_level
:
OptimizationLevel
=
VllmConfig
.
optimization_level
performance_mode
:
PerformanceMode
=
VllmConfig
.
performance_mode
kv_offloading_size
:
float
|
None
=
CacheConfig
.
kv_offloading_size
kv_offloading_backend
:
KVOffloadingBackend
=
CacheConfig
.
kv_offloading_backend
...
...
@@ -1264,6 +1265,7 @@ class EngineArgs:
vllm_group
.
add_argument
(
"--optimization-level"
,
**
vllm_kwargs
[
"optimization_level"
]
)
vllm_group
.
add_argument
(
"--performance-mode"
,
**
vllm_kwargs
[
"performance_mode"
])
vllm_group
.
add_argument
(
"--weight-transfer-config"
,
**
vllm_kwargs
[
"weight_transfer_config"
]
)
...
...
@@ -1894,6 +1896,7 @@ class EngineArgs:
profiler_config
=
self
.
profiler_config
,
additional_config
=
self
.
additional_config
,
optimization_level
=
self
.
optimization_level
,
performance_mode
=
self
.
performance_mode
,
weight_transfer_config
=
self
.
weight_transfer_config
,
)
...
...
@@ -2110,6 +2113,13 @@ class EngineArgs:
SchedulerConfig
.
DEFAULT_MAX_NUM_SEQS
,
)
# If throughput mode is set, double max_num_batched_tokens and max_num_seqs.
if
self
.
performance_mode
==
"throughput"
:
if
orig_max_num_batched_tokens
is
None
:
self
.
max_num_batched_tokens
*=
2
if
orig_max_num_seqs
is
None
:
self
.
max_num_seqs
*=
2
if
orig_max_num_batched_tokens
is
None
:
assert
model_config
.
max_model_len
is
not
None
,
(
"max_model_len must be set by this point"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment