worker_monitor.rs 12.5 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
5
6
7
8
9
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};

use dashmap::DashMap;

10
use crate::kv_router::KV_METRICS_SUBJECT;
11
use crate::kv_router::protocols::ActiveLoad;
12
use crate::model_card::ModelDeploymentCard;
13
use dynamo_runtime::component::Client;
14
use dynamo_runtime::discovery::{DiscoveryQuery, watch_and_extract_field};
15
16
use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider;
17
use dynamo_runtime::transports::event_plane::EventSubscriber;
18

19
20
21
/// Scale factor for storing f64 thresholds as u32 (10000 = 4 decimal places)
const THRESHOLD_SCALE: u32 = 10000;

Yan Ru Pei's avatar
Yan Ru Pei committed
22
23
/// Worker load monitoring state per dp_rank
#[derive(Clone, Debug, Default)]
24
pub struct WorkerLoadState {
25
    pub active_decode_blocks: HashMap<u32, u64>,
Yan Ru Pei's avatar
Yan Ru Pei committed
26
    pub kv_total_blocks: HashMap<u32, u64>,
27
    pub active_prefill_tokens: HashMap<u32, u64>,
28
29
30
}

impl WorkerLoadState {
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    /// Returns true if ALL dp_ranks are considered busy based on the dual-threshold logic:
    ///
    /// For each dp_rank:
    /// 1. If `active_prefill_tokens` is available, check if tokens exceed the literal threshold.
    ///    If so, that dp_rank is busy.
    /// 2. If not, check if `active_decode_blocks` and `kv_total_blocks` are both available,
    ///    and if blocks exceed threshold. If so, that dp_rank is busy.
    /// 3. If neither check can be performed (missing data), that dp_rank is considered free.
    ///
    /// The worker is busy only if ALL dp_ranks are busy.
    pub fn is_busy(
        &self,
        active_decode_blocks_threshold: f64,
        active_prefill_tokens_threshold: u64,
    ) -> bool {
        // Get all dp_ranks we know about
        let all_dp_ranks: std::collections::HashSet<_> = self
            .active_decode_blocks
Yan Ru Pei's avatar
Yan Ru Pei committed
49
            .keys()
50
51
            .chain(self.active_prefill_tokens.keys())
            .copied()
Yan Ru Pei's avatar
Yan Ru Pei committed
52
53
            .collect();

54
55
        // If no dp_ranks known, not busy
        if all_dp_ranks.is_empty() {
Yan Ru Pei's avatar
Yan Ru Pei committed
56
            return false;
57
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
58

59
60
61
62
63
64
65
66
67
68
69
70
71
        // Check if ALL dp_ranks are busy
        all_dp_ranks.iter().all(|&dp_rank| {
            // First check: prefill tokens threshold (literal token count)
            if let Some(&active_tokens) = self.active_prefill_tokens.get(&dp_rank)
                && active_tokens > active_prefill_tokens_threshold
            {
                return true; // This dp_rank is busy due to tokens
            }

            // Second check: blocks threshold
            // Skip if total_blocks is 0 (no capacity means threshold check is meaningless)
            if let (Some(&active_blocks), Some(&total_blocks)) = (
                self.active_decode_blocks.get(&dp_rank),
Yan Ru Pei's avatar
Yan Ru Pei committed
72
                self.kv_total_blocks.get(&dp_rank),
73
74
75
76
            ) && total_blocks > 0
                && (active_blocks as f64) > (active_decode_blocks_threshold * total_blocks as f64)
            {
                return true; // This dp_rank is busy due to blocks
Yan Ru Pei's avatar
Yan Ru Pei committed
77
            }
78
79
80

            // If we can't perform either check, this dp_rank is considered free
            false
Yan Ru Pei's avatar
Yan Ru Pei committed
81
        })
82
83
84
    }
}

85
86
/// Worker monitor for tracking KV cache usage and busy states.
///
87
/// Cloning shares state via internal Arc-wrapped fields. This allows multiple pipelines
88
89
/// (e.g., chat and completions) to share the same monitor instance.
#[derive(Clone)]
90
pub struct KvWorkerMonitor {
91
    client: Client,
92
    worker_load_states: Arc<DashMap<u64, WorkerLoadState>>,
93
94
95
96
    /// Active decode blocks threshold stored as parts-per-10000 (e.g., 8500 = 0.85)
    active_decode_blocks_threshold: Arc<AtomicU32>,
    /// Active prefill tokens threshold stored as literal token count (u64)
    active_prefill_tokens_threshold: Arc<AtomicU64>,
97
98
    /// Guard to ensure start_monitoring() only runs once across clones
    started: Arc<AtomicBool>,
99
100
}

101
impl KvWorkerMonitor {
102
    /// Create a new worker monitor with the given thresholds.
103
    ///
104
105
106
107
108
109
110
111
112
113
    /// - `active_decode_blocks_threshold` (0.0-1.0): Threshold percentage for KV cache block utilization
    /// - `active_prefill_tokens_threshold`: Literal token count threshold for prefill token utilization
    ///
    /// Both thresholds can be dynamically updated via `set_active_decode_blocks_threshold()` and
    /// `set_active_prefill_tokens_threshold()`.
    pub fn new(
        client: Client,
        active_decode_blocks_threshold: f64,
        active_prefill_tokens_threshold: u64,
    ) -> Self {
114
115
        Self {
            client,
116
            worker_load_states: Arc::new(DashMap::new()),
117
118
119
120
121
122
            active_decode_blocks_threshold: Arc::new(AtomicU32::new(
                Self::active_decode_blocks_threshold_to_scaled(active_decode_blocks_threshold),
            )),
            active_prefill_tokens_threshold: Arc::new(AtomicU64::new(
                active_prefill_tokens_threshold,
            )),
123
            started: Arc::new(AtomicBool::new(false)),
124
125
126
        }
    }

127
    /// Convert a f64 active decode blocks threshold (0.0-1.0) to scaled u32 for atomic storage.
128
    #[inline]
129
    fn active_decode_blocks_threshold_to_scaled(threshold: f64) -> u32 {
130
131
132
        (threshold * THRESHOLD_SCALE as f64) as u32
    }

133
    /// Convert a scaled u32 back to f64 active decode blocks threshold (0.0-1.0).
134
    #[inline]
135
    fn scaled_to_active_decode_blocks_threshold(scaled: u32) -> f64 {
136
137
138
        scaled as f64 / THRESHOLD_SCALE as f64
    }

139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    /// Get the current active decode blocks threshold value as f64.
    pub fn active_decode_blocks_threshold(&self) -> f64 {
        Self::scaled_to_active_decode_blocks_threshold(
            self.active_decode_blocks_threshold.load(Ordering::Relaxed),
        )
    }

    /// Set the active decode blocks threshold value from f64.
    pub fn set_active_decode_blocks_threshold(&self, threshold: f64) {
        self.active_decode_blocks_threshold.store(
            Self::active_decode_blocks_threshold_to_scaled(threshold),
            Ordering::Relaxed,
        );
    }

    /// Get the current active prefill tokens threshold value as u64.
    pub fn active_prefill_tokens_threshold(&self) -> u64 {
        self.active_prefill_tokens_threshold.load(Ordering::Relaxed)
157
158
    }

159
160
161
162
    /// Set the active prefill tokens threshold value from u64.
    pub fn set_active_prefill_tokens_threshold(&self, threshold: u64) {
        self.active_prefill_tokens_threshold
            .store(threshold, Ordering::Relaxed);
163
164
    }

165
    /// Get the worker load states for external access
166
    pub fn load_states(&self) -> Arc<DashMap<u64, WorkerLoadState>> {
167
168
        self.worker_load_states.clone()
    }
169
}
170

171
172
#[async_trait]
impl WorkerLoadMonitor for KvWorkerMonitor {
173
174
175
176
    /// Start background monitoring of worker KV cache usage.
    ///
    /// This is safe to call multiple times (e.g., from cloned monitors shared across
    /// pipelines) - only the first call spawns the background task.
177
    async fn start_monitoring(&self) -> anyhow::Result<()> {
178
179
180
181
182
183
        // Guard: only start once across all clones
        if self.started.swap(true, Ordering::SeqCst) {
            tracing::debug!("Worker monitoring already started, skipping");
            return Ok(());
        }

184
185
186
        let endpoint = &self.client.endpoint;
        let component = endpoint.component();

187
188
189
190
191
192
193
194
195
196
197
        let cancellation_token = component.drt().child_token();

        // Watch for runtime config updates from model deployment cards via discovery interface
        let discovery = component.drt().discovery();
        let discovery_stream = discovery
            .list_and_watch(DiscoveryQuery::AllModels, Some(cancellation_token.clone()))
            .await?;
        let mut config_events_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
198

199
200
201
202
203
        // Subscribe to KV metrics events using EventSubscriber (Msgpack payloads)
        let mut kv_metrics_rx =
            EventSubscriber::for_namespace(component.namespace(), KV_METRICS_SUBJECT)
                .await?
                .typed::<ActiveLoad>();
204
205
206

        let worker_load_states = self.worker_load_states.clone();
        let client = self.client.clone();
207
208
        let active_decode_blocks_threshold = self.active_decode_blocks_threshold.clone();
        let active_prefill_tokens_threshold = self.active_prefill_tokens_threshold.clone();
209
210
211
212
213
214
215
216
217
218
219
220

        // Spawn background monitoring task
        tokio::spawn(async move {
            let mut previous_busy_instances = Vec::new(); // Track previous state

            loop {
                tokio::select! {
                    _ = cancellation_token.cancelled() => {
                        tracing::debug!("Worker monitoring cancelled");
                        break;
                    }

221
                    // Handle runtime config updates
222
223
224
                    _ = config_events_rx.changed() => {
                        let runtime_configs = config_events_rx.borrow().clone();

225
                        worker_load_states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));
226

Yan Ru Pei's avatar
Yan Ru Pei committed
227
228
                        // Update worker load states with total blocks for all dp_ranks
                        for (lease_id, runtime_config) in runtime_configs.iter() {
229
                            let mut state = worker_load_states.entry(*lease_id).or_default();
Yan Ru Pei's avatar
Yan Ru Pei committed
230
231
232
233
234
235
236

                            // Populate total_blocks for all dp_ranks (they share the same total)
                            if let Some(total_blocks) = runtime_config.total_kv_blocks {
                                for dp_rank in 0..runtime_config.data_parallel_size {
                                    state.kv_total_blocks.insert(dp_rank, total_blocks);
                                }
                            }
237
238
239
                        }
                    }

240
                    // Handle KV metrics updates (ActiveLoad)
241
                    kv_event = kv_metrics_rx.next() => {
242
                        let Some(event_result) = kv_event else {
243
244
245
246
                            tracing::debug!("KV metrics stream closed");
                            break;
                        };

247
248
                        let Ok((_envelope, active_load)) = event_result else {
                            tracing::error!("Error receiving KV metrics event: {event_result:?}");
249
250
251
252
253
254
255
                            continue;
                        };

                        let worker_id = active_load.worker_id;
                        let dp_rank = active_load.dp_rank;

                        // Update worker load state per dp_rank
256
257
                        {
                            let mut state = worker_load_states.entry(worker_id).or_default();
258

259
260
261
262
263
264
                            if let Some(active_blocks) = active_load.active_decode_blocks {
                                state.active_decode_blocks.insert(dp_rank, active_blocks);
                            }
                            if let Some(active_tokens) = active_load.active_prefill_tokens {
                                state.active_prefill_tokens.insert(dp_rank, active_tokens);
                            }
265
266
267
268
269
270
271
272
273
                        }

                        // Load thresholds dynamically - allows runtime updates
                        let current_active_decode_blocks_threshold = Self::scaled_to_active_decode_blocks_threshold(
                            active_decode_blocks_threshold.load(Ordering::Relaxed),
                        );
                        let current_active_prefill_tokens_threshold = active_prefill_tokens_threshold.load(Ordering::Relaxed);

                        // Recalculate all busy instances and update
274
                        let busy_instances: Vec<u64> = worker_load_states
275
                            .iter()
276
277
                            .filter_map(|r| {
                                r.value()
278
                                    .is_busy(current_active_decode_blocks_threshold, current_active_prefill_tokens_threshold)
279
                                    .then_some(*r.key())
280
281
282
283
284
285
286
287
                            })
                            .collect();

                        // Only update if busy_instances has changed
                        if busy_instances != previous_busy_instances {
                            tracing::debug!("Busy instances changed: {:?}", busy_instances);
                            client.update_free_instances(&busy_instances);
                            previous_busy_instances = busy_instances;
288
289
290
291
292
293
294
295
296
297
298
                        }
                    }
                }
            }

            tracing::info!("Worker monitoring task exiting");
        });

        Ok(())
    }
}