Unverified Commit dd6dbd93 authored by Zhengxu Chen's avatar Zhengxu Chen Committed by GitHub
Browse files

[compile] Fix extra cache save on warm start. (#35921)


Signed-off-by: default avatarzhxchen17 <zhxchen17@fb.com>
parent 26366009
...@@ -61,11 +61,11 @@ def test_moe_startup(monkeypatch, vllm_runner, fresh_vllm_cache): ...@@ -61,11 +61,11 @@ def test_moe_startup(monkeypatch, vllm_runner, fresh_vllm_cache):
counters.clear() counters.clear()
with compilation_counter.expect( with compilation_counter.expect(
num_compiled_artifacts_loaded=3, num_compiled_artifacts_loaded=3,
# TODO: warm start should not save any artifacts num_compiled_artifacts_saved=0,
# https://github.com/vllm-project/vllm/issues/35708
num_compiled_artifacts_saved=1,
): ):
_run_vllm(vllm_runner) _run_vllm(vllm_runner)
assert counters["aot_autograd"]["total"] == 30 assert counters["aot_autograd"]["total"] == 30
assert counters["aot_autograd"]["autograd_cache_miss"] == 0 assert counters["aot_autograd"]["autograd_cache_miss"] == 0
assert counters["aot_autograd"]["autograd_cache_hit"] == 1 assert (
counters["aot_autograd"]["autograd_cache_hit"] == 0
) # No miss at aot_autograd level causing disk I/O.
...@@ -221,10 +221,28 @@ class CompilerManager: ...@@ -221,10 +221,28 @@ class CompilerManager:
) -> Callable[..., Any] | None: ) -> Callable[..., Any] | None:
if (compile_range, graph_index, self.compiler.name) not in self.cache: if (compile_range, graph_index, self.compiler.name) not in self.cache:
return None return None
handle = self.cache[(compile_range, graph_index, self.compiler.name)]
def parse_value(value: Any) -> tuple[tuple[str, str], str]:
assert isinstance(value, dict)
handle = value["graph_handle"]
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
cache_key = value["cache_key"]
return handle, cache_key
try:
handle, cache_key = parse_value(
self.cache[(compile_range, graph_index, self.compiler.name)]
)
except Exception:
# When the cache is outdated, we should ignore the existing file.
# This should cause the correct cache to be generated again.
return None
compiled_graph = self.compiler.load( compiled_graph = self.compiler.load(
handle, graph, example_inputs, graph_index, compile_range handle, graph, example_inputs, graph_index, compile_range
) )
self.loaded_artifacts[cache_key] = compiled_graph
logger.debug( logger.debug(
"Directly load the %s-th graph for compile range %sfrom %s via handle %s", "Directly load the %s-th graph for compile range %sfrom %s via handle %s",
graph_index, graph_index,
...@@ -341,7 +359,10 @@ class CompilerManager: ...@@ -341,7 +359,10 @@ class CompilerManager:
# store the artifact in the cache # store the artifact in the cache
if is_compile_cache_enabled(additional_inductor_config) and handle is not None: if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
self.cache[(compile_range, graph_index, self.compiler.name)] = handle self.cache[(compile_range, graph_index, self.compiler.name)] = {
"graph_handle": handle,
"cache_key": cache_key,
}
compilation_counter.num_cache_entries_updated += 1 compilation_counter.num_cache_entries_updated += 1
self.is_cache_updated = True self.is_cache_updated = True
if graph_index == 0: if graph_index == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment