worker_monitor.rs 7.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

// TODO: Make load comparisons and runtime metrics a generic trait so this monitoring
// system is not tied to KV cache concepts, which are LLM-specific. This would allow
// different types of workers to define their own load metrics and busy thresholds.

use crate::component::{Client, InstanceSource};
use crate::traits::DistributedRuntimeProvider;
10
use crate::traits::events::EventSubscriber;
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
use crate::utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio::sync::watch;
use tokio_stream::StreamExt;

// Constants for monitoring configuration
const KV_METRICS_SUBJECT: &str = "kv_metrics";
const MODEL_ROOT_PATH: &str = "models";

// Internal structs for deserializing metrics events
#[derive(serde::Deserialize)]
struct LoadEvent {
    worker_id: i64,
    data: ForwardPassMetrics,
}

#[derive(serde::Deserialize)]
struct ForwardPassMetrics {
    kv_stats: KvStats,
}

#[derive(serde::Deserialize)]
struct KvStats {
    kv_active_blocks: u64,
}

#[derive(serde::Deserialize)]
struct ModelEntry {
    runtime_config: Option<RuntimeConfig>,
}

#[derive(serde::Deserialize)]
struct RuntimeConfig {
    total_kv_blocks: Option<u64>,
}

/// Worker load monitoring state
#[derive(Clone, Debug)]
pub struct WorkerLoadState {
    pub kv_active_blocks: Option<u64>,
    pub kv_total_blocks: Option<u64>,
}

impl WorkerLoadState {
    pub fn is_busy(&self, threshold: f64) -> bool {
        match (self.kv_active_blocks, self.kv_total_blocks) {
            (Some(active), Some(total)) if total > 0 => {
                (active as f64) > (threshold * total as f64)
            }
            _ => false,
        }
    }
}

/// Worker monitor for tracking KV cache usage and busy states
pub struct WorkerMonitor {
    client: Arc<Client>,
    worker_load_states: Arc<RwLock<HashMap<i64, WorkerLoadState>>>,
    busy_threshold: f64,
}

impl WorkerMonitor {
    /// Create a new worker monitor with custom threshold
    pub fn new_with_threshold(client: Arc<Client>, busy_threshold: f64) -> Self {
        Self {
            client,
            worker_load_states: Arc::new(RwLock::new(HashMap::new())),
            busy_threshold,
        }
    }

    /// Get the worker load states for external access
    pub fn load_states(&self) -> Arc<RwLock<HashMap<i64, WorkerLoadState>>> {
        self.worker_load_states.clone()
    }

    /// Start background monitoring of worker KV cache usage
    pub async fn start_monitoring(&self) -> anyhow::Result<()> {
        let endpoint = &self.client.endpoint;
        let component = endpoint.component();

        let Some(etcd_client) = component.drt().etcd_client() else {
            // Static mode, no monitoring needed
            return Ok(());
        };

        let runtime_configs_watcher = watch_prefix_with_extraction(
            etcd_client,
            MODEL_ROOT_PATH,
            key_extractors::lease_id,
            |entry: ModelEntry| entry.runtime_config.and_then(|rc| rc.total_kv_blocks),
            component.drt().child_token(),
        )
        .await?;
        let mut config_events_rx = runtime_configs_watcher.receiver();

        // Subscribe to KV metrics events
        let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?;

        let worker_load_states = self.worker_load_states.clone();
        let client = self.client.clone();
        let cancellation_token = component.drt().child_token();
        let busy_threshold = self.busy_threshold; // Capture threshold for the closure

        // Spawn background monitoring task
        tokio::spawn(async move {
            let mut previous_busy_instances = Vec::new(); // Track previous state

            loop {
                tokio::select! {
                    _ = cancellation_token.cancelled() => {
                        tracing::debug!("Worker monitoring cancelled");
                        break;
                    }

                    // Handle runtime config updates - now receives full HashMap
                    _ = config_events_rx.changed() => {
                        let runtime_configs = config_events_rx.borrow().clone();

                        let mut states = worker_load_states.write().unwrap();
                        states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));

                        // Update worker load states with total blocks
                        for (lease_id, total_blocks) in runtime_configs.iter() {
                            let state = states.entry(*lease_id).or_insert(WorkerLoadState {
                                kv_active_blocks: None,
                                kv_total_blocks: None,
                            });
                            state.kv_total_blocks = Some(*total_blocks);
                        }
                    }

                    // Handle KV metrics updates
                    kv_event = kv_metrics_rx.next() => {
                        let Some(event) = kv_event else {
                            tracing::debug!("KV metrics stream closed");
                            break;
                        };

                        if let Ok(load_event) = serde_json::from_slice::<LoadEvent>(&event.payload) {
                            let worker_id = load_event.worker_id;
                            let active_blocks = load_event.data.kv_stats.kv_active_blocks;

                            // Update worker load state
                            let mut states = worker_load_states.write().unwrap();
                            let state = states.entry(worker_id).or_insert(WorkerLoadState {
                                kv_active_blocks: None,
                                kv_total_blocks: None,
                            });
                            state.kv_active_blocks = Some(active_blocks);
                            drop(states);

                            // Recalculate all busy instances and update
                            let states = worker_load_states.read().unwrap();
                            let busy_instances: Vec<i64> = states
                                .iter()
                                .filter_map(|(&id, state)| {
                                    state.is_busy(busy_threshold).then_some(id)
                                })
                                .collect();
                            drop(states);

                            // Only update if busy_instances has changed
                            if busy_instances != previous_busy_instances {
                                tracing::debug!("Busy instances changed: {:?}", busy_instances);
                                client.update_free_instances(&busy_instances);
                                previous_busy_instances = busy_instances;
                            }
                        }
                    }
                }
            }

            tracing::info!("Worker monitoring task exiting");
        });

        Ok(())
    }
}