// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use crate::common::protocols::MoveBlock; use derive_getters::Getters; use dynamo_tokens::blocks::UniqueBlock; use dynamo_tokens::{TokenBlockSequence, Tokens}; use rand::random; use validator::Validate; /// Create unique blocks from a TokenBlockSequence fn create_unique_blocks_from_sequence( tokens: &TokenBlockSequence, block_size: usize, enable_prefix_caching: bool, ) -> Vec { let mut unique_blocks: Vec = tokens .blocks() .iter() .map(|block| { if enable_prefix_caching { UniqueBlock::FullBlock(block.sequence_hash()) } else { UniqueBlock::FullBlock(random::()) } }) .collect(); // Only push the partial block if tokens count isn't a multiple of block_size if !tokens.total_tokens().is_multiple_of(block_size) { unique_blocks.push(UniqueBlock::default()); } unique_blocks } /// A sequence that is actively being built, with the ability to add tokens and commit to hashes /// TODO: reuse tokens #[derive(Debug, Getters, Validate)] pub struct ActiveSequence { unique_blocks: Vec, tokens: TokenBlockSequence, #[getter(copy)] #[validate(range(min = 2))] block_size: usize, #[getter(copy)] max_output_tokens: usize, #[getter(copy)] generated_tokens: usize, #[getter(copy)] num_input_tokens: usize, #[getter(copy)] num_allocated_tokens: usize, #[getter(copy)] enable_prefix_caching: bool, #[getter(copy)] emit_token_ids: bool, } impl ActiveSequence { /// Create a new ActiveSequence instance with the provided tokens pub fn new( tokens: Vec, max_output_tokens: usize, block_size: Option, enable_prefix_caching: bool, emit_token_ids: bool, ) -> Self { let block_size = block_size.unwrap_or(64); let num_input_tokens = tokens.len(); let tokens = Tokens::from(tokens).into_sequence(block_size as u32, Some(1337)); let unique_blocks = create_unique_blocks_from_sequence(&tokens, block_size, enable_prefix_caching); let seq = Self { unique_blocks, tokens, block_size, max_output_tokens, generated_tokens: 0, num_input_tokens, num_allocated_tokens: 0, enable_prefix_caching, emit_token_ids, }; seq.validate().expect("invalid ActiveSequence"); seq } pub fn extra_tokens(&self) -> u32 { (self.len() % self.block_size) as u32 } pub fn len(&self) -> usize { self.tokens.total_tokens() } pub fn is_empty(&self) -> bool { self.tokens.total_tokens() == 0 } /// Build a `MoveBlock::Use` signal for blocks up to `cumulative_tokens` /// without updating internal state. Returns `None` if no new blocks are needed. /// Call `commit_allocation` after the signal is successfully processed. pub fn prepare_allocation(&self, cumulative_tokens: usize) -> Option { let prev_blocks = self .num_allocated_tokens .div_ceil(self.block_size) .min(self.unique_blocks.len()); let target_blocks = cumulative_tokens .div_ceil(self.block_size) .min(self.unique_blocks.len()); if target_blocks <= prev_blocks { return None; } let range = prev_blocks..target_blocks; let blocks = self.unique_blocks[range.clone()].to_vec(); let all_hashes = self.block_hashes(); let num_full = all_hashes.len(); let hash_start = prev_blocks.min(num_full); let hash_end = target_blocks.min(num_full); let hashes = all_hashes[hash_start..hash_end].to_vec(); let token_ids = if self.emit_token_ids && hash_start < hash_end { let all_token_ids: Vec> = self .tokens .blocks() .iter() .map(|b| b.tokens().to_vec()) .collect(); Some(all_token_ids[hash_start..hash_end].to_vec()) } else { None }; let parent = if prev_blocks > 0 { Some(self.unique_blocks[prev_blocks - 1].clone()) } else { None }; Some(MoveBlock::Use(blocks, hashes, token_ids, parent)) } /// Commit a successful allocation by advancing `num_allocated_tokens`. pub fn commit_allocation(&mut self, cumulative_tokens: usize) { self.num_allocated_tokens = cumulative_tokens; } /// Prepare + commit in one call (convenience for paths where failure is impossible). pub fn allocate_blocks_for_chunk(&mut self, cumulative_tokens: usize) -> Option { let signal = self.prepare_allocation(cumulative_tokens); self.commit_allocation(cumulative_tokens); signal } /// Allocate all remaining blocks at once (backward compat). pub fn take_creation_signal(&mut self) -> Option { self.allocate_blocks_for_chunk(self.len()) } pub fn block_hashes(&self) -> Vec { self.tokens .blocks() .iter() .map(|block| block.block_hash()) .collect() } /// Create a new ActiveSequence instance and return the creation signal pub fn new_with_signal( tokens: Vec, max_output_tokens: usize, block_size: Option, enable_prefix_caching: bool, ) -> (Self, Option) { let mut sequence = Self::new( tokens, max_output_tokens, block_size, enable_prefix_caching, false, ); let signal = sequence.take_creation_signal(); (sequence, signal) } /// Push a token to the sequence pub fn push(&mut self, token: u32) -> Option> { self.tokens.append(token).expect("Token push failed."); self.generated_tokens += 1; if self.len() % self.block_size != 1 { return None; } // Add a partial block for the first token in a new partial sequence // Send Use signal (to allocate space for this new generation block) let mut signals = Vec::new(); // Replace last partial block with full block if it exists if let Some(UniqueBlock::PartialBlock(uuid)) = self.unique_blocks.last().cloned() { let last_complete = self.tokens.last_complete_block().unwrap(); let last_seq_hash = if self.enable_prefix_caching { last_complete.sequence_hash() } else { random::() }; let last_block_hash = last_complete.block_hash(); let promote_token_ids = if self.emit_token_ids { Some(last_complete.tokens().to_vec()) } else { None }; self.unique_blocks.pop(); // After pop, the last element is the parent block let second_to_last_hash = self.unique_blocks.last().map(|block| match block { UniqueBlock::FullBlock(hash) => *hash, UniqueBlock::PartialBlock(_) => panic!("Cannot have a partial block as parent"), }); self.unique_blocks .push(UniqueBlock::FullBlock(last_seq_hash)); signals.push(MoveBlock::Promote( uuid, last_seq_hash, second_to_last_hash, last_block_hash, promote_token_ids, )); } let new_partial_block = UniqueBlock::default(); self.unique_blocks.push(new_partial_block.clone()); signals.push(MoveBlock::Use(vec![new_partial_block], vec![], None, None)); Some(signals) } /// Generate a random token, push it to the sequence, and increment generation count. /// /// This function: /// - Generates a random token and adds it to the current sequence /// - Acquires a new partial block if needed or promotes an existing partial block to a full block /// - Returns appropriate signals for the KvManager to process /// /// # Panics /// /// Calling this function when max_output_tokens has already been reached will cause a panic. /// Always check `generated_tokens < max_output_tokens` before calling this method. pub fn generate(&mut self) -> Vec { // Assert that we haven't reached the maximum output tokens assert!( self.generated_tokens < self.max_output_tokens, "Cannot generate more tokens: reached max_output_tokens limit" ); // Generate a random token let token = random::(); // Collect signals let mut signals = Vec::new(); // Push the token to the sequence and collect any signals if let Some(move_blocks) = self.push(token) { signals.extend(move_blocks); } // Check if we've reached the limit after pushing if self.generated_tokens != self.max_output_tokens { return signals; } // Free all blocks when we reach max tokens signals.extend(self.free_signal()); signals } /// Free all blocks, generating appropriate signals for each block type pub fn free_signal(&self) -> Vec { self.unique_blocks .iter() .rev() .map(|block| match block { UniqueBlock::PartialBlock(uuid) => { MoveBlock::Destroy(vec![UniqueBlock::PartialBlock(*uuid)]) } UniqueBlock::FullBlock(hash) => { MoveBlock::Deref(vec![UniqueBlock::FullBlock(*hash)]) } }) .collect() } /// Move the request to a preempted state and return the free signals from freeing current blocks. /// Upon preemption, the sequence retains the tokens generated during the decode phase (if any). /// Resets `num_allocated_tokens` so re-admission will re-allocate from scratch. pub fn reset_with_signal(&mut self) -> Vec { let free_signal = self.free_signal(); self.num_allocated_tokens = 0; free_signal } /// Pops last token in the sequence. pub fn pop(&mut self) { self.tokens.pop(); self.generated_tokens = self.generated_tokens.saturating_sub(1); // Reverts to the last full block if self.tokens.total_tokens().is_multiple_of(self.block_size) { self.unique_blocks.pop(); } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_active_sequence_push() { // Create a sequence with block size 16 initialized with tokens [0..15] let initial_tokens: Vec = (0..15).collect(); let (mut seq1, signal1) = ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true); assert_eq!(seq1.num_input_tokens(), 15); assert_eq!(seq1.len(), 15); // Check that we got a Use signal assert!(signal1.is_some()); match &signal1 { Some(MoveBlock::Use(blocks, _hashes, ..)) => { assert_eq!(blocks.len(), 1); } _ => panic!("Expected Use signal"), } // Push token 15 which should complete the block (no signals yet) let signal_15 = seq1.push(15); assert!( signal_15.is_none(), "Completing a block should not trigger signals" ); // Push token 16 which should trigger both Promote and Use signals let signal_16 = seq1.push(16); assert!(signal_16.is_some()); let signal_16 = signal_16.unwrap(); assert_eq!(signal_16.len(), 2); // First signal should be Promote for the previous block match &signal_16[0] { MoveBlock::Promote(_, _, parent_hash, _hash, ..) => { assert_eq!(*parent_hash, None); } _ => panic!("Expected Promote signal as second signal"), } // Second signal should be Use for new partial block match &signal_16[1] { MoveBlock::Use(blocks, _hashes, ..) => { assert_eq!(blocks.len(), 1); assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); } _ => panic!("Expected Use signal as first signal"), } // Verify state after pushing tokens assert_eq!(seq1.unique_blocks().len(), 2); // One full block and one partial block assert_eq!(seq1.len(), 17); assert_eq!(seq1.len() % seq1.block_size(), 1); // Create another sequence with block size 16 initialized with tokens [0..17] let extended_tokens: Vec = (0..16).collect(); let (mut seq2, _) = ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), true); seq2.push(16); seq2.pop(); seq2.push(16); // Simplified assertions assert_eq!( seq1.unique_blocks()[0], seq2.unique_blocks()[0], "First blocks should be the same" ); assert_ne!( seq1.unique_blocks()[1], seq2.unique_blocks()[1], "Second blocks should be different" ); // Reset partial block on seq1 and push back token 16 seq1.push(17); seq1.pop(); seq1.pop(); seq1.push(16); // Now push tokens 17..32 to both sequences for token in 17..33 { seq1.push(token); seq2.push(token); } // Both sequences should now have 2 blocks: // 1. FullBlock for tokens 0-15 // 2. FullBlock for tokens 16-31 // 3. No partial block since there are no remaining tokens assert_eq!( seq1.unique_blocks().len(), 3, "seq1 should have exactly 3 blocks" ); assert_eq!( seq2.unique_blocks().len(), 3, "seq2 should have exactly 3 blocks" ); assert_eq!( seq1.len() % seq1.block_size(), 1, "seq1 should have 1 partial token" ); assert_eq!( seq2.len() % seq2.block_size(), 1, "seq2 should have 1 partial token" ); // Verify that both sequences have identical blocks up to the second position assert_eq!( &seq1.unique_blocks()[0..2], &seq2.unique_blocks()[0..2], "First two blocks should be identical" ); // Push tokens 34..47 to seq1 for token in 33..48 { seq1.push(token); } // Push token 48 and get the signal - this completes the block and triggers signals let signal = seq1.push(48); let signal = signal.unwrap(); // Check that signal[0] is promote match &signal[0] { MoveBlock::Promote(_, _, parent_hash, _hash, ..) => { // Check that the parent_hash matches unique_blocks[1], which should be a full block if let UniqueBlock::FullBlock(expected_hash) = seq1.unique_blocks()[1] { assert_eq!( *parent_hash, Some(expected_hash), "Parent hash should match unique_blocks[1]" ); } else { panic!("unique_blocks[1] should be a full block"); } } _ => panic!("Expected Promote signal as first signal"), } // Reset seq1 and check that it equals the original clone let free_signals = seq1.reset_with_signal(); // 49 - 15 generated tokens assert_eq!(seq1.generated_tokens(), 34); // Verify the reset signals include proper cleanup events assert!(!free_signals.is_empty()); } #[test] fn test_active_sequence_generate_signals() { // Create a sequence with block size 16, max_output_tokens 4, initialized with tokens [0..14) let initial_tokens: Vec = (0..14).collect(); let (mut seq, signal) = ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), true); // Initial signal - should have received a Use signal for the partial block assert!(signal.is_some()); match signal { Some(MoveBlock::Use(blocks, _hashes, ..)) => { assert_eq!(blocks.len(), 1); assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); } _ => panic!("Expected Use signal for the initial partial block"), } // Generate first two tokens - should not trigger new signals seq.generate(); let signals_first = seq.generate(); assert_eq!(signals_first.len(), 0); // Generate third token - this fills the block and should trigger both Promote and Use signals let signals_second = seq.generate(); assert_eq!(signals_second.len(), 2); // First signal should be Promote match &signals_second[0] { MoveBlock::Promote(_, _, parent_hash, _hash, ..) => { assert_eq!(*parent_hash, None); } _ => panic!("Expected Promote signal as first signal after second token"), } // Second signal should be Use for new partial block match &signals_second[1] { MoveBlock::Use(blocks, _hashes, ..) => { assert_eq!(blocks.len(), 1); assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); } _ => panic!("Expected Use signal as second signal after second token"), } // Generate fourth token - should not trigger new signals as it's adding to partial block let signals_third = seq.generate(); assert_eq!(signals_third.len(), 0); // Generate last token - we reach max_output_tokens, should trigger Destroy and Deref signals let signals_last = seq.generate(); assert_eq!(signals_last.len(), 2); // First signal should be Destroy for the partial block match &signals_last[0] { MoveBlock::Destroy(blocks) => { assert_eq!(blocks.len(), 1); assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_))); } _ => panic!("Expected Destroy signal for partial block after fourth token"), } // Second signal should be Deref for the full block match &signals_last[1] { MoveBlock::Deref(blocks) => { assert_eq!(blocks.len(), 1); assert!(matches!(blocks[0], UniqueBlock::FullBlock(_))); } _ => panic!("Expected Deref signal for full block after fourth token"), } } }