scheduler.rs 17.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Neelay Shah's avatar
Neelay Shah committed
16
17
use dynamo_runtime::component::Namespace;
use dynamo_runtime::traits::events::EventPublisher;
18
use rand::Rng;
19
use serde::{Deserialize, Serialize};
20
use std::collections::HashMap;
21
22
23
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
24

25
26
use super::protocols::WorkerSelectionResult;
use super::WorkerSelector;
27
use crate::kv_router::indexer::OverlapScores;
28
29
use crate::kv_router::protocols::LoadMetrics;
use crate::kv_router::sequence::ActiveSequencesMultiWorker;
30
use crate::kv_router::KvRouterConfig;
31
use crate::kv_router::KV_HIT_RATE_SUBJECT;
32
use crate::tokens::SequenceHash;
33
use dynamo_runtime::component::Instance;
34

35
36
37
38
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent {
    pub worker_id: i64,
    pub isl_blocks: usize,
39
    pub overlap_blocks: u32,
40
41
}

42
43
44
45
46
47
48
49
50
51
52
53
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
    #[error("no endpoints aviailable to route work")]
    NoEndpoints,

    #[error("all workers busy")]
    AllWorkersBusy,

    #[error("endpoint subscriber shutdown")]
    SubscriberShutdown,
}

54
55
/// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
/// is cleaned (not optional)
56
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57
58
59
pub struct Endpoint {
    pub name: String,
    pub subject: String,
60
    pub data: LoadMetrics,
61
62
63
}

impl Endpoint {
GuanLuo's avatar
GuanLuo committed
64
65
    pub fn worker_id(&self) -> i64 {
        i64::from_str_radix(
66
            self.subject
GuanLuo's avatar
GuanLuo committed
67
                .split("-")
68
69
70
71
                .last()
                .expect("invalid subject")
                .to_string()
                .as_str(),
GuanLuo's avatar
GuanLuo committed
72
            16,
73
        )
GuanLuo's avatar
GuanLuo committed
74
        .expect("invalid worker id")
75
76
77
    }
}

78
79
80
#[derive(Debug)]
pub struct SchedulingResponse {
    pub best_worker_id: i64,
81
    pub overlap_blocks: u32, // Add this field
82
83
84
    pub endpoints_changed: Option<Vec<i64>>,
}

85
pub struct SchedulingRequest {
86
    pub isl_tokens: usize,
87
    pub overlaps: OverlapScores,
88
    pub potential_blocks: HashMap<i64, usize>,
89
    pub potential_tokens: HashMap<i64, usize>,
90
    resp_tx: tokio::sync::oneshot::Sender<SchedulingResponse>,
91
92
93
}

impl SchedulingRequest {
94
95
96
    pub fn respond(self, response: SchedulingResponse) {
        if self.resp_tx.send(response).is_err() {
            tracing::error!("failed to send response to requestor");
97
98
99
100
101
102
        }
    }
}

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
103
    sequences: Arc<Mutex<ActiveSequencesMultiWorker>>,
104
105
106
107
}

impl KvScheduler {
    pub async fn start(
108
        ns: Namespace,
109
        block_size: u32,
110
        mut instances_rx: tokio::sync::watch::Receiver<Vec<Instance>>, // Changed from ProcessedEndpoints
111
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
112
    ) -> Result<Self, KvSchedulerError> {
113
        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
114
115
116
117
        let mut instances: Vec<Instance> = instances_rx.borrow_and_update().clone();

        // Get worker IDs from instances
        let worker_ids: Vec<i64> = instances.iter().map(|i| i.instance_id).collect();
118

119
120
121
122
        let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();
        tokio::spawn(async move {
            let mut event_rx = event_rx;
            while let Some(event) = event_rx.recv().await {
123
                if let Err(e) = ns.publish(KV_HIT_RATE_SUBJECT, &event).await {
124
125
126
127
128
                    tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
                }
            }
        });

129
130
        let sequences = Arc::new(Mutex::new(ActiveSequencesMultiWorker::new(
            block_size as usize,
131
            worker_ids,
132
133
        )));

134
        // Channel to accept new scheduling requests
135
        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
136
137
138
139
        // Background task to handle scheduling requests
        tokio::spawn(async move {
            let mut request: SchedulingRequest;
            let mut request_rx = request_rx;
140
            let mut pending_endpoint_update: Option<Vec<i64>> = None;
141
            tracing::trace!("scheduler background task started");
142
143
144
145
146

            'outer: loop {
                request = tokio::select! {
                    biased;

147
148
149
150
                    _ = instances_rx.changed() => {
                        instances = instances_rx.borrow_and_update().clone();
                        let worker_ids: Vec<i64> = instances.iter().map(|i| i.instance_id).collect();
                        pending_endpoint_update = Some(worker_ids);
151
                        continue 'outer;
152
                    }
153
154
155
156
157
158
159
160
161

                    maybe_new_request = request_rx.recv() => {
                        let Some(new_request) = maybe_new_request else {
                            tracing::warn!("scheduler shutdown");
                            break 'outer;
                        };
                        tracing::trace!("received request to be scheduled");
                        new_request
                    }
162
                };
163

164
                loop {
165
166
                    // When calling selector.select_worker, we need to adapt
                    match selector.select_worker(&instances, &request, block_size) {
167
                        Ok(selection) => {
168
169
170
171
172
173
174
175
176
177
                            if let Err(e) = event_tx.send(KVHitRateEvent {
                                worker_id: selection.worker_id,
                                isl_blocks: selection.required_blocks as usize,
                                overlap_blocks: selection.overlap_blocks,
                            }) {
                                tracing::warn!("Failed to send KV hit rate event: {:?}", e);
                            }

                            let response = SchedulingResponse {
                                best_worker_id: selection.worker_id,
178
                                overlap_blocks: selection.overlap_blocks,
179
180
181
                                endpoints_changed: pending_endpoint_update.take(),
                            };
                            request.respond(response);
182
183
                            continue 'outer;
                        }
184
185
                        Err(KvSchedulerError::NoEndpoints) => {
                            tracing::trace!("no endpoints available; waiting for endpoints update");
186
187
188
189
190
                            instances_rx.changed().await.ok();
                            instances = instances_rx.borrow_and_update().clone();
                            let worker_ids: Vec<i64> =
                                instances.iter().map(|i| i.instance_id).collect();
                            pending_endpoint_update = Some(worker_ids);
191
192
                            continue;
                        }
193
                        // TODO: this is not actually hooked up
194
                        Err(KvSchedulerError::AllWorkersBusy) => {
195
                            tracing::trace!("all workers busy; waiting for more capacity");
196
197
                            tokio::time::sleep(Duration::from_millis(5)).await;
                            continue;
198
199
                        }
                        Err(e) => {
200
                            tracing::error!("error scheduling request: {:?}", e);
201
202
203
204
205
206
                            break 'outer;
                        }
                    }
                }
            }

207
            tracing::trace!("background endpoint subscriber shutting down");
208
209
        });

210
211
212
213
        Ok(KvScheduler {
            request_tx,
            sequences,
        })
214
215
216
217
    }

    pub async fn schedule(
        &self,
218
        request_id: String,
219
        isl_tokens: usize,
220
        token_seq: Vec<SequenceHash>,
221
        overlaps: OverlapScores,
GuanLuo's avatar
GuanLuo committed
222
    ) -> Result<i64, KvSchedulerError> {
223
224
        let mut sequences = self.sequences.lock().await;

225
        let (potential_blocks, potential_tokens) =
226
            sequences.potential_blocks_and_tokens(token_seq.clone(), isl_tokens, overlaps.clone());
227

228
229
230
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
            isl_tokens,
231
            overlaps,
232
            potential_blocks,
233
            potential_tokens,
234
235
236
237
238
239
            resp_tx,
        };
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
240
        let response = resp_rx
241
242
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
243
244
245
246
247

        if let Some(new_worker_ids) = response.endpoints_changed {
            sequences.update_workers(new_worker_ids);
        }

248
249
        sequences.add_request(
            request_id,
250
251
            token_seq,
            isl_tokens,
252
253
254
            response.overlap_blocks,
            response.best_worker_id,
        );
255
256
257
258

        Ok(response.best_worker_id)
    }

259
    pub async fn mark_prefill_completed(&self, request_id: &String) {
260
        let mut sequences = self.sequences.lock().await;
261
        sequences.mark_prefill_completed(request_id)
262
263
    }

264
265
266
267
268
    /// Free all blocks associated with a request
    pub async fn free(&self, request_id: &String) {
        let mut sequences = self.sequences.lock().await;
        sequences.free(request_id)
    }
269
270
}

271
272
273
274
275
276
// Helper function for softmax sampling
fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 {
    if logits.is_empty() {
        panic!("Empty logits for softmax sampling");
    }

277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    // Guard: if temperature is 0, return the key with the smallest logit value
    if temperature == 0.0 {
        // Find the minimum logit value
        let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b));

        // Collect all keys with the minimum logit value (to handle ties)
        let min_keys: Vec<_> = logits
            .iter()
            .filter(|(_, &v)| v == min_logit)
            .map(|(k, _)| *k)
            .collect();

        // Randomly select from the minimum keys (handles single key case naturally)
        let mut rng = rand::rng();
        let index = rng.random_range(0..min_keys.len());
        return min_keys[index];
    }

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    let keys: Vec<_> = logits.keys().copied().collect();
    let values: Vec<_> = logits.values().copied().collect();

    // Find min and max for normalization
    let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
    let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));

    let probabilities = if min_val == max_val {
        // All values are the same, uniform probability
        vec![1.0 / keys.len() as f64; keys.len()]
    } else {
        // Normalize values
        let normalized: Vec<_> = values
            .iter()
            .map(|&v| {
                let norm = v / (max_val - min_val);
                // Lower is better, so negate
                -norm
            })
            .collect();

        // Apply temperature and softmax
        let scaled: Vec<_> = normalized.iter().map(|&v| v / temperature).collect();

        let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
        let exp_values: Vec<_> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();

        let sum_exp: f64 = exp_values.iter().sum();
        exp_values.iter().map(|&v| v / sum_exp).collect()
    };

    // Sample from the probability distribution
    let mut rng = rand::rng();
    let sample: f64 = rng.random();

    let mut cumsum = 0.0;
    for (i, &prob) in probabilities.iter().enumerate() {
        cumsum += prob;
        if sample <= cumsum {
            return keys[i];
        }
    }

    // Fallback to last key (shouldn't normally reach here)
    keys[keys.len() - 1]
}

342
// Default implementation matching the Python _cost_function
343
344
345
346
347
348
349
350
351
352
353
354
#[derive(Debug, Clone, Default)]
pub struct DefaultWorkerSelector {
    pub kv_router_config: KvRouterConfig,
}

impl DefaultWorkerSelector {
    pub fn new(kv_router_config: Option<KvRouterConfig>) -> Self {
        Self {
            kv_router_config: kv_router_config.unwrap_or_default(),
        }
    }
}
355
356
357
358

impl WorkerSelector for DefaultWorkerSelector {
    fn select_worker(
        &self,
359
        workers: &[Instance],
360
        request: &SchedulingRequest,
361
        block_size: u32,
362
363
364
    ) -> Result<WorkerSelectionResult, KvSchedulerError> {
        assert!(request.isl_tokens > 0);

365
        if workers.is_empty() {
366
367
368
            return Err(KvSchedulerError::NoEndpoints);
        }

369
370
371
372
373
        let isl = request.isl_tokens;
        let request_blocks = isl.div_ceil(block_size as usize);
        let overlaps = &request.overlaps.scores;

        // active blocks for decoding
374
        let potential_active_blocks = &request.potential_blocks;
375
376
        // active tokens in the batch (processed by the linear layers), mostly prefill tokens
        let potential_active_tokens = &request.potential_tokens;
377

378
        let mut worker_logits = HashMap::new();
379
        let mut max_logit = f64::NEG_INFINITY;
380

381
        // Calculate logits for each worker
382
383
        for instance in workers.iter() {
            let worker_id = instance.instance_id;
384
            // this is the number of tokens each worker would have if the request were scheduled there
385
            let potential_tokens = *potential_active_tokens.get(&worker_id).unwrap_or_else(|| {
386
387
388
389
390
                tracing::warn!(
                    "assuming {isl} tokens for {worker_id}, as the endpoint does not exist yet"
                );
                &isl
            }) as f64;
391

392
            // this is the number of blocks each worker would have if the request were scheduled there
393
            let potential_blocks = *potential_active_blocks.get(&worker_id).unwrap_or_else(||
394
395
                {tracing::warn!("assuming {request_blocks} decoding blocks for {worker_id}, as the endpoint does not exist yet");
                &request_blocks
396
            }) as f64;
397

398
399
            let potential_prefill_blocks = potential_tokens / (block_size as f64);

400
            // Calculate logit (lower is better)
401
402
            let logit = self.kv_router_config.overlap_score_weight * potential_prefill_blocks
                + potential_blocks;
403
            max_logit = max_logit.max(logit);
404

405
            worker_logits.insert(worker_id, logit);
406
407

            tracing::info!(
408
                "Formula for {worker_id}: {logit:.3} = {:.1} * {potential_prefill_blocks:.3} + {potential_blocks:.3}  (cached_blocks: {})",
409
                self.kv_router_config.overlap_score_weight,
410
                overlaps.get(&worker_id).unwrap_or(&0),
411
412
413
            );
        }

414
        // Normalize by dividing by max value
415
416
417
418
        if max_logit > 0.0 {
            for logit in worker_logits.values_mut() {
                *logit /= max_logit;
            }
419
        }
420

421
        // Use softmax sampling to select worker
422
        let temperature = self.kv_router_config.router_temperature;
423
424
        let best_worker_id = softmax_sample(&worker_logits, temperature);

425
        let overlap_blocks = overlaps.get(&best_worker_id).copied().unwrap_or(0);
426
427
428
        let best_logit = worker_logits[&best_worker_id];

        tracing::info!(
429
            "Selected worker: {}, normalized logit: {:.3}",
430
431
432
            best_worker_id,
            best_logit
        );
433
434

        Ok(WorkerSelectionResult {
435
436
            worker_id: best_worker_id,
            required_blocks: request_blocks as u64,
437
438
            overlap_blocks,
        })
439
440
    }
}
441
442
443
444
445

#[cfg(test)]
mod tests {
    use super::*;

446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    #[test]
    fn test_softmax_sample_single_key() {
        // Test that with a single key, softmax_sample always returns that key
        let mut logits = HashMap::new();
        let worker_id = 42;
        logits.insert(worker_id, 0.5); // The value doesn't matter

        // Test with different temperatures
        for temperature in &[0.1, 1.0, 10.0] {
            let result = softmax_sample(&logits, *temperature);
            assert_eq!(result, worker_id, "Should return the only available worker");
        }

        // Test with different logit values
        logits.clear();
        logits.insert(worker_id, -100.0); // Very negative value
        assert_eq!(softmax_sample(&logits, 1.0), worker_id);

        logits.clear();
        logits.insert(worker_id, 100.0); // Very positive value
        assert_eq!(softmax_sample(&logits, 1.0), worker_id);

        logits.clear();
        logits.insert(worker_id, 0.0); // Zero value
        assert_eq!(softmax_sample(&logits, 1.0), worker_id);
    }

473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
    #[test]
    fn test_softmax_sample_zero_temperature() {
        // Test that with temperature 0, softmax_sample returns the key with smallest logit
        let mut logits = HashMap::new();
        logits.insert(1, 5.0);
        logits.insert(2, 3.0); // This has the smallest logit
        logits.insert(3, 7.0);
        logits.insert(4, 3.5);

        // With temperature 0, should always return worker 2 (smallest logit)
        for _ in 0..10 {
            let result = softmax_sample(&logits, 0.0);
            assert_eq!(
                result, 2,
                "Should return worker with smallest logit when temperature is 0"
488
489
490
            );
        }

491
492
493
494
495
        // Test with negative values
        logits.clear();
        logits.insert(10, -1.0);
        logits.insert(20, -5.0); // This has the smallest logit
        logits.insert(30, 0.0);
496

497
498
        let result = softmax_sample(&logits, 0.0);
        assert_eq!(result, 20, "Should handle negative logits correctly");
499
500
    }
}