Unverified Commit 2dde535d authored by Richard Zou's avatar Richard Zou Committed by GitHub
Browse files

[compile] Split compile/warmup monitoring (#36098)

parent 379689d5
...@@ -189,13 +189,13 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -189,13 +189,13 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
self.shape_env = None self.shape_env = None
self.vllm_backend = vllm_backend self.vllm_backend = vllm_backend
self.sym_tensor_indices = sym_tensor_indices self.sym_tensor_indices = sym_tensor_indices
self._fake_mode: Any | None = None
import torch._functorch.config as functorch_config import torch._functorch.config as functorch_config
self.aot_autograd_config = ( self.aot_autograd_config = (
aot_autograd_config or functorch_config.save_config_portable() aot_autograd_config or functorch_config.save_config_portable()
) )
sym_input = next( sym_input = next(
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
) )
...@@ -217,6 +217,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -217,6 +217,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
state.pop("optimized_call") state.pop("optimized_call")
state.pop("shape_env") state.pop("shape_env")
state.pop("vllm_backend", None) state.pop("vllm_backend", None)
state.pop("_fake_mode", None)
for node in state["graph_module"].graph.nodes: for node in state["graph_module"].graph.nodes:
node.meta.pop("source_fn_stack", None) node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None) node.meta.pop("nn_module_stack", None)
...@@ -351,8 +352,31 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -351,8 +352,31 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return fn.optimized_call(*example_inputs) return fn.optimized_call(*example_inputs)
fn = cls(**state, optimized_call=optimized_call) fn = cls(**state, optimized_call=optimized_call)
fn._fake_mode = fake_mode
return fn return fn
def finalize_loading(self, vllm_config: VllmConfig) -> None:
"""Eagerly initialize the compiled backend and perform all loading.
Must be called after _verify_source_unchanged has populated
compilation_config.traced_files, which is needed for cache dir
computation.
"""
if self._fake_mode is None:
return # Already finalized, or mega path (no _fake_mode set)
from torch._guards import TracingContext, tracing
from vllm.compilation.backends import VllmBackend
vllm_backend = VllmBackend(vllm_config, self.prefix, self.is_encoder)
with tracing(TracingContext(self._fake_mode)):
result = vllm_backend(self.graph_module, list(self.example_inputs))
self.optimized_call = result.optimized_call
self.vllm_backend = vllm_backend
self._fake_mode = None
@property @property
def co_name(self) -> Literal["VllmSerializableFunction"]: def co_name(self) -> Literal["VllmSerializableFunction"]:
""" """
......
...@@ -30,7 +30,7 @@ from vllm.sequence import IntermediateTensors ...@@ -30,7 +30,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
from .monitor import start_monitoring_torch_compile from .monitor import monitor_profiling_run, monitor_torch_compile
if TYPE_CHECKING: if TYPE_CHECKING:
# Only added on nightly/2.10 so wrap # Only added on nightly/2.10 so wrap
...@@ -434,17 +434,24 @@ def _support_torch_compile( ...@@ -434,17 +434,24 @@ def _support_torch_compile(
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
aot_compilation_path = os.path.join(cache_dir, "model") aot_compilation_path = os.path.join(cache_dir, "model")
try: try:
with monitor_torch_compile(self.vllm_config):
with ( with (
set_current_vllm_config(self.vllm_config), set_current_vllm_config(self.vllm_config),
open(aot_compilation_path, "rb") as f, open(aot_compilation_path, "rb") as f,
): ):
start_monitoring_torch_compile(self.vllm_config)
loaded_fn = torch.compiler.load_compiled_function( loaded_fn = torch.compiler.load_compiled_function(
f, f_globals=self.forward.__globals__ f, f_globals=self.forward.__globals__
) )
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config) _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
if not self.compilation_config.dynamic_shapes_config.evaluate_guards: ds_config = self.compilation_config.dynamic_shapes_config
if not ds_config.evaluate_guards:
loaded_fn.disable_guard_check() loaded_fn.disable_guard_check()
# Eagerly load compiled artifacts now that traced_files
# is populated by _verify_source_unchanged.
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
loaded_fn._artifacts.compiled_fn.finalize_loading(
self.vllm_config
)
self.aot_compiled_fn = loaded_fn self.aot_compiled_fn = loaded_fn
self.was_aot_compile_fn_loaded_from_disk = True self.was_aot_compile_fn_loaded_from_disk = True
except Exception as e: except Exception as e:
...@@ -465,12 +472,11 @@ def _support_torch_compile( ...@@ -465,12 +472,11 @@ def _support_torch_compile(
logger.info( logger.info(
"Directly load AOT compilation from path %s", aot_compilation_path "Directly load AOT compilation from path %s", aot_compilation_path
) )
# Apply partition wrapper context for proper CUDA graph capture with (
from .monitor import end_monitoring_torch_compile monitor_profiling_run(),
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
with maybe_use_cudagraph_partition_wrapper(self.vllm_config): ):
output = self.aot_compiled_fn(self, *args, **kwargs) output = self.aot_compiled_fn(self, *args, **kwargs)
end_monitoring_torch_compile(self.vllm_config)
return output return output
if self.compiled: if self.compiled:
...@@ -489,8 +495,6 @@ def _support_torch_compile( ...@@ -489,8 +495,6 @@ def _support_torch_compile(
**kwargs, **kwargs,
) )
# here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config)
original_code_object = self.original_code_object() original_code_object = self.original_code_object()
logger.debug("Start compiling function %s", original_code_object) logger.debug("Start compiling function %s", original_code_object)
...@@ -559,16 +563,26 @@ def _support_torch_compile( ...@@ -559,16 +563,26 @@ def _support_torch_compile(
# store the path for saving after warmup # store the path for saving after warmup
self._aot_compilation_path = aot_compilation_path self._aot_compilation_path = aot_compilation_path
self._aot_cache_dir = cache_dir self._aot_cache_dir = cache_dir
with monitor_torch_compile(self.vllm_config):
self.aot_compiled_fn = self.aot_compile(*args, **kwargs) self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
# All compilation is done at this point, save the AOT artifact. # All compilation is done at this point, save the
# AOT artifact.
self.save_aot_compiled_function() self.save_aot_compiled_function()
with monitor_profiling_run():
output = self.aot_compiled_fn(self, *args, **kwargs) output = self.aot_compiled_fn(self, *args, **kwargs)
else: else:
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type] with monitor_torch_compile(
self.vllm_config,
from .monitor import end_monitoring_torch_compile "torch.compile and initial profiling/warmup "
"run together took %.2f s in total",
):
output = TorchCompileWithNoGuardsWrapper.__call__(
self, # type: ignore[arg-type]
*args,
**kwargs,
)
end_monitoring_torch_compile(self.vllm_config)
self.compiled = True self.compiled = True
return output return output
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import time import time
from collections.abc import Generator
from vllm.config import CompilationConfig, CompilationMode, VllmConfig from vllm.config import CompilationMode, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
context_manager = None # Shared global so backends.py can read the start time for Dynamo timing.
torch_compile_start_time: float = 0.0 torch_compile_start_time: float = 0.0
def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None: @contextlib.contextmanager
def monitor_torch_compile(
vllm_config: VllmConfig,
message: str = "torch.compile took %.2f s in total",
) -> Generator[None, None, None]:
"""Context manager that times torch.compile and manages depyf debugging.
On normal exit: logs the compile time and exits depyf.
On exception: cleans up depyf without logging (compilation failed).
"""
global torch_compile_start_time global torch_compile_start_time
torch_compile_start_time = time.perf_counter() torch_compile_start_time = time.perf_counter()
compilation_config: CompilationConfig = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
depyf_cm = None
path = vllm_config.compile_debug_dump_path() path = vllm_config.compile_debug_dump_path()
if compilation_config.mode == CompilationMode.VLLM_COMPILE and path: if compilation_config.mode == CompilationMode.VLLM_COMPILE and path:
import depyf import depyf
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
logger.debug("Dumping depyf output to %s", path) logger.debug("Dumping depyf output to %s", path)
global context_manager depyf_cm = depyf.prepare_debug(path.as_posix())
context_manager = depyf.prepare_debug(path.as_posix()) depyf_cm.__enter__()
context_manager.__enter__()
try:
yield
def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None: except Exception:
compilation_config: CompilationConfig = vllm_config.compilation_config raise
total_compile_time: float = time.perf_counter() - torch_compile_start_time else:
total_compile_time = time.perf_counter() - torch_compile_start_time
if compilation_config.mode == CompilationMode.VLLM_COMPILE: if compilation_config.mode == CompilationMode.VLLM_COMPILE:
logger.info_once(message, total_compile_time, scope="local")
finally:
if depyf_cm is not None:
try:
depyf_cm.__exit__(None, None, None)
except Exception:
logger.warning("Exception during depyf cleanup.", exc_info=True)
@contextlib.contextmanager
def monitor_profiling_run() -> Generator[None, None, None]:
"""Context manager that times the initial profiling run.
Asserts that no backend compilation occurs during the profiling run
(all compilation should have completed before this point).
"""
from vllm.compilation.counter import compilation_counter
backend_compilations_before = compilation_counter.num_backend_compilations
start = time.perf_counter()
yield
elapsed = time.perf_counter() - start
assert (
compilation_counter.num_backend_compilations == backend_compilations_before
), (
"backend compilation occurred during the initial profiling run; "
"all compilation should be complete before the profiling run starts."
)
logger.info_once( logger.info_once(
"torch.compile and initial profiling run took %.2f s in total", "Initial profiling/warmup run took %.2f s",
total_compile_time, elapsed,
scope="local", scope="local",
) )
global context_manager
if context_manager is not None:
context_manager.__exit__(None, None, None)
context_manager = None
cudagraph_capturing_enabled: bool = True cudagraph_capturing_enabled: bool = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment