scheduler.rs 16.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Neelay Shah's avatar
Neelay Shah committed
16
17
use dynamo_runtime::component::Namespace;
use dynamo_runtime::traits::events::EventPublisher;
18
use rand::Rng;
19
use serde::{Deserialize, Serialize};
20
use std::collections::HashMap;
21
22
23
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
24

25
26
use super::protocols::WorkerSelectionResult;
use super::WorkerSelector;
27
use crate::kv_router::indexer::OverlapScores;
28
29
use crate::kv_router::indexer::WorkerId;
use crate::kv_router::protocols::LoadMetrics;
30
use crate::kv_router::scoring::ProcessedEndpoints;
31
use crate::kv_router::sequence::ActiveSequencesMultiWorker;
32
use crate::kv_router::KvRouterConfig;
33
use crate::kv_router::KV_HIT_RATE_SUBJECT;
34
use crate::tokens::TokenBlockSequence;
35

36
37
38
39
40
41
42
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent {
    pub worker_id: i64,
    pub isl_blocks: usize,
    pub overlap_blocks: usize,
}

43
44
45
46
47
48
49
50
51
52
53
54
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
    #[error("no endpoints aviailable to route work")]
    NoEndpoints,

    #[error("all workers busy")]
    AllWorkersBusy,

    #[error("endpoint subscriber shutdown")]
    SubscriberShutdown,
}

55
56
/// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
/// is cleaned (not optional)
57
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
58
59
60
pub struct Endpoint {
    pub name: String,
    pub subject: String,
61
    pub data: LoadMetrics,
62
63
64
}

impl Endpoint {
GuanLuo's avatar
GuanLuo committed
65
66
    pub fn worker_id(&self) -> i64 {
        i64::from_str_radix(
67
            self.subject
GuanLuo's avatar
GuanLuo committed
68
                .split("-")
69
70
71
72
                .last()
                .expect("invalid subject")
                .to_string()
                .as_str(),
GuanLuo's avatar
GuanLuo committed
73
            16,
74
        )
GuanLuo's avatar
GuanLuo committed
75
        .expect("invalid worker id")
76
77
78
    }
}

79
80
81
82
83
84
#[derive(Debug)]
pub struct SchedulingResponse {
    pub best_worker_id: i64,
    pub endpoints_changed: Option<Vec<i64>>,
}

85
pub struct SchedulingRequest {
86
87
    pub isl_tokens: usize,
    pub overlap: OverlapScores,
88
89
    pub potential_blocks: HashMap<i64, usize>,
    resp_tx: tokio::sync::oneshot::Sender<SchedulingResponse>,
90
91
92
}

impl SchedulingRequest {
93
94
95
    pub fn respond(self, response: SchedulingResponse) {
        if self.resp_tx.send(response).is_err() {
            tracing::error!("failed to send response to requestor");
96
97
98
99
100
101
        }
    }
}

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
102
    sequences: Arc<Mutex<ActiveSequencesMultiWorker>>,
103
104
105
106
}

impl KvScheduler {
    pub async fn start(
107
        ns: Namespace,
108
        block_size: u32,
109
110
        endpoints_rx: tokio::sync::watch::Receiver<ProcessedEndpoints>,
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
111
    ) -> Result<Self, KvSchedulerError> {
112
        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
113
        let mut endpoints_rx = endpoints_rx;
114
        let mut endpoints: ProcessedEndpoints = endpoints_rx.borrow_and_update().clone();
115

116
117
118
119
        let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();
        tokio::spawn(async move {
            let mut event_rx = event_rx;
            while let Some(event) = event_rx.recv().await {
120
                if let Err(e) = ns.publish(KV_HIT_RATE_SUBJECT, &event).await {
121
122
123
124
125
                    tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
                }
            }
        });

126
127
128
129
130
        let sequences = Arc::new(Mutex::new(ActiveSequencesMultiWorker::new(
            block_size as usize,
            endpoints.worker_ids(),
        )));

131
        // Channel to accept new scheduling requests
132
        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
133
134
135
136
        // Background task to handle scheduling requests
        tokio::spawn(async move {
            let mut request: SchedulingRequest;
            let mut request_rx = request_rx;
137
            let mut pending_endpoint_update: Option<Vec<i64>> = None;
138
            tracing::trace!("scheduler background task started");
139
140
141
142
143
144
145
146

            'outer: loop {
                request = tokio::select! {
                    biased;

                    new_request = request_rx.recv() => {
                        match new_request {
                            Some(new_request) => {
147
                                tracing::trace!("received request to be scheduled");
148
149
150
                                new_request
                            },
                            None => {
151
                                tracing::trace!("scheduler shutdown");
152
153
154
155
156
                                break 'outer;
                            }
                        }
                    }

157
158
                    _ = endpoints_rx.changed() => {
                        endpoints = endpoints_rx.borrow_and_update().clone();
159
                        pending_endpoint_update = Some(endpoints.worker_ids());
160
                        continue 'outer;
161
162
                    }
                };
163

164
                loop {
165
166
                    match selector.select_worker(&endpoints, &request, block_size) {
                        Ok(selection) => {
167
168
169
170
171
172
173
174
175
176
177
178
179
                            if let Err(e) = event_tx.send(KVHitRateEvent {
                                worker_id: selection.worker_id,
                                isl_blocks: selection.required_blocks as usize,
                                overlap_blocks: selection.overlap_blocks,
                            }) {
                                tracing::warn!("Failed to send KV hit rate event: {:?}", e);
                            }

                            let response = SchedulingResponse {
                                best_worker_id: selection.worker_id,
                                endpoints_changed: pending_endpoint_update.take(),
                            };
                            request.respond(response);
180
181
182
                            continue 'outer;
                        }
                        Err(KvSchedulerError::AllWorkersBusy) => {
183
                            tracing::trace!("all workers busy; waiting for more capacity");
184
185
                            tokio::time::sleep(Duration::from_millis(5)).await;
                            continue;
186
187
                        }
                        Err(e) => {
188
                            tracing::error!("error scheduling request: {:?}", e);
189
190
191
192
193
194
                            break 'outer;
                        }
                    }
                }
            }

195
            tracing::trace!("background endpoint subscriber shutting down");
196
197
        });

198
199
200
201
        Ok(KvScheduler {
            request_tx,
            sequences,
        })
202
203
204
205
    }

    pub async fn schedule(
        &self,
206
        request_id: String,
207
        isl_tokens: usize,
208
209
210
        block_size: u32,
        tokens: &[u32],
        overlap: OverlapScores,
GuanLuo's avatar
GuanLuo committed
211
    ) -> Result<i64, KvSchedulerError> {
212
213
214
215
216
        let mut sequences = self.sequences.lock().await;

        let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
        let potential_blocks = sequences.potential_blocks(token_sequence);

217
218
219
220
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
            isl_tokens,
            overlap,
221
            potential_blocks,
222
223
224
225
226
227
            resp_tx,
        };
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
228
        let response = resp_rx
229
230
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

        if let Some(new_worker_ids) = response.endpoints_changed {
            sequences.update_workers(new_worker_ids);
        }

        let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
        sequences.add_request(request_id, token_sequence, response.best_worker_id);

        Ok(response.best_worker_id)
    }

    /// Find the potential blocks for each worker if the sequence were routed there
    pub async fn potential_blocks(
        &self,
        token_sequence: TokenBlockSequence,
    ) -> HashMap<i64, usize> {
        let sequences = self.sequences.lock().await;
        sequences.potential_blocks(token_sequence)
    }

    /// Add a new request with its initial tokens to a specific worker
    pub async fn add_request(
        &self,
        request_id: String,
        token_sequence: TokenBlockSequence,
        worker_id: WorkerId,
    ) {
        let mut sequences = self.sequences.lock().await;
        sequences.add_request(request_id, token_sequence, worker_id)
260
261
    }

262
263
    /// Push tokens to a specific request's sequence
    pub async fn push(&self, request_id: &String, tokens: &[u32]) {
264
        let mut sequences = self.sequences.lock().await;
265
        sequences.push(request_id, tokens)
266
267
    }

268
269
270
271
272
    /// Free all blocks associated with a request
    pub async fn free(&self, request_id: &String) {
        let mut sequences = self.sequences.lock().await;
        sequences.free(request_id)
    }
273
274
}

275
276
277
278
279
280
// Helper function for softmax sampling
fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 {
    if logits.is_empty() {
        panic!("Empty logits for softmax sampling");
    }

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    // Guard: if temperature is 0, return the key with the smallest logit value
    if temperature == 0.0 {
        // Find the minimum logit value
        let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b));

        // Collect all keys with the minimum logit value (to handle ties)
        let min_keys: Vec<_> = logits
            .iter()
            .filter(|(_, &v)| v == min_logit)
            .map(|(k, _)| *k)
            .collect();

        // Randomly select from the minimum keys (handles single key case naturally)
        let mut rng = rand::rng();
        let index = rng.random_range(0..min_keys.len());
        return min_keys[index];
    }

299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    let keys: Vec<_> = logits.keys().copied().collect();
    let values: Vec<_> = logits.values().copied().collect();

    // Find min and max for normalization
    let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
    let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));

    let probabilities = if min_val == max_val {
        // All values are the same, uniform probability
        vec![1.0 / keys.len() as f64; keys.len()]
    } else {
        // Normalize values
        let normalized: Vec<_> = values
            .iter()
            .map(|&v| {
                let norm = v / (max_val - min_val);
                // Lower is better, so negate
                -norm
            })
            .collect();

        // Apply temperature and softmax
        let scaled: Vec<_> = normalized.iter().map(|&v| v / temperature).collect();

        let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
        let exp_values: Vec<_> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();

        let sum_exp: f64 = exp_values.iter().sum();
        exp_values.iter().map(|&v| v / sum_exp).collect()
    };

    // Sample from the probability distribution
    let mut rng = rand::rng();
    let sample: f64 = rng.random();

    let mut cumsum = 0.0;
    for (i, &prob) in probabilities.iter().enumerate() {
        cumsum += prob;
        if sample <= cumsum {
            return keys[i];
        }
    }

    // Fallback to last key (shouldn't normally reach here)
    keys[keys.len() - 1]
}

346
// Default implementation matching the Python _cost_function
347
348
349
350
351
352
353
354
355
356
357
358
#[derive(Debug, Clone, Default)]
pub struct DefaultWorkerSelector {
    pub kv_router_config: KvRouterConfig,
}

impl DefaultWorkerSelector {
    pub fn new(kv_router_config: Option<KvRouterConfig>) -> Self {
        Self {
            kv_router_config: kv_router_config.unwrap_or_default(),
        }
    }
}
359
360
361
362
363
364

impl WorkerSelector for DefaultWorkerSelector {
    fn select_worker(
        &self,
        workers: &ProcessedEndpoints,
        request: &SchedulingRequest,
365
        block_size: u32,
366
367
368
    ) -> Result<WorkerSelectionResult, KvSchedulerError> {
        assert!(request.isl_tokens > 0);

369
370
371
372
        if workers.endpoints.is_empty() {
            return Err(KvSchedulerError::NoEndpoints);
        }

373
        let request_blocks = request.isl_tokens.div_ceil(block_size as usize);
374
375
        let potential_active_blocks = &request.potential_blocks;

376
        let mut worker_logits = HashMap::new();
377
        let mut max_logit = f64::NEG_INFINITY;
378

379
        // Calculate logits for each worker
380
381
382
        for (worker_id, _) in workers.endpoints.iter() {
            let cached_blocks = request.overlap.scores.get(worker_id).copied().unwrap_or(0) as f64;
            let prefill_blocks = request_blocks as f64 - cached_blocks;
383

384
385
386
387
388
            // this is the number of blocks each worker would have if the request were scheduled there
            let potential_blocks = *potential_active_blocks.get(worker_id).unwrap_or_else(||
                {tracing::warn!("assuming 0 decoding blocks for {worker_id}, as the load metrics endpoint does not exist yet");
                &0
            }) as f64;
389

390
            // Calculate logit (lower is better)
391
392
393
            let logit =
                self.kv_router_config.overlap_score_weight * prefill_blocks + potential_blocks;
            max_logit = max_logit.max(logit);
394

395
            worker_logits.insert(*worker_id, logit);
396
397

            tracing::info!(
398
                "Formula for {worker_id}: {logit:.3} = {:.1} * {prefill_blocks:.3} + {potential_blocks:.3}",
399
                self.kv_router_config.overlap_score_weight,
400
401
402
            );
        }

403
        // Normalize by dividing by max value
404
405
406
407
        if max_logit > 0.0 {
            for logit in worker_logits.values_mut() {
                *logit /= max_logit;
            }
408
        }
409

410
        // Use softmax sampling to select worker
411
        let temperature = self.kv_router_config.router_temperature;
412
413
414
415
416
417
418
419
420
421
422
        let best_worker_id = softmax_sample(&worker_logits, temperature);

        let overlap_blocks = request
            .overlap
            .scores
            .get(&best_worker_id)
            .copied()
            .unwrap_or(0) as usize;
        let best_logit = worker_logits[&best_worker_id];

        tracing::info!(
423
            "Selected worker: {}, normalized logit: {:.3}",
424
425
426
            best_worker_id,
            best_logit
        );
427
428

        Ok(WorkerSelectionResult {
429
430
            worker_id: best_worker_id,
            required_blocks: request_blocks as u64,
431
432
            overlap_blocks,
        })
433
434
    }
}
435
436
437
438
439

#[cfg(test)]
mod tests {
    use super::*;

440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
    #[test]
    fn test_softmax_sample_single_key() {
        // Test that with a single key, softmax_sample always returns that key
        let mut logits = HashMap::new();
        let worker_id = 42;
        logits.insert(worker_id, 0.5); // The value doesn't matter

        // Test with different temperatures
        for temperature in &[0.1, 1.0, 10.0] {
            let result = softmax_sample(&logits, *temperature);
            assert_eq!(result, worker_id, "Should return the only available worker");
        }

        // Test with different logit values
        logits.clear();
        logits.insert(worker_id, -100.0); // Very negative value
        assert_eq!(softmax_sample(&logits, 1.0), worker_id);

        logits.clear();
        logits.insert(worker_id, 100.0); // Very positive value
        assert_eq!(softmax_sample(&logits, 1.0), worker_id);

        logits.clear();
        logits.insert(worker_id, 0.0); // Zero value
        assert_eq!(softmax_sample(&logits, 1.0), worker_id);
    }

467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
    #[test]
    fn test_softmax_sample_zero_temperature() {
        // Test that with temperature 0, softmax_sample returns the key with smallest logit
        let mut logits = HashMap::new();
        logits.insert(1, 5.0);
        logits.insert(2, 3.0); // This has the smallest logit
        logits.insert(3, 7.0);
        logits.insert(4, 3.5);

        // With temperature 0, should always return worker 2 (smallest logit)
        for _ in 0..10 {
            let result = softmax_sample(&logits, 0.0);
            assert_eq!(
                result, 2,
                "Should return worker with smallest logit when temperature is 0"
482
483
484
            );
        }

485
486
487
488
489
        // Test with negative values
        logits.clear();
        logits.insert(10, -1.0);
        logits.insert(20, -5.0); // This has the smallest logit
        logits.insert(30, 0.0);
490

491
492
        let result = softmax_sample(&logits, 0.0);
        assert_eq!(result, 20, "Should handle negative logits correctly");
493
494
    }
}