Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
87aee9ed
Unverified
Commit
87aee9ed
authored
Dec 08, 2025
by
Laith Sakka
Committed by
GitHub
Dec 08, 2025
Browse files
Add evaluate_guards option to DynamicShapesConfig (#27432)
Signed-off-by:
Laith Sakka
<
lsakka@meta.com
>
parent
184076c3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
219 additions
and
28 deletions
+219
-28
tests/compile/test_dynamic_shapes_compilation.py
tests/compile/test_dynamic_shapes_compilation.py
+132
-7
vllm/compilation/backends.py
vllm/compilation/backends.py
+24
-2
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+2
-2
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+46
-13
vllm/config/compilation.py
vllm/config/compilation.py
+14
-3
vllm/config/vllm.py
vllm/config/vllm.py
+1
-1
No files found.
tests/compile/test_dynamic_shapes_compilation.py
View file @
87aee9ed
...
...
@@ -2,12 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
gc
import
tempfile
from
contextlib
import
contextmanager
import
pytest
import
torch
from
vllm
import
LLM
,
SamplingParams
from
vllm.config.compilation
import
CompilationMode
,
DynamicShapesType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CompilationConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.config.compilation
import
(
CompilationMode
,
DynamicShapesConfig
,
DynamicShapesType
,
)
from
vllm.forward_context
import
set_forward_context
from
vllm.tokenizers
import
get_tokenizer
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
...
...
@@ -29,18 +38,19 @@ def get_test_models():
)
@
pytest
.
mark
.
parametrize
(
"use_aot_compile"
,
[
"0"
])
@
pytest
.
mark
.
parametrize
(
"use_bytecode_hook"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"evaluate_guards"
,
[
False
,
True
])
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
def
test_dynamic_shapes_compilation
(
monkeypatch
,
model_name
,
shapes_type
,
use_aot_compile
,
use_bytecode_hook
monkeypatch
,
model_name
,
shapes_type
,
use_aot_compile
,
use_bytecode_hook
,
evaluate_guards
,
):
"""Test that all dynamic shapes types compile successfully"""
print
(
f
"
\n
Testing model:
{
model_name
}
with
{
shapes_type
.
name
}
, "
f
"AOT compile:
{
use_aot_compile
}
, "
f
"Bytecode hook:
{
use_bytecode_hook
}
"
)
if
use_bytecode_hook
and
shapes_type
==
DynamicShapesType
.
UNBACKED
:
pytest
.
skip
(
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0"
)
...
...
@@ -58,6 +68,7 @@ def test_dynamic_shapes_compilation(
"mode"
:
CompilationMode
.
VLLM_COMPILE
,
"dynamic_shapes_config"
:
{
"type"
:
shapes_type
.
value
,
"evaluate_guards"
:
evaluate_guards
,
},
},
)
...
...
@@ -86,3 +97,117 @@ def test_dynamic_shapes_compilation(
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
print
(
"GPU memory cleared"
)
@
pytest
.
mark
.
parametrize
(
"use_aot_compile"
,
[
"0"
,
"1"
])
@
pytest
.
mark
.
parametrize
(
"dynamic_shapes_type"
,
[
DynamicShapesType
.
BACKED
,
DynamicShapesType
.
BACKED_SIZE_OBLIVIOUS
,
],
)
@
pytest
.
mark
.
parametrize
(
"evaluate_guards"
,
[
False
,
True
])
def
test_model_specialization_with_evaluate_guards
(
monkeypatch
,
use_aot_compile
,
dynamic_shapes_type
,
evaluate_guards
):
"""Test that evaluate_guards correctly detects shape specialization
violations.
"""
if
(
use_aot_compile
==
"1"
and
dynamic_shapes_type
==
DynamicShapesType
.
BACKED
and
evaluate_guards
):
pytest
.
skip
(
"evaluate_guards for backed does not work with aot_compile =1"
)
@
support_torch_compile
class
ModelWithSizeCheck
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
):
# This will cause specialization - torch.compile will guard on
# sx.shape[0]
if
x
.
shape
[
0
]
>=
10
:
return
x
*
10
else
:
return
x
*
10
@
support_torch_compile
class
ModelWithOneSizeCheck
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
):
# This will cause 0/1 specializations.
if
x
.
shape
[
0
]
==
0
:
return
x
*
10
if
x
.
shape
[
0
]
==
1
:
return
x
*
10
else
:
return
x
*
10
@
contextmanager
def
use_vllm_config
(
vllm_config
:
VllmConfig
):
with
set_forward_context
({},
vllm_config
),
set_current_vllm_config
(
vllm_config
):
yield
monkeypatch
.
setenv
(
"TOKENIZERS_PARALLELISM"
,
"true"
)
monkeypatch
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
use_aot_compile
)
monkeypatch
.
setenv
(
"VLLM_USE_BYTECODE_HOOK"
,
"0"
)
# Create vllm config with the desired settings
from
vllm.config
import
CompilationMode
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
dynamic_shapes_config
=
DynamicShapesConfig
(
type
=
dynamic_shapes_type
,
evaluate_guards
=
evaluate_guards
,
),
)
)
def
test
(
model_class
,
input1
,
input2
,
is_01_specialization
=
False
):
with
(
torch
.
no_grad
(),
use_vllm_config
(
vllm_config
),
tempfile
.
TemporaryDirectory
()
as
tmpdirname
,
):
monkeypatch
.
setenv
(
"VLLM_CACHE_ROOT"
,
tmpdirname
)
model
=
model_class
(
vllm_config
=
vllm_config
).
cuda
()
model
(
input1
)
if
evaluate_guards
and
(
not
(
is_01_specialization
and
dynamic_shapes_type
==
DynamicShapesType
.
BACKED
)
):
# This should fail because guards were added.
with
pytest
.
raises
(
RuntimeError
)
as
excinfo
:
model
(
input2
)
# Expected failure - guard was violated
error_msg
=
str
(
excinfo
.
value
)
assert
(
"GuardManager check failed"
in
error_msg
or
"Detected recompile when torch.compile stance"
in
error_msg
),
error_msg
else
:
model
(
input2
)
test
(
ModelWithSizeCheck
,
torch
.
randn
(
20
,
10
).
cuda
(),
torch
.
randn
(
5
,
10
).
cuda
())
test
(
ModelWithSizeCheck
,
torch
.
randn
(
5
,
10
).
cuda
(),
torch
.
randn
(
20
,
10
).
cuda
())
test
(
ModelWithOneSizeCheck
,
torch
.
randn
(
20
,
10
).
cuda
(),
torch
.
randn
(
1
,
10
).
cuda
(),
is_01_specialization
=
True
,
)
vllm/compilation/backends.py
View file @
87aee9ed
...
...
@@ -26,6 +26,7 @@ from vllm.compilation.partition_rules import (
should_split
,
)
from
vllm.config
import
CompilationConfig
,
CUDAGraphMode
,
VllmConfig
from
vllm.config.compilation
import
DynamicShapesType
from
vllm.config.utils
import
Range
,
hash_factors
from
vllm.logger
import
init_logger
from
vllm.logging_utils
import
lazy
...
...
@@ -722,6 +723,29 @@ class VllmBackend:
self
.
split_gm
,
submod_names_to_compile
,
self
.
vllm_config
,
self
).
run
(
*
fake_args
)
from
torch._guards
import
detect_fake_mode
fake_mode
=
detect_fake_mode
()
if
(
self
.
compilation_config
.
dynamic_shapes_config
.
evaluate_guards
and
self
.
compilation_config
.
dynamic_shapes_config
.
type
==
DynamicShapesType
.
BACKED
):
from
torch.utils._sympy.value_ranges
import
ValueRanges
# Drop counter-0/1 specializations guards; for backed dynamic shapes,
# torch.compile will specialize for 0/1 inputs or otherwise guards that
# shape is >= 2. This is because it's really hard not to hit a check
# against 0/1. When we evaluate shape guards, we exclude checking those
# guards (We would fail always otherwise).
# We avoid that by updating the ranges of backed sizes when the min is
# 2 for any, we assume it's 0.
for
s
,
r
in
fake_mode
.
shape_env
.
var_to_range
.
items
():
if
r
.
lower
==
2
:
fake_mode
.
shape_env
.
var_to_range
[
s
]
=
ValueRanges
(
0
,
r
.
upper
)
graph_path
=
os
.
path
.
join
(
local_cache_dir
,
"computation_graph.py"
)
if
not
os
.
path
.
exists
(
graph_path
):
# code adapted from
...
...
@@ -749,8 +773,6 @@ class VllmBackend:
graph
,
example_inputs
,
self
.
prefix
,
self
.
split_gm
)
# if we need to copy input buffers for cudagraph
#
# index of tensors that have symbolic shapes (batch size)
# for weights and static buffers, they will have concrete shapes.
# symbolic shape only happens for input tensors.
...
...
vllm/compilation/decorators.py
View file @
87aee9ed
...
...
@@ -392,7 +392,6 @@ def _support_torch_compile(
factors
.
append
(
_model_hash_key
(
self
.
forward
))
hash_key
=
hashlib
.
sha256
(
str
(
factors
).
encode
()).
hexdigest
()
cache_dir
=
os
.
path
.
join
(
envs
.
VLLM_CACHE_ROOT
,
"torch_aot_compile"
,
...
...
@@ -413,7 +412,8 @@ def _support_torch_compile(
f
,
f_globals
=
self
.
forward
.
__globals__
)
_verify_source_unchanged
(
loaded_fn
.
source_info
(),
self
.
vllm_config
)
loaded_fn
.
disable_guard_check
()
if
not
self
.
compilation_config
.
dynamic_shapes_config
.
evaluate_guards
:
loaded_fn
.
disable_guard_check
()
self
.
aot_compiled_fn
=
loaded_fn
except
Exception
as
e
:
if
os
.
path
.
exists
(
aot_compilation_path
):
...
...
vllm/compilation/wrapper.py
View file @
87aee9ed
...
...
@@ -4,7 +4,7 @@
import
os
import
sys
from
abc
import
abstractmethod
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
,
nullcontext
from
types
import
CodeType
from
typing
import
Any
...
...
@@ -13,6 +13,7 @@ import torch._C._dynamo.guards
import
vllm.envs
as
envs
from
vllm.config
import
CompilationMode
,
CUDAGraphMode
,
get_current_vllm_config
from
vllm.config.compilation
import
DynamicShapesType
from
vllm.logger
import
init_logger
from
vllm.utils.nvtx_pytorch_hooks
import
layerwise_nvtx_marker_context
...
...
@@ -125,23 +126,49 @@ class TorchCompileWithNoGuardsWrapper:
if
isinstance
(
backend
,
str
)
and
backend
==
"inductor"
:
options
=
vllm_config
.
compilation_config
.
inductor_compile_config
self
.
first_compile
=
True
self
.
evaluate_guards
=
(
vllm_config
.
compilation_config
.
dynamic_shapes_config
.
evaluate_guards
)
ds_type
=
vllm_config
.
compilation_config
.
dynamic_shapes_config
.
type
if
mode
!=
CompilationMode
.
STOCK_TORCH_COMPILE
:
# Drop all the guards.
options
[
"guard_filter_fn"
]
=
lambda
x
:
[
False
for
_
in
x
]
if
self
.
evaluate_guards
:
assert
not
envs
.
VLLM_USE_BYTECODE_HOOK
,
(
"compilation_config.dynamic_shapes_config.evaluate_guards "
"requires VLLM_USE_BYTECODE_HOOK=0. "
)
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
from
vllm.compilation.decorators
import
DynamicShapesType
if
envs
.
VLLM_USE_AOT_COMPILE
:
# disabled until https://github.com/pytorch/pytorch/pull/169239
# is picked up.
assert
ds_type
!=
DynamicShapesType
.
BACKED
,
(
"evaluate_guards for backed shapes requires "
"VLLM_USE_AOT_COMPILE=False. "
)
options
[
"guard_filter_fn"
]
=
lambda
x
:
[
entry
.
guard_type
==
"SHAPE_ENV"
for
entry
in
x
]
else
:
options
[
"guard_filter_fn"
]
=
lambda
x
:
[
False
for
_
in
x
]
ds_type
=
vllm_config
.
compilation_config
.
dynamic_shapes_config
.
type
compiled_ptr
:
Any
=
self
.
forward
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
if
ds_type
==
DynamicShapesType
.
UNBACKED
:
if
envs
.
VLLM_USE_BYTECODE_HOOK
:
# reason is that bytecode does this hack torch._dynamo.eval_frame.
# remove_from_cache(self.original_code_object()) to force a new
# re-compilation.
raise
ValueError
(
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
)
# reason is that bytecode does torch._dynamo.eval_frame.
# remove_from_cache(self.original_code_object()) to force a new
# re-compilation. And if we use
# compiled_ptr = self.check_invariants_and_forward
# it will reset all entries.
assert
not
envs
.
VLLM_USE_BYTECODE_HOOK
,
(
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
)
assert
not
self
.
evaluate_guards
,
"UNBACKED dynamic shapes do not add guards"
compiled_ptr
=
self
.
check_invariants_and_forward
if
envs
.
VLLM_USE_AOT_COMPILE
:
...
...
@@ -195,7 +222,13 @@ class TorchCompileWithNoGuardsWrapper:
self
.
forward
,
*
args
,
**
kwargs
)
else
:
with
_compilation_context
():
ctx
=
(
nullcontext
()
if
self
.
first_compile
or
not
self
.
evaluate_guards
else
torch
.
compiler
.
set_stance
(
"fail_on_recompile"
)
)
self
.
first_compile
=
False
with
_compilation_context
(),
ctx
:
return
self
.
_call_with_optional_nvtx_range
(
self
.
_compiled_callable
,
*
args
,
**
kwargs
)
...
...
vllm/config/compilation.py
View file @
87aee9ed
...
...
@@ -344,7 +344,18 @@ class DynamicShapesConfig:
backed/unbacked.
"""
# TODO add a debug mode to fail
evaluate_guards
:
bool
=
False
"""
A debug mode to detect and fail if Dynamo ever specializes a dynamic shape by
guarding on it. When True, dynamic shape guards are not dropped from dynamo.
And a failure will be triggered if a recompilation ever happens due to that.
This mode requires VLLM_USE_BYTECODE_HOOK to be 0.
Enabling this allow observing the dynamic shapes guards in the tlparse
artifacts also.
When type is backed, aot_compile must be disabled for this mode to work.
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
"""
def
compute_hash
(
self
)
->
str
:
"""
...
...
@@ -455,8 +466,8 @@ class CompilationConfig:
We use string to avoid serialization issues when using compilation in a
distributed setting. When the compilation mode is 1 or 2, the backend is
used for the compilation directly (it sees the whole graph). When the
compilation mode is 3, the backend supports both whole graph and piecewise
compilation, available backends include eager, inductor, and custom backends,
compilation mode is 3, the backend supports both whole graph and piecewise
compilation, available backends include eager, inductor, and custom backends,
the latter of which can be defined via `get_compile_backend`. Furthermore,
compilation is only piecewise if splitting ops is set accordingly and
use_inductor_graph_partition is off. Note that the default options for
...
...
vllm/config/vllm.py
View file @
87aee9ed
...
...
@@ -66,7 +66,7 @@ class OptimizationLevel(IntEnum):
"""O0 : No optimization. no compilation, no cudagraphs, no other
optimization, just starting up immediately"""
O1
=
1
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
"""O1: Quick optimizations. Dynamo+Inductor compilation and Piecewise
cudagraphs"""
O2
=
2
"""O2: Full optimizations. -O1 as well as Full and Piecewise cudagraphs."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment