Unverified Commit c08f3b2a authored by Lucas Kabela's avatar Lucas Kabela Committed by GitHub
Browse files

Measure encoder compile time seperate from llm backbone (#39240)


Signed-off-by: default avatarLucas Kabela <lucaskabela@meta.com>
parent f02b3269
......@@ -16,7 +16,7 @@ import shutil
import tempfile
import time
from contextlib import contextmanager
from typing import Any
from typing import Any, NamedTuple
import numpy as np
from tqdm import tqdm
......@@ -27,6 +27,82 @@ from vllm.benchmarks.lib.utils import (
)
from vllm.engine.arg_utils import EngineArgs
PERCENTAGES = [10, 25, 50, 75, 90, 99]
class MetricDesc(NamedTuple):
"""Descriptor for a metric to collect from each iteration."""
iter_key: str # key in the iteration result dict
suffix: str # result key suffix, e.g. "startup", "compilation"
display_name: str
class MetricStats(NamedTuple):
"""Aggregated statistics for a single benchmark metric."""
key: str # e.g. "cold_startup", "warm_encoder_compilation"
display_name: str
values: list[float]
avg: float
percentiles: dict[int, float]
_BASE_METRICS = [
MetricDesc("total_startup_time", "startup", "Startup time"),
MetricDesc("compilation_time", "compilation", "Compilation time"),
]
_ENCODER_METRIC = MetricDesc(
"encoder_compilation_time",
"encoder_compilation",
"Encoder compilation time",
)
def _compute_metric(
phase: str,
desc: MetricDesc,
iterations: list[dict[str, float]],
) -> MetricStats:
values = [m[desc.iter_key] for m in iterations]
arr = np.array(values)
return MetricStats(
key=f"{phase}_{desc.suffix}",
display_name=desc.display_name,
values=values,
avg=float(np.mean(arr)),
percentiles=dict(zip(PERCENTAGES, np.percentile(arr, PERCENTAGES).tolist())),
)
def _collect_phase_metrics(
phase: str,
iterations: list[dict[str, float]],
has_encoder: bool,
) -> list[MetricStats]:
metrics = [_compute_metric(phase, desc, iterations) for desc in _BASE_METRICS]
if has_encoder:
metrics.append(_compute_metric(phase, _ENCODER_METRIC, iterations))
return metrics
def _print_phase(phase_name: str, metrics: list[MetricStats]) -> None:
print(f"\n{phase_name}:")
for m in metrics:
print(f"Avg {m.display_name.lower()}: {m.avg:.2f} seconds")
for m in metrics:
print(f"{m.display_name} percentiles:")
for pct, val in m.percentiles.items():
print(f" {pct}%: {val:.2f} seconds")
def _metric_to_json(m: MetricStats) -> dict[str, Any]:
return {
f"avg_{m.key}_time": m.avg,
f"{m.key}_times": m.values,
f"{m.key}_percentiles": m.percentiles,
}
@contextmanager
def cold_startup():
......@@ -72,6 +148,7 @@ def run_startup_in_subprocess(engine_args, result_queue):
# Extract compilation time if available
compilation_time = 0.0
encoder_compilation_time = 0.0
if hasattr(llm.llm_engine, "vllm_config"):
vllm_config = llm.llm_engine.vllm_config
if (
......@@ -79,11 +156,15 @@ def run_startup_in_subprocess(engine_args, result_queue):
and vllm_config.compilation_config is not None
):
compilation_time = vllm_config.compilation_config.compilation_time
encoder_compilation_time = (
vllm_config.compilation_config.encoder_compilation_time
)
result_queue.put(
{
"total_startup_time": total_startup_time,
"compilation_time": compilation_time,
"encoder_compilation_time": encoder_compilation_time,
}
)
......@@ -93,65 +174,20 @@ def run_startup_in_subprocess(engine_args, result_queue):
def save_to_pytorch_benchmark_format(
args: argparse.Namespace, results: dict[str, Any]
args: argparse.Namespace, metrics: list[MetricStats]
) -> None:
base_name = os.path.splitext(args.output_json)[0]
cold_startup_records = convert_to_pytorch_benchmark_format(
for m in metrics:
records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_cold_startup_time": [results["avg_cold_startup_time"]],
},
metrics={f"avg_{m.key}_time": [m.avg]},
extra_info={
"cold_startup_times": results["cold_startup_times"],
"cold_startup_percentiles": results["cold_startup_percentiles"],
f"{m.key}_times": m.values,
f"{m.key}_percentiles": m.percentiles,
},
)
if cold_startup_records:
write_to_json(f"{base_name}.cold_startup.pytorch.json", cold_startup_records)
cold_compilation_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_cold_compilation_time": [results["avg_cold_compilation_time"]],
},
extra_info={
"cold_compilation_times": results["cold_compilation_times"],
"cold_compilation_percentiles": results["cold_compilation_percentiles"],
},
)
if cold_compilation_records:
write_to_json(
f"{base_name}.cold_compilation.pytorch.json", cold_compilation_records
)
warm_startup_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_warm_startup_time": [results["avg_warm_startup_time"]],
},
extra_info={
"warm_startup_times": results["warm_startup_times"],
"warm_startup_percentiles": results["warm_startup_percentiles"],
},
)
if warm_startup_records:
write_to_json(f"{base_name}.warm_startup.pytorch.json", warm_startup_records)
warm_compilation_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"avg_warm_compilation_time": [results["avg_warm_compilation_time"]],
},
extra_info={
"warm_compilation_times": results["warm_compilation_times"],
"warm_compilation_percentiles": results["warm_compilation_percentiles"],
},
)
if warm_compilation_records:
write_to_json(
f"{base_name}.warm_compilation.pytorch.json", warm_compilation_records
)
if records:
write_to_json(f"{base_name}.{m.key}.pytorch.json", records)
def add_cli_args(parser: argparse.ArgumentParser):
......@@ -224,97 +260,46 @@ def main(args: argparse.Namespace):
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
print("Setting VLLM_ENABLE_V1_MULTIPROCESSING=0 to collect startup metrics.\n")
# Collect cold startup iterations
print("Measuring cold startup time...\n")
cold_startup_times = []
cold_compilation_times = []
cold_iterations = []
for i in tqdm(range(args.num_iters_cold), desc="Cold startup iterations"):
with cold_startup():
metrics = create_llm_and_measure_startup()
cold_startup_times.append(metrics["total_startup_time"])
cold_compilation_times.append(metrics["compilation_time"])
cold_iterations.append(create_llm_and_measure_startup())
# Warmup for warm startup
print("\nWarming up for warm startup measurement...\n")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
create_llm_and_measure_startup()
# Collect warm startup iterations
print("\nMeasuring warm startup time...\n")
warm_startup_times = []
warm_compilation_times = []
warm_iterations = []
for i in tqdm(range(args.num_iters_warm), desc="Warm startup iterations"):
metrics = create_llm_and_measure_startup()
warm_startup_times.append(metrics["total_startup_time"])
warm_compilation_times.append(metrics["compilation_time"])
# Calculate statistics
cold_startup_array = np.array(cold_startup_times)
cold_compilation_array = np.array(cold_compilation_times)
warm_startup_array = np.array(warm_startup_times)
warm_compilation_array = np.array(warm_compilation_times)
avg_cold_startup = np.mean(cold_startup_array)
avg_cold_compilation = np.mean(cold_compilation_array)
avg_warm_startup = np.mean(warm_startup_array)
avg_warm_compilation = np.mean(warm_compilation_array)
percentages = [10, 25, 50, 75, 90, 99]
cold_startup_percentiles = np.percentile(cold_startup_array, percentages)
cold_compilation_percentiles = np.percentile(cold_compilation_array, percentages)
warm_startup_percentiles = np.percentile(warm_startup_array, percentages)
warm_compilation_percentiles = np.percentile(warm_compilation_array, percentages)
warm_iterations.append(create_llm_and_measure_startup())
# Determine if encoder compilation occurred in any iteration
has_encoder = any(
m["encoder_compilation_time"] > 0 for m in cold_iterations + warm_iterations
)
cold_metrics = _collect_phase_metrics("cold", cold_iterations, has_encoder)
warm_metrics = _collect_phase_metrics("warm", warm_iterations, has_encoder)
all_metrics = cold_metrics + warm_metrics
# Print results
print("\n" + "=" * 60)
print("STARTUP TIME BENCHMARK RESULTS")
print("=" * 60)
# Cold startup statistics
print("\nCOLD STARTUP:")
print(f"Avg total startup time: {avg_cold_startup:.2f} seconds")
print(f"Avg compilation time: {avg_cold_compilation:.2f} seconds")
print("Startup time percentiles:")
for percentage, percentile in zip(percentages, cold_startup_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
print("Compilation time percentiles:")
for percentage, percentile in zip(percentages, cold_compilation_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
# Warm startup statistics
print("\nWARM STARTUP:")
print(f"Avg total startup time: {avg_warm_startup:.2f} seconds")
print(f"Avg compilation time: {avg_warm_compilation:.2f} seconds")
print("Startup time percentiles:")
for percentage, percentile in zip(percentages, warm_startup_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
print("Compilation time percentiles:")
for percentage, percentile in zip(percentages, warm_compilation_percentiles):
print(f" {percentage}%: {percentile:.2f} seconds")
_print_phase("COLD STARTUP", cold_metrics)
_print_phase("WARM STARTUP", warm_metrics)
print("=" * 60)
# Output JSON results if specified
if args.output_json:
results = {
"avg_cold_startup_time": float(avg_cold_startup),
"avg_cold_compilation_time": float(avg_cold_compilation),
"cold_startup_times": cold_startup_times,
"cold_compilation_times": cold_compilation_times,
"cold_startup_percentiles": dict(
zip(percentages, cold_startup_percentiles.tolist())
),
"cold_compilation_percentiles": dict(
zip(percentages, cold_compilation_percentiles.tolist())
),
"avg_warm_startup_time": float(avg_warm_startup),
"avg_warm_compilation_time": float(avg_warm_compilation),
"warm_startup_times": warm_startup_times,
"warm_compilation_times": warm_compilation_times,
"warm_startup_percentiles": dict(
zip(percentages, warm_startup_percentiles.tolist())
),
"warm_compilation_percentiles": dict(
zip(percentages, warm_compilation_percentiles.tolist())
),
}
results: dict[str, Any] = {}
for m in all_metrics:
results.update(_metric_to_json(m))
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)
save_to_pytorch_benchmark_format(args, all_metrics)
......@@ -265,6 +265,7 @@ class CompilerManager:
compile_range: Range,
graph_index: int = 0,
num_graphs: int = 1,
is_encoder: bool = False,
) -> Any:
if graph_index == 0:
# before compiling the first graph, record the start time
......@@ -282,6 +283,9 @@ class CompilerManager:
# after loading the last graph for this shape, record the time.
# there can be multiple graphs due to piecewise compilation.
elapsed = time.perf_counter() - compilation_start_time
if is_encoder:
compilation_config.encoder_compilation_time += elapsed
else:
compilation_config.compilation_time += elapsed
logger.info_once(
"Directly load the compiled graph(s) for compile range %s "
......@@ -387,6 +391,9 @@ class CompilerManager:
# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
elapsed = time.perf_counter() - compilation_start_time
if is_encoder:
compilation_config.encoder_compilation_time += elapsed
else:
compilation_config.compilation_time += elapsed
logger.info_once(
"Compiling a graph for compile range %s takes %.2f s",
......@@ -1130,6 +1137,9 @@ class VllmBackend:
logger.info_once(
"Dynamo bytecode transform time: %.2f s", dynamo_time, scope="local"
)
if self.is_encoder:
self.compilation_config.encoder_compilation_time += dynamo_time
else:
self.compilation_config.compilation_time += dynamo_time
# Record Dynamo time in tracing if available
......
......@@ -270,6 +270,7 @@ class PiecewiseBackend:
compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
is_encoder=self.vllm_backend.is_encoder,
)
range_entry.compiled = True
......
......@@ -710,6 +710,8 @@ class CompilationConfig:
"""files that are traced for compilation"""
compilation_time: float = field(default=0.0, init=False)
"""time taken for compilation"""
encoder_compilation_time: float = field(default=0.0, init=False)
"""time taken for multimodal encoder compilation"""
static_forward_context: dict[str, Any] = field(default_factory=dict, init=False)
"""Per-model forward context
......@@ -756,6 +758,7 @@ class CompilationConfig:
"local_cache_dir",
"traced_files",
"compilation_time",
"encoder_compilation_time",
"static_forward_context",
"pass_config", # handled separately below
"dynamic_shapes_config", # handled separately below
......@@ -775,6 +778,7 @@ class CompilationConfig:
"enabled_custom_ops": True,
"disabled_custom_ops": True,
"compilation_time": True,
"encoder_compilation_time": True,
"traced_files": True,
"inductor_compile_config": {
"post_grad_custom_post_pass": True,
......
......@@ -282,8 +282,30 @@ class EngineCore:
self.model_executor.initialize_from_config(kv_cache_configs)
elapsed = time.time() - start
compile_time = vllm_config.compilation_config.compilation_time
encoder_compile_time = vllm_config.compilation_config.encoder_compilation_time
if encoder_compile_time > 0:
logger.info_once(
"init engine (profile, create kv cache, warmup model) took %.2f seconds",
"init engine (profile, create kv cache, warmup model) took "
"%.2f s (compilation: %.2f s — language_model: %.2f s, "
"encoder: %.2f s)",
elapsed,
compile_time + encoder_compile_time,
compile_time,
encoder_compile_time,
scope="local",
)
elif compile_time > 0:
logger.info_once(
"init engine (profile, create kv cache, warmup model) took "
"%.2f s (compilation: %.2f s)",
elapsed,
compile_time,
scope="local",
)
else:
logger.info_once(
"init engine (profile, create kv cache, warmup model) took %.2f s",
elapsed,
scope="local",
)
......
......@@ -22,7 +22,7 @@ from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.worker_base import CompilationTimes, WorkerBase
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
......@@ -121,14 +121,19 @@ class Executor(ABC):
underlying workers.
"""
self.collective_rpc("initialize_from_config", args=(kv_cache_configs,))
compilation_times: list[float] = self.collective_rpc("compile_or_warm_up_model")
compilation_times: list[CompilationTimes] = self.collective_rpc(
"compile_or_warm_up_model"
)
# Propagate compilation time from workers back to the main process.
# With TP>1, compilation happens in worker processes, so the main
# process config is never updated. Use max across workers since they
# compile in parallel.
if compilation_times:
self.vllm_config.compilation_config.compilation_time = max(
compilation_times
t.language_model for t in compilation_times
)
self.vllm_config.compilation_config.encoder_compilation_time = max(
t.encoder for t in compilation_times
)
def register_failure_callback(self, callback: FailureCallback): # noqa: B027
......
......@@ -13,6 +13,7 @@ from vllm.profiler.wrapper import TorchProfilerWrapper
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_worker import Worker, init_worker_distributed_environment
from vllm.v1.worker.worker_base import CompilationTimes
logger = init_logger(__name__)
......@@ -104,12 +105,15 @@ class CPUWorker(Worker):
def determine_available_memory(self) -> int:
return self.cache_config.cpu_kvcache_space_bytes or 0
def compile_or_warm_up_model(self) -> float:
def compile_or_warm_up_model(self) -> CompilationTimes:
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
self.model_runner.warming_up_model()
return self.compilation_config.compilation_time
return CompilationTimes(
language_model=self.compilation_config.compilation_time,
encoder=self.compilation_config.encoder_compilation_time,
)
def profile(self, is_start: bool = True, profile_prefix: str | None = None):
if self.profiler is None:
......
......@@ -56,7 +56,7 @@ from vllm.v1.outputs import (
)
from vllm.v1.utils import compute_iteration_details, report_usage_stats
from vllm.v1.worker.utils import is_residual_scattered_for_sp
from vllm.v1.worker.worker_base import WorkerBase
from vllm.v1.worker.worker_base import CompilationTimes, WorkerBase
from vllm.v1.worker.workspace import init_workspace_manager
from ...model_executor.model_loader import TensorizerLoader
......@@ -547,7 +547,7 @@ class Worker(WorkerBase):
self.model_runner._init_kv_zero_meta()
@instrument(span_name="Warmup (GPU)")
def compile_or_warm_up_model(self) -> float:
def compile_or_warm_up_model(self) -> CompilationTimes:
warmup_sizes: list[int] = []
if self.vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE:
......@@ -689,7 +689,10 @@ class Worker(WorkerBase):
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
return self.compilation_config.compilation_time
return CompilationTimes(
language_model=self.compilation_config.compilation_time,
encoder=self.compilation_config.encoder_compilation_time,
)
def reset_mm_cache(self) -> None:
self.model_runner.reset_mm_cache()
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, TypeVar
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
import torch
import torch.nn as nn
......@@ -30,6 +30,11 @@ logger = init_logger(__name__)
_R = TypeVar("_R")
class CompilationTimes(NamedTuple):
language_model: float
encoder: float
class WorkerBase:
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to
......@@ -86,11 +91,11 @@ class WorkerBase:
"""Get specifications for KV cache implementation."""
raise NotImplementedError
def compile_or_warm_up_model(self) -> float:
def compile_or_warm_up_model(self) -> CompilationTimes:
"""Prepare model for execution through compilation/warmup.
Returns:
The accumulated compilation time in seconds.
Compilation times (language_model, encoder) in seconds.
"""
raise NotImplementedError
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment