Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
54756b61
Unverified
Commit
54756b61
authored
Mar 06, 2026
by
Richard Zou
Committed by
GitHub
Mar 06, 2026
Browse files
[compile] Stop unconditionally patching constrain_to_fx_strides (#36152)
Signed-off-by:
Richard Zou
<
zou3519@gmail.com
>
parent
39f9ea0d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
42 deletions
+44
-42
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+44
-1
vllm/env_override.py
vllm/env_override.py
+0
-41
No files found.
vllm/compilation/compiler_interface.py
View file @
54756b61
...
...
@@ -225,6 +225,48 @@ def _patch_standalone_compile_atomic_save() -> None:
logger
.
debug
(
"Patched %s.save for atomic writes (torch < 2.10)"
,
cls
.
__name__
)
def
_patch_constrain_to_fx_strides
()
->
contextlib
.
AbstractContextManager
:
"""Context manager that patches inductor's ``constrain_to_fx_strides``
to handle opaque (non-tensor) arguments.
The original calls ``.stride()`` on every FX arg's meta value, which
crashes on ``FakeScriptObject`` (the compile-time proxy for hoisted
opaque types). The patched version skips args whose meta value is
not a ``torch.Tensor``.
Returns ``nullcontext`` on torch < 2.11.
Upstream issue: https://github.com/pytorch/pytorch/issues/175973
"""
if
not
is_torch_equal_or_newer
(
"2.11.0.dev"
):
return
contextlib
.
nullcontext
()
import
torch._inductor.ir
as
_ir
import
torch._inductor.lowering
as
_lowering
from
torch._inductor.virtualized
import
V
as
_V
def
_patched
(
fx_node
,
*
args
,
**
kwargs
):
def
apply_constraint
(
arg
,
fx_arg
):
if
isinstance
(
arg
,
_ir
.
IRNode
):
meta_val
=
fx_arg
.
meta
.
get
(
"val"
)
if
isinstance
(
meta_val
,
torch
.
Tensor
):
stride_order
=
_ir
.
get_stride_order
(
meta_val
.
stride
(),
_V
.
graph
.
sizevars
.
shape_env
)
return
_ir
.
ExternKernel
.
require_stride_order
(
arg
,
stride_order
)
return
arg
if
isinstance
(
arg
,
dict
):
return
{
key
:
apply_constraint
(
arg
[
key
],
fx_arg
[
key
])
for
key
in
arg
}
return
arg
args
=
tuple
(
apply_constraint
(
arg
,
fx_arg
)
for
arg
,
fx_arg
in
zip
(
args
,
fx_node
.
args
)
)
kwargs
=
{
k
:
apply_constraint
(
v
,
fx_node
.
kwargs
[
k
])
for
k
,
v
in
kwargs
.
items
()}
return
args
,
kwargs
return
patch
.
object
(
_lowering
,
"constrain_to_fx_strides"
,
_patched
)
class
InductorStandaloneAdaptor
(
CompilerInterface
):
"""
The adaptor for the Inductor compiler.
...
...
@@ -312,7 +354,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
"torch._inductor.compile_fx._recursive_pre_grad_passes"
,
lambda
gm
,
_
:
gm
,
)
with
ctx
:
with
ctx
,
_patch_constrain_to_fx_strides
()
:
compiled_graph
=
standalone_compile
(
graph
,
example_inputs
,
**
compile_kwargs
)
if
use_aot
:
...
...
@@ -555,6 +597,7 @@ class InductorAdaptor(CompilerInterface):
stack
.
enter_context
(
torch
.
_functorch
.
config
.
patch
(
enable_remote_autograd_cache
=
False
)
)
stack
.
enter_context
(
_patch_constrain_to_fx_strides
())
compiled_graph
=
compile_fx
(
graph
,
...
...
vllm/env_override.py
View file @
54756b61
...
...
@@ -482,44 +482,3 @@ if is_torch_equal("2.9.0"):
PythonWrapperCodegen
.
memory_plan_reuse
=
memory_plan_reuse_patched
GraphLowering
.
_update_scheduler
=
_update_scheduler_patched
# ===================================================
# torch 2.11 Inductor constrain_to_fx_strides monkeypatch
# ===================================================
# Patch the inductor's `constrain_to_fx_strides` to handle opaque
# (non-tensor) arguments. The original calls `.stride()` on every FX
# arg's meta value, which crashes on FakeScriptObject (the compile-time
# proxy for hoisted opaque types). The patched version skips args
# whose meta value is not a torch.Tensor.
# Upstream issue: https://github.com/pytorch/pytorch/issues/175973
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
if
is_torch_equal_or_newer
(
"2.11.0.dev"
):
import
torch._inductor.ir
as
_ir
import
torch._inductor.lowering
as
_lowering
from
torch._inductor.virtualized
import
V
as
_V
_orig_constrain
=
_lowering
.
constrain_to_fx_strides
def
_patched_constrain_to_fx_strides
(
fx_node
,
*
args
,
**
kwargs
):
def
apply_constraint
(
arg
,
fx_arg
):
if
isinstance
(
arg
,
_ir
.
IRNode
):
meta_val
=
fx_arg
.
meta
.
get
(
"val"
)
if
isinstance
(
meta_val
,
torch
.
Tensor
):
stride_order
=
_ir
.
get_stride_order
(
meta_val
.
stride
(),
_V
.
graph
.
sizevars
.
shape_env
)
return
_ir
.
ExternKernel
.
require_stride_order
(
arg
,
stride_order
)
return
arg
if
isinstance
(
arg
,
dict
):
return
{
key
:
apply_constraint
(
arg
[
key
],
fx_arg
[
key
])
for
key
in
arg
}
return
arg
args
=
tuple
(
apply_constraint
(
arg
,
fx_arg
)
for
arg
,
fx_arg
in
zip
(
args
,
fx_node
.
args
)
)
kwargs
=
{
k
:
apply_constraint
(
v
,
fx_node
.
kwargs
[
k
])
for
k
,
v
in
kwargs
.
items
()}
return
args
,
kwargs
_lowering
.
constrain_to_fx_strides
=
_patched_constrain_to_fx_strides
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment