scheduler.rs 25.9 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
// SPDX-License-Identifier: Apache-2.0
3

4
use crate::discovery::RuntimeConfigWatch;
5
use crate::local_model::runtime_config::ModelRuntimeConfig;
6
use anyhow::Result;
7
use dynamo_runtime::component::Component;
Yan Ru Pei's avatar
Yan Ru Pei committed
8
use dynamo_runtime::traits::DistributedRuntimeProvider;
9
use rand::Rng;
10
use serde::{Deserialize, Serialize};
11
use std::collections::{HashMap, HashSet};
12
13
use std::sync::Arc;
use std::time::Duration;
14
15
#[cfg(feature = "bench")]
use std::time::Instant;
16
use tokio::sync::Notify;
17

18
use super::KvRouterConfig;
19
use super::RouterConfigOverride;
20
use super::WorkerSelector;
21
use super::protocols::{DpRank, OverlapScores, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
22
23
use super::queue::SchedulerQueue;
use super::sequence::{ActiveSequencesMultiWorker, SequenceError, SequenceRequest};
24

25
use dynamo_tokens::SequenceHash;
26

27
28
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad {
Yan Ru Pei's avatar
Yan Ru Pei committed
29
30
    pub worker_id: WorkerId,
    pub dp_rank: DpRank,
31
32
33
34
    pub potential_prefill_tokens: usize,
    pub potential_decode_blocks: usize,
}

35
36
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
37
    #[error("no endpoints available to route work")]
38
39
40
41
    NoEndpoints,

    #[error("endpoint subscriber shutdown")]
    SubscriberShutdown,
42
43
44

    #[error("failed to initialize event publisher: {0}")]
    InitFailed(String),
45
46
}

47
48
#[derive(Debug)]
pub struct SchedulingResponse {
Yan Ru Pei's avatar
Yan Ru Pei committed
49
    pub best_worker: WorkerWithDpRank,
50
    pub overlap_blocks: u32,
51
52
}

53
pub struct SchedulingRequest {
Yan Ru Pei's avatar
Yan Ru Pei committed
54
    pub maybe_request_id: Option<String>,
55
    pub token_seq: Option<Vec<SequenceHash>>,
56
    pub isl_tokens: usize,
57
    pub overlaps: OverlapScores,
Yan Ru Pei's avatar
Yan Ru Pei committed
58
59
    pub decode_blocks: HashMap<WorkerWithDpRank, usize>,
    pub prefill_tokens: HashMap<WorkerWithDpRank, usize>,
60
61
    // Router config overrides for this specific request
    pub router_config_override: Option<RouterConfigOverride>,
62
63
    // Whether to update scheduler states (false for query_instance_id requests)
    pub update_states: bool,
64
65
    // LORA adapter name extracted from request.model field
    pub lora_name: Option<String>,
66
67
    /// Priority jump in seconds; decreases effective arrival time in the queue.
    pub priority_jump: f64,
68
69
    // Option to take it out to send the response without moving the struct
    resp_tx: Option<tokio::sync::oneshot::Sender<SchedulingResponse>>,
70
71
72
}

impl SchedulingRequest {
73
74
75
76
77
78
79
80
81
    pub fn respond(&mut self, response: SchedulingResponse) {
        // Changed to &mut self
        if let Some(tx) = self.resp_tx.take() {
            // Use take() to extract the sender
            if tx.send(response).is_err() {
                tracing::error!("failed to send response to requestor");
            }
        } else {
            tracing::error!("respond called multiple times on same request");
82
83
84
85
86
87
        }
    }
}

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
88
    slots: Arc<ActiveSequencesMultiWorker>,
89
    queue: Arc<SchedulerQueue>,
90
91
92
}

impl KvScheduler {
93
    #[allow(clippy::too_many_arguments)]
94
    pub async fn start(
95
        component: Component,
96
        block_size: u32,
97
        workers_with_configs: RuntimeConfigWatch,
98
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
99
        replica_sync: bool,
100
        router_id: u64,
101
        worker_type: &'static str,
102
        queue_threshold: Option<f64>,
103
    ) -> Result<Self, KvSchedulerError> {
104
        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
105

106
107
108
109
        // Get initial workers from watch receiver.
        // Caller must ensure at least one worker is present (via wait_for).
        let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
            workers_with_configs.borrow().clone();
110

111
112
113
114
115
116
117
        let slots = Arc::new(
            ActiveSequencesMultiWorker::new(
                component.clone(),
                block_size as usize,
                initial_workers,
                replica_sync,
                router_id,
118
                worker_type,
119
120
121
122
            )
            .await
            .map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?,
        );
123

124
        // Spawn background task to sync slots when the watch value changes.
Yan Ru Pei's avatar
Yan Ru Pei committed
125
        let slots_monitor = slots.clone();
126
        let mut monitor_rx = workers_with_configs.clone();
127
        let monitor_cancel_token = component.drt().child_token();
128
        tokio::spawn(async move {
129
            tracing::trace!("KvScheduler workers monitoring task started");
130
            let mut last_workers: HashMap<WorkerId, ModelRuntimeConfig> = HashMap::new();
131

132
            loop {
Yan Ru Pei's avatar
Yan Ru Pei committed
133
134
                tokio::select! {
                    _ = monitor_cancel_token.cancelled() => {
135
                        tracing::trace!("KvScheduler workers monitoring task shutting down");
136
137
                        break;
                    }
138
                    result = monitor_rx.changed() => {
139
140
141
142
143
                        if result.is_err() {
                            tracing::warn!("KvScheduler: config watch sender dropped, shutting down");
                            break;
                        }
                    }
144
145
                }

146
147
148
149
150
                let current_workers = monitor_rx.borrow_and_update().clone();

                if current_workers != last_workers {
                    slots_monitor.update_workers(current_workers.clone());
                    last_workers = current_workers;
Yan Ru Pei's avatar
Yan Ru Pei committed
151
152
153
154
155
                }
            }
        });

        let slots_clone = slots.clone();
156
        let scheduler_rx = workers_with_configs.clone();
Yan Ru Pei's avatar
Yan Ru Pei committed
157
158
159
        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
        let scheduler_cancel_token = component.drt().primary_token();

160
161
162
163
164
165
166
167
168
169
        // Create queue with shared notify for waking the scheduler loop
        let ready_notify = Arc::new(Notify::new());
        let queue = Arc::new(SchedulerQueue::new(
            slots.clone(),
            workers_with_configs.clone(),
            ready_notify.clone(),
            queue_threshold,
        ));
        let queue_clone = queue.clone();

Yan Ru Pei's avatar
Yan Ru Pei committed
170
171
172
        // Background task to handle scheduling requests
        tokio::spawn(async move {
            let mut request_rx = request_rx;
173
            let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
Yan Ru Pei's avatar
Yan Ru Pei committed
174
175
176
            tracing::trace!("scheduler background task started");

            loop {
177
178
179
180
181
182
183
184
185
186
                // Use select! to wait on: new request, ready_notify, periodic recheck, or cancellation
                tokio::select! {
                    _ = scheduler_cancel_token.cancelled() => {
                        tracing::trace!("scheduler background task shutting down");
                        break;
                    }
                    request = request_rx.recv() => {
                        let Some(request) = request else {
                            tracing::warn!("scheduler shutdown");
                            break;
187
                        };
188
189
190
191
192
193
194
195
196
197
198
                        tracing::trace!("received request to be scheduled");
                        queue_clone.enqueue(request).await;
                    }
                    _ = ready_notify.notified() => {
                        // Woken by update() after prefill_complete/free - just continue to drain ready queue
                    }
                    _ = recheck_interval.tick() => {
                        // Periodic recheck to prevent requests stuck in pending
                        queue_clone.update().await;
                    }
                }
199

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
                // Drain ALL ready requests (each iteration uses fresh slot state)
                while let Some(mut request) = queue_clone.try_dequeue().await {
                    let (decode_blocks, prefill_tokens) = slots_clone
                        .potential_blocks_and_tokens(
                            request.token_seq.clone(),
                            request.isl_tokens,
                            request.overlaps.clone(),
                        )
                        .await;
                    request.decode_blocks = decode_blocks;
                    request.prefill_tokens = prefill_tokens;

                    // Read the current workers configuration from watch receiver
                    let workers: HashMap<WorkerId, ModelRuntimeConfig> =
                        scheduler_rx.borrow().clone();

                    match selector.select_worker(&workers, &request, block_size) {
                        Ok(selection) => {
                            let response = SchedulingResponse {
                                best_worker: selection.worker,
                                overlap_blocks: selection.overlap_blocks,
                            };
                            request.respond(response);

                            // Skip state update if not requested
                            if !request.update_states {
                                continue;
                            }

                            let Some(request_id) = request.maybe_request_id else {
                                tracing::error!(
                                    "No request_id provided to add_request to the slot tracker"
                                );
                                continue;
                            };

                            if let Err(e) = slots_clone
                                .add_request(SequenceRequest {
                                    request_id: request_id.clone(),
                                    token_sequence: request.token_seq,
                                    isl: request.isl_tokens,
                                    overlap: selection.overlap_blocks,
                                    expected_output_tokens: None,
                                    worker: selection.worker,
                                    lora_name: request.lora_name.clone(),
                                })
                                .await
                            {
                                tracing::warn!("Failed to add request {request_id}: {e}");
                            }
250
                        }
251
252
253
254
255
                        Err(KvSchedulerError::NoEndpoints) => {
                            tracing::warn!("no endpoints available, dropping request");
                        }
                        Err(e) => {
                            tracing::error!("error scheduling request: {:?}", e);
256
                        }
257
258
259
260
                    }
                }
            }

261
            tracing::trace!("background endpoint subscriber shutting down");
262
263
        });

264
265
266
267
268
        Ok(KvScheduler {
            request_tx,
            slots,
            queue,
        })
269
270
    }

271
    #[allow(clippy::too_many_arguments)]
272
273
    pub async fn schedule(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
274
        maybe_request_id: Option<String>,
275
        isl_tokens: usize,
276
        token_seq: Option<Vec<SequenceHash>>,
277
        overlaps: OverlapScores,
278
        router_config_override: Option<&RouterConfigOverride>,
279
        update_states: bool,
280
        lora_name: Option<String>,
281
        priority_jump: f64,
Yan Ru Pei's avatar
Yan Ru Pei committed
282
    ) -> Result<WorkerWithDpRank, KvSchedulerError> {
283
284
285
        #[cfg(feature = "bench")]
        let start = Instant::now();

286
287
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
Yan Ru Pei's avatar
Yan Ru Pei committed
288
            maybe_request_id,
289
            token_seq,
290
            isl_tokens,
291
            overlaps,
292
293
            decode_blocks: HashMap::new(),
            prefill_tokens: HashMap::new(),
294
            router_config_override: router_config_override.cloned(),
295
            update_states,
296
            lora_name,
297
298
            priority_jump,
            resp_tx: Some(resp_tx),
299
        };
300

301
302
303
304
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
305
306
307
308

        #[cfg(feature = "bench")]
        let send_elapsed = start.elapsed();

309
        let response = resp_rx
310
311
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
312

313
314
315
316
317
318
319
320
321
322
        #[cfg(feature = "bench")]
        let total_elapsed = start.elapsed();
        #[cfg(feature = "bench")]
        tracing::info!(
            isl_tokens,
            send_us = send_elapsed.as_micros() as u64,
            total_us = total_elapsed.as_micros() as u64,
            "scheduler.schedule completed"
        );

Yan Ru Pei's avatar
Yan Ru Pei committed
323
        Ok(response.best_worker)
324
325
    }

326
327
    pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
        self.slots.add_request(req).await
328
329
    }

330
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
331
        self.slots
332
            .mark_prefill_completed(&request_id.to_string())
333
334
335
            .await?;
        self.queue.update().await;
        Ok(())
336
337
    }

338
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
339
340
341
        self.slots.free(&request_id.to_string()).await?;
        self.queue.update().await;
        Ok(())
342
    }
343

344
345
346
347
348
349
    /// Get the worker type for this scheduler ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.slots.worker_type()
    }

350
351
352
353
354
355
356
357
358
359
    pub async fn add_output_block(
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        self.slots
            .add_output_block(&request_id.to_string(), decay_fraction)
            .await
    }

360
361
    pub async fn get_potential_loads(
        &self,
362
        token_seq: Option<Vec<SequenceHash>>,
363
364
365
366
367
368
369
370
        isl_tokens: usize,
        overlaps: OverlapScores,
    ) -> Vec<PotentialLoad> {
        let (decode_blocks, prefill_tokens) = self
            .slots
            .potential_blocks_and_tokens(token_seq, isl_tokens, overlaps)
            .await;

Yan Ru Pei's avatar
Yan Ru Pei committed
371
372
373
374
        // Get all unique WorkerWithDpRank from both hashmaps
        let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
        workers.extend(decode_blocks.keys().copied());
        workers.extend(prefill_tokens.keys().copied());
375
376
377

        // Create PotentialLoad for each worker
        let mut loads = Vec::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
378
        for worker in workers {
379
            loads.push(PotentialLoad {
Yan Ru Pei's avatar
Yan Ru Pei committed
380
381
                worker_id: worker.worker_id,
                dp_rank: worker.dp_rank,
382
                potential_prefill_tokens: prefill_tokens
Yan Ru Pei's avatar
Yan Ru Pei committed
383
                    .get(&worker)
384
385
                    .copied()
                    .unwrap_or(isl_tokens),
Yan Ru Pei's avatar
Yan Ru Pei committed
386
                potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
387
388
389
390
391
            });
        }

        loads
    }
392
393
394
395
396

    /// Get active request counts grouped by LORA name
    pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
        self.slots.get_active_lora_counts()
    }
397
398
}

399
// Helper function for softmax sampling
400
401
402
403
404
// Returns a vec of workers: multiple if tied, single if sampled
fn softmax_sample(
    logits: &HashMap<WorkerWithDpRank, f64>,
    temperature: f64,
) -> Vec<WorkerWithDpRank> {
405
406
407
408
    if logits.is_empty() {
        panic!("Empty logits for softmax sampling");
    }

409
    // Guard: if temperature is 0, return all keys with the smallest logit value (ties)
410
411
412
413
414
415
416
    if temperature == 0.0 {
        // Find the minimum logit value
        let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b));

        // Collect all keys with the minimum logit value (to handle ties)
        let min_keys: Vec<_> = logits
            .iter()
417
            .filter(|&(_, &v)| v == min_logit)
418
419
420
            .map(|(k, _)| *k)
            .collect();

421
        return min_keys;
422
423
    }

424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    let keys: Vec<_> = logits.keys().copied().collect();
    let values: Vec<_> = logits.values().copied().collect();

    // Find min and max for normalization
    let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
    let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));

    let probabilities = if min_val == max_val {
        // All values are the same, uniform probability
        vec![1.0 / keys.len() as f64; keys.len()]
    } else {
        // Normalize values
        let normalized: Vec<_> = values
            .iter()
            .map(|&v| {
                // Lower is better, so negate
440
441
                // Note we don't need to do actual min-max norm here, just off by an offset
                let norm = v / (max_val - min_val);
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
                -norm
            })
            .collect();

        // Apply temperature and softmax
        let scaled: Vec<_> = normalized.iter().map(|&v| v / temperature).collect();

        let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
        let exp_values: Vec<_> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();

        let sum_exp: f64 = exp_values.iter().sum();
        exp_values.iter().map(|&v| v / sum_exp).collect()
    };

    // Sample from the probability distribution
    let mut rng = rand::rng();
    let sample: f64 = rng.random();

    let mut cumsum = 0.0;
    for (i, &prob) in probabilities.iter().enumerate() {
        cumsum += prob;
        if sample <= cumsum {
464
            return vec![keys[i]];
465
466
467
468
        }
    }

    // Fallback to last key (shouldn't normally reach here)
469
    vec![keys[keys.len() - 1]]
470
471
}

472
// Default implementation matching the Python _cost_function
473
474
475
476
477
478
479
480
481
482
483
484
#[derive(Debug, Clone, Default)]
pub struct DefaultWorkerSelector {
    pub kv_router_config: KvRouterConfig,
}

impl DefaultWorkerSelector {
    pub fn new(kv_router_config: Option<KvRouterConfig>) -> Self {
        Self {
            kv_router_config: kv_router_config.unwrap_or_default(),
        }
    }
}
485
486
487
488

impl WorkerSelector for DefaultWorkerSelector {
    fn select_worker(
        &self,
489
        workers: &HashMap<WorkerId, ModelRuntimeConfig>,
490
        request: &SchedulingRequest,
491
        block_size: u32,
492
493
494
    ) -> Result<WorkerSelectionResult, KvSchedulerError> {
        assert!(request.isl_tokens > 0);

495
        if workers.is_empty() {
496
497
498
            return Err(KvSchedulerError::NoEndpoints);
        }

499
500
501
502
        let isl = request.isl_tokens;
        let request_blocks = isl.div_ceil(block_size as usize);
        let overlaps = &request.overlaps.scores;

503
504
        let decode_blocks = &request.decode_blocks;
        let prefill_tokens = &request.prefill_tokens;
505

506
        let mut worker_logits = HashMap::new();
507

508
509
510
511
512
513
514
        // Use override if provided, otherwise use default config
        let overlap_weight = request
            .router_config_override
            .as_ref()
            .and_then(|cfg| cfg.overlap_score_weight)
            .unwrap_or(self.kv_router_config.overlap_score_weight);

Yan Ru Pei's avatar
Yan Ru Pei committed
515
516
517
518
        // Calculate logits for each worker with dp_rank
        // Outer loop: iterate over all workers from runtime config
        // Inner loop: iterate over all dp_ranks for each worker
        for (worker_id, config) in workers.iter() {
519
            let data_parallel_size = config.data_parallel_size;
Yan Ru Pei's avatar
Yan Ru Pei committed
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549

            for dp_rank in 0..data_parallel_size {
                let worker = WorkerWithDpRank::new(*worker_id, dp_rank);

                // Get overlap for this worker (defaults to 0 if not in overlaps)
                let overlap = *overlaps.get(&worker).unwrap_or(&0);

                // this is the number of prefill tokens the worker would have if the request were scheduled there
                let prefill_token = *prefill_tokens.get(&worker).unwrap_or(&isl);
                let potential_prefill_block = (prefill_token as f64) / (block_size as f64);

                // this is the number of decode blocks the worker would have if the request were scheduled there
                let decode_block = *decode_blocks
                    .get(&worker)
                    .unwrap_or(&(potential_prefill_block.floor() as usize))
                    as f64;

                // Calculate logit (lower is better)
                let logit = overlap_weight * potential_prefill_block + decode_block;

                worker_logits.insert(worker, logit);

                tracing::info!(
                    "Formula for worker_id={} dp_rank={:?} with {overlap} cached blocks: {logit:.3} \
                     = {overlap_weight:.1} * prefill_blocks + decode_blocks \
                     = {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
                    worker.worker_id,
                    worker.dp_rank
                );
            }
550
551
        }

552
        // Use softmax sampling to select worker(s)
553
554
555
556
557
558
        // Use override if provided, otherwise use default config
        let temperature = request
            .router_config_override
            .as_ref()
            .and_then(|cfg| cfg.router_temperature)
            .unwrap_or(self.kv_router_config.router_temperature);
559
560
561
        let candidates = softmax_sample(&worker_logits, temperature);

        // If multiple candidates (tied), use tree size as tie-breaker
562
        // If tree sizes are also equal, use random selection to avoid bias
563
564
        let best_worker = if candidates.len() > 1 {
            tracing::info!("Multiple workers tied with same logit, using tree size as tie-breaker");
565
            let tree_sizes: Vec<(usize, &WorkerWithDpRank)> = candidates
566
                .iter()
567
568
569
570
571
572
573
574
575
                .map(|w| (request.overlaps.tree_sizes.get(w).copied().unwrap_or(0), w))
                .collect();

            if tree_sizes.iter().all(|(s, _)| *s == tree_sizes[0].0) {
                let idx = rand::rng().random_range(0..candidates.len());
                candidates[idx]
            } else {
                *tree_sizes.iter().min_by_key(|(s, _)| *s).unwrap().1
            }
576
577
578
579
        } else {
            candidates[0]
        };

Yan Ru Pei's avatar
Yan Ru Pei committed
580
581
582
        let best_logit = worker_logits[&best_worker];

        let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0);
583

Yan Ru Pei's avatar
Yan Ru Pei committed
584
        // this is a runtime config set on a per worker basis, not per dp-rank
585
        let total_blocks_info = workers
Yan Ru Pei's avatar
Yan Ru Pei committed
586
            .get(&best_worker.worker_id)
587
588
589
590
            .and_then(|cfg| cfg.total_kv_blocks)
            .map(|blocks| format!(", total blocks: {}", blocks))
            .unwrap_or_default();

591
592
593
594
595
596
597
        let tree_size = request
            .overlaps
            .tree_sizes
            .get(&best_worker)
            .copied()
            .unwrap_or(0);

598
        tracing::info!(
599
            "Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}, tree size: {}{}",
Yan Ru Pei's avatar
Yan Ru Pei committed
600
601
            best_worker.worker_id,
            best_worker.dp_rank,
602
603
            best_logit,
            best_overlap,
604
            tree_size,
605
            total_blocks_info
606
        );
607
608

        Ok(WorkerSelectionResult {
Yan Ru Pei's avatar
Yan Ru Pei committed
609
            worker: best_worker,
610
            required_blocks: request_blocks as u64,
Yan Ru Pei's avatar
Yan Ru Pei committed
611
            overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0),
612
        })
613
614
    }
}
615
616
617
618
619

#[cfg(test)]
mod tests {
    use super::*;

620
621
622
623
    #[test]
    fn test_softmax_sample_single_key() {
        // Test that with a single key, softmax_sample always returns that key
        let mut logits = HashMap::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
624
625
        let worker = WorkerWithDpRank::from_worker_id(42);
        logits.insert(worker, 0.5); // The value doesn't matter
626
627
628
629

        // Test with different temperatures
        for temperature in &[0.1, 1.0, 10.0] {
            let result = softmax_sample(&logits, *temperature);
630
631
            assert_eq!(result.len(), 1, "Should return exactly one worker");
            assert_eq!(result[0], worker, "Should return the only available worker");
632
633
634
635
        }

        // Test with different logit values
        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
636
        logits.insert(worker, -100.0); // Very negative value
637
638
639
        let result = softmax_sample(&logits, 1.0);
        assert_eq!(result.len(), 1);
        assert_eq!(result[0], worker);
640
641

        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
642
        logits.insert(worker, 100.0); // Very positive value
643
644
645
        let result = softmax_sample(&logits, 1.0);
        assert_eq!(result.len(), 1);
        assert_eq!(result[0], worker);
646
647

        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
648
        logits.insert(worker, 0.0); // Zero value
649
650
651
        let result = softmax_sample(&logits, 1.0);
        assert_eq!(result.len(), 1);
        assert_eq!(result[0], worker);
652
653
    }

654
655
    #[test]
    fn test_softmax_sample_zero_temperature() {
656
        // Test that with temperature 0, softmax_sample returns all keys with smallest logit
657
        let mut logits = HashMap::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
658
659
660
661
662
663
664
665
        let worker1 = WorkerWithDpRank::from_worker_id(1);
        let worker2 = WorkerWithDpRank::from_worker_id(2);
        let worker3 = WorkerWithDpRank::from_worker_id(3);
        let worker4 = WorkerWithDpRank::from_worker_id(4);
        logits.insert(worker1, 5.0);
        logits.insert(worker2, 3.0); // This has the smallest logit
        logits.insert(worker3, 7.0);
        logits.insert(worker4, 3.5);
666

667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
        // With temperature 0, should always return only worker2 (smallest logit)
        let result = softmax_sample(&logits, 0.0);
        assert_eq!(
            result.len(),
            1,
            "Should return one worker when there's no tie"
        );
        assert_eq!(
            result[0], worker2,
            "Should return worker with smallest logit when temperature is 0"
        );

        // Test with tied minimum logits
        logits.clear();
        let worker5 = WorkerWithDpRank::from_worker_id(5);
        let worker6 = WorkerWithDpRank::from_worker_id(6);
        logits.insert(worker1, 5.0);
        logits.insert(worker2, 3.0); // Tied for smallest
        logits.insert(worker5, 3.0); // Tied for smallest
        logits.insert(worker6, 7.0);

        let result = softmax_sample(&logits, 0.0);
        assert_eq!(
            result.len(),
            2,
            "Should return all workers with smallest logit when tied"
        );
        assert!(
            result.contains(&worker2) && result.contains(&worker5),
            "Should contain both tied workers"
        );
698

699
700
        // Test with negative values
        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
701
702
703
704
705
706
        let worker10 = WorkerWithDpRank::from_worker_id(10);
        let worker20 = WorkerWithDpRank::from_worker_id(20);
        let worker30 = WorkerWithDpRank::from_worker_id(30);
        logits.insert(worker10, -1.0);
        logits.insert(worker20, -5.0); // This has the smallest logit
        logits.insert(worker30, 0.0);
707

708
        let result = softmax_sample(&logits, 0.0);
709
710
711
712
713
        assert_eq!(result.len(), 1);
        assert_eq!(
            result[0], worker20,
            "Should handle negative logits correctly"
        );
714
715
    }
}