"fern/pages/guides/request-plane.md" did not exist on "18b64e90b6919f3958c5578833c790bdad5b514f"
scheduler.rs 22.4 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
// SPDX-License-Identifier: Apache-2.0
3

4
5
6
7
8
use super::KvRouterConfig;
use super::RouterConfigOverride;
use super::WorkerSelector;
use super::protocols::{DpRank, OverlapScores, WorkerId, WorkerSelectionResult, WorkerWithDpRank};
use super::queue::SchedulerQueue;
9
10
11
use super::sequence::{
    ActiveSequencesMulti, SequenceError, SequenceRequest, create_multi_worker_sequences,
};
12
use crate::discovery::RuntimeConfigWatch;
13
use crate::local_model::runtime_config::ModelRuntimeConfig;
14
use anyhow::Result;
15
use dynamo_runtime::component::Component;
Yan Ru Pei's avatar
Yan Ru Pei committed
16
use dynamo_runtime::traits::DistributedRuntimeProvider;
17
use rand::Rng;
18
use serde::{Deserialize, Serialize};
19
use std::collections::{HashMap, HashSet};
20
21
use std::sync::Arc;
use std::time::Duration;
22
23
#[cfg(feature = "bench")]
use std::time::Instant;
24

25
use dynamo_tokens::SequenceHash;
26

27
28
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PotentialLoad {
Yan Ru Pei's avatar
Yan Ru Pei committed
29
30
    pub worker_id: WorkerId,
    pub dp_rank: DpRank,
31
32
33
34
    pub potential_prefill_tokens: usize,
    pub potential_decode_blocks: usize,
}

35
36
#[derive(Debug, thiserror::Error)]
pub enum KvSchedulerError {
37
    #[error("no endpoints available to route work")]
38
39
40
41
    NoEndpoints,

    #[error("endpoint subscriber shutdown")]
    SubscriberShutdown,
42
43
44

    #[error("failed to initialize event publisher: {0}")]
    InitFailed(String),
45
46
}

47
48
#[derive(Debug)]
pub struct SchedulingResponse {
Yan Ru Pei's avatar
Yan Ru Pei committed
49
    pub best_worker: WorkerWithDpRank,
50
    pub overlap_blocks: u32,
51
52
}

53
pub struct SchedulingRequest {
Yan Ru Pei's avatar
Yan Ru Pei committed
54
    pub maybe_request_id: Option<String>,
55
    pub token_seq: Option<Vec<SequenceHash>>,
56
    pub isl_tokens: usize,
57
    pub overlaps: OverlapScores,
Yan Ru Pei's avatar
Yan Ru Pei committed
58
59
    pub decode_blocks: HashMap<WorkerWithDpRank, usize>,
    pub prefill_tokens: HashMap<WorkerWithDpRank, usize>,
60
61
    // Router config overrides for this specific request
    pub router_config_override: Option<RouterConfigOverride>,
62
63
    // Whether to update scheduler states (false for query_instance_id requests)
    pub update_states: bool,
64
65
    // LORA adapter name extracted from request.model field
    pub lora_name: Option<String>,
66
67
    /// Priority jump in seconds; decreases effective arrival time in the queue.
    pub priority_jump: f64,
68
69
    /// Optional set of allowed worker IDs to restrict routing decisions (EPP).
    pub allowed_worker_ids: Option<HashSet<WorkerId>>,
70
    resp_tx: Option<tokio::sync::oneshot::Sender<Result<SchedulingResponse, KvSchedulerError>>>,
71
72
73
}

impl SchedulingRequest {
74
75
    pub fn respond(&mut self, result: Result<SchedulingResponse, KvSchedulerError>) {
        let Some(tx) = self.resp_tx.take() else {
76
            tracing::error!("respond called multiple times on same request");
77
78
79
80
            return;
        };
        if tx.send(result).is_err() {
            tracing::error!("failed to send response to requestor");
81
82
83
84
85
86
        }
    }
}

pub struct KvScheduler {
    request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
87
    slots: Arc<ActiveSequencesMulti>,
88
    queue: Arc<SchedulerQueue>,
89
90
91
92
}

impl KvScheduler {
    pub async fn start(
93
        component: Component,
94
        block_size: u32,
95
        workers_with_configs: RuntimeConfigWatch,
96
        selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
97
        kv_router_config: &KvRouterConfig,
98
        worker_type: &'static str,
99
    ) -> Result<Self, KvSchedulerError> {
100
        let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
101

102
103
104
105
        // Get initial workers from watch receiver.
        // Caller must ensure at least one worker is present (via wait_for).
        let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
            workers_with_configs.borrow().clone();
106

107
        let router_id = component.drt().discovery().instance_id();
108
109
110
111
112
113
114
115
116
117
        let slots = create_multi_worker_sequences(
            component.clone(),
            block_size as usize,
            initial_workers,
            kv_router_config.router_replica_sync,
            router_id,
            worker_type,
        )
        .await
        .map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?;
118

119
        // Spawn background task to sync slots when the watch value changes.
Yan Ru Pei's avatar
Yan Ru Pei committed
120
        let slots_monitor = slots.clone();
121
        let mut monitor_rx = workers_with_configs.clone();
122
        let monitor_cancel_token = component.drt().child_token();
123
        tokio::spawn(async move {
124
            tracing::trace!("KvScheduler workers monitoring task started");
125
            let mut last_workers: HashMap<WorkerId, ModelRuntimeConfig> = HashMap::new();
126

127
            loop {
Yan Ru Pei's avatar
Yan Ru Pei committed
128
129
                tokio::select! {
                    _ = monitor_cancel_token.cancelled() => {
130
                        tracing::trace!("KvScheduler workers monitoring task shutting down");
131
132
                        break;
                    }
133
                    result = monitor_rx.changed() => {
134
135
136
137
138
                        if result.is_err() {
                            tracing::warn!("KvScheduler: config watch sender dropped, shutting down");
                            break;
                        }
                    }
139
140
                }

141
142
143
                let current_workers = monitor_rx.borrow_and_update().clone();

                if current_workers != last_workers {
144
145
146
147
148
                    let dp_sizes: HashMap<u64, u32> = current_workers
                        .iter()
                        .map(|(&id, c)| (id, c.data_parallel_size))
                        .collect();
                    slots_monitor.update_workers(dp_sizes);
149
                    last_workers = current_workers;
Yan Ru Pei's avatar
Yan Ru Pei committed
150
151
152
153
154
155
156
                }
            }
        });

        let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
        let scheduler_cancel_token = component.drt().primary_token();

157
158
159
        let queue = Arc::new(SchedulerQueue::new(
            slots.clone(),
            workers_with_configs.clone(),
160
            kv_router_config.router_queue_threshold,
161
162
            block_size,
            selector,
163
164
165
        ));
        let queue_clone = queue.clone();

166
        // Background task: receive requests and periodically recheck pending
Yan Ru Pei's avatar
Yan Ru Pei committed
167
168
        tokio::spawn(async move {
            let mut request_rx = request_rx;
169
            let mut recheck_interval = tokio::time::interval(Duration::from_secs(60));
Yan Ru Pei's avatar
Yan Ru Pei committed
170
171
172
            tracing::trace!("scheduler background task started");

            loop {
173
174
175
176
177
178
179
180
181
                tokio::select! {
                    _ = scheduler_cancel_token.cancelled() => {
                        tracing::trace!("scheduler background task shutting down");
                        break;
                    }
                    request = request_rx.recv() => {
                        let Some(request) = request else {
                            tracing::warn!("scheduler shutdown");
                            break;
182
                        };
183
184
185
186
187
188
189
                        tracing::trace!("received request to be scheduled");
                        queue_clone.enqueue(request).await;
                    }
                    _ = recheck_interval.tick() => {
                        queue_clone.update().await;
                    }
                }
190
191
            }

192
            tracing::trace!("background endpoint subscriber shutting down");
193
194
        });

195
196
197
198
199
        Ok(KvScheduler {
            request_tx,
            slots,
            queue,
        })
200
201
    }

202
    #[allow(clippy::too_many_arguments)]
203
204
    pub async fn schedule(
        &self,
Yan Ru Pei's avatar
Yan Ru Pei committed
205
        maybe_request_id: Option<String>,
206
        isl_tokens: usize,
207
        token_seq: Option<Vec<SequenceHash>>,
208
        overlaps: OverlapScores,
209
        router_config_override: Option<&RouterConfigOverride>,
210
        update_states: bool,
211
        lora_name: Option<String>,
212
        priority_jump: f64,
213
214
        allowed_worker_ids: Option<HashSet<WorkerId>>,
    ) -> Result<SchedulingResponse, KvSchedulerError> {
215
216
217
        #[cfg(feature = "bench")]
        let start = Instant::now();

218
219
        let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
        let request = SchedulingRequest {
Yan Ru Pei's avatar
Yan Ru Pei committed
220
            maybe_request_id,
221
            token_seq,
222
            isl_tokens,
223
            overlaps,
224
225
            decode_blocks: HashMap::new(),
            prefill_tokens: HashMap::new(),
226
            router_config_override: router_config_override.cloned(),
227
            update_states,
228
            lora_name,
229
            priority_jump,
230
            allowed_worker_ids,
231
            resp_tx: Some(resp_tx),
232
        };
233

234
235
236
237
        self.request_tx
            .send(request)
            .await
            .map_err(|_| KvSchedulerError::SubscriberShutdown)?;
238
239
240
241

        #[cfg(feature = "bench")]
        let send_elapsed = start.elapsed();

242
        let response = resp_rx
243
            .await
244
            .map_err(|_| KvSchedulerError::SubscriberShutdown)??;
245

246
247
248
249
250
251
252
253
254
255
        #[cfg(feature = "bench")]
        let total_elapsed = start.elapsed();
        #[cfg(feature = "bench")]
        tracing::info!(
            isl_tokens,
            send_us = send_elapsed.as_micros() as u64,
            total_us = total_elapsed.as_micros() as u64,
            "scheduler.schedule completed"
        );

256
        Ok(response)
257
258
    }

259
260
    pub async fn add_request(&self, req: SequenceRequest) -> Result<(), SequenceError> {
        self.slots.add_request(req).await
261
262
    }

263
    pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> {
264
        self.slots
265
            .mark_prefill_completed(&request_id.to_string())
266
267
268
            .await?;
        self.queue.update().await;
        Ok(())
269
270
    }

271
    pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> {
272
273
274
        self.slots.free(&request_id.to_string()).await?;
        self.queue.update().await;
        Ok(())
275
    }
276

277
278
279
280
281
282
    /// Get the worker type for this scheduler ("prefill" or "decode").
    /// Used for Prometheus metric labeling.
    pub fn worker_type(&self) -> &'static str {
        self.slots.worker_type()
    }

283
    pub fn add_output_block(
284
285
286
287
288
289
290
291
        &self,
        request_id: &str,
        decay_fraction: Option<f64>,
    ) -> Result<(), SequenceError> {
        self.slots
            .add_output_block(&request_id.to_string(), decay_fraction)
    }

292
    pub fn get_potential_loads(
293
        &self,
294
        token_seq: Option<Vec<SequenceHash>>,
295
296
297
298
299
        isl_tokens: usize,
        overlaps: OverlapScores,
    ) -> Vec<PotentialLoad> {
        let (decode_blocks, prefill_tokens) = self
            .slots
300
            .potential_blocks_and_tokens(token_seq, isl_tokens, overlaps);
301

Yan Ru Pei's avatar
Yan Ru Pei committed
302
303
304
305
        // Get all unique WorkerWithDpRank from both hashmaps
        let mut workers: HashSet<WorkerWithDpRank> = HashSet::new();
        workers.extend(decode_blocks.keys().copied());
        workers.extend(prefill_tokens.keys().copied());
306
307
308

        // Create PotentialLoad for each worker
        let mut loads = Vec::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
309
        for worker in workers {
310
            loads.push(PotentialLoad {
Yan Ru Pei's avatar
Yan Ru Pei committed
311
312
                worker_id: worker.worker_id,
                dp_rank: worker.dp_rank,
313
                potential_prefill_tokens: prefill_tokens
Yan Ru Pei's avatar
Yan Ru Pei committed
314
                    .get(&worker)
315
316
                    .copied()
                    .unwrap_or(isl_tokens),
Yan Ru Pei's avatar
Yan Ru Pei committed
317
                potential_decode_blocks: decode_blocks.get(&worker).copied().unwrap_or(0),
318
319
320
321
322
            });
        }

        loads
    }
323
324
325
326
327

    /// Get active request counts grouped by LORA name
    pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
        self.slots.get_active_lora_counts()
    }
328
329
}

330
// Helper function for softmax sampling
331
332
333
334
335
// Returns a vec of workers: multiple if tied, single if sampled
fn softmax_sample(
    logits: &HashMap<WorkerWithDpRank, f64>,
    temperature: f64,
) -> Vec<WorkerWithDpRank> {
336
337
338
339
    if logits.is_empty() {
        panic!("Empty logits for softmax sampling");
    }

340
    // Guard: if temperature is 0, return all keys with the smallest logit value (ties)
341
342
343
344
345
346
347
    if temperature == 0.0 {
        // Find the minimum logit value
        let min_logit = logits.values().fold(f64::INFINITY, |a, &b| a.min(b));

        // Collect all keys with the minimum logit value (to handle ties)
        let min_keys: Vec<_> = logits
            .iter()
348
            .filter(|&(_, &v)| v == min_logit)
349
350
351
            .map(|(k, _)| *k)
            .collect();

352
        return min_keys;
353
354
    }

355
356
357
358
359
360
361
362
363
364
365
    let keys: Vec<_> = logits.keys().copied().collect();
    let values: Vec<_> = logits.values().copied().collect();

    // Find min and max for normalization
    let min_val = values.iter().fold(f64::INFINITY, |a, &b| a.min(b));
    let max_val = values.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));

    let probabilities = if min_val == max_val {
        // All values are the same, uniform probability
        vec![1.0 / keys.len() as f64; keys.len()]
    } else {
366
367
368
        // Fused normalize → negate → scale → exp, then normalize probabilities
        let range = max_val - min_val;
        let scaled: Vec<f64> = values.iter().map(|&v| -(v / range) / temperature).collect();
369
        let max_scaled = scaled.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
370
371
372
373
        let mut probs: Vec<f64> = scaled.iter().map(|&v| (v - max_scaled).exp()).collect();
        let sum: f64 = probs.iter().sum();
        probs.iter_mut().for_each(|p| *p /= sum);
        probs
374
375
376
377
378
379
380
381
382
383
    };

    // Sample from the probability distribution
    let mut rng = rand::rng();
    let sample: f64 = rng.random();

    let mut cumsum = 0.0;
    for (i, &prob) in probabilities.iter().enumerate() {
        cumsum += prob;
        if sample <= cumsum {
384
            return vec![keys[i]];
385
386
387
388
        }
    }

    // Fallback to last key (shouldn't normally reach here)
389
    vec![keys[keys.len() - 1]]
390
391
}

392
// Default implementation matching the Python _cost_function
393
394
395
396
397
398
399
400
401
402
403
404
#[derive(Debug, Clone, Default)]
pub struct DefaultWorkerSelector {
    pub kv_router_config: KvRouterConfig,
}

impl DefaultWorkerSelector {
    pub fn new(kv_router_config: Option<KvRouterConfig>) -> Self {
        Self {
            kv_router_config: kv_router_config.unwrap_or_default(),
        }
    }
}
405
406
407
408

impl WorkerSelector for DefaultWorkerSelector {
    fn select_worker(
        &self,
409
        workers: &HashMap<WorkerId, ModelRuntimeConfig>,
410
        request: &SchedulingRequest,
411
        block_size: u32,
412
413
414
    ) -> Result<WorkerSelectionResult, KvSchedulerError> {
        assert!(request.isl_tokens > 0);

415
416
417
418
419
        let allowed_ids = request.allowed_worker_ids.as_ref();

        if allowed_ids.map_or(workers.is_empty(), |ids| {
            !workers.keys().any(|wid| ids.contains(wid))
        }) {
420
421
422
            return Err(KvSchedulerError::NoEndpoints);
        }

423
424
425
426
        let isl = request.isl_tokens;
        let request_blocks = isl.div_ceil(block_size as usize);
        let overlaps = &request.overlaps.scores;

427
428
        let decode_blocks = &request.decode_blocks;
        let prefill_tokens = &request.prefill_tokens;
429

430
        let mut worker_logits = HashMap::new();
431

432
433
434
435
436
437
438
        // Use override if provided, otherwise use default config
        let overlap_weight = request
            .router_config_override
            .as_ref()
            .and_then(|cfg| cfg.overlap_score_weight)
            .unwrap_or(self.kv_router_config.overlap_score_weight);

439
440
441
442
        for (worker_id, config) in workers
            .iter()
            .filter(|(wid, _)| allowed_ids.is_none_or(|ids| ids.contains(wid)))
        {
443
            let data_parallel_size = config.data_parallel_size;
Yan Ru Pei's avatar
Yan Ru Pei committed
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473

            for dp_rank in 0..data_parallel_size {
                let worker = WorkerWithDpRank::new(*worker_id, dp_rank);

                // Get overlap for this worker (defaults to 0 if not in overlaps)
                let overlap = *overlaps.get(&worker).unwrap_or(&0);

                // this is the number of prefill tokens the worker would have if the request were scheduled there
                let prefill_token = *prefill_tokens.get(&worker).unwrap_or(&isl);
                let potential_prefill_block = (prefill_token as f64) / (block_size as f64);

                // this is the number of decode blocks the worker would have if the request were scheduled there
                let decode_block = *decode_blocks
                    .get(&worker)
                    .unwrap_or(&(potential_prefill_block.floor() as usize))
                    as f64;

                // Calculate logit (lower is better)
                let logit = overlap_weight * potential_prefill_block + decode_block;

                worker_logits.insert(worker, logit);

                tracing::info!(
                    "Formula for worker_id={} dp_rank={:?} with {overlap} cached blocks: {logit:.3} \
                     = {overlap_weight:.1} * prefill_blocks + decode_blocks \
                     = {overlap_weight:.1} * {potential_prefill_block:.3} + {decode_block:.3}",
                    worker.worker_id,
                    worker.dp_rank
                );
            }
474
475
        }

476
        // Use softmax sampling to select worker(s)
477
478
479
480
481
482
        // Use override if provided, otherwise use default config
        let temperature = request
            .router_config_override
            .as_ref()
            .and_then(|cfg| cfg.router_temperature)
            .unwrap_or(self.kv_router_config.router_temperature);
483
484
485
        let candidates = softmax_sample(&worker_logits, temperature);

        // If multiple candidates (tied), use tree size as tie-breaker
486
        // If tree sizes are also equal, use random selection to avoid bias
487
488
        let best_worker = if candidates.len() > 1 {
            tracing::info!("Multiple workers tied with same logit, using tree size as tie-breaker");
489
            let tree_sizes: Vec<(usize, &WorkerWithDpRank)> = candidates
490
                .iter()
491
492
493
494
495
496
497
498
499
                .map(|w| (request.overlaps.tree_sizes.get(w).copied().unwrap_or(0), w))
                .collect();

            if tree_sizes.iter().all(|(s, _)| *s == tree_sizes[0].0) {
                let idx = rand::rng().random_range(0..candidates.len());
                candidates[idx]
            } else {
                *tree_sizes.iter().min_by_key(|(s, _)| *s).unwrap().1
            }
500
501
502
503
        } else {
            candidates[0]
        };

Yan Ru Pei's avatar
Yan Ru Pei committed
504
505
506
        let best_logit = worker_logits[&best_worker];

        let best_overlap = *overlaps.get(&best_worker).unwrap_or(&0);
507

Yan Ru Pei's avatar
Yan Ru Pei committed
508
        // this is a runtime config set on a per worker basis, not per dp-rank
509
        let total_blocks_info = workers
Yan Ru Pei's avatar
Yan Ru Pei committed
510
            .get(&best_worker.worker_id)
511
512
513
514
            .and_then(|cfg| cfg.total_kv_blocks)
            .map(|blocks| format!(", total blocks: {}", blocks))
            .unwrap_or_default();

515
516
517
518
519
520
521
        let tree_size = request
            .overlaps
            .tree_sizes
            .get(&best_worker)
            .copied()
            .unwrap_or(0);

522
        tracing::info!(
523
            "Selected worker: worker_id={} dp_rank={:?}, logit: {:.3}, cached blocks: {}, tree size: {}{}",
Yan Ru Pei's avatar
Yan Ru Pei committed
524
525
            best_worker.worker_id,
            best_worker.dp_rank,
526
527
            best_logit,
            best_overlap,
528
            tree_size,
529
            total_blocks_info
530
        );
531
532

        Ok(WorkerSelectionResult {
Yan Ru Pei's avatar
Yan Ru Pei committed
533
            worker: best_worker,
534
            required_blocks: request_blocks as u64,
Yan Ru Pei's avatar
Yan Ru Pei committed
535
            overlap_blocks: overlaps.get(&best_worker).copied().unwrap_or(0),
536
        })
537
538
    }
}
539
540
541
542
543

#[cfg(test)]
mod tests {
    use super::*;

544
545
546
547
    #[test]
    fn test_softmax_sample_single_key() {
        // Test that with a single key, softmax_sample always returns that key
        let mut logits = HashMap::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
548
549
        let worker = WorkerWithDpRank::from_worker_id(42);
        logits.insert(worker, 0.5); // The value doesn't matter
550
551
552
553

        // Test with different temperatures
        for temperature in &[0.1, 1.0, 10.0] {
            let result = softmax_sample(&logits, *temperature);
554
555
            assert_eq!(result.len(), 1, "Should return exactly one worker");
            assert_eq!(result[0], worker, "Should return the only available worker");
556
557
558
559
        }

        // Test with different logit values
        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
560
        logits.insert(worker, -100.0); // Very negative value
561
562
563
        let result = softmax_sample(&logits, 1.0);
        assert_eq!(result.len(), 1);
        assert_eq!(result[0], worker);
564
565

        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
566
        logits.insert(worker, 100.0); // Very positive value
567
568
569
        let result = softmax_sample(&logits, 1.0);
        assert_eq!(result.len(), 1);
        assert_eq!(result[0], worker);
570
571

        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
572
        logits.insert(worker, 0.0); // Zero value
573
574
575
        let result = softmax_sample(&logits, 1.0);
        assert_eq!(result.len(), 1);
        assert_eq!(result[0], worker);
576
577
    }

578
579
    #[test]
    fn test_softmax_sample_zero_temperature() {
580
        // Test that with temperature 0, softmax_sample returns all keys with smallest logit
581
        let mut logits = HashMap::new();
Yan Ru Pei's avatar
Yan Ru Pei committed
582
583
584
585
586
587
588
589
        let worker1 = WorkerWithDpRank::from_worker_id(1);
        let worker2 = WorkerWithDpRank::from_worker_id(2);
        let worker3 = WorkerWithDpRank::from_worker_id(3);
        let worker4 = WorkerWithDpRank::from_worker_id(4);
        logits.insert(worker1, 5.0);
        logits.insert(worker2, 3.0); // This has the smallest logit
        logits.insert(worker3, 7.0);
        logits.insert(worker4, 3.5);
590

591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
        // With temperature 0, should always return only worker2 (smallest logit)
        let result = softmax_sample(&logits, 0.0);
        assert_eq!(
            result.len(),
            1,
            "Should return one worker when there's no tie"
        );
        assert_eq!(
            result[0], worker2,
            "Should return worker with smallest logit when temperature is 0"
        );

        // Test with tied minimum logits
        logits.clear();
        let worker5 = WorkerWithDpRank::from_worker_id(5);
        let worker6 = WorkerWithDpRank::from_worker_id(6);
        logits.insert(worker1, 5.0);
        logits.insert(worker2, 3.0); // Tied for smallest
        logits.insert(worker5, 3.0); // Tied for smallest
        logits.insert(worker6, 7.0);

        let result = softmax_sample(&logits, 0.0);
        assert_eq!(
            result.len(),
            2,
            "Should return all workers with smallest logit when tied"
        );
        assert!(
            result.contains(&worker2) && result.contains(&worker5),
            "Should contain both tied workers"
        );
622

623
624
        // Test with negative values
        logits.clear();
Yan Ru Pei's avatar
Yan Ru Pei committed
625
626
627
628
629
630
        let worker10 = WorkerWithDpRank::from_worker_id(10);
        let worker20 = WorkerWithDpRank::from_worker_id(20);
        let worker30 = WorkerWithDpRank::from_worker_id(30);
        logits.insert(worker10, -1.0);
        logits.insert(worker20, -5.0); // This has the smallest logit
        logits.insert(worker30, 0.0);
631

632
        let result = softmax_sample(&logits, 0.0);
633
634
635
636
637
        assert_eq!(result.len(), 1);
        assert_eq!(
            result[0], worker20,
            "Should handle negative logits correctly"
        );
638
639
    }
}