single.rs 30.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! KV Cache Sequence Management for LLM Inference
//!
//! This module provides efficient management of token sequences and their associated KV cache blocks
//! for distributed LLM inference. It implements a shared block system where multiple requests can
//! reuse the same KV cache blocks for common token prefixes, significantly reducing memory usage.
//!
//! # Key Components
//!
//! - [`ActiveSequences`]: Per-worker sequence manager that tracks active requests and their
//!   token sequences, managing shared KV cache blocks efficiently.
//!
//! # Architecture
//!
//! The system uses a block-based approach where token sequences are divided into fixed-size blocks.
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples).

use derive_getters::Getters;
use dynamo_tokens::SequenceHash;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
use uuid::Uuid;

29
30
31
32
use super::block_tracker::BlockTracker;
use super::prefill_tracker::{PrefillLoadState, PrefillLoadTracker};
use crate::protocols::PrefillLoadHint;

33
/// Duration after which stale requests may be expired (5 minutes).
34
35
const EXPIRY_DURATION: Duration = Duration::from_secs(300);

36
37
38
39
/// How often we *check* for stale requests (30 seconds). This is not
/// the expiration time, that is EXPIRY_DURATION.
const CHECK_EXPIRY_FREQUENCY: Duration = Duration::from_secs(30);

40
41
42
// TODO: use the common request_id if it exists in the repo
pub type RequestId = String;

43
44
45
46
47
48
49
50
#[derive(Debug)]
pub(super) struct RequestState {
    blocks: Vec<(SequenceHash, Arc<()>)>,
    started_at: Instant,
    prefill: Option<PrefillLoadState>,
    expected_output_tokens: Option<u32>,
}

51
52
53
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)]
pub struct ActiveSequences {
54
55
56
    requests: HashMap<RequestId, RequestState>,
    prefill: PrefillLoadTracker,
    blocks: BlockTracker,
57
58
59
60

    #[getter(copy)]
    block_size: usize,

61
    last_expiry_check_time: Instant,
62
63
64
65
66
67
68
69
}

impl ActiveSequences {
    /// Create a new SharedSequenceManager instance
    pub fn new(block_size: usize) -> Self {
        assert!(block_size > 1, "block_size must be greater than 1");

        Self {
70
71
72
            requests: HashMap::new(),
            prefill: PrefillLoadTracker::default(),
            blocks: BlockTracker::default(),
73
            block_size,
74
            last_expiry_check_time: Instant::now(),
75
76
77
        }
    }

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    #[cfg(any(test, debug_assertions))]
    fn assert_consistent(&self) {
        let active_prefills: HashSet<RequestId> = self
            .requests
            .iter()
            .filter(|(_, state)| state.prefill.is_some())
            .map(|(request_id, _)| request_id.clone())
            .collect();
        let ordered_prefills: HashSet<RequestId> =
            self.prefill.prefill_order.iter().cloned().collect();
        let recomputed_prefill_sum: usize = self
            .requests
            .values()
            .filter_map(|state| state.prefill)
            .map(|prefill| prefill.initial_effective_prefill_tokens)
            .sum();
        assert_eq!(
            ordered_prefills.len(),
            self.prefill.prefill_order.len(),
            "prefill_order contains duplicate request ids",
        );
        assert_eq!(
            ordered_prefills, active_prefills,
            "prefill_order must match requests with active prefill load",
        );
        assert_eq!(
            self.prefill.prefill_full_tokens_sum, recomputed_prefill_sum,
            "prefill_full_tokens_sum drifted from request state",
        );
        if let Some(oldest_request_id) = self.prefill.prefill_order.front() {
            let Some((anchored_request_id, _)) = self.prefill.anchored_prefill.as_ref() else {
                panic!("anchored_prefill must exist when prefill_order is non-empty");
            };
            assert!(
                self.requests
                    .get(oldest_request_id)
                    .is_some_and(|state| state.prefill.is_some()),
                "prefill_order front must point to an active prefill request",
            );
            assert_eq!(
                anchored_request_id, oldest_request_id,
                "anchored_prefill must match prefill_order.front()",
            );
        } else {
            assert!(
                self.prefill.anchored_prefill.is_none(),
                "anchored_prefill must be absent when no active prefills remain",
            );
126
        }
127
128
129
130
131
132
133
        assert!(
            self.blocks
                .fractional_blocks
                .keys()
                .all(|hash| self.blocks.unique_blocks.contains_key(hash)),
            "fractional_blocks cannot reference non-active blocks",
        );
134
135
    }

136
137
138
139
    #[inline]
    fn validate_state(&self) {
        #[cfg(any(test, debug_assertions))]
        self.assert_consistent();
140
141
142
    }

    pub fn active_blocks(&self) -> usize {
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
        self.blocks.active_blocks()
    }

    fn insert_prefill_load(
        &mut self,
        request_id: &RequestId,
        prefill: PrefillLoadState,
        decay_now: Instant,
    ) {
        self.prefill.insert(request_id, prefill, decay_now);
    }

    fn remove_prefill_load(
        &mut self,
        request_id: &RequestId,
        decay_now: Instant,
    ) -> Option<PrefillLoadState> {
        let prefill = {
            let state = self.requests.get_mut(request_id)?;
            state.prefill.take()?
        };
        self.prefill.remove(request_id, prefill, decay_now);
        Some(prefill)
    }

    fn active_prefill_tokens_at(&self, now: Instant) -> usize {
        let Some((oldest_request_id, oldest_since)) = self.prefill.anchored_prefill.as_ref() else {
            return 0;
        };
        let prefill = self
            .requests
            .get(oldest_request_id)
            .and_then(|state| state.prefill)
            .expect("prefill_order front missing prefill load");
        let oldest_full = prefill.initial_effective_prefill_tokens;
        let oldest_remaining = match prefill.expected_prefill_duration {
            None => oldest_full,
            Some(expected_prefill_duration) if expected_prefill_duration.is_zero() => 0,
            Some(expected_prefill_duration) => {
                let elapsed = now.saturating_duration_since(*oldest_since);
                let remaining_fraction = (1.0
                    - (elapsed.as_secs_f64() / expected_prefill_duration.as_secs_f64()))
                .clamp(0.0, 1.0);
                ((oldest_full as f64) * remaining_fraction).ceil() as usize
187
            }
188
189
190
191
192
193
194
195
196
197
198
        };

        self.prefill
            .prefill_full_tokens_sum
            .checked_sub(oldest_full)
            .expect("prefill_full_tokens_sum smaller than oldest load")
            + oldest_remaining
    }

    pub fn active_tokens(&self, decay_now: Instant) -> usize {
        self.active_prefill_tokens_at(decay_now)
199
200
201
202
203
    }

    /// Find all blocks in a request that have only a single strong reference (only used by this request)
    /// and insert them into fractional_blocks with the given fraction value.
    pub fn set_single_ref_blocks_as_fractional(&mut self, request_id: &RequestId, fraction: f64) {
204
        let Some(request_state) = self.requests.get(request_id) else {
205
206
207
208
209
210
            tracing::warn!(
                "Request {request_id} not found for set_single_ref_blocks_as_fractional"
            );
            return;
        };

211
        for (hash, rc) in &request_state.blocks {
212
            if Arc::strong_count(rc) == 1 {
213
                self.blocks.fractional_blocks.insert(*hash, fraction);
214
215
216
217
            }
        }
    }

218
219
    /// Add a new request with its initial tokens.
    /// Returns the set of expired request IDs that were removed during cleanup.
220
221
222
223
224
225
226
    pub fn add_request(
        &mut self,
        request_id: RequestId,
        token_sequence: Option<Vec<SequenceHash>>,
        isl: usize,
        overlap: u32,
        expected_output_tokens: Option<u32>,
227
        decay_now: Instant,
228
229
230
231
232
233
234
235
    ) -> HashSet<RequestId> {
        self.add_request_with_prefill_tracking(
            request_id,
            token_sequence,
            isl,
            overlap,
            expected_output_tokens,
            true,
236
237
            None,
            decay_now,
238
239
240
241
242
        )
    }

    /// Add a new request with optional prompt-token load accounting.
    /// Returns the set of expired request IDs that were removed during cleanup.
243
    #[allow(clippy::too_many_arguments)]
244
245
246
247
248
249
250
251
    pub fn add_request_with_prefill_tracking(
        &mut self,
        request_id: RequestId,
        token_sequence: Option<Vec<SequenceHash>>,
        isl: usize,
        overlap: u32,
        expected_output_tokens: Option<u32>,
        track_prefill_tokens: bool,
252
253
        prefill_load_hint: Option<PrefillLoadHint>,
        decay_now: Instant,
254
    ) -> HashSet<RequestId> {
255
        if self.requests.contains_key(&request_id) {
256
257
258
259
260
            tracing::error!("Request {request_id} is already active. Ignoring duplicate add.");
            return HashSet::new();
        }

        let removed_requests = self.force_expiry();
261
262
263
264
265
266
267
268
269
270
271
272
        let started_at = Instant::now();

        let blocks = match token_sequence {
            Some(sequence) => sequence
                .into_iter()
                .map(|block| {
                    let rc = self.blocks.touch_block(&block);
                    (block, rc)
                })
                .collect(),
            None => Vec::new(),
        };
273

274
275
276
277
278
279
280
281
282
283
284
        let prefill = if track_prefill_tokens {
            let default_tokens = self.new_tokens(isl, overlap);
            let hint = prefill_load_hint.unwrap_or(PrefillLoadHint {
                initial_effective_prefill_tokens: default_tokens,
                expected_prefill_duration: None,
            });

            (hint.initial_effective_prefill_tokens > 0).then_some(PrefillLoadState {
                initial_effective_prefill_tokens: hint.initial_effective_prefill_tokens,
                expected_prefill_duration: hint.expected_prefill_duration,
            })
285
        } else {
286
            None
287
        };
288

289
290
291
292
293
294
295
296
297
298
299
300
        self.requests.insert(
            request_id.clone(),
            RequestState {
                blocks,
                started_at,
                prefill,
                expected_output_tokens,
            },
        );

        if let Some(prefill) = prefill {
            self.insert_prefill_load(&request_id, prefill, decay_now);
301
302
        }

303
        self.validate_state();
304
305
306
        removed_requests
    }

307
308
309
310
    /// Mark prefill as completed for a request, removing it from prompt-load tracking.
    pub fn mark_prefill_completed(&mut self, request_id: &RequestId, decay_now: Instant) {
        let _ = self.remove_prefill_load(request_id, decay_now);
        self.validate_state();
311
312
313
314
    }

    pub fn new_tokens(&self, isl: usize, overlap: u32) -> usize {
        let cached_tokens = (overlap as usize) * self.block_size;
315
316
317
318
319
320
321
        isl.checked_sub(cached_tokens).unwrap_or_else(|| {
            tracing::error!(
                "prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens} (overlap {overlap} * block_size {}), returning 0",
                self.block_size
            );
            0
        })
322
323
324
325
326
327
328
    }

    pub fn potential_blocks_and_tokens(
        &self,
        token_sequence: Option<&[SequenceHash]>,
        isl: usize,
        overlap: u32,
329
        decay_now: Instant,
330
    ) -> (usize, usize) {
331
332
333
334
335
336
337
        self.potential_blocks_and_tokens_with_prefill_tracking(
            token_sequence,
            isl,
            overlap,
            true,
            decay_now,
        )
338
339
340
341
342
343
344
345
    }

    pub fn potential_blocks_and_tokens_with_prefill_tracking(
        &self,
        token_sequence: Option<&[SequenceHash]>,
        isl: usize,
        overlap: u32,
        track_prefill_tokens: bool,
346
        decay_now: Instant,
347
348
349
350
351
352
    ) -> (usize, usize) {
        let potential_blocks = if let Some(token_seq) = token_sequence {
            self.new_blocks(token_seq) + self.active_blocks()
        } else {
            self.active_blocks()
        };
353
        let active_tokens = self.active_tokens(decay_now);
354
        let potential_tokens = if track_prefill_tokens {
355
            self.new_tokens(isl, overlap) + active_tokens
356
        } else {
357
            active_tokens
358
        };
359

360
361
362
363
364
365
366
        (potential_blocks, potential_tokens)
    }

    /// Match a request against existing blocks and return the number of new blocks that would be added
    pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
        token_sequence
            .iter()
367
            .filter(|block| !self.blocks.unique_blocks.contains_key(block))
368
369
370
            .count()
    }

371
    /// Return the total number of blocks that would be used if the token sequence was added.
372
373
374
375
    pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
        self.new_blocks(token_sequence) + self.active_blocks()
    }

376
377
378
379
    /// Free all blocks associated with a request.
    ///
    /// This implicitly calls [`Self::mark_prefill_completed`] first, so callers do not need
    /// to invoke both when the request is finishing.
380
381
382
383
384
385
    pub fn free(&mut self, request_id: &RequestId, decay_now: Instant) -> usize {
        self.mark_prefill_completed(request_id, decay_now);

        let Some(request_state) = self.requests.remove(request_id) else {
            tracing::warn!("Trying to free non-existent request {request_id}");
            return self.active_blocks();
386
387
        };

388
389
        let _ = request_state.expected_output_tokens;
        for (block_hash, rc) in request_state.blocks {
390
            drop(rc);
391
            self.blocks.try_remove_block(&block_hash);
392
393
        }

394
        self.validate_state();
395
396
397
398
399
400
401
402
403
404
405
        self.active_blocks()
    }

    /// Add an output block with a random hash and optional fractional decay weight.
    ///
    /// This is used during generation to track output blocks as they are created.
    pub fn add_output_block(
        &mut self,
        request_id: &RequestId,
        decay_fraction: Option<f64>,
    ) -> bool {
406
        if !self.requests.contains_key(request_id) {
407
408
409
410
411
            tracing::warn!("Request {request_id} not found for add_output_block");
            return false;
        }

        let random_hash: SequenceHash = Uuid::new_v4().as_u64_pair().0;
412
413
        let rc = self.blocks.touch_block(&random_hash);
        self.requests
414
            .get_mut(request_id)
415
416
            .expect("request existence was checked above")
            .blocks
417
418
419
420
421
422
            .push((random_hash, rc));

        if let Some(frac) = decay_fraction {
            self.set_single_ref_blocks_as_fractional(request_id, frac);
        }

423
        self.validate_state();
424
425
426
        true
    }

427
428
    /// Force expiry of stale requests if the timer has elapsed.
    /// Returns the set of expired request IDs that were removed.
429
430
431
    pub fn force_expiry(&mut self) -> HashSet<RequestId> {
        let now = Instant::now();

432
        if now < self.last_expiry_check_time + CHECK_EXPIRY_FREQUENCY {
433
434
435
            return HashSet::new();
        }

436
437
        self.last_expiry_check_time = now;
        let expired_requests_time = now - EXPIRY_DURATION;
438
439
440
441
442
443
        let expired_requests: HashSet<RequestId> = self
            .requests
            .iter()
            .filter(|(_, state)| state.started_at < expired_requests_time)
            .map(|(request_id, _)| request_id.clone())
            .collect();
444

445
        for request_id in &expired_requests {
446
            tracing::warn!("Expiring stale request: {}", request_id);
447
            self.free(request_id, now);
448
449
        }

450
        self.validate_state();
451
452
453
454
455
456
457
        expired_requests
    }
}

#[cfg(test)]
mod tests {
    use super::*;
458
459
460
461
462
463
464
465
    use std::collections::VecDeque;

    fn prefill_hint(tokens: usize, duration_secs: u64) -> PrefillLoadHint {
        PrefillLoadHint {
            initial_effective_prefill_tokens: tokens,
            expected_prefill_duration: Some(Duration::from_secs(duration_secs)),
        }
    }
466
467
468
469
470

    #[test]
    fn test_active_sequences_shared_blocks() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);
471
        let decay_now = Instant::now();
472

473
474
475
476
477
478
479
480
        seq_manager.add_request(
            "request_1".to_string(),
            Some(vec![1, 2, 3]),
            12,
            0,
            None,
            decay_now,
        );
481
        assert_eq!(seq_manager.active_blocks(), 3);
482
        assert_eq!(seq_manager.active_tokens(decay_now), 12);
483

484
485
486
487
488
489
490
491
        seq_manager.add_request(
            "request_2".to_string(),
            Some(vec![4]),
            4,
            0,
            None,
            decay_now,
        );
492
        assert_eq!(seq_manager.active_blocks(), 4);
493
        assert_eq!(seq_manager.active_tokens(decay_now), 16);
494

495
496
497
498
499
500
501
502
        seq_manager.add_request(
            "request_3".to_string(),
            Some(vec![1, 2, 3, 4]),
            16,
            4,
            None,
            decay_now,
        );
503
        assert_eq!(seq_manager.active_blocks(), 4);
504
        assert_eq!(seq_manager.active_tokens(decay_now), 16);
505

506
        seq_manager.free(&"request_2".to_string(), decay_now);
507
        assert_eq!(seq_manager.active_blocks(), 4);
508
        assert_eq!(seq_manager.active_tokens(decay_now), 12);
509

510
        seq_manager.free(&"request_3".to_string(), decay_now);
511
        assert_eq!(seq_manager.active_blocks(), 3);
512
        assert_eq!(seq_manager.active_tokens(decay_now), 12);
513

514
        seq_manager.free(&"request_1".to_string(), decay_now);
515
        assert_eq!(seq_manager.active_blocks(), 0);
516
        assert_eq!(seq_manager.active_tokens(decay_now), 0);
517
518
519
520
521
522
    }

    #[test]
    fn test_output_blocks_with_fractional_decay() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);
523
        let decay_now = Instant::now();
524

525
526
527
528
529
530
531
532
        seq_manager.add_request(
            "r1".to_string(),
            Some(vec![1, 2, 3]),
            12,
            0,
            None,
            decay_now,
        );
533
534
535
536
537
        assert_eq!(seq_manager.active_blocks(), 3);

        assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.5)));
        assert_eq!(seq_manager.active_blocks(), 2);

538
        seq_manager.add_request("r2".to_string(), Some(vec![1, 2]), 8, 0, None, decay_now);
539
540
541
542
543
        assert_eq!(seq_manager.active_blocks(), 2);

        assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.0)));
        assert_eq!(seq_manager.active_blocks(), 1);

544
545
        seq_manager.free(&"r2".to_string(), decay_now);
        seq_manager.free(&"r1".to_string(), decay_now);
546
        assert_eq!(seq_manager.active_blocks(), 0);
547
        assert_eq!(seq_manager.active_tokens(decay_now), 0);
548
549
550
551
552
553
    }

    #[test]
    fn test_mark_prefill_completed() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);
554
        let decay_now = Instant::now();
555

556
557
558
559
560
561
562
563
564
        seq_manager.add_request(
            "r1".to_string(),
            Some(vec![1, 2, 3]),
            12,
            0,
            None,
            decay_now,
        );
        assert_eq!(seq_manager.active_tokens(decay_now), 12);
565

566
567
        seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
        assert_eq!(seq_manager.active_tokens(decay_now), 0);
568

569
570
        seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
        assert_eq!(seq_manager.active_tokens(decay_now), 0);
571

572
573
        seq_manager.add_request("r2".to_string(), Some(vec![4, 5]), 8, 0, None, decay_now);
        assert_eq!(seq_manager.active_tokens(decay_now), 8);
574

575
576
        seq_manager.free(&"r2".to_string(), decay_now);
        assert_eq!(seq_manager.active_tokens(decay_now), 0);
577
578
    }

579
580
581
    #[test]
    fn test_add_request_without_prefill_tracking_keeps_active_tokens_zero() {
        let mut seq_manager = ActiveSequences::new(4);
582
        let decay_now = Instant::now();
583
584
585
586
587
588
589
590

        seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1, 2, 3]),
            12,
            0,
            None,
            false,
591
592
            None,
            decay_now,
593
594
        );

595
596
597
598
599
600
601
        assert_eq!(seq_manager.active_tokens(decay_now), 0);
        assert!(seq_manager.prefill.prefill_order.is_empty());
        assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 0);

        seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
        assert_eq!(seq_manager.active_tokens(decay_now), 0);
        seq_manager.free(&"r1".to_string(), decay_now);
602
603
604
605
606
607
        assert_eq!(seq_manager.active_blocks(), 0);
    }

    #[test]
    fn test_potential_blocks_and_tokens_without_prefill_tracking_ignores_prompt_load() {
        let mut seq_manager = ActiveSequences::new(4);
608
        let decay_now = Instant::now();
609
610
611
612
613
614
615
        seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1, 2, 3]),
            12,
            0,
            None,
            false,
616
617
            None,
            decay_now,
618
619
620
621
622
623
624
        );

        let (blocks, tokens) = seq_manager.potential_blocks_and_tokens_with_prefill_tracking(
            Some(&[1, 2, 3, 4]),
            16,
            0,
            false,
625
            decay_now,
626
627
628
629
630
        );
        assert_eq!(blocks, 4);
        assert_eq!(tokens, 0);
    }

631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
    #[test]
    fn test_prefill_decay_only_applies_to_oldest_request() {
        let mut seq_manager = ActiveSequences::new(4);
        let epoch = Instant::now();

        seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1]),
            100,
            0,
            None,
            true,
            Some(prefill_hint(100, 10)),
            epoch,
        );
        seq_manager.add_request_with_prefill_tracking(
            "r2".to_string(),
            Some(vec![2]),
            60,
            0,
            None,
            true,
            Some(prefill_hint(60, 6)),
            epoch + Duration::from_secs(2),
        );

        assert_eq!(
            seq_manager.active_tokens(epoch + Duration::from_secs(2)),
            140
        );

        let decayed = seq_manager.active_tokens(epoch + Duration::from_secs(5));
        assert_eq!(decayed, 110);
        assert!(decayed <= 160);
        assert!(decayed >= 60);
    }

    #[test]
    fn test_prefill_decay_hands_off_to_next_oldest_request() {
        let mut seq_manager = ActiveSequences::new(4);
        let epoch = Instant::now();

        seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1]),
            100,
            0,
            None,
            true,
            Some(prefill_hint(100, 10)),
            epoch,
        );
        seq_manager.add_request_with_prefill_tracking(
            "r2".to_string(),
            Some(vec![2]),
            40,
            0,
            None,
            true,
            Some(prefill_hint(40, 8)),
            epoch,
        );

        assert_eq!(
            seq_manager.active_tokens(epoch + Duration::from_secs(3)),
            110
        );

        seq_manager.mark_prefill_completed(&"r1".to_string(), epoch + Duration::from_secs(3));
        assert_eq!(
            seq_manager.active_tokens(epoch + Duration::from_secs(3)),
            40
        );
        assert_eq!(
            seq_manager.prefill.prefill_order,
            VecDeque::from(vec!["r2".to_string()])
        );
        assert!(
            seq_manager
                .prefill
                .anchored_prefill
                .as_ref()
                .is_some_and(|(request_id, _)| request_id == "r2")
        );

        assert_eq!(
            seq_manager.active_tokens(epoch + Duration::from_secs(5)),
            30
        );
    }

    #[test]
    fn test_prefill_decay_resets_when_request_becomes_oldest() {
        let mut seq_manager = ActiveSequences::new(4);
        let epoch = Instant::now();

        seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1]),
            100,
            0,
            None,
            true,
            Some(prefill_hint(100, 10)),
            epoch,
        );
        seq_manager.add_request_with_prefill_tracking(
            "r2".to_string(),
            Some(vec![2]),
            80,
            0,
            None,
            true,
            Some(prefill_hint(80, 8)),
            epoch + Duration::from_secs(4),
        );

        assert_eq!(
            seq_manager.active_tokens(epoch + Duration::from_secs(8)),
            100
        );

        seq_manager.mark_prefill_completed(&"r1".to_string(), epoch + Duration::from_secs(8));
        assert_eq!(
            seq_manager.active_tokens(epoch + Duration::from_secs(8)),
            80
        );

        assert_eq!(
            seq_manager.active_tokens(epoch + Duration::from_secs(10)),
            60
        );
    }

    #[test]
    fn test_prefill_front_removal_reanchors_queue_front() {
        let mut seq_manager = ActiveSequences::new(4);
        let epoch = Instant::now();

        seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1]),
            30,
            0,
            None,
            true,
            Some(prefill_hint(30, 6)),
            epoch,
        );
        seq_manager.add_request_with_prefill_tracking(
            "r2".to_string(),
            Some(vec![2]),
            20,
            0,
            None,
            true,
            Some(prefill_hint(20, 4)),
            epoch,
        );

        seq_manager.mark_prefill_completed(&"r1".to_string(), epoch + Duration::from_secs(2));

        assert!(
            seq_manager
                .prefill
                .anchored_prefill
                .as_ref()
                .is_some_and(|(request_id, _)| request_id == "r2")
        );
        assert_eq!(
            seq_manager.active_tokens(epoch + Duration::from_secs(2)),
            20
        );
    }

    #[test]
    fn test_prefill_queue_and_sum_invariants_survive_idempotent_cleanup() {
        let mut seq_manager = ActiveSequences::new(4);
        let decay_now = Instant::now();

        seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1]),
            50,
            0,
            None,
            true,
            Some(prefill_hint(50, 10)),
            decay_now,
        );
        seq_manager.add_request_with_prefill_tracking(
            "r2".to_string(),
            Some(vec![2]),
            30,
            0,
            None,
            true,
            Some(prefill_hint(30, 10)),
            decay_now,
        );

        assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 80);
        assert_eq!(
            seq_manager.prefill.prefill_order,
            VecDeque::from(vec!["r1".to_string(), "r2".to_string()])
        );

        seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
        seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
        assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 30);
        assert_eq!(
            seq_manager.prefill.prefill_order,
            VecDeque::from(vec!["r2".to_string()])
        );

        seq_manager.free(&"r1".to_string(), decay_now);
        seq_manager.free(&"r1".to_string(), decay_now);
        assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 30);
        assert_eq!(
            seq_manager.prefill.prefill_order,
            VecDeque::from(vec!["r2".to_string()])
        );

        seq_manager.free(&"r2".to_string(), decay_now);
        assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 0);
        assert!(seq_manager.prefill.prefill_order.is_empty());
        assert!(seq_manager.requests.is_empty());
    }

860
861
862
863
864
    #[tokio::test(start_paused = true)]
    async fn test_force_expiry() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);

865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
        seq_manager.add_request(
            "r1".to_string(),
            Some(vec![1, 2]),
            8,
            0,
            None,
            Instant::now(),
        );
        seq_manager.add_request(
            "r2".to_string(),
            Some(vec![3, 4]),
            8,
            0,
            None,
            Instant::now(),
        );
881
882
        assert_eq!(seq_manager.active_blocks(), 4);

883
        tokio::time::advance(Duration::from_secs(20)).await;
884
        let expired = seq_manager.force_expiry();
885
886
        assert!(expired.is_empty(), "no check before CHECK_EXPIRY_FREQUENCY");
        assert_eq!(seq_manager.active_blocks(), 4);
887

888
889
890
891
        tokio::time::advance(Duration::from_secs(11)).await;
        let expired = seq_manager.force_expiry();
        assert!(expired.is_empty(), "requests not old enough to expire");
        assert_eq!(seq_manager.active_blocks(), 4);
892
        seq_manager.assert_consistent();
893
894
895

        tokio::time::advance(Duration::from_secs(270)).await;
        let expired = seq_manager.force_expiry();
896
        assert_eq!(expired, HashSet::from(["r1".to_string(), "r2".to_string()]));
897
        assert_eq!(seq_manager.active_blocks(), 0);
898
899
        assert_eq!(seq_manager.active_tokens(Instant::now()), 0);
        seq_manager.assert_consistent();
900

901
        tokio::time::advance(Duration::from_secs(31)).await;
902
903
        let expired =
            seq_manager.add_request("r3".to_string(), Some(vec![5]), 4, 0, None, Instant::now());
904
        assert!(expired.is_empty());
905
        assert_eq!(seq_manager.active_blocks(), 1);
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
        assert_eq!(seq_manager.active_tokens(Instant::now()), 4);
        seq_manager.assert_consistent();
    }

    #[tokio::test(start_paused = true)]
    async fn test_force_expiry_reanchors_new_oldest_request() {
        let mut seq_manager = ActiveSequences::new(4);
        let first_decay_now = Instant::now();

        seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1]),
            40,
            0,
            None,
            true,
            Some(prefill_hint(40, 100)),
            first_decay_now,
        );
        tokio::time::advance(Duration::from_secs(250)).await;
        seq_manager.add_request_with_prefill_tracking(
            "r2".to_string(),
            Some(vec![2]),
            30,
            0,
            None,
            true,
            Some(prefill_hint(30, 100)),
            Instant::now(),
        );

        tokio::time::advance(Duration::from_secs(60)).await;
        let expired = seq_manager.force_expiry();
        assert_eq!(expired, HashSet::from(["r1".to_string()]));
        assert_eq!(seq_manager.active_tokens(Instant::now()), 30);
        assert!(
            seq_manager
                .prefill
                .anchored_prefill
                .as_ref()
                .is_some_and(|(request_id, _)| request_id == "r2")
        );

        tokio::time::advance(Duration::from_secs(20)).await;
        assert_eq!(seq_manager.active_tokens(Instant::now()), 24);
951
952
    }
}