Unverified Commit 0283f303 authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

[BE] Fix compile time message to be consistent (use monitoring) (#40641)


Signed-off-by: default avatarLucas Kabela <lucaskabela@meta.com>
parent ac58e2a1
...@@ -283,10 +283,6 @@ class CompilerManager: ...@@ -283,10 +283,6 @@ class CompilerManager:
# after loading the last graph for this shape, record the time. # after loading the last graph for this shape, record the time.
# there can be multiple graphs due to piecewise compilation. # there can be multiple graphs due to piecewise compilation.
elapsed = time.perf_counter() - compilation_start_time elapsed = time.perf_counter() - compilation_start_time
if is_encoder:
compilation_config.encoder_compilation_time += elapsed
else:
compilation_config.compilation_time += elapsed
logger.info_once( logger.info_once(
"Directly load the compiled graph(s) for compile range %s " "Directly load the compiled graph(s) for compile range %s "
"from the cache, took %.3f s", "from the cache, took %.3f s",
...@@ -388,10 +384,6 @@ class CompilerManager: ...@@ -388,10 +384,6 @@ class CompilerManager:
# after compiling the last graph, record the end time # after compiling the last graph, record the end time
if graph_index == num_graphs - 1: if graph_index == num_graphs - 1:
elapsed = time.perf_counter() - compilation_start_time elapsed = time.perf_counter() - compilation_start_time
if is_encoder:
compilation_config.encoder_compilation_time += elapsed
else:
compilation_config.compilation_time += elapsed
logger.info_once( logger.info_once(
"Compiling a graph for compile range %s takes %.2f s", "Compiling a graph for compile range %s takes %.2f s",
str(compile_range), str(compile_range),
...@@ -1129,11 +1121,10 @@ class VllmBackend: ...@@ -1129,11 +1121,10 @@ class VllmBackend:
from .monitor import torch_compile_start_time from .monitor import torch_compile_start_time
dynamo_time = time.perf_counter() - torch_compile_start_time dynamo_time = time.perf_counter() - torch_compile_start_time
logger.info_once("Dynamo bytecode transform time: %.2f s", dynamo_time) logger.info_once(
if self.is_encoder: "Dynamo bytecode transform time: %.2f s",
self.compilation_config.encoder_compilation_time += dynamo_time dynamo_time,
else: )
self.compilation_config.compilation_time += dynamo_time
# Record Dynamo time in tracing if available # Record Dynamo time in tracing if available
start_time = int(torch_compile_start_time * 1e9) start_time = int(torch_compile_start_time * 1e9)
......
...@@ -285,7 +285,7 @@ def _try_load_aot_compiled_fn( ...@@ -285,7 +285,7 @@ def _try_load_aot_compiled_fn(
Re-raises on failure when ``VLLM_FORCE_AOT_LOAD`` is set. Re-raises on failure when ``VLLM_FORCE_AOT_LOAD`` is set.
""" """
try: try:
with monitor_torch_compile(model.vllm_config): with monitor_torch_compile(model.vllm_config, is_encoder=model._is_encoder):
with ( with (
set_current_vllm_config(model.vllm_config), set_current_vllm_config(model.vllm_config),
open(aot_compilation_path, "rb") as f, open(aot_compilation_path, "rb") as f,
...@@ -617,7 +617,9 @@ def _support_torch_compile( ...@@ -617,7 +617,9 @@ def _support_torch_compile(
# store the path for saving after warmup # store the path for saving after warmup
self._aot_compilation_path = aot_compilation_path self._aot_compilation_path = aot_compilation_path
self._aot_cache_dir = cache_dir self._aot_cache_dir = cache_dir
with monitor_torch_compile(self.vllm_config): with monitor_torch_compile(
self.vllm_config, is_encoder=self._is_encoder
):
self.aot_compiled_fn = self.aot_compile(*args, **kwargs) self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
compilation_counter.num_aot_compiles += 1 compilation_counter.num_aot_compiles += 1
# All compilation is done at this point, save the # All compilation is done at this point, save the
...@@ -631,6 +633,7 @@ def _support_torch_compile( ...@@ -631,6 +633,7 @@ def _support_torch_compile(
self.vllm_config, self.vllm_config,
"torch.compile and initial profiling/warmup " "torch.compile and initial profiling/warmup "
"run together took %.2f s in total", "run together took %.2f s in total",
is_encoder=self._is_encoder,
): ):
output = TorchCompileWithNoGuardsWrapper.__call__( output = TorchCompileWithNoGuardsWrapper.__call__(
self, # type: ignore[arg-type] self, # type: ignore[arg-type]
......
...@@ -18,6 +18,7 @@ torch_compile_start_time: float = 0.0 ...@@ -18,6 +18,7 @@ torch_compile_start_time: float = 0.0
def monitor_torch_compile( def monitor_torch_compile(
vllm_config: VllmConfig, vllm_config: VllmConfig,
message: str = "torch.compile took %.2f s in total", message: str = "torch.compile took %.2f s in total",
is_encoder: bool = False,
) -> Generator[None, None, None]: ) -> Generator[None, None, None]:
"""Context manager that times torch.compile and manages depyf debugging. """Context manager that times torch.compile and manages depyf debugging.
...@@ -45,6 +46,10 @@ def monitor_torch_compile( ...@@ -45,6 +46,10 @@ def monitor_torch_compile(
else: else:
total_compile_time = time.perf_counter() - torch_compile_start_time total_compile_time = time.perf_counter() - torch_compile_start_time
if compilation_config.mode == CompilationMode.VLLM_COMPILE: if compilation_config.mode == CompilationMode.VLLM_COMPILE:
if is_encoder:
compilation_config.encoder_compilation_time += total_compile_time
else:
compilation_config.compilation_time += total_compile_time
logger.info_once(message, total_compile_time) logger.info_once(message, total_compile_time)
finally: finally:
if depyf_cm is not None: if depyf_cm is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment