Unverified Commit c08f3b2a authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

Measure encoder compile time seperate from llm backbone (#39240)


Signed-off-by: default avatarLucas Kabela <lucaskabela@meta.com>
parent f02b3269
...@@ -16,7 +16,7 @@ import shutil ...@@ -16,7 +16,7 @@ import shutil
import tempfile import tempfile
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any from typing import Any, NamedTuple
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
...@@ -27,6 +27,82 @@ from vllm.benchmarks.lib.utils import ( ...@@ -27,6 +27,82 @@ from vllm.benchmarks.lib.utils import (
) )
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
PERCENTAGES = [10, 25, 50, 75, 90, 99]
class MetricDesc(NamedTuple):
"""Descriptor for a metric to collect from each iteration."""
iter_key: str # key in the iteration result dict
suffix: str # result key suffix, e.g. "startup", "compilation"
display_name: str
class MetricStats(NamedTuple):
"""Aggregated statistics for a single benchmark metric."""
key: str # e.g. "cold_startup", "warm_encoder_compilation"
display_name: str
values: list[float]
avg: float
percentiles: dict[int, float]
_BASE_METRICS = [
MetricDesc("total_startup_time", "startup", "Startup time"),
MetricDesc("compilation_time", "compilation", "Compilation time"),
]
_ENCODER_METRIC = MetricDesc(
"encoder_compilation_time",
"encoder_compilation",
"Encoder compilation time",
)
def _compute_metric(
phase: str,
desc: MetricDesc,
iterations: list[dict[str, float]],
) -> MetricStats:
values = [m[desc.iter_key] for m in iterations]
arr = np.array(values)
return MetricStats(
key=f"{phase}_{desc.suffix}",
display_name=desc.display_name,
values=values,
avg=float(np.mean(arr)),
percentiles=dict(zip(PERCENTAGES, np.percentile(arr, PERCENTAGES).tolist())),
)
def _collect_phase_metrics(
phase: str,
iterations: list[dict[str, float]],
has_encoder: bool,
) -> list[MetricStats]:
metrics = [_compute_metric(phase, desc, iterations) for desc in _BASE_METRICS]
if has_encoder:
metrics.append(_compute_metric(phase, _ENCODER_METRIC, iterations))
return metrics
def _print_phase(phase_name: str, metrics: list[MetricStats]) -> None:
print(f"\n{phase_name}:")
for m in metrics:
print(f"Avg {m.display_name.lower()}: {m.avg:.2f} seconds")
for m in metrics:
print(f"{m.display_name} percentiles:")
for pct, val in m.percentiles.items():
print(f" {pct}%: {val:.2f} seconds")
def _metric_to_json(m: MetricStats) -> dict[str, Any]:
return {
f"avg_{m.key}_time": m.avg,
f"{m.key}_times": m.values,
f"{m.key}_percentiles": m.percentiles,
}
@contextmanager @contextmanager
def cold_startup(): def cold_startup():
...@@ -72,6 +148,7 @@ def run_startup_in_subprocess(engine_args, result_queue): ...@@ -72,6 +148,7 @@ def run_startup_in_subprocess(engine_args, result_queue):
# Extract compilation time if available # Extract compilation time if available
compilation_time = 0.0 compilation_time = 0.0
encoder_compilation_time = 0.0
if hasattr(llm.llm_engine, "vllm_config"): if hasattr(llm.llm_engine, "vllm_config"):
vllm_config = llm.llm_engine.vllm_config vllm_config = llm.llm_engine.vllm_config
if ( if (
...@@ -79,11 +156,15 @@ def run_startup_in_subprocess(engine_args, result_queue): ...@@ -79,11 +156,15 @@ def run_startup_in_subprocess(engine_args, result_queue):
and vllm_config.compilation_config is not None and vllm_config.compilation_config is not None
): ):
compilation_time = vllm_config.compilation_config.compilation_time compilation_time = vllm_config.compilation_config.compilation_time
encoder_compilation_time = (
vllm_config.compilation_config.encoder_compilation_time
)
result_queue.put( result_queue.put(
{ {
"total_startup_time": total_startup_time, "total_startup_time": total_startup_time,
"compilation_time": compilation_time, "compilation_time": compilation_time,
"encoder_compilation_time": encoder_compilation_time,
} }
) )
...@@ -93,65 +174,20 @@ def run_startup_in_subprocess(engine_args, result_queue): ...@@ -93,65 +174,20 @@ def run_startup_in_subprocess(engine_args, result_queue):
def save_to_pytorch_benchmark_format( def save_to_pytorch_benchmark_format(
args: argparse.Namespace, results: dict[str, Any] args: argparse.Namespace, metrics: list[MetricStats]
) -> None: ) -> None:
base_name = os.path.splitext(args.output_json)[0] base_name = os.path.splitext(args.output_json)[0]
for m in metrics:
cold_startup_records = convert_to_pytorch_benchmark_format( records = convert_to_pytorch_benchmark_format(
args=args, args=args,
metrics={ metrics={f"avg_{m.key}_time": [m.avg]},
"avg_cold_startup_time": [results["avg_cold_startup_time"]],
},
extra_info={ extra_info={
"cold_startup_times": results["cold_startup_times"], f"{m.key}_times": m.values,
"cold_startup_percentiles": results["cold_startup_percentiles"], f"{m.key}_percentiles": m.percentiles,
}, },
) )
if cold_startup_records: if records:
write_to_json(f"{base_name}.cold_startup.pytorch.json", cold_startup_records) write_to_json(f"{base_name}.{m.key}.pytorch.json", records)
cold_compilation_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_cold_compilation_time": [results["avg_cold_compilation_time"]],
},
extra_info={
"cold_compilation_times": results["cold_compilation_times"],
"cold_compilation_percentiles": results["cold_compilation_percentiles"],
},
)
if cold_compilation_records:
write_to_json(
f"{base_name}.cold_compilation.pytorch.json", cold_compilation_records
)
warm_startup_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_warm_startup_time": [results["avg_warm_startup_time"]],
},
extra_info={
"warm_startup_times": results["warm_startup_times"],
"warm_startup_percentiles": results["warm_startup_percentiles"],
},
)
if warm_startup_records:
write_to_json(f"{base_name}.warm_startup.pytorch.json", warm_startup_records)
warm_compilation_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_warm_compilation_time": [results["avg_warm_compilation_time"]],
},
extra_info={
"warm_compilation_times": results["warm_compilation_times"],
"warm_compilation_percentiles": results["warm_compilation_percentiles"],
},
)
if warm_compilation_records:
write_to_json(
f"{base_name}.warm_compilation.pytorch.json", warm_compilation_records
)
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -224,97 +260,46 @@ def main(args: argparse.Namespace): ...@@ -224,97 +260,46 @@ def main(args: argparse.Namespace):
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
print("Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.\n") print("Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.\n")
# Collect cold startup iterations
print("Measuring cold startup time...\n") print("Measuring cold startup time...\n")
cold_startup_times = [] cold_iterations = []
cold_compilation_times = []
for i in tqdm(range(args.num_iters_cold), desc="Cold startup iterations"): for i in tqdm(range(args.num_iters_cold), desc="Cold startup iterations"):
with cold_startup(): with cold_startup():
metrics = create_llm_and_measure_startup() cold_iterations.append(create_llm_and_measure_startup())
cold_startup_times.append(metrics["total_startup_time"])
cold_compilation_times.append(metrics["compilation_time"])
# Warmup for warm startup # Warmup for warm startup
print("\nWarming up for warm startup measurement...\n") print("\nWarming up for warm startup measurement...\n")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
create_llm_and_measure_startup() create_llm_and_measure_startup()
# Collect warm startup iterations
print("\nMeasuring warm startup time...\n") print("\nMeasuring warm startup time...\n")
warm_startup_times = [] warm_iterations = []
warm_compilation_times = []
for i in tqdm(range(args.num_iters_warm), desc="Warm startup iterations"): for i in tqdm(range(args.num_iters_warm), desc="Warm startup iterations"):
metrics = create_llm_and_measure_startup() warm_iterations.append(create_llm_and_measure_startup())
warm_startup_times.append(metrics["total_startup_time"])
warm_compilation_times.append(metrics["compilation_time"]) # Determine if encoder compilation occurred in any iteration
has_encoder = any(
# Calculate statistics m["encoder_compilation_time"] > 0 for m in cold_iterations + warm_iterations
cold_startup_array = np.array(cold_startup_times) )
cold_compilation_array = np.array(cold_compilation_times)
warm_startup_array = np.array(warm_startup_times) cold_metrics = _collect_phase_metrics("cold", cold_iterations, has_encoder)
warm_compilation_array = np.array(warm_compilation_times) warm_metrics = _collect_phase_metrics("warm", warm_iterations, has_encoder)
all_metrics = cold_metrics + warm_metrics
avg_cold_startup = np.mean(cold_startup_array)
avg_cold_compilation = np.mean(cold_compilation_array)
avg_warm_startup = np.mean(warm_startup_array)
avg_warm_compilation = np.mean(warm_compilation_array)
percentages = [10, 25, 50, 75, 90, 99]
cold_startup_percentiles = np.percentile(cold_startup_array, percentages)
cold_compilation_percentiles = np.percentile(cold_compilation_array, percentages)
warm_startup_percentiles = np.percentile(warm_startup_array, percentages)
warm_compilation_percentiles = np.percentile(warm_compilation_array, percentages)
# Print results
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("STARTUP TIME BENCHMARK RESULTS") print("STARTUP TIME BENCHMARK RESULTS")
print("=" * 60) print("=" * 60)
_print_phase("COLD STARTUP", cold_metrics)
# Cold startup statistics _print_phase("WARM STARTUP", warm_metrics)
print("\nCOLD STARTUP:")
print(f"Avg total startup time: {avg_cold_startup:.2f} seconds")
print(f"Avg compilation time: {avg_cold_compilation:.2f} seconds")
print("Startup time percentiles:")
for percentage, percentile in zip(percentages, cold_startup_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
print("Compilation time percentiles:")
for percentage, percentile in zip(percentages, cold_compilation_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
# Warm startup statistics
print("\nWARM STARTUP:")
print(f"Avg total startup time: {avg_warm_startup:.2f} seconds")
print(f"Avg compilation time: {avg_warm_compilation:.2f} seconds")
print("Startup time percentiles:")
for percentage, percentile in zip(percentages, warm_startup_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
print("Compilation time percentiles:")
for percentage, percentile in zip(percentages, warm_compilation_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
print("=" * 60) print("=" * 60)
# Output JSON results if specified # Output JSON results if specified
if args.output_json: if args.output_json:
results = { results: dict[str, Any] = {}
"avg_cold_startup_time": float(avg_cold_startup), for m in all_metrics:
"avg_cold_compilation_time": float(avg_cold_compilation), results.update(_metric_to_json(m))
"cold_startup_times": cold_startup_times,
"cold_compilation_times": cold_compilation_times,
"cold_startup_percentiles": dict(
zip(percentages, cold_startup_percentiles.tolist())
),
"cold_compilation_percentiles": dict(
zip(percentages, cold_compilation_percentiles.tolist())
),
"avg_warm_startup_time": float(avg_warm_startup),
"avg_warm_compilation_time": float(avg_warm_compilation),
"warm_startup_times": warm_startup_times,
"warm_compilation_times": warm_compilation_times,
"warm_startup_percentiles": dict(
zip(percentages, warm_startup_percentiles.tolist())
),
"warm_compilation_percentiles": dict(
zip(percentages, warm_compilation_percentiles.tolist())
),
}
with open(args.output_json, "w") as f: with open(args.output_json, "w") as f:
json.dump(results, f, indent=4) json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results) save_to_pytorch_benchmark_format(args, all_metrics)
...@@ -265,6 +265,7 @@ class CompilerManager: ...@@ -265,6 +265,7 @@ class CompilerManager:
compile_range: Range, compile_range: Range,
graph_index: int = 0, graph_index: int = 0,
num_graphs: int = 1, num_graphs: int = 1,
is_encoder: bool = False,
) -> Any: ) -> Any:
if graph_index == 0: if graph_index == 0:
# before compiling the first graph, record the start time # before compiling the first graph, record the start time
...@@ -282,6 +283,9 @@ class CompilerManager: ...@@ -282,6 +283,9 @@ class CompilerManager:
# after loading the last graph for this shape, record the time. # after loading the last graph for this shape, record the time.
# there can be multiple graphs due to piecewise compilation. # there can be multiple graphs due to piecewise compilation.
elapsed = time.perf_counter() - compilation_start_time elapsed = time.perf_counter() - compilation_start_time
if is_encoder:
compilation_config.encoder_compilation_time += elapsed
else:
compilation_config.compilation_time += elapsed compilation_config.compilation_time += elapsed
logger.info_once( logger.info_once(
"Directly load the compiled graph(s) for compile range %s " "Directly load the compiled graph(s) for compile range %s "
...@@ -387,6 +391,9 @@ class CompilerManager: ...@@ -387,6 +391,9 @@ class CompilerManager:
# after compiling the last graph, record the end time # after compiling the last graph, record the end time
if graph_index == num_graphs - 1: if graph_index == num_graphs - 1:
elapsed = time.perf_counter() - compilation_start_time elapsed = time.perf_counter() - compilation_start_time
if is_encoder:
compilation_config.encoder_compilation_time += elapsed
else:
compilation_config.compilation_time += elapsed compilation_config.compilation_time += elapsed
logger.info_once( logger.info_once(
"Compiling a graph for compile range %s takes %.2f s", "Compiling a graph for compile range %s takes %.2f s",
...@@ -1130,6 +1137,9 @@ class VllmBackend: ...@@ -1130,6 +1137,9 @@ class VllmBackend:
logger.info_once( logger.info_once(
"Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local" "Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
) )
if self.is_encoder:
self.compilation_config.encoder_compilation_time += dynamo_time
else:
self.compilation_config.compilation_time += dynamo_time self.compilation_config.compilation_time += dynamo_time
# Record Dynamo time in tracing if available # Record Dynamo time in tracing if available
......
...@@ -270,6 +270,7 @@ class PiecewiseBackend: ...@@ -270,6 +270,7 @@ class PiecewiseBackend:
compile_range=range_entry.compile_range, compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index, graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles, num_graphs=self.total_piecewise_compiles,
is_encoder=self.vllm_backend.is_encoder,
) )
range_entry.compiled = True range_entry.compiled = True
......
...@@ -710,6 +710,8 @@ class CompilationConfig: ...@@ -710,6 +710,8 @@ class CompilationConfig:
"""files that are traced for compilation""" """files that are traced for compilation"""
compilation_time: float = field(default=0.0, init=False) compilation_time: float = field(default=0.0, init=False)
"""time taken for compilation""" """time taken for compilation"""
encoder_compilation_time: float = field(default=0.0, init=False)
"""time taken for multimodal encoder compilation"""
static_forward_context: dict[str, Any] = field(default_factory=dict, init=False) static_forward_context: dict[str, Any] = field(default_factory=dict, init=False)
"""Per-model forward context """Per-model forward context
...@@ -756,6 +758,7 @@ class CompilationConfig: ...@@ -756,6 +758,7 @@ class CompilationConfig:
"local_cache_dir", "local_cache_dir",
"traced_files", "traced_files",
"compilation_time", "compilation_time",
"encoder_compilation_time",
"static_forward_context", "static_forward_context",
"pass_config", # handled separately below "pass_config", # handled separately below
"dynamic_shapes_config", # handled separately below "dynamic_shapes_config", # handled separately below
...@@ -775,6 +778,7 @@ class CompilationConfig: ...@@ -775,6 +778,7 @@ class CompilationConfig:
"enabled_custom_ops": True, "enabled_custom_ops": True,
"disabled_custom_ops": True, "disabled_custom_ops": True,
"compilation_time": True, "compilation_time": True,
"encoder_compilation_time": True,
"traced_files": True, "traced_files": True,
"inductor_compile_config": { "inductor_compile_config": {
"post_grad_custom_post_pass": True, "post_grad_custom_post_pass": True,
......
...@@ -282,8 +282,30 @@ class EngineCore: ...@@ -282,8 +282,30 @@ class EngineCore:
self.model_executor.initialize_from_config(kv_cache_configs) self.model_executor.initialize_from_config(kv_cache_configs)
elapsed = time.time() - start elapsed = time.time() - start
compile_time = vllm_config.compilation_config.compilation_time
encoder_compile_time = vllm_config.compilation_config.encoder_compilation_time
if encoder_compile_time > 0:
logger.info_once( logger.info_once(
"init engine (profile, create kv cache, warmup model) took %.2f seconds", "init engine (profile, create kv cache, warmup model) took "
"%.2f s (compilation: %.2f s — language_model: %.2f s, "
"encoder: %.2f s)",
elapsed,
compile_time + encoder_compile_time,
compile_time,
encoder_compile_time,
scope="local",
)
elif compile_time > 0:
logger.info_once(
"init engine (profile, create kv cache, warmup model) took "
"%.2f s (compilation: %.2f s)",
elapsed,
compile_time,
scope="local",
)
else:
logger.info_once(
"init engine (profile, create kv cache, warmup model) took %.2f s",
elapsed, elapsed,
scope="local", scope="local",
) )
......
...@@ -22,7 +22,7 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput ...@@ -22,7 +22,7 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest from vllm.v1.engine import ReconfigureDistributedRequest
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import CompilationTimes, WorkerBase
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
...@@ -121,14 +121,19 @@ class Executor(ABC): ...@@ -121,14 +121,19 @@ class Executor(ABC):
underlying workers. underlying workers.
""" """
self.collective_rpc("initialize_from_config", args=(kv_cache_configs,)) self.collective_rpc("initialize_from_config", args=(kv_cache_configs,))
compilation_times: list[float] = self.collective_rpc("compile_or_warm_up_model") compilation_times: list[CompilationTimes] = self.collective_rpc(
"compile_or_warm_up_model"
)
# Propagate compilation time from workers back to the main process. # Propagate compilation time from workers back to the main process.
# With TP>1, compilation happens in worker processes, so the main # With TP>1, compilation happens in worker processes, so the main
# process config is never updated. Use max across workers since they # process config is never updated. Use max across workers since they
# compile in parallel. # compile in parallel.
if compilation_times: if compilation_times:
self.vllm_config.compilation_config.compilation_time = max( self.vllm_config.compilation_config.compilation_time = max(
compilation_times t.language_model for t in compilation_times
)
self.vllm_config.compilation_config.encoder_compilation_time = max(
t.encoder for t in compilation_times
) )
def register_failure_callback(self, callback: FailureCallback): # noqa: B027 def register_failure_callback(self, callback: FailureCallback): # noqa: B027
......
...@@ -13,6 +13,7 @@ from vllm.profiler.wrapper import TorchProfilerWrapper ...@@ -13,6 +13,7 @@ from vllm.profiler.wrapper import TorchProfilerWrapper
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
from vllm.v1.worker.cpu_model_runner import CPUModelRunner from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
from vllm.v1.worker.worker_base import CompilationTimes
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -104,12 +105,15 @@ class CPUWorker(Worker): ...@@ -104,12 +105,15 @@ class CPUWorker(Worker):
def determine_available_memory(self) -> int: def determine_available_memory(self) -> int:
return self.cache_config.cpu_kvcache_space_bytes or 0 return self.cache_config.cpu_kvcache_space_bytes or 0
def compile_or_warm_up_model(self) -> float: def compile_or_warm_up_model(self) -> CompilationTimes:
# Reset the seed to ensure that the random state is not affected by # Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
self.model_runner.warming_up_model() self.model_runner.warming_up_model()
return self.compilation_config.compilation_time return CompilationTimes(
language_model=self.compilation_config.compilation_time,
encoder=self.compilation_config.encoder_compilation_time,
)
def profile(self, is_start: bool = True, profile_prefix: str | None = None): def profile(self, is_start: bool = True, profile_prefix: str | None = None):
if self.profiler is None: if self.profiler is None:
......
...@@ -56,7 +56,7 @@ from vllm.v1.outputs import ( ...@@ -56,7 +56,7 @@ from vllm.v1.outputs import (
) )
from vllm.v1.utils import compute_iteration_details, report_usage_stats from vllm.v1.utils import compute_iteration_details, report_usage_stats
from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import CompilationTimes, WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager from vllm.v1.worker.workspace import init_workspace_manager
from ...model_executor.model_loader import TensorizerLoader from ...model_executor.model_loader import TensorizerLoader
...@@ -547,7 +547,7 @@ class Worker(WorkerBase): ...@@ -547,7 +547,7 @@ class Worker(WorkerBase):
self.model_runner._init_kv_zero_meta() self.model_runner._init_kv_zero_meta()
@instrument(span_name="Warmup (GPU)") @instrument(span_name="Warmup (GPU)")
def compile_or_warm_up_model(self) -> float: def compile_or_warm_up_model(self) -> CompilationTimes:
warmup_sizes: list[int] = [] warmup_sizes: list[int] = []
if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE: if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
...@@ -689,7 +689,10 @@ class Worker(WorkerBase): ...@@ -689,7 +689,10 @@ class Worker(WorkerBase):
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
return self.compilation_config.compilation_time return CompilationTimes(
language_model=self.compilation_config.compilation_time,
encoder=self.compilation_config.encoder_compilation_time,
)
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache() self.model_runner.reset_mm_cache()
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, Any, TypeVar from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -30,6 +30,11 @@ logger = init_logger(__name__) ...@@ -30,6 +30,11 @@ logger = init_logger(__name__)
_R = TypeVar("_R") _R = TypeVar("_R")
class CompilationTimes(NamedTuple):
language_model: float
encoder: float
class WorkerBase: class WorkerBase:
"""Worker interface that allows vLLM to cleanly separate implementations for """Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to different hardware. Also abstracts control plane communication, e.g., to
...@@ -86,11 +91,11 @@ class WorkerBase: ...@@ -86,11 +91,11 @@ class WorkerBase:
"""Get specifications for KV cache implementation.""" """Get specifications for KV cache implementation."""
raise NotImplementedError raise NotImplementedError
def compile_or_warm_up_model(self) -> float: def compile_or_warm_up_model(self) -> CompilationTimes:
"""Prepare model for execution through compilation/warmup. """Prepare model for execution through compilation/warmup.
Returns: Returns:
The accumulated compilation time in seconds. Compilation times (language_model, encoder) in seconds.
""" """
raise NotImplementedError raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment