Unverified Commit af30b779 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: forward pass metric via ZMQ in vllm (#7200)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent 834eea61
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
ForwardPassMetrics schema for per-iteration scheduler telemetry.
Published over ZMQ PUB by InstrumentedScheduler, consumed by the
planner or any ZMQ SUB listener.
Uses msgspec.Struct for zero-copy serialization (same approach as
vLLM's KV cache events).
TODO: hook to our rust infra for discovery
TODO: add metrics for Trtllm/SGLang
TODO: planner consuming these metrics instead of frontend/router metrics
"""
from __future__ import annotations
import msgspec
class ScheduledRequestMetrics(
msgspec.Struct,
frozen=True,
gc=False,
):
"""Metrics for requests scheduled in this iteration"""
# Number of prefill requests (new requests + chunked prefill continuations).
num_prefill_requests: int = 0
# Total tokens being freshly computed for prefill requests in this
# iteration. Does NOT include prefix-cached or previously-chunked tokens
# (those are in sum_prefill_kv_tokens). For chunked prefill, this is the
# chunk size being computed this step.
sum_prefill_tokens: int = 0
# Population variance of total prompt lengths (not chunk sizes) across
# prefill requests. A request with a 10k-token prompt counts as 10k even
# if only a 2k chunk is computed this iteration.
var_prefill_length: float = 0.0
# Total KV cache tokens that must be read (not computed) for prefill
# requests. Includes prefix cache hits for new requests and previously
# computed chunks for chunked prefill continuations.
sum_prefill_kv_tokens: int = 0
# Number of decode requests (generating output tokens).
num_decode_requests: int = 0
# Total KV context length across all decode requests (prompt + generated
# tokens so far). Reflects the memory pressure from decoding.
sum_decode_kv_tokens: int = 0
# Population variance of KV context lengths across decode requests.
# High variance means a mix of short and long sequences decoding together.
var_decode_kv_tokens: float = 0.0
class QueuedRequestMetrics(
msgspec.Struct,
frozen=True,
gc=False,
):
"""Metrics for requests waiting in the queue (not scheduled this iteration).
All token counts here are raw totals -- prefix cache effects are unknown
until a request is actually scheduled.
"""
# Number of queued prefill requests (status=WAITING).
num_prefill_requests: int = 0
# Total prompt token count of queued prefill requests.
sum_prefill_tokens: int = 0
# Population variance of prompt lengths for queued prefill requests.
var_prefill_length: float = 0.0
# Number of queued decode requests (preempted -- were decoding but got
# evicted back to the waiting queue due to memory pressure).
num_decode_requests: int = 0
# Total KV context length of queued decode (preempted) requests.
sum_decode_kv_tokens: int = 0
# Population variance of KV context lengths for queued decode requests.
var_decode_kv_tokens: float = 0.0
class ForwardPassMetrics(
msgspec.Struct,
frozen=True,
gc=False,
):
"""Per-iteration metrics emitted by InstrumentedScheduler.
One message is emitted per scheduler iteration (one per forward pass).
An idle heartbeat (all zeros, wall_time=0) is emitted once when the
engine transitions from active to idle.
"""
# Unique worker identifier (Dynamo runtime connection_id).
worker_id: str = ""
# Data parallel rank. Each DP rank has its own scheduler and ZMQ port.
dp_rank: int = 0
# Wall-clock time of this iteration: from schedule() to update_from_output().
# Covers scheduling + model forward pass + output processing.
# 0.0 for idle heartbeat messages.
wall_time: float = 0.0
# Requests that were scheduled and executed in this iteration.
scheduled_requests: ScheduledRequestMetrics = ScheduledRequestMetrics()
# Requests that exist in the waiting queue but were not scheduled.
queued_requests: QueuedRequestMetrics = QueuedRequestMetrics()
_encoder = msgspec.msgpack.Encoder()
_decoder = msgspec.msgpack.Decoder(ForwardPassMetrics)
def encode(metrics: ForwardPassMetrics) -> bytes:
return _encoder.encode(metrics)
def decode(data: bytes) -> ForwardPassMetrics:
return _decoder.decode(data)
...@@ -259,6 +259,24 @@ def update_engine_config_with_dynamo( ...@@ -259,6 +259,24 @@ def update_engine_config_with_dynamo(
f"(use_kv_events={dynamo_config.use_kv_events})" f"(use_kv_events={dynamo_config.use_kv_events})"
) )
if envs.is_set("DYN_VLLM_FORWARDPASS_METRIC_PORT"):
existing_cls = getattr(engine_config, "scheduler_cls", None)
if existing_cls is None:
defaults[
"scheduler_cls"
] = "dynamo.vllm.instrumented_scheduler.InstrumentedScheduler"
logger.info(
"Forward pass metrics enabled: scheduler_cls set to InstrumentedScheduler "
f"(port={envs.DYN_VLLM_FORWARDPASS_METRIC_PORT})"
)
else:
logger.warning(
f"DYN_VLLM_FORWARDPASS_METRIC_PORT is set but scheduler_cls "
f"is already '{existing_cls}'. InstrumentedScheduler will NOT "
f"be injected. To use forward pass metrics, either remove "
f"--scheduler-cls or subclass InstrumentedScheduler."
)
logger.debug("Setting Dynamo defaults for vLLM") logger.debug("Setting Dynamo defaults for vLLM")
for key, value in defaults.items(): for key, value in defaults.items():
if hasattr(engine_config, key): if hasattr(engine_config, key):
......
...@@ -20,6 +20,7 @@ REGISTERED_PORT_MAX = 49151 ...@@ -20,6 +20,7 @@ REGISTERED_PORT_MAX = 49151
if TYPE_CHECKING: if TYPE_CHECKING:
DYN_VLLM_KV_EVENT_PORT: int = 20080 DYN_VLLM_KV_EVENT_PORT: int = 20080
DYN_VLLM_FORWARDPASS_METRIC_PORT: int = 20380
def _resolve_port(env_var: str, default_port: int) -> int: def _resolve_port(env_var: str, default_port: int) -> int:
...@@ -61,6 +62,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -61,6 +62,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Port used for KV events publishing to the frontend # Port used for KV events publishing to the frontend
# Note: This env variable is ignored if explicitly using --kv-events-config '' # Note: This env variable is ignored if explicitly using --kv-events-config ''
"DYN_VLLM_KV_EVENT_PORT": lambda: _resolve_port("DYN_VLLM_KV_EVENT_PORT", 20080), "DYN_VLLM_KV_EVENT_PORT": lambda: _resolve_port("DYN_VLLM_KV_EVENT_PORT", 20080),
"DYN_VLLM_FORWARDPASS_METRIC_PORT": lambda: _resolve_port(
"DYN_VLLM_FORWARDPASS_METRIC_PORT", 20380
),
} }
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
InstrumentedScheduler -- vLLM Scheduler subclass that emits
ForwardPassMetrics over ZMQ PUB on every iteration.
The scheduler thread does a single-pass accumulation (count, sum,
sum_of_squares) and produces a final ForwardPassMetrics struct.
Serialization and ZMQ send are handled by a background thread
(same approach as vLLM's ZmqEventPublisher) so the scheduler
hot path only pays for accumulation + queue.put().
Inject via:
--scheduler-cls "dynamo.vllm.instrumented_scheduler.InstrumentedScheduler"
"""
from __future__ import annotations
import logging
import os
import queue
import threading
import time
from itertools import count
from typing import TYPE_CHECKING
import zmq
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import RequestStatus
from dynamo.common.forward_pass_metrics import (
ForwardPassMetrics,
QueuedRequestMetrics,
ScheduledRequestMetrics,
encode,
)
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.structured_output import StructuredOutputManager
logger = logging.getLogger(__name__)
DEFAULT_FPM_PORT = 20380
ENV_FPM_PORT = "DYN_VLLM_FORWARDPASS_METRIC_PORT"
class _Accum:
"""Welford's online algorithm for count / sum / population-variance.
Numerically stable single-pass computation -- avoids catastrophic
cancellation that sum-of-squares can suffer with large values.
"""
__slots__ = ("n", "s", "_mean", "_m2")
def __init__(self) -> None:
self.n = 0
self.s = 0
self._mean = 0.0
self._m2 = 0.0
def add(self, v: int) -> None:
self.n += 1
self.s += v
delta = v - self._mean
self._mean += delta / self.n
delta2 = v - self._mean
self._m2 += delta * delta2
def variance(self) -> float:
if self.n == 0:
return 0.0
return self._m2 / self.n
# ---------------------------------------------------------------------------
# Background publisher thread
# ---------------------------------------------------------------------------
class _FpmPublisherThread:
"""Background thread that serializes and sends ForwardPassMetrics over ZMQ.
Also emits periodic heartbeats when idle.
"""
SHUTDOWN_TIMEOUT: float = 1.0
HEARTBEAT_INTERVAL: float = 1.0
def __init__(
self,
endpoint: str,
worker_id: str,
dp_rank: int,
max_queue_size: int = 10_000,
) -> None:
self._queue: queue.Queue[ForwardPassMetrics | None] = queue.Queue(
maxsize=max_queue_size
)
self._seq = count()
self._worker_id = worker_id
self._dp_rank = dp_rank
self._ctx = zmq.Context.instance()
self._pub = self._ctx.socket(zmq.PUB)
self._pub.bind(endpoint)
self._running = True
self._thread = threading.Thread(
target=self._run, daemon=True, name="fpm-zmq-publisher"
)
self._thread.start()
def publish(self, metrics: ForwardPassMetrics) -> None:
if not self._running:
return
try:
self._queue.put_nowait(metrics)
except queue.Full:
pass
def shutdown(self) -> None:
self._running = False
try:
self._queue.put_nowait(None)
except queue.Full:
pass
self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
try:
self._pub.close(linger=0)
except Exception:
pass
def _run(self) -> None:
topic = b""
last_publish = time.monotonic()
while self._running or not self._queue.empty():
try:
metrics = self._queue.get(timeout=self.HEARTBEAT_INTERVAL)
if metrics is None:
break
except queue.Empty:
if time.monotonic() - last_publish >= self.HEARTBEAT_INTERVAL:
metrics = ForwardPassMetrics(
worker_id=self._worker_id,
dp_rank=self._dp_rank,
)
else:
continue
try:
payload = encode(metrics)
seq_bytes = next(self._seq).to_bytes(8, "big")
self._pub.send_multipart((topic, seq_bytes, payload), flags=zmq.NOBLOCK)
last_publish = time.monotonic()
except zmq.Again:
pass
except Exception:
logger.warning("FPM publisher send failed", exc_info=True)
# ---------------------------------------------------------------------------
# Scheduler subclass
# ---------------------------------------------------------------------------
class InstrumentedScheduler(Scheduler):
def __init__(
self,
vllm_config: "VllmConfig",
kv_cache_config: "KVCacheConfig",
structured_output_manager: "StructuredOutputManager",
block_size: int,
**kwargs,
) -> None:
super().__init__(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
structured_output_manager=structured_output_manager,
block_size=block_size,
**kwargs,
)
dp_rank = getattr(vllm_config.parallel_config, "data_parallel_rank", 0) or 0
self._fpm_worker_id = vllm_config.additional_config.get("fpm_worker_id", "")
self._fpm_dp_rank = dp_rank
self._schedule_time: float = 0.0
self._pending_output: SchedulerOutput | None = None
self._pending_queued: QueuedRequestMetrics | None = None
self._prompt_len_per_req: dict[str, int] = {}
base_port = int(os.environ.get(ENV_FPM_PORT, str(DEFAULT_FPM_PORT)))
port = base_port + dp_rank
self._publisher = _FpmPublisherThread(
f"tcp://*:{port}",
worker_id=self._fpm_worker_id,
dp_rank=dp_rank,
)
logger.info(
"InstrumentedScheduler: ZMQ PUB bound on tcp://*:%d "
"(worker_id=%s, dp_rank=%d)",
port,
self._fpm_worker_id,
dp_rank,
)
# ------------------------------------------------------------------
# Overrides
# ------------------------------------------------------------------
def shutdown(self) -> None:
self._publisher.shutdown()
super().shutdown()
def schedule(self) -> SchedulerOutput:
self._schedule_time = time.monotonic()
output = super().schedule()
self._pending_output = output
self._pending_queued = self._compute_queued()
return output
def update_from_output(
self,
scheduler_output: SchedulerOutput,
model_runner_output: "ModelRunnerOutput",
):
result = super().update_from_output(scheduler_output, model_runner_output)
wall_time = time.monotonic() - self._schedule_time
if self._pending_output is not None:
metrics = self._extract_metrics(
self._pending_output,
self._pending_queued,
wall_time,
)
self._publisher.publish(metrics)
self._pending_output = None
self._pending_queued = None
self._cleanup_finished(scheduler_output)
return result
# ------------------------------------------------------------------
# Metric extraction (single-pass with _Accum, no lists)
# ------------------------------------------------------------------
def _extract_metrics(
self,
output: SchedulerOutput,
queued: QueuedRequestMetrics | None,
wall_time: float,
) -> ForwardPassMetrics:
return ForwardPassMetrics(
worker_id=self._fpm_worker_id,
dp_rank=self._fpm_dp_rank,
wall_time=wall_time,
scheduled_requests=self._extract_scheduled(output),
queued_requests=queued or QueuedRequestMetrics(),
)
def _extract_scheduled(self, output: SchedulerOutput) -> ScheduledRequestMetrics:
new_reqs: list[NewRequestData] = output.scheduled_new_reqs
cached: CachedRequestData = output.scheduled_cached_reqs
num_scheduled = output.num_scheduled_tokens
num_prefill = 0
sum_prefill_tokens = 0
prefill_lengths = _Accum()
sum_prefill_kv_tokens = 0
decode_kv = _Accum()
for req in new_reqs:
num_prefill += 1
sum_prefill_tokens += num_scheduled.get(req.req_id, 0)
prompt_len = len(req.prompt_token_ids) if req.prompt_token_ids else 0
prefill_lengths.add(prompt_len)
sum_prefill_kv_tokens += req.num_computed_tokens
self._prompt_len_per_req[req.req_id] = prompt_len
for i, req_id in enumerate(cached.req_ids):
if cached.is_context_phase(req_id):
num_prefill += 1
sum_prefill_tokens += num_scheduled.get(req_id, 0)
prefill_lengths.add(self._prompt_len_per_req.get(req_id, 0))
sum_prefill_kv_tokens += cached.num_computed_tokens[i]
else:
decode_kv.add(cached.num_computed_tokens[i])
return ScheduledRequestMetrics(
num_prefill_requests=num_prefill,
sum_prefill_tokens=sum_prefill_tokens,
var_prefill_length=prefill_lengths.variance(),
sum_prefill_kv_tokens=sum_prefill_kv_tokens,
num_decode_requests=decode_kv.n,
sum_decode_kv_tokens=decode_kv.s,
var_decode_kv_tokens=decode_kv.variance(),
)
def _compute_queued(self) -> QueuedRequestMetrics:
"""Single-pass aggregation over self.waiting -- no intermediate list."""
prefill = _Accum()
decode_kv = _Accum()
for request in self.waiting:
if request.status == RequestStatus.PREEMPTED:
decode_kv.add(request.num_computed_tokens)
else:
prefill.add(request.num_tokens)
return QueuedRequestMetrics(
num_prefill_requests=prefill.n,
sum_prefill_tokens=prefill.s,
var_prefill_length=prefill.variance(),
num_decode_requests=decode_kv.n,
sum_decode_kv_tokens=decode_kv.s,
var_decode_kv_tokens=decode_kv.variance(),
)
# ------------------------------------------------------------------
# State cleanup
# ------------------------------------------------------------------
def _cleanup_finished(self, output: SchedulerOutput) -> None:
for req_id in output.finished_req_ids:
self._prompt_len_per_req.pop(req_id, None)
...@@ -370,7 +370,9 @@ def setup_kv_event_publisher( ...@@ -370,7 +370,9 @@ def setup_kv_event_publisher(
def setup_vllm_engine( def setup_vllm_engine(
config: Config, stat_logger: Optional[StatLoggerFactory] = None config: Config,
stat_logger: Optional[StatLoggerFactory] = None,
fpm_worker_id: Optional[str] = None,
) -> tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]: ) -> tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]:
# vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object # vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object
# instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR # instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
...@@ -485,6 +487,10 @@ def setup_vllm_engine( ...@@ -485,6 +487,10 @@ def setup_vllm_engine(
# dataclass fields; monkey-patching attributes onto VllmConfig is no longer safe). # dataclass fields; monkey-patching attributes onto VllmConfig is no longer safe).
vllm_config.additional_config["consolidator_endpoints"] = consolidator_endpoints vllm_config.additional_config["consolidator_endpoints"] = consolidator_endpoints
# Pass worker identity to InstrumentedScheduler via additional_config.
if fpm_worker_id is not None:
vllm_config.additional_config["fpm_worker_id"] = fpm_worker_id
factory = [] factory = []
if stat_logger: if stat_logger:
factory.append(stat_logger) factory.append(stat_logger)
...@@ -627,7 +633,9 @@ async def init_prefill( ...@@ -627,7 +633,9 @@ async def init_prefill(
default_sampling_params, default_sampling_params,
prometheus_temp_dir, prometheus_temp_dir,
_component_gauges, _component_gauges,
) = setup_vllm_engine(config) ) = setup_vllm_engine(
config, fpm_worker_id=str(generate_endpoint.connection_id())
)
handler = PrefillWorkerHandler( handler = PrefillWorkerHandler(
runtime, runtime,
...@@ -808,7 +816,9 @@ async def init( ...@@ -808,7 +816,9 @@ async def init(
default_sampling_params, default_sampling_params,
prometheus_temp_dir, prometheus_temp_dir,
component_gauges, component_gauges,
) = setup_vllm_engine(config, factory) ) = setup_vllm_engine(
config, factory, fpm_worker_id=str(generate_endpoint.connection_id())
)
# TODO Hack to get data, move this to registering in TBD # TODO Hack to get data, move this to registering in TBD
factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks) factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks)
......
...@@ -24,6 +24,8 @@ dependencies = [ ...@@ -24,6 +24,8 @@ dependencies = [
"click<8.2.0", "click<8.2.0",
"setuptools", "setuptools",
"prometheus_client>=0.23.1,<1.0", "prometheus_client>=0.23.1,<1.0",
"msgspec>=0.19.0",
"pyzmq>=26.0.0",
"msgpack==1.1.2", "msgpack==1.1.2",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment