Unverified Commit 03c160af authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: vllm mock workers, Rusty skeleton (#1033)


Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
parent 84377e5d
......@@ -28,6 +28,7 @@ pub mod hub;
pub mod key_value_store;
pub mod kv_router;
pub use kv_router::DEFAULT_KV_BLOCK_SIZE;
pub mod mocker;
pub mod model_card;
pub mod model_type;
pub mod preprocessor;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod evictor;
pub mod kv_manager;
pub mod protocols;
pub mod scheduler;
pub mod sequence;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::cmp::Eq;
use std::collections::{HashMap, VecDeque};
use std::hash::Hash;
use std::time::Instant;
/// An LRU evictor that maintains objects and evicts them based on their
/// last accessed time. Implements a "lazy" eviction mechanism where:
/// 1. The priority queue does not immediately reflect updates or removes
/// 2. Objects are pushed to the queue in order of increasing priority (older objects first)
/// 3. The user must ensure objects are added in correct priority (temporal order)
/// 4. Remove and update operations are lazy - entries remain in the queue until
/// they are either evicted or cleaned up during maintenance
#[derive(Debug)]
pub struct LRUEvictor<T: Clone + Eq + Hash> {
free_table: HashMap<T, f64>,
priority_queue: VecDeque<(T, f64)>,
cleanup_threshold: usize,
start_time: Instant,
}
impl<T: Clone + Eq + Hash> Default for LRUEvictor<T> {
fn default() -> Self {
Self {
free_table: HashMap::new(),
priority_queue: VecDeque::new(),
cleanup_threshold: 50,
start_time: Instant::now(),
}
}
}
impl<T: Clone + Eq + Hash> LRUEvictor<T> {
/// Create a new LRUEvictor with the default cleanup threshold
pub fn new(cleanup_threshold: usize) -> Self {
Self {
cleanup_threshold,
..Default::default()
}
}
/// Get the current timestamp as seconds since initialization
pub fn current_timestamp(&self) -> f64 {
self.start_time.elapsed().as_secs_f64()
}
/// Get an iterator over the keys in the evictor
pub fn keys(&self) -> std::collections::hash_map::Keys<'_, T, f64> {
self.free_table.keys()
}
/// Insert or update an object in the evictor with current timestamp
pub fn insert(&mut self, object: T) {
let timestamp = self.current_timestamp();
self._insert(object, timestamp);
}
/// Check if the evictor contains the given object
pub fn contains(&self, object: &T) -> bool {
self.free_table.contains_key(object)
}
/// Evict an object based on LRU policy
/// Returns the evicted object or None if no objects are available
pub fn evict(&mut self) -> Option<T> {
if self.free_table.is_empty() {
return None;
}
while let Some((object, last_accessed)) = self.priority_queue.pop_front() {
let Some(&current_last_accessed) = self.free_table.get(&object) else {
continue; // entry is already removed
};
if current_last_accessed == last_accessed {
self.free_table.remove(&object);
return Some(object);
} // otherwise entry is stale
}
None
}
/// Insert or update an object in the evictor
fn _insert(&mut self, object: T, last_accessed: f64) {
self.free_table.insert(object.clone(), last_accessed);
self.priority_queue.push_back((object, last_accessed));
self.cleanup_if_necessary();
}
/// Remove an object from the evictor
/// We don't remove from the priority queue immediately, as that would be inefficient
/// Outdated entries will be filtered out during eviction or cleanup
pub fn remove(&mut self, object: &T) -> bool {
self.free_table.remove(object).is_some()
}
/// Get the number of objects in the evictor
pub fn len(&self) -> usize {
self.free_table.len()
}
/// Check if the evictor is empty
pub fn is_empty(&self) -> bool {
self.free_table.is_empty()
}
/// Check if cleanup is necessary and perform it if needed
fn cleanup_if_necessary(&mut self) {
if self.priority_queue.len() > self.cleanup_threshold * self.free_table.len() {
self.cleanup();
}
}
/// Clean up the priority queue by removing outdated entries
fn cleanup(&mut self) {
let mut new_priority_queue = VecDeque::new();
for (object, timestamp) in self.priority_queue.drain(..) {
let Some(&current_timestamp) = self.free_table.get(&object) else {
continue;
};
if current_timestamp == timestamp {
new_priority_queue.push_back((object, timestamp));
}
}
self.priority_queue = new_priority_queue;
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
fn test_lru_evictor_eviction_order(#[case] threshold: usize) {
// Create a new LRUEvictor with the given cleanup threshold
let mut evictor = LRUEvictor::<i32>::new(threshold);
// Add items in the specified order with small delays between each
evictor.insert(4);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(3);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(2);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(1);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(5);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(1); // Updates timestamp for 1
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(4); // Updates timestamp for 4
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(2); // Updates timestamp for 2
// Verify the eviction order
println!("Testing with threshold {}", threshold);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 3);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 5);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 1);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 4);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 2);
let evicted = evictor.evict();
assert_eq!(evicted, None);
assert_eq!(evictor.len(), 0);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! # KV Manager
//! A synchronous implementation of a block manager that handles MoveBlock signals for caching KV blocks.
//!
//! ## Block Operations
//! The KV manager processes four types of MoveBlock signals:
//!
//! ### Use
//! - Checks if block exists in active pool → increment reference count
//! - If in inactive pool → move to active pool
//! - If neither → try evicting from inactive pool to make room
//! - If inactive pool is empty → pre-empt the oldest running request
//!
//! ### Destroy
//! - Removes the block from the active pool
//!
//! ### Deref
//! - Decrements reference count of a block in active pool
//! - If count reaches zero → move block to inactive pool
//!
//! ### Promote
//! - Converts a partial block (uuid) into a full block (global block hash)
//!
//! ## Preemption
//! If a Use operation fails (typically due to insufficient space), a false boolean signal
//! is returned to the scheduler for preemption. Initial KV block allocations for new requests
//! should not fail due to the watermark checking.
//!
//! ## NOTE
//! For simplicity (or non-simplicity), reference counting is tracked manually instead of using
//! the more idiomatic built-in Arc reference counter. This can be considered a shadow / mirror
//! implementation of the main block manager.
use crate::mocker::evictor::LRUEvictor;
use crate::mocker::protocols::{MoveBlock, PrefillCost, UniqueBlock};
use crate::mocker::sequence::ActiveSequence;
use derive_getters::Getters;
use std::collections::{HashMap, HashSet};
#[derive(Getters)]
pub struct KvManager {
#[getter(copy)]
max_capacity: usize,
#[getter(copy)]
block_size: usize,
active_blocks: HashMap<UniqueBlock, usize>,
inactive_blocks: LRUEvictor<UniqueBlock>,
all_blocks: HashSet<UniqueBlock>,
}
impl KvManager {
pub fn new(max_capacity: usize, block_size: usize) -> Self {
let active_blocks = HashMap::new();
let inactive_blocks = LRUEvictor::default();
let all_blocks = HashSet::new();
KvManager {
max_capacity,
block_size,
active_blocks,
inactive_blocks,
all_blocks,
}
}
/// Process a MoveBlock instruction synchronously
pub fn process(&mut self, event: &MoveBlock) -> bool {
match event {
MoveBlock::Use(hashes, _) => {
for hash in hashes {
// First check if it already exists in active blocks
if let Some(ref_count) = self.active_blocks.get_mut(hash) {
// Block already active, just increment reference count
*ref_count += 1;
continue;
}
// Then check if it exists in inactive and move it to active if found
if self.inactive_blocks.remove(hash) {
// Insert into active with reference count 1
self.active_blocks.insert(hash.clone(), 1);
continue;
}
// Get counts for capacity check
let active_count = self.active_blocks.len();
let inactive_count = self.inactive_blocks.len();
// If at max capacity, evict the oldest entry from inactive blocks
if active_count + inactive_count >= self.max_capacity {
if let Some(evicted) = self.inactive_blocks.evict() {
// Remove evicted block from all_blocks
self.all_blocks.remove(&evicted);
} else {
// Cannot evict block, meaning no free blocks left in inactive pool
// Send a signal, scheduler would expect to handle preemption upon receiving this
return false;
}
}
// Now insert the new block in active blocks with reference count 1
self.active_blocks.insert(hash.clone(), 1);
// Add to all_blocks as it's a new block
self.all_blocks.insert(hash.clone());
}
}
MoveBlock::Destroy(hashes) => {
// Loop in inverse direction
for hash in hashes.iter().rev() {
self.active_blocks.remove(hash).unwrap();
// Remove from all_blocks when destroyed
assert!(self.all_blocks.remove(hash));
}
}
MoveBlock::Deref(hashes) => {
// Loop in inverse direction
for hash in hashes.iter().rev() {
// Decrement reference count and check if we need to move to inactive
if let Some(ref_count) = self.active_blocks.get_mut(hash) {
if *ref_count == 0 {
panic!("Negative reference count would be encountered after Deref.");
}
*ref_count -= 1;
// If reference count reaches zero, remove from active and move to inactive
if *ref_count == 0 {
self.active_blocks.remove(hash);
// Use the LRUEvictor's timing functionality
self.inactive_blocks.insert(hash.clone());
}
}
}
}
MoveBlock::Promote(uuid, hash) => {
let uuid_block = UniqueBlock::PartialBlock(*uuid);
let hash_block = UniqueBlock::FullBlock(*hash);
let Some(ref_count) = self.active_blocks.remove(&uuid_block) else {
let in_all_blocks = self.all_blocks.contains(&uuid_block);
panic!(
"Missing active block for promotion: {:?}. Block still exists: {}",
uuid_block, in_all_blocks
);
};
// Replace with hash block, keeping the same reference count
self.active_blocks.insert(hash_block.clone(), ref_count);
// Update all_blocks
assert!(self.all_blocks.remove(&uuid_block));
self.all_blocks.insert(hash_block);
}
}
// Return true if we made it this far
true
}
/// Get the count of blocks in the input list that aren't in all_blocks
pub fn probe_new_blocks(&self, blocks: &[UniqueBlock]) -> usize {
blocks
.iter()
.filter(|&block| !self.all_blocks.contains(block))
.count()
}
/// Get the current capacity (active blocks + inactive blocks)
pub fn current_capacity(&self) -> usize {
let active = self.active_blocks.len();
let inactive = self.inactive_blocks.len();
active + inactive
}
/// Get the current capacity as a percentage of the maximum capacity
pub fn current_capacity_perc(&self) -> f64 {
let current = self.current_capacity() as f64;
current / self.max_capacity as f64
}
/// Get the number of active blocks
pub fn num_active_blocks(&self) -> usize {
self.active_blocks.len()
}
/// Get the number of inactive blocks
pub fn num_inactive_blocks(&self) -> usize {
self.inactive_blocks.len()
}
/// Get the keys of inactive blocks
pub fn get_inactive_blocks(&self) -> Vec<&UniqueBlock> {
self.inactive_blocks.keys().collect()
}
/// Get the keys of active blocks
pub fn get_active_blocks(&self) -> Vec<&UniqueBlock> {
self.active_blocks.keys().collect()
}
/// Check if a sequence can be scheduled and calculate cost if possible
pub fn try_schedule(
&self,
sequence: &ActiveSequence,
watermark: f64,
tokens_budget: usize,
) -> Option<PrefillCost> {
// Return None immediately if tokens_budget is 0
if tokens_budget == 0 {
return None;
}
// Get unique blocks from the sequence
let unique_blocks = sequence.unique_blocks();
// Get the count of new blocks
let new_blocks = self.probe_new_blocks(unique_blocks);
// Calculate current usage and available capacity
let active_count = self.active_blocks.len();
// Check if we can schedule based on the watermark
if (active_count + new_blocks) as f64 > (1.0 - watermark) * self.max_capacity as f64 {
return None;
}
// Calculate overlap blocks
let overlap_blocks = unique_blocks.len() - new_blocks;
// Calculate new tokens
let new_tokens = sequence.num_input_tokens() - overlap_blocks * self.block_size;
// // Print the full equation with actual values substituted
// println!("{} = {} - ({} * {}) (new_tokens = num_input_tokens - overlap_blocks * block_size)",
// new_tokens,
// sequence.num_input_tokens(),
// overlap_blocks,
// self.block_size);
// Return None if new_tokens exceeds tokens_budget
if new_tokens > tokens_budget {
return None;
}
// Calculate prefill compute
let prefill_compute =
new_tokens as f64 * (new_tokens + overlap_blocks * self.block_size) as f64;
Some(PrefillCost {
new_tokens,
prefill_compute,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_failure_on_max_capacity() {
// Create a KvManager with 10 blocks capacity
let mut manager = KvManager::new(10, 16);
// Helper function to use multiple blocks that returns the response
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) -> bool {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Use(blocks, None))
}
// First use 10 blocks (0 to 9) in a batch
let response = use_blocks(&mut manager, (0..10).collect());
assert!(response, "Expected success response");
// Verify we are at capacity
assert_eq!(manager.current_capacity(), 10);
// The 11th block should return false, not panic
let response = use_blocks(&mut manager, vec![10]);
assert!(
!response,
"Expected failure response when exceeding max capacity"
);
}
#[test]
// This is taken directly from the example in the vllm v1 prefix caching docs
fn test_block_lifecycle_stringent() {
// Create a KvManager with 10 blocks capacity
let mut manager = KvManager::new(10, 16);
// Helper function to use multiple blocks
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Use(blocks, None));
}
// Helper function to destroy multiple blocks
fn destroy_blocks(manager: &mut KvManager, ids: Vec<u64>) {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Destroy(blocks));
}
// Helper function to deref multiple blocks
fn deref_blocks(manager: &mut KvManager, ids: Vec<u64>) {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Deref(blocks));
}
// Helper function to check if active blocks contain expected blocks with expected ref counts
fn assert_active_blocks(manager: &KvManager, expected_blocks: &[(u64, usize)]) {
assert_eq!(
manager.active_blocks().len(),
expected_blocks.len(),
"Active blocks count doesn't match expected"
);
for &(id, ref_count) in expected_blocks {
let block = UniqueBlock::FullBlock(id);
assert!(
manager.active_blocks().contains_key(&block),
"Block {} not found in active blocks",
id
);
assert_eq!(
manager.active_blocks().get(&block),
Some(&ref_count),
"Block {} has wrong reference count",
id
);
}
}
// Helper function to check if inactive blocks contain expected blocks
fn assert_inactive_blocks(
manager: &KvManager,
expected_size: usize,
expected_blocks: &[u64],
) {
let inactive_blocks = manager.get_inactive_blocks();
let inactive_blocks_count = manager.inactive_blocks().len();
assert_eq!(
inactive_blocks_count, expected_size,
"Inactive blocks count doesn't match expected"
);
for &id in expected_blocks {
let block = UniqueBlock::FullBlock(id);
assert!(
inactive_blocks.iter().any(|&b| *b == block),
"Block {} not found in inactive blocks",
id
);
}
}
// First use blocks 0, 1, 2, 3, 4 in a batch
use_blocks(&mut manager, (0..5).collect());
// Then use blocks 0, 1, 5, 6 in a batch
use_blocks(&mut manager, vec![0, 1, 5, 6]);
// Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2
assert_active_blocks(
&manager,
&[(0, 2), (1, 2), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1)],
);
// Now destroy block 4
destroy_blocks(&mut manager, vec![4]);
// And deref blocks 3, 2, 1, 0 in this order as a batch
deref_blocks(&mut manager, vec![0, 1, 2, 3]);
// Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2
assert_inactive_blocks(&manager, 2, &[3, 2]);
assert_active_blocks(&manager, &[(0, 1), (1, 1), (5, 1), (6, 1)]);
// Now destroy block 6
destroy_blocks(&mut manager, vec![6]);
// And deref blocks 5, 1, 0 as a batch
deref_blocks(&mut manager, vec![0, 1, 5]);
// Check that the inactive_blocks is size 5, and contains 0, 1, 2, 3, 5
assert_inactive_blocks(&manager, 5, &[0, 1, 2, 3, 5]);
assert_active_blocks(&manager, &[]);
// Now use 0, 1, 2, 7, 8, 9 as a batch
use_blocks(&mut manager, vec![0, 1, 2, 7, 8, 9]);
// Check that the inactive_blocks is size 2, and contains 3 and 5
assert_inactive_blocks(&manager, 2, &[3, 5]);
assert_active_blocks(&manager, &[(0, 1), (1, 1), (2, 1), (7, 1), (8, 1), (9, 1)]);
// Test the new_blocks method - only block 4 should be new out of [0,1,2,3,4]
let blocks_to_check: Vec<UniqueBlock> = vec![0, 1, 2, 3, 4]
.into_iter()
.map(UniqueBlock::FullBlock)
.collect();
assert_eq!(manager.probe_new_blocks(&blocks_to_check), 1);
// Now use blocks 10, 11, 12 as a batch
use_blocks(&mut manager, vec![10, 11, 12]);
// Check that the inactive_blocks is size 1 and contains only 5
assert_inactive_blocks(&manager, 1, &[5]);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use uuid::Uuid;
pub type Token = u32;
pub type LocalBlockHash = u64;
/// A global hash identifier for blocks
pub type GlobalHash = u64;
pub type NumBlocks = usize;
/// Represents an active block in the cache with a reference count
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
pub enum UniqueBlock {
/// Block identified by UUID
PartialBlock(Uuid),
/// Block identified by hash
FullBlock(GlobalHash),
}
impl Default for UniqueBlock {
fn default() -> Self {
// Generate a random UUID when default is used
Self::PartialBlock(Uuid::new_v4())
}
}
/// Represents different block movement operations in the cache
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlock {
Use(Vec<UniqueBlock>, Option<f64>),
Destroy(Vec<UniqueBlock>),
Deref(Vec<UniqueBlock>),
Promote(Uuid, GlobalHash),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DirectRequest {
pub tokens: Vec<Token>,
pub max_output_tokens: usize,
pub uuid: Option<Uuid>,
}
/// Represents the cost of prefilling content in the cache
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefillCost {
pub new_tokens: usize,
pub prefill_compute: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unique_block_default_uniqueness() {
// Create 10 default UniqueBlock instances
let blocks: Vec<UniqueBlock> = (0..10).map(|_| UniqueBlock::default()).collect();
// Extract UUIDs from each block
let mut uuids = Vec::new();
for block in blocks {
match block {
UniqueBlock::PartialBlock(uuid) => uuids.push(uuid),
_ => panic!("Expected UuidIdentifier variant"),
}
}
// Check that all UUIDs are unique by comparing each with every other
for i in 0..uuids.len() {
for j in i + 1..uuids.len() {
assert_ne!(
uuids[i], uuids[j],
"UUID at index {} and {} are identical",
i, j
);
}
}
}
}
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::mocker::protocols::{MoveBlock, UniqueBlock};
use crate::tokens::{TokenBlockSequence, Tokens};
use derive_getters::Getters;
use rand::random;
use uuid;
/// Create unique blocks from a TokenBlockSequence
fn create_unique_blocks_from_sequence(
tokens: &TokenBlockSequence,
uuid: Option<uuid::Uuid>,
block_size: usize,
) -> Vec<UniqueBlock> {
let mut unique_blocks: Vec<UniqueBlock> = tokens
.blocks()
.iter()
.map(|block| UniqueBlock::FullBlock(block.sequence_hash()))
.collect();
// Only push the partial block if tokens count isn't a multiple of block_size
if tokens.total_tokens() % block_size != 0 {
unique_blocks.push(match uuid {
Some(uuid) => UniqueBlock::PartialBlock(uuid),
None => UniqueBlock::default(),
});
}
unique_blocks
}
/// A sequence that is actively being built, with the ability to add tokens and commit to hashes
/// TODO: reuse tokens
#[derive(Debug, Getters)]
pub struct ActiveSequence {
unique_blocks: Vec<UniqueBlock>,
tokens: TokenBlockSequence,
#[getter(copy)]
block_size: usize,
#[getter(copy)]
chunk_size: usize, // TODO: not actually used
#[getter(copy)]
max_output_tokens: usize,
#[getter(copy)]
generated_tokens: usize,
#[getter(copy)]
num_input_tokens: usize,
creation_signal: Option<MoveBlock>,
}
impl ActiveSequence {
/// Create a new ActiveSequence instance with the provided tokens
pub fn new(
tokens: Vec<u32>,
max_output_tokens: usize,
block_size: Option<usize>,
chunk_size: Option<usize>,
) -> Self {
let block_size = block_size.unwrap_or(64);
assert!(block_size > 1, "block_size must be greater than 1");
let chunk_size = chunk_size.unwrap_or(256);
let num_input_tokens = tokens.len();
let tokens = Tokens::from(tokens).into_sequence(block_size, None);
let unique_blocks = create_unique_blocks_from_sequence(&tokens, None, block_size);
let creation_signal = Some(MoveBlock::Use(unique_blocks.clone(), None));
Self {
unique_blocks,
tokens,
block_size,
chunk_size,
max_output_tokens,
generated_tokens: 0,
num_input_tokens,
creation_signal,
}
}
pub fn extra_tokens(&self) -> usize {
self.len() % self.block_size
}
pub fn len(&self) -> usize {
self.tokens.total_tokens()
}
pub fn is_empty(&self) -> bool {
self.tokens.total_tokens() == 0
}
/// Create a new ActiveSequence instance and return the creation signal
pub fn new_with_signal(
tokens: Vec<u32>,
max_output_tokens: usize,
block_size: Option<usize>,
chunk_size: Option<usize>,
) -> (Self, Option<MoveBlock>) {
let mut sequence = Self::new(tokens, max_output_tokens, block_size, chunk_size);
let signal = sequence.creation_signal.take();
(sequence, signal)
}
/// Push a token to the sequence
pub fn push(&mut self, token: u32) -> Option<Vec<MoveBlock>> {
self.tokens.append(token).expect("Token push failed.");
self.generated_tokens += 1;
if self.len() % self.block_size != 1 {
return None;
}
// Add a partial block for the first token in a new partial sequence
// Send Use signal (to allocate space for this new generation block)
let mut signals = Vec::new();
// Replace last partial block with full block if it exists
if let Some(UniqueBlock::PartialBlock(uuid)) = self.unique_blocks.last().cloned() {
let last_block_hash = self.tokens.last_complete_block().unwrap().sequence_hash();
self.unique_blocks.pop();
self.unique_blocks
.push(UniqueBlock::FullBlock(last_block_hash));
signals.push(MoveBlock::Promote(uuid, last_block_hash));
}
let new_partial_block = UniqueBlock::default();
self.unique_blocks.push(new_partial_block.clone());
signals.push(MoveBlock::Use(vec![new_partial_block], None));
Some(signals)
}
/// Generate a random token, push it to the sequence, and increment generation count.
///
/// This function:
/// - Generates a random token and adds it to the current sequence
/// - Acquires a new partial block if needed or promotes an existing partial block to a full block
/// - Returns appropriate signals for the KvManager to process
///
/// # Panics
///
/// Calling this function when max_output_tokens has already been reached will cause a panic.
/// Always check `generated_tokens < max_output_tokens` before calling this method.
pub fn generate(&mut self) -> Vec<MoveBlock> {
// Assert that we haven't reached the maximum output tokens
assert!(
self.generated_tokens < self.max_output_tokens,
"Cannot generate more tokens: reached max_output_tokens limit"
);
// Generate a random token
let token = random::<u32>();
// Collect signals
let mut signals = Vec::new();
// Push the token to the sequence and collect any signals
if let Some(move_blocks) = self.push(token) {
signals.extend(move_blocks);
}
// Check if we've reached the limit after pushing
if self.generated_tokens != self.max_output_tokens {
return signals;
}
// Free all blocks when we reach max tokens
signals.extend(self.free_signal());
signals
}
/// Free all blocks, generating appropriate signals for each block type
pub fn free_signal(&self) -> Vec<MoveBlock> {
self.unique_blocks
.iter()
.rev()
.map(|block| match block {
UniqueBlock::PartialBlock(uuid) => {
MoveBlock::Destroy(vec![UniqueBlock::PartialBlock(*uuid)])
}
UniqueBlock::FullBlock(hash) => {
MoveBlock::Deref(vec![UniqueBlock::FullBlock(*hash)])
}
})
.collect()
}
/// Reset the sequence to its initial state and return the free signals from freeing current blocks
/// maintaining the uuid of the last partial block
pub fn reset_with_signal(&mut self) -> Vec<MoveBlock> {
let free_signal = self.free_signal();
self.tokens.truncate(self.num_input_tokens).unwrap();
self.unique_blocks =
create_unique_blocks_from_sequence(&self.tokens, None, self.block_size);
self.generated_tokens = 0;
self.creation_signal = Some(MoveBlock::Use(self.unique_blocks.clone(), None));
free_signal
}
/// Pops last token in the sequence.
pub fn pop(&mut self) {
self.tokens.pop();
self.generated_tokens = self.generated_tokens.saturating_sub(1);
// Reverts to the last full block
if self.tokens.total_tokens() % self.block_size == 0 {
self.unique_blocks.pop();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_active_sequence_push() {
// Create a sequence with block size 16 initialized with tokens [0..15]
let initial_tokens: Vec<u32> = (0..15).collect();
let (mut seq1, signal1) =
ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), Some(256));
assert_eq!(seq1.num_input_tokens(), 15);
assert_eq!(seq1.len(), 15);
// Check that we got a Use signal
assert!(signal1.is_some());
match &signal1 {
Some(MoveBlock::Use(blocks, _)) => {
assert_eq!(blocks.len(), 1);
}
_ => panic!("Expected Use signal"),
}
// Push token 15 which should complete the block (no signals yet)
let signal_15 = seq1.push(15);
assert!(
signal_15.is_none(),
"Completing a block should not trigger signals"
);
// Push token 16 which should trigger both Promote and Use signals
let signal_16 = seq1.push(16);
assert!(signal_16.is_some());
let signal_16 = signal_16.unwrap();
assert_eq!(signal_16.len(), 2);
// Second signal should be Use for new partial block
match &signal_16[1] {
MoveBlock::Use(blocks, _) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected Use signal as first signal"),
}
// First signal should be Promote for the previous block
match &signal_16[0] {
MoveBlock::Promote(uuid, _) => {
// The uuid is generated dynamically, so we just check it exists
let _ = uuid;
}
_ => panic!("Expected Promote signal as second signal"),
}
// Verify state after pushing tokens
assert_eq!(seq1.unique_blocks().len(), 2); // One full block and one partial block
assert_eq!(seq1.len(), 17);
assert_eq!(seq1.len() % seq1.block_size(), 1);
// Create another sequence with block size 16 initialized with tokens [0..17]
let extended_tokens: Vec<u32> = (0..16).collect();
let (mut seq2, _) =
ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), Some(256));
seq2.push(16);
seq2.pop();
seq2.push(16);
// Simplified assertions
assert_eq!(
seq1.unique_blocks()[0],
seq2.unique_blocks()[0],
"First blocks should be the same"
);
assert_ne!(
seq1.unique_blocks()[1],
seq2.unique_blocks()[1],
"Second blocks should be different"
);
// Reset partial block on seq1 and push back token 16
seq1.push(17);
seq1.pop();
seq1.pop();
seq1.push(16);
// Now push tokens 17..32 to both sequences
for token in 17..33 {
seq1.push(token);
seq2.push(token);
}
// Both sequences should now have 2 blocks:
// 1. FullBlock for tokens 0-15
// 2. FullBlock for tokens 16-31
// 3. No partial block since there are no remaining tokens
assert_eq!(
seq1.unique_blocks().len(),
3,
"seq1 should have exactly 3 blocks"
);
assert_eq!(
seq2.unique_blocks().len(),
3,
"seq2 should have exactly 3 blocks"
);
assert_eq!(
seq1.len() % seq1.block_size(),
1,
"seq1 should have 1 partial token"
);
assert_eq!(
seq2.len() % seq2.block_size(),
1,
"seq2 should have 1 partial token"
);
// Verify that both sequences have identical blocks up to the second position
assert_eq!(
&seq1.unique_blocks()[0..2],
&seq2.unique_blocks()[0..2],
"First two blocks should be identical"
);
// Reset seq1 and check that it equals the original clone
let free_signals = seq1.reset_with_signal();
// Verify the reset signals include proper cleanup events
assert!(!free_signals.is_empty());
}
#[test]
fn test_active_sequence_generate_signals() {
// Create a sequence with block size 16, max_output_tokens 4, initialized with tokens [0..14)
let initial_tokens: Vec<u32> = (0..14).collect();
let (mut seq, signal) =
ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), Some(256));
// Initial signal - should have received a Use signal for the partial block
assert!(signal.is_some());
match signal {
Some(MoveBlock::Use(blocks, _)) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected Use signal for the initial partial block"),
}
// Generate first two tokens - should not trigger new signals
seq.generate();
let signals_first = seq.generate();
assert_eq!(signals_first.len(), 0);
// Generate third token - this fills the block and should trigger both Promote and Use signals
let signals_second = seq.generate();
assert_eq!(signals_second.len(), 2);
// First signal should be Use for new partial block
match &signals_second[1] {
MoveBlock::Use(blocks, _) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected Use signal as second signal after second token"),
}
// Second signal should be Promote
match &signals_second[0] {
MoveBlock::Promote(uuid, hash) => {
// The uuid and hash values are generated dynamically, so we just check the event type
let _ = uuid;
let _ = hash;
}
_ => panic!("Expected Promote signal as first signal after second token"),
}
// Generate fourth token - should not trigger new signals as it's adding to partial block
let signals_third = seq.generate();
assert_eq!(signals_third.len(), 0);
// Generate last token - we reach max_output_tokens, should trigger Destroy and Deref signals
let signals_last = seq.generate();
assert_eq!(signals_last.len(), 2);
// First signal should be Destroy for the partial block
match &signals_last[0] {
MoveBlock::Destroy(blocks) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected Destroy signal for partial block after fourth token"),
}
// Second signal should be Deref for the full block
match &signals_last[1] {
MoveBlock::Deref(blocks) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::FullBlock(_)));
}
_ => panic!("Expected Deref signal for full block after fourth token"),
}
}
}
......@@ -188,7 +188,7 @@ pub enum TokenBlockError {
///
/// This structure accumulates tokens until it reaches the specified `block_size`,
/// at which point it can be [`commit`](PartialTokenBlock::commit)ted into a full [`TokenBlock`].
#[derive(Debug)] // No Clone: intended to be unique within a sequence
#[derive(Debug, PartialEq)] // No Clone: intended to be unique within a sequence
pub struct PartialTokenBlock {
tokens: Tokens,
block_size: usize,
......@@ -478,7 +478,7 @@ impl TokenBlock {
/// - [`BlockHash`]: Hash of tokens within a single block (seeded by [`SaltHash`]).
/// - [`SequenceHash`]: Hash combining the previous block's [`SequenceHash`] and the current
/// block's [`BlockHash`] (also seeded by [`SaltHash`]).
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub struct TokenBlockSequence {
blocks: Vec<TokenBlock>,
current_block: PartialTokenBlock,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment