Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d9aa39a3
Unverified
Commit
d9aa39a3
authored
Jan 27, 2026
by
Richard Zou
Committed by
GitHub
Jan 27, 2026
Browse files
[torch.compile] Speed up MOE handling in forward_context (#33184)
Signed-off-by:
Richard Zou
<
zou3519@gmail.com
>
parent
3a6d5cbe
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
18 deletions
+22
-18
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+1
-1
vllm/config/compilation.py
vllm/config/compilation.py
+4
-0
vllm/forward_context.py
vllm/forward_context.py
+8
-13
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+9
-4
No files found.
tests/kernels/moe/test_moe.py
View file @
d9aa39a3
...
@@ -715,7 +715,7 @@ def test_mixtral_moe(
...
@@ -715,7 +715,7 @@ def test_mixtral_moe(
# need to override the forward context for unittests, otherwise it assumes
# need to override the forward context for unittests, otherwise it assumes
# we're running the model forward pass (the model specified in vllm_config)
# we're running the model forward pass (the model specified in vllm_config)
get_forward_context
().
remaining
_moe_layers
=
None
get_forward_context
().
all
_moe_layers
=
None
# Run forward passes for both MoE blocks
# Run forward passes for both MoE blocks
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
hf_states
,
_
=
hf_moe
.
forward
(
hf_inputs
)
...
...
vllm/config/compilation.py
View file @
d9aa39a3
...
@@ -597,6 +597,10 @@ class CompilationConfig:
...
@@ -597,6 +597,10 @@ class CompilationConfig:
Map from layer name to layer objects that need to be accessed outside
Map from layer name to layer objects that need to be accessed outside
model code, e.g., Attention, FusedMOE when dp_size>1."""
model code, e.g., Attention, FusedMOE when dp_size>1."""
static_all_moe_layers
:
list
[
str
]
=
field
(
default_factory
=
list
,
init
=
False
)
"""The names of all the MOE layers in the model
"""
# Attention ops; used for piecewise cudagraphs
# Attention ops; used for piecewise cudagraphs
# Use PyTorch operator format: "namespace::name"
# Use PyTorch operator format: "namespace::name"
_attention_ops
:
ClassVar
[
list
[
str
]]
=
[
_attention_ops
:
ClassVar
[
list
[
str
]]
=
[
...
...
vllm/forward_context.py
View file @
d9aa39a3
...
@@ -217,9 +217,11 @@ class ForwardContext:
...
@@ -217,9 +217,11 @@ class ForwardContext:
# the graph.
# the graph.
#
#
# The workaround is to store a list of the strings that each of those
# The workaround is to store a list of the strings that each of those
# custom ops needs, in reverse order, in the ForwardContext.
# custom ops needs in the ForwardContext (all_moe_layers)
# as well as a counter (moe_layer_index).
# The ForwardContext object is alive for the duration of the forward pass.
# The ForwardContext object is alive for the duration of the forward pass.
# When the custom op needs the string, pop the string from this list.
# When the custom op needs a layer string, get the next string
# from all_moe_layers and increment the counter.
#
#
# This assumes that the custom operators will always be executed in
# This assumes that the custom operators will always be executed in
# order and that torch.compile will not try to reorder these
# order and that torch.compile will not try to reorder these
...
@@ -233,7 +235,8 @@ class ForwardContext:
...
@@ -233,7 +235,8 @@ class ForwardContext:
#
#
# If this value is None (like in some tests), then we end up baking the string
# If this value is None (like in some tests), then we end up baking the string
# into the graph. Otherwise, the moe custom ops will pop a string from this list.
# into the graph. Otherwise, the moe custom ops will pop a string from this list.
remaining_moe_layers
:
list
[
str
]
|
None
=
None
all_moe_layers
:
list
[
str
]
|
None
=
None
moe_layer_index
:
int
=
0
additional_kwargs
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
additional_kwargs
:
dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
...
@@ -271,17 +274,9 @@ def create_forward_context(
...
@@ -271,17 +274,9 @@ def create_forward_context(
additional_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
additional_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
skip_compiled
:
bool
=
False
,
skip_compiled
:
bool
=
False
,
):
):
no_compile_layers
=
vllm_config
.
compilation_config
.
static_forward_context
from
vllm.model_executor.layers.fused_moe.layer
import
FusedMoE
remaining_moe_layers
=
[
name
for
name
,
layer
in
no_compile_layers
.
items
()
if
isinstance
(
layer
,
FusedMoE
)
]
remaining_moe_layers
.
reverse
()
return
ForwardContext
(
return
ForwardContext
(
no_compile_layers
=
no_compile_layers
,
no_compile_layers
=
vllm_config
.
compilation_config
.
static_forward_context
,
remaining
_moe_layers
=
remaining
_moe_layers
,
all
_moe_layers
=
vllm_config
.
compilation_config
.
static_all
_moe_layers
,
virtual_engine
=
virtual_engine
,
virtual_engine
=
virtual_engine
,
attn_metadata
=
attn_metadata
,
attn_metadata
=
attn_metadata
,
slot_mapping
=
slot_mapping
or
{},
slot_mapping
=
slot_mapping
or
{},
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
d9aa39a3
...
@@ -407,6 +407,7 @@ class FusedMoE(CustomOp):
...
@@ -407,6 +407,7 @@ class FusedMoE(CustomOp):
if
prefix
in
compilation_config
.
static_forward_context
:
if
prefix
in
compilation_config
.
static_forward_context
:
raise
ValueError
(
"Duplicate layer name: {}"
.
format
(
prefix
))
raise
ValueError
(
"Duplicate layer name: {}"
.
format
(
prefix
))
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_forward_context
[
prefix
]
=
self
compilation_config
.
static_all_moe_layers
.
append
(
prefix
)
self
.
layer_name
=
prefix
self
.
layer_name
=
prefix
self
.
enable_eplb
=
enable_eplb
self
.
enable_eplb
=
enable_eplb
...
@@ -1566,7 +1567,7 @@ class FusedMoE(CustomOp):
...
@@ -1566,7 +1567,7 @@ class FusedMoE(CustomOp):
# Can be unavailable or None in unittests
# Can be unavailable or None in unittests
if
(
if
(
is_forward_context_available
()
is_forward_context_available
()
and
get_forward_context
().
remaining
_moe_layers
is
not
None
and
get_forward_context
().
all
_moe_layers
is
not
None
):
):
return
"from_forward_context"
return
"from_forward_context"
return
self
.
layer_name
return
self
.
layer_name
...
@@ -1987,13 +1988,17 @@ class FusedMoE(CustomOp):
...
@@ -1987,13 +1988,17 @@ class FusedMoE(CustomOp):
def
get_layer_from_name
(
layer_name
:
str
)
->
FusedMoE
:
def
get_layer_from_name
(
layer_name
:
str
)
->
FusedMoE
:
forward_context
:
ForwardContext
=
get_forward_context
()
forward_context
:
ForwardContext
=
get_forward_context
()
if
layer_name
==
"from_forward_context"
:
if
layer_name
==
"from_forward_context"
:
if
not
forward_context
.
remaining_moe_layers
:
all_moe_layers
=
forward_context
.
all_moe_layers
assert
all_moe_layers
is
not
None
moe_layer_index
=
forward_context
.
moe_layer_index
if
moe_layer_index
>=
len
(
all_moe_layers
):
raise
AssertionError
(
raise
AssertionError
(
"We expected the number of MOE layers in `
remaining
_moe_layers` "
"We expected the number of MOE layers in `
all
_moe_layers` "
"to be equal to the number of "
"to be equal to the number of "
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
"{vllm.moe_forward, vllm.moe_forward_shared} calls."
)
)
layer_name
=
forward_context
.
remaining_moe_layers
.
pop
()
layer_name
=
all_moe_layers
[
moe_layer_index
]
forward_context
.
moe_layer_index
+=
1
self
=
cast
(
FusedMoE
,
forward_context
.
no_compile_layers
[
layer_name
])
self
=
cast
(
FusedMoE
,
forward_context
.
no_compile_layers
[
layer_name
])
return
self
return
self
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment