Unverified Commit 4df841fe authored by Richard Zou's avatar Richard Zou Committed by GitHub
Browse files

[torch.compile] Add an option to force-enable the MOE cold start optimization (#33735)


Signed-off-by: default avatarRichard Zou <zou3519@gmail.com>
parent a263aa61
......@@ -593,7 +593,7 @@ class CompilationConfig:
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""
fast_moe_cold_start = True
fast_moe_cold_start: bool | None = None
"""Optimization for fast MOE cold start.
This is a bit of a hack that assumes that:
......@@ -604,8 +604,14 @@ class CompilationConfig:
When the above two conditions hold, this option greatly decreases cold start
time for MOE models.
If the above two conditions don't hold, then this option will lead to silent
incorrectness. The only condition in which this doesn't hold is speculative
The options are:
- True: optimization is always on
- False: optimization is always off
- None: optimization is on usually but off for speculative decoding
If conditions 1&2 don't hold then this option will lead to silent
incorrectness.
The only condition in which this doesn't hold is speculative
decoding, where there is a draft model that may have MOEs in them.
NB: We're working on a longer-term solution that doesn't need these assumptions.
......
......@@ -806,6 +806,14 @@ class VllmConfig:
else:
self.compilation_config.custom_ops.append("+rms_norm")
if self.compilation_config.fast_moe_cold_start is None:
# resolve default behavior: try to be as safe as possible
# this config is unsafe if any spec decoding draft model has a MOE.
# We'll conservatively turn it off if we see spec decoding.
self.compilation_config.fast_moe_cold_start = (
self.speculative_config is None
)
if current_platform.support_static_graph_mode():
# if cudagraph_mode has full cudagraphs, we need to check support
if model_config := self.model_config:
......
......@@ -287,15 +287,7 @@ def create_forward_context(
skip_compiled: bool = False,
):
if vllm_config.compilation_config.fast_moe_cold_start:
if vllm_config.speculative_config is None:
all_moe_layers = vllm_config.compilation_config.static_all_moe_layers
else:
logger.warning_once(
"vllm_config.compilation_config.fast_moe_cold_start is not "
"compatible with speculative decoding so we are ignoring "
"fast_moe_cold_start."
)
all_moe_layers = None
all_moe_layers = vllm_config.compilation_config.static_all_moe_layers
else:
all_moe_layers = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment