Unverified Commit 341eed3d authored by Richard Zou's avatar Richard Zou Committed by GitHub
Browse files

[torch.compile] Disable recursive pre_grad_passes (#34092)


Signed-off-by: default avatarRichard Zou <zou3519@gmail.com>
parent 6f2f59f2
...@@ -257,6 +257,19 @@ class InductorStandaloneAdaptor(CompilerInterface): ...@@ -257,6 +257,19 @@ class InductorStandaloneAdaptor(CompilerInterface):
if use_aot: if use_aot:
compile_kwargs["aot"] = True # type: ignore[assignment] compile_kwargs["aot"] = True # type: ignore[assignment]
# Inductor's pre-grad passes don't do anything for vLLM.
# The pre-grad passes get run even on cache-hit and negatively impact
# vllm cold compile times by O(1s)
# Can remove this after the following issue gets fixed
# https://github.com/pytorch/pytorch/issues/174502
if envs.VLLM_ENABLE_PREGRAD_PASSES:
ctx: Any = contextlib.nullcontext()
else:
ctx = patch(
"torch._inductor.compile_fx._recursive_pre_grad_passes",
lambda gm, _: gm,
)
with ctx:
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs) compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
if use_aot: if use_aot:
......
...@@ -132,6 +132,7 @@ if TYPE_CHECKING: ...@@ -132,6 +132,7 @@ if TYPE_CHECKING:
VLLM_DP_RANK_LOCAL: int = -1 VLLM_DP_RANK_LOCAL: int = -1
VLLM_DP_SIZE: int = 1 VLLM_DP_SIZE: int = 1
VLLM_USE_STANDALONE_COMPILE: bool = True VLLM_USE_STANDALONE_COMPILE: bool = True
VLLM_ENABLE_PREGRAD_PASSES: bool = False
VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0 VLLM_DP_MASTER_PORT: int = 0
VLLM_MOE_DP_CHUNK_SIZE: int = 256 VLLM_MOE_DP_CHUNK_SIZE: int = 256
...@@ -568,6 +569,15 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -568,6 +569,15 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_STANDALONE_COMPILE", "1" "VLLM_USE_STANDALONE_COMPILE", "1"
) )
== "1", == "1",
# Inductor's pre-grad passes don't do anything for vLLM.
# The pre-grad passes get run even on cache-hit and negatively impact
# vllm cold compile times by O(1s)
# Can remove this after the following issue gets fixed
# https://github.com/pytorch/pytorch/issues/174502
"VLLM_ENABLE_PREGRAD_PASSES": lambda: os.environ.get(
"VLLM_ENABLE_PREGRAD_PASSES", "0"
)
== "1",
# Debug pattern matching inside custom passes. # Debug pattern matching inside custom passes.
# Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3'). # Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3').
"VLLM_PATTERN_MATCH_DEBUG": lambda: os.environ.get( "VLLM_PATTERN_MATCH_DEBUG": lambda: os.environ.get(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment