scheduler.rs 11.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Neelay Shah's avatar
Neelay Shah committed
16
17
use dynamo_runtime::component::Namespace;
use dynamo_runtime::traits::events::EventPublisher;
18
use rand::Rng;
19
20
use serde::{Deserialize, Serialize};
use std::borrow::BorrowMut;
21
use std::collections::HashMap;
22
23

use crate::kv_router::indexer::OverlapScores;
24
pub use crate::kv_router::protocols::ForwardPassMetrics;
25
use crate::kv_router::scoring::ProcessedEndpoints;
26
use crate::kv_router::KV_HIT_RATE_SUBJECT;
27

28
29
30
use super::protocols::WorkerSelectionResult;
use super::WorkerSelector;

31
32
33
34
35
36
37
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent {
    pub worker_id: i64,
    pub isl_blocks: usize,
    pub overlap_blocks: usize,
}

38
39
40
41
42
43
44
45
46
47
48
49
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
    #[error("no endpoints aviailable to route work")]
    NoEndpoints,

    #[error("all workers busy")]
    AllWorkersBusy,

    #[error("endpoint subscriber shutdown")]
    SubscriberShutdown,
}

50
51
/// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
/// is cleaned (not optional)
52
53
54
55
56
57
58
59
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint {
    pub name: String,
    pub subject: String,
    pub data: ForwardPassMetrics,
}

impl Endpoint {
GuanLuo's avatar
GuanLuo committed
60
61
    pub fn worker_id(&self) -> i64 {
        i64::from_str_radix(
62
            self.subject
GuanLuo's avatar
GuanLuo committed
63
                .split("-")
64
65
66
67
                .last()
                .expect("invalid subject")
                .to_string()
                .as_str(),
GuanLuo's avatar
GuanLuo committed
68
            16,
69
        )
GuanLuo's avatar
GuanLuo committed
70
        .expect("invalid worker id")
71
72
73
74
    }
}

pub struct SchedulingRequest {
75
76
    pub isl_tokens: usize,
    pub overlap: OverlapScores,
GuanLuo's avatar
GuanLuo committed
77
    resp_tx: tokio::sync::oneshot::Sender<i64>,
78
79
80
}

impl SchedulingRequest {
GuanLuo's avatar
GuanLuo committed
81
    pub fn respond(self, worker_id: i64) {
82
        if self.resp_tx.send(worker_id).is_err() {
83
            tracing::trace!("failed to send response to requestor");
84
85
86
87
88
89
90
91
92
93
        }
    }
}

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
}

impl KvScheduler {
    pub async fn start(
94
        ns: Namespace,
95
96
97
        block_size: usize,
        endpoints_rx: tokio::sync::watch::Receiver<ProcessedEndpoints>,
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
98
    ) -> Result<Self, KvSchedulerError> {
99
        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector));
100
        let mut endpoints_rx = endpoints_rx;
101
        let mut endpoints: ProcessedEndpoints = endpoints_rx.borrow_and_update().clone();
102

103
104
105
106
        let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();
        tokio::spawn(async move {
            let mut event_rx = event_rx;
            while let Some(event) = event_rx.recv().await {
107
                if let Err(e) = ns.publish(KV_HIT_RATE_SUBJECT, &event).await {
108
109
110
111
112
                    tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
                }
            }
        });

113
        // Channel to accept new scheduling requests
114
        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
115
        tracing::debug!("scheduler starting");
116
117
118
119
        // Background task to handle scheduling requests
        tokio::spawn(async move {
            let mut request: SchedulingRequest;
            let mut request_rx = request_rx;
120
            tracing::debug!("scheduler background task started");
121
122
123
124
125
126
127
128

            'outer: loop {
                request = tokio::select! {
                    biased;

                    new_request = request_rx.recv() => {
                        match new_request {
                            Some(new_request) => {
129
                                tracing::trace!("received request to be scheduled");
130
131
132
                                new_request
                            },
                            None => {
133
                                tracing::trace!("scheduler shutdown");
134
135
136
137
138
                                break 'outer;
                            }
                        }
                    }

139
140
141
                    _ = endpoints_rx.changed() => {
                        endpoints = endpoints_rx.borrow_and_update().clone();
                        continue 'outer;
142
143
                    }
                };
144
                tracing::debug!("selected");
145
                loop {
146
147
148
149
150
151
152
                    match selector.select_worker(&endpoints, &request, block_size) {
                        Ok(selection) => {
                            let worker_id = process_worker_selection(
                                endpoints.borrow_mut(),
                                selection,
                                &event_tx,
                            );
153
154
155
156
                            request.respond(worker_id);
                            continue 'outer;
                        }
                        Err(KvSchedulerError::AllWorkersBusy) => {
157
                            tracing::trace!("all workers busy; waiting for more capacity");
158
159
160
161
                            match endpoints_rx.changed().await {
                                Ok(_) => {}
                                Err(e) => {
                                    tracing::error!("error waiting for endpoints change: {:?}", e);
162
163
164
                                    break 'outer;
                                }
                            };
165
                            endpoints = endpoints_rx.borrow_and_update().clone();
166
167
                        }
                        Err(e) => {
168
                            tracing::error!("error scheduling request: {:?}", e);
169
170
171
172
173
174
                            break 'outer;
                        }
                    }
                }
            }

175
            tracing::trace!("background endpoint subscriber shutting down");
176
177
178
179
180
181
182
183
184
        });

        Ok(KvScheduler { request_tx })
    }

    pub async fn schedule(
        &self,
        overlap: OverlapScores,
        isl_tokens: usize,
GuanLuo's avatar
GuanLuo committed
185
    ) -> Result<i64, KvSchedulerError> {
186
187
188
189
190
191
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
            isl_tokens,
            overlap,
            resp_tx,
        };
192
        tracing::debug!("before sending request");
193
194
195
196
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
197
        tracing::debug!("after sending request");
198
199
200
201

        let res = resp_rx
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
202
        tracing::debug!("after receiving response");
203
204
205
206
        Ok(res)
    }
}

207
208
// This becomes the driver function that handles the selection result
pub fn process_worker_selection(
209
    workers: &mut ProcessedEndpoints,
210
    selection: WorkerSelectionResult,
211
    event_tx: &tokio::sync::mpsc::UnboundedSender<KVHitRateEvent>,
212
213
214
215
216
217
) -> i64 {
    let worker = workers
        .endpoints
        .get_mut(&selection.worker_id)
        .expect("worker not found");

218
219
220
221
    // Update worker state predictively
    // Will be overwritten on next polling of metrics
    worker.data.num_requests_waiting += 1;
    // Assumes radix attention so KV load is only incremented by uncached blocks
222
223
224
225
226
227
228
229
230
    worker.data.kv_active_blocks += selection.required_blocks - selection.overlap_blocks as u64;

    // Emit event
    if let Err(e) = event_tx.send(KVHitRateEvent {
        worker_id: selection.worker_id,
        isl_blocks: selection.required_blocks as usize,
        overlap_blocks: selection.overlap_blocks,
    }) {
        tracing::warn!("Failed to send KV hit rate event: {:?}", e);
231
232
    }

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    selection.worker_id
}

// Default implementation matching the Python _cost_function
#[derive(Default)]
pub struct DefaultWorkerSelector;

impl WorkerSelector for DefaultWorkerSelector {
    fn select_worker(
        &self,
        workers: &ProcessedEndpoints,
        request: &SchedulingRequest,
        block_size: usize,
    ) -> Result<WorkerSelectionResult, KvSchedulerError> {
        assert!(request.isl_tokens > 0);

        let mut worker_scores = HashMap::new();
250
        let mut max_waiting = 0.0;
251
252
253
254
255
256
257
258
259
260

        // Calculate worker scores and find max waiting requests
        for (worker_id, ep) in workers.endpoints.iter() {
            // Calculate score similar to Python version
            if let Some(score) = request.overlap.scores.get(worker_id) {
                let score = *score as f64 * block_size as f64 / request.isl_tokens as f64;
                worker_scores.insert(worker_id, score);
            }

            // Track max waiting requests
261
            max_waiting = f64::max(max_waiting, ep.data.num_requests_waiting as f64);
262
263
        }

264
        if max_waiting == 0.0 {
265
266
            return Err(KvSchedulerError::NoEndpoints);
        }
267

268
269
        // make immutable
        let worker_scores = worker_scores;
270
        let max_waiting = max_waiting;
271

272
273
274
        // Calculate logits for each worker
        let mut best_logit = f64::NEG_INFINITY;
        let mut best_workers = Vec::new();
275

276
277
        for (worker_id, ep) in workers.endpoints.iter() {
            let worker_id = *worker_id;
278

279
280
            // Get score or default to 0.0
            let score = worker_scores.get(&worker_id).copied().unwrap_or(0.0);
281

282
283
284
            // Calculate normalized metrics
            assert!(ep.data.kv_total_blocks > 0);
            let gpu_cache_usage = ep.data.kv_active_blocks as f64 / ep.data.kv_total_blocks as f64;
285
286
            let normalized_waiting = if max_waiting > 0.0 {
                ep.data.num_requests_waiting as f64 / max_waiting
287
288
289
290
291
            } else {
                0.0
            };

            // Calculate logit using same formula as Python
292
            let logit = 2.0 * score - gpu_cache_usage - normalized_waiting;
293
294
295

            tracing::info!(
                "Formula for {}: {:.3} = 2.0 * {:.3} - {:.3} - {:.3}",
296
                worker_id,
297
298
299
                logit,
                score,
                gpu_cache_usage,
300
                normalized_waiting
301
302
            );

303
304
305
306
307
308
309
310
311
312
313
314
            // Track best workers
            match logit.partial_cmp(&best_logit) {
                Some(std::cmp::Ordering::Greater) => {
                    best_logit = logit;
                    best_workers.clear();
                    best_workers.push(worker_id);
                }
                Some(std::cmp::Ordering::Equal) => {
                    best_workers.push(worker_id);
                }
                _ => {}
            }
315
316
        }

317
        // Return early if no valid workers found
318
        if best_workers.is_empty() {
319
            return Err(KvSchedulerError::NoEndpoints);
320
321
        } else if best_logit == 0.0 {
            tracing::warn!("best worker logit is 0");
322
        }
323

324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
        let worker_id = if best_workers.len() == 1 {
            best_workers[0]
        } else {
            // Randomly select from best workers
            let mut rng = rand::rng();
            best_workers[rng.random_range(0..best_workers.len())]
        };

        // Log selection metrics
        tracing::info!("Selected worker: {}, logit: {:.3}", worker_id, best_logit);

        let total_blocks = std::cmp::min(request.isl_tokens / block_size, 1) as u64;
        let overlap_blocks = request.overlap.scores.get(&worker_id).copied().unwrap_or(0) as usize;

        Ok(WorkerSelectionResult {
            worker_id,
            required_blocks: total_blocks,
            overlap_blocks,
        })
343
344
    }
}