single.rs 20.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! KV Cache Sequence Management for LLM Inference
//!
//! This module provides efficient management of token sequences and their associated KV cache blocks
//! for distributed LLM inference. It implements a shared block system where multiple requests can
//! reuse the same KV cache blocks for common token prefixes, significantly reducing memory usage.
//!
//! # Key Components
//!
//! - [`ActiveSequences`]: Per-worker sequence manager that tracks active requests and their
//!   token sequences, managing shared KV cache blocks efficiently.
//!
//! # Architecture
//!
//! The system uses a block-based approach where token sequences are divided into fixed-size blocks.
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples).

use derive_getters::Getters;
use dynamo_tokens::SequenceHash;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
use uuid::Uuid;

29
/// Duration after which stale requests may be expired (5 minutes).
30
31
const EXPIRY_DURATION: Duration = Duration::from_secs(300);

32
33
34
35
/// How often we *check* for stale requests (30 seconds). This is not
/// the expiration time, that is EXPIRY_DURATION.
const CHECK_EXPIRY_FREQUENCY: Duration = Duration::from_secs(30);

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// TODO: use the common request_id if it exists in the repo
pub type RequestId = String;

/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)]
pub struct ActiveSequences {
    active_seqs: HashMap<RequestId, Vec<(SequenceHash, Arc<()>)>>,

    prefill_tokens: HashMap<RequestId, usize>,

    /// Expected output tokens per request (used for resource estimation)
    expected_output_tokens: HashMap<RequestId, u32>,

    unique_blocks: HashMap<SequenceHash, std::sync::Weak<()>>,

    /// Fractional block counts for blocks that are partially cached
    /// When a block is in both unique_blocks and fractional_blocks,
    /// it contributes the fractional value instead of 1 to active_blocks()
    fractional_blocks: HashMap<SequenceHash, f64>,

    #[getter(copy)]
    block_size: usize,

    #[getter(copy)]
    active_tokens: usize,

62
63
    // Request timestamps, for expiration.
    request_timestamps: HashMap<RequestId, Instant>,
64

65
    last_expiry_check_time: Instant,
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
}

impl ActiveSequences {
    /// Create a new SharedSequenceManager instance
    pub fn new(block_size: usize) -> Self {
        // TODO: make this not a hard req
        assert!(block_size > 1, "block_size must be greater than 1");

        Self {
            active_seqs: HashMap::new(),
            prefill_tokens: HashMap::new(),
            expected_output_tokens: HashMap::new(),
            unique_blocks: HashMap::new(),
            fractional_blocks: HashMap::new(),
            block_size,
            active_tokens: 0,
82
83
            request_timestamps: HashMap::new(),
            last_expiry_check_time: Instant::now(),
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        }
    }

    fn touch_block(&mut self, block: &SequenceHash) -> Arc<()> {
        if let Some(weak) = self.unique_blocks.get(block)
            && let Some(rc) = weak.upgrade()
        {
            return rc;
        }

        let rc = Arc::new(());
        self.unique_blocks.insert(*block, Arc::downgrade(&rc));
        rc
    }

    fn try_remove_block(&mut self, block: &SequenceHash) {
        if let Some(weak) = self.unique_blocks.get(block)
            && weak.strong_count() == 0
        {
            self.unique_blocks.remove(block);
            self.fractional_blocks.remove(block);
        }
    }

    pub fn active_blocks(&self) -> usize {
        let mut count = self.unique_blocks.len() as f64;
        for (hash, frac) in &self.fractional_blocks {
            if self.unique_blocks.contains_key(hash) {
                // Subtract 1 (the full block) and add the fractional value
                count = count - 1.0 + frac;
            }
        }
        count.round() as usize
    }

    /// Find all blocks in a request that have only a single strong reference (only used by this request)
    /// and insert them into fractional_blocks with the given fraction value.
    pub fn set_single_ref_blocks_as_fractional(&mut self, request_id: &RequestId, fraction: f64) {
        let Some(blocks) = self.active_seqs.get(request_id) else {
            tracing::warn!(
                "Request {request_id} not found for set_single_ref_blocks_as_fractional"
            );
            return;
        };

        for (hash, rc) in blocks {
            // A block with strong_count == 1 means only this request holds a reference
            if Arc::strong_count(rc) == 1 {
                self.fractional_blocks.insert(*hash, fraction);
            }
        }
    }

    /// Add a new request with its initial tokens
    /// Returns the set of expired request IDs that were removed during cleanup
    pub fn add_request(
        &mut self,
        request_id: RequestId,
        token_sequence: Option<Vec<SequenceHash>>,
        isl: usize,
        overlap: u32,
        expected_output_tokens: Option<u32>,
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    ) -> HashSet<RequestId> {
        self.add_request_with_prefill_tracking(
            request_id,
            token_sequence,
            isl,
            overlap,
            expected_output_tokens,
            true,
        )
    }

    /// Add a new request with optional prompt-token load accounting.
    /// Returns the set of expired request IDs that were removed during cleanup.
    pub fn add_request_with_prefill_tracking(
        &mut self,
        request_id: RequestId,
        token_sequence: Option<Vec<SequenceHash>>,
        isl: usize,
        overlap: u32,
        expected_output_tokens: Option<u32>,
        track_prefill_tokens: bool,
167
168
169
170
171
172
173
174
175
176
    ) -> HashSet<RequestId> {
        // Check for double-add and log error, returning early
        if self.active_seqs.contains_key(&request_id) {
            tracing::error!("Request {request_id} is already active. Ignoring duplicate add.");
            return HashSet::new();
        }

        // Lazily check and clean up expired requests, capturing removed IDs
        let removed_requests = self.force_expiry();

177
178
179
180
181
        let prefill_tokens = if track_prefill_tokens {
            self.new_tokens(isl, overlap)
        } else {
            0
        };
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
        self.prefill_tokens
            .insert(request_id.clone(), prefill_tokens);
        self.active_tokens += prefill_tokens;

        // Store expected output tokens if provided
        if let Some(tokens) = expected_output_tokens {
            self.expected_output_tokens
                .insert(request_id.clone(), tokens);
        }

        if let Some(sequence) = token_sequence {
            let sequence_with_refs: Vec<(SequenceHash, Arc<()>)> = sequence
                .iter()
                .map(|block| (*block, self.touch_block(block)))
                .collect();
            self.active_seqs
                .insert(request_id.clone(), sequence_with_refs);
        } else {
            // dummy empty sequence
            self.active_seqs.insert(request_id.clone(), Vec::new());
        }
203
204
        self.request_timestamps
            .insert(request_id.clone(), Instant::now());
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235

        removed_requests
    }

    /// Mark prefill as completed for a request, removing it from prefill_tokens tracking
    pub fn mark_prefill_completed(&mut self, request_id: &RequestId) {
        if let Some(tokens) = self.prefill_tokens.remove(request_id) {
            self.active_tokens = self
                .active_tokens
                .checked_sub(tokens)
                .expect("active_tokens underflow");
        }
    }

    pub fn new_tokens(&self, isl: usize, overlap: u32) -> usize {
        let cached_tokens = (overlap as usize) * self.block_size;
        isl.checked_sub(cached_tokens)
            .unwrap_or_else(|| {
                tracing::error!(
                    "prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens} (overlap {overlap} * block_size {}), returning 0",
                    self.block_size
                );
                0
            })
    }

    pub fn potential_blocks_and_tokens(
        &self,
        token_sequence: Option<&[SequenceHash]>,
        isl: usize,
        overlap: u32,
236
237
238
239
240
241
242
243
244
245
    ) -> (usize, usize) {
        self.potential_blocks_and_tokens_with_prefill_tracking(token_sequence, isl, overlap, true)
    }

    pub fn potential_blocks_and_tokens_with_prefill_tracking(
        &self,
        token_sequence: Option<&[SequenceHash]>,
        isl: usize,
        overlap: u32,
        track_prefill_tokens: bool,
246
247
248
249
250
251
    ) -> (usize, usize) {
        let potential_blocks = if let Some(token_seq) = token_sequence {
            self.new_blocks(token_seq) + self.active_blocks()
        } else {
            self.active_blocks()
        };
252
253
254
255
256
        let potential_tokens = if track_prefill_tokens {
            self.new_tokens(isl, overlap) + self.active_tokens
        } else {
            self.active_tokens
        };
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        (potential_blocks, potential_tokens)
    }

    /// Match a request against existing blocks and return the number of new blocks that would be added
    pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
        token_sequence
            .iter()
            .filter(|block| !self.unique_blocks.contains_key(block))
            .count()
    }

    /// Return the total number of blocks that would be used if the token sequence was added
    /// This is the sum of new blocks that would be added plus the current active blocks
    pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
        self.new_blocks(token_sequence) + self.active_blocks()
    }

274
275
276
277
    /// Free all blocks associated with a request.
    ///
    /// This implicitly calls [`Self::mark_prefill_completed`] first, so callers do not need
    /// to invoke both when the request is finishing.
278
279
280
281
282
283
284
    pub fn free(&mut self, request_id: &RequestId) -> usize {
        self.mark_prefill_completed(request_id);

        // Remove expected output tokens tracking
        self.expected_output_tokens.remove(request_id);

        // Remove from active_seqs and get the token sequence
285
        self.request_timestamps.remove(request_id);
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
        let token_seq = match self.active_seqs.remove(request_id) {
            Some(seq) => seq,
            None => {
                tracing::warn!("Trying to free non-existent request {request_id}");
                return self.active_blocks();
            }
        };

        // Drop each Rc reference, then clean up the corresponding weak reference
        for (block_hash, rc) in token_seq {
            drop(rc);
            self.try_remove_block(&block_hash);
        }

        self.active_blocks()
    }

    /// Add an output block with a random hash and optional fractional decay weight.
    ///
    /// This is used during generation to track output blocks as they are created.
    /// The decay_fraction (if provided) represents how "temporary" the block is:
    /// - 1.0 means fully counted (early in generation)
    /// - 0.0 means not counted (near end of expected output)
    /// - Computed as: 1 - (current_osl / expected_output_tokens)
    ///
    /// Returns true if the block was added, false if the request was not found.
    pub fn add_output_block(
        &mut self,
        request_id: &RequestId,
        decay_fraction: Option<f64>,
    ) -> bool {
        // Check if request exists first (immutable borrow)
        if !self.active_seqs.contains_key(request_id) {
            tracing::warn!("Request {request_id} not found for add_output_block");
            return false;
        }

        // Generate a random block hash using UUID
        let random_hash: SequenceHash = Uuid::new_v4().as_u64_pair().0;

        // Touch the block (adds to unique_blocks)
        let rc = self.touch_block(&random_hash);

        // Now we can safely get_mut and push
        self.active_seqs
            .get_mut(request_id)
            .unwrap()
            .push((random_hash, rc));

        // Apply fractional decay to all single-ref blocks in this request if provided
        if let Some(frac) = decay_fraction {
            self.set_single_ref_blocks_as_fractional(request_id, frac);
        }

        true
    }

    /// Force expiry of stale requests if the timer has elapsed
    /// Returns the set of expired request IDs that were removed
    pub fn force_expiry(&mut self) -> HashSet<RequestId> {
        let now = Instant::now();

348
349
        // Early return if timer hasn't expired yet.
        if now < self.last_expiry_check_time + CHECK_EXPIRY_FREQUENCY {
350
351
352
            return HashSet::new();
        }

353
354
355
356
357
358
359
360
361
362
        self.last_expiry_check_time = now;
        let expired_requests_time = now - EXPIRY_DURATION;

        let mut expired_requests: HashSet<RequestId> = HashSet::new();
        for (request_id, timestamp) in self.request_timestamps.iter() {
            if *timestamp < expired_requests_time {
                expired_requests.insert(request_id.clone());
            }
        }

363
        for request_id in &expired_requests {
364
            tracing::warn!("Expiring stale request: {}", request_id);
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
            self.free(request_id);
        }

        expired_requests
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_active_sequences_shared_blocks() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);

        seq_manager.add_request("request_1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
        assert_eq!(seq_manager.active_blocks(), 3);
        assert_eq!(seq_manager.active_tokens(), 12);

        seq_manager.add_request("request_2".to_string(), Some(vec![4]), 4, 0, None);
        assert_eq!(seq_manager.active_blocks(), 4);
        assert_eq!(seq_manager.active_tokens(), 16);

        seq_manager.add_request("request_3".to_string(), Some(vec![1, 2, 3, 4]), 16, 4, None);
        assert_eq!(seq_manager.active_blocks(), 4);
        assert_eq!(seq_manager.active_tokens(), 16);

        seq_manager.free(&"request_2".to_string());
        assert_eq!(seq_manager.active_blocks(), 4);
        assert_eq!(seq_manager.active_tokens(), 12);

        seq_manager.free(&"request_3".to_string());
        assert_eq!(seq_manager.active_blocks(), 3);
        assert_eq!(seq_manager.active_tokens(), 12);

        seq_manager.free(&"request_1".to_string());
        assert_eq!(seq_manager.active_blocks(), 0);
        assert_eq!(seq_manager.active_tokens(), 0);
    }

    #[test]
    fn test_output_blocks_with_fractional_decay() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);

        // Add request with 3 prefill blocks
        seq_manager.add_request("r1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
        assert_eq!(seq_manager.active_blocks(), 3);

        // Add output block with 0.5 decay fraction.
        // This adds a random block and sets all single-ref blocks to 0.5.
        assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.5)));
        // 4 unique blocks, all single-ref → all fractional at 0.5
        // active_blocks = 4 - 4 + 4*0.5 = 2
        assert_eq!(seq_manager.active_blocks(), 2);

        // Add second request sharing prefix [1, 2]
        seq_manager.add_request("r2".to_string(), Some(vec![1, 2]), 8, 0, None);
        // Blocks 1,2 now have strong_count=2 but still have fractional 0.5 from before
        // No new unique blocks → active_blocks = 4 - 4 + 2.0 = 2
        assert_eq!(seq_manager.active_blocks(), 2);

        // Add another output block with 0.0 decay for r1.
        // set_single_ref_blocks_as_fractional updates only single-ref blocks:
        //   blocks 1,2: strong_count=2, NOT updated (remain 0.5)
        //   block 3, old output, new output: strong_count=1, set to 0.0
        // active_blocks = 5 - 5 + (0.5+0.5+0.0+0.0+0.0) = 1
        assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.0)));
        assert_eq!(seq_manager.active_blocks(), 1);

        // Free both requests, verify clean state
        seq_manager.free(&"r2".to_string());
        seq_manager.free(&"r1".to_string());
        assert_eq!(seq_manager.active_blocks(), 0);
        assert_eq!(seq_manager.active_tokens(), 0);
    }

    #[test]
    fn test_mark_prefill_completed() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);

        // Add request with isl=12, overlap=0 → active_tokens=12
        seq_manager.add_request("r1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
        assert_eq!(seq_manager.active_tokens(), 12);

        // Mark prefill completed → active_tokens drops to 0
        seq_manager.mark_prefill_completed(&"r1".to_string());
        assert_eq!(seq_manager.active_tokens(), 0);

        // Double-mark: no panic, still 0
        seq_manager.mark_prefill_completed(&"r1".to_string());
        assert_eq!(seq_manager.active_tokens(), 0);

        // Add second request with isl=8
        seq_manager.add_request("r2".to_string(), Some(vec![4, 5]), 8, 0, None);
        assert_eq!(seq_manager.active_tokens(), 8);

        // Free it (internally calls mark_prefill_completed) → active_tokens=0
        seq_manager.free(&"r2".to_string());
        assert_eq!(seq_manager.active_tokens(), 0);
    }

469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    #[test]
    fn test_add_request_without_prefill_tracking_keeps_active_tokens_zero() {
        let mut seq_manager = ActiveSequences::new(4);

        seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1, 2, 3]),
            12,
            0,
            None,
            false,
        );

        assert_eq!(seq_manager.active_tokens(), 0);
        seq_manager.mark_prefill_completed(&"r1".to_string());
        assert_eq!(seq_manager.active_tokens(), 0);
        seq_manager.free(&"r1".to_string());
        assert_eq!(seq_manager.active_blocks(), 0);
    }

    #[test]
    fn test_potential_blocks_and_tokens_without_prefill_tracking_ignores_prompt_load() {
        let mut seq_manager = ActiveSequences::new(4);
        seq_manager.add_request_with_prefill_tracking(
            "r1".to_string(),
            Some(vec![1, 2, 3]),
            12,
            0,
            None,
            false,
        );

        let (blocks, tokens) = seq_manager.potential_blocks_and_tokens_with_prefill_tracking(
            Some(&[1, 2, 3, 4]),
            16,
            0,
            false,
        );
        assert_eq!(blocks, 4);
        assert_eq!(tokens, 0);
    }

511
512
513
514
515
    #[tokio::test(start_paused = true)]
    async fn test_force_expiry() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);

516
        // Add two requests at time 0 (paused clock)
517
518
519
520
        seq_manager.add_request("r1".to_string(), Some(vec![1, 2]), 8, 0, None);
        seq_manager.add_request("r2".to_string(), Some(vec![3, 4]), 8, 0, None);
        assert_eq!(seq_manager.active_blocks(), 4);

521
522
523
        // Advance 20s: check interval (CHECK_EXPIRY_FREQUENCY = 30s) not reached,
        // force_expiry returns without running the check.
        tokio::time::advance(Duration::from_secs(20)).await;
524
        let expired = seq_manager.force_expiry();
525
526
        assert!(expired.is_empty(), "no check before CHECK_EXPIRY_FREQUENCY");
        assert_eq!(seq_manager.active_blocks(), 4);
527

528
529
530
531
532
533
534
535
536
537
538
        // Advance to 31s: first time we pass the check interval. Requests are 31s old,
        // still under EXPIRY_DURATION (300s), so none are expired.
        tokio::time::advance(Duration::from_secs(11)).await;
        let expired = seq_manager.force_expiry();
        assert!(expired.is_empty(), "requests not old enough to expire");
        assert_eq!(seq_manager.active_blocks(), 4);

        // Advance to 301s: requests are now older than EXPIRY_DURATION.
        // force_expiry runs and expires r1, r2.
        tokio::time::advance(Duration::from_secs(270)).await;
        let expired = seq_manager.force_expiry();
539
        assert_eq!(expired, HashSet::from(["r1".to_string(), "r2".to_string()]));
540
541
        assert_eq!(seq_manager.active_blocks(), 0);
        assert_eq!(seq_manager.active_tokens(), 0);
542

543
544
545
546
547
        // add_request calls force_expiry internally. Add r3; no old requests remain,
        // so expired set is empty and only r3 is active.
        tokio::time::advance(Duration::from_secs(31)).await;
        let expired = seq_manager.add_request("r3".to_string(), Some(vec![5]), 4, 0, None);
        assert!(expired.is_empty());
548
549
550
551
        assert_eq!(seq_manager.active_blocks(), 1);
        assert_eq!(seq_manager.active_tokens(), 4);
    }
}