"tests/vscode:/vscode.git/clone" did not exist on "05f10e93f0640d18c729208d9a54995a2d708e5d"
Unverified Commit 10b01b45 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: early rejection based on active prefill tokens (#4837)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 82577b06
...@@ -190,10 +190,16 @@ def parse_args(): ...@@ -190,10 +190,16 @@ def parse_args():
help="Enforce disaggregated prefill-decode. When set, unactivated prefill router will return an error instead of falling back to decode-only mode.", help="Enforce disaggregated prefill-decode. When set, unactivated prefill router will return an error instead of falling back to decode-only mode.",
) )
parser.add_argument( parser.add_argument(
"--busy-threshold", "--active-decode-blocks-threshold",
type=float, type=float,
default=None, default=None,
help="Threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache usage. If not set, busy detection is disabled.", help="Threshold percentage (0.0-1.0) for determining when a worker is considered busy based on KV cache block utilization. If not set, blocks-based busy detection is disabled.",
)
parser.add_argument(
"--active-prefill-tokens-threshold",
type=int,
default=None,
help="Literal token count threshold for determining when a worker is considered busy based on prefill token utilization. When active prefill tokens exceed this threshold, the worker is marked as busy. If not set, tokens-based busy detection is disabled.",
) )
parser.add_argument( parser.add_argument(
"--model-name", "--model-name",
...@@ -316,7 +322,11 @@ async def async_main(): ...@@ -316,7 +322,11 @@ async def async_main():
"http_port": flags.http_port, "http_port": flags.http_port,
"kv_cache_block_size": flags.kv_cache_block_size, "kv_cache_block_size": flags.kv_cache_block_size,
"router_config": RouterConfig( "router_config": RouterConfig(
router_mode, kv_router_config, flags.busy_threshold, flags.enforce_disagg router_mode,
kv_router_config,
active_decode_blocks_threshold=flags.active_decode_blocks_threshold,
active_prefill_tokens_threshold=flags.active_prefill_tokens_threshold,
enforce_disagg=flags.enforce_disagg,
), ),
} }
......
...@@ -20,7 +20,6 @@ import os ...@@ -20,7 +20,6 @@ import os
from typing import Optional from typing import Optional
import numpy as np import numpy as np
import scipy
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -80,6 +79,9 @@ class PrefillInterpolator: ...@@ -80,6 +79,9 @@ class PrefillInterpolator:
self.min_isl = min(self.prefill_isl) self.min_isl = min(self.prefill_isl)
self.max_isl = max(self.prefill_isl) self.max_isl = max(self.prefill_isl)
# Lazy import scipy only when interpolation is actually needed
import scipy.interpolate
# perform 1d interpolation # perform 1d interpolation
self.ttft_interpolator = scipy.interpolate.interp1d( self.ttft_interpolator = scipy.interpolate.interp1d(
self.prefill_isl, self.prefill_ttft, kind="cubic" self.prefill_isl, self.prefill_ttft, kind="cubic"
...@@ -151,6 +153,9 @@ class DecodeInterpolator: ...@@ -151,6 +153,9 @@ class DecodeInterpolator:
self.yi = np.linspace(0, max(self.y_context_length), resolution) self.yi = np.linspace(0, max(self.y_context_length), resolution)
self.X, self.Y = np.meshgrid(self.xi, self.yi) self.X, self.Y = np.meshgrid(self.xi, self.yi)
# Lazy import scipy only when interpolation is actually needed
import scipy.interpolate
# perform 2d interpolation with fallback for NaN values # perform 2d interpolation with fallback for NaN values
self.itl_interpolator = scipy.interpolate.griddata( self.itl_interpolator = scipy.interpolate.griddata(
(self.x_kv_usage, self.y_context_length), (self.x_kv_usage, self.y_context_length),
......
...@@ -31,7 +31,9 @@ The main KV-aware routing arguments: ...@@ -31,7 +31,9 @@ The main KV-aware routing arguments:
- `--no-track-active-blocks`: Disables tracking of active blocks (blocks being used for ongoing generation/decode phases). By default, the router tracks active blocks for load balancing. Disable this when routing to workers that only perform prefill (no decode phase), as tracking decode load is not relevant. This reduces router overhead and simplifies state management. - `--no-track-active-blocks`: Disables tracking of active blocks (blocks being used for ongoing generation/decode phases). By default, the router tracks active blocks for load balancing. Disable this when routing to workers that only perform prefill (no decode phase), as tracking decode load is not relevant. This reduces router overhead and simplifies state management.
- `--busy-threshold`: Initial threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache usage. When a worker's KV cache active blocks exceed this percentage of total blocks, it will be marked as busy and excluded from routing. If not set, busy detection is disabled. This feature works with all routing modes (`--router-mode kv|round-robin|random`) as long as backend engines emit `ForwardPassMetrics`. The threshold can be dynamically updated at runtime via the `/busy_threshold` HTTP endpoint (see [Dynamic Threshold Configuration](#dynamic-threshold-configuration)). - `--active-decode-blocks-threshold`: Initial threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache block utilization. When a worker's KV cache active blocks exceed this percentage of total blocks, it will be marked as busy and excluded from routing. If not set, blocks-based busy detection is disabled. This feature works with all routing modes (`--router-mode kv|round-robin|random`) as long as backend engines emit `ForwardPassMetrics`. The threshold can be dynamically updated at runtime via the `/busy_threshold` HTTP endpoint (see [Dynamic Threshold Configuration](#dynamic-threshold-configuration)).
- `--active-prefill-tokens-threshold`: Literal token count threshold for determining when a worker is considered busy based on prefill token utilization. When active prefill tokens exceed this threshold, the worker is marked as busy. If not set, tokens-based busy detection is disabled.
- `--router-ttl`: Time-to-live in seconds for blocks in the router's local cache predictions. Blocks older than this duration will be automatically expired and removed from the router's radix tree. Defaults to 120.0 seconds when `--no-kv-events` is used. This helps manage memory usage by removing stale cache predictions that are unlikely to be accurate. - `--router-ttl`: Time-to-live in seconds for blocks in the router's local cache predictions. Blocks older than this duration will be automatically expired and removed from the router's radix tree. Defaults to 120.0 seconds when `--no-kv-events` is used. This helps manage memory usage by removing stale cache predictions that are unlikely to be accurate.
...@@ -585,28 +587,32 @@ See [KV Router Architecture](../router/README.md) for performance tuning details ...@@ -585,28 +587,32 @@ See [KV Router Architecture](../router/README.md) for performance tuning details
## Dynamic Threshold Configuration ## Dynamic Threshold Configuration
The busy threshold can be updated at runtime without restarting the frontend. The frontend exposes HTTP endpoints at `/busy_threshold`: The busy thresholds can be updated at runtime without restarting the frontend. The frontend exposes HTTP endpoints at `/busy_threshold`:
**Get or set a model's threshold (POST):** **Get or set a model's thresholds (POST):**
```bash ```bash
# Set threshold for a model # Set both thresholds for a model
curl -X POST http://localhost:8000/busy_threshold \ curl -X POST http://localhost:8000/busy_threshold \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{"model": "meta-llama/Llama-2-7b-hf", "threshold": 0.85}' -d '{"model": "meta-llama/Llama-2-7b-hf", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}'
# Response: {"model": "meta-llama/Llama-2-7b-hf", "threshold": 0.85} # Response: {"model": "meta-llama/Llama-2-7b-hf", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
# Get current threshold (omit threshold field) # Set only active decode blocks threshold
curl -X POST http://localhost:8000/busy_threshold \
-H "Content-Type: application/json" \
-d '{"model": "meta-llama/Llama-2-7b-hf", "active_decode_blocks_threshold": 0.85}'
# Response: {"model": "meta-llama/Llama-2-7b-hf", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": <current_value>}
# Get current thresholds (omit threshold fields)
curl -X POST http://localhost:8000/busy_threshold \ curl -X POST http://localhost:8000/busy_threshold \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{"model": "meta-llama/Llama-2-7b-hf"}' -d '{"model": "meta-llama/Llama-2-7b-hf"}'
# Response: {"model": "meta-llama/Llama-2-7b-hf", "threshold": 0.85} # Response: {"model": "meta-llama/Llama-2-7b-hf", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
# Or if not configured: {"model": "...", "threshold": null} # Or if not configured: {"model": "...", "active_decode_blocks_threshold": null, "active_prefill_tokens_threshold": null}
``` ```
**List all configured thresholds (GET):** **List all configured thresholds (GET):**
```bash ```bash
curl http://localhost:8000/busy_threshold curl http://localhost:8000/busy_threshold
# Response: {"thresholds": [{"model": "meta-llama/Llama-2-7b-hf", "threshold": 0.85}]} # Response: {"thresholds": [{"model": "meta-llama/Llama-2-7b-hf", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}]}
``` ```
This allows you to tune the busy threshold based on observed system behavior without service interruption.
\ No newline at end of file
...@@ -966,7 +966,9 @@ pub async fn create_worker_selection_pipeline_chat( ...@@ -966,7 +966,9 @@ pub async fn create_worker_selection_pipeline_chat(
let router_config = dynamo_llm::entrypoint::RouterConfig { let router_config = dynamo_llm::entrypoint::RouterConfig {
router_mode, router_mode,
kv_router_config: kv_router_config.unwrap_or_default(), kv_router_config: kv_router_config.unwrap_or_default(),
busy_threshold, // C bindings only support active_decode_blocks_threshold for now (via busy_threshold param)
active_decode_blocks_threshold: busy_threshold,
active_prefill_tokens_threshold: None,
enforce_disagg: false, enforce_disagg: false,
}; };
let watcher = ModelWatcher::new( let watcher = ModelWatcher::new(
...@@ -1031,7 +1033,8 @@ pub async fn create_worker_selection_pipeline_chat( ...@@ -1031,7 +1033,8 @@ pub async fn create_worker_selection_pipeline_chat(
// Create worker monitor if busy_threshold is set // Create worker monitor if busy_threshold is set
// Note: C bindings don't register with ModelManager, so HTTP endpoint won't see this // Note: C bindings don't register with ModelManager, so HTTP endpoint won't see this
let worker_monitor = busy_threshold.map(|t| KvWorkerMonitor::new(client.clone(), t)); // C bindings only support active_decode_blocks_threshold for now (active_prefill_tokens_threshold defaults to 1000000 tokens = effectively disabled)
let worker_monitor = busy_threshold.map(|t| KvWorkerMonitor::new(client.clone(), t, 1000000));
let engine = build_routed_pipeline::< let engine = build_routed_pipeline::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
......
...@@ -77,24 +77,29 @@ impl KvRouterConfig { ...@@ -77,24 +77,29 @@ impl KvRouterConfig {
pub struct RouterConfig { pub struct RouterConfig {
router_mode: RouterMode, router_mode: RouterMode,
kv_router_config: KvRouterConfig, kv_router_config: KvRouterConfig,
busy_threshold: Option<f64>, /// Threshold for active decode blocks utilization (0.0-1.0)
active_decode_blocks_threshold: Option<f64>,
/// Threshold for active prefill tokens utilization (literal token count)
active_prefill_tokens_threshold: Option<u64>,
enforce_disagg: bool, enforce_disagg: bool,
} }
#[pymethods] #[pymethods]
impl RouterConfig { impl RouterConfig {
#[new] #[new]
#[pyo3(signature = (mode, config=None, busy_threshold=None, enforce_disagg=false))] #[pyo3(signature = (mode, config=None, active_decode_blocks_threshold=None, active_prefill_tokens_threshold=None, enforce_disagg=false))]
pub fn new( pub fn new(
mode: RouterMode, mode: RouterMode,
config: Option<KvRouterConfig>, config: Option<KvRouterConfig>,
busy_threshold: Option<f64>, active_decode_blocks_threshold: Option<f64>,
active_prefill_tokens_threshold: Option<u64>,
enforce_disagg: bool, enforce_disagg: bool,
) -> Self { ) -> Self {
Self { Self {
router_mode: mode, router_mode: mode,
kv_router_config: config.unwrap_or_default(), kv_router_config: config.unwrap_or_default(),
busy_threshold, active_decode_blocks_threshold,
active_prefill_tokens_threshold,
enforce_disagg, enforce_disagg,
} }
} }
...@@ -105,7 +110,8 @@ impl From<RouterConfig> for RsRouterConfig { ...@@ -105,7 +110,8 @@ impl From<RouterConfig> for RsRouterConfig {
RsRouterConfig { RsRouterConfig {
router_mode: rc.router_mode.into(), router_mode: rc.router_mode.into(),
kv_router_config: rc.kv_router_config.inner, kv_router_config: rc.kv_router_config.inner,
busy_threshold: rc.busy_threshold, active_decode_blocks_threshold: rc.active_decode_blocks_threshold,
active_prefill_tokens_threshold: rc.active_prefill_tokens_threshold,
enforce_disagg: rc.enforce_disagg, enforce_disagg: rc.enforce_disagg,
} }
} }
......
...@@ -487,9 +487,11 @@ impl ModelManager { ...@@ -487,9 +487,11 @@ impl ModelManager {
/// Gets or sets the busy threshold for a model via its worker monitor. /// Gets or sets the busy threshold for a model via its worker monitor.
/// ///
/// Get or set the active decode blocks threshold for a model's worker monitor.
///
/// This is the primary API for HTTP endpoints and external callers. /// This is the primary API for HTTP endpoints and external callers.
/// The threshold (0.0 to 1.0) controls when workers are marked as "busy" /// The threshold (0.0 to 1.0) controls when workers are marked as "busy"
/// based on KV cache utilization. /// based on KV cache block utilization.
/// ///
/// # Arguments /// # Arguments
/// ///
...@@ -499,31 +501,63 @@ impl ModelManager { ...@@ -499,31 +501,63 @@ impl ModelManager {
/// # Returns /// # Returns
/// ///
/// The threshold value as f64, or `None` if no monitor exists for this model. /// The threshold value as f64, or `None` if no monitor exists for this model.
/// Note: Setting a threshold for a non-existent model returns `None` (monitor pub fn active_decode_blocks_threshold(
/// must be created via `get_or_create_worker_monitor` during model discovery). &self,
pub fn busy_threshold(&self, model: &str, threshold: Option<f64>) -> Option<f64> { model: &str,
threshold: Option<f64>,
) -> Option<f64> {
let monitors = self.worker_monitors.read(); let monitors = self.worker_monitors.read();
let monitor = monitors.get(model)?; let monitor = monitors.get(model)?;
match threshold { match threshold {
Some(value) => { Some(value) => {
monitor.set_threshold(value); monitor.set_active_decode_blocks_threshold(value);
Some(value) Some(value)
} }
None => Some(monitor.threshold()), None => Some(monitor.active_decode_blocks_threshold()),
}
}
/// Get or set the active prefill tokens threshold for a model's worker monitor.
///
/// The threshold is a literal token count (not a percentage).
///
/// # Arguments
///
/// * `model` - The model name
/// * `threshold` - `Some(value)` to set, `None` to get existing
///
/// # Returns
///
/// The threshold value as u64, or `None` if no monitor exists for this model.
pub fn active_prefill_tokens_threshold(
&self,
model: &str,
threshold: Option<u64>,
) -> Option<u64> {
let monitors = self.worker_monitors.read();
let monitor = monitors.get(model)?;
match threshold {
Some(value) => {
monitor.set_active_prefill_tokens_threshold(value);
Some(value)
}
None => Some(monitor.active_prefill_tokens_threshold()),
} }
} }
/// Gets or creates a worker monitor for a model. /// Gets or creates a worker monitor for a model.
/// ///
/// If a monitor already exists, updates its threshold and returns a clone. /// If a monitor already exists, updates its thresholds and returns a clone.
/// If no monitor exists, creates one with the given client and threshold. /// If no monitor exists, creates one with the given client and thresholds.
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `model` - The model name /// * `model` - The model name
/// * `client` - The client for subscribing to KV metrics (only used if creating new) /// * `client` - The client for subscribing to KV metrics (only used if creating new)
/// * `threshold` - The initial/updated threshold value (0.0-1.0) /// * `active_decode_blocks_threshold` - The initial/updated active decode blocks threshold value (0.0-1.0)
/// * `active_prefill_tokens_threshold` - The initial/updated active prefill tokens threshold value (literal token count)
/// ///
/// # Returns /// # Returns
/// ///
...@@ -532,15 +566,21 @@ impl ModelManager { ...@@ -532,15 +566,21 @@ impl ModelManager {
&self, &self,
model: &str, model: &str,
client: Client, client: Client,
threshold: f64, active_decode_blocks_threshold: f64,
active_prefill_tokens_threshold: u64,
) -> KvWorkerMonitor { ) -> KvWorkerMonitor {
let mut monitors = self.worker_monitors.write(); let mut monitors = self.worker_monitors.write();
if let Some(existing) = monitors.get(model) { if let Some(existing) = monitors.get(model) {
existing.set_threshold(threshold); existing.set_active_decode_blocks_threshold(active_decode_blocks_threshold);
existing.set_active_prefill_tokens_threshold(active_prefill_tokens_threshold);
existing.clone() existing.clone()
} else { } else {
let monitor = KvWorkerMonitor::new(client, threshold); let monitor = KvWorkerMonitor::new(
client,
active_decode_blocks_threshold,
active_prefill_tokens_threshold,
);
monitors.insert(model.to_string(), monitor.clone()); monitors.insert(model.to_string(), monitor.clone());
monitor monitor
} }
...@@ -553,12 +593,18 @@ impl ModelManager { ...@@ -553,12 +593,18 @@ impl ModelManager {
/// Lists all models that have worker monitors (and thus busy thresholds) configured. /// Lists all models that have worker monitors (and thus busy thresholds) configured.
/// ///
/// Returns a vector of (model_name, threshold_value) tuples. /// Returns a vector of (model_name, active_decode_blocks_threshold, active_prefill_tokens_threshold) tuples.
pub fn list_busy_thresholds(&self) -> Vec<(String, f64)> { pub fn list_busy_thresholds(&self) -> Vec<(String, f64, u64)> {
self.worker_monitors self.worker_monitors
.read() .read()
.iter() .iter()
.map(|(k, monitor)| (k.clone(), monitor.threshold())) .map(|(k, monitor)| {
(
k.clone(),
monitor.active_decode_blocks_threshold(),
monitor.active_prefill_tokens_threshold(),
)
})
.collect() .collect()
} }
} }
......
...@@ -404,10 +404,28 @@ impl ModelWatcher { ...@@ -404,10 +404,28 @@ impl ModelWatcher {
// Get or create the worker monitor for this model // Get or create the worker monitor for this model
// This allows dynamic threshold updates via the ModelManager // This allows dynamic threshold updates via the ModelManager
let worker_monitor = self.router_config.busy_threshold.map(|threshold| { // Create monitor if either threshold is configured
self.manager let worker_monitor = if self.router_config.active_decode_blocks_threshold.is_some()
.get_or_create_worker_monitor(card.name(), client.clone(), threshold) || self.router_config.active_prefill_tokens_threshold.is_some()
}); {
// Default thresholds: active_decode_blocks=1.0 (disabled), active_prefill_tokens=1000000 (effectively disabled)
let active_decode_blocks = self
.router_config
.active_decode_blocks_threshold
.unwrap_or(1.0);
let active_prefill_tokens = self
.router_config
.active_prefill_tokens_threshold
.unwrap_or(1000000);
Some(self.manager.get_or_create_worker_monitor(
card.name(),
client.clone(),
active_decode_blocks,
active_prefill_tokens,
))
} else {
None
};
// Add chat engine only if the model supports chat // Add chat engine only if the model supports chat
if card.model_type.supports_chat() { if card.model_type.supports_chat() {
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::kv_router::KV_METRICS_SUBJECT; use crate::kv_router::KV_METRICS_SUBJECT;
use crate::kv_router::scoring::LoadEvent; use crate::kv_router::protocols::ActiveLoad;
use crate::model_card::ModelDeploymentCard; use crate::model_card::ModelDeploymentCard;
use dynamo_runtime::component::Client; use dynamo_runtime::component::Client;
use dynamo_runtime::discovery::{DiscoveryQuery, watch_and_extract_field}; use dynamo_runtime::discovery::{DiscoveryQuery, watch_and_extract_field};
...@@ -10,7 +10,7 @@ use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait}; ...@@ -10,7 +10,7 @@ use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventSubscriber; use dynamo_runtime::traits::events::EventSubscriber;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
...@@ -20,35 +20,62 @@ const THRESHOLD_SCALE: u32 = 10000; ...@@ -20,35 +20,62 @@ const THRESHOLD_SCALE: u32 = 10000;
/// Worker load monitoring state per dp_rank /// Worker load monitoring state per dp_rank
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub struct WorkerLoadState { pub struct WorkerLoadState {
pub kv_active_blocks: HashMap<u32, u64>, pub active_decode_blocks: HashMap<u32, u64>,
pub kv_total_blocks: HashMap<u32, u64>, pub kv_total_blocks: HashMap<u32, u64>,
pub active_prefill_tokens: HashMap<u32, u64>,
} }
impl WorkerLoadState { impl WorkerLoadState {
/// Returns true if ALL dp_ranks (that have data in both maps) exceed the threshold /// Returns true if ALL dp_ranks are considered busy based on the dual-threshold logic:
pub fn is_busy(&self, threshold: f64) -> bool { ///
// Get all dp_ranks that exist in both active and total blocks /// For each dp_rank:
let common_dp_ranks: Vec<_> = self /// 1. If `active_prefill_tokens` is available, check if tokens exceed the literal threshold.
.kv_active_blocks /// If so, that dp_rank is busy.
/// 2. If not, check if `active_decode_blocks` and `kv_total_blocks` are both available,
/// and if blocks exceed threshold. If so, that dp_rank is busy.
/// 3. If neither check can be performed (missing data), that dp_rank is considered free.
///
/// The worker is busy only if ALL dp_ranks are busy.
pub fn is_busy(
&self,
active_decode_blocks_threshold: f64,
active_prefill_tokens_threshold: u64,
) -> bool {
// Get all dp_ranks we know about
let all_dp_ranks: std::collections::HashSet<_> = self
.active_decode_blocks
.keys() .keys()
.filter(|dp_rank| self.kv_total_blocks.contains_key(dp_rank)) .chain(self.active_prefill_tokens.keys())
.copied()
.collect(); .collect();
// If no common dp_ranks, not busy // If no dp_ranks known, not busy
if common_dp_ranks.is_empty() { if all_dp_ranks.is_empty() {
return false; return false;
} }
// Check if ALL common dp_ranks exceed threshold // Check if ALL dp_ranks are busy
common_dp_ranks.iter().all(|&&dp_rank| { all_dp_ranks.iter().all(|&dp_rank| {
if let (Some(&active), Some(&total)) = ( // First check: prefill tokens threshold (literal token count)
self.kv_active_blocks.get(&dp_rank), if let Some(&active_tokens) = self.active_prefill_tokens.get(&dp_rank)
&& active_tokens > active_prefill_tokens_threshold
{
return true; // This dp_rank is busy due to tokens
}
// Second check: blocks threshold
// Skip if total_blocks is 0 (no capacity means threshold check is meaningless)
if let (Some(&active_blocks), Some(&total_blocks)) = (
self.active_decode_blocks.get(&dp_rank),
self.kv_total_blocks.get(&dp_rank), self.kv_total_blocks.get(&dp_rank),
) { ) && total_blocks > 0
total > 0 && (active as f64) > (threshold * total as f64) && (active_blocks as f64) > (active_decode_blocks_threshold * total_blocks as f64)
} else { {
false return true; // This dp_rank is busy due to blocks
} }
// If we can't perform either check, this dp_rank is considered free
false
}) })
} }
} }
...@@ -61,47 +88,76 @@ impl WorkerLoadState { ...@@ -61,47 +88,76 @@ impl WorkerLoadState {
pub struct KvWorkerMonitor { pub struct KvWorkerMonitor {
client: Client, client: Client,
worker_load_states: Arc<RwLock<HashMap<u64, WorkerLoadState>>>, worker_load_states: Arc<RwLock<HashMap<u64, WorkerLoadState>>>,
/// Threshold stored as parts-per-10000 (e.g., 8500 = 0.85) /// Active decode blocks threshold stored as parts-per-10000 (e.g., 8500 = 0.85)
busy_threshold: Arc<AtomicU32>, active_decode_blocks_threshold: Arc<AtomicU32>,
/// Active prefill tokens threshold stored as literal token count (u64)
active_prefill_tokens_threshold: Arc<AtomicU64>,
/// Guard to ensure start_monitoring() only runs once across clones /// Guard to ensure start_monitoring() only runs once across clones
started: Arc<AtomicBool>, started: Arc<AtomicBool>,
} }
impl KvWorkerMonitor { impl KvWorkerMonitor {
/// Create a new worker monitor with the given threshold. /// Create a new worker monitor with the given thresholds.
///
/// - `active_decode_blocks_threshold` (0.0-1.0): Threshold percentage for KV cache block utilization
/// - `active_prefill_tokens_threshold`: Literal token count threshold for prefill token utilization
/// ///
/// The threshold (0.0-1.0) controls when workers are considered busy based on /// Both thresholds can be dynamically updated via `set_active_decode_blocks_threshold()` and
/// KV cache utilization. It can be dynamically updated via `set_threshold()`. /// `set_active_prefill_tokens_threshold()`.
pub fn new(client: Client, threshold: f64) -> Self { pub fn new(
client: Client,
active_decode_blocks_threshold: f64,
active_prefill_tokens_threshold: u64,
) -> Self {
Self { Self {
client, client,
worker_load_states: Arc::new(RwLock::new(HashMap::new())), worker_load_states: Arc::new(RwLock::new(HashMap::new())),
busy_threshold: Arc::new(AtomicU32::new(Self::threshold_to_scaled(threshold))), active_decode_blocks_threshold: Arc::new(AtomicU32::new(
Self::active_decode_blocks_threshold_to_scaled(active_decode_blocks_threshold),
)),
active_prefill_tokens_threshold: Arc::new(AtomicU64::new(
active_prefill_tokens_threshold,
)),
started: Arc::new(AtomicBool::new(false)), started: Arc::new(AtomicBool::new(false)),
} }
} }
/// Convert a f64 threshold (0.0-1.0) to scaled u32 for atomic storage. /// Convert a f64 active decode blocks threshold (0.0-1.0) to scaled u32 for atomic storage.
#[inline] #[inline]
fn threshold_to_scaled(threshold: f64) -> u32 { fn active_decode_blocks_threshold_to_scaled(threshold: f64) -> u32 {
(threshold * THRESHOLD_SCALE as f64) as u32 (threshold * THRESHOLD_SCALE as f64) as u32
} }
/// Convert a scaled u32 back to f64 threshold (0.0-1.0). /// Convert a scaled u32 back to f64 active decode blocks threshold (0.0-1.0).
#[inline] #[inline]
fn scaled_to_threshold(scaled: u32) -> f64 { fn scaled_to_active_decode_blocks_threshold(scaled: u32) -> f64 {
scaled as f64 / THRESHOLD_SCALE as f64 scaled as f64 / THRESHOLD_SCALE as f64
} }
/// Get the current threshold value as f64. /// Get the current active decode blocks threshold value as f64.
pub fn threshold(&self) -> f64 { pub fn active_decode_blocks_threshold(&self) -> f64 {
Self::scaled_to_threshold(self.busy_threshold.load(Ordering::Relaxed)) Self::scaled_to_active_decode_blocks_threshold(
self.active_decode_blocks_threshold.load(Ordering::Relaxed),
)
} }
/// Set the threshold value from f64. /// Set the active decode blocks threshold value from f64.
pub fn set_threshold(&self, threshold: f64) { pub fn set_active_decode_blocks_threshold(&self, threshold: f64) {
self.busy_threshold self.active_decode_blocks_threshold.store(
.store(Self::threshold_to_scaled(threshold), Ordering::Relaxed); Self::active_decode_blocks_threshold_to_scaled(threshold),
Ordering::Relaxed,
);
}
/// Get the current active prefill tokens threshold value as u64.
pub fn active_prefill_tokens_threshold(&self) -> u64 {
self.active_prefill_tokens_threshold.load(Ordering::Relaxed)
}
/// Set the active prefill tokens threshold value from u64.
pub fn set_active_prefill_tokens_threshold(&self, threshold: u64) {
self.active_prefill_tokens_threshold
.store(threshold, Ordering::Relaxed);
} }
/// Get the worker load states for external access /// Get the worker load states for external access
...@@ -143,7 +199,8 @@ impl WorkerLoadMonitor for KvWorkerMonitor { ...@@ -143,7 +199,8 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
let worker_load_states = self.worker_load_states.clone(); let worker_load_states = self.worker_load_states.clone();
let client = self.client.clone(); let client = self.client.clone();
let busy_threshold = self.busy_threshold.clone(); let active_decode_blocks_threshold = self.active_decode_blocks_threshold.clone();
let active_prefill_tokens_threshold = self.active_prefill_tokens_threshold.clone();
// Spawn background monitoring task // Spawn background monitoring task
tokio::spawn(async move { tokio::spawn(async move {
...@@ -176,34 +233,46 @@ impl WorkerLoadMonitor for KvWorkerMonitor { ...@@ -176,34 +233,46 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
} }
} }
// Handle KV metrics updates // Handle KV metrics updates (ActiveLoad)
kv_event = kv_metrics_rx.next() => { kv_event = kv_metrics_rx.next() => {
let Some(event) = kv_event else { let Some(event) = kv_event else {
tracing::debug!("KV metrics stream closed"); tracing::debug!("KV metrics stream closed");
break; break;
}; };
if let Ok(load_event) = serde_json::from_slice::<LoadEvent>(&event.payload) { let Ok(active_load) = serde_json::from_slice::<ActiveLoad>(&event.payload) else {
let worker_id = load_event.worker_id; continue;
let active_blocks = load_event.data.kv_stats.kv_active_blocks; };
let dp_rank = load_event.data.worker_stats.data_parallel_rank.unwrap_or(0);
let worker_id = active_load.worker_id;
let dp_rank = active_load.dp_rank;
// Update worker load state per dp_rank // Update worker load state per dp_rank
let mut states = worker_load_states.write().unwrap(); let mut states = worker_load_states.write().unwrap();
let state = states.entry(worker_id).or_default(); let state = states.entry(worker_id).or_default();
state.kv_active_blocks.insert(dp_rank, active_blocks);
if let Some(active_blocks) = active_load.active_decode_blocks {
state.active_decode_blocks.insert(dp_rank, active_blocks);
}
if let Some(active_tokens) = active_load.active_prefill_tokens {
state.active_prefill_tokens.insert(dp_rank, active_tokens);
}
drop(states); drop(states);
// Load threshold dynamically - allows runtime updates // Load thresholds dynamically - allows runtime updates
let scaled_threshold = busy_threshold.load(Ordering::Relaxed); let current_active_decode_blocks_threshold = Self::scaled_to_active_decode_blocks_threshold(
let current_threshold = Self::scaled_to_threshold(scaled_threshold); active_decode_blocks_threshold.load(Ordering::Relaxed),
);
let current_active_prefill_tokens_threshold = active_prefill_tokens_threshold.load(Ordering::Relaxed);
// Recalculate all busy instances and update // Recalculate all busy instances and update
let states = worker_load_states.read().unwrap(); let states = worker_load_states.read().unwrap();
let busy_instances: Vec<u64> = states let busy_instances: Vec<u64> = states
.iter() .iter()
.filter_map(|(&id, state)| { .filter_map(|(&id, state)| {
state.is_busy(current_threshold).then_some(id) state
.is_busy(current_active_decode_blocks_threshold, current_active_prefill_tokens_threshold)
.then_some(id)
}) })
.collect(); .collect();
drop(states); drop(states);
...@@ -217,7 +286,6 @@ impl WorkerLoadMonitor for KvWorkerMonitor { ...@@ -217,7 +286,6 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
} }
} }
} }
}
tracing::info!("Worker monitoring task exiting"); tracing::info!("Worker monitoring task exiting");
}); });
......
...@@ -21,7 +21,10 @@ use crate::{ ...@@ -21,7 +21,10 @@ use crate::{
pub struct RouterConfig { pub struct RouterConfig {
pub router_mode: RouterMode, pub router_mode: RouterMode,
pub kv_router_config: KvRouterConfig, pub kv_router_config: KvRouterConfig,
pub busy_threshold: Option<f64>, /// Threshold for active decode blocks utilization (0.0-1.0)
pub active_decode_blocks_threshold: Option<f64>,
/// Threshold for active prefill tokens utilization (literal token count)
pub active_prefill_tokens_threshold: Option<u64>,
pub enforce_disagg: bool, pub enforce_disagg: bool,
} }
...@@ -30,13 +33,19 @@ impl RouterConfig { ...@@ -30,13 +33,19 @@ impl RouterConfig {
Self { Self {
router_mode, router_mode,
kv_router_config, kv_router_config,
busy_threshold: None, active_decode_blocks_threshold: None,
active_prefill_tokens_threshold: None,
enforce_disagg: false, enforce_disagg: false,
} }
} }
pub fn with_busy_threshold(mut self, threshold: Option<f64>) -> Self { pub fn with_active_decode_blocks_threshold(mut self, threshold: Option<f64>) -> Self {
self.busy_threshold = threshold; self.active_decode_blocks_threshold = threshold;
self
}
pub fn with_active_prefill_tokens_threshold(mut self, threshold: Option<u64>) -> Self {
self.active_prefill_tokens_threshold = threshold;
self self
} }
......
...@@ -237,7 +237,10 @@ where ...@@ -237,7 +237,10 @@ where
}; };
// Get threshold value and wrap monitor for PushRouter // Get threshold value and wrap monitor for PushRouter
let threshold_value = worker_monitor.as_ref().map(|m| m.threshold()); // Note: PushRouter uses active_decode_blocks_threshold for its internal logic
let threshold_value = worker_monitor
.as_ref()
.map(|m| m.active_decode_blocks_threshold());
let monitor_arc = let monitor_arc =
worker_monitor.map(|m| Arc::new(m) as Arc<dyn dynamo_runtime::pipeline::WorkerLoadMonitor>); worker_monitor.map(|m| Arc::new(m) as Arc<dyn dynamo_runtime::pipeline::WorkerLoadMonitor>);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//! HTTP endpoint for dynamically getting/setting the busy threshold per model. //! HTTP endpoint for dynamically getting/setting the busy thresholds per model.
//! //!
//! The busy threshold controls when workers are marked as "busy" based on their //! The busy thresholds control when workers are marked as "busy" based on their
//! KV cache utilization. When all workers for a model exceed their threshold, //! KV cache block utilization and prefill token utilization. When all workers
//! new requests are rejected with a 503 Service Unavailable response. //! for a model exceed their thresholds, new requests are rejected with a 503
//! Service Unavailable response.
//! //!
//! ## Endpoints //! ## Endpoints
//! //!
//! ### POST /busy_threshold //! ### POST /busy_threshold
//! //!
//! Get or set a model's busy threshold. //! Get or set a model's busy thresholds.
//! //!
//! **Set threshold:** //! **Set thresholds:**
//! ```json //! ```json
//! // Request //! // Request
//! {"model": "llama-3-70b", "threshold": 0.85} //! {"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
//! // Response //! // Response
//! {"model": "llama-3-70b", "threshold": 0.85} //! {"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
//! ``` //! ```
//! //!
//! **Get threshold (omit or null threshold):** //! **Get thresholds (omit thresholds):**
//! ```json //! ```json
//! // Request //! // Request
//! {"model": "llama-3-70b"} //! {"model": "llama-3-70b"}
//! // Response (if configured) //! // Response (if configured)
//! {"model": "llama-3-70b", "threshold": 0.85} //! {"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
//! // Response (if not configured) //! // Response (if not configured)
//! {"model": "llama-3-70b", "threshold": null} //! {"model": "llama-3-70b", "active_decode_blocks_threshold": null, "active_prefill_tokens_threshold": null}
//! ``` //! ```
//! //!
//! ### GET /busy_threshold //! ### GET /busy_threshold
...@@ -37,29 +38,33 @@ ...@@ -37,29 +38,33 @@
//! //!
//! ```json //! ```json
//! // Response //! // Response
//! {"thresholds": [{"model": "llama-3-70b", "threshold": 0.85}]} //! {"thresholds": [{"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}]}
//! ``` //! ```
use super::{RouteDoc, service_v2}; use super::{RouteDoc, service_v2};
use axum::{ use axum::{
Json, Router, Json, Router,
extract::Request,
http::{Method, StatusCode}, http::{Method, StatusCode},
response::IntoResponse, middleware::Next,
response::{IntoResponse, Response},
routing::{get, post}, routing::{get, post},
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::Arc; use std::sync::Arc;
/// Request body for getting or setting a busy threshold. /// Request body for getting or setting busy thresholds.
/// ///
/// - If `threshold` is provided: sets/creates the threshold and returns the new value /// - If thresholds are provided: sets/creates the thresholds and returns the new values
/// - If `threshold` is null/omitted: returns the existing threshold if any /// - If thresholds are null/omitted: returns the existing thresholds if any
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct BusyThresholdRequest { pub struct BusyThresholdRequest {
/// The model name /// The model name
pub model: String, pub model: String,
/// The threshold value (0.0 to 1.0), or null to just get the current value /// The active decode blocks threshold value (0.0 to 1.0), or null to just get the current value
pub threshold: Option<f64>, pub active_decode_blocks_threshold: Option<f64>,
/// The active prefill tokens threshold value (literal token count), or null to just get the current value
pub active_prefill_tokens_threshold: Option<u64>,
} }
/// Response for a threshold operation /// Response for a threshold operation
...@@ -67,8 +72,10 @@ pub struct BusyThresholdRequest { ...@@ -67,8 +72,10 @@ pub struct BusyThresholdRequest {
pub struct BusyThresholdResponse { pub struct BusyThresholdResponse {
/// The model name /// The model name
pub model: String, pub model: String,
/// The threshold value (null if no threshold is configured) /// The active decode blocks threshold value (null if no threshold is configured)
pub threshold: Option<f64>, pub active_decode_blocks_threshold: Option<f64>,
/// The active prefill tokens threshold value (null if no threshold is configured)
pub active_prefill_tokens_threshold: Option<u64>,
} }
/// Response for listing all thresholds /// Response for listing all thresholds
...@@ -84,6 +91,29 @@ pub struct ErrorResponse { ...@@ -84,6 +91,29 @@ pub struct ErrorResponse {
pub error: String, pub error: String,
} }
/// Middleware to convert 422 Unprocessable Entity responses (from JSON deserialization errors)
/// to JSON format instead of text/plain.
async fn json_error_middleware(request: Request, next: Next) -> Response {
let response = next.run(request).await;
if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
let (_parts, body) = response.into_parts();
let body_bytes = axum::body::to_bytes(body, usize::MAX)
.await
.unwrap_or_default();
let error_message = String::from_utf8_lossy(&body_bytes).to_string();
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(serde_json::json!(ErrorResponse {
error: error_message,
})),
)
.into_response()
} else {
response
}
}
pub fn busy_threshold_router( pub fn busy_threshold_router(
state: Arc<service_v2::State>, state: Arc<service_v2::State>,
path: Option<String>, path: Option<String>,
...@@ -98,6 +128,7 @@ pub fn busy_threshold_router( ...@@ -98,6 +128,7 @@ pub fn busy_threshold_router(
let router = Router::new() let router = Router::new()
.route(&base_path, post(busy_threshold_handler)) .route(&base_path, post(busy_threshold_handler))
.route(&base_path, get(list_busy_thresholds_handler)) .route(&base_path, get(list_busy_thresholds_handler))
.layer(axum::middleware::from_fn(json_error_middleware))
.with_state(state); .with_state(state);
(docs, router) (docs, router)
...@@ -107,25 +138,36 @@ async fn busy_threshold_handler( ...@@ -107,25 +138,36 @@ async fn busy_threshold_handler(
axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>, axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
Json(request): Json<BusyThresholdRequest>, Json(request): Json<BusyThresholdRequest>,
) -> impl IntoResponse { ) -> impl IntoResponse {
// Validate threshold range if provided // Validate active decode blocks threshold range if provided (must be 0.0-1.0)
if let Some(threshold) = request.threshold if let Some(threshold) = request.active_decode_blocks_threshold
&& !(0.0..=1.0).contains(&threshold) && !(0.0..=1.0).contains(&threshold)
{ {
return ( return (
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
Json(serde_json::json!(ErrorResponse { Json(serde_json::json!(ErrorResponse {
error: format!("Threshold must be between 0.0 and 1.0, got {}", threshold), error: format!(
"active_decode_blocks_threshold must be between 0.0 and 1.0, got {}",
threshold
),
})), })),
); );
} }
let manager = state.manager(); let manager = state.manager();
// Get or set the threshold via the model's worker monitor // Get or set the thresholds via the model's worker monitor
let threshold = manager.busy_threshold(&request.model, request.threshold); let active_decode_blocks_threshold = manager
.active_decode_blocks_threshold(&request.model, request.active_decode_blocks_threshold);
let active_prefill_tokens_threshold = manager
.active_prefill_tokens_threshold(&request.model, request.active_prefill_tokens_threshold);
// If trying to SET but model has no monitor, return 404 // If trying to SET but model has no monitor, return 404
if request.threshold.is_some() && threshold.is_none() { let is_setting = request.active_decode_blocks_threshold.is_some()
|| request.active_prefill_tokens_threshold.is_some();
if is_setting
&& active_decode_blocks_threshold.is_none()
&& active_prefill_tokens_threshold.is_none()
{
return ( return (
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
Json(serde_json::json!(ErrorResponse { Json(serde_json::json!(ErrorResponse {
...@@ -137,11 +179,12 @@ async fn busy_threshold_handler( ...@@ -137,11 +179,12 @@ async fn busy_threshold_handler(
); );
} }
if request.threshold.is_some() { if is_setting {
tracing::info!( tracing::info!(
model = %request.model, model = %request.model,
threshold = ?threshold, active_decode_blocks_threshold = ?active_decode_blocks_threshold,
"Updated busy threshold" active_prefill_tokens_threshold = ?active_prefill_tokens_threshold,
"Updated busy thresholds"
); );
} }
...@@ -149,7 +192,8 @@ async fn busy_threshold_handler( ...@@ -149,7 +192,8 @@ async fn busy_threshold_handler(
StatusCode::OK, StatusCode::OK,
Json(serde_json::json!(BusyThresholdResponse { Json(serde_json::json!(BusyThresholdResponse {
model: request.model, model: request.model,
threshold, active_decode_blocks_threshold,
active_prefill_tokens_threshold,
})), })),
) )
} }
...@@ -163,10 +207,15 @@ async fn list_busy_thresholds_handler( ...@@ -163,10 +207,15 @@ async fn list_busy_thresholds_handler(
let response = ListBusyThresholdsResponse { let response = ListBusyThresholdsResponse {
thresholds: thresholds thresholds: thresholds
.into_iter() .into_iter()
.map(|(model, threshold)| BusyThresholdResponse { .map(
|(model, active_decode_blocks_threshold, active_prefill_tokens_threshold)| {
BusyThresholdResponse {
model, model,
threshold: Some(threshold), active_decode_blocks_threshold: Some(active_decode_blocks_threshold),
}) active_prefill_tokens_threshold: Some(active_prefill_tokens_threshold),
}
},
)
.collect(), .collect(),
}; };
......
...@@ -143,6 +143,21 @@ pub struct SpecDecodeStats { ...@@ -143,6 +143,21 @@ pub struct SpecDecodeStats {
pub num_accepted_tokens_per_pos: Option<Vec<u32>>, pub num_accepted_tokens_per_pos: Option<Vec<u32>>,
} }
/// Active load metrics for a worker, used for busy detection.
///
/// Published by workers (with only `active_decode_blocks`) and by the scheduler
/// (with both `active_decode_blocks` and `active_prefill_tokens`).
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct ActiveLoad {
pub worker_id: WorkerId,
#[serde(default)]
pub dp_rank: DpRank,
/// Number of active KV cache blocks on the worker (decode phase).
pub active_decode_blocks: Option<u64>,
/// Number of active prefill tokens (from scheduler's view).
pub active_prefill_tokens: Option<u64>,
}
/// A [`LocalBlockHash`] is a hash computed from the tokens_ids, extra_token_ids and the optional /// A [`LocalBlockHash`] is a hash computed from the tokens_ids, extra_token_ids and the optional
/// lora_id of a block. /// lora_id of a block.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
......
...@@ -26,7 +26,6 @@ use crate::kv_router::{ ...@@ -26,7 +26,6 @@ use crate::kv_router::{
KV_EVENT_SUBJECT, KV_METRICS_SUBJECT, KV_EVENT_SUBJECT, KV_METRICS_SUBJECT,
indexer::{RouterEvent, compute_block_hash_for_seq}, indexer::{RouterEvent, compute_block_hash_for_seq},
protocols::*, protocols::*,
scoring::LoadEvent,
}; };
use dynamo_runtime::config::environment_names::nats as env_nats; use dynamo_runtime::config::environment_names::nats as env_nats;
...@@ -867,14 +866,16 @@ impl WorkerMetricsPublisher { ...@@ -867,14 +866,16 @@ impl WorkerMetricsPublisher {
// Timer expired - publish if we have pending metrics // Timer expired - publish if we have pending metrics
_ = &mut publish_timer => { _ = &mut publish_timer => {
if let Some(metrics) = pending_publish.take() { if let Some(metrics) = pending_publish.take() {
// Create LoadEvent wrapping the metrics // Create ActiveLoad with only active_decode_blocks (worker doesn't know prefill tokens)
let load_event = LoadEvent { let active_load = ActiveLoad {
worker_id, worker_id,
data: (*metrics).clone(), dp_rank: metrics.worker_stats.data_parallel_rank.unwrap_or(0),
active_decode_blocks: Some(metrics.kv_stats.kv_active_blocks),
active_prefill_tokens: None,
}; };
if let Err(e) = if let Err(e) =
namespace.publish(KV_METRICS_SUBJECT, &load_event).await namespace.publish(KV_METRICS_SUBJECT, &active_load).await
{ {
tracing::warn!("Failed to publish metrics over NATS: {}", e); tracing::warn!("Failed to publish metrics over NATS: {}", e);
} }
...@@ -1239,7 +1240,7 @@ mod test_exponential_backoff { ...@@ -1239,7 +1240,7 @@ mod test_exponential_backoff {
#[cfg(all(test, feature = "integration"))] #[cfg(all(test, feature = "integration"))]
mod test_integration_publisher { mod test_integration_publisher {
use super::*; use super::*;
use crate::kv_router::protocols::{ForwardPassMetrics, KvStats, WorkerStats}; use crate::kv_router::protocols::{ActiveLoad, ForwardPassMetrics, KvStats, WorkerStats};
use dynamo_runtime::distributed_test_utils::create_test_drt_async; use dynamo_runtime::distributed_test_utils::create_test_drt_async;
use dynamo_runtime::traits::events::EventSubscriber; use dynamo_runtime::traits::events::EventSubscriber;
use futures::StreamExt; use futures::StreamExt;
...@@ -1253,7 +1254,7 @@ mod test_integration_publisher { ...@@ -1253,7 +1254,7 @@ mod test_integration_publisher {
// Create a subscriber for the metrics events using subscribe_with_type // Create a subscriber for the metrics events using subscribe_with_type
let mut subscriber = namespace let mut subscriber = namespace
.subscribe_with_type::<LoadEvent>(KV_METRICS_SUBJECT) .subscribe_with_type::<ActiveLoad>(KV_METRICS_SUBJECT)
.await .await
.unwrap(); .unwrap();
...@@ -1301,8 +1302,8 @@ mod test_integration_publisher { ...@@ -1301,8 +1302,8 @@ mod test_integration_publisher {
let event = result.unwrap().unwrap(); // Unwrap the Option and the Result let event = result.unwrap().unwrap(); // Unwrap the Option and the Result
assert_eq!(event.worker_id, worker_id); assert_eq!(event.worker_id, worker_id);
assert_eq!(event.data.kv_stats.kv_active_blocks, 900); // Last value: 9 * 100 assert_eq!(event.active_decode_blocks, Some(900)); // Last value: 9 * 100
assert_eq!(event.data.worker_stats.num_requests_waiting, 90); // Last value: 9 * 10 assert_eq!(event.active_prefill_tokens, None); // Worker doesn't publish prefill tokens
// Ensure no more events are waiting // Ensure no more events are waiting
let no_msg = let no_msg =
......
...@@ -133,7 +133,7 @@ impl KvScheduler { ...@@ -133,7 +133,7 @@ impl KvScheduler {
let slots_monitor = slots.clone(); let slots_monitor = slots.clone();
let mut instance_ids_monitor_rx = instance_ids_rx.clone(); let mut instance_ids_monitor_rx = instance_ids_rx.clone();
let mut configs_monitor_rx = runtime_configs_rx.clone(); let mut configs_monitor_rx = runtime_configs_rx.clone();
let monitor_cancel_token = component.drt().primary_token(); let monitor_cancel_token = component.drt().child_token();
tokio::spawn(async move { tokio::spawn(async move {
tracing::trace!("workers monitoring task started"); tracing::trace!("workers monitoring task started");
loop { loop {
......
...@@ -3,16 +3,10 @@ ...@@ -3,16 +3,10 @@
//! Scoring functions for the KV router. //! Scoring functions for the KV router.
use super::protocols::{ForwardPassMetrics, LoadMetrics}; use super::protocols::LoadMetrics;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct LoadEvent {
pub worker_id: u64,
pub data: ForwardPassMetrics,
}
/// [gluo FIXME] exactly the same as EndpointInfo except that 'data' /// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
/// is cleaned (not optional) /// is cleaned (not optional)
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
......
...@@ -38,8 +38,10 @@ use std::time::Duration; ...@@ -38,8 +38,10 @@ use std::time::Duration;
use tokio::time::Instant; use tokio::time::Instant;
use uuid::Uuid; use uuid::Uuid;
use super::protocols::{ActiveSequenceEvent, ActiveSequenceEventData, WorkerWithDpRank}; use super::protocols::{
use crate::kv_router::ACTIVE_SEQUENCES_SUBJECT; ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, WorkerWithDpRank,
};
use crate::kv_router::{ACTIVE_SEQUENCES_SUBJECT, KV_METRICS_SUBJECT};
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_runtime::CancellationToken; use dynamo_runtime::CancellationToken;
...@@ -701,6 +703,9 @@ impl ActiveSequencesMultiWorker { ...@@ -701,6 +703,9 @@ impl ActiveSequencesMultiWorker {
self.request_to_worker.remove(expired_id); self.request_to_worker.remove(expired_id);
} }
// Publish ActiveLoad metrics for this worker
self.publish_active_load_for_worker(worker).await;
Ok(()) Ok(())
} }
...@@ -744,6 +749,9 @@ impl ActiveSequencesMultiWorker { ...@@ -744,6 +749,9 @@ impl ActiveSequencesMultiWorker {
self.request_to_worker.remove(request_id); self.request_to_worker.remove(request_id);
// Publish ActiveLoad metrics for this worker
self.publish_active_load_for_worker(worker).await;
Ok(()) Ok(())
} }
...@@ -790,9 +798,66 @@ impl ActiveSequencesMultiWorker { ...@@ -790,9 +798,66 @@ impl ActiveSequencesMultiWorker {
}) })
.map_err(|_| SequenceError::WorkerChannelClosed)?; .map_err(|_| SequenceError::WorkerChannelClosed)?;
// Publish ActiveLoad metrics for this worker
self.publish_active_load_for_worker(worker).await;
Ok(()) Ok(())
} }
/// Helper method to query a single worker for active blocks/tokens and publish ActiveLoad
async fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) {
let Some(sender) = self.senders.get(&worker) else {
tracing::warn!("Worker {worker:?} not found when publishing ActiveLoad");
return;
};
// Query active blocks
let (blocks_tx, blocks_rx) = tokio::sync::oneshot::channel();
if sender
.send(UpdateSequences::ActiveBlocks { resp_tx: blocks_tx })
.is_err()
{
tracing::warn!("Failed to send ActiveBlocks query to worker {worker:?}");
return;
}
// Query active tokens
let (tokens_tx, tokens_rx) = tokio::sync::oneshot::channel();
if sender
.send(UpdateSequences::ActiveTokens { resp_tx: tokens_tx })
.is_err()
{
tracing::warn!("Failed to send ActiveTokens query to worker {worker:?}");
return;
}
// Await both responses
let (active_blocks, active_tokens) = match tokio::join!(blocks_rx, tokens_rx) {
(Ok(blocks), Ok(tokens)) => (blocks, tokens),
_ => {
tracing::warn!("Failed to receive active blocks/tokens from worker {worker:?}");
return;
}
};
// Publish ActiveLoad
let active_load = ActiveLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
active_decode_blocks: Some(active_blocks as u64),
active_prefill_tokens: Some(active_tokens as u64),
};
if let Err(e) = self
.component
.namespace()
.publish(KV_METRICS_SUBJECT, &active_load)
.await
{
tracing::warn!("Failed to publish ActiveLoad for worker {worker:?}: {e:?}");
}
}
/// Get the number of workers /// Get the number of workers
pub fn num_workers(&self) -> usize { pub fn num_workers(&self) -> usize {
self.senders.len() self.senders.len()
......
...@@ -38,7 +38,8 @@ class KVRouterProcess(ManagedProcess): ...@@ -38,7 +38,8 @@ class KVRouterProcess(ManagedProcess):
namespace: str, namespace: str,
store_backend: str = "etcd", store_backend: str = "etcd",
enforce_disagg: bool = False, enforce_disagg: bool = False,
busy_threshold: float | None = None, blocks_threshold: float | None = None,
tokens_threshold: float | None = None,
request_plane: str = "nats", request_plane: str = "nats",
): ):
command = [ command = [
...@@ -60,8 +61,11 @@ class KVRouterProcess(ManagedProcess): ...@@ -60,8 +61,11 @@ class KVRouterProcess(ManagedProcess):
if enforce_disagg: if enforce_disagg:
command.append("--enforce-disagg") command.append("--enforce-disagg")
if busy_threshold is not None: if blocks_threshold is not None:
command.extend(["--busy-threshold", str(busy_threshold)]) command.extend(["--active-decode-blocks-threshold", str(blocks_threshold)])
if tokens_threshold is not None:
command.extend(["--active-prefill-tokens-threshold", str(tokens_threshold)])
env = os.environ.copy() env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request_plane env["DYN_REQUEST_PLANE"] = request_plane
...@@ -1156,7 +1160,7 @@ def _test_router_overload_503( ...@@ -1156,7 +1160,7 @@ def _test_router_overload_503(
request, request,
frontend_port: int, frontend_port: int,
test_payload: dict, test_payload: dict,
busy_threshold: float = 0.2, blocks_threshold: float = 0.2,
): ):
"""Test that KV router returns 503 when all workers are busy. """Test that KV router returns 503 when all workers are busy.
...@@ -1169,7 +1173,7 @@ def _test_router_overload_503( ...@@ -1169,7 +1173,7 @@ def _test_router_overload_503(
request: Pytest request fixture for managing resources request: Pytest request fixture for managing resources
frontend_port: Port for the frontend HTTP server frontend_port: Port for the frontend HTTP server
test_payload: Base test payload to send to /v1/chat/completions test_payload: Base test payload to send to /v1/chat/completions
busy_threshold: Busy threshold for the router (default 0.2) blocks_threshold: Active decode blocks threshold for the router (default 0.2)
Raises: Raises:
AssertionError: If 503 response is not received when expected AssertionError: If 503 response is not received when expected
...@@ -1185,8 +1189,8 @@ def _test_router_overload_503( ...@@ -1185,8 +1189,8 @@ def _test_router_overload_503(
"python", "python",
"-m", "-m",
"dynamo.frontend", "dynamo.frontend",
"--busy-threshold", "--active-decode-blocks-threshold",
str(busy_threshold), str(blocks_threshold),
"--kv-cache-block-size", "--kv-cache-block-size",
str(block_size), str(block_size),
"--router-mode", "--router-mode",
...@@ -2038,11 +2042,12 @@ def _test_busy_threshold_endpoint( ...@@ -2038,11 +2042,12 @@ def _test_busy_threshold_endpoint(
Raises: Raises:
AssertionError: If endpoint responses are incorrect AssertionError: If endpoint responses are incorrect
""" """
# Initial threshold - we need to start with one so the monitor is created # Initial thresholds - we need to start with these so the monitor is created
initial_threshold = 0.9 initial_active_decode_blocks_threshold = 0.9
initial_active_prefill_tokens_threshold = 1000 # Literal token count threshold
try: try:
# Start KV router frontend with initial busy_threshold to create monitor # Start KV router frontend with initial thresholds to create monitor
logger.info(f"Starting KV router frontend on port {frontend_port}") logger.info(f"Starting KV router frontend on port {frontend_port}")
kv_router = KVRouterProcess( kv_router = KVRouterProcess(
request, request,
...@@ -2050,7 +2055,8 @@ def _test_busy_threshold_endpoint( ...@@ -2050,7 +2055,8 @@ def _test_busy_threshold_endpoint(
frontend_port, frontend_port,
engine_workers.namespace, engine_workers.namespace,
store_backend, store_backend,
busy_threshold=initial_threshold, blocks_threshold=initial_active_decode_blocks_threshold,
tokens_threshold=initial_active_prefill_tokens_threshold,
request_plane=request_plane, request_plane=request_plane,
) )
kv_router.__enter__() kv_router.__enter__()
...@@ -2073,7 +2079,6 @@ def _test_busy_threshold_endpoint( ...@@ -2073,7 +2079,6 @@ def _test_busy_threshold_endpoint(
async def test_busy_threshold_api(): async def test_busy_threshold_api():
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
# Test 1: GET /busy_threshold - list all thresholds # Test 1: GET /busy_threshold - list all thresholds
# Should have the initial threshold since we started with --busy-threshold
logger.info("Testing GET /busy_threshold (list all)") logger.info("Testing GET /busy_threshold (list all)")
async with session.get(busy_threshold_url) as response: async with session.get(busy_threshold_url) as response:
assert ( assert (
...@@ -2083,14 +2088,11 @@ def _test_busy_threshold_endpoint( ...@@ -2083,14 +2088,11 @@ def _test_busy_threshold_endpoint(
assert ( assert (
"thresholds" in data "thresholds" in data
), f"Expected 'thresholds' key in response: {data}" ), f"Expected 'thresholds' key in response: {data}"
thresholds = data.get("thresholds", [])
# Should have at least the model with initial_threshold
logger.info(f"GET /busy_threshold response: {data}") logger.info(f"GET /busy_threshold response: {data}")
# Test 2: POST /busy_threshold with model only (get threshold) # Test 2: POST /busy_threshold with model only (get thresholds)
# Should return the initial threshold since we started with --busy-threshold
logger.info( logger.info(
f"Testing POST /busy_threshold to get threshold for model '{model_name}'" f"Testing POST /busy_threshold to get thresholds for model '{model_name}'"
) )
async with session.post( async with session.post(
busy_threshold_url, busy_threshold_url,
...@@ -2101,99 +2103,173 @@ def _test_busy_threshold_endpoint( ...@@ -2101,99 +2103,173 @@ def _test_busy_threshold_endpoint(
), f"POST /busy_threshold (get) failed with status {response.status}" ), f"POST /busy_threshold (get) failed with status {response.status}"
data = await response.json() data = await response.json()
assert ( assert (
data.get("threshold") == initial_threshold data.get("active_decode_blocks_threshold")
), f"Expected initial threshold={initial_threshold}: {data}" == initial_active_decode_blocks_threshold
), f"Expected initial active_decode_blocks_threshold={initial_active_decode_blocks_threshold}: {data}"
assert (
data.get("active_prefill_tokens_threshold")
== initial_active_prefill_tokens_threshold
), f"Expected initial active_prefill_tokens_threshold={initial_active_prefill_tokens_threshold}: {data}"
logger.info( logger.info(
f"POST /busy_threshold (get) response: status={response.status}, data={data}" f"POST /busy_threshold (get) response: status={response.status}, data={data}"
) )
# Test 3: POST /busy_threshold to set a threshold # Test 3: POST /busy_threshold to set active_decode_blocks_threshold only
test_threshold = 0.75 test_active_decode_blocks_threshold = 0.75
logger.info( logger.info(
f"Testing POST /busy_threshold to set threshold={test_threshold}" f"Testing POST /busy_threshold to set active_decode_blocks_threshold={test_active_decode_blocks_threshold}"
) )
async with session.post( async with session.post(
busy_threshold_url, busy_threshold_url,
json={"model": model_name, "threshold": test_threshold}, json={
"model": model_name,
"active_decode_blocks_threshold": test_active_decode_blocks_threshold,
},
) as response: ) as response:
assert ( assert (
response.status == 200 response.status == 200
), f"POST /busy_threshold (set) failed with status {response.status}" ), f"POST /busy_threshold (set blocks) failed with status {response.status}"
data = await response.json() data = await response.json()
assert ( assert (
data.get("model") == model_name data.get("model") == model_name
), f"Expected model={model_name}: {data}" ), f"Expected model={model_name}: {data}"
assert ( assert (
data.get("threshold") == test_threshold data.get("active_decode_blocks_threshold")
), f"Expected threshold={test_threshold}: {data}" == test_active_decode_blocks_threshold
logger.info(f"POST /busy_threshold (set) response: {data}") ), f"Expected active_decode_blocks_threshold={test_active_decode_blocks_threshold}: {data}"
logger.info(f"POST /busy_threshold (set blocks) response: {data}")
# Test 4: POST /busy_threshold to get the threshold we just set # Test 4: POST /busy_threshold to set active_prefill_tokens_threshold only
logger.info("Testing POST /busy_threshold to verify threshold was set") test_active_prefill_tokens_threshold = (
2000 # Literal token count threshold
)
logger.info(
f"Testing POST /busy_threshold to set active_prefill_tokens_threshold={test_active_prefill_tokens_threshold}"
)
async with session.post( async with session.post(
busy_threshold_url, busy_threshold_url,
json={"model": model_name}, json={
"model": model_name,
"active_prefill_tokens_threshold": test_active_prefill_tokens_threshold,
},
) as response: ) as response:
assert ( assert (
response.status == 200 response.status == 200
), f"POST /busy_threshold (get after set) failed with status {response.status}" ), f"POST /busy_threshold (set tokens) failed with status {response.status}"
data = await response.json() data = await response.json()
assert ( assert (
data.get("threshold") == test_threshold data.get("active_prefill_tokens_threshold")
), f"Expected threshold={test_threshold}: {data}" == test_active_prefill_tokens_threshold
logger.info( ), f"Expected active_prefill_tokens_threshold={test_active_prefill_tokens_threshold}: {data}"
f"POST /busy_threshold (get after set) response: {data}" logger.info(f"POST /busy_threshold (set tokens) response: {data}")
)
# Test 5: POST /busy_threshold to update the threshold # Test 5: POST /busy_threshold to set both thresholds
new_threshold = 0.5 new_active_decode_blocks_threshold = 0.5
new_active_prefill_tokens_threshold = (
1200 # Literal token count threshold
)
logger.info( logger.info(
f"Testing POST /busy_threshold to update threshold={new_threshold}" f"Testing POST /busy_threshold to set both thresholds: "
f"active_decode_blocks={new_active_decode_blocks_threshold}, active_prefill_tokens={new_active_prefill_tokens_threshold}"
) )
async with session.post( async with session.post(
busy_threshold_url, busy_threshold_url,
json={"model": model_name, "threshold": new_threshold}, json={
"model": model_name,
"active_decode_blocks_threshold": new_active_decode_blocks_threshold,
"active_prefill_tokens_threshold": new_active_prefill_tokens_threshold,
},
) as response: ) as response:
assert ( assert (
response.status == 200 response.status == 200
), f"POST /busy_threshold (update) failed with status {response.status}" ), f"POST /busy_threshold (set both) failed with status {response.status}"
data = await response.json() data = await response.json()
assert ( assert (
data.get("threshold") == new_threshold data.get("active_decode_blocks_threshold")
), f"Expected threshold={new_threshold}: {data}" == new_active_decode_blocks_threshold
logger.info(f"POST /busy_threshold (update) response: {data}") ), f"Expected active_decode_blocks_threshold={new_active_decode_blocks_threshold}: {data}"
assert (
data.get("active_prefill_tokens_threshold")
== new_active_prefill_tokens_threshold
), f"Expected active_prefill_tokens_threshold={new_active_prefill_tokens_threshold}: {data}"
logger.info(f"POST /busy_threshold (set both) response: {data}")
# Test 6: GET /busy_threshold - verify threshold appears in list # Test 6: GET /busy_threshold - verify thresholds appear in list
logger.info("Testing GET /busy_threshold to verify threshold in list") logger.info("Testing GET /busy_threshold to verify thresholds in list")
async with session.get(busy_threshold_url) as response: async with session.get(busy_threshold_url) as response:
assert ( assert (
response.status == 200 response.status == 200
), f"GET /busy_threshold failed with status {response.status}" ), f"GET /busy_threshold failed with status {response.status}"
data = await response.json() data = await response.json()
thresholds = data.get("thresholds", []) thresholds = data.get("thresholds", [])
# thresholds is an array of {model, threshold} objects model_entry = next(
model_thresholds = {t["model"]: t["threshold"] for t in thresholds} (t for t in thresholds if t["model"] == model_name), None
)
assert ( assert (
model_name in model_thresholds model_entry is not None
), f"Expected model '{model_name}' in thresholds: {data}" ), f"Expected model '{model_name}' in thresholds: {data}"
assert ( assert (
model_thresholds[model_name] == new_threshold model_entry.get("active_decode_blocks_threshold")
), f"Expected threshold={new_threshold} for model '{model_name}': {data}" == new_active_decode_blocks_threshold
), f"Expected active_decode_blocks_threshold={new_active_decode_blocks_threshold}: {data}"
assert (
model_entry.get("active_prefill_tokens_threshold")
== new_active_prefill_tokens_threshold
), f"Expected active_prefill_tokens_threshold={new_active_prefill_tokens_threshold}: {data}"
logger.info(f"GET /busy_threshold (after set) response: {data}") logger.info(f"GET /busy_threshold (after set) response: {data}")
# Test 7: Invalid threshold value (should fail validation) # Test 7: Invalid active_decode_blocks_threshold value (should fail validation)
logger.info( logger.info(
"Testing POST /busy_threshold with invalid threshold (>1.0)" "Testing POST /busy_threshold with invalid active_decode_blocks_threshold (>1.0)"
) )
async with session.post( async with session.post(
busy_threshold_url, busy_threshold_url,
json={"model": model_name, "threshold": 1.5}, json={"model": model_name, "active_decode_blocks_threshold": 1.5},
) as response: ) as response:
assert ( assert (
response.status == 400 response.status == 400
), f"Expected 400 for invalid threshold, got {response.status}" ), f"Expected 400 for invalid active_decode_blocks_threshold, got {response.status}"
data = await response.json() data = await response.json()
logger.info(f"POST /busy_threshold (invalid) response: {data}") logger.info(
f"POST /busy_threshold (invalid blocks) response: {data}"
)
# Test 8: active_prefill_tokens_threshold accepts large values (should be valid)
logger.info(
"Testing POST /busy_threshold with large active_prefill_tokens_threshold (valid)"
)
async with session.post(
busy_threshold_url,
json={"model": model_name, "active_prefill_tokens_threshold": 5000},
) as response:
assert (
response.status == 200
), f"Expected 200 for large active_prefill_tokens_threshold, got {response.status}"
data = await response.json()
assert (
data.get("active_prefill_tokens_threshold") == 5000
), f"Expected active_prefill_tokens_threshold=5000: {data}"
logger.info(
f"POST /busy_threshold (large tokens threshold) response: {data}"
)
# Test 9: Invalid active_prefill_tokens_threshold value (should fail validation for < 0)
# Note: Returns 422 because -1.0 can't be deserialized into u64 (type validation)
# vs Test 7 which returns 400 because 1.5 is a valid f64 but fails range validation
logger.info(
"Testing POST /busy_threshold with invalid active_prefill_tokens_threshold (< 0)"
)
async with session.post(
busy_threshold_url,
json={"model": model_name, "active_prefill_tokens_threshold": -1.0},
) as response:
assert (
response.status == 422
), f"Expected 422 for negative active_prefill_tokens_threshold, got {response.status}"
data = await response.json()
logger.info(
f"POST /busy_threshold (invalid tokens) response: {data}"
)
logger.info("All busy_threshold endpoint tests passed!") logger.info("All busy_threshold endpoint tests passed!")
......
...@@ -426,7 +426,7 @@ def test_mocker_kv_router_overload_503( ...@@ -426,7 +426,7 @@ def test_mocker_kv_router_overload_503(
request=request, request=request,
frontend_port=frontend_port, frontend_port=frontend_port,
test_payload=TEST_PAYLOAD, test_payload=TEST_PAYLOAD,
busy_threshold=0.2, blocks_threshold=0.2,
) )
finally: finally:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment