single.rs 17.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

//! KV Cache Sequence Management for LLM Inference
//!
//! This module provides efficient management of token sequences and their associated KV cache blocks
//! for distributed LLM inference. It implements a shared block system where multiple requests can
//! reuse the same KV cache blocks for common token prefixes, significantly reducing memory usage.
//!
//! # Key Components
//!
//! - [`ActiveSequences`]: Per-worker sequence manager that tracks active requests and their
//!   token sequences, managing shared KV cache blocks efficiently.
//!
//! # Architecture
//!
//! The system uses a block-based approach where token sequences are divided into fixed-size blocks.
//! Each block is identified by a hash of its contents, allowing for deduplication when multiple
//! requests share common prefixes (e.g., system prompts, few-shot examples).

use derive_getters::Getters;
use dynamo_tokens::SequenceHash;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::Instant;
use uuid::Uuid;

29
/// Duration after which stale requests may be expired (5 minutes).
30
31
const EXPIRY_DURATION: Duration = Duration::from_secs(300);

32
33
34
35
/// How often we *check* for stale requests (30 seconds). This is not
/// the expiration time, that is EXPIRY_DURATION.
const CHECK_EXPIRY_FREQUENCY: Duration = Duration::from_secs(30);

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// TODO: use the common request_id if it exists in the repo
pub type RequestId = String;

/// A multi-request sequence manager that handles multiple active sequences with shared KV cache
#[derive(Debug, Getters)]
pub struct ActiveSequences {
    active_seqs: HashMap<RequestId, Vec<(SequenceHash, Arc<()>)>>,

    prefill_tokens: HashMap<RequestId, usize>,

    /// Expected output tokens per request (used for resource estimation)
    expected_output_tokens: HashMap<RequestId, u32>,

    unique_blocks: HashMap<SequenceHash, std::sync::Weak<()>>,

    /// Fractional block counts for blocks that are partially cached
    /// When a block is in both unique_blocks and fractional_blocks,
    /// it contributes the fractional value instead of 1 to active_blocks()
    fractional_blocks: HashMap<SequenceHash, f64>,

    #[getter(copy)]
    block_size: usize,

    #[getter(copy)]
    active_tokens: usize,

62
63
    // Request timestamps, for expiration.
    request_timestamps: HashMap<RequestId, Instant>,
64

65
    last_expiry_check_time: Instant,
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
}

impl ActiveSequences {
    /// Create a new SharedSequenceManager instance
    pub fn new(block_size: usize) -> Self {
        // TODO: make this not a hard req
        assert!(block_size > 1, "block_size must be greater than 1");

        Self {
            active_seqs: HashMap::new(),
            prefill_tokens: HashMap::new(),
            expected_output_tokens: HashMap::new(),
            unique_blocks: HashMap::new(),
            fractional_blocks: HashMap::new(),
            block_size,
            active_tokens: 0,
82
83
            request_timestamps: HashMap::new(),
            last_expiry_check_time: Instant::now(),
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        }
    }

    fn touch_block(&mut self, block: &SequenceHash) -> Arc<()> {
        if let Some(weak) = self.unique_blocks.get(block)
            && let Some(rc) = weak.upgrade()
        {
            return rc;
        }

        let rc = Arc::new(());
        self.unique_blocks.insert(*block, Arc::downgrade(&rc));
        rc
    }

    fn try_remove_block(&mut self, block: &SequenceHash) {
        if let Some(weak) = self.unique_blocks.get(block)
            && weak.strong_count() == 0
        {
            self.unique_blocks.remove(block);
            self.fractional_blocks.remove(block);
        }
    }

    pub fn active_blocks(&self) -> usize {
        let mut count = self.unique_blocks.len() as f64;
        for (hash, frac) in &self.fractional_blocks {
            if self.unique_blocks.contains_key(hash) {
                // Subtract 1 (the full block) and add the fractional value
                count = count - 1.0 + frac;
            }
        }
        count.round() as usize
    }

    /// Find all blocks in a request that have only a single strong reference (only used by this request)
    /// and insert them into fractional_blocks with the given fraction value.
    pub fn set_single_ref_blocks_as_fractional(&mut self, request_id: &RequestId, fraction: f64) {
        let Some(blocks) = self.active_seqs.get(request_id) else {
            tracing::warn!(
                "Request {request_id} not found for set_single_ref_blocks_as_fractional"
            );
            return;
        };

        for (hash, rc) in blocks {
            // A block with strong_count == 1 means only this request holds a reference
            if Arc::strong_count(rc) == 1 {
                self.fractional_blocks.insert(*hash, fraction);
            }
        }
    }

    /// Add a new request with its initial tokens
    /// Returns the set of expired request IDs that were removed during cleanup
    pub fn add_request(
        &mut self,
        request_id: RequestId,
        token_sequence: Option<Vec<SequenceHash>>,
        isl: usize,
        overlap: u32,
        expected_output_tokens: Option<u32>,
    ) -> HashSet<RequestId> {
        // Check for double-add and log error, returning early
        if self.active_seqs.contains_key(&request_id) {
            tracing::error!("Request {request_id} is already active. Ignoring duplicate add.");
            return HashSet::new();
        }

        // Lazily check and clean up expired requests, capturing removed IDs
        let removed_requests = self.force_expiry();

        let prefill_tokens = self.new_tokens(isl, overlap);
        self.prefill_tokens
            .insert(request_id.clone(), prefill_tokens);
        self.active_tokens += prefill_tokens;

        // Store expected output tokens if provided
        if let Some(tokens) = expected_output_tokens {
            self.expected_output_tokens
                .insert(request_id.clone(), tokens);
        }

        if let Some(sequence) = token_sequence {
            let sequence_with_refs: Vec<(SequenceHash, Arc<()>)> = sequence
                .iter()
                .map(|block| (*block, self.touch_block(block)))
                .collect();
            self.active_seqs
                .insert(request_id.clone(), sequence_with_refs);
        } else {
            // dummy empty sequence
            self.active_seqs.insert(request_id.clone(), Vec::new());
        }
178
179
        self.request_timestamps
            .insert(request_id.clone(), Instant::now());
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242

        removed_requests
    }

    /// Mark prefill as completed for a request, removing it from prefill_tokens tracking
    pub fn mark_prefill_completed(&mut self, request_id: &RequestId) {
        if let Some(tokens) = self.prefill_tokens.remove(request_id) {
            self.active_tokens = self
                .active_tokens
                .checked_sub(tokens)
                .expect("active_tokens underflow");
        }
    }

    pub fn new_tokens(&self, isl: usize, overlap: u32) -> usize {
        let cached_tokens = (overlap as usize) * self.block_size;
        isl.checked_sub(cached_tokens)
            .unwrap_or_else(|| {
                tracing::error!(
                    "prefill_tokens < 0 with ISL {isl} < cached_tokens {cached_tokens} (overlap {overlap} * block_size {}), returning 0",
                    self.block_size
                );
                0
            })
    }

    pub fn potential_blocks_and_tokens(
        &self,
        token_sequence: Option<&[SequenceHash]>,
        isl: usize,
        overlap: u32,
    ) -> (usize, usize) {
        let potential_blocks = if let Some(token_seq) = token_sequence {
            self.new_blocks(token_seq) + self.active_blocks()
        } else {
            self.active_blocks()
        };
        let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens;
        (potential_blocks, potential_tokens)
    }

    /// Match a request against existing blocks and return the number of new blocks that would be added
    pub fn new_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
        token_sequence
            .iter()
            .filter(|block| !self.unique_blocks.contains_key(block))
            .count()
    }

    /// Return the total number of blocks that would be used if the token sequence was added
    /// This is the sum of new blocks that would be added plus the current active blocks
    pub fn potential_blocks(&self, token_sequence: &[SequenceHash]) -> usize {
        self.new_blocks(token_sequence) + self.active_blocks()
    }

    /// Free all blocks associated with a request
    pub fn free(&mut self, request_id: &RequestId) -> usize {
        self.mark_prefill_completed(request_id);

        // Remove expected output tokens tracking
        self.expected_output_tokens.remove(request_id);

        // Remove from active_seqs and get the token sequence
243
        self.request_timestamps.remove(request_id);
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        let token_seq = match self.active_seqs.remove(request_id) {
            Some(seq) => seq,
            None => {
                tracing::warn!("Trying to free non-existent request {request_id}");
                return self.active_blocks();
            }
        };

        // Drop each Rc reference, then clean up the corresponding weak reference
        for (block_hash, rc) in token_seq {
            drop(rc);
            self.try_remove_block(&block_hash);
        }

        self.active_blocks()
    }

    /// Add an output block with a random hash and optional fractional decay weight.
    ///
    /// This is used during generation to track output blocks as they are created.
    /// The decay_fraction (if provided) represents how "temporary" the block is:
    /// - 1.0 means fully counted (early in generation)
    /// - 0.0 means not counted (near end of expected output)
    /// - Computed as: 1 - (current_osl / expected_output_tokens)
    ///
    /// Returns true if the block was added, false if the request was not found.
    pub fn add_output_block(
        &mut self,
        request_id: &RequestId,
        decay_fraction: Option<f64>,
    ) -> bool {
        // Check if request exists first (immutable borrow)
        if !self.active_seqs.contains_key(request_id) {
            tracing::warn!("Request {request_id} not found for add_output_block");
            return false;
        }

        // Generate a random block hash using UUID
        let random_hash: SequenceHash = Uuid::new_v4().as_u64_pair().0;

        // Touch the block (adds to unique_blocks)
        let rc = self.touch_block(&random_hash);

        // Now we can safely get_mut and push
        self.active_seqs
            .get_mut(request_id)
            .unwrap()
            .push((random_hash, rc));

        // Apply fractional decay to all single-ref blocks in this request if provided
        if let Some(frac) = decay_fraction {
            self.set_single_ref_blocks_as_fractional(request_id, frac);
        }

        true
    }

    /// Force expiry of stale requests if the timer has elapsed
    /// Returns the set of expired request IDs that were removed
    pub fn force_expiry(&mut self) -> HashSet<RequestId> {
        let now = Instant::now();

306
307
        // Early return if timer hasn't expired yet.
        if now < self.last_expiry_check_time + CHECK_EXPIRY_FREQUENCY {
308
309
310
            return HashSet::new();
        }

311
312
313
314
315
316
317
318
319
320
        self.last_expiry_check_time = now;
        let expired_requests_time = now - EXPIRY_DURATION;

        let mut expired_requests: HashSet<RequestId> = HashSet::new();
        for (request_id, timestamp) in self.request_timestamps.iter() {
            if *timestamp < expired_requests_time {
                expired_requests.insert(request_id.clone());
            }
        }

321
        for request_id in &expired_requests {
322
            tracing::warn!("Expiring stale request: {}", request_id);
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
            self.free(request_id);
        }

        expired_requests
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_active_sequences_shared_blocks() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);

        seq_manager.add_request("request_1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
        assert_eq!(seq_manager.active_blocks(), 3);
        assert_eq!(seq_manager.active_tokens(), 12);

        seq_manager.add_request("request_2".to_string(), Some(vec![4]), 4, 0, None);
        assert_eq!(seq_manager.active_blocks(), 4);
        assert_eq!(seq_manager.active_tokens(), 16);

        seq_manager.add_request("request_3".to_string(), Some(vec![1, 2, 3, 4]), 16, 4, None);
        assert_eq!(seq_manager.active_blocks(), 4);
        assert_eq!(seq_manager.active_tokens(), 16);

        seq_manager.free(&"request_2".to_string());
        assert_eq!(seq_manager.active_blocks(), 4);
        assert_eq!(seq_manager.active_tokens(), 12);

        seq_manager.free(&"request_3".to_string());
        assert_eq!(seq_manager.active_blocks(), 3);
        assert_eq!(seq_manager.active_tokens(), 12);

        seq_manager.free(&"request_1".to_string());
        assert_eq!(seq_manager.active_blocks(), 0);
        assert_eq!(seq_manager.active_tokens(), 0);
    }

    #[test]
    fn test_output_blocks_with_fractional_decay() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);

        // Add request with 3 prefill blocks
        seq_manager.add_request("r1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
        assert_eq!(seq_manager.active_blocks(), 3);

        // Add output block with 0.5 decay fraction.
        // This adds a random block and sets all single-ref blocks to 0.5.
        assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.5)));
        // 4 unique blocks, all single-ref → all fractional at 0.5
        // active_blocks = 4 - 4 + 4*0.5 = 2
        assert_eq!(seq_manager.active_blocks(), 2);

        // Add second request sharing prefix [1, 2]
        seq_manager.add_request("r2".to_string(), Some(vec![1, 2]), 8, 0, None);
        // Blocks 1,2 now have strong_count=2 but still have fractional 0.5 from before
        // No new unique blocks → active_blocks = 4 - 4 + 2.0 = 2
        assert_eq!(seq_manager.active_blocks(), 2);

        // Add another output block with 0.0 decay for r1.
        // set_single_ref_blocks_as_fractional updates only single-ref blocks:
        //   blocks 1,2: strong_count=2, NOT updated (remain 0.5)
        //   block 3, old output, new output: strong_count=1, set to 0.0
        // active_blocks = 5 - 5 + (0.5+0.5+0.0+0.0+0.0) = 1
        assert!(seq_manager.add_output_block(&"r1".to_string(), Some(0.0)));
        assert_eq!(seq_manager.active_blocks(), 1);

        // Free both requests, verify clean state
        seq_manager.free(&"r2".to_string());
        seq_manager.free(&"r1".to_string());
        assert_eq!(seq_manager.active_blocks(), 0);
        assert_eq!(seq_manager.active_tokens(), 0);
    }

    #[test]
    fn test_mark_prefill_completed() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);

        // Add request with isl=12, overlap=0 → active_tokens=12
        seq_manager.add_request("r1".to_string(), Some(vec![1, 2, 3]), 12, 0, None);
        assert_eq!(seq_manager.active_tokens(), 12);

        // Mark prefill completed → active_tokens drops to 0
        seq_manager.mark_prefill_completed(&"r1".to_string());
        assert_eq!(seq_manager.active_tokens(), 0);

        // Double-mark: no panic, still 0
        seq_manager.mark_prefill_completed(&"r1".to_string());
        assert_eq!(seq_manager.active_tokens(), 0);

        // Add second request with isl=8
        seq_manager.add_request("r2".to_string(), Some(vec![4, 5]), 8, 0, None);
        assert_eq!(seq_manager.active_tokens(), 8);

        // Free it (internally calls mark_prefill_completed) → active_tokens=0
        seq_manager.free(&"r2".to_string());
        assert_eq!(seq_manager.active_tokens(), 0);
    }

    #[tokio::test(start_paused = true)]
    async fn test_force_expiry() {
        let block_size = 4;
        let mut seq_manager = ActiveSequences::new(block_size);

432
        // Add two requests at time 0 (paused clock)
433
434
435
436
        seq_manager.add_request("r1".to_string(), Some(vec![1, 2]), 8, 0, None);
        seq_manager.add_request("r2".to_string(), Some(vec![3, 4]), 8, 0, None);
        assert_eq!(seq_manager.active_blocks(), 4);

437
438
439
        // Advance 20s: check interval (CHECK_EXPIRY_FREQUENCY = 30s) not reached,
        // force_expiry returns without running the check.
        tokio::time::advance(Duration::from_secs(20)).await;
440
        let expired = seq_manager.force_expiry();
441
442
        assert!(expired.is_empty(), "no check before CHECK_EXPIRY_FREQUENCY");
        assert_eq!(seq_manager.active_blocks(), 4);
443

444
445
446
447
448
449
450
451
452
453
454
        // Advance to 31s: first time we pass the check interval. Requests are 31s old,
        // still under EXPIRY_DURATION (300s), so none are expired.
        tokio::time::advance(Duration::from_secs(11)).await;
        let expired = seq_manager.force_expiry();
        assert!(expired.is_empty(), "requests not old enough to expire");
        assert_eq!(seq_manager.active_blocks(), 4);

        // Advance to 301s: requests are now older than EXPIRY_DURATION.
        // force_expiry runs and expires r1, r2.
        tokio::time::advance(Duration::from_secs(270)).await;
        let expired = seq_manager.force_expiry();
455
        assert_eq!(expired, HashSet::from(["r1".to_string(), "r2".to_string()]));
456
457
        assert_eq!(seq_manager.active_blocks(), 0);
        assert_eq!(seq_manager.active_tokens(), 0);
458

459
460
461
462
463
        // add_request calls force_expiry internally. Add r3; no old requests remain,
        // so expired set is empty and only r3 is active.
        tokio::time::advance(Duration::from_secs(31)).await;
        let expired = seq_manager.add_request("r3".to_string(), Some(vec![5]), 4, 0, None);
        assert!(expired.is_empty());
464
465
466
467
        assert_eq!(seq_manager.active_blocks(), 1);
        assert_eq!(seq_manager.active_tokens(), 4);
    }
}