Unverified Commit d9aa39a3 authored by Richard Zou's avatar Richard Zou Committed by GitHub
Browse files

[torch.compile] Speed up MOE handling in forward_context (#33184)


Signed-off-by: default avatarRichard Zou <zou3519@gmail.com>
parent 3a6d5cbe
...@@ -715,7 +715,7 @@ def test_mixtral_moe( ...@@ -715,7 +715,7 @@ def test_mixtral_moe(
# need to override the forward context for unittests, otherwise it assumes # need to override the forward context for unittests, otherwise it assumes
# we're running the model forward pass (the model specified in vllm_config) # we're running the model forward pass (the model specified in vllm_config)
get_forward_context().remaining_moe_layers = None get_forward_context().all_moe_layers = None
# Run forward passes for both MoE blocks # Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(hf_inputs) hf_states, _ = hf_moe.forward(hf_inputs)
......
...@@ -597,6 +597,10 @@ class CompilationConfig: ...@@ -597,6 +597,10 @@ class CompilationConfig:
Map from layer name to layer objects that need to be accessed outside Map from layer name to layer objects that need to be accessed outside
model code, e.g., Attention, FusedMOE when dp_size>1.""" model code, e.g., Attention, FusedMOE when dp_size>1."""
static_all_moe_layers: list[str] = field(default_factory=list, init=False)
"""The names of all the MOE layers in the model
"""
# Attention ops; used for piecewise cudagraphs # Attention ops; used for piecewise cudagraphs
# Use PyTorch operator format: "namespace::name" # Use PyTorch operator format: "namespace::name"
_attention_ops: ClassVar[list[str]] = [ _attention_ops: ClassVar[list[str]] = [
......
...@@ -217,9 +217,11 @@ class ForwardContext: ...@@ -217,9 +217,11 @@ class ForwardContext:
# the graph. # the graph.
# #
# The workaround is to store a list of the strings that each of those # The workaround is to store a list of the strings that each of those
# custom ops needs, in reverse order, in the ForwardContext. # custom ops needs in the ForwardContext (all_moe_layers)
# as well as a counter (moe_layer_index).
# The ForwardContext object is alive for the duration of the forward pass. # The ForwardContext object is alive for the duration of the forward pass.
# When the custom op needs the string, pop the string from this list. # When the custom op needs a layer string, get the next string
# from all_moe_layers and increment the counter.
# #
# This assumes that the custom operators will always be executed in # This assumes that the custom operators will always be executed in
# order and that torch.compile will not try to reorder these # order and that torch.compile will not try to reorder these
...@@ -233,7 +235,8 @@ class ForwardContext: ...@@ -233,7 +235,8 @@ class ForwardContext:
# #
# If this value is None (like in some tests), then we end up baking the string # If this value is None (like in some tests), then we end up baking the string
# into the graph. Otherwise, the moe custom ops will pop a string from this list. # into the graph. Otherwise, the moe custom ops will pop a string from this list.
remaining_moe_layers: list[str] | None = None all_moe_layers: list[str] | None = None
moe_layer_index: int = 0
additional_kwargs: dict[str, Any] = field(default_factory=dict) additional_kwargs: dict[str, Any] = field(default_factory=dict)
...@@ -271,17 +274,9 @@ def create_forward_context( ...@@ -271,17 +274,9 @@ def create_forward_context(
additional_kwargs: dict[str, Any] | None = None, additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False, skip_compiled: bool = False,
): ):
no_compile_layers = vllm_config.compilation_config.static_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
remaining_moe_layers = [
name for name, layer in no_compile_layers.items() if isinstance(layer, FusedMoE)
]
remaining_moe_layers.reverse()
return ForwardContext( return ForwardContext(
no_compile_layers=no_compile_layers, no_compile_layers=vllm_config.compilation_config.static_forward_context,
remaining_moe_layers=remaining_moe_layers, all_moe_layers=vllm_config.compilation_config.static_all_moe_layers,
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
slot_mapping=slot_mapping or {}, slot_mapping=slot_mapping or {},
......
...@@ -407,6 +407,7 @@ class FusedMoE(CustomOp): ...@@ -407,6 +407,7 @@ class FusedMoE(CustomOp):
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(prefix)) raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self compilation_config.static_forward_context[prefix] = self
compilation_config.static_all_moe_layers.append(prefix)
self.layer_name = prefix self.layer_name = prefix
self.enable_eplb = enable_eplb self.enable_eplb = enable_eplb
...@@ -1566,7 +1567,7 @@ class FusedMoE(CustomOp): ...@@ -1566,7 +1567,7 @@ class FusedMoE(CustomOp):
# Can be unavailable or None in unittests # Can be unavailable or None in unittests
if ( if (
is_forward_context_available() is_forward_context_available()
and get_forward_context().remaining_moe_layers is not None and get_forward_context().all_moe_layers is not None
): ):
return "from_forward_context" return "from_forward_context"
return self.layer_name return self.layer_name
...@@ -1987,13 +1988,17 @@ class FusedMoE(CustomOp): ...@@ -1987,13 +1988,17 @@ class FusedMoE(CustomOp):
def get_layer_from_name(layer_name: str) -> FusedMoE: def get_layer_from_name(layer_name: str) -> FusedMoE:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
if layer_name == "from_forward_context": if layer_name == "from_forward_context":
if not forward_context.remaining_moe_layers: all_moe_layers = forward_context.all_moe_layers
assert all_moe_layers is not None
moe_layer_index = forward_context.moe_layer_index
if moe_layer_index >= len(all_moe_layers):
raise AssertionError( raise AssertionError(
"We expected the number of MOE layers in `remaining_moe_layers` " "We expected the number of MOE layers in `all_moe_layers` "
"to be equal to the number of " "to be equal to the number of "
"{vllm.moe_forward, vllm.moe_forward_shared} calls." "{vllm.moe_forward, vllm.moe_forward_shared} calls."
) )
layer_name = forward_context.remaining_moe_layers.pop() layer_name = all_moe_layers[moe_layer_index]
forward_context.moe_layer_index += 1
self = cast(FusedMoE, forward_context.no_compile_layers[layer_name]) self = cast(FusedMoE, forward_context.no_compile_layers[layer_name])
return self return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment