scheduler.rs 17.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Neelay Shah's avatar
Neelay Shah committed
16
17
use dynamo_runtime::component::Namespace;
use dynamo_runtime::traits::events::EventPublisher;
18
use rand::Rng;
19
use serde::{Deserialize, Serialize};
20
use std::collections::HashMap;
21
22
23
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
24

25
26
use super::protocols::WorkerSelectionResult;
use super::WorkerSelector;
27
use crate::kv_router::indexer::OverlapScores;
28
29
use crate::kv_router::protocols::LoadMetrics;
use crate::kv_router::sequence::ActiveSequencesMultiWorker;
30
use crate::kv_router::KvRouterConfig;
31
use crate::kv_router::KV_HIT_RATE_SUBJECT;
32
use crate::tokens::TokenBlockSequence;
33
use dynamo_runtime::component::Instance;
34

35
36
37
38
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent {
    pub worker_id: i64,
    pub isl_blocks: usize,
39
    pub overlap_blocks: u32,
40
41
}

42
43
44
45
46
47
48
49
50
51
52
53
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
    #[error("no endpoints aviailable to route work")]
    NoEndpoints,

    #[error("all workers busy")]
    AllWorkersBusy,

    #[error("endpoint subscriber shutdown")]
    SubscriberShutdown,
}

54
55
/// [gluo FIXME] exactly the same as EndpointInfo except that 'data'
/// is cleaned (not optional)
56
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
57
58
59
pub struct Endpoint {
    pub name: String,
    pub subject: String,
60
    pub data: LoadMetrics,
61
62
63
}

impl Endpoint {
GuanLuo's avatar
GuanLuo committed
64
65
    pub fn worker_id(&self) -> i64 {
        i64::from_str_radix(
66
            self.subject
GuanLuo's avatar
GuanLuo committed
67
                .split("-")
68
69
70
71
                .last()
                .expect("invalid subject")
                .to_string()
                .as_str(),
GuanLuo's avatar
GuanLuo committed
72
            16,
73
        )
GuanLuo's avatar
GuanLuo committed
74
        .expect("invalid worker id")
75
76
77
    }
}

78
79
80
#[derive(Debug)]
pub struct SchedulingResponse {
    pub best_worker_id: i64,
81
    pub overlap_blocks: u32, // Add this field
82
83
84
    pub endpoints_changed: Option<Vec<i64>>,
}

85
pub struct SchedulingRequest {
86
    pub isl_tokens: usize,
87
    pub overlaps: OverlapScores,
88
    pub potential_blocks: HashMap<i64, usize>,
89
    pub potential_tokens: HashMap<i64, usize>,
90
    resp_tx: tokio::sync::oneshot::Sender<SchedulingResponse>,
91
92
93
}

impl SchedulingRequest {
94
95
96
    pub fn respond(self, response: SchedulingResponse) {
        if self.resp_tx.send(response).is_err() {
            tracing::error!("failed to send response to requestor");
97
98
99
100
101
102
        }
    }
}

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
103
    sequences: Arc<Mutex<ActiveSequencesMultiWorker>>,
104
105
106
107
}

impl KvScheduler {
    pub async fn start(
108
        ns: Namespace,
109
        block_size: u32,
110
        mut instances_rx: tokio::sync::watch::Receiver<Vec<Instance>>, // Changed from ProcessedEndpoints
111
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
112
    ) -> Result<Self, KvSchedulerError> {
113
        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
114
115
116
117
        let mut instances: Vec<Instance> = instances_rx.borrow_and_update().clone();

        // Get worker IDs from instances
        let worker_ids: Vec<i64> = instances.iter().map(|i| i.instance_id).collect();
118

119
120
121
122
        let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();
        tokio::spawn(async move {
            let mut event_rx = event_rx;
            while let Some(event) = event_rx.recv().await {
123
                if let Err(e) = ns.publish(KV_HIT_RATE_SUBJECT, &event).await {
124
125
126
127
128
                    tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
                }
            }
        });

129
130
        let sequences = Arc::new(Mutex::new(ActiveSequencesMultiWorker::new(
            block_size as usize,
131
            worker_ids,
132
133
        )));

134
        // Channel to accept new scheduling requests
135
        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
136
137
138
139
        // Background task to handle scheduling requests
        tokio::spawn(async move {
            let mut request: SchedulingRequest;
            let mut request_rx = request_rx;
140
            let mut pending_endpoint_update: Option<Vec<i64>> = None;
141
            tracing::trace!("scheduler background task started");
142
143
144
145
146

            'outer: loop {
                request = tokio::select! {
                    biased;

147
148
149
150
                    _ = instances_rx.changed() => {
                        instances = instances_rx.borrow_and_update().clone();
                        let worker_ids: Vec<i64> = instances.iter().map(|i| i.instance_id).collect();
                        pending_endpoint_update = Some(worker_ids);
151
                        continue 'outer;
152
                    }
153
154
155
156
157
158
159
160
161

                    maybe_new_request = request_rx.recv() => {
                        let Some(new_request) = maybe_new_request else {
                            tracing::warn!("scheduler shutdown");
                            break 'outer;
                        };
                        tracing::trace!("received request to be scheduled");
                        new_request
                    }
162
                };
163

164
                loop {
165
166
                    // When calling selector.select_worker, we need to adapt
                    match selector.select_worker(&instances, &request, block_size) {
167
                        Ok(selection) => {
168
169
170
171
172
173
174
175
176
177
                            if let Err(e) = event_tx.send(KVHitRateEvent {
                                worker_id: selection.worker_id,
                                isl_blocks: selection.required_blocks as usize,
                                overlap_blocks: selection.overlap_blocks,
                            }) {
                                tracing::warn!("Failed to send KV hit rate event: {:?}", e);
                            }

                            let response = SchedulingResponse {
                                best_worker_id: selection.worker_id,
178
                                overlap_blocks: selection.overlap_blocks,
179
180
181
                                endpoints_changed: pending_endpoint_update.take(),
                            };
                            request.respond(response);
182
183
                            continue 'outer;
                        }
184
185
                        Err(KvSchedulerError::NoEndpoints) => {
                            tracing::trace!("no endpoints available; waiting for endpoints update");
186
187
188
189
190
                            instances_rx.changed().await.ok();
                            instances = instances_rx.borrow_and_update().clone();
                            let worker_ids: Vec<i64> =
                                instances.iter().map(|i| i.instance_id).collect();
                            pending_endpoint_update = Some(worker_ids);
191
192
                            continue;
                        }
193
                        // TODO: this is not actually hooked up
194
                        Err(KvSchedulerError::AllWorkersBusy) => {
195
                            tracing::trace!("all workers busy; waiting for more capacity");
196
197
                            tokio::time::sleep(Duration::from_millis(5)).await;
                            continue;
198
199
                        }
                        Err(e) => {
200
                            tracing::error!("error scheduling request: {:?}", e);
201
202
203
204
205
206
                            break 'outer;
                        }
                    }
                }
            }

207
            tracing::trace!("background endpoint subscriber shutting down");
208
209
        });

210
211
212
213
        Ok(KvScheduler {
            request_tx,
            sequences,
        })
214
215
216
217
    }

    pub async fn schedule(
        &self,
218
        request_id: String,
219
        isl_tokens: usize,
220
221
        block_size: u32,
        tokens: &[u32],
222
        overlaps: OverlapScores,
GuanLuo's avatar
GuanLuo committed
223
    ) -> Result<i64, KvSchedulerError> {
224
225
226
        let mut sequences = self.sequences.lock().await;

        let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
227
228
        let (potential_blocks, potential_tokens) =
            sequences.potential_blocks_and_tokens(token_sequence, overlaps.clone());
229

230
231
232
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
            isl_tokens,
233
            overlaps,
234
            potential_blocks,
235
            potential_tokens,
236
237
238
239
240
241
            resp_tx,
        };
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
242
        let response = resp_rx
243
244
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
245
246
247
248
249
250

        if let Some(new_worker_ids) = response.endpoints_changed {
            sequences.update_workers(new_worker_ids);
        }

        let token_sequence = TokenBlockSequence::from_slice(tokens, block_size, None);
251
252
253
254
255
256
        sequences.add_request(
            request_id,
            token_sequence,
            response.overlap_blocks,
            response.best_worker_id,
        );
257
258
259
260

        Ok(response.best_worker_id)
    }

261
262
    /// Push tokens to a specific request's sequence
    pub async fn push(&self, request_id: &String, tokens: &[u32]) {
263
        let mut sequences = self.sequences.lock().await;
264
        sequences.push(request_id, tokens)
265
266
    }

267
268
269
270
271
    /// Free all blocks associated with a request
    pub async fn free(&self, request_id: &String) {
        let mut sequences = self.sequences.lock().await;
        sequences.free(request_id)
    }
272
273
}

274
275
276
277
278
279
// Helper function for softmax sampling
fn softmax_sample(logits: &HashMap<i64, f64>, temperature: f64) -> i64 {
    if logits.is_empty() {
        panic!("Empty logits for softmax sampling");
    }

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    // Guard: if temperature is 0, return the key with the smallest logit value
    if temperature == 0.0 {
        // Find the minimum logit value
        let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b));

        // Collect all keys with the minimum logit value (to handle ties)
        let min_keys: Vec<_> = logits
            .iter()
            .filter(|(_, &v)| v == min_logit)
            .map(|(k, _)| *k)
            .collect();

        // Randomly select from the minimum keys (handles single key case naturally)
        let mut rng = rand::rng();
        let index = rng.random_range(0..min_keys.len());
        return min_keys[index];
    }

298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    let keys: Vec<_> = logits.keys().copied().collect();
    let values: Vec<_> = logits.values().copied().collect();

    // Find min and max for normalization
    let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
    let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));

    let probabilities = if min_val == max_val {
        // All values are the same, uniform probability
        vec![1.0 / keys.len() as f64; keys.len()]
    } else {
        // Normalize values
        let normalized: Vec<_> = values
            .iter()
            .map(|&v| {
                let norm = v / (max_val - min_val);
                // Lower is better, so negate
                -norm
            })
            .collect();

        // Apply temperature and softmax
        let scaled: Vec<_> = normalized.iter().map(|&v| v / temperature).collect();

        let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
        let exp_values: Vec<_> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();

        let sum_exp: f64 = exp_values.iter().sum();
        exp_values.iter().map(|&v| v / sum_exp).collect()
    };

    // Sample from the probability distribution
    let mut rng = rand::rng();
    let sample: f64 = rng.random();

    let mut cumsum = 0.0;
    for (i, &prob) in probabilities.iter().enumerate() {
        cumsum += prob;
        if sample <= cumsum {
            return keys[i];
        }
    }

    // Fallback to last key (shouldn't normally reach here)
    keys[keys.len() - 1]
}

345
// Default implementation matching the Python _cost_function
346
347
348
349
350
351
352
353
354
355
356
357
#[derive(Debug, Clone, Default)]
pub struct DefaultWorkerSelector {
    pub kv_router_config: KvRouterConfig,
}

impl DefaultWorkerSelector {
    pub fn new(kv_router_config: Option<KvRouterConfig>) -> Self {
        Self {
            kv_router_config: kv_router_config.unwrap_or_default(),
        }
    }
}
358
359
360
361

impl WorkerSelector for DefaultWorkerSelector {
    fn select_worker(
        &self,
362
        workers: &[Instance],
363
        request: &SchedulingRequest,
364
        block_size: u32,
365
366
367
    ) -> Result<WorkerSelectionResult, KvSchedulerError> {
        assert!(request.isl_tokens > 0);

368
        if workers.is_empty() {
369
370
371
            return Err(KvSchedulerError::NoEndpoints);
        }

372
373
374
375
376
        let isl = request.isl_tokens;
        let request_blocks = isl.div_ceil(block_size as usize);
        let overlaps = &request.overlaps.scores;

        // active blocks for decoding
377
        let potential_active_blocks = &request.potential_blocks;
378
379
        // active tokens in the batch (processed by the linear layers), mostly prefill tokens
        let potential_active_tokens = &request.potential_tokens;
380

381
        let mut worker_logits = HashMap::new();
382
        let mut max_logit = f64::NEG_INFINITY;
383

384
        // Calculate logits for each worker
385
386
        for instance in workers.iter() {
            let worker_id = instance.instance_id;
387
            // this is the number of tokens each worker would have if the request were scheduled there
388
            let potential_tokens = *potential_active_tokens.get(&worker_id).unwrap_or_else(|| {
389
390
391
392
393
                tracing::warn!(
                    "assuming {isl} tokens for {worker_id}, as the endpoint does not exist yet"
                );
                &isl
            }) as f64;
394

395
            // this is the number of blocks each worker would have if the request were scheduled there
396
            let potential_blocks = *potential_active_blocks.get(&worker_id).unwrap_or_else(||
397
398
                {tracing::warn!("assuming {request_blocks} decoding blocks for {worker_id}, as the endpoint does not exist yet");
                &request_blocks
399
            }) as f64;
400

401
402
            let potential_prefill_blocks = potential_tokens / (block_size as f64);

403
            // Calculate logit (lower is better)
404
405
            let logit = self.kv_router_config.overlap_score_weight * potential_prefill_blocks
                + potential_blocks;
406
            max_logit = max_logit.max(logit);
407

408
            worker_logits.insert(worker_id, logit);
409
410

            tracing::info!(
411
                "Formula for {worker_id}: {logit:.3} = {:.1} * {potential_prefill_blocks:.3} + {potential_blocks:.3}  (cached_blocks: {})",
412
                self.kv_router_config.overlap_score_weight,
413
                overlaps.get(&worker_id).unwrap_or(&0),
414
415
416
            );
        }

417
        // Normalize by dividing by max value
418
419
420
421
        if max_logit > 0.0 {
            for logit in worker_logits.values_mut() {
                *logit /= max_logit;
            }
422
        }
423

424
        // Use softmax sampling to select worker
425
        let temperature = self.kv_router_config.router_temperature;
426
427
        let best_worker_id = softmax_sample(&worker_logits, temperature);

428
        let overlap_blocks = overlaps.get(&best_worker_id).copied().unwrap_or(0);
429
430
431
        let best_logit = worker_logits[&best_worker_id];

        tracing::info!(
432
            "Selected worker: {}, normalized logit: {:.3}",
433
434
435
            best_worker_id,
            best_logit
        );
436
437

        Ok(WorkerSelectionResult {
438
439
            worker_id: best_worker_id,
            required_blocks: request_blocks as u64,
440
441
            overlap_blocks,
        })
442
443
    }
}
444
445
446
447
448

#[cfg(test)]
mod tests {
    use super::*;

449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    #[test]
    fn test_softmax_sample_single_key() {
        // Test that with a single key, softmax_sample always returns that key
        let mut logits = HashMap::new();
        let worker_id = 42;
        logits.insert(worker_id, 0.5); // The value doesn't matter

        // Test with different temperatures
        for temperature in &[0.1, 1.0, 10.0] {
            let result = softmax_sample(&logits, *temperature);
            assert_eq!(result, worker_id, "Should return the only available worker");
        }

        // Test with different logit values
        logits.clear();
        logits.insert(worker_id, -100.0); // Very negative value
        assert_eq!(softmax_sample(&logits, 1.0), worker_id);

        logits.clear();
        logits.insert(worker_id, 100.0); // Very positive value
        assert_eq!(softmax_sample(&logits, 1.0), worker_id);

        logits.clear();
        logits.insert(worker_id, 0.0); // Zero value
        assert_eq!(softmax_sample(&logits, 1.0), worker_id);
    }

476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    #[test]
    fn test_softmax_sample_zero_temperature() {
        // Test that with temperature 0, softmax_sample returns the key with smallest logit
        let mut logits = HashMap::new();
        logits.insert(1, 5.0);
        logits.insert(2, 3.0); // This has the smallest logit
        logits.insert(3, 7.0);
        logits.insert(4, 3.5);

        // With temperature 0, should always return worker 2 (smallest logit)
        for _ in 0..10 {
            let result = softmax_sample(&logits, 0.0);
            assert_eq!(
                result, 2,
                "Should return worker with smallest logit when temperature is 0"
491
492
493
            );
        }

494
495
496
497
498
        // Test with negative values
        logits.clear();
        logits.insert(10, -1.0);
        logits.insert(20, -5.0); // This has the smallest logit
        logits.insert(30, 0.0);
499

500
501
        let result = softmax_sample(&logits, 0.0);
        assert_eq!(result, 20, "Should handle negative logits correctly");
502
503
    }
}