"vllm/vscode:/vscode.git/clone" did not exist on "58ee61422169ce17e08248f8efa1e9df434fe395"
vllm.rs 36.3 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
// SPDX-License-Identifier: Apache-2.0

//! Asynchronous Scheduler for LLM Request Management
//!
//! This module implements an asynchronous scheduler that handles three main functions:
//! 1. Receiving new requests and placing them in the waiting queue
//! 2. Scheduling waiting requests against available KV cache resources
//! 3. Simulating the execution of running requests with realistic timing
//!
//! ## Scheduling Process
12
//! The scheduler checks direct block capacity to determine if there's sufficient
13
14
15
16
17
18
19
20
21
22
23
24
//! KV cache space for new requests. It also enforces a batched tokens budget to prevent
//! oversubscription of computational resources. Only requests that can be allocated
//! these resources are moved from waiting to running state.
//!
//! ## Request Simulation
//! The simulation models two key phases:
//! - Prefill phase: Uses a quadratic cost function: (cached_tokens + new_tokens) * new_tokens
//! - Decode phase: Uses a cost function proportional to active KV blocks (linear)
//!
//! ## Resource Management
//! The scheduler communicates with the KvManager through MoveBlock signals at each
//! stage of request processing. When resources become constrained, it employs an
25
26
//! preemption strategy (LIFO by default, matching vLLM v1) where a running request
//! is evicted and placed at the front of the waiting queue to be rescheduled later.
27
28
29
30
//!
//! ## NOTE
//! The current prefill and decoding time simulations are not scientific at all and are WIP

31
32
use crate::common::perf_model::PerfModel;
use crate::common::protocols::{
33
    DirectRequest, KvCacheEventSink, MockEngineArgs, MoveBlock, OutputSignal, PreemptionMode,
34
    WorkerType,
Yan Ru Pei's avatar
Yan Ru Pei committed
35
};
36
37
use crate::common::running_mean::RunningMean;
use crate::common::sequence::ActiveSequence;
38
use crate::common::utils::sleep_until_precise;
39
use crate::kv_manager::KvManager;
40
use dynamo_kv_router::protocols::DpRank;
41
use dynamo_tokens::blocks::UniqueBlock;
Yan Ru Pei's avatar
Yan Ru Pei committed
42
use std::collections::{HashMap, VecDeque};
43
use std::sync::Arc;
44
use std::time::Instant;
45
use tokio::sync::mpsc;
46
use tokio::time::Duration;
47
48
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
49
use validator::Validate;
50

51
52
53
54
55
56
57
/// Simple metrics struct for mocker's internal use
#[derive(Clone, Default, Debug)]
pub struct MockerMetrics {
    pub dp_rank: DpRank,
    pub active_decode_blocks: u64,
}

58
59
60
61
62
63
64
65
66
/// Enum representing either a direct request or an active sequence
pub enum Request {
    Direct(DirectRequest),
    Active(ActiveSequence),
}

#[derive(Default)]
struct SchedulerState {
    waiting: VecDeque<Uuid>,
67
    prefill: VecDeque<Uuid>,
68
    decode: VecDeque<Uuid>,
69
70
71
72
    requests: HashMap<Uuid, Request>,
}

impl SchedulerState {
73
74
75
76
    fn is_empty(&self) -> bool {
        self.requests.is_empty()
    }

77
78
79
80
81
82
83
    fn receive(&mut self, request: DirectRequest) -> Uuid {
        let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
        self.requests.insert(uuid, Request::Direct(request));
        self.waiting.push_back(uuid);
        uuid
    }

84
85
86
87
88
89
    /// Try to admit one request from waiting → prefill.
    /// Converts DirectRequest → ActiveSequence if needed. PrefillCost is computed
    /// later in simulate_prefill when the request reaches the front of the queue.
    fn admit_one(&mut self, args: &MockEngineArgs) -> bool {
        let Some(&uuid) = self.waiting.front() else {
            return false;
90
        };
91
92
93
94
        let num_active = self.prefill.len() + self.decode.len();
        if args.max_num_seqs.is_some_and(|limit| num_active >= limit) {
            return false;
        }
95

96
        self.waiting.pop_front();
97

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        // Convert DirectRequest → ActiveSequence if needed
        if let Some(Request::Direct(_)) = self.requests.get(&uuid) {
            let Some(Request::Direct(direct)) = self.requests.remove(&uuid) else {
                unreachable!()
            };
            self.requests.insert(
                uuid,
                Request::Active(ActiveSequence::new(
                    direct.tokens,
                    direct.max_output_tokens,
                    Some(args.block_size),
                    args.enable_prefix_caching,
                    args.zmq_kv_events_port.is_some(),
                )),
            );
        }
114

115
116
        self.prefill.push_back(uuid);
        true
117
118
    }

119
120
121
122
123
124
125
126
    fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> {
        if !self.decode.contains(&uuid) {
            return None;
        }
        let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else {
            panic!("Request does not exist.");
        };
        Some(sequence)
127
128
129
130
    }

    /// Remove a UUID and its associated Request from collections.
    fn complete(&mut self, uuid: &Uuid) {
131
        tracing::trace!("Request {uuid} will complete");
132
        self.decode.retain(|u| u != uuid);
133
134
135
        self.requests.remove(uuid);
    }

136
137
138
139
140
141
142
143
144
145
    /// Preempt a running request by evicting it from decode, resetting the sequence,
    /// and adding it back to the front of the waiting queue.
    /// In LIFO mode, evicts the newest request (matches vLLM v1).
    /// In FIFO mode, evicts the oldest request.
    fn preempt(&mut self, mode: PreemptionMode) -> Vec<MoveBlock> {
        let uuid = match mode {
            PreemptionMode::Lifo => self.decode.pop_back(),
            PreemptionMode::Fifo => self.decode.pop_front(),
        }
        .expect("Nothing to evict for preemption.");
146
147
148
149
        let request = self
            .requests
            .remove(&uuid)
            .expect("Request does not exist.");
150
        tracing::warn!("Request {uuid} will be preempted");
151

152
        // Reset the sequence and re-queue for prefill
153
154
155
156
157
        let Request::Active(mut active_sequence) = request else {
            panic!("Expected ActiveSequence in running queue")
        };
        let signals = active_sequence.reset_with_signal();

158
159
        self.requests.insert(uuid, Request::Active(active_sequence));
        self.waiting.push_front(uuid);
160
161

        signals
162
163
164
    }
}

165
166
167
168
169
170
171
172
173
174
/// Cancels its token when dropped. Shared via Arc so the background task is
/// only cancelled when the last Scheduler clone is dropped.
struct CancelGuard(CancellationToken);

impl Drop for CancelGuard {
    fn drop(&mut self) {
        self.0.cancel();
    }
}

175
176
177
/// Manages scheduling of requests using KvManager resources
#[derive(Clone)]
pub struct Scheduler {
178
    request_tx: mpsc::UnboundedSender<DirectRequest>,
179
    metrics_rx: tokio::sync::watch::Receiver<MockerMetrics>,
180
    _cancel_guard: Arc<CancelGuard>,
181
182
183
184
185
}

impl Scheduler {
    /// Create a new Scheduler with the given parameters
    pub fn new(
186
        args: MockEngineArgs,
Yan Ru Pei's avatar
Yan Ru Pei committed
187
        dp_rank: u32,
188
        output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
189
        kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
190
191
        cancellation_token: Option<CancellationToken>,
    ) -> Self {
192
        args.validate().expect("invalid MockEngineArgs");
193

194
195
        // Create channel for request handling
        let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
196
197
198
199
        let initial_metrics = MockerMetrics {
            dp_rank,
            active_decode_blocks: 0,
        };
200
        let (metrics_tx, metrics_rx) =
201
            tokio::sync::watch::channel::<MockerMetrics>(initial_metrics);
202

203
204
205
        let cancel_token = cancellation_token.unwrap_or_default();
        let cancel_token_clone = cancel_token.clone();
        let cancel_guard = Arc::new(CancelGuard(cancel_token));
206
207
208

        // Spawn main background task with cancellation token
        tokio::spawn(async move {
209
            // Create state and kv_manager as local variables owned by this task
210
            let mut state = SchedulerState::default();
211
            let mut kv_manager = KvManager::new_with_event_sink(
Yan Ru Pei's avatar
Yan Ru Pei committed
212
213
                args.num_gpu_blocks,
                args.block_size,
214
                kv_event_sink,
Yan Ru Pei's avatar
Yan Ru Pei committed
215
216
217
                dp_rank,
            );
            let mut hit_rates = RunningMean::new(1000);
218
219

            loop {
Yan Ru Pei's avatar
Yan Ru Pei committed
220
                // 1. Receive requests
221
222
223
224
225
                if receive_requests(&mut state, &mut request_rx, &cancel_token_clone)
                    .await
                    .is_none()
                {
                    break;
226
                }
227

228
229
                // 2. Simulate prefill + decode
                simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
230

231
                simulate_decode(
232
233
234
235
236
                    &mut state,
                    &mut kv_manager,
                    &output_tx,
                    &args.perf_model,
                    args.block_size,
237
                    args.speedup_ratio,
238
                    args.preemption_mode,
239
240
                )
                .await;
241

242
                // 3. Send metrics once per forward pass (after all prefill and decode processing)
243
                let _ = metrics_tx.send(MockerMetrics {
244
                    dp_rank,
245
246
                    active_decode_blocks: kv_manager.num_active_blocks() as u64,
                });
247
248
249
250
251
            }
        });

        Self {
            request_tx,
252
            metrics_rx,
253
            _cancel_guard: cancel_guard,
254
255
256
        }
    }

257
    /// Add a new request to the prefill queue
258
    pub async fn receive(&self, request: DirectRequest) {
259
260
261
262
263
        let _ = self.request_tx.send(request);
    }

    pub fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
        self.request_tx.clone()
264
265
    }

266
    /// Get a watch receiver for forward pass metrics
267
    pub fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
268
269
270
        self.metrics_rx.clone()
    }
}
271

272
273
274
275
276
277
278
279
280
281
282
283
/// Receive requests from the channel.
/// Returns `Some(())` to continue the loop, `None` to break (on cancellation).
async fn receive_requests(
    state: &mut SchedulerState,
    request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
    cancel_token: &CancellationToken,
) -> Option<()> {
    if cancel_token.is_cancelled() {
        return None;
    }

    if state.is_empty() {
284
        // Fully idle - block until new request arrives or shutdown
285
286
287
288
289
        tokio::select! {
            biased;
            _ = cancel_token.cancelled() => {
                return None;
            }
290
291
292
293
            result = request_rx.recv() => {
                let Some(request) = result else {
                    return None; // channel closed
                };
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
                state.receive(request);
                return Some(());
            }
        }
    }

    // Has active/waiting work - collect any pending requests without blocking
    while let Ok(request) = request_rx.try_recv() {
        state.receive(request);
    }

    Some(())
}

/// Simulate prefill phase for all pending prefill requests.
309
310
311
312
313
///
/// Handles token budget, block allocation, and preemption inline.
/// Token budget: `max_num_batched_tokens - decode.len()` (1 token per decode request).
/// When blocks are unavailable, decode requests are preempted (LIFO by default)
/// to free capacity, matching vLLM v1 behavior.
314
async fn simulate_prefill(
315
316
    state: &mut SchedulerState,
    kv_manager: &mut KvManager,
317
318
    hit_rates: &mut RunningMean<f32>,
    args: &MockEngineArgs,
319
) -> Duration {
320
    let start_time = Instant::now();
321
322
    let mut total_time = Duration::ZERO;

323
324
325
326
327
328
329
330
    let mut token_budget = args
        .max_num_batched_tokens
        .map_or(usize::MAX, |t| t.saturating_sub(state.decode.len()));

    'prefill: while token_budget > 0 {
        // Drain prefill first, then pull from waiting one at a time
        if state.prefill.is_empty() && !state.admit_one(args) {
            break;
331
        }
332
        let uuid = state.prefill[0];
333

334
335
336
337
338
339
340
341
342
343
344
345
        let Some(Request::Active(seq)) = state.requests.get(&uuid) else {
            panic!("Request does not exist.");
        };
        let prefill_cost = kv_manager.get_prefill_cost(seq);
        let sequence_len = seq.len();
        let allocated_tokens = seq.num_allocated_tokens();
        let remaining = prefill_cost.new_tokens;

        // Token budget check
        let tokens_left = sequence_len - allocated_tokens;
        if !args.enable_chunked_prefill && tokens_left > token_budget {
            break;
346
        }
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        let chunk = tokens_left.min(token_budget);
        let cumulative = allocated_tokens + chunk;

        // Allocate blocks. process() returns the number of blocks committed.
        // On partial success, preempt a decode request and retry — the next
        // loop iteration re-prepares from the updated num_allocated_tokens.
        let Some(Request::Active(seq)) = state.requests.get_mut(&uuid) else {
            panic!("Request does not exist.");
        };
        if let Some(signal) = seq.prepare_allocation(cumulative) {
            let expected = match &signal {
                MoveBlock::Use(blocks, ..) => blocks.len(),
                _ => unreachable!(),
            };
            let allocated = kv_manager.process(&signal);
            // Commit the blocks that were actually allocated
            let committed_tokens = if allocated == expected {
                cumulative
            } else {
                // Partial: compute token boundary from block count
                let prev_blocks = allocated_tokens
                    .div_ceil(seq.block_size())
                    .min(seq.unique_blocks().len());
                (prev_blocks + allocated) * seq.block_size()
            };
            seq.commit_allocation(committed_tokens.min(cumulative));
373

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
            if allocated < expected {
                if state.decode.is_empty() {
                    break;
                }
                for signal in state.preempt(args.preemption_mode) {
                    kv_manager.process(&signal);
                }
                continue 'prefill; // retry with freed capacity
            }
        } else {
            seq.commit_allocation(cumulative);
        }

        // Accumulate prefill compute time (only for the new tokens in this chunk)
        let new_tokens_in_chunk = chunk.min(remaining);
        if args.worker_type != WorkerType::Decode && new_tokens_in_chunk > 0 {
            total_time += Duration::from_secs_f64(
                prefill_cost.predict_prefill_compute(Some(new_tokens_in_chunk), &args.perf_model)
                    / 1000.0,
            );
        }

        // Hit rate: fraction of tokens that were already cached
        let hit_rate = if sequence_len > 0 {
            1.0 - (remaining as f32 / sequence_len as f32)
        } else {
            0.0
        };
        hit_rates.push(hit_rate);

        token_budget -= chunk;

        if cumulative >= sequence_len {
            // Fully prefilled — promote to decode queue
            state.prefill.pop_front();
            state.decode.push_back(uuid);
        } else {
            // Partially prefilled — resume next iteration with updated allocated_tokens
412
413
414
            break;
        }
    }
415

416
417
    if args.speedup_ratio > 0.0 && total_time > Duration::ZERO {
        let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
418
419
        let deadline = start_time + sleep_duration;

420
        sleep_until_precise(deadline).await;
421
    }
422

423
424
425
426
427
    total_time
}

/// Simulate decode phase for all active decode requests.
/// Returns the total decode compute time.
428
async fn simulate_decode(
429
430
431
432
433
    state: &mut SchedulerState,
    kv_manager: &mut KvManager,
    output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
    perf_model: &PerfModel,
    block_size: usize,
434
    speedup_ratio: f64,
435
    preemption_mode: PreemptionMode,
436
) -> Duration {
437
438
    let start_time = Instant::now();

439
440
    // Compute decode timing
    let active_kv_tokens = kv_manager.num_active_blocks() * block_size;
441

442
    // Compute average context length across all active decode requests
443
    let total_length: usize = state
444
        .decode
445
        .iter()
446
447
448
        .map(|uuid| {
            if let Request::Active(seq) = state.requests.get(uuid).unwrap() {
                seq.len()
449
            } else {
450
                0
451
            }
452
453
454
455
        })
        .sum();
    let count = state.decode.len();

456
457
458
459
460
    let context_length = if count > 0 { total_length / count } else { 0 };
    let decoding_time = perf_model.predict_decode_time(active_kv_tokens, context_length);
    let total_time = Duration::from_secs_f64(decoding_time / 1000.0);

    // Process decoding
461
    let uuids: Vec<Uuid> = state.decode.iter().copied().collect();
462
    for uuid in uuids {
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
        // Try to generate; if allocation fails, preempt until it succeeds
        // or nothing is left to preempt (matches vLLM v1 scheduler loop).
        // Reborrow sequence each iteration so the mutable ref doesn't
        // conflict with state.preempt().
        let mut allocated = false;
        loop {
            let Some(sequence) = state.run(uuid) else {
                break;
            };
            let signals = sequence.generate();
            if process_signals(kv_manager, &signals) {
                allocated = true;
                break;
            }
            sequence.pop(); // revert the failed generation
478

479
480
481
482
483
484
            if state.decode.is_empty() {
                break;
            }

            // Preempt one request and free its blocks
            for signal in state.preempt(preemption_mode) {
485
486
                kv_manager.process(&signal);
            }
487
488
489
490
491
492
493
494

            // If the current request was the one preempted, stop retrying
            if !state.decode.contains(&uuid) {
                break;
            }
        }

        if !allocated {
495
496
497
            continue;
        }

498
499
500
501
        let Some(sequence) = state.run(uuid) else {
            continue;
        };

502
503
504
        // Check completion and send notification
        let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();

505
506
507
508
509
510
511
        let send_failed = output_tx.as_ref().is_some_and(|tx| {
            tx.send(OutputSignal {
                uuid,
                completed: is_complete,
            })
            .is_err()
        });
512
513
514
515
516
517
518
519
520
521
522

        if send_failed {
            for signal in &sequence.free_signal() {
                kv_manager.process(signal);
            }
        }

        if send_failed || is_complete {
            state.complete(&uuid);
        }
    }
523
524
525
526
527

    if speedup_ratio > 0.0 && total_time > Duration::ZERO {
        let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / speedup_ratio);
        let deadline = start_time + sleep_duration;

528
        sleep_until_precise(deadline).await;
529
    }
530

531
532
533
    total_time
}

534
535
536
537
538
539
540
/// Processes MoveBlock signals with the KvManager.
///
/// When a signal fails, this function verifies that the failure is for an expected case:
/// specifically a single signal attempting to create a single partial (generation) block.
/// This validation is important because in normal operation, the only legitimate failure
/// case should be when trying to acquire a new generation block - any other failures would
/// indicate an unexpected state in the system.
541
fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool {
542
    for signal in signals {
543
        if kv_manager.process(signal) > 0 {
544
545
546
547
            continue;
        }

        // Check we have a Use signal with blocks
548
        let MoveBlock::Use(blocks, _hashes, ..) = signal else {
549
550
551
            panic!(
                "Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}"
            );
552
553
554
        };

        // Verify the signal contains exactly one block
555
        let num_blocks = blocks.len();
556
        let num_active_blocks = kv_manager.num_active_blocks();
557
558
559
560
        if num_blocks != 1 {
            panic!(
                "Failed signal is Invalid. Tried to create (prefill) {num_blocks} blocks on top of {num_active_blocks} active blocks."
            );
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
        }

        // Verify the block is a PartialBlock (generation block)
        if !matches!(blocks[0], UniqueBlock::PartialBlock(_)) {
            panic!("Failed signal is Invalid. Generation block has to be partial.");
        }

        return false;
    }

    true
}

#[cfg(test)]
mod tests {
    use super::*;
    use rstest::rstest;
    use std::time::Duration;
579
    use tokio::time::interval;
580

581
582
    /// Helper function to verify that the scheduler is idle (no active KV blocks)
    fn assert_scheduler_idle(metrics: &MockerMetrics) {
583
        assert_eq!(
584
            metrics.active_decode_blocks, 0,
585
            "Expected 0 active blocks, got {}",
586
            metrics.active_decode_blocks
587
588
589
        );
    }

590
    #[rstest]
591
592
593
594
595
596
597
598
    #[case::case_1(false, false, false)]
    #[case::case_2(false, true, false)]
    #[case::case_3(true, false, false)]
    #[case::case_4(true, true, false)]
    #[case::case_5(false, false, true)]
    #[case::case_6(false, true, true)]
    #[case::case_7(true, false, true)]
    #[case::case_8(true, true, true)]
599
    #[tokio::test]
600
601
602
    async fn test_scheduler_token_generation_patterns(
        #[case] use_shared_tokens: bool,
        #[case] enable_prefix_caching: bool,
603
        #[case] enable_chunked_prefill: bool,
604
    ) {
605
        unsafe { std::env::set_var("RUST_LOG", "debug") };
606
607

        let kv_capacity: usize = 500;
608
        let block_size: usize = 64;
609
        let num_requests: usize = 200;
610
611
612
613
        let input_len: usize = 1000;
        let max_output_tokens: usize = 100;

        // Create channel for token output
614
615
616
617
618
619
620
621
        let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();

        // Create scheduler args using builder - now including enable_prefix_caching
        let args = MockEngineArgs::builder()
            .num_gpu_blocks(kv_capacity)
            .block_size(block_size)
            .speedup_ratio(10.0)
            .enable_prefix_caching(enable_prefix_caching)
622
            .enable_chunked_prefill(enable_chunked_prefill)
623
624
625
626
            .build()
            .unwrap();

        // Create scheduler with new args struct
Yan Ru Pei's avatar
Yan Ru Pei committed
627
        let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657

        // Create shared tokens for caching case
        let shared_tokens = if use_shared_tokens {
            Some(
                (0..input_len / 2)
                    .map(|_| rand::random::<u32>() % 50000)
                    .collect::<Vec<_>>(),
            )
        } else {
            None
        };

        // Create test requests
        for _ in 0..num_requests {
            let input_tokens = if let Some(ref shared) = shared_tokens {
                // For caching case: use shared tokens for first half, random for second half
                let mut tokens = shared.clone();
                tokens.extend((0..input_len / 2).map(|_| rand::random::<u32>() % 50000));
                tokens
            } else {
                // For random case: create unique random token vector for each request
                (0..input_len)
                    .map(|_| rand::random::<u32>() % 50000)
                    .collect::<Vec<_>>()
            };

            let request = DirectRequest {
                tokens: input_tokens,
                max_output_tokens,
                uuid: None,
Yan Ru Pei's avatar
Yan Ru Pei committed
658
                dp_rank: 0,
659
660
661
662
663
664
665
666
667
668
669
670
671
672
            };
            scheduler.receive(request).await;
        }

        let start_time = std::time::Instant::now();

        // Collect all generated tokens (should be num_requests * max_output_tokens)
        let expected_tokens = num_requests * max_output_tokens;
        let mut received_tokens = 0;

        // Set up a timeout that causes the test to panic if no tokens are received for 2 seconds
        let timeout = tokio::time::sleep(Duration::from_secs(2));
        tokio::pin!(timeout);

673
674
675
        // Get metrics receiver
        let metrics_rx = scheduler.metrics_receiver();

676
677
678
679
680
681
682
683
684
        // Set up debug ticker interval
        let mut debug_interval = interval(Duration::from_millis(500));

        loop {
            tokio::select! {
                biased;

                // Manual debug ticker that prints forward pass metrics
                _ = debug_interval.tick() => {
685
                    let _metrics = metrics_rx.borrow().clone();
686
                    tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
                }

                Some(_) = output_rx.recv() => {
                    received_tokens += 1;
                    // Reset timeout whenever we receive a token
                    timeout.set(tokio::time::sleep(Duration::from_secs(2)));
                }

                _ = &mut timeout => {
                    // Break instead of panicking when timeout occurs
                    break;
                }
            }
        }

        // Calculate and print elapsed time
        let elapsed = start_time.elapsed();
        println!(
705
            "Test completed in: {elapsed:?} for {} case with prefix_caching={enable_prefix_caching} and chunked_prefill={enable_chunked_prefill}",
706
707
708
709
            if use_shared_tokens {
                "caching"
            } else {
                "random"
710
            }
711
712
713
714
        );

        // Assert that we received the expected number of tokens
        assert!(
715
716
717
            received_tokens == expected_tokens,
            "Received {received_tokens} tokens but expected exactly {expected_tokens}"
        );
718

719
720
        // Wait a bit for final metrics update to propagate
        tokio::time::sleep(Duration::from_millis(100)).await;
721

722
723
        let metrics = scheduler.metrics_receiver().borrow().clone();
        assert_scheduler_idle(&metrics);
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
    }

    #[tokio::test]
    async fn test_cache_hit_rate_with_identical_requests() {
        let block_size: usize = 64;
        let max_output_tokens: usize = 10;
        let speedup_ratio = 10.0;
        let num_requests = 10;
        let token_length = 65;

        // Create channel for token output
        let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();

        // Create scheduler args
        let args = MockEngineArgs::builder()
            .num_gpu_blocks(100) // Large enough to not be a constraint
            .block_size(block_size)
            .speedup_ratio(speedup_ratio)
            .build()
            .unwrap();

        // Create scheduler
Yan Ru Pei's avatar
Yan Ru Pei committed
746
        let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
747
748
749
750
751
752
753
754
755
756

        // Create identical tokens for all requests
        let identical_tokens: Vec<u32> = (0..token_length).map(|i| i as u32).collect();

        // Send all requests with identical tokens
        for _ in 0..num_requests {
            let request = DirectRequest {
                tokens: identical_tokens.clone(),
                max_output_tokens,
                uuid: None,
Yan Ru Pei's avatar
Yan Ru Pei committed
757
                dp_rank: 0,
758
759
760
761
762
763
764
765
766
767
768
769
770
            };
            scheduler.receive(request).await;
            // Sleep for 0.1 second after each request
            tokio::time::sleep(Duration::from_millis(100)).await;
        }

        // Collect all generated tokens
        let mut received_tokens = 0;

        // Set up a timeout that resets to 0.5 seconds on each received token
        let timeout = tokio::time::sleep(Duration::from_millis(500));
        tokio::pin!(timeout);

771
772
773
        // Get metrics receiver
        let metrics_rx = scheduler.metrics_receiver();

774
775
776
777
778
779
780
781
782
        // Set up debug ticker interval
        let mut debug_interval = interval(Duration::from_millis(500));

        loop {
            tokio::select! {
                biased;

                // Manual debug ticker that prints forward pass metrics
                _ = debug_interval.tick() => {
783
                    let _metrics = metrics_rx.borrow().clone();
784
                    tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
                }

                Some(_signal) = output_rx.recv() => {
                    received_tokens += 1;
                    // Reset timeout whenever we receive a token
                    timeout.set(tokio::time::sleep(Duration::from_millis(500)));
                }

                _ = &mut timeout => {
                    // Break when timeout occurs (no more tokens for 0.5 seconds)
                    break;
                }
            }
        }

800
801
802
        // Wait a bit for final metrics update
        tokio::time::sleep(Duration::from_millis(100)).await;

803
        // Verify forward pass metrics - scheduler should be idle after completing all requests
804
        let metrics = metrics_rx.borrow().clone();
805
        assert_scheduler_idle(&metrics);
806

807
        println!("Test passed! Received {received_tokens} tokens");
808
809
    }

810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
    /// White-box unit test that directly creates SchedulerState + KvManager,
    /// manually invokes simulate_prefill / simulate_decode, and asserts on
    /// queue states and block counts after each step.
    #[tokio::test]
    async fn test_scheduler_internal_state_transitions() {
        let args = MockEngineArgs::builder()
            .block_size(4)
            .num_gpu_blocks(6)
            .max_num_batched_tokens(Some(12))
            .max_num_seqs(Some(3))
            .enable_chunked_prefill(true)
            .enable_prefix_caching(false)
            .speedup_ratio(0.0)
            .build()
            .unwrap();

        let mut state = SchedulerState::default();
        let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
        let mut hit_rates = RunningMean::new(1000);
        let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;

        let r1_uuid = Uuid::from_u128(1);
        let r2_uuid = Uuid::from_u128(2);
        let r3_uuid = Uuid::from_u128(3);

        // ── Step 1: Receive 3 requests ──
        // R1: 8 input, 2 max_output → 2 blocks
        // R2: 8 input, 2 max_output → 2 blocks
        // R3: 12 input, 2 max_output → 3 blocks
        state.receive(DirectRequest {
            tokens: (0..8).collect(),
            max_output_tokens: 2,
            uuid: Some(r1_uuid),
            dp_rank: 0,
        });
        state.receive(DirectRequest {
            tokens: (100..108).collect(),
            max_output_tokens: 2,
            uuid: Some(r2_uuid),
            dp_rank: 0,
        });
        state.receive(DirectRequest {
            tokens: (200..212).collect(),
            max_output_tokens: 2,
            uuid: Some(r3_uuid),
            dp_rank: 0,
        });

        assert_eq!(state.waiting.len(), 3);
        assert_eq!(state.prefill.len(), 0);
        assert_eq!(state.decode.len(), 0);
        assert_eq!(kv_manager.num_active_blocks(), 0);

        // ── Step 2: First simulate_prefill ──
        // Budget=12. R1 takes 8 tokens (2 blocks), fully prefilled → decode.
        // R2 takes 4 tokens (1 block, chunked), partially prefilled → stays in prefill.
        simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;

        assert_eq!(state.waiting.len(), 1);
        assert_eq!(state.prefill.len(), 1);
        assert_eq!(state.decode.len(), 1);
        assert_eq!(state.decode[0], r1_uuid);
        assert_eq!(state.prefill[0], r2_uuid);
        assert_eq!(state.waiting[0], r3_uuid);
        assert_eq!(kv_manager.num_active_blocks(), 3); // 2 for R1 + 1 for R2

        let seq = match state.requests.get(&r1_uuid).unwrap() {
            Request::Active(s) => s,
            _ => panic!("expected ActiveSequence"),
        };
        assert_eq!(seq.num_allocated_tokens(), 8);
        assert_eq!(seq.generated_tokens(), 0);

        let seq = match state.requests.get(&r2_uuid).unwrap() {
            Request::Active(s) => s,
            _ => panic!("expected ActiveSequence"),
        };
        assert_eq!(seq.num_allocated_tokens(), 4);
        assert_eq!(seq.generated_tokens(), 0);

        // ── Step 3: First simulate_decode ──
        // R1 generates 1 token, gains a partial block.
        simulate_decode(
            &mut state,
            &mut kv_manager,
            &output_tx,
            &args.perf_model,
            args.block_size,
            args.speedup_ratio,
            args.preemption_mode,
        )
        .await;

        assert_eq!(state.decode.len(), 1);
        assert_eq!(state.decode[0], r1_uuid);
        assert_eq!(kv_manager.num_active_blocks(), 4); // +1 partial for R1

        let seq = match state.requests.get(&r1_uuid).unwrap() {
            Request::Active(s) => s,
            _ => panic!("expected ActiveSequence"),
        };
        assert_eq!(seq.generated_tokens(), 1);

        // ── Step 4: Second simulate_prefill ──
        // Budget=11. R2 finishes (4 more tokens, 1 block → active=5, decode).
        // R3 admitted, needs 2 blocks for chunk of 7. Only 1 free slot → partial.
        // Preempt R2 (LIFO) → R2 back to waiting. Retry R3 → evicts R2's
        // inactive blocks, allocates 2 more → R3 allocated_tokens=11.
        simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;

        assert_eq!(state.waiting.len(), 1, "R2 preempted back to waiting");
        assert_eq!(state.waiting[0], r2_uuid);
        assert_eq!(state.prefill.len(), 1, "R3 partially prefilled");
        assert_eq!(state.prefill[0], r3_uuid);
        assert_eq!(state.decode.len(), 1, "R1 still decoding");
        assert_eq!(state.decode[0], r1_uuid);
        assert_eq!(kv_manager.num_active_blocks(), 6); // at capacity

        let seq = match state.requests.get(&r3_uuid).unwrap() {
            Request::Active(s) => s,
            _ => panic!("expected ActiveSequence"),
        };
        assert_eq!(seq.num_allocated_tokens(), 11);

        // ── Step 5: Second simulate_decode ──
        // R1 generates 2nd token → complete. Frees 3 blocks (1 destroyed, 2 deactivated).
        simulate_decode(
            &mut state,
            &mut kv_manager,
            &output_tx,
            &args.perf_model,
            args.block_size,
            args.speedup_ratio,
            args.preemption_mode,
        )
        .await;

        assert!(!state.requests.contains_key(&r1_uuid), "R1 completed");
        assert_eq!(state.decode.len(), 0);
        assert_eq!(state.prefill.len(), 1);
        assert_eq!(state.waiting.len(), 1);
        assert_eq!(kv_manager.num_active_blocks(), 3); // only R3's 3 blocks

        // ── Step 6: Third simulate_prefill ──
        // R3 finishes prefill (1 token left, no new blocks) → decode.
        // R2 re-admitted, fully prefilled (2 blocks via inactive eviction) → decode.
        simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;

        assert_eq!(state.waiting.len(), 0);
        assert_eq!(state.prefill.len(), 0);
        assert_eq!(state.decode.len(), 2);
        assert!(state.decode.contains(&r3_uuid));
        assert!(state.decode.contains(&r2_uuid));
        assert_eq!(kv_manager.num_active_blocks(), 5); // 3 for R3 + 2 for R2

        // ── Steps 7+: Cycle until all requests complete ──
        loop {
            simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
            simulate_decode(
                &mut state,
                &mut kv_manager,
                &output_tx,
                &args.perf_model,
                args.block_size,
                args.speedup_ratio,
                args.preemption_mode,
            )
            .await;

            if state.is_empty() {
                break;
            }
        }

        assert_eq!(state.waiting.len(), 0);
        assert_eq!(state.prefill.len(), 0);
        assert_eq!(state.decode.len(), 0);
        assert_eq!(kv_manager.num_active_blocks(), 0);
    }

990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
    #[tokio::test]
    async fn test_receiver_drop_cleans_up_resources() {
        let block_size: usize = 64;
        let input_tokens = 256;
        let max_output_tokens = 200; // More than we'll receive

        // Create channel for token output
        let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();

        // Create scheduler args
        let args = MockEngineArgs::builder()
            .num_gpu_blocks(10) // Enough for 256 tokens (4 blocks)
            .block_size(block_size)
            .speedup_ratio(100.0) // Fast simulation
            .build()
            .unwrap();

        // Create scheduler
Yan Ru Pei's avatar
Yan Ru Pei committed
1008
        let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
1009
1010
1011
1012
1013
1014
1015

        // Create request with 256 tokens
        let tokens: Vec<u32> = (0..input_tokens).map(|i| i as u32).collect();
        let request = DirectRequest {
            tokens,
            max_output_tokens,
            uuid: None,
Yan Ru Pei's avatar
Yan Ru Pei committed
1016
            dp_rank: 0,
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
        };

        scheduler.receive(request).await;

        // Receive exactly 129 tokens
        let mut received_count = 0;
        while received_count < 129 {
            if let Some(_signal) = output_rx.recv().await {
                received_count += 1;
            } else {
                panic!("Channel closed before receiving 129 tokens");
            }
        }

        // Drop the receiver immediately
        drop(output_rx);

        // Wait for 1 second to allow cleanup
        tokio::time::sleep(Duration::from_secs(1)).await;

        // Check forward pass metrics
1038
1039
        let metrics_rx = scheduler.metrics_receiver();
        let metrics = metrics_rx.borrow().clone();
1040

1041
        assert_scheduler_idle(&metrics);
1042
1043
    }
}