vllm.rs 31.9 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
// SPDX-License-Identifier: Apache-2.0

//! Asynchronous Scheduler for LLM Request Management
//!
//! This module implements an asynchronous scheduler that handles three main functions:
//! 1. Receiving new requests and placing them in the waiting queue
//! 2. Scheduling waiting requests against available KV cache resources
//! 3. Simulating the execution of running requests with realistic timing
//!
//! ## Scheduling Process
12
//! The scheduler checks direct block capacity to determine if there's sufficient
13
14
15
16
17
18
19
20
21
22
23
24
//! KV cache space for new requests. It also enforces a batched tokens budget to prevent
//! oversubscription of computational resources. Only requests that can be allocated
//! these resources are moved from waiting to running state.
//!
//! ## Request Simulation
//! The simulation models two key phases:
//! - Prefill phase: Uses a quadratic cost function: (cached_tokens + new_tokens) * new_tokens
//! - Decode phase: Uses a cost function proportional to active KV blocks (linear)
//!
//! ## Resource Management
//! The scheduler communicates with the KvManager through MoveBlock signals at each
//! stage of request processing. When resources become constrained, it employs an
25
26
//! preemption strategy (LIFO by default, matching vLLM v1) where a running request
//! is evicted and placed at the front of the waiting queue to be rescheduled later.
27
28
29
30
//!
//! ## NOTE
//! The current prefill and decoding time simulations are not scientific at all and are WIP

31
use crate::common::protocols::{
32
    DirectRequest, KvCacheEventSink, MockEngineArgs, MoveBlock, OutputSignal, PreemptionMode,
33
    WorkerType,
Yan Ru Pei's avatar
Yan Ru Pei committed
34
};
35
36
use crate::common::running_mean::RunningMean;
use crate::common::sequence::ActiveSequence;
37
use crate::common::utils::sleep_until_precise;
38
use crate::kv_manager::KvManager;
39
use dynamo_kv_router::protocols::DpRank;
40
use dynamo_tokens::blocks::UniqueBlock;
Yan Ru Pei's avatar
Yan Ru Pei committed
41
use std::collections::{HashMap, VecDeque};
42
use std::sync::Arc;
43
use std::time::Instant;
44
use tokio::sync::mpsc;
45
use tokio::time::Duration;
46
47
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
48
use validator::Validate;
49

50
51
52
53
54
55
56
/// Simple metrics struct for mocker's internal use
#[derive(Clone, Default, Debug)]
pub struct MockerMetrics {
    pub dp_rank: DpRank,
    pub active_decode_blocks: u64,
}

57
58
59
60
61
62
63
64
65
/// Enum representing either a direct request or an active sequence
pub enum Request {
    Direct(DirectRequest),
    Active(ActiveSequence),
}

#[derive(Default)]
struct SchedulerState {
    waiting: VecDeque<Uuid>,
66
    prefill: VecDeque<Uuid>,
67
    decode: VecDeque<Uuid>,
68
69
70
71
    requests: HashMap<Uuid, Request>,
}

impl SchedulerState {
72
73
74
75
    fn is_empty(&self) -> bool {
        self.requests.is_empty()
    }

76
77
78
79
80
81
82
    fn receive(&mut self, request: DirectRequest) -> Uuid {
        let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
        self.requests.insert(uuid, Request::Direct(request));
        self.waiting.push_back(uuid);
        uuid
    }

83
84
85
86
87
88
    /// Try to admit one request from waiting → prefill.
    /// Converts DirectRequest → ActiveSequence if needed. PrefillCost is computed
    /// later in simulate_prefill when the request reaches the front of the queue.
    fn admit_one(&mut self, args: &MockEngineArgs) -> bool {
        let Some(&uuid) = self.waiting.front() else {
            return false;
89
        };
90
91
92
93
        let num_active = self.prefill.len() + self.decode.len();
        if args.max_num_seqs.is_some_and(|limit| num_active >= limit) {
            return false;
        }
94

95
        self.waiting.pop_front();
96

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        // Convert DirectRequest → ActiveSequence if needed
        if let Some(Request::Direct(_)) = self.requests.get(&uuid) {
            let Some(Request::Direct(direct)) = self.requests.remove(&uuid) else {
                unreachable!()
            };
            self.requests.insert(
                uuid,
                Request::Active(ActiveSequence::new(
                    direct.tokens,
                    direct.max_output_tokens,
                    Some(args.block_size),
                    args.enable_prefix_caching,
                    args.zmq_kv_events_port.is_some(),
                )),
            );
        }
113

114
115
        self.prefill.push_back(uuid);
        true
116
117
    }

118
119
120
121
122
123
124
125
    fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> {
        if !self.decode.contains(&uuid) {
            return None;
        }
        let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else {
            panic!("Request does not exist.");
        };
        Some(sequence)
126
127
128
129
    }

    /// Remove a UUID and its associated Request from collections.
    fn complete(&mut self, uuid: &Uuid) {
130
        tracing::trace!("Request {uuid} will complete");
131
        self.decode.retain(|u| u != uuid);
132
133
134
        self.requests.remove(uuid);
    }

135
136
137
138
139
140
141
142
143
144
    /// Preempt a running request by evicting it from decode, resetting the sequence,
    /// and adding it back to the front of the waiting queue.
    /// In LIFO mode, evicts the newest request (matches vLLM v1).
    /// In FIFO mode, evicts the oldest request.
    fn preempt(&mut self, mode: PreemptionMode) -> Vec<MoveBlock> {
        let uuid = match mode {
            PreemptionMode::Lifo => self.decode.pop_back(),
            PreemptionMode::Fifo => self.decode.pop_front(),
        }
        .expect("Nothing to evict for preemption.");
145
146
147
148
        let request = self
            .requests
            .remove(&uuid)
            .expect("Request does not exist.");
149
        tracing::warn!("Request {uuid} will be preempted");
150

151
        // Reset the sequence and re-queue for prefill
152
153
154
155
156
        let Request::Active(mut active_sequence) = request else {
            panic!("Expected ActiveSequence in running queue")
        };
        let signals = active_sequence.reset_with_signal();

157
158
        self.requests.insert(uuid, Request::Active(active_sequence));
        self.waiting.push_front(uuid);
159
160

        signals
161
162
163
    }
}

164
165
166
167
168
169
170
171
172
173
/// Cancels its token when dropped. Shared via Arc so the background task is
/// only cancelled when the last Scheduler clone is dropped.
struct CancelGuard(CancellationToken);

impl Drop for CancelGuard {
    fn drop(&mut self) {
        self.0.cancel();
    }
}

174
175
176
/// Manages scheduling of requests using KvManager resources
#[derive(Clone)]
pub struct Scheduler {
177
    request_tx: mpsc::UnboundedSender<DirectRequest>,
178
    metrics_rx: tokio::sync::watch::Receiver<MockerMetrics>,
179
    _cancel_guard: Arc<CancelGuard>,
180
181
182
183
184
}

impl Scheduler {
    /// Create a new Scheduler with the given parameters
    pub fn new(
185
        args: MockEngineArgs,
Yan Ru Pei's avatar
Yan Ru Pei committed
186
        dp_rank: u32,
187
        output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
188
        kv_event_sink: Option<Arc<dyn KvCacheEventSink>>,
189
190
        cancellation_token: Option<CancellationToken>,
    ) -> Self {
191
        args.validate().expect("invalid MockEngineArgs");
192

193
194
        // Create channel for request handling
        let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
195
196
197
198
        let initial_metrics = MockerMetrics {
            dp_rank,
            active_decode_blocks: 0,
        };
199
        let (metrics_tx, metrics_rx) =
200
            tokio::sync::watch::channel::<MockerMetrics>(initial_metrics);
201

202
203
204
        let cancel_token = cancellation_token.unwrap_or_default();
        let cancel_token_clone = cancel_token.clone();
        let cancel_guard = Arc::new(CancelGuard(cancel_token));
205
206
207

        // Spawn main background task with cancellation token
        tokio::spawn(async move {
208
            // Create state and kv_manager as local variables owned by this task
209
            let mut state = SchedulerState::default();
210
            let mut kv_manager = KvManager::new_with_event_sink(
Yan Ru Pei's avatar
Yan Ru Pei committed
211
212
                args.num_gpu_blocks,
                args.block_size,
213
                kv_event_sink,
Yan Ru Pei's avatar
Yan Ru Pei committed
214
215
216
                dp_rank,
            );
            let mut hit_rates = RunningMean::new(1000);
217
218

            loop {
Yan Ru Pei's avatar
Yan Ru Pei committed
219
                // 1. Receive requests
220
221
222
223
224
                if receive_requests(&mut state, &mut request_rx, &cancel_token_clone)
                    .await
                    .is_none()
                {
                    break;
225
                }
226

227
228
                // 2. Simulate prefill + decode
                simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
229

230
                simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
231

232
                // 3. Send metrics once per forward pass (after all prefill and decode processing)
233
                let _ = metrics_tx.send(MockerMetrics {
234
                    dp_rank,
235
236
                    active_decode_blocks: kv_manager.num_active_blocks() as u64,
                });
237
238
239
240
241
            }
        });

        Self {
            request_tx,
242
            metrics_rx,
243
            _cancel_guard: cancel_guard,
244
245
        }
    }
246
}
247

248
249
impl super::SchedulerHandle for Scheduler {
    fn receive(&self, request: DirectRequest) {
250
251
252
        let _ = self.request_tx.send(request);
    }

253
    fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
254
        self.request_tx.clone()
255
256
    }

257
    fn metrics_receiver(&self) -> tokio::sync::watch::Receiver<MockerMetrics> {
258
259
260
        self.metrics_rx.clone()
    }
}
261

262
263
264
265
266
267
268
269
270
271
272
273
/// Receive requests from the channel.
/// Returns `Some(())` to continue the loop, `None` to break (on cancellation).
async fn receive_requests(
    state: &mut SchedulerState,
    request_rx: &mut mpsc::UnboundedReceiver<DirectRequest>,
    cancel_token: &CancellationToken,
) -> Option<()> {
    if cancel_token.is_cancelled() {
        return None;
    }

    if state.is_empty() {
274
        // Fully idle - block until new request arrives or shutdown
275
276
277
278
279
        tokio::select! {
            biased;
            _ = cancel_token.cancelled() => {
                return None;
            }
280
281
282
283
            result = request_rx.recv() => {
                let Some(request) = result else {
                    return None; // channel closed
                };
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
                state.receive(request);
                return Some(());
            }
        }
    }

    // Has active/waiting work - collect any pending requests without blocking
    while let Ok(request) = request_rx.try_recv() {
        state.receive(request);
    }

    Some(())
}

/// Simulate prefill phase for all pending prefill requests.
299
300
301
302
303
///
/// Handles token budget, block allocation, and preemption inline.
/// Token budget: `max_num_batched_tokens - decode.len()` (1 token per decode request).
/// When blocks are unavailable, decode requests are preempted (LIFO by default)
/// to free capacity, matching vLLM v1 behavior.
304
async fn simulate_prefill(
305
306
    state: &mut SchedulerState,
    kv_manager: &mut KvManager,
307
308
    hit_rates: &mut RunningMean<f32>,
    args: &MockEngineArgs,
309
) -> Duration {
310
    let start_time = Instant::now();
311
312
    let mut total_time = Duration::ZERO;

313
314
315
316
317
318
319
320
    let mut token_budget = args
        .max_num_batched_tokens
        .map_or(usize::MAX, |t| t.saturating_sub(state.decode.len()));

    'prefill: while token_budget > 0 {
        // Drain prefill first, then pull from waiting one at a time
        if state.prefill.is_empty() && !state.admit_one(args) {
            break;
321
        }
322
        let uuid = state.prefill[0];
323

324
325
326
327
328
329
330
331
332
333
334
335
        let Some(Request::Active(seq)) = state.requests.get(&uuid) else {
            panic!("Request does not exist.");
        };
        let prefill_cost = kv_manager.get_prefill_cost(seq);
        let sequence_len = seq.len();
        let allocated_tokens = seq.num_allocated_tokens();
        let remaining = prefill_cost.new_tokens;

        // Token budget check
        let tokens_left = sequence_len - allocated_tokens;
        if !args.enable_chunked_prefill && tokens_left > token_budget {
            break;
336
        }
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
        let chunk = tokens_left.min(token_budget);
        let cumulative = allocated_tokens + chunk;

        // Allocate blocks. process() returns the number of blocks committed.
        // On partial success, preempt a decode request and retry — the next
        // loop iteration re-prepares from the updated num_allocated_tokens.
        let Some(Request::Active(seq)) = state.requests.get_mut(&uuid) else {
            panic!("Request does not exist.");
        };
        if let Some(signal) = seq.prepare_allocation(cumulative) {
            let expected = match &signal {
                MoveBlock::Use(blocks, ..) => blocks.len(),
                _ => unreachable!(),
            };
            let allocated = kv_manager.process(&signal);
            // Commit the blocks that were actually allocated
            let committed_tokens = if allocated == expected {
                cumulative
            } else {
                // Partial: compute token boundary from block count
                let prev_blocks = allocated_tokens
                    .div_ceil(seq.block_size())
                    .min(seq.unique_blocks().len());
                (prev_blocks + allocated) * seq.block_size()
            };
            seq.commit_allocation(committed_tokens.min(cumulative));
363

364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
            if allocated < expected {
                if state.decode.is_empty() {
                    break;
                }
                for signal in state.preempt(args.preemption_mode) {
                    kv_manager.process(&signal);
                }
                continue 'prefill; // retry with freed capacity
            }
        } else {
            seq.commit_allocation(cumulative);
        }

        // Accumulate prefill compute time (only for the new tokens in this chunk)
        let new_tokens_in_chunk = chunk.min(remaining);
        if args.worker_type != WorkerType::Decode && new_tokens_in_chunk > 0 {
            total_time += Duration::from_secs_f64(
                prefill_cost.predict_prefill_compute(Some(new_tokens_in_chunk), &args.perf_model)
                    / 1000.0,
            );
        }

        // Hit rate: fraction of tokens that were already cached
        let hit_rate = if sequence_len > 0 {
            1.0 - (remaining as f32 / sequence_len as f32)
        } else {
            0.0
        };
        hit_rates.push(hit_rate);

        token_budget -= chunk;

        if cumulative >= sequence_len {
            // Fully prefilled — promote to decode queue
            state.prefill.pop_front();
            state.decode.push_back(uuid);
        } else {
            // Partially prefilled — resume next iteration with updated allocated_tokens
402
403
404
            break;
        }
    }
405

406
407
    if args.speedup_ratio > 0.0 && total_time > Duration::ZERO {
        let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
408
409
        let deadline = start_time + sleep_duration;

410
        sleep_until_precise(deadline).await;
411
    }
412

413
414
415
416
417
    total_time
}

/// Simulate decode phase for all active decode requests.
/// Returns the total decode compute time.
418
async fn simulate_decode(
419
420
421
    state: &mut SchedulerState,
    kv_manager: &mut KvManager,
    output_tx: &Option<mpsc::UnboundedSender<OutputSignal>>,
422
    args: &MockEngineArgs,
423
) -> Duration {
424
425
    let start_time = Instant::now();

426
    // Compute decode timing
427
    let active_kv_tokens = kv_manager.num_active_blocks() * args.block_size;
428

429
    // Compute average context length across all active decode requests
430
    let total_length: usize = state
431
        .decode
432
        .iter()
433
434
435
        .map(|uuid| {
            if let Request::Active(seq) = state.requests.get(uuid).unwrap() {
                seq.len()
436
            } else {
437
                0
438
            }
439
440
441
442
        })
        .sum();
    let count = state.decode.len();

443
    let context_length = if count > 0 { total_length / count } else { 0 };
444
445
446
    let decoding_time = args
        .perf_model
        .predict_decode_time(active_kv_tokens, context_length);
447
448
449
    let total_time = Duration::from_secs_f64(decoding_time / 1000.0);

    // Process decoding
450
    let uuids: Vec<Uuid> = state.decode.iter().copied().collect();
451
    for uuid in uuids {
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
        // Try to generate; if allocation fails, preempt until it succeeds
        // or nothing is left to preempt (matches vLLM v1 scheduler loop).
        // Reborrow sequence each iteration so the mutable ref doesn't
        // conflict with state.preempt().
        let mut allocated = false;
        loop {
            let Some(sequence) = state.run(uuid) else {
                break;
            };
            let signals = sequence.generate();
            if process_signals(kv_manager, &signals) {
                allocated = true;
                break;
            }
            sequence.pop(); // revert the failed generation
467

468
469
470
471
472
            if state.decode.is_empty() {
                break;
            }

            // Preempt one request and free its blocks
473
            for signal in state.preempt(args.preemption_mode) {
474
475
                kv_manager.process(&signal);
            }
476
477
478
479
480
481
482
483

            // If the current request was the one preempted, stop retrying
            if !state.decode.contains(&uuid) {
                break;
            }
        }

        if !allocated {
484
485
486
            continue;
        }

487
488
489
490
        let Some(sequence) = state.run(uuid) else {
            continue;
        };

491
492
493
        // Check completion and send notification
        let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();

494
495
496
497
498
499
500
        let send_failed = output_tx.as_ref().is_some_and(|tx| {
            tx.send(OutputSignal {
                uuid,
                completed: is_complete,
            })
            .is_err()
        });
501
502
503
504
505
506
507
508
509
510
511

        if send_failed {
            for signal in &sequence.free_signal() {
                kv_manager.process(signal);
            }
        }

        if send_failed || is_complete {
            state.complete(&uuid);
        }
    }
512

513
514
515
    let effective_ratio = args.speedup_ratio * args.decode_speedup_ratio;
    if effective_ratio > 0.0 && total_time > Duration::ZERO {
        let sleep_duration = Duration::from_secs_f64(total_time.as_secs_f64() / effective_ratio);
516
517
        let deadline = start_time + sleep_duration;

518
        sleep_until_precise(deadline).await;
519
    }
520

521
522
523
    total_time
}

524
525
526
527
528
529
530
/// Processes MoveBlock signals with the KvManager.
///
/// When a signal fails, this function verifies that the failure is for an expected case:
/// specifically a single signal attempting to create a single partial (generation) block.
/// This validation is important because in normal operation, the only legitimate failure
/// case should be when trying to acquire a new generation block - any other failures would
/// indicate an unexpected state in the system.
531
fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool {
532
    for signal in signals {
533
        if kv_manager.process(signal) > 0 {
534
535
536
537
            continue;
        }

        // Check we have a Use signal with blocks
538
        let MoveBlock::Use(blocks, _hashes, ..) = signal else {
539
540
541
            panic!(
                "Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}"
            );
542
543
544
        };

        // Verify the signal contains exactly one block
545
        let num_blocks = blocks.len();
546
        let num_active_blocks = kv_manager.num_active_blocks();
547
548
549
550
        if num_blocks != 1 {
            panic!(
                "Failed signal is Invalid. Tried to create (prefill) {num_blocks} blocks on top of {num_active_blocks} active blocks."
            );
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
        }

        // Verify the block is a PartialBlock (generation block)
        if !matches!(blocks[0], UniqueBlock::PartialBlock(_)) {
            panic!("Failed signal is Invalid. Generation block has to be partial.");
        }

        return false;
    }

    true
}

#[cfg(test)]
mod tests {
    use super::*;
567
    use crate::scheduler::SchedulerHandle;
568
569
    use rstest::rstest;
    use std::time::Duration;
570
    use tokio::time::interval;
571

572
573
    /// Helper function to verify that the scheduler is idle (no active KV blocks)
    fn assert_scheduler_idle(metrics: &MockerMetrics) {
574
        assert_eq!(
575
            metrics.active_decode_blocks, 0,
576
            "Expected 0 active blocks, got {}",
577
            metrics.active_decode_blocks
578
579
580
        );
    }

581
    #[rstest]
582
583
584
585
586
587
588
589
    #[case::case_1(false, false, false)]
    #[case::case_2(false, true, false)]
    #[case::case_3(true, false, false)]
    #[case::case_4(true, true, false)]
    #[case::case_5(false, false, true)]
    #[case::case_6(false, true, true)]
    #[case::case_7(true, false, true)]
    #[case::case_8(true, true, true)]
590
    #[tokio::test]
591
592
593
    async fn test_scheduler_token_generation_patterns(
        #[case] use_shared_tokens: bool,
        #[case] enable_prefix_caching: bool,
594
        #[case] enable_chunked_prefill: bool,
595
596
597
598
    ) {
        let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();

        let args = MockEngineArgs::builder()
599
600
            .num_gpu_blocks(500)
            .block_size(64)
601
602
            .speedup_ratio(10.0)
            .enable_prefix_caching(enable_prefix_caching)
603
            .enable_chunked_prefill(enable_chunked_prefill)
604
605
606
            .build()
            .unwrap();

Yan Ru Pei's avatar
Yan Ru Pei committed
607
        let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
608

609
610
611
612
613
614
615
616
617
        crate::scheduler::test_utils::assert_scheduler_completes_all(
            &scheduler,
            &mut output_rx,
            200,
            1000,
            100,
            use_shared_tokens,
        )
        .await;
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
    }

    #[tokio::test]
    async fn test_cache_hit_rate_with_identical_requests() {
        let block_size: usize = 64;
        let max_output_tokens: usize = 10;
        let speedup_ratio = 10.0;
        let num_requests = 10;
        let token_length = 65;

        // Create channel for token output
        let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();

        // Create scheduler args
        let args = MockEngineArgs::builder()
            .num_gpu_blocks(100) // Large enough to not be a constraint
            .block_size(block_size)
            .speedup_ratio(speedup_ratio)
            .build()
            .unwrap();

        // Create scheduler
Yan Ru Pei's avatar
Yan Ru Pei committed
640
        let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
641
642
643
644
645
646
647
648
649
650

        // Create identical tokens for all requests
        let identical_tokens: Vec<u32> = (0..token_length).map(|i| i as u32).collect();

        // Send all requests with identical tokens
        for _ in 0..num_requests {
            let request = DirectRequest {
                tokens: identical_tokens.clone(),
                max_output_tokens,
                uuid: None,
Yan Ru Pei's avatar
Yan Ru Pei committed
651
                dp_rank: 0,
652
            };
653
            scheduler.receive(request);
654
655
656
657
658
659
660
661
662
663
664
            // Sleep for 0.1 second after each request
            tokio::time::sleep(Duration::from_millis(100)).await;
        }

        // Collect all generated tokens
        let mut received_tokens = 0;

        // Set up a timeout that resets to 0.5 seconds on each received token
        let timeout = tokio::time::sleep(Duration::from_millis(500));
        tokio::pin!(timeout);

665
666
667
        // Get metrics receiver
        let metrics_rx = scheduler.metrics_receiver();

668
669
670
671
672
673
674
675
676
        // Set up debug ticker interval
        let mut debug_interval = interval(Duration::from_millis(500));

        loop {
            tokio::select! {
                biased;

                // Manual debug ticker that prints forward pass metrics
                _ = debug_interval.tick() => {
677
                    let _metrics = metrics_rx.borrow().clone();
678
                    tracing::debug!("Forward Pass Metrics: {_metrics:#?}");
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
                }

                Some(_signal) = output_rx.recv() => {
                    received_tokens += 1;
                    // Reset timeout whenever we receive a token
                    timeout.set(tokio::time::sleep(Duration::from_millis(500)));
                }

                _ = &mut timeout => {
                    // Break when timeout occurs (no more tokens for 0.5 seconds)
                    break;
                }
            }
        }

694
695
696
        // Wait a bit for final metrics update
        tokio::time::sleep(Duration::from_millis(100)).await;

697
        // Verify forward pass metrics - scheduler should be idle after completing all requests
698
        let metrics = metrics_rx.borrow().clone();
699
        assert_scheduler_idle(&metrics);
700

701
        println!("Test passed! Received {received_tokens} tokens");
702
703
    }

704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
    /// White-box unit test that directly creates SchedulerState + KvManager,
    /// manually invokes simulate_prefill / simulate_decode, and asserts on
    /// queue states and block counts after each step.
    #[tokio::test]
    async fn test_scheduler_internal_state_transitions() {
        let args = MockEngineArgs::builder()
            .block_size(4)
            .num_gpu_blocks(6)
            .max_num_batched_tokens(Some(12))
            .max_num_seqs(Some(3))
            .enable_chunked_prefill(true)
            .enable_prefix_caching(false)
            .speedup_ratio(0.0)
            .build()
            .unwrap();

        let mut state = SchedulerState::default();
        let mut kv_manager = KvManager::new(args.num_gpu_blocks, args.block_size);
        let mut hit_rates = RunningMean::new(1000);
        let output_tx: Option<mpsc::UnboundedSender<OutputSignal>> = None;

        let r1_uuid = Uuid::from_u128(1);
        let r2_uuid = Uuid::from_u128(2);
        let r3_uuid = Uuid::from_u128(3);

        // ── Step 1: Receive 3 requests ──
        // R1: 8 input, 2 max_output → 2 blocks
        // R2: 8 input, 2 max_output → 2 blocks
        // R3: 12 input, 2 max_output → 3 blocks
        state.receive(DirectRequest {
            tokens: (0..8).collect(),
            max_output_tokens: 2,
            uuid: Some(r1_uuid),
            dp_rank: 0,
        });
        state.receive(DirectRequest {
            tokens: (100..108).collect(),
            max_output_tokens: 2,
            uuid: Some(r2_uuid),
            dp_rank: 0,
        });
        state.receive(DirectRequest {
            tokens: (200..212).collect(),
            max_output_tokens: 2,
            uuid: Some(r3_uuid),
            dp_rank: 0,
        });

        assert_eq!(state.waiting.len(), 3);
        assert_eq!(state.prefill.len(), 0);
        assert_eq!(state.decode.len(), 0);
        assert_eq!(kv_manager.num_active_blocks(), 0);

        // ── Step 2: First simulate_prefill ──
        // Budget=12. R1 takes 8 tokens (2 blocks), fully prefilled → decode.
        // R2 takes 4 tokens (1 block, chunked), partially prefilled → stays in prefill.
        simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;

        assert_eq!(state.waiting.len(), 1);
        assert_eq!(state.prefill.len(), 1);
        assert_eq!(state.decode.len(), 1);
        assert_eq!(state.decode[0], r1_uuid);
        assert_eq!(state.prefill[0], r2_uuid);
        assert_eq!(state.waiting[0], r3_uuid);
        assert_eq!(kv_manager.num_active_blocks(), 3); // 2 for R1 + 1 for R2

        let seq = match state.requests.get(&r1_uuid).unwrap() {
            Request::Active(s) => s,
            _ => panic!("expected ActiveSequence"),
        };
        assert_eq!(seq.num_allocated_tokens(), 8);
        assert_eq!(seq.generated_tokens(), 0);

        let seq = match state.requests.get(&r2_uuid).unwrap() {
            Request::Active(s) => s,
            _ => panic!("expected ActiveSequence"),
        };
        assert_eq!(seq.num_allocated_tokens(), 4);
        assert_eq!(seq.generated_tokens(), 0);

        // ── Step 3: First simulate_decode ──
        // R1 generates 1 token, gains a partial block.
786
        simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820

        assert_eq!(state.decode.len(), 1);
        assert_eq!(state.decode[0], r1_uuid);
        assert_eq!(kv_manager.num_active_blocks(), 4); // +1 partial for R1

        let seq = match state.requests.get(&r1_uuid).unwrap() {
            Request::Active(s) => s,
            _ => panic!("expected ActiveSequence"),
        };
        assert_eq!(seq.generated_tokens(), 1);

        // ── Step 4: Second simulate_prefill ──
        // Budget=11. R2 finishes (4 more tokens, 1 block → active=5, decode).
        // R3 admitted, needs 2 blocks for chunk of 7. Only 1 free slot → partial.
        // Preempt R2 (LIFO) → R2 back to waiting. Retry R3 → evicts R2's
        // inactive blocks, allocates 2 more → R3 allocated_tokens=11.
        simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;

        assert_eq!(state.waiting.len(), 1, "R2 preempted back to waiting");
        assert_eq!(state.waiting[0], r2_uuid);
        assert_eq!(state.prefill.len(), 1, "R3 partially prefilled");
        assert_eq!(state.prefill[0], r3_uuid);
        assert_eq!(state.decode.len(), 1, "R1 still decoding");
        assert_eq!(state.decode[0], r1_uuid);
        assert_eq!(kv_manager.num_active_blocks(), 6); // at capacity

        let seq = match state.requests.get(&r3_uuid).unwrap() {
            Request::Active(s) => s,
            _ => panic!("expected ActiveSequence"),
        };
        assert_eq!(seq.num_allocated_tokens(), 11);

        // ── Step 5: Second simulate_decode ──
        // R1 generates 2nd token → complete. Frees 3 blocks (1 destroyed, 2 deactivated).
821
        simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843

        assert!(!state.requests.contains_key(&r1_uuid), "R1 completed");
        assert_eq!(state.decode.len(), 0);
        assert_eq!(state.prefill.len(), 1);
        assert_eq!(state.waiting.len(), 1);
        assert_eq!(kv_manager.num_active_blocks(), 3); // only R3's 3 blocks

        // ── Step 6: Third simulate_prefill ──
        // R3 finishes prefill (1 token left, no new blocks) → decode.
        // R2 re-admitted, fully prefilled (2 blocks via inactive eviction) → decode.
        simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;

        assert_eq!(state.waiting.len(), 0);
        assert_eq!(state.prefill.len(), 0);
        assert_eq!(state.decode.len(), 2);
        assert!(state.decode.contains(&r3_uuid));
        assert!(state.decode.contains(&r2_uuid));
        assert_eq!(kv_manager.num_active_blocks(), 5); // 3 for R3 + 2 for R2

        // ── Steps 7+: Cycle until all requests complete ──
        loop {
            simulate_prefill(&mut state, &mut kv_manager, &mut hit_rates, &args).await;
844
            simulate_decode(&mut state, &mut kv_manager, &output_tx, &args).await;
845
846
847
848
849
850
851
852
853
854
855
856

            if state.is_empty() {
                break;
            }
        }

        assert_eq!(state.waiting.len(), 0);
        assert_eq!(state.prefill.len(), 0);
        assert_eq!(state.decode.len(), 0);
        assert_eq!(kv_manager.num_active_blocks(), 0);
    }

857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
    #[tokio::test]
    async fn test_receiver_drop_cleans_up_resources() {
        let block_size: usize = 64;
        let input_tokens = 256;
        let max_output_tokens = 200; // More than we'll receive

        // Create channel for token output
        let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();

        // Create scheduler args
        let args = MockEngineArgs::builder()
            .num_gpu_blocks(10) // Enough for 256 tokens (4 blocks)
            .block_size(block_size)
            .speedup_ratio(100.0) // Fast simulation
            .build()
            .unwrap();

        // Create scheduler
Yan Ru Pei's avatar
Yan Ru Pei committed
875
        let scheduler = Scheduler::new(args, 0, Some(output_tx), None, None);
876
877
878
879
880
881
882

        // Create request with 256 tokens
        let tokens: Vec<u32> = (0..input_tokens).map(|i| i as u32).collect();
        let request = DirectRequest {
            tokens,
            max_output_tokens,
            uuid: None,
Yan Ru Pei's avatar
Yan Ru Pei committed
883
            dp_rank: 0,
884
885
        };

886
        scheduler.receive(request);
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904

        // Receive exactly 129 tokens
        let mut received_count = 0;
        while received_count < 129 {
            if let Some(_signal) = output_rx.recv().await {
                received_count += 1;
            } else {
                panic!("Channel closed before receiving 129 tokens");
            }
        }

        // Drop the receiver immediately
        drop(output_rx);

        // Wait for 1 second to allow cleanup
        tokio::time::sleep(Duration::from_secs(1)).await;

        // Check forward pass metrics
905
906
        let metrics_rx = scheduler.metrics_receiver();
        let metrics = metrics_rx.borrow().clone();
907

908
        assert_scheduler_idle(&metrics);
909
910
    }
}