"vllm/vscode:/vscode.git/clone" did not exist on "efe73d0575951767180468dac8202739cb479074"
Unverified Commit 10b01b45 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: early rejection based on active prefill tokens (#4837)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 82577b06
......@@ -190,10 +190,16 @@ def parse_args():
help="Enforce disaggregated prefill-decode. When set, unactivated prefill router will return an error instead of falling back to decode-only mode.",
)
parser.add_argument(
"--busy-threshold",
"--active-decode-blocks-threshold",
type=float,
default=None,
help="Threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache usage. If not set, busy detection is disabled.",
help="Threshold percentage (0.0-1.0) for determining when a worker is considered busy based on KV cache block utilization. If not set, blocks-based busy detection is disabled.",
)
parser.add_argument(
"--active-prefill-tokens-threshold",
type=int,
default=None,
help="Literal token count threshold for determining when a worker is considered busy based on prefill token utilization. When active prefill tokens exceed this threshold, the worker is marked as busy. If not set, tokens-based busy detection is disabled.",
)
parser.add_argument(
"--model-name",
......@@ -316,7 +322,11 @@ async def async_main():
"http_port": flags.http_port,
"kv_cache_block_size": flags.kv_cache_block_size,
"router_config": RouterConfig(
router_mode, kv_router_config, flags.busy_threshold, flags.enforce_disagg
router_mode,
kv_router_config,
active_decode_blocks_threshold=flags.active_decode_blocks_threshold,
active_prefill_tokens_threshold=flags.active_prefill_tokens_threshold,
enforce_disagg=flags.enforce_disagg,
),
}
......
......@@ -20,7 +20,6 @@ import os
from typing import Optional
import numpy as np
import scipy
from dynamo.runtime.logging import configure_dynamo_logging
......@@ -80,6 +79,9 @@ class PrefillInterpolator:
self.min_isl = min(self.prefill_isl)
self.max_isl = max(self.prefill_isl)
# Lazy import scipy only when interpolation is actually needed
import scipy.interpolate
# perform 1d interpolation
self.ttft_interpolator = scipy.interpolate.interp1d(
self.prefill_isl, self.prefill_ttft, kind="cubic"
......@@ -151,6 +153,9 @@ class DecodeInterpolator:
self.yi = np.linspace(0, max(self.y_context_length), resolution)
self.X, self.Y = np.meshgrid(self.xi, self.yi)
# Lazy import scipy only when interpolation is actually needed
import scipy.interpolate
# perform 2d interpolation with fallback for NaN values
self.itl_interpolator = scipy.interpolate.griddata(
(self.x_kv_usage, self.y_context_length),
......
......@@ -31,7 +31,9 @@ The main KV-aware routing arguments:
- `--no-track-active-blocks`: Disables tracking of active blocks (blocks being used for ongoing generation/decode phases). By default, the router tracks active blocks for load balancing. Disable this when routing to workers that only perform prefill (no decode phase), as tracking decode load is not relevant. This reduces router overhead and simplifies state management.
- `--busy-threshold`: Initial threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache usage. When a worker's KV cache active blocks exceed this percentage of total blocks, it will be marked as busy and excluded from routing. If not set, busy detection is disabled. This feature works with all routing modes (`--router-mode kv|round-robin|random`) as long as backend engines emit `ForwardPassMetrics`. The threshold can be dynamically updated at runtime via the `/busy_threshold` HTTP endpoint (see [Dynamic Threshold Configuration](#dynamic-threshold-configuration)).
- `--active-decode-blocks-threshold`: Initial threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache block utilization. When a worker's KV cache active blocks exceed this percentage of total blocks, it will be marked as busy and excluded from routing. If not set, blocks-based busy detection is disabled. This feature works with all routing modes (`--router-mode kv|round-robin|random`) as long as backend engines emit `ForwardPassMetrics`. The threshold can be dynamically updated at runtime via the `/busy_threshold` HTTP endpoint (see [Dynamic Threshold Configuration](#dynamic-threshold-configuration)).
- `--active-prefill-tokens-threshold`: Literal token count threshold for determining when a worker is considered busy based on prefill token utilization. When active prefill tokens exceed this threshold, the worker is marked as busy. If not set, tokens-based busy detection is disabled.
- `--router-ttl`: Time-to-live in seconds for blocks in the router's local cache predictions. Blocks older than this duration will be automatically expired and removed from the router's radix tree. Defaults to 120.0 seconds when `--no-kv-events` is used. This helps manage memory usage by removing stale cache predictions that are unlikely to be accurate.
......@@ -585,28 +587,32 @@ See [KV Router Architecture](../router/README.md) for performance tuning details
## Dynamic Threshold Configuration
The busy threshold can be updated at runtime without restarting the frontend. The frontend exposes HTTP endpoints at `/busy_threshold`:
The busy thresholds can be updated at runtime without restarting the frontend. The frontend exposes HTTP endpoints at `/busy_threshold`:
**Get or set a model's threshold (POST):**
**Get or set a model's thresholds (POST):**
```bash
# Set threshold for a model
# Set both thresholds for a model
curl -X POST http://localhost:8000/busy_threshold \
-H "Content-Type: application/json" \
-d '{"model": "meta-llama/Llama-2-7b-hf", "threshold": 0.85}'
# Response: {"model": "meta-llama/Llama-2-7b-hf", "threshold": 0.85}
-d '{"model": "meta-llama/Llama-2-7b-hf", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}'
# Response: {"model": "meta-llama/Llama-2-7b-hf", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
# Get current threshold (omit threshold field)
# Set only active decode blocks threshold
curl -X POST http://localhost:8000/busy_threshold \
-H "Content-Type: application/json" \
-d '{"model": "meta-llama/Llama-2-7b-hf", "active_decode_blocks_threshold": 0.85}'
# Response: {"model": "meta-llama/Llama-2-7b-hf", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": <current_value>}
# Get current thresholds (omit threshold fields)
curl -X POST http://localhost:8000/busy_threshold \
-H "Content-Type: application/json" \
-d '{"model": "meta-llama/Llama-2-7b-hf"}'
# Response: {"model": "meta-llama/Llama-2-7b-hf", "threshold": 0.85}
# Or if not configured: {"model": "...", "threshold": null}
# Response: {"model": "meta-llama/Llama-2-7b-hf", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
# Or if not configured: {"model": "...", "active_decode_blocks_threshold": null, "active_prefill_tokens_threshold": null}
```
**List all configured thresholds (GET):**
```bash
curl http://localhost:8000/busy_threshold
# Response: {"thresholds": [{"model": "meta-llama/Llama-2-7b-hf", "threshold": 0.85}]}
# Response: {"thresholds": [{"model": "meta-llama/Llama-2-7b-hf", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}]}
```
This allows you to tune the busy threshold based on observed system behavior without service interruption.
\ No newline at end of file
......@@ -966,7 +966,9 @@ pub async fn create_worker_selection_pipeline_chat(
let router_config = dynamo_llm::entrypoint::RouterConfig {
router_mode,
kv_router_config: kv_router_config.unwrap_or_default(),
busy_threshold,
// C bindings only support active_decode_blocks_threshold for now (via busy_threshold param)
active_decode_blocks_threshold: busy_threshold,
active_prefill_tokens_threshold: None,
enforce_disagg: false,
};
let watcher = ModelWatcher::new(
......@@ -1031,7 +1033,8 @@ pub async fn create_worker_selection_pipeline_chat(
// Create worker monitor if busy_threshold is set
// Note: C bindings don't register with ModelManager, so HTTP endpoint won't see this
let worker_monitor = busy_threshold.map(|t| KvWorkerMonitor::new(client.clone(), t));
// C bindings only support active_decode_blocks_threshold for now (active_prefill_tokens_threshold defaults to 1000000 tokens = effectively disabled)
let worker_monitor = busy_threshold.map(|t| KvWorkerMonitor::new(client.clone(), t, 1000000));
let engine = build_routed_pipeline::<
NvCreateChatCompletionRequest,
......
......@@ -77,24 +77,29 @@ impl KvRouterConfig {
pub struct RouterConfig {
router_mode: RouterMode,
kv_router_config: KvRouterConfig,
busy_threshold: Option<f64>,
/// Threshold for active decode blocks utilization (0.0-1.0)
active_decode_blocks_threshold: Option<f64>,
/// Threshold for active prefill tokens utilization (literal token count)
active_prefill_tokens_threshold: Option<u64>,
enforce_disagg: bool,
}
#[pymethods]
impl RouterConfig {
#[new]
#[pyo3(signature = (mode, config=None, busy_threshold=None, enforce_disagg=false))]
#[pyo3(signature = (mode, config=None, active_decode_blocks_threshold=None, active_prefill_tokens_threshold=None, enforce_disagg=false))]
pub fn new(
mode: RouterMode,
config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
active_decode_blocks_threshold: Option<f64>,
active_prefill_tokens_threshold: Option<u64>,
enforce_disagg: bool,
) -> Self {
Self {
router_mode: mode,
kv_router_config: config.unwrap_or_default(),
busy_threshold,
active_decode_blocks_threshold,
active_prefill_tokens_threshold,
enforce_disagg,
}
}
......@@ -105,7 +110,8 @@ impl From<RouterConfig> for RsRouterConfig {
RsRouterConfig {
router_mode: rc.router_mode.into(),
kv_router_config: rc.kv_router_config.inner,
busy_threshold: rc.busy_threshold,
active_decode_blocks_threshold: rc.active_decode_blocks_threshold,
active_prefill_tokens_threshold: rc.active_prefill_tokens_threshold,
enforce_disagg: rc.enforce_disagg,
}
}
......
......@@ -487,9 +487,11 @@ impl ModelManager {
/// Gets or sets the busy threshold for a model via its worker monitor.
///
/// Get or set the active decode blocks threshold for a model's worker monitor.
///
/// This is the primary API for HTTP endpoints and external callers.
/// The threshold (0.0 to 1.0) controls when workers are marked as "busy"
/// based on KV cache utilization.
/// based on KV cache block utilization.
///
/// # Arguments
///
......@@ -499,31 +501,63 @@ impl ModelManager {
/// # Returns
///
/// The threshold value as f64, or `None` if no monitor exists for this model.
/// Note: Setting a threshold for a non-existent model returns `None` (monitor
/// must be created via `get_or_create_worker_monitor` during model discovery).
pub fn busy_threshold(&self, model: &str, threshold: Option<f64>) -> Option<f64> {
pub fn active_decode_blocks_threshold(
&self,
model: &str,
threshold: Option<f64>,
) -> Option<f64> {
let monitors = self.worker_monitors.read();
let monitor = monitors.get(model)?;
match threshold {
Some(value) => {
monitor.set_threshold(value);
monitor.set_active_decode_blocks_threshold(value);
Some(value)
}
None => Some(monitor.threshold()),
None => Some(monitor.active_decode_blocks_threshold()),
}
}
/// Get or set the active prefill tokens threshold for a model's worker monitor.
///
/// The threshold is a literal token count (not a percentage).
///
/// # Arguments
///
/// * `model` - The model name
/// * `threshold` - `Some(value)` to set, `None` to get existing
///
/// # Returns
///
/// The threshold value as u64, or `None` if no monitor exists for this model.
pub fn active_prefill_tokens_threshold(
&self,
model: &str,
threshold: Option<u64>,
) -> Option<u64> {
let monitors = self.worker_monitors.read();
let monitor = monitors.get(model)?;
match threshold {
Some(value) => {
monitor.set_active_prefill_tokens_threshold(value);
Some(value)
}
None => Some(monitor.active_prefill_tokens_threshold()),
}
}
/// Gets or creates a worker monitor for a model.
///
/// If a monitor already exists, updates its threshold and returns a clone.
/// If no monitor exists, creates one with the given client and threshold.
/// If a monitor already exists, updates its thresholds and returns a clone.
/// If no monitor exists, creates one with the given client and thresholds.
///
/// # Arguments
///
/// * `model` - The model name
/// * `client` - The client for subscribing to KV metrics (only used if creating new)
/// * `threshold` - The initial/updated threshold value (0.0-1.0)
/// * `active_decode_blocks_threshold` - The initial/updated active decode blocks threshold value (0.0-1.0)
/// * `active_prefill_tokens_threshold` - The initial/updated active prefill tokens threshold value (literal token count)
///
/// # Returns
///
......@@ -532,15 +566,21 @@ impl ModelManager {
&self,
model: &str,
client: Client,
threshold: f64,
active_decode_blocks_threshold: f64,
active_prefill_tokens_threshold: u64,
) -> KvWorkerMonitor {
let mut monitors = self.worker_monitors.write();
if let Some(existing) = monitors.get(model) {
existing.set_threshold(threshold);
existing.set_active_decode_blocks_threshold(active_decode_blocks_threshold);
existing.set_active_prefill_tokens_threshold(active_prefill_tokens_threshold);
existing.clone()
} else {
let monitor = KvWorkerMonitor::new(client, threshold);
let monitor = KvWorkerMonitor::new(
client,
active_decode_blocks_threshold,
active_prefill_tokens_threshold,
);
monitors.insert(model.to_string(), monitor.clone());
monitor
}
......@@ -553,12 +593,18 @@ impl ModelManager {
/// Lists all models that have worker monitors (and thus busy thresholds) configured.
///
/// Returns a vector of (model_name, threshold_value) tuples.
pub fn list_busy_thresholds(&self) -> Vec<(String, f64)> {
/// Returns a vector of (model_name, active_decode_blocks_threshold, active_prefill_tokens_threshold) tuples.
pub fn list_busy_thresholds(&self) -> Vec<(String, f64, u64)> {
self.worker_monitors
.read()
.iter()
.map(|(k, monitor)| (k.clone(), monitor.threshold()))
.map(|(k, monitor)| {
(
k.clone(),
monitor.active_decode_blocks_threshold(),
monitor.active_prefill_tokens_threshold(),
)
})
.collect()
}
}
......
......@@ -404,10 +404,28 @@ impl ModelWatcher {
// Get or create the worker monitor for this model
// This allows dynamic threshold updates via the ModelManager
let worker_monitor = self.router_config.busy_threshold.map(|threshold| {
self.manager
.get_or_create_worker_monitor(card.name(), client.clone(), threshold)
});
// Create monitor if either threshold is configured
let worker_monitor = if self.router_config.active_decode_blocks_threshold.is_some()
|| self.router_config.active_prefill_tokens_threshold.is_some()
{
// Default thresholds: active_decode_blocks=1.0 (disabled), active_prefill_tokens=1000000 (effectively disabled)
let active_decode_blocks = self
.router_config
.active_decode_blocks_threshold
.unwrap_or(1.0);
let active_prefill_tokens = self
.router_config
.active_prefill_tokens_threshold
.unwrap_or(1000000);
Some(self.manager.get_or_create_worker_monitor(
card.name(),
client.clone(),
active_decode_blocks,
active_prefill_tokens,
))
} else {
None
};
// Add chat engine only if the model supports chat
if card.model_type.supports_chat() {
......
......@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use crate::kv_router::KV_METRICS_SUBJECT;
use crate::kv_router::scoring::LoadEvent;
use crate::kv_router::protocols::ActiveLoad;
use crate::model_card::ModelDeploymentCard;
use dynamo_runtime::component::Client;
use dynamo_runtime::discovery::{DiscoveryQuery, watch_and_extract_field};
......@@ -10,7 +10,7 @@ use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventSubscriber;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use tokio_stream::StreamExt;
......@@ -20,35 +20,62 @@ const THRESHOLD_SCALE: u32 = 10000;
/// Worker load monitoring state per dp_rank
#[derive(Clone, Debug, Default)]
pub struct WorkerLoadState {
pub kv_active_blocks: HashMap<u32, u64>,
pub active_decode_blocks: HashMap<u32, u64>,
pub kv_total_blocks: HashMap<u32, u64>,
pub active_prefill_tokens: HashMap<u32, u64>,
}
impl WorkerLoadState {
/// Returns true if ALL dp_ranks (that have data in both maps) exceed the threshold
pub fn is_busy(&self, threshold: f64) -> bool {
// Get all dp_ranks that exist in both active and total blocks
let common_dp_ranks: Vec<_> = self
.kv_active_blocks
/// Returns true if ALL dp_ranks are considered busy based on the dual-threshold logic:
///
/// For each dp_rank:
/// 1. If `active_prefill_tokens` is available, check if tokens exceed the literal threshold.
/// If so, that dp_rank is busy.
/// 2. If not, check if `active_decode_blocks` and `kv_total_blocks` are both available,
/// and if blocks exceed threshold. If so, that dp_rank is busy.
/// 3. If neither check can be performed (missing data), that dp_rank is considered free.
///
/// The worker is busy only if ALL dp_ranks are busy.
pub fn is_busy(
&self,
active_decode_blocks_threshold: f64,
active_prefill_tokens_threshold: u64,
) -> bool {
// Get all dp_ranks we know about
let all_dp_ranks: std::collections::HashSet<_> = self
.active_decode_blocks
.keys()
.filter(|dp_rank| self.kv_total_blocks.contains_key(dp_rank))
.chain(self.active_prefill_tokens.keys())
.copied()
.collect();
// If no common dp_ranks, not busy
if common_dp_ranks.is_empty() {
// If no dp_ranks known, not busy
if all_dp_ranks.is_empty() {
return false;
}
// Check if ALL common dp_ranks exceed threshold
common_dp_ranks.iter().all(|&&dp_rank| {
if let (Some(&active), Some(&total)) = (
self.kv_active_blocks.get(&dp_rank),
// Check if ALL dp_ranks are busy
all_dp_ranks.iter().all(|&dp_rank| {
// First check: prefill tokens threshold (literal token count)
if let Some(&active_tokens) = self.active_prefill_tokens.get(&dp_rank)
&& active_tokens > active_prefill_tokens_threshold
{
return true; // This dp_rank is busy due to tokens
}
// Second check: blocks threshold
// Skip if total_blocks is 0 (no capacity means threshold check is meaningless)
if let (Some(&active_blocks), Some(&total_blocks)) = (
self.active_decode_blocks.get(&dp_rank),
self.kv_total_blocks.get(&dp_rank),
) {
total > 0 && (active as f64) > (threshold * total as f64)
} else {
false
) && total_blocks > 0
&& (active_blocks as f64) > (active_decode_blocks_threshold * total_blocks as f64)
{
return true; // This dp_rank is busy due to blocks
}
// If we can't perform either check, this dp_rank is considered free
false
})
}
}
......@@ -61,47 +88,76 @@ impl WorkerLoadState {
pub struct KvWorkerMonitor {
client: Client,
worker_load_states: Arc<RwLock<HashMap<u64, WorkerLoadState>>>,
/// Threshold stored as parts-per-10000 (e.g., 8500 = 0.85)
busy_threshold: Arc<AtomicU32>,
/// Active decode blocks threshold stored as parts-per-10000 (e.g., 8500 = 0.85)
active_decode_blocks_threshold: Arc<AtomicU32>,
/// Active prefill tokens threshold stored as literal token count (u64)
active_prefill_tokens_threshold: Arc<AtomicU64>,
/// Guard to ensure start_monitoring() only runs once across clones
started: Arc<AtomicBool>,
}
impl KvWorkerMonitor {
/// Create a new worker monitor with the given threshold.
/// Create a new worker monitor with the given thresholds.
///
/// - `active_decode_blocks_threshold` (0.0-1.0): Threshold percentage for KV cache block utilization
/// - `active_prefill_tokens_threshold`: Literal token count threshold for prefill token utilization
///
/// The threshold (0.0-1.0) controls when workers are considered busy based on
/// KV cache utilization. It can be dynamically updated via `set_threshold()`.
pub fn new(client: Client, threshold: f64) -> Self {
/// Both thresholds can be dynamically updated via `set_active_decode_blocks_threshold()` and
/// `set_active_prefill_tokens_threshold()`.
pub fn new(
client: Client,
active_decode_blocks_threshold: f64,
active_prefill_tokens_threshold: u64,
) -> Self {
Self {
client,
worker_load_states: Arc::new(RwLock::new(HashMap::new())),
busy_threshold: Arc::new(AtomicU32::new(Self::threshold_to_scaled(threshold))),
active_decode_blocks_threshold: Arc::new(AtomicU32::new(
Self::active_decode_blocks_threshold_to_scaled(active_decode_blocks_threshold),
)),
active_prefill_tokens_threshold: Arc::new(AtomicU64::new(
active_prefill_tokens_threshold,
)),
started: Arc::new(AtomicBool::new(false)),
}
}
/// Convert a f64 threshold (0.0-1.0) to scaled u32 for atomic storage.
/// Convert a f64 active decode blocks threshold (0.0-1.0) to scaled u32 for atomic storage.
#[inline]
fn threshold_to_scaled(threshold: f64) -> u32 {
fn active_decode_blocks_threshold_to_scaled(threshold: f64) -> u32 {
(threshold * THRESHOLD_SCALE as f64) as u32
}
/// Convert a scaled u32 back to f64 threshold (0.0-1.0).
/// Convert a scaled u32 back to f64 active decode blocks threshold (0.0-1.0).
#[inline]
fn scaled_to_threshold(scaled: u32) -> f64 {
fn scaled_to_active_decode_blocks_threshold(scaled: u32) -> f64 {
scaled as f64 / THRESHOLD_SCALE as f64
}
/// Get the current threshold value as f64.
pub fn threshold(&self) -> f64 {
Self::scaled_to_threshold(self.busy_threshold.load(Ordering::Relaxed))
/// Get the current active decode blocks threshold value as f64.
pub fn active_decode_blocks_threshold(&self) -> f64 {
Self::scaled_to_active_decode_blocks_threshold(
self.active_decode_blocks_threshold.load(Ordering::Relaxed),
)
}
/// Set the threshold value from f64.
pub fn set_threshold(&self, threshold: f64) {
self.busy_threshold
.store(Self::threshold_to_scaled(threshold), Ordering::Relaxed);
/// Set the active decode blocks threshold value from f64.
pub fn set_active_decode_blocks_threshold(&self, threshold: f64) {
self.active_decode_blocks_threshold.store(
Self::active_decode_blocks_threshold_to_scaled(threshold),
Ordering::Relaxed,
);
}
/// Get the current active prefill tokens threshold value as u64.
pub fn active_prefill_tokens_threshold(&self) -> u64 {
self.active_prefill_tokens_threshold.load(Ordering::Relaxed)
}
/// Set the active prefill tokens threshold value from u64.
pub fn set_active_prefill_tokens_threshold(&self, threshold: u64) {
self.active_prefill_tokens_threshold
.store(threshold, Ordering::Relaxed);
}
/// Get the worker load states for external access
......@@ -143,7 +199,8 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
let worker_load_states = self.worker_load_states.clone();
let client = self.client.clone();
let busy_threshold = self.busy_threshold.clone();
let active_decode_blocks_threshold = self.active_decode_blocks_threshold.clone();
let active_prefill_tokens_threshold = self.active_prefill_tokens_threshold.clone();
// Spawn background monitoring task
tokio::spawn(async move {
......@@ -176,34 +233,46 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
}
}
// Handle KV metrics updates
// Handle KV metrics updates (ActiveLoad)
kv_event = kv_metrics_rx.next() => {
let Some(event) = kv_event else {
tracing::debug!("KV metrics stream closed");
break;
};
if let Ok(load_event) = serde_json::from_slice::<LoadEvent>(&event.payload) {
let worker_id = load_event.worker_id;
let active_blocks = load_event.data.kv_stats.kv_active_blocks;
let dp_rank = load_event.data.worker_stats.data_parallel_rank.unwrap_or(0);
let Ok(active_load) = serde_json::from_slice::<ActiveLoad>(&event.payload) else {
continue;
};
let worker_id = active_load.worker_id;
let dp_rank = active_load.dp_rank;
// Update worker load state per dp_rank
let mut states = worker_load_states.write().unwrap();
let state = states.entry(worker_id).or_default();
state.kv_active_blocks.insert(dp_rank, active_blocks);
if let Some(active_blocks) = active_load.active_decode_blocks {
state.active_decode_blocks.insert(dp_rank, active_blocks);
}
if let Some(active_tokens) = active_load.active_prefill_tokens {
state.active_prefill_tokens.insert(dp_rank, active_tokens);
}
drop(states);
// Load threshold dynamically - allows runtime updates
let scaled_threshold = busy_threshold.load(Ordering::Relaxed);
let current_threshold = Self::scaled_to_threshold(scaled_threshold);
// Load thresholds dynamically - allows runtime updates
let current_active_decode_blocks_threshold = Self::scaled_to_active_decode_blocks_threshold(
active_decode_blocks_threshold.load(Ordering::Relaxed),
);
let current_active_prefill_tokens_threshold = active_prefill_tokens_threshold.load(Ordering::Relaxed);
// Recalculate all busy instances and update
let states = worker_load_states.read().unwrap();
let busy_instances: Vec<u64> = states
.iter()
.filter_map(|(&id, state)| {
state.is_busy(current_threshold).then_some(id)
state
.is_busy(current_active_decode_blocks_threshold, current_active_prefill_tokens_threshold)
.then_some(id)
})
.collect();
drop(states);
......@@ -217,7 +286,6 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
}
}
}
}
tracing::info!("Worker monitoring task exiting");
});
......
......@@ -21,7 +21,10 @@ use crate::{
pub struct RouterConfig {
pub router_mode: RouterMode,
pub kv_router_config: KvRouterConfig,
pub busy_threshold: Option<f64>,
/// Threshold for active decode blocks utilization (0.0-1.0)
pub active_decode_blocks_threshold: Option<f64>,
/// Threshold for active prefill tokens utilization (literal token count)
pub active_prefill_tokens_threshold: Option<u64>,
pub enforce_disagg: bool,
}
......@@ -30,13 +33,19 @@ impl RouterConfig {
Self {
router_mode,
kv_router_config,
busy_threshold: None,
active_decode_blocks_threshold: None,
active_prefill_tokens_threshold: None,
enforce_disagg: false,
}
}
pub fn with_busy_threshold(mut self, threshold: Option<f64>) -> Self {
self.busy_threshold = threshold;
pub fn with_active_decode_blocks_threshold(mut self, threshold: Option<f64>) -> Self {
self.active_decode_blocks_threshold = threshold;
self
}
pub fn with_active_prefill_tokens_threshold(mut self, threshold: Option<u64>) -> Self {
self.active_prefill_tokens_threshold = threshold;
self
}
......
......@@ -237,7 +237,10 @@ where
};
// Get threshold value and wrap monitor for PushRouter
let threshold_value = worker_monitor.as_ref().map(|m| m.threshold());
// Note: PushRouter uses active_decode_blocks_threshold for its internal logic
let threshold_value = worker_monitor
.as_ref()
.map(|m| m.active_decode_blocks_threshold());
let monitor_arc =
worker_monitor.map(|m| Arc::new(m) as Arc<dyn dynamo_runtime::pipeline::WorkerLoadMonitor>);
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! HTTP endpoint for dynamically getting/setting the busy threshold per model.
//! HTTP endpoint for dynamically getting/setting the busy thresholds per model.
//!
//! The busy threshold controls when workers are marked as "busy" based on their
//! KV cache utilization. When all workers for a model exceed their threshold,
//! new requests are rejected with a 503 Service Unavailable response.
//! The busy thresholds control when workers are marked as "busy" based on their
//! KV cache block utilization and prefill token utilization. When all workers
//! for a model exceed their thresholds, new requests are rejected with a 503
//! Service Unavailable response.
//!
//! ## Endpoints
//!
//! ### POST /busy_threshold
//!
//! Get or set a model's busy threshold.
//! Get or set a model's busy thresholds.
//!
//! **Set threshold:**
//! **Set thresholds:**
//! ```json
//! // Request
//! {"model": "llama-3-70b", "threshold": 0.85}
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
//! // Response
//! {"model": "llama-3-70b", "threshold": 0.85}
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
//! ```
//!
//! **Get threshold (omit or null threshold):**
//! **Get thresholds (omit thresholds):**
//! ```json
//! // Request
//! {"model": "llama-3-70b"}
//! // Response (if configured)
//! {"model": "llama-3-70b", "threshold": 0.85}
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
//! // Response (if not configured)
//! {"model": "llama-3-70b", "threshold": null}
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": null, "active_prefill_tokens_threshold": null}
//! ```
//!
//! ### GET /busy_threshold
......@@ -37,29 +38,33 @@
//!
//! ```json
//! // Response
//! {"thresholds": [{"model": "llama-3-70b", "threshold": 0.85}]}
//! {"thresholds": [{"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}]}
//! ```
use super::{RouteDoc, service_v2};
use axum::{
Json, Router,
extract::Request,
http::{Method, StatusCode},
response::IntoResponse,
middleware::Next,
response::{IntoResponse, Response},
routing::{get, post},
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
/// Request body for getting or setting a busy threshold.
/// Request body for getting or setting busy thresholds.
///
/// - If `threshold` is provided: sets/creates the threshold and returns the new value
/// - If `threshold` is null/omitted: returns the existing threshold if any
/// - If thresholds are provided: sets/creates the thresholds and returns the new values
/// - If thresholds are null/omitted: returns the existing thresholds if any
#[derive(Debug, Deserialize)]
pub struct BusyThresholdRequest {
/// The model name
pub model: String,
/// The threshold value (0.0 to 1.0), or null to just get the current value
pub threshold: Option<f64>,
/// The active decode blocks threshold value (0.0 to 1.0), or null to just get the current value
pub active_decode_blocks_threshold: Option<f64>,
/// The active prefill tokens threshold value (literal token count), or null to just get the current value
pub active_prefill_tokens_threshold: Option<u64>,
}
/// Response for a threshold operation
......@@ -67,8 +72,10 @@ pub struct BusyThresholdRequest {
pub struct BusyThresholdResponse {
/// The model name
pub model: String,
/// The threshold value (null if no threshold is configured)
pub threshold: Option<f64>,
/// The active decode blocks threshold value (null if no threshold is configured)
pub active_decode_blocks_threshold: Option<f64>,
/// The active prefill tokens threshold value (null if no threshold is configured)
pub active_prefill_tokens_threshold: Option<u64>,
}
/// Response for listing all thresholds
......@@ -84,6 +91,29 @@ pub struct ErrorResponse {
pub error: String,
}
/// Middleware to convert 422 Unprocessable Entity responses (from JSON deserialization errors)
/// to JSON format instead of text/plain.
async fn json_error_middleware(request: Request, next: Next) -> Response {
let response = next.run(request).await;
if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
let (_parts, body) = response.into_parts();
let body_bytes = axum::body::to_bytes(body, usize::MAX)
.await
.unwrap_or_default();
let error_message = String::from_utf8_lossy(&body_bytes).to_string();
(
StatusCode::UNPROCESSABLE_ENTITY,
Json(serde_json::json!(ErrorResponse {
error: error_message,
})),
)
.into_response()
} else {
response
}
}
pub fn busy_threshold_router(
state: Arc<service_v2::State>,
path: Option<String>,
......@@ -98,6 +128,7 @@ pub fn busy_threshold_router(
let router = Router::new()
.route(&base_path, post(busy_threshold_handler))
.route(&base_path, get(list_busy_thresholds_handler))
.layer(axum::middleware::from_fn(json_error_middleware))
.with_state(state);
(docs, router)
......@@ -107,25 +138,36 @@ async fn busy_threshold_handler(
axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
Json(request): Json<BusyThresholdRequest>,
) -> impl IntoResponse {
// Validate threshold range if provided
if let Some(threshold) = request.threshold
// Validate active decode blocks threshold range if provided (must be 0.0-1.0)
if let Some(threshold) = request.active_decode_blocks_threshold
&& !(0.0..=1.0).contains(&threshold)
{
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!(ErrorResponse {
error: format!("Threshold must be between 0.0 and 1.0, got {}", threshold),
error: format!(
"active_decode_blocks_threshold must be between 0.0 and 1.0, got {}",
threshold
),
})),
);
}
let manager = state.manager();
// Get or set the threshold via the model's worker monitor
let threshold = manager.busy_threshold(&request.model, request.threshold);
// Get or set the thresholds via the model's worker monitor
let active_decode_blocks_threshold = manager
.active_decode_blocks_threshold(&request.model, request.active_decode_blocks_threshold);
let active_prefill_tokens_threshold = manager
.active_prefill_tokens_threshold(&request.model, request.active_prefill_tokens_threshold);
// If trying to SET but model has no monitor, return 404
if request.threshold.is_some() && threshold.is_none() {
let is_setting = request.active_decode_blocks_threshold.is_some()
|| request.active_prefill_tokens_threshold.is_some();
if is_setting
&& active_decode_blocks_threshold.is_none()
&& active_prefill_tokens_threshold.is_none()
{
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!(ErrorResponse {
......@@ -137,11 +179,12 @@ async fn busy_threshold_handler(
);
}
if request.threshold.is_some() {
if is_setting {
tracing::info!(
model = %request.model,
threshold = ?threshold,
"Updated busy threshold"
active_decode_blocks_threshold = ?active_decode_blocks_threshold,
active_prefill_tokens_threshold = ?active_prefill_tokens_threshold,
"Updated busy thresholds"
);
}
......@@ -149,7 +192,8 @@ async fn busy_threshold_handler(
StatusCode::OK,
Json(serde_json::json!(BusyThresholdResponse {
model: request.model,
threshold,
active_decode_blocks_threshold,
active_prefill_tokens_threshold,
})),
)
}
......@@ -163,10 +207,15 @@ async fn list_busy_thresholds_handler(
let response = ListBusyThresholdsResponse {
thresholds: thresholds
.into_iter()
.map(|(model, threshold)| BusyThresholdResponse {
.map(
|(model, active_decode_blocks_threshold, active_prefill_tokens_threshold)| {
BusyThresholdResponse {
model,
threshold: Some(threshold),
})
active_decode_blocks_threshold: Some(active_decode_blocks_threshold),
active_prefill_tokens_threshold: Some(active_prefill_tokens_threshold),
}
},
)
.collect(),
};
......
......@@ -143,6 +143,21 @@ pub struct SpecDecodeStats {
pub num_accepted_tokens_per_pos: Option<Vec<u32>>,
}
/// Active load metrics for a worker, used for busy detection.
///
/// Published by workers (with only `active_decode_blocks`) and by the scheduler
/// (with both `active_decode_blocks` and `active_prefill_tokens`).
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct ActiveLoad {
pub worker_id: WorkerId,
#[serde(default)]
pub dp_rank: DpRank,
/// Number of active KV cache blocks on the worker (decode phase).
pub active_decode_blocks: Option<u64>,
/// Number of active prefill tokens (from scheduler's view).
pub active_prefill_tokens: Option<u64>,
}
/// A [`LocalBlockHash`] is a hash computed from the tokens_ids, extra_token_ids and the optional
/// lora_id of a block.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
......
......@@ -26,7 +26,6 @@ use crate::kv_router::{
KV_EVENT_SUBJECT, KV_METRICS_SUBJECT,
indexer::{RouterEvent, compute_block_hash_for_seq},
protocols::*,
scoring::LoadEvent,
};
use dynamo_runtime::config::environment_names::nats as env_nats;
......@@ -867,14 +866,16 @@ impl WorkerMetricsPublisher {
// Timer expired - publish if we have pending metrics
_ = &mut publish_timer => {
if let Some(metrics) = pending_publish.take() {
// Create LoadEvent wrapping the metrics
let load_event = LoadEvent {
// Create ActiveLoad with only active_decode_blocks (worker doesn't know prefill tokens)
let active_load = ActiveLoad {
worker_id,
data: (*metrics).clone(),
dp_rank: metrics.worker_stats.data_parallel_rank.unwrap_or(0),
active_decode_blocks: Some(metrics.kv_stats.kv_active_blocks),
active_prefill_tokens: None,
};
if let Err(e) =
namespace.publish(KV_METRICS_SUBJECT, &load_event).await
namespace.publish(KV_METRICS_SUBJECT, &active_load).await
{
tracing::warn!("Failed to publish metrics over NATS: {}", e);
}
......@@ -1239,7 +1240,7 @@ mod test_exponential_backoff {
#[cfg(all(test, feature = "integration"))]
mod test_integration_publisher {
use super::*;
use crate::kv_router::protocols::{ForwardPassMetrics, KvStats, WorkerStats};
use crate::kv_router::protocols::{ActiveLoad, ForwardPassMetrics, KvStats, WorkerStats};
use dynamo_runtime::distributed_test_utils::create_test_drt_async;
use dynamo_runtime::traits::events::EventSubscriber;
use futures::StreamExt;
......@@ -1253,7 +1254,7 @@ mod test_integration_publisher {
// Create a subscriber for the metrics events using subscribe_with_type
let mut subscriber = namespace
.subscribe_with_type::<LoadEvent>(KV_METRICS_SUBJECT)
.subscribe_with_type::<ActiveLoad>(KV_METRICS_SUBJECT)
.await
.unwrap();
......@@ -1301,8 +1302,8 @@ mod test_integration_publisher {
let event = result.unwrap().unwrap(); // Unwrap the Option and the Result
assert_eq!(event.worker_id, worker_id);
assert_eq!(event.data.kv_stats.kv_active_blocks, 900); // Last value: 9 * 100
assert_eq!(event.data.worker_stats.num_requests_waiting, 90); // Last value: 9 * 10
assert_eq!(event.active_decode_blocks, Some(900)); // Last value: 9 * 100
assert_eq!(event.active_prefill_tokens, None); // Worker doesn't publish prefill tokens
// Ensure no more events are waiting
let no_msg =
......
......@@ -133,7 +133,7 @@ impl KvScheduler {
let slots_monitor = slots.clone();
let mut instance_ids_monitor_rx = instance_ids_rx.clone();
let mut configs_monitor_rx = runtime_configs_rx.clone();
let monitor_cancel_token = component.drt().primary_token();
let monitor_cancel_token = component.drt().child_token();
tokio::spawn(async move {
tracing::trace!("workers monitoring task started");
loop {
......
......@@ -3,16 +3,10 @@
//! Scoring functions for the KV router.
use super::protocols::{ForwardPassMetrics, LoadMetrics};
use super::protocols::LoadMetrics;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct LoadEvent {
pub worker_id: u64,
pub data: ForwardPassMetrics,
}
/// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
/// is cleaned (not optional)
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
......
......@@ -38,8 +38,10 @@ use std::time::Duration;
use tokio::time::Instant;
use uuid::Uuid;
use super::protocols::{ActiveSequenceEvent, ActiveSequenceEventData, WorkerWithDpRank};
use crate::kv_router::ACTIVE_SEQUENCES_SUBJECT;
use super::protocols::{
ActiveLoad, ActiveSequenceEvent, ActiveSequenceEventData, WorkerWithDpRank,
};
use crate::kv_router::{ACTIVE_SEQUENCES_SUBJECT, KV_METRICS_SUBJECT};
use crate::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_runtime::CancellationToken;
......@@ -701,6 +703,9 @@ impl ActiveSequencesMultiWorker {
self.request_to_worker.remove(expired_id);
}
// Publish ActiveLoad metrics for this worker
self.publish_active_load_for_worker(worker).await;
Ok(())
}
......@@ -744,6 +749,9 @@ impl ActiveSequencesMultiWorker {
self.request_to_worker.remove(request_id);
// Publish ActiveLoad metrics for this worker
self.publish_active_load_for_worker(worker).await;
Ok(())
}
......@@ -790,9 +798,66 @@ impl ActiveSequencesMultiWorker {
})
.map_err(|_| SequenceError::WorkerChannelClosed)?;
// Publish ActiveLoad metrics for this worker
self.publish_active_load_for_worker(worker).await;
Ok(())
}
/// Helper method to query a single worker for active blocks/tokens and publish ActiveLoad
async fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) {
let Some(sender) = self.senders.get(&worker) else {
tracing::warn!("Worker {worker:?} not found when publishing ActiveLoad");
return;
};
// Query active blocks
let (blocks_tx, blocks_rx) = tokio::sync::oneshot::channel();
if sender
.send(UpdateSequences::ActiveBlocks { resp_tx: blocks_tx })
.is_err()
{
tracing::warn!("Failed to send ActiveBlocks query to worker {worker:?}");
return;
}
// Query active tokens
let (tokens_tx, tokens_rx) = tokio::sync::oneshot::channel();
if sender
.send(UpdateSequences::ActiveTokens { resp_tx: tokens_tx })
.is_err()
{
tracing::warn!("Failed to send ActiveTokens query to worker {worker:?}");
return;
}
// Await both responses
let (active_blocks, active_tokens) = match tokio::join!(blocks_rx, tokens_rx) {
(Ok(blocks), Ok(tokens)) => (blocks, tokens),
_ => {
tracing::warn!("Failed to receive active blocks/tokens from worker {worker:?}");
return;
}
};
// Publish ActiveLoad
let active_load = ActiveLoad {
worker_id: worker.worker_id,
dp_rank: worker.dp_rank,
active_decode_blocks: Some(active_blocks as u64),
active_prefill_tokens: Some(active_tokens as u64),
};
if let Err(e) = self
.component
.namespace()
.publish(KV_METRICS_SUBJECT, &active_load)
.await
{
tracing::warn!("Failed to publish ActiveLoad for worker {worker:?}: {e:?}");
}
}
/// Get the number of workers
pub fn num_workers(&self) -> usize {
self.senders.len()
......
......@@ -38,7 +38,8 @@ class KVRouterProcess(ManagedProcess):
namespace: str,
store_backend: str = "etcd",
enforce_disagg: bool = False,
busy_threshold: float | None = None,
blocks_threshold: float | None = None,
tokens_threshold: float | None = None,
request_plane: str = "nats",
):
command = [
......@@ -60,8 +61,11 @@ class KVRouterProcess(ManagedProcess):
if enforce_disagg:
command.append("--enforce-disagg")
if busy_threshold is not None:
command.extend(["--busy-threshold", str(busy_threshold)])
if blocks_threshold is not None:
command.extend(["--active-decode-blocks-threshold", str(blocks_threshold)])
if tokens_threshold is not None:
command.extend(["--active-prefill-tokens-threshold", str(tokens_threshold)])
env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request_plane
......@@ -1156,7 +1160,7 @@ def _test_router_overload_503(
request,
frontend_port: int,
test_payload: dict,
busy_threshold: float = 0.2,
blocks_threshold: float = 0.2,
):
"""Test that KV router returns 503 when all workers are busy.
......@@ -1169,7 +1173,7 @@ def _test_router_overload_503(
request: Pytest request fixture for managing resources
frontend_port: Port for the frontend HTTP server
test_payload: Base test payload to send to /v1/chat/completions
busy_threshold: Busy threshold for the router (default 0.2)
blocks_threshold: Active decode blocks threshold for the router (default 0.2)
Raises:
AssertionError: If 503 response is not received when expected
......@@ -1185,8 +1189,8 @@ def _test_router_overload_503(
"python",
"-m",
"dynamo.frontend",
"--busy-threshold",
str(busy_threshold),
"--active-decode-blocks-threshold",
str(blocks_threshold),
"--kv-cache-block-size",
str(block_size),
"--router-mode",
......@@ -2038,11 +2042,12 @@ def _test_busy_threshold_endpoint(
Raises:
AssertionError: If endpoint responses are incorrect
"""
# Initial threshold - we need to start with one so the monitor is created
initial_threshold = 0.9
# Initial thresholds - we need to start with these so the monitor is created
initial_active_decode_blocks_threshold = 0.9
initial_active_prefill_tokens_threshold = 1000 # Literal token count threshold
try:
# Start KV router frontend with initial busy_threshold to create monitor
# Start KV router frontend with initial thresholds to create monitor
logger.info(f"Starting KV router frontend on port {frontend_port}")
kv_router = KVRouterProcess(
request,
......@@ -2050,7 +2055,8 @@ def _test_busy_threshold_endpoint(
frontend_port,
engine_workers.namespace,
store_backend,
busy_threshold=initial_threshold,
blocks_threshold=initial_active_decode_blocks_threshold,
tokens_threshold=initial_active_prefill_tokens_threshold,
request_plane=request_plane,
)
kv_router.__enter__()
......@@ -2073,7 +2079,6 @@ def _test_busy_threshold_endpoint(
async def test_busy_threshold_api():
async with aiohttp.ClientSession() as session:
# Test 1: GET /busy_threshold - list all thresholds
# Should have the initial threshold since we started with --busy-threshold
logger.info("Testing GET /busy_threshold (list all)")
async with session.get(busy_threshold_url) as response:
assert (
......@@ -2083,14 +2088,11 @@ def _test_busy_threshold_endpoint(
assert (
"thresholds" in data
), f"Expected 'thresholds' key in response: {data}"
thresholds = data.get("thresholds", [])
# Should have at least the model with initial_threshold
logger.info(f"GET /busy_threshold response: {data}")
# Test 2: POST /busy_threshold with model only (get threshold)
# Should return the initial threshold since we started with --busy-threshold
# Test 2: POST /busy_threshold with model only (get thresholds)
logger.info(
f"Testing POST /busy_threshold to get threshold for model '{model_name}'"
f"Testing POST /busy_threshold to get thresholds for model '{model_name}'"
)
async with session.post(
busy_threshold_url,
......@@ -2101,99 +2103,173 @@ def _test_busy_threshold_endpoint(
), f"POST /busy_threshold (get) failed with status {response.status}"
data = await response.json()
assert (
data.get("threshold") == initial_threshold
), f"Expected initial threshold={initial_threshold}: {data}"
data.get("active_decode_blocks_threshold")
== initial_active_decode_blocks_threshold
), f"Expected initial active_decode_blocks_threshold={initial_active_decode_blocks_threshold}: {data}"
assert (
data.get("active_prefill_tokens_threshold")
== initial_active_prefill_tokens_threshold
), f"Expected initial active_prefill_tokens_threshold={initial_active_prefill_tokens_threshold}: {data}"
logger.info(
f"POST /busy_threshold (get) response: status={response.status}, data={data}"
)
# Test 3: POST /busy_threshold to set a threshold
test_threshold = 0.75
# Test 3: POST /busy_threshold to set active_decode_blocks_threshold only
test_active_decode_blocks_threshold = 0.75
logger.info(
f"Testing POST /busy_threshold to set threshold={test_threshold}"
f"Testing POST /busy_threshold to set active_decode_blocks_threshold={test_active_decode_blocks_threshold}"
)
async with session.post(
busy_threshold_url,
json={"model": model_name, "threshold": test_threshold},
json={
"model": model_name,
"active_decode_blocks_threshold": test_active_decode_blocks_threshold,
},
) as response:
assert (
response.status == 200
), f"POST /busy_threshold (set) failed with status {response.status}"
), f"POST /busy_threshold (set blocks) failed with status {response.status}"
data = await response.json()
assert (
data.get("model") == model_name
), f"Expected model={model_name}: {data}"
assert (
data.get("threshold") == test_threshold
), f"Expected threshold={test_threshold}: {data}"
logger.info(f"POST /busy_threshold (set) response: {data}")
data.get("active_decode_blocks_threshold")
== test_active_decode_blocks_threshold
), f"Expected active_decode_blocks_threshold={test_active_decode_blocks_threshold}: {data}"
logger.info(f"POST /busy_threshold (set blocks) response: {data}")
# Test 4: POST /busy_threshold to get the threshold we just set
logger.info("Testing POST /busy_threshold to verify threshold was set")
# Test 4: POST /busy_threshold to set active_prefill_tokens_threshold only
test_active_prefill_tokens_threshold = (
2000 # Literal token count threshold
)
logger.info(
f"Testing POST /busy_threshold to set active_prefill_tokens_threshold={test_active_prefill_tokens_threshold}"
)
async with session.post(
busy_threshold_url,
json={"model": model_name},
json={
"model": model_name,
"active_prefill_tokens_threshold": test_active_prefill_tokens_threshold,
},
) as response:
assert (
response.status == 200
), f"POST /busy_threshold (get after set) failed with status {response.status}"
), f"POST /busy_threshold (set tokens) failed with status {response.status}"
data = await response.json()
assert (
data.get("threshold") == test_threshold
), f"Expected threshold={test_threshold}: {data}"
logger.info(
f"POST /busy_threshold (get after set) response: {data}"
)
data.get("active_prefill_tokens_threshold")
== test_active_prefill_tokens_threshold
), f"Expected active_prefill_tokens_threshold={test_active_prefill_tokens_threshold}: {data}"
logger.info(f"POST /busy_threshold (set tokens) response: {data}")
# Test 5: POST /busy_threshold to update the threshold
new_threshold = 0.5
# Test 5: POST /busy_threshold to set both thresholds
new_active_decode_blocks_threshold = 0.5
new_active_prefill_tokens_threshold = (
1200 # Literal token count threshold
)
logger.info(
f"Testing POST /busy_threshold to update threshold={new_threshold}"
f"Testing POST /busy_threshold to set both thresholds: "
f"active_decode_blocks={new_active_decode_blocks_threshold}, active_prefill_tokens={new_active_prefill_tokens_threshold}"
)
async with session.post(
busy_threshold_url,
json={"model": model_name, "threshold": new_threshold},
json={
"model": model_name,
"active_decode_blocks_threshold": new_active_decode_blocks_threshold,
"active_prefill_tokens_threshold": new_active_prefill_tokens_threshold,
},
) as response:
assert (
response.status == 200
), f"POST /busy_threshold (update) failed with status {response.status}"
), f"POST /busy_threshold (set both) failed with status {response.status}"
data = await response.json()
assert (
data.get("threshold") == new_threshold
), f"Expected threshold={new_threshold}: {data}"
logger.info(f"POST /busy_threshold (update) response: {data}")
data.get("active_decode_blocks_threshold")
== new_active_decode_blocks_threshold
), f"Expected active_decode_blocks_threshold={new_active_decode_blocks_threshold}: {data}"
assert (
data.get("active_prefill_tokens_threshold")
== new_active_prefill_tokens_threshold
), f"Expected active_prefill_tokens_threshold={new_active_prefill_tokens_threshold}: {data}"
logger.info(f"POST /busy_threshold (set both) response: {data}")
# Test 6: GET /busy_threshold - verify threshold appears in list
logger.info("Testing GET /busy_threshold to verify threshold in list")
# Test 6: GET /busy_threshold - verify thresholds appear in list
logger.info("Testing GET /busy_threshold to verify thresholds in list")
async with session.get(busy_threshold_url) as response:
assert (
response.status == 200
), f"GET /busy_threshold failed with status {response.status}"
data = await response.json()
thresholds = data.get("thresholds", [])
# thresholds is an array of {model, threshold} objects
model_thresholds = {t["model"]: t["threshold"] for t in thresholds}
model_entry = next(
(t for t in thresholds if t["model"] == model_name), None
)
assert (
model_name in model_thresholds
model_entry is not None
), f"Expected model '{model_name}' in thresholds: {data}"
assert (
model_thresholds[model_name] == new_threshold
), f"Expected threshold={new_threshold} for model '{model_name}': {data}"
model_entry.get("active_decode_blocks_threshold")
== new_active_decode_blocks_threshold
), f"Expected active_decode_blocks_threshold={new_active_decode_blocks_threshold}: {data}"
assert (
model_entry.get("active_prefill_tokens_threshold")
== new_active_prefill_tokens_threshold
), f"Expected active_prefill_tokens_threshold={new_active_prefill_tokens_threshold}: {data}"
logger.info(f"GET /busy_threshold (after set) response: {data}")
# Test 7: Invalid threshold value (should fail validation)
# Test 7: Invalid active_decode_blocks_threshold value (should fail validation)
logger.info(
"Testing POST /busy_threshold with invalid threshold (>1.0)"
"Testing POST /busy_threshold with invalid active_decode_blocks_threshold (>1.0)"
)
async with session.post(
busy_threshold_url,
json={"model": model_name, "threshold": 1.5},
json={"model": model_name, "active_decode_blocks_threshold": 1.5},
) as response:
assert (
response.status == 400
), f"Expected 400 for invalid threshold, got {response.status}"
), f"Expected 400 for invalid active_decode_blocks_threshold, got {response.status}"
data = await response.json()
logger.info(f"POST /busy_threshold (invalid) response: {data}")
logger.info(
f"POST /busy_threshold (invalid blocks) response: {data}"
)
# Test 8: active_prefill_tokens_threshold accepts large values (should be valid)
logger.info(
"Testing POST /busy_threshold with large active_prefill_tokens_threshold (valid)"
)
async with session.post(
busy_threshold_url,
json={"model": model_name, "active_prefill_tokens_threshold": 5000},
) as response:
assert (
response.status == 200
), f"Expected 200 for large active_prefill_tokens_threshold, got {response.status}"
data = await response.json()
assert (
data.get("active_prefill_tokens_threshold") == 5000
), f"Expected active_prefill_tokens_threshold=5000: {data}"
logger.info(
f"POST /busy_threshold (large tokens threshold) response: {data}"
)
# Test 9: Invalid active_prefill_tokens_threshold value (should fail validation for < 0)
# Note: Returns 422 because -1.0 can't be deserialized into u64 (type validation)
# vs Test 7 which returns 400 because 1.5 is a valid f64 but fails range validation
logger.info(
"Testing POST /busy_threshold with invalid active_prefill_tokens_threshold (< 0)"
)
async with session.post(
busy_threshold_url,
json={"model": model_name, "active_prefill_tokens_threshold": -1.0},
) as response:
assert (
response.status == 422
), f"Expected 422 for negative active_prefill_tokens_threshold, got {response.status}"
data = await response.json()
logger.info(
f"POST /busy_threshold (invalid tokens) response: {data}"
)
logger.info("All busy_threshold endpoint tests passed!")
......
......@@ -426,7 +426,7 @@ def test_mocker_kv_router_overload_503(
request=request,
frontend_port=frontend_port,
test_payload=TEST_PAYLOAD,
busy_threshold=0.2,
blocks_threshold=0.2,
)
finally:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment