worker_monitor.rs 7.25 KB
Newer Older
1
2
3
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

4
5
6
7
8
9
10
11
use crate::kv_router::KV_METRICS_SUBJECT;
use crate::kv_router::scoring::LoadEvent;
use crate::model_card::{self, ModelDeploymentCard};
use dynamo_runtime::component::Client;
use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventSubscriber;
use dynamo_runtime::utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction};
12
13
14
15
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use tokio_stream::StreamExt;

Yan Ru Pei's avatar
Yan Ru Pei committed
16
17
/// Worker load monitoring state per dp_rank
#[derive(Clone, Debug, Default)]
18
pub struct WorkerLoadState {
Yan Ru Pei's avatar
Yan Ru Pei committed
19
20
    pub kv_active_blocks: HashMap<u32, u64>,
    pub kv_total_blocks: HashMap<u32, u64>,
21
22
23
}

impl WorkerLoadState {
Yan Ru Pei's avatar
Yan Ru Pei committed
24
    /// Returns true if ALL dp_ranks (that have data in both maps) exceed the threshold
25
    pub fn is_busy(&self, threshold: f64) -> bool {
Yan Ru Pei's avatar
Yan Ru Pei committed
26
27
28
29
30
31
32
33
34
35
        // Get all dp_ranks that exist in both active and total blocks
        let common_dp_ranks: Vec<_> = self
            .kv_active_blocks
            .keys()
            .filter(|dp_rank| self.kv_total_blocks.contains_key(dp_rank))
            .collect();

        // If no common dp_ranks, not busy
        if common_dp_ranks.is_empty() {
            return false;
36
        }
Yan Ru Pei's avatar
Yan Ru Pei committed
37
38
39
40
41
42
43
44
45
46
47
48

        // Check if ALL common dp_ranks exceed threshold
        common_dp_ranks.iter().all(|&&dp_rank| {
            if let (Some(&active), Some(&total)) = (
                self.kv_active_blocks.get(&dp_rank),
                self.kv_total_blocks.get(&dp_rank),
            ) {
                total > 0 && (active as f64) > (threshold * total as f64)
            } else {
                false
            }
        })
49
50
51
52
    }
}

/// Worker monitor for tracking KV cache usage and busy states
53
pub struct KvWorkerMonitor {
54
    client: Arc<Client>,
55
    worker_load_states: Arc<RwLock<HashMap<u64, WorkerLoadState>>>,
56
57
58
    busy_threshold: f64,
}

59
impl KvWorkerMonitor {
60
    /// Create a new worker monitor with custom threshold
61
    pub fn new(client: Arc<Client>, busy_threshold: f64) -> Self {
62
63
64
65
66
67
68
69
        Self {
            client,
            worker_load_states: Arc::new(RwLock::new(HashMap::new())),
            busy_threshold,
        }
    }

    /// Get the worker load states for external access
70
    pub fn load_states(&self) -> Arc<RwLock<HashMap<u64, WorkerLoadState>>> {
71
72
        self.worker_load_states.clone()
    }
73
}
74

75
76
#[async_trait]
impl WorkerLoadMonitor for KvWorkerMonitor {
77
    /// Start background monitoring of worker KV cache usage
78
    async fn start_monitoring(&self) -> anyhow::Result<()> {
79
80
81
82
83
84
85
86
        let endpoint = &self.client.endpoint;
        let component = endpoint.component();

        let Some(etcd_client) = component.drt().etcd_client() else {
            // Static mode, no monitoring needed
            return Ok(());
        };

87
        // Watch for runtime config updates from model deployment cards
88
89
        let runtime_configs_watcher = watch_prefix_with_extraction(
            etcd_client,
90
            model_card::ROOT_PATH,
91
            key_extractors::lease_id,
92
            |card: ModelDeploymentCard| Some(card.runtime_config),
93
94
95
96
97
98
99
100
101
102
103
            component.drt().child_token(),
        )
        .await?;
        let mut config_events_rx = runtime_configs_watcher.receiver();

        // Subscribe to KV metrics events
        let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?;

        let worker_load_states = self.worker_load_states.clone();
        let client = self.client.clone();
        let cancellation_token = component.drt().child_token();
104
        let busy_threshold = self.busy_threshold;
105
106
107
108
109
110
111
112
113
114
115
116

        // Spawn background monitoring task
        tokio::spawn(async move {
            let mut previous_busy_instances = Vec::new(); // Track previous state

            loop {
                tokio::select! {
                    _ = cancellation_token.cancelled() => {
                        tracing::debug!("Worker monitoring cancelled");
                        break;
                    }

117
                    // Handle runtime config updates
118
119
120
121
122
123
                    _ = config_events_rx.changed() => {
                        let runtime_configs = config_events_rx.borrow().clone();

                        let mut states = worker_load_states.write().unwrap();
                        states.retain(|lease_id, _| runtime_configs.contains_key(lease_id));

Yan Ru Pei's avatar
Yan Ru Pei committed
124
125
126
127
128
129
130
131
132
133
                        // Update worker load states with total blocks for all dp_ranks
                        for (lease_id, runtime_config) in runtime_configs.iter() {
                            let state = states.entry(*lease_id).or_default();

                            // Populate total_blocks for all dp_ranks (they share the same total)
                            if let Some(total_blocks) = runtime_config.total_kv_blocks {
                                for dp_rank in 0..runtime_config.data_parallel_size {
                                    state.kv_total_blocks.insert(dp_rank, total_blocks);
                                }
                            }
134
135
136
137
138
139
140
141
142
143
144
145
146
                        }
                    }

                    // Handle KV metrics updates
                    kv_event = kv_metrics_rx.next() => {
                        let Some(event) = kv_event else {
                            tracing::debug!("KV metrics stream closed");
                            break;
                        };

                        if let Ok(load_event) = serde_json::from_slice::<LoadEvent>(&event.payload) {
                            let worker_id = load_event.worker_id;
                            let active_blocks = load_event.data.kv_stats.kv_active_blocks;
Yan Ru Pei's avatar
Yan Ru Pei committed
147
                            let dp_rank = load_event.data.worker_stats.data_parallel_rank.unwrap_or(0);
148

Yan Ru Pei's avatar
Yan Ru Pei committed
149
                            // Update worker load state per dp_rank
150
                            let mut states = worker_load_states.write().unwrap();
Yan Ru Pei's avatar
Yan Ru Pei committed
151
152
                            let state = states.entry(worker_id).or_default();
                            state.kv_active_blocks.insert(dp_rank, active_blocks);
153
154
155
156
                            drop(states);

                            // Recalculate all busy instances and update
                            let states = worker_load_states.read().unwrap();
157
                            let busy_instances: Vec<u64> = states
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
                                .iter()
                                .filter_map(|(&id, state)| {
                                    state.is_busy(busy_threshold).then_some(id)
                                })
                                .collect();
                            drop(states);

                            // Only update if busy_instances has changed
                            if busy_instances != previous_busy_instances {
                                tracing::debug!("Busy instances changed: {:?}", busy_instances);
                                client.update_free_instances(&busy_instances);
                                previous_busy_instances = busy_instances;
                            }
                        }
                    }
                }
            }

            tracing::info!("Worker monitoring task exiting");
        });

        Ok(())
    }
}