Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
87aee9ed
Unverified
Commit
87aee9ed
authored
Dec 08, 2025
by
Laith Sakka
Committed by
GitHub
Dec 08, 2025
Browse files
Add evaluate_guards option to DynamicShapesConfig (#27432)
Signed-off-by:
Laith Sakka
<
lsakka@meta.com
>
parent
184076c3
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
218 additions
and
27 deletions
+218
-27
tests/compile/test_dynamic_shapes_compilation.py
tests/compile/test_dynamic_shapes_compilation.py
+132
-7
vllm/compilation/backends.py
vllm/compilation/backends.py
+24
-2
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+2
-2
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+46
-13
vllm/config/compilation.py
vllm/config/compilation.py
+14
-3
No files found.
tests/compile/test_dynamic_shapes_compilation.py
View file @
87aee9ed
...
...
@@ -2,12 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
gc
import
tempfile
from
contextlib
import
contextmanager
import
pytest
import
torch
from
vllm
import
LLM
,
SamplingParams
from
vllm.config.compilation
import
CompilationMode
,
DynamicShapesType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CompilationConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.config.compilation
import
(
CompilationMode
,
DynamicShapesConfig
,
DynamicShapesType
,
)
from
vllm.forward_context
import
set_forward_context
from
vllm.tokenizers
import
get_tokenizer
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
...
...
@@ -29,18 +38,19 @@ def get_test_models():
)
@
pytest
.
mark
.
parametrize
(
"use_aot_compile"
,
[
"0"
])
@
pytest
.
mark
.
parametrize
(
"use_bytecode_hook"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"evaluate_guards"
,
[
False
,
True
])
@
pytest
.
mark
.
skipif
(
not
is_torch_equal_or_newer
(
"2.10.0.dev"
),
reason
=
"requires torch 2.10"
)
def
test_dynamic_shapes_compilation
(
monkeypatch
,
model_name
,
shapes_type
,
use_aot_compile
,
use_bytecode_hook
monkeypatch
,
model_name
,
shapes_type
,
use_aot_compile
,
use_bytecode_hook
,
evaluate_guards
,
):
"""Test that all dynamic shapes types compile successfully"""
print
(
f
"
\n
Testing model:
{
model_name
}
with
{
shapes_type
.
name
}
, "
f
"AOT compile:
{
use_aot_compile
}
, "
f
"Bytecode hook:
{
use_bytecode_hook
}
"
)
if
use_bytecode_hook
and
shapes_type
==
DynamicShapesType
.
UNBACKED
:
pytest
.
skip
(
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0"
)
...
...
@@ -58,6 +68,7 @@ def test_dynamic_shapes_compilation(
"mode"
:
CompilationMode
.
VLLM_COMPILE
,
"dynamic_shapes_config"
:
{
"type"
:
shapes_type
.
value
,
"evaluate_guards"
:
evaluate_guards
,
},
},
)
...
...
@@ -86,3 +97,117 @@ def test_dynamic_shapes_compilation(
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
synchronize
()
print
(
"GPU memory cleared"
)
@
pytest
.
mark
.
parametrize
(
"use_aot_compile"
,
[
"0"
,
"1"
])
@
pytest
.
mark
.
parametrize
(
"dynamic_shapes_type"
,
[
DynamicShapesType
.
BACKED
,
DynamicShapesType
.
BACKED_SIZE_OBLIVIOUS
,
],
)
@
pytest
.
mark
.
parametrize
(
"evaluate_guards"
,
[
False
,
True
])
def
test_model_specialization_with_evaluate_guards
(
monkeypatch
,
use_aot_compile
,
dynamic_shapes_type
,
evaluate_guards
):
"""Test that evaluate_guards correctly detects shape specialization
violations.
"""
if
(
use_aot_compile
==
"1"
and
dynamic_shapes_type
==
DynamicShapesType
.
BACKED
and
evaluate_guards
):
pytest
.
skip
(
"evaluate_guards for backed does not work with aot_compile =1"
)
@
support_torch_compile
class
ModelWithSizeCheck
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
):
# This will cause specialization - torch.compile will guard on
# sx.shape[0]
if
x
.
shape
[
0
]
>=
10
:
return
x
*
10
else
:
return
x
*
10
@
support_torch_compile
class
ModelWithOneSizeCheck
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
):
# This will cause 0/1 specializations.
if
x
.
shape
[
0
]
==
0
:
return
x
*
10
if
x
.
shape
[
0
]
==
1
:
return
x
*
10
else
:
return
x
*
10
@
contextmanager
def
use_vllm_config
(
vllm_config
:
VllmConfig
):
with
set_forward_context
({},
vllm_config
),
set_current_vllm_config
(
vllm_config
):
yield
monkeypatch
.
setenv
(
"TOKENIZERS_PARALLELISM"
,
"true"
)
monkeypatch
.
setenv
(
"VLLM_USE_AOT_COMPILE"
,
use_aot_compile
)
monkeypatch
.
setenv
(
"VLLM_USE_BYTECODE_HOOK"
,
"0"
)
# Create vllm config with the desired settings
from
vllm.config
import
CompilationMode
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
mode
=
CompilationMode
.
VLLM_COMPILE
,
dynamic_shapes_config
=
DynamicShapesConfig
(
type
=
dynamic_shapes_type
,
evaluate_guards
=
evaluate_guards
,
),
)
)
def
test
(
model_class
,
input1
,
input2
,
is_01_specialization
=
False
):
with
(
torch
.
no_grad
(),
use_vllm_config
(
vllm_config
),
tempfile
.
TemporaryDirectory
()
as
tmpdirname
,
):
monkeypatch
.
setenv
(
"VLLM_CACHE_ROOT"
,
tmpdirname
)
model
=
model_class
(
vllm_config
=
vllm_config
).
cuda
()
model
(
input1
)
if
evaluate_guards
and
(
not
(
is_01_specialization
and
dynamic_shapes_type
==
DynamicShapesType
.
BACKED
)
):
# This should fail because guards were added.
with
pytest
.
raises
(
RuntimeError
)
as
excinfo
:
model
(
input2
)
# Expected failure - guard was violated
error_msg
=
str
(
excinfo
.
value
)
assert
(
"GuardManager check failed"
in
error_msg
or
"Detected recompile when torch.compile stance"
in
error_msg
),
error_msg
else
:
model
(
input2
)
test
(
ModelWithSizeCheck
,
torch
.
randn
(
20
,
10
).
cuda
(),
torch
.
randn
(
5
,
10
).
cuda
())
test
(
ModelWithSizeCheck
,
torch
.
randn
(
5
,
10
).
cuda
(),
torch
.
randn
(
20
,
10
).
cuda
())
test
(
ModelWithOneSizeCheck
,
torch
.
randn
(
20
,
10
).
cuda
(),
torch
.
randn
(
1
,
10
).
cuda
(),
is_01_specialization
=
True
,
)
vllm/compilation/backends.py
View file @
87aee9ed
...
...
@@ -26,6 +26,7 @@ from vllm.compilation.partition_rules import (
should_split
,
)
from
vllm.config
import
CompilationConfig
,
CUDAGraphMode
,
VllmConfig
from
vllm.config.compilation
import
DynamicShapesType
from
vllm.config.utils
import
Range
,
hash_factors
from
vllm.logger
import
init_logger
from
vllm.logging_utils
import
lazy
...
...
@@ -722,6 +723,29 @@ class VllmBackend:
self
.
split_gm
,
submod_names_to_compile
,
self
.
vllm_config
,
self
).
run
(
*
fake_args
)
from
torch._guards
import
detect_fake_mode
fake_mode
=
detect_fake_mode
()
if
(
self
.
compilation_config
.
dynamic_shapes_config
.
evaluate_guards
and
self
.
compilation_config
.
dynamic_shapes_config
.
type
==
DynamicShapesType
.
BACKED
):
from
torch.utils._sympy.value_ranges
import
ValueRanges
# Drop counter-0/1 specializations guards; for backed dynamic shapes,
# torch.compile will specialize for 0/1 inputs or otherwise guards that
# shape is >= 2. This is because it's really hard not to hit a check
# against 0/1. When we evaluate shape guards, we exclude checking those
# guards (We would fail always otherwise).
# We avoid that by updating the ranges of backed sizes when the min is
# 2 for any, we assume it's 0.
for
s
,
r
in
fake_mode
.
shape_env
.
var_to_range
.
items
():
if
r
.
lower
==
2
:
fake_mode
.
shape_env
.
var_to_range
[
s
]
=
ValueRanges
(
0
,
r
.
upper
)
graph_path
=
os
.
path
.
join
(
local_cache_dir
,
"computation_graph.py"
)
if
not
os
.
path
.
exists
(
graph_path
):
# code adapted from
...
...
@@ -749,8 +773,6 @@ class VllmBackend:
graph
,
example_inputs
,
self
.
prefix
,
self
.
split_gm
)
# if we need to copy input buffers for cudagraph
#
# index of tensors that have symbolic shapes (batch size)
# for weights and static buffers, they will have concrete shapes.
# symbolic shape only happens for input tensors.
...
...
vllm/compilation/decorators.py
View file @
87aee9ed
...
...
@@ -392,7 +392,6 @@ def _support_torch_compile(
factors
.
append
(
_model_hash_key
(
self
.
forward
))
hash_key
=
hashlib
.
sha256
(
str
(
factors
).
encode
()).
hexdigest
()
cache_dir
=
os
.
path
.
join
(
envs
.
VLLM_CACHE_ROOT
,
"torch_aot_compile"
,
...
...
@@ -413,6 +412,7 @@ def _support_torch_compile(
f
,
f_globals
=
self
.
forward
.
__globals__
)
_verify_source_unchanged
(
loaded_fn
.
source_info
(),
self
.
vllm_config
)
if
not
self
.
compilation_config
.
dynamic_shapes_config
.
evaluate_guards
:
loaded_fn
.
disable_guard_check
()
self
.
aot_compiled_fn
=
loaded_fn
except
Exception
as
e
:
...
...
vllm/compilation/wrapper.py
View file @
87aee9ed
...
...
@@ -4,7 +4,7 @@
import
os
import
sys
from
abc
import
abstractmethod
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
,
nullcontext
from
types
import
CodeType
from
typing
import
Any
...
...
@@ -13,6 +13,7 @@ import torch._C._dynamo.guards
import
vllm.envs
as
envs
from
vllm.config
import
CompilationMode
,
CUDAGraphMode
,
get_current_vllm_config
from
vllm.config.compilation
import
DynamicShapesType
from
vllm.logger
import
init_logger
from
vllm.utils.nvtx_pytorch_hooks
import
layerwise_nvtx_marker_context
...
...
@@ -125,23 +126,49 @@ class TorchCompileWithNoGuardsWrapper:
if
isinstance
(
backend
,
str
)
and
backend
==
"inductor"
:
options
=
vllm_config
.
compilation_config
.
inductor_compile_config
self
.
first_compile
=
True
self
.
evaluate_guards
=
(
vllm_config
.
compilation_config
.
dynamic_shapes_config
.
evaluate_guards
)
ds_type
=
vllm_config
.
compilation_config
.
dynamic_shapes_config
.
type
if
mode
!=
CompilationMode
.
STOCK_TORCH_COMPILE
:
# Drop all the guards.
if
self
.
evaluate_guards
:
assert
not
envs
.
VLLM_USE_BYTECODE_HOOK
,
(
"compilation_config.dynamic_shapes_config.evaluate_guards "
"requires VLLM_USE_BYTECODE_HOOK=0. "
)
if
envs
.
VLLM_USE_AOT_COMPILE
:
# disabled until https://github.com/pytorch/pytorch/pull/169239
# is picked up.
assert
ds_type
!=
DynamicShapesType
.
BACKED
,
(
"evaluate_guards for backed shapes requires "
"VLLM_USE_AOT_COMPILE=False. "
)
options
[
"guard_filter_fn"
]
=
lambda
x
:
[
entry
.
guard_type
==
"SHAPE_ENV"
for
entry
in
x
]
else
:
options
[
"guard_filter_fn"
]
=
lambda
x
:
[
False
for
_
in
x
]
compiled_ptr
:
Any
=
self
.
forward
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
from
vllm.compilation.decorators
import
DynamicShapesType
ds_type
=
vllm_config
.
compilation_config
.
dynamic_shapes_config
.
type
compiled_ptr
:
Any
=
self
.
forward
if
ds_type
==
DynamicShapesType
.
UNBACKED
:
if
envs
.
VLLM_USE_BYTECODE_HOOK
:
# reason is that bytecode does this hack torch._dynamo.eval_frame.
# reason is that bytecode does torch._dynamo.eval_frame.
# remove_from_cache(self.original_code_object()) to force a new
# re-compilation.
raise
ValueError
(
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
# re-compilation. And if we use
# compiled_ptr = self.check_invariants_and_forward
# it will reset all entries.
assert
not
envs
.
VLLM_USE_BYTECODE_HOOK
,
(
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
)
assert
not
self
.
evaluate_guards
,
"UNBACKED dynamic shapes do not add guards"
compiled_ptr
=
self
.
check_invariants_and_forward
if
envs
.
VLLM_USE_AOT_COMPILE
:
...
...
@@ -195,7 +222,13 @@ class TorchCompileWithNoGuardsWrapper:
self
.
forward
,
*
args
,
**
kwargs
)
else
:
with
_compilation_context
():
ctx
=
(
nullcontext
()
if
self
.
first_compile
or
not
self
.
evaluate_guards
else
torch
.
compiler
.
set_stance
(
"fail_on_recompile"
)
)
self
.
first_compile
=
False
with
_compilation_context
(),
ctx
:
return
self
.
_call_with_optional_nvtx_range
(
self
.
_compiled_callable
,
*
args
,
**
kwargs
)
...
...
vllm/config/compilation.py
View file @
87aee9ed
...
...
@@ -344,7 +344,18 @@ class DynamicShapesConfig:
backed/unbacked.
"""
# TODO add a debug mode to fail
evaluate_guards
:
bool
=
False
"""
A debug mode to detect and fail if Dynamo ever specializes a dynamic shape by
guarding on it. When True, dynamic shape guards are not dropped from dynamo.
And a failure will be triggered if a recompilation ever happens due to that.
This mode requires VLLM_USE_BYTECODE_HOOK to be 0.
Enabling this allow observing the dynamic shapes guards in the tlparse
artifacts also.
When type is backed, aot_compile must be disabled for this mode to work.
until this change picked up https://github.com/pytorch/pytorch/pull/169239.
"""
def
compute_hash
(
self
)
->
str
:
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment