Unverified Commit 002b07c4 authored by gmagogsfm's avatar gmagogsfm Committed by GitHub
Browse files

[Bugfix] vLLM should check Inductor config for compile cache enablement status (#27637)


Signed-off-by: default avatarYanan Cao <gmagogsfm@gmail.com>
parent 752ddeac
...@@ -33,6 +33,7 @@ from .compiler_interface import ( ...@@ -33,6 +33,7 @@ from .compiler_interface import (
EagerAdaptor, EagerAdaptor,
InductorAdaptor, InductorAdaptor,
InductorStandaloneAdaptor, InductorStandaloneAdaptor,
is_compile_cache_enabled,
) )
from .counter import compilation_counter from .counter import compilation_counter
from .inductor_pass import InductorPass from .inductor_pass import InductorPass
...@@ -239,7 +240,7 @@ class CompilerManager: ...@@ -239,7 +240,7 @@ class CompilerManager:
assert compiled_graph is not None, "Failed to compile the graph" assert compiled_graph is not None, "Failed to compile the graph"
# store the artifact in the cache # store the artifact in the cache
if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None: if is_compile_cache_enabled(additional_inductor_config) and handle is not None:
self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle
compilation_counter.num_cache_entries_updated += 1 compilation_counter.num_cache_entries_updated += 1
self.is_cache_updated = True self.is_cache_updated = True
...@@ -611,7 +612,9 @@ class VllmBackend: ...@@ -611,7 +612,9 @@ class VllmBackend:
os.makedirs(local_cache_dir, exist_ok=True) os.makedirs(local_cache_dir, exist_ok=True)
self.compilation_config.local_cache_dir = local_cache_dir self.compilation_config.local_cache_dir = local_cache_dir
disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE disable_cache = not is_compile_cache_enabled(
self.compilation_config.inductor_compile_config
)
if disable_cache: if disable_cache:
logger.info_once("vLLM's torch.compile cache is disabled.", scope="local") logger.info_once("vLLM's torch.compile cache is disabled.", scope="local")
......
...@@ -163,6 +163,23 @@ def get_inductor_factors() -> list[Any]: ...@@ -163,6 +163,23 @@ def get_inductor_factors() -> list[Any]:
return factors return factors
def is_compile_cache_enabled(
vllm_additional_inductor_config: dict[str, Any],
) -> bool:
vllm_inductor_config_disable_cache = vllm_additional_inductor_config.get(
"force_disable_caches", False
)
# TODO(gmagogsfm): Replace torch._inductor.config.force_disable_caches
# with torch.compiler.config.force_disable_caches when minimum PyTorch
# version reaches 2.10
return (
not envs.VLLM_DISABLE_COMPILE_CACHE
and not torch._inductor.config.force_disable_caches
and not vllm_inductor_config_disable_cache
)
class InductorStandaloneAdaptor(CompilerInterface): class InductorStandaloneAdaptor(CompilerInterface):
""" """
The adaptor for the Inductor compiler. The adaptor for the Inductor compiler.
...@@ -222,7 +239,8 @@ class InductorStandaloneAdaptor(CompilerInterface): ...@@ -222,7 +239,8 @@ class InductorStandaloneAdaptor(CompilerInterface):
# Save the compiled artifact to disk in the specified path # Save the compiled artifact to disk in the specified path
assert key is not None assert key is not None
path = os.path.join(self.cache_dir, key) path = os.path.join(self.cache_dir, key)
if not envs.VLLM_DISABLE_COMPILE_CACHE:
if is_compile_cache_enabled(compiler_config):
compiled_graph.save(path=path, format=self.save_format) compiled_graph.save(path=path, format=self.save_format)
compilation_counter.num_compiled_artifacts_saved += 1 compilation_counter.num_compiled_artifacts_saved += 1
return compiled_graph, (key, path) return compiled_graph, (key, path)
...@@ -472,10 +490,8 @@ class InductorAdaptor(CompilerInterface): ...@@ -472,10 +490,8 @@ class InductorAdaptor(CompilerInterface):
config_patches=current_config, config_patches=current_config,
) )
# We treat VLLM_DISABLE_COMPILE_CACHE as the overall switch for torch # Turn off the checks if we disable the compilation cache.
# compilation cache. So turn off the checks if we disable the if is_compile_cache_enabled(compiler_config):
# compilation cache.
if not envs.VLLM_DISABLE_COMPILE_CACHE:
if hash_str is None: if hash_str is None:
raise RuntimeError( raise RuntimeError(
"vLLM failed to compile the model. The most " "vLLM failed to compile the model. The most "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment