Unverified Commit 4e6c3964 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

feat(mocker): emit forward pass metrics to event plane for planner consumption (#8032)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
Signed-off-by: default avatarjthomson04 <jothomson@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 1142478a
......@@ -11,10 +11,16 @@
//! [`crate::kv_router::publisher::KvEventPublisher`], but is much simpler:
//! no event transformation, no batching, no local indexer — just raw byte relay.
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use futures::StreamExt;
use serde::Serialize;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use dynamo_mocker::common::protocols::{ForwardPassSnapshot, FpmPublisher, FpmSink};
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::transports::event_plane::EventPublisher;
......@@ -22,6 +28,9 @@ use dynamo_runtime::transports::event_plane::EventPublisher;
use crate::utils::zmq::{connect_sub_socket, multipart_message};
const FPM_TOPIC: &str = "forward-pass-metrics";
const FPM_VERSION: i32 = 1;
/// Matches Python `_FpmPublisherThread.HEARTBEAT_INTERVAL`.
const IDLE_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(1);
/// A relay that bridges ForwardPassMetrics from a local raw ZMQ PUB socket
/// to the Dynamo event plane.
......@@ -114,3 +123,588 @@ impl Drop for FpmEventRelay {
self.cancel.cancel();
}
}
// ---------------------------------------------------------------------------
// Direct publisher: mocker scheduler -> event plane (no ZMQ hop)
// ---------------------------------------------------------------------------
/// Serialization struct matching Python `ScheduledRequestMetrics` from
/// `forward_pass_metrics.py`. Field names must match exactly since msgspec
/// (without `array_like=True`) encodes structs as msgpack **maps** with
/// string keys.
#[derive(Serialize)]
struct ScheduledRequestMetricsSer {
num_prefill_requests: i32,
sum_prefill_tokens: i64,
var_prefill_length: f64,
sum_prefill_kv_tokens: i64,
num_decode_requests: i32,
sum_decode_kv_tokens: i64,
var_decode_kv_tokens: f64,
}
/// Serialization struct matching Python `QueuedRequestMetrics`.
#[derive(Serialize)]
struct QueuedRequestMetricsSer {
num_prefill_requests: i32,
sum_prefill_tokens: i64,
var_prefill_length: f64,
num_decode_requests: i32,
sum_decode_kv_tokens: i64,
var_decode_kv_tokens: f64,
}
/// Top-level serialization struct matching Python `ForwardPassMetrics`.
#[derive(Serialize)]
struct ForwardPassMetricsSer {
version: i32,
worker_id: String,
dp_rank: i64,
counter_id: i64,
wall_time: f64,
scheduled_requests: ScheduledRequestMetricsSer,
queued_requests: QueuedRequestMetricsSer,
}
fn serialize_fpm(
snapshot: &ForwardPassSnapshot,
worker_id: &str,
dp_rank: u32,
counter_id: i64,
) -> Result<Vec<u8>> {
let metrics = ForwardPassMetricsSer {
version: FPM_VERSION,
worker_id: worker_id.to_owned(),
dp_rank: dp_rank as i64,
counter_id,
wall_time: snapshot.wall_time_secs,
scheduled_requests: ScheduledRequestMetricsSer {
num_prefill_requests: snapshot.num_prefill_requests as i32,
sum_prefill_tokens: snapshot.sum_prefill_tokens as i64,
var_prefill_length: snapshot.var_prefill_length,
sum_prefill_kv_tokens: snapshot.sum_prefill_kv_tokens as i64,
num_decode_requests: snapshot.num_decode_requests as i32,
sum_decode_kv_tokens: snapshot.sum_decode_kv_tokens as i64,
var_decode_kv_tokens: snapshot.var_decode_kv_tokens,
},
queued_requests: QueuedRequestMetricsSer {
num_prefill_requests: snapshot.num_queued_prefill as i32,
sum_prefill_tokens: snapshot.sum_queued_prefill_tokens as i64,
var_prefill_length: snapshot.var_queued_prefill_length,
num_decode_requests: snapshot.num_queued_decode as i32,
sum_decode_kv_tokens: snapshot.sum_queued_decode_kv_tokens as i64,
var_decode_kv_tokens: snapshot.var_queued_decode_kv_tokens,
},
};
rmp_serde::to_vec_named(&metrics).map_err(|e| anyhow::anyhow!("FPM serialization failed: {e}"))
}
/// Live FPM sink that forwards snapshots to the `FpmDirectPublisher`'s
/// internal serialization pipeline via an mpsc channel.
struct LiveFpmSink {
tx: mpsc::UnboundedSender<ForwardPassSnapshot>,
}
impl FpmSink for LiveFpmSink {
fn publish(&self, snapshot: ForwardPassSnapshot) -> Result<()> {
self.tx
.send(snapshot)
.map_err(|_| anyhow::anyhow!("FPM publisher channel closed"))
}
}
/// Direct FPM publisher for the mocker engine.
///
/// Unlike [`FpmEventRelay`] (which bridges raw ZMQ from a forked vLLM child
/// process), this publishes [`ForwardPassSnapshot`] data directly to the
/// event plane from in-process mocker schedulers.
pub struct FpmDirectPublisher {
cancel: CancellationToken,
}
impl FpmDirectPublisher {
/// Create and start a new direct publisher, returning per-dp-rank sink handles.
///
/// Each returned [`FpmPublisher`] wraps a sink that feeds the shared
/// serialization + event-plane publish pipeline. The scheduler passes
/// one to each engine via the deferred-sink model.
///
/// - `component`: Dynamo component (provides runtime + discovery scope).
/// - `worker_id`: Unique worker identifier (typically `connection_id().to_string()`).
/// - `dp_size`: Number of data-parallel ranks.
pub async fn new(
component: Component,
worker_id: String,
dp_size: u32,
) -> Result<(Self, Vec<FpmPublisher>)> {
let rt = component.drt().runtime().secondary();
let cancel = CancellationToken::new();
let publisher = EventPublisher::for_component(&component, FPM_TOPIC).await?;
// Shared channel: per-dp_rank serialization tasks send bytes here,
// a single publisher task writes them to the event plane.
let (pub_tx, mut pub_rx) = mpsc::unbounded_channel::<Vec<u8>>();
// Publisher task
let cancel_pub = cancel.clone();
rt.spawn(async move {
loop {
tokio::select! {
biased;
_ = cancel_pub.cancelled() => break,
result = pub_rx.recv() => {
match result {
Some(payload) => {
if let Err(e) = publisher.publish_bytes(payload).await {
tracing::warn!("FPM direct publisher: event plane publish failed: {e}");
}
}
None => break,
}
}
}
}
tracing::info!("FPM direct publisher: shutting down");
});
// Per-dp_rank: create internal channels, return sink handles.
//
// Each task forwards active-pass snapshots and emits periodic idle
// heartbeats (zeroed snapshot, wall_time=0.0) when the scheduler is
// idle, matching the Python `_FpmPublisherThread` contract.
let mut fpm_publishers = Vec::with_capacity(dp_size as usize);
for dp_rank in 0..dp_size {
let (fpm_tx, mut fpm_rx) = mpsc::unbounded_channel();
let sink = Arc::new(LiveFpmSink { tx: fpm_tx }) as Arc<dyn FpmSink>;
fpm_publishers.push(FpmPublisher::new(Some(sink)));
let pub_tx = pub_tx.clone();
let worker_id = worker_id.clone();
let cancel_ser = cancel.clone();
rt.spawn(async move {
let mut counter: i64 = 0;
let heartbeat_sleep = tokio::time::sleep(IDLE_HEARTBEAT_INTERVAL);
tokio::pin!(heartbeat_sleep);
loop {
let snapshot = tokio::select! {
biased;
_ = cancel_ser.cancelled() => break,
result = fpm_rx.recv() => {
match result {
Some(snapshot) => {
// Active pass — reset the heartbeat timer.
heartbeat_sleep
.as_mut()
.reset(tokio::time::Instant::now() + IDLE_HEARTBEAT_INTERVAL);
snapshot
}
None => break,
}
}
_ = &mut heartbeat_sleep => {
// No snapshot for IDLE_HEARTBEAT_INTERVAL — emit
// zeroed idle heartbeat, then reset for the next
// interval.
heartbeat_sleep
.as_mut()
.reset(tokio::time::Instant::now() + IDLE_HEARTBEAT_INTERVAL);
ForwardPassSnapshot::default()
}
};
counter += 1;
match serialize_fpm(&snapshot, &worker_id, dp_rank, counter) {
Ok(bytes) => {
let _ = pub_tx.send(bytes);
}
Err(e) => {
tracing::warn!(
"FPM serialization failed for dp_rank {dp_rank}: {e}"
);
}
}
}
});
}
tracing::info!(
worker_id = %worker_id,
"FPM direct publisher started"
);
Ok((Self { cancel }, fpm_publishers))
}
pub fn shutdown(&self) {
self.cancel.cancel();
}
}
impl Drop for FpmDirectPublisher {
fn drop(&mut self) {
self.cancel.cancel();
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::Deserialize;
use std::collections::HashMap;
/// Verify that serialize_fpm produces valid msgpack that round-trips
/// through deserialization with the exact field names and values
/// expected by the Python `ForwardPassMetrics` schema.
#[test]
fn test_serialize_fpm_round_trip() {
let snapshot = ForwardPassSnapshot {
num_prefill_requests: 2,
sum_prefill_tokens: 256,
var_prefill_length: 100.0,
sum_prefill_kv_tokens: 64,
num_decode_requests: 3,
sum_decode_kv_tokens: 1024,
var_decode_kv_tokens: 50.0,
num_queued_prefill: 1,
sum_queued_prefill_tokens: 128,
var_queued_prefill_length: 0.0,
num_queued_decode: 0,
sum_queued_decode_kv_tokens: 0,
var_queued_decode_kv_tokens: 0.0,
wall_time_secs: 0.025,
};
let bytes = serialize_fpm(&snapshot, "worker-abc", 2, 42).unwrap();
// Deserialize with matching struct (Deserialize derived) to verify
// the wire format round-trips correctly.
#[derive(Deserialize, Debug)]
#[allow(dead_code)]
struct ScheduledDe {
num_prefill_requests: i32,
sum_prefill_tokens: i64,
var_prefill_length: f64,
sum_prefill_kv_tokens: i64,
num_decode_requests: i32,
sum_decode_kv_tokens: i64,
var_decode_kv_tokens: f64,
}
#[derive(Deserialize, Debug)]
#[allow(dead_code)]
struct QueuedDe {
num_prefill_requests: i32,
sum_prefill_tokens: i64,
var_prefill_length: f64,
num_decode_requests: i32,
sum_decode_kv_tokens: i64,
var_decode_kv_tokens: f64,
}
#[derive(Deserialize, Debug)]
#[allow(dead_code)]
struct FpmDe {
version: i32,
worker_id: String,
dp_rank: i64,
counter_id: i64,
wall_time: f64,
scheduled_requests: ScheduledDe,
queued_requests: QueuedDe,
}
let decoded: FpmDe = rmp_serde::from_slice(&bytes).expect("round-trip decode failed");
assert_eq!(decoded.version, 1);
assert_eq!(decoded.worker_id, "worker-abc");
assert_eq!(decoded.dp_rank, 2);
assert_eq!(decoded.counter_id, 42);
assert!((decoded.wall_time - 0.025).abs() < 1e-10);
assert_eq!(decoded.scheduled_requests.num_prefill_requests, 2);
assert_eq!(decoded.scheduled_requests.sum_prefill_tokens, 256);
assert!((decoded.scheduled_requests.var_prefill_length - 100.0).abs() < 1e-10);
assert_eq!(decoded.scheduled_requests.sum_prefill_kv_tokens, 64);
assert_eq!(decoded.scheduled_requests.num_decode_requests, 3);
assert_eq!(decoded.scheduled_requests.sum_decode_kv_tokens, 1024);
assert_eq!(decoded.queued_requests.num_prefill_requests, 1);
assert_eq!(decoded.queued_requests.sum_prefill_tokens, 128);
assert_eq!(decoded.queued_requests.num_decode_requests, 0);
}
/// Verify that worker_id and dp_rank can be extracted from the serialized
/// bytes by deserializing into a flat HashMap, simulating the subscriber's
/// `extract_fpm_key` approach of scanning the msgpack map for specific keys.
#[test]
fn test_serialize_fpm_extractable_key() {
let snapshot = ForwardPassSnapshot {
num_prefill_requests: 1,
sum_prefill_tokens: 100,
wall_time_secs: 0.01,
..Default::default()
};
let bytes = serialize_fpm(&snapshot, "my-worker-id", 7, 99).unwrap();
// Deserialize only the top-level flat fields (nested maps become
// opaque), matching the subscriber's partial-decode approach.
#[derive(Deserialize)]
struct PartialFpm {
worker_id: String,
dp_rank: i64,
}
let partial: PartialFpm = rmp_serde::from_slice(&bytes).expect("partial decode failed");
assert_eq!(partial.worker_id, "my-worker-id");
assert_eq!(partial.dp_rank, 7);
}
/// Verify that the idle heartbeat fires when no FPM arrives within
/// IDLE_HEARTBEAT_INTERVAL. We replicate the per-dp_rank serialization
/// task logic with real channels to test the timeout behavior.
#[tokio::test]
async fn test_idle_heartbeat_emits_zeroed_snapshot() {
let (fpm_tx, mut fpm_rx) = mpsc::unbounded_channel::<ForwardPassSnapshot>();
let (pub_tx, mut pub_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let cancel = CancellationToken::new();
let cancel_clone = cancel.clone();
let worker_id = "test-worker".to_string();
let dp_rank: u32 = 0;
// Spawn the same task logic as FpmDirectPublisher
tokio::spawn(async move {
let mut counter: i64 = 0;
let heartbeat_sleep = tokio::time::sleep(IDLE_HEARTBEAT_INTERVAL);
tokio::pin!(heartbeat_sleep);
loop {
let snapshot = tokio::select! {
biased;
_ = cancel_clone.cancelled() => break,
result = fpm_rx.recv() => {
match result {
Some(snapshot) => {
heartbeat_sleep
.as_mut()
.reset(tokio::time::Instant::now() + IDLE_HEARTBEAT_INTERVAL);
snapshot
}
None => break,
}
}
_ = &mut heartbeat_sleep => {
heartbeat_sleep
.as_mut()
.reset(tokio::time::Instant::now() + IDLE_HEARTBEAT_INTERVAL);
ForwardPassSnapshot::default()
}
};
counter += 1;
if let Ok(bytes) = serialize_fpm(&snapshot, &worker_id, dp_rank, counter) {
let _ = pub_tx.send(bytes);
}
}
});
// 1) Send an active snapshot first
let active = ForwardPassSnapshot {
num_prefill_requests: 2,
sum_prefill_tokens: 100,
wall_time_secs: 0.05,
..Default::default()
};
fpm_tx.send(active).unwrap();
// Receive the active snapshot
let bytes = tokio::time::timeout(Duration::from_secs(2), pub_rx.recv())
.await
.expect("timed out waiting for active FPM")
.expect("channel closed");
#[derive(Deserialize)]
struct FpmWallTime {
wall_time: f64,
}
let decoded: FpmWallTime = rmp_serde::from_slice(&bytes).expect("active FPM decode failed");
assert!(
decoded.wall_time > 0.0,
"active snapshot should have wall_time > 0"
);
// 2) Now wait for the idle heartbeat (should arrive within ~1s)
let heartbeat_bytes = tokio::time::timeout(Duration::from_secs(3), pub_rx.recv())
.await
.expect("timed out waiting for idle heartbeat")
.expect("channel closed");
#[derive(Deserialize)]
#[allow(dead_code)]
struct HeartbeatDe {
wall_time: f64,
counter_id: i64,
worker_id: String,
}
let heartbeat: HeartbeatDe =
rmp_serde::from_slice(&heartbeat_bytes).expect("heartbeat decode failed");
assert_eq!(
heartbeat.wall_time, 0.0,
"idle heartbeat should have wall_time=0.0"
);
assert_eq!(heartbeat.counter_id, 2, "heartbeat is the second message");
assert_eq!(heartbeat.worker_id, "test-worker");
cancel.cancel();
}
/// Verify that active snapshots reset the heartbeat timer so heartbeats
/// only fire after a period of true inactivity.
#[tokio::test]
async fn test_active_snapshots_suppress_heartbeat() {
let (fpm_tx, mut fpm_rx) = mpsc::unbounded_channel::<ForwardPassSnapshot>();
let (pub_tx, mut pub_rx) = mpsc::unbounded_channel::<Vec<u8>>();
let cancel = CancellationToken::new();
let cancel_clone = cancel.clone();
tokio::spawn(async move {
let mut counter: i64 = 0;
let heartbeat_sleep = tokio::time::sleep(IDLE_HEARTBEAT_INTERVAL);
tokio::pin!(heartbeat_sleep);
loop {
let snapshot = tokio::select! {
biased;
_ = cancel_clone.cancelled() => break,
result = fpm_rx.recv() => {
match result {
Some(snapshot) => {
heartbeat_sleep
.as_mut()
.reset(tokio::time::Instant::now() + IDLE_HEARTBEAT_INTERVAL);
snapshot
}
None => break,
}
}
_ = &mut heartbeat_sleep => {
heartbeat_sleep
.as_mut()
.reset(tokio::time::Instant::now() + IDLE_HEARTBEAT_INTERVAL);
ForwardPassSnapshot::default()
}
};
counter += 1;
if let Ok(bytes) = serialize_fpm(&snapshot, "w", 0, counter) {
let _ = pub_tx.send(bytes);
}
}
});
// Send active snapshots every 500ms for 2 seconds — heartbeat should
// NOT fire during this time since each send resets the timer.
for _ in 0..4 {
let active = ForwardPassSnapshot {
num_decode_requests: 1,
wall_time_secs: 0.01,
..Default::default()
};
fpm_tx.send(active).unwrap();
tokio::time::sleep(Duration::from_millis(500)).await;
}
// Drain all active snapshots
let mut active_count = 0;
while let Ok(Some(bytes)) =
tokio::time::timeout(Duration::from_millis(100), pub_rx.recv()).await
{
#[derive(Deserialize)]
struct Wt {
wall_time: f64,
}
let wt: Wt = rmp_serde::from_slice(&bytes).unwrap();
assert!(
wt.wall_time > 0.0,
"all messages during active period should have wall_time > 0"
);
active_count += 1;
}
assert_eq!(
active_count, 4,
"should have received exactly 4 active snapshots"
);
// Now wait for the heartbeat (should fire ~1s after last active send)
let heartbeat_bytes = tokio::time::timeout(Duration::from_secs(3), pub_rx.recv())
.await
.expect("timed out waiting for heartbeat after active period")
.expect("channel closed");
#[derive(Deserialize)]
struct Wt2 {
wall_time: f64,
}
let hb: Wt2 = rmp_serde::from_slice(&heartbeat_bytes).unwrap();
assert_eq!(hb.wall_time, 0.0, "heartbeat should have wall_time=0.0");
cancel.cancel();
}
/// Verify all 7 expected field names appear in scheduled_requests and
/// 6 in queued_requests — matching the Python schema exactly.
#[test]
fn test_serialize_fpm_field_names() {
let snapshot = ForwardPassSnapshot::default();
let bytes = serialize_fpm(&snapshot, "", 0, 0).unwrap();
// Deserialize the whole thing as nested HashMaps to inspect field names
#[derive(Deserialize)]
struct Wrapper {
scheduled_requests: HashMap<String, serde_json::Value>,
queued_requests: HashMap<String, serde_json::Value>,
}
let w: Wrapper = rmp_serde::from_slice(&bytes).expect("decode failed");
let expected_sched = [
"num_prefill_requests",
"sum_prefill_tokens",
"var_prefill_length",
"sum_prefill_kv_tokens",
"num_decode_requests",
"sum_decode_kv_tokens",
"var_decode_kv_tokens",
];
for key in &expected_sched {
assert!(
w.scheduled_requests.contains_key(*key),
"scheduled_requests missing field: {key}"
);
}
assert_eq!(
w.scheduled_requests.len(),
expected_sched.len(),
"scheduled_requests has unexpected extra fields"
);
let expected_queued = [
"num_prefill_requests",
"sum_prefill_tokens",
"var_prefill_length",
"num_decode_requests",
"sum_decode_kv_tokens",
"var_decode_kv_tokens",
];
for key in &expected_queued {
assert!(
w.queued_requests.contains_key(*key),
"queued_requests missing field: {key}"
);
}
assert_eq!(
w.queued_requests.len(),
expected_queued.len(),
"queued_requests has unexpected extra fields"
);
}
}
......@@ -303,6 +303,8 @@ pub struct MockEngine {
bootstrap_server: Arc<OnceCell<Arc<BootstrapServer>>>,
/// Keep schedulers alive so their CancelGuards don't fire prematurely.
_schedulers: OnceCell<Vec<Box<dyn SchedulerHandle>>>,
/// Forward pass metrics publisher (kept alive for the engine lifetime).
_fpm_publisher: OnceCell<crate::fpm_publisher::FpmDirectPublisher>,
}
impl MockEngine {
......@@ -316,6 +318,7 @@ impl MockEngine {
unset_dp_rank_counter: AtomicU32::new(0),
bootstrap_server: Arc::new(OnceCell::new()),
_schedulers: OnceCell::new(),
_fpm_publisher: OnceCell::new(),
}
}
......@@ -361,11 +364,33 @@ impl MockEngine {
None
};
// Create FPM publisher upfront and get per-dp-rank sink handles.
let worker_id = component.drt().connection_id().to_string();
let fpm_sinks = match crate::fpm_publisher::FpmDirectPublisher::new(
component.clone(),
worker_id,
self.engine_args.dp_size,
)
.await
{
Ok((publisher, sinks)) => {
let _ = self._fpm_publisher.set(publisher);
sinks
}
Err(e) => {
tracing::error!("Failed to start FPM publisher: {e}");
(0..self.engine_args.dp_size)
.map(|_| dynamo_mocker::common::protocols::FpmPublisher::default())
.collect()
}
};
let schedulers = self
.start_schedulers(kv_component, cancel_token.clone())
.start_schedulers(kv_component, cancel_token.clone(), fpm_sinks)
.await;
Self::start_metrics_publishing(&schedulers, component, cancel_token.clone()).await?;
Self::start_metrics_publishing(&schedulers, component.clone(), cancel_token.clone())
.await?;
let _ = self._schedulers.set(schedulers);
......@@ -395,17 +420,18 @@ impl MockEngine {
let _ = senders[dp_rank].send(request);
}
/// Create schedulers and spawn their background tasks for distributing token notifications
/// Create schedulers and spawn their background tasks for distributing token notifications.
async fn start_schedulers(
&self,
component: Option<&Component>,
cancel_token: CancellationToken,
fpm_sinks: Vec<dynamo_mocker::common::protocols::FpmPublisher>,
) -> Vec<Box<dyn SchedulerHandle>> {
let args = &self.engine_args;
let mut schedulers = Vec::<Box<dyn SchedulerHandle>>::new();
let mut senders = Vec::with_capacity(args.dp_size as usize);
for dp_rank in 0..args.dp_size {
for (dp_rank, fpm_publisher) in (0..args.dp_size).zip(fpm_sinks) {
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let (kv_event_publishers, relay_publisher): (
......@@ -493,6 +519,7 @@ impl MockEngine {
Some(output_tx),
kv_event_publishers,
Some(cancel_token.clone()),
fpm_publisher,
);
senders.push(scheduler.request_sender());
......
......@@ -80,6 +80,59 @@ impl KvEventPublishers {
}
}
/// Per-iteration forward pass snapshot, mirroring the Python `ForwardPassMetrics`
/// schema in `components/src/dynamo/common/forward_pass_metrics.py`.
///
/// Produced by the scheduler core after each `execute_pass_internal()` call.
/// The runtime-dependent layer (`lib/llm`) wraps this with identity fields
/// (worker_id, dp_rank, counter_id) and serializes to msgpack for the event plane.
#[derive(Debug, Clone, Default)]
pub struct ForwardPassSnapshot {
// -- scheduled requests (executed this iteration) --
pub num_prefill_requests: u32,
pub sum_prefill_tokens: u64,
pub var_prefill_length: f64,
pub sum_prefill_kv_tokens: u64,
pub num_decode_requests: u32,
pub sum_decode_kv_tokens: u64,
pub var_decode_kv_tokens: f64,
// -- queued requests (waiting, not scheduled) --
pub num_queued_prefill: u32,
pub sum_queued_prefill_tokens: u64,
pub var_queued_prefill_length: f64,
pub num_queued_decode: u32,
pub sum_queued_decode_kv_tokens: u64,
pub var_queued_decode_kv_tokens: f64,
// -- timing --
pub wall_time_secs: f64,
}
/// Trait for publishing forward pass metrics snapshots.
/// This abstracts the FPM publishing pipeline so mocker schedulers remain generic.
pub trait FpmSink: Send + Sync {
fn publish(&self, snapshot: ForwardPassSnapshot) -> anyhow::Result<()>;
}
/// Optional FPM sink used by schedulers.
/// Wraps `Option<Arc<dyn FpmSink>>` for ergonomic passing and no-op default behavior.
#[derive(Clone, Default)]
pub struct FpmPublisher {
sink: Option<Arc<dyn FpmSink>>,
}
impl FpmPublisher {
pub fn new(sink: Option<Arc<dyn FpmSink>>) -> Self {
Self { sink }
}
pub fn publish(&self, snapshot: ForwardPassSnapshot) -> anyhow::Result<()> {
if let Some(sink) = &self.sink {
sink.publish(snapshot)?;
}
Ok(())
}
}
pub type NumBlocks = usize;
/// Represents different block movement operations in the cache
......
......@@ -6,7 +6,9 @@
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::common::protocols::{EngineType, KvEventPublishers, MockEngineArgs, OutputSignal};
use crate::common::protocols::{
EngineType, FpmPublisher, KvEventPublishers, MockEngineArgs, OutputSignal,
};
use crate::scheduler::{Scheduler, SchedulerHandle, SglangScheduler};
/// Create a scheduler for the configured engine type.
......@@ -19,6 +21,7 @@ pub fn create_engine(
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
fpm_publisher: FpmPublisher,
) -> Box<dyn SchedulerHandle> {
match args.engine_type {
EngineType::Vllm => Box::new(Scheduler::new(
......@@ -27,6 +30,7 @@ pub fn create_engine(
output_tx,
kv_event_publishers,
cancellation_token,
fpm_publisher,
)),
EngineType::Sglang => Box::new(SglangScheduler::new(
args,
......@@ -34,6 +38,7 @@ pub fn create_engine(
output_tx,
kv_event_publishers,
cancellation_token,
fpm_publisher,
)),
}
}
......@@ -11,7 +11,7 @@ use tokio::task::JoinSet;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use crate::common::protocols::{DirectRequest, MockEngineArgs, OutputSignal};
use crate::common::protocols::{DirectRequest, FpmPublisher, MockEngineArgs, OutputSignal};
use crate::loadgen::WorkloadDriver;
use crate::replay::{ReplayPrefillLoadEstimator, ReplayRouterMode, TraceSimulationReport};
use crate::scheduler::{AdmissionEvent, EngineScheduler, SchedulerHandle};
......@@ -68,6 +68,7 @@ impl LiveRuntime {
router.sink(worker_idx as _),
Some(cancel_token.clone()),
Some(admission_tx.clone()),
FpmPublisher::default(),
);
senders.push(scheduler.request_sender());
schedulers.push(scheduler);
......
......@@ -6,7 +6,10 @@ use std::sync::{Arc, Mutex};
use anyhow::Result;
use dynamo_kv_router::protocols::{KvCacheEvent, RouterEvent, WorkerId};
use crate::common::protocols::{KvCacheEventSink, KvEventPublishers, RawKvEvent, RawKvEventSink};
use crate::common::protocols::{
ForwardPassSnapshot, FpmPublisher, KvCacheEventSink, KvEventPublishers, RawKvEvent,
RawKvEventSink,
};
/// Captures router-ready events for offline replay and scheduler tests.
///
......@@ -156,3 +159,30 @@ pub(crate) fn publish_deferred_kv_events(
}
}
}
/// Captures FPM snapshots for the live scheduler so it can flush them at the
/// correct pass phase, matching the deferred KV event pattern.
#[derive(Clone, Default)]
pub(crate) struct DeferredFpmBuffer {
snapshots: Arc<Mutex<Vec<ForwardPassSnapshot>>>,
}
impl DeferredFpmBuffer {
pub(crate) fn push(&self, snapshot: ForwardPassSnapshot) {
self.snapshots.lock().unwrap().push(snapshot);
}
pub(crate) fn drain(&self) -> Vec<ForwardPassSnapshot> {
std::mem::take(&mut *self.snapshots.lock().unwrap())
}
}
/// Forwards buffered FPM snapshots to the real sink once the pass reaches
/// the configured visibility point.
pub(crate) fn publish_deferred_fpm(sink: &FpmPublisher, snapshots: Vec<ForwardPassSnapshot>) {
for snapshot in snapshots {
if let Err(error) = sink.publish(snapshot) {
tracing::warn!("Failed to forward buffered FPM snapshot: {error}");
}
}
}
......@@ -8,16 +8,106 @@ mod kv_event_sink;
pub mod sglang;
pub mod vllm;
use crate::common::protocols::{DirectRequest, KvEventPublishers, OutputSignal};
pub use crate::common::protocols::ForwardPassSnapshot;
use crate::common::protocols::{DirectRequest, FpmPublisher, KvEventPublishers, OutputSignal};
use dynamo_kv_router::protocols::RouterEvent;
pub(crate) use kv_event_sink::{
CapturedRouterEventBuffer, capture_deferred_kv_publish_sink, capture_router_event_sink,
publish_deferred_kv_events,
CapturedRouterEventBuffer, DeferredFpmBuffer, capture_deferred_kv_publish_sink,
capture_router_event_sink, publish_deferred_fpm, publish_deferred_kv_events,
};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
/// Welford's online algorithm for count / sum / population-variance.
///
/// Mirrors the Python `WelfordAccumulator` in `forward_pass_metrics.py`.
#[derive(Default)]
pub(crate) struct WelfordAcc {
pub(crate) count: u32,
pub(crate) sum: f64,
mean: f64,
m2: f64,
}
impl WelfordAcc {
pub(crate) fn add(&mut self, v: f64) {
self.count += 1;
self.sum += v;
let delta = v - self.mean;
self.mean += delta / self.count as f64;
let delta2 = v - self.mean;
self.m2 += delta * delta2;
}
pub(crate) fn variance(&self) -> f64 {
if self.count == 0 {
return 0.0;
}
self.m2 / self.count as f64
}
}
/// Build a [`ForwardPassSnapshot`] from engine-agnostic iterators.
///
/// Each engine (vLLM, SGLang) calls this with its own iterators, avoiding
/// duplicated variance/accumulation logic.
///
/// - `scheduled_prefills`: `(prompt_len, prefix_tokens, tokens_computed)` per request
/// - `scheduled_decodes`: `sequence_len` per request
/// - `queued_prefills`: `prompt_len` per waiting prefill request
/// - `queued_decodes`: `kv_tokens` per preempted decode request
pub(crate) fn build_fpm_snapshot(
scheduled_prefills: impl Iterator<Item = (u64, u64, u64)>,
scheduled_decodes: impl Iterator<Item = u64>,
queued_prefills: impl Iterator<Item = u64>,
queued_decodes: impl Iterator<Item = u64>,
wall_time_secs: f64,
) -> ForwardPassSnapshot {
let mut prefill_acc = WelfordAcc::default();
let mut decode_acc = WelfordAcc::default();
let mut sum_prefill_tokens: u64 = 0;
let mut sum_prefill_kv_tokens: u64 = 0;
for (prompt_len, prefix_tokens, tokens_computed) in scheduled_prefills {
sum_prefill_tokens += tokens_computed;
sum_prefill_kv_tokens += prefix_tokens;
prefill_acc.add(prompt_len as f64);
}
for sequence_len in scheduled_decodes {
decode_acc.add(sequence_len as f64);
}
let mut queued_prefill_acc = WelfordAcc::default();
let mut queued_decode_acc = WelfordAcc::default();
for prompt_len in queued_prefills {
queued_prefill_acc.add(prompt_len as f64);
}
for kv_tokens in queued_decodes {
queued_decode_acc.add(kv_tokens as f64);
}
ForwardPassSnapshot {
num_prefill_requests: prefill_acc.count,
sum_prefill_tokens,
var_prefill_length: prefill_acc.variance(),
sum_prefill_kv_tokens,
num_decode_requests: decode_acc.count,
sum_decode_kv_tokens: decode_acc.sum as u64,
var_decode_kv_tokens: decode_acc.variance(),
num_queued_prefill: queued_prefill_acc.count,
sum_queued_prefill_tokens: queued_prefill_acc.sum as u64,
var_queued_prefill_length: queued_prefill_acc.variance(),
num_queued_decode: queued_decode_acc.count,
sum_queued_decode_kv_tokens: queued_decode_acc.sum as u64,
var_queued_decode_kv_tokens: queued_decode_acc.variance(),
wall_time_secs,
}
}
pub(crate) use sglang::SglangCore;
pub use sglang::SglangScheduler;
pub(crate) use vllm::VllmCore;
......@@ -41,6 +131,8 @@ pub(crate) struct EnginePassResult {
pub(crate) router_event_visibility: RouterEventVisibility,
/// Router-visible KV events emitted during this pass.
pub(crate) kv_events: Vec<RouterEvent>,
/// Forward pass metrics snapshot for this iteration.
pub(crate) fpm: Option<ForwardPassSnapshot>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
......@@ -112,6 +204,7 @@ impl EngineScheduler {
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
fpm_publisher: FpmPublisher,
) -> Self {
match args.engine_type {
crate::common::protocols::EngineType::Vllm => {
......@@ -122,6 +215,7 @@ impl EngineScheduler {
kv_event_publishers,
cancellation_token,
admission_tx,
fpm_publisher,
))
}
crate::common::protocols::EngineType::Sglang => {
......@@ -132,6 +226,7 @@ impl EngineScheduler {
kv_event_publishers,
cancellation_token,
admission_tx,
fpm_publisher,
))
}
}
......@@ -179,3 +274,58 @@ pub trait SchedulerHandle: Send + Sync {
/// Shared test utilities for scheduler stress tests.
#[cfg(test)]
pub(crate) mod test_utils;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn welford_acc_empty() {
let acc = WelfordAcc::default();
assert_eq!(acc.count, 0);
assert_eq!(acc.sum, 0.0);
assert_eq!(acc.variance(), 0.0);
}
#[test]
fn welford_acc_single_value() {
let mut acc = WelfordAcc::default();
acc.add(42.0);
assert_eq!(acc.count, 1);
assert_eq!(acc.sum, 42.0);
assert_eq!(acc.variance(), 0.0);
}
#[test]
fn welford_acc_population_variance() {
let mut acc = WelfordAcc::default();
// Values: 2, 4, 4, 4, 5, 5, 7, 9
// Mean = 5, Population variance = 4.0
for v in [2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0] {
acc.add(v);
}
assert_eq!(acc.count, 8);
assert_eq!(acc.sum, 40.0);
assert!((acc.variance() - 4.0).abs() < 1e-10);
}
#[test]
fn welford_acc_matches_python() {
// Reproduce the Python WelfordAccumulator behavior:
// values = [100, 200, 300], mean = 200,
// population variance = ((100-200)^2 + (200-200)^2 + (300-200)^2) / 3
// = (10000 + 0 + 10000) / 3 = 6666.666...
let mut acc = WelfordAcc::default();
acc.add(100.0);
acc.add(200.0);
acc.add(300.0);
assert_eq!(acc.count, 3);
assert_eq!(acc.sum, 600.0);
let expected = 20000.0 / 3.0;
assert!(
(acc.variance() - expected).abs() < 1e-10,
"expected {expected}, got {}",
acc.variance()
);
}
}
......@@ -17,7 +17,8 @@ use super::policy::apply_schedule_policy;
use super::prefill::get_new_batch_prefill;
use super::request::{SglangRequest, direct_to_sglang};
use crate::scheduler::{
CapturedRouterEventBuffer, EnginePassResult, RouterEventVisibility, capture_router_event_sink,
CapturedRouterEventBuffer, EnginePassResult, RouterEventVisibility, build_fpm_snapshot,
capture_router_event_sink,
};
pub(crate) struct SglangCore {
......@@ -130,6 +131,9 @@ impl SglangCore {
}
}
// Capture per-request prefill FPM data before dispersing can_run.
let prefill_fpm = admit.prefill_fpm;
let batch_size = admit.can_run.len();
let mean_isl = if batch_size > 0 {
admit.total_isl / batch_size
......@@ -153,6 +157,13 @@ impl SglangCore {
}
}
// Capture scheduled decode data before the decode step modifies running.
let scheduled_decode_lens: Vec<u64> = self
.running
.iter()
.map(|req| req.current_sequence_len() as u64)
.collect();
let decode_start_ms = now_ms + prefill_time.as_secs_f64() * 1000.0;
let mut decode = simulate_decode_step(
&mut self.running,
......@@ -178,6 +189,27 @@ impl SglangCore {
self.new_token_ratio = (self.new_token_ratio - self.config.new_token_ratio_decay_step)
.max(self.config.min_new_token_ratio);
// Build FPM snapshot now that all state has settled.
let fpm = build_fpm_snapshot(
prefill_fpm.iter().map(|p| {
(
p.prompt_len as u64,
p.prefix_tokens as u64,
p.tokens_computed as u64,
)
}),
scheduled_decode_lens.into_iter(),
self.waiting
.iter()
.filter(|req| req.output_len() == 0)
.map(|req| req.prompt_len() as u64),
self.waiting
.iter()
.filter(|req| req.output_len() > 0)
.map(|req| req.current_sequence_len() as u64),
(decode.end_ms - now_ms) / 1000.0,
);
debug_assert_sglang_scheduler_state(&self.waiting, &self.running, self.config.block_size);
EnginePassResult {
end_ms: decode.end_ms,
......@@ -195,6 +227,7 @@ impl SglangCore {
.as_ref()
.map(CapturedRouterEventBuffer::drain)
.unwrap_or_default(),
fpm: Some(fpm),
}
}
......
......@@ -7,11 +7,13 @@ use std::time::Instant;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::common::protocols::{DirectRequest, KvEventPublishers, MockEngineArgs, OutputSignal};
use crate::common::protocols::{
DirectRequest, FpmPublisher, KvEventPublishers, MockEngineArgs, OutputSignal,
};
use crate::common::utils::sleep_until_precise;
use crate::scheduler::{
AdmissionEvent, MockerMetrics, RouterEventVisibility, SchedulerHandle,
capture_deferred_kv_publish_sink, publish_deferred_kv_events,
AdmissionEvent, DeferredFpmBuffer, MockerMetrics, RouterEventVisibility, SchedulerHandle,
capture_deferred_kv_publish_sink, publish_deferred_fpm, publish_deferred_kv_events,
};
use super::core::SglangCore;
......@@ -39,6 +41,7 @@ impl SglangScheduler {
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
fpm_publisher: FpmPublisher,
) -> Self {
Self::new_internal(
args,
......@@ -47,6 +50,7 @@ impl SglangScheduler {
kv_event_publishers,
cancellation_token,
None,
fpm_publisher,
)
}
......@@ -57,6 +61,7 @@ impl SglangScheduler {
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
fpm_publisher: FpmPublisher,
) -> Self {
Self::new_internal(
args,
......@@ -65,6 +70,7 @@ impl SglangScheduler {
kv_event_publishers,
cancellation_token,
admission_tx,
fpm_publisher,
)
}
......@@ -75,6 +81,7 @@ impl SglangScheduler {
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
fpm_publisher: FpmPublisher,
) -> Self {
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
let total_blocks = args.num_gpu_blocks as u64;
......@@ -89,6 +96,7 @@ impl SglangScheduler {
tokio::spawn(async move {
let (deferred_kv_events, buffering_publishers) =
capture_deferred_kv_publish_sink(kv_event_publishers.raw_enabled());
let deferred_fpm = DeferredFpmBuffer::default();
let mut core = SglangCore::new_with_sink(args, dp_rank, buffering_publishers);
loop {
......@@ -111,8 +119,12 @@ impl SglangScheduler {
let _ = admission_tx.send(admission.clone());
}
}
if let Some(fpm) = pass.fpm {
deferred_fpm.push(fpm);
}
if pass.router_event_visibility == RouterEventVisibility::PassStart {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
publish_deferred_fpm(&fpm_publisher, deferred_fpm.drain());
}
let total_time = std::time::Duration::from_secs_f64(pass.end_ms / 1000.0);
if total_time > std::time::Duration::ZERO {
......@@ -120,10 +132,12 @@ impl SglangScheduler {
}
if pass.router_event_visibility == RouterEventVisibility::PassEnd {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
publish_deferred_fpm(&fpm_publisher, deferred_fpm.drain());
}
let active_decode_blocks = pass.active_decode_blocks;
flush_output_signals(&output_tx, pass.output_signals);
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
publish_deferred_fpm(&fpm_publisher, deferred_fpm.drain());
let _ = metrics_tx.send(MockerMetrics::new(
dp_rank,
active_decode_blocks,
......
......@@ -8,12 +8,21 @@ use super::config::{SglangConfig, ceil_to_block};
use super::request::SglangRequest;
use crate::kv_manager::SglangKvManager;
/// Per-request prefill data needed for FPM snapshot construction.
pub(super) struct PrefillFpmItem {
pub(super) prompt_len: usize,
pub(super) tokens_computed: usize,
pub(super) prefix_tokens: usize,
}
pub(super) struct AdmitResult {
pub(super) can_run: Vec<SglangRequest>,
pub(super) admissions: Vec<AdmissionEvent>,
pub(super) total_isl: usize,
pub(super) total_prefix: usize,
pub(super) oom: bool,
/// Per-request prefill info for building FPM snapshots.
pub(super) prefill_fpm: Vec<PrefillFpmItem>,
}
pub(super) fn get_new_batch_prefill(
......@@ -50,6 +59,7 @@ pub(super) fn get_new_batch_prefill(
let mut can_run = Vec::new();
let mut admissions = Vec::new();
let mut prefill_fpm = Vec::new();
let mut rejected = VecDeque::new();
let mut oom = false;
let mut total_isl = 0usize;
......@@ -139,6 +149,11 @@ pub(super) fn get_new_batch_prefill(
uuid: req.uuid,
reused_input_tokens: alloc.prefix_len,
});
prefill_fpm.push(PrefillFpmItem {
prompt_len: req.prompt_len(),
tokens_computed: chunk_tokens,
prefix_tokens: alloc.prefix_len,
});
total_isl += chunk_end;
total_prefix += alloc.prefix_len;
......@@ -162,5 +177,6 @@ pub(super) fn get_new_batch_prefill(
total_isl,
total_prefix,
oom,
prefill_fpm,
}
}
......@@ -19,7 +19,8 @@ use super::policy::apply_schedule_policy;
use super::prefill::get_new_batch_prefill;
use super::request::SglangRequest;
use crate::common::protocols::{
DirectRequest, EngineType, KvEventPublishers, MockEngineArgs, OutputSignal, SglangArgs,
DirectRequest, EngineType, FpmPublisher, KvEventPublishers, MockEngineArgs, OutputSignal,
SglangArgs,
};
use crate::kv_manager::SglangKvManager;
use crate::scheduler::test_utils::{
......@@ -95,8 +96,14 @@ mod scheduling {
.unwrap();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let scheduler =
SglangScheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
let scheduler = SglangScheduler::new(
args,
0,
Some(output_tx),
KvEventPublishers::default(),
None,
FpmPublisher::default(),
);
let num_requests = 5;
let max_output = 3;
......@@ -616,8 +623,14 @@ mod router_events {
}))
.build()
.unwrap();
let scheduler =
SglangScheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
let scheduler = SglangScheduler::new(
args,
0,
Some(output_tx),
KvEventPublishers::default(),
None,
FpmPublisher::default(),
);
assert_sglang_scheduler_completes_all(
&scheduler,
......@@ -836,6 +849,7 @@ mod router_events {
Some(output_tx),
KvEventPublishers::new(Some(sink.clone()), None),
None,
FpmPublisher::default(),
);
for _ in 0..8 {
......@@ -908,3 +922,408 @@ mod router_events {
assert_eq!(signal.handoff_delay_ms, Some(8.0));
}
}
mod forward_pass_metrics {
use super::*;
fn fpm_args() -> MockEngineArgs {
MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.num_gpu_blocks(16)
.max_num_batched_tokens(Some(16))
.max_num_seqs(Some(4))
.speedup_ratio(0.0)
.sglang(Some(SglangArgs {
page_size: Some(4),
chunked_prefill_size: Some(16),
..Default::default()
}))
.build()
.unwrap()
}
#[test]
fn test_fpm_single_prefill_request() {
let mut core = SglangCore::new(fpm_args());
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
let pass = core.execute_pass(&mut collector, 0.0);
let fpm = pass.fpm.expect("FPM should be present");
assert_eq!(fpm.num_prefill_requests, 1);
assert!(
fpm.sum_prefill_tokens > 0,
"prefill tokens should be computed"
);
// In SGLang, after prefill the request immediately joins running and
// participates in the decode step of the same pass.
assert_eq!(fpm.num_decode_requests, 1);
assert_eq!(fpm.num_queued_prefill, 0);
assert_eq!(fpm.num_queued_decode, 0);
assert!(fpm.wall_time_secs > 0.0);
}
#[test]
fn test_fpm_prefill_and_decode_mixed_batch() {
let mut core = SglangCore::new(fpm_args());
// r1: 4-token prompt, 3 output tokens
core.receive(DirectRequest {
tokens: (0..4).collect(),
max_output_tokens: 3,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
// Pass 1: prefill r1
let pass1 = core.execute_pass(&mut collector, 0.0);
let fpm1 = pass1.fpm.expect("FPM should be present");
assert_eq!(fpm1.num_prefill_requests, 1);
// r2: arriving while r1 is decoding
core.receive(DirectRequest {
tokens: (100..104).collect(),
max_output_tokens: 3,
uuid: Some(Uuid::from_u128(2)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
// Pass 2: r2 prefill + decode step runs on all running (r1 + r2)
let pass2 = core.execute_pass(&mut collector, 1.0);
let fpm2 = pass2.fpm.expect("FPM should be present");
assert_eq!(fpm2.num_prefill_requests, 1, "r2 is prefilling");
// In SGLang, after r2 prefill completes it joins running alongside r1,
// so the decode step sees both.
assert_eq!(fpm2.num_decode_requests, 2, "r1 + r2 both in decode step");
assert!(
fpm2.sum_decode_kv_tokens > 0,
"decode requests should have KV context"
);
}
#[test]
fn test_fpm_empty_pass_is_zeroed() {
let mut core = SglangCore::new(fpm_args());
// Submit and fully drain a request first so the core isn't empty
// (empty core blocks in receive_requests in live mode, but
// execute_pass_internal works fine on an empty core).
let pass = core.execute_hidden_pass(0.0);
let fpm = pass.fpm.expect("FPM should be present even for empty pass");
assert_eq!(fpm.num_prefill_requests, 0);
assert_eq!(fpm.num_decode_requests, 0);
assert_eq!(fpm.num_queued_prefill, 0);
assert_eq!(fpm.num_queued_decode, 0);
assert_eq!(fpm.sum_prefill_tokens, 0);
assert_eq!(fpm.sum_decode_kv_tokens, 0);
}
#[test]
fn test_fpm_queued_requests() {
// Very limited KV to force queuing.
let args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.num_gpu_blocks(4)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(2))
.speedup_ratio(0.0)
.sglang(Some(SglangArgs {
page_size: Some(4),
chunked_prefill_size: Some(8),
..Default::default()
}))
.build()
.unwrap();
let mut core = SglangCore::new(args);
// Two 8-token requests but limited KV
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
core.receive(DirectRequest {
tokens: (100..108).collect(),
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(2)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
let pass = core.execute_pass(&mut collector, 0.0);
let fpm = pass.fpm.expect("FPM should be present");
let total_scheduled = fpm.num_prefill_requests + fpm.num_decode_requests;
assert!(
total_scheduled >= 1,
"at least one request should be scheduled"
);
// With tight KV, the second request should be queued.
let total_queued = fpm.num_queued_prefill + fpm.num_queued_decode;
assert!(
total_queued >= 1,
"at least one request should be queued, got {total_queued}"
);
}
#[test]
fn test_fpm_var_prefill_length_with_multiple_requests() {
let args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.num_gpu_blocks(32)
.max_num_batched_tokens(Some(32))
.max_num_seqs(Some(4))
.speedup_ratio(0.0)
.sglang(Some(SglangArgs {
page_size: Some(4),
chunked_prefill_size: Some(32),
..Default::default()
}))
.build()
.unwrap();
let mut core = SglangCore::new(args);
// Two prefill requests with different prompt lengths
core.receive(DirectRequest {
tokens: (0..4).collect(), // prompt_len = 4
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
core.receive(DirectRequest {
tokens: (100..112).collect(), // prompt_len = 12
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(2)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
let pass = core.execute_pass(&mut collector, 0.0);
let fpm = pass.fpm.expect("FPM should be present");
assert_eq!(fpm.num_prefill_requests, 2);
// Population variance of [4, 12]: mean=8, var=((4-8)^2+(12-8)^2)/2 = 16
assert!(
(fpm.var_prefill_length - 16.0).abs() < 1e-6,
"expected var=16.0, got {}",
fpm.var_prefill_length
);
}
#[test]
fn test_fpm_chunked_prefill_reports_chunk_not_full_prompt() {
// With chunked_prefill_size=8 and a 16-token prompt, the request
// should be chunked. Each pass should report only the chunk size
// in sum_prefill_tokens, not the full prompt length.
let args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.num_gpu_blocks(32)
.max_num_batched_tokens(Some(32))
.max_num_seqs(Some(4))
.speedup_ratio(0.0)
.sglang(Some(SglangArgs {
page_size: Some(4),
chunked_prefill_size: Some(8),
..Default::default()
}))
.build()
.unwrap();
let mut core = SglangCore::new(args);
core.receive(DirectRequest {
tokens: (0..16).collect(),
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
// Pass 1: first chunk
let pass1 = core.execute_pass(&mut collector, 0.0);
let fpm1 = pass1.fpm.expect("FPM should be present");
assert_eq!(fpm1.num_prefill_requests, 1);
assert!(
fpm1.sum_prefill_tokens <= 8,
"chunk should be at most 8 tokens, got {}",
fpm1.sum_prefill_tokens
);
assert!(fpm1.sum_prefill_tokens > 0);
// Pass 2: remaining chunk
let pass2 = core.execute_pass(&mut collector, 1.0);
let fpm2 = pass2.fpm.expect("FPM should be present");
assert_eq!(fpm2.num_prefill_requests, 1, "still prefilling");
assert!(
fpm2.sum_prefill_tokens <= 8,
"second chunk should also be at most 8 tokens, got {}",
fpm2.sum_prefill_tokens
);
// Total across both chunks should equal the full prompt length
assert_eq!(
fpm1.sum_prefill_tokens + fpm2.sum_prefill_tokens,
16,
"total prefill tokens across chunks should equal full prompt"
);
}
#[test]
fn test_fpm_retracted_decode_becomes_queued_decode() {
// Very tight KV to force decode retraction. Fill the KV with running
// requests, then the decode step should retract some, and those should
// appear as queued decodes in the next pass's FPM.
let args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.num_gpu_blocks(6) // 24 tokens — very tight
.max_num_batched_tokens(Some(32))
.max_num_seqs(Some(4))
.speedup_ratio(0.0)
.sglang(Some(SglangArgs {
page_size: Some(4),
chunked_prefill_size: Some(32),
..Default::default()
}))
.build()
.unwrap();
let mut core = SglangCore::new(args);
let mut collector = crate::replay::TraceCollector::default();
// Two requests with 4-token prompts and long outputs to fill KV
core.receive(DirectRequest {
tokens: (0..4).collect(),
max_output_tokens: 20,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
core.receive(DirectRequest {
tokens: (100..104).collect(),
max_output_tokens: 20,
uuid: Some(Uuid::from_u128(2)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
// Run several passes to build up KV pressure
for i in 0..4 {
core.execute_pass(&mut collector, i as f64);
}
// Add a third request to increase memory pressure
core.receive(DirectRequest {
tokens: (200..212).collect(), // 12 tokens
max_output_tokens: 10,
uuid: Some(Uuid::from_u128(3)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
// Run more passes — at some point retraction should occur
let mut saw_queued_decode = false;
for i in 4..10 {
let pass = core.execute_pass(&mut collector, i as f64);
let fpm = pass.fpm.expect("FPM should be present");
if fpm.num_queued_decode > 0 {
saw_queued_decode = true;
assert!(
fpm.sum_queued_decode_kv_tokens > 0,
"retracted decode should have KV context"
);
break;
}
}
// If retraction didn't happen (KV was sufficient), that's also valid —
// just verify we always get Some(fpm).
if !saw_queued_decode {
// Verify the requests completed or are still running with valid FPM
let pass = core.execute_hidden_pass(10.0);
assert!(pass.fpm.is_some(), "FPM should always be present");
}
}
#[tokio::test]
async fn test_fpm_sent_through_sink() {
use std::sync::Arc;
use crate::common::protocols::FpmSink;
use crate::scheduler::test_utils::CapturingFpmSink;
let args = MockEngineArgs::builder()
.engine_type(EngineType::Sglang)
.block_size(4)
.num_gpu_blocks(16)
.max_num_batched_tokens(Some(16))
.max_num_seqs(Some(4))
.speedup_ratio(0.0)
.sglang(Some(SglangArgs {
page_size: Some(4),
chunked_prefill_size: Some(16),
..Default::default()
}))
.build()
.unwrap();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let fpm_sink = Arc::new(CapturingFpmSink::default());
let fpm_publisher = FpmPublisher::new(Some(fpm_sink.clone() as Arc<dyn FpmSink>));
let scheduler = SglangScheduler::new(
args,
0,
Some(output_tx),
KvEventPublishers::default(),
None,
fpm_publisher,
);
scheduler.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
// Wait for at least one output signal — ensures the scheduler has
// completed at least one pass and drained the deferred FPM buffer.
tokio::time::timeout(Duration::from_secs(5), output_rx.recv())
.await
.expect("timed out waiting for output")
.expect("output channel closed");
let snapshots = fpm_sink.take();
assert!(
!snapshots.is_empty(),
"should have received at least one FPM snapshot"
);
let fpm = &snapshots[0];
assert_eq!(fpm.num_prefill_requests, 1);
assert!(fpm.sum_prefill_tokens > 0);
assert!(fpm.wall_time_secs > 0.0);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::sync::{Arc, Mutex};
use anyhow::anyhow;
use dynamo_kv_router::indexer::{
......@@ -17,8 +17,8 @@ use tokio::task::JoinHandle;
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
use super::{DirectRequest, OutputSignal, SchedulerHandle};
use crate::common::protocols::KvCacheEventSink;
use super::{DirectRequest, ForwardPassSnapshot, OutputSignal, SchedulerHandle};
use crate::common::protocols::{FpmSink, KvCacheEventSink};
pub(crate) struct RouterIndexerHarness {
indexer: Arc<LocalKvIndexer>,
......@@ -207,6 +207,25 @@ pub(crate) fn removed_event_count(events: &[RouterEvent]) -> usize {
.count()
}
/// Test sink that captures FPM snapshots for assertion.
#[derive(Default)]
pub(crate) struct CapturingFpmSink {
snapshots: Mutex<Vec<ForwardPassSnapshot>>,
}
impl FpmSink for CapturingFpmSink {
fn publish(&self, snapshot: ForwardPassSnapshot) -> anyhow::Result<()> {
self.snapshots.lock().unwrap().push(snapshot);
Ok(())
}
}
impl CapturingFpmSink {
pub(crate) fn take(&self) -> Vec<ForwardPassSnapshot> {
std::mem::take(&mut *self.snapshots.lock().unwrap())
}
}
/// Send `num_requests` to a scheduler, collect all output signals, and assert
/// that the scheduler produces exactly `num_requests * max_output_tokens` signals
/// and returns to idle (0 active decode blocks).
......
......@@ -19,8 +19,8 @@ use crate::common::utils::compute_prefill_handoff_delay_ms;
use crate::kv_manager::KvManager;
use crate::replay::TraceCollector;
use crate::scheduler::{
AdmissionEvent, CapturedRouterEventBuffer, EnginePassResult, RouterEventVisibility,
capture_router_event_sink,
AdmissionEvent, CapturedRouterEventBuffer, EnginePassResult, ForwardPassSnapshot,
RouterEventVisibility, build_fpm_snapshot, capture_router_event_sink,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
......@@ -56,6 +56,12 @@ struct ScheduledWork {
total_tokens: usize,
prompt_tokens: usize,
prefix_tokens: usize,
/// Full prompt length, captured at schedule time for FPM variance calculation.
prompt_len: usize,
/// Total sequence length (prompt + generated) at schedule time, used for
/// decode KV context in FPM. Captured here because completed requests are
/// removed from state before `compute_fpm` runs.
sequence_len: usize,
}
enum ScheduleOutcome {
......@@ -380,6 +386,8 @@ impl VllmCore {
let (decode_time, output_signals) = self.emit_ready_tokens(collector, decode_start_ms);
let end_ms = decode_start_ms + decode_time.as_secs_f64() * 1000.0;
let fpm = self.compute_fpm(&scheduled, (end_ms - now_ms) / 1000.0);
debug_assert_vllm_scheduler_state(&self.state);
EnginePassResult {
end_ms,
......@@ -393,6 +401,7 @@ impl VllmCore {
.as_ref()
.map(CapturedRouterEventBuffer::drain)
.unwrap_or_default(),
fpm: Some(fpm),
}
}
......@@ -406,6 +415,52 @@ impl VllmCore {
self.state.complete(&uuid);
}
/// Compute a forward pass metrics snapshot from the just-completed pass.
///
/// `scheduled` contains the work items that were scheduled in this iteration.
/// Per-request metadata (prompt_len, sequence_len) is captured in `ScheduledWork`
/// at schedule time, so this method does not depend on `self.state.requests` for
/// scheduled requests — completed requests may have already been removed.
/// Queue metrics are derived from `self.state.waiting` at the moment of the call.
fn compute_fpm(
&self,
scheduled: &FxHashMap<Uuid, ScheduledWork>,
wall_time_secs: f64,
) -> ForwardPassSnapshot {
let scheduled_prefills = scheduled.values().filter_map(|work| {
(work.prompt_tokens > 0).then_some((
work.prompt_len as u64,
work.prefix_tokens as u64,
work.total_tokens as u64,
))
});
let scheduled_decodes = scheduled
.values()
.filter_map(|work| (work.prompt_tokens == 0).then_some(work.sequence_len as u64));
let queued_prefills = self.state.waiting.iter().filter_map(|uuid| {
let request = self.state.requests.get(uuid)?;
matches!(request.status, RequestStatus::Waiting)
.then_some(request.sequence.num_input_tokens() as u64)
});
let queued_decodes = self.state.waiting.iter().filter_map(|uuid| {
let request = self.state.requests.get(uuid)?;
matches!(request.status, RequestStatus::Preempted).then_some(
(request.sequence.num_input_tokens() + request.sequence.generated_tokens()) as u64,
)
});
build_fpm_snapshot(
scheduled_prefills,
scheduled_decodes,
queued_prefills,
queued_decodes,
wall_time_secs,
)
}
#[allow(clippy::too_many_arguments)]
fn schedule_request(
&mut self,
......@@ -539,12 +594,20 @@ impl VllmCore {
let prompt_after = actual_computed_after.min(prompt_len);
let prompt_tokens = prompt_after.saturating_sub(prompt_before);
let sequence_len = self
.state
.requests
.get(&uuid)
.map(|r| r.sequence.len())
.unwrap_or(0);
scheduled.insert(
uuid,
ScheduledWork {
total_tokens: tokens_used,
prompt_tokens,
prefix_tokens: prompt_before,
prompt_len,
sequence_len,
},
);
if prompt_tokens > 0 && self.args.worker_type != WorkerType::Decode {
......
......@@ -7,11 +7,13 @@ use std::time::Instant;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use crate::common::protocols::{DirectRequest, KvEventPublishers, MockEngineArgs, OutputSignal};
use crate::common::protocols::{
DirectRequest, FpmPublisher, KvEventPublishers, MockEngineArgs, OutputSignal,
};
use crate::common::utils::sleep_until_precise;
use crate::scheduler::{
AdmissionEvent, RouterEventVisibility, SchedulerHandle, capture_deferred_kv_publish_sink,
publish_deferred_kv_events,
AdmissionEvent, DeferredFpmBuffer, RouterEventVisibility, SchedulerHandle,
capture_deferred_kv_publish_sink, publish_deferred_fpm, publish_deferred_kv_events,
};
use super::core::VllmCore;
......@@ -66,6 +68,7 @@ impl Scheduler {
output_tx: Option<mpsc::UnboundedSender<Vec<OutputSignal>>>,
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
fpm_publisher: FpmPublisher,
) -> Self {
Self::new_internal(
args,
......@@ -74,6 +77,7 @@ impl Scheduler {
kv_event_publishers,
cancellation_token,
None,
fpm_publisher,
)
}
......@@ -84,6 +88,7 @@ impl Scheduler {
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
fpm_publisher: FpmPublisher,
) -> Self {
Self::new_internal(
args,
......@@ -92,6 +97,7 @@ impl Scheduler {
kv_event_publishers,
cancellation_token,
admission_tx,
fpm_publisher,
)
}
......@@ -102,6 +108,7 @@ impl Scheduler {
kv_event_publishers: KvEventPublishers,
cancellation_token: Option<CancellationToken>,
admission_tx: Option<mpsc::UnboundedSender<AdmissionEvent>>,
fpm_publisher: FpmPublisher,
) -> Self {
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
let total_blocks = args.num_gpu_blocks as u64;
......@@ -115,6 +122,7 @@ impl Scheduler {
tokio::spawn(async move {
let (deferred_kv_events, buffering_publishers) =
capture_deferred_kv_publish_sink(kv_event_publishers.raw_enabled());
let deferred_fpm = DeferredFpmBuffer::default();
let mut core = VllmCore::new_with_sink(args, dp_rank, buffering_publishers);
loop {
......@@ -128,17 +136,23 @@ impl Scheduler {
let iteration_start = Instant::now();
let pass = core.execute_pass_internal(None, 0.0, admission_tx.as_ref());
let total_time = std::time::Duration::from_secs_f64(pass.end_ms / 1000.0);
if let Some(fpm) = pass.fpm {
deferred_fpm.push(fpm);
}
if pass.router_event_visibility == RouterEventVisibility::PassStart {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
publish_deferred_fpm(&fpm_publisher, deferred_fpm.drain());
}
if total_time > std::time::Duration::ZERO {
sleep_until_precise(iteration_start + total_time).await;
}
if pass.router_event_visibility == RouterEventVisibility::PassEnd {
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
publish_deferred_fpm(&fpm_publisher, deferred_fpm.drain());
}
flush_output_signals(&mut core, &output_tx, pass.output_signals);
publish_deferred_kv_events(&kv_event_publishers, deferred_kv_events.drain());
publish_deferred_fpm(&fpm_publisher, deferred_fpm.drain());
let _ = metrics_tx.send(MockerMetrics::new(
dp_rank,
core.kv_manager.num_active_blocks() as u64,
......
......@@ -12,7 +12,7 @@ use tokio::time::interval;
use uuid::Uuid;
use crate::common::protocols::{
DirectRequest, KvCacheEventSink, KvEventPublishers, MockEngineArgs, OutputSignal,
DirectRequest, FpmPublisher, KvCacheEventSink, KvEventPublishers, MockEngineArgs, OutputSignal,
PreemptionMode, RawKvEvent, RawKvEventSink,
};
use crate::common::sequence::ActiveSequence;
......@@ -577,8 +577,14 @@ mod live_scheduler {
.build()
.unwrap();
let scheduler =
Scheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
let scheduler = Scheduler::new(
args,
0,
Some(output_tx),
KvEventPublishers::default(),
None,
FpmPublisher::default(),
);
crate::scheduler::test_utils::assert_scheduler_completes_all(
&scheduler,
......@@ -608,8 +614,14 @@ mod live_scheduler {
.build()
.unwrap();
let scheduler =
Scheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
let scheduler = Scheduler::new(
args,
0,
Some(output_tx),
KvEventPublishers::default(),
None,
FpmPublisher::default(),
);
let identical_tokens: Vec<u32> = (0..token_length).collect();
for _ in 0..num_requests {
......@@ -660,8 +672,14 @@ mod live_scheduler {
.build()
.unwrap();
let scheduler =
Scheduler::new(args, 0, Some(output_tx), KvEventPublishers::default(), None);
let scheduler = Scheduler::new(
args,
0,
Some(output_tx),
KvEventPublishers::default(),
None,
FpmPublisher::default(),
);
scheduler.receive(DirectRequest {
tokens: (0..256).collect(),
max_output_tokens: 200,
......@@ -717,6 +735,7 @@ mod live_scheduler {
Some(output_tx),
KvEventPublishers::new(None, Some(sink.clone())),
None,
FpmPublisher::default(),
);
scheduler.receive(DirectRequest {
......@@ -771,6 +790,7 @@ mod live_scheduler {
Some(output_tx),
KvEventPublishers::new(Some(sink.clone()), None),
None,
FpmPublisher::default(),
);
for _ in 0..8 {
......@@ -813,3 +833,439 @@ mod live_scheduler {
harness.shutdown();
}
}
mod forward_pass_metrics {
use super::*;
/// Helper to build args with specific parameters for FPM tests.
fn fpm_args() -> MockEngineArgs {
MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(16)
.max_num_batched_tokens(Some(16))
.max_num_seqs(Some(4))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap()
}
#[test]
fn test_fpm_single_prefill_request() {
let mut core = VllmCore::new(fpm_args());
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
let pass = core.execute_pass(&mut collector, 0.0);
let fpm = pass.fpm.expect("FPM should be present");
assert_eq!(fpm.num_prefill_requests, 1);
assert_eq!(fpm.sum_prefill_tokens, 8, "all 8 prompt tokens computed");
assert_eq!(fpm.sum_prefill_kv_tokens, 0, "no prefix cache");
assert_eq!(fpm.num_decode_requests, 0);
assert_eq!(fpm.num_queued_prefill, 0);
assert_eq!(fpm.num_queued_decode, 0);
assert!(fpm.wall_time_secs > 0.0);
}
#[test]
fn test_fpm_prefill_and_decode_mixed_batch() {
let mut core = VllmCore::new(fpm_args());
// r1: 4-token prompt, 3 output tokens
let r1 = Uuid::from_u128(1);
core.receive(DirectRequest {
tokens: (0..4).collect(),
max_output_tokens: 3,
uuid: Some(r1),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
// Pass 1: prefill r1 (4 tokens) + first decode token
let pass1 = core.execute_pass(&mut collector, 0.0);
let fpm1 = pass1.fpm.expect("FPM should be present");
assert_eq!(fpm1.num_prefill_requests, 1);
assert_eq!(fpm1.sum_prefill_tokens, 4);
// r2: 4-token prompt arriving while r1 is decoding
let r2 = Uuid::from_u128(2);
core.receive(DirectRequest {
tokens: (100..104).collect(),
max_output_tokens: 3,
uuid: Some(r2),
dp_rank: 0,
arrival_timestamp_ms: None,
});
// Pass 2: r1 decode + r2 prefill (mixed batch)
let pass2 = core.execute_pass(&mut collector, 1.0);
let fpm2 = pass2.fpm.expect("FPM should be present");
assert_eq!(fpm2.num_prefill_requests, 1, "r2 is prefilling");
assert_eq!(fpm2.num_decode_requests, 1, "r1 is decoding");
assert_eq!(fpm2.sum_prefill_tokens, 4);
assert!(
fpm2.sum_decode_kv_tokens > 0,
"decode request should have KV context"
);
}
#[test]
fn test_fpm_completed_requests_metrics_correct() {
// This tests the fix: completed requests should still contribute
// correct metrics even though they're removed from state before
// compute_fpm runs.
let mut core = VllmCore::new(fpm_args());
// Request with 4-token prompt and 1 output token — completes in 1 pass
let r1 = Uuid::from_u128(1);
core.receive(DirectRequest {
tokens: (0..4).collect(),
max_output_tokens: 1,
uuid: Some(r1),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
let pass = core.execute_pass(&mut collector, 0.0);
let fpm = pass.fpm.expect("FPM should be present");
// r1 completes in this pass. The bug was that prompt_len would be 0
// because the request was removed from state before compute_fpm ran.
assert_eq!(fpm.num_prefill_requests, 1);
assert_eq!(fpm.sum_prefill_tokens, 4);
// var_prefill_length should reflect the actual prompt length (4), not 0.
// With a single request, variance is 0 regardless, so check sum_prefill_tokens
// as the main indicator.
assert!(pass.completed_requests > 0, "request should have completed");
}
#[test]
fn test_fpm_completed_decode_request_has_kv_context() {
// Decode request that completes — its KV context should be captured
// correctly even though it's removed from state.
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(16)
.max_num_batched_tokens(Some(16))
.max_num_seqs(Some(4))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new(args);
let r1 = Uuid::from_u128(1);
core.receive(DirectRequest {
tokens: (0..4).collect(),
max_output_tokens: 2,
uuid: Some(r1),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
// Pass 1: prefill + first decode token
core.execute_pass(&mut collector, 0.0);
// Pass 2: second decode token (completes the request)
let pass2 = core.execute_pass(&mut collector, 1.0);
let fpm2 = pass2.fpm.expect("FPM should be present");
assert_eq!(fpm2.num_decode_requests, 1);
// The completed decode request should have contributed its KV context
// (prompt_len + generated_so_far at schedule time).
assert!(
fpm2.sum_decode_kv_tokens > 0,
"completed decode request should still contribute KV context, got {}",
fpm2.sum_decode_kv_tokens
);
}
#[test]
fn test_fpm_queued_requests() {
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(4) // Very limited KV — only room for one request
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(2))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new(args);
// r1 and r2 both have 8-token prompts but only 4 blocks available
let r1 = Uuid::from_u128(1);
let r2 = Uuid::from_u128(2);
core.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 1,
uuid: Some(r1),
dp_rank: 0,
arrival_timestamp_ms: None,
});
core.receive(DirectRequest {
tokens: (100..108).collect(),
max_output_tokens: 1,
uuid: Some(r2),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
let pass = core.execute_pass(&mut collector, 0.0);
let fpm = pass.fpm.expect("FPM should be present");
// At least one request should be scheduled, the other might be queued
// (depending on KV capacity). Some requests may have completed and
// been removed from both scheduled and queued.
let total_scheduled = fpm.num_prefill_requests + fpm.num_decode_requests;
assert!(
total_scheduled >= 1,
"at least one request should be scheduled"
);
}
#[test]
fn test_fpm_var_prefill_length_with_multiple_requests() {
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(32)
.max_num_batched_tokens(Some(32))
.max_num_seqs(Some(4))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new(args);
// Two prefill requests with different prompt lengths
core.receive(DirectRequest {
tokens: (0..4).collect(), // prompt_len = 4
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
core.receive(DirectRequest {
tokens: (100..112).collect(), // prompt_len = 12
max_output_tokens: 1,
uuid: Some(Uuid::from_u128(2)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
let pass = core.execute_pass(&mut collector, 0.0);
let fpm = pass.fpm.expect("FPM should be present");
assert_eq!(fpm.num_prefill_requests, 2);
// Population variance of [4, 12]: mean=8, var=((4-8)^2+(12-8)^2)/2 = 16
assert!(
(fpm.var_prefill_length - 16.0).abs() < 1e-6,
"expected var=16.0, got {}",
fpm.var_prefill_length
);
}
#[test]
fn test_fpm_chunked_prefill_reports_chunk_not_full_prompt() {
// With max_num_batched_tokens=8 and a 16-token prompt, chunked prefill
// should split across two passes. Each pass should report only the
// chunk size in sum_prefill_tokens, not the full prompt length.
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(16)
.max_num_batched_tokens(Some(8))
.max_num_seqs(Some(4))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new(args);
core.receive(DirectRequest {
tokens: (0..16).collect(),
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
let mut collector = crate::replay::TraceCollector::default();
// Pass 1: first chunk
let pass1 = core.execute_pass(&mut collector, 0.0);
let fpm1 = pass1.fpm.expect("FPM should be present");
assert_eq!(fpm1.num_prefill_requests, 1);
assert!(
fpm1.sum_prefill_tokens <= 8,
"chunk should be at most 8 tokens, got {}",
fpm1.sum_prefill_tokens
);
assert!(fpm1.sum_prefill_tokens > 0);
// Pass 2: remaining chunk
let pass2 = core.execute_pass(&mut collector, 1.0);
let fpm2 = pass2.fpm.expect("FPM should be present");
assert_eq!(fpm2.num_prefill_requests, 1, "still prefilling");
assert!(
fpm2.sum_prefill_tokens <= 8,
"second chunk should also be at most 8 tokens, got {}",
fpm2.sum_prefill_tokens
);
// Total across both chunks should equal the full prompt length
assert_eq!(
fpm1.sum_prefill_tokens + fpm2.sum_prefill_tokens,
16,
"total prefill tokens across chunks should equal full prompt"
);
// Variance should be over the full prompt length (16) in both passes
assert_eq!(
fpm1.var_prefill_length, 0.0,
"single request → zero variance"
);
assert_eq!(
fpm2.var_prefill_length, 0.0,
"single request → zero variance"
);
}
#[test]
fn test_fpm_preemption_creates_queued_decode() {
// Trigger preemption: fill KV with running requests, then submit a new
// one that forces eviction. The preempted request should appear as a
// queued decode in FPM.
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(6) // 24 tokens of KV — very tight
.max_num_batched_tokens(Some(32))
.max_num_seqs(Some(3))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.preemption_mode(PreemptionMode::Lifo)
.speedup_ratio(0.0)
.build()
.unwrap();
let mut core = VllmCore::new(args);
let mut collector = crate::replay::TraceCollector::default();
// r1: 4-token prompt, long output (stays running)
core.receive(DirectRequest {
tokens: (0..4).collect(),
max_output_tokens: 20,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
// Prefill r1 and decode a few tokens to build up KV
core.execute_pass(&mut collector, 0.0);
core.execute_pass(&mut collector, 1.0);
core.execute_pass(&mut collector, 2.0);
// r2: another request that will compete for KV
core.receive(DirectRequest {
tokens: (100..116).collect(), // 16 tokens — will pressure KV
max_output_tokens: 5,
uuid: Some(Uuid::from_u128(2)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
// This pass should trigger preemption
let pass = core.execute_pass(&mut collector, 3.0);
let fpm = pass.fpm.expect("FPM should be present");
// We should see at least one queued decode (preempted request) OR one
// queued prefill (if the new request couldn't be scheduled). The key
// assertion is that queued metrics are non-zero when KV pressure exists.
let total_queued = fpm.num_queued_prefill + fpm.num_queued_decode;
if total_queued > 0 {
// Preemption occurred — verify the preempted decode has KV context
if fpm.num_queued_decode > 0 {
assert!(
fpm.sum_queued_decode_kv_tokens > 0,
"preempted decode should have KV context"
);
}
}
// Regardless, at least one request should be scheduled
let total_scheduled = fpm.num_prefill_requests + fpm.num_decode_requests;
assert!(total_scheduled >= 1);
}
#[tokio::test]
async fn test_fpm_sent_through_sink() {
use crate::scheduler::test_utils::CapturingFpmSink;
let args = MockEngineArgs::builder()
.block_size(4)
.num_gpu_blocks(16)
.max_num_batched_tokens(Some(16))
.max_num_seqs(Some(4))
.enable_chunked_prefill(true)
.enable_prefix_caching(false)
.speedup_ratio(0.0)
.build()
.unwrap();
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<Vec<OutputSignal>>();
let fpm_sink = Arc::new(CapturingFpmSink::default());
let fpm_publisher = crate::common::protocols::FpmPublisher::new(Some(
fpm_sink.clone() as Arc<dyn crate::common::protocols::FpmSink>
));
let scheduler = Scheduler::new(
args,
0,
Some(output_tx),
KvEventPublishers::default(),
None,
fpm_publisher,
);
scheduler.receive(DirectRequest {
tokens: (0..8).collect(),
max_output_tokens: 2,
uuid: Some(Uuid::from_u128(1)),
dp_rank: 0,
arrival_timestamp_ms: None,
});
// Wait for at least one output signal — ensures the scheduler has
// completed at least one pass and drained the deferred FPM buffer.
tokio::time::timeout(Duration::from_secs(5), output_rx.recv())
.await
.expect("timed out waiting for output")
.expect("output channel closed");
let snapshots = fpm_sink.take();
assert!(
!snapshots.is_empty(),
"should have received at least one FPM snapshot"
);
let fpm = &snapshots[0];
assert_eq!(fpm.num_prefill_requests, 1);
assert!(fpm.sum_prefill_tokens > 0);
assert!(fpm.wall_time_secs > 0.0);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment