scheduler.rs 25.5 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
// SPDX-License-Identifier: Apache-2.0
3

4
use crate::discovery::RuntimeConfigWatch;
5
use crate::local_model::runtime_config::ModelRuntimeConfig;
6
use anyhow::Result;
7
use dynamo_runtime::component::Component;
Yan Ru Pei's avatar
Yan Ru Pei committed
8
use dynamo_runtime::traits::DistributedRuntimeProvider;
9
use dynamo_runtime::transports::event_plane::EventPublisher;
10
use rand::Rng;
11
use serde::{Deserialize, Serialize};
12
use std::collections::{HashMap, HashSet};
13
14
use std::sync::Arc;
use std::time::Duration;
15
16
#[cfg(feature = "bench")]
use std::time::Instant;
17

18
19
use super::KV_HIT_RATE_SUBJECT;
use super::KvRouterConfig;
20
use super::RouterConfigOverride;
21
use super::WorkerSelector;
22
use super::protocols::{DpRank, OverlapScores, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
23
use super::sequence::{ActiveSequencesMultiWorker, SequenceError};
24

25
use dynamo_tokens::SequenceHash;
26

27
28
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent {
Yan Ru Pei's avatar
Yan Ru Pei committed
29
30
31
    pub worker_id: WorkerId,
    #[serde(default)]
    pub dp_rank: DpRank,
32
    pub isl_blocks: usize,
33
    pub overlap_blocks: u32,
34
35
}

36
37
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad {
Yan Ru Pei's avatar
Yan Ru Pei committed
38
39
    pub worker_id: WorkerId,
    pub dp_rank: DpRank,
40
41
42
43
    pub potential_prefill_tokens: usize,
    pub potential_decode_blocks: usize,
}

44
45
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
46
    #[error("no endpoints available to route work")]
47
48
49
50
    NoEndpoints,

    #[error("endpoint subscriber shutdown")]
    SubscriberShutdown,
51
52
53

    #[error("failed to initialize event publisher: {0}")]
    InitFailed(String),
54
55
}

56
57
#[derive(Debug)]
pub struct SchedulingResponse {
Yan Ru Pei's avatar
Yan Ru Pei committed
58
    pub best_worker: WorkerWithDpRank,
59
    pub overlap_blocks: u32,
60
61
}

62
pub struct SchedulingRequest {
Yan Ru Pei's avatar
Yan Ru Pei committed
63
    pub maybe_request_id: Option<String>,
64
    pub token_seq: Option<Vec<SequenceHash>>,
65
    pub isl_tokens: usize,
66
    pub overlaps: OverlapScores,
Yan Ru Pei's avatar
Yan Ru Pei committed
67
68
    pub decode_blocks: HashMap<WorkerWithDpRank, usize>,
    pub prefill_tokens: HashMap<WorkerWithDpRank, usize>,
69
70
    // Router config overrides for this specific request
    pub router_config_override: Option<RouterConfigOverride>,
71
72
    // Whether to update scheduler states (false for query_instance_id requests)
    pub update_states: bool,
73
74
    // LORA adapter name extracted from request.model field
    pub lora_name: Option<String>,
75
76
    // Option to take it out to send the response without moving the struct
    resp_tx: Option<tokio::sync::oneshot::Sender<SchedulingResponse>>,
77
78
79
}

impl SchedulingRequest {
80
81
82
83
84
85
86
87
88
    pub fn respond(&mut self, response: SchedulingResponse) {
        // Changed to &mut self
        if let Some(tx) = self.resp_tx.take() {
            // Use take() to extract the sender
            if tx.send(response).is_err() {
                tracing::error!("failed to send response to requestor");
            }
        } else {
            tracing::error!("respond called multiple times on same request");
89
90
91
92
93
94
        }
    }
}

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
95
    slots: Arc<ActiveSequencesMultiWorker>,
96
97
98
99
}

impl KvScheduler {
    pub async fn start(
100
        component: Component,
101
        block_size: u32,
102
        workers_with_configs: RuntimeConfigWatch,
103
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
104
        replica_sync: bool,
105
        router_id: u64,
106
        worker_type: &'static str,
107
    ) -> Result<Self, KvSchedulerError> {
108
        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
109

110
111
112
113
        // Get initial workers from watch receiver.
        // Caller must ensure at least one worker is present (via wait_for).
        let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
            workers_with_configs.borrow().clone();
114

115
116
117
118
119
120
121
        let slots = Arc::new(
            ActiveSequencesMultiWorker::new(
                component.clone(),
                block_size as usize,
                initial_workers,
                replica_sync,
                router_id,
122
                worker_type,
123
124
125
126
            )
            .await
            .map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?,
        );
127

128
        // Spawn background task to sync slots when the watch value changes.
Yan Ru Pei's avatar
Yan Ru Pei committed
129
        let slots_monitor = slots.clone();
130
        let mut monitor_rx = workers_with_configs.clone();
131
        let monitor_cancel_token = component.drt().child_token();
132
        tokio::spawn(async move {
133
            tracing::trace!("KvScheduler workers monitoring task started");
134
            let mut last_workers: HashMap<WorkerId, ModelRuntimeConfig> = HashMap::new();
135

136
            loop {
Yan Ru Pei's avatar
Yan Ru Pei committed
137
138
                tokio::select! {
                    _ = monitor_cancel_token.cancelled() => {
139
                        tracing::trace!("KvScheduler workers monitoring task shutting down");
140
141
                        break;
                    }
142
                    result = monitor_rx.changed() => {
143
144
145
146
147
                        if result.is_err() {
                            tracing::warn!("KvScheduler: config watch sender dropped, shutting down");
                            break;
                        }
                    }
148
149
                }

150
151
152
153
154
                let current_workers = monitor_rx.borrow_and_update().clone();

                if current_workers != last_workers {
                    slots_monitor.update_workers(current_workers.clone());
                    last_workers = current_workers;
Yan Ru Pei's avatar
Yan Ru Pei committed
155
156
157
158
159
                }
            }
        });

        let slots_clone = slots.clone();
160
        let scheduler_rx = workers_with_configs.clone();
Yan Ru Pei's avatar
Yan Ru Pei committed
161
162
        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
        let scheduler_cancel_token = component.drt().primary_token();
163
164
165
166
        let hit_rate_publisher =
            EventPublisher::for_namespace(component.namespace(), KV_HIT_RATE_SUBJECT)
                .await
                .map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?;
Yan Ru Pei's avatar
Yan Ru Pei committed
167
168
169
170
171
172
173
174
175
176
177

        // Background task to handle scheduling requests
        tokio::spawn(async move {
            let mut request_rx = request_rx;
            tracing::trace!("scheduler background task started");

            loop {
                // Check for cancellation at beginning of loop
                if scheduler_cancel_token.is_cancelled() {
                    tracing::trace!("scheduler background task shutting down");
                    break;
178
179
180
                }

                // Wait for a new request
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
                let Some(mut request) = request_rx.recv().await else {
                    tracing::warn!("scheduler shutdown");
                    break;
                };
                tracing::trace!("received request to be scheduled");

                let (decode_blocks, prefill_tokens) = slots_clone
                    .potential_blocks_and_tokens(
                        request.token_seq.clone(),
                        request.isl_tokens,
                        request.overlaps.clone(),
                    )
                    .await;
                request.decode_blocks = decode_blocks;
                request.prefill_tokens = prefill_tokens;

197
198
                // Read the current workers configuration from watch receiver
                let workers: HashMap<WorkerId, ModelRuntimeConfig> = scheduler_rx.borrow().clone();
Yan Ru Pei's avatar
Yan Ru Pei committed
199
200

                match selector.select_worker(&workers, &request, block_size) {
201
                    Ok(selection) => {
Yan Ru Pei's avatar
Yan Ru Pei committed
202
                        let event = KVHitRateEvent {
Yan Ru Pei's avatar
Yan Ru Pei committed
203
204
                            worker_id: selection.worker.worker_id,
                            dp_rank: selection.worker.dp_rank,
205
206
                            isl_blocks: selection.required_blocks as usize,
                            overlap_blocks: selection.overlap_blocks,
Yan Ru Pei's avatar
Yan Ru Pei committed
207
                        };
208
                        if let Err(e) = hit_rate_publisher.publish(&event).await {
Yan Ru Pei's avatar
Yan Ru Pei committed
209
                            tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
210
                        }
211
212

                        let response = SchedulingResponse {
Yan Ru Pei's avatar
Yan Ru Pei committed
213
                            best_worker: selection.worker,
214
215
216
217
                            overlap_blocks: selection.overlap_blocks,
                        };
                        request.respond(response);

218
219
220
                        // Skip state update if not requested
                        if !request.update_states {
                            continue;
221
                        }
222

Yan Ru Pei's avatar
Yan Ru Pei committed
223
224
225
226
227
228
229
                        let Some(request_id) = request.maybe_request_id else {
                            tracing::error!(
                                "No request_id provided to add_request to the slot tracker"
                            );
                            continue;
                        };

230
231
232
233
234
235
                        if let Err(e) = slots_clone
                            .add_request(
                                request_id.clone(),
                                request.token_seq,
                                request.isl_tokens,
                                selection.overlap_blocks,
236
                                None, // expected_output_tokens not available in scheduler loop
Yan Ru Pei's avatar
Yan Ru Pei committed
237
                                selection.worker,
238
                                request.lora_name.clone(),
239
240
241
                            )
                            .await
                        {
242
                            tracing::warn!("Failed to add request {request_id}: {e}");
243
                        }
244
245
246
247
248
249
250
251
252
                    }
                    Err(KvSchedulerError::NoEndpoints) => {
                        tracing::trace!("no endpoints available; waiting for endpoints update");
                        tokio::time::sleep(Duration::from_millis(5)).await;
                        continue;
                    }
                    Err(e) => {
                        tracing::error!("error scheduling request: {:?}", e);
                        break;
253
254
255
256
                    }
                }
            }

257
            tracing::trace!("background endpoint subscriber shutting down");
258
259
        });

260
        Ok(KvScheduler { request_tx, slots })
261
262
    }

263
    #[allow(clippy::too_many_arguments)]
264
265
    pub async fn schedule(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
266
        maybe_request_id: Option<String>,
267
        isl_tokens: usize,
268
        token_seq: Option<Vec<SequenceHash>>,
269
        overlaps: OverlapScores,
270
        router_config_override: Option<&RouterConfigOverride>,
271
        update_states: bool,
272
        lora_name: Option<String>,
Yan Ru Pei's avatar
Yan Ru Pei committed
273
    ) -> Result<WorkerWithDpRank, KvSchedulerError> {
274
275
276
        #[cfg(feature = "bench")]
        let start = Instant::now();

277
278
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
Yan Ru Pei's avatar
Yan Ru Pei committed
279
            maybe_request_id,
280
            token_seq,
281
            isl_tokens,
282
            overlaps,
283
284
            decode_blocks: HashMap::new(),
            prefill_tokens: HashMap::new(),
285
            router_config_override: router_config_override.cloned(),
286
            update_states,
287
            lora_name,
288
            resp_tx: Some(resp_tx), // Wrap in Some()
289
        };
290

291
292
293
294
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
295
296
297
298

        #[cfg(feature = "bench")]
        let send_elapsed = start.elapsed();

299
        let response = resp_rx
300
301
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
302

303
304
305
306
307
308
309
310
311
312
        #[cfg(feature = "bench")]
        let total_elapsed = start.elapsed();
        #[cfg(feature = "bench")]
        tracing::info!(
            isl_tokens,
            send_us = send_elapsed.as_micros() as u64,
            total_us = total_elapsed.as_micros() as u64,
            "scheduler.schedule completed"
        );

Yan Ru Pei's avatar
Yan Ru Pei committed
313
        Ok(response.best_worker)
314
315
    }

316
    #[allow(clippy::too_many_arguments)]
317
318
319
    pub async fn add_request(
        &self,
        request_id: String,
320
        token_sequence: Option<Vec<SequenceHash>>,
321
322
        isl: usize,
        overlap: u32,
323
        expected_output_tokens: Option<u32>,
Yan Ru Pei's avatar
Yan Ru Pei committed
324
        worker: WorkerWithDpRank,
325
        lora_name: Option<String>,
326
327
    ) -> Result<(), SequenceError> {
        self.slots
328
329
330
331
332
333
334
            .add_request(
                request_id,
                token_sequence,
                isl,
                overlap,
                expected_output_tokens,
                worker,
335
                lora_name,
336
            )
337
            .await
338
339
    }

340
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
341
        self.slots
342
            .mark_prefill_completed(&request_id.to_string())
343
            .await
344
345
    }

346
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
347
        self.slots.free(&request_id.to_string()).await
348
    }
349

350
351
352
353
354
355
    /// Get the worker type for this scheduler ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.slots.worker_type()
    }

356
357
358
359
360
361
362
363
364
365
    pub async fn add_output_block(
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        self.slots
            .add_output_block(&request_id.to_string(), decay_fraction)
            .await
    }

366
367
    pub async fn get_potential_loads(
        &self,
368
        token_seq: Option<Vec<SequenceHash>>,
369
370
371
372
373
374
375
376
        isl_tokens: usize,
        overlaps: OverlapScores,
    ) -> Vec<PotentialLoad> {
        let (decode_blocks, prefill_tokens) = self
            .slots
            .potential_blocks_and_tokens(token_seq, isl_tokens, overlaps)
            .await;

Yan Ru Pei's avatar
Yan Ru Pei committed
377
378
379
380
        // Get all unique WorkerWithDpRank from both hashmaps
        let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
        workers.extend(decode_blocks.keys().copied());
        workers.extend(prefill_tokens.keys().copied());
381
382
383

        // Create PotentialLoad for each worker
        let mut loads = Vec::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
384
        for worker in workers {
385
            loads.push(PotentialLoad {
Yan Ru Pei's avatar
Yan Ru Pei committed
386
387
                worker_id: worker.worker_id,
                dp_rank: worker.dp_rank,
388
                potential_prefill_tokens: prefill_tokens
Yan Ru Pei's avatar
Yan Ru Pei committed
389
                    .get(&worker)
390
391
                    .copied()
                    .unwrap_or(isl_tokens),
Yan Ru Pei's avatar
Yan Ru Pei committed
392
                potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
393
394
395
396
397
            });
        }

        loads
    }
398
399
400
401
402

    /// Get active request counts grouped by LORA name
    pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
        self.slots.get_active_lora_counts()
    }
403
404
}

405
// Helper function for softmax sampling
406
407
408
409
410
// Returns a vec of workers: multiple if tied, single if sampled
fn softmax_sample(
    logits: &HashMap<WorkerWithDpRank, f64>,
    temperature: f64,
) -> Vec<WorkerWithDpRank> {
411
412
413
414
    if logits.is_empty() {
        panic!("Empty logits for softmax sampling");
    }

415
    // Guard: if temperature is 0, return all keys with the smallest logit value (ties)
416
417
418
419
420
421
422
    if temperature == 0.0 {
        // Find the minimum logit value
        let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b));

        // Collect all keys with the minimum logit value (to handle ties)
        let min_keys: Vec<_> = logits
            .iter()
423
            .filter(|&(_, &v)| v == min_logit)
424
425
426
            .map(|(k, _)| *k)
            .collect();

427
        return min_keys;
428
429
    }

430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
    let keys: Vec<_> = logits.keys().copied().collect();
    let values: Vec<_> = logits.values().copied().collect();

    // Find min and max for normalization
    let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
    let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));

    let probabilities = if min_val == max_val {
        // All values are the same, uniform probability
        vec![1.0 / keys.len() as f64; keys.len()]
    } else {
        // Normalize values
        let normalized: Vec<_> = values
            .iter()
            .map(|&v| {
                // Lower is better, so negate
446
447
                // Note we don't need to do actual min-max norm here, just off by an offset
                let norm = v / (max_val - min_val);
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
                -norm
            })
            .collect();

        // Apply temperature and softmax
        let scaled: Vec<_> = normalized.iter().map(|&v| v / temperature).collect();

        let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
        let exp_values: Vec<_> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();

        let sum_exp: f64 = exp_values.iter().sum();
        exp_values.iter().map(|&v| v / sum_exp).collect()
    };

    // Sample from the probability distribution
    let mut rng = rand::rng();
    let sample: f64 = rng.random();

    let mut cumsum = 0.0;
    for (i, &prob) in probabilities.iter().enumerate() {
        cumsum += prob;
        if sample <= cumsum {
470
            return vec![keys[i]];
471
472
473
474
        }
    }

    // Fallback to last key (shouldn't normally reach here)
475
    vec![keys[keys.len() - 1]]
476
477
}

478
// Default implementation matching the Python _cost_function
479
480
481
482
483
484
485
486
487
488
489
490
#[derive(Debug, Clone, Default)]
pub struct DefaultWorkerSelector {
    pub kv_router_config: KvRouterConfig,
}

impl DefaultWorkerSelector {
    pub fn new(kv_router_config: Option<KvRouterConfig>) -> Self {
        Self {
            kv_router_config: kv_router_config.unwrap_or_default(),
        }
    }
}
491
492
493
494

impl WorkerSelector for DefaultWorkerSelector {
    fn select_worker(
        &self,
495
        workers: &HashMap<WorkerId, ModelRuntimeConfig>,
496
        request: &SchedulingRequest,
497
        block_size: u32,
498
499
500
    ) -> Result<WorkerSelectionResult, KvSchedulerError> {
        assert!(request.isl_tokens > 0);

501
        if workers.is_empty() {
502
503
504
            return Err(KvSchedulerError::NoEndpoints);
        }

505
506
507
508
        let isl = request.isl_tokens;
        let request_blocks = isl.div_ceil(block_size as usize);
        let overlaps = &request.overlaps.scores;

509
510
        let decode_blocks = &request.decode_blocks;
        let prefill_tokens = &request.prefill_tokens;
511

512
        let mut worker_logits = HashMap::new();
513

514
515
516
517
518
519
520
        // Use override if provided, otherwise use default config
        let overlap_weight = request
            .router_config_override
            .as_ref()
            .and_then(|cfg| cfg.overlap_score_weight)
            .unwrap_or(self.kv_router_config.overlap_score_weight);

Yan Ru Pei's avatar
Yan Ru Pei committed
521
522
523
524
        // Calculate logits for each worker with dp_rank
        // Outer loop: iterate over all workers from runtime config
        // Inner loop: iterate over all dp_ranks for each worker
        for (worker_id, config) in workers.iter() {
525
            let data_parallel_size = config.data_parallel_size;
Yan Ru Pei's avatar
Yan Ru Pei committed
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555

            for dp_rank in 0..data_parallel_size {
                let worker = WorkerWithDpRank::new(*worker_id, dp_rank);

                // Get overlap for this worker (defaults to 0 if not in overlaps)
                let overlap = *overlaps.get(&worker).unwrap_or(&0);

                // this is the number of prefill tokens the worker would have if the request were scheduled there
                let prefill_token = *prefill_tokens.get(&worker).unwrap_or(&isl);
                let potential_prefill_block = (prefill_token as f64) / (block_size as f64);

                // this is the number of decode blocks the worker would have if the request were scheduled there
                let decode_block = *decode_blocks
                    .get(&worker)
                    .unwrap_or(&(potential_prefill_block.floor() as usize))
                    as f64;

                // Calculate logit (lower is better)
                let logit = overlap_weight * potential_prefill_block + decode_block;

                worker_logits.insert(worker, logit);

                tracing::info!(
                    "Formula for worker_id={} dp_rank={:?} with {overlap} cached blocks: {logit:.3} \
                     = {overlap_weight:.1} * prefill_blocks + decode_blocks \
                     = {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
                    worker.worker_id,
                    worker.dp_rank
                );
            }
556
557
        }

558
        // Use softmax sampling to select worker(s)
559
560
561
562
563
564
        // Use override if provided, otherwise use default config
        let temperature = request
            .router_config_override
            .as_ref()
            .and_then(|cfg| cfg.router_temperature)
            .unwrap_or(self.kv_router_config.router_temperature);
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
        let candidates = softmax_sample(&worker_logits, temperature);

        // If multiple candidates (tied), use tree size as tie-breaker
        // If tree sizes are also equal, min_by_key uses HashMap iteration order (pseudo-random)
        let best_worker = if candidates.len() > 1 {
            tracing::info!("Multiple workers tied with same logit, using tree size as tie-breaker");
            *candidates
                .iter()
                .min_by_key(|worker| {
                    request
                        .overlaps
                        .tree_sizes
                        .get(worker)
                        .copied()
                        .unwrap_or(0)
                })
                .expect("candidates should not be empty")
        } else {
            candidates[0]
        };

Yan Ru Pei's avatar
Yan Ru Pei committed
586
587
588
        let best_logit = worker_logits[&best_worker];

        let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0);
589

Yan Ru Pei's avatar
Yan Ru Pei committed
590
        // this is a runtime config set on a per worker basis, not per dp-rank
591
        let total_blocks_info = workers
Yan Ru Pei's avatar
Yan Ru Pei committed
592
            .get(&best_worker.worker_id)
593
594
595
596
            .and_then(|cfg| cfg.total_kv_blocks)
            .map(|blocks| format!(", total blocks: {}", blocks))
            .unwrap_or_default();

597
598
599
600
601
602
603
        let tree_size = request
            .overlaps
            .tree_sizes
            .get(&best_worker)
            .copied()
            .unwrap_or(0);

604
        tracing::info!(
605
            "Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}, tree size: {}{}",
Yan Ru Pei's avatar
Yan Ru Pei committed
606
607
            best_worker.worker_id,
            best_worker.dp_rank,
608
609
            best_logit,
            best_overlap,
610
            tree_size,
611
            total_blocks_info
612
        );
613
614

        Ok(WorkerSelectionResult {
Yan Ru Pei's avatar
Yan Ru Pei committed
615
            worker: best_worker,
616
            required_blocks: request_blocks as u64,
Yan Ru Pei's avatar
Yan Ru Pei committed
617
            overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0),
618
        })
619
620
    }
}
621
622
623
624
625

#[cfg(test)]
mod tests {
    use super::*;

626
627
628
629
    #[test]
    fn test_softmax_sample_single_key() {
        // Test that with a single key, softmax_sample always returns that key
        let mut logits = HashMap::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
630
631
        let worker = WorkerWithDpRank::from_worker_id(42);
        logits.insert(worker, 0.5); // The value doesn't matter
632
633
634
635

        // Test with different temperatures
        for temperature in &[0.1, 1.0, 10.0] {
            let result = softmax_sample(&logits, *temperature);
636
637
            assert_eq!(result.len(), 1, "Should return exactly one worker");
            assert_eq!(result[0], worker, "Should return the only available worker");
638
639
640
641
        }

        // Test with different logit values
        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
642
        logits.insert(worker, -100.0); // Very negative value
643
644
645
        let result = softmax_sample(&logits, 1.0);
        assert_eq!(result.len(), 1);
        assert_eq!(result[0], worker);
646
647

        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
648
        logits.insert(worker, 100.0); // Very positive value
649
650
651
        let result = softmax_sample(&logits, 1.0);
        assert_eq!(result.len(), 1);
        assert_eq!(result[0], worker);
652
653

        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
654
        logits.insert(worker, 0.0); // Zero value
655
656
657
        let result = softmax_sample(&logits, 1.0);
        assert_eq!(result.len(), 1);
        assert_eq!(result[0], worker);
658
659
    }

660
661
    #[test]
    fn test_softmax_sample_zero_temperature() {
662
        // Test that with temperature 0, softmax_sample returns all keys with smallest logit
663
        let mut logits = HashMap::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
664
665
666
667
668
669
670
671
        let worker1 = WorkerWithDpRank::from_worker_id(1);
        let worker2 = WorkerWithDpRank::from_worker_id(2);
        let worker3 = WorkerWithDpRank::from_worker_id(3);
        let worker4 = WorkerWithDpRank::from_worker_id(4);
        logits.insert(worker1, 5.0);
        logits.insert(worker2, 3.0); // This has the smallest logit
        logits.insert(worker3, 7.0);
        logits.insert(worker4, 3.5);
672

673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
        // With temperature 0, should always return only worker2 (smallest logit)
        let result = softmax_sample(&logits, 0.0);
        assert_eq!(
            result.len(),
            1,
            "Should return one worker when there's no tie"
        );
        assert_eq!(
            result[0], worker2,
            "Should return worker with smallest logit when temperature is 0"
        );

        // Test with tied minimum logits
        logits.clear();
        let worker5 = WorkerWithDpRank::from_worker_id(5);
        let worker6 = WorkerWithDpRank::from_worker_id(6);
        logits.insert(worker1, 5.0);
        logits.insert(worker2, 3.0); // Tied for smallest
        logits.insert(worker5, 3.0); // Tied for smallest
        logits.insert(worker6, 7.0);

        let result = softmax_sample(&logits, 0.0);
        assert_eq!(
            result.len(),
            2,
            "Should return all workers with smallest logit when tied"
        );
        assert!(
            result.contains(&worker2) && result.contains(&worker5),
            "Should contain both tied workers"
        );
704

705
706
        // Test with negative values
        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
707
708
709
710
711
712
        let worker10 = WorkerWithDpRank::from_worker_id(10);
        let worker20 = WorkerWithDpRank::from_worker_id(20);
        let worker30 = WorkerWithDpRank::from_worker_id(30);
        logits.insert(worker10, -1.0);
        logits.insert(worker20, -5.0); // This has the smallest logit
        logits.insert(worker30, 0.0);
713

714
        let result = softmax_sample(&logits, 0.0);
715
716
717
718
719
        assert_eq!(result.len(), 1);
        assert_eq!(
            result[0], worker20,
            "Should handle negative logits correctly"
        );
720
721
    }
}