Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d9aa39a3
Unverified
Commit
d9aa39a3
authored
Jan 27, 2026
by
Richard Zou
Committed by
GitHub
Jan 27, 2026
Browse files
[torch.compile] Speed up MOE handling in forward_context (#33184)
Signed-off-by:
Richard Zou
<
zou3519@gmail.com
>
parent
3a6d5cbe
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
18 deletions
+22
-18
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+1
-1
vllm/config/compilation.py
vllm/config/compilation.py
+4
-0
vllm/forward_context.py
vllm/forward_context.py
+8
-13
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+9
-4
No files found.
tests/kernels/moe/test_moe.py
View file @
d9aa39a3
...
...
@@ -715,7 +715,7 @@ def test_mixtral_moe(
# need to override the forward context for unittests, otherwise it assumes
# we're running the model forward pass (the model specified in vllm_config)
get_forward_context
().
remaining
_moe_layers
=
None
get_forward_context
().
all
_moe_layers
=
None
# Run forward passes for both MoE blocks
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
...
...
vllm/config/compilation.py
View file @
d9aa39a3
...
...
@@ -597,6 +597,10 @@ class CompilationConfig:
Map from layer name to layer objects that need to be accessed outside
model code, e.g., Attention, FusedMOE when dp_size>1."""
static_all_moe_layers
:
list
[
str
]
=
field
(
default_factory
=
list
,
init
=
False
)
"""The names of all the MOE layers in the model
"""
# Attention ops; used for piecewise cudagraphs
# Use PyTorch operator format: "namespace::name"
_attention_ops
:
ClassVar
[
list
[
str
]]
=
[
...
...
vllm/forward_context.py
View file @
d9aa39a3
...
...
@@ -217,9 +217,11 @@ class ForwardContext:
# the graph.
#
# The workaround is to store a list of the strings that each of those
# custom ops needs, in reverse order, in the ForwardContext.
# custom ops needs in the ForwardContext (all_moe_layers)
# as well as a counter (moe_layer_index).
# The ForwardContext object is alive for the duration of the forward pass.
# When the custom op needs the string, pop the string from this list.
# When the custom op needs a layer string, get the next string
# from all_moe_layers and increment the counter.
#
# This assumes that the custom operators will always be executed in
# order and that torch.compile will not try to reorder these
...
...
@@ -233,7 +235,8 @@ class ForwardContext:
#
# If this value is None (like in some tests), then we end up baking the string
# into the graph. Otherwise, the moe custom ops will pop a string from this list.
remaining_moe_layers
:
list
[
str
]
|
None
=
None
all_moe_layers
:
list
[
str
]
|
None
=
None
moe_layer_index
:
int
=
0
additional_kwargs
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
...
...
@@ -271,17 +274,9 @@ def create_forward_context(
additional_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
skip_compiled
:
bool
=
False
,
):
no_compile_layers
=
vllm_config
.
compilation_config
.
static_forward_context
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
remaining_moe_layers
=
[
name
for
name
,
layer
in
no_compile_layers
.
items
()
if
isinstance
(
layer
,
FusedMoE
)
]
remaining_moe_layers
.
reverse
()
return
ForwardContext
(
no_compile_layers
=
no_compile_layers
,
remaining
_moe_layers
=
remaining
_moe_layers
,
no_compile_layers
=
vllm_config
.
compilation_config
.
static_forward_context
,
all
_moe_layers
=
vllm_config
.
compilation_config
.
static_all
_moe_layers
,
virtual_engine
=
virtual_engine
,
attn_metadata
=
attn_metadata
,
slot_mapping
=
slot_mapping
or
{},
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
d9aa39a3
...
...
@@ -407,6 +407,7 @@ class FusedMoE(CustomOp):
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
"Duplicate layer name: {}"
.
format
(
prefix
))
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_all_moe_layers
.
append
(
prefix
)
self
.
layer_name
=
prefix
self
.
enable_eplb
=
enable_eplb
...
...
@@ -1566,7 +1567,7 @@ class FusedMoE(CustomOp):
# Can be unavailable or None in unittests
if
(
is_forward_context_available
()
and
get_forward_context
().
remaining
_moe_layers
is
not
None
and
get_forward_context
().
all
_moe_layers
is
not
None
):
return
"from_forward_context"
return
self
.
layer_name
...
...
@@ -1987,13 +1988,17 @@ class FusedMoE(CustomOp):
def
get_layer_from_name
(
layer_name
:
str
)
->
FusedMoE
:
forward_context
:
ForwardContext
=
get_forward_context
()
if
layer_name
==
"from_forward_context"
:
if
not
forward_context
.
remaining_moe_layers
:
all_moe_layers
=
forward_context
.
all_moe_layers
assert
all_moe_layers
is
not
None
moe_layer_index
=
forward_context
.
moe_layer_index
if
moe_layer_index
>=
len
(
all_moe_layers
):
raise
AssertionError
(
"We expected the number of MOE layers in `
remaining
_moe_layers` "
"We expected the number of MOE layers in `
all
_moe_layers` "
"to be equal to the number of "
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
)
layer_name
=
forward_context
.
remaining_moe_layers
.
pop
()
layer_name
=
all_moe_layers
[
moe_layer_index
]
forward_context
.
moe_layer_index
+=
1
self
=
cast
(
FusedMoE
,
forward_context
.
no_compile_layers
[
layer_name
])
return
self
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment