Unverified Commit 03c160af authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: vllm mock workers, Rusty skeleton (#1033)


Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
parent 84377e5d
...@@ -28,6 +28,7 @@ pub mod hub; ...@@ -28,6 +28,7 @@ pub mod hub;
pub mod key_value_store; pub mod key_value_store;
pub mod kv_router; pub mod kv_router;
pub use kv_router::DEFAULT_KV_BLOCK_SIZE; pub use kv_router::DEFAULT_KV_BLOCK_SIZE;
pub mod mocker;
pub mod model_card; pub mod model_card;
pub mod model_type; pub mod model_type;
pub mod preprocessor; pub mod preprocessor;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod evictor;
pub mod kv_manager;
pub mod protocols;
pub mod scheduler;
pub mod sequence;
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::cmp::Eq;
use std::collections::{HashMap, VecDeque};
use std::hash::Hash;
use std::time::Instant;
/// An LRU evictor that maintains objects and evicts them based on their
/// last accessed time. Implements a "lazy" eviction mechanism where:
/// 1. The priority queue does not immediately reflect updates or removes
/// 2. Objects are pushed to the queue in order of increasing priority (older objects first)
/// 3. The user must ensure objects are added in correct priority (temporal order)
/// 4. Remove and update operations are lazy - entries remain in the queue until
/// they are either evicted or cleaned up during maintenance
#[derive(Debug)]
pub struct LRUEvictor<T: Clone + Eq + Hash> {
free_table: HashMap<T, f64>,
priority_queue: VecDeque<(T, f64)>,
cleanup_threshold: usize,
start_time: Instant,
}
impl<T: Clone + Eq + Hash> Default for LRUEvictor<T> {
fn default() -> Self {
Self {
free_table: HashMap::new(),
priority_queue: VecDeque::new(),
cleanup_threshold: 50,
start_time: Instant::now(),
}
}
}
impl<T: Clone + Eq + Hash> LRUEvictor<T> {
/// Create a new LRUEvictor with the default cleanup threshold
pub fn new(cleanup_threshold: usize) -> Self {
Self {
cleanup_threshold,
..Default::default()
}
}
/// Get the current timestamp as seconds since initialization
pub fn current_timestamp(&self) -> f64 {
self.start_time.elapsed().as_secs_f64()
}
/// Get an iterator over the keys in the evictor
pub fn keys(&self) -> std::collections::hash_map::Keys<'_, T, f64> {
self.free_table.keys()
}
/// Insert or update an object in the evictor with current timestamp
pub fn insert(&mut self, object: T) {
let timestamp = self.current_timestamp();
self._insert(object, timestamp);
}
/// Check if the evictor contains the given object
pub fn contains(&self, object: &T) -> bool {
self.free_table.contains_key(object)
}
/// Evict an object based on LRU policy
/// Returns the evicted object or None if no objects are available
pub fn evict(&mut self) -> Option<T> {
if self.free_table.is_empty() {
return None;
}
while let Some((object, last_accessed)) = self.priority_queue.pop_front() {
let Some(&current_last_accessed) = self.free_table.get(&object) else {
continue; // entry is already removed
};
if current_last_accessed == last_accessed {
self.free_table.remove(&object);
return Some(object);
} // otherwise entry is stale
}
None
}
/// Insert or update an object in the evictor
fn _insert(&mut self, object: T, last_accessed: f64) {
self.free_table.insert(object.clone(), last_accessed);
self.priority_queue.push_back((object, last_accessed));
self.cleanup_if_necessary();
}
/// Remove an object from the evictor
/// We don't remove from the priority queue immediately, as that would be inefficient
/// Outdated entries will be filtered out during eviction or cleanup
pub fn remove(&mut self, object: &T) -> bool {
self.free_table.remove(object).is_some()
}
/// Get the number of objects in the evictor
pub fn len(&self) -> usize {
self.free_table.len()
}
/// Check if the evictor is empty
pub fn is_empty(&self) -> bool {
self.free_table.is_empty()
}
/// Check if cleanup is necessary and perform it if needed
fn cleanup_if_necessary(&mut self) {
if self.priority_queue.len() > self.cleanup_threshold * self.free_table.len() {
self.cleanup();
}
}
/// Clean up the priority queue by removing outdated entries
fn cleanup(&mut self) {
let mut new_priority_queue = VecDeque::new();
for (object, timestamp) in self.priority_queue.drain(..) {
let Some(&current_timestamp) = self.free_table.get(&object) else {
continue;
};
if current_timestamp == timestamp {
new_priority_queue.push_back((object, timestamp));
}
}
self.priority_queue = new_priority_queue;
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
fn test_lru_evictor_eviction_order(#[case] threshold: usize) {
// Create a new LRUEvictor with the given cleanup threshold
let mut evictor = LRUEvictor::<i32>::new(threshold);
// Add items in the specified order with small delays between each
evictor.insert(4);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(3);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(2);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(1);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(5);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(1); // Updates timestamp for 1
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(4); // Updates timestamp for 4
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(2); // Updates timestamp for 2
// Verify the eviction order
println!("Testing with threshold {}", threshold);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 3);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 5);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 1);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 4);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 2);
let evicted = evictor.evict();
assert_eq!(evicted, None);
assert_eq!(evictor.len(), 0);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! # KV Manager
//! A synchronous implementation of a block manager that handles MoveBlock signals for caching KV blocks.
//!
//! ## Block Operations
//! The KV manager processes four types of MoveBlock signals:
//!
//! ### Use
//! - Checks if block exists in active pool → increment reference count
//! - If in inactive pool → move to active pool
//! - If neither → try evicting from inactive pool to make room
//! - If inactive pool is empty → pre-empt the oldest running request
//!
//! ### Destroy
//! - Removes the block from the active pool
//!
//! ### Deref
//! - Decrements reference count of a block in active pool
//! - If count reaches zero → move block to inactive pool
//!
//! ### Promote
//! - Converts a partial block (uuid) into a full block (global block hash)
//!
//! ## Preemption
//! If a Use operation fails (typically due to insufficient space), a false boolean signal
//! is returned to the scheduler for preemption. Initial KV block allocations for new requests
//! should not fail due to the watermark checking.
//!
//! ## NOTE
//! For simplicity (or non-simplicity), reference counting is tracked manually instead of using
//! the more idiomatic built-in Arc reference counter. This can be considered a shadow / mirror
//! implementation of the main block manager.
use crate::mocker::evictor::LRUEvictor;
use crate::mocker::protocols::{MoveBlock, PrefillCost, UniqueBlock};
use crate::mocker::sequence::ActiveSequence;
use derive_getters::Getters;
use std::collections::{HashMap, HashSet};
#[derive(Getters)]
pub struct KvManager {
#[getter(copy)]
max_capacity: usize,
#[getter(copy)]
block_size: usize,
active_blocks: HashMap<UniqueBlock, usize>,
inactive_blocks: LRUEvictor<UniqueBlock>,
all_blocks: HashSet<UniqueBlock>,
}
impl KvManager {
pub fn new(max_capacity: usize, block_size: usize) -> Self {
let active_blocks = HashMap::new();
let inactive_blocks = LRUEvictor::default();
let all_blocks = HashSet::new();
KvManager {
max_capacity,
block_size,
active_blocks,
inactive_blocks,
all_blocks,
}
}
/// Process a MoveBlock instruction synchronously
pub fn process(&mut self, event: &MoveBlock) -> bool {
match event {
MoveBlock::Use(hashes, _) => {
for hash in hashes {
// First check if it already exists in active blocks
if let Some(ref_count) = self.active_blocks.get_mut(hash) {
// Block already active, just increment reference count
*ref_count += 1;
continue;
}
// Then check if it exists in inactive and move it to active if found
if self.inactive_blocks.remove(hash) {
// Insert into active with reference count 1
self.active_blocks.insert(hash.clone(), 1);
continue;
}
// Get counts for capacity check
let active_count = self.active_blocks.len();
let inactive_count = self.inactive_blocks.len();
// If at max capacity, evict the oldest entry from inactive blocks
if active_count + inactive_count >= self.max_capacity {
if let Some(evicted) = self.inactive_blocks.evict() {
// Remove evicted block from all_blocks
self.all_blocks.remove(&evicted);
} else {
// Cannot evict block, meaning no free blocks left in inactive pool
// Send a signal, scheduler would expect to handle preemption upon receiving this
return false;
}
}
// Now insert the new block in active blocks with reference count 1
self.active_blocks.insert(hash.clone(), 1);
// Add to all_blocks as it's a new block
self.all_blocks.insert(hash.clone());
}
}
MoveBlock::Destroy(hashes) => {
// Loop in inverse direction
for hash in hashes.iter().rev() {
self.active_blocks.remove(hash).unwrap();
// Remove from all_blocks when destroyed
assert!(self.all_blocks.remove(hash));
}
}
MoveBlock::Deref(hashes) => {
// Loop in inverse direction
for hash in hashes.iter().rev() {
// Decrement reference count and check if we need to move to inactive
if let Some(ref_count) = self.active_blocks.get_mut(hash) {
if *ref_count == 0 {
panic!("Negative reference count would be encountered after Deref.");
}
*ref_count -= 1;
// If reference count reaches zero, remove from active and move to inactive
if *ref_count == 0 {
self.active_blocks.remove(hash);
// Use the LRUEvictor's timing functionality
self.inactive_blocks.insert(hash.clone());
}
}
}
}
MoveBlock::Promote(uuid, hash) => {
let uuid_block = UniqueBlock::PartialBlock(*uuid);
let hash_block = UniqueBlock::FullBlock(*hash);
let Some(ref_count) = self.active_blocks.remove(&uuid_block) else {
let in_all_blocks = self.all_blocks.contains(&uuid_block);
panic!(
"Missing active block for promotion: {:?}. Block still exists: {}",
uuid_block, in_all_blocks
);
};
// Replace with hash block, keeping the same reference count
self.active_blocks.insert(hash_block.clone(), ref_count);
// Update all_blocks
assert!(self.all_blocks.remove(&uuid_block));
self.all_blocks.insert(hash_block);
}
}
// Return true if we made it this far
true
}
/// Get the count of blocks in the input list that aren't in all_blocks
pub fn probe_new_blocks(&self, blocks: &[UniqueBlock]) -> usize {
blocks
.iter()
.filter(|&block| !self.all_blocks.contains(block))
.count()
}
/// Get the current capacity (active blocks + inactive blocks)
pub fn current_capacity(&self) -> usize {
let active = self.active_blocks.len();
let inactive = self.inactive_blocks.len();
active + inactive
}
/// Get the current capacity as a percentage of the maximum capacity
pub fn current_capacity_perc(&self) -> f64 {
let current = self.current_capacity() as f64;
current / self.max_capacity as f64
}
/// Get the number of active blocks
pub fn num_active_blocks(&self) -> usize {
self.active_blocks.len()
}
/// Get the number of inactive blocks
pub fn num_inactive_blocks(&self) -> usize {
self.inactive_blocks.len()
}
/// Get the keys of inactive blocks
pub fn get_inactive_blocks(&self) -> Vec<&UniqueBlock> {
self.inactive_blocks.keys().collect()
}
/// Get the keys of active blocks
pub fn get_active_blocks(&self) -> Vec<&UniqueBlock> {
self.active_blocks.keys().collect()
}
/// Check if a sequence can be scheduled and calculate cost if possible
pub fn try_schedule(
&self,
sequence: &ActiveSequence,
watermark: f64,
tokens_budget: usize,
) -> Option<PrefillCost> {
// Return None immediately if tokens_budget is 0
if tokens_budget == 0 {
return None;
}
// Get unique blocks from the sequence
let unique_blocks = sequence.unique_blocks();
// Get the count of new blocks
let new_blocks = self.probe_new_blocks(unique_blocks);
// Calculate current usage and available capacity
let active_count = self.active_blocks.len();
// Check if we can schedule based on the watermark
if (active_count + new_blocks) as f64 > (1.0 - watermark) * self.max_capacity as f64 {
return None;
}
// Calculate overlap blocks
let overlap_blocks = unique_blocks.len() - new_blocks;
// Calculate new tokens
let new_tokens = sequence.num_input_tokens() - overlap_blocks * self.block_size;
// // Print the full equation with actual values substituted
// println!("{} = {} - ({} * {}) (new_tokens = num_input_tokens - overlap_blocks * block_size)",
// new_tokens,
// sequence.num_input_tokens(),
// overlap_blocks,
// self.block_size);
// Return None if new_tokens exceeds tokens_budget
if new_tokens > tokens_budget {
return None;
}
// Calculate prefill compute
let prefill_compute =
new_tokens as f64 * (new_tokens + overlap_blocks * self.block_size) as f64;
Some(PrefillCost {
new_tokens,
prefill_compute,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_failure_on_max_capacity() {
// Create a KvManager with 10 blocks capacity
let mut manager = KvManager::new(10, 16);
// Helper function to use multiple blocks that returns the response
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) -> bool {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Use(blocks, None))
}
// First use 10 blocks (0 to 9) in a batch
let response = use_blocks(&mut manager, (0..10).collect());
assert!(response, "Expected success response");
// Verify we are at capacity
assert_eq!(manager.current_capacity(), 10);
// The 11th block should return false, not panic
let response = use_blocks(&mut manager, vec![10]);
assert!(
!response,
"Expected failure response when exceeding max capacity"
);
}
#[test]
// This is taken directly from the example in the vllm v1 prefix caching docs
fn test_block_lifecycle_stringent() {
// Create a KvManager with 10 blocks capacity
let mut manager = KvManager::new(10, 16);
// Helper function to use multiple blocks
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Use(blocks, None));
}
// Helper function to destroy multiple blocks
fn destroy_blocks(manager: &mut KvManager, ids: Vec<u64>) {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Destroy(blocks));
}
// Helper function to deref multiple blocks
fn deref_blocks(manager: &mut KvManager, ids: Vec<u64>) {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Deref(blocks));
}
// Helper function to check if active blocks contain expected blocks with expected ref counts
fn assert_active_blocks(manager: &KvManager, expected_blocks: &[(u64, usize)]) {
assert_eq!(
manager.active_blocks().len(),
expected_blocks.len(),
"Active blocks count doesn't match expected"
);
for &(id, ref_count) in expected_blocks {
let block = UniqueBlock::FullBlock(id);
assert!(
manager.active_blocks().contains_key(&block),
"Block {} not found in active blocks",
id
);
assert_eq!(
manager.active_blocks().get(&block),
Some(&ref_count),
"Block {} has wrong reference count",
id
);
}
}
// Helper function to check if inactive blocks contain expected blocks
fn assert_inactive_blocks(
manager: &KvManager,
expected_size: usize,
expected_blocks: &[u64],
) {
let inactive_blocks = manager.get_inactive_blocks();
let inactive_blocks_count = manager.inactive_blocks().len();
assert_eq!(
inactive_blocks_count, expected_size,
"Inactive blocks count doesn't match expected"
);
for &id in expected_blocks {
let block = UniqueBlock::FullBlock(id);
assert!(
inactive_blocks.iter().any(|&b| *b == block),
"Block {} not found in inactive blocks",
id
);
}
}
// First use blocks 0, 1, 2, 3, 4 in a batch
use_blocks(&mut manager, (0..5).collect());
// Then use blocks 0, 1, 5, 6 in a batch
use_blocks(&mut manager, vec![0, 1, 5, 6]);
// Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2
assert_active_blocks(
&manager,
&[(0, 2), (1, 2), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1)],
);
// Now destroy block 4
destroy_blocks(&mut manager, vec![4]);
// And deref blocks 3, 2, 1, 0 in this order as a batch
deref_blocks(&mut manager, vec![0, 1, 2, 3]);
// Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2
assert_inactive_blocks(&manager, 2, &[3, 2]);
assert_active_blocks(&manager, &[(0, 1), (1, 1), (5, 1), (6, 1)]);
// Now destroy block 6
destroy_blocks(&mut manager, vec![6]);
// And deref blocks 5, 1, 0 as a batch
deref_blocks(&mut manager, vec![0, 1, 5]);
// Check that the inactive_blocks is size 5, and contains 0, 1, 2, 3, 5
assert_inactive_blocks(&manager, 5, &[0, 1, 2, 3, 5]);
assert_active_blocks(&manager, &[]);
// Now use 0, 1, 2, 7, 8, 9 as a batch
use_blocks(&mut manager, vec![0, 1, 2, 7, 8, 9]);
// Check that the inactive_blocks is size 2, and contains 3 and 5
assert_inactive_blocks(&manager, 2, &[3, 5]);
assert_active_blocks(&manager, &[(0, 1), (1, 1), (2, 1), (7, 1), (8, 1), (9, 1)]);
// Test the new_blocks method - only block 4 should be new out of [0,1,2,3,4]
let blocks_to_check: Vec<UniqueBlock> = vec![0, 1, 2, 3, 4]
.into_iter()
.map(UniqueBlock::FullBlock)
.collect();
assert_eq!(manager.probe_new_blocks(&blocks_to_check), 1);
// Now use blocks 10, 11, 12 as a batch
use_blocks(&mut manager, vec![10, 11, 12]);
// Check that the inactive_blocks is size 1 and contains only 5
assert_inactive_blocks(&manager, 1, &[5]);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use uuid::Uuid;
pub type Token = u32;
pub type LocalBlockHash = u64;
/// A global hash identifier for blocks
pub type GlobalHash = u64;
pub type NumBlocks = usize;
/// Represents an active block in the cache with a reference count
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
pub enum UniqueBlock {
/// Block identified by UUID
PartialBlock(Uuid),
/// Block identified by hash
FullBlock(GlobalHash),
}
impl Default for UniqueBlock {
fn default() -> Self {
// Generate a random UUID when default is used
Self::PartialBlock(Uuid::new_v4())
}
}
/// Represents different block movement operations in the cache
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlock {
Use(Vec<UniqueBlock>, Option<f64>),
Destroy(Vec<UniqueBlock>),
Deref(Vec<UniqueBlock>),
Promote(Uuid, GlobalHash),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DirectRequest {
pub tokens: Vec<Token>,
pub max_output_tokens: usize,
pub uuid: Option<Uuid>,
}
/// Represents the cost of prefilling content in the cache
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefillCost {
pub new_tokens: usize,
pub prefill_compute: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unique_block_default_uniqueness() {
// Create 10 default UniqueBlock instances
let blocks: Vec<UniqueBlock> = (0..10).map(|_| UniqueBlock::default()).collect();
// Extract UUIDs from each block
let mut uuids = Vec::new();
for block in blocks {
match block {
UniqueBlock::PartialBlock(uuid) => uuids.push(uuid),
_ => panic!("Expected UuidIdentifier variant"),
}
}
// Check that all UUIDs are unique by comparing each with every other
for i in 0..uuids.len() {
for j in i + 1..uuids.len() {
assert_ne!(
uuids[i], uuids[j],
"UUID at index {} and {} are identical",
i, j
);
}
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Asynchronous Scheduler for LLM Request Management
//!
//! This module implements an asynchronous scheduler that handles three main functions:
//! 1. Receiving new requests and placing them in the waiting queue
//! 2. Scheduling waiting requests against available KV cache resources
//! 3. Simulating the execution of running requests with realistic timing
//!
//! ## Scheduling Process
//! The scheduler uses a watermark-based approach to determine if there's sufficient
//! KV cache space for new requests. It also enforces a batched tokens budget to prevent
//! oversubscription of computational resources. Only requests that can be allocated
//! these resources are moved from waiting to running state.
//!
//! ## Request Simulation
//! The simulation models two key phases:
//! - Prefill phase: Uses a quadratic cost function: (cached_tokens + new_tokens) * new_tokens
//! - Decode phase: Uses a cost function proportional to active KV blocks (linear)
//!
//! ## Resource Management
//! The scheduler communicates with the KvManager through MoveBlock signals at each
//! stage of request processing. When resources become constrained, it employs an
//! LRU-based preemption strategy where the oldest running request is evicted and
//! placed at the back of the waiting queue to be rescheduled later.
//!
//! ## NOTE
//! The current prefill and decoding time simulations are not scientific at all and are WIP
use crate::kv_router::protocols::ForwardPassMetrics;
use crate::mocker::evictor::LRUEvictor;
use crate::mocker::kv_manager::KvManager;
use crate::mocker::protocols::DirectRequest;
use crate::mocker::protocols::{MoveBlock, PrefillCost, UniqueBlock};
use crate::mocker::sequence::ActiveSequence;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use tokio::time::{interval, Duration};
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
/// Enum representing either a direct request or an active sequence
pub enum Request {
Direct(DirectRequest),
Active(ActiveSequence),
}
#[derive(Default)]
struct SchedulerState {
waiting: VecDeque<Uuid>,
ready: VecDeque<Uuid>,
running: LRUEvictor<Uuid>,
requests: HashMap<Uuid, Request>,
prefill_costs: HashMap<Uuid, Option<PrefillCost>>,
}
impl SchedulerState {
/// Create a new UUID for a DirectRequest, add it to requests, and push the UUID to waiting.
fn receive(&mut self, request: DirectRequest) -> Uuid {
// Use the provided UUID if available, otherwise generate a new one
let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
// Add the request to the map and waiting queue
self.requests.insert(uuid, Request::Direct(request));
self.waiting.push_back(uuid);
uuid
}
/// Get the next UUID from ready or waiting queue and its associated Request.
/// Returns from ready if not empty, otherwise from waiting, or None if both are empty.
/// Also removes the Request from the requests HashMap.
fn next(&mut self) -> Option<(Uuid, Request)> {
let uuid = self
.ready
.pop_front()
.or_else(|| self.waiting.pop_front())?;
let request = self.requests.remove(&uuid)?;
Some((uuid, request))
}
/// Move a UUID and its Request to the ready queue.
fn make_ready(&mut self, uuid: Uuid, active_seq: ActiveSequence) {
self.requests.insert(uuid, Request::Active(active_seq));
self.ready.push_back(uuid);
}
/// Schedule the request with the given UUID.
/// Returns the creation signal from the ActiveSequence.
fn run(&mut self, uuid: Uuid, active_seq: ActiveSequence) -> MoveBlock {
// Insert the request into the map
self.requests.insert(uuid, Request::Active(active_seq));
// Get the creation signal
let Some(Request::Active(sequence)) = self.requests.get(&uuid) else {
panic!("Failed to get ActiveSequence for UUID");
};
let Some(signal) = sequence.creation_signal() else {
panic!("Failed to get creation signal from ActiveSequence");
};
// Add to running requests
self.running.insert(uuid);
signal.clone()
}
/// Set the prefill cost for a UUID
fn set_prefill_cost(&mut self, uuid: Uuid, cost: Option<PrefillCost>) {
self.prefill_costs.insert(uuid, cost);
}
/// Get the prefill compute value for a UUID if available
fn get_prefill_compute(&self, uuid: &Uuid) -> Option<f64> {
self.prefill_costs
.get(uuid)
.and_then(|cost| cost.as_ref())
.map(|cost| cost.prefill_compute)
}
/// Calculate the current running batched tokens
fn num_batched_tokens(&self) -> usize {
self.prefill_costs
.values()
.map(|cost| match cost {
Some(cost) => cost.new_tokens,
None => 1,
})
.sum()
}
/// Remove a UUID and its associated Request from collections.
fn complete(&mut self, uuid: &Uuid) {
// println!("Request {} will complete", uuid);
self.running.remove(uuid);
self.requests.remove(uuid);
self.prefill_costs.remove(uuid);
}
/// Preempt the oldest running request by evicting it from running, resetting the sequence,
/// and adding it back to the waiting queue.
/// Returns the signal from reset_with_signal or None if no requests are running.
fn preempt(&mut self) -> Option<Vec<MoveBlock>> {
// Evict the oldest UUID from running
let uuid = self.running.evict()?;
eprintln!("Request {} will be preempted", uuid);
// Remove the request from the requests HashMap and ensure it's an ActiveSequence
let request = self.requests.remove(&uuid)?;
// Remove the prefill cost to force recomputation
self.prefill_costs.remove(&uuid);
// Extract the ActiveSequence from the Request enum
let Request::Active(mut active_sequence) = request else {
panic!("Expected ActiveSequence in running queue")
};
// Reset the sequence and get the new sequence and signal
let signals = active_sequence.reset_with_signal();
// Insert the new sequence back into the requests map and add to waiting queue
self.requests.insert(uuid, Request::Active(active_sequence));
self.waiting.push_back(uuid);
Some(signals)
}
}
/// Manages scheduling of requests using KvManager resources
#[derive(Clone)]
pub struct Scheduler {
state: Arc<Mutex<SchedulerState>>,
kv_manager: Arc<Mutex<KvManager>>,
request_tx: mpsc::Sender<DirectRequest>,
}
impl Scheduler {
/// Create a new Scheduler with the given parameters
pub fn new(
kv_capacity: usize,
watermark: f64,
block_size: usize,
chunk_size: Option<usize>,
output_tx: Option<mpsc::Sender<Uuid>>,
cancellation_token: Option<CancellationToken>,
) -> Self {
// Create KvManager internally
let kv_manager = KvManager::new(kv_capacity, block_size);
let token_capacity: usize = 8192;
let state = Arc::new(Mutex::new(SchedulerState::default()));
let kv_manager = Arc::new(Mutex::new(kv_manager));
let chunk_size = chunk_size.unwrap_or(256);
// Create channel for request handling
let (request_tx, mut request_rx) = mpsc::channel::<DirectRequest>(1024);
// Use provided cancellation token or create new one
let cancellation_token = cancellation_token.unwrap_or_default();
let token_clone = cancellation_token.clone();
// Create a clone for the background task
let state_clone = state.clone();
let kv_manager_clone = kv_manager.clone();
let output_tx_clone = output_tx.clone();
// Spawn main background task with cancellation token
tokio::spawn(async move {
let mut schedule_interval = interval(Duration::from_millis(5));
let mut simulate_interval = interval(Duration::from_millis(1));
loop {
tokio::select! {
biased;
// Enqueue new request
Some(request) = request_rx.recv() => {
let mut state = state_clone.lock().await;
state.receive(request);
}
// Try Scheduling Requests
_ = schedule_interval.tick() => {
let mut state_guard = state_clone.lock().await;
let mut kv_manager_guard = kv_manager_clone.lock().await;
// Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't
// schedule anymore.
while let Some((uuid, request)) = state_guard.next() {
let active_sequence = get_active_sequence(request, block_size, chunk_size);
// Calculate token budget using new_tokens from PrefillCost
let total_prefill_tokens = state_guard.num_batched_tokens();
let tokens_budget = token_capacity.saturating_sub(total_prefill_tokens);
// Check if it can be scheduled
let Some(prefill_cost) = kv_manager_guard.try_schedule(&active_sequence, watermark, tokens_budget) else {
state_guard.make_ready(uuid, active_sequence);
break;
};
// Get creation signal and schedule the request
let signal = state_guard.run(uuid, active_sequence);
kv_manager_guard.process(&signal);
state_guard.set_prefill_cost(uuid, Some(prefill_cost));
}
}
// Check for cancellation
_ = token_clone.cancelled() => {
break;
}
// Simulate running requests (prefill + decode)
_ = simulate_interval.tick() => {
let mut state_guard = state_clone.lock().await;
let mut kv_manager_guard = kv_manager_clone.lock().await;
// Base time needed for decoding (assumed memory bound on KV cache)
let active_tokens = kv_manager_guard.num_active_blocks() * block_size;
// TODO: 2 is a dummy / magic scaling factor
let mut generation_time = Duration::from_micros((active_tokens / 2) as u64);
// Process each running request
let uuids: Vec<Uuid> = state_guard.running.keys().cloned().collect();
for uuid in uuids {
// Check if UUID is still in running_requests, if not skip this iteration
if !state_guard.running.contains(&uuid) {
continue;
}
// Get prefill compute value first
let prefill_compute = state_guard.get_prefill_compute(&uuid);
// Get the active sequence for this UUID
let sequence = state_guard.requests.get_mut(&uuid)
.and_then(|req| if let Request::Active(seq) = req { Some(seq) } else { None })
.expect("UUID in running_requests must have a corresponding active sequence");
// Generate token and get signals
let signals = sequence.generate();
// Accumulate sleep duration based on prefill_compute if available
// prefill compute = (cached_tokens + new_tokens) * new_tokens
let sleep_ms = if let Some(compute) = prefill_compute {
// TODO: 1024 is a dummy / magic scaling factor
(compute / 1024.0) as u64
} else { 0 };
generation_time += Duration::from_micros(sleep_ms);
// Process all signals with the KvManager
// Handling of preemption on failure
if !process_signals(&mut kv_manager_guard, &signals) {
sequence.pop(); // revert the failed generation op
// free_signal derefs the preempted blocks
let Some(free_signal) = state_guard.preempt() else {
panic!("Failed to acquire signal to free KV blocks from preemption");
};
for signal in free_signal {
kv_manager_guard.process(&signal);
}
continue;
}
// Send UUID notification for each generated token
// TODO: hook this up to an AsyncEngine
if let Some(tx) = &output_tx_clone {
let _ = tx.try_send(uuid);
}
// Check if we're done after generating
if sequence.generated_tokens() >= sequence.max_output_tokens() {
state_guard.complete(&uuid);
continue;
}
// Transition to decode (no prefill cost)
if sequence.generated_tokens() == 1 {
state_guard.set_prefill_cost(uuid, None);
}
}
// Sleep once for the accumulated duration
if generation_time.as_millis() > 0 {
tokio::time::sleep(generation_time).await;
}
}
}
}
});
Self {
state,
kv_manager,
request_tx,
}
}
/// Add a new request to the waiting queue
pub async fn receive(&self, request: DirectRequest) {
let _ = self.request_tx.send(request).await;
}
/// Get the count of waiting requests
pub async fn waiting_count(&self) -> usize {
let state = self.state.lock().await;
state.waiting.len()
}
/// Get the count of running requests
pub async fn running_count(&self) -> usize {
let state = self.state.lock().await;
state.running.len()
}
/// Get the current capacity of the KvManager
pub async fn kv_usage_perc(&self) -> f64 {
let kv_manager = self.kv_manager.lock().await;
kv_manager.current_capacity_perc()
}
/// Returns forward pass metrics for monitoring purposes
pub async fn get_forward_pass_metrics(&self) -> ForwardPassMetrics {
let state = self.state.lock().await;
let kv_manager = self.kv_manager.lock().await;
// Get the active blocks and total capacity from KvManager
let active_blocks_count = kv_manager.active_blocks().len() as u64;
let total_capacity = kv_manager.max_capacity() as u64;
// Calculate GPU cache usage percentage
let gpu_cache_usage_perc = if total_capacity > 0 {
active_blocks_count as f32 / total_capacity as f32
} else {
0.0
};
ForwardPassMetrics {
request_active_slots: state.running.len() as u64,
request_total_slots: 420, // Dummy value as specified
kv_active_blocks: active_blocks_count,
kv_total_blocks: total_capacity,
num_requests_waiting: state.waiting.len() as u64,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate: 0.0, // Placeholder value as specified
}
}
}
/// Convert a Request to an ActiveSequence
fn get_active_sequence(request: Request, block_size: usize, chunk_size: usize) -> ActiveSequence {
if let Request::Active(active_seq) = request {
return active_seq;
}
let Request::Direct(direct_request) = request else {
unreachable!("Request must be either Direct or Active");
};
ActiveSequence::new(
direct_request.tokens,
direct_request.max_output_tokens,
Some(block_size),
Some(chunk_size),
)
}
/// Processes MoveBlock signals with the KvManager.
///
/// When a signal fails, this function verifies that the failure is for an expected case:
/// specifically a single signal attempting to create a single partial (generation) block.
/// This validation is important because in normal operation, the only legitimate failure
/// case should be when trying to acquire a new generation block - any other failures would
/// indicate an unexpected state in the system.
fn process_signals(
kv_manager_guard: &mut tokio::sync::MutexGuard<'_, KvManager>,
signals: &[MoveBlock],
) -> bool {
for signal in signals {
if kv_manager_guard.process(signal) {
continue;
}
// Check we have a Use signal with blocks
let MoveBlock::Use(blocks, _) = signal else {
panic!("Failed signal is Invalid. Has to fail on generation signal.");
};
// Verify the signal contains exactly one block
if blocks.len() != 1 {
panic!("Failed signal is Invalid. Can have only one generation signal.");
}
// Verify the block is a PartialBlock (generation block)
if !matches!(blocks[0], UniqueBlock::PartialBlock(_)) {
panic!("Failed signal is Invalid. Generation block has to be partial.");
}
return false;
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
use std::time::Duration;
#[rstest]
#[case::random(false)]
#[case::caching(true)]
#[tokio::test]
async fn test_scheduler_token_generation_patterns(#[case] use_shared_tokens: bool) {
std::env::set_var("RUST_LOG", "debug");
let kv_capacity: usize = 500;
let watermark: f64 = 0.01; // 1% watermark
let block_size: usize = 64;
let chunk_size: usize = 256;
let num_requests: usize = 100;
let input_len: usize = 1000;
let max_output_tokens: usize = 100;
// Create channel for token output
let (output_tx, mut output_rx) = mpsc::channel::<Uuid>(1024);
// Create scheduler with internal KvManager
let scheduler = Scheduler::new(
kv_capacity,
watermark,
block_size,
Some(chunk_size),
Some(output_tx),
None,
);
// Create shared tokens for caching case
let shared_tokens = if use_shared_tokens {
Some(
(0..input_len / 2)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>(),
)
} else {
None
};
// Create test requests
for _ in 0..num_requests {
let input_tokens = if let Some(ref shared) = shared_tokens {
// For caching case: use shared tokens for first half, random for second half
let mut tokens = shared.clone();
tokens.extend((0..input_len / 2).map(|_| rand::random::<u32>() % 50000));
tokens
} else {
// For random case: create unique random token vector for each request
(0..input_len)
.map(|_| rand::random::<u32>() % 50000)
.collect::<Vec<_>>()
};
let request = DirectRequest {
tokens: input_tokens,
max_output_tokens,
uuid: None,
};
scheduler.receive(request).await;
}
let start_time = std::time::Instant::now();
// Collect all generated tokens (should be num_requests * max_output_tokens)
let expected_tokens = num_requests * max_output_tokens;
let mut received_tokens = 0;
// Set up a timeout that causes the test to panic if no tokens are received for 2 seconds
let timeout = tokio::time::sleep(Duration::from_secs(2));
tokio::pin!(timeout);
// Set up debug ticker interval
let mut debug_interval = interval(Duration::from_millis(500));
loop {
tokio::select! {
biased;
// Manual debug ticker that prints forward pass metrics
_ = debug_interval.tick() => {
let _metrics = scheduler.get_forward_pass_metrics().await;
// println!("Forward Pass Metrics: {:#?}", _metrics);
}
Some(_) = output_rx.recv() => {
received_tokens += 1;
// Reset timeout whenever we receive a token
timeout.set(tokio::time::sleep(Duration::from_secs(2)));
}
_ = &mut timeout => {
// Break instead of panicking when timeout occurs
break;
}
}
}
// Calculate and print elapsed time
let elapsed = start_time.elapsed();
println!(
"Test completed in: {:?} for {} case",
elapsed,
if use_shared_tokens {
"caching"
} else {
"random"
}
);
// Assert that we received the expected number of tokens
assert!(
received_tokens > expected_tokens,
"Received {} tokens but expected more than {}",
received_tokens,
expected_tokens
);
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::mocker::protocols::{MoveBlock, UniqueBlock};
use crate::tokens::{TokenBlockSequence, Tokens};
use derive_getters::Getters;
use rand::random;
use uuid;
/// Create unique blocks from a TokenBlockSequence
fn create_unique_blocks_from_sequence(
tokens: &TokenBlockSequence,
uuid: Option<uuid::Uuid>,
block_size: usize,
) -> Vec<UniqueBlock> {
let mut unique_blocks: Vec<UniqueBlock> = tokens
.blocks()
.iter()
.map(|block| UniqueBlock::FullBlock(block.sequence_hash()))
.collect();
// Only push the partial block if tokens count isn't a multiple of block_size
if tokens.total_tokens() % block_size != 0 {
unique_blocks.push(match uuid {
Some(uuid) => UniqueBlock::PartialBlock(uuid),
None => UniqueBlock::default(),
});
}
unique_blocks
}
/// A sequence that is actively being built, with the ability to add tokens and commit to hashes
/// TODO: reuse tokens
#[derive(Debug, Getters)]
pub struct ActiveSequence {
unique_blocks: Vec<UniqueBlock>,
tokens: TokenBlockSequence,
#[getter(copy)]
block_size: usize,
#[getter(copy)]
chunk_size: usize, // TODO: not actually used
#[getter(copy)]
max_output_tokens: usize,
#[getter(copy)]
generated_tokens: usize,
#[getter(copy)]
num_input_tokens: usize,
creation_signal: Option<MoveBlock>,
}
impl ActiveSequence {
/// Create a new ActiveSequence instance with the provided tokens
pub fn new(
tokens: Vec<u32>,
max_output_tokens: usize,
block_size: Option<usize>,
chunk_size: Option<usize>,
) -> Self {
let block_size = block_size.unwrap_or(64);
assert!(block_size > 1, "block_size must be greater than 1");
let chunk_size = chunk_size.unwrap_or(256);
let num_input_tokens = tokens.len();
let tokens = Tokens::from(tokens).into_sequence(block_size, None);
let unique_blocks = create_unique_blocks_from_sequence(&tokens, None, block_size);
let creation_signal = Some(MoveBlock::Use(unique_blocks.clone(), None));
Self {
unique_blocks,
tokens,
block_size,
chunk_size,
max_output_tokens,
generated_tokens: 0,
num_input_tokens,
creation_signal,
}
}
pub fn extra_tokens(&self) -> usize {
self.len() % self.block_size
}
pub fn len(&self) -> usize {
self.tokens.total_tokens()
}
pub fn is_empty(&self) -> bool {
self.tokens.total_tokens() == 0
}
/// Create a new ActiveSequence instance and return the creation signal
pub fn new_with_signal(
tokens: Vec<u32>,
max_output_tokens: usize,
block_size: Option<usize>,
chunk_size: Option<usize>,
) -> (Self, Option<MoveBlock>) {
let mut sequence = Self::new(tokens, max_output_tokens, block_size, chunk_size);
let signal = sequence.creation_signal.take();
(sequence, signal)
}
/// Push a token to the sequence
pub fn push(&mut self, token: u32) -> Option<Vec<MoveBlock>> {
self.tokens.append(token).expect("Token push failed.");
self.generated_tokens += 1;
if self.len() % self.block_size != 1 {
return None;
}
// Add a partial block for the first token in a new partial sequence
// Send Use signal (to allocate space for this new generation block)
let mut signals = Vec::new();
// Replace last partial block with full block if it exists
if let Some(UniqueBlock::PartialBlock(uuid)) = self.unique_blocks.last().cloned() {
let last_block_hash = self.tokens.last_complete_block().unwrap().sequence_hash();
self.unique_blocks.pop();
self.unique_blocks
.push(UniqueBlock::FullBlock(last_block_hash));
signals.push(MoveBlock::Promote(uuid, last_block_hash));
}
let new_partial_block = UniqueBlock::default();
self.unique_blocks.push(new_partial_block.clone());
signals.push(MoveBlock::Use(vec![new_partial_block], None));
Some(signals)
}
/// Generate a random token, push it to the sequence, and increment generation count.
///
/// This function:
/// - Generates a random token and adds it to the current sequence
/// - Acquires a new partial block if needed or promotes an existing partial block to a full block
/// - Returns appropriate signals for the KvManager to process
///
/// # Panics
///
/// Calling this function when max_output_tokens has already been reached will cause a panic.
/// Always check `generated_tokens < max_output_tokens` before calling this method.
pub fn generate(&mut self) -> Vec<MoveBlock> {
// Assert that we haven't reached the maximum output tokens
assert!(
self.generated_tokens < self.max_output_tokens,
"Cannot generate more tokens: reached max_output_tokens limit"
);
// Generate a random token
let token = random::<u32>();
// Collect signals
let mut signals = Vec::new();
// Push the token to the sequence and collect any signals
if let Some(move_blocks) = self.push(token) {
signals.extend(move_blocks);
}
// Check if we've reached the limit after pushing
if self.generated_tokens != self.max_output_tokens {
return signals;
}
// Free all blocks when we reach max tokens
signals.extend(self.free_signal());
signals
}
/// Free all blocks, generating appropriate signals for each block type
pub fn free_signal(&self) -> Vec<MoveBlock> {
self.unique_blocks
.iter()
.rev()
.map(|block| match block {
UniqueBlock::PartialBlock(uuid) => {
MoveBlock::Destroy(vec![UniqueBlock::PartialBlock(*uuid)])
}
UniqueBlock::FullBlock(hash) => {
MoveBlock::Deref(vec![UniqueBlock::FullBlock(*hash)])
}
})
.collect()
}
/// Reset the sequence to its initial state and return the free signals from freeing current blocks
/// maintaining the uuid of the last partial block
pub fn reset_with_signal(&mut self) -> Vec<MoveBlock> {
let free_signal = self.free_signal();
self.tokens.truncate(self.num_input_tokens).unwrap();
self.unique_blocks =
create_unique_blocks_from_sequence(&self.tokens, None, self.block_size);
self.generated_tokens = 0;
self.creation_signal = Some(MoveBlock::Use(self.unique_blocks.clone(), None));
free_signal
}
/// Pops last token in the sequence.
pub fn pop(&mut self) {
self.tokens.pop();
self.generated_tokens = self.generated_tokens.saturating_sub(1);
// Reverts to the last full block
if self.tokens.total_tokens() % self.block_size == 0 {
self.unique_blocks.pop();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_active_sequence_push() {
// Create a sequence with block size 16 initialized with tokens [0..15]
let initial_tokens: Vec<u32> = (0..15).collect();
let (mut seq1, signal1) =
ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), Some(256));
assert_eq!(seq1.num_input_tokens(), 15);
assert_eq!(seq1.len(), 15);
// Check that we got a Use signal
assert!(signal1.is_some());
match &signal1 {
Some(MoveBlock::Use(blocks, _)) => {
assert_eq!(blocks.len(), 1);
}
_ => panic!("Expected Use signal"),
}
// Push token 15 which should complete the block (no signals yet)
let signal_15 = seq1.push(15);
assert!(
signal_15.is_none(),
"Completing a block should not trigger signals"
);
// Push token 16 which should trigger both Promote and Use signals
let signal_16 = seq1.push(16);
assert!(signal_16.is_some());
let signal_16 = signal_16.unwrap();
assert_eq!(signal_16.len(), 2);
// Second signal should be Use for new partial block
match &signal_16[1] {
MoveBlock::Use(blocks, _) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected Use signal as first signal"),
}
// First signal should be Promote for the previous block
match &signal_16[0] {
MoveBlock::Promote(uuid, _) => {
// The uuid is generated dynamically, so we just check it exists
let _ = uuid;
}
_ => panic!("Expected Promote signal as second signal"),
}
// Verify state after pushing tokens
assert_eq!(seq1.unique_blocks().len(), 2); // One full block and one partial block
assert_eq!(seq1.len(), 17);
assert_eq!(seq1.len() % seq1.block_size(), 1);
// Create another sequence with block size 16 initialized with tokens [0..17]
let extended_tokens: Vec<u32> = (0..16).collect();
let (mut seq2, _) =
ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), Some(256));
seq2.push(16);
seq2.pop();
seq2.push(16);
// Simplified assertions
assert_eq!(
seq1.unique_blocks()[0],
seq2.unique_blocks()[0],
"First blocks should be the same"
);
assert_ne!(
seq1.unique_blocks()[1],
seq2.unique_blocks()[1],
"Second blocks should be different"
);
// Reset partial block on seq1 and push back token 16
seq1.push(17);
seq1.pop();
seq1.pop();
seq1.push(16);
// Now push tokens 17..32 to both sequences
for token in 17..33 {
seq1.push(token);
seq2.push(token);
}
// Both sequences should now have 2 blocks:
// 1. FullBlock for tokens 0-15
// 2. FullBlock for tokens 16-31
// 3. No partial block since there are no remaining tokens
assert_eq!(
seq1.unique_blocks().len(),
3,
"seq1 should have exactly 3 blocks"
);
assert_eq!(
seq2.unique_blocks().len(),
3,
"seq2 should have exactly 3 blocks"
);
assert_eq!(
seq1.len() % seq1.block_size(),
1,
"seq1 should have 1 partial token"
);
assert_eq!(
seq2.len() % seq2.block_size(),
1,
"seq2 should have 1 partial token"
);
// Verify that both sequences have identical blocks up to the second position
assert_eq!(
&seq1.unique_blocks()[0..2],
&seq2.unique_blocks()[0..2],
"First two blocks should be identical"
);
// Reset seq1 and check that it equals the original clone
let free_signals = seq1.reset_with_signal();
// Verify the reset signals include proper cleanup events
assert!(!free_signals.is_empty());
}
#[test]
fn test_active_sequence_generate_signals() {
// Create a sequence with block size 16, max_output_tokens 4, initialized with tokens [0..14)
let initial_tokens: Vec<u32> = (0..14).collect();
let (mut seq, signal) =
ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), Some(256));
// Initial signal - should have received a Use signal for the partial block
assert!(signal.is_some());
match signal {
Some(MoveBlock::Use(blocks, _)) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected Use signal for the initial partial block"),
}
// Generate first two tokens - should not trigger new signals
seq.generate();
let signals_first = seq.generate();
assert_eq!(signals_first.len(), 0);
// Generate third token - this fills the block and should trigger both Promote and Use signals
let signals_second = seq.generate();
assert_eq!(signals_second.len(), 2);
// First signal should be Use for new partial block
match &signals_second[1] {
MoveBlock::Use(blocks, _) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected Use signal as second signal after second token"),
}
// Second signal should be Promote
match &signals_second[0] {
MoveBlock::Promote(uuid, hash) => {
// The uuid and hash values are generated dynamically, so we just check the event type
let _ = uuid;
let _ = hash;
}
_ => panic!("Expected Promote signal as first signal after second token"),
}
// Generate fourth token - should not trigger new signals as it's adding to partial block
let signals_third = seq.generate();
assert_eq!(signals_third.len(), 0);
// Generate last token - we reach max_output_tokens, should trigger Destroy and Deref signals
let signals_last = seq.generate();
assert_eq!(signals_last.len(), 2);
// First signal should be Destroy for the partial block
match &signals_last[0] {
MoveBlock::Destroy(blocks) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected Destroy signal for partial block after fourth token"),
}
// Second signal should be Deref for the full block
match &signals_last[1] {
MoveBlock::Deref(blocks) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::FullBlock(_)));
}
_ => panic!("Expected Deref signal for full block after fourth token"),
}
}
}
...@@ -188,7 +188,7 @@ pub enum TokenBlockError { ...@@ -188,7 +188,7 @@ pub enum TokenBlockError {
/// ///
/// This structure accumulates tokens until it reaches the specified `block_size`, /// This structure accumulates tokens until it reaches the specified `block_size`,
/// at which point it can be [`commit`](PartialTokenBlock::commit)ted into a full [`TokenBlock`]. /// at which point it can be [`commit`](PartialTokenBlock::commit)ted into a full [`TokenBlock`].
#[derive(Debug)] // No Clone: intended to be unique within a sequence #[derive(Debug, PartialEq)] // No Clone: intended to be unique within a sequence
pub struct PartialTokenBlock { pub struct PartialTokenBlock {
tokens: Tokens, tokens: Tokens,
block_size: usize, block_size: usize,
...@@ -478,7 +478,7 @@ impl TokenBlock { ...@@ -478,7 +478,7 @@ impl TokenBlock {
/// - [`BlockHash`]: Hash of tokens within a single block (seeded by [`SaltHash`]). /// - [`BlockHash`]: Hash of tokens within a single block (seeded by [`SaltHash`]).
/// - [`SequenceHash`]: Hash combining the previous block's [`SequenceHash`] and the current /// - [`SequenceHash`]: Hash combining the previous block's [`SequenceHash`] and the current
/// block's [`BlockHash`] (also seeded by [`SaltHash`]). /// block's [`BlockHash`] (also seeded by [`SaltHash`]).
#[derive(Debug)] #[derive(Debug, PartialEq)]
pub struct TokenBlockSequence { pub struct TokenBlockSequence {
blocks: Vec<TokenBlock>, blocks: Vec<TokenBlock>,
current_block: PartialTokenBlock, current_block: PartialTokenBlock,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment