worker_monitor.rs 12.7 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use crate::kv_router::KV_METRICS_SUBJECT;
5
use crate::kv_router::protocols::ActiveLoad;
6
use crate::model_card::ModelDeploymentCard;
7
use dynamo_runtime::component::Client;
8
use dynamo_runtime::discovery::{DiscoveryQuery, watch_and_extract_field};
9
10
use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider;
11
use dynamo_runtime::transports::event_plane::EventSubscriber;
12
use std::collections::HashMap;
13
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
14
15
use std::sync::{Arc, RwLock};

16
17
18
/// Scale factor for storing f64 thresholds as u32 (10000 = 4 decimal places)
const THRESHOLD_SCALE: u32 = 10000;

Yan Ru Pei's avatar
Yan Ru Pei committed
19
20
/// Worker load monitoring state per dp_rank
#[derive(Clone, Debug, Default)]
21
pub struct WorkerLoadState {
22
    pub active_decode_blocks: HashMap<u32, u64>,
Yan Ru Pei's avatar
Yan Ru Pei committed
23
    pub kv_total_blocks: HashMap<u32, u64>,
24
    pub active_prefill_tokens: HashMap<u32, u64>,
25
26
27
}

impl WorkerLoadState {
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    /// Returns true if ALL dp_ranks are considered busy based on the dual-threshold logic:
    ///
    /// For each dp_rank:
    /// 1. If `active_prefill_tokens` is available, check if tokens exceed the literal threshold.
    ///    If so, that dp_rank is busy.
    /// 2. If not, check if `active_decode_blocks` and `kv_total_blocks` are both available,
    ///    and if blocks exceed threshold. If so, that dp_rank is busy.
    /// 3. If neither check can be performed (missing data), that dp_rank is considered free.
    ///
    /// The worker is busy only if ALL dp_ranks are busy.
    pub fn is_busy(
        &self,
        active_decode_blocks_threshold: f64,
        active_prefill_tokens_threshold: u64,
    ) -> bool {
        // Get all dp_ranks we know about
        let all_dp_ranks: std::collections::HashSet<_> = self
            .active_decode_blocks
Yan Ru Pei's avatar
Yan Ru Pei committed
46
            .keys()
47
48
            .chain(self.active_prefill_tokens.keys())
            .copied()
Yan Ru Pei's avatar
Yan Ru Pei committed
49
50
            .collect();

51
52
        // If no dp_ranks known, not busy
        if all_dp_ranks.is_empty() {
Yan Ru Pei's avatar
Yan Ru Pei committed
53
            return false;
54
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
55

56
57
58
59
60
61
62
63
64
65
66
67
68
        // Check if ALL dp_ranks are busy
        all_dp_ranks.iter().all(|&dp_rank| {
            // First check: prefill tokens threshold (literal token count)
            if let Some(&active_tokens) = self.active_prefill_tokens.get(&dp_rank)
                && active_tokens > active_prefill_tokens_threshold
            {
                return true; // This dp_rank is busy due to tokens
            }

            // Second check: blocks threshold
            // Skip if total_blocks is 0 (no capacity means threshold check is meaningless)
            if let (Some(&active_blocks), Some(&total_blocks)) = (
                self.active_decode_blocks.get(&dp_rank),
Yan Ru Pei's avatar
Yan Ru Pei committed
69
                self.kv_total_blocks.get(&dp_rank),
70
71
72
73
            ) && total_blocks > 0
                && (active_blocks as f64) > (active_decode_blocks_threshold * total_blocks as f64)
            {
                return true; // This dp_rank is busy due to blocks
Yan Ru Pei's avatar
Yan Ru Pei committed
74
            }
75
76
77

            // If we can't perform either check, this dp_rank is considered free
            false
Yan Ru Pei's avatar
Yan Ru Pei committed
78
        })
79
80
81
    }
}

82
83
/// Worker monitor for tracking KV cache usage and busy states.
///
84
/// Cloning shares state via internal Arc-wrapped fields. This allows multiple pipelines
85
86
/// (e.g., chat and completions) to share the same monitor instance.
#[derive(Clone)]
87
pub struct KvWorkerMonitor {
88
    client: Client,
89
    worker_load_states: Arc<RwLock<HashMap<u64, WorkerLoadState>>>,
90
91
92
93
    /// Active decode blocks threshold stored as parts-per-10000 (e.g., 8500 = 0.85)
    active_decode_blocks_threshold: Arc<AtomicU32>,
    /// Active prefill tokens threshold stored as literal token count (u64)
    active_prefill_tokens_threshold: Arc<AtomicU64>,
94
95
    /// Guard to ensure start_monitoring() only runs once across clones
    started: Arc<AtomicBool>,
96
97
}

98
impl KvWorkerMonitor {
99
    /// Create a new worker monitor with the given thresholds.
100
    ///
101
102
103
104
105
106
107
108
109
110
    /// - `active_decode_blocks_threshold` (0.0-1.0): Threshold percentage for KV cache block utilization
    /// - `active_prefill_tokens_threshold`: Literal token count threshold for prefill token utilization
    ///
    /// Both thresholds can be dynamically updated via `set_active_decode_blocks_threshold()` and
    /// `set_active_prefill_tokens_threshold()`.
    pub fn new(
        client: Client,
        active_decode_blocks_threshold: f64,
        active_prefill_tokens_threshold: u64,
    ) -> Self {
111
112
113
        Self {
            client,
            worker_load_states: Arc::new(RwLock::new(HashMap::new())),
114
115
116
117
118
119
            active_decode_blocks_threshold: Arc::new(AtomicU32::new(
                Self::active_decode_blocks_threshold_to_scaled(active_decode_blocks_threshold),
            )),
            active_prefill_tokens_threshold: Arc::new(AtomicU64::new(
                active_prefill_tokens_threshold,
            )),
120
            started: Arc::new(AtomicBool::new(false)),
121
122
123
        }
    }

124
    /// Convert a f64 active decode blocks threshold (0.0-1.0) to scaled u32 for atomic storage.
125
    #[inline]
126
    fn active_decode_blocks_threshold_to_scaled(threshold: f64) -> u32 {
127
128
129
        (threshold * THRESHOLD_SCALE as f64) as u32
    }

130
    /// Convert a scaled u32 back to f64 active decode blocks threshold (0.0-1.0).
131
    #[inline]
132
    fn scaled_to_active_decode_blocks_threshold(scaled: u32) -> f64 {
133
134
135
        scaled as f64 / THRESHOLD_SCALE as f64
    }

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
    /// Get the current active decode blocks threshold value as f64.
    pub fn active_decode_blocks_threshold(&self) -> f64 {
        Self::scaled_to_active_decode_blocks_threshold(
            self.active_decode_blocks_threshold.load(Ordering::Relaxed),
        )
    }

    /// Set the active decode blocks threshold value from f64.
    pub fn set_active_decode_blocks_threshold(&self, threshold: f64) {
        self.active_decode_blocks_threshold.store(
            Self::active_decode_blocks_threshold_to_scaled(threshold),
            Ordering::Relaxed,
        );
    }

    /// Get the current active prefill tokens threshold value as u64.
    pub fn active_prefill_tokens_threshold(&self) -> u64 {
        self.active_prefill_tokens_threshold.load(Ordering::Relaxed)
154
155
    }

156
157
158
159
    /// Set the active prefill tokens threshold value from u64.
    pub fn set_active_prefill_tokens_threshold(&self, threshold: u64) {
        self.active_prefill_tokens_threshold
            .store(threshold, Ordering::Relaxed);
160
161
    }

162
    /// Get the worker load states for external access
163
    pub fn load_states(&self) -> Arc<RwLock<HashMap<u64, WorkerLoadState>>> {
164
165
        self.worker_load_states.clone()
    }
166
}
167

168
169
#[async_trait]
impl WorkerLoadMonitor for KvWorkerMonitor {
170
171
172
173
    /// Start background monitoring of worker KV cache usage.
    ///
    /// This is safe to call multiple times (e.g., from cloned monitors shared across
    /// pipelines) - only the first call spawns the background task.
174
    async fn start_monitoring(&self) -> anyhow::Result<()> {
175
176
177
178
179
180
        // Guard: only start once across all clones
        if self.started.swap(true, Ordering::SeqCst) {
            tracing::debug!("Worker monitoring already started, skipping");
            return Ok(());
        }

181
182
183
        let endpoint = &self.client.endpoint;
        let component = endpoint.component();

184
185
186
187
188
189
190
191
192
193
194
        let cancellation_token = component.drt().child_token();

        // Watch for runtime config updates from model deployment cards via discovery interface
        let discovery = component.drt().discovery();
        let discovery_stream = discovery
            .list_and_watch(DiscoveryQuery::AllModels, Some(cancellation_token.clone()))
            .await?;
        let mut config_events_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
195

196
197
198
199
200
        // Subscribe to KV metrics events using EventSubscriber (Msgpack payloads)
        let mut kv_metrics_rx =
            EventSubscriber::for_namespace(component.namespace(), KV_METRICS_SUBJECT)
                .await?
                .typed::<ActiveLoad>();
201
202
203

        let worker_load_states = self.worker_load_states.clone();
        let client = self.client.clone();
204
205
        let active_decode_blocks_threshold = self.active_decode_blocks_threshold.clone();
        let active_prefill_tokens_threshold = self.active_prefill_tokens_threshold.clone();
206
207
208
209
210
211
212
213
214
215
216
217

        // Spawn background monitoring task
        tokio::spawn(async move {
            let mut previous_busy_instances = Vec::new(); // Track previous state

            loop {
                tokio::select! {
                    _ = cancellation_token.cancelled() => {
                        tracing::debug!("Worker monitoring cancelled");
                        break;
                    }

218
                    // Handle runtime config updates
219
220
221
222
223
224
                    _ = config_events_rx.changed() => {
                        let runtime_configs = config_events_rx.borrow().clone();

                        let mut states = worker_load_states.write().unwrap();
                        states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));

Yan Ru Pei's avatar
Yan Ru Pei committed
225
226
227
228
229
230
231
232
233
234
                        // Update worker load states with total blocks for all dp_ranks
                        for (lease_id, runtime_config) in runtime_configs.iter() {
                            let state = states.entry(*lease_id).or_default();

                            // Populate total_blocks for all dp_ranks (they share the same total)
                            if let Some(total_blocks) = runtime_config.total_kv_blocks {
                                for dp_rank in 0..runtime_config.data_parallel_size {
                                    state.kv_total_blocks.insert(dp_rank, total_blocks);
                                }
                            }
235
236
237
                        }
                    }

238
                    // Handle KV metrics updates (ActiveLoad)
239
                    kv_event = kv_metrics_rx.next() => {
240
                        let Some(event_result) = kv_event else {
241
242
243
244
                            tracing::debug!("KV metrics stream closed");
                            break;
                        };

245
246
                        let Ok((_envelope, active_load)) = event_result else {
                            tracing::error!("Error receiving KV metrics event: {event_result:?}");
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
                            continue;
                        };

                        let worker_id = active_load.worker_id;
                        let dp_rank = active_load.dp_rank;

                        // Update worker load state per dp_rank
                        let mut states = worker_load_states.write().unwrap();
                        let state = states.entry(worker_id).or_default();

                        if let Some(active_blocks) = active_load.active_decode_blocks {
                            state.active_decode_blocks.insert(dp_rank, active_blocks);
                        }
                        if let Some(active_tokens) = active_load.active_prefill_tokens {
                            state.active_prefill_tokens.insert(dp_rank, active_tokens);
                        }
                        drop(states);

                        // Load thresholds dynamically - allows runtime updates
                        let current_active_decode_blocks_threshold = Self::scaled_to_active_decode_blocks_threshold(
                            active_decode_blocks_threshold.load(Ordering::Relaxed),
                        );
                        let current_active_prefill_tokens_threshold = active_prefill_tokens_threshold.load(Ordering::Relaxed);

                        // Recalculate all busy instances and update
                        let states = worker_load_states.read().unwrap();
                        let busy_instances: Vec<u64> = states
                            .iter()
                            .filter_map(|(&id, state)| {
                                state
                                    .is_busy(current_active_decode_blocks_threshold, current_active_prefill_tokens_threshold)
                                    .then_some(id)
                            })
                            .collect();
                        drop(states);

                        // Only update if busy_instances has changed
                        if busy_instances != previous_busy_instances {
                            tracing::debug!("Busy instances changed: {:?}", busy_instances);
                            client.update_free_instances(&busy_instances);
                            previous_busy_instances = busy_instances;
288
289
290
291
292
293
294
295
296
297
298
                        }
                    }
                }
            }

            tracing::info!("Worker monitoring task exiting");
        });

        Ok(())
    }
}