Unverified Commit 81fe69ca authored by Richard Zou's avatar Richard Zou Committed by GitHub
Browse files

[torch.compile] Stop compiling identical artifacts (#34003)


Signed-off-by: default avatarRichard Zou <zou3519@gmail.com>
parent dd6a6e11
...@@ -37,12 +37,12 @@ def test_moe_compilation_cold_start(monkeypatch, use_fresh_inductor_cache): ...@@ -37,12 +37,12 @@ def test_moe_compilation_cold_start(monkeypatch, use_fresh_inductor_cache):
# The forward pass consists of 32 transformer layers. # The forward pass consists of 32 transformer layers.
# Then, we split on the attention operation. This results in # Then, we split on the attention operation. This results in
# 33 subgraphs (not including the attention operation). # 33 subgraphs (not including the attention operation).
# The 33 subgraphs then get standalone_compile'd. # We then generate compiled artifacts for the unique subgraphs.
# #
# There are actually only 3 unique subgraphs for this model # There are actually only 3 unique subgraphs for this model
# (all of its transformer layers are the same modulo weights); # (all of its transformer layers are the same modulo weights);
# this is true for most vLLM models. # this is true for most vLLM models.
# So we test that during cold start, the aot_autograd cache # So we test that during cold start, we are only compling
# misses for 3 subgraphs and hits for the rest. # for 3 unique subgraphs.
assert counters["aot_autograd"]["autograd_cache_miss"] == 3 assert counters["aot_autograd"]["autograd_cache_miss"] == 3
assert counters["aot_autograd"]["autograd_cache_hit"] == 30 assert counters["aot_autograd"]["autograd_cache_hit"] == 0
...@@ -134,6 +134,7 @@ class CompilerManager: ...@@ -134,6 +134,7 @@ class CompilerManager:
self.is_cache_updated = False self.is_cache_updated = False
self.compilation_config = compilation_config self.compilation_config = compilation_config
self.compiler = make_compiler(compilation_config) self.compiler = make_compiler(compilation_config)
self.loaded_artifacts: dict[str, Any] = {}
def compute_hash(self, vllm_config: VllmConfig) -> str: def compute_hash(self, vllm_config: VllmConfig) -> str:
return self.compiler.compute_hash(vllm_config) return self.compiler.compute_hash(vllm_config)
...@@ -282,6 +283,49 @@ class CompilerManager: ...@@ -282,6 +283,49 @@ class CompilerManager:
maybe_key += f"{compile_range.start}_{compile_range.end}" maybe_key += f"{compile_range.start}_{compile_range.end}"
maybe_key += f"_subgraph_{graph_index}" maybe_key += f"_subgraph_{graph_index}"
with self.compile_context(compile_range): with self.compile_context(compile_range):
# There is a compilation time optimization here.
#
# If the (input metdata, graph, compiler config) are the same, then
# we want to avoid compiling the same artifact again. If we didn't
# do this optimization, the backend compilation (InductorAdaptor or
# InductorStandaloneAdaptor)
# is able to cache hit and produce an artifact faster if it was
# already created, but it is still a duplicate artifact that
# requires unnecessary things e.g. disk IO.
#
# The optimization is: If the backend compilation cache hits,
# then do an early return from the backend compilation and look up
# which of the previous in-memory artifacts we created to reuse.
#
# We implemented this by monkey-patching torch (torch does not
# easily expose the cache_key function), but in the future torch
# should expose the cache_key function that we can just call
# directly before invoking backend compilation.
cache_key = None
orig = torch._functorch._aot_autograd.autograd_cache.autograd_cache_key
def autograd_cache_key(*args, **kwargs):
result = orig(*args, **kwargs)
if result is None:
return None
nonlocal cache_key
cache_key = result[0]
if cache_key in self.loaded_artifacts:
raise StopCompiling()
return result
from unittest.mock import patch
with (
# Graphs that are isometric (different node names but same
# structure) should be treated as the same.
torch._functorch.config.patch(autograd_cache_normalize_inputs=True),
patch(
"torch._functorch._aot_autograd.autograd_cache.autograd_cache_key",
autograd_cache_key,
),
):
try:
compiled_graph, handle = self.compiler.compile( compiled_graph, handle = self.compiler.compile(
graph, graph,
example_inputs, example_inputs,
...@@ -289,6 +333,11 @@ class CompilerManager: ...@@ -289,6 +333,11 @@ class CompilerManager:
compile_range, compile_range,
maybe_key, maybe_key,
) )
except StopCompiling:
assert cache_key is not None
return self.loaded_artifacts[cache_key]
if cache_key is not None and compiled_graph is not None:
self.loaded_artifacts[cache_key] = compiled_graph
assert compiled_graph is not None, "Failed to compile the graph" assert compiled_graph is not None, "Failed to compile the graph"
...@@ -326,6 +375,10 @@ class CompilerManager: ...@@ -326,6 +375,10 @@ class CompilerManager:
return compiled_graph return compiled_graph
class StopCompiling(BaseException):
pass
@dataclasses.dataclass @dataclasses.dataclass
class SplitItem: class SplitItem:
submod_name: str submod_name: str
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment