Unverified Commit 4c1bc4ee authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: dynamic setting of thresholds for rejection (#4673)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 9fb5f03a
...@@ -31,7 +31,7 @@ The main KV-aware routing arguments: ...@@ -31,7 +31,7 @@ The main KV-aware routing arguments:
- `--no-track-active-blocks`: Disables tracking of active blocks (blocks being used for ongoing generation/decode phases). By default, the router tracks active blocks for load balancing. Disable this when routing to workers that only perform prefill (no decode phase), as tracking decode load is not relevant. This reduces router overhead and simplifies state management. - `--no-track-active-blocks`: Disables tracking of active blocks (blocks being used for ongoing generation/decode phases). By default, the router tracks active blocks for load balancing. Disable this when routing to workers that only perform prefill (no decode phase), as tracking decode load is not relevant. This reduces router overhead and simplifies state management.
- `--busy-threshold`: Threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache usage. When a worker's KV cache active blocks exceed this percentage of total blocks, it will be marked as busy and excluded from routing. If not set, busy detection is disabled. This feature works with all routing modes (`--router-mode kv|round-robin|random`) as long as backend engines emit `ForwardPassMetrics`. - `--busy-threshold`: Initial threshold (0.0-1.0) for determining when a worker is considered busy based on KV cache usage. When a worker's KV cache active blocks exceed this percentage of total blocks, it will be marked as busy and excluded from routing. If not set, busy detection is disabled. This feature works with all routing modes (`--router-mode kv|round-robin|random`) as long as backend engines emit `ForwardPassMetrics`. The threshold can be dynamically updated at runtime via the `/busy_threshold` HTTP endpoint (see [Dynamic Threshold Configuration](#dynamic-threshold-configuration)).
- `--router-ttl`: Time-to-live in seconds for blocks in the router's local cache predictions. Blocks older than this duration will be automatically expired and removed from the router's radix tree. Defaults to 120.0 seconds when `--no-kv-events` is used. This helps manage memory usage by removing stale cache predictions that are unlikely to be accurate. - `--router-ttl`: Time-to-live in seconds for blocks in the router's local cache predictions. Blocks older than this duration will be automatically expired and removed from the router's radix tree. Defaults to 120.0 seconds when `--no-kv-events` is used. This helps manage memory usage by removing stale cache predictions that are unlikely to be accurate.
...@@ -582,3 +582,31 @@ This approach gives you complete control over routing decisions, allowing you to ...@@ -582,3 +582,31 @@ This approach gives you complete control over routing decisions, allowing you to
- **Balance load**: Consider both `potential_prefill_tokens` and `potential_decode_blocks` together - **Balance load**: Consider both `potential_prefill_tokens` and `potential_decode_blocks` together
See [KV Router Architecture](../router/README.md) for performance tuning details. See [KV Router Architecture](../router/README.md) for performance tuning details.
## Dynamic Threshold Configuration
The busy threshold can be updated at runtime without restarting the frontend. The frontend exposes HTTP endpoints at `/busy_threshold`:
**Get or set a model's threshold (POST):**
```bash
# Set threshold for a model
curl -X POST http://localhost:8000/busy_threshold \
-H "Content-Type: application/json" \
-d '{"model": "meta-llama/Llama-2-7b-hf", "threshold": 0.85}'
# Response: {"model": "meta-llama/Llama-2-7b-hf", "threshold": 0.85}
# Get current threshold (omit threshold field)
curl -X POST http://localhost:8000/busy_threshold \
-H "Content-Type: application/json" \
-d '{"model": "meta-llama/Llama-2-7b-hf"}'
# Response: {"model": "meta-llama/Llama-2-7b-hf", "threshold": 0.85}
# Or if not configured: {"model": "...", "threshold": null}
```
**List all configured thresholds (GET):**
```bash
curl http://localhost:8000/busy_threshold
# Response: {"thresholds": [{"model": "meta-llama/Llama-2-7b-hf", "threshold": 0.85}]}
```
This allows you to tune the busy threshold based on observed system behavior without service interruption.
\ No newline at end of file
...@@ -6,10 +6,12 @@ use libc::c_char; ...@@ -6,10 +6,12 @@ use libc::c_char;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use std::borrow::Cow; use std::borrow::Cow;
use std::ffi::CStr; use std::ffi::CStr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::atomic::{AtomicU32, Ordering};
use dynamo_llm::kv_router::{ use dynamo_llm::{
indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher, discovery::{KvWorkerMonitor, ModelWatcher},
kv_router::{indexer::compute_block_hash_for_seq, protocols::*, publisher::KvEventPublisher},
}; };
use dynamo_runtime::{DistributedRuntime, Worker}; use dynamo_runtime::{DistributedRuntime, Worker};
static WK: OnceCell<Worker> = OnceCell::new(); static WK: OnceCell<Worker> = OnceCell::new();
...@@ -960,8 +962,7 @@ pub async fn create_worker_selection_pipeline_chat( ...@@ -960,8 +962,7 @@ pub async fn create_worker_selection_pipeline_chat(
tracing::debug!("Looking for model: {}", model_name); tracing::debug!("Looking for model: {}", model_name);
tracing::debug!("Namespace: {}", namespace); tracing::debug!("Namespace: {}", namespace);
use dynamo_llm::discovery::ModelWatcher; let model_manager = Arc::new(ModelManager::new());
let model_manager = std::sync::Arc::new(ModelManager::new());
let router_config = dynamo_llm::entrypoint::RouterConfig { let router_config = dynamo_llm::entrypoint::RouterConfig {
router_mode, router_mode,
kv_router_config: kv_router_config.unwrap_or_default(), kv_router_config: kv_router_config.unwrap_or_default(),
...@@ -1028,6 +1029,10 @@ pub async fn create_worker_selection_pipeline_chat( ...@@ -1028,6 +1029,10 @@ pub async fn create_worker_selection_pipeline_chat(
.tokenizer_hf() .tokenizer_hf()
.with_context(|| format!("Failed to load tokenizer for: {}", card.display_name))?; .with_context(|| format!("Failed to load tokenizer for: {}", card.display_name))?;
// Create worker monitor if busy_threshold is set
// Note: C bindings don't register with ModelManager, so HTTP endpoint won't see this
let worker_monitor = busy_threshold.map(|t| KvWorkerMonitor::new(Arc::new(client.clone()), t));
let engine = build_routed_pipeline::< let engine = build_routed_pipeline::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
...@@ -1035,7 +1040,7 @@ pub async fn create_worker_selection_pipeline_chat( ...@@ -1035,7 +1040,7 @@ pub async fn create_worker_selection_pipeline_chat(
&card_with_local_files, &card_with_local_files,
&client, &client,
router_mode, router_mode,
busy_threshold, worker_monitor,
chooser, chooser,
hf_tokenizer, hf_tokenizer,
None, // prefill_chooser None, // prefill_chooser
......
...@@ -9,8 +9,10 @@ use std::{ ...@@ -9,8 +9,10 @@ use std::{
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use tokio::sync::oneshot; use tokio::sync::oneshot;
use crate::discovery::KvWorkerMonitor;
use dynamo_runtime::{ use dynamo_runtime::{
component::{Endpoint, build_transport_type}, component::{Client, Endpoint, build_transport_type},
discovery::DiscoverySpec, discovery::DiscoverySpec,
prelude::DistributedRuntimeProvider, prelude::DistributedRuntimeProvider,
protocols::EndpointId, protocols::EndpointId,
...@@ -47,7 +49,12 @@ pub enum ModelManagerError { ...@@ -47,7 +49,12 @@ pub enum ModelManagerError {
ModelAlreadyExists(String), ModelAlreadyExists(String),
} }
// Don't implement Clone for this, put it in an Arc instead. /// Central manager for model engines, routing, and configuration.
///
/// Manages model lifecycle including engines, KV routers, prefill coordination,
/// and per-model busy thresholds for load-based request rejection.
///
/// Note: Don't implement Clone for this, put it in an Arc instead.
pub struct ModelManager { pub struct ModelManager {
// We read a lot and write rarely, so these three are RwLock // We read a lot and write rarely, so these three are RwLock
completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>, completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
...@@ -61,6 +68,11 @@ pub struct ModelManager { ...@@ -61,6 +68,11 @@ pub struct ModelManager {
cards: Mutex<HashMap<String, ModelDeploymentCard>>, cards: Mutex<HashMap<String, ModelDeploymentCard>>,
kv_choosers: Mutex<HashMap<EndpointId, Arc<KvRouter>>>, kv_choosers: Mutex<HashMap<EndpointId, Arc<KvRouter>>>,
prefill_router_activators: Mutex<HashMap<String, PrefillActivationState>>, prefill_router_activators: Mutex<HashMap<String, PrefillActivationState>>,
/// Per-model worker monitors for dynamic KV cache load rejection.
/// Key: model name, Value: cloneable monitor (all fields are Arc).
/// HTTP endpoint can update thresholds via monitor.set_threshold().
worker_monitors: RwLock<HashMap<String, KvWorkerMonitor>>,
} }
impl Default for ModelManager { impl Default for ModelManager {
...@@ -80,6 +92,7 @@ impl ModelManager { ...@@ -80,6 +92,7 @@ impl ModelManager {
cards: Mutex::new(HashMap::new()), cards: Mutex::new(HashMap::new()),
kv_choosers: Mutex::new(HashMap::new()), kv_choosers: Mutex::new(HashMap::new()),
prefill_router_activators: Mutex::new(HashMap::new()), prefill_router_activators: Mutex::new(HashMap::new()),
worker_monitors: RwLock::new(HashMap::new()),
} }
} }
...@@ -471,6 +484,83 @@ impl ModelManager { ...@@ -471,6 +484,83 @@ impl ModelManager {
crate::protocols::openai::ParsingOptions::new(tool_call_parser, reasoning_parser) crate::protocols::openai::ParsingOptions::new(tool_call_parser, reasoning_parser)
} }
/// Gets or sets the busy threshold for a model via its worker monitor.
///
/// This is the primary API for HTTP endpoints and external callers.
/// The threshold (0.0 to 1.0) controls when workers are marked as "busy"
/// based on KV cache utilization.
///
/// # Arguments
///
/// * `model` - The model name
/// * `threshold` - `Some(value)` to set, `None` to get existing
///
/// # Returns
///
/// The threshold value as f64, or `None` if no monitor exists for this model.
/// Note: Setting a threshold for a non-existent model returns `None` (monitor
/// must be created via `get_or_create_worker_monitor` during model discovery).
pub fn busy_threshold(&self, model: &str, threshold: Option<f64>) -> Option<f64> {
let monitors = self.worker_monitors.read();
let monitor = monitors.get(model)?;
match threshold {
Some(value) => {
monitor.set_threshold(value);
Some(value)
}
None => Some(monitor.threshold()),
}
}
/// Gets or creates a worker monitor for a model.
///
/// If a monitor already exists, updates its threshold and returns a clone.
/// If no monitor exists, creates one with the given client and threshold.
///
/// # Arguments
///
/// * `model` - The model name
/// * `client` - The client for subscribing to KV metrics (only used if creating new)
/// * `threshold` - The initial/updated threshold value (0.0-1.0)
///
/// # Returns
///
/// A cloneable monitor that shares state with the stored instance.
pub fn get_or_create_worker_monitor(
&self,
model: &str,
client: Arc<Client>,
threshold: f64,
) -> KvWorkerMonitor {
let mut monitors = self.worker_monitors.write();
if let Some(existing) = monitors.get(model) {
existing.set_threshold(threshold);
existing.clone()
} else {
let monitor = KvWorkerMonitor::new(client, threshold);
monitors.insert(model.to_string(), monitor.clone());
monitor
}
}
/// Gets an existing worker monitor for a model, if one exists.
pub fn get_worker_monitor(&self, model: &str) -> Option<KvWorkerMonitor> {
self.worker_monitors.read().get(model).cloned()
}
/// Lists all models that have worker monitors (and thus busy thresholds) configured.
///
/// Returns a vector of (model_name, threshold_value) tuples.
pub fn list_busy_thresholds(&self) -> Vec<(String, f64)> {
self.worker_monitors
.read()
.iter()
.map(|(k, monitor)| (k.clone(), monitor.threshold()))
.collect()
}
} }
pub struct ModelEngines<E> { pub struct ModelEngines<E> {
......
...@@ -402,6 +402,16 @@ impl ModelWatcher { ...@@ -402,6 +402,16 @@ impl ModelWatcher {
) )
}); });
// Get or create the worker monitor for this model
// This allows dynamic threshold updates via the ModelManager
let worker_monitor = self.router_config.busy_threshold.map(|threshold| {
self.manager.get_or_create_worker_monitor(
card.name(),
Arc::new(client.clone()),
threshold,
)
});
// Add chat engine only if the model supports chat // Add chat engine only if the model supports chat
if card.model_type.supports_chat() { if card.model_type.supports_chat() {
let chat_engine = entrypoint::build_routed_pipeline::< let chat_engine = entrypoint::build_routed_pipeline::<
...@@ -411,7 +421,7 @@ impl ModelWatcher { ...@@ -411,7 +421,7 @@ impl ModelWatcher {
card, card,
&client, &client,
self.router_config.router_mode, self.router_config.router_mode,
self.router_config.busy_threshold, worker_monitor.clone(),
kv_chooser.clone(), kv_chooser.clone(),
tokenizer_hf.clone(), tokenizer_hf.clone(),
prefill_chooser.clone(), prefill_chooser.clone(),
...@@ -442,7 +452,7 @@ impl ModelWatcher { ...@@ -442,7 +452,7 @@ impl ModelWatcher {
card, card,
&client, &client,
self.router_config.router_mode, self.router_config.router_mode,
self.router_config.busy_threshold, worker_monitor,
kv_chooser, kv_chooser,
preprocessor, preprocessor,
tokenizer_hf, tokenizer_hf,
......
...@@ -10,9 +10,13 @@ use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait}; ...@@ -10,9 +10,13 @@ use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventSubscriber; use dynamo_runtime::traits::events::EventSubscriber;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
/// Scale factor for storing f64 thresholds as u32 (10000 = 4 decimal places)
const THRESHOLD_SCALE: u32 = 10000;
/// Worker load monitoring state per dp_rank /// Worker load monitoring state per dp_rank
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub struct WorkerLoadState { pub struct WorkerLoadState {
...@@ -49,23 +53,57 @@ impl WorkerLoadState { ...@@ -49,23 +53,57 @@ impl WorkerLoadState {
} }
} }
/// Worker monitor for tracking KV cache usage and busy states /// Worker monitor for tracking KV cache usage and busy states.
///
/// All fields are `Arc`, so cloning shares state. This allows multiple pipelines
/// (e.g., chat and completions) to share the same monitor instance.
#[derive(Clone)]
pub struct KvWorkerMonitor { pub struct KvWorkerMonitor {
client: Arc<Client>, client: Arc<Client>,
worker_load_states: Arc<RwLock<HashMap<u64, WorkerLoadState>>>, worker_load_states: Arc<RwLock<HashMap<u64, WorkerLoadState>>>,
busy_threshold: f64, /// Threshold stored as parts-per-10000 (e.g., 8500 = 0.85)
busy_threshold: Arc<AtomicU32>,
/// Guard to ensure start_monitoring() only runs once across clones
started: Arc<AtomicBool>,
} }
impl KvWorkerMonitor { impl KvWorkerMonitor {
/// Create a new worker monitor with custom threshold /// Create a new worker monitor with the given threshold.
pub fn new(client: Arc<Client>, busy_threshold: f64) -> Self { ///
/// The threshold (0.0-1.0) controls when workers are considered busy based on
/// KV cache utilization. It can be dynamically updated via `set_threshold()`.
pub fn new(client: Arc<Client>, threshold: f64) -> Self {
Self { Self {
client, client,
worker_load_states: Arc::new(RwLock::new(HashMap::new())), worker_load_states: Arc::new(RwLock::new(HashMap::new())),
busy_threshold, busy_threshold: Arc::new(AtomicU32::new(Self::threshold_to_scaled(threshold))),
started: Arc::new(AtomicBool::new(false)),
} }
} }
/// Convert a f64 threshold (0.0-1.0) to scaled u32 for atomic storage.
#[inline]
fn threshold_to_scaled(threshold: f64) -> u32 {
(threshold * THRESHOLD_SCALE as f64) as u32
}
/// Convert a scaled u32 back to f64 threshold (0.0-1.0).
#[inline]
fn scaled_to_threshold(scaled: u32) -> f64 {
scaled as f64 / THRESHOLD_SCALE as f64
}
/// Get the current threshold value as f64.
pub fn threshold(&self) -> f64 {
Self::scaled_to_threshold(self.busy_threshold.load(Ordering::Relaxed))
}
/// Set the threshold value from f64.
pub fn set_threshold(&self, threshold: f64) {
self.busy_threshold
.store(Self::threshold_to_scaled(threshold), Ordering::Relaxed);
}
/// Get the worker load states for external access /// Get the worker load states for external access
pub fn load_states(&self) -> Arc<RwLock<HashMap<u64, WorkerLoadState>>> { pub fn load_states(&self) -> Arc<RwLock<HashMap<u64, WorkerLoadState>>> {
self.worker_load_states.clone() self.worker_load_states.clone()
...@@ -74,8 +112,17 @@ impl KvWorkerMonitor { ...@@ -74,8 +112,17 @@ impl KvWorkerMonitor {
#[async_trait] #[async_trait]
impl WorkerLoadMonitor for KvWorkerMonitor { impl WorkerLoadMonitor for KvWorkerMonitor {
/// Start background monitoring of worker KV cache usage /// Start background monitoring of worker KV cache usage.
///
/// This is safe to call multiple times (e.g., from cloned monitors shared across
/// pipelines) - only the first call spawns the background task.
async fn start_monitoring(&self) -> anyhow::Result<()> { async fn start_monitoring(&self) -> anyhow::Result<()> {
// Guard: only start once across all clones
if self.started.swap(true, Ordering::SeqCst) {
tracing::debug!("Worker monitoring already started, skipping");
return Ok(());
}
let endpoint = &self.client.endpoint; let endpoint = &self.client.endpoint;
let component = endpoint.component(); let component = endpoint.component();
...@@ -96,7 +143,7 @@ impl WorkerLoadMonitor for KvWorkerMonitor { ...@@ -96,7 +143,7 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
let worker_load_states = self.worker_load_states.clone(); let worker_load_states = self.worker_load_states.clone();
let client = self.client.clone(); let client = self.client.clone();
let busy_threshold = self.busy_threshold; let busy_threshold = self.busy_threshold.clone();
// Spawn background monitoring task // Spawn background monitoring task
tokio::spawn(async move { tokio::spawn(async move {
...@@ -147,12 +194,16 @@ impl WorkerLoadMonitor for KvWorkerMonitor { ...@@ -147,12 +194,16 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
state.kv_active_blocks.insert(dp_rank, active_blocks); state.kv_active_blocks.insert(dp_rank, active_blocks);
drop(states); drop(states);
// Load threshold dynamically - allows runtime updates
let scaled_threshold = busy_threshold.load(Ordering::Relaxed);
let current_threshold = Self::scaled_to_threshold(scaled_threshold);
// Recalculate all busy instances and update // Recalculate all busy instances and update
let states = worker_load_states.read().unwrap(); let states = worker_load_states.read().unwrap();
let busy_instances: Vec<u64> = states let busy_instances: Vec<u64> = states
.iter() .iter()
.filter_map(|(&id, state)| { .filter_map(|(&id, state)| {
state.is_busy(busy_threshold).then_some(id) state.is_busy(current_threshold).then_some(id)
}) })
.collect(); .collect();
drop(states); drop(states);
......
...@@ -5,7 +5,7 @@ use std::pin::Pin; ...@@ -5,7 +5,7 @@ use std::pin::Pin;
use crate::{ use crate::{
backend::{Backend, ExecutionContext}, backend::{Backend, ExecutionContext},
discovery::{ModelManager, ModelWatcher}, discovery::{KvWorkerMonitor, ModelManager, ModelWatcher},
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{EngineConfig, RouterConfig}, entrypoint::{EngineConfig, RouterConfig},
kv_router::{KvPushRouter, KvRouter, PrefillRouter}, kv_router::{KvPushRouter, KvRouter, PrefillRouter},
...@@ -166,7 +166,7 @@ pub async fn build_routed_pipeline<Req, Resp>( ...@@ -166,7 +166,7 @@ pub async fn build_routed_pipeline<Req, Resp>(
card: &ModelDeploymentCard, card: &ModelDeploymentCard,
client: &Client, client: &Client,
router_mode: RouterMode, router_mode: RouterMode,
busy_threshold: Option<f64>, worker_monitor: Option<KvWorkerMonitor>,
chooser: Option<Arc<KvRouter>>, chooser: Option<Arc<KvRouter>>,
hf_tokenizer: tokenizers::Tokenizer, hf_tokenizer: tokenizers::Tokenizer,
prefill_chooser: Option<Arc<PrefillRouter>>, prefill_chooser: Option<Arc<PrefillRouter>>,
...@@ -189,7 +189,7 @@ where ...@@ -189,7 +189,7 @@ where
card, card,
client, client,
router_mode, router_mode,
busy_threshold, worker_monitor,
chooser, chooser,
preprocessor, preprocessor,
hf_tokenizer, hf_tokenizer,
...@@ -204,7 +204,7 @@ pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>( ...@@ -204,7 +204,7 @@ pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
card: &ModelDeploymentCard, card: &ModelDeploymentCard,
client: &Client, client: &Client,
router_mode: RouterMode, router_mode: RouterMode,
busy_threshold: Option<f64>, worker_monitor: Option<KvWorkerMonitor>,
chooser: Option<Arc<KvRouter>>, chooser: Option<Arc<KvRouter>>,
preprocessor: Arc<OpenAIPreprocessor>, preprocessor: Arc<OpenAIPreprocessor>,
hf_tokenizer: tokenizers::Tokenizer, hf_tokenizer: tokenizers::Tokenizer,
...@@ -236,20 +236,17 @@ where ...@@ -236,20 +236,17 @@ where
client.clone() client.clone()
}; };
// Create worker monitor only if busy_threshold is set // Get threshold value and wrap monitor for PushRouter
let worker_monitor = busy_threshold.map(|threshold| { let threshold_value = worker_monitor.as_ref().map(|m| m.threshold());
Arc::new(crate::discovery::KvWorkerMonitor::new( let monitor_arc =
Arc::new(router_client.clone()), worker_monitor.map(|m| Arc::new(m) as Arc<dyn dynamo_runtime::pipeline::WorkerLoadMonitor>);
threshold,
)) as Arc<dyn dynamo_runtime::pipeline::WorkerLoadMonitor>
});
let router = let router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold( PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
router_client, router_client,
router_mode, router_mode,
busy_threshold, threshold_value,
worker_monitor, monitor_arc,
) )
.await?; .await?;
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
mod openai; mod openai;
pub mod busy_threshold;
pub mod custom_backend_metrics; pub mod custom_backend_metrics;
pub mod disconnect; pub mod disconnect;
pub mod error; pub mod error;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! HTTP endpoint for dynamically getting/setting the busy threshold per model.
//!
//! The busy threshold controls when workers are marked as "busy" based on their
//! KV cache utilization. When all workers for a model exceed their threshold,
//! new requests are rejected with a 503 Service Unavailable response.
//!
//! ## Endpoints
//!
//! ### POST /busy_threshold
//!
//! Get or set a model's busy threshold.
//!
//! **Set threshold:**
//! ```json
//! // Request
//! {"model": "llama-3-70b", "threshold": 0.85}
//! // Response
//! {"model": "llama-3-70b", "threshold": 0.85}
//! ```
//!
//! **Get threshold (omit or null threshold):**
//! ```json
//! // Request
//! {"model": "llama-3-70b"}
//! // Response (if configured)
//! {"model": "llama-3-70b", "threshold": 0.85}
//! // Response (if not configured)
//! {"model": "llama-3-70b", "threshold": null}
//! ```
//!
//! ### GET /busy_threshold
//!
//! List all configured busy thresholds.
//!
//! ```json
//! // Response
//! {"thresholds": [{"model": "llama-3-70b", "threshold": 0.85}]}
//! ```
use super::{RouteDoc, service_v2};
use axum::{
Json, Router,
http::{Method, StatusCode},
response::IntoResponse,
routing::{get, post},
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
/// Request body for getting or setting a busy threshold.
///
/// - If `threshold` is provided: sets/creates the threshold and returns the new value
/// - If `threshold` is null/omitted: returns the existing threshold if any
#[derive(Debug, Deserialize)]
pub struct BusyThresholdRequest {
/// The model name
pub model: String,
/// The threshold value (0.0 to 1.0), or null to just get the current value
pub threshold: Option<f64>,
}
/// Response for a threshold operation
#[derive(Debug, Serialize)]
pub struct BusyThresholdResponse {
/// The model name
pub model: String,
/// The threshold value (null if no threshold is configured)
pub threshold: Option<f64>,
}
/// Response for listing all thresholds
#[derive(Debug, Serialize)]
pub struct ListBusyThresholdsResponse {
/// List of model thresholds
pub thresholds: Vec<BusyThresholdResponse>,
}
/// Error response
#[derive(Debug, Serialize)]
pub struct ErrorResponse {
pub error: String,
}
pub fn busy_threshold_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let base_path = path.unwrap_or_else(|| "/busy_threshold".to_string());
let docs: Vec<RouteDoc> = vec![
RouteDoc::new(Method::POST, &base_path),
RouteDoc::new(Method::GET, &base_path),
];
let router = Router::new()
.route(&base_path, post(busy_threshold_handler))
.route(&base_path, get(list_busy_thresholds_handler))
.with_state(state);
(docs, router)
}
async fn busy_threshold_handler(
axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
Json(request): Json<BusyThresholdRequest>,
) -> impl IntoResponse {
// Validate threshold range if provided
if let Some(threshold) = request.threshold
&& !(0.0..=1.0).contains(&threshold)
{
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!(ErrorResponse {
error: format!("Threshold must be between 0.0 and 1.0, got {}", threshold),
})),
);
}
let manager = state.manager();
// Get or set the threshold via the model's worker monitor
let threshold = manager.busy_threshold(&request.model, request.threshold);
// If trying to SET but model has no monitor, return 404
if request.threshold.is_some() && threshold.is_none() {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!(ErrorResponse {
error: format!(
"Model '{}' not found. Thresholds can only be set for discovered models.",
request.model
),
})),
);
}
if request.threshold.is_some() {
tracing::info!(
model = %request.model,
threshold = ?threshold,
"Updated busy threshold"
);
}
(
StatusCode::OK,
Json(serde_json::json!(BusyThresholdResponse {
model: request.model,
threshold,
})),
)
}
async fn list_busy_thresholds_handler(
axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
) -> impl IntoResponse {
let manager = state.manager();
let thresholds = manager.list_busy_thresholds();
let response = ListBusyThresholdsResponse {
thresholds: thresholds
.into_iter()
.map(|(model, threshold)| BusyThresholdResponse {
model,
threshold: Some(threshold),
})
.collect(),
};
Json(serde_json::json!(response))
}
...@@ -369,6 +369,7 @@ impl HttpServiceConfigBuilder { ...@@ -369,6 +369,7 @@ impl HttpServiceConfigBuilder {
super::openai::list_models_router(state.clone(), var(HTTP_SVC_MODELS_PATH_ENV).ok()), super::openai::list_models_router(state.clone(), var(HTTP_SVC_MODELS_PATH_ENV).ok()),
super::health::health_check_router(state.clone(), var(HTTP_SVC_HEALTH_PATH_ENV).ok()), super::health::health_check_router(state.clone(), var(HTTP_SVC_HEALTH_PATH_ENV).ok()),
super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()), super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()),
super::busy_threshold::busy_threshold_router(state.clone(), None),
]; ];
let endpoint_routes = let endpoint_routes =
......
...@@ -37,6 +37,7 @@ class KVRouterProcess(ManagedProcess): ...@@ -37,6 +37,7 @@ class KVRouterProcess(ManagedProcess):
namespace: str, namespace: str,
store_backend: str = "etcd", store_backend: str = "etcd",
enforce_disagg: bool = False, enforce_disagg: bool = False,
busy_threshold: float | None = None,
): ):
command = [ command = [
"python3", "python3",
...@@ -57,6 +58,9 @@ class KVRouterProcess(ManagedProcess): ...@@ -57,6 +58,9 @@ class KVRouterProcess(ManagedProcess):
if enforce_disagg: if enforce_disagg:
command.append("--enforce-disagg") command.append("--enforce-disagg")
if busy_threshold is not None:
command.extend(["--busy-threshold", str(busy_threshold)])
super().__init__( super().__init__(
command=command, command=command,
timeout=60, timeout=60,
...@@ -1882,3 +1886,196 @@ def _test_router_decisions( ...@@ -1882,3 +1886,196 @@ def _test_router_decisions(
f"All events correctly routed to worker_id={expected_worker_id} as expected. " f"All events correctly routed to worker_id={expected_worker_id} as expected. "
f"KV events synchronized correctly." f"KV events synchronized correctly."
) )
def _test_busy_threshold_endpoint(
engine_workers,
block_size: int,
request,
frontend_port: int,
test_payload: dict,
store_backend: str = "etcd",
):
"""Test that the /busy_threshold endpoint can be hit and responds correctly.
TODO: This doesn't actually test any e2e rejection for now. A proper test would:
1. Set a very low threshold
2. Send enough requests to exceed the threshold
3. Verify that subsequent requests are rejected with 503
For now, this test only verifies the endpoint is accessible and returns valid responses.
Args:
engine_workers: Backend workers (mocker/vllm) already initialized with __enter__()
block_size: Block size for KV cache
request: Pytest request fixture for managing resources
frontend_port: Port for the frontend HTTP server
test_payload: Base test payload (used to extract model name)
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
Raises:
AssertionError: If endpoint responses are incorrect
"""
# Initial threshold - we need to start with one so the monitor is created
initial_threshold = 0.9
try:
# Start KV router frontend with initial busy_threshold to create monitor
logger.info(f"Starting KV router frontend on port {frontend_port}")
kv_router = KVRouterProcess(
request,
block_size,
frontend_port,
engine_workers.namespace,
store_backend,
busy_threshold=initial_threshold,
)
kv_router.__enter__()
frontend_url = f"http://localhost:{frontend_port}"
busy_threshold_url = f"{frontend_url}/busy_threshold"
# Wait for workers to register with frontend
logger.info("Waiting for workers to register with frontend...")
asyncio.run(
wait_for_frontend_ready(
frontend_url=frontend_url,
expected_num_workers=engine_workers.num_workers,
timeout=120,
)
)
model_name = test_payload.get("model", "test-model")
async def test_busy_threshold_api():
async with aiohttp.ClientSession() as session:
# Test 1: GET /busy_threshold - list all thresholds
# Should have the initial threshold since we started with --busy-threshold
logger.info("Testing GET /busy_threshold (list all)")
async with session.get(busy_threshold_url) as response:
assert (
response.status == 200
), f"GET /busy_threshold failed with status {response.status}"
data = await response.json()
assert (
"thresholds" in data
), f"Expected 'thresholds' key in response: {data}"
thresholds = data.get("thresholds", [])
# Should have at least the model with initial_threshold
logger.info(f"GET /busy_threshold response: {data}")
# Test 2: POST /busy_threshold with model only (get threshold)
# Should return the initial threshold since we started with --busy-threshold
logger.info(
f"Testing POST /busy_threshold to get threshold for model '{model_name}'"
)
async with session.post(
busy_threshold_url,
json={"model": model_name},
) as response:
assert (
response.status == 200
), f"POST /busy_threshold (get) failed with status {response.status}"
data = await response.json()
assert (
data.get("threshold") == initial_threshold
), f"Expected initial threshold={initial_threshold}: {data}"
logger.info(
f"POST /busy_threshold (get) response: status={response.status}, data={data}"
)
# Test 3: POST /busy_threshold to set a threshold
test_threshold = 0.75
logger.info(
f"Testing POST /busy_threshold to set threshold={test_threshold}"
)
async with session.post(
busy_threshold_url,
json={"model": model_name, "threshold": test_threshold},
) as response:
assert (
response.status == 200
), f"POST /busy_threshold (set) failed with status {response.status}"
data = await response.json()
assert (
data.get("model") == model_name
), f"Expected model={model_name}: {data}"
assert (
data.get("threshold") == test_threshold
), f"Expected threshold={test_threshold}: {data}"
logger.info(f"POST /busy_threshold (set) response: {data}")
# Test 4: POST /busy_threshold to get the threshold we just set
logger.info("Testing POST /busy_threshold to verify threshold was set")
async with session.post(
busy_threshold_url,
json={"model": model_name},
) as response:
assert (
response.status == 200
), f"POST /busy_threshold (get after set) failed with status {response.status}"
data = await response.json()
assert (
data.get("threshold") == test_threshold
), f"Expected threshold={test_threshold}: {data}"
logger.info(
f"POST /busy_threshold (get after set) response: {data}"
)
# Test 5: POST /busy_threshold to update the threshold
new_threshold = 0.5
logger.info(
f"Testing POST /busy_threshold to update threshold={new_threshold}"
)
async with session.post(
busy_threshold_url,
json={"model": model_name, "threshold": new_threshold},
) as response:
assert (
response.status == 200
), f"POST /busy_threshold (update) failed with status {response.status}"
data = await response.json()
assert (
data.get("threshold") == new_threshold
), f"Expected threshold={new_threshold}: {data}"
logger.info(f"POST /busy_threshold (update) response: {data}")
# Test 6: GET /busy_threshold - verify threshold appears in list
logger.info("Testing GET /busy_threshold to verify threshold in list")
async with session.get(busy_threshold_url) as response:
assert (
response.status == 200
), f"GET /busy_threshold failed with status {response.status}"
data = await response.json()
thresholds = data.get("thresholds", [])
# thresholds is an array of {model, threshold} objects
model_thresholds = {t["model"]: t["threshold"] for t in thresholds}
assert (
model_name in model_thresholds
), f"Expected model '{model_name}' in thresholds: {data}"
assert (
model_thresholds[model_name] == new_threshold
), f"Expected threshold={new_threshold} for model '{model_name}': {data}"
logger.info(f"GET /busy_threshold (after set) response: {data}")
# Test 7: Invalid threshold value (should fail validation)
logger.info(
"Testing POST /busy_threshold with invalid threshold (>1.0)"
)
async with session.post(
busy_threshold_url,
json={"model": model_name, "threshold": 1.5},
) as response:
assert (
response.status == 400
), f"Expected 400 for invalid threshold, got {response.status}"
data = await response.json()
logger.info(f"POST /busy_threshold (invalid) response: {data}")
logger.info("All busy_threshold endpoint tests passed!")
asyncio.run(test_busy_threshold_api())
finally:
if "kv_router" in locals():
kv_router.__exit__(None, None, None)
...@@ -7,6 +7,7 @@ from typing import Any, Dict, Optional ...@@ -7,6 +7,7 @@ from typing import Any, Dict, Optional
import pytest import pytest
from tests.router.common import ( # utilities from tests.router.common import ( # utilities
_test_busy_threshold_endpoint,
_test_python_router_bindings, _test_python_router_bindings,
_test_router_basic, _test_router_basic,
_test_router_decisions, _test_router_decisions,
...@@ -67,6 +68,7 @@ def get_unique_ports( ...@@ -67,6 +68,7 @@ def get_unique_ports(
"test_mocker_kv_router_overload_503": 200, "test_mocker_kv_router_overload_503": 200,
"test_query_instance_id_returns_worker_and_tokens": 300, "test_query_instance_id_returns_worker_and_tokens": 300,
"test_router_disagg_decisions": 400, "test_router_disagg_decisions": 400,
"test_busy_threshold_endpoint": 500,
} }
base_offset = test_offsets.get(test_name, 0) base_offset = test_offsets.get(test_name, 0)
...@@ -676,3 +678,47 @@ def test_router_disagg_decisions( ...@@ -676,3 +678,47 @@ def test_router_disagg_decisions(
decode_workers.__exit__(None, None, None) decode_workers.__exit__(None, None, None)
if prefill_workers is not None: if prefill_workers is not None:
prefill_workers.__exit__(None, None, None) prefill_workers.__exit__(None, None, None)
@pytest.mark.pre_merge
@pytest.mark.gpu_0
@pytest.mark.integration
@pytest.mark.parallel
@pytest.mark.model(MODEL_NAME)
def test_busy_threshold_endpoint(
request, runtime_services_session, predownload_tokenizers
):
"""Test that the /busy_threshold endpoint can be hit and responds correctly.
TODO: This doesn't actually test any e2e rejection for now. A proper test would:
1. Set a very low threshold
2. Send enough requests to exceed the threshold
3. Verify that subsequent requests are rejected with 503
For now, this test only verifies the endpoint is accessible and returns valid responses.
"""
logger.info("Starting busy_threshold endpoint test")
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
try:
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(
request, mocker_args=mocker_args, num_mockers=NUM_MOCKERS
)
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
frontend_port = get_unique_ports(request, num_ports=1)[0]
_test_busy_threshold_endpoint(
engine_workers=mockers,
block_size=BLOCK_SIZE,
request=request,
frontend_port=frontend_port,
test_payload=TEST_PAYLOAD,
)
finally:
if "mockers" in locals():
mockers.__exit__(None, None, None)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment