worker_monitor.rs 12.5 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
use crate::kv_router::KV_METRICS_SUBJECT;
5
use crate::kv_router::protocols::ActiveLoad;
6
use crate::model_card::ModelDeploymentCard;
7
use dynamo_runtime::component::Client;
8
use dynamo_runtime::discovery::{DiscoveryQuery, watch_and_extract_field};
9
10
11
use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventSubscriber;
12
use std::collections::HashMap;
13
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
14
15
16
use std::sync::{Arc, RwLock};
use tokio_stream::StreamExt;

17
18
19
/// Scale factor for storing f64 thresholds as u32 (10000 = 4 decimal places)
const THRESHOLD_SCALE: u32 = 10000;

Yan Ru Pei's avatar
Yan Ru Pei committed
20
21
/// Worker load monitoring state per dp_rank
#[derive(Clone, Debug, Default)]
22
pub struct WorkerLoadState {
23
    pub active_decode_blocks: HashMap<u32, u64>,
Yan Ru Pei's avatar
Yan Ru Pei committed
24
    pub kv_total_blocks: HashMap<u32, u64>,
25
    pub active_prefill_tokens: HashMap<u32, u64>,
26
27
28
}

impl WorkerLoadState {
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    /// Returns true if ALL dp_ranks are considered busy based on the dual-threshold logic:
    ///
    /// For each dp_rank:
    /// 1. If `active_prefill_tokens` is available, check if tokens exceed the literal threshold.
    ///    If so, that dp_rank is busy.
    /// 2. If not, check if `active_decode_blocks` and `kv_total_blocks` are both available,
    ///    and if blocks exceed threshold. If so, that dp_rank is busy.
    /// 3. If neither check can be performed (missing data), that dp_rank is considered free.
    ///
    /// The worker is busy only if ALL dp_ranks are busy.
    pub fn is_busy(
        &self,
        active_decode_blocks_threshold: f64,
        active_prefill_tokens_threshold: u64,
    ) -> bool {
        // Get all dp_ranks we know about
        let all_dp_ranks: std::collections::HashSet<_> = self
            .active_decode_blocks
Yan Ru Pei's avatar
Yan Ru Pei committed
47
            .keys()
48
49
            .chain(self.active_prefill_tokens.keys())
            .copied()
Yan Ru Pei's avatar
Yan Ru Pei committed
50
51
            .collect();

52
53
        // If no dp_ranks known, not busy
        if all_dp_ranks.is_empty() {
Yan Ru Pei's avatar
Yan Ru Pei committed
54
            return false;
55
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
56

57
58
59
60
61
62
63
64
65
66
67
68
69
        // Check if ALL dp_ranks are busy
        all_dp_ranks.iter().all(|&dp_rank| {
            // First check: prefill tokens threshold (literal token count)
            if let Some(&active_tokens) = self.active_prefill_tokens.get(&dp_rank)
                && active_tokens > active_prefill_tokens_threshold
            {
                return true; // This dp_rank is busy due to tokens
            }

            // Second check: blocks threshold
            // Skip if total_blocks is 0 (no capacity means threshold check is meaningless)
            if let (Some(&active_blocks), Some(&total_blocks)) = (
                self.active_decode_blocks.get(&dp_rank),
Yan Ru Pei's avatar
Yan Ru Pei committed
70
                self.kv_total_blocks.get(&dp_rank),
71
72
73
74
            ) && total_blocks > 0
                && (active_blocks as f64) > (active_decode_blocks_threshold * total_blocks as f64)
            {
                return true; // This dp_rank is busy due to blocks
Yan Ru Pei's avatar
Yan Ru Pei committed
75
            }
76
77
78

            // If we can't perform either check, this dp_rank is considered free
            false
Yan Ru Pei's avatar
Yan Ru Pei committed
79
        })
80
81
82
    }
}

83
84
/// Worker monitor for tracking KV cache usage and busy states.
///
85
/// Cloning shares state via internal Arc-wrapped fields. This allows multiple pipelines
86
87
/// (e.g., chat and completions) to share the same monitor instance.
#[derive(Clone)]
88
pub struct KvWorkerMonitor {
89
    client: Client,
90
    worker_load_states: Arc<RwLock<HashMap<u64, WorkerLoadState>>>,
91
92
93
94
    /// Active decode blocks threshold stored as parts-per-10000 (e.g., 8500 = 0.85)
    active_decode_blocks_threshold: Arc<AtomicU32>,
    /// Active prefill tokens threshold stored as literal token count (u64)
    active_prefill_tokens_threshold: Arc<AtomicU64>,
95
96
    /// Guard to ensure start_monitoring() only runs once across clones
    started: Arc<AtomicBool>,
97
98
}

99
impl KvWorkerMonitor {
100
    /// Create a new worker monitor with the given thresholds.
101
    ///
102
103
104
105
106
107
108
109
110
111
    /// - `active_decode_blocks_threshold` (0.0-1.0): Threshold percentage for KV cache block utilization
    /// - `active_prefill_tokens_threshold`: Literal token count threshold for prefill token utilization
    ///
    /// Both thresholds can be dynamically updated via `set_active_decode_blocks_threshold()` and
    /// `set_active_prefill_tokens_threshold()`.
    pub fn new(
        client: Client,
        active_decode_blocks_threshold: f64,
        active_prefill_tokens_threshold: u64,
    ) -> Self {
112
113
114
        Self {
            client,
            worker_load_states: Arc::new(RwLock::new(HashMap::new())),
115
116
117
118
119
120
            active_decode_blocks_threshold: Arc::new(AtomicU32::new(
                Self::active_decode_blocks_threshold_to_scaled(active_decode_blocks_threshold),
            )),
            active_prefill_tokens_threshold: Arc::new(AtomicU64::new(
                active_prefill_tokens_threshold,
            )),
121
            started: Arc::new(AtomicBool::new(false)),
122
123
124
        }
    }

125
    /// Convert a f64 active decode blocks threshold (0.0-1.0) to scaled u32 for atomic storage.
126
    #[inline]
127
    fn active_decode_blocks_threshold_to_scaled(threshold: f64) -> u32 {
128
129
130
        (threshold * THRESHOLD_SCALE as f64) as u32
    }

131
    /// Convert a scaled u32 back to f64 active decode blocks threshold (0.0-1.0).
132
    #[inline]
133
    fn scaled_to_active_decode_blocks_threshold(scaled: u32) -> f64 {
134
135
136
        scaled as f64 / THRESHOLD_SCALE as f64
    }

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    /// Get the current active decode blocks threshold value as f64.
    pub fn active_decode_blocks_threshold(&self) -> f64 {
        Self::scaled_to_active_decode_blocks_threshold(
            self.active_decode_blocks_threshold.load(Ordering::Relaxed),
        )
    }

    /// Set the active decode blocks threshold value from f64.
    pub fn set_active_decode_blocks_threshold(&self, threshold: f64) {
        self.active_decode_blocks_threshold.store(
            Self::active_decode_blocks_threshold_to_scaled(threshold),
            Ordering::Relaxed,
        );
    }

    /// Get the current active prefill tokens threshold value as u64.
    pub fn active_prefill_tokens_threshold(&self) -> u64 {
        self.active_prefill_tokens_threshold.load(Ordering::Relaxed)
155
156
    }

157
158
159
160
    /// Set the active prefill tokens threshold value from u64.
    pub fn set_active_prefill_tokens_threshold(&self, threshold: u64) {
        self.active_prefill_tokens_threshold
            .store(threshold, Ordering::Relaxed);
161
162
    }

163
    /// Get the worker load states for external access
164
    pub fn load_states(&self) -> Arc<RwLock<HashMap<u64, WorkerLoadState>>> {
165
166
        self.worker_load_states.clone()
    }
167
}
168

169
170
#[async_trait]
impl WorkerLoadMonitor for KvWorkerMonitor {
171
172
173
174
    /// Start background monitoring of worker KV cache usage.
    ///
    /// This is safe to call multiple times (e.g., from cloned monitors shared across
    /// pipelines) - only the first call spawns the background task.
175
    async fn start_monitoring(&self) -> anyhow::Result<()> {
176
177
178
179
180
181
        // Guard: only start once across all clones
        if self.started.swap(true, Ordering::SeqCst) {
            tracing::debug!("Worker monitoring already started, skipping");
            return Ok(());
        }

182
183
184
        let endpoint = &self.client.endpoint;
        let component = endpoint.component();

185
186
187
188
189
190
191
192
193
194
195
        let cancellation_token = component.drt().child_token();

        // Watch for runtime config updates from model deployment cards via discovery interface
        let discovery = component.drt().discovery();
        let discovery_stream = discovery
            .list_and_watch(DiscoveryQuery::AllModels, Some(cancellation_token.clone()))
            .await?;
        let mut config_events_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
196
197
198
199
200
201

        // Subscribe to KV metrics events
        let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?;

        let worker_load_states = self.worker_load_states.clone();
        let client = self.client.clone();
202
203
        let active_decode_blocks_threshold = self.active_decode_blocks_threshold.clone();
        let active_prefill_tokens_threshold = self.active_prefill_tokens_threshold.clone();
204
205
206
207
208
209
210
211
212
213
214
215

        // Spawn background monitoring task
        tokio::spawn(async move {
            let mut previous_busy_instances = Vec::new(); // Track previous state

            loop {
                tokio::select! {
                    _ = cancellation_token.cancelled() => {
                        tracing::debug!("Worker monitoring cancelled");
                        break;
                    }

216
                    // Handle runtime config updates
217
218
219
220
221
222
                    _ = config_events_rx.changed() => {
                        let runtime_configs = config_events_rx.borrow().clone();

                        let mut states = worker_load_states.write().unwrap();
                        states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));

Yan Ru Pei's avatar
Yan Ru Pei committed
223
224
225
226
227
228
229
230
231
232
                        // Update worker load states with total blocks for all dp_ranks
                        for (lease_id, runtime_config) in runtime_configs.iter() {
                            let state = states.entry(*lease_id).or_default();

                            // Populate total_blocks for all dp_ranks (they share the same total)
                            if let Some(total_blocks) = runtime_config.total_kv_blocks {
                                for dp_rank in 0..runtime_config.data_parallel_size {
                                    state.kv_total_blocks.insert(dp_rank, total_blocks);
                                }
                            }
233
234
235
                        }
                    }

236
                    // Handle KV metrics updates (ActiveLoad)
237
238
239
240
241
242
                    kv_event = kv_metrics_rx.next() => {
                        let Some(event) = kv_event else {
                            tracing::debug!("KV metrics stream closed");
                            break;
                        };

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
                        let Ok(active_load) = serde_json::from_slice::<ActiveLoad>(&event.payload) else {
                            continue;
                        };

                        let worker_id = active_load.worker_id;
                        let dp_rank = active_load.dp_rank;

                        // Update worker load state per dp_rank
                        let mut states = worker_load_states.write().unwrap();
                        let state = states.entry(worker_id).or_default();

                        if let Some(active_blocks) = active_load.active_decode_blocks {
                            state.active_decode_blocks.insert(dp_rank, active_blocks);
                        }
                        if let Some(active_tokens) = active_load.active_prefill_tokens {
                            state.active_prefill_tokens.insert(dp_rank, active_tokens);
                        }
                        drop(states);

                        // Load thresholds dynamically - allows runtime updates
                        let current_active_decode_blocks_threshold = Self::scaled_to_active_decode_blocks_threshold(
                            active_decode_blocks_threshold.load(Ordering::Relaxed),
                        );
                        let current_active_prefill_tokens_threshold = active_prefill_tokens_threshold.load(Ordering::Relaxed);

                        // Recalculate all busy instances and update
                        let states = worker_load_states.read().unwrap();
                        let busy_instances: Vec<u64> = states
                            .iter()
                            .filter_map(|(&id, state)| {
                                state
                                    .is_busy(current_active_decode_blocks_threshold, current_active_prefill_tokens_threshold)
                                    .then_some(id)
                            })
                            .collect();
                        drop(states);

                        // Only update if busy_instances has changed
                        if busy_instances != previous_busy_instances {
                            tracing::debug!("Busy instances changed: {:?}", busy_instances);
                            client.update_free_instances(&busy_instances);
                            previous_busy_instances = busy_instances;
285
286
287
288
289
290
291
292
293
294
295
                        }
                    }
                }
            }

            tracing::info!("Worker monitoring task exiting");
        });

        Ok(())
    }
}