scheduler.rs 22.3 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
// SPDX-License-Identifier: Apache-2.0
3

4
5
6
7
8
9
use super::KvRouterConfig;
use super::RouterConfigOverride;
use super::WorkerSelector;
use super::protocols::{DpRank, OverlapScores, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::queue::SchedulerQueue;
use super::sequence::{ActiveSequencesMultiWorker, SequenceError, SequenceRequest};
10
use crate::discovery::RuntimeConfigWatch;
11
use crate::local_model::runtime_config::ModelRuntimeConfig;
12
use anyhow::Result;
13
use dynamo_runtime::component::Component;
Yan Ru Pei's avatar
Yan Ru Pei committed
14
use dynamo_runtime::traits::DistributedRuntimeProvider;
15
use rand::Rng;
16
use serde::{Deserialize, Serialize};
17
use std::collections::{HashMap, HashSet};
18
19
use std::sync::Arc;
use std::time::Duration;
20
21
#[cfg(feature = "bench")]
use std::time::Instant;
22

23
use dynamo_tokens::SequenceHash;
24

25
26
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad {
Yan Ru Pei's avatar
Yan Ru Pei committed
27
28
    pub worker_id: WorkerId,
    pub dp_rank: DpRank,
29
30
31
32
    pub potential_prefill_tokens: usize,
    pub potential_decode_blocks: usize,
}

33
34
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
35
    #[error("no endpoints available to route work")]
36
37
38
39
    NoEndpoints,

    #[error("endpoint subscriber shutdown")]
    SubscriberShutdown,
40
41
42

    #[error("failed to initialize event publisher: {0}")]
    InitFailed(String),
43
44
}

45
46
#[derive(Debug)]
pub struct SchedulingResponse {
Yan Ru Pei's avatar
Yan Ru Pei committed
47
    pub best_worker: WorkerWithDpRank,
48
    pub overlap_blocks: u32,
49
50
}

51
pub struct SchedulingRequest {
Yan Ru Pei's avatar
Yan Ru Pei committed
52
    pub maybe_request_id: Option<String>,
53
    pub token_seq: Option<Vec<SequenceHash>>,
54
    pub isl_tokens: usize,
55
    pub overlaps: OverlapScores,
Yan Ru Pei's avatar
Yan Ru Pei committed
56
57
    pub decode_blocks: HashMap<WorkerWithDpRank, usize>,
    pub prefill_tokens: HashMap<WorkerWithDpRank, usize>,
58
59
    // Router config overrides for this specific request
    pub router_config_override: Option<RouterConfigOverride>,
60
61
    // Whether to update scheduler states (false for query_instance_id requests)
    pub update_states: bool,
62
63
    // LORA adapter name extracted from request.model field
    pub lora_name: Option<String>,
64
65
    /// Priority jump in seconds; decreases effective arrival time in the queue.
    pub priority_jump: f64,
66
67
    /// Optional set of allowed worker IDs to restrict routing decisions (EPP).
    pub allowed_worker_ids: Option<HashSet<WorkerId>>,
68
    resp_tx: Option<tokio::sync::oneshot::Sender<Result<SchedulingResponse, KvSchedulerError>>>,
69
70
71
}

impl SchedulingRequest {
72
73
    pub fn respond(&mut self, result: Result<SchedulingResponse, KvSchedulerError>) {
        let Some(tx) = self.resp_tx.take() else {
74
            tracing::error!("respond called multiple times on same request");
75
76
77
78
            return;
        };
        if tx.send(result).is_err() {
            tracing::error!("failed to send response to requestor");
79
80
81
82
83
84
        }
    }
}

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
85
    slots: Arc<ActiveSequencesMultiWorker>,
86
    queue: Arc<SchedulerQueue>,
87
88
89
90
}

impl KvScheduler {
    pub async fn start(
91
        component: Component,
92
        block_size: u32,
93
        workers_with_configs: RuntimeConfigWatch,
94
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
95
        kv_router_config: &KvRouterConfig,
96
        worker_type: &'static str,
97
    ) -> Result<Self, KvSchedulerError> {
98
        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
99

100
101
102
103
        // Get initial workers from watch receiver.
        // Caller must ensure at least one worker is present (via wait_for).
        let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
            workers_with_configs.borrow().clone();
104

105
        let router_id = component.drt().discovery().instance_id();
106
107
108
109
110
        let slots = Arc::new(
            ActiveSequencesMultiWorker::new(
                component.clone(),
                block_size as usize,
                initial_workers,
111
                kv_router_config.router_replica_sync,
112
                router_id,
113
                worker_type,
114
115
116
117
            )
            .await
            .map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?,
        );
118

119
        // Spawn background task to sync slots when the watch value changes.
Yan Ru Pei's avatar
Yan Ru Pei committed
120
        let slots_monitor = slots.clone();
121
        let mut monitor_rx = workers_with_configs.clone();
122
        let monitor_cancel_token = component.drt().child_token();
123
        tokio::spawn(async move {
124
            tracing::trace!("KvScheduler workers monitoring task started");
125
            let mut last_workers: HashMap<WorkerId, ModelRuntimeConfig> = HashMap::new();
126

127
            loop {
Yan Ru Pei's avatar
Yan Ru Pei committed
128
129
                tokio::select! {
                    _ = monitor_cancel_token.cancelled() => {
130
                        tracing::trace!("KvScheduler workers monitoring task shutting down");
131
132
                        break;
                    }
133
                    result = monitor_rx.changed() => {
134
135
136
137
138
                        if result.is_err() {
                            tracing::warn!("KvScheduler: config watch sender dropped, shutting down");
                            break;
                        }
                    }
139
140
                }

141
142
143
144
145
                let current_workers = monitor_rx.borrow_and_update().clone();

                if current_workers != last_workers {
                    slots_monitor.update_workers(current_workers.clone());
                    last_workers = current_workers;
Yan Ru Pei's avatar
Yan Ru Pei committed
146
147
148
149
150
151
152
                }
            }
        });

        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
        let scheduler_cancel_token = component.drt().primary_token();

153
154
155
        let queue = Arc::new(SchedulerQueue::new(
            slots.clone(),
            workers_with_configs.clone(),
156
            kv_router_config.router_queue_threshold,
157
158
            block_size,
            selector,
159
160
161
        ));
        let queue_clone = queue.clone();

162
        // Background task: receive requests and periodically recheck pending
Yan Ru Pei's avatar
Yan Ru Pei committed
163
164
        tokio::spawn(async move {
            let mut request_rx = request_rx;
165
            let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
Yan Ru Pei's avatar
Yan Ru Pei committed
166
167
168
            tracing::trace!("scheduler background task started");

            loop {
169
170
171
172
173
174
175
176
177
                tokio::select! {
                    _ = scheduler_cancel_token.cancelled() => {
                        tracing::trace!("scheduler background task shutting down");
                        break;
                    }
                    request = request_rx.recv() => {
                        let Some(request) = request else {
                            tracing::warn!("scheduler shutdown");
                            break;
178
                        };
179
180
181
182
183
184
185
                        tracing::trace!("received request to be scheduled");
                        queue_clone.enqueue(request).await;
                    }
                    _ = recheck_interval.tick() => {
                        queue_clone.update().await;
                    }
                }
186
187
            }

188
            tracing::trace!("background endpoint subscriber shutting down");
189
190
        });

191
192
193
194
195
        Ok(KvScheduler {
            request_tx,
            slots,
            queue,
        })
196
197
    }

198
    #[allow(clippy::too_many_arguments)]
199
200
    pub async fn schedule(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
201
        maybe_request_id: Option<String>,
202
        isl_tokens: usize,
203
        token_seq: Option<Vec<SequenceHash>>,
204
        overlaps: OverlapScores,
205
        router_config_override: Option<&RouterConfigOverride>,
206
        update_states: bool,
207
        lora_name: Option<String>,
208
        priority_jump: f64,
209
210
        allowed_worker_ids: Option<HashSet<WorkerId>>,
    ) -> Result<SchedulingResponse, KvSchedulerError> {
211
212
213
        #[cfg(feature = "bench")]
        let start = Instant::now();

214
215
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
Yan Ru Pei's avatar
Yan Ru Pei committed
216
            maybe_request_id,
217
            token_seq,
218
            isl_tokens,
219
            overlaps,
220
221
            decode_blocks: HashMap::new(),
            prefill_tokens: HashMap::new(),
222
            router_config_override: router_config_override.cloned(),
223
            update_states,
224
            lora_name,
225
            priority_jump,
226
            allowed_worker_ids,
227
            resp_tx: Some(resp_tx),
228
        };
229

230
231
232
233
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
234
235
236
237

        #[cfg(feature = "bench")]
        let send_elapsed = start.elapsed();

238
        let response = resp_rx
239
            .await
240
            .map_err(|_| KvSchedulerError::SubscriberShutdown)??;
241

242
243
244
245
246
247
248
249
250
251
        #[cfg(feature = "bench")]
        let total_elapsed = start.elapsed();
        #[cfg(feature = "bench")]
        tracing::info!(
            isl_tokens,
            send_us = send_elapsed.as_micros() as u64,
            total_us = total_elapsed.as_micros() as u64,
            "scheduler.schedule completed"
        );

252
        Ok(response)
253
254
    }

255
256
    pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
        self.slots.add_request(req).await
257
258
    }

259
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
260
        self.slots
261
            .mark_prefill_completed(&request_id.to_string())
262
263
264
            .await?;
        self.queue.update().await;
        Ok(())
265
266
    }

267
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
268
269
270
        self.slots.free(&request_id.to_string()).await?;
        self.queue.update().await;
        Ok(())
271
    }
272

273
274
275
276
277
278
    /// Get the worker type for this scheduler ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.slots.worker_type()
    }

279
    pub fn add_output_block(
280
281
282
283
284
285
286
287
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        self.slots
            .add_output_block(&request_id.to_string(), decay_fraction)
    }

288
    pub fn get_potential_loads(
289
        &self,
290
        token_seq: Option<Vec<SequenceHash>>,
291
292
293
294
295
        isl_tokens: usize,
        overlaps: OverlapScores,
    ) -> Vec<PotentialLoad> {
        let (decode_blocks, prefill_tokens) = self
            .slots
296
            .potential_blocks_and_tokens(token_seq, isl_tokens, overlaps);
297

Yan Ru Pei's avatar
Yan Ru Pei committed
298
299
300
301
        // Get all unique WorkerWithDpRank from both hashmaps
        let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
        workers.extend(decode_blocks.keys().copied());
        workers.extend(prefill_tokens.keys().copied());
302
303
304

        // Create PotentialLoad for each worker
        let mut loads = Vec::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
305
        for worker in workers {
306
            loads.push(PotentialLoad {
Yan Ru Pei's avatar
Yan Ru Pei committed
307
308
                worker_id: worker.worker_id,
                dp_rank: worker.dp_rank,
309
                potential_prefill_tokens: prefill_tokens
Yan Ru Pei's avatar
Yan Ru Pei committed
310
                    .get(&worker)
311
312
                    .copied()
                    .unwrap_or(isl_tokens),
Yan Ru Pei's avatar
Yan Ru Pei committed
313
                potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
314
315
316
317
318
            });
        }

        loads
    }
319
320
321
322
323

    /// Get active request counts grouped by LORA name
    pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
        self.slots.get_active_lora_counts()
    }
324
325
}

326
// Helper function for softmax sampling
327
328
329
330
331
// Returns a vec of workers: multiple if tied, single if sampled
fn softmax_sample(
    logits: &HashMap<WorkerWithDpRank, f64>,
    temperature: f64,
) -> Vec<WorkerWithDpRank> {
332
333
334
335
    if logits.is_empty() {
        panic!("Empty logits for softmax sampling");
    }

336
    // Guard: if temperature is 0, return all keys with the smallest logit value (ties)
337
338
339
340
341
342
343
    if temperature == 0.0 {
        // Find the minimum logit value
        let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b));

        // Collect all keys with the minimum logit value (to handle ties)
        let min_keys: Vec<_> = logits
            .iter()
344
            .filter(|&(_, &v)| v == min_logit)
345
346
347
            .map(|(k, _)| *k)
            .collect();

348
        return min_keys;
349
350
    }

351
352
353
354
355
356
357
358
359
360
361
    let keys: Vec<_> = logits.keys().copied().collect();
    let values: Vec<_> = logits.values().copied().collect();

    // Find min and max for normalization
    let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
    let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));

    let probabilities = if min_val == max_val {
        // All values are the same, uniform probability
        vec![1.0 / keys.len() as f64; keys.len()]
    } else {
362
363
364
        // Fused normalize → negate → scale → exp, then normalize probabilities
        let range = max_val - min_val;
        let scaled: Vec<f64> = values.iter().map(|&v| -(v / range) / temperature).collect();
365
        let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
366
367
368
369
        let mut probs: Vec<f64> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();
        let sum: f64 = probs.iter().sum();
        probs.iter_mut().for_each(|p| *p /= sum);
        probs
370
371
372
373
374
375
376
377
378
379
    };

    // Sample from the probability distribution
    let mut rng = rand::rng();
    let sample: f64 = rng.random();

    let mut cumsum = 0.0;
    for (i, &prob) in probabilities.iter().enumerate() {
        cumsum += prob;
        if sample <= cumsum {
380
            return vec![keys[i]];
381
382
383
384
        }
    }

    // Fallback to last key (shouldn't normally reach here)
385
    vec![keys[keys.len() - 1]]
386
387
}

388
// Default implementation matching the Python _cost_function
389
390
391
392
393
394
395
396
397
398
399
400
#[derive(Debug, Clone, Default)]
pub struct DefaultWorkerSelector {
    pub kv_router_config: KvRouterConfig,
}

impl DefaultWorkerSelector {
    pub fn new(kv_router_config: Option<KvRouterConfig>) -> Self {
        Self {
            kv_router_config: kv_router_config.unwrap_or_default(),
        }
    }
}
401
402
403
404

impl WorkerSelector for DefaultWorkerSelector {
    fn select_worker(
        &self,
405
        workers: &HashMap<WorkerId, ModelRuntimeConfig>,
406
        request: &SchedulingRequest,
407
        block_size: u32,
408
409
410
    ) -> Result<WorkerSelectionResult, KvSchedulerError> {
        assert!(request.isl_tokens > 0);

411
412
413
414
415
        let allowed_ids = request.allowed_worker_ids.as_ref();

        if allowed_ids.map_or(workers.is_empty(), |ids| {
            !workers.keys().any(|wid| ids.contains(wid))
        }) {
416
417
418
            return Err(KvSchedulerError::NoEndpoints);
        }

419
420
421
422
        let isl = request.isl_tokens;
        let request_blocks = isl.div_ceil(block_size as usize);
        let overlaps = &request.overlaps.scores;

423
424
        let decode_blocks = &request.decode_blocks;
        let prefill_tokens = &request.prefill_tokens;
425

426
        let mut worker_logits = HashMap::new();
427

428
429
430
431
432
433
434
        // Use override if provided, otherwise use default config
        let overlap_weight = request
            .router_config_override
            .as_ref()
            .and_then(|cfg| cfg.overlap_score_weight)
            .unwrap_or(self.kv_router_config.overlap_score_weight);

435
436
437
438
        for (worker_id, config) in workers
            .iter()
            .filter(|(wid, _)| allowed_ids.is_none_or(|ids| ids.contains(wid)))
        {
439
            let data_parallel_size = config.data_parallel_size;
Yan Ru Pei's avatar
Yan Ru Pei committed
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469

            for dp_rank in 0..data_parallel_size {
                let worker = WorkerWithDpRank::new(*worker_id, dp_rank);

                // Get overlap for this worker (defaults to 0 if not in overlaps)
                let overlap = *overlaps.get(&worker).unwrap_or(&0);

                // this is the number of prefill tokens the worker would have if the request were scheduled there
                let prefill_token = *prefill_tokens.get(&worker).unwrap_or(&isl);
                let potential_prefill_block = (prefill_token as f64) / (block_size as f64);

                // this is the number of decode blocks the worker would have if the request were scheduled there
                let decode_block = *decode_blocks
                    .get(&worker)
                    .unwrap_or(&(potential_prefill_block.floor() as usize))
                    as f64;

                // Calculate logit (lower is better)
                let logit = overlap_weight * potential_prefill_block + decode_block;

                worker_logits.insert(worker, logit);

                tracing::info!(
                    "Formula for worker_id={} dp_rank={:?} with {overlap} cached blocks: {logit:.3} \
                     = {overlap_weight:.1} * prefill_blocks + decode_blocks \
                     = {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
                    worker.worker_id,
                    worker.dp_rank
                );
            }
470
471
        }

472
        // Use softmax sampling to select worker(s)
473
474
475
476
477
478
        // Use override if provided, otherwise use default config
        let temperature = request
            .router_config_override
            .as_ref()
            .and_then(|cfg| cfg.router_temperature)
            .unwrap_or(self.kv_router_config.router_temperature);
479
480
481
        let candidates = softmax_sample(&worker_logits, temperature);

        // If multiple candidates (tied), use tree size as tie-breaker
482
        // If tree sizes are also equal, use random selection to avoid bias
483
484
        let best_worker = if candidates.len() > 1 {
            tracing::info!("Multiple workers tied with same logit, using tree size as tie-breaker");
485
            let tree_sizes: Vec<(usize, &WorkerWithDpRank)> = candidates
486
                .iter()
487
488
489
490
491
492
493
494
495
                .map(|w| (request.overlaps.tree_sizes.get(w).copied().unwrap_or(0), w))
                .collect();

            if tree_sizes.iter().all(|(s, _)| *s == tree_sizes[0].0) {
                let idx = rand::rng().random_range(0..candidates.len());
                candidates[idx]
            } else {
                *tree_sizes.iter().min_by_key(|(s, _)| *s).unwrap().1
            }
496
497
498
499
        } else {
            candidates[0]
        };

Yan Ru Pei's avatar
Yan Ru Pei committed
500
501
502
        let best_logit = worker_logits[&best_worker];

        let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0);
503

Yan Ru Pei's avatar
Yan Ru Pei committed
504
        // this is a runtime config set on a per worker basis, not per dp-rank
505
        let total_blocks_info = workers
Yan Ru Pei's avatar
Yan Ru Pei committed
506
            .get(&best_worker.worker_id)
507
508
509
510
            .and_then(|cfg| cfg.total_kv_blocks)
            .map(|blocks| format!(", total blocks: {}", blocks))
            .unwrap_or_default();

511
512
513
514
515
516
517
        let tree_size = request
            .overlaps
            .tree_sizes
            .get(&best_worker)
            .copied()
            .unwrap_or(0);

518
        tracing::info!(
519
            "Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}, tree size: {}{}",
Yan Ru Pei's avatar
Yan Ru Pei committed
520
521
            best_worker.worker_id,
            best_worker.dp_rank,
522
523
            best_logit,
            best_overlap,
524
            tree_size,
525
            total_blocks_info
526
        );
527
528

        Ok(WorkerSelectionResult {
Yan Ru Pei's avatar
Yan Ru Pei committed
529
            worker: best_worker,
530
            required_blocks: request_blocks as u64,
Yan Ru Pei's avatar
Yan Ru Pei committed
531
            overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0),
532
        })
533
534
    }
}
535
536
537
538
539

#[cfg(test)]
mod tests {
    use super::*;

540
541
542
543
    #[test]
    fn test_softmax_sample_single_key() {
        // Test that with a single key, softmax_sample always returns that key
        let mut logits = HashMap::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
544
545
        let worker = WorkerWithDpRank::from_worker_id(42);
        logits.insert(worker, 0.5); // The value doesn't matter
546
547
548
549

        // Test with different temperatures
        for temperature in &[0.1, 1.0, 10.0] {
            let result = softmax_sample(&logits, *temperature);
550
551
            assert_eq!(result.len(), 1, "Should return exactly one worker");
            assert_eq!(result[0], worker, "Should return the only available worker");
552
553
554
555
        }

        // Test with different logit values
        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
556
        logits.insert(worker, -100.0); // Very negative value
557
558
559
        let result = softmax_sample(&logits, 1.0);
        assert_eq!(result.len(), 1);
        assert_eq!(result[0], worker);
560
561

        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
562
        logits.insert(worker, 100.0); // Very positive value
563
564
565
        let result = softmax_sample(&logits, 1.0);
        assert_eq!(result.len(), 1);
        assert_eq!(result[0], worker);
566
567

        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
568
        logits.insert(worker, 0.0); // Zero value
569
570
571
        let result = softmax_sample(&logits, 1.0);
        assert_eq!(result.len(), 1);
        assert_eq!(result[0], worker);
572
573
    }

574
575
    #[test]
    fn test_softmax_sample_zero_temperature() {
576
        // Test that with temperature 0, softmax_sample returns all keys with smallest logit
577
        let mut logits = HashMap::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
578
579
580
581
582
583
584
585
        let worker1 = WorkerWithDpRank::from_worker_id(1);
        let worker2 = WorkerWithDpRank::from_worker_id(2);
        let worker3 = WorkerWithDpRank::from_worker_id(3);
        let worker4 = WorkerWithDpRank::from_worker_id(4);
        logits.insert(worker1, 5.0);
        logits.insert(worker2, 3.0); // This has the smallest logit
        logits.insert(worker3, 7.0);
        logits.insert(worker4, 3.5);
586

587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
        // With temperature 0, should always return only worker2 (smallest logit)
        let result = softmax_sample(&logits, 0.0);
        assert_eq!(
            result.len(),
            1,
            "Should return one worker when there's no tie"
        );
        assert_eq!(
            result[0], worker2,
            "Should return worker with smallest logit when temperature is 0"
        );

        // Test with tied minimum logits
        logits.clear();
        let worker5 = WorkerWithDpRank::from_worker_id(5);
        let worker6 = WorkerWithDpRank::from_worker_id(6);
        logits.insert(worker1, 5.0);
        logits.insert(worker2, 3.0); // Tied for smallest
        logits.insert(worker5, 3.0); // Tied for smallest
        logits.insert(worker6, 7.0);

        let result = softmax_sample(&logits, 0.0);
        assert_eq!(
            result.len(),
            2,
            "Should return all workers with smallest logit when tied"
        );
        assert!(
            result.contains(&worker2) && result.contains(&worker5),
            "Should contain both tied workers"
        );
618

619
620
        // Test with negative values
        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
621
622
623
624
625
626
        let worker10 = WorkerWithDpRank::from_worker_id(10);
        let worker20 = WorkerWithDpRank::from_worker_id(20);
        let worker30 = WorkerWithDpRank::from_worker_id(30);
        logits.insert(worker10, -1.0);
        logits.insert(worker20, -5.0); // This has the smallest logit
        logits.insert(worker30, 0.0);
627

628
        let result = softmax_sample(&logits, 0.0);
629
630
631
632
633
        assert_eq!(result.len(), 1);
        assert_eq!(
            result[0], worker20,
            "Should handle negative logits correctly"
        );
634
635
    }
}