Unverified Commit a337113a authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: prefill tokens threshold based on max num batched tokens frac (#5867)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 902eabd9
......@@ -241,6 +241,12 @@ def parse_args():
default=None,
help="Literal token count threshold for determining when a worker is considered busy based on prefill token utilization. When active prefill tokens exceed this threshold, the worker is marked as busy. If not set, tokens-based busy detection is disabled.",
)
parser.add_argument(
"--active-prefill-tokens-threshold-frac",
type=float,
default=None,
help="Fraction of max_num_batched_tokens for busy detection. Worker is busy when active_prefill_tokens > frac * max_num_batched_tokens. Default 1.5 (disabled). Uses OR logic with --active-prefill-tokens-threshold.",
)
parser.add_argument(
"--model-name",
type=validate_model_name,
......@@ -408,6 +414,7 @@ async def async_main():
kv_router_config,
active_decode_blocks_threshold=flags.active_decode_blocks_threshold,
active_prefill_tokens_threshold=flags.active_prefill_tokens_threshold,
active_prefill_tokens_threshold_frac=flags.active_prefill_tokens_threshold_frac,
enforce_disagg=flags.enforce_disagg,
),
}
......
......@@ -1398,8 +1398,11 @@ pub async fn create_worker_selection_pipeline_chat(
let router_config = dynamo_llm::entrypoint::RouterConfig {
router_mode,
kv_router_config: kv_router_config.unwrap_or_default(),
load_threshold_config: dynamo_llm::discovery::LoadThresholdConfig {
active_decode_blocks_threshold: busy_threshold,
active_prefill_tokens_threshold: None,
active_prefill_tokens_threshold_frac: None,
},
enforce_disagg,
};
// Create metrics for migration tracking (not exposed via /metrics in C bindings)
......@@ -1496,7 +1499,16 @@ pub async fn create_worker_selection_pipeline_chat(
// Create worker monitor if busy_threshold is set
// Note: C bindings don't register with ModelManager, so HTTP endpoint won't see this
let worker_monitor = busy_threshold.map(|t| KvWorkerMonitor::new(client.clone(), t, 1000000));
let worker_monitor = busy_threshold.map(|t| {
KvWorkerMonitor::new(
client.clone(),
dynamo_llm::discovery::LoadThresholdConfig {
active_decode_blocks_threshold: Some(t),
active_prefill_tokens_threshold: None,
active_prefill_tokens_threshold_frac: None,
},
)
});
// Clone chooser before passing to build_routed_pipeline (which takes ownership)
let kv_router = chooser.clone();
......
......@@ -10,6 +10,7 @@ use std::sync::Arc;
use pyo3::{exceptions::PyException, prelude::*};
use pyo3_async_runtimes::TaskLocals;
use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig;
use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
use dynamo_llm::entrypoint::EngineFactoryCallback;
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
......@@ -94,18 +95,21 @@ pub struct RouterConfig {
active_decode_blocks_threshold: Option<f64>,
/// Threshold for active prefill tokens utilization (literal token count)
active_prefill_tokens_threshold: Option<u64>,
/// Threshold for active prefill tokens as fraction of max_num_batched_tokens
active_prefill_tokens_threshold_frac: Option<f64>,
enforce_disagg: bool,
}
#[pymethods]
impl RouterConfig {
#[new]
#[pyo3(signature = (mode, config=None, active_decode_blocks_threshold=None, active_prefill_tokens_threshold=None, enforce_disagg=false))]
#[pyo3(signature = (mode, config=None, active_decode_blocks_threshold=None, active_prefill_tokens_threshold=None, active_prefill_tokens_threshold_frac=None, enforce_disagg=false))]
pub fn new(
mode: RouterMode,
config: Option<KvRouterConfig>,
active_decode_blocks_threshold: Option<f64>,
active_prefill_tokens_threshold: Option<u64>,
active_prefill_tokens_threshold_frac: Option<f64>,
enforce_disagg: bool,
) -> Self {
Self {
......@@ -113,6 +117,7 @@ impl RouterConfig {
kv_router_config: config.unwrap_or_default(),
active_decode_blocks_threshold,
active_prefill_tokens_threshold,
active_prefill_tokens_threshold_frac,
enforce_disagg,
}
}
......@@ -123,8 +128,11 @@ impl From<RouterConfig> for RsRouterConfig {
RsRouterConfig {
router_mode: rc.router_mode.into(),
kv_router_config: rc.kv_router_config.inner,
load_threshold_config: RsLoadThresholdConfig {
active_decode_blocks_threshold: rc.active_decode_blocks_threshold,
active_prefill_tokens_threshold: rc.active_prefill_tokens_threshold,
active_prefill_tokens_threshold_frac: rc.active_prefill_tokens_threshold_frac,
},
enforce_disagg: rc.enforce_disagg,
}
}
......
......@@ -11,4 +11,4 @@ mod watcher;
pub use watcher::{ModelUpdate, ModelWatcher};
mod worker_monitor;
pub use worker_monitor::{KvWorkerMonitor, WorkerLoadState};
pub use worker_monitor::{KvWorkerMonitor, LoadThresholdConfig, WorkerLoadState};
......@@ -10,8 +10,8 @@ use dashmap::{DashMap, mapref::entry::Entry};
use parking_lot::RwLock;
use tokio::sync::oneshot;
use crate::discovery::KvWorkerMonitor;
use crate::discovery::runtime_configs::RuntimeConfigs;
use super::worker_monitor::LoadThresholdConfig;
use super::{KvWorkerMonitor, RuntimeConfigs};
use dynamo_runtime::{
component::{Client, Endpoint, build_transport_type},
......@@ -74,14 +74,8 @@ pub struct ModelManager {
kv_choosers: DashMap<EndpointId, Arc<KvRouter>>,
prefill_router_activators: DashMap<String, PrefillActivationState>,
/// Per-model worker monitors for dynamic KV cache load rejection.
/// Key: model name, Value: cloneable monitor (all fields are Arc).
/// HTTP endpoint can update thresholds via monitor.set_threshold().
worker_monitors: RwLock<HashMap<String, KvWorkerMonitor>>,
/// Runtime configs per endpoint using DashMap for lock-free access.
/// Outer DashMap: keyed by EndpointId
/// Inner RuntimeConfigs: shared with KvScheduler
// Per-model monitoring: worker_monitors for load-based rejection, runtime_configs for KvScheduler
worker_monitors: DashMap<String, KvWorkerMonitor>,
runtime_configs: DashMap<EndpointId, Arc<RuntimeConfigs>>,
}
......@@ -102,7 +96,7 @@ impl ModelManager {
cards: DashMap::new(),
kv_choosers: DashMap::new(),
prefill_router_activators: DashMap::new(),
worker_monitors: RwLock::new(HashMap::new()),
worker_monitors: DashMap::new(),
runtime_configs: DashMap::new(),
}
}
......@@ -501,110 +495,35 @@ impl ModelManager {
crate::protocols::openai::ParsingOptions::new(tool_call_parser, reasoning_parser)
}
/// Gets or sets the busy threshold for a model via its worker monitor.
///
/// Get or set the active decode blocks threshold for a model's worker monitor.
///
/// This is the primary API for HTTP endpoints and external callers.
/// The threshold (0.0 to 1.0) controls when workers are marked as "busy"
/// based on KV cache block utilization.
///
/// # Arguments
///
/// * `model` - The model name
/// * `threshold` - `Some(value)` to set, `None` to get existing
///
/// # Returns
///
/// The threshold value as f64, or `None` if no monitor exists for this model.
pub fn active_decode_blocks_threshold(
&self,
model: &str,
threshold: Option<f64>,
) -> Option<f64> {
let monitors = self.worker_monitors.read();
let monitor = monitors.get(model)?;
match threshold {
Some(value) => {
monitor.set_active_decode_blocks_threshold(value);
Some(value)
}
None => Some(monitor.active_decode_blocks_threshold()),
}
}
/// Get or set the active prefill tokens threshold for a model's worker monitor.
///
/// The threshold is a literal token count (not a percentage).
///
/// # Arguments
///
/// * `model` - The model name
/// * `threshold` - `Some(value)` to set, `None` to get existing
///
/// # Returns
///
/// The threshold value as u64, or `None` if no monitor exists for this model.
pub fn active_prefill_tokens_threshold(
/// Gets or sets the load threshold config for a model's worker monitor.
/// Pass `Some(config)` to update, `None` to get. Returns `None` if no monitor exists.
pub fn load_threshold_config(
&self,
model: &str,
threshold: Option<u64>,
) -> Option<u64> {
let monitors = self.worker_monitors.read();
let monitor = monitors.get(model)?;
match threshold {
Some(value) => {
monitor.set_active_prefill_tokens_threshold(value);
Some(value)
}
None => Some(monitor.active_prefill_tokens_threshold()),
}
}
/// Gets or creates a worker monitor for a model.
///
/// If a monitor already exists, updates its thresholds and returns a clone.
/// If no monitor exists, creates one with the given client and thresholds.
///
/// # Arguments
///
/// * `model` - The model name
/// * `client` - The client for subscribing to KV metrics (only used if creating new)
/// * `active_decode_blocks_threshold` - The initial/updated active decode blocks threshold value (0.0-1.0)
/// * `active_prefill_tokens_threshold` - The initial/updated active prefill tokens threshold value (literal token count)
///
/// # Returns
///
/// A cloneable monitor that shares state with the stored instance.
config: Option<&LoadThresholdConfig>,
) -> Option<LoadThresholdConfig> {
let monitor = self.worker_monitors.get(model)?;
if let Some(cfg) = config {
monitor.set_load_threshold_config(cfg);
}
Some(monitor.load_threshold_config())
}
/// Gets or creates a worker monitor for a model. Updates thresholds if monitor exists.
pub fn get_or_create_worker_monitor(
&self,
model: &str,
client: Client,
active_decode_blocks_threshold: f64,
active_prefill_tokens_threshold: u64,
config: LoadThresholdConfig,
) -> KvWorkerMonitor {
let mut monitors = self.worker_monitors.write();
if let Some(existing) = monitors.get(model) {
existing.set_active_decode_blocks_threshold(active_decode_blocks_threshold);
existing.set_active_prefill_tokens_threshold(active_prefill_tokens_threshold);
existing.clone()
} else {
let monitor = KvWorkerMonitor::new(
client,
active_decode_blocks_threshold,
active_prefill_tokens_threshold,
);
monitors.insert(model.to_string(), monitor.clone());
monitor
if let Some(existing) = self.worker_monitors.get(model) {
existing.set_load_threshold_config(&config);
return existing.clone();
}
}
/// Gets an existing worker monitor for a model, if one exists.
pub fn get_worker_monitor(&self, model: &str) -> Option<KvWorkerMonitor> {
self.worker_monitors.read().get(model).cloned()
let monitor = KvWorkerMonitor::new(client, config);
self.worker_monitors
.insert(model.to_string(), monitor.clone());
monitor
}
/// Get or create a runtime config watcher for an endpoint.
......@@ -651,20 +570,11 @@ impl ModelManager {
config_ref.as_ref()?.disaggregated_endpoint.clone()
}
/// Lists all models that have worker monitors (and thus busy thresholds) configured.
///
/// Returns a vector of (model_name, active_decode_blocks_threshold, active_prefill_tokens_threshold) tuples.
pub fn list_busy_thresholds(&self) -> Vec<(String, f64, u64)> {
/// Lists all models with worker monitors configured.
pub fn list_busy_thresholds(&self) -> Vec<(String, LoadThresholdConfig)> {
self.worker_monitors
.read()
.iter()
.map(|(k, monitor)| {
(
k.clone(),
monitor.active_decode_blocks_threshold(),
monitor.active_prefill_tokens_threshold(),
)
})
.map(|entry| (entry.key().clone(), entry.value().load_threshold_config()))
.collect()
}
}
......
......@@ -460,24 +460,12 @@ impl ModelWatcher {
// Get or create the worker monitor for this model
// This allows dynamic threshold updates via the ModelManager
// Create monitor if either threshold is configured
let worker_monitor = if self.router_config.active_decode_blocks_threshold.is_some()
|| self.router_config.active_prefill_tokens_threshold.is_some()
{
// Default thresholds: active_decode_blocks=1.0 (disabled), active_prefill_tokens=1000000 (effectively disabled)
let active_decode_blocks = self
.router_config
.active_decode_blocks_threshold
.unwrap_or(1.0);
let active_prefill_tokens = self
.router_config
.active_prefill_tokens_threshold
.unwrap_or(1000000);
// Create monitor if any threshold is configured
let worker_monitor = if self.router_config.load_threshold_config.is_configured() {
Some(self.manager.get_or_create_worker_monitor(
card.name(),
client.clone(),
active_decode_blocks,
active_prefill_tokens,
self.router_config.load_threshold_config.clone(),
))
} else {
None
......
......@@ -6,6 +6,7 @@ use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use crate::kv_router::KV_METRICS_SUBJECT;
use crate::kv_router::protocols::ActiveLoad;
......@@ -19,29 +20,70 @@ use dynamo_runtime::transports::event_plane::EventSubscriber;
/// Scale factor for storing f64 thresholds as u32 (10000 = 4 decimal places)
const THRESHOLD_SCALE: u32 = 10000;
/// Default value for max_num_batched_tokens and active_prefill_tokens_threshold
/// when not configured. Set high enough to effectively disable busy detection.
const DEFAULT_MAX_TOKENS: u64 = 10_000_000;
/// Configuration for worker load thresholds used in busy detection.
///
/// All thresholds are optional. When not set, defaults are applied:
/// - `active_decode_blocks_threshold`: 1.0 (effectively disabled)
/// - `active_prefill_tokens_threshold`: 10,000,000 (effectively disabled)
/// - `active_prefill_tokens_threshold_frac`: 1.5 (effectively disabled)
/// - `max_num_batched_tokens` (from runtime config): 10,000,000 if not reported
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct LoadThresholdConfig {
/// KV cache block utilization threshold (0.0-1.0).
/// Worker is busy when `active_decode_blocks / total_blocks > threshold`.
#[serde(skip_serializing_if = "Option::is_none")]
pub active_decode_blocks_threshold: Option<f64>,
/// Absolute prefill token count threshold.
/// Worker is busy when `active_prefill_tokens > threshold`.
#[serde(skip_serializing_if = "Option::is_none")]
pub active_prefill_tokens_threshold: Option<u64>,
/// Fraction of max_num_batched_tokens (0.0-1.5+).
/// Worker is busy when `active_prefill_tokens > frac * max_num_batched_tokens`.
#[serde(skip_serializing_if = "Option::is_none")]
pub active_prefill_tokens_threshold_frac: Option<f64>,
}
impl LoadThresholdConfig {
/// Returns true if any threshold is configured.
pub fn is_configured(&self) -> bool {
self.active_decode_blocks_threshold.is_some()
|| self.active_prefill_tokens_threshold.is_some()
|| self.active_prefill_tokens_threshold_frac.is_some()
}
}
/// Worker load monitoring state per dp_rank
#[derive(Clone, Debug, Default)]
pub struct WorkerLoadState {
pub active_decode_blocks: HashMap<u32, u64>,
pub kv_total_blocks: HashMap<u32, u64>,
pub active_prefill_tokens: HashMap<u32, u64>,
/// max_num_batched_tokens from runtime config (same for all dp_ranks)
pub max_num_batched_tokens: HashMap<u32, u64>,
}
impl WorkerLoadState {
/// Returns true if ALL dp_ranks are considered busy based on the dual-threshold logic:
/// Returns true if ALL dp_ranks are considered busy based on the threshold logic.
///
/// For each dp_rank:
/// 1. If `active_prefill_tokens` is available, check if tokens exceed the literal threshold.
/// If so, that dp_rank is busy.
/// 2. If not, check if `active_decode_blocks` and `kv_total_blocks` are both available,
/// and if blocks exceed threshold. If so, that dp_rank is busy.
/// 3. If neither check can be performed (missing data), that dp_rank is considered free.
/// For each dp_rank, a dp_rank is busy if ANY of these conditions is met (OR logic):
/// 1. `active_prefill_tokens > active_prefill_tokens_threshold` (absolute threshold)
/// 2. `active_prefill_tokens > frac * max_num_batched_tokens` (fraction-based threshold)
/// 3. `active_decode_blocks / total_blocks > active_decode_blocks_threshold` (blocks threshold)
///
/// If none of these checks can be performed (missing data), that dp_rank is considered free.
///
/// The worker is busy only if ALL dp_ranks are busy.
pub fn is_busy(
&self,
active_decode_blocks_threshold: f64,
active_prefill_tokens_threshold: u64,
active_prefill_tokens_threshold_frac: f64,
) -> bool {
// Get all dp_ranks we know about
let all_dp_ranks: std::collections::HashSet<_> = self
......@@ -58,14 +100,26 @@ impl WorkerLoadState {
// Check if ALL dp_ranks are busy
all_dp_ranks.iter().all(|&dp_rank| {
// First check: prefill tokens threshold (literal token count)
if let Some(&active_tokens) = self.active_prefill_tokens.get(&dp_rank)
&& active_tokens > active_prefill_tokens_threshold
{
return true; // This dp_rank is busy due to tokens
// Check 1: prefill tokens threshold (absolute token count)
if let Some(&active_tokens) = self.active_prefill_tokens.get(&dp_rank) {
if active_tokens > active_prefill_tokens_threshold {
return true; // This dp_rank is busy due to absolute token threshold
}
// Second check: blocks threshold
// Check 2: prefill tokens threshold (fraction of max_num_batched_tokens)
let max_batched = self
.max_num_batched_tokens
.get(&dp_rank)
.copied()
.unwrap_or(DEFAULT_MAX_TOKENS);
let frac_threshold =
(active_prefill_tokens_threshold_frac * max_batched as f64) as u64;
if active_tokens > frac_threshold {
return true; // This dp_rank is busy due to frac-based token threshold
}
}
// Check 3: blocks threshold
// Skip if total_blocks is 0 (no capacity means threshold check is meaningless)
if let (Some(&active_blocks), Some(&total_blocks)) = (
self.active_decode_blocks.get(&dp_rank),
......@@ -76,7 +130,7 @@ impl WorkerLoadState {
return true; // This dp_rank is busy due to blocks
}
// If we can't perform either check, this dp_rank is considered free
// If we can't perform any check or no threshold exceeded, this dp_rank is free
false
})
}
......@@ -94,61 +148,64 @@ pub struct KvWorkerMonitor {
active_decode_blocks_threshold: Arc<AtomicU32>,
/// Active prefill tokens threshold stored as literal token count (u64)
active_prefill_tokens_threshold: Arc<AtomicU64>,
/// Active prefill tokens threshold as fraction of max_num_batched_tokens, stored scaled
active_prefill_tokens_threshold_frac: Arc<AtomicU32>,
/// Guard to ensure start_monitoring() only runs once across clones
started: Arc<AtomicBool>,
}
impl KvWorkerMonitor {
/// Create a new worker monitor with the given thresholds.
/// Create a new worker monitor with the given threshold configuration.
///
/// - `active_decode_blocks_threshold` (0.0-1.0): Threshold percentage for KV cache block utilization
/// - `active_prefill_tokens_threshold`: Literal token count threshold for prefill token utilization
/// All thresholds can be dynamically updated via setter methods or
/// `set_load_threshold_config()`.
///
/// Both thresholds can be dynamically updated via `set_active_decode_blocks_threshold()` and
/// `set_active_prefill_tokens_threshold()`.
pub fn new(
client: Client,
active_decode_blocks_threshold: f64,
active_prefill_tokens_threshold: u64,
) -> Self {
/// Defaults are applied for any threshold not specified in the config:
/// - `active_decode_blocks_threshold`: 1.0 (effectively disabled)
/// - `active_prefill_tokens_threshold`: DEFAULT_MAX_TOKENS (effectively disabled)
/// - `active_prefill_tokens_threshold_frac`: 1.5 (effectively disabled)
pub fn new(client: Client, config: LoadThresholdConfig) -> Self {
let active_decode_blocks = config.active_decode_blocks_threshold.unwrap_or(1.0);
let active_prefill_tokens = config
.active_prefill_tokens_threshold
.unwrap_or(DEFAULT_MAX_TOKENS);
let active_prefill_tokens_frac = config.active_prefill_tokens_threshold_frac.unwrap_or(1.5);
Self {
client,
worker_load_states: Arc::new(DashMap::new()),
active_decode_blocks_threshold: Arc::new(AtomicU32::new(
Self::active_decode_blocks_threshold_to_scaled(active_decode_blocks_threshold),
)),
active_prefill_tokens_threshold: Arc::new(AtomicU64::new(
active_prefill_tokens_threshold,
)),
active_decode_blocks_threshold: Arc::new(AtomicU32::new(Self::f64_to_scaled(
active_decode_blocks,
))),
active_prefill_tokens_threshold: Arc::new(AtomicU64::new(active_prefill_tokens)),
active_prefill_tokens_threshold_frac: Arc::new(AtomicU32::new(Self::f64_to_scaled(
active_prefill_tokens_frac,
))),
started: Arc::new(AtomicBool::new(false)),
}
}
/// Convert a f64 active decode blocks threshold (0.0-1.0) to scaled u32 for atomic storage.
/// Convert a f64 threshold to scaled u32 for atomic storage.
#[inline]
fn active_decode_blocks_threshold_to_scaled(threshold: f64) -> u32 {
fn f64_to_scaled(threshold: f64) -> u32 {
(threshold * THRESHOLD_SCALE as f64) as u32
}
/// Convert a scaled u32 back to f64 active decode blocks threshold (0.0-1.0).
/// Convert a scaled u32 back to f64 threshold.
#[inline]
fn scaled_to_active_decode_blocks_threshold(scaled: u32) -> f64 {
fn scaled_to_f64(scaled: u32) -> f64 {
scaled as f64 / THRESHOLD_SCALE as f64
}
/// Get the current active decode blocks threshold value as f64.
pub fn active_decode_blocks_threshold(&self) -> f64 {
Self::scaled_to_active_decode_blocks_threshold(
self.active_decode_blocks_threshold.load(Ordering::Relaxed),
)
Self::scaled_to_f64(self.active_decode_blocks_threshold.load(Ordering::Relaxed))
}
/// Set the active decode blocks threshold value from f64.
pub fn set_active_decode_blocks_threshold(&self, threshold: f64) {
self.active_decode_blocks_threshold.store(
Self::active_decode_blocks_threshold_to_scaled(threshold),
Ordering::Relaxed,
);
self.active_decode_blocks_threshold
.store(Self::f64_to_scaled(threshold), Ordering::Relaxed);
}
/// Get the current active prefill tokens threshold value as u64.
......@@ -162,9 +219,41 @@ impl KvWorkerMonitor {
.store(threshold, Ordering::Relaxed);
}
/// Get the worker load states for external access
pub fn load_states(&self) -> Arc<DashMap<u64, WorkerLoadState>> {
self.worker_load_states.clone()
/// Get the current active prefill tokens threshold frac value as f64.
pub fn active_prefill_tokens_threshold_frac(&self) -> f64 {
Self::scaled_to_f64(
self.active_prefill_tokens_threshold_frac
.load(Ordering::Relaxed),
)
}
/// Set the active prefill tokens threshold frac value from f64.
pub fn set_active_prefill_tokens_threshold_frac(&self, frac: f64) {
self.active_prefill_tokens_threshold_frac
.store(Self::f64_to_scaled(frac), Ordering::Relaxed);
}
/// Get the current load threshold configuration.
pub fn load_threshold_config(&self) -> LoadThresholdConfig {
LoadThresholdConfig {
active_decode_blocks_threshold: Some(self.active_decode_blocks_threshold()),
active_prefill_tokens_threshold: Some(self.active_prefill_tokens_threshold()),
active_prefill_tokens_threshold_frac: Some(self.active_prefill_tokens_threshold_frac()),
}
}
/// Update all thresholds from a LoadThresholdConfig.
/// Only updates fields that are Some in the config.
pub fn set_load_threshold_config(&self, config: &LoadThresholdConfig) {
if let Some(threshold) = config.active_decode_blocks_threshold {
self.set_active_decode_blocks_threshold(threshold);
}
if let Some(threshold) = config.active_prefill_tokens_threshold {
self.set_active_prefill_tokens_threshold(threshold);
}
if let Some(frac) = config.active_prefill_tokens_threshold_frac {
self.set_active_prefill_tokens_threshold_frac(frac);
}
}
}
......@@ -206,6 +295,8 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
let client = self.client.clone();
let active_decode_blocks_threshold = self.active_decode_blocks_threshold.clone();
let active_prefill_tokens_threshold = self.active_prefill_tokens_threshold.clone();
let active_prefill_tokens_threshold_frac =
self.active_prefill_tokens_threshold_frac.clone();
// Spawn background monitoring task
tokio::spawn(async move {
......@@ -224,7 +315,7 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
worker_load_states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));
// Update worker load states with total blocks for all dp_ranks
// Update worker load states with runtime config values for all dp_ranks
for (lease_id, runtime_config) in runtime_configs.iter() {
let mut state = worker_load_states.entry(*lease_id).or_default();
......@@ -234,6 +325,13 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
state.kv_total_blocks.insert(dp_rank, total_blocks);
}
}
// Populate max_num_batched_tokens for all dp_ranks
if let Some(max_batched) = runtime_config.max_num_batched_tokens {
for dp_rank in 0..runtime_config.data_parallel_size {
state.max_num_batched_tokens.insert(dp_rank, max_batched);
}
}
}
}
......@@ -255,7 +353,6 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
// Update worker load state per dp_rank
{
let mut state = worker_load_states.entry(worker_id).or_default();
if let Some(active_blocks) = active_load.active_decode_blocks {
state.active_decode_blocks.insert(dp_rank, active_blocks);
}
......@@ -265,18 +362,25 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
}
// Load thresholds dynamically - allows runtime updates
let current_active_decode_blocks_threshold = Self::scaled_to_active_decode_blocks_threshold(
active_decode_blocks_threshold.load(Ordering::Relaxed),
);
let current_active_prefill_tokens_threshold = active_prefill_tokens_threshold.load(Ordering::Relaxed);
let current_active_decode_blocks_threshold =
Self::scaled_to_f64(active_decode_blocks_threshold.load(Ordering::Relaxed));
let current_active_prefill_tokens_threshold =
active_prefill_tokens_threshold.load(Ordering::Relaxed);
let current_active_prefill_tokens_threshold_frac =
Self::scaled_to_f64(active_prefill_tokens_threshold_frac.load(Ordering::Relaxed));
// Recalculate all busy instances and update
let busy_instances: Vec<u64> = worker_load_states
.iter()
.filter_map(|r| {
r.value()
.is_busy(current_active_decode_blocks_threshold, current_active_prefill_tokens_threshold)
.then_some(*r.key())
.filter_map(|entry| {
entry
.value()
.is_busy(
current_active_decode_blocks_threshold,
current_active_prefill_tokens_threshold,
current_active_prefill_tokens_threshold_frac,
)
.then_some(*entry.key())
})
.collect();
......
......@@ -15,8 +15,8 @@ use std::sync::Arc;
use dynamo_runtime::pipeline::RouterMode;
use crate::{
backend::ExecutionContext, engines::StreamingEngine, kv_router::KvRouterConfig,
local_model::LocalModel, model_card::ModelDeploymentCard,
backend::ExecutionContext, discovery::LoadThresholdConfig, engines::StreamingEngine,
kv_router::KvRouterConfig, local_model::LocalModel, model_card::ModelDeploymentCard,
types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine,
};
......@@ -34,10 +34,8 @@ pub type EngineFactoryCallback = Arc<
pub struct RouterConfig {
pub router_mode: RouterMode,
pub kv_router_config: KvRouterConfig,
/// Threshold for active decode blocks utilization (0.0-1.0)
pub active_decode_blocks_threshold: Option<f64>,
/// Threshold for active prefill tokens utilization (literal token count)
pub active_prefill_tokens_threshold: Option<u64>,
/// Load threshold configuration for busy detection
pub load_threshold_config: LoadThresholdConfig,
pub enforce_disagg: bool,
}
......@@ -46,19 +44,13 @@ impl RouterConfig {
Self {
router_mode,
kv_router_config,
active_decode_blocks_threshold: None,
active_prefill_tokens_threshold: None,
load_threshold_config: LoadThresholdConfig::default(),
enforce_disagg: false,
}
}
pub fn with_active_decode_blocks_threshold(mut self, threshold: Option<f64>) -> Self {
self.active_decode_blocks_threshold = threshold;
self
}
pub fn with_active_prefill_tokens_threshold(mut self, threshold: Option<u64>) -> Self {
self.active_prefill_tokens_threshold = threshold;
pub fn with_load_threshold_config(mut self, config: LoadThresholdConfig) -> Self {
self.load_threshold_config = config;
self
}
......
......@@ -17,9 +17,9 @@
//! **Set thresholds:**
//! ```json
//! // Request
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000, "active_prefill_tokens_threshold_frac": 0.8}
//! // Response
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000, "active_prefill_tokens_threshold_frac": 0.8}
//! ```
//!
//! **Get thresholds (omit thresholds):**
......@@ -27,9 +27,9 @@
//! // Request
//! {"model": "llama-3-70b"}
//! // Response (if configured)
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000, "active_prefill_tokens_threshold_frac": 0.8}
//! // Response (if not configured)
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": null, "active_prefill_tokens_threshold": null}
//! {"model": "llama-3-70b", "active_decode_blocks_threshold": null, "active_prefill_tokens_threshold": null, "active_prefill_tokens_threshold_frac": null}
//! ```
//!
//! ### GET /busy_threshold
......@@ -38,10 +38,11 @@
//!
//! ```json
//! // Response
//! {"thresholds": [{"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000}]}
//! {"thresholds": [{"model": "llama-3-70b", "active_decode_blocks_threshold": 0.85, "active_prefill_tokens_threshold": 1000, "active_prefill_tokens_threshold_frac": 0.8}]}
//! ```
use super::{RouteDoc, service_v2};
use crate::discovery::LoadThresholdConfig;
use axum::{
Json, Router,
extract::Request,
......@@ -65,6 +66,8 @@ pub struct BusyThresholdRequest {
pub active_decode_blocks_threshold: Option<f64>,
/// The active prefill tokens threshold value (literal token count), or null to just get the current value
pub active_prefill_tokens_threshold: Option<u64>,
/// The active prefill tokens threshold as fraction of max_num_batched_tokens, or null to just get the current value
pub active_prefill_tokens_threshold_frac: Option<f64>,
}
/// Response for a threshold operation
......@@ -76,6 +79,8 @@ pub struct BusyThresholdResponse {
pub active_decode_blocks_threshold: Option<f64>,
/// The active prefill tokens threshold value (null if no threshold is configured)
pub active_prefill_tokens_threshold: Option<u64>,
/// The active prefill tokens threshold as fraction of max_num_batched_tokens
pub active_prefill_tokens_threshold_frac: Option<f64>,
}
/// Response for listing all thresholds
......@@ -155,19 +160,26 @@ async fn busy_threshold_handler(
let manager = state.manager();
// Build LoadThresholdConfig from request if any threshold is being set
let is_setting = request.active_decode_blocks_threshold.is_some()
|| request.active_prefill_tokens_threshold.is_some()
|| request.active_prefill_tokens_threshold_frac.is_some();
let update_config = if is_setting {
Some(LoadThresholdConfig {
active_decode_blocks_threshold: request.active_decode_blocks_threshold,
active_prefill_tokens_threshold: request.active_prefill_tokens_threshold,
active_prefill_tokens_threshold_frac: request.active_prefill_tokens_threshold_frac,
})
} else {
None
};
// Get or set the thresholds via the model's worker monitor
let active_decode_blocks_threshold = manager
.active_decode_blocks_threshold(&request.model, request.active_decode_blocks_threshold);
let active_prefill_tokens_threshold = manager
.active_prefill_tokens_threshold(&request.model, request.active_prefill_tokens_threshold);
let config = manager.load_threshold_config(&request.model, update_config.as_ref());
// If trying to SET but model has no monitor, return 404
let is_setting = request.active_decode_blocks_threshold.is_some()
|| request.active_prefill_tokens_threshold.is_some();
if is_setting
&& active_decode_blocks_threshold.is_none()
&& active_prefill_tokens_threshold.is_none()
{
if is_setting && config.is_none() {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!(ErrorResponse {
......@@ -182,18 +194,30 @@ async fn busy_threshold_handler(
if is_setting {
tracing::info!(
model = %request.model,
active_decode_blocks_threshold = ?active_decode_blocks_threshold,
active_prefill_tokens_threshold = ?active_prefill_tokens_threshold,
config = ?config,
"Updated busy thresholds"
);
}
let (
active_decode_blocks_threshold,
active_prefill_tokens_threshold,
active_prefill_tokens_threshold_frac,
) = config.map_or((None, None, None), |c| {
(
c.active_decode_blocks_threshold,
c.active_prefill_tokens_threshold,
c.active_prefill_tokens_threshold_frac,
)
});
(
StatusCode::OK,
Json(serde_json::json!(BusyThresholdResponse {
model: request.model,
active_decode_blocks_threshold,
active_prefill_tokens_threshold,
active_prefill_tokens_threshold_frac,
})),
)
}
......@@ -207,15 +231,12 @@ async fn list_busy_thresholds_handler(
let response = ListBusyThresholdsResponse {
thresholds: thresholds
.into_iter()
.map(
|(model, active_decode_blocks_threshold, active_prefill_tokens_threshold)| {
BusyThresholdResponse {
.map(|(model, config)| BusyThresholdResponse {
model,
active_decode_blocks_threshold: Some(active_decode_blocks_threshold),
active_prefill_tokens_threshold: Some(active_prefill_tokens_threshold),
}
},
)
active_decode_blocks_threshold: config.active_decode_blocks_threshold,
active_prefill_tokens_threshold: config.active_prefill_tokens_threshold,
active_prefill_tokens_threshold_frac: config.active_prefill_tokens_threshold_frac,
})
.collect(),
};
......
......@@ -48,6 +48,7 @@ class KVRouterProcess(ManagedProcess):
enforce_disagg: bool = False,
blocks_threshold: float | None = None,
tokens_threshold: float | None = None,
tokens_threshold_frac: float | None = None,
request_plane: str = "nats",
):
command = [
......@@ -75,6 +76,11 @@ class KVRouterProcess(ManagedProcess):
if tokens_threshold is not None:
command.extend(["--active-prefill-tokens-threshold", str(tokens_threshold)])
if tokens_threshold_frac is not None:
command.extend(
["--active-prefill-tokens-threshold-frac", str(tokens_threshold_frac)]
)
env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request_plane
......@@ -2394,6 +2400,52 @@ def _test_busy_threshold_endpoint(
f"POST /busy_threshold (invalid tokens) response: {data}"
)
# Test 10: Set active_prefill_tokens_threshold_frac (fraction of max_num_batched_tokens)
test_frac_threshold = 0.8
logger.info(
f"Testing POST /busy_threshold to set active_prefill_tokens_threshold_frac={test_frac_threshold}"
)
async with session.post(
busy_threshold_url,
json={
"model": model_name,
"active_prefill_tokens_threshold_frac": test_frac_threshold,
},
) as response:
assert (
response.status == 200
), f"POST /busy_threshold (set frac) failed with status {response.status}"
data = await response.json()
assert (
data.get("active_prefill_tokens_threshold_frac")
== test_frac_threshold
), f"Expected active_prefill_tokens_threshold_frac={test_frac_threshold}: {data}"
logger.info(f"POST /busy_threshold (set frac) response: {data}")
# Test 11: Verify frac threshold appears in GET /busy_threshold list
logger.info(
"Testing GET /busy_threshold to verify frac threshold in list"
)
async with session.get(busy_threshold_url) as response:
assert (
response.status == 200
), f"GET /busy_threshold failed with status {response.status}"
data = await response.json()
thresholds = data.get("thresholds", [])
model_entry = next(
(t for t in thresholds if t["model"] == model_name), None
)
assert (
model_entry is not None
), f"Expected model '{model_name}' in thresholds: {data}"
assert (
model_entry.get("active_prefill_tokens_threshold_frac")
== test_frac_threshold
), f"Expected active_prefill_tokens_threshold_frac={test_frac_threshold}: {data}"
logger.info(
f"GET /busy_threshold (after set frac) response: {data}"
)
logger.info("All busy_threshold endpoint tests passed!")
asyncio.run(test_busy_threshold_api())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment