single.rs 30.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! KV Cache Sequence Management for LLM Inference
//!
//! This module provides efficient management of token sequences and their associated KV cache blocks
//! for distributed LLM inference. It implements a shared block system where multiple requests can
//! reuse the same KV cache blocks for common token prefixes, significantly reducing memory usage.
//!
//! # Key Components
//!
//! - [`ActiveSequences`]: Per-worker sequence manager that tracks active requests and their
//!   token sequences, managing shared KV cache blocks efficiently.
//!
//! # Architecture
//!
//! The system uses a block-based approach where token sequences are divided into fixed-size blocks.
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples).

use dynamo_tokens::SequenceHash;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
use uuid::Uuid;

28
29
30
#[cfg(test)]
use rustc_hash::FxHashSet;

31
use super::block_tracker::BlockTracker;
32
33
34
#[cfg(test)]
use super::prefill_tracker::added_prefill_tokens;
use super::prefill_tracker::{PrefillLoadState, PrefillLoadTracker};
35
use super::prompt_registry::WorkerLoadSnapshot;
36
37
use crate::protocols::PrefillLoadHint;

38
/// Duration after which stale requests may be expired (5 minutes).
39
40
const EXPIRY_DURATION: Duration = Duration::from_secs(300);

41
42
43
44
/// How often we *check* for stale requests (30 seconds). This is not
/// the expiration time, that is EXPIRY_DURATION.
const CHECK_EXPIRY_FREQUENCY: Duration = Duration::from_secs(30);

45
46
47
// TODO: use the common request_id if it exists in the repo
pub type RequestId = String;

48
49
#[derive(Debug)]
pub(super) struct RequestState {
50
51
    prompt_blocks: Vec<(SequenceHash, Arc<()>)>,
    output_blocks: Vec<(SequenceHash, Arc<()>)>,
52
53
54
55
    started_at: Instant,
    expected_output_tokens: Option<u32>,
}

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
impl RequestState {
    fn all_blocks(&self) -> impl Iterator<Item = &(SequenceHash, Arc<()>)> {
        self.prompt_blocks.iter().chain(self.output_blocks.iter())
    }
}

#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(super) struct PromptMembershipStore {
    pub parent: Option<SequenceHash>,
    pub hashes: Vec<SequenceHash>,
}

#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(super) struct PromptMembershipRemove {
    pub hashes: Vec<SequenceHash>,
}

#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(super) struct PromptMembershipDelta {
    pub stores: Vec<PromptMembershipStore>,
    pub removes: Vec<PromptMembershipRemove>,
}

impl PromptMembershipDelta {
    fn extend(&mut self, other: Self) {
        self.stores.extend(other.stores);
        self.removes.extend(other.removes);
    }

    fn push_store(&mut self, parent: Option<SequenceHash>, hashes: Vec<SequenceHash>) {
        if hashes.is_empty() {
            return;
        }
        self.stores.push(PromptMembershipStore { parent, hashes });
    }

    fn push_remove(&mut self, hashes: Vec<SequenceHash>) {
        if hashes.is_empty() {
            return;
        }
        self.removes.push(PromptMembershipRemove { hashes });
    }
}

#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(super) struct SequenceMutationOutcome {
    pub membership_delta: PromptMembershipDelta,
    pub expired_request_ids: HashSet<RequestId>,
}

106
/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
107
#[derive(Debug)]
108
pub struct ActiveSequences {
109
110
111
    requests: HashMap<RequestId, RequestState>,
    prefill: PrefillLoadTracker,
    blocks: BlockTracker,
112
    #[cfg(test)]
113
    block_size: usize,
114
    last_expiry_check_time: Instant,
115
116
117
118
}

impl ActiveSequences {
    /// Create a new SharedSequenceManager instance
119
    pub(super) fn new(block_size: usize) -> Self {
120
        assert!(block_size > 0, "block_size must be greater than 0");
121
122

        Self {
123
124
125
            requests: HashMap::new(),
            prefill: PrefillLoadTracker::default(),
            blocks: BlockTracker::default(),
126
            #[cfg(test)]
127
            block_size,
128
            last_expiry_check_time: Instant::now(),
129
130
131
        }
    }

132
133
    #[cfg(any(test, debug_assertions))]
    fn assert_consistent(&self) {
134
135
136
137
138
139
        self.prefill.assert_consistent();
        let active_prefills: HashSet<RequestId> = self.prefill.prefills.keys().cloned().collect();
        let active_requests: HashSet<RequestId> = self.requests.keys().cloned().collect();
        assert!(
            active_prefills.is_subset(&active_requests),
            "prefill tracker cannot reference missing request state",
140
141
142
143
144
145
146
147
        );
        assert!(
            self.blocks
                .fractional_blocks
                .keys()
                .all(|hash| self.blocks.unique_blocks.contains_key(hash)),
            "fractional_blocks cannot reference non-active blocks",
        );
148
149
    }

150
151
152
153
    #[inline]
    fn validate_state(&self) {
        #[cfg(any(test, debug_assertions))]
        self.assert_consistent();
154
155
    }

156
    pub(super) fn active_blocks(&self) -> usize {
157
158
159
        self.blocks.active_blocks()
    }

160
161
162
    #[cfg(test)]
    pub(super) fn active_tokens(&self, decay_now: Instant) -> usize {
        self.prefill.snapshot().active_tokens_at(decay_now)
163
164
    }

165
    /// Add a new request with optional prompt-token load accounting.
166
    /// Returns block membership transitions plus any expired request IDs removed during cleanup.
167
    #[allow(clippy::too_many_arguments)]
168
    pub(super) fn add_request_with_prefill_tracking(
169
170
171
172
173
        &mut self,
        request_id: RequestId,
        token_sequence: Option<Vec<SequenceHash>>,
        expected_output_tokens: Option<u32>,
        track_prefill_tokens: bool,
174
175
        prefill_load_hint: Option<PrefillLoadHint>,
        decay_now: Instant,
176
    ) -> SequenceMutationOutcome {
177
        if self.requests.contains_key(&request_id) {
178
            tracing::error!("Request {request_id} is already active. Ignoring duplicate add.");
179
            return SequenceMutationOutcome::default();
180
181
        }

182
        let mut outcome = self.force_expiry();
183
184
        let started_at = Instant::now();

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        let prompt_blocks = match token_sequence {
            Some(sequence) => {
                let mut first_new_prompt_idx = None;
                let prompt_blocks: Vec<_> = sequence
                    .into_iter()
                    .enumerate()
                    .map(|(idx, block)| {
                        let acquire = self.blocks.touch_block(&block);
                        if acquire.became_present_on_worker && first_new_prompt_idx.is_none() {
                            first_new_prompt_idx = Some(idx);
                        }
                        (block, acquire.rc)
                    })
                    .collect();

                if let Some(first_new_prompt_idx) = first_new_prompt_idx {
                    debug_assert!(
                        prompt_blocks[first_new_prompt_idx..]
                            .iter()
                            .all(|(hash, _)| self.blocks.unique_blocks.contains_key(hash))
                    );
                    let parent = first_new_prompt_idx
                        .checked_sub(1)
                        .map(|idx| prompt_blocks[idx].0);
                    let hashes = prompt_blocks[first_new_prompt_idx..]
                        .iter()
                        .map(|(hash, _)| *hash)
                        .collect();
                    outcome.membership_delta.push_store(parent, hashes);
                }

                prompt_blocks
            }
218
219
            None => Vec::new(),
        };
220

221
        let prefill = if track_prefill_tokens {
222
223
224
225
226
            prefill_load_hint.and_then(|hint| {
                (hint.initial_effective_prefill_tokens > 0).then_some(PrefillLoadState {
                    initial_effective_prefill_tokens: hint.initial_effective_prefill_tokens,
                    expected_prefill_duration: hint.expected_prefill_duration,
                })
227
            })
228
        } else {
229
            None
230
        };
231

232
233
234
        self.requests.insert(
            request_id.clone(),
            RequestState {
235
236
                prompt_blocks,
                output_blocks: Vec::new(),
237
238
239
240
241
242
                started_at,
                expected_output_tokens,
            },
        );

        if let Some(prefill) = prefill {
243
            self.prefill.insert(&request_id, prefill, decay_now);
244
245
        }

246
        self.validate_state();
247
        outcome
248
249
    }

250
    /// Mark prefill as completed for a request, removing it from prompt-load tracking.
251
252
    pub(super) fn mark_prefill_completed(&mut self, request_id: &RequestId, decay_now: Instant) {
        let _ = self.prefill.remove(request_id, decay_now);
253
        self.validate_state();
254
255
    }

256
257
258
259
    /// Free all blocks associated with a request.
    ///
    /// This implicitly calls [`Self::mark_prefill_completed`] first, so callers do not need
    /// to invoke both when the request is finishing.
260
261
262
263
264
265
    pub(super) fn free(
        &mut self,
        request_id: &RequestId,
        decay_now: Instant,
    ) -> PromptMembershipDelta {
        let _ = self.prefill.remove(request_id, decay_now);
266
267
268

        let Some(request_state) = self.requests.remove(request_id) else {
            tracing::warn!("Trying to free non-existent request {request_id}");
269
            return PromptMembershipDelta::default();
270
271
        };

272
        let _ = request_state.expected_output_tokens;
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        let mut membership_delta = PromptMembershipDelta::default();
        let mut first_absent_prompt_idx = None;
        let prompt_hashes: Vec<_> = request_state
            .prompt_blocks
            .iter()
            .map(|(hash, _)| *hash)
            .collect();

        for (idx, (block_hash, rc)) in request_state.prompt_blocks.into_iter().enumerate() {
            drop(rc);
            if self.blocks.try_remove_block(&block_hash) && first_absent_prompt_idx.is_none() {
                first_absent_prompt_idx = Some(idx);
            }
        }

        if let Some(first_absent_prompt_idx) = first_absent_prompt_idx {
            let prompt_remove = prompt_hashes[first_absent_prompt_idx..].to_vec();
            membership_delta.push_remove(prompt_remove);
        }

        for (block_hash, rc) in request_state.output_blocks {
294
            drop(rc);
295
            self.blocks.try_remove_block(&block_hash);
296
297
        }

298
        self.validate_state();
299
        membership_delta
300
301
302
303
304
    }

    /// Add an output block with a random hash and optional fractional decay weight.
    ///
    /// This is used during generation to track output blocks as they are created.
305
    pub(super) fn add_output_block(
306
307
308
        &mut self,
        request_id: &RequestId,
        decay_fraction: Option<f64>,
309
    ) -> Option<SequenceHash> {
310
        if !self.requests.contains_key(request_id) {
311
            tracing::warn!("Request {request_id} not found for add_output_block");
312
            return None;
313
314
        }

315
316
        // TODO: Output blocks still use random hashes, so indexing them mainly simplifies
        // generic block bookkeeping and usually adds little real reuse signal.
317
        let random_hash: SequenceHash = Uuid::new_v4().as_u64_pair().0;
318
        let acquire = self.blocks.touch_block(&random_hash);
319
        self.requests
320
            .get_mut(request_id)
321
            .expect("request existence was checked above")
322
323
            .output_blocks
            .push((random_hash, acquire.rc));
324
325
326
327
328

        if let Some(frac) = decay_fraction {
            self.set_single_ref_blocks_as_fractional(request_id, frac);
        }

329
        self.validate_state();
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
        acquire.became_present_on_worker.then_some(random_hash)
    }

    #[cfg(test)]
    fn potential_blocks_and_tokens_with_prefill_tracking(
        &self,
        token_sequence: Option<&[SequenceHash]>,
        isl: usize,
        overlap: u32,
        track_prefill_tokens: bool,
        decay_now: Instant,
    ) -> (usize, usize) {
        let potential_blocks = if let Some(token_seq) = token_sequence {
            self.new_blocks(token_seq) + self.active_blocks()
        } else {
            self.active_blocks()
        };
        let active_tokens = self.active_tokens(decay_now);
        let potential_tokens = if track_prefill_tokens {
349
            added_prefill_tokens(self.block_size, isl, overlap) + active_tokens
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
        } else {
            active_tokens
        };

        (potential_blocks, potential_tokens)
    }

    /// Match a request against existing blocks and return the number of new blocks that would be added
    pub(super) fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
        token_sequence
            .iter()
            .filter(|block| !self.blocks.unique_blocks.contains_key(block))
            .count()
    }

    /// Return the total number of blocks that would be used if the token sequence was added.
    pub(super) fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
        self.new_blocks(token_sequence) + self.active_blocks()
368
369
    }

370
    /// Force expiry of stale requests if the timer has elapsed.
371
372
    /// Returns block membership transitions plus the set of expired request IDs that were removed.
    pub(super) fn force_expiry(&mut self) -> SequenceMutationOutcome {
373
374
        let now = Instant::now();

375
        if now < self.last_expiry_check_time + CHECK_EXPIRY_FREQUENCY {
376
            return SequenceMutationOutcome::default();
377
378
        }

379
380
        self.last_expiry_check_time = now;
        let expired_requests_time = now - EXPIRY_DURATION;
381
        let expired_request_ids: HashSet<RequestId> = self
382
383
384
385
386
            .requests
            .iter()
            .filter(|(_, state)| state.started_at < expired_requests_time)
            .map(|(request_id, _)| request_id.clone())
            .collect();
387

388
389
390
391
392
393
        let mut outcome = SequenceMutationOutcome {
            expired_request_ids,
            ..Default::default()
        };

        for request_id in &outcome.expired_request_ids {
394
            tracing::warn!("Expiring stale request: {}", request_id);
395
            outcome.membership_delta.extend(self.free(request_id, now));
396
397
        }

398
        self.validate_state();
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
        outcome
    }

    /// Find all blocks in a request that have only a single strong reference (only used by this request)
    /// and insert them into fractional_blocks with the given fraction value.
    fn set_single_ref_blocks_as_fractional(&mut self, request_id: &RequestId, fraction: f64) {
        let Some(request_state) = self.requests.get(request_id) else {
            tracing::warn!(
                "Request {request_id} not found for set_single_ref_blocks_as_fractional"
            );
            return;
        };

        for (hash, rc) in request_state.all_blocks() {
            if Arc::strong_count(rc) == 1 {
                self.blocks.fractional_blocks.insert(*hash, fraction);
            }
        }
    }

    pub(super) fn worker_load_snapshot(&self) -> WorkerLoadSnapshot {
        WorkerLoadSnapshot {
            active_blocks: self.active_blocks(),
            prefill: self.prefill.snapshot(),
        }
    }

    #[cfg(test)]
    pub(super) fn active_block_hashes(&self) -> FxHashSet<SequenceHash> {
        self.blocks.unique_blocks.keys().copied().collect()
    }

    #[cfg(test)]
    pub(super) fn active_prompt_hashes(&self) -> FxHashSet<SequenceHash> {
        self.requests
            .values()
            .flat_map(|state| state.prompt_blocks.iter().map(|(hash, _)| *hash))
            .collect()
437
438
439
440
441
442
    }
}

#[cfg(test)]
mod tests {
    use super::*;
443
444
445
446
447
448
449
450
    use std::collections::VecDeque;

    fn prefill_hint(tokens: usize, duration_secs: u64) -> PrefillLoadHint {
        PrefillLoadHint {
            initial_effective_prefill_tokens: tokens,
            expected_prefill_duration: Some(Duration::from_secs(duration_secs)),
        }
    }
451

452
453
454
455
456
457
458
459
    fn tracking_hint(block_size: usize, isl: usize, overlap: u32) -> Option<PrefillLoadHint> {
        let tokens = added_prefill_tokens(block_size, isl, overlap);
        (tokens > 0).then_some(PrefillLoadHint {
            initial_effective_prefill_tokens: tokens,
            expected_prefill_duration: None,
        })
    }

460
461
462
463
464
465
466
467
468
469
    #[test]
    fn test_prompt_membership_delta_only_reports_first_add_and_last_remove() {
        let mut seq_manager = ActiveSequences::new(4);
        let decay_now = Instant::now();

        let first = seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1, 2]),
            None,
            true,
470
            tracking_hint(4, 8, 0),
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
            decay_now,
        );
        assert_eq!(
            first.membership_delta,
            PromptMembershipDelta {
                stores: vec![PromptMembershipStore {
                    parent: None,
                    hashes: vec![1, 2],
                }],
                removes: Vec::new(),
            }
        );
        assert!(first.expired_request_ids.is_empty());

        let second = seq_manager.add_request_with_prefill_tracking(
            "r2".to_string(),
            Some(vec![1, 2, 3]),
            None,
            true,
490
            tracking_hint(4, 12, 0),
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
            decay_now,
        );
        assert_eq!(
            second.membership_delta,
            PromptMembershipDelta {
                stores: vec![PromptMembershipStore {
                    parent: Some(2),
                    hashes: vec![3],
                }],
                removes: Vec::new(),
            }
        );

        let first_free = seq_manager.free(&"r1".to_string(), decay_now);
        assert!(first_free.removes.is_empty());
        assert!(first_free.stores.is_empty());

        let second_free = seq_manager.free(&"r2".to_string(), decay_now);
        assert!(second_free.stores.is_empty());
        assert_eq!(
            second_free.removes,
            vec![PromptMembershipRemove {
                hashes: vec![1, 2, 3],
            }]
        );
    }

    #[test]
    fn test_generic_block_membership_includes_output_blocks() {
        let mut seq_manager = ActiveSequences::new(4);
        let decay_now = Instant::now();

        let outcome = seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1, 2, 3]),
            None,
            true,
528
            tracking_hint(4, 12, 0),
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
            decay_now,
        );
        assert_eq!(
            outcome.membership_delta.stores,
            vec![PromptMembershipStore {
                parent: None,
                hashes: vec![1, 2, 3],
            }]
        );
        assert_eq!(
            seq_manager.active_block_hashes(),
            [1, 2, 3].into_iter().collect()
        );

        let output_hash = seq_manager
            .add_output_block(&"r1".to_string(), Some(0.5))
            .expect("request exists");
        assert_eq!(
            seq_manager.active_block_hashes(),
            [1, 2, 3, output_hash].into_iter().collect()
        );

        seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
        assert_eq!(seq_manager.active_tokens(decay_now), 0);
        assert_eq!(
            seq_manager.active_block_hashes(),
            [1, 2, 3, output_hash].into_iter().collect()
        );

        let free_delta = seq_manager.free(&"r1".to_string(), decay_now);
        assert_eq!(
            free_delta.removes,
            vec![PromptMembershipRemove {
                hashes: vec![1, 2, 3],
            }]
        );
    }

567
568
569
570
    #[test]
    fn test_active_sequences_shared_blocks() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);
571
        let decay_now = Instant::now();
572

573
        seq_manager.add_request_with_prefill_tracking(
574
575
576
            "request_1".to_string(),
            Some(vec![1, 2, 3]),
            None,
577
578
            true,
            tracking_hint(block_size, 12, 0),
579
580
            decay_now,
        );
581
        assert_eq!(seq_manager.active_blocks(), 3);
582
        assert_eq!(seq_manager.active_tokens(decay_now), 12);
583

584
        seq_manager.add_request_with_prefill_tracking(
585
586
587
            "request_2".to_string(),
            Some(vec![4]),
            None,
588
589
            true,
            tracking_hint(block_size, 4, 0),
590
591
            decay_now,
        );
592
        assert_eq!(seq_manager.active_blocks(), 4);
593
        assert_eq!(seq_manager.active_tokens(decay_now), 16);
594

595
        seq_manager.add_request_with_prefill_tracking(
596
597
598
            "request_3".to_string(),
            Some(vec![1, 2, 3, 4]),
            None,
599
600
            true,
            tracking_hint(block_size, 16, 4),
601
602
            decay_now,
        );
603
        assert_eq!(seq_manager.active_blocks(), 4);
604
        assert_eq!(seq_manager.active_tokens(decay_now), 16);
605

606
        seq_manager.free(&"request_2".to_string(), decay_now);
607
        assert_eq!(seq_manager.active_blocks(), 4);
608
        assert_eq!(seq_manager.active_tokens(decay_now), 12);
609

610
        seq_manager.free(&"request_3".to_string(), decay_now);
611
        assert_eq!(seq_manager.active_blocks(), 3);
612
        assert_eq!(seq_manager.active_tokens(decay_now), 12);
613

614
        seq_manager.free(&"request_1".to_string(), decay_now);
615
        assert_eq!(seq_manager.active_blocks(), 0);
616
        assert_eq!(seq_manager.active_tokens(decay_now), 0);
617
618
619
620
621
622
    }

    #[test]
    fn test_output_blocks_with_fractional_decay() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);
623
        let decay_now = Instant::now();
624

625
        seq_manager.add_request_with_prefill_tracking(
626
627
628
            "r1".to_string(),
            Some(vec![1, 2, 3]),
            None,
629
630
            true,
            tracking_hint(block_size, 12, 0),
631
632
            decay_now,
        );
633
634
        assert_eq!(seq_manager.active_blocks(), 3);

635
636
637
638
639
        assert!(
            seq_manager
                .add_output_block(&"r1".to_string(), Some(0.5))
                .is_some()
        );
640
641
        assert_eq!(seq_manager.active_blocks(), 2);

642
643
644
645
646
647
648
649
        seq_manager.add_request_with_prefill_tracking(
            "r2".to_string(),
            Some(vec![1, 2]),
            None,
            true,
            tracking_hint(block_size, 8, 0),
            decay_now,
        );
650
651
        assert_eq!(seq_manager.active_blocks(), 2);

652
653
654
655
656
        assert!(
            seq_manager
                .add_output_block(&"r1".to_string(), Some(0.0))
                .is_some()
        );
657
658
        assert_eq!(seq_manager.active_blocks(), 1);

659
660
        seq_manager.free(&"r2".to_string(), decay_now);
        seq_manager.free(&"r1".to_string(), decay_now);
661
        assert_eq!(seq_manager.active_blocks(), 0);
662
        assert_eq!(seq_manager.active_tokens(decay_now), 0);
663
664
665
666
667
668
    }

    #[test]
    fn test_mark_prefill_completed() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);
669
        let decay_now = Instant::now();
670

671
        seq_manager.add_request_with_prefill_tracking(
672
673
674
            "r1".to_string(),
            Some(vec![1, 2, 3]),
            None,
675
676
            true,
            tracking_hint(block_size, 12, 0),
677
678
679
            decay_now,
        );
        assert_eq!(seq_manager.active_tokens(decay_now), 12);
680

681
682
        seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
        assert_eq!(seq_manager.active_tokens(decay_now), 0);
683

684
685
        seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
        assert_eq!(seq_manager.active_tokens(decay_now), 0);
686

687
688
689
690
691
692
693
694
        seq_manager.add_request_with_prefill_tracking(
            "r2".to_string(),
            Some(vec![4, 5]),
            None,
            true,
            tracking_hint(block_size, 8, 0),
            decay_now,
        );
695
        assert_eq!(seq_manager.active_tokens(decay_now), 8);
696

697
698
        seq_manager.free(&"r2".to_string(), decay_now);
        assert_eq!(seq_manager.active_tokens(decay_now), 0);
699
700
    }

701
702
703
    #[test]
    fn test_add_request_without_prefill_tracking_keeps_active_tokens_zero() {
        let mut seq_manager = ActiveSequences::new(4);
704
        let decay_now = Instant::now();
705
706
707
708
709
710

        seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1, 2, 3]),
            None,
            false,
711
712
            None,
            decay_now,
713
714
        );

715
716
717
718
719
720
721
        assert_eq!(seq_manager.active_tokens(decay_now), 0);
        assert!(seq_manager.prefill.prefill_order.is_empty());
        assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 0);

        seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
        assert_eq!(seq_manager.active_tokens(decay_now), 0);
        seq_manager.free(&"r1".to_string(), decay_now);
722
723
724
725
726
727
        assert_eq!(seq_manager.active_blocks(), 0);
    }

    #[test]
    fn test_potential_blocks_and_tokens_without_prefill_tracking_ignores_prompt_load() {
        let mut seq_manager = ActiveSequences::new(4);
728
        let decay_now = Instant::now();
729
730
731
732
733
        seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1, 2, 3]),
            None,
            false,
734
735
            None,
            decay_now,
736
737
738
739
740
741
742
        );

        let (blocks, tokens) = seq_manager.potential_blocks_and_tokens_with_prefill_tracking(
            Some(&[1, 2, 3, 4]),
            16,
            0,
            false,
743
            decay_now,
744
745
746
747
748
        );
        assert_eq!(blocks, 4);
        assert_eq!(tokens, 0);
    }

749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
    #[test]
    fn test_prefill_queue_and_sum_invariants_survive_idempotent_cleanup() {
        let mut seq_manager = ActiveSequences::new(4);
        let decay_now = Instant::now();

        seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1]),
            None,
            true,
            Some(prefill_hint(50, 10)),
            decay_now,
        );
        seq_manager.add_request_with_prefill_tracking(
            "r2".to_string(),
            Some(vec![2]),
            None,
            true,
            Some(prefill_hint(30, 10)),
            decay_now,
        );

        assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 80);
        assert_eq!(
            seq_manager.prefill.prefill_order,
            VecDeque::from(vec!["r1".to_string(), "r2".to_string()])
        );

        seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
        seq_manager.mark_prefill_completed(&"r1".to_string(), decay_now);
        assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 30);
        assert_eq!(
            seq_manager.prefill.prefill_order,
            VecDeque::from(vec!["r2".to_string()])
        );

        seq_manager.free(&"r1".to_string(), decay_now);
        seq_manager.free(&"r1".to_string(), decay_now);
        assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 30);
        assert_eq!(
            seq_manager.prefill.prefill_order,
            VecDeque::from(vec!["r2".to_string()])
        );

        seq_manager.free(&"r2".to_string(), decay_now);
        assert_eq!(seq_manager.prefill.prefill_full_tokens_sum, 0);
        assert!(seq_manager.prefill.prefill_order.is_empty());
        assert!(seq_manager.requests.is_empty());
    }

799
800
801
802
803
    #[tokio::test(start_paused = true)]
    async fn test_force_expiry() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);

804
        seq_manager.add_request_with_prefill_tracking(
805
806
807
            "r1".to_string(),
            Some(vec![1, 2]),
            None,
808
809
            true,
            tracking_hint(block_size, 8, 0),
810
811
            Instant::now(),
        );
812
        seq_manager.add_request_with_prefill_tracking(
813
814
815
            "r2".to_string(),
            Some(vec![3, 4]),
            None,
816
817
            true,
            tracking_hint(block_size, 8, 0),
818
819
            Instant::now(),
        );
820
821
        assert_eq!(seq_manager.active_blocks(), 4);

822
        tokio::time::advance(Duration::from_secs(20)).await;
823
        let expired = seq_manager.force_expiry();
824
825
826
827
        assert!(
            expired.expired_request_ids.is_empty(),
            "no check before CHECK_EXPIRY_FREQUENCY"
        );
828
        assert_eq!(seq_manager.active_blocks(), 4);
829

830
831
        tokio::time::advance(Duration::from_secs(11)).await;
        let expired = seq_manager.force_expiry();
832
833
834
835
        assert!(
            expired.expired_request_ids.is_empty(),
            "requests not old enough to expire"
        );
836
        assert_eq!(seq_manager.active_blocks(), 4);
837
        seq_manager.assert_consistent();
838
839
840

        tokio::time::advance(Duration::from_secs(270)).await;
        let expired = seq_manager.force_expiry();
841
842
843
844
        assert_eq!(
            expired.expired_request_ids,
            HashSet::from(["r1".to_string(), "r2".to_string()])
        );
845
        assert_eq!(seq_manager.active_blocks(), 0);
846
847
        assert_eq!(seq_manager.active_tokens(Instant::now()), 0);
        seq_manager.assert_consistent();
848

849
        tokio::time::advance(Duration::from_secs(31)).await;
850
851
852
853
854
855
856
857
        let expired = seq_manager.add_request_with_prefill_tracking(
            "r3".to_string(),
            Some(vec![5]),
            None,
            true,
            tracking_hint(block_size, 4, 0),
            Instant::now(),
        );
858
        assert!(expired.expired_request_ids.is_empty());
859
        assert_eq!(seq_manager.active_blocks(), 1);
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
        assert_eq!(seq_manager.active_tokens(Instant::now()), 4);
        seq_manager.assert_consistent();
    }

    #[tokio::test(start_paused = true)]
    async fn test_force_expiry_reanchors_new_oldest_request() {
        let mut seq_manager = ActiveSequences::new(4);
        let first_decay_now = Instant::now();

        seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1]),
            None,
            true,
            Some(prefill_hint(40, 100)),
            first_decay_now,
        );
        tokio::time::advance(Duration::from_secs(250)).await;
        seq_manager.add_request_with_prefill_tracking(
            "r2".to_string(),
            Some(vec![2]),
            None,
            true,
            Some(prefill_hint(30, 100)),
            Instant::now(),
        );

        tokio::time::advance(Duration::from_secs(60)).await;
        let expired = seq_manager.force_expiry();
889
890
891
892
        assert_eq!(
            expired.expired_request_ids,
            HashSet::from(["r1".to_string()])
        );
893
894
895
896
897
898
899
900
901
902
903
        assert_eq!(seq_manager.active_tokens(Instant::now()), 30);
        assert!(
            seq_manager
                .prefill
                .anchored_prefill
                .as_ref()
                .is_some_and(|(request_id, _)| request_id == "r2")
        );

        tokio::time::advance(Duration::from_secs(20)).await;
        assert_eq!(seq_manager.active_tokens(Instant::now()), 24);
904
905
    }
}