Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e82fbeec
Unverified
Commit
e82fbeec
authored
Mar 01, 2026
by
Richard Zou
Committed by
GitHub
Mar 01, 2026
Browse files
[torch.compile] Undo the fast_moe_cold_start hack in torch>=2.11 (#35475)
Signed-off-by:
Richard Zou
<
zou3519@gmail.com
>
parent
62904708
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
109 additions
and
10 deletions
+109
-10
vllm/config/vllm.py
vllm/config/vllm.py
+7
-1
vllm/env_override.py
vllm/env_override.py
+41
-0
vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
...el_executor/layers/fused_moe/runner/default_moe_runner.py
+26
-9
vllm/utils/torch_utils.py
vllm/utils/torch_utils.py
+35
-0
No files found.
vllm/config/vllm.py
View file @
e82fbeec
...
...
@@ -883,7 +883,13 @@ class VllmConfig:
self
.
compilation_config
.
pass_config
.
enable_sp
=
False
self
.
compilation_config
.
pass_config
.
fuse_gemm_comms
=
False
if
self
.
compilation_config
.
fast_moe_cold_start
is
None
:
from
vllm.utils.torch_utils
import
HAS_OPAQUE_TYPE
if
HAS_OPAQUE_TYPE
:
# On torch >= 2.11 the hoisted OpaqueObject approach supersedes
# fast_moe_cold_start, so force it off.
self
.
compilation_config
.
fast_moe_cold_start
=
False
elif
self
.
compilation_config
.
fast_moe_cold_start
is
None
:
# resolve default behavior: try to be as safe as possible
# this config is unsafe if any spec decoding draft model has a MOE.
# We'll conservatively turn it off if we see spec decoding.
...
...
vllm/env_override.py
View file @
e82fbeec
...
...
@@ -482,3 +482,44 @@ if is_torch_equal("2.9.0"):
PythonWrapperCodegen
.
memory_plan_reuse
=
memory_plan_reuse_patched
GraphLowering
.
_update_scheduler
=
_update_scheduler_patched
# ===================================================
# torch 2.11 Inductor constrain_to_fx_strides monkeypatch
# ===================================================
# Patch the inductor's `constrain_to_fx_strides` to handle opaque
# (non-tensor) arguments. The original calls `.stride()` on every FX
# arg's meta value, which crashes on FakeScriptObject (the compile-time
# proxy for hoisted opaque types). The patched version skips args
# whose meta value is not a torch.Tensor.
# Upstream issue: https://github.com/pytorch/pytorch/issues/175973
from
vllm.utils.torch_utils
import
is_torch_equal_or_newer
if
is_torch_equal_or_newer
(
"2.11.0.dev"
):
import
torch._inductor.ir
as
_ir
import
torch._inductor.lowering
as
_lowering
from
torch._inductor.virtualized
import
V
as
_V
_orig_constrain
=
_lowering
.
constrain_to_fx_strides
def
_patched_constrain_to_fx_strides
(
fx_node
,
*
args
,
**
kwargs
):
def
apply_constraint
(
arg
,
fx_arg
):
if
isinstance
(
arg
,
_ir
.
IRNode
):
meta_val
=
fx_arg
.
meta
.
get
(
"val"
)
if
isinstance
(
meta_val
,
torch
.
Tensor
):
stride_order
=
_ir
.
get_stride_order
(
meta_val
.
stride
(),
_V
.
graph
.
sizevars
.
shape_env
)
return
_ir
.
ExternKernel
.
require_stride_order
(
arg
,
stride_order
)
return
arg
if
isinstance
(
arg
,
dict
):
return
{
key
:
apply_constraint
(
arg
[
key
],
fx_arg
[
key
])
for
key
in
arg
}
return
arg
args
=
tuple
(
apply_constraint
(
arg
,
fx_arg
)
for
arg
,
fx_arg
in
zip
(
args
,
fx_node
.
args
)
)
kwargs
=
{
k
:
apply_constraint
(
v
,
fx_node
.
kwargs
[
k
])
for
k
,
v
in
kwargs
.
items
()}
return
args
,
kwargs
_lowering
.
constrain_to_fx_strides
=
_patched_constrain_to_fx_strides
vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
View file @
e82fbeec
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
contextlib
import
nullcontext
from
typing
import
TYPE_CHECKING
import
torch
import
torch.nn.functional
as
F
...
...
@@ -30,6 +31,8 @@ from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.utils.torch_utils
import
(
HAS_OPAQUE_TYPE
,
ModuleName
,
aux_stream
,
current_stream
,
direct_register_custom_op
,
...
...
@@ -56,13 +59,27 @@ def get_layer_from_name(layer_name: str) -> torch.nn.Module:
return
forward_context
.
no_compile_layers
[
layer_name
]
# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object;
# on older versions it remains a plain str.
if
TYPE_CHECKING
:
from
typing
import
TypeAlias
_layer_name_type
:
TypeAlias
=
str
|
ModuleName
else
:
_layer_name_type
=
ModuleName
if
HAS_OPAQUE_TYPE
else
str
def
_resolve_layer_name
(
layer_name
:
str
|
ModuleName
)
->
str
:
return
layer_name
.
value
if
isinstance
(
layer_name
,
ModuleName
)
else
layer_name
def
_moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
layer_name
:
str
,
layer_name
:
_layer_name_type
,
)
->
torch
.
Tensor
:
layer
=
get_layer_from_name
(
layer_name
)
layer
=
get_layer_from_name
(
_resolve_
layer_name
(
layer_name
)
)
# TODO(bnell): this can be removed after MK migration is complete.
layer
.
ensure_moe_quant_config_init
()
return
layer
.
runner
.
forward_impl
(
...
...
@@ -74,7 +91,7 @@ def _moe_forward_fake(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
layer_name
:
str
,
layer_name
:
_layer_name_type
,
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
...
...
@@ -83,9 +100,9 @@ def _moe_forward_shared(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
layer_name
:
str
,
layer_name
:
_layer_name_type
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
layer
=
get_layer_from_name
(
layer_name
)
layer
=
get_layer_from_name
(
_resolve_
layer_name
(
layer_name
)
)
# TODO(bnell): this can be removed after MK migration is complete.
layer
.
ensure_moe_quant_config_init
()
return
layer
.
runner
.
forward_impl
(
...
...
@@ -97,7 +114,7 @@ def _moe_forward_shared_fake(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
layer_name
:
str
,
layer_name
:
_layer_name_type
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Output shapes:
# - fused_out: same as hidden_states (routed experts use transformed size)
...
...
@@ -105,12 +122,10 @@ def _moe_forward_shared_fake(
# hidden_states
# (For latent MoE: shared experts use original hidden_size, not latent size)
fused_out
=
torch
.
empty_like
(
hidden_states
)
if
shared_experts_input
is
not
None
:
shared_out
=
torch
.
empty_like
(
shared_experts_input
)
else
:
shared_out
=
torch
.
empty_like
(
hidden_states
)
return
shared_out
,
fused_out
...
...
@@ -367,7 +382,9 @@ class DefaultMoERunner(MoERunner):
assert
len
(
trunc_sizes
)
==
1
return
func
(
states
,
trunc_sizes
[
0
])
def
_encode_layer_name
(
self
)
->
str
:
def
_encode_layer_name
(
self
)
->
str
|
ModuleName
:
if
HAS_OPAQUE_TYPE
:
return
ModuleName
(
self
.
layer_name
)
# Can be unavailable or None in unittests
if
(
is_forward_context_available
()
...
...
vllm/utils/torch_utils.py
View file @
e82fbeec
...
...
@@ -740,6 +740,41 @@ def is_torch_equal(target: str) -> bool:
return
Version
(
importlib
.
metadata
.
version
(
"torch"
))
==
Version
(
target
)
HAS_OPAQUE_TYPE
=
is_torch_equal_or_newer
(
"2.11.0.dev"
)
if
HAS_OPAQUE_TYPE
:
from
torch._opaque_base
import
OpaqueBase
else
:
OpaqueBase
=
object
# type: ignore[misc, assignment]
class
ModuleName
(
OpaqueBase
):
# type: ignore[misc]
"""Wraps a module name string for use as a torch opaque type.
When torch >= 2.11, this is registered as a hoisted value-type opaque
object so that torch.compile lifts it as a graph input instead of baking
it as a constant. This avoids per-layer recompilation for MOE ops.
"""
def
__init__
(
self
,
value
:
str
):
self
.
value
=
value
def
__eq__
(
self
,
other
):
return
isinstance
(
other
,
ModuleName
)
and
self
.
value
==
other
.
value
def
__hash__
(
self
):
return
hash
(
self
.
value
)
def
__fx_repr__
(
self
):
return
(
f
"ModuleName(
{
self
.
value
!
r
}
)"
,
{
ModuleName
})
if
HAS_OPAQUE_TYPE
:
from
torch._library.opaque_object
import
register_opaque_type
register_opaque_type
(
ModuleName
,
typ
=
"value"
,
hoist
=
True
)
# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform
def
supports_xccl
()
->
bool
:
return
torch
.
distributed
.
is_xccl_available
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment