worker_monitor.rs 9.27 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
use crate::kv_router::KV_METRICS_SUBJECT;
use crate::kv_router::scoring::LoadEvent;
6
use crate::model_card::ModelDeploymentCard;
7
use dynamo_runtime::component::Client;
8
use dynamo_runtime::discovery::{DiscoveryQuery, watch_and_extract_field};
9
10
11
use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventSubscriber;
12
use std::collections::HashMap;
13
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
14
15
16
use std::sync::{Arc, RwLock};
use tokio_stream::StreamExt;

17
18
19
/// Scale factor for storing f64 thresholds as u32 (10000 = 4 decimal places)
const THRESHOLD_SCALE: u32 = 10000;

Yan Ru Pei's avatar
Yan Ru Pei committed
20
21
/// Worker load monitoring state per dp_rank
#[derive(Clone, Debug, Default)]
22
pub struct WorkerLoadState {
Yan Ru Pei's avatar
Yan Ru Pei committed
23
24
    pub kv_active_blocks: HashMap<u32, u64>,
    pub kv_total_blocks: HashMap<u32, u64>,
25
26
27
}

impl WorkerLoadState {
Yan Ru Pei's avatar
Yan Ru Pei committed
28
    /// Returns true if ALL dp_ranks (that have data in both maps) exceed the threshold
29
    pub fn is_busy(&self, threshold: f64) -> bool {
Yan Ru Pei's avatar
Yan Ru Pei committed
30
31
32
33
34
35
36
37
38
39
        // Get all dp_ranks that exist in both active and total blocks
        let common_dp_ranks: Vec<_> = self
            .kv_active_blocks
            .keys()
            .filter(|dp_rank| self.kv_total_blocks.contains_key(dp_rank))
            .collect();

        // If no common dp_ranks, not busy
        if common_dp_ranks.is_empty() {
            return false;
40
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
41
42
43
44
45
46
47
48
49
50
51
52

        // Check if ALL common dp_ranks exceed threshold
        common_dp_ranks.iter().all(|&&dp_rank| {
            if let (Some(&active), Some(&total)) = (
                self.kv_active_blocks.get(&dp_rank),
                self.kv_total_blocks.get(&dp_rank),
            ) {
                total > 0 && (active as f64) > (threshold * total as f64)
            } else {
                false
            }
        })
53
54
55
    }
}

56
57
58
59
60
/// Worker monitor for tracking KV cache usage and busy states.
///
/// All fields are `Arc`, so cloning shares state. This allows multiple pipelines
/// (e.g., chat and completions) to share the same monitor instance.
#[derive(Clone)]
61
pub struct KvWorkerMonitor {
62
    client: Arc<Client>,
63
    worker_load_states: Arc<RwLock<HashMap<u64, WorkerLoadState>>>,
64
65
66
67
    /// Threshold stored as parts-per-10000 (e.g., 8500 = 0.85)
    busy_threshold: Arc<AtomicU32>,
    /// Guard to ensure start_monitoring() only runs once across clones
    started: Arc<AtomicBool>,
68
69
}

70
impl KvWorkerMonitor {
71
72
73
74
75
    /// Create a new worker monitor with the given threshold.
    ///
    /// The threshold (0.0-1.0) controls when workers are considered busy based on
    /// KV cache utilization. It can be dynamically updated via `set_threshold()`.
    pub fn new(client: Arc<Client>, threshold: f64) -> Self {
76
77
78
        Self {
            client,
            worker_load_states: Arc::new(RwLock::new(HashMap::new())),
79
80
            busy_threshold: Arc::new(AtomicU32::new(Self::threshold_to_scaled(threshold))),
            started: Arc::new(AtomicBool::new(false)),
81
82
83
        }
    }

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    /// Convert a f64 threshold (0.0-1.0) to scaled u32 for atomic storage.
    #[inline]
    fn threshold_to_scaled(threshold: f64) -> u32 {
        (threshold * THRESHOLD_SCALE as f64) as u32
    }

    /// Convert a scaled u32 back to f64 threshold (0.0-1.0).
    #[inline]
    fn scaled_to_threshold(scaled: u32) -> f64 {
        scaled as f64 / THRESHOLD_SCALE as f64
    }

    /// Get the current threshold value as f64.
    pub fn threshold(&self) -> f64 {
        Self::scaled_to_threshold(self.busy_threshold.load(Ordering::Relaxed))
    }

    /// Set the threshold value from f64.
    pub fn set_threshold(&self, threshold: f64) {
        self.busy_threshold
            .store(Self::threshold_to_scaled(threshold), Ordering::Relaxed);
    }

107
    /// Get the worker load states for external access
108
    pub fn load_states(&self) -> Arc<RwLock<HashMap<u64, WorkerLoadState>>> {
109
110
        self.worker_load_states.clone()
    }
111
}
112

113
114
#[async_trait]
impl WorkerLoadMonitor for KvWorkerMonitor {
115
116
117
118
    /// Start background monitoring of worker KV cache usage.
    ///
    /// This is safe to call multiple times (e.g., from cloned monitors shared across
    /// pipelines) - only the first call spawns the background task.
119
    async fn start_monitoring(&self) -> anyhow::Result<()> {
120
121
122
123
124
125
        // Guard: only start once across all clones
        if self.started.swap(true, Ordering::SeqCst) {
            tracing::debug!("Worker monitoring already started, skipping");
            return Ok(());
        }

126
127
128
        let endpoint = &self.client.endpoint;
        let component = endpoint.component();

129
130
131
132
133
134
135
136
137
138
139
        let cancellation_token = component.drt().child_token();

        // Watch for runtime config updates from model deployment cards via discovery interface
        let discovery = component.drt().discovery();
        let discovery_stream = discovery
            .list_and_watch(DiscoveryQuery::AllModels, Some(cancellation_token.clone()))
            .await?;
        let mut config_events_rx =
            watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
                card.runtime_config
            });
140
141
142
143
144
145

        // Subscribe to KV metrics events
        let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?;

        let worker_load_states = self.worker_load_states.clone();
        let client = self.client.clone();
146
        let busy_threshold = self.busy_threshold.clone();
147
148
149
150
151
152
153
154
155
156
157
158

        // Spawn background monitoring task
        tokio::spawn(async move {
            let mut previous_busy_instances = Vec::new(); // Track previous state

            loop {
                tokio::select! {
                    _ = cancellation_token.cancelled() => {
                        tracing::debug!("Worker monitoring cancelled");
                        break;
                    }

159
                    // Handle runtime config updates
160
161
162
163
164
165
                    _ = config_events_rx.changed() => {
                        let runtime_configs = config_events_rx.borrow().clone();

                        let mut states = worker_load_states.write().unwrap();
                        states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));

Yan Ru Pei's avatar
Yan Ru Pei committed
166
167
168
169
170
171
172
173
174
175
                        // Update worker load states with total blocks for all dp_ranks
                        for (lease_id, runtime_config) in runtime_configs.iter() {
                            let state = states.entry(*lease_id).or_default();

                            // Populate total_blocks for all dp_ranks (they share the same total)
                            if let Some(total_blocks) = runtime_config.total_kv_blocks {
                                for dp_rank in 0..runtime_config.data_parallel_size {
                                    state.kv_total_blocks.insert(dp_rank, total_blocks);
                                }
                            }
176
177
178
179
180
181
182
183
184
185
186
187
188
                        }
                    }

                    // Handle KV metrics updates
                    kv_event = kv_metrics_rx.next() => {
                        let Some(event) = kv_event else {
                            tracing::debug!("KV metrics stream closed");
                            break;
                        };

                        if let Ok(load_event) = serde_json::from_slice::<LoadEvent>(&event.payload) {
                            let worker_id = load_event.worker_id;
                            let active_blocks = load_event.data.kv_stats.kv_active_blocks;
Yan Ru Pei's avatar
Yan Ru Pei committed
189
                            let dp_rank = load_event.data.worker_stats.data_parallel_rank.unwrap_or(0);
190

Yan Ru Pei's avatar
Yan Ru Pei committed
191
                            // Update worker load state per dp_rank
192
                            let mut states = worker_load_states.write().unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
193
194
                            let state = states.entry(worker_id).or_default();
                            state.kv_active_blocks.insert(dp_rank, active_blocks);
195
196
                            drop(states);

197
198
199
200
                            // Load threshold dynamically - allows runtime updates
                            let scaled_threshold = busy_threshold.load(Ordering::Relaxed);
                            let current_threshold = Self::scaled_to_threshold(scaled_threshold);

201
202
                            // Recalculate all busy instances and update
                            let states = worker_load_states.read().unwrap();
203
                            let busy_instances: Vec<u64> = states
204
205
                                .iter()
                                .filter_map(|(&id, state)| {
206
                                    state.is_busy(current_threshold).then_some(id)
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
                                })
                                .collect();
                            drop(states);

                            // Only update if busy_instances has changed
                            if busy_instances != previous_busy_instances {
                                tracing::debug!("Busy instances changed: {:?}", busy_instances);
                                client.update_free_instances(&busy_instances);
                                previous_busy_instances = busy_instances;
                            }
                        }
                    }
                }
            }

            tracing::info!("Worker monitoring task exiting");
        });

        Ok(())
    }
}