"docs/vscode:/vscode.git/clone" did not exist on "4eefbf9609e5ddb996e3ac37e192e92466ec35cc"
Unverified Commit f0652d89 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: vllm mocker enhancement (#1236)

parent 0d6cae85
......@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod engine;
pub mod evictor;
pub mod kv_manager;
pub mod protocols;
......
This diff is collapsed.
......@@ -13,167 +13,158 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::cmp::Eq;
use std::collections::{HashMap, VecDeque};
use std::cmp::{Eq, Ordering};
use std::collections::{BTreeSet, HashMap};
use std::hash::Hash;
use std::time::Instant;
/// A wrapper for (T, counter) that implements Ord based only on counter
#[derive(Debug, Clone, Eq, PartialEq)]
struct PriorityItem<T> {
item: T,
counter: i64,
}
impl<T: Eq> Ord for PriorityItem<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.counter.cmp(&other.counter)
}
}
impl<T: Eq> PartialOrd for PriorityItem<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
/// An LRU evictor that maintains objects and evicts them based on their
/// last accessed time. Implements a "lazy" eviction mechanism where:
/// 1. The priority queue does not immediately reflect updates or removes
/// 2. Objects are pushed to the queue in order of increasing priority (older objects first)
/// 3. The user must ensure objects are added in correct priority (temporal order)
/// 4. Remove and update operations are lazy - entries remain in the queue until
/// they are either evicted or cleaned up during maintenance
/// priority counter. Lower counter values are evicted first.
#[derive(Debug)]
pub struct LRUEvictor<T: Clone + Eq + Hash> {
free_table: HashMap<T, f64>,
priority_queue: VecDeque<(T, f64)>,
cleanup_threshold: usize,
start_time: Instant,
free_table: HashMap<T, i64>,
priority_queue: BTreeSet<PriorityItem<T>>,
positive_counter: i64,
negative_counter: i64,
}
impl<T: Clone + Eq + Hash> Default for LRUEvictor<T> {
fn default() -> Self {
Self {
free_table: HashMap::new(),
priority_queue: VecDeque::new(),
cleanup_threshold: 50,
start_time: Instant::now(),
priority_queue: BTreeSet::new(),
positive_counter: 0,
negative_counter: 0,
}
}
}
impl<T: Clone + Eq + Hash> LRUEvictor<T> {
/// Create a new LRUEvictor with the default cleanup threshold
pub fn new(cleanup_threshold: usize) -> Self {
Self {
cleanup_threshold,
..Default::default()
}
pub fn new(_cleanup_threshold: usize) -> Self {
Self::default()
}
/// Get the current timestamp as seconds since initialization
pub fn current_timestamp(&self) -> f64 {
self.start_time.elapsed().as_secs_f64()
pub fn keys(&self) -> std::collections::hash_map::Keys<'_, T, i64> {
self.free_table.keys()
}
/// Get an iterator over the keys in the evictor
pub fn keys(&self) -> std::collections::hash_map::Keys<'_, T, f64> {
self.free_table.keys()
fn update(&mut self, object: T, counter: i64) {
self.free_table.insert(object.clone(), counter);
self.priority_queue.insert(PriorityItem {
item: object,
counter,
});
}
/// Insert or update an object in the evictor with current timestamp
pub fn insert(&mut self, object: T) {
let timestamp = self.current_timestamp();
self._insert(object, timestamp);
// Remove old entry if it exists
if let Some(&old_counter) = self.free_table.get(&object) {
self.priority_queue.remove(&PriorityItem {
item: object.clone(),
counter: old_counter,
});
}
/// Check if the evictor contains the given object
pub fn contains(&self, object: &T) -> bool {
self.free_table.contains_key(object)
// Increment positive counter and insert
self.positive_counter += 1;
let counter = self.positive_counter;
self.update(object, counter);
}
/// Evict an object based on LRU policy
/// Returns the evicted object or None if no objects are available
pub fn evict(&mut self) -> Option<T> {
if self.free_table.is_empty() {
return None;
/// Push an object to the front with negative counter (highest priority for eviction)
pub fn push_front(&mut self, object: T) {
// Remove old entry if it exists
if let Some(&old_counter) = self.free_table.get(&object) {
self.priority_queue.remove(&PriorityItem {
item: object.clone(),
counter: old_counter,
});
}
while let Some((object, last_accessed)) = self.priority_queue.pop_front() {
let Some(&current_last_accessed) = self.free_table.get(&object) else {
continue; // entry is already removed
};
// Decrement negative counter and insert
self.negative_counter -= 1;
let counter = self.negative_counter;
if current_last_accessed == last_accessed {
self.free_table.remove(&object);
return Some(object);
} // otherwise entry is stale
self.update(object, counter);
}
None
pub fn contains(&self, object: &T) -> bool {
self.free_table.contains_key(object)
}
/// Insert or update an object in the evictor
fn _insert(&mut self, object: T, last_accessed: f64) {
self.free_table.insert(object.clone(), last_accessed);
self.priority_queue.push_back((object, last_accessed));
self.cleanup_if_necessary();
/// Evict an object based on LRU policy (lowest counter value)
/// Returns the evicted object or None if no objects are available
pub fn evict(&mut self) -> Option<T> {
self.priority_queue.pop_first().map(|item| {
self.free_table.remove(&item.item);
item.item
})
}
/// Remove an object from the evictor
/// We don't remove from the priority queue immediately, as that would be inefficient
/// Outdated entries will be filtered out during eviction or cleanup
pub fn remove(&mut self, object: &T) -> bool {
self.free_table.remove(object).is_some()
let Some(&counter) = self.free_table.get(object) else {
return false;
};
self.free_table.remove(object);
self.priority_queue.remove(&PriorityItem {
item: object.clone(),
counter,
});
true
}
/// Get the number of objects in the evictor
pub fn len(&self) -> usize {
self.free_table.len()
}
/// Check if the evictor is empty
pub fn is_empty(&self) -> bool {
self.free_table.is_empty()
}
/// Check if cleanup is necessary and perform it if needed
fn cleanup_if_necessary(&mut self) {
if self.priority_queue.len() > self.cleanup_threshold * self.free_table.len() {
self.cleanup();
}
}
/// Clean up the priority queue by removing outdated entries
fn cleanup(&mut self) {
let mut new_priority_queue = VecDeque::new();
for (object, timestamp) in self.priority_queue.drain(..) {
let Some(&current_timestamp) = self.free_table.get(&object) else {
continue;
};
if current_timestamp == timestamp {
new_priority_queue.push_back((object, timestamp));
}
}
self.priority_queue = new_priority_queue;
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
fn test_lru_evictor_eviction_order(#[case] threshold: usize) {
// Create a new LRUEvictor with the given cleanup threshold
let mut evictor = LRUEvictor::<i32>::new(threshold);
#[test]
fn test_lru_evictor_eviction_order() {
// Create a new LRUEvictor
let mut evictor = LRUEvictor::<i32>::new(1); // threshold value doesn't matter anymore
// Add items in the specified order with small delays between each
// Add items in the specified order
evictor.insert(4);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(3);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(2);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(1);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(5);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(1); // Updates timestamp for 1
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(4); // Updates timestamp for 4
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(2); // Updates timestamp for 2
evictor.insert(1); // Updates counter for 1
evictor.insert(4); // Updates counter for 4
evictor.insert(2); // Updates counter for 2
evictor.push_front(4);
// Verify the eviction order
println!("Testing with threshold {}", threshold);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 4);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 3);
let evicted = evictor.evict().unwrap();
......@@ -181,11 +172,11 @@ mod tests {
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 1);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 4);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 2);
let evicted = evictor.evict();
assert_eq!(evicted, None);
assert_eq!(evictor.len(), 0);
}
// ... existing test_push_front test ...
}
......@@ -46,10 +46,11 @@
//! implementation of the main block manager.
use crate::mocker::evictor::LRUEvictor;
use crate::mocker::protocols::{MoveBlock, PrefillCost, UniqueBlock};
use crate::mocker::protocols::{MoveBlock, MoveBlockResponse, PrefillCost, UniqueBlock};
use crate::mocker::sequence::ActiveSequence;
use derive_getters::Getters;
use std::collections::{HashMap, HashSet};
use tokio::sync::mpsc;
#[derive(Getters)]
pub struct KvManager {
......@@ -57,17 +58,27 @@ pub struct KvManager {
max_capacity: usize,
#[getter(copy)]
block_size: u32,
block_size: usize,
active_blocks: HashMap<UniqueBlock, usize>,
inactive_blocks: LRUEvictor<UniqueBlock>,
all_blocks: HashSet<UniqueBlock>,
move_block_response_tx: Option<mpsc::UnboundedSender<MoveBlockResponse>>,
}
impl KvManager {
pub fn new(max_capacity: usize, block_size: u32) -> Self {
pub fn new(max_capacity: usize, block_size: usize) -> Self {
Self::new_with_sender(max_capacity, block_size, None)
}
pub fn new_with_sender(
max_capacity: usize,
block_size: usize,
move_block_response_tx: Option<mpsc::UnboundedSender<MoveBlockResponse>>,
) -> Self {
let active_blocks = HashMap::new();
let inactive_blocks = LRUEvictor::default();
let all_blocks = HashSet::new();
......@@ -78,18 +89,46 @@ impl KvManager {
active_blocks,
inactive_blocks,
all_blocks,
move_block_response_tx,
}
}
/// Utility method to send block responses with optional reversing
fn send_block_response(
&self,
mut blocks: Vec<u64>,
reverse: bool,
store: bool,
parent_hash: Option<u64>,
) {
if let Some(ref tx) = self.move_block_response_tx {
if !blocks.is_empty() {
if reverse {
blocks.reverse();
}
let response = if store {
MoveBlockResponse::Store(blocks, parent_hash)
} else {
MoveBlockResponse::Remove(blocks)
};
tx.send(response).unwrap();
}
}
}
/// Process a MoveBlock instruction synchronously
pub fn process(&mut self, event: &MoveBlock) -> bool {
match event {
MoveBlock::Use(hashes, _) => {
MoveBlock::Use(hashes) => {
let mut blocks_stored = Vec::<u64>::new();
let mut parent_block: Option<&UniqueBlock> = None;
for hash in hashes {
// First check if it already exists in active blocks
if let Some(ref_count) = self.active_blocks.get_mut(hash) {
// Block already active, just increment reference count
*ref_count += 1;
parent_block = Some(hash);
continue;
}
......@@ -97,6 +136,7 @@ impl KvManager {
if self.inactive_blocks.remove(hash) {
// Insert into active with reference count 1
self.active_blocks.insert(hash.clone(), 1);
parent_block = Some(hash);
continue;
}
......@@ -106,30 +146,53 @@ impl KvManager {
// If at max capacity, evict the oldest entry from inactive blocks
if active_count + inactive_count >= self.max_capacity {
if let Some(evicted) = self.inactive_blocks.evict() {
// Remove evicted block from all_blocks
self.all_blocks.remove(&evicted);
} else {
// Cannot evict block, meaning no free blocks left in inactive pool
// Send a signal, scheduler would expect to handle preemption upon receiving this
let Some(evicted) = self.inactive_blocks.evict() else {
return false;
};
self.all_blocks.remove(&evicted);
if let UniqueBlock::FullBlock(evicted_full_block) = evicted {
self.send_block_response(vec![evicted_full_block], false, false, None);
}
}
// Now insert the new block in active blocks with reference count 1
self.active_blocks.insert(hash.clone(), 1);
// Add to all_blocks as it's a new block
self.all_blocks.insert(hash.clone());
if self.move_block_response_tx.is_some() {
if let UniqueBlock::FullBlock(stored_full_block) = hash {
blocks_stored.push(*stored_full_block);
}
}
}
let parent_hash = match parent_block {
None => None,
Some(UniqueBlock::FullBlock(block)) => Some(*block),
Some(UniqueBlock::PartialBlock(_)) => panic!("parent block cannot be partial"),
};
self.send_block_response(blocks_stored, false, true, parent_hash);
}
MoveBlock::Destroy(hashes) => {
let mut blocks_destroyed = Vec::<u64>::new();
// Loop in inverse direction
for hash in hashes.iter().rev() {
self.active_blocks.remove(hash).unwrap();
// Remove from all_blocks when destroyed
assert!(self.all_blocks.remove(hash));
// Track blocks for batch sending
if self.move_block_response_tx.is_some() {
if let UniqueBlock::FullBlock(destroyed_full_block) = hash {
blocks_destroyed.push(*destroyed_full_block);
}
}
}
self.send_block_response(blocks_destroyed, true, false, None);
}
MoveBlock::Deref(hashes) => {
// Loop in inverse direction
for hash in hashes.iter().rev() {
......@@ -149,15 +212,15 @@ impl KvManager {
}
}
}
MoveBlock::Promote(uuid, hash) => {
MoveBlock::Promote(uuid, hash, parent_hash) => {
let uuid_block = UniqueBlock::PartialBlock(*uuid);
let hash_block = UniqueBlock::FullBlock(*hash);
let Some(ref_count) = self.active_blocks.remove(&uuid_block) else {
let in_all_blocks = self.all_blocks.contains(&uuid_block);
panic!(
"Missing active block for promotion: {:?}. Block still exists: {}",
uuid_block, in_all_blocks
"Missing active block for promotion: {uuid_block:?}. Block still exists: {in_all_blocks}"
);
};
......@@ -167,6 +230,7 @@ impl KvManager {
// Update all_blocks
assert!(self.all_blocks.remove(&uuid_block));
self.all_blocks.insert(hash_block);
self.send_block_response(vec![*hash], false, true, *parent_hash);
}
}
......@@ -178,6 +242,7 @@ impl KvManager {
pub fn probe_new_blocks(&self, blocks: &[UniqueBlock]) -> usize {
blocks
.iter()
// .filter(|&block| !self.active_blocks.contains_key(block))
.filter(|&block| !self.all_blocks.contains(block))
.count()
}
......@@ -200,6 +265,11 @@ impl KvManager {
self.active_blocks.len()
}
/// Get the percentage of active blocks relative to maximum capacity
pub fn get_active_perc(&self) -> f64 {
self.active_blocks.len() as f64 / self.max_capacity as f64
}
/// Get the number of inactive blocks
pub fn num_inactive_blocks(&self) -> usize {
self.inactive_blocks.len()
......@@ -216,63 +286,28 @@ impl KvManager {
}
/// Check if a sequence can be scheduled and calculate cost if possible
pub fn try_schedule(
&self,
sequence: &ActiveSequence,
watermark: f64,
tokens_budget: usize,
) -> Option<PrefillCost> {
// Return None immediately if tokens_budget is 0
if tokens_budget == 0 {
return None;
}
// Get unique blocks from the sequence
let unique_blocks = sequence.unique_blocks();
// Get the count of new blocks
let new_blocks = self.probe_new_blocks(unique_blocks);
// Calculate current usage and available capacity
let active_count = self.active_blocks.len();
// Check if we can schedule based on the watermark
if (active_count + new_blocks) as f64 > (1.0 - watermark) * self.max_capacity as f64 {
return None;
}
// Calculate overlap blocks
let overlap_blocks = unique_blocks.len() - new_blocks;
// Calculate new tokens
let new_tokens = sequence.num_input_tokens() - overlap_blocks * (self.block_size as usize);
// // Print the full equation with actual values substituted
// println!("{} = {} - ({} * {}) (new_tokens = num_input_tokens - overlap_blocks * block_size)",
// new_tokens,
// sequence.num_input_tokens(),
// overlap_blocks,
// self.block_size);
// Return None if new_tokens exceeds tokens_budget
if new_tokens > tokens_budget {
return None;
}
pub fn get_prefill_cost(&self, sequence: &ActiveSequence) -> PrefillCost {
let seq_blocks = sequence.unique_blocks();
let new_blocks = self.probe_new_blocks(seq_blocks);
let overlap_blocks = seq_blocks.len() - new_blocks;
let new_tokens = sequence.num_input_tokens() - overlap_blocks * self.block_size;
// Calculate prefill compute
let prefill_compute =
new_tokens as f64 * (new_tokens + overlap_blocks * (self.block_size as usize)) as f64;
1.25e-6 * (new_tokens as f64).powi(2) + 7.41e-2 * (new_tokens as f64) + 2.62e1;
Some(PrefillCost {
PrefillCost {
new_blocks,
new_tokens,
prefill_compute,
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::mpsc;
#[test]
fn test_failure_on_max_capacity() {
......@@ -282,7 +317,7 @@ mod tests {
// Helper function to use multiple blocks that returns the response
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) -> bool {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Use(blocks, None))
manager.process(&MoveBlock::Use(blocks))
}
// First use 10 blocks (0 to 9) in a batch
......@@ -301,15 +336,17 @@ mod tests {
}
#[test]
// This is taken directly from the example in the vllm v1 prefix caching docs
fn test_block_lifecycle_stringent() {
// Create a KvManager with 10 blocks capacity
let mut manager = KvManager::new(10, 16);
// Create a channel to listen to block responses
let (tx, mut rx) = mpsc::unbounded_channel::<MoveBlockResponse>();
// Create a KvManager with 10 blocks capacity and the response sender
let mut manager = KvManager::new_with_sender(10, 16, Some(tx));
// Helper function to use multiple blocks
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Use(blocks, None));
manager.process(&MoveBlock::Use(blocks));
}
// Helper function to destroy multiple blocks
......@@ -324,6 +361,56 @@ mod tests {
manager.process(&MoveBlock::Deref(blocks));
}
// Helper function to assert block responses
fn assert_block_response(
rx: &mut mpsc::UnboundedReceiver<MoveBlockResponse>,
expected_type: &str,
expected_blocks: Vec<u64>,
description: &str,
) {
let response = rx
.try_recv()
.unwrap_or_else(|_| panic!("Expected {expected_type} response {description}"));
match (&response, expected_type) {
(MoveBlockResponse::Store(blocks, _parent_hash), "Store") => {
assert_eq!(
blocks.len(),
expected_blocks.len(),
"Expected {} blocks in Store response {}",
expected_blocks.len(),
description
);
assert_eq!(
*blocks, expected_blocks,
"Store blocks don't match expected {description}"
);
}
(MoveBlockResponse::Remove(blocks), "Remove") => {
assert_eq!(
blocks.len(),
expected_blocks.len(),
"Expected {} blocks in Remove response {}",
expected_blocks.len(),
description
);
assert_eq!(
*blocks, expected_blocks,
"Remove blocks don't match expected {description}"
);
}
_ => panic!("Expected {expected_type} response, got {response:?} {description}"),
}
}
// Helper function to assert no response is received
fn assert_no_response(
rx: &mut mpsc::UnboundedReceiver<MoveBlockResponse>,
description: &str,
) {
assert!(rx.try_recv().is_err(), "Expected no response {description}",);
}
// Helper function to check if active blocks contain expected blocks with expected ref counts
fn assert_active_blocks(manager: &KvManager, expected_blocks: &[(u64, usize)]) {
assert_eq!(
......@@ -336,14 +423,12 @@ mod tests {
let block = UniqueBlock::FullBlock(id);
assert!(
manager.active_blocks().contains_key(&block),
"Block {} not found in active blocks",
id
"Block {id} not found in active blocks",
);
assert_eq!(
manager.active_blocks().get(&block),
Some(&ref_count),
"Block {} has wrong reference count",
id
"Block {id} has wrong reference count",
);
}
}
......@@ -366,17 +451,18 @@ mod tests {
let block = UniqueBlock::FullBlock(id);
assert!(
inactive_blocks.iter().any(|&b| *b == block),
"Block {} not found in inactive blocks",
id
"Block {id} not found in inactive blocks",
);
}
}
// First use blocks 0, 1, 2, 3, 4 in a batch
use_blocks(&mut manager, (0..5).collect());
assert_block_response(&mut rx, "Store", vec![0, 1, 2, 3, 4], "after first use");
// Then use blocks 0, 1, 5, 6 in a batch
use_blocks(&mut manager, vec![0, 1, 5, 6]);
assert_block_response(&mut rx, "Store", vec![5, 6], "after second use");
// Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2
assert_active_blocks(
......@@ -386,9 +472,11 @@ mod tests {
// Now destroy block 4
destroy_blocks(&mut manager, vec![4]);
assert_block_response(&mut rx, "Remove", vec![4], "after destroy block 4");
// And deref blocks 3, 2, 1, 0 in this order as a batch
deref_blocks(&mut manager, vec![0, 1, 2, 3]);
assert_no_response(&mut rx, "after deref operation");
// Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2
assert_inactive_blocks(&manager, 2, &[3, 2]);
......@@ -396,6 +484,7 @@ mod tests {
// Now destroy block 6
destroy_blocks(&mut manager, vec![6]);
assert_block_response(&mut rx, "Remove", vec![6], "after block 6 eviction");
// And deref blocks 5, 1, 0 as a batch
deref_blocks(&mut manager, vec![0, 1, 5]);
......@@ -406,6 +495,7 @@ mod tests {
// Now use 0, 1, 2, 7, 8, 9 as a batch
use_blocks(&mut manager, vec![0, 1, 2, 7, 8, 9]);
assert_block_response(&mut rx, "Store", vec![7, 8, 9], "after [7, 8, 9] use");
// Check that the inactive_blocks is size 2, and contains 3 and 5
assert_inactive_blocks(&manager, 2, &[3, 5]);
......@@ -420,8 +510,14 @@ mod tests {
// Now use blocks 10, 11, 12 as a batch
use_blocks(&mut manager, vec![10, 11, 12]);
assert_block_response(&mut rx, "Remove", vec![3], "after block 5 eviction");
assert_block_response(&mut rx, "Store", vec![10, 11, 12], "after [10, 11, 12] use");
// Check that the inactive_blocks is size 1 and contains only 5
assert_inactive_blocks(&manager, 1, &[5]);
use_blocks(&mut manager, vec![13]);
assert_block_response(&mut rx, "Remove", vec![5], "after block 5 eviction");
assert_block_response(&mut rx, "Store", vec![13], "after block 13 use");
}
}
......@@ -13,12 +13,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash,
};
pub type Token = u32;
pub type LocalBlockHash = u64;
/// A global hash identifier for blocks
pub type GlobalHash = u64;
pub type NumBlocks = usize;
......@@ -39,12 +43,19 @@ impl Default for UniqueBlock {
}
/// Represents different block movement operations in the cache
/// For Use and Promote variants, parent hash is the second field
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlock {
Use(Vec<UniqueBlock>, Option<f64>),
Use(Vec<UniqueBlock>),
Destroy(Vec<UniqueBlock>),
Deref(Vec<UniqueBlock>),
Promote(Uuid, GlobalHash),
Promote(Uuid, GlobalHash, Option<u64>),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlockResponse {
Store(Vec<GlobalHash>, Option<u64>),
Remove(Vec<GlobalHash>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
......@@ -52,15 +63,86 @@ pub struct DirectRequest {
pub tokens: Vec<Token>,
pub max_output_tokens: usize,
pub uuid: Option<Uuid>,
pub dp_rank: Option<u32>,
}
/// Represents the cost of prefilling content in the cache
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefillCost {
pub new_blocks: usize,
pub new_tokens: usize,
pub prefill_compute: f64,
}
/// Signal for output token generation with completion status
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputSignal {
pub uuid: Uuid,
pub completed: bool,
}
/// Configuration arguments for MockVllmEngine
#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
#[builder(pattern = "owned", build_fn(public))]
pub struct MockEngineArgs {
#[builder(default = "16384")]
pub num_gpu_blocks: usize,
#[builder(default = "64")]
pub block_size: usize,
// This was 1024 in the past but reverted back to 256
#[builder(default = Some(256))]
pub max_num_seqs: Option<usize>,
// default for open api server, for llm class it's 16384
#[builder(default = Some(8192))]
pub max_num_batched_tokens: Option<usize>,
#[builder(default = true)]
pub enable_prefix_caching: bool,
#[builder(default = "0.01")]
pub watermark: f64,
#[builder(default = "1.0")]
pub speedup_ratio: f64,
#[builder(default = "1")]
pub dp_size: u32,
}
impl MockEngineArgs {
pub fn builder() -> MockEngineArgsBuilder {
MockEngineArgsBuilder::default()
}
}
/// Note: This assumes block_hash and tokens_hash are the same, which is not correct in rare cases
/// where the sequence-aware hash differs from the token content hash.
pub fn block_response_to_kv_event(response: MoveBlockResponse) -> KvCacheEventData {
match response {
MoveBlockResponse::Store(full_blocks, parent_hash) => {
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_hash.map(ExternalSequenceBlockHash),
blocks: full_blocks
.into_iter()
.map(|block| KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block),
tokens_hash: LocalBlockHash(block),
})
.collect(),
})
}
MoveBlockResponse::Remove(full_blocks) => KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: full_blocks
.into_iter()
.map(ExternalSequenceBlockHash)
.collect(),
}),
}
}
#[cfg(test)]
mod tests {
use super::*;
......
This diff is collapsed.
......@@ -23,16 +23,23 @@ use uuid;
fn create_unique_blocks_from_sequence(
tokens: &TokenBlockSequence,
uuid: Option<uuid::Uuid>,
block_size: u32,
block_size: usize,
enable_prefix_caching: bool,
) -> Vec<UniqueBlock> {
let mut unique_blocks: Vec<UniqueBlock> = tokens
.blocks()
.iter()
.map(|block| UniqueBlock::FullBlock(block.sequence_hash()))
.map(|block| {
if enable_prefix_caching {
UniqueBlock::FullBlock(block.sequence_hash())
} else {
UniqueBlock::FullBlock(random::<u64>())
}
})
.collect();
// Only push the partial block if tokens count isn't a multiple of block_size
if tokens.total_tokens() % (block_size as usize) != 0 {
if tokens.total_tokens() % block_size != 0 {
unique_blocks.push(match uuid {
Some(uuid) => UniqueBlock::PartialBlock(uuid),
None => UniqueBlock::default(),
......@@ -50,10 +57,7 @@ pub struct ActiveSequence {
tokens: TokenBlockSequence,
#[getter(copy)]
block_size: u32,
#[getter(copy)]
chunk_size: usize, // TODO: not actually used
block_size: usize,
#[getter(copy)]
max_output_tokens: usize,
......@@ -61,10 +65,16 @@ pub struct ActiveSequence {
#[getter(copy)]
generated_tokens: usize,
#[getter(copy)]
already_generated_tokens: usize,
#[getter(copy)]
num_input_tokens: usize,
creation_signal: Option<MoveBlock>,
#[getter(copy)]
enable_prefix_caching: bool,
}
impl ActiveSequence {
......@@ -72,32 +82,33 @@ impl ActiveSequence {
pub fn new(
tokens: Vec<u32>,
max_output_tokens: usize,
block_size: Option<u32>,
chunk_size: Option<usize>,
block_size: Option<usize>,
enable_prefix_caching: bool,
) -> Self {
let block_size = block_size.unwrap_or(64);
assert!(block_size > 1, "block_size must be greater than 1");
let chunk_size = chunk_size.unwrap_or(256);
let num_input_tokens = tokens.len();
let tokens = Tokens::from(tokens).into_sequence(block_size, None);
let unique_blocks = create_unique_blocks_from_sequence(&tokens, None, block_size);
let creation_signal = Some(MoveBlock::Use(unique_blocks.clone(), None));
let tokens = Tokens::from(tokens).into_sequence(block_size as u32, None);
let unique_blocks =
create_unique_blocks_from_sequence(&tokens, None, block_size, enable_prefix_caching);
let creation_signal = Some(MoveBlock::Use(unique_blocks.clone()));
Self {
unique_blocks,
tokens,
block_size,
chunk_size,
max_output_tokens,
generated_tokens: 0,
already_generated_tokens: 0,
num_input_tokens,
creation_signal,
enable_prefix_caching,
}
}
pub fn extra_tokens(&self) -> u32 {
(self.len() % self.block_size as usize) as u32
(self.len() % self.block_size) as u32
}
pub fn len(&self) -> usize {
......@@ -112,20 +123,31 @@ impl ActiveSequence {
pub fn new_with_signal(
tokens: Vec<u32>,
max_output_tokens: usize,
block_size: Option<u32>,
chunk_size: Option<usize>,
block_size: Option<usize>,
enable_prefix_caching: bool,
) -> (Self, Option<MoveBlock>) {
let mut sequence = Self::new(tokens, max_output_tokens, block_size, chunk_size);
let mut sequence = Self::new(tokens, max_output_tokens, block_size, enable_prefix_caching);
let signal = sequence.creation_signal.take();
(sequence, signal)
}
/// Get the parent hash from the second-to-last block if it exists and is a FullBlock
fn get_parent_hash(&self) -> Option<u64> {
if self.unique_blocks.len() < 2 {
return None;
}
match &self.unique_blocks[self.unique_blocks.len() - 2] {
UniqueBlock::FullBlock(hash) => Some(*hash),
_ => panic!("Cannot have a partial block as parent"),
}
}
/// Push a token to the sequence
pub fn push(&mut self, token: u32) -> Option<Vec<MoveBlock>> {
self.tokens.append(token).expect("Token push failed.");
self.generated_tokens += 1;
if self.len() % (self.block_size as usize) != 1 {
if self.len() % self.block_size != 1 {
return None;
}
......@@ -135,16 +157,24 @@ impl ActiveSequence {
// Replace last partial block with full block if it exists
if let Some(UniqueBlock::PartialBlock(uuid)) = self.unique_blocks.last().cloned() {
let last_block_hash = self.tokens.last_complete_block().unwrap().sequence_hash();
let last_block_hash = if self.enable_prefix_caching {
self.tokens.last_complete_block().unwrap().sequence_hash()
} else {
random::<u64>()
};
self.unique_blocks.pop();
self.unique_blocks
.push(UniqueBlock::FullBlock(last_block_hash));
signals.push(MoveBlock::Promote(uuid, last_block_hash));
signals.push(MoveBlock::Promote(
uuid,
last_block_hash,
self.get_parent_hash(),
));
}
let new_partial_block = UniqueBlock::default();
self.unique_blocks.push(new_partial_block.clone());
signals.push(MoveBlock::Use(vec![new_partial_block], None));
signals.push(MoveBlock::Use(vec![new_partial_block]));
Some(signals)
}
......@@ -204,15 +234,19 @@ impl ActiveSequence {
}
/// Reset the sequence to its initial state and return the free signals from freeing current blocks
/// maintaining the uuid of the last partial block
pub fn reset_with_signal(&mut self) -> Vec<MoveBlock> {
let free_signal = self.free_signal();
self.tokens.truncate(self.num_input_tokens).unwrap();
self.unique_blocks =
create_unique_blocks_from_sequence(&self.tokens, None, self.block_size);
self.unique_blocks = create_unique_blocks_from_sequence(
&self.tokens,
None,
self.block_size,
self.enable_prefix_caching,
);
self.already_generated_tokens = self.generated_tokens.max(self.already_generated_tokens);
self.generated_tokens = 0;
self.creation_signal = Some(MoveBlock::Use(self.unique_blocks.clone(), None));
self.creation_signal = Some(MoveBlock::Use(self.unique_blocks.clone()));
free_signal
}
......@@ -223,7 +257,7 @@ impl ActiveSequence {
self.generated_tokens = self.generated_tokens.saturating_sub(1);
// Reverts to the last full block
if self.tokens.total_tokens() % (self.block_size as usize) == 0 {
if self.tokens.total_tokens() % self.block_size == 0 {
self.unique_blocks.pop();
}
}
......@@ -238,14 +272,14 @@ mod tests {
// Create a sequence with block size 16 initialized with tokens [0..15]
let initial_tokens: Vec<u32> = (0..15).collect();
let (mut seq1, signal1) =
ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), Some(256));
ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true);
assert_eq!(seq1.num_input_tokens(), 15);
assert_eq!(seq1.len(), 15);
// Check that we got a Use signal
assert!(signal1.is_some());
match &signal1 {
Some(MoveBlock::Use(blocks, _)) => {
Some(MoveBlock::Use(blocks)) => {
assert_eq!(blocks.len(), 1);
}
_ => panic!("Expected Use signal"),
......@@ -264,33 +298,31 @@ mod tests {
let signal_16 = signal_16.unwrap();
assert_eq!(signal_16.len(), 2);
// First signal should be Promote for the previous block
match &signal_16[0] {
MoveBlock::Promote(_, _, parent_hash) => {
assert_eq!(*parent_hash, None);
}
_ => panic!("Expected Promote signal as second signal"),
}
// Second signal should be Use for new partial block
match &signal_16[1] {
MoveBlock::Use(blocks, _) => {
MoveBlock::Use(blocks) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected Use signal as first signal"),
}
// First signal should be Promote for the previous block
match &signal_16[0] {
MoveBlock::Promote(uuid, _) => {
// The uuid is generated dynamically, so we just check it exists
let _ = uuid;
}
_ => panic!("Expected Promote signal as second signal"),
}
// Verify state after pushing tokens
assert_eq!(seq1.unique_blocks().len(), 2); // One full block and one partial block
assert_eq!(seq1.len(), 17);
assert_eq!(seq1.len() % (seq1.block_size() as usize), 1);
assert_eq!(seq1.len() % seq1.block_size(), 1);
// Create another sequence with block size 16 initialized with tokens [0..17]
let extended_tokens: Vec<u32> = (0..16).collect();
let (mut seq2, _) =
ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), Some(256));
let (mut seq2, _) = ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), true);
seq2.push(16);
seq2.pop();
seq2.push(16);
......@@ -335,12 +367,12 @@ mod tests {
"seq2 should have exactly 3 blocks"
);
assert_eq!(
seq1.len() % (seq1.block_size() as usize),
seq1.len() % seq1.block_size(),
1,
"seq1 should have 1 partial token"
);
assert_eq!(
seq2.len() % (seq2.block_size() as usize),
seq2.len() % seq2.block_size(),
1,
"seq2 should have 1 partial token"
);
......@@ -352,9 +384,38 @@ mod tests {
"First two blocks should be identical"
);
// Push tokens 34..47 to seq1
for token in 33..48 {
seq1.push(token);
}
// Push token 48 and get the signal - this completes the block and triggers signals
let signal = seq1.push(48);
let signal = signal.unwrap();
// Check that signal[0] is promote
match &signal[0] {
MoveBlock::Promote(_, _, parent_hash) => {
// Check that the parent_hash matches unique_blocks[1], which should be a full block
if let UniqueBlock::FullBlock(expected_hash) = seq1.unique_blocks()[1] {
assert_eq!(
*parent_hash,
Some(expected_hash),
"Parent hash should match unique_blocks[1]"
);
} else {
panic!("unique_blocks[1] should be a full block");
}
}
_ => panic!("Expected Promote signal as first signal"),
}
// Reset seq1 and check that it equals the original clone
let free_signals = seq1.reset_with_signal();
// 49 - 15 generated tokens
assert_eq!(seq1.already_generated_tokens, 34);
// Verify the reset signals include proper cleanup events
assert!(!free_signals.is_empty());
}
......@@ -363,13 +424,12 @@ mod tests {
fn test_active_sequence_generate_signals() {
// Create a sequence with block size 16, max_output_tokens 4, initialized with tokens [0..14)
let initial_tokens: Vec<u32> = (0..14).collect();
let (mut seq, signal) =
ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), Some(256));
let (mut seq, signal) = ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), true);
// Initial signal - should have received a Use signal for the partial block
assert!(signal.is_some());
match signal {
Some(MoveBlock::Use(blocks, _)) => {
Some(MoveBlock::Use(blocks)) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
......@@ -385,25 +445,23 @@ mod tests {
let signals_second = seq.generate();
assert_eq!(signals_second.len(), 2);
// First signal should be Use for new partial block
// First signal should be Promote
match &signals_second[0] {
MoveBlock::Promote(_, _, parent_hash) => {
assert_eq!(*parent_hash, None);
}
_ => panic!("Expected Promote signal as first signal after second token"),
}
// Second signal should be Use for new partial block
match &signals_second[1] {
MoveBlock::Use(blocks, _) => {
MoveBlock::Use(blocks) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected Use signal as second signal after second token"),
}
// Second signal should be Promote
match &signals_second[0] {
MoveBlock::Promote(uuid, hash) => {
// The uuid and hash values are generated dynamically, so we just check the event type
let _ = uuid;
let _ = hash;
}
_ => panic!("Expected Promote signal as first signal after second token"),
}
// Generate fourth token - should not trigger new signals as it's adding to partial block
let signals_third = seq.generate();
assert_eq!(signals_third.len(), 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment