Unverified Commit 9b6504c3 authored by Richard Zou's avatar Richard Zou Committed by GitHub
Browse files

[BugFix] Work around graph partition x torch.compile cache issue (#26956)


Signed-off-by: default avatarRichard Zou <zou3519@gmail.com>
parent e19b16dd
...@@ -337,9 +337,8 @@ def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor: ...@@ -337,9 +337,8 @@ def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
def test_toy_llama( def test_toy_llama(
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
): ):
# We disable the vLLM compile cache into a new tmp dir for 2 reasons: # We disable the vLLM compile cache into a new tmp dir for 1 reason:
# 1. To make sure we can properly track the number of Inductor compilations. # 1. To make sure we can properly track the number of Inductor compilations.
# 2. Inductor partitioning does not play nicely with Autograd cache (below)
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"): if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
...@@ -369,15 +368,6 @@ def test_toy_llama( ...@@ -369,15 +368,6 @@ def test_toy_llama(
cudagraph_capture_sizes=[1, 2], cudagraph_capture_sizes=[1, 2],
) )
# FIXME(luka/boyuan): the graph from the previous test case
# (no inductor partition) gets cached by AotAutograd so then the
# compilation with inductor partitioning incorrectly loads an unpartitioned
# graph and never partitions. I think this is a bug with custom inductor
# partitioning but does not affect vLLM more generally as vLLM uses its own
# cache (which takes inductor partitioning into account).
if use_inductor_graph_partition:
compile_config_no_split.inductor_compile_config["force_disable_caches"] = True
compile_config_split = deepcopy(compile_config_no_split) compile_config_split = deepcopy(compile_config_no_split)
compile_config_split.splitting_ops = ["silly::attention"] compile_config_split.splitting_ops = ["silly::attention"]
......
...@@ -110,6 +110,27 @@ class PostGradPassManager(CustomGraphPass): ...@@ -110,6 +110,27 @@ class PostGradPassManager(CustomGraphPass):
self.post_cleanup = PostCleanupPass(config) self.post_cleanup = PostCleanupPass(config)
self.fix_functionalization = FixFunctionalizationPass(config) self.fix_functionalization = FixFunctionalizationPass(config)
# [HACK: Bug with Inductor graph partition and torch.compile cache]
# In PyTorch 2.9, torch.compile has a bug where the graph
# partition is not taken into account during caching.
# Because vLLM's Mode.VLLM_COMPILE is the only mode that uses
# Inductor graph partition, and VLLM_COMPILE implies there
# is a PostGradPassManager, we put the list of operators to graph
# partition into the PostGradPassManager's uuid (which
# then gets incorporated into Inductor's FX graph cache key).
# Remove this hack whenever torch.compile fixes it.
# This is the list of operators that vLLM asks Inductor to split.
self.inductor_splitting_ops = []
if (
config.compilation_config.use_inductor_graph_partition
and config.compilation_config.splitting_ops is not None
):
# Sort them so we're not dependent on the ordering.
self.inductor_splitting_ops = sorted(
config.compilation_config.splitting_ops
)
def add(self, pass_: InductorPass): def add(self, pass_: InductorPass):
assert isinstance(pass_, InductorPass) assert isinstance(pass_, InductorPass)
self.passes.append(pass_) self.passes.append(pass_)
...@@ -120,8 +141,16 @@ class PostGradPassManager(CustomGraphPass): ...@@ -120,8 +141,16 @@ class PostGradPassManager(CustomGraphPass):
affects compilation caching. Its uuid depends on the UUIDs of all affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info. dependent passes and the pass config. See InductorPass for more info.
""" """
state = {"pass_config": self.pass_config.uuid(), "passes": []} state = {
"pass_config": self.pass_config.uuid(),
"passes": [],
"inductor_splitting_ops": [],
}
for pass_ in self.passes: for pass_ in self.passes:
state["passes"].append(pass_.uuid()) state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid()) state["passes"].append(self.fix_functionalization.uuid())
# See [HACK: Bug with Inductor graph partition and torch.compile cache]
state["inductor_splitting_ops"].extend(self.inductor_splitting_ops)
return InductorPass.hash_dict(state) return InductorPass.hash_dict(state)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment