Unverified Commit ffe20629 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: latch busy load until both cool off (#7912)


Signed-off-by: default avatarAmeenP <ameenp360@gmail.com>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarAmeenP <ameenp360@gmail.com>
parent 95a750f4
......@@ -140,7 +140,9 @@ class DynamoSglangPublisher:
else self.dp_rank
)
active_decode_blocks = kv_metrics.kv_active_blocks
self.metrics_publisher.publish(dp_rank, active_decode_blocks)
self.metrics_publisher.publish(
dp_rank, kv_used_blocks=active_decode_blocks
)
dp_rank_str = str(dp_rank)
# Publish total blocks (always available in KvMetrics)
self.component_gauges.set_total_blocks(
......@@ -185,7 +187,7 @@ class DynamoSglangPublisher:
def init_engine_metrics_publish(self) -> None:
"""Publish initial dummy metrics to bootstrap the metrics endpoint."""
logging.info("Sending dummy metrics to initialize")
self.metrics_publisher.publish(self.dp_rank, 0)
self.metrics_publisher.publish(self.dp_rank, kv_used_blocks=0)
dp_rank_str = str(self.dp_rank)
self.component_gauges.set_total_blocks(dp_rank_str, 0)
self.component_gauges.set_gpu_cache_usage(dp_rank_str, 0.0)
......
......@@ -410,7 +410,7 @@ class Publisher:
# Publish initial metrics with 0 active blocks
# TRT-LLM doesn't use data parallelism currently (dp_rank="0")
self.metrics_publisher.publish(None, 0)
self.metrics_publisher.publish(None, kv_used_blocks=0)
self.component_gauges.set_total_blocks("0", 0)
self.component_gauges.set_gpu_cache_usage("0", 0.0)
......@@ -478,7 +478,7 @@ class Publisher:
logging.debug(f"Publishing stats: kv_active_blocks: {kv_active_blocks}")
# TRT-LLM doesn't use data parallelism currently (dp_rank=None for NATS, "0" for Prometheus)
assert self.metrics_publisher is not None
self.metrics_publisher.publish(None, kv_active_blocks)
self.metrics_publisher.publish(None, kv_used_blocks=kv_active_blocks)
# Publish Prometheus metrics
self.component_gauges.set_total_blocks("0", kv_total_blocks)
......
......@@ -58,7 +58,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
**kwargs: object,
) -> None:
active_decode_blocks = int(self.num_gpu_block * scheduler_stats.kv_cache_usage)
self.inner.publish(self.dp_rank, active_decode_blocks)
self.inner.publish(self.dp_rank, kv_used_blocks=active_decode_blocks)
dp_rank_str = str(self.dp_rank)
self.component_gauges.set_total_blocks(dp_rank_str, self.num_gpu_block)
......@@ -72,7 +72,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
)
def init_publish(self) -> None:
self.inner.publish(self.dp_rank, 0)
self.inner.publish(self.dp_rank, kv_used_blocks=0)
dp_rank_str = str(self.dp_rank)
self.component_gauges.set_total_blocks(dp_rank_str, 0)
self.component_gauges.set_gpu_cache_usage(dp_rank_str, 0.0)
......
......@@ -234,11 +234,17 @@ impl WorkerMetricsPublisher {
///
/// # Arguments
/// * `dp_rank` - Data parallel rank of the worker (None defaults to 0)
/// * `active_decode_blocks` - Number of active KV cache blocks
#[pyo3(signature = (dp_rank, active_decode_blocks))]
fn publish(&self, dp_rank: Option<u32>, active_decode_blocks: u64) -> PyResult<()> {
/// * `active_decode_blocks` - Scheduler-compatible active decode block count
/// * `kv_used_blocks` - Authoritative total KV blocks currently in use
#[pyo3(signature = (dp_rank=None, active_decode_blocks=None, kv_used_blocks=None))]
fn publish(
&self,
dp_rank: Option<u32>,
active_decode_blocks: Option<u64>,
kv_used_blocks: Option<u64>,
) -> PyResult<()> {
self.inner
.publish(dp_rank, active_decode_blocks)
.publish(dp_rank, active_decode_blocks, kv_used_blocks)
.map_err(to_pyerr)
}
}
......
......@@ -431,15 +431,17 @@ class WorkerMetricsPublisher:
def publish(
self,
dp_rank: Optional[int],
active_decode_blocks: int,
dp_rank: Optional[int] = None,
active_decode_blocks: int | None = None,
kv_used_blocks: int | None = None,
) -> None:
"""
Publish worker metrics for load monitoring.
Args:
dp_rank: Data parallel rank of the worker (None defaults to 0)
active_decode_blocks: Number of active KV cache blocks
active_decode_blocks: Optional scheduler-compatible decode-block signal
kv_used_blocks: Optional authoritative total KV blocks currently in use
"""
...
......
......@@ -361,17 +361,23 @@ pub struct WorkerSelectionResult {
/// Active load metrics for a worker, used for busy detection.
///
/// Published by workers (with only `active_decode_blocks`) and by the scheduler
/// (with both `active_decode_blocks` and `active_prefill_tokens`).
/// Published by workers (with `kv_used_blocks`) and by the scheduler (with
/// `active_decode_blocks` and `active_prefill_tokens`).
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct ActiveLoad {
pub worker_id: WorkerId,
#[serde(default)]
pub dp_rank: DpRank,
/// Number of active KV cache blocks on the worker (decode phase).
/// Scheduler-reported decode block load.
pub active_decode_blocks: Option<u64>,
/// Number of active prefill tokens (from scheduler's view).
pub active_prefill_tokens: Option<u64>,
/// Total KV blocks currently in use on the worker.
///
/// This is published by workers only and is the authoritative signal for
/// backend KV occupancy used by busy detection.
#[serde(default)]
pub kv_used_blocks: Option<u64>,
}
/// A [`LocalBlockHash`] is a hash computed from the token IDs, optional multimodal metadata,
......
......@@ -660,6 +660,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
dp_rank: worker.dp_rank,
active_decode_blocks: Some(active_blocks as u64),
active_prefill_tokens: Some(active_tokens as u64),
kv_used_blocks: None,
};
self.publisher.publish_load(active_load);
......
......@@ -88,22 +88,138 @@ impl LoadThresholdConfig {
}
/// Worker load monitoring state per dp_rank
#[derive(Clone, Debug)]
struct DecodeBusyLatchState {
latched_busy: bool,
kv_used_blocks_cleared: bool,
active_decode_blocks_cleared: bool,
}
impl Default for DecodeBusyLatchState {
fn default() -> Self {
Self {
latched_busy: false,
kv_used_blocks_cleared: true,
active_decode_blocks_cleared: true,
}
}
}
#[derive(Clone, Debug, Default)]
pub struct WorkerLoadState {
pub active_decode_blocks: HashMap<u32, u64>,
pub kv_used_blocks: HashMap<u32, u64>,
pub kv_total_blocks: HashMap<u32, u64>,
pub active_prefill_tokens: HashMap<u32, u64>,
/// max_num_batched_tokens from runtime config (same for all dp_ranks)
pub max_num_batched_tokens: HashMap<u32, u64>,
decode_busy_latches: HashMap<u32, DecodeBusyLatchState>,
}
impl WorkerLoadState {
fn is_decode_signal_busy(
used_blocks: u64,
total_blocks: u64,
active_decode_blocks_threshold: f64,
) -> bool {
total_blocks > 0
&& (used_blocks as f64) > (active_decode_blocks_threshold * total_blocks as f64)
}
fn current_decode_busy(&self, dp_rank: u32, active_decode_blocks_threshold: f64) -> bool {
let Some(&total_blocks) = self.kv_total_blocks.get(&dp_rank) else {
return false;
};
self.kv_used_blocks
.get(&dp_rank)
.is_some_and(|&used_blocks| {
Self::is_decode_signal_busy(
used_blocks,
total_blocks,
active_decode_blocks_threshold,
)
})
|| self
.active_decode_blocks
.get(&dp_rank)
.is_some_and(|&active_blocks| {
Self::is_decode_signal_busy(
active_blocks,
total_blocks,
active_decode_blocks_threshold,
)
})
}
fn update_decode_busy_latch(
&mut self,
dp_rank: u32,
active_decode_blocks: Option<u64>,
kv_used_blocks: Option<u64>,
active_decode_blocks_threshold: f64,
) {
let Some(&total_blocks) = self.kv_total_blocks.get(&dp_rank) else {
return;
};
if total_blocks == 0 {
return;
}
let active_decode_busy = active_decode_blocks.is_some_and(|value| {
Self::is_decode_signal_busy(value, total_blocks, active_decode_blocks_threshold)
});
let kv_used_busy = kv_used_blocks.is_some_and(|value| {
Self::is_decode_signal_busy(value, total_blocks, active_decode_blocks_threshold)
});
let latch = self.decode_busy_latches.entry(dp_rank).or_default();
if active_decode_busy || kv_used_busy {
latch.latched_busy = true;
}
if let Some(value) = active_decode_blocks {
latch.active_decode_blocks_cleared =
!Self::is_decode_signal_busy(value, total_blocks, active_decode_blocks_threshold);
}
if let Some(value) = kv_used_blocks {
latch.kv_used_blocks_cleared =
!Self::is_decode_signal_busy(value, total_blocks, active_decode_blocks_threshold);
}
if latch.latched_busy && latch.kv_used_blocks_cleared && latch.active_decode_blocks_cleared
{
latch.latched_busy = false;
}
}
fn update_from_active_load(
&mut self,
active_load: &ActiveLoad,
active_decode_blocks_threshold: f64,
) {
let dp_rank = active_load.dp_rank;
if let Some(active_blocks) = active_load.active_decode_blocks {
self.active_decode_blocks.insert(dp_rank, active_blocks);
}
if let Some(kv_used_blocks) = active_load.kv_used_blocks {
self.kv_used_blocks.insert(dp_rank, kv_used_blocks);
}
if let Some(active_tokens) = active_load.active_prefill_tokens {
self.active_prefill_tokens.insert(dp_rank, active_tokens);
}
self.update_decode_busy_latch(
dp_rank,
active_load.active_decode_blocks,
active_load.kv_used_blocks,
active_decode_blocks_threshold,
);
}
/// Returns true if ALL dp_ranks are considered busy based on the threshold logic.
///
/// For each dp_rank, a dp_rank is busy if ANY of these conditions is met (OR logic):
/// 1. `active_prefill_tokens > active_prefill_tokens_threshold` (absolute threshold)
/// 2. `active_prefill_tokens > frac * max_num_batched_tokens` (fraction-based threshold)
/// 3. `active_decode_blocks / total_blocks > active_decode_blocks_threshold` (blocks threshold)
/// 3. decode busy latch set by either `kv_used_blocks` or `active_decode_blocks`
///
/// If none of these checks can be performed (missing data), that dp_rank is considered free.
///
......@@ -118,6 +234,8 @@ impl WorkerLoadState {
let all_dp_ranks: std::collections::HashSet<_> = self
.active_decode_blocks
.keys()
.chain(self.kv_used_blocks.keys())
.chain(self.decode_busy_latches.keys())
.chain(self.active_prefill_tokens.keys())
.copied()
.collect();
......@@ -148,15 +266,13 @@ impl WorkerLoadState {
}
}
// Check 3: blocks threshold
// Skip if total_blocks is 0 (no capacity means threshold check is meaningless)
if let (Some(&active_blocks), Some(&total_blocks)) = (
self.active_decode_blocks.get(&dp_rank),
self.kv_total_blocks.get(&dp_rank),
) && total_blocks > 0
&& (active_blocks as f64) > (active_decode_blocks_threshold * total_blocks as f64)
{
return true; // This dp_rank is busy due to blocks
// Check 3: decode busy latch
if let Some(latch) = self.decode_busy_latches.get(&dp_rank) {
if latch.latched_busy {
return true;
}
} else if self.current_decode_busy(dp_rank, active_decode_blocks_threshold) {
return true;
}
// If we can't perform any check or no threshold exceeded, this dp_rank is free
......@@ -504,18 +620,6 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
.or_default()
.insert(dp_rank);
// Update worker load state per dp_rank (for busy detection only)
// Note: Prometheus gauges are updated directly by sequence.rs
{
let mut state = worker_load_states.entry(worker_id).or_default();
if let Some(active_blocks) = active_load.active_decode_blocks {
state.active_decode_blocks.insert(dp_rank, active_blocks);
}
if let Some(active_tokens) = active_load.active_prefill_tokens {
state.active_prefill_tokens.insert(dp_rank, active_tokens);
}
}
// Load thresholds dynamically - allows runtime updates
let current_active_decode_blocks_threshold =
Self::scaled_to_f64(active_decode_blocks_threshold.load(Ordering::Relaxed));
......@@ -524,6 +628,16 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
let current_active_prefill_tokens_threshold_frac =
Self::scaled_to_f64(active_prefill_tokens_threshold_frac.load(Ordering::Relaxed));
// Update worker load state per dp_rank (for busy detection only)
// Note: Prometheus gauges are updated directly by sequence.rs
{
let mut state = worker_load_states.entry(worker_id).or_default();
state.update_from_active_load(
&active_load,
current_active_decode_blocks_threshold,
);
}
// Recalculate all busy instances and update
let busy_instances: Vec<u64> = worker_load_states
.iter()
......@@ -648,3 +762,187 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::WorkerLoadState;
use dynamo_kv_router::protocols::ActiveLoad;
#[test]
fn is_busy_prefers_kv_used_blocks_over_active_decode_blocks() {
let mut state = WorkerLoadState::default();
state.active_decode_blocks.insert(0, 10);
state.kv_used_blocks.insert(0, 90);
state.kv_total_blocks.insert(0, 100);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
}
#[test]
fn is_busy_falls_back_to_active_decode_blocks_when_kv_used_missing() {
let mut state = WorkerLoadState::default();
state.active_decode_blocks.insert(0, 90);
state.kv_total_blocks.insert(0, 100);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
}
#[test]
fn is_busy_recognizes_dp_rank_known_only_from_kv_used_blocks() {
let mut state = WorkerLoadState::default();
state.kv_used_blocks.insert(0, 90);
state.kv_total_blocks.insert(0, 100);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
}
#[test]
fn decode_busy_latch_sets_busy_if_any_signal_is_busy() {
let mut state = WorkerLoadState::default();
state.kv_total_blocks.insert(0, 100);
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: None,
active_prefill_tokens: None,
kv_used_blocks: Some(90),
},
0.6,
);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
}
#[test]
fn decode_busy_latch_only_clears_after_both_signals_report_nonbusy() {
let mut state = WorkerLoadState::default();
state.kv_total_blocks.insert(0, 100);
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: None,
active_prefill_tokens: None,
kv_used_blocks: Some(90),
},
0.6,
);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: Some(10),
active_prefill_tokens: None,
kv_used_blocks: None,
},
0.6,
);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: None,
active_prefill_tokens: None,
kv_used_blocks: Some(10),
},
0.6,
);
assert!(!state.is_busy(0.6, u64::MAX, 2.0));
}
#[test]
fn decode_busy_latch_clears_with_only_kv_used_blocks_signal() {
let mut state = WorkerLoadState::default();
state.kv_total_blocks.insert(0, 100);
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: None,
active_prefill_tokens: None,
kv_used_blocks: Some(90),
},
0.6,
);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: None,
active_prefill_tokens: None,
kv_used_blocks: Some(10),
},
0.6,
);
assert!(!state.is_busy(0.6, u64::MAX, 2.0));
}
#[test]
fn decode_busy_latch_clears_with_only_active_decode_blocks_signal() {
let mut state = WorkerLoadState::default();
state.kv_total_blocks.insert(0, 100);
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: Some(90),
active_prefill_tokens: None,
kv_used_blocks: None,
},
0.6,
);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: Some(10),
active_prefill_tokens: None,
kv_used_blocks: None,
},
0.6,
);
assert!(!state.is_busy(0.6, u64::MAX, 2.0));
}
#[test]
fn decode_busy_latch_clears_when_both_signals_are_nonbusy_in_same_event() {
let mut state = WorkerLoadState::default();
state.kv_total_blocks.insert(0, 100);
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: Some(90),
active_prefill_tokens: None,
kv_used_blocks: None,
},
0.6,
);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: Some(10),
active_prefill_tokens: None,
kv_used_blocks: Some(10),
},
0.6,
);
assert!(!state.is_busy(0.6, u64::MAX, 2.0));
}
}
......@@ -1225,7 +1225,8 @@ mod test_integration_publisher {
// Test 1: Publish 10 different metrics with 0.5ms intervals
// Only the last one should be published after 1ms of stability
for i in 0..10 {
publisher.publish(None, (i * 100) as u64).unwrap();
let value = (i * 100) as u64;
publisher.publish(None, None, Some(value)).unwrap();
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
......@@ -1240,8 +1241,9 @@ mod test_integration_publisher {
let (_envelope, event) = result.unwrap().unwrap(); // Unwrap the Option and the Result
assert_eq!(event.worker_id, worker_id);
assert_eq!(event.active_decode_blocks, Some(900)); // Last value: 9 * 100
assert_eq!(event.active_decode_blocks, None); // Worker publisher sends kv_used_blocks
assert_eq!(event.active_prefill_tokens, None); // Worker doesn't publish prefill tokens
assert_eq!(event.kv_used_blocks, Some(900));
// Ensure no more events are waiting
let no_msg =
......@@ -1250,7 +1252,7 @@ mod test_integration_publisher {
// Test 2: Publish 10 more metrics with same active_decode_blocks - should not trigger publish
for _ in 0..10 {
publisher.publish(None, 900).unwrap(); // Keep same as last published
publisher.publish(None, None, Some(900)).unwrap(); // Keep same as last published
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
......
......@@ -13,7 +13,8 @@ use crate::kv_router::KV_METRICS_SUBJECT;
#[derive(Debug, Clone, Default, PartialEq)]
struct WorkerMetrics {
dp_rank: DpRank,
active_decode_blocks: u64,
active_decode_blocks: Option<u64>,
kv_used_blocks: Option<u64>,
}
pub struct WorkerMetricsPublisher {
......@@ -27,15 +28,26 @@ impl WorkerMetricsPublisher {
Ok(Self { tx, rx })
}
pub fn publish(&self, dp_rank: Option<DpRank>, active_decode_blocks: u64) -> Result<()> {
pub fn publish(
&self,
dp_rank: Option<DpRank>,
active_decode_blocks: Option<u64>,
kv_used_blocks: Option<u64>,
) -> Result<()> {
if active_decode_blocks.is_none() && kv_used_blocks.is_none() {
anyhow::bail!("worker metrics publish requires at least one load metric");
}
let metrics = WorkerMetrics {
dp_rank: dp_rank.unwrap_or(0),
active_decode_blocks,
kv_used_blocks,
};
tracing::trace!(
"Publish metrics: dp_rank={}, active_decode_blocks={}",
"Publish metrics: dp_rank={}, active_decode_blocks={:?}, kv_used_blocks={:?}",
metrics.dp_rank,
metrics.active_decode_blocks
metrics.active_decode_blocks,
metrics.kv_used_blocks
);
self.tx
.send(metrics)
......@@ -95,8 +107,9 @@ impl WorkerMetricsPublisher {
let active_load = ActiveLoad {
worker_id,
dp_rank: metrics.dp_rank,
active_decode_blocks: Some(metrics.active_decode_blocks),
active_decode_blocks: metrics.active_decode_blocks,
active_prefill_tokens: None,
kv_used_blocks: metrics.kv_used_blocks,
};
if let Err(e) = event_publisher.publish(&active_load).await {
......
......@@ -552,7 +552,11 @@ impl MockEngine {
let metrics = metrics_rx.borrow().clone();
// Publish metrics using flat API
if let Err(e) = publisher.publish(Some(metrics.dp_rank), metrics.active_decode_blocks) {
if let Err(e) = publisher.publish(
Some(metrics.dp_rank),
None,
Some(metrics.active_decode_blocks),
) {
tracing::warn!("Failed to publish metrics for DP rank {}: {e}", metrics.dp_rank);
} else {
tracing::trace!("Published metrics for DP rank {}", metrics.dp_rank);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment