Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ba4a78eb
Unverified
Commit
ba4a78eb
authored
Apr 09, 2026
by
Richard Zou
Committed by
GitHub
Apr 08, 2026
Browse files
[torch.compile] Allow usage of Opaque Objects in PyTorch 2.11 (#39286)
Signed-off-by:
Richard Zou
<
zou3519@gmail.com
>
parent
f3c7941e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
69 additions
and
45 deletions
+69
-45
vllm/compilation/compiler_interface.py
vllm/compilation/compiler_interface.py
+4
-44
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+7
-0
vllm/env_override.py
vllm/env_override.py
+54
-0
vllm/utils/torch_utils.py
vllm/utils/torch_utils.py
+1
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-0
No files found.
vllm/compilation/compiler_interface.py
View file @
ba4a78eb
...
...
@@ -16,6 +16,7 @@ import vllm.envs as envs
from
vllm.compilation.counter
import
compilation_counter
from
vllm.config
import
VllmConfig
from
vllm.config.utils
import
Range
from
vllm.env_override
import
_apply_constrain_to_fx_strides_patch
from
vllm.logger
import
init_logger
from
vllm.utils.hashing
import
safe_hash
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
...
...
@@ -225,48 +226,6 @@ def _patch_standalone_compile_atomic_save() -> None:
logger
.
debug
(
"Patched %s.save for atomic writes (torch < 2.10)"
,
cls
.
__name__
)
def
_patch_constrain_to_fx_strides
()
->
contextlib
.
AbstractContextManager
:
"""Context manager that patches inductor's ``constrain_to_fx_strides``
to handle opaque (non-tensor) arguments.
The original calls ``.stride()`` on every FX arg's meta value, which
crashes on ``FakeScriptObject`` (the compile-time proxy for hoisted
opaque types). The patched version skips args whose meta value is
not a ``torch.Tensor``.
Returns ``nullcontext`` on torch < 2.11.
Upstream issue: https://github.com/pytorch/pytorch/issues/175973
"""
if
not
is_torch_equal_or_newer
(
"2.11.0.dev"
):
return
contextlib
.
nullcontext
()
import
torch._inductor.ir
as
_ir
import
torch._inductor.lowering
as
_lowering
from
torch._inductor.virtualized
import
V
as
_V
def
_patched
(
fx_node
,
*
args
,
**
kwargs
):
def
apply_constraint
(
arg
,
fx_arg
):
if
isinstance
(
arg
,
_ir
.
IRNode
):
meta_val
=
fx_arg
.
meta
.
get
(
"val"
)
if
isinstance
(
meta_val
,
torch
.
Tensor
):
stride_order
=
_ir
.
get_stride_order
(
meta_val
.
stride
(),
_V
.
graph
.
sizevars
.
shape_env
)
return
_ir
.
ExternKernel
.
require_stride_order
(
arg
,
stride_order
)
return
arg
if
isinstance
(
arg
,
dict
):
return
{
key
:
apply_constraint
(
arg
[
key
],
fx_arg
[
key
])
for
key
in
arg
}
return
arg
args
=
tuple
(
apply_constraint
(
arg
,
fx_arg
)
for
arg
,
fx_arg
in
zip
(
args
,
fx_node
.
args
)
)
kwargs
=
{
k
:
apply_constraint
(
v
,
fx_node
.
kwargs
[
k
])
for
k
,
v
in
kwargs
.
items
()}
return
args
,
kwargs
return
patch
.
object
(
_lowering
,
"constrain_to_fx_strides"
,
_patched
)
class
InductorStandaloneAdaptor
(
CompilerInterface
):
"""
The adaptor for the Inductor compiler.
...
...
@@ -304,6 +263,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
compile_range
:
Range
,
key
:
str
|
None
=
None
,
)
->
tuple
[
Callable
[...,
Any
]
|
None
,
Any
|
None
]:
_apply_constrain_to_fx_strides_patch
()
compilation_counter
.
num_inductor_compiles
+=
1
current_config
=
{}
if
compiler_config
is
not
None
:
...
...
@@ -387,7 +347,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
else
:
fake_mode_ctx
=
contextlib
.
nullcontext
()
with
pregrad_ctx
,
fake_mode_ctx
,
_patch_constrain_to_fx_strides
()
:
with
pregrad_ctx
,
fake_mode_ctx
:
compiled_graph
=
standalone_compile
(
graph
,
example_inputs
,
**
compile_kwargs
)
if
use_aot
:
...
...
@@ -502,6 +462,7 @@ class InductorAdaptor(CompilerInterface):
compile_range
:
Range
,
key
:
str
|
None
=
None
,
)
->
tuple
[
Callable
[...,
Any
]
|
None
,
Any
|
None
]:
_apply_constrain_to_fx_strides_patch
()
compilation_counter
.
num_inductor_compiles
+=
1
from
torch._inductor.compile_fx
import
compile_fx
...
...
@@ -630,7 +591,6 @@ class InductorAdaptor(CompilerInterface):
stack
.
enter_context
(
torch
.
_functorch
.
config
.
patch
(
enable_remote_autograd_cache
=
False
)
)
stack
.
enter_context
(
_patch_constrain_to_fx_strides
())
# Clear the tracing context before calling compile_fx.
# vLLM calls compile_fx from within a PiecewiseCompileInterpreter
...
...
vllm/compilation/wrapper.py
View file @
ba4a78eb
...
...
@@ -143,6 +143,13 @@ class TorchCompileWithNoGuardsWrapper:
compiled_ptr
=
self
.
check_invariants_and_forward
# Apply the constrain_to_fx_strides patch before first compilation.
# This covers STOCK_TORCH_COMPILE and DYNAMO_ONCE paths. The VLLM
# compile paths call this from their own compile() methods too.
from
vllm.env_override
import
_apply_constrain_to_fx_strides_patch
_apply_constrain_to_fx_strides_patch
()
aot_context
=
nullcontext
()
if
envs
.
VLLM_USE_AOT_COMPILE
:
if
hasattr
(
torch
.
_dynamo
.
config
,
"enable_aot_compile"
):
...
...
vllm/env_override.py
View file @
ba4a78eb
...
...
@@ -500,6 +500,60 @@ if is_torch_equal("2.9.0"):
# This mirrors the fix in https://github.com/pytorch/pytorch/pull/177558
# and can be removed once torch >=2.12 is the minimum supported version.
# ===================================================
# torch >= 2.11 Inductor constrain_to_fx_strides monkeypatch
# ===================================================
# Inductor's constrain_to_fx_strides calls .stride() on every FX arg's meta
# value, which crashes on FakeScriptObject (the compile-time proxy for hoisted
# opaque types). The patched version skips args whose meta value is not a
# torch.Tensor.
# Upstream issue: https://github.com/pytorch/pytorch/issues/175973
_constrain_to_fx_strides_patched
=
False
def
_apply_constrain_to_fx_strides_patch
():
"""Patch lowering.constrain_to_fx_strides globally. Safe to call
multiple times; only the first call does anything.
Only applies for torch >= 2.11 and < 2.12."""
global
_constrain_to_fx_strides_patched
if
_constrain_to_fx_strides_patched
:
return
_constrain_to_fx_strides_patched
=
True
if
not
is_torch_equal_or_newer
(
"2.11.0.dev"
)
or
is_torch_equal_or_newer
(
"2.12.0.dev"
):
return
import
torch._inductor.ir
as
_ir
import
torch._inductor.lowering
as
_lowering
from
torch._inductor.virtualized
import
V
as
_V
def
_patched
(
fx_node
,
*
args
,
**
kwargs
):
def
apply_constraint
(
arg
,
fx_arg
):
if
isinstance
(
arg
,
_ir
.
IRNode
):
meta_val
=
fx_arg
.
meta
.
get
(
"val"
)
if
isinstance
(
meta_val
,
torch
.
Tensor
):
stride_order
=
_ir
.
get_stride_order
(
meta_val
.
stride
(),
_V
.
graph
.
sizevars
.
shape_env
)
return
_ir
.
ExternKernel
.
require_stride_order
(
arg
,
stride_order
)
return
arg
if
isinstance
(
arg
,
dict
):
return
{
key
:
apply_constraint
(
arg
[
key
],
fx_arg
[
key
])
for
key
in
arg
}
return
arg
args
=
tuple
(
apply_constraint
(
arg
,
fx_arg
)
for
arg
,
fx_arg
in
zip
(
args
,
fx_node
.
args
)
)
kwargs
=
{
k
:
apply_constraint
(
v
,
fx_node
.
kwargs
[
k
])
for
k
,
v
in
kwargs
.
items
()}
return
args
,
kwargs
_lowering
.
constrain_to_fx_strides
=
_patched
if
is_torch_equal_or_newer
(
"2.10.0"
)
and
not
is_torch_equal_or_newer
(
"2.12.0"
):
import
builtins
as
_builtins
import
pickle
...
...
vllm/utils/torch_utils.py
View file @
ba4a78eb
...
...
@@ -706,7 +706,7 @@ def is_torch_equal(target: str) -> bool:
return
Version
(
importlib
.
metadata
.
version
(
"torch"
))
==
Version
(
target
)
HAS_OPAQUE_TYPE
=
is_torch_equal_or_newer
(
"2.1
2
.0.dev"
)
HAS_OPAQUE_TYPE
=
is_torch_equal_or_newer
(
"2.1
1
.0.dev"
)
if
HAS_OPAQUE_TYPE
:
from
torch._opaque_base
import
OpaqueBase
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
ba4a78eb
...
...
@@ -4857,6 +4857,9 @@ class GPUModelRunner(
self
.
vllm_config
.
compilation_config
.
mode
==
CompilationMode
.
STOCK_TORCH_COMPILE
):
from
vllm.env_override
import
_apply_constrain_to_fx_strides_patch
_apply_constrain_to_fx_strides_patch
()
backend
=
self
.
vllm_config
.
compilation_config
.
init_backend
(
self
.
vllm_config
)
compilation_counter
.
stock_torch_compile_count
+=
1
self
.
model
.
compile
(
fullgraph
=
True
,
backend
=
backend
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment