Unverified Commit f0652d89 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: vllm mocker enhancement (#1236)

parent 0d6cae85
......@@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod engine;
pub mod evictor;
pub mod kv_manager;
pub mod protocols;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! MockSchedulerEngine - AsyncEngine wrapper around the Scheduler
//!
//! This module provides an AsyncEngine implementation that wraps the Scheduler
//! to provide streaming token generation with realistic timing simulation.
use crate::kv_router::publisher::WorkerMetricsPublisher;
use crate::mocker::protocols::DirectRequest;
use crate::mocker::protocols::{MockEngineArgs, OutputSignal};
use crate::mocker::scheduler::Scheduler;
use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
use crate::protocols::TokenIdType;
use dynamo_runtime::protocols::annotated::Annotated;
use dynamo_runtime::DistributedRuntime;
use tokio_util::sync::CancellationToken;
use dynamo_runtime::{
component::Component,
engine::AsyncEngineContextProvider,
pipeline::{async_trait, AsyncEngine, Error, ManyOut, ResponseStream, SingleIn},
traits::DistributedRuntimeProvider,
Result,
};
use crate::kv_router::protocols::{KvCacheEvent, KvCacheEventData};
use crate::kv_router::publisher::KvEventPublisher;
use futures::StreamExt;
use rand::Rng;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, Mutex, OnceCell};
use tokio::time::{interval, Duration};
use tokio_stream::wrappers::ReceiverStream;
use uuid::Uuid;
pub const MOCKER_COMPONENT: &str = "mocker";
/// Generate a random token ID from 1k to 5k
fn generate_random_token() -> TokenIdType {
let mut rng = rand::rng();
rng.random_range(1000..5000)
}
/// AsyncEngine wrapper around the Scheduler that generates random character tokens
#[derive(Clone)]
pub struct MockVllmEngine {
active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
request_senders: Arc<OnceCell<Vec<mpsc::UnboundedSender<DirectRequest>>>>,
engine_args: MockEngineArgs,
}
impl MockVllmEngine {
/// Create a new MockVllmEngine with the given parameters
pub fn new(args: MockEngineArgs) -> Self {
Self {
active_requests: Arc::new(Mutex::new(HashMap::new())),
request_senders: Arc::new(OnceCell::new()),
engine_args: args,
}
}
pub async fn start(&self, component: Component) -> Result<()> {
let cancel_token = component.drt().runtime().child_token();
let (schedulers, kv_event_receiver) = self.start_schedulers(
self.engine_args.clone(),
self.active_requests.clone(),
cancel_token.clone(),
);
Self::start_metrics_publishing(&schedulers, Some(component.clone()), cancel_token.clone())
.await?;
// Start KV events publishing with the actual receivers from schedulers
if self.engine_args.enable_prefix_caching {
Self::start_kv_events_publishing(
kv_event_receiver,
Some(component.clone()),
self.engine_args.block_size,
cancel_token.clone(),
)
.await?;
}
Ok(())
}
pub fn direct(&self, request: DirectRequest, dp_rank: usize) {
let senders = self.request_senders.get().expect("Not initialized");
let _ = senders[dp_rank].send(request);
}
/// Create schedulers and spawn their background tasks for distributing token notifications
/// Returns schedulers and their corresponding KV event receivers
fn start_schedulers(
&self,
args: MockEngineArgs,
active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
cancel_token: CancellationToken,
) -> (
Vec<Scheduler>,
Vec<mpsc::UnboundedReceiver<KvCacheEventData>>,
) {
let mut schedulers = Vec::<Scheduler>::new();
let mut kv_event_receivers = Vec::new();
let mut senders = Vec::with_capacity(args.dp_size as usize);
// Create multiple schedulers and their background tasks
for dp_rank in 0..args.dp_size {
// Create a shared output channel that this scheduler will use
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
// Create a channel for KV events from this scheduler
let (kv_events_tx, kv_events_rx) = mpsc::unbounded_channel::<KvCacheEventData>();
let scheduler = Scheduler::new(
args.clone(),
Some(dp_rank),
Some(output_tx),
Some(kv_events_tx), // Pass the KV events sender to scheduler
Some(cancel_token.clone()),
);
senders.push(scheduler.request_sender());
schedulers.push(scheduler);
kv_event_receivers.push(kv_events_rx);
// Spawn a background task for this scheduler to distribute token notifications to active requests
// let output_rx = Arc::new(Mutex::new(output_rx));
let active_requests_clone = active_requests.clone();
let cancel_token_cloned = cancel_token.clone();
tokio::spawn(async move {
loop {
tokio::select! {
signal_result = output_rx.recv() => {
let Some(signal) = signal_result else {
break; // Channel closed
};
// Notify the specific request that a token was generated
let active = active_requests_clone.lock().await;
if let Some(request_tx) = active.get(&signal.uuid) {
let _ = request_tx.send(signal);
}
}
_ = cancel_token_cloned.cancelled() => {
break;
}
}
}
});
}
// Set the senders once
self.request_senders
.set(senders)
.expect("Already initialized");
(schedulers, kv_event_receivers)
}
/// Start background tasks to poll and publish metrics every second
async fn start_metrics_publishing(
schedulers: &[Scheduler],
component: Option<Component>,
cancel_token: CancellationToken,
) -> Result<()> {
tracing::info!("Creating metrics publisher");
let metrics_publisher = Arc::new(WorkerMetricsPublisher::new()?);
tracing::info!("Metrics publisher created");
if let Some(comp) = component {
tracing::info!("Creating metrics endpoint");
tokio::spawn({
let publisher = metrics_publisher.clone();
async move {
if let Err(e) = publisher.create_endpoint(comp.clone()).await {
tracing::error!("Metrics endpoint failed: {e}");
}
}
});
// Give it a moment to start
tokio::time::sleep(Duration::from_millis(100)).await;
tracing::info!("Metrics endpoint started (background)");
}
tracing::info!("Starting metrics background tasks");
for (dp_rank, scheduler) in schedulers.iter().enumerate() {
let scheduler = scheduler.clone();
let publisher = metrics_publisher.clone();
let dp_rank = dp_rank as u32;
let cancel_token = cancel_token.clone();
tokio::spawn(async move {
let mut interval = interval(Duration::from_millis(100));
loop {
tokio::select! {
_ = interval.tick() => {
// Get metrics from scheduler
let metrics = scheduler.get_forward_pass_metrics().await;
// Publish metrics
if let Err(e) = publisher.publish(Arc::new(metrics)) {
tracing::warn!("Failed to publish metrics for DP rank {dp_rank}: {e}");
} else {
tracing::trace!("Published metrics for DP rank {}", dp_rank);
}
}
_ = cancel_token.cancelled() => {
tracing::info!("Metrics publishing cancelled for DP rank {dp_rank}");
break;
}
}
}
});
}
tracing::info!("Metrics background tasks started");
Ok(())
}
/// Start background tasks to collect and publish KV events from schedulers
async fn start_kv_events_publishing(
kv_event_receivers: Vec<mpsc::UnboundedReceiver<KvCacheEventData>>,
component: Option<Component>,
block_size: usize,
cancel_token: CancellationToken,
) -> Result<()> {
tracing::info!("Starting KV events publishing");
// Only start KV events publishing if we have a component
let Some(comp) = component else {
tracing::warn!("No component provided, skipping KV events publishing");
return Ok(());
};
tracing::info!("Component found for KV events publishing");
tracing::debug!("Getting worker_id");
let worker_id = comp
.drt()
.primary_lease()
.expect("Cannot publish KV events without lease") // ← This will PANIC on static!
.id();
// let worker_id = 0;
tracing::debug!("Worker_id set to: {worker_id}");
tracing::info!("Creating KV event publisher");
let kv_event_publisher = Arc::new(KvEventPublisher::new(
comp.clone(),
worker_id,
block_size as u32,
None,
)?);
tracing::info!("KV event publisher created");
tracing::info!(
"Starting KV event background tasks for {} receivers",
kv_event_receivers.len()
);
for (dp_rank, mut kv_events_rx) in kv_event_receivers.into_iter().enumerate() {
tracing::debug!("Starting background task for DP rank {dp_rank}");
let publisher = kv_event_publisher.clone();
let dp_rank = dp_rank as u32;
let cancel_token = cancel_token.clone();
tokio::spawn(async move {
tracing::debug!("Background task started for DP rank {dp_rank}");
loop {
tokio::select! {
// Receive actual KV events from the scheduler
Some(event_data) = kv_events_rx.recv() => {
// Convert KvCacheEventData to KvCacheEvent with random UUID as event_id
let event = KvCacheEvent {
event_id: Uuid::new_v4().as_u128() as u64,
data: event_data,
};
// Publish the event
if let Err(e) = publisher.publish(event) {
tracing::warn!("Failed to publish KV event for DP rank {dp_rank}: {e}");
} else {
tracing::trace!("Published KV event for DP rank {dp_rank}");
}
}
_ = cancel_token.cancelled() => {
tracing::info!("KV events publishing cancelled for DP rank {dp_rank}");
break;
}
}
}
});
}
tracing::info!("All KV event background tasks started");
Ok(())
}
}
#[async_trait]
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
for MockVllmEngine
{
async fn generate(
&self,
input: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<LLMEngineOutput>, Error> {
let (request, ctx) = input.into_parts();
// Extract dp_rank from annotations if present
let dp_rank = request
.annotations
.iter()
.find_map(|ann| {
if ann.starts_with("dp_rank:") {
ann.strip_prefix("dp_rank:").and_then(|s| s.parse().ok())
} else {
None
}
})
.unwrap_or(0);
// Validate dp_rank
if dp_rank >= self.engine_args.dp_size {
return Err(Error::msg(format!(
"dp_rank {} is out of bounds for dp_size {}",
dp_rank, self.engine_args.dp_size
)));
}
let request_uuid = ctx.id().parse().unwrap_or(Uuid::new_v4());
// Convert PreprocessedRequest to DirectRequest for scheduler
let direct_request = DirectRequest {
tokens: request.token_ids.clone(),
max_output_tokens: request
.stop_conditions
.max_tokens
.expect("max_output_tokens must be specified for mocker")
as usize,
uuid: Some(request_uuid),
dp_rank: Some(dp_rank),
};
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<OutputSignal>();
{
let mut active = self.active_requests.lock().await;
active.insert(request_uuid, request_tx);
}
// Send the request to the appropriate scheduler based on dp_rank
self.direct(direct_request, dp_rank as usize);
// Create a simple channel for the stream
let (stream_tx, stream_rx) = mpsc::channel::<LLMEngineOutput>(64);
let active_requests = self.active_requests.clone();
let async_context = ctx.context();
let max_tokens = request.stop_conditions.max_tokens.unwrap_or(100) as usize;
// Spawn a task to handle the complex async logic
tokio::spawn(async move {
let mut token_count = 0;
loop {
tokio::select! {
maybe_signal = request_rx.recv() => {
let Some(signal) = maybe_signal else {
let _ = stream_tx.send(LLMEngineOutput::error("All output transmitters closed".to_string())).await;
break;
};
if signal.completed && token_count < max_tokens {
let _ = stream_tx.send(LLMEngineOutput::error("Completion signal received before max tokens reached".to_string())).await;
break;
}
if signal.completed {
let _ = stream_tx.send(LLMEngineOutput::length()).await;
break;
}
// Generate a new token
let token_id = generate_random_token();
token_count += 1;
let output = LLMEngineOutput {
token_ids: vec![token_id],
tokens: None, // Let backend handle detokenization
text: None,
cum_log_probs: None,
log_probs: None,
finish_reason: None,
index: None,
};
if stream_tx.send(output).await.is_err() {
break;
}
}
_ = async_context.stopped() => {
let _ = stream_tx.send(LLMEngineOutput::cancelled()).await;
break;
}
}
}
// Clean up: remove this request from active requests
let mut active = active_requests.lock().await;
active.remove(&request_uuid);
});
// Create a simple ReceiverStream which is naturally Send + Sync
let stream = ReceiverStream::new(stream_rx);
Ok(ResponseStream::new(Box::pin(stream), ctx.context()))
}
}
pub struct AnnotatedMockEngine {
inner: Arc<MockVllmEngine>,
}
impl AnnotatedMockEngine {
pub fn new(
inner: MockVllmEngine,
distributed_runtime: DistributedRuntime,
endpoint: dynamo_runtime::protocols::Endpoint,
) -> Self {
let inner = Arc::new(inner);
let inner_clone = inner.clone();
// Start background task to wait for component service and start the engine
tokio::spawn(async move {
loop {
// Try to create component
let Ok(namespace) = distributed_runtime.namespace(&endpoint.namespace) else {
tracing::debug!("Namespace not available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
let Ok(component) = namespace.component(&endpoint.component) else {
tracing::debug!("Component not available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
// Check if service is available by trying to list instances
let Ok(instances) = component.list_instances().await else {
tracing::debug!("Cannot list instances yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
};
if instances.is_empty() {
tracing::debug!("No instances available yet, retrying...");
tokio::time::sleep(Duration::from_millis(100)).await;
continue;
}
tracing::info!("Component service is now available, starting mocker engine");
// Start the engine with the component
if let Err(e) = inner_clone.start(component).await {
tracing::error!("Failed to start mocker engine: {e}");
}
break;
}
});
Self { inner }
}
}
#[async_trait]
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for AnnotatedMockEngine
{
async fn generate(
&self,
input: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
let stream = self.inner.generate(input).await?;
let context = stream.context();
// Convert stream of LLMEngineOutput to Annotated<LLMEngineOutput>
let annotated_stream = stream.map(Annotated::from_data);
Ok(ResponseStream::new(Box::pin(annotated_stream), context))
}
}
/// Create a mocker engine as ExecutionContext
pub async fn make_mocker_engine(
distributed_runtime: DistributedRuntime,
endpoint: dynamo_runtime::protocols::Endpoint,
args: MockEngineArgs,
) -> Result<crate::backend::ExecutionContext, Error> {
// Create the mocker engine
tracing::info!("Creating mocker engine (service will be started in background)");
let annotated_engine =
AnnotatedMockEngine::new(MockVllmEngine::new(args), distributed_runtime, endpoint);
Ok(Arc::new(annotated_engine))
}
#[cfg(test)]
mod integration_tests {
use super::*;
use crate::kv_router::indexer::RouterEvent;
use crate::kv_router::KV_EVENT_SUBJECT;
use crate::protocols::common::{SamplingOptions, StopConditions};
use dynamo_runtime::{
pipeline::Context,
pipeline::{network::Ingress, PushRouter},
traits::events::EventSubscriber,
DistributedRuntime, Worker,
};
use futures::StreamExt;
use tokio::time::timeout;
#[tokio::test]
#[ignore] // Run with: cargo test -- --ignored
async fn test_mock_vllm_engine_full_integration() -> Result<()> {
const DP_SIZE: u32 = 2;
const TOKENS_PER_REQUEST: usize = 20;
const BLOCK_SIZE: usize = 2;
// Create runtime and distributed runtime
let worker = Worker::from_settings()?;
let runtime = worker.runtime();
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
tracing::info!("✓ Runtime and distributed runtime created");
// Create component for MockVllmEngine (needed for publishers)
let test_component = distributed
.namespace("test")?
.component(MOCKER_COMPONENT)?
.service_builder()
.create()
.await?;
tracing::info!("✓ Test component created");
// Create MockVllmEngine WITH component (enables publishers)
let args = MockEngineArgs::builder()
.speedup_ratio(10.0)
.dp_size(DP_SIZE)
.block_size(BLOCK_SIZE)
.build()
.unwrap();
let engine = MockVllmEngine::new(args);
engine.start(test_component.clone()).await?;
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
let engine = Arc::new(engine);
tracing::info!("✓ MockVllmEngine created with DP_SIZE: {DP_SIZE}");
// Set up KV events subscriber
let mut kv_events_subscriber = test_component.subscribe(KV_EVENT_SUBJECT).await?;
tracing::info!("✓ KV events subscriber created");
// Wrap with Ingress and register with component/endpoint
let ingress = Ingress::for_engine(engine)?;
tracing::info!("✓ Ingress wrapper created");
// Start the server in background
let server_handle = tokio::spawn({
let test_component = test_component.clone();
async move {
if let Err(e) = test_component
.endpoint("generate")
.endpoint_builder()
.handler(ingress)
.start()
.await
{
eprintln!("❌ Generate endpoint failed: {e}");
}
}
});
tracing::info!("✓ Server started in background");
// Give server time to start
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
tracing::info!("✓ Server startup delay completed");
// Print all registered instances from etcd
match test_component.list_instances().await {
Ok(instances) => {
tracing::info!("📋 Found {} registered instances:", instances.len());
for instance in instances {
tracing::info!(
" • {}/{}/{} (ID: {})",
instance.namespace,
instance.component,
instance.endpoint,
instance.instance_id
);
}
}
Err(e) => {
tracing::error!("❌ Failed to list instances: {e}");
}
}
// Create client
let client = distributed
.namespace("test")?
.component(MOCKER_COMPONENT)?
.endpoint("generate")
.client()
.await?;
tracing::info!("✓ Client created");
let router = PushRouter::from_client(client, Default::default()).await?;
tracing::info!("✓ Router created");
// Create test requests for both DP workers
let create_request = |tokens: Vec<TokenIdType>, dp_rank: u32| PreprocessedRequest {
token_ids: tokens,
batch_token_ids: None,
stop_conditions: StopConditions {
max_tokens: Some(TOKENS_PER_REQUEST as u32),
..Default::default()
},
sampling_options: SamplingOptions::default(),
eos_token_ids: vec![],
mdc_sum: None,
annotations: vec![format!("dp_rank:{dp_rank}")],
estimated_prefix_hit_num_blocks: None,
};
let requests = vec![
create_request(vec![1, 2, 3, 4, 5], 0),
create_request(vec![1, 2, 3, 4, 5], 0),
create_request(vec![1, 2, 3, 4, 5], 1),
create_request(vec![1, 2, 3, 4, 5], 1),
];
tracing::info!(
"✓ Test requests created ({} requests total)",
requests.len()
);
// Test each request
for (i, request) in requests.into_iter().enumerate() {
tracing::info!("Testing request {}", i + 1);
let response_stream = router.generate(Context::new(request)).await?;
let responses: Vec<LLMEngineOutput> = response_stream.collect().await;
// Should have at least one response
assert!(
!responses.is_empty(),
"Request {} should produce at least one response",
i + 1
);
// Count total tokens generated (excluding final message)
let mut total_tokens = 0;
let mut has_finish_reason = false;
for response in &responses {
total_tokens += response.token_ids.len();
if response.finish_reason.is_some() {
has_finish_reason = true;
}
}
// Should have a finish reason in the last response
assert!(
has_finish_reason,
"Request {} should have a finish reason",
i + 1
);
// Verify we got approximately the expected number of tokens
assert!(
total_tokens <= TOKENS_PER_REQUEST + 1, // +1 for potential final empty response
"Request {} generated {} tokens, expected at most {}",
i + 1,
total_tokens,
TOKENS_PER_REQUEST + 1
);
tracing::info!(
"✓ Request {} completed successfully with {} tokens",
i + 1,
total_tokens
);
}
tracing::info!("🎉 All requests completed successfully!");
// Try to receive at least one KV event with 100ms timeout
tracing::info!("Waiting for KV event with 100ms timeout...");
let msg = timeout(Duration::from_millis(100), kv_events_subscriber.next())
.await
.map_err(|_| Error::msg("Timeout waiting for KV event"))?
.ok_or_else(|| Error::msg("KV events stream ended unexpectedly"))?;
match serde_json::from_slice::<RouterEvent>(&msg.payload) {
Ok(event) => {
tracing::info!("✓ Received KV event: {event:?}");
}
Err(e) => {
return Err(Error::msg(format!("Failed to deserialize KV event: {e}")));
}
}
// Use KvMetricsAggregator to get metrics more easily
let cancel_token = test_component.drt().runtime().child_token();
let metrics_aggregator = crate::kv_router::metrics_aggregator::KvMetricsAggregator::new(
test_component.clone(),
cancel_token,
)
.await;
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
let processed_endpoints = metrics_aggregator.get_endpoints();
tracing::info!(
"Found {} metrics endpoints",
processed_endpoints.endpoints.len()
);
// Verify we found at least one metrics endpoint
assert!(
!processed_endpoints.endpoints.is_empty(),
"Should find at least one metrics endpoint"
);
tracing::info!(
"✓ Successfully found {} metrics endpoints",
processed_endpoints.endpoints.len()
);
// Verify the metrics endpoints contain valid data
for (worker_id, endpoint) in &processed_endpoints.endpoints {
tracing::info!("✓ Worker {} metrics: {:?}", worker_id, endpoint.data);
}
tracing::info!("🎉 Event verification completed!");
// Cleanup
distributed.shutdown();
server_handle.await?;
Ok(())
}
}
......@@ -13,167 +13,158 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::cmp::Eq;
use std::collections::{HashMap, VecDeque};
use std::cmp::{Eq, Ordering};
use std::collections::{BTreeSet, HashMap};
use std::hash::Hash;
use std::time::Instant;
/// A wrapper for (T, counter) that implements Ord based only on counter
#[derive(Debug, Clone, Eq, PartialEq)]
struct PriorityItem<T> {
item: T,
counter: i64,
}
impl<T: Eq> Ord for PriorityItem<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.counter.cmp(&other.counter)
}
}
impl<T: Eq> PartialOrd for PriorityItem<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
/// An LRU evictor that maintains objects and evicts them based on their
/// last accessed time. Implements a "lazy" eviction mechanism where:
/// 1. The priority queue does not immediately reflect updates or removes
/// 2. Objects are pushed to the queue in order of increasing priority (older objects first)
/// 3. The user must ensure objects are added in correct priority (temporal order)
/// 4. Remove and update operations are lazy - entries remain in the queue until
/// they are either evicted or cleaned up during maintenance
/// priority counter. Lower counter values are evicted first.
#[derive(Debug)]
pub struct LRUEvictor<T: Clone + Eq + Hash> {
free_table: HashMap<T, f64>,
priority_queue: VecDeque<(T, f64)>,
cleanup_threshold: usize,
start_time: Instant,
free_table: HashMap<T, i64>,
priority_queue: BTreeSet<PriorityItem<T>>,
positive_counter: i64,
negative_counter: i64,
}
impl<T: Clone + Eq + Hash> Default for LRUEvictor<T> {
fn default() -> Self {
Self {
free_table: HashMap::new(),
priority_queue: VecDeque::new(),
cleanup_threshold: 50,
start_time: Instant::now(),
priority_queue: BTreeSet::new(),
positive_counter: 0,
negative_counter: 0,
}
}
}
impl<T: Clone + Eq + Hash> LRUEvictor<T> {
/// Create a new LRUEvictor with the default cleanup threshold
pub fn new(cleanup_threshold: usize) -> Self {
Self {
cleanup_threshold,
..Default::default()
}
pub fn new(_cleanup_threshold: usize) -> Self {
Self::default()
}
/// Get the current timestamp as seconds since initialization
pub fn current_timestamp(&self) -> f64 {
self.start_time.elapsed().as_secs_f64()
pub fn keys(&self) -> std::collections::hash_map::Keys<'_, T, i64> {
self.free_table.keys()
}
/// Get an iterator over the keys in the evictor
pub fn keys(&self) -> std::collections::hash_map::Keys<'_, T, f64> {
self.free_table.keys()
fn update(&mut self, object: T, counter: i64) {
self.free_table.insert(object.clone(), counter);
self.priority_queue.insert(PriorityItem {
item: object,
counter,
});
}
/// Insert or update an object in the evictor with current timestamp
pub fn insert(&mut self, object: T) {
let timestamp = self.current_timestamp();
self._insert(object, timestamp);
// Remove old entry if it exists
if let Some(&old_counter) = self.free_table.get(&object) {
self.priority_queue.remove(&PriorityItem {
item: object.clone(),
counter: old_counter,
});
}
/// Check if the evictor contains the given object
pub fn contains(&self, object: &T) -> bool {
self.free_table.contains_key(object)
// Increment positive counter and insert
self.positive_counter += 1;
let counter = self.positive_counter;
self.update(object, counter);
}
/// Evict an object based on LRU policy
/// Returns the evicted object or None if no objects are available
pub fn evict(&mut self) -> Option<T> {
if self.free_table.is_empty() {
return None;
/// Push an object to the front with negative counter (highest priority for eviction)
pub fn push_front(&mut self, object: T) {
// Remove old entry if it exists
if let Some(&old_counter) = self.free_table.get(&object) {
self.priority_queue.remove(&PriorityItem {
item: object.clone(),
counter: old_counter,
});
}
while let Some((object, last_accessed)) = self.priority_queue.pop_front() {
let Some(&current_last_accessed) = self.free_table.get(&object) else {
continue; // entry is already removed
};
// Decrement negative counter and insert
self.negative_counter -= 1;
let counter = self.negative_counter;
if current_last_accessed == last_accessed {
self.free_table.remove(&object);
return Some(object);
} // otherwise entry is stale
self.update(object, counter);
}
None
pub fn contains(&self, object: &T) -> bool {
self.free_table.contains_key(object)
}
/// Insert or update an object in the evictor
fn _insert(&mut self, object: T, last_accessed: f64) {
self.free_table.insert(object.clone(), last_accessed);
self.priority_queue.push_back((object, last_accessed));
self.cleanup_if_necessary();
/// Evict an object based on LRU policy (lowest counter value)
/// Returns the evicted object or None if no objects are available
pub fn evict(&mut self) -> Option<T> {
self.priority_queue.pop_first().map(|item| {
self.free_table.remove(&item.item);
item.item
})
}
/// Remove an object from the evictor
/// We don't remove from the priority queue immediately, as that would be inefficient
/// Outdated entries will be filtered out during eviction or cleanup
pub fn remove(&mut self, object: &T) -> bool {
self.free_table.remove(object).is_some()
let Some(&counter) = self.free_table.get(object) else {
return false;
};
self.free_table.remove(object);
self.priority_queue.remove(&PriorityItem {
item: object.clone(),
counter,
});
true
}
/// Get the number of objects in the evictor
pub fn len(&self) -> usize {
self.free_table.len()
}
/// Check if the evictor is empty
pub fn is_empty(&self) -> bool {
self.free_table.is_empty()
}
/// Check if cleanup is necessary and perform it if needed
fn cleanup_if_necessary(&mut self) {
if self.priority_queue.len() > self.cleanup_threshold * self.free_table.len() {
self.cleanup();
}
}
/// Clean up the priority queue by removing outdated entries
fn cleanup(&mut self) {
let mut new_priority_queue = VecDeque::new();
for (object, timestamp) in self.priority_queue.drain(..) {
let Some(&current_timestamp) = self.free_table.get(&object) else {
continue;
};
if current_timestamp == timestamp {
new_priority_queue.push_back((object, timestamp));
}
}
self.priority_queue = new_priority_queue;
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
#[rstest]
#[case(1)]
#[case(2)]
#[case(3)]
fn test_lru_evictor_eviction_order(#[case] threshold: usize) {
// Create a new LRUEvictor with the given cleanup threshold
let mut evictor = LRUEvictor::<i32>::new(threshold);
#[test]
fn test_lru_evictor_eviction_order() {
// Create a new LRUEvictor
let mut evictor = LRUEvictor::<i32>::new(1); // threshold value doesn't matter anymore
// Add items in the specified order with small delays between each
// Add items in the specified order
evictor.insert(4);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(3);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(2);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(1);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(5);
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(1); // Updates timestamp for 1
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(4); // Updates timestamp for 4
std::thread::sleep(std::time::Duration::from_millis(1));
evictor.insert(2); // Updates timestamp for 2
evictor.insert(1); // Updates counter for 1
evictor.insert(4); // Updates counter for 4
evictor.insert(2); // Updates counter for 2
evictor.push_front(4);
// Verify the eviction order
println!("Testing with threshold {}", threshold);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 4);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 3);
let evicted = evictor.evict().unwrap();
......@@ -181,11 +172,11 @@ mod tests {
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 1);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 4);
let evicted = evictor.evict().unwrap();
assert_eq!(evicted, 2);
let evicted = evictor.evict();
assert_eq!(evicted, None);
assert_eq!(evictor.len(), 0);
}
// ... existing test_push_front test ...
}
......@@ -46,10 +46,11 @@
//! implementation of the main block manager.
use crate::mocker::evictor::LRUEvictor;
use crate::mocker::protocols::{MoveBlock, PrefillCost, UniqueBlock};
use crate::mocker::protocols::{MoveBlock, MoveBlockResponse, PrefillCost, UniqueBlock};
use crate::mocker::sequence::ActiveSequence;
use derive_getters::Getters;
use std::collections::{HashMap, HashSet};
use tokio::sync::mpsc;
#[derive(Getters)]
pub struct KvManager {
......@@ -57,17 +58,27 @@ pub struct KvManager {
max_capacity: usize,
#[getter(copy)]
block_size: u32,
block_size: usize,
active_blocks: HashMap<UniqueBlock, usize>,
inactive_blocks: LRUEvictor<UniqueBlock>,
all_blocks: HashSet<UniqueBlock>,
move_block_response_tx: Option<mpsc::UnboundedSender<MoveBlockResponse>>,
}
impl KvManager {
pub fn new(max_capacity: usize, block_size: u32) -> Self {
pub fn new(max_capacity: usize, block_size: usize) -> Self {
Self::new_with_sender(max_capacity, block_size, None)
}
pub fn new_with_sender(
max_capacity: usize,
block_size: usize,
move_block_response_tx: Option<mpsc::UnboundedSender<MoveBlockResponse>>,
) -> Self {
let active_blocks = HashMap::new();
let inactive_blocks = LRUEvictor::default();
let all_blocks = HashSet::new();
......@@ -78,18 +89,46 @@ impl KvManager {
active_blocks,
inactive_blocks,
all_blocks,
move_block_response_tx,
}
}
/// Utility method to send block responses with optional reversing
fn send_block_response(
&self,
mut blocks: Vec<u64>,
reverse: bool,
store: bool,
parent_hash: Option<u64>,
) {
if let Some(ref tx) = self.move_block_response_tx {
if !blocks.is_empty() {
if reverse {
blocks.reverse();
}
let response = if store {
MoveBlockResponse::Store(blocks, parent_hash)
} else {
MoveBlockResponse::Remove(blocks)
};
tx.send(response).unwrap();
}
}
}
/// Process a MoveBlock instruction synchronously
pub fn process(&mut self, event: &MoveBlock) -> bool {
match event {
MoveBlock::Use(hashes, _) => {
MoveBlock::Use(hashes) => {
let mut blocks_stored = Vec::<u64>::new();
let mut parent_block: Option<&UniqueBlock> = None;
for hash in hashes {
// First check if it already exists in active blocks
if let Some(ref_count) = self.active_blocks.get_mut(hash) {
// Block already active, just increment reference count
*ref_count += 1;
parent_block = Some(hash);
continue;
}
......@@ -97,6 +136,7 @@ impl KvManager {
if self.inactive_blocks.remove(hash) {
// Insert into active with reference count 1
self.active_blocks.insert(hash.clone(), 1);
parent_block = Some(hash);
continue;
}
......@@ -106,30 +146,53 @@ impl KvManager {
// If at max capacity, evict the oldest entry from inactive blocks
if active_count + inactive_count >= self.max_capacity {
if let Some(evicted) = self.inactive_blocks.evict() {
// Remove evicted block from all_blocks
self.all_blocks.remove(&evicted);
} else {
// Cannot evict block, meaning no free blocks left in inactive pool
// Send a signal, scheduler would expect to handle preemption upon receiving this
let Some(evicted) = self.inactive_blocks.evict() else {
return false;
};
self.all_blocks.remove(&evicted);
if let UniqueBlock::FullBlock(evicted_full_block) = evicted {
self.send_block_response(vec![evicted_full_block], false, false, None);
}
}
// Now insert the new block in active blocks with reference count 1
self.active_blocks.insert(hash.clone(), 1);
// Add to all_blocks as it's a new block
self.all_blocks.insert(hash.clone());
if self.move_block_response_tx.is_some() {
if let UniqueBlock::FullBlock(stored_full_block) = hash {
blocks_stored.push(*stored_full_block);
}
}
}
let parent_hash = match parent_block {
None => None,
Some(UniqueBlock::FullBlock(block)) => Some(*block),
Some(UniqueBlock::PartialBlock(_)) => panic!("parent block cannot be partial"),
};
self.send_block_response(blocks_stored, false, true, parent_hash);
}
MoveBlock::Destroy(hashes) => {
let mut blocks_destroyed = Vec::<u64>::new();
// Loop in inverse direction
for hash in hashes.iter().rev() {
self.active_blocks.remove(hash).unwrap();
// Remove from all_blocks when destroyed
assert!(self.all_blocks.remove(hash));
// Track blocks for batch sending
if self.move_block_response_tx.is_some() {
if let UniqueBlock::FullBlock(destroyed_full_block) = hash {
blocks_destroyed.push(*destroyed_full_block);
}
}
}
self.send_block_response(blocks_destroyed, true, false, None);
}
MoveBlock::Deref(hashes) => {
// Loop in inverse direction
for hash in hashes.iter().rev() {
......@@ -149,15 +212,15 @@ impl KvManager {
}
}
}
MoveBlock::Promote(uuid, hash) => {
MoveBlock::Promote(uuid, hash, parent_hash) => {
let uuid_block = UniqueBlock::PartialBlock(*uuid);
let hash_block = UniqueBlock::FullBlock(*hash);
let Some(ref_count) = self.active_blocks.remove(&uuid_block) else {
let in_all_blocks = self.all_blocks.contains(&uuid_block);
panic!(
"Missing active block for promotion: {:?}. Block still exists: {}",
uuid_block, in_all_blocks
"Missing active block for promotion: {uuid_block:?}. Block still exists: {in_all_blocks}"
);
};
......@@ -167,6 +230,7 @@ impl KvManager {
// Update all_blocks
assert!(self.all_blocks.remove(&uuid_block));
self.all_blocks.insert(hash_block);
self.send_block_response(vec![*hash], false, true, *parent_hash);
}
}
......@@ -178,6 +242,7 @@ impl KvManager {
pub fn probe_new_blocks(&self, blocks: &[UniqueBlock]) -> usize {
blocks
.iter()
// .filter(|&block| !self.active_blocks.contains_key(block))
.filter(|&block| !self.all_blocks.contains(block))
.count()
}
......@@ -200,6 +265,11 @@ impl KvManager {
self.active_blocks.len()
}
/// Get the percentage of active blocks relative to maximum capacity
pub fn get_active_perc(&self) -> f64 {
self.active_blocks.len() as f64 / self.max_capacity as f64
}
/// Get the number of inactive blocks
pub fn num_inactive_blocks(&self) -> usize {
self.inactive_blocks.len()
......@@ -216,63 +286,28 @@ impl KvManager {
}
/// Check if a sequence can be scheduled and calculate cost if possible
pub fn try_schedule(
&self,
sequence: &ActiveSequence,
watermark: f64,
tokens_budget: usize,
) -> Option<PrefillCost> {
// Return None immediately if tokens_budget is 0
if tokens_budget == 0 {
return None;
}
// Get unique blocks from the sequence
let unique_blocks = sequence.unique_blocks();
// Get the count of new blocks
let new_blocks = self.probe_new_blocks(unique_blocks);
// Calculate current usage and available capacity
let active_count = self.active_blocks.len();
// Check if we can schedule based on the watermark
if (active_count + new_blocks) as f64 > (1.0 - watermark) * self.max_capacity as f64 {
return None;
}
// Calculate overlap blocks
let overlap_blocks = unique_blocks.len() - new_blocks;
// Calculate new tokens
let new_tokens = sequence.num_input_tokens() - overlap_blocks * (self.block_size as usize);
// // Print the full equation with actual values substituted
// println!("{} = {} - ({} * {}) (new_tokens = num_input_tokens - overlap_blocks * block_size)",
// new_tokens,
// sequence.num_input_tokens(),
// overlap_blocks,
// self.block_size);
// Return None if new_tokens exceeds tokens_budget
if new_tokens > tokens_budget {
return None;
}
pub fn get_prefill_cost(&self, sequence: &ActiveSequence) -> PrefillCost {
let seq_blocks = sequence.unique_blocks();
let new_blocks = self.probe_new_blocks(seq_blocks);
let overlap_blocks = seq_blocks.len() - new_blocks;
let new_tokens = sequence.num_input_tokens() - overlap_blocks * self.block_size;
// Calculate prefill compute
let prefill_compute =
new_tokens as f64 * (new_tokens + overlap_blocks * (self.block_size as usize)) as f64;
1.25e-6 * (new_tokens as f64).powi(2) + 7.41e-2 * (new_tokens as f64) + 2.62e1;
Some(PrefillCost {
PrefillCost {
new_blocks,
new_tokens,
prefill_compute,
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::mpsc;
#[test]
fn test_failure_on_max_capacity() {
......@@ -282,7 +317,7 @@ mod tests {
// Helper function to use multiple blocks that returns the response
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) -> bool {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Use(blocks, None))
manager.process(&MoveBlock::Use(blocks))
}
// First use 10 blocks (0 to 9) in a batch
......@@ -301,15 +336,17 @@ mod tests {
}
#[test]
// This is taken directly from the example in the vllm v1 prefix caching docs
fn test_block_lifecycle_stringent() {
// Create a KvManager with 10 blocks capacity
let mut manager = KvManager::new(10, 16);
// Create a channel to listen to block responses
let (tx, mut rx) = mpsc::unbounded_channel::<MoveBlockResponse>();
// Create a KvManager with 10 blocks capacity and the response sender
let mut manager = KvManager::new_with_sender(10, 16, Some(tx));
// Helper function to use multiple blocks
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Use(blocks, None));
manager.process(&MoveBlock::Use(blocks));
}
// Helper function to destroy multiple blocks
......@@ -324,6 +361,56 @@ mod tests {
manager.process(&MoveBlock::Deref(blocks));
}
// Helper function to assert block responses
fn assert_block_response(
rx: &mut mpsc::UnboundedReceiver<MoveBlockResponse>,
expected_type: &str,
expected_blocks: Vec<u64>,
description: &str,
) {
let response = rx
.try_recv()
.unwrap_or_else(|_| panic!("Expected {expected_type} response {description}"));
match (&response, expected_type) {
(MoveBlockResponse::Store(blocks, _parent_hash), "Store") => {
assert_eq!(
blocks.len(),
expected_blocks.len(),
"Expected {} blocks in Store response {}",
expected_blocks.len(),
description
);
assert_eq!(
*blocks, expected_blocks,
"Store blocks don't match expected {description}"
);
}
(MoveBlockResponse::Remove(blocks), "Remove") => {
assert_eq!(
blocks.len(),
expected_blocks.len(),
"Expected {} blocks in Remove response {}",
expected_blocks.len(),
description
);
assert_eq!(
*blocks, expected_blocks,
"Remove blocks don't match expected {description}"
);
}
_ => panic!("Expected {expected_type} response, got {response:?} {description}"),
}
}
// Helper function to assert no response is received
fn assert_no_response(
rx: &mut mpsc::UnboundedReceiver<MoveBlockResponse>,
description: &str,
) {
assert!(rx.try_recv().is_err(), "Expected no response {description}",);
}
// Helper function to check if active blocks contain expected blocks with expected ref counts
fn assert_active_blocks(manager: &KvManager, expected_blocks: &[(u64, usize)]) {
assert_eq!(
......@@ -336,14 +423,12 @@ mod tests {
let block = UniqueBlock::FullBlock(id);
assert!(
manager.active_blocks().contains_key(&block),
"Block {} not found in active blocks",
id
"Block {id} not found in active blocks",
);
assert_eq!(
manager.active_blocks().get(&block),
Some(&ref_count),
"Block {} has wrong reference count",
id
"Block {id} has wrong reference count",
);
}
}
......@@ -366,17 +451,18 @@ mod tests {
let block = UniqueBlock::FullBlock(id);
assert!(
inactive_blocks.iter().any(|&b| *b == block),
"Block {} not found in inactive blocks",
id
"Block {id} not found in inactive blocks",
);
}
}
// First use blocks 0, 1, 2, 3, 4 in a batch
use_blocks(&mut manager, (0..5).collect());
assert_block_response(&mut rx, "Store", vec![0, 1, 2, 3, 4], "after first use");
// Then use blocks 0, 1, 5, 6 in a batch
use_blocks(&mut manager, vec![0, 1, 5, 6]);
assert_block_response(&mut rx, "Store", vec![5, 6], "after second use");
// Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2
assert_active_blocks(
......@@ -386,9 +472,11 @@ mod tests {
// Now destroy block 4
destroy_blocks(&mut manager, vec![4]);
assert_block_response(&mut rx, "Remove", vec![4], "after destroy block 4");
// And deref blocks 3, 2, 1, 0 in this order as a batch
deref_blocks(&mut manager, vec![0, 1, 2, 3]);
assert_no_response(&mut rx, "after deref operation");
// Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2
assert_inactive_blocks(&manager, 2, &[3, 2]);
......@@ -396,6 +484,7 @@ mod tests {
// Now destroy block 6
destroy_blocks(&mut manager, vec![6]);
assert_block_response(&mut rx, "Remove", vec![6], "after block 6 eviction");
// And deref blocks 5, 1, 0 as a batch
deref_blocks(&mut manager, vec![0, 1, 5]);
......@@ -406,6 +495,7 @@ mod tests {
// Now use 0, 1, 2, 7, 8, 9 as a batch
use_blocks(&mut manager, vec![0, 1, 2, 7, 8, 9]);
assert_block_response(&mut rx, "Store", vec![7, 8, 9], "after [7, 8, 9] use");
// Check that the inactive_blocks is size 2, and contains 3 and 5
assert_inactive_blocks(&manager, 2, &[3, 5]);
......@@ -420,8 +510,14 @@ mod tests {
// Now use blocks 10, 11, 12 as a batch
use_blocks(&mut manager, vec![10, 11, 12]);
assert_block_response(&mut rx, "Remove", vec![3], "after block 5 eviction");
assert_block_response(&mut rx, "Store", vec![10, 11, 12], "after [10, 11, 12] use");
// Check that the inactive_blocks is size 1 and contains only 5
assert_inactive_blocks(&manager, 1, &[5]);
use_blocks(&mut manager, vec![13]);
assert_block_response(&mut rx, "Remove", vec![5], "after block 5 eviction");
assert_block_response(&mut rx, "Store", vec![13], "after block 13 use");
}
}
......@@ -13,12 +13,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash,
};
pub type Token = u32;
pub type LocalBlockHash = u64;
/// A global hash identifier for blocks
pub type GlobalHash = u64;
pub type NumBlocks = usize;
......@@ -39,12 +43,19 @@ impl Default for UniqueBlock {
}
/// Represents different block movement operations in the cache
/// For Use and Promote variants, parent hash is the second field
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlock {
Use(Vec<UniqueBlock>, Option<f64>),
Use(Vec<UniqueBlock>),
Destroy(Vec<UniqueBlock>),
Deref(Vec<UniqueBlock>),
Promote(Uuid, GlobalHash),
Promote(Uuid, GlobalHash, Option<u64>),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlockResponse {
Store(Vec<GlobalHash>, Option<u64>),
Remove(Vec<GlobalHash>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
......@@ -52,15 +63,86 @@ pub struct DirectRequest {
pub tokens: Vec<Token>,
pub max_output_tokens: usize,
pub uuid: Option<Uuid>,
pub dp_rank: Option<u32>,
}
/// Represents the cost of prefilling content in the cache
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefillCost {
pub new_blocks: usize,
pub new_tokens: usize,
pub prefill_compute: f64,
}
/// Signal for output token generation with completion status
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputSignal {
pub uuid: Uuid,
pub completed: bool,
}
/// Configuration arguments for MockVllmEngine
#[derive(Debug, Clone, Serialize, Deserialize, Builder)]
#[builder(pattern = "owned", build_fn(public))]
pub struct MockEngineArgs {
#[builder(default = "16384")]
pub num_gpu_blocks: usize,
#[builder(default = "64")]
pub block_size: usize,
// This was 1024 in the past but reverted back to 256
#[builder(default = Some(256))]
pub max_num_seqs: Option<usize>,
// default for open api server, for llm class it's 16384
#[builder(default = Some(8192))]
pub max_num_batched_tokens: Option<usize>,
#[builder(default = true)]
pub enable_prefix_caching: bool,
#[builder(default = "0.01")]
pub watermark: f64,
#[builder(default = "1.0")]
pub speedup_ratio: f64,
#[builder(default = "1")]
pub dp_size: u32,
}
impl MockEngineArgs {
pub fn builder() -> MockEngineArgsBuilder {
MockEngineArgsBuilder::default()
}
}
/// Note: This assumes block_hash and tokens_hash are the same, which is not correct in rare cases
/// where the sequence-aware hash differs from the token content hash.
pub fn block_response_to_kv_event(response: MoveBlockResponse) -> KvCacheEventData {
match response {
MoveBlockResponse::Store(full_blocks, parent_hash) => {
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_hash.map(ExternalSequenceBlockHash),
blocks: full_blocks
.into_iter()
.map(|block| KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block),
tokens_hash: LocalBlockHash(block),
})
.collect(),
})
}
MoveBlockResponse::Remove(full_blocks) => KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: full_blocks
.into_iter()
.map(ExternalSequenceBlockHash)
.collect(),
}),
}
}
#[cfg(test)]
mod tests {
use super::*;
......
......@@ -40,11 +40,13 @@
//! ## NOTE
//! The current prefill and decoding time simulations are not scientific at all and are WIP
use crate::kv_router::protocols::ForwardPassMetrics;
use crate::kv_router::protocols::{ForwardPassMetrics, KvCacheEventData};
use crate::mocker::evictor::LRUEvictor;
use crate::mocker::kv_manager::KvManager;
use crate::mocker::protocols::DirectRequest;
use crate::mocker::protocols::{MoveBlock, PrefillCost, UniqueBlock};
use crate::mocker::protocols::{
block_response_to_kv_event, MoveBlock, OutputSignal, PrefillCost, UniqueBlock,
};
use crate::mocker::protocols::{DirectRequest, MockEngineArgs, MoveBlockResponse};
use crate::mocker::sequence::ActiveSequence;
use std::collections::HashMap;
use std::collections::VecDeque;
......@@ -63,8 +65,8 @@ pub enum Request {
#[derive(Default)]
struct SchedulerState {
waiting: VecDeque<Uuid>,
ready: VecDeque<Uuid>,
running: LRUEvictor<Uuid>,
prefill: VecDeque<Uuid>,
decode: LRUEvictor<Uuid>,
requests: HashMap<Uuid, Request>,
prefill_costs: HashMap<Uuid, Option<PrefillCost>>,
}
......@@ -74,61 +76,70 @@ impl SchedulerState {
fn receive(&mut self, request: DirectRequest) -> Uuid {
// Use the provided UUID if available, otherwise generate a new one
let uuid = request.uuid.unwrap_or_else(Uuid::new_v4);
// Add the request to the map and waiting queue
self.requests.insert(uuid, Request::Direct(request));
self.waiting.push_back(uuid);
uuid
}
/// Get the next UUID from ready or waiting queue and its associated Request.
/// Returns from ready if not empty, otherwise from waiting, or None if both are empty.
/// Also removes the Request from the requests HashMap.
fn next(&mut self) -> Option<(Uuid, Request)> {
let uuid = self
.ready
.pop_front()
.or_else(|| self.waiting.pop_front())?;
let request = self.requests.remove(&uuid)?;
let uuid = self.waiting.pop_front()?;
let request = self
.requests
.remove(&uuid)
.expect("Request does not exist.");
Some((uuid, request))
}
/// Move a UUID and its Request to the waiting queue (front).
fn first_in_line(&mut self, uuid: Uuid, request: Request) {
self.requests.insert(uuid, request);
self.waiting.push_front(uuid);
}
/// Move a UUID and its Request to the ready queue.
fn make_ready(&mut self, uuid: Uuid, active_seq: ActiveSequence) {
fn start_prefill(&mut self, uuid: Uuid, active_seq: ActiveSequence, cost: Option<PrefillCost>) {
self.requests.insert(uuid, Request::Active(active_seq));
self.ready.push_back(uuid);
self.prefill.push_back(uuid);
self.prefill_costs.insert(uuid, cost);
}
/// Schedule the request with the given UUID.
/// Returns the creation signal from the ActiveSequence.
fn run(&mut self, uuid: Uuid, active_seq: ActiveSequence) -> MoveBlock {
// Insert the request into the map
self.requests.insert(uuid, Request::Active(active_seq));
/// Pop from prefill queue and move to decode queue.
/// Returns the prefill_compute value if available.
fn start_decode(&mut self) -> Option<(f64, MoveBlock)> {
let uuid = self.prefill.pop_front()?;
self.decode.insert(uuid);
// Remove and extract prefill_compute from prefill_costs
let prefill_cost = self
.prefill_costs
.remove(&uuid)
.flatten()
.expect("Expects valid prefill cost.");
// Get the creation signal
let Some(Request::Active(sequence)) = self.requests.get(&uuid) else {
panic!("Failed to get ActiveSequence for UUID");
};
let Some(signal) = sequence.creation_signal() else {
panic!("Failed to get creation signal from ActiveSequence");
panic!("Request does not exist.");
};
let creation_signal = sequence
.creation_signal()
.clone()
.expect("Must have creation signal.");
// Add to running requests
self.running.insert(uuid);
signal.clone()
Some((prefill_cost.prefill_compute, creation_signal))
}
/// Set the prefill cost for a UUID
fn set_prefill_cost(&mut self, uuid: Uuid, cost: Option<PrefillCost>) {
self.prefill_costs.insert(uuid, cost);
fn run(&mut self, uuid: Uuid) -> Option<&mut ActiveSequence> {
if !self.decode.contains(&uuid) {
return None;
}
let Some(Request::Active(sequence)) = self.requests.get_mut(&uuid) else {
panic!("Request does not exist.");
};
Some(sequence)
}
/// Get the prefill compute value for a UUID if available
fn get_prefill_compute(&self, uuid: &Uuid) -> Option<f64> {
self.prefill_costs
.get(uuid)
.and_then(|cost| cost.as_ref())
.map(|cost| cost.prefill_compute)
fn num_active_requests(&self) -> usize {
self.prefill.len() + self.decode.len()
}
/// Calculate the current running batched tokens
......@@ -145,7 +156,7 @@ impl SchedulerState {
/// Remove a UUID and its associated Request from collections.
fn complete(&mut self, uuid: &Uuid) {
// println!("Request {} will complete", uuid);
self.running.remove(uuid);
self.decode.remove(uuid);
self.requests.remove(uuid);
self.prefill_costs.remove(uuid);
}
......@@ -153,76 +164,93 @@ impl SchedulerState {
/// Preempt the oldest running request by evicting it from running, resetting the sequence,
/// and adding it back to the waiting queue.
/// Returns the signal from reset_with_signal or None if no requests are running.
fn preempt(&mut self) -> Option<Vec<MoveBlock>> {
fn preempt(&mut self) -> Vec<MoveBlock> {
// Evict the oldest UUID from running
let uuid = self.running.evict()?;
eprintln!("Request {} will be preempted", uuid);
// Remove the request from the requests HashMap and ensure it's an ActiveSequence
let request = self.requests.remove(&uuid)?;
// Remove the prefill cost to force recomputation
let uuid = self
.decode
.evict()
.expect("Nothing to evict for preemption.");
let request = self
.requests
.remove(&uuid)
.expect("Request does not exist.");
self.prefill_costs.remove(&uuid);
eprintln!("Request {uuid} will be preempted");
// Extract the ActiveSequence from the Request enum
// Reset the sequence and get the new sequence and signal
// Insert the new sequence back into the requests map and add to waiting queue
let Request::Active(mut active_sequence) = request else {
panic!("Expected ActiveSequence in running queue")
};
// Reset the sequence and get the new sequence and signal
let signals = active_sequence.reset_with_signal();
// Insert the new sequence back into the requests map and add to waiting queue
self.requests.insert(uuid, Request::Active(active_sequence));
self.waiting.push_back(uuid);
// Note: For preemption, we don't compute hit rate since we don't have access to new_tokens
// and the sequence is being reset anyway. Hit rate tracking is primarily for new scheduling attempts.
self.first_in_line(uuid, Request::Active(active_sequence));
Some(signals)
signals
}
}
/// Manages scheduling of requests using KvManager resources
#[derive(Clone)]
pub struct Scheduler {
dp_rank: Option<u32>,
state: Arc<Mutex<SchedulerState>>,
kv_manager: Arc<Mutex<KvManager>>,
request_tx: mpsc::Sender<DirectRequest>,
request_tx: mpsc::UnboundedSender<DirectRequest>,
hit_rates: Arc<Mutex<VecDeque<f32>>>,
}
impl Scheduler {
/// Create a new Scheduler with the given parameters
pub fn new(
kv_capacity: usize,
watermark: f64,
block_size: u32,
chunk_size: Option<usize>,
output_tx: Option<mpsc::Sender<Uuid>>,
args: MockEngineArgs,
dp_rank: Option<u32>,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_events_tx: Option<mpsc::UnboundedSender<KvCacheEventData>>,
cancellation_token: Option<CancellationToken>,
) -> Self {
// Create KvManager internally
let kv_manager = KvManager::new(kv_capacity, block_size);
let token_capacity: usize = 8192;
let state = Arc::new(Mutex::new(SchedulerState::default()));
let kv_manager = Arc::new(Mutex::new(kv_manager));
let chunk_size = chunk_size.unwrap_or(256);
// Create internal channel for KV events only if needed
let (block_resp_tx, mut block_resp_rx) = if kv_events_tx.is_some() {
let (tx, rx) = mpsc::unbounded_channel::<MoveBlockResponse>();
(Some(tx), Some(rx))
} else {
(None, None)
};
let kv_manager = Arc::new(Mutex::new(KvManager::new_with_sender(
args.num_gpu_blocks,
args.block_size,
block_resp_tx,
)));
let hit_rates = Arc::new(Mutex::new(VecDeque::with_capacity(1000)));
// Create channel for request handling
let (request_tx, mut request_rx) = mpsc::channel::<DirectRequest>(1024);
// Assert speedup_ratio is greater than 0
assert!(
args.speedup_ratio > 0.0,
"speedup_ratio must be greater than 0, got: {}",
args.speedup_ratio
);
// Use provided cancellation token or create new one
let cancellation_token = cancellation_token.unwrap_or_default();
let token_clone = cancellation_token.clone();
// Create channel for request handling
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<DirectRequest>();
// Create a clone for the background task
let state_clone = state.clone();
let kv_manager_clone = kv_manager.clone();
let output_tx_clone = output_tx.clone();
let cancel_token_clone = cancellation_token.unwrap_or_default().clone();
let hit_rates_clone = hit_rates.clone();
// Spawn main background task with cancellation token
tokio::spawn(async move {
let mut schedule_interval = interval(Duration::from_millis(5));
let mut simulate_interval = interval(Duration::from_millis(1));
let mut schedule_interval = interval(Duration::from_secs_f64(1e-3));
let mut simulate_interval = interval(Duration::from_secs_f64(1e-4));
let mut should_schedule = true;
loop {
tokio::select! {
......@@ -234,35 +262,63 @@ impl Scheduler {
state.receive(request);
}
// Try Scheduling Requests
// Try Scheduling Requests - runs on normal interval or after simulation
_ = schedule_interval.tick() => {
// Skip if we just ran scheduling after simulation to prevent consecutive runs
if !should_schedule {
continue;
}
let mut state_guard = state_clone.lock().await;
let mut kv_manager_guard = kv_manager_clone.lock().await;
let kv_manager_guard = kv_manager_clone.lock().await;
// Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't
// schedule anymore.
let mut current_blocks = kv_manager_guard.num_active_blocks();
let mut current_tokens = state_guard.num_batched_tokens();
let mut current_seqs = state_guard.num_active_requests();
while let Some((uuid, request)) = state_guard.next() {
let active_sequence = get_active_sequence(request, block_size, chunk_size);
let active_sequence = get_active_sequence(request, args.block_size, args.enable_prefix_caching);
// Calculate token budget using new_tokens from PrefillCost
let total_prefill_tokens = state_guard.num_batched_tokens();
let tokens_budget = token_capacity.saturating_sub(total_prefill_tokens);
// Update predictive budgets
let prefill_cost = kv_manager_guard.get_prefill_cost(&active_sequence);
let total_tokens = active_sequence.len();
let new_blocks = (total_tokens + 1) / args.block_size; // this is conservative, assumes no cache hit
let new_tokens = prefill_cost.new_tokens;
current_blocks += new_blocks;
current_tokens += new_tokens;
current_seqs += 1;
// Check if it can be scheduled
let Some(prefill_cost) = kv_manager_guard.try_schedule(&active_sequence, watermark, tokens_budget) else {
state_guard.make_ready(uuid, active_sequence);
let under_block_budget = current_blocks as f64 <= (1. - args.watermark) * kv_manager_guard.max_capacity() as f64;
let under_token_budget = args.max_num_batched_tokens.is_none_or(|limit| current_tokens <= limit);
let under_seq_budget = args.max_num_seqs.is_none_or(|limit| current_seqs <= limit);
// Cannot schedule, put first in line instead
if !(under_block_budget && under_token_budget && under_seq_budget) {
state_guard.first_in_line(uuid, Request::Active(active_sequence));
break;
};
}
// Get creation signal and schedule the request
let signal = state_guard.run(uuid, active_sequence);
kv_manager_guard.process(&signal);
state_guard.set_prefill_cost(uuid, Some(prefill_cost));
// Compute and store hit rate
let hit_rate = if !active_sequence.is_empty() { 1.0 - (new_tokens as f32 / active_sequence.len() as f32) } else { 0.0 };
{
let mut hit_rates_guard = hit_rates_clone.lock().await;
hit_rates_guard.push_back(hit_rate);
if hit_rates_guard.len() > 1000 {
hit_rates_guard.pop_front();
}
}
state_guard.start_prefill(uuid, active_sequence, Some(prefill_cost));
should_schedule = false;
}
}
// Check for cancellation
_ = token_clone.cancelled() => {
_ = cancel_token_clone.cancelled() => {
break;
}
......@@ -271,75 +327,84 @@ impl Scheduler {
let mut state_guard = state_clone.lock().await;
let mut kv_manager_guard = kv_manager_clone.lock().await;
// Base time needed for decoding (assumed memory bound on KV cache)
let active_tokens = kv_manager_guard.num_active_blocks() * (block_size as usize);
// TODO: 2 is a dummy / magic scaling factor
let mut generation_time = Duration::from_micros((active_tokens / 2) as u64);
// Base time needed for decoding using active percentage and quadratic formula
let active_perc = kv_manager_guard.get_active_perc();
let decoding_time = -5.47 * active_perc.powi(2) + 43.88 * active_perc + 19.44;
let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0);
// Process each running request
let uuids: Vec<Uuid> = state_guard.running.keys().cloned().collect();
for uuid in uuids {
// Check if UUID is still in running_requests, if not skip this iteration
if !state_guard.running.contains(&uuid) {
continue;
// Process prefilling
while let Some((prefill_compute, creation_signal)) = state_guard.start_decode() {
// NOTE: Prefill cost/time is always incremented for new blocks, even if they
// could be cached by other requests in the same batch. This matches vLLM behavior.
total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
let prefill_success = process_signals(&mut kv_manager_guard, std::slice::from_ref(&creation_signal));
if !prefill_success {
panic!("Block allocation for prefilling cannot fail.");
}
// Get prefill compute value first
let prefill_compute = state_guard.get_prefill_compute(&uuid);
// Get the active sequence for this UUID
let sequence = state_guard.requests.get_mut(&uuid)
.and_then(|req| if let Request::Active(seq) = req { Some(seq) } else { None })
.expect("UUID in running_requests must have a corresponding active sequence");
// Drain KV events and forward to relay after prefill signal processing
if let (Some(ref relay_tx), Some(ref mut rx)) = (&kv_events_tx, &mut block_resp_rx) {
while let Ok(event) = rx.try_recv() {
let _ = relay_tx.send(block_response_to_kv_event(event));
}
}
}
// Generate token and get signals
// Process decoding
let uuids: Vec<Uuid> = state_guard.decode.keys().cloned().collect();
if !uuids.is_empty() {should_schedule = true};
for uuid in uuids {
let Some(sequence) = state_guard.run(uuid) else {
continue;
};
let signals = sequence.generate();
// Accumulate sleep duration based on prefill_compute if available
// prefill compute = (cached_tokens + new_tokens) * new_tokens
let sleep_ms = if let Some(compute) = prefill_compute {
// TODO: 1024 is a dummy / magic scaling factor
(compute / 1024.0) as u64
} else { 0 };
generation_time += Duration::from_micros(sleep_ms);
// Process all signals with the KvManager
// Handling of preemption on failure
if !process_signals(&mut kv_manager_guard, &signals) {
sequence.pop(); // revert the failed generation op
// free_signal derefs the preempted blocks
let Some(free_signal) = state_guard.preempt() else {
panic!("Failed to acquire signal to free KV blocks from preemption");
};
for signal in free_signal {
for signal in state_guard.preempt() {
kv_manager_guard.process(&signal);
}
continue;
}
// Send UUID notification for each generated token
// TODO: hook this up to an AsyncEngine
if let Some(tx) = &output_tx_clone {
let _ = tx.try_send(uuid);
// Drain KV events and forward to relay after decode signal processing
if let (Some(ref relay_tx), Some(ref mut rx)) = (&kv_events_tx, &mut block_resp_rx) {
while let Ok(event) = rx.try_recv() {
let _ = relay_tx.send(block_response_to_kv_event(event));
}
}
// Check if we're done after generating
if sequence.generated_tokens() >= sequence.max_output_tokens() {
state_guard.complete(&uuid);
continue;
// Check completion and send notification
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
let should_output = sequence.generated_tokens() > sequence.already_generated_tokens();
let mut send_failed = false;
if should_output {
send_failed = output_tx_clone.as_ref().is_some_and(|tx| {
tx.send(OutputSignal { uuid, completed: is_complete }).is_err()
});
}
if send_failed {
for signal in &sequence.free_signal() {
kv_manager_guard.process(signal);
}
}
// Transition to decode (no prefill cost)
if sequence.generated_tokens() == 1 {
state_guard.set_prefill_cost(uuid, None);
if send_failed || is_complete {
state_guard.complete(&uuid);
continue;
}
}
// Sleep once for the accumulated duration
if generation_time.as_millis() > 0 {
tokio::time::sleep(generation_time).await;
// Sleep once for the adjusted duration
drop(kv_manager_guard);
drop(state_guard);
let adjusted_time = Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
if adjusted_time.as_millis() > 0 {
tokio::time::sleep(adjusted_time).await;
}
}
}
......@@ -347,15 +412,22 @@ impl Scheduler {
});
Self {
dp_rank,
state,
kv_manager,
request_tx,
hit_rates,
}
}
/// Add a new request to the waiting queue
pub async fn receive(&self, request: DirectRequest) {
let _ = self.request_tx.send(request).await;
let _ = self.request_tx.send(request);
}
/// Expose the sender
pub fn request_sender(&self) -> mpsc::UnboundedSender<DirectRequest> {
self.request_tx.clone()
}
/// Get the count of waiting requests
......@@ -367,7 +439,7 @@ impl Scheduler {
/// Get the count of running requests
pub async fn running_count(&self) -> usize {
let state = self.state.lock().await;
state.running.len()
state.decode.len()
}
/// Get the current capacity of the KvManager
......@@ -378,35 +450,53 @@ impl Scheduler {
/// Returns forward pass metrics for monitoring purposes
pub async fn get_forward_pass_metrics(&self) -> ForwardPassMetrics {
// Acquire all locks in consistent order: state -> kv_manager -> hit_rates
let state = self.state.lock().await;
let kv_manager = self.kv_manager.lock().await;
let hit_rates_guard = self.hit_rates.lock().await;
// Get the active blocks and total capacity from KvManager
// Get state metrics
let request_active_slots = state.decode.len() as u64;
let num_requests_waiting = state.waiting.len() as u64;
// Get KV manager metrics
let active_blocks_count = kv_manager.active_blocks().len() as u64;
let total_capacity = kv_manager.max_capacity() as u64;
// Calculate GPU cache usage percentage
let gpu_cache_usage_perc = if total_capacity > 0 {
active_blocks_count as f32 / total_capacity as f32
} else {
0.0
};
// Get hit rate metrics
let gpu_prefix_cache_hit_rate = if hit_rates_guard.is_empty() {
0.0
} else {
let sum: f32 = hit_rates_guard.iter().sum();
sum / hit_rates_guard.len() as f32
};
ForwardPassMetrics {
data_parallel_rank: None, // Default for backwards compatibility
request_active_slots: state.running.len() as u64,
request_total_slots: 420, // Dummy value as specified
data_parallel_rank: self.dp_rank,
request_active_slots,
// vllm max_num_seqs for gpu >= 70 vram, otherwise 256, fallback is 128
request_total_slots: 1024,
kv_active_blocks: active_blocks_count,
kv_total_blocks: total_capacity,
num_requests_waiting: state.waiting.len() as u64,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate: 0.0, // Placeholder value as specified
gpu_prefix_cache_hit_rate,
}
// Guards drop naturally here in reverse order (LIFO): hit_rates_guard, kv_manager, state
}
}
/// Convert a Request to an ActiveSequence
fn get_active_sequence(request: Request, block_size: u32, chunk_size: usize) -> ActiveSequence {
fn get_active_sequence(
request: Request,
block_size: usize,
enable_prefix_caching: bool,
) -> ActiveSequence {
if let Request::Active(active_seq) = request {
return active_seq;
}
......@@ -419,7 +509,7 @@ fn get_active_sequence(request: Request, block_size: u32, chunk_size: usize) ->
direct_request.tokens,
direct_request.max_output_tokens,
Some(block_size),
Some(chunk_size),
enable_prefix_caching,
)
}
......@@ -440,7 +530,7 @@ fn process_signals(
}
// Check we have a Use signal with blocks
let MoveBlock::Use(blocks, _) = signal else {
let MoveBlock::Use(blocks) = signal else {
panic!("Failed signal is Invalid. Has to fail on generation signal.");
};
......@@ -467,32 +557,37 @@ mod tests {
use std::time::Duration;
#[rstest]
#[case::random(false)]
#[case::caching(true)]
#[case::random_no_prefix_caching(false, false)]
#[case::random_with_prefix_caching(false, true)]
#[case::caching_no_prefix_caching(true, false)]
#[case::caching_with_prefix_caching(true, true)]
#[tokio::test]
async fn test_scheduler_token_generation_patterns(#[case] use_shared_tokens: bool) {
async fn test_scheduler_token_generation_patterns(
#[case] use_shared_tokens: bool,
#[case] enable_prefix_caching: bool,
) {
std::env::set_var("RUST_LOG", "debug");
let kv_capacity: usize = 500;
let watermark: f64 = 0.01; // 1% watermark
let block_size: u32 = 64;
let chunk_size: usize = 256;
let block_size: usize = 64;
let num_requests: usize = 100;
let input_len: usize = 1000;
let max_output_tokens: usize = 100;
// Create channel for token output
let (output_tx, mut output_rx) = mpsc::channel::<Uuid>(1024);
// Create scheduler with internal KvManager
let scheduler = Scheduler::new(
kv_capacity,
watermark,
block_size,
Some(chunk_size),
Some(output_tx),
None,
);
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
// Create scheduler args using builder - now including enable_prefix_caching
let args = MockEngineArgs::builder()
.num_gpu_blocks(kv_capacity)
.block_size(block_size)
.speedup_ratio(10.0)
.enable_prefix_caching(enable_prefix_caching)
.build()
.unwrap();
// Create scheduler with new args struct
let scheduler = Scheduler::new(args, None, Some(output_tx), None, None);
// Create shared tokens for caching case
let shared_tokens = if use_shared_tokens {
......@@ -523,6 +618,7 @@ mod tests {
tokens: input_tokens,
max_output_tokens,
uuid: None,
dp_rank: None,
};
scheduler.receive(request).await;
}
......@@ -547,7 +643,7 @@ mod tests {
// Manual debug ticker that prints forward pass metrics
_ = debug_interval.tick() => {
let _metrics = scheduler.get_forward_pass_metrics().await;
// println!("Forward Pass Metrics: {:#?}", _metrics);
println!("Forward Pass Metrics: {_metrics:#?}");
}
Some(_) = output_rx.recv() => {
......@@ -566,21 +662,177 @@ mod tests {
// Calculate and print elapsed time
let elapsed = start_time.elapsed();
println!(
"Test completed in: {:?} for {} case",
"Test completed in: {:?} for {} case with prefix_caching={}",
elapsed,
if use_shared_tokens {
"caching"
} else {
"random"
}
},
enable_prefix_caching
);
// Assert that we received the expected number of tokens
assert!(
received_tokens > expected_tokens,
"Received {} tokens but expected more than {}",
received_tokens,
expected_tokens
received_tokens == expected_tokens,
"Received {received_tokens} tokens but expected exactly {expected_tokens}"
);
}
#[tokio::test]
async fn test_cache_hit_rate_with_identical_requests() {
let block_size: usize = 64;
let max_output_tokens: usize = 10;
let speedup_ratio = 10.0;
let num_requests = 10;
let token_length = 65;
// Create channel for token output
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
// Create scheduler args
let args = MockEngineArgs::builder()
.num_gpu_blocks(100) // Large enough to not be a constraint
.block_size(block_size)
.speedup_ratio(speedup_ratio)
.build()
.unwrap();
// Create scheduler
let scheduler = Scheduler::new(args, None, Some(output_tx), None, None);
// Create identical tokens for all requests
let identical_tokens: Vec<u32> = (0..token_length).map(|i| i as u32).collect();
// Send all requests with identical tokens
for _ in 0..num_requests {
let request = DirectRequest {
tokens: identical_tokens.clone(),
max_output_tokens,
uuid: None,
dp_rank: None,
};
scheduler.receive(request).await;
// Sleep for 0.1 second after each request
tokio::time::sleep(Duration::from_millis(100)).await;
}
// Collect all generated tokens
let mut received_tokens = 0;
// Set up a timeout that resets to 0.5 seconds on each received token
let timeout = tokio::time::sleep(Duration::from_millis(500));
tokio::pin!(timeout);
// Set up debug ticker interval
let mut debug_interval = interval(Duration::from_millis(500));
loop {
tokio::select! {
biased;
// Manual debug ticker that prints forward pass metrics
_ = debug_interval.tick() => {
let _metrics = scheduler.get_forward_pass_metrics().await;
println!("Forward Pass Metrics: {_metrics:#?}");
}
Some(_signal) = output_rx.recv() => {
received_tokens += 1;
// Reset timeout whenever we receive a token
timeout.set(tokio::time::sleep(Duration::from_millis(500)));
}
_ = &mut timeout => {
// Break when timeout occurs (no more tokens for 0.5 seconds)
break;
}
}
}
// Verify forward pass metrics
let metrics = scheduler.get_forward_pass_metrics().await;
assert_eq!(
metrics.num_requests_waiting, 0,
"Expected no waiting requests, got {}",
metrics.num_requests_waiting
);
assert!(
metrics.gpu_prefix_cache_hit_rate > 0.8,
"Expected cache hit rate > 0.8, got {}",
metrics.gpu_prefix_cache_hit_rate
);
println!(
"Test passed! Cache hit rate: {:.3}",
metrics.gpu_prefix_cache_hit_rate
);
println!("Received {received_tokens} tokens");
}
#[tokio::test]
async fn test_receiver_drop_cleans_up_resources() {
let block_size: usize = 64;
let input_tokens = 256;
let max_output_tokens = 200; // More than we'll receive
// Create channel for token output
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
// Create scheduler args
let args = MockEngineArgs::builder()
.num_gpu_blocks(10) // Enough for 256 tokens (4 blocks)
.block_size(block_size)
.speedup_ratio(100.0) // Fast simulation
.build()
.unwrap();
// Create scheduler
let scheduler = Scheduler::new(args, None, Some(output_tx), None, None);
// Create request with 256 tokens
let tokens: Vec<u32> = (0..input_tokens).map(|i| i as u32).collect();
let request = DirectRequest {
tokens,
max_output_tokens,
uuid: None,
dp_rank: None,
};
scheduler.receive(request).await;
// Receive exactly 129 tokens
let mut received_count = 0;
while received_count < 129 {
if let Some(_signal) = output_rx.recv().await {
received_count += 1;
} else {
panic!("Channel closed before receiving 129 tokens");
}
}
// Drop the receiver immediately
drop(output_rx);
// Wait for 1 second to allow cleanup
tokio::time::sleep(Duration::from_secs(1)).await;
// Check forward pass metrics
let metrics = scheduler.get_forward_pass_metrics().await;
assert_eq!(
metrics.gpu_cache_usage_perc,
0.0,
"Expected GPU cache usage to be 0%, got {}%",
metrics.gpu_cache_usage_perc * 100.0
);
assert_eq!(
metrics.kv_active_blocks, 0,
"Expected 0 active blocks, got {}",
metrics.kv_active_blocks
);
}
}
......@@ -23,16 +23,23 @@ use uuid;
fn create_unique_blocks_from_sequence(
tokens: &TokenBlockSequence,
uuid: Option<uuid::Uuid>,
block_size: u32,
block_size: usize,
enable_prefix_caching: bool,
) -> Vec<UniqueBlock> {
let mut unique_blocks: Vec<UniqueBlock> = tokens
.blocks()
.iter()
.map(|block| UniqueBlock::FullBlock(block.sequence_hash()))
.map(|block| {
if enable_prefix_caching {
UniqueBlock::FullBlock(block.sequence_hash())
} else {
UniqueBlock::FullBlock(random::<u64>())
}
})
.collect();
// Only push the partial block if tokens count isn't a multiple of block_size
if tokens.total_tokens() % (block_size as usize) != 0 {
if tokens.total_tokens() % block_size != 0 {
unique_blocks.push(match uuid {
Some(uuid) => UniqueBlock::PartialBlock(uuid),
None => UniqueBlock::default(),
......@@ -50,10 +57,7 @@ pub struct ActiveSequence {
tokens: TokenBlockSequence,
#[getter(copy)]
block_size: u32,
#[getter(copy)]
chunk_size: usize, // TODO: not actually used
block_size: usize,
#[getter(copy)]
max_output_tokens: usize,
......@@ -61,10 +65,16 @@ pub struct ActiveSequence {
#[getter(copy)]
generated_tokens: usize,
#[getter(copy)]
already_generated_tokens: usize,
#[getter(copy)]
num_input_tokens: usize,
creation_signal: Option<MoveBlock>,
#[getter(copy)]
enable_prefix_caching: bool,
}
impl ActiveSequence {
......@@ -72,32 +82,33 @@ impl ActiveSequence {
pub fn new(
tokens: Vec<u32>,
max_output_tokens: usize,
block_size: Option<u32>,
chunk_size: Option<usize>,
block_size: Option<usize>,
enable_prefix_caching: bool,
) -> Self {
let block_size = block_size.unwrap_or(64);
assert!(block_size > 1, "block_size must be greater than 1");
let chunk_size = chunk_size.unwrap_or(256);
let num_input_tokens = tokens.len();
let tokens = Tokens::from(tokens).into_sequence(block_size, None);
let unique_blocks = create_unique_blocks_from_sequence(&tokens, None, block_size);
let creation_signal = Some(MoveBlock::Use(unique_blocks.clone(), None));
let tokens = Tokens::from(tokens).into_sequence(block_size as u32, None);
let unique_blocks =
create_unique_blocks_from_sequence(&tokens, None, block_size, enable_prefix_caching);
let creation_signal = Some(MoveBlock::Use(unique_blocks.clone()));
Self {
unique_blocks,
tokens,
block_size,
chunk_size,
max_output_tokens,
generated_tokens: 0,
already_generated_tokens: 0,
num_input_tokens,
creation_signal,
enable_prefix_caching,
}
}
pub fn extra_tokens(&self) -> u32 {
(self.len() % self.block_size as usize) as u32
(self.len() % self.block_size) as u32
}
pub fn len(&self) -> usize {
......@@ -112,20 +123,31 @@ impl ActiveSequence {
pub fn new_with_signal(
tokens: Vec<u32>,
max_output_tokens: usize,
block_size: Option<u32>,
chunk_size: Option<usize>,
block_size: Option<usize>,
enable_prefix_caching: bool,
) -> (Self, Option<MoveBlock>) {
let mut sequence = Self::new(tokens, max_output_tokens, block_size, chunk_size);
let mut sequence = Self::new(tokens, max_output_tokens, block_size, enable_prefix_caching);
let signal = sequence.creation_signal.take();
(sequence, signal)
}
/// Get the parent hash from the second-to-last block if it exists and is a FullBlock
fn get_parent_hash(&self) -> Option<u64> {
if self.unique_blocks.len() < 2 {
return None;
}
match &self.unique_blocks[self.unique_blocks.len() - 2] {
UniqueBlock::FullBlock(hash) => Some(*hash),
_ => panic!("Cannot have a partial block as parent"),
}
}
/// Push a token to the sequence
pub fn push(&mut self, token: u32) -> Option<Vec<MoveBlock>> {
self.tokens.append(token).expect("Token push failed.");
self.generated_tokens += 1;
if self.len() % (self.block_size as usize) != 1 {
if self.len() % self.block_size != 1 {
return None;
}
......@@ -135,16 +157,24 @@ impl ActiveSequence {
// Replace last partial block with full block if it exists
if let Some(UniqueBlock::PartialBlock(uuid)) = self.unique_blocks.last().cloned() {
let last_block_hash = self.tokens.last_complete_block().unwrap().sequence_hash();
let last_block_hash = if self.enable_prefix_caching {
self.tokens.last_complete_block().unwrap().sequence_hash()
} else {
random::<u64>()
};
self.unique_blocks.pop();
self.unique_blocks
.push(UniqueBlock::FullBlock(last_block_hash));
signals.push(MoveBlock::Promote(uuid, last_block_hash));
signals.push(MoveBlock::Promote(
uuid,
last_block_hash,
self.get_parent_hash(),
));
}
let new_partial_block = UniqueBlock::default();
self.unique_blocks.push(new_partial_block.clone());
signals.push(MoveBlock::Use(vec![new_partial_block], None));
signals.push(MoveBlock::Use(vec![new_partial_block]));
Some(signals)
}
......@@ -204,15 +234,19 @@ impl ActiveSequence {
}
/// Reset the sequence to its initial state and return the free signals from freeing current blocks
/// maintaining the uuid of the last partial block
pub fn reset_with_signal(&mut self) -> Vec<MoveBlock> {
let free_signal = self.free_signal();
self.tokens.truncate(self.num_input_tokens).unwrap();
self.unique_blocks =
create_unique_blocks_from_sequence(&self.tokens, None, self.block_size);
self.unique_blocks = create_unique_blocks_from_sequence(
&self.tokens,
None,
self.block_size,
self.enable_prefix_caching,
);
self.already_generated_tokens = self.generated_tokens.max(self.already_generated_tokens);
self.generated_tokens = 0;
self.creation_signal = Some(MoveBlock::Use(self.unique_blocks.clone(), None));
self.creation_signal = Some(MoveBlock::Use(self.unique_blocks.clone()));
free_signal
}
......@@ -223,7 +257,7 @@ impl ActiveSequence {
self.generated_tokens = self.generated_tokens.saturating_sub(1);
// Reverts to the last full block
if self.tokens.total_tokens() % (self.block_size as usize) == 0 {
if self.tokens.total_tokens() % self.block_size == 0 {
self.unique_blocks.pop();
}
}
......@@ -238,14 +272,14 @@ mod tests {
// Create a sequence with block size 16 initialized with tokens [0..15]
let initial_tokens: Vec<u32> = (0..15).collect();
let (mut seq1, signal1) =
ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), Some(256));
ActiveSequence::new_with_signal(initial_tokens, 100, Some(16), true);
assert_eq!(seq1.num_input_tokens(), 15);
assert_eq!(seq1.len(), 15);
// Check that we got a Use signal
assert!(signal1.is_some());
match &signal1 {
Some(MoveBlock::Use(blocks, _)) => {
Some(MoveBlock::Use(blocks)) => {
assert_eq!(blocks.len(), 1);
}
_ => panic!("Expected Use signal"),
......@@ -264,33 +298,31 @@ mod tests {
let signal_16 = signal_16.unwrap();
assert_eq!(signal_16.len(), 2);
// First signal should be Promote for the previous block
match &signal_16[0] {
MoveBlock::Promote(_, _, parent_hash) => {
assert_eq!(*parent_hash, None);
}
_ => panic!("Expected Promote signal as second signal"),
}
// Second signal should be Use for new partial block
match &signal_16[1] {
MoveBlock::Use(blocks, _) => {
MoveBlock::Use(blocks) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected Use signal as first signal"),
}
// First signal should be Promote for the previous block
match &signal_16[0] {
MoveBlock::Promote(uuid, _) => {
// The uuid is generated dynamically, so we just check it exists
let _ = uuid;
}
_ => panic!("Expected Promote signal as second signal"),
}
// Verify state after pushing tokens
assert_eq!(seq1.unique_blocks().len(), 2); // One full block and one partial block
assert_eq!(seq1.len(), 17);
assert_eq!(seq1.len() % (seq1.block_size() as usize), 1);
assert_eq!(seq1.len() % seq1.block_size(), 1);
// Create another sequence with block size 16 initialized with tokens [0..17]
let extended_tokens: Vec<u32> = (0..16).collect();
let (mut seq2, _) =
ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), Some(256));
let (mut seq2, _) = ActiveSequence::new_with_signal(extended_tokens, 100, Some(16), true);
seq2.push(16);
seq2.pop();
seq2.push(16);
......@@ -335,12 +367,12 @@ mod tests {
"seq2 should have exactly 3 blocks"
);
assert_eq!(
seq1.len() % (seq1.block_size() as usize),
seq1.len() % seq1.block_size(),
1,
"seq1 should have 1 partial token"
);
assert_eq!(
seq2.len() % (seq2.block_size() as usize),
seq2.len() % seq2.block_size(),
1,
"seq2 should have 1 partial token"
);
......@@ -352,9 +384,38 @@ mod tests {
"First two blocks should be identical"
);
// Push tokens 34..47 to seq1
for token in 33..48 {
seq1.push(token);
}
// Push token 48 and get the signal - this completes the block and triggers signals
let signal = seq1.push(48);
let signal = signal.unwrap();
// Check that signal[0] is promote
match &signal[0] {
MoveBlock::Promote(_, _, parent_hash) => {
// Check that the parent_hash matches unique_blocks[1], which should be a full block
if let UniqueBlock::FullBlock(expected_hash) = seq1.unique_blocks()[1] {
assert_eq!(
*parent_hash,
Some(expected_hash),
"Parent hash should match unique_blocks[1]"
);
} else {
panic!("unique_blocks[1] should be a full block");
}
}
_ => panic!("Expected Promote signal as first signal"),
}
// Reset seq1 and check that it equals the original clone
let free_signals = seq1.reset_with_signal();
// 49 - 15 generated tokens
assert_eq!(seq1.already_generated_tokens, 34);
// Verify the reset signals include proper cleanup events
assert!(!free_signals.is_empty());
}
......@@ -363,13 +424,12 @@ mod tests {
fn test_active_sequence_generate_signals() {
// Create a sequence with block size 16, max_output_tokens 4, initialized with tokens [0..14)
let initial_tokens: Vec<u32> = (0..14).collect();
let (mut seq, signal) =
ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), Some(256));
let (mut seq, signal) = ActiveSequence::new_with_signal(initial_tokens, 5, Some(16), true);
// Initial signal - should have received a Use signal for the partial block
assert!(signal.is_some());
match signal {
Some(MoveBlock::Use(blocks, _)) => {
Some(MoveBlock::Use(blocks)) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
......@@ -385,25 +445,23 @@ mod tests {
let signals_second = seq.generate();
assert_eq!(signals_second.len(), 2);
// First signal should be Use for new partial block
// First signal should be Promote
match &signals_second[0] {
MoveBlock::Promote(_, _, parent_hash) => {
assert_eq!(*parent_hash, None);
}
_ => panic!("Expected Promote signal as first signal after second token"),
}
// Second signal should be Use for new partial block
match &signals_second[1] {
MoveBlock::Use(blocks, _) => {
MoveBlock::Use(blocks) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
_ => panic!("Expected Use signal as second signal after second token"),
}
// Second signal should be Promote
match &signals_second[0] {
MoveBlock::Promote(uuid, hash) => {
// The uuid and hash values are generated dynamically, so we just check the event type
let _ = uuid;
let _ = hash;
}
_ => panic!("Expected Promote signal as first signal after second token"),
}
// Generate fourth token - should not trigger new signals as it's adding to partial block
let signals_third = seq.generate();
assert_eq!(signals_third.len(), 0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment