Unverified Commit 09b6f998 authored by Richard Zou's avatar Richard Zou Committed by GitHub
Browse files

[compile] aot_compile should respect VLLM_DISABLE_COMPILE_CACHE (#36358)


Signed-off-by: default avatarRichard Zou <zou3519@gmail.com>
parent c87fb515
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import functools import functools
import hashlib import hashlib
import multiprocessing import multiprocessing
import os
import pickle import pickle
import tempfile import tempfile
from contextlib import contextmanager from contextlib import contextmanager
...@@ -19,6 +20,7 @@ from vllm.compilation.caching import ( ...@@ -19,6 +20,7 @@ from vllm.compilation.caching import (
StandaloneCompiledArtifacts, StandaloneCompiledArtifacts,
VllmSerializableFunction, VllmSerializableFunction,
) )
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ( from vllm.config import (
CompilationConfig, CompilationConfig,
...@@ -763,3 +765,115 @@ class TestStandaloneCompiledArtifactsIntegration: ...@@ -763,3 +765,115 @@ class TestStandaloneCompiledArtifactsIntegration:
assert isinstance(config, dict) assert isinstance(config, dict)
assert "bundled_autograd_cache" in config assert "bundled_autograd_cache" in config
assert config["bundled_autograd_cache"] is True assert config["bundled_autograd_cache"] is True
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
def test_disable_compile_cache_skips_aot_save(
monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str
):
"""When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be saved."""
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
disable_envs_cache()
args = (torch.randn(10, 10),)
expected = reference_fn(*args)
vllm_config = make_vllm_config()
with (
use_vllm_config(vllm_config),
compilation_counter.expect(
num_aot_compiles=1,
num_aot_artifacts_saved=0,
num_aot_artifacts_loaded=0,
),
):
mod = CompiledMod(vllm_config=vllm_config)
actual = mod(*args)
assert torch.allclose(actual, expected)
# No cached artifact should exist on disk
aot_dir = os.path.join(fresh_vllm_cache, "torch_compile_cache", "torch_aot_compile")
if os.path.isdir(aot_dir):
for root, _dirs, files in os.walk(aot_dir):
for f in files:
assert f != "model", (
f"AOT artifact unexpectedly saved at {os.path.join(root, f)}"
)
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
def test_disable_compile_cache_skips_aot_load(
monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str
):
"""When VLLM_DISABLE_COMPILE_CACHE=1, AOT artifacts must not be loaded."""
# Phase 1: compile and save with cache enabled
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
disable_envs_cache()
args = (torch.randn(10, 10),)
vllm_config = make_vllm_config()
with (
use_vllm_config(vllm_config),
compilation_counter.expect(num_aot_artifacts_saved=1),
):
CompiledMod(vllm_config=vllm_config)(*args)
# Phase 2: disable cache, compile again — should NOT load from disk
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
disable_envs_cache()
torch._dynamo.reset()
vllm_config = make_vllm_config()
with (
use_vllm_config(vllm_config),
compilation_counter.expect(
num_aot_compiles=1,
num_aot_artifacts_saved=0,
num_aot_artifacts_loaded=0,
),
):
mod = CompiledMod(vllm_config=vllm_config)
mod(*args)
assert not mod.was_aot_compile_fn_loaded_from_disk
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
def test_aot_counters_on_save_and_load(
monkeypatch: pytest.MonkeyPatch, fresh_vllm_cache: str
):
"""Verify AOT counters are incremented correctly on save and load."""
monkeypatch.setenv("VLLM_USE_AOT_COMPILE", "1")
disable_envs_cache()
args = (torch.randn(10, 10),)
# Phase 1: fresh compile + save
vllm_config = make_vllm_config()
with (
use_vllm_config(vllm_config),
compilation_counter.expect(
num_aot_compiles=1,
num_aot_artifacts_saved=1,
num_aot_artifacts_loaded=0,
),
):
CompiledMod(vllm_config=vllm_config)(*args)
# Phase 2: load from cache
monkeypatch.setenv("VLLM_FORCE_AOT_LOAD", "1")
disable_envs_cache()
vllm_config = make_vllm_config()
with (
use_vllm_config(vllm_config),
compilation_counter.expect(
num_aot_compiles=0,
num_aot_artifacts_saved=0,
num_aot_artifacts_loaded=1,
),
):
CompiledMod(vllm_config=vllm_config)(*args)
...@@ -31,6 +31,12 @@ class CompilationCounter: ...@@ -31,6 +31,12 @@ class CompilationCounter:
num_compiled_artifacts_saved: int = 0 num_compiled_artifacts_saved: int = 0
# The number of standalone_compile compiled artifacts loaded from cache # The number of standalone_compile compiled artifacts loaded from cache
num_compiled_artifacts_loaded: int = 0 num_compiled_artifacts_loaded: int = 0
# The number of AOT compile invocations
num_aot_compiles: int = 0
# The number of AOT compiled artifacts saved to disk
num_aot_artifacts_saved: int = 0
# The number of AOT compiled artifacts loaded from disk
num_aot_artifacts_loaded: int = 0
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE # Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
stock_torch_compile_count: int = 0 stock_torch_compile_count: int = 0
......
...@@ -266,6 +266,51 @@ def _verify_source_unchanged( ...@@ -266,6 +266,51 @@ def _verify_source_unchanged(
) )
def _try_load_aot_compiled_fn(
model: Any,
aot_compilation_path: str,
) -> Any | None:
"""Try to load an AOT-compiled function from disk.
Returns the loaded callable on success, or None on failure.
Re-raises on failure when ``VLLM_FORCE_AOT_LOAD`` is set.
"""
try:
with monitor_torch_compile(model.vllm_config):
with (
set_current_vllm_config(model.vllm_config),
open(aot_compilation_path, "rb") as f,
):
loaded_fn = torch.compiler.load_compiled_function(
f, f_globals=model.forward.__globals__
)
_verify_source_unchanged(loaded_fn.source_info(), model.vllm_config)
ds_config = model.compilation_config.dynamic_shapes_config
if not ds_config.evaluate_guards:
loaded_fn.disable_guard_check()
# Eagerly load compiled artifacts now that traced_files
# is populated by _verify_source_unchanged.
with maybe_use_cudagraph_partition_wrapper(model.vllm_config):
loaded_fn._artifacts.compiled_fn.finalize_loading(model.vllm_config)
compilation_counter.num_aot_artifacts_loaded += 1
logger.info("Directly load AOT compilation from path %s", aot_compilation_path)
return loaded_fn
except Exception as e:
if os.path.exists(aot_compilation_path):
if isinstance(e, EOFError):
message = "Compile cache file corrupted."
else:
message = str(e)
logger.warning(
"Compiling model again due to a load failure from %s, reason: %s",
aot_compilation_path,
message,
)
if envs.VLLM_FORCE_AOT_LOAD:
raise e
return None
def _support_torch_compile( def _support_torch_compile(
cls: type[_T], cls: type[_T],
dynamic_arg_dims: dict[str, int | list[int]], dynamic_arg_dims: dict[str, int | list[int]],
...@@ -438,45 +483,11 @@ def _support_torch_compile( ...@@ -438,45 +483,11 @@ def _support_torch_compile(
dp_rank = self.vllm_config.parallel_config.data_parallel_index dp_rank = self.vllm_config.parallel_config.data_parallel_index
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
aot_compilation_path = os.path.join(cache_dir, "model") aot_compilation_path = os.path.join(cache_dir, "model")
try: if not envs.VLLM_DISABLE_COMPILE_CACHE:
with monitor_torch_compile(self.vllm_config): loaded_fn = _try_load_aot_compiled_fn(self, aot_compilation_path)
with ( if loaded_fn is not None:
set_current_vllm_config(self.vllm_config),
open(aot_compilation_path, "rb") as f,
):
loaded_fn = torch.compiler.load_compiled_function(
f, f_globals=self.forward.__globals__
)
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
ds_config = self.compilation_config.dynamic_shapes_config
if not ds_config.evaluate_guards:
loaded_fn.disable_guard_check()
# Eagerly load compiled artifacts now that traced_files
# is populated by _verify_source_unchanged.
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
loaded_fn._artifacts.compiled_fn.finalize_loading(
self.vllm_config
)
self.aot_compiled_fn = loaded_fn self.aot_compiled_fn = loaded_fn
self.was_aot_compile_fn_loaded_from_disk = True self.was_aot_compile_fn_loaded_from_disk = True
except Exception as e:
if os.path.exists(aot_compilation_path):
if isinstance(e, EOFError):
message = "Compile cache file corrupted."
else:
message = str(e)
logger.warning(
"Compiling model again due to a load failure from %s, "
"reason: %s",
aot_compilation_path,
message,
)
if envs.VLLM_FORCE_AOT_LOAD:
raise e
if getattr(self, "aot_compiled_fn", None) is not None:
logger.info(
"Directly load AOT compilation from path %s", aot_compilation_path
)
with ( with (
monitor_profiling_run(), monitor_profiling_run(),
maybe_use_cudagraph_partition_wrapper(self.vllm_config), maybe_use_cudagraph_partition_wrapper(self.vllm_config),
...@@ -570,6 +581,7 @@ def _support_torch_compile( ...@@ -570,6 +581,7 @@ def _support_torch_compile(
self._aot_cache_dir = cache_dir self._aot_cache_dir = cache_dir
with monitor_torch_compile(self.vllm_config): with monitor_torch_compile(self.vllm_config):
self.aot_compiled_fn = self.aot_compile(*args, **kwargs) self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
compilation_counter.num_aot_compiles += 1
# All compilation is done at this point, save the # All compilation is done at this point, save the
# AOT artifact. # AOT artifact.
self.save_aot_compiled_function() self.save_aot_compiled_function()
...@@ -593,6 +605,9 @@ def _support_torch_compile( ...@@ -593,6 +605,9 @@ def _support_torch_compile(
# triggers VllmSerializableFunction.serialize() # triggers VllmSerializableFunction.serialize()
def save_aot_compiled_function(self: type[_T]) -> None: def save_aot_compiled_function(self: type[_T]) -> None:
if envs.VLLM_DISABLE_COMPILE_CACHE:
return
if self.was_aot_compile_fn_loaded_from_disk: if self.was_aot_compile_fn_loaded_from_disk:
logger.debug("AOT compiled function was loaded from cache, skipping save") logger.debug("AOT compiled function was loaded from cache, skipping save")
return return
...@@ -608,6 +623,7 @@ def _support_torch_compile( ...@@ -608,6 +623,7 @@ def _support_torch_compile(
tmp_file = f"{self._aot_compilation_path}.{os.getpid()}.tmp" tmp_file = f"{self._aot_compilation_path}.{os.getpid()}.tmp"
self.aot_compiled_fn.save_compiled_function(tmp_file) self.aot_compiled_fn.save_compiled_function(tmp_file)
os.replace(tmp_file, self._aot_compilation_path) os.replace(tmp_file, self._aot_compilation_path)
compilation_counter.num_aot_artifacts_saved += 1
logger.info_once( logger.info_once(
"saved AOT compiled function to %s", "saved AOT compiled function to %s",
self._aot_compilation_path, self._aot_compilation_path,
......
...@@ -349,6 +349,9 @@ def reset_compile_wrapper(model: torch.nn.Module) -> None: ...@@ -349,6 +349,9 @@ def reset_compile_wrapper(model: torch.nn.Module) -> None:
compilation_counter.num_cache_entries_updated = 0 compilation_counter.num_cache_entries_updated = 0
compilation_counter.num_compiled_artifacts_saved = 0 compilation_counter.num_compiled_artifacts_saved = 0
compilation_counter.stock_torch_compile_count = 0 compilation_counter.stock_torch_compile_count = 0
compilation_counter.num_aot_compiles = 0
compilation_counter.num_aot_artifacts_saved = 0
compilation_counter.num_aot_artifacts_loaded = 0
# Clear the AOT compiled function so the model is forced to # Clear the AOT compiled function so the model is forced to
# recompile on the next call. Without this, decorators.py # recompile on the next call. Without this, decorators.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment