Unverified Commit ffe20629 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: latch busy load until both cool off (#7912)


Signed-off-by: default avatarAmeenP <ameenp360@gmail.com>
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarAmeenP <ameenp360@gmail.com>
parent 95a750f4
...@@ -140,7 +140,9 @@ class DynamoSglangPublisher: ...@@ -140,7 +140,9 @@ class DynamoSglangPublisher:
else self.dp_rank else self.dp_rank
) )
active_decode_blocks = kv_metrics.kv_active_blocks active_decode_blocks = kv_metrics.kv_active_blocks
self.metrics_publisher.publish(dp_rank, active_decode_blocks) self.metrics_publisher.publish(
dp_rank, kv_used_blocks=active_decode_blocks
)
dp_rank_str = str(dp_rank) dp_rank_str = str(dp_rank)
# Publish total blocks (always available in KvMetrics) # Publish total blocks (always available in KvMetrics)
self.component_gauges.set_total_blocks( self.component_gauges.set_total_blocks(
...@@ -185,7 +187,7 @@ class DynamoSglangPublisher: ...@@ -185,7 +187,7 @@ class DynamoSglangPublisher:
def init_engine_metrics_publish(self) -> None: def init_engine_metrics_publish(self) -> None:
"""Publish initial dummy metrics to bootstrap the metrics endpoint.""" """Publish initial dummy metrics to bootstrap the metrics endpoint."""
logging.info("Sending dummy metrics to initialize") logging.info("Sending dummy metrics to initialize")
self.metrics_publisher.publish(self.dp_rank, 0) self.metrics_publisher.publish(self.dp_rank, kv_used_blocks=0)
dp_rank_str = str(self.dp_rank) dp_rank_str = str(self.dp_rank)
self.component_gauges.set_total_blocks(dp_rank_str, 0) self.component_gauges.set_total_blocks(dp_rank_str, 0)
self.component_gauges.set_gpu_cache_usage(dp_rank_str, 0.0) self.component_gauges.set_gpu_cache_usage(dp_rank_str, 0.0)
......
...@@ -410,7 +410,7 @@ class Publisher: ...@@ -410,7 +410,7 @@ class Publisher:
# Publish initial metrics with 0 active blocks # Publish initial metrics with 0 active blocks
# TRT-LLM doesn't use data parallelism currently (dp_rank="0") # TRT-LLM doesn't use data parallelism currently (dp_rank="0")
self.metrics_publisher.publish(None, 0) self.metrics_publisher.publish(None, kv_used_blocks=0)
self.component_gauges.set_total_blocks("0", 0) self.component_gauges.set_total_blocks("0", 0)
self.component_gauges.set_gpu_cache_usage("0", 0.0) self.component_gauges.set_gpu_cache_usage("0", 0.0)
...@@ -478,7 +478,7 @@ class Publisher: ...@@ -478,7 +478,7 @@ class Publisher:
logging.debug(f"Publishing stats: kv_active_blocks: {kv_active_blocks}") logging.debug(f"Publishing stats: kv_active_blocks: {kv_active_blocks}")
# TRT-LLM doesn't use data parallelism currently (dp_rank=None for NATS, "0" for Prometheus) # TRT-LLM doesn't use data parallelism currently (dp_rank=None for NATS, "0" for Prometheus)
assert self.metrics_publisher is not None assert self.metrics_publisher is not None
self.metrics_publisher.publish(None, kv_active_blocks) self.metrics_publisher.publish(None, kv_used_blocks=kv_active_blocks)
# Publish Prometheus metrics # Publish Prometheus metrics
self.component_gauges.set_total_blocks("0", kv_total_blocks) self.component_gauges.set_total_blocks("0", kv_total_blocks)
......
...@@ -58,7 +58,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase): ...@@ -58,7 +58,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
**kwargs: object, **kwargs: object,
) -> None: ) -> None:
active_decode_blocks = int(self.num_gpu_block * scheduler_stats.kv_cache_usage) active_decode_blocks = int(self.num_gpu_block * scheduler_stats.kv_cache_usage)
self.inner.publish(self.dp_rank, active_decode_blocks) self.inner.publish(self.dp_rank, kv_used_blocks=active_decode_blocks)
dp_rank_str = str(self.dp_rank) dp_rank_str = str(self.dp_rank)
self.component_gauges.set_total_blocks(dp_rank_str, self.num_gpu_block) self.component_gauges.set_total_blocks(dp_rank_str, self.num_gpu_block)
...@@ -72,7 +72,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase): ...@@ -72,7 +72,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
) )
def init_publish(self) -> None: def init_publish(self) -> None:
self.inner.publish(self.dp_rank, 0) self.inner.publish(self.dp_rank, kv_used_blocks=0)
dp_rank_str = str(self.dp_rank) dp_rank_str = str(self.dp_rank)
self.component_gauges.set_total_blocks(dp_rank_str, 0) self.component_gauges.set_total_blocks(dp_rank_str, 0)
self.component_gauges.set_gpu_cache_usage(dp_rank_str, 0.0) self.component_gauges.set_gpu_cache_usage(dp_rank_str, 0.0)
......
...@@ -234,11 +234,17 @@ impl WorkerMetricsPublisher { ...@@ -234,11 +234,17 @@ impl WorkerMetricsPublisher {
/// ///
/// # Arguments /// # Arguments
/// * `dp_rank` - Data parallel rank of the worker (None defaults to 0) /// * `dp_rank` - Data parallel rank of the worker (None defaults to 0)
/// * `active_decode_blocks` - Number of active KV cache blocks /// * `active_decode_blocks` - Scheduler-compatible active decode block count
#[pyo3(signature = (dp_rank, active_decode_blocks))] /// * `kv_used_blocks` - Authoritative total KV blocks currently in use
fn publish(&self, dp_rank: Option<u32>, active_decode_blocks: u64) -> PyResult<()> { #[pyo3(signature = (dp_rank=None, active_decode_blocks=None, kv_used_blocks=None))]
fn publish(
&self,
dp_rank: Option<u32>,
active_decode_blocks: Option<u64>,
kv_used_blocks: Option<u64>,
) -> PyResult<()> {
self.inner self.inner
.publish(dp_rank, active_decode_blocks) .publish(dp_rank, active_decode_blocks, kv_used_blocks)
.map_err(to_pyerr) .map_err(to_pyerr)
} }
} }
......
...@@ -431,15 +431,17 @@ class WorkerMetricsPublisher: ...@@ -431,15 +431,17 @@ class WorkerMetricsPublisher:
def publish( def publish(
self, self,
dp_rank: Optional[int], dp_rank: Optional[int] = None,
active_decode_blocks: int, active_decode_blocks: int | None = None,
kv_used_blocks: int | None = None,
) -> None: ) -> None:
""" """
Publish worker metrics for load monitoring. Publish worker metrics for load monitoring.
Args: Args:
dp_rank: Data parallel rank of the worker (None defaults to 0) dp_rank: Data parallel rank of the worker (None defaults to 0)
active_decode_blocks: Number of active KV cache blocks active_decode_blocks: Optional scheduler-compatible decode-block signal
kv_used_blocks: Optional authoritative total KV blocks currently in use
""" """
... ...
......
...@@ -361,17 +361,23 @@ pub struct WorkerSelectionResult { ...@@ -361,17 +361,23 @@ pub struct WorkerSelectionResult {
/// Active load metrics for a worker, used for busy detection. /// Active load metrics for a worker, used for busy detection.
/// ///
/// Published by workers (with only `active_decode_blocks`) and by the scheduler /// Published by workers (with `kv_used_blocks`) and by the scheduler (with
/// (with both `active_decode_blocks` and `active_prefill_tokens`). /// `active_decode_blocks` and `active_prefill_tokens`).
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
pub struct ActiveLoad { pub struct ActiveLoad {
pub worker_id: WorkerId, pub worker_id: WorkerId,
#[serde(default)] #[serde(default)]
pub dp_rank: DpRank, pub dp_rank: DpRank,
/// Number of active KV cache blocks on the worker (decode phase). /// Scheduler-reported decode block load.
pub active_decode_blocks: Option<u64>, pub active_decode_blocks: Option<u64>,
/// Number of active prefill tokens (from scheduler's view). /// Number of active prefill tokens (from scheduler's view).
pub active_prefill_tokens: Option<u64>, pub active_prefill_tokens: Option<u64>,
/// Total KV blocks currently in use on the worker.
///
/// This is published by workers only and is the authoritative signal for
/// backend KV occupancy used by busy detection.
#[serde(default)]
pub kv_used_blocks: Option<u64>,
} }
/// A [`LocalBlockHash`] is a hash computed from the token IDs, optional multimodal metadata, /// A [`LocalBlockHash`] is a hash computed from the token IDs, optional multimodal metadata,
......
...@@ -660,6 +660,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -660,6 +660,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
dp_rank: worker.dp_rank, dp_rank: worker.dp_rank,
active_decode_blocks: Some(active_blocks as u64), active_decode_blocks: Some(active_blocks as u64),
active_prefill_tokens: Some(active_tokens as u64), active_prefill_tokens: Some(active_tokens as u64),
kv_used_blocks: None,
}; };
self.publisher.publish_load(active_load); self.publisher.publish_load(active_load);
......
...@@ -88,22 +88,138 @@ impl LoadThresholdConfig { ...@@ -88,22 +88,138 @@ impl LoadThresholdConfig {
} }
/// Worker load monitoring state per dp_rank /// Worker load monitoring state per dp_rank
#[derive(Clone, Debug)]
struct DecodeBusyLatchState {
latched_busy: bool,
kv_used_blocks_cleared: bool,
active_decode_blocks_cleared: bool,
}
impl Default for DecodeBusyLatchState {
fn default() -> Self {
Self {
latched_busy: false,
kv_used_blocks_cleared: true,
active_decode_blocks_cleared: true,
}
}
}
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub struct WorkerLoadState { pub struct WorkerLoadState {
pub active_decode_blocks: HashMap<u32, u64>, pub active_decode_blocks: HashMap<u32, u64>,
pub kv_used_blocks: HashMap<u32, u64>,
pub kv_total_blocks: HashMap<u32, u64>, pub kv_total_blocks: HashMap<u32, u64>,
pub active_prefill_tokens: HashMap<u32, u64>, pub active_prefill_tokens: HashMap<u32, u64>,
/// max_num_batched_tokens from runtime config (same for all dp_ranks) /// max_num_batched_tokens from runtime config (same for all dp_ranks)
pub max_num_batched_tokens: HashMap<u32, u64>, pub max_num_batched_tokens: HashMap<u32, u64>,
decode_busy_latches: HashMap<u32, DecodeBusyLatchState>,
} }
impl WorkerLoadState { impl WorkerLoadState {
fn is_decode_signal_busy(
used_blocks: u64,
total_blocks: u64,
active_decode_blocks_threshold: f64,
) -> bool {
total_blocks > 0
&& (used_blocks as f64) > (active_decode_blocks_threshold * total_blocks as f64)
}
fn current_decode_busy(&self, dp_rank: u32, active_decode_blocks_threshold: f64) -> bool {
let Some(&total_blocks) = self.kv_total_blocks.get(&dp_rank) else {
return false;
};
self.kv_used_blocks
.get(&dp_rank)
.is_some_and(|&used_blocks| {
Self::is_decode_signal_busy(
used_blocks,
total_blocks,
active_decode_blocks_threshold,
)
})
|| self
.active_decode_blocks
.get(&dp_rank)
.is_some_and(|&active_blocks| {
Self::is_decode_signal_busy(
active_blocks,
total_blocks,
active_decode_blocks_threshold,
)
})
}
fn update_decode_busy_latch(
&mut self,
dp_rank: u32,
active_decode_blocks: Option<u64>,
kv_used_blocks: Option<u64>,
active_decode_blocks_threshold: f64,
) {
let Some(&total_blocks) = self.kv_total_blocks.get(&dp_rank) else {
return;
};
if total_blocks == 0 {
return;
}
let active_decode_busy = active_decode_blocks.is_some_and(|value| {
Self::is_decode_signal_busy(value, total_blocks, active_decode_blocks_threshold)
});
let kv_used_busy = kv_used_blocks.is_some_and(|value| {
Self::is_decode_signal_busy(value, total_blocks, active_decode_blocks_threshold)
});
let latch = self.decode_busy_latches.entry(dp_rank).or_default();
if active_decode_busy || kv_used_busy {
latch.latched_busy = true;
}
if let Some(value) = active_decode_blocks {
latch.active_decode_blocks_cleared =
!Self::is_decode_signal_busy(value, total_blocks, active_decode_blocks_threshold);
}
if let Some(value) = kv_used_blocks {
latch.kv_used_blocks_cleared =
!Self::is_decode_signal_busy(value, total_blocks, active_decode_blocks_threshold);
}
if latch.latched_busy && latch.kv_used_blocks_cleared && latch.active_decode_blocks_cleared
{
latch.latched_busy = false;
}
}
fn update_from_active_load(
&mut self,
active_load: &ActiveLoad,
active_decode_blocks_threshold: f64,
) {
let dp_rank = active_load.dp_rank;
if let Some(active_blocks) = active_load.active_decode_blocks {
self.active_decode_blocks.insert(dp_rank, active_blocks);
}
if let Some(kv_used_blocks) = active_load.kv_used_blocks {
self.kv_used_blocks.insert(dp_rank, kv_used_blocks);
}
if let Some(active_tokens) = active_load.active_prefill_tokens {
self.active_prefill_tokens.insert(dp_rank, active_tokens);
}
self.update_decode_busy_latch(
dp_rank,
active_load.active_decode_blocks,
active_load.kv_used_blocks,
active_decode_blocks_threshold,
);
}
/// Returns true if ALL dp_ranks are considered busy based on the threshold logic. /// Returns true if ALL dp_ranks are considered busy based on the threshold logic.
/// ///
/// For each dp_rank, a dp_rank is busy if ANY of these conditions is met (OR logic): /// For each dp_rank, a dp_rank is busy if ANY of these conditions is met (OR logic):
/// 1. `active_prefill_tokens > active_prefill_tokens_threshold` (absolute threshold) /// 1. `active_prefill_tokens > active_prefill_tokens_threshold` (absolute threshold)
/// 2. `active_prefill_tokens > frac * max_num_batched_tokens` (fraction-based threshold) /// 2. `active_prefill_tokens > frac * max_num_batched_tokens` (fraction-based threshold)
/// 3. `active_decode_blocks / total_blocks > active_decode_blocks_threshold` (blocks threshold) /// 3. decode busy latch set by either `kv_used_blocks` or `active_decode_blocks`
/// ///
/// If none of these checks can be performed (missing data), that dp_rank is considered free. /// If none of these checks can be performed (missing data), that dp_rank is considered free.
/// ///
...@@ -118,6 +234,8 @@ impl WorkerLoadState { ...@@ -118,6 +234,8 @@ impl WorkerLoadState {
let all_dp_ranks: std::collections::HashSet<_> = self let all_dp_ranks: std::collections::HashSet<_> = self
.active_decode_blocks .active_decode_blocks
.keys() .keys()
.chain(self.kv_used_blocks.keys())
.chain(self.decode_busy_latches.keys())
.chain(self.active_prefill_tokens.keys()) .chain(self.active_prefill_tokens.keys())
.copied() .copied()
.collect(); .collect();
...@@ -148,15 +266,13 @@ impl WorkerLoadState { ...@@ -148,15 +266,13 @@ impl WorkerLoadState {
} }
} }
// Check 3: blocks threshold // Check 3: decode busy latch
// Skip if total_blocks is 0 (no capacity means threshold check is meaningless) if let Some(latch) = self.decode_busy_latches.get(&dp_rank) {
if let (Some(&active_blocks), Some(&total_blocks)) = ( if latch.latched_busy {
self.active_decode_blocks.get(&dp_rank), return true;
self.kv_total_blocks.get(&dp_rank), }
) && total_blocks > 0 } else if self.current_decode_busy(dp_rank, active_decode_blocks_threshold) {
&& (active_blocks as f64) > (active_decode_blocks_threshold * total_blocks as f64) return true;
{
return true; // This dp_rank is busy due to blocks
} }
// If we can't perform any check or no threshold exceeded, this dp_rank is free // If we can't perform any check or no threshold exceeded, this dp_rank is free
...@@ -504,18 +620,6 @@ impl WorkerLoadMonitor for KvWorkerMonitor { ...@@ -504,18 +620,6 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
.or_default() .or_default()
.insert(dp_rank); .insert(dp_rank);
// Update worker load state per dp_rank (for busy detection only)
// Note: Prometheus gauges are updated directly by sequence.rs
{
let mut state = worker_load_states.entry(worker_id).or_default();
if let Some(active_blocks) = active_load.active_decode_blocks {
state.active_decode_blocks.insert(dp_rank, active_blocks);
}
if let Some(active_tokens) = active_load.active_prefill_tokens {
state.active_prefill_tokens.insert(dp_rank, active_tokens);
}
}
// Load thresholds dynamically - allows runtime updates // Load thresholds dynamically - allows runtime updates
let current_active_decode_blocks_threshold = let current_active_decode_blocks_threshold =
Self::scaled_to_f64(active_decode_blocks_threshold.load(Ordering::Relaxed)); Self::scaled_to_f64(active_decode_blocks_threshold.load(Ordering::Relaxed));
...@@ -524,6 +628,16 @@ impl WorkerLoadMonitor for KvWorkerMonitor { ...@@ -524,6 +628,16 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
let current_active_prefill_tokens_threshold_frac = let current_active_prefill_tokens_threshold_frac =
Self::scaled_to_f64(active_prefill_tokens_threshold_frac.load(Ordering::Relaxed)); Self::scaled_to_f64(active_prefill_tokens_threshold_frac.load(Ordering::Relaxed));
// Update worker load state per dp_rank (for busy detection only)
// Note: Prometheus gauges are updated directly by sequence.rs
{
let mut state = worker_load_states.entry(worker_id).or_default();
state.update_from_active_load(
&active_load,
current_active_decode_blocks_threshold,
);
}
// Recalculate all busy instances and update // Recalculate all busy instances and update
let busy_instances: Vec<u64> = worker_load_states let busy_instances: Vec<u64> = worker_load_states
.iter() .iter()
...@@ -648,3 +762,187 @@ impl WorkerLoadMonitor for KvWorkerMonitor { ...@@ -648,3 +762,187 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
Ok(()) Ok(())
} }
} }
#[cfg(test)]
mod tests {
use super::WorkerLoadState;
use dynamo_kv_router::protocols::ActiveLoad;
#[test]
fn is_busy_prefers_kv_used_blocks_over_active_decode_blocks() {
let mut state = WorkerLoadState::default();
state.active_decode_blocks.insert(0, 10);
state.kv_used_blocks.insert(0, 90);
state.kv_total_blocks.insert(0, 100);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
}
#[test]
fn is_busy_falls_back_to_active_decode_blocks_when_kv_used_missing() {
let mut state = WorkerLoadState::default();
state.active_decode_blocks.insert(0, 90);
state.kv_total_blocks.insert(0, 100);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
}
#[test]
fn is_busy_recognizes_dp_rank_known_only_from_kv_used_blocks() {
let mut state = WorkerLoadState::default();
state.kv_used_blocks.insert(0, 90);
state.kv_total_blocks.insert(0, 100);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
}
#[test]
fn decode_busy_latch_sets_busy_if_any_signal_is_busy() {
let mut state = WorkerLoadState::default();
state.kv_total_blocks.insert(0, 100);
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: None,
active_prefill_tokens: None,
kv_used_blocks: Some(90),
},
0.6,
);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
}
#[test]
fn decode_busy_latch_only_clears_after_both_signals_report_nonbusy() {
let mut state = WorkerLoadState::default();
state.kv_total_blocks.insert(0, 100);
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: None,
active_prefill_tokens: None,
kv_used_blocks: Some(90),
},
0.6,
);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: Some(10),
active_prefill_tokens: None,
kv_used_blocks: None,
},
0.6,
);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: None,
active_prefill_tokens: None,
kv_used_blocks: Some(10),
},
0.6,
);
assert!(!state.is_busy(0.6, u64::MAX, 2.0));
}
#[test]
fn decode_busy_latch_clears_with_only_kv_used_blocks_signal() {
let mut state = WorkerLoadState::default();
state.kv_total_blocks.insert(0, 100);
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: None,
active_prefill_tokens: None,
kv_used_blocks: Some(90),
},
0.6,
);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: None,
active_prefill_tokens: None,
kv_used_blocks: Some(10),
},
0.6,
);
assert!(!state.is_busy(0.6, u64::MAX, 2.0));
}
#[test]
fn decode_busy_latch_clears_with_only_active_decode_blocks_signal() {
let mut state = WorkerLoadState::default();
state.kv_total_blocks.insert(0, 100);
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: Some(90),
active_prefill_tokens: None,
kv_used_blocks: None,
},
0.6,
);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: Some(10),
active_prefill_tokens: None,
kv_used_blocks: None,
},
0.6,
);
assert!(!state.is_busy(0.6, u64::MAX, 2.0));
}
#[test]
fn decode_busy_latch_clears_when_both_signals_are_nonbusy_in_same_event() {
let mut state = WorkerLoadState::default();
state.kv_total_blocks.insert(0, 100);
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: Some(90),
active_prefill_tokens: None,
kv_used_blocks: None,
},
0.6,
);
assert!(state.is_busy(0.6, u64::MAX, 2.0));
state.update_from_active_load(
&ActiveLoad {
worker_id: 1,
dp_rank: 0,
active_decode_blocks: Some(10),
active_prefill_tokens: None,
kv_used_blocks: Some(10),
},
0.6,
);
assert!(!state.is_busy(0.6, u64::MAX, 2.0));
}
}
...@@ -1225,7 +1225,8 @@ mod test_integration_publisher { ...@@ -1225,7 +1225,8 @@ mod test_integration_publisher {
// Test 1: Publish 10 different metrics with 0.5ms intervals // Test 1: Publish 10 different metrics with 0.5ms intervals
// Only the last one should be published after 1ms of stability // Only the last one should be published after 1ms of stability
for i in 0..10 { for i in 0..10 {
publisher.publish(None, (i * 100) as u64).unwrap(); let value = (i * 100) as u64;
publisher.publish(None, None, Some(value)).unwrap();
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await; tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
} }
...@@ -1240,8 +1241,9 @@ mod test_integration_publisher { ...@@ -1240,8 +1241,9 @@ mod test_integration_publisher {
let (_envelope, event) = result.unwrap().unwrap(); // Unwrap the Option and the Result let (_envelope, event) = result.unwrap().unwrap(); // Unwrap the Option and the Result
assert_eq!(event.worker_id, worker_id); assert_eq!(event.worker_id, worker_id);
assert_eq!(event.active_decode_blocks, Some(900)); // Last value: 9 * 100 assert_eq!(event.active_decode_blocks, None); // Worker publisher sends kv_used_blocks
assert_eq!(event.active_prefill_tokens, None); // Worker doesn't publish prefill tokens assert_eq!(event.active_prefill_tokens, None); // Worker doesn't publish prefill tokens
assert_eq!(event.kv_used_blocks, Some(900));
// Ensure no more events are waiting // Ensure no more events are waiting
let no_msg = let no_msg =
...@@ -1250,7 +1252,7 @@ mod test_integration_publisher { ...@@ -1250,7 +1252,7 @@ mod test_integration_publisher {
// Test 2: Publish 10 more metrics with same active_decode_blocks - should not trigger publish // Test 2: Publish 10 more metrics with same active_decode_blocks - should not trigger publish
for _ in 0..10 { for _ in 0..10 {
publisher.publish(None, 900).unwrap(); // Keep same as last published publisher.publish(None, None, Some(900)).unwrap(); // Keep same as last published
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await; tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
} }
......
...@@ -13,7 +13,8 @@ use crate::kv_router::KV_METRICS_SUBJECT; ...@@ -13,7 +13,8 @@ use crate::kv_router::KV_METRICS_SUBJECT;
#[derive(Debug, Clone, Default, PartialEq)] #[derive(Debug, Clone, Default, PartialEq)]
struct WorkerMetrics { struct WorkerMetrics {
dp_rank: DpRank, dp_rank: DpRank,
active_decode_blocks: u64, active_decode_blocks: Option<u64>,
kv_used_blocks: Option<u64>,
} }
pub struct WorkerMetricsPublisher { pub struct WorkerMetricsPublisher {
...@@ -27,15 +28,26 @@ impl WorkerMetricsPublisher { ...@@ -27,15 +28,26 @@ impl WorkerMetricsPublisher {
Ok(Self { tx, rx }) Ok(Self { tx, rx })
} }
pub fn publish(&self, dp_rank: Option<DpRank>, active_decode_blocks: u64) -> Result<()> { pub fn publish(
&self,
dp_rank: Option<DpRank>,
active_decode_blocks: Option<u64>,
kv_used_blocks: Option<u64>,
) -> Result<()> {
if active_decode_blocks.is_none() && kv_used_blocks.is_none() {
anyhow::bail!("worker metrics publish requires at least one load metric");
}
let metrics = WorkerMetrics { let metrics = WorkerMetrics {
dp_rank: dp_rank.unwrap_or(0), dp_rank: dp_rank.unwrap_or(0),
active_decode_blocks, active_decode_blocks,
kv_used_blocks,
}; };
tracing::trace!( tracing::trace!(
"Publish metrics: dp_rank={}, active_decode_blocks={}", "Publish metrics: dp_rank={}, active_decode_blocks={:?}, kv_used_blocks={:?}",
metrics.dp_rank, metrics.dp_rank,
metrics.active_decode_blocks metrics.active_decode_blocks,
metrics.kv_used_blocks
); );
self.tx self.tx
.send(metrics) .send(metrics)
...@@ -95,8 +107,9 @@ impl WorkerMetricsPublisher { ...@@ -95,8 +107,9 @@ impl WorkerMetricsPublisher {
let active_load = ActiveLoad { let active_load = ActiveLoad {
worker_id, worker_id,
dp_rank: metrics.dp_rank, dp_rank: metrics.dp_rank,
active_decode_blocks: Some(metrics.active_decode_blocks), active_decode_blocks: metrics.active_decode_blocks,
active_prefill_tokens: None, active_prefill_tokens: None,
kv_used_blocks: metrics.kv_used_blocks,
}; };
if let Err(e) = event_publisher.publish(&active_load).await { if let Err(e) = event_publisher.publish(&active_load).await {
......
...@@ -552,7 +552,11 @@ impl MockEngine { ...@@ -552,7 +552,11 @@ impl MockEngine {
let metrics = metrics_rx.borrow().clone(); let metrics = metrics_rx.borrow().clone();
// Publish metrics using flat API // Publish metrics using flat API
if let Err(e) = publisher.publish(Some(metrics.dp_rank), metrics.active_decode_blocks) { if let Err(e) = publisher.publish(
Some(metrics.dp_rank),
None,
Some(metrics.active_decode_blocks),
) {
tracing::warn!("Failed to publish metrics for DP rank {}: {e}", metrics.dp_rank); tracing::warn!("Failed to publish metrics for DP rank {}: {e}", metrics.dp_rank);
} else { } else {
tracing::trace!("Published metrics for DP rank {}", metrics.dp_rank); tracing::trace!("Published metrics for DP rank {}", metrics.dp_rank);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment