Unverified Commit cbdab502 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

fix: vllm forwardpass metric for async scheduling (#7537)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent b7fe46b1
...@@ -10,7 +10,8 @@ etcd / file) and prints each metric message as JSON. ...@@ -10,7 +10,8 @@ etcd / file) and prints each metric message as JSON.
Usage: Usage:
python -m dynamo.common.recv_forward_pass_metrics \\ python -m dynamo.common.recv_forward_pass_metrics \\
--namespace dynamo --component backend --endpoint generate \\ --namespace dynamo --component backend --endpoint generate \\
[--discovery-backend etcd] [--request-plane nats] [--discovery-backend etcd] [--request-plane nats] \\
[--save-plot metrics.png]
""" """
import argparse import argparse
...@@ -18,18 +19,58 @@ import asyncio ...@@ -18,18 +19,58 @@ import asyncio
import json import json
import logging import logging
import os import os
import time
import matplotlib
import matplotlib.pyplot as plt
import msgspec import msgspec
from dynamo.common.forward_pass_metrics import decode from dynamo.common.forward_pass_metrics import ForwardPassMetrics, decode
from dynamo.llm import FpmEventSubscriber
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
matplotlib.use("Agg")
configure_dynamo_logging() configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _save_plot(path: str, history: list[tuple[float, ForwardPassMetrics]]) -> None:
"""Render 5-panel time-series plot and save to *path*."""
if not history:
logger.warning("No data collected, skipping plot.")
return
ts = [t for t, _ in history]
num_prefill = [m.scheduled_requests.num_prefill_requests for _, m in history]
sum_prefill = [m.scheduled_requests.sum_prefill_tokens for _, m in history]
num_decode = [m.scheduled_requests.num_decode_requests for _, m in history]
sum_kv = [m.scheduled_requests.sum_decode_kv_tokens for _, m in history]
wall = [m.wall_time for _, m in history]
fig, axes = plt.subplots(5, 1, figsize=(12, 14), sharex=True)
panels = [
(axes[0], num_prefill, "num_prefill_requests"),
(axes[1], sum_prefill, "sum_prefill_tokens"),
(axes[2], num_decode, "num_decode_requests"),
(axes[3], sum_kv, "sum_decode_kv_tokens"),
(axes[4], wall, "wall_time (s)"),
]
for ax, data, label in panels:
ax.plot(ts, data, linewidth=0.8)
ax.set_ylabel(label)
ax.grid(True, alpha=0.3)
axes[-1].set_xlabel("Time (s)")
fig.suptitle("ForwardPassMetrics", fontsize=14)
fig.tight_layout()
fig.savefig(path, dpi=150)
plt.close(fig)
logger.info("Plot saved to %s (%d data points)", path, len(history))
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Receive ForwardPassMetrics from the Dynamo event plane" description="Receive ForwardPassMetrics from the Dynamo event plane"
...@@ -53,12 +94,20 @@ def main() -> None: ...@@ -53,12 +94,20 @@ def main() -> None:
default=os.environ.get("DYN_REQUEST_PLANE", "nats"), default=os.environ.get("DYN_REQUEST_PLANE", "nats"),
help="Request plane (default: nats)", help="Request plane (default: nats)",
) )
parser.add_argument(
"--save-plot",
metavar="PATH",
default=None,
help="Save a time-series plot to the given PNG path on exit",
)
args = parser.parse_args() args = parser.parse_args()
asyncio.run(run(args)) asyncio.run(run(args))
async def run(args: argparse.Namespace) -> None: async def run(args: argparse.Namespace) -> None:
from dynamo.llm import FpmEventSubscriber
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
event_plane = os.environ.get("DYN_EVENT_PLANE", "nats") event_plane = os.environ.get("DYN_EVENT_PLANE", "nats")
enable_nats = args.request_plane == "nats" or event_plane == "nats" enable_nats = args.request_plane == "nats" or event_plane == "nats"
...@@ -77,6 +126,9 @@ async def run(args: argparse.Namespace) -> None: ...@@ -77,6 +126,9 @@ async def run(args: argparse.Namespace) -> None:
args.component, args.component,
) )
history: list[tuple[float, ForwardPassMetrics]] = []
start_time: float | None = None
try: try:
while True: while True:
data = await asyncio.to_thread(subscriber.recv) data = await asyncio.to_thread(subscriber.recv)
...@@ -86,6 +138,14 @@ async def run(args: argparse.Namespace) -> None: ...@@ -86,6 +138,14 @@ async def run(args: argparse.Namespace) -> None:
metrics = decode(data) metrics = decode(data)
if metrics is None: if metrics is None:
continue continue
now = time.monotonic()
if start_time is None:
start_time = now
if args.save_plot:
history.append((now - start_time, metrics))
pretty = json.loads(json_encoder.encode(metrics)) pretty = json.loads(json_encoder.encode(metrics))
logger.info( logger.info(
"[worker=%s dp=%d counter=%d] %s", "[worker=%s dp=%d counter=%d] %s",
...@@ -98,6 +158,8 @@ async def run(args: argparse.Namespace) -> None: ...@@ -98,6 +158,8 @@ async def run(args: argparse.Namespace) -> None:
logger.info("Stopped.") logger.info("Stopped.")
finally: finally:
subscriber.shutdown() subscriber.shutdown()
if args.save_plot and history:
_save_plot(args.save_plot, history)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -2,11 +2,73 @@ ...@@ -2,11 +2,73 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" """
InstrumentedScheduler -- vLLM Scheduler subclass that emits InstrumentedScheduler -- vLLM AsyncScheduler subclass that emits
ForwardPassMetrics over ZMQ PUB on every iteration. ForwardPassMetrics over ZMQ PUB on every forward pass completion.
Scheduling modes
----------------
vLLM's EngineCore has two execution modes selected at startup:
* **Sync** (``batch_queue`` is None, uses ``EngineCore.step``):
``schedule() -> execute_model() [blocking] -> update_from_output()``
One schedule per forward pass, CPU blocks while GPU runs.
* **Async** (``batch_queue_size=2``, uses ``step_with_batch_queue``):
The engine overlaps scheduling with GPU execution to hide CPU overhead.
``schedule(N)`` is called and the batch is submitted, then the engine
returns early. On the next loop iteration ``schedule(N+1)`` runs
(while the GPU is still processing batch N), then the engine blocks
until batch N completes and calls ``update_from_output(N)``.
This means ``schedule()`` is called **twice** before the first
``update_from_output()``.
``AsyncScheduler`` handles this by adding *output placeholders* in
``_update_after_schedule()``: ``num_output_placeholders += 1`` keeps
``num_new_tokens == 1`` for every running request, so the next
``schedule()`` can schedule all requests again without waiting for
the sampled token from ``update_from_output()``.
Why we extend AsyncScheduler (not Scheduler)
---------------------------------------------
vLLM's ``--scheduler-cls`` only accepts a single class; it does not
auto-select between ``Scheduler`` and ``AsyncScheduler`` based on the
engine mode. We extend ``AsyncScheduler`` because:
1. If we extended ``Scheduler`` (without placeholders), the second
``schedule()`` call in async mode would see ``num_new_tokens == 0``
for all requests already advanced by ``_update_after_schedule``,
producing partial batches (e.g. 22/28 split of 50 requests) with
incorrect per-batch ``sum_decode_kv_tokens`` and other metrics.
2. ``AsyncScheduler`` is a thin wrapper (adds placeholders in
``_update_after_schedule`` and decrements them in
``_update_request_with_output``). The placeholder logic is
harmless in sync mode: placeholders are added and immediately
consumed within the same step (``0 -> 1 -> 0`` per iteration).
3. A single subclass that works correctly in both sync and async
engine modes avoids the need for mode detection or two classes.
How metrics are measured
------------------------
* **Emission point**: ``update_from_output()``, called once per
completed GPU forward pass (after the engine pops the batch result).
Empty batches (``total_num_scheduled_tokens == 0``) are skipped.
* **scheduled_requests**: extracted from the ``SchedulerOutput``
parameter passed to ``update_from_output`` (the EngineCore always
passes the correct output for the batch being processed, even in
async mode where multiple batches are in flight).
* **queued_requests**: computed from ``self.waiting`` at emit time.
* **wall_time**: approximates the schedule-to-update_from_output
latency described in ``ForwardPassMetrics``. Measured as the time
between consecutive ``update_from_output()`` calls. This works
because the EngineCore always blocks on ``future.result()`` (the
GPU forward pass) right before calling ``update_from_output``, so
the interval is dominated by GPU compute. Assumption: CPU overhead
(scheduling + output processing) between consecutive calls is small
relative to GPU forward pass time. ``wall_time`` is ``0.0`` for
the first message after engine idle and for heartbeats.
The scheduler thread does a single-pass accumulation (count, sum,
sum_of_squares) and produces a final ForwardPassMetrics struct.
Serialization and ZMQ send are handled by a background thread Serialization and ZMQ send are handled by a background thread
(same approach as vLLM's ZmqEventPublisher) so the scheduler (same approach as vLLM's ZmqEventPublisher) so the scheduler
hot path only pays for accumulation + queue.put(). hot path only pays for accumulation + queue.put().
...@@ -27,8 +89,8 @@ from typing import TYPE_CHECKING ...@@ -27,8 +89,8 @@ from typing import TYPE_CHECKING
import msgspec.structs import msgspec.structs
import zmq import zmq
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import RequestStatus from vllm.v1.request import RequestStatus
from dynamo.common.forward_pass_metrics import ( from dynamo.common.forward_pass_metrics import (
...@@ -148,7 +210,7 @@ class _FpmPublisherThread: ...@@ -148,7 +210,7 @@ class _FpmPublisherThread:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class InstrumentedScheduler(Scheduler): class InstrumentedScheduler(AsyncScheduler):
def __init__( def __init__(
self, self,
vllm_config: "VllmConfig", vllm_config: "VllmConfig",
...@@ -169,9 +231,7 @@ class InstrumentedScheduler(Scheduler): ...@@ -169,9 +231,7 @@ class InstrumentedScheduler(Scheduler):
self._fpm_worker_id = vllm_config.additional_config.get("fpm_worker_id", "") self._fpm_worker_id = vllm_config.additional_config.get("fpm_worker_id", "")
self._fpm_dp_rank = dp_rank self._fpm_dp_rank = dp_rank
self._schedule_time: float = 0.0 self._last_update_time: float = 0.0
self._pending_output: SchedulerOutput | None = None
self._pending_queued: QueuedRequestMetrics | None = None
self._prompt_len_per_req: dict[str, int] = {} self._prompt_len_per_req: dict[str, int] = {}
base_port = int(os.environ.get(ENV_FPM_PORT, str(DEFAULT_FPM_PORT))) base_port = int(os.environ.get(ENV_FPM_PORT, str(DEFAULT_FPM_PORT)))
...@@ -198,38 +258,28 @@ class InstrumentedScheduler(Scheduler): ...@@ -198,38 +258,28 @@ class InstrumentedScheduler(Scheduler):
self._publisher.shutdown() self._publisher.shutdown()
super().shutdown() super().shutdown()
def schedule(self) -> SchedulerOutput:
self._schedule_time = time.monotonic()
output = super().schedule()
self._pending_output = output
self._pending_queued = self._compute_queued()
return output
def update_from_output( def update_from_output(
self, self,
scheduler_output: SchedulerOutput, scheduler_output: SchedulerOutput,
model_runner_output: "ModelRunnerOutput", model_runner_output: "ModelRunnerOutput",
): ):
result = super().update_from_output(scheduler_output, model_runner_output) result = super().update_from_output(scheduler_output, model_runner_output)
now = time.monotonic()
wall_time = time.monotonic() - self._schedule_time if scheduler_output.total_num_scheduled_tokens > 0:
wall_time = (
now - self._last_update_time if self._last_update_time > 0 else 0.0
)
self._last_update_time = now
if self._pending_output is not None:
metrics = self._extract_metrics( metrics = self._extract_metrics(
self._pending_output, scheduler_output, self._compute_queued(), wall_time
self._pending_queued,
wall_time,
) )
self._publisher.publish(metrics) self._publisher.publish(metrics)
else:
self._pending_output = None self._last_update_time = 0.0
self._pending_queued = None
self._cleanup_finished(scheduler_output) self._cleanup_finished(scheduler_output)
return result return result
# ------------------------------------------------------------------ # ------------------------------------------------------------------
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment