Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
654a71fc
Unverified
Commit
654a71fc
authored
Jan 22, 2026
by
Richard Zou
Committed by
GitHub
Jan 22, 2026
Browse files
[torch.compile] Improve Cold Start for MoEs (#32805)
Signed-off-by:
Richard Zou
<
zou3519@gmail.com
>
parent
15e302df
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
71 additions
and
9 deletions
+71
-9
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+5
-1
vllm/forward_context.py
vllm/forward_context.py
+34
-1
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+32
-7
No files found.
tests/kernels/moe/test_moe.py
View file @
654a71fc
...
...
@@ -23,7 +23,7 @@ from tests.kernels.utils import opcheck, stack_and_dev, torch_experts, torch_moe
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.distributed.parallel_state
import
init_distributed_environment
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.model_executor.layers.fused_moe
import
(
fused_topk
,
)
...
...
@@ -713,6 +713,10 @@ def test_mixtral_moe(
vllm_moe
.
experts
.
quant_method
.
process_weights_after_loading
(
vllm_moe
.
experts
)
# need to override the forward context for unittests, otherwise it assumes
# we're running the model forward pass (the model specified in vllm_config)
get_forward_context
().
remaining_moe_layers
=
None
# Run forward passes for both MoE blocks
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
vllm_states
=
vllm_moe
.
forward
(
vllm_inputs
)
...
...
vllm/forward_context.py
View file @
654a71fc
...
...
@@ -210,6 +210,30 @@ class ForwardContext:
# If True, bypass the compiled model call, e.g. by using .forward() directly
skip_compiled
:
bool
=
False
# For torch.compile cold start times, we need to avoid hard-coding
# any strings into the graph. Right now, the vllm.moe_forward
# and vllm.moe_forward_shared custom operators hard-code strings into
# the graph.
#
# The workaround is to store a list of the strings that each of those
# custom ops needs, in reverse order, in the ForwardContext.
# The ForwardContext object is alive for the duration of the forward pass.
# When the custom op needs the string, pop the string from this list.
#
# This assumes that the custom operators will always be executed in
# order and that torch.compile will not try to reorder these
# operations with respect to each other.
#
# TODO(https://github.com/vllm-project/vllm/issues/31985):
# There are longer-term solutions, like unwrapping the moe custom operator,
# that aren't ready yet.
# We could also treat the string as a "symbolic input" to the graph but
# the PyTorch-side bits for that aren't ready yet either.
#
# If this value is None (like in some tests), then we end up baking the string
# into the graph. Otherwise, the moe custom ops will pop a string from this list.
remaining_moe_layers
:
list
[
str
]
|
None
=
None
additional_kwargs
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
def
__post_init__
(
self
):
...
...
@@ -245,8 +269,17 @@ def create_forward_context(
additional_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
skip_compiled
:
bool
=
False
,
):
no_compile_layers
=
vllm_config
.
compilation_config
.
static_forward_context
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
remaining_moe_layers
=
[
name
for
name
,
layer
in
no_compile_layers
.
items
()
if
isinstance
(
layer
,
FusedMoE
)
]
remaining_moe_layers
.
reverse
()
return
ForwardContext
(
no_compile_layers
=
vllm_config
.
compilation_config
.
static_forward_context
,
no_compile_layers
=
no_compile_layers
,
remaining_moe_layers
=
remaining_moe_layers
,
virtual_engine
=
virtual_engine
,
attn_metadata
=
attn_metadata
,
dp_metadata
=
dp_metadata
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
654a71fc
...
...
@@ -22,7 +22,11 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce
,
)
from
vllm.distributed.eplb.eplb_state
import
EplbLayerState
,
EplbState
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
(
ForwardContext
,
get_forward_context
,
is_forward_context_available
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.fused_moe.config
import
(
...
...
@@ -1564,6 +1568,15 @@ class FusedMoE(CustomOp):
states
=
self
.
maybe_all_reduce_tensor_model_parallel
(
states
)
return
states
def
encode_layer_name
()
->
str
:
# Can be unavailable or None in unittests
if
(
is_forward_context_available
()
and
get_forward_context
().
remaining_moe_layers
is
not
None
):
return
"from_forward_context"
return
self
.
layer_name
if
self
.
shared_experts
is
None
:
if
current_platform
.
is_tpu
()
or
current_platform
.
is_cpu
():
# TODO: Once the OOM issue for the TPU backend is resolved, we
...
...
@@ -1573,7 +1586,7 @@ class FusedMoE(CustomOp):
assert
not
isinstance
(
fused_output
,
tuple
)
else
:
fused_output
=
torch
.
ops
.
vllm
.
moe_forward
(
hidden_states
,
router_logits
,
self
.
layer_name
hidden_states
,
router_logits
,
encode_
layer_name
()
)
return
reduce_output
(
fused_output
)[...,
:
og_hidden_states
]
else
:
...
...
@@ -1586,7 +1599,7 @@ class FusedMoE(CustomOp):
)
else
:
shared_output
,
fused_output
=
torch
.
ops
.
vllm
.
moe_forward_shared
(
hidden_states
,
router_logits
,
self
.
layer_name
hidden_states
,
router_logits
,
encode_
layer_name
()
)
return
(
reduce_output
(
shared_output
)[...,
:
og_hidden_states
],
...
...
@@ -1936,13 +1949,26 @@ class FusedMoE(CustomOp):
return
s
def
get_layer_from_name
(
layer_name
:
str
)
->
FusedMoE
:
forward_context
:
ForwardContext
=
get_forward_context
()
if
layer_name
==
"from_forward_context"
:
if
not
forward_context
.
remaining_moe_layers
:
raise
AssertionError
(
"We expected the number of MOE layers in `remaining_moe_layers` "
"to be equal to the number of "
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
)
layer_name
=
forward_context
.
remaining_moe_layers
.
pop
()
self
=
cast
(
FusedMoE
,
forward_context
.
no_compile_layers
[
layer_name
])
return
self
def
moe_forward
(
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
torch
.
Tensor
:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
=
get_layer_from_name
(
layer_name
)
assert
self
.
shared_experts
is
None
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
...
...
@@ -1969,8 +1995,7 @@ def moe_forward_shared(
router_logits
:
torch
.
Tensor
,
layer_name
:
str
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
forward_context
:
ForwardContext
=
get_forward_context
()
self
=
forward_context
.
no_compile_layers
[
layer_name
]
self
=
get_layer_from_name
(
layer_name
)
assert
self
.
shared_experts
is
not
None
return
self
.
forward_impl
(
hidden_states
,
router_logits
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment