Unverified Commit b1bb18de authored by Richard Zou's avatar Richard Zou Committed by GitHub
Browse files

[torch.compile] Significantly speed up cold start times (#33641)


Signed-off-by: default avatarRichard Zou <zou3519@gmail.com>
parent 2267cb1c
...@@ -37,12 +37,13 @@ def test_moe_compilation_cold_start(monkeypatch, use_fresh_inductor_cache): ...@@ -37,12 +37,13 @@ def test_moe_compilation_cold_start(monkeypatch, use_fresh_inductor_cache):
# The forward pass consists of 32 transformer layers. # The forward pass consists of 32 transformer layers.
# Then, we split on the attention operation. This results in # Then, we split on the attention operation. This results in
# 33 subgraphs (not including the attention operation). # 33 subgraphs (not including the attention operation).
# The 33 subgraphs then get standalone_compile'd. # We then standalone_compile the unique subgraphs.
# #
# There are actually only 3 unique subgraphs for this model # There are actually only 3 unique subgraphs for this model
# (all of its transformer layers are the same modulo weights); # (all of its transformer layers are the same modulo weights);
# this is true for most vLLM models. # this is true for most vLLM models.
# So we test that during cold start, the aot_autograd cache # So we test that during cold start, only 3 subgraphs are compiled
# misses for 3 subgraphs and hits for the rest. # These 3 subgraphs should cache miss, and then there should be
# no other compilation (so no cache hits).
assert counters["aot_autograd"]["autograd_cache_miss"] == 3 assert counters["aot_autograd"]["autograd_cache_miss"] == 3
assert counters["aot_autograd"]["autograd_cache_hit"] == 30 assert counters["aot_autograd"]["autograd_cache_hit"] == 0
...@@ -121,7 +121,7 @@ class CompilerManager: ...@@ -121,7 +121,7 @@ class CompilerManager:
and compiling the graph. and compiling the graph.
The cache is a dict mapping The cache is a dict mapping
`(runtime_shape, graph_index, backend_name)` `(runtime_shape, graph_hash, backend_name)`
to `any_data` returned from the compiler. to `any_data` returned from the compiler.
When serializing the cache, we save it to a Python file When serializing the cache, we save it to a Python file
...@@ -130,7 +130,7 @@ class CompilerManager: ...@@ -130,7 +130,7 @@ class CompilerManager:
""" """
def __init__(self, compilation_config: CompilationConfig) -> None: def __init__(self, compilation_config: CompilationConfig) -> None:
self.cache: dict[tuple[Range, int, str], Any] = dict() self.cache: dict[tuple[Range, str, str], Any] = dict()
self.is_cache_updated = False self.is_cache_updated = False
self.compilation_config = compilation_config self.compilation_config = compilation_config
self.compiler = make_compiler(compilation_config) self.compiler = make_compiler(compilation_config)
...@@ -173,6 +173,7 @@ class CompilerManager: ...@@ -173,6 +173,7 @@ class CompilerManager:
self.disable_cache = disable_cache self.disable_cache = disable_cache
self.cache_dir = cache_dir self.cache_dir = cache_dir
self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py") self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py")
self.loaded_cache_entries: dict[tuple[Range, str, str], Any] = {}
if not disable_cache and os.path.exists(self.cache_file_path): if not disable_cache and os.path.exists(self.cache_file_path):
# load the cache from the file # load the cache from the file
...@@ -186,9 +187,9 @@ class CompilerManager: ...@@ -186,9 +187,9 @@ class CompilerManager:
if not isinstance(value, ty): if not isinstance(value, ty):
raise TypeError(f"Expected {ty} but got {type(value)} for {value}") raise TypeError(f"Expected {ty} but got {type(value)} for {value}")
def parse_key(key: Any) -> tuple[Range, int, str]: def parse_key(key: Any) -> tuple[Range, str, str]:
range_tuple, graph_index, compiler_name = key range_tuple, graph_hash, compiler_name = key
check_type(graph_index, int) check_type(graph_hash, str)
check_type(compiler_name, str) check_type(compiler_name, str)
if isinstance(range_tuple, tuple): if isinstance(range_tuple, tuple):
start, end = range_tuple start, end = range_tuple
...@@ -196,7 +197,7 @@ class CompilerManager: ...@@ -196,7 +197,7 @@ class CompilerManager:
check_type(end, int) check_type(end, int)
range_tuple = Range(start=start, end=end) range_tuple = Range(start=start, end=end)
check_type(range_tuple, Range) check_type(range_tuple, Range)
return range_tuple, graph_index, compiler_name return range_tuple, graph_hash, compiler_name
self.cache = {parse_key(key): value for key, value in cache.items()} self.cache = {parse_key(key): value for key, value in cache.items()}
...@@ -216,18 +217,25 @@ class CompilerManager: ...@@ -216,18 +217,25 @@ class CompilerManager:
self, self,
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: list[Any], example_inputs: list[Any],
graph_index: int, graph_hash: str,
compile_range: Range, compile_range: Range,
) -> Callable[..., Any] | None: ) -> Callable[..., Any] | None:
if (compile_range, graph_index, self.compiler.name) not in self.cache: key = (compile_range, graph_hash, self.compiler.name)
# See if we've already loaded this cache entry
if key in self.loaded_cache_entries:
return self.loaded_cache_entries[key]
# Otherwise, go load it from disk
if key not in self.cache:
return None return None
handle = self.cache[(compile_range, graph_index, self.compiler.name)] handle = self.cache[key]
compiled_graph = self.compiler.load( compiled_graph = self.compiler.load(
handle, graph, example_inputs, graph_index, compile_range handle, graph, example_inputs, compile_range
) )
self.loaded_cache_entries[key] = compiled_graph
logger.debug( logger.debug(
"Directly load the %s-th graph for compile range %sfrom %s via handle %s", "Directly load the graph (hash %s) for compile range "
graph_index, "%sfrom %s via handle %s",
graph_hash,
str(compile_range), str(compile_range),
self.compiler.name, self.compiler.name,
handle, handle,
...@@ -249,12 +257,22 @@ class CompilerManager: ...@@ -249,12 +257,22 @@ class CompilerManager:
global compilation_start_time global compilation_start_time
compilation_start_time = time.time() compilation_start_time = time.time()
from torch._functorch._aot_autograd.autograd_cache import (
AOTAutogradCachePickler,
sanitize_gm_for_cache,
)
with sanitize_gm_for_cache(graph):
pickler = AOTAutogradCachePickler(graph)
dumped_graph = pickler.dumps(graph)
graph_hash = hashlib.sha256(dumped_graph).hexdigest()
compilation_counter.num_backend_compilations += 1 compilation_counter.num_backend_compilations += 1
compiled_graph = None compiled_graph = None
# try to load from the cache # try to load from the cache
compiled_graph = self.load(graph, example_inputs, graph_index, compile_range) compiled_graph = self.load(graph, example_inputs, graph_hash, compile_range)
if compiled_graph is not None: if compiled_graph is not None:
if graph_index == num_graphs - 1: if graph_index == num_graphs - 1:
# after loading the last graph for this shape, record the time. # after loading the last graph for this shape, record the time.
...@@ -290,9 +308,13 @@ class CompilerManager: ...@@ -290,9 +308,13 @@ class CompilerManager:
assert compiled_graph is not None, "Failed to compile the graph" assert compiled_graph is not None, "Failed to compile the graph"
self.loaded_cache_entries[(compile_range, graph_hash, self.compiler.name)] = (
compiled_graph
)
# store the artifact in the cache # store the artifact in the cache
if is_compile_cache_enabled(additional_inductor_config) and handle is not None: if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
self.cache[(compile_range, graph_index, self.compiler.name)] = handle self.cache[(compile_range, graph_hash, self.compiler.name)] = handle
compilation_counter.num_cache_entries_updated += 1 compilation_counter.num_cache_entries_updated += 1
self.is_cache_updated = True self.is_cache_updated = True
if graph_index == 0: if graph_index == 0:
......
...@@ -101,7 +101,6 @@ class CompilerInterface: ...@@ -101,7 +101,6 @@ class CompilerInterface:
handle: Any, handle: Any,
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: list[Any], example_inputs: list[Any],
graph_index: int,
compile_range: Range, compile_range: Range,
) -> Callable[..., Any]: ) -> Callable[..., Any]:
""" """
...@@ -302,7 +301,6 @@ class InductorStandaloneAdaptor(CompilerInterface): ...@@ -302,7 +301,6 @@ class InductorStandaloneAdaptor(CompilerInterface):
handle: Any, handle: Any,
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: list[Any], example_inputs: list[Any],
graph_index: int,
compile_range: Range, compile_range: Range,
) -> Callable[..., Any]: ) -> Callable[..., Any]:
assert isinstance(handle, tuple) assert isinstance(handle, tuple)
...@@ -527,7 +525,6 @@ class InductorAdaptor(CompilerInterface): ...@@ -527,7 +525,6 @@ class InductorAdaptor(CompilerInterface):
handle: Any, handle: Any,
graph: fx.GraphModule, graph: fx.GraphModule,
example_inputs: list[Any], example_inputs: list[Any],
graph_index: int,
compile_range: Range, compile_range: Range,
) -> Callable[..., Any]: ) -> Callable[..., Any]:
assert isinstance(handle, tuple) assert isinstance(handle, tuple)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment