Unverified Commit 7b346ba8 authored by Huy Do's avatar Huy Do Committed by GitHub
Browse files

[Bugfix] Propagate compilation_time from workers to main process for TP>1 (#35503)


Signed-off-by: default avatarHuy Do <huydhn@gmail.com>
parent dea26833
...@@ -115,7 +115,15 @@ class Executor(ABC): ...@@ -115,7 +115,15 @@ class Executor(ABC):
underlying workers. underlying workers.
""" """
self.collective_rpc("initialize_from_config", args=(kv_cache_configs,)) self.collective_rpc("initialize_from_config", args=(kv_cache_configs,))
self.collective_rpc("compile_or_warm_up_model") compilation_times: list[float] = self.collective_rpc("compile_or_warm_up_model")
# Propagate compilation time from workers back to the main process.
# With TP>1, compilation happens in worker processes, so the main
# process config is never updated. Use max across workers since they
# compile in parallel.
if compilation_times:
self.vllm_config.compilation_config.compilation_time = max(
compilation_times
)
def register_failure_callback(self, callback: FailureCallback): # noqa: B027 def register_failure_callback(self, callback: FailureCallback): # noqa: B027
""" """
......
...@@ -118,11 +118,12 @@ class CPUWorker(Worker): ...@@ -118,11 +118,12 @@ class CPUWorker(Worker):
def determine_available_memory(self) -> int: def determine_available_memory(self) -> int:
return self.cache_config.cpu_kvcache_space_bytes or 0 return self.cache_config.cpu_kvcache_space_bytes or 0
def compile_or_warm_up_model(self) -> None: def compile_or_warm_up_model(self) -> float:
# Reset the seed to ensure that the random state is not affected by # Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
self.model_runner.warming_up_model() self.model_runner.warming_up_model()
return self.compilation_config.compilation_time
def _get_autobind_cpu_ids( def _get_autobind_cpu_ids(
self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]] self, cpu_selector: Callable[[list[LogicalCPUInfo]], list[LogicalCPUInfo]]
......
...@@ -480,7 +480,7 @@ class Worker(WorkerBase): ...@@ -480,7 +480,7 @@ class Worker(WorkerBase):
self.model_runner.initialize_kv_cache(kv_cache_config) self.model_runner.initialize_kv_cache(kv_cache_config)
@instrument(span_name="Warmup (GPU)") @instrument(span_name="Warmup (GPU)")
def compile_or_warm_up_model(self) -> None: def compile_or_warm_up_model(self) -> float:
warmup_sizes = [] warmup_sizes = []
if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE: if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
...@@ -605,6 +605,8 @@ class Worker(WorkerBase): ...@@ -605,6 +605,8 @@ class Worker(WorkerBase):
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
return self.compilation_config.compilation_time
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache() self.model_runner.reset_mm_cache()
......
...@@ -87,8 +87,12 @@ class WorkerBase: ...@@ -87,8 +87,12 @@ class WorkerBase:
"""Get specifications for KV cache implementation.""" """Get specifications for KV cache implementation."""
raise NotImplementedError raise NotImplementedError
def compile_or_warm_up_model(self) -> None: def compile_or_warm_up_model(self) -> float:
"""Prepare model for execution through compilation/warmup.""" """Prepare model for execution through compilation/warmup.
Returns:
The accumulated compilation time in seconds.
"""
raise NotImplementedError raise NotImplementedError
def check_health(self) -> None: def check_health(self) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment