scheduler.rs 11.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Neelay Shah's avatar
Neelay Shah committed
16
17
use dynamo_runtime::component::Namespace;
use dynamo_runtime::traits::events::EventPublisher;
18
use rand::Rng;
19
20
use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
21
use std::collections::HashMap;
22
23

use crate::kv_router::indexer::OverlapScores;
24
pub use crate::kv_router::protocols::ForwardPassMetrics;
25
use crate::kv_router::scoring::ProcessedEndpoints;
26
use crate::kv_router::KV_HIT_RATE_SUBJECT;
27

28
29
30
use super::protocols::WorkerSelectionResult;
use super::WorkerSelector;

31
32
33
34
35
36
37
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent {
    pub worker_id: i64,
    pub isl_blocks: usize,
    pub overlap_blocks: usize,
}

38
39
40
41
42
43
44
45
46
47
48
49
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
    #[error("no endpoints aviailable to route work")]
    NoEndpoints,

    #[error("all workers busy")]
    AllWorkersBusy,

    #[error("endpoint subscriber shutdown")]
    SubscriberShutdown,
}

50
51
/// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
/// is cleaned (not optional)
52
53
54
55
56
57
58
59
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint {
    pub name: String,
    pub subject: String,
    pub data: ForwardPassMetrics,
}

impl Endpoint {
GuanLuo's avatar
GuanLuo committed
60
61
    pub fn worker_id(&self) -> i64 {
        i64::from_str_radix(
62
            self.subject
GuanLuo's avatar
GuanLuo committed
63
                .split("-")
64
65
66
67
                .last()
                .expect("invalid subject")
                .to_string()
                .as_str(),
GuanLuo's avatar
GuanLuo committed
68
            16,
69
        )
GuanLuo's avatar
GuanLuo committed
70
        .expect("invalid worker id")
71
72
73
74
    }
}

pub struct SchedulingRequest {
75
76
    pub isl_tokens: usize,
    pub overlap: OverlapScores,
GuanLuo's avatar
GuanLuo committed
77
    resp_tx: tokio::sync::oneshot::Sender<i64>,
78
79
80
}

impl SchedulingRequest {
GuanLuo's avatar
GuanLuo committed
81
    pub fn respond(self, worker_id: i64) {
82
        if self.resp_tx.send(worker_id).is_err() {
83
            tracing::trace!("failed to send response to requestor");
84
85
86
87
88
89
90
91
92
93
        }
    }
}

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
}

impl KvScheduler {
    pub async fn start(
94
        ns: Namespace,
95
96
97
        block_size: usize,
        endpoints_rx: tokio::sync::watch::Receiver<ProcessedEndpoints>,
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
98
    ) -> Result<Self, KvSchedulerError> {
99
        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector));
100
        let mut endpoints_rx = endpoints_rx;
101
        let mut endpoints: ProcessedEndpoints = endpoints_rx.borrow_and_update().clone();
102

103
104
105
106
        let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();
        tokio::spawn(async move {
            let mut event_rx = event_rx;
            while let Some(event) = event_rx.recv().await {
107
                if let Err(e) = ns.publish(KV_HIT_RATE_SUBJECT, &event).await {
108
109
110
111
112
                    tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
                }
            }
        });

113
        // Channel to accept new scheduling requests
114
        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
115
116
117
118
        // Background task to handle scheduling requests
        tokio::spawn(async move {
            let mut request: SchedulingRequest;
            let mut request_rx = request_rx;
119
            tracing::trace!("scheduler background task started");
120
121
122
123
124
125
126
127

            'outer: loop {
                request = tokio::select! {
                    biased;

                    new_request = request_rx.recv() => {
                        match new_request {
                            Some(new_request) => {
128
                                tracing::trace!("received request to be scheduled");
129
130
131
                                new_request
                            },
                            None => {
132
                                tracing::trace!("scheduler shutdown");
133
134
135
136
137
                                break 'outer;
                            }
                        }
                    }

138
139
140
                    _ = endpoints_rx.changed() => {
                        endpoints = endpoints_rx.borrow_and_update().clone();
                        continue 'outer;
141
142
143
                    }
                };
                loop {
144
145
146
147
148
149
150
                    match selector.select_worker(&endpoints, &request, block_size) {
                        Ok(selection) => {
                            let worker_id = process_worker_selection(
                                endpoints.borrow_mut(),
                                selection,
                                &event_tx,
                            );
151
152
153
154
                            request.respond(worker_id);
                            continue 'outer;
                        }
                        Err(KvSchedulerError::AllWorkersBusy) => {
155
                            tracing::trace!("all workers busy; waiting for more capacity");
156
157
158
159
                            match endpoints_rx.changed().await {
                                Ok(_) => {}
                                Err(e) => {
                                    tracing::error!("error waiting for endpoints change: {:?}", e);
160
161
162
                                    break 'outer;
                                }
                            };
163
                            endpoints = endpoints_rx.borrow_and_update().clone();
164
165
                        }
                        Err(e) => {
166
                            tracing::error!("error scheduling request: {:?}", e);
167
168
169
170
171
172
                            break 'outer;
                        }
                    }
                }
            }

173
            tracing::trace!("background endpoint subscriber shutting down");
174
175
176
177
178
179
180
181
182
        });

        Ok(KvScheduler { request_tx })
    }

    pub async fn schedule(
        &self,
        overlap: OverlapScores,
        isl_tokens: usize,
GuanLuo's avatar
GuanLuo committed
183
    ) -> Result<i64, KvSchedulerError> {
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
            isl_tokens,
            overlap,
            resp_tx,
        };
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
        let res = resp_rx
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
        Ok(res)
    }
}

201
202
// This becomes the driver function that handles the selection result
pub fn process_worker_selection(
203
    workers: &mut ProcessedEndpoints,
204
    selection: WorkerSelectionResult,
205
    event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>,
206
207
208
209
210
211
) -> i64 {
    let worker = workers
        .endpoints
        .get_mut(&selection.worker_id)
        .expect("worker not found");

212
213
214
215
    // Update worker state predictively
    // Will be overwritten on next polling of metrics
    worker.data.num_requests_waiting += 1;
    // Assumes radix attention so KV load is only incremented by uncached blocks
216
217
218
219
    // overlap_blocks can be bigger than required_blocks. I don't know if that's a bug or not.
    worker.data.kv_active_blocks += selection
        .required_blocks
        .saturating_sub(selection.overlap_blocks as u64);
220
221
222
223
224
225
226
227

    // Emit event
    if let Err(e) = event_tx.send(KVHitRateEvent {
        worker_id: selection.worker_id,
        isl_blocks: selection.required_blocks as usize,
        overlap_blocks: selection.overlap_blocks,
    }) {
        tracing::warn!("Failed to send KV hit rate event: {:?}", e);
228
229
    }

230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    selection.worker_id
}

// Default implementation matching the Python _cost_function
#[derive(Default)]
pub struct DefaultWorkerSelector;

impl WorkerSelector for DefaultWorkerSelector {
    fn select_worker(
        &self,
        workers: &ProcessedEndpoints,
        request: &SchedulingRequest,
        block_size: usize,
    ) -> Result<WorkerSelectionResult, KvSchedulerError> {
        assert!(request.isl_tokens > 0);

246
247
248
249
        if workers.endpoints.is_empty() {
            return Err(KvSchedulerError::NoEndpoints);
        }

250
        let mut worker_scores = HashMap::new();
251
        let mut max_waiting = 0.0;
252
253
254
255
256
257
258
259
260
261

        // Calculate worker scores and find max waiting requests
        for (worker_id, ep) in workers.endpoints.iter() {
            // Calculate score similar to Python version
            if let Some(score) = request.overlap.scores.get(worker_id) {
                let score = *score as f64 * block_size as f64 / request.isl_tokens as f64;
                worker_scores.insert(worker_id, score);
            }

            // Track max waiting requests
262
            max_waiting = f64::max(max_waiting, ep.data.num_requests_waiting as f64);
263
264
        }

265
266
        // make immutable
        let worker_scores = worker_scores;
267
        let max_waiting = max_waiting;
268

269
270
271
        // Calculate logits for each worker
        let mut best_logit = f64::NEG_INFINITY;
        let mut best_workers = Vec::new();
272

273
274
        for (worker_id, ep) in workers.endpoints.iter() {
            let worker_id = *worker_id;
275

276
277
            // Get score or default to 0.0
            let score = worker_scores.get(&worker_id).copied().unwrap_or(0.0);
278

279
280
281
            // Calculate normalized metrics
            assert!(ep.data.kv_total_blocks > 0);
            let gpu_cache_usage = ep.data.kv_active_blocks as f64 / ep.data.kv_total_blocks as f64;
282
283
            let normalized_waiting = if max_waiting > 0.0 {
                ep.data.num_requests_waiting as f64 / max_waiting
284
285
286
287
288
            } else {
                0.0
            };

            // Calculate logit using same formula as Python
289
            let logit = 2.0 * score - gpu_cache_usage - normalized_waiting;
290

291
292
            tracing::trace!(
                "Formula for {worker_id}: {logit:.3} = 2.0 * {score:.3} - {gpu_cache_usage:.3} - {normalized_waiting:.3}",
293
294
            );

295
296
297
298
299
300
301
302
303
304
305
306
            // Track best workers
            match logit.partial_cmp(&best_logit) {
                Some(std::cmp::Ordering::Greater) => {
                    best_logit = logit;
                    best_workers.clear();
                    best_workers.push(worker_id);
                }
                Some(std::cmp::Ordering::Equal) => {
                    best_workers.push(worker_id);
                }
                _ => {}
            }
307
308
        }

309
        // Return early if no valid workers found
310
        if best_workers.is_empty() {
311
            return Err(KvSchedulerError::NoEndpoints);
312
        } else if best_logit == 0.0 {
313
            tracing::debug!("best worker logit is 0");
314
        }
315

316
317
318
319
320
321
322
323
        let worker_id = if best_workers.len() == 1 {
            best_workers[0]
        } else {
            // Randomly select from best workers
            let mut rng = rand::rng();
            best_workers[rng.random_range(0..best_workers.len())]
        };

324
325
        // Lower to trace level eventually. Nice to see KV routing working for now.
        tracing::debug!("Selected worker: {worker_id}, logit: {best_logit:.3}");
326

327
        // Log selection metrics
328
        let total_blocks = std::cmp::max(request.isl_tokens / block_size, 1) as u64;
329
330
331
332
333
334
335
        let overlap_blocks = request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize;

        Ok(WorkerSelectionResult {
            worker_id,
            required_blocks: total_blocks,
            overlap_blocks,
        })
336
337
    }
}