Unverified Commit 24523a1c authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat(vllm): add self-benchmark mode to InstrumentedScheduler (#7779)

parent a873045c
...@@ -273,6 +273,41 @@ def update_engine_config_with_dynamo( ...@@ -273,6 +273,41 @@ def update_engine_config_with_dynamo(
f"--scheduler-cls or subclass InstrumentedScheduler." f"--scheduler-cls or subclass InstrumentedScheduler."
) )
if dynamo_config.benchmark_mode is not None:
if dynamo_config.multimodal_worker or dynamo_config.multimodal_decode_worker:
logger.warning(
"--benchmark-mode is not supported for multimodal workers. "
"Benchmark data will be collected but not served via endpoint."
)
existing_cls = getattr(engine_config, "scheduler_cls", None)
if existing_cls is None and not envs.is_set("DYN_FORWARDPASS_METRIC_PORT"):
defaults[
"scheduler_cls"
] = "dynamo.vllm.instrumented_scheduler.InstrumentedScheduler"
logger.info("Benchmark mode: auto-enabling InstrumentedScheduler")
elif existing_cls is not None and "InstrumentedScheduler" not in str(
existing_cls
):
raise ValueError(
f"--benchmark-mode requires InstrumentedScheduler but "
f"--scheduler-cls is set to '{existing_cls}'. Either remove "
f"--scheduler-cls or use a subclass of InstrumentedScheduler."
)
dynamo_config._benchmark_additional_config = { # type: ignore[attr-defined]
"mode": dynamo_config.benchmark_mode,
"prefill_isl_granularity": dynamo_config.benchmark_prefill_granularity,
"decode_length_granularity": dynamo_config.benchmark_decode_length_granularity,
"decode_batch_size_granularity": dynamo_config.benchmark_decode_batch_granularity,
"warmup_iterations": dynamo_config.benchmark_warmup_iterations,
"output_path": dynamo_config.benchmark_output_path,
"timeout": dynamo_config.benchmark_timeout,
}
logger.info(
"Benchmark mode=%s configured (output=%s)",
dynamo_config.benchmark_mode,
dynamo_config.benchmark_output_path,
)
logger.debug("Setting Dynamo defaults for vLLM") logger.debug("Setting Dynamo defaults for vLLM")
for key, value in defaults.items(): for key, value in defaults.items():
if hasattr(engine_config, key): if hasattr(engine_config, key):
......
...@@ -171,6 +171,78 @@ class DynamoVllmArgGroup(ArgGroup): ...@@ -171,6 +171,78 @@ class DynamoVllmArgGroup(ArgGroup):
), ),
) )
# Benchmark / self-profiling
add_argument(
g,
flag_name="--benchmark-mode",
env_var="DYN_BENCHMARK_MODE",
default=None,
choices=["prefill", "decode", "agg"],
help=(
"Run self-benchmark on startup before accepting requests. "
"Sweeps prefill ISLs and/or decode (context_length x batch_size) "
"points, collecting ForwardPassMetrics at each operating point."
),
)
add_argument(
g,
flag_name="--benchmark-prefill-granularity",
env_var="DYN_BENCHMARK_PREFILL_GRANULARITY",
default=16,
type=int,
help="Number of ISL sample points for prefill sweep (default: 16).",
)
add_argument(
g,
flag_name="--benchmark-decode-length-granularity",
env_var="DYN_BENCHMARK_DECODE_LENGTH_GRANULARITY",
default=6,
type=int,
help=(
"Number of context length sample points for decode sweep "
"(default: 6)."
),
)
add_argument(
g,
flag_name="--benchmark-decode-batch-granularity",
env_var="DYN_BENCHMARK_DECODE_BATCH_GRANULARITY",
default=6,
type=int,
help=(
"Number of batch size sample points per context length " "(default: 6)."
),
)
add_argument(
g,
flag_name="--benchmark-warmup-iterations",
env_var="DYN_BENCHMARK_WARMUP_ITERATIONS",
default=5,
type=int,
help="Warmup iterations before benchmark (default: 5).",
)
add_argument(
g,
flag_name="--benchmark-output-path",
env_var="DYN_BENCHMARK_OUTPUT_PATH",
default="/tmp/benchmark_results.json",
help=(
"Path to write benchmark results JSON "
"(default: /tmp/benchmark_results.json)."
),
)
add_argument(
g,
flag_name="--benchmark-timeout",
env_var="DYN_BENCHMARK_TIMEOUT",
default=300,
type=int,
help=(
"Maximum seconds to wait for benchmark to complete "
"(default: 300). Worker startup fails if exceeded."
),
)
# @dataclass() # @dataclass()
class DynamoVllmConfig(ConfigBase): class DynamoVllmConfig(ConfigBase):
...@@ -204,6 +276,15 @@ class DynamoVllmConfig(ConfigBase): ...@@ -204,6 +276,15 @@ class DynamoVllmConfig(ConfigBase):
# GMS shadow mode # GMS shadow mode
gms_shadow_mode: bool = False gms_shadow_mode: bool = False
# Benchmark / self-profiling
benchmark_mode: Optional[str] = None
benchmark_prefill_granularity: int = 16
benchmark_decode_length_granularity: int = 6
benchmark_decode_batch_granularity: int = 6
benchmark_warmup_iterations: int = 5
benchmark_output_path: str = "/tmp/benchmark_results.json"
benchmark_timeout: int = 300
def validate(self) -> None: def validate(self) -> None:
"""Validate vLLM wrapper configuration.""" """Validate vLLM wrapper configuration."""
self._resolve_disaggregation_mode() self._resolve_disaggregation_mode()
......
...@@ -354,6 +354,8 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -354,6 +354,8 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
Request handler for the generate and clear_kv_blocks endpoints. Request handler for the generate and clear_kv_blocks endpoints.
""" """
_benchmark_results: Optional[dict] = None
def __init__( def __init__(
self, self,
runtime, runtime,
...@@ -683,6 +685,14 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]): ...@@ -683,6 +685,14 @@ class BaseWorkerHandler(ABC, Generic[RequestT, ResponseT]):
except Exception as e: except Exception as e:
yield {"status": "error", "message": str(e)} yield {"status": "error", "message": str(e)}
async def get_perf_metrics(self, request=None):
"""Return self-benchmark FPM results, or an error dict if none."""
result = getattr(self, "_benchmark_results", None)
if result is None:
yield {"status": "error", "message": "no benchmark data"}
else:
yield result
def add_temp_dir(self, temp_dir: tempfile.TemporaryDirectory) -> None: def add_temp_dir(self, temp_dir: tempfile.TemporaryDirectory) -> None:
"""Add a temporary directory to be cleaned up later.""" """Add a temporary directory to be cleaned up later."""
if temp_dir is not None: if temp_dir is not None:
......
...@@ -59,15 +59,13 @@ How metrics are measured ...@@ -59,15 +59,13 @@ How metrics are measured
passes the correct output for the batch being processed, even in passes the correct output for the batch being processed, even in
async mode where multiple batches are in flight). async mode where multiple batches are in flight).
* **queued_requests**: computed from ``self.waiting`` at emit time. * **queued_requests**: computed from ``self.waiting`` at emit time.
* **wall_time**: approximates the schedule-to-update_from_output * **wall_time**: approximates the GPU forward pass time for each batch.
latency described in ``ForwardPassMetrics``. Measured as the time In steady state, measured as the interval between consecutive
between consecutive ``update_from_output()`` calls. This works ``update_from_output()`` calls (accurate because CPU scheduling
because the EngineCore always blocks on ``future.result()`` (the overlaps with GPU execution). For the first batch after engine idle
GPU forward pass) right before calling ``update_from_output``, so (no previous ``update_from_output``), falls back to a per-batch
the interval is dominated by GPU compute. Assumption: CPU overhead ``schedule()``-to-``update_from_output()`` timestamp recorded via a
(scheduling + output processing) between consecutive calls is small FIFO queue. ``wall_time`` is ``0.0`` only for heartbeats.
relative to GPU forward pass time. ``wall_time`` is ``0.0`` for
the first message after engine idle and for heartbeats.
Serialization and ZMQ send are handled by a background thread Serialization and ZMQ send are handled by a background thread
(same approach as vLLM's ZmqEventPublisher) so the scheduler (same approach as vLLM's ZmqEventPublisher) so the scheduler
...@@ -79,19 +77,27 @@ Inject via: ...@@ -79,19 +77,27 @@ Inject via:
from __future__ import annotations from __future__ import annotations
import enum
import json
import logging import logging
import os import os
import queue import queue
import threading import threading
import time import time
from collections import deque
from dataclasses import asdict, dataclass, field
from itertools import count from itertools import count
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Literal
import msgspec.structs import msgspec.structs
import numpy as np
import zmq import zmq
from vllm.sampling_params import SamplingParams
from vllm.utils.hashing import get_hash_fn_by_name
from vllm.v1.core.kv_cache_utils import get_request_block_hasher, init_none_hash
from vllm.v1.core.sched.async_scheduler import AsyncScheduler from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
from vllm.v1.request import RequestStatus from vllm.v1.request import Request, RequestStatus
from dynamo.common.forward_pass_metrics import ( from dynamo.common.forward_pass_metrics import (
ForwardPassMetrics, ForwardPassMetrics,
...@@ -104,7 +110,6 @@ from dynamo.runtime.logging import configure_dynamo_logging ...@@ -104,7 +110,6 @@ from dynamo.runtime.logging import configure_dynamo_logging
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
...@@ -116,6 +121,43 @@ DEFAULT_FPM_PORT = 20380 ...@@ -116,6 +121,43 @@ DEFAULT_FPM_PORT = 20380
ENV_FPM_PORT = "DYN_FORWARDPASS_METRIC_PORT" ENV_FPM_PORT = "DYN_FORWARDPASS_METRIC_PORT"
# ---------------------------------------------------------------------------
# Benchmark mode dataclasses
# ---------------------------------------------------------------------------
@dataclass
class BenchmarkConfig:
mode: Literal["prefill", "decode", "agg"] = "agg"
prefill_isl_granularity: int = 16
decode_length_granularity: int = 6
decode_batch_size_granularity: int = 6
warmup_iterations: int = 5
output_path: str = "/tmp/benchmark_results.json"
class _BenchPhase(enum.Enum):
IDLE = "idle"
WARMUP = "warmup"
PREFILL_SWEEP = "prefill_sweep"
DECODE_SWEEP = "decode_sweep"
DONE = "done"
@dataclass
class BenchmarkPoint:
point_type: str # "prefill" or "decode"
isl: int = 0
context_length: int = 0
batch_size: int = 0
@dataclass
class BenchmarkPointResult:
point: BenchmarkPoint
fpms: list = field(default_factory=list)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Background publisher thread # Background publisher thread
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
...@@ -231,8 +273,11 @@ class InstrumentedScheduler(AsyncScheduler): ...@@ -231,8 +273,11 @@ class InstrumentedScheduler(AsyncScheduler):
self._fpm_worker_id = vllm_config.additional_config.get("fpm_worker_id", "") self._fpm_worker_id = vllm_config.additional_config.get("fpm_worker_id", "")
self._fpm_dp_rank = dp_rank self._fpm_dp_rank = dp_rank
self._schedule_times: deque[float] = deque()
self._last_update_time: float = 0.0 self._last_update_time: float = 0.0
self._prompt_len_per_req: dict[str, int] = {} self._prompt_len_per_req: dict[str, int] = {}
self._bench_active: bool = False
self._bench_phase: _BenchPhase = _BenchPhase.IDLE
base_port = int(os.environ.get(ENV_FPM_PORT, str(DEFAULT_FPM_PORT))) base_port = int(os.environ.get(ENV_FPM_PORT, str(DEFAULT_FPM_PORT)))
port = base_port + dp_rank port = base_port + dp_rank
...@@ -250,11 +295,68 @@ class InstrumentedScheduler(AsyncScheduler): ...@@ -250,11 +295,68 @@ class InstrumentedScheduler(AsyncScheduler):
dp_rank, dp_rank,
) )
self._bench_init(vllm_config)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Overrides # Overrides
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def has_requests(self) -> bool:
if self._bench_active:
return True
return super().has_requests()
def schedule(self) -> SchedulerOutput:
if self._bench_active and self._bench_phase != _BenchPhase.IDLE:
try:
output = self._bench_step()
except Exception:
logger.exception("Benchmark step failed, cleaning up")
self._bench_cleanup_requests()
self._bench_active = False
self._bench_phase = _BenchPhase.IDLE
return self._schedule_and_record_time()
if output is not None:
self.kv_cache_manager.new_step_starts()
self._update_after_schedule(output)
self._schedule_times.append(time.monotonic())
return output
if (
self._bench_phase == _BenchPhase.DECODE_SWEEP
and self._bench_active_req_ids
):
empty = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=(
[0] * self.kv_cache_manager.num_kv_cache_groups
),
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=[],
)
self._update_after_schedule(empty)
return empty
return self._schedule_and_record_time()
def _schedule_and_record_time(self) -> SchedulerOutput:
output = super().schedule()
if output.total_num_scheduled_tokens > 0:
self._schedule_times.append(time.monotonic())
return output
def shutdown(self) -> None: def shutdown(self) -> None:
if self._bench_active and self._bench_active_req_ids:
logger.warning(
"Benchmark interrupted, cleaning up %d requests",
len(self._bench_active_req_ids),
)
self._bench_cleanup_requests()
self._publisher.shutdown() self._publisher.shutdown()
super().shutdown() super().shutdown()
...@@ -264,18 +366,28 @@ class InstrumentedScheduler(AsyncScheduler): ...@@ -264,18 +366,28 @@ class InstrumentedScheduler(AsyncScheduler):
model_runner_output: "ModelRunnerOutput", model_runner_output: "ModelRunnerOutput",
): ):
result = super().update_from_output(scheduler_output, model_runner_output) result = super().update_from_output(scheduler_output, model_runner_output)
now = time.monotonic()
if scheduler_output.total_num_scheduled_tokens > 0: if scheduler_output.total_num_scheduled_tokens > 0:
wall_time = ( now = time.monotonic()
now - self._last_update_time if self._last_update_time > 0 else 0.0 t_sched = self._schedule_times.popleft() if self._schedule_times else 0.0
)
if self._last_update_time > 0:
wall_time = now - self._last_update_time
elif t_sched > 0:
wall_time = now - t_sched
else:
wall_time = 0.0
self._last_update_time = now self._last_update_time = now
metrics = self._extract_metrics( metrics = self._extract_metrics(
scheduler_output, self._compute_queued(), wall_time scheduler_output, self._compute_queued(), wall_time
) )
self._publisher.publish(metrics) self._publisher.publish(metrics)
if self._bench_active:
self._bench_current_fpms.append(
json.loads(msgspec.json.encode(metrics))
)
else: else:
self._last_update_time = 0.0 self._last_update_time = 0.0
...@@ -365,3 +477,388 @@ class InstrumentedScheduler(AsyncScheduler): ...@@ -365,3 +477,388 @@ class InstrumentedScheduler(AsyncScheduler):
def _cleanup_finished(self, output: SchedulerOutput) -> None: def _cleanup_finished(self, output: SchedulerOutput) -> None:
for req_id in output.finished_req_ids: for req_id in output.finished_req_ids:
self._prompt_len_per_req.pop(req_id, None) self._prompt_len_per_req.pop(req_id, None)
# ------------------------------------------------------------------
# Benchmark mode
# ------------------------------------------------------------------
def _bench_init(self, vllm_config: "VllmConfig") -> None:
"""Parse benchmark config and initialise state machine."""
bench_cfg = vllm_config.additional_config.get("benchmark")
if not bench_cfg:
self._bench_active = False
return
cfg = bench_cfg if isinstance(bench_cfg, dict) else {}
# additional_config values arrive as strings from JSON; coerce to
# the types that BenchmarkConfig expects.
_INT_FIELDS = {
"prefill_isl_granularity",
"decode_length_granularity",
"decode_batch_size_granularity",
"warmup_iterations",
}
for k in _INT_FIELDS:
if k in cfg and not isinstance(cfg[k], int):
cfg[k] = int(cfg[k])
known = {f.name for f in BenchmarkConfig.__dataclass_fields__.values()}
self._bench_config = BenchmarkConfig(
**{k: v for k, v in cfg.items() if k in known}
)
dp_rank = self._fpm_dp_rank
if dp_rank > 0:
base, ext = os.path.splitext(self._bench_config.output_path)
self._bench_config.output_path = f"{base}_dp{dp_rank}{ext}"
try:
os.unlink(self._bench_config.output_path)
except FileNotFoundError:
pass
self._bench_active = True
self._bench_phase = _BenchPhase.WARMUP
self._bench_grid: deque[BenchmarkPoint] = deque()
self._bench_current_point: BenchmarkPoint | None = None
self._bench_results: list[BenchmarkPointResult] = []
self._bench_current_fpms: list[dict] = []
self._bench_active_req_ids: set[str] = set()
self._bench_seq = 0
self._bench_grid_built = False
self._bench_drain_pending = False
# Build block_hasher so benchmark requests work with prefix caching.
if self.cache_config.enable_prefix_caching:
caching_hash_fn = get_hash_fn_by_name(
self.cache_config.prefix_caching_hash_algo
)
init_none_hash(caching_hash_fn)
self._bench_block_hasher = get_request_block_hasher(
self.block_size, caching_hash_fn
)
else:
self._bench_block_hasher = None
logger.info("Benchmark mode enabled: %s", self._bench_config)
# -- Grid generation ------------------------------------------------
def _bench_build_grid(self) -> None:
"""Generate the sweep grid once scheduler limits are known."""
if self._bench_grid_built:
return
self._bench_grid_built = True
mode = self._bench_config.mode
if mode in ("prefill", "agg"):
self._bench_generate_prefill_grid()
if mode in ("decode", "agg"):
self._bench_generate_decode_grid()
logger.info("Benchmark grid: %d points (%s mode)", len(self._bench_grid), mode)
def _bench_generate_prefill_grid(self) -> None:
n = max(1, self._bench_config.prefill_isl_granularity)
max_tokens = self.max_num_scheduled_tokens
if max_tokens < 10:
logger.warning(
"max_num_scheduled_tokens=%d too small, skipping prefill grid",
max_tokens,
)
return
isls = np.unique(np.linspace(10, max_tokens, n, dtype=int))
for isl in isls:
self._bench_grid.append(BenchmarkPoint(point_type="prefill", isl=int(isl)))
def _bench_generate_decode_grid(self) -> None:
n_len = max(1, self._bench_config.decode_length_granularity)
n_bs = max(1, self._bench_config.decode_batch_size_granularity)
total_kv_tokens = self.cache_config.num_gpu_blocks * self.block_size
max_ctx = self.max_model_len - 10
if max_ctx < self.block_size:
logger.warning("max_model_len too small for decode grid, skipping")
return
ctx_lens = np.unique(np.linspace(self.block_size, max_ctx, n_len, dtype=int))
for ctx_len in ctx_lens:
ctx_len = int(ctx_len)
max_batch = min(self.max_num_running_reqs, total_kv_tokens // ctx_len)
if max_batch < 1:
continue
batch_sizes = np.unique(np.linspace(1, max_batch, n_bs, dtype=int))
for bs in batch_sizes:
self._bench_grid.append(
BenchmarkPoint(
point_type="decode",
context_length=ctx_len,
batch_size=int(bs),
)
)
# -- Request injection / cleanup ------------------------------------
def _bench_inject_prefill(
self, prompt_len: int, max_tokens: int, n: int = 1
) -> None:
for _ in range(n):
req_id = f"__bench_{self._bench_seq}"
req = Request(
request_id=req_id,
prompt_token_ids=[0] * prompt_len,
sampling_params=SamplingParams(max_tokens=max_tokens),
pooling_params=None,
block_hasher=self._bench_block_hasher,
cache_salt=req_id,
)
self.add_request(req)
self._bench_active_req_ids.add(req_id)
self._bench_seq += 1
def _bench_inject_fake_decode(
self, ctx_len: int, batch_size: int
) -> SchedulerOutput:
"""Create fake decode requests with pre-allocated KV and return
a custom SchedulerOutput that registers them with the model runner."""
new_reqs_data: list[NewRequestData] = []
num_scheduled_tokens: dict[str, int] = {}
for _ in range(batch_size):
req_id = f"__bench_{self._bench_seq}"
prompt = [0] * ctx_len
req = Request(
request_id=req_id,
prompt_token_ids=prompt,
sampling_params=SamplingParams(max_tokens=100_000),
pooling_params=None,
block_hasher=self._bench_block_hasher,
cache_salt=req_id,
)
new_blocks = self.kv_cache_manager.allocate_slots(
req, ctx_len, delay_cache_blocks=True
)
if new_blocks is None:
logger.warning(
"KV exhausted at ctx_len=%d after %d requests, " "truncating batch",
ctx_len,
len(new_reqs_data),
)
break
req.num_computed_tokens = ctx_len
req.status = RequestStatus.RUNNING
req.append_output_token_ids(0)
self.requests[req_id] = req
self.running.append(req) # type: ignore[has-type]
self._bench_active_req_ids.add(req_id)
self._bench_seq += 1
block_ids = new_blocks.get_block_ids()
new_reqs_data.append(
NewRequestData(
req_id=req_id,
prompt_token_ids=prompt,
mm_features=[],
sampling_params=req.sampling_params,
pooling_params=None,
block_ids=block_ids,
num_computed_tokens=ctx_len,
lora_request=None,
)
)
num_scheduled_tokens[req_id] = 1
new_block_ids_to_zero = (
(self.kv_cache_manager.take_new_block_ids() or None)
if getattr(self, "needs_kv_cache_zeroing", False)
else None
)
return SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=len(new_reqs_data),
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=([0] * self.kv_cache_manager.num_kv_cache_groups),
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=[],
new_block_ids_to_zero=new_block_ids_to_zero,
)
def _bench_cleanup_requests(self) -> None:
"""Free all resources held by active benchmark requests."""
for req_id in list(self._bench_active_req_ids):
req = self.requests.get(req_id)
if req:
self.kv_cache_manager.free(req)
self.finished_req_ids.add(req_id)
del self.requests[req_id]
running = self.running # type: ignore[has-type]
self.running = [
r for r in running if r.request_id not in self._bench_active_req_ids
]
self._bench_active_req_ids.clear()
self._schedule_times.clear()
# -- State machine --------------------------------------------------
def _bench_step(self) -> SchedulerOutput | None:
"""Advance the benchmark state machine.
Returns a custom ``SchedulerOutput`` for fake-decode points, or
``None`` when normal scheduling should handle the current step
(prefill / warmup / cleanup cycles).
"""
self._bench_build_grid()
if self._bench_phase == _BenchPhase.WARMUP:
return self._bench_step_warmup()
if self._bench_phase == _BenchPhase.PREFILL_SWEEP:
return self._bench_step_prefill()
if self._bench_phase == _BenchPhase.DECODE_SWEEP:
return self._bench_step_decode()
if self._bench_phase == _BenchPhase.DONE:
self._bench_write_results()
self._bench_active = False
self._bench_phase = _BenchPhase.IDLE
logger.info("Benchmark complete")
return None
def _bench_step_warmup(self) -> SchedulerOutput | None:
if not self._bench_active_req_ids:
iters = self._bench_config.warmup_iterations
if iters > 0:
self._bench_inject_prefill(prompt_len=256, max_tokens=iters)
logger.info("Benchmark warmup: 1 prefill + %d decode steps", iters)
else:
self._bench_transition_after_warmup()
return None
still_alive = any(rid in self.requests for rid in self._bench_active_req_ids)
if not still_alive:
self._bench_transition_after_warmup()
return None
def _bench_transition_after_warmup(self) -> None:
self._bench_cleanup_requests()
self._bench_current_fpms.clear()
mode = self._bench_config.mode
if mode in ("prefill", "agg"):
self._bench_phase = _BenchPhase.PREFILL_SWEEP
logger.info("Benchmark: entering PREFILL_SWEEP")
else:
self._bench_phase = _BenchPhase.DECODE_SWEEP
logger.info("Benchmark: entering DECODE_SWEEP")
def _bench_drain_if_pending(self) -> bool:
"""If a drain cycle is pending, discard stale FPMs and return True."""
if not self._bench_drain_pending:
return False
self._bench_drain_pending = False
self._bench_current_fpms.clear()
self._schedule_times.clear()
return True
def _bench_step_prefill(self) -> SchedulerOutput | None:
if self._bench_drain_if_pending():
pass # fall through to inject next point
elif self._bench_active_req_ids:
still_alive = any(
rid in self.requests for rid in self._bench_active_req_ids
)
if still_alive:
return None
if not self._bench_current_fpms:
return None
self._bench_save_current_point()
self._bench_cleanup_requests()
self._bench_drain_pending = True
return None
point = self._bench_pop_next("prefill")
if point is None:
if self._bench_config.mode == "agg":
self._bench_phase = _BenchPhase.DECODE_SWEEP
logger.info("Benchmark: entering DECODE_SWEEP")
else:
self._bench_phase = _BenchPhase.DONE
return None
self._bench_current_point = point
self._bench_current_fpms = []
self._bench_inject_prefill(prompt_len=point.isl, max_tokens=1)
logger.info("Benchmark prefill: ISL=%d", point.isl)
return None
def _bench_step_decode(self) -> SchedulerOutput | None:
if self._bench_drain_if_pending():
pass # fall through to inject next point
elif self._bench_active_req_ids:
if not self._bench_current_fpms:
return None
self._bench_save_current_point()
self._bench_cleanup_requests()
self._bench_drain_pending = True
return None
point = self._bench_pop_next("decode")
if point is None:
self._bench_phase = _BenchPhase.DONE
return None
self._bench_current_point = point
self._bench_current_fpms = []
logger.info(
"Benchmark decode: ctx_len=%d batch_size=%d",
point.context_length,
point.batch_size,
)
return self._bench_inject_fake_decode(point.context_length, point.batch_size)
def _bench_pop_next(self, point_type: str) -> BenchmarkPoint | None:
while self._bench_grid:
pt = self._bench_grid[0]
if pt.point_type == point_type:
return self._bench_grid.popleft()
break
return None
def _bench_save_current_point(self) -> None:
if self._bench_current_point is not None and self._bench_current_fpms:
self._bench_results.append(
BenchmarkPointResult(
point=self._bench_current_point,
fpms=list(self._bench_current_fpms),
)
)
self._bench_current_point = None
self._bench_current_fpms = []
# -- Results output -------------------------------------------------
def _bench_write_results(self) -> None:
output = {
"config": asdict(self._bench_config),
"limits": {
"max_num_scheduled_tokens": self.max_num_scheduled_tokens,
"max_num_running_reqs": self.max_num_running_reqs,
"max_model_len": self.max_model_len,
"block_size": self.block_size,
"num_gpu_blocks": self.cache_config.num_gpu_blocks,
},
"results": [
{"point": asdict(r.point), "fpms": r.fpms} for r in self._bench_results
],
}
dest = self._bench_config.output_path
tmp = dest + ".tmp"
with open(tmp, "w") as f:
json.dump(output, f, indent=2)
os.replace(tmp, dest)
logger.info(
"Benchmark results written to %s (%d points)",
dest,
len(self._bench_results),
)
...@@ -556,6 +556,15 @@ def setup_vllm_engine( ...@@ -556,6 +556,15 @@ def setup_vllm_engine(
if fpm_worker_id is not None: if fpm_worker_id is not None:
vllm_config.additional_config["fpm_worker_id"] = fpm_worker_id vllm_config.additional_config["fpm_worker_id"] = fpm_worker_id
# Pass benchmark config to InstrumentedScheduler via additional_config.
if hasattr(config, "_benchmark_additional_config"):
bench = config._benchmark_additional_config
if fpm_worker_id and bench["output_path"] == "/tmp/benchmark_results.json":
short_id = fpm_worker_id[-8:]
bench["output_path"] = f"/tmp/benchmark_results_{short_id}.json"
vllm_config.additional_config["benchmark"] = bench
logger.info("Benchmark config injected into additional_config")
factory = [] factory = []
if stat_logger: if stat_logger:
factory.append(stat_logger) factory.append(stat_logger)
......
...@@ -596,3 +596,128 @@ class TestVllmOmniOptionalDependency: ...@@ -596,3 +596,128 @@ class TestVllmOmniOptionalDependency:
sys.modules.pop(mod, None) sys.modules.pop(mod, None)
# Restore original state # Restore original state
sys.modules.update(saved) sys.modules.update(saved)
# ---------------------------------------------------------------------------
# Benchmark mode unit tests
# ---------------------------------------------------------------------------
class TestBenchmarkConfig:
"""Tests for BenchmarkConfig dataclass and grid generation."""
def test_benchmark_config_defaults(self):
from dynamo.vllm.instrumented_scheduler import BenchmarkConfig
cfg = BenchmarkConfig()
assert cfg.mode == "agg"
assert cfg.prefill_isl_granularity == 16
assert cfg.decode_length_granularity == 6
assert cfg.decode_batch_size_granularity == 6
assert cfg.warmup_iterations == 5
assert cfg.output_path == "/tmp/benchmark_results.json"
def test_benchmark_config_from_dict(self):
from dynamo.vllm.instrumented_scheduler import BenchmarkConfig
cfg = BenchmarkConfig(
mode="decode",
prefill_isl_granularity=4,
decode_length_granularity=3,
decode_batch_size_granularity=3,
warmup_iterations=2,
output_path="/tmp/test.json",
)
assert cfg.mode == "decode"
assert cfg.prefill_isl_granularity == 4
def test_benchmark_config_kwargs_unpack(self):
from dynamo.vllm.instrumented_scheduler import BenchmarkConfig
d = {"mode": "prefill", "warmup_iterations": 1}
cfg = BenchmarkConfig(**d)
assert cfg.mode == "prefill"
assert cfg.warmup_iterations == 1
assert cfg.prefill_isl_granularity == 16
class TestBenchmarkGrid:
"""Tests for benchmark grid generation logic (no GPU required)."""
def _make_grid_helper(self):
"""Return (prefill_grid_fn, decode_grid_fn) that operate on plain params."""
import numpy as np
def generate_prefill_grid(max_num_scheduled_tokens, granularity):
isls = np.unique(
np.linspace(10, max_num_scheduled_tokens, granularity, dtype=int)
)
return [int(x) for x in isls]
def generate_decode_grid(
block_size,
max_model_len,
max_num_running_reqs,
num_gpu_blocks,
length_granularity,
batch_granularity,
):
total_kv_tokens = num_gpu_blocks * block_size
ctx_lens = np.unique(
np.linspace(block_size, max_model_len, length_granularity, dtype=int)
)
points = []
for ctx_len in ctx_lens:
ctx_len = int(ctx_len)
max_batch = min(max_num_running_reqs, total_kv_tokens // ctx_len)
if max_batch < 1:
continue
batch_sizes = np.unique(
np.linspace(1, max_batch, batch_granularity, dtype=int)
)
for bs in batch_sizes:
points.append((ctx_len, int(bs)))
return points
return generate_prefill_grid, generate_decode_grid
def test_prefill_grid_count(self):
gen_prefill, _ = self._make_grid_helper()
isls = gen_prefill(max_num_scheduled_tokens=8192, granularity=10)
assert len(isls) == 10
assert isls[0] == 10
assert isls[-1] == 8192
def test_prefill_grid_dedup(self):
gen_prefill, _ = self._make_grid_helper()
isls = gen_prefill(max_num_scheduled_tokens=20, granularity=100)
assert len(isls) == len(set(isls))
def test_decode_grid_batch_capped(self):
_, gen_decode = self._make_grid_helper()
points = gen_decode(
block_size=16,
max_model_len=4096,
max_num_running_reqs=64,
num_gpu_blocks=256,
length_granularity=3,
batch_granularity=3,
)
total_kv = 256 * 16
for ctx_len, bs in points:
assert bs <= min(64, total_kv // ctx_len)
assert bs >= 1
def test_decode_grid_skips_large_ctx(self):
_, gen_decode = self._make_grid_helper()
points = gen_decode(
block_size=16,
max_model_len=100000,
max_num_running_reqs=64,
num_gpu_blocks=100,
length_granularity=5,
batch_granularity=3,
)
total_kv = 100 * 16
for ctx_len, bs in points:
assert ctx_len <= total_kv
...@@ -4,9 +4,12 @@ ...@@ -4,9 +4,12 @@
"""Worker initialization factory for vLLM workers.""" """Worker initialization factory for vLLM workers."""
import asyncio import asyncio
import json
import logging import logging
import os import os
import time as _time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -20,7 +23,7 @@ from dynamo.runtime import DistributedRuntime ...@@ -20,7 +23,7 @@ from dynamo.runtime import DistributedRuntime
from .args import Config from .args import Config
from .constants import DisaggregationMode from .constants import DisaggregationMode
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler from .handlers import DecodeWorkerHandler, PrefillWorkerHandler, get_dp_range_for_worker
from .health_check import VllmHealthCheckPayload, VllmPrefillHealthCheckPayload from .health_check import VllmHealthCheckPayload, VllmPrefillHealthCheckPayload
from .multimodal_handlers import EncodeWorkerHandler from .multimodal_handlers import EncodeWorkerHandler
from .publisher import StatLoggerFactory from .publisher import StatLoggerFactory
...@@ -30,6 +33,66 @@ logger = logging.getLogger(__name__) ...@@ -30,6 +33,66 @@ logger = logging.getLogger(__name__)
# (engine_client, vllm_config, default_sampling_params, prometheus_temp_dir, component_gauges) # (engine_client, vllm_config, default_sampling_params, prometheus_temp_dir, component_gauges)
EngineSetupResult = tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics] EngineSetupResult = tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]
async def _wait_and_load_benchmark(bench_cfg: dict, vllm_config: VllmConfig) -> dict:
"""Wait for benchmark result files and aggregate across DP ranks."""
base_path = Path(bench_cfg["output_path"])
timeout = int(bench_cfg.get("timeout", 300))
try:
dp_start, dp_size = get_dp_range_for_worker(vllm_config)
except Exception:
logger.warning(
"Could not determine DP range, assuming single rank",
exc_info=True,
)
dp_start, dp_size = 0, 1
rank_paths = []
for dp_rank in range(dp_start, dp_start + dp_size):
if dp_rank == 0:
rank_paths.append(base_path)
else:
stem, ext = os.path.splitext(str(base_path))
rank_paths.append(Path(f"{stem}_dp{dp_rank}{ext}"))
logger.info(
"Waiting for benchmark to complete (files: %s, timeout: %ds)...",
rank_paths,
timeout,
)
deadline = _time.monotonic() + timeout
for p in rank_paths:
while not p.exists():
if _time.monotonic() > deadline:
raise TimeoutError(
f"Benchmark did not complete within {timeout}s. " f"Missing: {p}"
)
await asyncio.sleep(0.1)
merged: dict = {}
for i, p in enumerate(rank_paths):
with open(p) as f:
data = json.load(f)
if i == 0:
merged = data
for r in merged.get("results", []):
r["point"]["dp_rank"] = dp_start
else:
dp_rank = dp_start + i
for r in data.get("results", []):
r["point"]["dp_rank"] = dp_rank
merged.setdefault("results", []).extend(data.get("results", []))
logger.info(
"Benchmark complete, %d points across %d rank(s)",
len(merged.get("results", [])),
len(rank_paths),
)
return merged
SetupVllmEngineFn = Callable[..., EngineSetupResult] SetupVllmEngineFn = Callable[..., EngineSetupResult]
SetupKvEventPublisherFn = Callable[..., Optional[Any]] SetupKvEventPublisherFn = Callable[..., Optional[Any]]
RegisterVllmModelFn = Callable[..., Awaitable[None]] RegisterVllmModelFn = Callable[..., Awaitable[None]]
...@@ -64,6 +127,9 @@ class WorkerFactory: ...@@ -64,6 +127,9 @@ class WorkerFactory:
) -> None: ) -> None:
"""Create the appropriate multimodal worker based on config flags.""" """Create the appropriate multimodal worker based on config flags."""
# NOTE: --benchmark-mode is only supported for prefill/decode workers.
# The encode worker path does not wire benchmark waiting or
# the get_perf_metrics endpoint.
if config.disaggregation_mode == DisaggregationMode.ENCODE: if config.disaggregation_mode == DisaggregationMode.ENCODE:
await self._create_multimodal_encode_worker( await self._create_multimodal_encode_worker(
runtime, config, shutdown_event, shutdown_endpoints runtime, config, shutdown_event, shutdown_endpoints
...@@ -296,6 +362,13 @@ class WorkerFactory: ...@@ -296,6 +362,13 @@ class WorkerFactory:
handler._quiesce_controller.mark_resumed() handler._quiesce_controller.mark_resumed()
logger.info("[Shadow] Engine awake, registering with discovery") logger.info("[Shadow] Engine awake, registering with discovery")
# Wait for self-benchmark to complete before registering.
bench_cfg = vllm_config.additional_config.get("benchmark")
if bench_cfg:
handler._benchmark_results = await _wait_and_load_benchmark(
bench_cfg, vllm_config
)
await self.register_vllm_model( await self.register_vllm_model(
model_input, model_input,
model_type, model_type,
...@@ -309,6 +382,11 @@ class WorkerFactory: ...@@ -309,6 +382,11 @@ class WorkerFactory:
engine_client, use_text_input=config.use_vllm_tokenizer engine_client, use_text_input=config.use_vllm_tokenizer
).to_dict() ).to_dict()
perf_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.get_perf_metrics"
)
shutdown_endpoints.append(perf_endpoint)
try: try:
logger.debug("Starting serve_endpoint for decode worker") logger.debug("Starting serve_endpoint for decode worker")
...@@ -336,6 +414,10 @@ class WorkerFactory: ...@@ -336,6 +414,10 @@ class WorkerFactory:
handler.clear_kv_blocks, handler.clear_kv_blocks,
metrics_labels=model_metrics_labels, metrics_labels=model_metrics_labels,
), ),
perf_endpoint.serve_endpoint(
handler.get_perf_metrics,
metrics_labels=model_metrics_labels,
),
] ]
if lora_enabled: if lora_enabled:
...@@ -467,7 +549,17 @@ class WorkerFactory: ...@@ -467,7 +549,17 @@ class WorkerFactory:
"Registered engine routes: /engine/sleep, /engine/wake_up, /engine/scale_elastic_ep" "Registered engine routes: /engine/sleep, /engine/wake_up, /engine/scale_elastic_ep"
) )
shutdown_endpoints[:] = [generate_endpoint, clear_endpoint] # Wait for self-benchmark to complete before registering.
bench_cfg = vllm_config.additional_config.get("benchmark")
if bench_cfg:
handler._benchmark_results = await _wait_and_load_benchmark(
bench_cfg, vllm_config
)
perf_endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.get_perf_metrics"
)
shutdown_endpoints[:] = [generate_endpoint, clear_endpoint, perf_endpoint]
# Register prefill model with ModelType.Prefill # Register prefill model with ModelType.Prefill
model_input = ( model_input = (
...@@ -486,18 +578,7 @@ class WorkerFactory: ...@@ -486,18 +578,7 @@ class WorkerFactory:
engine_client, use_text_input=config.use_vllm_tokenizer engine_client, use_text_input=config.use_vllm_tokenizer
).to_dict() ).to_dict()
try: prefill_metrics_labels = [
logger.debug("Starting serve_endpoint for prefill worker")
await asyncio.gather(
# for prefill, we want to shutdown the engine after all prefill requests are finished because
# (temp reason): we don't support re-routing prefill requests
# (long-term reason): prefill engine should pull from a global queue so there is
# only a few in-flight requests that can be quickly finished
generate_endpoint.serve_endpoint(
handler.generate, # type: ignore
graceful_shutdown=True,
# In practice config.served_model_name is always set, but mypy needs the "or" here.
metrics_labels=[
( (
prometheus_names.labels.MODEL, prometheus_names.labels.MODEL,
config.served_model_name or config.model, config.served_model_name or config.model,
...@@ -506,21 +587,24 @@ class WorkerFactory: ...@@ -506,21 +587,24 @@ class WorkerFactory:
prometheus_names.labels.MODEL_NAME, prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model, config.served_model_name or config.model,
), ),
], ]
try:
logger.debug("Starting serve_endpoint for prefill worker")
await asyncio.gather(
generate_endpoint.serve_endpoint(
handler.generate, # type: ignore
graceful_shutdown=True,
metrics_labels=prefill_metrics_labels,
health_check_payload=health_check_payload, health_check_payload=health_check_payload,
), ),
clear_endpoint.serve_endpoint( clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, # type: ignore handler.clear_kv_blocks, # type: ignore
metrics_labels=[ metrics_labels=prefill_metrics_labels,
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
), ),
], perf_endpoint.serve_endpoint(
handler.get_perf_metrics,
metrics_labels=prefill_metrics_labels,
), ),
) )
logger.debug("serve_endpoint completed for prefill worker") logger.debug("serve_endpoint completed for prefill worker")
......
# vLLM RFC: Per-Iteration Forward Pass Metrics via ZMQ
> For submission to https://github.com/vllm-project/vllm/issues/new?template=750-RFC.yml
---
## Title
`[RFC]: Per-iteration forward pass metrics with accurate engine-level timing`
---
## Motivation
**Problem: orchestration systems need per-iteration scheduler telemetry, but vLLM only exposes aggregated Prometheus metrics.**
Inference orchestrators (autoscalers, routers, disaggregated serving planners) need to understand the *per-iteration* cost structure of a running vLLM engine:
- How many prefill vs decode requests were in each batch?
- What was the KV cache depth distribution across decode requests?
- How many tokens were computed vs cache-hit?
- How long did the GPU forward pass actually take?
- How many requests are queued and waiting?
Today, vLLM exposes Prometheus gauge/histogram metrics that are **scraped asynchronously** by an external collector. This has fundamental limitations for per-iteration telemetry:
1. **Lossy**: Prometheus scraping is pull-based at a configurable interval. With iteration times of 10-100ms, the scraper can miss 90%+ of iterations. Gauge values reflect only the most recent state at scrape time, not the full distribution. Aggregated metrics inevitably lose information.
2. **Unsynchronized**: The scraper runs on a separate timer from the engine loop. Metrics from different gauges may reflect different iterations, making it impossible to correlate prefill/decode counts with wall time for the same batch.
3. **No per-iteration history**: There is no way to reconstruct the sequence of batch compositions over time. An autoscaler cannot build a cost model from Prometheus data because it only sees snapshots.
4. **Latency**: Push-based Prometheus (Pushgateway) uses HTTP, adding latency and overhead proportional to scrape frequency. For per-iteration emission at 100+ iterations/second, this is prohibitive.
**Why this matters for the ecosystem:**
- **NVIDIA Dynamo** currently implements this as an out-of-tree `--scheduler-cls` subclass ([InstrumentedScheduler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/vllm/instrumented_scheduler.py)), but measuring wall time from the scheduler is inherently imprecise because the scheduler cannot observe the GPU forward pass boundary (see Proposed Change).
- **Autoscalers** (Kubernetes HPA, custom planners) need per-iteration throughput signals to make scaling decisions within seconds, not minutes.
---
## Proposed Change
### 1. Add `wall_time` measurement in EngineCore
Measure the GPU forward pass time at the exact boundary -- around `future.result()` in `EngineCore.step()` / `step_with_batch_queue()`:
```python
# In EngineCore.step():
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
...
t_start = time.monotonic()
model_output = future.result() # blocks until GPU finishes
wall_time = time.monotonic() - t_start
...
self.scheduler.update_from_output(scheduler_output, model_output, wall_time=wall_time)
```
This is the **only** place in the codebase with direct access to both the GPU wait boundary and the scheduler output. The scheduler cannot measure this accurately because:
- In sync mode: `schedule()` returns before `execute_model` runs
- In async mode: `schedule(N+1)` runs concurrently with GPU batch N, so scheduler-side timestamps include overlap from adjacent batches
Pass `wall_time` to `update_from_output()` as a new optional kwarg so the scheduler can include it in metrics.
### 2. Define a per-iteration metrics struct
A compact, versioned struct emitted once per forward pass:
```python
class ForwardPassMetrics(msgspec.Struct, frozen=True):
version: int = 1 # can include more info in later versions
# Identity
worker_id: str = "" # unique engine instance identifier
dp_rank: int = 0 # data parallel rank
counter_id: int = 0 # monotonic sequence number
# Timing (measured in EngineCore)
wall_time: float = 0.0 # seconds, GPU forward pass time
# Scheduled batch composition
num_prefill_requests: int = 0
sum_prefill_tokens: int = 0 # tokens being computed this iteration
var_prefill_length: float = 0.0 # variance of total prompt lengths
sum_prefill_kv_tokens: int = 0 # KV tokens read (cache hits + prior chunks)
num_decode_requests: int = 0
sum_decode_kv_tokens: int = 0 # total KV depth across decode requests
var_decode_kv_tokens: float = 0.0
# Queue state
num_queued_prefill: int = 0
sum_queued_prefill_tokens: int = 0
num_queued_decode: int = 0 # preempted requests waiting
sum_queued_decode_kv_tokens: int = 0
```
**Why these specific fields:**
- An autoscaler needs `wall_time` + `num_prefill_requests` + `num_decode_requests` + token counts to build a cost model of the form `latency = f(prefill_tokens, decode_batch_size, kv_depth)`.
- Variance fields enable detecting heterogeneous batches (mix of short and long sequences) which affect padding overhead and CUDA graph efficiency.
- Queue metrics enable load-aware routing and backpressure signals.
- `msgspec.Struct` is zero-copy serializable and already used by vLLM for KV cache events.
### 3. Emit via ZMQ PUB/SUB (not Prometheus)
Publish the struct over a ZMQ PUB socket bound to a configurable localhost port, using msgpack serialization:
```
ZMQ message: [topic_bytes, sequence_bytes, msgpack_payload]
```
**Why ZMQ over Prometheus:**
| | ZMQ PUB/SUB | Prometheus |
|---|---|---|
| **Delivery** | Push, every iteration | Pull, scraper interval |
| **Completeness** | Every iteration captured | 90%+ iterations missed |
| **Correlation** | All fields from same iteration in one message | Gauges may reflect different iterations |
| **Latency** | ~10us per message (IPC) | HTTP round-trip per scrape |
| **CPU overhead** | Background thread, non-blocking send | Metric registry lock contention |
| **Consumers** | Multiple SUB sockets, zero-copy | One scraper endpoint |
| **Format** | Versioned, typed, extensible (msgspec) | Flat key-value gauges |
The ZMQ publisher runs in a background daemon thread (same pattern as vLLM's existing `ZmqEventPublisher` for KV cache events). The scheduler hot path only pays for `queue.put_nowait()` on a bounded queue -- no serialization, no I/O.
**Backward compatibility: Prometheus "most recent" gauges.** For users who only need approximate metrics via existing Prometheus infrastructure, we can optionally expose the most recent `ForwardPassMetrics` as Prometheus gauges (updated in-place each iteration, scraped at whatever interval the collector uses). This is strictly less capable than the ZMQ stream but maintains compatibility with existing monitoring dashboards.
### 4. Data parallel support
Each DP rank runs its own EngineCore with its own scheduler. Each rank binds its own ZMQ PUB socket on `base_port + dp_rank`, emitting independent FPM streams tagged with `dp_rank`.
**Attention DP (non-MoE):** Each rank is fully independent (`dp_size=1` locally). Each rank emits its own FPM stream. No cross-rank coordination needed -- the consumer (autoscaler, planner) subscribes to each rank's ZMQ port independently and aggregates as needed.
**DP+EP (MoE):** Each rank has its own scheduler and emits its own FPM. Although the GPU forward pass is synchronized across ranks via collectives (`coordinate_batch_across_dp`), each rank's `wall_time` is measured locally at its own `future.result()` boundary. The measurements are nearly identical across ranks (collectives force sync), so any rank's data is representative. Consumers can average or use rank 0's data.
This is the **same approach used by KV cache events** today: each DP rank publishes to its own ZMQ port, and the relay/consumer layer handles multi-rank aggregation outside the engine.
### 5. Activation
Controlled by a new engine argument:
```
--forward-pass-metrics-port PORT # 0 = disabled (default), >0 = ZMQ PUB base port
```
For DP deployments, rank N binds on `PORT + N`. When enabled, the scheduler base class (or a thin mixin) handles metric extraction and ZMQ publishing. No subclass override needed -- this should work with any scheduler implementation.
### 6. Wire format and versioning
- **Serialization**: msgpack via `msgspec.msgpack.Encoder` (same as KV cache events)
- **ZMQ multipart**: `[b"", seq.to_bytes(8, "big"), msgpack_payload]`
- Empty topic allows future topic-based filtering
- 8-byte big-endian sequence number for ordering / gap detection
- msgpack payload is the serialized `ForwardPassMetrics`
- **Versioning**: `version` field in the struct. Consumers must check version before interpreting fields. Bump on incompatible changes.
### 7. Implementation scope
| Component | Change |
|-----------|--------|
| `EngineCore.step()` / `step_with_batch_queue()` | Measure `wall_time` around `future.result()`, pass to `update_from_output()` |
| `Scheduler.update_from_output()` | Accept optional `wall_time` kwarg |
| `SchedulerInterface` | New optional method `get_forward_pass_metrics()` or mixin |
| New: `ForwardPassMetrics` struct | In `vllm/v1/metrics/` or `vllm/v1/core/sched/` |
| New: `FpmPublisher` (ZMQ background thread) | Modeled after existing `ZmqEventPublisher` |
| `AsyncEngineArgs` | New `--forward-pass-metrics-port` argument |
| Optional: Prometheus stat logger | Expose most-recent FPM fields as gauges |
---
## Feedback Period
2 weeks.
---
## CC List
@simon-mo @youkaichao @WoosukKwon @robertgshaw2-redhat
---
## Any Other Things
**Reference implementation:** NVIDIA Dynamo's [InstrumentedScheduler](https://github.com/ai-dynamo/dynamo/blob/main/components/src/dynamo/vllm/instrumented_scheduler.py) implements this as an out-of-tree scheduler subclass with scheduler-side timing. Moving the timing into EngineCore and the ZMQ publisher into vLLM core would:
1. Eliminate the need for `--scheduler-cls` overrides for metrics
2. Provide accurate GPU timing (not scheduler-approximate)
3. Allow any orchestration system (not just Dynamo) to consume per-iteration metrics
4. Reuse existing ZMQ infrastructure from KV cache events
**Existing ZMQ precedent in vLLM:** The KV cache event system (`KVEventsConfig`, `ZmqEventPublisher`) already uses this exact pattern -- ZMQ PUB on localhost, msgpack serialization, background thread. Forward pass metrics would follow the same architecture.
**Not in scope:** How consumers (Dynamo, custom autoscalers, etc.) subscribe, relay, or aggregate these metrics. That is consumer-side logic. This RFC only covers emission from vLLM.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment