Unverified Commit a58bcc31 authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

refactor: load planner using new forwardpass metric and many improvements (#7351)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent db14d63f
...@@ -26,13 +26,13 @@ The Dynamo **Planner** is an autoscaler purpose-built for these constraints. It ...@@ -26,13 +26,13 @@ The Dynamo **Planner** is an autoscaler purpose-built for these constraints. It
The Planner supports two scaling modes that can run independently or together: The Planner supports two scaling modes that can run independently or together:
- **Throughput-based scaling**: Uses pre-deployment profiling data and traffic prediction to compute the replica count needed to meet TTFT and ITL targets. Adjusts on a longer interval (default 180s). This is the primary mode for production deployments. - **Throughput-based scaling**: Uses pre-deployment profiling data and traffic prediction to compute the replica count needed to meet TTFT and ITL targets. Adjusts on a longer interval (default 180s). This is the primary mode for production deployments.
- **Load-based scaling (Experimental)**: Uses real-time per-worker load metrics (active prefill tokens, active KV blocks) from the router and fits an online linear regression to make scaling decisions. No profiling data required. Adjusts on a short interval (default 5s) to respond quickly to bursts. - **Load-based scaling**: Uses ForwardPassMetrics (FPM) from the Dynamo event plane and fits an online linear regression to make scaling decisions. No profiling data or KV Router required. Adjusts on a short interval (default 5s) to respond quickly to bursts.
When both modes are enabled, throughput-based scaling provides a capacity floor (long-term planning) while load-based scaling handles real-time adjustments above that floor. When both modes are enabled, throughput-based scaling provides a capacity floor (long-term planning) while load-based scaling handles real-time adjustments above that floor.
## Feature Matrix ## Feature Matrix
| Feature | Throughput-Based | Load-Based (Experimental) | | Feature | Throughput-Based | Load-Based |
|---------|:----------------:|:-------------------------:| |---------|:----------------:|:-------------------------:|
| **Deployment** | | | | **Deployment** | | |
| Disaggregated | Supported | Supported | | Disaggregated | Supported | Supported |
...@@ -99,13 +99,11 @@ kubectl apply -f examples/backends/vllm/deploy/disagg_planner.yaml -n $NAMESPACE ...@@ -99,13 +99,11 @@ kubectl apply -f examples/backends/vllm/deploy/disagg_planner.yaml -n $NAMESPACE
## Current Limitations ## Current Limitations
### Load-based scaling (Experimental) ### Load-based scaling
Load-based scaling is experimental and has the following known limitations. These are actively being addressed as part of the metrics refactor work. Throughput-based scaling is not affected by any of these. Load-based scaling has the following known limitations. Throughput-based scaling is not affected by any of these.
**Requires the KV Router.** Load-based scaling relies on per-worker engine metrics (active prefill tokens, active KV blocks) published by the [KV Router](../router/README.md). Other routing strategies (round-robin, random) do not emit these metrics, so load-based scaling cannot operate without the KV Router. **Requires ForwardPassMetrics (FPM).** Load-based scaling uses per-engine per-iteration metrics delivered via the Dynamo event plane (ForwardPassMetrics). FPM is currently only available for vllm and is automatically enabled when the engine uses `InstrumentedScheduler` and `DYN_FORWARDPASS_METRIC_PORT` is set. The KV Router is **not** required for load-based scaling.
**Scale-down with idle workers.** If a worker receives no requests (for example, because the router is not distributing traffic evenly), the router does not publish metrics for that worker. Without metrics, the Planner cannot evaluate whether the worker is underutilized, which can prevent scale-down decisions. **Workaround:** Ensure traffic distribution reaches all workers. If you observe workers stuck at zero load, check your router configuration.
### General ### General
...@@ -144,7 +142,7 @@ Load-based scaling is experimental and has the following known limitations. Thes ...@@ -144,7 +142,7 @@ Load-based scaling is experimental and has the following known limitations. Thes
| `--profile-results-dir` | `profiling_results` | Path to profiling data (NPZ/JSON) | | `--profile-results-dir` | `profiling_results` | Path to profiling data (NPZ/JSON) |
| `--load-predictor` | `arima` | Prediction model (`arima`, `prophet`, `kalman`, `constant`) | | `--load-predictor` | `arima` | Prediction model (`arima`, `prophet`, `kalman`, `constant`) |
| `--no-correction` | `true` | Disable correction factors (auto-disabled when load-based scaling is on) | | `--no-correction` | `true` | Disable correction factors (auto-disabled when load-based scaling is on) |
| **Load-based scaling (Experimental)** | | | | **Load-based scaling** | | |
| `--enable-loadbased-scaling` | `false` | Enable load-based scaling | | `--enable-loadbased-scaling` | `false` | Enable load-based scaling |
| `--disable-throughput-scaling` | `false` | Disable throughput-based scaling (required for `agg` mode) | | `--disable-throughput-scaling` | `false` | Disable throughput-based scaling (required for `agg` mode) |
| `--loadbased-router-metrics-url` | auto-discovered | URL to router's `/metrics` endpoint | | `--loadbased-router-metrics-url` | auto-discovered | URL to router's `/metrics` endpoint |
...@@ -186,7 +184,7 @@ The dashboard shows: ...@@ -186,7 +184,7 @@ The dashboard shows:
- TTFT and ITL distributions - TTFT and ITL distributions
- Input/output sequence lengths - Input/output sequence lengths
**Load-based scaling** pulls per-engine status directly from the frontend's `/metrics` endpoint: **Load-based scaling** uses ForwardPassMetrics (FPM) from the Dynamo event plane:
- Active prefill tokens per worker - Per-iteration wall time, scheduled prefill/decode tokens, and queued request status
- Active decode blocks per worker - Delivered via `FpmEventSubscriber` with automatic engine discovery and lifecycle tracking
- Last observed TTFT, ITL, and ISL per worker - No router `/metrics` scraping required
...@@ -72,7 +72,6 @@ When throughput-based scaling is enabled, the planner needs interpolation curves ...@@ -72,7 +72,6 @@ When throughput-based scaling is enabled, the planner needs interpolation curves
| `load_scaling_down_sensitivity` | int | `80` | Scale-down sensitivity 0–100 (0=never, 100=aggressive). | | `load_scaling_down_sensitivity` | int | `80` | Scale-down sensitivity 0–100 (0=never, 100=aggressive). |
| `load_metric_samples` | int | `10` | Number of metric samples to collect per decision. | | `load_metric_samples` | int | `10` | Number of metric samples to collect per decision. |
| `load_min_observations` | int | `5` | Minimum observations before making scaling decisions. | | `load_min_observations` | int | `5` | Minimum observations before making scaling decisions. |
| `load_router_metrics_url` | string | `null` | Router metrics endpoint. Auto-discovered in Kubernetes mode. |
### General Settings ### General Settings
......
...@@ -165,30 +165,31 @@ After the delay: ...@@ -165,30 +165,31 @@ After the delay:
- **Interpolation accuracy vs profiling cost**: Higher `prefillInterpolationGranularity` and `decodeInterpolationGranularity` in the profiling sweep produce more accurate interpolation but increase profiling time linearly. Default granularity (16 prefill, 6 decode) balances accuracy with profiling duration. - **Interpolation accuracy vs profiling cost**: Higher `prefillInterpolationGranularity` and `decodeInterpolationGranularity` in the profiling sweep produce more accurate interpolation but increase profiling time linearly. Default granularity (16 prefill, 6 decode) balances accuracy with profiling duration.
- **Predictor warm-up period**: All predictors need observation history before making reliable forecasts. ARIMA and Prophet need multiple adjustment intervals of data. Kalman starts forecasting after `--kalman-min-points` observations. During warm-up, the planner uses the constant predictor as fallback. - **Predictor warm-up period**: All predictors need observation history before making reliable forecasts. ARIMA and Prophet need multiple adjustment intervals of data. Kalman starts forecasting after `--kalman-min-points` observations. During warm-up, the planner uses the constant predictor as fallback.
## Load-Based Scaling (Experimental) ## Load-Based Scaling
The load-based mode uses real-time per-worker metrics from the router to make SLA-aware scaling decisions without requiring profiling data. The load-based mode uses ForwardPassMetrics (FPM) from the Dynamo event plane to make SLA-aware scaling decisions without requiring profiling data or the KV Router.
### Metrics ### Metrics
The planner pulls per-worker load metrics directly from the frontend's `/metrics` endpoint: Each engine emits per-iteration `ForwardPassMetrics` via ZMQ -> FpmEventRelay -> event plane. The planner subscribes via `FpmEventSubscriber` with automatic engine discovery and MDC-based lifecycle tracking. Key fields used:
- **Active prefill tokens**: pending prefill tokens per worker - **wall_time**: per-iteration execution time (regression target)
- **Active decode blocks**: active KV blocks per worker - **scheduled_requests.sum_prefill_tokens**: prefill regression input
- **Last TTFT, ITL, ISL**: most recent observed latencies per worker - **scheduled_requests.sum_decode_kv_tokens**: decode regression input
- **queued_requests**: queued prefill/decode load for TTFT/ITL simulation
- Idle heartbeats (wall_time=0) are skipped
### Regression Model ### Regression Models
A sliding-window linear regression maps load to latency: Three specialized regression models (`fpm_regression.py`):
- Prefill: `(active_prefill_tokens + ISL)` -> `TTFT` - **PrefillRegressionModel**: 1D regression `sum_prefill_tokens -> wall_time`. Estimates TTFT by simulating chunked prefill scheduling (chunks of `max_num_batched_tokens`).
- Decode: `active_decode_blocks` -> `ITL` - **DecodeRegressionModel**: 1D regression `sum_decode_kv_tokens -> wall_time`. Estimates ITL for total decode load (scheduled + queued + avg decode length).
- **AggRegressionModel**: 2D regression `(sum_prefill_tokens, sum_decode_kv_tokens) -> wall_time`. Estimates both TTFT (simulated prefill with piggybacked decode) and ITL (decode with average piggybacked prefill).
Given a TTFT/ITL SLA target, the model reverse-solves for the maximum load that satisfies the SLA.
### Scaling Decisions ### Scaling Decisions
- **Scale up**: if ALL workers' recent load exceeds the regression-derived target - **Prefill/Decode**: Scale up if ALL engines' estimated TTFT/ITL > SLA; scale down if ALL < SLA * sensitivity
- **Scale down**: if ALL workers' recent load is below the target adjusted by `(num_workers - 1) / num_workers * sensitivity / 100` - **Agg**: Scale up if (ALL TTFT > SLA) OR (ALL ITL > SLA); scale down if (ALL TTFT < SLA * sensitivity) AND (ALL ITL < SLA * sensitivity)
- Only scales by +/-1 per interval (blocking) - Only scales by +/-1 per interval (non-blocking with pending-desired guard: metrics continue to be observed while scaling is in progress, but no new scaling action is issued until the previous one completes)
### Co-existence with Throughput-Based Scaling ### Co-existence with Throughput-Based Scaling
......
...@@ -1763,6 +1763,7 @@ dependencies = [ ...@@ -1763,6 +1763,7 @@ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
"clap", "clap",
"dashmap 6.1.0",
"dynamo-kv-router", "dynamo-kv-router",
"dynamo-llm", "dynamo-llm",
"dynamo-mocker", "dynamo-mocker",
...@@ -1774,6 +1775,7 @@ dependencies = [ ...@@ -1774,6 +1775,7 @@ dependencies = [
"pyo3", "pyo3",
"pyo3-async-runtimes", "pyo3-async-runtimes",
"pythonize", "pythonize",
"rmp",
"serde", "serde",
"serde_json", "serde_json",
"thiserror 2.0.18", "thiserror 2.0.18",
......
...@@ -36,7 +36,9 @@ dynamo-parsers = { path = "../../parsers" } ...@@ -36,7 +36,9 @@ dynamo-parsers = { path = "../../parsers" }
anyhow = { version = "1" } anyhow = { version = "1" }
async-trait = { version = "0.1" } async-trait = { version = "0.1" }
dashmap = { version = "6.1" }
futures = { version = "0.3" } futures = { version = "0.3" }
rmp = { version = "0.8" }
once_cell = { version = "1.20.3" } once_cell = { version = "1.20.3" }
parking_lot = { version = "0.12.4" } parking_lot = { version = "0.12.4" }
serde = { version = "1" } serde = { version = "1" }
......
...@@ -4,16 +4,26 @@ ...@@ -4,16 +4,26 @@
//! Python bindings for Forward Pass Metrics (FPM = ForwardPassMetrics) event plane integration. //! Python bindings for Forward Pass Metrics (FPM = ForwardPassMetrics) event plane integration.
//! //!
//! - `FpmEventRelay`: thin wrapper around `dynamo_llm::fpm_publisher::FpmEventRelay` //! - `FpmEventRelay`: thin wrapper around `dynamo_llm::fpm_publisher::FpmEventRelay`
//! - `FpmEventSubscriber`: wraps `EventSubscriber::for_component` for the consumer side //! - `FpmEventSubscriber`: wraps `EventSubscriber::for_component` for the consumer side.
//! Supports two mutually exclusive modes:
use std::sync::Arc; //! - **recv mode**: call `recv()` to pull one message at a time (existing behaviour).
//! - **tracking mode**: call `start_tracking()` once, then `get_recent_stats()` to
//! retrieve the latest FPM bytes keyed by `(worker_id, dp_rank)`.
use dashmap::{DashMap, DashSet};
use futures::StreamExt;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*; use pyo3::prelude::*;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use super::*; use super::*;
use crate::Endpoint; use crate::Endpoint;
use crate::to_pyerr; use crate::to_pyerr;
use dynamo_runtime::component::Component;
use dynamo_runtime::discovery::{DiscoveryEvent, DiscoveryQuery};
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::transports::event_plane::EventSubscriber; use dynamo_runtime::transports::event_plane::EventSubscriber;
...@@ -54,6 +64,204 @@ impl FpmEventRelay { ...@@ -54,6 +64,204 @@ impl FpmEventRelay {
} }
} }
// ---------------------------------------------------------------------------
// Helpers: partial msgpack decode
// ---------------------------------------------------------------------------
/// Extract `(worker_id, dp_rank)` from a msgspec-encoded `ForwardPassMetrics`.
///
/// msgspec.Struct (without `array_like=True`) encodes as a msgpack **map**:
/// `{"version": 1, "worker_id": "...", "dp_rank": 0, ...}`
///
/// We iterate through the map entries, read "worker_id" and "dp_rank",
/// and skip all other values. Breaks early once both keys are found.
fn extract_fpm_key(data: &[u8]) -> Option<(String, i64)> {
use rmp::decode::{read_int, read_map_len, read_str_len};
let mut cursor = std::io::Cursor::new(data);
let map_len = read_map_len(&mut cursor).ok()?;
let mut worker_id: Option<String> = None;
let mut dp_rank: Option<i64> = None;
for _ in 0..map_len {
// Read key (always a string in msgspec map encoding)
let key_len = read_str_len(&mut cursor).ok()? as usize;
let pos = cursor.position() as usize;
if pos + key_len > data.len() {
return None;
}
let key = std::str::from_utf8(&data[pos..pos + key_len]).ok()?;
cursor.set_position((pos + key_len) as u64);
match key {
"worker_id" => {
let str_len = read_str_len(&mut cursor).ok()? as usize;
let pos = cursor.position() as usize;
if pos + str_len > data.len() {
return None;
}
worker_id = Some(
std::str::from_utf8(&data[pos..pos + str_len])
.ok()?
.to_owned(),
);
cursor.set_position((pos + str_len) as u64);
}
"dp_rank" => {
dp_rank = Some(read_int(&mut cursor).ok()?);
}
_ => {
skip_msgpack_value(&mut cursor)?;
}
}
if worker_id.is_some() && dp_rank.is_some() {
break;
}
}
Some((worker_id?, dp_rank?))
}
/// Advance the cursor past one msgpack value of any type.
///
/// Handles all msgpack formats needed for `ForwardPassMetrics` fields:
/// positive/negative fixint, uint/int 8-64, float 32/64, fixstr/str 8-32,
/// bool, nil, fixarray/array 16-32, fixmap/map 16-32, bin 8-32.
fn skip_msgpack_value(cursor: &mut std::io::Cursor<&[u8]>) -> Option<()> {
use rmp::Marker;
let marker = rmp::decode::read_marker(cursor).ok()?;
match marker {
// Integers
Marker::FixPos(_) | Marker::FixNeg(_) => {}
Marker::U8 | Marker::I8 => skip_bytes(cursor, 1)?,
Marker::U16 | Marker::I16 => skip_bytes(cursor, 2)?,
Marker::U32 | Marker::I32 | Marker::F32 => skip_bytes(cursor, 4)?,
Marker::U64 | Marker::I64 | Marker::F64 => skip_bytes(cursor, 8)?,
// Nil / Bool
Marker::Null | Marker::True | Marker::False => {}
// Strings
Marker::FixStr(len) => skip_bytes(cursor, len as u64)?,
Marker::Str8 => {
let len = read_u8(cursor)? as u64;
skip_bytes(cursor, len)?;
}
Marker::Str16 => {
let len = read_u16(cursor)? as u64;
skip_bytes(cursor, len)?;
}
Marker::Str32 => {
let len = read_u32(cursor)? as u64;
skip_bytes(cursor, len)?;
}
// Binary
Marker::Bin8 => {
let len = read_u8(cursor)? as u64;
skip_bytes(cursor, len)?;
}
Marker::Bin16 => {
let len = read_u16(cursor)? as u64;
skip_bytes(cursor, len)?;
}
Marker::Bin32 => {
let len = read_u32(cursor)? as u64;
skip_bytes(cursor, len)?;
}
// Arrays (recurse to skip each element)
Marker::FixArray(len) => {
for _ in 0..len {
skip_msgpack_value(cursor)?;
}
}
Marker::Array16 => {
let len = read_u16(cursor)?;
for _ in 0..len {
skip_msgpack_value(cursor)?;
}
}
Marker::Array32 => {
let len = read_u32(cursor)?;
for _ in 0..len {
skip_msgpack_value(cursor)?;
}
}
// Maps (recurse to skip each key-value pair)
Marker::FixMap(len) => {
for _ in 0..len {
skip_msgpack_value(cursor)?;
skip_msgpack_value(cursor)?;
}
}
Marker::Map16 => {
let len = read_u16(cursor)?;
for _ in 0..len {
skip_msgpack_value(cursor)?;
skip_msgpack_value(cursor)?;
}
}
Marker::Map32 => {
let len = read_u32(cursor)?;
for _ in 0..len {
skip_msgpack_value(cursor)?;
skip_msgpack_value(cursor)?;
}
}
// Ext types
Marker::FixExt1 => skip_bytes(cursor, 2)?,
Marker::FixExt2 => skip_bytes(cursor, 3)?,
Marker::FixExt4 => skip_bytes(cursor, 5)?,
Marker::FixExt8 => skip_bytes(cursor, 9)?,
Marker::FixExt16 => skip_bytes(cursor, 17)?,
Marker::Ext8 => {
let len = read_u8(cursor)? as u64;
skip_bytes(cursor, 1 + len)?;
}
Marker::Ext16 => {
let len = read_u16(cursor)? as u64;
skip_bytes(cursor, 1 + len)?;
}
Marker::Ext32 => {
let len = read_u32(cursor)? as u64;
skip_bytes(cursor, 1 + len)?;
}
Marker::Reserved => return None,
}
Some(())
}
fn skip_bytes(cursor: &mut std::io::Cursor<&[u8]>, n: u64) -> Option<()> {
let new_pos = cursor.position().checked_add(n)?;
if new_pos > cursor.get_ref().len() as u64 {
return None;
}
cursor.set_position(new_pos);
Some(())
}
fn read_u8(cursor: &mut std::io::Cursor<&[u8]>) -> Option<u8> {
use std::io::Read;
let mut buf = [0u8; 1];
cursor.read_exact(&mut buf).ok()?;
Some(buf[0])
}
fn read_u16(cursor: &mut std::io::Cursor<&[u8]>) -> Option<u16> {
use std::io::Read;
let mut buf = [0u8; 2];
cursor.read_exact(&mut buf).ok()?;
Some(u16::from_be_bytes(buf))
}
fn read_u32(cursor: &mut std::io::Cursor<&[u8]>) -> Option<u32> {
use std::io::Read;
let mut buf = [0u8; 4];
cursor.read_exact(&mut buf).ok()?;
Some(u32::from_be_bytes(buf))
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Subscriber: event plane -> consumer // Subscriber: event plane -> consumer
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
...@@ -61,89 +269,346 @@ impl FpmEventRelay { ...@@ -61,89 +269,346 @@ impl FpmEventRelay {
/// Subscriber for ForwardPassMetrics from the event plane. /// Subscriber for ForwardPassMetrics from the event plane.
/// ///
/// Auto-discovers engine publishers via the discovery plane (K8s CRD / etcd / file). /// Auto-discovers engine publishers via the discovery plane (K8s CRD / etcd / file).
/// Returns raw msgspec-serialized bytes that Python decodes with ///
/// `forward_pass_metrics.decode()`. /// Two mutually exclusive usage modes:
///
/// 1. **recv mode** (default): call `recv()` to pull individual messages.
/// 2. **tracking mode**: call `start_tracking()` once, then poll `get_recent_stats()`
/// to retrieve the latest FPM bytes keyed by `(worker_id, dp_rank)`.
///
/// # Tracking mode concurrency design
///
/// Three concurrent actors access shared state:
///
/// - **Task 1** (event consumption, tokio): writes to `latest_stats` on every FPM.
/// - **Task 2** (MDC discovery watch, tokio): maintains `known_workers` set and
/// removes dead-worker entries from `latest_stats` on `Removed` events.
/// - **`get_recent_stats()`** (Python thread): reads both `latest_stats` and
/// `known_workers` to produce a filtered snapshot.
///
/// Both collections use `DashMap`/`DashSet` (sharded concurrent maps) so that
/// `get_recent_stats()` never blocks Task 1's high-frequency writes. Per-shard
/// locking means readers and writers only contend if they happen to hit the same
/// shard, which is rare in practice.
///
/// Ghost entries (FPM arriving after its worker's MDC `Removed` event) are
/// filtered out by the `known_workers` check in `get_recent_stats()` and eagerly
/// pruned from `latest_stats` on `Removed` events.
#[pyclass] #[pyclass]
pub(crate) struct FpmEventSubscriber { pub(crate) struct FpmEventSubscriber {
rx: Arc<std::sync::Mutex<tokio::sync::mpsc::UnboundedReceiver<Vec<u8>>>>, component: Component,
cancel: CancellationToken, cancel: CancellationToken,
// recv mode state (lazily initialised on first recv() call)
recv_started: Arc<AtomicBool>,
rx: Arc<std::sync::Mutex<Option<tokio::sync::mpsc::UnboundedReceiver<Vec<u8>>>>>,
// tracking mode state
tracking_started: Arc<AtomicBool>,
latest_stats: Arc<DashMap<(String, i64), Vec<u8>>>,
// Worker IDs currently registered in MDC. Maintained by Task 2
// (insert on Added, remove on Removed). Used by get_recent_stats()
// to filter out ghost entries without contending with Task 1's writes.
known_workers: Arc<DashSet<String>>,
} }
#[pymethods] #[pymethods]
impl FpmEventSubscriber { impl FpmEventSubscriber {
/// Create a subscriber that auto-discovers FPM publishers. /// Create a subscriber that auto-discovers FPM publishers.
/// ///
/// No background tasks are started until `recv()` or `start_tracking()` is called.
///
/// Args: /// Args:
/// endpoint: Dynamo component endpoint (provides runtime + discovery). /// endpoint: Dynamo component endpoint (provides runtime + discovery).
#[new] #[new]
#[pyo3(signature = (endpoint,))] #[pyo3(signature = (endpoint,))]
fn new(endpoint: Endpoint) -> PyResult<Self> { fn new(endpoint: Endpoint) -> PyResult<Self> {
let component = endpoint.inner.component().clone(); let component = endpoint.inner.component().clone();
Ok(Self {
component,
cancel: CancellationToken::new(),
recv_started: Arc::new(AtomicBool::new(false)),
rx: Arc::new(std::sync::Mutex::new(None)),
tracking_started: Arc::new(AtomicBool::new(false)),
latest_stats: Arc::new(DashMap::new()),
known_workers: Arc::new(DashSet::new()),
})
}
/// Blocking receive of next message bytes. Releases the GIL while waiting.
///
/// On the first call a background subscriber task is spawned (recv mode).
/// Cannot be used after `start_tracking()`.
///
/// Returns the raw msgspec payload, or None if the stream is closed.
fn recv(&self, py: Python) -> PyResult<Option<Vec<u8>>> {
if self.tracking_started.load(Ordering::SeqCst) {
return Err(PyRuntimeError::new_err(
"Cannot call recv() after start_tracking()",
));
}
// Lazily start the recv-mode subscriber task on the first call.
if !self.recv_started.swap(true, Ordering::SeqCst) {
let component = self.component.clone();
let cancel = self.cancel.clone();
let (tx, rx_new) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>();
{
let mut guard = self.rx.lock().map_err(|e| to_pyerr(format!("{e}")))?;
*guard = Some(rx_new);
}
let rt = component.drt().runtime().secondary();
rt.spawn(async move {
let mut subscriber =
match EventSubscriber::for_component(&component, FPM_TOPIC).await {
Ok(s) => s,
Err(e) => {
tracing::error!("FPM subscriber (recv): failed to create: {e}");
return;
}
};
tracing::info!("FPM subscriber (recv): listening for forward-pass-metrics events");
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => {
tracing::info!("FPM subscriber (recv): shutting down");
break;
}
event = subscriber.next() => {
match event {
Some(Ok(envelope)) => {
if tx.send(envelope.payload.to_vec()).is_err() {
tracing::info!(
"FPM subscriber (recv): receiver dropped, exiting"
);
break;
}
}
Some(Err(e)) => {
tracing::warn!("FPM subscriber (recv): event error: {e}");
}
None => {
tracing::info!("FPM subscriber (recv): stream ended");
break;
}
}
}
}
}
});
}
let rx = self.rx.clone();
py.allow_threads(move || {
let mut guard = rx
.lock()
.map_err(|e| to_pyerr(format!("lock poisoned: {e}")))?;
match guard.as_mut() {
Some(rx) => Ok(rx.blocking_recv()),
None => Ok(None),
}
})
}
/// Start background tracking of the latest FPM per `(worker_id, dp_rank)`.
///
/// Spawns two background tasks:
///
/// 1. **Event consumption** (Task 1): subscribes to FPM events, extracts
/// `(worker_id, dp_rank)` from the msgpack payload, stores the latest
/// raw bytes in `latest_stats`. Uses per-shard locking via `DashMap`
/// so contention with concurrent readers is minimal.
///
/// 2. **MDC discovery watch** (Task 2): monitors `ComponentModels` for the
/// target component. Maintains `known_workers` (the set of currently
/// alive worker IDs) and eagerly removes dead-worker entries from
/// `latest_stats` on `Removed` events.
///
/// After calling this method, `recv()` will raise an error.
fn start_tracking(&self) -> PyResult<()> {
if self.recv_started.load(Ordering::SeqCst) {
return Err(PyRuntimeError::new_err(
"Cannot call start_tracking() after recv()",
));
}
if self.tracking_started.swap(true, Ordering::SeqCst) {
return Err(PyRuntimeError::new_err("Tracking already started"));
}
let component = self.component.clone();
let rt = component.drt().runtime().secondary(); let rt = component.drt().runtime().secondary();
let cancel = CancellationToken::new(); let cancel = self.cancel.clone();
let cancel_clone = cancel.clone(); let stats = self.latest_stats.clone();
let known = self.known_workers.clone();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>(); // Task 1: event consumption.
//
// Inserts every FPM into latest_stats without checking known_workers.
// Ghost entries (from workers that have already been removed) are
// filtered out by get_recent_stats() at read time. DashMap's
// per-shard locking keeps contention low but does not eliminate it
// entirely -- a concurrent reader hitting the same shard will briefly
// wait for the insert to complete.
rt.spawn({
let cancel = cancel.clone();
let component = component.clone();
let stats = stats.clone();
async move {
let mut subscriber =
match EventSubscriber::for_component(&component, FPM_TOPIC).await {
Ok(s) => s,
Err(e) => {
tracing::error!("FPM tracker: failed to create subscriber: {e}");
return;
}
};
tracing::info!("FPM tracker: listening for forward-pass-metrics events");
rt.spawn(async move { loop {
let mut subscriber = match EventSubscriber::for_component(&component, FPM_TOPIC).await { tokio::select! {
Ok(s) => s, biased;
Err(e) => { _ = cancel.cancelled() => {
tracing::error!("FPM subscriber: failed to create: {e}"); tracing::info!("FPM tracker: shutting down event task");
return; break;
}
event = subscriber.next() => {
match event {
Some(Ok(envelope)) => {
let payload = envelope.payload.to_vec();
if let Some(key) = extract_fpm_key(&payload) {
stats.insert(key, payload);
} else {
tracing::warn!(
"FPM tracker: failed to extract key from payload ({} bytes)",
envelope.payload.len()
);
}
}
Some(Err(e)) => {
tracing::warn!("FPM tracker: event error: {e}");
}
None => {
tracing::info!("FPM tracker: event stream ended");
break;
}
}
}
}
} }
}; }
});
tracing::info!("FPM subscriber: listening for forward-pass-metrics events"); // Task 2: MDC discovery watch.
//
// Maintains known_workers (insert on Added, remove on Removed) and
// eagerly prunes latest_stats on Removed events. This handles the
// normal scale-down path. Any ghost entries created by the race
// condition (FPM arriving *after* the Removed event) are caught by the
// known_workers filter in get_recent_stats().
rt.spawn({
let cancel = cancel.clone();
let component = component.clone();
let stats = stats.clone();
let known = known.clone();
async move {
let discovery = component.drt().discovery();
let query = DiscoveryQuery::ComponentModels {
namespace: component.namespace().name(),
component: component.name().to_string(),
};
loop { let stream = match discovery.list_and_watch(query, Some(cancel.clone())).await {
tokio::select! { Ok(s) => s,
biased; Err(e) => {
_ = cancel_clone.cancelled() => { tracing::error!("FPM tracker: failed to create discovery watch: {e}");
tracing::info!("FPM subscriber: shutting down"); return;
break;
} }
event = subscriber.next() => { };
match event {
Some(Ok(envelope)) => { tracing::info!("FPM tracker: watching MDC discovery for engine lifecycle");
if tx.send(envelope.payload.to_vec()).is_err() {
tracing::info!("FPM subscriber: receiver dropped, exiting"); let mut stream = stream;
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => {
tracing::info!("FPM tracker: shutting down discovery task");
break;
}
event = stream.next() => {
match event {
Some(Ok(DiscoveryEvent::Added(instance))) => {
let wid = instance.instance_id().to_string();
known.insert(wid.clone());
tracing::debug!("FPM tracker: worker {wid} added to known set");
}
Some(Ok(DiscoveryEvent::Removed(id))) => {
let removed_id = id.instance_id().to_string();
known.remove(&removed_id);
// Eagerly prune latest_stats for the common case
// (worker removed cleanly before any late FPMs arrive).
let before = stats.len();
stats.retain(|(worker_id, _), _| *worker_id != removed_id);
let removed = before - stats.len();
if removed > 0 {
tracing::info!(
"FPM tracker: removed {removed} entries for \
worker_id={removed_id} (MDC removed)"
);
}
}
Some(Err(e)) => {
tracing::warn!("FPM tracker: discovery error: {e}");
}
None => {
tracing::info!("FPM tracker: discovery stream ended");
break; break;
} }
} }
Some(Err(e)) => {
tracing::warn!("FPM subscriber: event error: {e}");
}
None => {
tracing::info!("FPM subscriber: stream ended");
break;
}
} }
} }
} }
} }
}); });
Ok(Self { Ok(())
rx: Arc::new(std::sync::Mutex::new(rx)),
cancel,
})
} }
/// Blocking receive of next message bytes. Releases the GIL while waiting. /// Return the latest FPM bytes for every tracked `(worker_id, dp_rank)`.
/// ///
/// Returns the raw msgspec payload, or None if the stream is closed. /// The returned snapshot is filtered against `known_workers` so that
fn recv(&self, py: Python) -> PyResult<Option<Vec<u8>>> { /// ghost entries (late FPMs from already-removed workers) are excluded.
let rx = self.rx.clone(); /// Uses `DashMap`/`DashSet` with per-shard locking so contention with
py.allow_threads(move || { /// the hot-path writer is minimal (but not zero -- a reader and writer
let mut guard = rx /// hitting the same shard will briefly contend).
.lock() ///
.map_err(|e| to_pyerr(format!("lock poisoned: {e}")))?; /// Returns:
Ok(guard.blocking_recv()) /// dict mapping `(worker_id: str, dp_rank: int)` to raw msgspec bytes.
}) fn get_recent_stats(&self) -> PyResult<HashMap<(String, i64), Vec<u8>>> {
if !self.tracking_started.load(Ordering::SeqCst) {
return Err(PyRuntimeError::new_err(
"start_tracking() has not been called",
));
}
let snapshot = self
.latest_stats
.iter()
.filter(|entry| self.known_workers.contains(&entry.key().0))
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect();
Ok(snapshot)
} }
/// Shut down the subscriber. /// Shut down the subscriber (all background tasks).
fn shutdown(&self) { fn shutdown(&self) {
self.cancel.cancel(); self.cancel.cancel();
} }
......
...@@ -841,12 +841,23 @@ class FpmEventSubscriber: ...@@ -841,12 +841,23 @@ class FpmEventSubscriber:
""" """
Subscriber for ForwardPassMetrics from the Dynamo event plane. Subscriber for ForwardPassMetrics from the Dynamo event plane.
Auto-discovers engine publishers via the discovery plane. Auto-discovers engine publishers via the discovery plane.
Two mutually exclusive usage modes:
1. **recv mode** (default): call ``recv()`` to pull individual messages.
2. **tracking mode**: call ``start_tracking()`` once, then poll
``get_recent_stats()`` to retrieve the latest FPM bytes keyed by
``(worker_id, dp_rank)``. Stale entries are cleaned up when
workers are removed (via discovery watch).
""" """
def __init__(self, endpoint: Endpoint) -> None: def __init__(self, endpoint: Endpoint) -> None:
""" """
Create a subscriber that auto-discovers FPM publishers. Create a subscriber that auto-discovers FPM publishers.
No background tasks are started until ``recv()`` or
``start_tracking()`` is called.
Args: Args:
endpoint: Dynamo component endpoint (provides runtime + discovery). endpoint: Dynamo component endpoint (provides runtime + discovery).
""" """
...@@ -857,13 +868,48 @@ class FpmEventSubscriber: ...@@ -857,13 +868,48 @@ class FpmEventSubscriber:
Blocking receive of the next message (raw msgspec bytes). Blocking receive of the next message (raw msgspec bytes).
Releases the GIL while waiting. Releases the GIL while waiting.
On the first call a background subscriber task is spawned (recv mode).
Cannot be used after ``start_tracking()``.
Returns: Returns:
Raw msgspec payload, or None if the stream is closed. Raw msgspec payload, or None if the stream is closed.
""" """
... ...
def start_tracking(self) -> None:
"""
Start background tracking of the latest FPM per (worker_id, dp_rank).
Spawns two background tasks:
1. Event consumption: subscribes to FPM events, extracts the composite
key (worker_id, dp_rank) from the msgpack payload, stores latest
raw bytes in an internal map.
2. MDC discovery watch: monitors ComponentModels for the target
component. When a model is removed, all entries whose
worker_id matches the removed instance_id are purged.
After calling this, ``recv()`` will raise RuntimeError.
"""
...
def get_recent_stats(self) -> dict[tuple[str, int], bytes]:
"""
Return the latest FPM bytes for every tracked (worker_id, dp_rank).
Cleanup of removed engines is handled by the MDC discovery watch
task spawned by ``start_tracking()``.
Raises RuntimeError if ``start_tracking()`` has not been called.
Returns:
dict mapping ``(worker_id, dp_rank)`` to raw msgspec bytes.
Decode each value with ``forward_pass_metrics.decode(data)``.
"""
...
def shutdown(self) -> None: def shutdown(self) -> None:
"""Shut down the subscriber.""" """Shut down the subscriber (all background tasks)."""
... ...
......
...@@ -23,6 +23,7 @@ from dynamo.planner.utils.planner_core import ( ...@@ -23,6 +23,7 @@ from dynamo.planner.utils.planner_core import (
) )
from dynamo.planner.utils.prefill_planner import PrefillPlanner from dynamo.planner.utils.prefill_planner import PrefillPlanner
from dynamo.planner.utils.prometheus import Metrics from dynamo.planner.utils.prometheus import Metrics
from dynamo.planner.worker_info import WorkerInfo
pytestmark = [pytest.mark.pre_merge, pytest.mark.gpu_0] pytestmark = [pytest.mark.pre_merge, pytest.mark.gpu_0]
...@@ -56,12 +57,12 @@ class PlannerHarness: ...@@ -56,12 +57,12 @@ class PlannerHarness:
target_replicas = [ target_replicas = [
{ {
"sub_component_type": "prefill", "sub_component_type": "prefill",
"component_name": self.prefill_planner.prefill_component_name, "component_name": self.prefill_planner.prefill_worker_info.k8s_name,
"desired_replicas": next_num_p, "desired_replicas": next_num_p,
}, },
{ {
"sub_component_type": "decode", "sub_component_type": "decode",
"component_name": self.prefill_planner.decode_component_name, "component_name": self.prefill_planner.decode_worker_info.k8s_name,
"desired_replicas": next_num_d, "desired_replicas": next_num_d,
}, },
] ]
...@@ -83,12 +84,12 @@ class PlannerHarness: ...@@ -83,12 +84,12 @@ class PlannerHarness:
} }
prefill_attrs = { prefill_attrs = {
"prefill_interpolator", "prefill_interpolator",
"prefill_component_name", "prefill_worker_info",
"p_correction_factor", "p_correction_factor",
} }
decode_attrs = { decode_attrs = {
"decode_interpolator", "decode_interpolator",
"decode_component_name", "decode_worker_info",
"d_correction_factor", "d_correction_factor",
} }
if name == "last_metrics": if name == "last_metrics":
...@@ -185,6 +186,20 @@ def planner(): ...@@ -185,6 +186,20 @@ def planner():
decode_planner = DecodePlanner(mock_runtime, config, shared_state=shared_state) decode_planner = DecodePlanner(mock_runtime, config, shared_state=shared_state)
planner = PlannerHarness(prefill_planner, decode_planner, shared_state) planner = PlannerHarness(prefill_planner, decode_planner, shared_state)
# Set up WorkerInfo for both planners
prefill_planner.prefill_worker_info = WorkerInfo(
k8s_name="VllmPrefillWorker",
component_name="prefill",
endpoint="generate",
)
prefill_planner.decode_worker_info = WorkerInfo(
k8s_name="VllmDecodeWorker",
component_name="backend",
endpoint="generate",
)
decode_planner.prefill_worker_info = prefill_planner.prefill_worker_info
decode_planner.decode_worker_info = prefill_planner.decode_worker_info
# Mock the interpolators to return fixed values for testing # Mock the interpolators to return fixed values for testing
planner.prefill_interpolator = Mock() planner.prefill_interpolator = Mock()
planner.decode_interpolator = Mock() planner.decode_interpolator = Mock()
......
...@@ -6,12 +6,27 @@ from unittest.mock import Mock, patch ...@@ -6,12 +6,27 @@ from unittest.mock import Mock, patch
import pytest import pytest
try:
import msgspec # noqa: F401
except ImportError:
pytest.skip("msgspec required for FPM tests", allow_module_level=True)
from dynamo.common.forward_pass_metrics import (
ForwardPassMetrics,
QueuedRequestMetrics,
ScheduledRequestMetrics,
encode,
)
from dynamo.planner.utils.decode_planner import DecodePlanner from dynamo.planner.utils.decode_planner import DecodePlanner
from dynamo.planner.utils.load_based_regression import LoadBasedRegressionModel from dynamo.planner.utils.fpm_regression import (
AggRegressionModel,
DecodeRegressionModel,
PrefillRegressionModel,
)
from dynamo.planner.utils.planner_config import PlannerConfig from dynamo.planner.utils.planner_config import PlannerConfig
from dynamo.planner.utils.planner_core import PlannerSharedState from dynamo.planner.utils.planner_core import PlannerSharedState
from dynamo.planner.utils.prefill_planner import PrefillPlanner from dynamo.planner.utils.prefill_planner import PrefillPlanner
from dynamo.planner.utils.prometheus import CachedLoadMetrics, DirectRouterMetricsClient from dynamo.planner.worker_info import WorkerInfo
pytestmark = [ pytestmark = [
pytest.mark.gpu_0, pytest.mark.gpu_0,
...@@ -21,203 +36,199 @@ pytestmark = [ ...@@ -21,203 +36,199 @@ pytestmark = [
] ]
# ── LoadBasedRegressionModel tests ────────────────────────────────────── def _make_fpm(
*,
sum_prefill_tokens: int = 0,
num_prefill_requests: int = 0,
sum_decode_kv_tokens: int = 0,
num_decode_requests: int = 0,
queued_prefill_tokens: int = 0,
queued_decode_kv_tokens: int = 0,
wall_time: float = 0.01,
worker_id: str = "w1",
dp_rank: int = 0,
) -> ForwardPassMetrics:
return ForwardPassMetrics(
worker_id=worker_id,
dp_rank=dp_rank,
wall_time=wall_time,
scheduled_requests=ScheduledRequestMetrics(
sum_prefill_tokens=sum_prefill_tokens,
num_prefill_requests=num_prefill_requests,
sum_decode_kv_tokens=sum_decode_kv_tokens,
num_decode_requests=num_decode_requests,
),
queued_requests=QueuedRequestMetrics(
sum_prefill_tokens=queued_prefill_tokens,
sum_decode_kv_tokens=queued_decode_kv_tokens,
),
)
# ── PrefillRegressionModel tests ─────────────────────────────────────
class TestLoadBasedRegressionModel: class TestPrefillRegressionModel:
def test_insufficient_data(self): def test_insufficient_data(self):
model = LoadBasedRegressionModel(window_size=50, min_observations=5) model = PrefillRegressionModel(window_size=50, min_observations=5)
assert not model.has_sufficient_data() assert not model.has_sufficient_data()
assert model.predict_x_from_sla(100.0) is None assert model.estimate_next_ttft(0, 2048) is None
def test_heartbeat_skipped(self):
model = PrefillRegressionModel(window_size=50, min_observations=3)
fpm = _make_fpm(wall_time=0.0, sum_prefill_tokens=100, num_prefill_requests=1)
model.add_observation(fpm)
assert model.num_observations == 0
def test_basic_regression_and_ttft_estimate(self):
model = PrefillRegressionModel(window_size=50, min_observations=3)
# wall_time = 0.001 * prefill_tokens + 0.002 (linear relationship)
for tokens in [500, 1000, 1500, 2000, 2500]:
fpm = _make_fpm(
sum_prefill_tokens=tokens,
num_prefill_requests=1,
wall_time=0.001 * tokens + 0.002,
)
model.add_observation(fpm)
def test_basic_linear_prediction(self):
model = LoadBasedRegressionModel(window_size=50, min_observations=3)
# y = 2x + 10: x in [1..5], y in [12..20]
for x in range(1, 6):
model.add_observation(float(x), 2.0 * x + 10.0)
assert model.has_sufficient_data() assert model.has_sufficient_data()
# Reverse: x = (y - 10) / 2, y=100 => x=45
result = model.predict_x_from_sla(100.0)
assert result is not None
assert abs(result - 45.0) < 0.5
def test_negative_slope_fallback_points_below_sla(self):
model = LoadBasedRegressionModel(window_size=50, min_observations=3)
# Negative slope: higher x => lower y
# x=1 -> y=98, x=2 -> y=96, x=3 -> y=94, x=4 -> y=92, x=5 -> y=90
for x in range(1, 6):
model.add_observation(float(x), 100.0 - 2.0 * x)
# target_y=95 => points below: x=3(y=94), x=4(y=92), x=5(y=90)
# min x among those is 3
result = model.predict_x_from_sla(95.0)
assert result is not None
assert abs(result - 3.0) < 0.01
def test_negative_slope_fallback_all_above_sla(self):
model = LoadBasedRegressionModel(window_size=50, min_observations=3)
# Negative slope: x=1 -> y=98, x=2 -> y=96, ..., x=5 -> y=90
for x in range(1, 6):
model.add_observation(float(x), 100.0 - 2.0 * x)
# target_y=50 => all points have y >= 90 > 50, none below
# fallback returns smallest x overall = 1
result = model.predict_x_from_sla(50.0)
assert result is not None
assert abs(result - 1.0) < 0.01
def test_sliding_window_evicts_old(self):
model = LoadBasedRegressionModel(window_size=5, min_observations=3)
# Add 10 observations; only last 5 should remain
for i in range(10):
model.add_observation(float(i), float(i) * 2)
assert model.num_observations == 5
def test_result_clamped_to_non_negative(self): # Single iteration: queued=0, avg_isl should be mean of [500..2500]=1500
model = LoadBasedRegressionModel(window_size=50, min_observations=3) # total_tokens = 0 + avg_isl ≈ 1500
# y = 10x + 100: intercept=100, slope=10 # 1 iteration at max_num_batched_tokens=2048 (1500 < 2048)
for x in range(1, 6): est = model.estimate_next_ttft(
model.add_observation(float(x), 10.0 * x + 100.0) queued_prefill_tokens=0, max_num_batched_tokens=2048
# target_y=5 => x = (5-100)/10 = -9.5 => clamped to 0
result = model.predict_x_from_sla(5.0)
assert result == 0.0
def test_slope_and_intercept_properties(self):
model = LoadBasedRegressionModel(window_size=50, min_observations=3)
for x in range(1, 6):
model.add_observation(float(x), 3.0 * x + 5.0)
assert model.slope is not None
assert abs(model.slope - 3.0) < 0.01
assert model.intercept is not None
assert abs(model.intercept - 5.0) < 0.01
# ── DirectRouterMetricsClient tests ─────────────────────────────────────
class TestDirectRouterMetricsClient:
def test_parse_prometheus_text_basic(self):
"""Metrics with dynamo_namespace/model labels are grouped by worker_type."""
client = DirectRouterMetricsClient("http://localhost:8000/metrics", "test-ns")
text = (
"# HELP dynamo_frontend_worker_active_prefill_tokens Active prefill tokens\n"
"# TYPE dynamo_frontend_worker_active_prefill_tokens gauge\n"
'dynamo_frontend_worker_active_prefill_tokens{dynamo_namespace="test-ns",model="TestModel",worker_type="prefill",worker_id="w1"} 1234\n'
'dynamo_frontend_worker_active_decode_blocks{dynamo_namespace="test-ns",model="TestModel",worker_type="decode",worker_id="w2"} 56\n'
'dynamo_frontend_worker_last_time_to_first_token_seconds{dynamo_namespace="test-ns",model="TestModel",worker_type="prefill",worker_id="w1"} 0.25\n'
'dynamo_frontend_worker_last_input_sequence_tokens{dynamo_namespace="test-ns",model="TestModel",worker_type="prefill",worker_id="w1"} 3000\n'
'dynamo_frontend_worker_last_inter_token_latency_seconds{dynamo_namespace="test-ns",model="TestModel",worker_type="decode",worker_id="w2"} 0.04\n'
)
result = client._parse_prometheus_text(text)
assert "prefill" in result
assert "w1" in result["prefill"]
assert result["prefill"]["w1"]["active_prefill_tokens"] == 1234.0
assert abs(result["prefill"]["w1"]["last_ttft"] - 0.25) < 1e-6
assert result["prefill"]["w1"]["last_isl"] == 3000.0
assert "decode" in result
assert "w2" in result["decode"]
assert result["decode"]["w2"]["active_decode_blocks"] == 56.0
assert abs(result["decode"]["w2"]["last_itl"] - 0.04) < 1e-6
def test_parse_ignores_extra_labels(self):
"""Parser extracts metrics regardless of extra labels like dynamo_namespace/model."""
client = DirectRouterMetricsClient("http://localhost:8000/metrics", "ns")
text = 'dynamo_frontend_worker_active_prefill_tokens{dynamo_namespace="any-ns",model="mymodel",worker_type="prefill",worker_id="w1"} 100\n'
result = client._parse_prometheus_text(text)
assert "prefill" in result
assert "w1" in result["prefill"]
assert result["prefill"]["w1"]["active_prefill_tokens"] == 100.0
def test_get_recent_and_averaged_empty_buffer(self):
client = DirectRouterMetricsClient("http://localhost:8000/metrics", "ns")
assert client.get_recent_and_averaged_metrics("prefill") is None
def test_get_recent_and_averaged_single_sample(self):
client = DirectRouterMetricsClient("http://localhost:8000/metrics", "ns")
client._sample_buffer = [
{
"prefill": {"w1": {"active_prefill_tokens": 100.0}},
"decode": {"w2": {"active_decode_blocks": 50.0}},
}
]
result = client.get_recent_and_averaged_metrics("prefill")
assert result is not None
recent, per_worker_avg, cluster_avg = result
assert recent["w1"]["active_prefill_tokens"] == 100.0
assert per_worker_avg["w1"]["active_prefill_tokens"] == 100.0
assert cluster_avg["active_prefill_tokens"] == 100.0
# decode workers not included
assert "w2" not in recent
result_d = client.get_recent_and_averaged_metrics("decode")
assert result_d is not None
recent_d, per_worker_avg_d, cluster_avg_d = result_d
assert recent_d["w2"]["active_decode_blocks"] == 50.0
assert per_worker_avg_d["w2"]["active_decode_blocks"] == 50.0
assert cluster_avg_d["active_decode_blocks"] == 50.0
def test_get_recent_and_averaged_multiple_samples(self):
client = DirectRouterMetricsClient("http://localhost:8000/metrics", "ns")
client._sample_buffer = [
{"prefill": {"w1": {"active_prefill_tokens": 100.0}}},
{"prefill": {"w1": {"active_prefill_tokens": 200.0}}},
{"prefill": {"w1": {"active_prefill_tokens": 300.0}}},
]
result = client.get_recent_and_averaged_metrics("prefill")
assert result is not None
recent, per_worker_avg, cluster_avg = result
# Recent should be the last sample
assert abs(recent["w1"]["active_prefill_tokens"] - 300.0) < 1e-6
# Per-worker averaged over time
assert abs(per_worker_avg["w1"]["active_prefill_tokens"] - 200.0) < 1e-6
# Cluster averaged (same as per-worker when only 1 worker)
assert abs(cluster_avg["active_prefill_tokens"] - 200.0) < 1e-6
def test_parse_multiple_workers(self):
client = DirectRouterMetricsClient("http://localhost:8000/metrics", "ns")
text = (
'dynamo_frontend_worker_active_prefill_tokens{dynamo_namespace="ns",model="M",worker_type="prefill",worker_id="w1"} 100\n'
'dynamo_frontend_worker_active_prefill_tokens{dynamo_namespace="ns",model="M",worker_type="prefill",worker_id="w2"} 200\n'
) )
result = client._parse_prometheus_text(text) assert est is not None
assert len(result.get("prefill", {})) == 2 assert est > 0
assert result["prefill"]["w1"]["active_prefill_tokens"] == 100.0
assert result["prefill"]["w2"]["active_prefill_tokens"] == 200.0 def test_chunked_ttft_simulation(self):
model = PrefillRegressionModel(window_size=50, min_observations=3)
def test_parse_rust_labels_separates_worker_types(self): # Simple: wall_time = 0.001 * prefill_tokens (slope=0.001, intercept≈0)
"""Rust KV router emits all metrics for all workers; parser must separate by worker_type.""" for tokens in [100, 200, 300, 400, 500]:
client = DirectRouterMetricsClient("http://localhost:8000/metrics", "ns") fpm = _make_fpm(
text = ( sum_prefill_tokens=tokens,
"# HELP dynamo_frontend_worker_active_prefill_tokens Active prefill tokens\n" num_prefill_requests=1,
"# TYPE dynamo_frontend_worker_active_prefill_tokens gauge\n" wall_time=0.001 * tokens,
'dynamo_frontend_worker_active_prefill_tokens{worker_id="123",dp_rank="0",worker_type="prefill"} 500\n' )
'dynamo_frontend_worker_active_prefill_tokens{worker_id="456",dp_rank="0",worker_type="decode"} 0\n' model.add_observation(fpm)
'dynamo_frontend_worker_active_decode_blocks{worker_id="123",dp_rank="0",worker_type="prefill"} 0\n'
'dynamo_frontend_worker_active_decode_blocks{worker_id="456",dp_rank="0",worker_type="decode"} 30\n' # avg_isl = mean([100,200,300,400,500]) = 300
'dynamo_frontend_worker_last_time_to_first_token_seconds{worker_id="123",dp_rank="0",worker_type="prefill"} 0.15\n' # total_tokens = 5000 (queued) + 300 (next ISL) = 5300
'dynamo_frontend_worker_last_input_sequence_tokens{worker_id="123",dp_rank="0",worker_type="prefill"} 2000\n' # max_num_batched_tokens = 2048
'dynamo_frontend_worker_last_inter_token_latency_seconds{worker_id="456",dp_rank="0",worker_type="decode"} 0.03\n' # iterations: ceil(5300/2048) = 3
# chunk1=2048, chunk2=2048, chunk3=1204
est = model.estimate_next_ttft(
queued_prefill_tokens=5000, max_num_batched_tokens=2048
) )
result = client._parse_prometheus_text(text) assert est is not None
assert est > 0.003 # at least 3 iterations worth
def test_avg_isl_tracking(self):
model = PrefillRegressionModel(window_size=50, min_observations=3)
for isl in [1000, 2000, 3000]:
fpm = _make_fpm(
sum_prefill_tokens=isl, num_prefill_requests=1, wall_time=0.01
)
model.add_observation(fpm)
assert abs(model.avg_isl - 2000.0) < 1.0
def test_sliding_window_eviction(self):
model = PrefillRegressionModel(window_size=5, min_observations=3)
for i in range(10):
fpm = _make_fpm(sum_prefill_tokens=100 * (i + 1), wall_time=0.01)
model.add_observation(fpm)
assert model.num_observations == 5
# Prefill worker 123 grouped under "prefill"
assert "prefill" in result
assert "123" in result["prefill"]
assert result["prefill"]["123"]["active_prefill_tokens"] == 500.0
assert result["prefill"]["123"]["last_ttft"] == 0.15
assert result["prefill"]["123"]["last_isl"] == 2000.0
# Decode worker 456 grouped under "decode" # ── DecodeRegressionModel tests ──────────────────────────────────────
assert "decode" in result
assert "456" in result["decode"]
assert result["decode"]["456"]["active_decode_blocks"] == 30.0
assert abs(result["decode"]["456"]["last_itl"] - 0.03) < 1e-6
# Cross-type metrics are stored under the correct worker_type
# (prefill worker's decode_blocks=0 stored under "prefill", not "decode")
assert "456" not in result["prefill"]
assert "123" not in result["decode"]
class TestDecodeRegressionModel:
def test_insufficient_data(self):
model = DecodeRegressionModel(window_size=50, min_observations=5)
assert not model.has_sufficient_data()
assert model.estimate_next_itl(0, 0) is None
def test_heartbeat_skipped(self):
model = DecodeRegressionModel(window_size=50, min_observations=3)
fpm = _make_fpm(wall_time=0.0, sum_decode_kv_tokens=100, num_decode_requests=1)
model.add_observation(fpm)
assert model.num_observations == 0
def test_basic_itl_estimate(self):
model = DecodeRegressionModel(window_size=50, min_observations=3)
# wall_time = 0.0001 * decode_kv + 0.001
for kv in [1000, 2000, 3000, 4000, 5000]:
fpm = _make_fpm(
sum_decode_kv_tokens=kv,
num_decode_requests=10,
wall_time=0.0001 * kv + 0.001,
)
model.add_observation(fpm)
assert model.has_sufficient_data()
est = model.estimate_next_itl(scheduled_decode_kv=3000, queued_decode_kv=0)
assert est is not None
assert est > 0
def test_avg_decode_length_tracking(self):
model = DecodeRegressionModel(window_size=50, min_observations=3)
for total_kv, num_req in [(1000, 10), (2000, 10), (3000, 10)]:
fpm = _make_fpm(
sum_decode_kv_tokens=total_kv,
num_decode_requests=num_req,
wall_time=0.01,
)
model.add_observation(fpm)
assert abs(model.avg_decode_length - 200.0) < 1.0
# ── PrefillPlanner load-based scaling tests ─────────────────────────────
# ── AggRegressionModel tests ─────────────────────────────────────────
class TestAggRegressionModel:
def test_insufficient_data(self):
model = AggRegressionModel(window_size=50, min_observations=5)
assert not model.has_sufficient_data()
assert model.estimate_next_ttft(0, 2048, 0) is None
assert model.estimate_next_itl(0, 0) is None
def test_heartbeat_skipped(self):
model = AggRegressionModel(window_size=50, min_observations=3)
fpm = _make_fpm(wall_time=0.0, sum_prefill_tokens=100, sum_decode_kv_tokens=200)
model.add_observation(fpm)
assert model.num_observations == 0
def test_2d_regression(self):
model = AggRegressionModel(window_size=50, min_observations=3)
# wall_time = 0.001 * prefill + 0.0001 * decode_kv + 0.001
for p, d in [(100, 1000), (200, 2000), (300, 3000), (400, 4000), (500, 5000)]:
fpm = _make_fpm(
sum_prefill_tokens=p,
num_prefill_requests=1,
sum_decode_kv_tokens=d,
num_decode_requests=10,
wall_time=0.001 * p + 0.0001 * d + 0.001,
)
model.add_observation(fpm)
assert model.has_sufficient_data()
ttft = model.estimate_next_ttft(
queued_prefill_tokens=0,
max_num_batched_tokens=2048,
current_decode_kv=3000,
)
assert ttft is not None
assert ttft > 0
itl = model.estimate_next_itl(scheduled_decode_kv=3000, queued_decode_kv=0)
assert itl is not None
assert itl > 0
# ── Planner integration tests (with mocked FPM subscriber) ──────────
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -253,7 +264,6 @@ def _build_load_config(**overrides) -> PlannerConfig: ...@@ -253,7 +264,6 @@ def _build_load_config(**overrides) -> PlannerConfig:
mode="disagg", mode="disagg",
enable_load_scaling=True, enable_load_scaling=True,
enable_throughput_scaling=True, enable_throughput_scaling=True,
load_router_metrics_url="http://router:8000/metrics",
load_adjustment_interval=5, load_adjustment_interval=5,
load_learning_window=50, load_learning_window=50,
load_scaling_down_sensitivity=80, load_scaling_down_sensitivity=80,
...@@ -264,219 +274,155 @@ def _build_load_config(**overrides) -> PlannerConfig: ...@@ -264,219 +274,155 @@ def _build_load_config(**overrides) -> PlannerConfig:
return PlannerConfig.model_construct(**defaults) return PlannerConfig.model_construct(**defaults)
def _avg(per_worker: dict[str, dict[str, float]]) -> dict[str, float]: def _mock_fpm_subscriber(fpm_stats: dict[tuple[str, int], ForwardPassMetrics]):
"""Compute flat averaged metrics from per-worker dicts (for test convenience).""" """Create a mock FPM subscriber that returns encoded FPM stats."""
sums: dict[str, float] = {} mock = Mock()
counts: dict[str, int] = {} encoded = {k: encode(v) for k, v in fpm_stats.items()}
for metrics in per_worker.values(): mock.get_recent_stats.return_value = encoded
for k, v in metrics.items(): return mock
sums[k] = sums.get(k, 0.0) + v
counts[k] = counts.get(k, 0) + 1
return {k: sums[k] / counts[k] for k in sums}
class TestPrefillLoadBasedScaling: class TestPrefillFpmScaling:
def test_scale_up_all_workers_above_target(self): def test_scale_up_all_engines_above_sla(self):
"""When all workers have active_prefill_tokens above the regression target, scale up.""" """All engines have high queued prefill -> estimated TTFT > SLA -> scale up."""
config = _build_load_config() config = _build_load_config(ttft=5.0) # 5ms SLA (easy to exceed)
shared_state = PlannerSharedState() shared_state = PlannerSharedState()
shared_state.num_p_workers = 2 shared_state.num_p_workers = 2
planner = PrefillPlanner(None, config, shared_state=shared_state) planner = PrefillPlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model" planner.model_name = "test-model"
planner.prefill_worker_info = WorkerInfo(max_num_batched_tokens=2048)
# Feed regression data: TTFT = 0.1 * (active_prefill_tokens + ISL) + 100
# With TTFT SLA = 500ms: x_sla = (500 - 100) / 0.1 = 4000 # Train regression: wall_time grows linearly with prefill tokens
# If ISL avg = 3000, target_active_tokens = 4000 - 3000 = 1000 for tokens in range(200, 1200, 100):
for i in range(10): fpm = _make_fpm(
x = 2000 + i * 200 # active_tokens + ISL sum_prefill_tokens=tokens,
y = 0.1 * x + 100 # TTFT in ms num_prefill_requests=1,
planner.ttft_regression.add_observation(x, y) wall_time=0.001 * tokens,
)
# Set per-worker metrics: all workers ABOVE target (1000) planner.ttft_regression.add_observation(fpm)
metrics = {
"w1": { # Both engines have heavy queued prefill -> high estimated TTFT
"active_prefill_tokens": 1500.0, stats = {
"last_isl": 3000.0, ("w1", 0): _make_fpm(
"last_ttft": 0.35, worker_id="w1",
}, queued_prefill_tokens=10000,
"w2": { sum_prefill_tokens=500,
"active_prefill_tokens": 1200.0, num_prefill_requests=1,
"last_isl": 3000.0, wall_time=0.5,
"last_ttft": 0.30, ),
}, ("w2", 0): _make_fpm(
worker_id="w2",
queued_prefill_tokens=8000,
sum_prefill_tokens=600,
num_prefill_requests=1,
wall_time=0.6,
),
} }
planner.cached_load_metrics = CachedLoadMetrics( planner.fpm_subscriber = _mock_fpm_subscriber(stats)
recent=metrics, per_worker_averaged=metrics, cluster_averaged=_avg(metrics)
)
result = planner.load_plan_adjustment() result = planner.load_plan_adjustment()
assert result == 3 # scale up from 2 to 3 assert result == 3
def test_scale_down_all_workers_below_boundary(self): def test_scale_down_all_engines_below_sla(self):
"""When all workers are below the scale-down boundary, scale down.""" """All engines have low queued prefill -> estimated TTFT < SLA * sensitivity."""
config = _build_load_config(load_scaling_down_sensitivity=100) config = _build_load_config(ttft=500.0, load_scaling_down_sensitivity=100)
shared_state = PlannerSharedState() shared_state = PlannerSharedState()
shared_state.num_p_workers = 3 shared_state.num_p_workers = 3
planner = PrefillPlanner(None, config, shared_state=shared_state) planner = PrefillPlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model" planner.model_name = "test-model"
planner.prefill_worker_info = WorkerInfo(max_num_batched_tokens=2048)
# Feed regression: TTFT = 0.1 * x + 100
# x_sla = (500-100)/0.1 = 4000, target = 4000-3000 = 1000 # Train with short ISL (100 tokens each) so avg_isl stays low.
# boundary = 1000 * (3-1)/3 * 1.0 = 666.67 # Regression: wall_time ≈ 0.001 * prefill_tokens
for i in range(10): for tokens in range(100, 600, 50):
x = 2000 + i * 200 fpm = _make_fpm(
y = 0.1 * x + 100 sum_prefill_tokens=tokens,
planner.ttft_regression.add_observation(x, y) num_prefill_requests=1,
wall_time=0.001 * tokens,
# All workers below boundary (666.67) )
metrics = { planner.ttft_regression.add_observation(fpm)
"w1": {
"active_prefill_tokens": 100.0, # All engines idle (no queued prefill).
"last_isl": 3000.0, # estimate_next_ttft: total = 0 + avg_isl(~100) = ~100 tokens
"last_ttft": 0.15, # predicted wall_time ≈ 0.001 * 100 = 0.1s = 100ms < 500ms SLA
}, stats = {
"w2": { (f"w{i}", 0): _make_fpm(
"active_prefill_tokens": 200.0, worker_id=f"w{i}",
"last_isl": 3000.0, queued_prefill_tokens=0,
"last_ttft": 0.16, sum_prefill_tokens=100,
}, num_prefill_requests=1,
"w3": { wall_time=0.1,
"active_prefill_tokens": 150.0, )
"last_isl": 3000.0, for i in range(3)
"last_ttft": 0.15,
},
}
planner.cached_load_metrics = CachedLoadMetrics(
recent=metrics, per_worker_averaged=metrics, cluster_averaged=_avg(metrics)
)
result = planner.load_plan_adjustment()
assert result == 2 # scale down from 3 to 2
def test_no_change_mixed_workers(self):
"""When workers are mixed (some above, some below), no scaling."""
config = _build_load_config()
shared_state = PlannerSharedState()
shared_state.num_p_workers = 2
planner = PrefillPlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
for i in range(10):
x = 2000 + i * 200
y = 0.1 * x + 100
planner.ttft_regression.add_observation(x, y)
# Mixed: one above target, one below
metrics = {
"w1": {
"active_prefill_tokens": 1500.0,
"last_isl": 3000.0,
"last_ttft": 0.35,
},
"w2": {
"active_prefill_tokens": 100.0,
"last_isl": 3000.0,
"last_ttft": 0.15,
},
} }
planner.cached_load_metrics = CachedLoadMetrics( planner.fpm_subscriber = _mock_fpm_subscriber(stats)
recent=metrics, per_worker_averaged=metrics, cluster_averaged=_avg(metrics)
)
result = planner.load_plan_adjustment() result = planner.load_plan_adjustment()
assert result is None assert result == 2
def test_cold_start_returns_none(self): def test_cold_start_returns_none(self):
"""With insufficient data, load_plan_adjustment returns None."""
config = _build_load_config() config = _build_load_config()
shared_state = PlannerSharedState() shared_state = PlannerSharedState()
shared_state.num_p_workers = 2 shared_state.num_p_workers = 2
planner = PrefillPlanner(None, config, shared_state=shared_state) planner = PrefillPlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model" planner.model_name = "test-model"
planner.prefill_worker_info = WorkerInfo(max_num_batched_tokens=2048)
# Only 2 observations (min is 5) # Only 2 observations, need 5
planner.ttft_regression.add_observation(1000.0, 200.0) for tokens in [100, 200]:
planner.ttft_regression.add_observation(2000.0, 300.0) fpm = _make_fpm(sum_prefill_tokens=tokens, wall_time=0.01)
planner.ttft_regression.add_observation(fpm)
metrics = { stats = {("w1", 0): _make_fpm(queued_prefill_tokens=5000, wall_time=0.5)}
"w1": { planner.fpm_subscriber = _mock_fpm_subscriber(stats)
"active_prefill_tokens": 5000.0,
"last_isl": 3000.0,
"last_ttft": 0.5,
},
}
planner.cached_load_metrics = CachedLoadMetrics(
recent=metrics, per_worker_averaged=metrics, cluster_averaged=_avg(metrics)
)
result = planner.load_plan_adjustment() result = planner.load_plan_adjustment()
assert result is None assert result is None
class TestDecodeLoadBasedScaling: class TestDecodeFpmScaling:
def test_scale_up_all_workers_above_target(self): def test_scale_up_all_engines_above_sla(self):
"""When all workers have active_decode_blocks above x_sla, scale up.""" """All engines have high decode load -> estimated ITL > SLA -> scale up."""
config = _build_load_config() config = _build_load_config(itl=5.0) # 5ms SLA
shared_state = PlannerSharedState() shared_state = PlannerSharedState()
shared_state.num_d_workers = 2 shared_state.num_d_workers = 2
planner = DecodePlanner(None, config, shared_state=shared_state) planner = DecodePlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model" planner.model_name = "test-model"
# Feed regression: ITL = 0.5 * active_decode_blocks + 10 for kv in range(1000, 6000, 500):
# x_sla = (50 - 10) / 0.5 = 80 fpm = _make_fpm(
for i in range(10): sum_decode_kv_tokens=kv,
x = 20 + i * 10 num_decode_requests=10,
y = 0.5 * x + 10 wall_time=0.0001 * kv + 0.001,
planner.itl_regression.add_observation(x, y) )
planner.itl_regression.add_observation(fpm)
# All workers above x_sla (80)
metrics = { stats = {
"w1": {"active_decode_blocks": 100.0, "last_itl": 0.06}, ("w1", 0): _make_fpm(
"w2": {"active_decode_blocks": 95.0, "last_itl": 0.055}, worker_id="w1",
sum_decode_kv_tokens=5000,
queued_decode_kv_tokens=3000,
num_decode_requests=20,
wall_time=0.6,
),
("w2", 0): _make_fpm(
worker_id="w2",
sum_decode_kv_tokens=4500,
queued_decode_kv_tokens=2500,
num_decode_requests=18,
wall_time=0.55,
),
} }
planner.cached_load_metrics = CachedLoadMetrics( planner.fpm_subscriber = _mock_fpm_subscriber(stats)
recent=metrics, per_worker_averaged=metrics, cluster_averaged=_avg(metrics)
)
result = planner.load_plan_adjustment() result = planner.load_plan_adjustment()
assert result == 3 assert result == 3
def test_scale_down_all_workers_below_boundary(self):
"""When all decode workers are below boundary, scale down."""
config = _build_load_config(load_scaling_down_sensitivity=100)
shared_state = PlannerSharedState()
shared_state.num_d_workers = 3
planner = DecodePlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
# ITL = 0.5 * x + 10, x_sla = (50-10)/0.5 = 80
# boundary = 80 * (3-1)/3 * 1.0 = 53.33
for i in range(10):
x = 20 + i * 10
y = 0.5 * x + 10
planner.itl_regression.add_observation(x, y)
# All workers below boundary (53.33)
metrics = {
"w1": {"active_decode_blocks": 10.0, "last_itl": 0.02},
"w2": {"active_decode_blocks": 15.0, "last_itl": 0.025},
"w3": {"active_decode_blocks": 20.0, "last_itl": 0.03},
}
planner.cached_load_metrics = CachedLoadMetrics(
recent=metrics, per_worker_averaged=metrics, cluster_averaged=_avg(metrics)
)
result = planner.load_plan_adjustment()
assert result == 2
def test_cold_start_returns_none(self): def test_cold_start_returns_none(self):
"""Decode cold start also returns None."""
config = _build_load_config() config = _build_load_config()
shared_state = PlannerSharedState() shared_state = PlannerSharedState()
shared_state.num_d_workers = 2 shared_state.num_d_workers = 2
...@@ -484,257 +430,40 @@ class TestDecodeLoadBasedScaling: ...@@ -484,257 +430,40 @@ class TestDecodeLoadBasedScaling:
planner = DecodePlanner(None, config, shared_state=shared_state) planner = DecodePlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model" planner.model_name = "test-model"
planner.itl_regression.add_observation(10.0, 15.0) fpm = _make_fpm(sum_decode_kv_tokens=1000, wall_time=0.01)
planner.itl_regression.add_observation(fpm)
metrics = { stats = {("w1", 0): _make_fpm(sum_decode_kv_tokens=5000, wall_time=0.5)}
"w1": {"active_decode_blocks": 200.0, "last_itl": 0.1}, planner.fpm_subscriber = _mock_fpm_subscriber(stats)
}
planner.cached_load_metrics = CachedLoadMetrics(
recent=metrics, per_worker_averaged=metrics, cluster_averaged=_avg(metrics)
)
result = planner.load_plan_adjustment() result = planner.load_plan_adjustment()
assert result is None assert result is None
class TestLowerBoundEnforcement:
def test_throughput_lower_bound_respected(self):
"""Load-based scaling should never go below throughput lower bound."""
config = _build_load_config()
shared_state = PlannerSharedState()
shared_state.num_p_workers = 5
# Throughput says we need at least 4 prefill workers
shared_state.throughput_lower_bound_p = 4
planner = PrefillPlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
# Regression says we should scale down to 4 (from 5)
for i in range(10):
x = 2000 + i * 200
y = 0.1 * x + 100
planner.ttft_regression.add_observation(x, y)
# Workers all lightly loaded => wants to scale down to 4
metrics = {
f"w{i}": {
"active_prefill_tokens": 50.0,
"last_isl": 3000.0,
"last_ttft": 0.12,
}
for i in range(5)
}
planner.cached_load_metrics = CachedLoadMetrics(
recent=metrics, per_worker_averaged=metrics, cluster_averaged=_avg(metrics)
)
result = planner.load_plan_adjustment()
# Even though load-based wants to scale down, the result should be
# at least 4 after lower bound enforcement (done in the loop, not in
# load_plan_adjustment itself)
# load_plan_adjustment returns raw desired value
assert result == 4 # raw value from load-based
def test_scaling_down_sensitivity_zero_never_scales_down(self):
"""With sensitivity=0, scale-down boundary is 0 so never scale down."""
config = _build_load_config(load_scaling_down_sensitivity=0)
shared_state = PlannerSharedState()
shared_state.num_p_workers = 3
planner = PrefillPlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
for i in range(10):
x = 2000 + i * 200
y = 0.1 * x + 100
planner.ttft_regression.add_observation(x, y)
# All workers at zero load
metrics = {
f"w{i}": {
"active_prefill_tokens": 0.0,
"last_isl": 3000.0,
"last_ttft": 0.12,
}
for i in range(3)
}
planner.cached_load_metrics = CachedLoadMetrics(
recent=metrics, per_worker_averaged=metrics, cluster_averaged=_avg(metrics)
)
# boundary = target * (3-1)/3 * 0/100 = 0
# all workers at 0 which is NOT less than 0 (it's equal)
result = planner.load_plan_adjustment()
assert result is None # no scaling happens
# ── Correction factor auto-disable tests ───────────────────────────── # ── Correction factor auto-disable tests ─────────────────────────────
class TestCorrectionFactorAutoDisable: class TestCorrectionFactorAutoDisable:
def test_correction_factor_disabled_when_load_enabled(self): def test_correction_factor_disabled_when_load_enabled(self):
"""Correction factor should be auto-disabled when load-based scaling is on."""
config = PlannerConfig( config = PlannerConfig(
enable_load_scaling=True, enable_load_scaling=True,
enable_throughput_scaling=True, enable_throughput_scaling=True,
no_correction=False, no_correction=False,
load_router_metrics_url="http://router:8000/metrics",
) )
assert config.no_correction is True assert config.no_correction is True
def test_correction_factor_stays_disabled_if_already_set(self): def test_correction_factor_stays_disabled_if_already_set(self):
"""If user already set no_correction, it stays True."""
config = PlannerConfig( config = PlannerConfig(
enable_load_scaling=True, enable_load_scaling=True,
enable_throughput_scaling=True, enable_throughput_scaling=True,
no_correction=True, no_correction=True,
load_router_metrics_url="http://router:8000/metrics",
) )
assert config.no_correction is True assert config.no_correction is True
def test_correction_factor_not_disabled_without_loadbased(self): def test_correction_factor_not_disabled_without_loadbased(self):
"""Without load-based scaling, correction factor should respect user setting."""
config = PlannerConfig( config = PlannerConfig(
enable_load_scaling=False, enable_load_scaling=False,
enable_throughput_scaling=True, enable_throughput_scaling=True,
no_correction=False, no_correction=False,
) )
assert config.no_correction is False assert config.no_correction is False
# ── DGD worker count reconciliation tests ────────────────────────────
class TestWorkerCountReconciliation:
@pytest.mark.asyncio
async def test_prefill_observe_gets_only_prefill_workers(self):
"""observe_engine_load_stats for prefill queries get_recent_and_averaged_metrics('prefill')."""
config = _build_load_config()
shared_state = PlannerSharedState()
shared_state.num_p_workers = 1
planner = PrefillPlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
# get_recent_and_averaged_metrics("prefill") returns (recent, per_worker_avg, cluster_avg)
prefill_metrics = {
"w1": {
"active_prefill_tokens": 500.0,
"last_ttft": 0.2,
"last_isl": 3000.0,
},
}
planner.prometheus_engine_client = Mock()
planner.prometheus_engine_client.get_recent_and_averaged_metrics.return_value = (
prefill_metrics,
prefill_metrics,
_avg(prefill_metrics),
)
await planner.observe_engine_load_stats()
planner.prometheus_engine_client.get_recent_and_averaged_metrics.assert_called_once_with(
"prefill"
)
assert len(planner.cached_load_metrics.recent) == 1
assert "w1" in planner.cached_load_metrics.recent
@pytest.mark.asyncio
async def test_decode_observe_gets_only_decode_workers(self):
"""observe_engine_load_stats for decode queries get_recent_and_averaged_metrics('decode')."""
config = _build_load_config()
shared_state = PlannerSharedState()
shared_state.num_d_workers = 1
planner = DecodePlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
decode_metrics = {
"w2": {"active_decode_blocks": 50.0, "last_itl": 0.04},
}
planner.prometheus_engine_client = Mock()
planner.prometheus_engine_client.get_recent_and_averaged_metrics.return_value = (
decode_metrics,
decode_metrics,
_avg(decode_metrics),
)
await planner.observe_engine_load_stats()
planner.prometheus_engine_client.get_recent_and_averaged_metrics.assert_called_once_with(
"decode"
)
assert len(planner.cached_load_metrics.recent) == 1
assert "w2" in planner.cached_load_metrics.recent
def test_worker_count_mismatch_detected(self):
"""When DGD and Prometheus worker counts differ, the mismatch should be detectable."""
config = _build_load_config()
shared_state = PlannerSharedState()
# DGD says 3 prefill workers
shared_state.num_p_workers = 3
planner = PrefillPlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
# But router only reports 2 prefill workers
metrics = {
"w1": {
"active_prefill_tokens": 500.0,
"last_isl": 3000.0,
"last_ttft": 0.2,
},
"w2": {
"active_prefill_tokens": 600.0,
"last_isl": 3000.0,
"last_ttft": 0.25,
},
}
planner.cached_load_metrics = CachedLoadMetrics(
recent=metrics, per_worker_averaged=metrics, cluster_averaged=_avg(metrics)
)
# The mismatch should be detectable by comparing counts
prom_count = len(planner.cached_load_metrics.recent)
dgd_count = shared_state.num_p_workers
assert prom_count != dgd_count
assert prom_count == 2
assert dgd_count == 3
def test_worker_count_match_allows_scaling(self):
"""When DGD and Prometheus counts match, scaling proceeds normally."""
config = _build_load_config()
shared_state = PlannerSharedState()
shared_state.num_p_workers = 2
planner = PrefillPlanner(None, config, shared_state=shared_state)
planner.model_name = "test-model"
metrics = {
"w1": {
"active_prefill_tokens": 1500.0,
"last_isl": 3000.0,
"last_ttft": 0.35,
},
"w2": {
"active_prefill_tokens": 1200.0,
"last_isl": 3000.0,
"last_ttft": 0.30,
},
}
planner.cached_load_metrics = CachedLoadMetrics(
recent=metrics, per_worker_averaged=metrics, cluster_averaged=_avg(metrics)
)
prom_count = len(planner.cached_load_metrics.recent)
dgd_count = shared_state.num_p_workers
assert prom_count == dgd_count
# With matching counts and sufficient regression data, scaling should work
for i in range(10):
x = 2000 + i * 200
y = 0.1 * x + 100
planner.ttft_regression.add_observation(x, y)
result = planner.load_plan_adjustment()
assert result is not None # scaling proceeds
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment