vllm.rs 35.5 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
// SPDX-License-Identifier: Apache-2.0

//! Asynchronous Scheduler for LLM Request Management
//!
//! This module implements an asynchronous scheduler that handles three main functions:
//! 1. Receiving new requests and placing them in the waiting queue
//! 2. Scheduling waiting requests against available KV cache resources
//! 3. Simulating the execution of running requests with realistic timing
//!
//! ## Scheduling Process
12
//! The scheduler checks direct block capacity to determine if there's sufficient
13
14
15
16
17
18
19
20
21
22
23
24
//! KV cache space for new requests. It also enforces a batched tokens budget to prevent
//! oversubscription of computational resources. Only requests that can be allocated
//! these resources are moved from waiting to running state.
//!
//! ## Request Simulation
//! The simulation models two key phases:
//! - Prefill phase: Uses a quadratic cost function: (cached_tokens + new_tokens) * new_tokens
//! - Decode phase: Uses a cost function proportional to active KV blocks (linear)
//!
//! ## Resource Management
//! The scheduler communicates with the KvManager through MoveBlock signals at each
//! stage of request processing. When resources become constrained, it employs an
25
26
//! preemption strategy (LIFO by default, matching vLLM v1) where a running request
//! is evicted and placed at the front of the waiting queue to be rescheduled later.
27
28
29
30
//!
//! ## NOTE
//! The current prefill and decoding time simulations are not scientific at all and are WIP

31
use crate::common::protocols::{
32
    DirectRequest, KvCacheEventSink, MockEngineArgs, MoveBlock, OutputSignal, PreemptionMode,
33
    WorkerType,
Yan Ru Pei's avatar
Yan Ru Pei committed
34
};
35
36
use crate::common::running_mean::RunningMean;
use crate::common::sequence::ActiveSequence;
37
use crate::common::utils::sleep_until_precise;
38
use crate::kv_manager::KvManager;
39
use dynamo_kv_router::protocols::DpRank;
40
use dynamo_tokens::blocks::UniqueBlock;
Yan Ru Pei's avatar
Yan Ru Pei committed
41
use std::collections::{HashMap, VecDeque};
42
use std::sync::Arc;
43
use std::time::Instant;
44
use tokio::sync::mpsc;
45
use tokio::time::Duration;
46
47
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
48
use validator::Validate;
49

50
51
52
53
54
55
56
/// Simple metrics struct for mocker's internal use
#[derive(Clone, Default, Debug)]
pub struct MockerMetrics {
    pub dp_rank: DpRank,
    pub active_decode_blocks: u64,
}

57
58
59
60
61
62
63
64
65
/// Enum representing either a direct request or an active sequence
pub enum Request {
    Direct(DirectRequest),
    Active(ActiveSequence),
}

#[derive(Default)]
struct SchedulerState {
    waiting: VecDeque<Uuid>,
66
    prefill: VecDeque<Uuid>,
67
    decode: VecDeque<Uuid>,
68
69
70
71
    requests: HashMap<Uuid, Request>,
}

impl SchedulerState {
72
73
74
75
    fn is_empty(&self) -> bool {
        self.requests.is_empty()
    }

76
77
78
79
80
81
82
    fn receive(&mut self, request: DirectRequest) -> Uuid {
        let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
        self.requests.insert(uuid, Request::Direct(request));
        self.waiting.push_back(uuid);
        uuid
    }

83
84
85
86
87
88
    /// Try to admit one request from waiting → prefill.
    /// Converts DirectRequest → ActiveSequence if needed. PrefillCost is computed
    /// later in simulate_prefill when the request reaches the front of the queue.
    fn admit_one(&mut self, args: &MockEngineArgs) -> bool {
        let Some(&uuid) = self.waiting.front() else {
            return false;
89
        };
90
91
92
93
        let num_active = self.prefill.len() + self.decode.len();
        if args.max_num_seqs.is_some_and(|limit| num_active >= limit) {
            return false;
        }
94

95
        self.waiting.pop_front();
96

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        // Convert DirectRequest → ActiveSequence if needed
        if let Some(Request::Direct(_)) = self.requests.get(&uuid) {
            let Some(Request::Direct(direct)) = self.requests.remove(&uuid) else {
                unreachable!()
            };
            self.requests.insert(
                uuid,
                Request::Active(ActiveSequence::new(
                    direct.tokens,
                    direct.max_output_tokens,
                    Some(args.block_size),
                    args.enable_prefix_caching,
                    args.zmq_kv_events_port.is_some(),
                )),
            );
        }
113

114
115
        self.prefill.push_back(uuid);
        true
116
117
    }

118
119
120
121
122
123
124
125
    fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> {
        if !self.decode.contains(&uuid) {
            return None;
        }
        let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else {
            panic!("Request does not exist.");
        };
        Some(sequence)
126
127
128
129
    }

    /// Remove a UUID and its associated Request from collections.
    fn complete(&mut self, uuid: &Uuid) {
130
        tracing::trace!("Request {uuid} will complete");
131
        self.decode.retain(|u| u != uuid);
132
133
134
        self.requests.remove(uuid);
    }

135
136
137
138
139
140
141
142
143
144
    /// Preempt a running request by evicting it from decode, resetting the sequence,
    /// and adding it back to the front of the waiting queue.
    /// In LIFO mode, evicts the newest request (matches vLLM v1).
    /// In FIFO mode, evicts the oldest request.
    fn preempt(&mut self, mode: PreemptionMode) -> Vec<MoveBlock> {
        let uuid = match mode {
            PreemptionMode::Lifo => self.decode.pop_back(),
            PreemptionMode::Fifo => self.decode.pop_front(),
        }
        .expect("Nothing to evict for preemption.");
145
146
147
148
        let request = self
            .requests
            .remove(&uuid)
            .expect("Request does not exist.");
149
        tracing::warn!("Request {uuid} will be preempted");
150

151
        // Reset the sequence and re-queue for prefill
152
153
154
155
156
        let Request::Active(mut active_sequence) = request else {
            panic!("Expected ActiveSequence in running queue")
        };
        let signals = active_sequence.reset_with_signal();

157
158
        self.requests.insert(uuid, Request::Active(active_sequence));
        self.waiting.push_front(uuid);
159
160

        signals
161
162
163
    }
}

164
165
166
167
168
169
170
171
172
173
/// Cancels its token when dropped. Shared via Arc so the background task is
/// only cancelled when the last Scheduler clone is dropped.
struct CancelGuard(CancellationToken);

impl Drop for CancelGuard {
    fn drop(&mut self) {
        self.0.cancel();
    }
}

174
175
176
/// Manages scheduling of requests using KvManager resources
#[derive(Clone)]
pub struct Scheduler {
177
    request_tx: mpsc::UnboundedSender<DirectRequest>,
178
    metrics_rx: tokio::sync::watch::Receiver<MockerMetrics>,
179
    _cancel_guard: Arc<CancelGuard>,
180
181
182
183
184
}

impl Scheduler {
    /// Create a new Scheduler with the given parameters
    pub fn new(
185
        args: MockEngineArgs,
Yan Ru Pei's avatar
Yan Ru Pei committed
186
        dp_rank: u32,
187
        output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
188
        kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
189
190
        cancellation_token: Option<CancellationToken>,
    ) -> Self {
191
        args.validate().expect("invalid MockEngineArgs");
192

193
194
        // Create channel for request handling
        let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
195
196
197
198
        let initial_metrics = MockerMetrics {
            dp_rank,
            active_decode_blocks: 0,
        };
199
        let (metrics_tx, metrics_rx) =
200
            tokio::sync::watch::channel::<MockerMetrics>(initial_metrics);
201

202
203
204
        let cancel_token = cancellation_token.unwrap_or_default();
        let cancel_token_clone = cancel_token.clone();
        let cancel_guard = Arc::new(CancelGuard(cancel_token));
205
206
207

        // Spawn main background task with cancellation token
        tokio::spawn(async move {
208
            // Create state and kv_manager as local variables owned by this task
209
            let mut state = SchedulerState::default();
210
            let mut kv_manager = KvManager::new_with_event_sink(
Yan Ru Pei's avatar
Yan Ru Pei committed
211
212
                args.num_gpu_blocks,
                args.block_size,
213
                kv_event_sink,
Yan Ru Pei's avatar
Yan Ru Pei committed
214
215
216
                dp_rank,
            );
            let mut hit_rates = RunningMean::new(1000);
217
218

            loop {
Yan Ru Pei's avatar
Yan Ru Pei committed
219
                // 1. Receive requests
220
221
222
223
224
                if receive_requests(&mut state, &mut request_rx, &cancel_token_clone)
                    .await
                    .is_none()
                {
                    break;
225
                }
226

227
228
                // 2. Simulate prefill + decode
                simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
229

230
                simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
231

232
                // 3. Send metrics once per forward pass (after all prefill and decode processing)
233
                let _ = metrics_tx.send(MockerMetrics {
234
                    dp_rank,
235
236
                    active_decode_blocks: kv_manager.num_active_blocks() as u64,
                });
237
238
239
240
241
            }
        });

        Self {
            request_tx,
242
            metrics_rx,
243
            _cancel_guard: cancel_guard,
244
245
246
        }
    }

247
    /// Add a new request to the prefill queue
248
    pub async fn receive(&self, request: DirectRequest) {
249
250
251
252
253
        let _ = self.request_tx.send(request);
    }

    pub fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
        self.request_tx.clone()
254
255
    }

256
    /// Get a watch receiver for forward pass metrics
257
    pub fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
258
259
260
        self.metrics_rx.clone()
    }
}
261

262
263
264
265
266
267
268
269
270
271
272
273
/// Receive requests from the channel.
/// Returns `Some(())` to continue the loop, `None` to break (on cancellation).
async fn receive_requests(
    state: &mut SchedulerState,
    request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
    cancel_token: &CancellationToken,
) -> Option<()> {
    if cancel_token.is_cancelled() {
        return None;
    }

    if state.is_empty() {
274
        // Fully idle - block until new request arrives or shutdown
275
276
277
278
279
        tokio::select! {
            biased;
            _ = cancel_token.cancelled() => {
                return None;
            }
280
281
282
283
            result = request_rx.recv() => {
                let Some(request) = result else {
                    return None; // channel closed
                };
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
                state.receive(request);
                return Some(());
            }
        }
    }

    // Has active/waiting work - collect any pending requests without blocking
    while let Ok(request) = request_rx.try_recv() {
        state.receive(request);
    }

    Some(())
}

/// Simulate prefill phase for all pending prefill requests.
299
300
301
302
303
///
/// Handles token budget, block allocation, and preemption inline.
/// Token budget: `max_num_batched_tokens - decode.len()` (1 token per decode request).
/// When blocks are unavailable, decode requests are preempted (LIFO by default)
/// to free capacity, matching vLLM v1 behavior.
304
async fn simulate_prefill(
305
306
    state: &mut SchedulerState,
    kv_manager: &mut KvManager,
307
308
    hit_rates: &mut RunningMean<f32>,
    args: &MockEngineArgs,
309
) -> Duration {
310
    let start_time = Instant::now();
311
312
    let mut total_time = Duration::ZERO;

313
314
315
316
317
318
319
320
    let mut token_budget = args
        .max_num_batched_tokens
        .map_or(usize::MAX, |t| t.saturating_sub(state.decode.len()));

    'prefill: while token_budget > 0 {
        // Drain prefill first, then pull from waiting one at a time
        if state.prefill.is_empty() && !state.admit_one(args) {
            break;
321
        }
322
        let uuid = state.prefill[0];
323

324
325
326
327
328
329
330
331
332
333
334
335
        let Some(Request::Active(seq)) = state.requests.get(&uuid) else {
            panic!("Request does not exist.");
        };
        let prefill_cost = kv_manager.get_prefill_cost(seq);
        let sequence_len = seq.len();
        let allocated_tokens = seq.num_allocated_tokens();
        let remaining = prefill_cost.new_tokens;

        // Token budget check
        let tokens_left = sequence_len - allocated_tokens;
        if !args.enable_chunked_prefill && tokens_left > token_budget {
            break;
336
        }
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
        let chunk = tokens_left.min(token_budget);
        let cumulative = allocated_tokens + chunk;

        // Allocate blocks. process() returns the number of blocks committed.
        // On partial success, preempt a decode request and retry — the next
        // loop iteration re-prepares from the updated num_allocated_tokens.
        let Some(Request::Active(seq)) = state.requests.get_mut(&uuid) else {
            panic!("Request does not exist.");
        };
        if let Some(signal) = seq.prepare_allocation(cumulative) {
            let expected = match &signal {
                MoveBlock::Use(blocks, ..) => blocks.len(),
                _ => unreachable!(),
            };
            let allocated = kv_manager.process(&signal);
            // Commit the blocks that were actually allocated
            let committed_tokens = if allocated == expected {
                cumulative
            } else {
                // Partial: compute token boundary from block count
                let prev_blocks = allocated_tokens
                    .div_ceil(seq.block_size())
                    .min(seq.unique_blocks().len());
                (prev_blocks + allocated) * seq.block_size()
            };
            seq.commit_allocation(committed_tokens.min(cumulative));
363

364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
            if allocated < expected {
                if state.decode.is_empty() {
                    break;
                }
                for signal in state.preempt(args.preemption_mode) {
                    kv_manager.process(&signal);
                }
                continue 'prefill; // retry with freed capacity
            }
        } else {
            seq.commit_allocation(cumulative);
        }

        // Accumulate prefill compute time (only for the new tokens in this chunk)
        let new_tokens_in_chunk = chunk.min(remaining);
        if args.worker_type != WorkerType::Decode && new_tokens_in_chunk > 0 {
            total_time += Duration::from_secs_f64(
                prefill_cost.predict_prefill_compute(Some(new_tokens_in_chunk), &args.perf_model)
                    / 1000.0,
            );
        }

        // Hit rate: fraction of tokens that were already cached
        let hit_rate = if sequence_len > 0 {
            1.0 - (remaining as f32 / sequence_len as f32)
        } else {
            0.0
        };
        hit_rates.push(hit_rate);

        token_budget -= chunk;

        if cumulative >= sequence_len {
            // Fully prefilled — promote to decode queue
            state.prefill.pop_front();
            state.decode.push_back(uuid);
        } else {
            // Partially prefilled — resume next iteration with updated allocated_tokens
402
403
404
            break;
        }
    }
405

406
407
    if args.speedup_ratio > 0.0 && total_time > Duration::ZERO {
        let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
408
409
        let deadline = start_time + sleep_duration;

410
        sleep_until_precise(deadline).await;
411
    }
412

413
414
415
416
417
    total_time
}

/// Simulate decode phase for all active decode requests.
/// Returns the total decode compute time.
418
async fn simulate_decode(
419
420
421
    state: &mut SchedulerState,
    kv_manager: &mut KvManager,
    output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
422
    args: &MockEngineArgs,
423
) -> Duration {
424
425
    let start_time = Instant::now();

426
    // Compute decode timing
427
    let active_kv_tokens = kv_manager.num_active_blocks() * args.block_size;
428

429
    // Compute average context length across all active decode requests
430
    let total_length: usize = state
431
        .decode
432
        .iter()
433
434
435
        .map(|uuid| {
            if let Request::Active(seq) = state.requests.get(uuid).unwrap() {
                seq.len()
436
            } else {
437
                0
438
            }
439
440
441
442
        })
        .sum();
    let count = state.decode.len();

443
    let context_length = if count > 0 { total_length / count } else { 0 };
444
445
446
    let decoding_time = args
        .perf_model
        .predict_decode_time(active_kv_tokens, context_length);
447
448
449
    let total_time = Duration::from_secs_f64(decoding_time / 1000.0);

    // Process decoding
450
    let uuids: Vec<Uuid> = state.decode.iter().copied().collect();
451
    for uuid in uuids {
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
        // Try to generate; if allocation fails, preempt until it succeeds
        // or nothing is left to preempt (matches vLLM v1 scheduler loop).
        // Reborrow sequence each iteration so the mutable ref doesn't
        // conflict with state.preempt().
        let mut allocated = false;
        loop {
            let Some(sequence) = state.run(uuid) else {
                break;
            };
            let signals = sequence.generate();
            if process_signals(kv_manager, &signals) {
                allocated = true;
                break;
            }
            sequence.pop(); // revert the failed generation
467

468
469
470
471
472
            if state.decode.is_empty() {
                break;
            }

            // Preempt one request and free its blocks
473
            for signal in state.preempt(args.preemption_mode) {
474
475
                kv_manager.process(&signal);
            }
476
477
478
479
480
481
482
483

            // If the current request was the one preempted, stop retrying
            if !state.decode.contains(&uuid) {
                break;
            }
        }

        if !allocated {
484
485
486
            continue;
        }

487
488
489
490
        let Some(sequence) = state.run(uuid) else {
            continue;
        };

491
492
493
        // Check completion and send notification
        let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();

494
495
496
497
498
499
500
        let send_failed = output_tx.as_ref().is_some_and(|tx| {
            tx.send(OutputSignal {
                uuid,
                completed: is_complete,
            })
            .is_err()
        });
501
502
503
504
505
506
507
508
509
510
511

        if send_failed {
            for signal in &sequence.free_signal() {
                kv_manager.process(signal);
            }
        }

        if send_failed || is_complete {
            state.complete(&uuid);
        }
    }
512

513
514
515
    let effective_ratio = args.speedup_ratio * args.decode_speedup_ratio;
    if effective_ratio > 0.0 && total_time > Duration::ZERO {
        let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / effective_ratio);
516
517
        let deadline = start_time + sleep_duration;

518
        sleep_until_precise(deadline).await;
519
    }
520

521
522
523
    total_time
}

524
525
526
527
528
529
530
/// Processes MoveBlock signals with the KvManager.
///
/// When a signal fails, this function verifies that the failure is for an expected case:
/// specifically a single signal attempting to create a single partial (generation) block.
/// This validation is important because in normal operation, the only legitimate failure
/// case should be when trying to acquire a new generation block - any other failures would
/// indicate an unexpected state in the system.
531
fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool {
532
    for signal in signals {
533
        if kv_manager.process(signal) > 0 {
534
535
536
537
            continue;
        }

        // Check we have a Use signal with blocks
538
        let MoveBlock::Use(blocks, _hashes, ..) = signal else {
539
540
541
            panic!(
                "Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}"
            );
542
543
544
        };

        // Verify the signal contains exactly one block
545
        let num_blocks = blocks.len();
546
        let num_active_blocks = kv_manager.num_active_blocks();
547
548
549
550
        if num_blocks != 1 {
            panic!(
                "Failed signal is Invalid. Tried to create (prefill) {num_blocks} blocks on top of {num_active_blocks} active blocks."
            );
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        }

        // Verify the block is a PartialBlock (generation block)
        if !matches!(blocks[0], UniqueBlock::PartialBlock(_)) {
            panic!("Failed signal is Invalid. Generation block has to be partial.");
        }

        return false;
    }

    true
}

#[cfg(test)]
mod tests {
    use super::*;
    use rstest::rstest;
    use std::time::Duration;
569
    use tokio::time::interval;
570

571
572
    /// Helper function to verify that the scheduler is idle (no active KV blocks)
    fn assert_scheduler_idle(metrics: &MockerMetrics) {
573
        assert_eq!(
574
            metrics.active_decode_blocks, 0,
575
            "Expected 0 active blocks, got {}",
576
            metrics.active_decode_blocks
577
578
579
        );
    }

580
    #[rstest]
581
582
583
584
585
586
587
588
    #[case::case_1(false, false, false)]
    #[case::case_2(false, true, false)]
    #[case::case_3(true, false, false)]
    #[case::case_4(true, true, false)]
    #[case::case_5(false, false, true)]
    #[case::case_6(false, true, true)]
    #[case::case_7(true, false, true)]
    #[case::case_8(true, true, true)]
589
    #[tokio::test]
590
591
592
    async fn test_scheduler_token_generation_patterns(
        #[case] use_shared_tokens: bool,
        #[case] enable_prefix_caching: bool,
593
        #[case] enable_chunked_prefill: bool,
594
    ) {
595
        unsafe { std::env::set_var("RUST_LOG", "debug") };
596
597

        let kv_capacity: usize = 500;
598
        let block_size: usize = 64;
599
        let num_requests: usize = 200;
600
601
602
603
        let input_len: usize = 1000;
        let max_output_tokens: usize = 100;

        // Create channel for token output
604
605
606
607
608
609
610
611
        let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();

        // Create scheduler args using builder - now including enable_prefix_caching
        let args = MockEngineArgs::builder()
            .num_gpu_blocks(kv_capacity)
            .block_size(block_size)
            .speedup_ratio(10.0)
            .enable_prefix_caching(enable_prefix_caching)
612
            .enable_chunked_prefill(enable_chunked_prefill)
613
614
615
616
            .build()
            .unwrap();

        // Create scheduler with new args struct
Yan Ru Pei's avatar
Yan Ru Pei committed
617
        let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647

        // Create shared tokens for caching case
        let shared_tokens = if use_shared_tokens {
            Some(
                (0..input_len / 2)
                    .map(|_| rand::random::<u32>() % 50000)
                    .collect::<Vec<_>>(),
            )
        } else {
            None
        };

        // Create test requests
        for _ in 0..num_requests {
            let input_tokens = if let Some(ref shared) = shared_tokens {
                // For caching case: use shared tokens for first half, random for second half
                let mut tokens = shared.clone();
                tokens.extend((0..input_len / 2).map(|_| rand::random::<u32>() % 50000));
                tokens
            } else {
                // For random case: create unique random token vector for each request
                (0..input_len)
                    .map(|_| rand::random::<u32>() % 50000)
                    .collect::<Vec<_>>()
            };

            let request = DirectRequest {
                tokens: input_tokens,
                max_output_tokens,
                uuid: None,
Yan Ru Pei's avatar
Yan Ru Pei committed
648
                dp_rank: 0,
649
650
651
652
653
654
655
656
657
658
659
660
661
662
            };
            scheduler.receive(request).await;
        }

        let start_time = std::time::Instant::now();

        // Collect all generated tokens (should be num_requests * max_output_tokens)
        let expected_tokens = num_requests * max_output_tokens;
        let mut received_tokens = 0;

        // Set up a timeout that causes the test to panic if no tokens are received for 2 seconds
        let timeout = tokio::time::sleep(Duration::from_secs(2));
        tokio::pin!(timeout);

663
664
665
        // Get metrics receiver
        let metrics_rx = scheduler.metrics_receiver();

666
667
668
669
670
671
672
673
674
        // Set up debug ticker interval
        let mut debug_interval = interval(Duration::from_millis(500));

        loop {
            tokio::select! {
                biased;

                // Manual debug ticker that prints forward pass metrics
                _ = debug_interval.tick() => {
675
                    let _metrics = metrics_rx.borrow().clone();
676
                    tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
                }

                Some(_) = output_rx.recv() => {
                    received_tokens += 1;
                    // Reset timeout whenever we receive a token
                    timeout.set(tokio::time::sleep(Duration::from_secs(2)));
                }

                _ = &mut timeout => {
                    // Break instead of panicking when timeout occurs
                    break;
                }
            }
        }

        // Calculate and print elapsed time
        let elapsed = start_time.elapsed();
        println!(
695
            "Test completed in: {elapsed:?} for {} case with prefix_caching={enable_prefix_caching} and chunked_prefill={enable_chunked_prefill}",
696
697
698
699
            if use_shared_tokens {
                "caching"
            } else {
                "random"
700
            }
701
702
703
704
        );

        // Assert that we received the expected number of tokens
        assert!(
705
706
707
            received_tokens == expected_tokens,
            "Received {received_tokens} tokens but expected exactly {expected_tokens}"
        );
708

709
710
        // Wait a bit for final metrics update to propagate
        tokio::time::sleep(Duration::from_millis(100)).await;
711

712
713
        let metrics = scheduler.metrics_receiver().borrow().clone();
        assert_scheduler_idle(&metrics);
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
    }

    #[tokio::test]
    async fn test_cache_hit_rate_with_identical_requests() {
        let block_size: usize = 64;
        let max_output_tokens: usize = 10;
        let speedup_ratio = 10.0;
        let num_requests = 10;
        let token_length = 65;

        // Create channel for token output
        let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();

        // Create scheduler args
        let args = MockEngineArgs::builder()
            .num_gpu_blocks(100) // Large enough to not be a constraint
            .block_size(block_size)
            .speedup_ratio(speedup_ratio)
            .build()
            .unwrap();

        // Create scheduler
Yan Ru Pei's avatar
Yan Ru Pei committed
736
        let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
737
738
739
740
741
742
743
744
745
746

        // Create identical tokens for all requests
        let identical_tokens: Vec<u32> = (0..token_length).map(|i| i as u32).collect();

        // Send all requests with identical tokens
        for _ in 0..num_requests {
            let request = DirectRequest {
                tokens: identical_tokens.clone(),
                max_output_tokens,
                uuid: None,
Yan Ru Pei's avatar
Yan Ru Pei committed
747
                dp_rank: 0,
748
749
750
751
752
753
754
755
756
757
758
759
760
            };
            scheduler.receive(request).await;
            // Sleep for 0.1 second after each request
            tokio::time::sleep(Duration::from_millis(100)).await;
        }

        // Collect all generated tokens
        let mut received_tokens = 0;

        // Set up a timeout that resets to 0.5 seconds on each received token
        let timeout = tokio::time::sleep(Duration::from_millis(500));
        tokio::pin!(timeout);

761
762
763
        // Get metrics receiver
        let metrics_rx = scheduler.metrics_receiver();

764
765
766
767
768
769
770
771
772
        // Set up debug ticker interval
        let mut debug_interval = interval(Duration::from_millis(500));

        loop {
            tokio::select! {
                biased;

                // Manual debug ticker that prints forward pass metrics
                _ = debug_interval.tick() => {
773
                    let _metrics = metrics_rx.borrow().clone();
774
                    tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
                }

                Some(_signal) = output_rx.recv() => {
                    received_tokens += 1;
                    // Reset timeout whenever we receive a token
                    timeout.set(tokio::time::sleep(Duration::from_millis(500)));
                }

                _ = &mut timeout => {
                    // Break when timeout occurs (no more tokens for 0.5 seconds)
                    break;
                }
            }
        }

790
791
792
        // Wait a bit for final metrics update
        tokio::time::sleep(Duration::from_millis(100)).await;

793
        // Verify forward pass metrics - scheduler should be idle after completing all requests
794
        let metrics = metrics_rx.borrow().clone();
795
        assert_scheduler_idle(&metrics);
796

797
        println!("Test passed! Received {received_tokens} tokens");
798
799
    }

800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
    /// White-box unit test that directly creates SchedulerState + KvManager,
    /// manually invokes simulate_prefill / simulate_decode, and asserts on
    /// queue states and block counts after each step.
    #[tokio::test]
    async fn test_scheduler_internal_state_transitions() {
        let args = MockEngineArgs::builder()
            .block_size(4)
            .num_gpu_blocks(6)
            .max_num_batched_tokens(Some(12))
            .max_num_seqs(Some(3))
            .enable_chunked_prefill(true)
            .enable_prefix_caching(false)
            .speedup_ratio(0.0)
            .build()
            .unwrap();

        let mut state = SchedulerState::default();
        let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
        let mut hit_rates = RunningMean::new(1000);
        let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;

        let r1_uuid = Uuid::from_u128(1);
        let r2_uuid = Uuid::from_u128(2);
        let r3_uuid = Uuid::from_u128(3);

        // ── Step 1: Receive 3 requests ──
        // R1: 8 input, 2 max_output → 2 blocks
        // R2: 8 input, 2 max_output → 2 blocks
        // R3: 12 input, 2 max_output → 3 blocks
        state.receive(DirectRequest {
            tokens: (0..8).collect(),
            max_output_tokens: 2,
            uuid: Some(r1_uuid),
            dp_rank: 0,
        });
        state.receive(DirectRequest {
            tokens: (100..108).collect(),
            max_output_tokens: 2,
            uuid: Some(r2_uuid),
            dp_rank: 0,
        });
        state.receive(DirectRequest {
            tokens: (200..212).collect(),
            max_output_tokens: 2,
            uuid: Some(r3_uuid),
            dp_rank: 0,
        });

        assert_eq!(state.waiting.len(), 3);
        assert_eq!(state.prefill.len(), 0);
        assert_eq!(state.decode.len(), 0);
        assert_eq!(kv_manager.num_active_blocks(), 0);

        // ── Step 2: First simulate_prefill ──
        // Budget=12. R1 takes 8 tokens (2 blocks), fully prefilled → decode.
        // R2 takes 4 tokens (1 block, chunked), partially prefilled → stays in prefill.
        simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;

        assert_eq!(state.waiting.len(), 1);
        assert_eq!(state.prefill.len(), 1);
        assert_eq!(state.decode.len(), 1);
        assert_eq!(state.decode[0], r1_uuid);
        assert_eq!(state.prefill[0], r2_uuid);
        assert_eq!(state.waiting[0], r3_uuid);
        assert_eq!(kv_manager.num_active_blocks(), 3); // 2 for R1 + 1 for R2

        let seq = match state.requests.get(&r1_uuid).unwrap() {
            Request::Active(s) => s,
            _ => panic!("expected ActiveSequence"),
        };
        assert_eq!(seq.num_allocated_tokens(), 8);
        assert_eq!(seq.generated_tokens(), 0);

        let seq = match state.requests.get(&r2_uuid).unwrap() {
            Request::Active(s) => s,
            _ => panic!("expected ActiveSequence"),
        };
        assert_eq!(seq.num_allocated_tokens(), 4);
        assert_eq!(seq.generated_tokens(), 0);

        // ── Step 3: First simulate_decode ──
        // R1 generates 1 token, gains a partial block.
882
        simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916

        assert_eq!(state.decode.len(), 1);
        assert_eq!(state.decode[0], r1_uuid);
        assert_eq!(kv_manager.num_active_blocks(), 4); // +1 partial for R1

        let seq = match state.requests.get(&r1_uuid).unwrap() {
            Request::Active(s) => s,
            _ => panic!("expected ActiveSequence"),
        };
        assert_eq!(seq.generated_tokens(), 1);

        // ── Step 4: Second simulate_prefill ──
        // Budget=11. R2 finishes (4 more tokens, 1 block → active=5, decode).
        // R3 admitted, needs 2 blocks for chunk of 7. Only 1 free slot → partial.
        // Preempt R2 (LIFO) → R2 back to waiting. Retry R3 → evicts R2's
        // inactive blocks, allocates 2 more → R3 allocated_tokens=11.
        simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;

        assert_eq!(state.waiting.len(), 1, "R2 preempted back to waiting");
        assert_eq!(state.waiting[0], r2_uuid);
        assert_eq!(state.prefill.len(), 1, "R3 partially prefilled");
        assert_eq!(state.prefill[0], r3_uuid);
        assert_eq!(state.decode.len(), 1, "R1 still decoding");
        assert_eq!(state.decode[0], r1_uuid);
        assert_eq!(kv_manager.num_active_blocks(), 6); // at capacity

        let seq = match state.requests.get(&r3_uuid).unwrap() {
            Request::Active(s) => s,
            _ => panic!("expected ActiveSequence"),
        };
        assert_eq!(seq.num_allocated_tokens(), 11);

        // ── Step 5: Second simulate_decode ──
        // R1 generates 2nd token → complete. Frees 3 blocks (1 destroyed, 2 deactivated).
917
        simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939

        assert!(!state.requests.contains_key(&r1_uuid), "R1 completed");
        assert_eq!(state.decode.len(), 0);
        assert_eq!(state.prefill.len(), 1);
        assert_eq!(state.waiting.len(), 1);
        assert_eq!(kv_manager.num_active_blocks(), 3); // only R3's 3 blocks

        // ── Step 6: Third simulate_prefill ──
        // R3 finishes prefill (1 token left, no new blocks) → decode.
        // R2 re-admitted, fully prefilled (2 blocks via inactive eviction) → decode.
        simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;

        assert_eq!(state.waiting.len(), 0);
        assert_eq!(state.prefill.len(), 0);
        assert_eq!(state.decode.len(), 2);
        assert!(state.decode.contains(&r3_uuid));
        assert!(state.decode.contains(&r2_uuid));
        assert_eq!(kv_manager.num_active_blocks(), 5); // 3 for R3 + 2 for R2

        // ── Steps 7+: Cycle until all requests complete ──
        loop {
            simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
940
            simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
941
942
943
944
945
946
947
948
949
950
951
952

            if state.is_empty() {
                break;
            }
        }

        assert_eq!(state.waiting.len(), 0);
        assert_eq!(state.prefill.len(), 0);
        assert_eq!(state.decode.len(), 0);
        assert_eq!(kv_manager.num_active_blocks(), 0);
    }

953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
    #[tokio::test]
    async fn test_receiver_drop_cleans_up_resources() {
        let block_size: usize = 64;
        let input_tokens = 256;
        let max_output_tokens = 200; // More than we'll receive

        // Create channel for token output
        let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();

        // Create scheduler args
        let args = MockEngineArgs::builder()
            .num_gpu_blocks(10) // Enough for 256 tokens (4 blocks)
            .block_size(block_size)
            .speedup_ratio(100.0) // Fast simulation
            .build()
            .unwrap();

        // Create scheduler
Yan Ru Pei's avatar
Yan Ru Pei committed
971
        let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
972
973
974
975
976
977
978

        // Create request with 256 tokens
        let tokens: Vec<u32> = (0..input_tokens).map(|i| i as u32).collect();
        let request = DirectRequest {
            tokens,
            max_output_tokens,
            uuid: None,
Yan Ru Pei's avatar
Yan Ru Pei committed
979
            dp_rank: 0,
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
        };

        scheduler.receive(request).await;

        // Receive exactly 129 tokens
        let mut received_count = 0;
        while received_count < 129 {
            if let Some(_signal) = output_rx.recv().await {
                received_count += 1;
            } else {
                panic!("Channel closed before receiving 129 tokens");
            }
        }

        // Drop the receiver immediately
        drop(output_rx);

        // Wait for 1 second to allow cleanup
        tokio::time::sleep(Duration::from_secs(1)).await;

        // Check forward pass metrics
1001
1002
        let metrics_rx = scheduler.metrics_receiver();
        let metrics = metrics_rx.borrow().clone();
1003

1004
        assert_scheduler_idle(&metrics);
1005
1006
    }
}