Unverified Commit cc4c3516 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: mocker disagg (#3833)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent e3346dab
......@@ -196,14 +196,20 @@ for i in $(seq 1 $NUM_WORKERS); do
# Run mocker engine (no GPU assignment needed)
MOCKER_ARGS=()
MOCKER_ARGS+=("--model-path" "$MODEL_PATH")
MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate")
if [ "$DATA_PARALLEL_SIZE" -gt 1 ]; then
MOCKER_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE")
fi
# Set endpoint based on worker mode
if [ "$MODE" = "prefill" ]; then
MOCKER_ARGS+=("--endpoint" "dyn://test.prefill.generate")
MOCKER_ARGS+=("--is-prefill-worker")
elif [ "$MODE" = "decode" ]; then
MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate")
MOCKER_ARGS+=("--is-decode-worker")
else
MOCKER_ARGS+=("--endpoint" "dyn://test.mocker.generate")
fi
if [ "$DATA_PARALLEL_SIZE" -gt 1 ]; then
MOCKER_ARGS+=("--data-parallel-size" "$DATA_PARALLEL_SIZE")
fi
MOCKER_ARGS+=("${EXTRA_ARGS[@]}")
......
......@@ -12,6 +12,7 @@ from . import __version__
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.backend.generate"
DEFAULT_PREFILL_ENDPOINT = f"dyn://{DYN_NAMESPACE}.prefill.generate"
logger = logging.getLogger(__name__)
......@@ -85,8 +86,8 @@ def parse_args():
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string (default: {DEFAULT_ENDPOINT})",
default=None,
help=f"Dynamo endpoint string (default: {DEFAULT_ENDPOINT} for aggregated/decode, {DEFAULT_PREFILL_ENDPOINT} for prefill)",
)
parser.add_argument(
"--model-name",
......@@ -199,4 +200,14 @@ def parse_args():
args = parser.parse_args()
validate_worker_type_args(args)
# Set endpoint default based on worker type if not explicitly provided
if args.endpoint is None:
if args.is_prefill_worker:
args.endpoint = DEFAULT_PREFILL_ENDPOINT
logger.debug(f"Using default prefill endpoint: {args.endpoint}")
else:
args.endpoint = DEFAULT_ENDPOINT
logger.debug(f"Using default endpoint: {args.endpoint}")
return args
......@@ -60,7 +60,6 @@ pub enum DynamoLlmResult {
pub unsafe extern "C" fn dynamo_llm_init(
namespace_c_str: *const c_char,
component_c_str: *const c_char,
worker_id: i64,
kv_block_size: u32,
) -> DynamoLlmResult {
initialize_tracing();
......@@ -102,7 +101,7 @@ pub unsafe extern "C" fn dynamo_llm_init(
match result {
Ok(_) => match KV_PUB.get_or_try_init(move || {
dynamo_create_kv_publisher(namespace, component, worker_id, kv_block_size)
dynamo_create_kv_publisher(namespace, component, kv_block_size)
}) {
Ok(_) => DynamoLlmResult::OK,
Err(e) => {
......@@ -144,7 +143,6 @@ pub extern "C" fn dynamo_llm_load_publisher_create() -> DynamoLlmResult {
fn dynamo_create_kv_publisher(
namespace: String,
component: String,
worker_id: i64,
kv_block_size: u32,
) -> Result<KvEventPublisher, anyhow::Error> {
tracing::info!("Creating KV Publisher for model: {}", component);
......@@ -154,7 +152,7 @@ fn dynamo_create_kv_publisher(
{
Ok(drt) => {
let backend = drt.namespace(namespace)?.component(component)?;
KvEventPublisher::new(backend, worker_id as u64, kv_block_size, None)
KvEventPublisher::new(backend, kv_block_size, None)
}
Err(e) => Err(e),
}
......
......@@ -143,7 +143,6 @@ impl ZmqKvEventPublisher {
fn new(component: Component, config: ZmqKvEventPublisherConfig) -> PyResult<Self> {
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
component.inner,
config.worker_id,
config.kv_block_size as u32,
Some(KvEventSourceConfig::Zmq {
endpoint: config.zmq_endpoint,
......@@ -239,20 +238,14 @@ pub(crate) struct KvEventPublisher {
#[pymethods]
impl KvEventPublisher {
#[new]
#[pyo3(signature = (component, worker_id, kv_block_size, dp_rank=0))]
fn new(
component: Component,
worker_id: WorkerId,
kv_block_size: usize,
dp_rank: DpRank,
) -> PyResult<Self> {
#[pyo3(signature = (component, kv_block_size, dp_rank=0))]
fn new(component: Component, kv_block_size: usize, dp_rank: DpRank) -> PyResult<Self> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
}
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
component.inner,
worker_id,
kv_block_size as u32,
None,
)
......
......@@ -354,9 +354,10 @@ impl RadixTree {
None => {
tracing::warn!(
worker_id = worker.worker_id.to_string(),
dp_rank = ?worker.dp_rank,
dp_rank = worker.dp_rank,
id,
parent_hash = ?op.parent_hash,
num_blocks = op.blocks.len(),
"Failed to find parent block; skipping store operation"
);
return Err(KvCacheEventError::ParentBlockNotFound);
......@@ -412,8 +413,10 @@ impl RadixTree {
Some(entry) => entry.clone(),
None => {
tracing::warn!(
worker_id = worker_id.to_string(),
worker_id = worker.worker_id.to_string(),
dp_rank = worker.dp_rank,
id,
block_hash = ?block,
"Failed to find block to remove; skipping remove operation"
);
return Err(KvCacheEventError::BlockNotFound);
......
......@@ -213,8 +213,12 @@ impl
let (req, context) = request.into_parts();
let request_id = context.id().to_string();
// Prepare prefill request with linked context for cancellation propagation
let prefill_req = req.clone();
// Save original max_tokens for decode
let original_max_tokens = req.stop_conditions.max_tokens;
// Prepare prefill request with max_tokens = 1
let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1);
let prefill_context = Context::with_id(prefill_req, request_id.clone());
// Link the prefill context as a child so that kill signals propagate
......@@ -230,6 +234,8 @@ impl
// Update request with disaggregated_params and router config
let mut decode_req = req;
decode_req.disaggregated_params = Some(disaggregated_params);
// Restore original max_tokens for decode
decode_req.stop_conditions.max_tokens = original_max_tokens;
// Set router_config_override for decode: overlap_score_weight = 0
let existing_override = decode_req.router_config_override.take();
......
......@@ -97,7 +97,6 @@ pub struct KvEventPublisher {
impl KvEventPublisher {
pub fn new(
component: Component,
worker_id: u64,
kv_block_size: u32,
source_config: Option<KvEventSourceConfig>,
) -> Result<Self> {
......@@ -105,6 +104,9 @@ impl KvEventPublisher {
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
// Infer worker_id from component's connection
let worker_id = component.drt().connection_id();
// Create our event source (if any)
let mut source = None;
if let Some(config) = source_config {
......
......@@ -5,5 +5,6 @@ pub mod engine;
pub mod evictor;
pub mod kv_manager;
pub mod protocols;
pub mod running_mean;
pub mod scheduler;
pub mod sequence;
......@@ -8,7 +8,7 @@
use crate::kv_router::publisher::WorkerMetricsPublisher;
use crate::mocker::protocols::DirectRequest;
use crate::mocker::protocols::{MockEngineArgs, OutputSignal};
use crate::mocker::protocols::{MockEngineArgs, OutputSignal, WorkerType};
use crate::mocker::scheduler::Scheduler;
use crate::protocols::TokenIdType;
use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
......@@ -23,9 +23,6 @@ use dynamo_runtime::{
pipeline::{AsyncEngine, Error, ManyOut, ResponseStream, SingleIn, async_trait},
traits::DistributedRuntimeProvider,
};
use crate::kv_router::protocols::{KvCacheEvent, KvCacheEventData};
use crate::kv_router::publisher::KvEventPublisher;
use futures::StreamExt;
use rand::Rng;
use std::collections::HashMap;
......@@ -37,10 +34,9 @@ use uuid::Uuid;
pub const MOCKER_COMPONENT: &str = "mocker";
/// Generate a random token ID from 1k to 5k
fn generate_random_token() -> TokenIdType {
let mut rng = rand::rng();
rng.random_range(1000..5000)
rng.random_range(1000..2000)
}
/// AsyncEngine wrapper around the Scheduler that generates random character tokens
......@@ -71,26 +67,25 @@ impl MockVllmEngine {
tracing::info!("Engine startup simulation completed");
}
let (schedulers, kv_event_receiver) = self.start_schedulers(
// Pass component to schedulers only if prefix caching is enabled and not a decode worker
let scheduler_component = if self.engine_args.enable_prefix_caching
&& self.engine_args.worker_type != WorkerType::Decode
{
Some(component.clone())
} else {
None
};
let schedulers = self.start_schedulers(
self.engine_args.clone(),
self.active_requests.clone(),
scheduler_component,
cancel_token.clone(),
);
Self::start_metrics_publishing(&schedulers, Some(component.clone()), cancel_token.clone())
.await?;
// Start KV events publishing with the actual receivers from schedulers
if self.engine_args.enable_prefix_caching {
Self::start_kv_events_publishing(
kv_event_receiver,
Some(component.clone()),
self.engine_args.block_size,
cancel_token.clone(),
)
.await?;
}
Ok(())
}
......@@ -100,18 +95,14 @@ impl MockVllmEngine {
}
/// Create schedulers and spawn their background tasks for distributing token notifications
/// Returns schedulers and their corresponding KV event receivers
fn start_schedulers(
&self,
args: MockEngineArgs,
active_requests: Arc<Mutex<HashMap<Uuid, mpsc::UnboundedSender<OutputSignal>>>>,
component: Option<Component>,
cancel_token: CancellationToken,
) -> (
Vec<Scheduler>,
Vec<mpsc::UnboundedReceiver<KvCacheEventData>>,
) {
) -> Vec<Scheduler> {
let mut schedulers = Vec::<Scheduler>::new();
let mut kv_event_receivers = Vec::new();
let mut senders = Vec::with_capacity(args.dp_size as usize);
// Create multiple schedulers and their background tasks
......@@ -119,20 +110,16 @@ impl MockVllmEngine {
// Create a shared output channel that this scheduler will use
let (output_tx, mut output_rx) = mpsc::unbounded_channel::<OutputSignal>();
// Create a channel for KV events from this scheduler
let (kv_events_tx, kv_events_rx) = mpsc::unbounded_channel::<KvCacheEventData>();
let scheduler = Scheduler::new(
args.clone(),
dp_rank,
Some(output_tx),
Some(kv_events_tx), // Pass the KV events sender to scheduler
component.clone(),
Some(cancel_token.clone()),
);
senders.push(scheduler.request_sender());
schedulers.push(scheduler);
kv_event_receivers.push(kv_events_rx);
// Spawn a background task for this scheduler to distribute token notifications to active requests
// let output_rx = Arc::new(Mutex::new(output_rx));
......@@ -166,7 +153,7 @@ impl MockVllmEngine {
.set(senders)
.expect("Already initialized");
(schedulers, kv_event_receivers)
schedulers
}
/// Start background tasks to publish metrics on change
......@@ -228,78 +215,6 @@ impl MockVllmEngine {
tracing::info!("Metrics background tasks started");
Ok(())
}
/// Start background tasks to collect and publish KV events from schedulers
async fn start_kv_events_publishing(
kv_event_receivers: Vec<mpsc::UnboundedReceiver<KvCacheEventData>>,
component: Option<Component>,
block_size: usize,
cancel_token: CancellationToken,
) -> Result<()> {
tracing::debug!("Starting KV events publishing");
// Only start KV events publishing if we have a component
let Some(comp) = component else {
tracing::warn!("No component provided, skipping KV events publishing");
return Ok(());
};
tracing::debug!("Component found for KV events publishing");
tracing::debug!("Getting worker_id");
let worker_id = comp.drt().connection_id();
tracing::debug!("Worker_id set to: {worker_id}");
tracing::debug!("Creating KV event publisher");
let kv_event_publisher = Arc::new(KvEventPublisher::new(
comp.clone(),
worker_id,
block_size as u32,
None,
)?);
tracing::debug!("KV event publisher created");
tracing::debug!(
"Starting KV event background tasks for {} receivers",
kv_event_receivers.len()
);
for (dp_rank, mut kv_events_rx) in kv_event_receivers.into_iter().enumerate() {
tracing::debug!("Starting background task for DP rank {dp_rank}");
let publisher = kv_event_publisher.clone();
let dp_rank = dp_rank as u32;
let cancel_token = cancel_token.clone();
tokio::spawn(async move {
tracing::debug!("Background task started for DP rank {dp_rank}");
loop {
tokio::select! {
// Receive actual KV events from the scheduler
Some(event_data) = kv_events_rx.recv() => {
// Convert KvCacheEventData to KvCacheEvent with random UUID as event_id
let event = KvCacheEvent {
event_id: Uuid::new_v4().as_u128() as u64,
data: event_data,
dp_rank,
};
// Publish the event
if let Err(e) = publisher.publish(event) {
tracing::warn!("Failed to publish KV event for DP rank {dp_rank}: {e}");
} else {
tracing::trace!("Published KV event for DP rank {dp_rank}");
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("KV events publishing cancelled for DP rank {dp_rank}");
break;
}
}
}
});
}
tracing::info!("All KV event background tasks started");
Ok(())
}
}
#[async_trait]
......@@ -325,14 +240,21 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
let request_uuid = ctx.id().parse().unwrap_or(Uuid::new_v4());
// For prefill workers, override max_tokens to 1
let is_prefill = self.engine_args.worker_type == WorkerType::Prefill;
let max_output_tokens = if is_prefill {
1
} else {
request
.stop_conditions
.max_tokens
.expect("max_output_tokens must be specified for mocker") as usize
};
// Convert PreprocessedRequest to DirectRequest for scheduler
let direct_request = DirectRequest {
tokens: request.token_ids.clone(),
max_output_tokens: request
.stop_conditions
.max_tokens
.expect("max_output_tokens must be specified for mocker")
as usize,
max_output_tokens,
uuid: Some(request_uuid),
dp_rank,
};
......@@ -351,7 +273,6 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
let active_requests = self.active_requests.clone();
let async_context = ctx.context();
let max_tokens = request.stop_conditions.max_tokens.unwrap_or(100) as usize;
// Spawn a task to handle the complex async logic
tokio::spawn(async move {
......@@ -378,11 +299,16 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<LLMEngineOutput>, Error>
top_logprobs: None,
finish_reason: None,
index: None,
disaggregated_params: None,
// Add dummy disaggregated_params for prefill workers
disaggregated_params: if is_prefill {
Some(serde_json::json!("dummy"))
} else {
None
},
extra_args: None,
};
if signal.completed && token_count < max_tokens {
if signal.completed && token_count < max_output_tokens {
let _ = stream_tx.send(LLMEngineOutput::error("Completion signal received before max tokens reached".to_string()));
break;
}
......
......@@ -33,13 +33,20 @@
//! the more idiomatic built-in Arc reference counter. This can be considered a shadow / mirror
//! implementation of the main block manager.
use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash,
};
use crate::kv_router::publisher::KvEventPublisher;
use crate::mocker::evictor::LRUEvictor;
use crate::mocker::protocols::{MoveBlock, MoveBlockResponse, PrefillCost};
use crate::mocker::protocols::{MoveBlock, PrefillCost};
use crate::mocker::sequence::ActiveSequence;
use crate::tokens::blocks::UniqueBlock;
use crate::tokens::{BlockHash, SequenceHash};
use derive_getters::Getters;
use dynamo_runtime::component::Component;
use std::collections::{HashMap, HashSet};
use tokio::sync::mpsc;
use std::sync::Arc;
#[derive(Getters)]
pub struct KvManager {
......@@ -55,60 +62,113 @@ pub struct KvManager {
all_blocks: HashSet<UniqueBlock>,
move_block_response_tx: Option<mpsc::UnboundedSender<MoveBlockResponse>>,
kv_event_publisher: Option<Arc<KvEventPublisher>>,
#[getter(copy)]
dp_rank: u32,
next_event_id: u64,
}
impl KvManager {
pub fn new(max_capacity: usize, block_size: usize) -> Self {
Self::new_with_sender(max_capacity, block_size, None)
Self::new_with_publisher(max_capacity, block_size, None, 0)
}
pub fn new_with_sender(
pub fn new_with_publisher(
max_capacity: usize,
block_size: usize,
move_block_response_tx: Option<mpsc::UnboundedSender<MoveBlockResponse>>,
component: Option<Component>,
dp_rank: u32,
) -> Self {
let active_blocks = HashMap::new();
let inactive_blocks = LRUEvictor::default();
let all_blocks = HashSet::new();
let kv_event_publisher = component.map(|comp| {
tracing::info!(
"Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}"
);
Arc::new(
KvEventPublisher::new(comp, block_size as u32, None)
.expect("Failed to create KV event publisher"),
)
});
KvManager {
max_capacity,
block_size,
active_blocks,
inactive_blocks,
all_blocks,
move_block_response_tx,
kv_event_publisher,
dp_rank,
next_event_id: 0,
}
}
/// Utility method to send block responses with optional reversing
fn send_block_response(
&self,
mut blocks: Vec<u64>,
reverse: bool,
store: bool,
/// Converts stored/removed blocks into KvCacheEventData and publishes if publisher is available
fn publish_kv_event(
&mut self,
full_blocks: Vec<SequenceHash>,
local_hashes: &[BlockHash],
parent_hash: Option<u64>,
is_store: bool,
) {
if let Some(ref tx) = self.move_block_response_tx
&& !blocks.is_empty()
{
if reverse {
blocks.reverse();
}
let response = if store {
MoveBlockResponse::Store(blocks, parent_hash)
} else {
MoveBlockResponse::Remove(blocks)
};
tx.send(response).unwrap();
if full_blocks.is_empty() {
return;
}
let Some(ref publisher) = self.kv_event_publisher else {
return;
};
let event_data = if is_store {
let num_blocks = full_blocks.len();
let local_hashes_slice = &local_hashes[local_hashes
.len()
.checked_sub(num_blocks)
.expect("local hashes fewer than stored blocks")..];
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_hash.map(ExternalSequenceBlockHash),
blocks: full_blocks
.into_iter()
.zip(local_hashes_slice.iter())
.map(|(global_hash, local_hash)| KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(global_hash),
tokens_hash: LocalBlockHash(*local_hash),
})
.collect(),
})
} else {
KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: full_blocks
.into_iter()
.map(ExternalSequenceBlockHash)
.collect(),
})
};
// Use incremental event ID starting from 0
let event_id = self.next_event_id;
self.next_event_id += 1;
let event = KvCacheEvent {
event_id,
data: event_data,
dp_rank: self.dp_rank,
};
if let Err(e) = publisher.publish(event) {
tracing::warn!("Failed to publish KV event: {e}");
}
}
/// Process a MoveBlock instruction synchronously
pub fn process(&mut self, event: &MoveBlock) -> bool {
match event {
MoveBlock::Use(hashes) => {
MoveBlock::Use(hashes, local_hashes) => {
let mut blocks_stored = Vec::<u64>::new();
let mut parent_block: Option<&UniqueBlock> = None;
......@@ -138,16 +198,20 @@ impl KvManager {
let Some(evicted) = self.inactive_blocks.evict() else {
return false;
};
tracing::trace!(
"Evicting block from inactive pool: {evicted:?}, dp_rank={}",
self.dp_rank
);
self.all_blocks.remove(&evicted);
if let UniqueBlock::FullBlock(evicted_full_block) = evicted {
self.send_block_response(vec![evicted_full_block], false, false, None);
self.publish_kv_event(vec![evicted_full_block], &[], None, false);
}
}
// Now insert the new block in active blocks with reference count 1
self.active_blocks.insert(hash.clone(), 1);
self.all_blocks.insert(hash.clone());
if self.move_block_response_tx.is_some()
if self.kv_event_publisher.is_some()
&& let UniqueBlock::FullBlock(stored_full_block) = hash
{
blocks_stored.push(*stored_full_block);
......@@ -159,32 +223,32 @@ impl KvManager {
Some(UniqueBlock::FullBlock(block)) => Some(*block),
Some(UniqueBlock::PartialBlock(_)) => panic!("parent block cannot be partial"),
};
self.send_block_response(blocks_stored, false, true, parent_hash);
self.publish_kv_event(blocks_stored, local_hashes, parent_hash, true);
}
MoveBlock::Destroy(hashes) => {
let mut blocks_destroyed = Vec::<u64>::new();
// Loop in inverse direction
for hash in hashes.iter().rev() {
// Process blocks in order (already reversed by caller if needed)
for hash in hashes.iter() {
self.active_blocks.remove(hash).unwrap();
// Remove from all_blocks when destroyed
assert!(self.all_blocks.remove(hash));
// Track blocks for batch sending
if self.move_block_response_tx.is_some()
if self.kv_event_publisher.is_some()
&& let UniqueBlock::FullBlock(destroyed_full_block) = hash
{
blocks_destroyed.push(*destroyed_full_block);
}
}
self.send_block_response(blocks_destroyed, true, false, None);
self.publish_kv_event(blocks_destroyed, &[], None, false);
}
MoveBlock::Deref(hashes) => {
// Loop in inverse direction
for hash in hashes.iter().rev() {
// Process blocks in order (already reversed by caller if needed)
for hash in hashes.iter() {
// Decrement reference count and check if we need to move to inactive
if let Some(ref_count) = self.active_blocks.get_mut(hash) {
if *ref_count == 0 {
......@@ -202,24 +266,30 @@ impl KvManager {
}
}
MoveBlock::Promote(uuid, hash, parent_hash) => {
MoveBlock::Promote(uuid, hash, parent_hash, local_hash) => {
let uuid_block = UniqueBlock::PartialBlock(*uuid);
let hash_block = UniqueBlock::FullBlock(*hash);
let Some(ref_count) = self.active_blocks.remove(&uuid_block) else {
let in_all_blocks = self.all_blocks.contains(&uuid_block);
panic!(
"Missing active block for promotion: {uuid_block:?}. Block still exists: {in_all_blocks}"
);
assert_eq!(
self.active_blocks.remove(&uuid_block),
Some(1),
"uuid_block {uuid_block:?} should exist and be unique with ref_count=1"
);
let hash_ref_count = if let Some(ref_count) = self.active_blocks.get(&hash_block) {
*ref_count
} else if self.inactive_blocks.remove(&hash_block) {
0
} else {
self.publish_kv_event(vec![*hash], &[*local_hash], *parent_hash, true);
0
};
// Replace with hash block, keeping the same reference count
self.active_blocks.insert(hash_block.clone(), ref_count);
self.active_blocks
.insert(hash_block.clone(), hash_ref_count + 1);
// Update all_blocks
assert!(self.all_blocks.remove(&uuid_block));
self.all_blocks.insert(hash_block);
self.send_block_response(vec![*hash], false, true, *parent_hash);
}
}
......@@ -291,7 +361,6 @@ impl KvManager {
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::mpsc;
#[test]
fn test_failure_on_max_capacity() {
......@@ -300,8 +369,9 @@ mod tests {
// Helper function to use multiple blocks that returns the response
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) -> bool {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Use(blocks))
let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect();
let hashes: Vec<_> = ids.into_iter().collect();
manager.process(&MoveBlock::Use(blocks, hashes))
}
// First use 10 blocks (0 to 9) in a batch
......@@ -321,16 +391,14 @@ mod tests {
#[test]
fn test_block_lifecycle_stringent() {
// Create a channel to listen to block responses
let (tx, mut rx) = mpsc::unbounded_channel::<MoveBlockResponse>();
// Create a KvManager with 10 blocks capacity and the response sender
let mut manager = KvManager::new_with_sender(10, 16, Some(tx));
// Create a KvManager with 10 blocks capacity (no KV event publisher for tests)
let mut manager = KvManager::new(10, 16);
// Helper function to use multiple blocks
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Use(blocks));
let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect();
let hashes: Vec<_> = ids.into_iter().collect();
manager.process(&MoveBlock::Use(blocks, hashes));
}
// Helper function to destroy multiple blocks
......@@ -345,56 +413,6 @@ mod tests {
manager.process(&MoveBlock::Deref(blocks));
}
// Helper function to assert block responses
fn assert_block_response(
rx: &mut mpsc::UnboundedReceiver<MoveBlockResponse>,
expected_type: &str,
expected_blocks: Vec<u64>,
description: &str,
) {
let response = rx
.try_recv()
.unwrap_or_else(|_| panic!("Expected {expected_type} response {description}"));
match (&response, expected_type) {
(MoveBlockResponse::Store(blocks, _parent_hash), "Store") => {
assert_eq!(
blocks.len(),
expected_blocks.len(),
"Expected {} blocks in Store response {}",
expected_blocks.len(),
description
);
assert_eq!(
*blocks, expected_blocks,
"Store blocks don't match expected {description}"
);
}
(MoveBlockResponse::Remove(blocks), "Remove") => {
assert_eq!(
blocks.len(),
expected_blocks.len(),
"Expected {} blocks in Remove response {}",
expected_blocks.len(),
description
);
assert_eq!(
*blocks, expected_blocks,
"Remove blocks don't match expected {description}"
);
}
_ => panic!("Expected {expected_type} response, got {response:?} {description}"),
}
}
// Helper function to assert no response is received
fn assert_no_response(
rx: &mut mpsc::UnboundedReceiver<MoveBlockResponse>,
description: &str,
) {
assert!(rx.try_recv().is_err(), "Expected no response {description}",);
}
// Helper function to check if active blocks contain expected blocks with expected ref counts
fn assert_active_blocks(manager: &KvManager, expected_blocks: &[(u64, usize)]) {
assert_eq!(
......@@ -442,11 +460,9 @@ mod tests {
// First use blocks 0, 1, 2, 3, 4 in a batch
use_blocks(&mut manager, (0..5).collect());
assert_block_response(&mut rx, "Store", vec![0, 1, 2, 3, 4], "after first use");
// Then use blocks 0, 1, 5, 6 in a batch
use_blocks(&mut manager, vec![0, 1, 5, 6]);
assert_block_response(&mut rx, "Store", vec![5, 6], "after second use");
// Check that the blocks 0 and 1 are in active blocks, both with reference counts of 2
assert_active_blocks(
......@@ -456,11 +472,9 @@ mod tests {
// Now destroy block 4
destroy_blocks(&mut manager, vec![4]);
assert_block_response(&mut rx, "Remove", vec![4], "after destroy block 4");
// And deref blocks 3, 2, 1, 0 in this order as a batch
deref_blocks(&mut manager, vec![0, 1, 2, 3]);
assert_no_response(&mut rx, "after deref operation");
// Check that the inactive_blocks is size 2 (via num_objects) and contains 3 and 2
assert_inactive_blocks(&manager, 2, &[3, 2]);
......@@ -468,7 +482,6 @@ mod tests {
// Now destroy block 6
destroy_blocks(&mut manager, vec![6]);
assert_block_response(&mut rx, "Remove", vec![6], "after block 6 eviction");
// And deref blocks 5, 1, 0 as a batch
deref_blocks(&mut manager, vec![0, 1, 5]);
......@@ -479,7 +492,6 @@ mod tests {
// Now use 0, 1, 2, 7, 8, 9 as a batch
use_blocks(&mut manager, vec![0, 1, 2, 7, 8, 9]);
assert_block_response(&mut rx, "Store", vec![7, 8, 9], "after [7, 8, 9] use");
// Check that the inactive_blocks is size 2, and contains 3 and 5
assert_inactive_blocks(&manager, 2, &[3, 5]);
......@@ -494,14 +506,10 @@ mod tests {
// Now use blocks 10, 11, 12 as a batch
use_blocks(&mut manager, vec![10, 11, 12]);
assert_block_response(&mut rx, "Remove", vec![3], "after block 5 eviction");
assert_block_response(&mut rx, "Store", vec![10, 11, 12], "after [10, 11, 12] use");
// Check that the inactive_blocks is size 1 and contains only 5
assert_inactive_blocks(&manager, 1, &[5]);
use_blocks(&mut manager, vec![13]);
assert_block_response(&mut rx, "Remove", vec![5], "after block 5 eviction");
assert_block_response(&mut rx, "Store", vec![13], "after block 13 use");
}
}
......@@ -7,23 +7,19 @@ use std::collections::{HashMap, HashSet};
use std::path::Path;
use uuid::Uuid;
use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash,
};
use crate::tokens::blocks::UniqueBlock;
use crate::tokens::{BlockHash, SequenceHash, Token};
pub type NumBlocks = usize;
/// Represents different block movement operations in the cache
/// For Use and Promote variants, parent hash is the second field
/// For Use and Promote variants, block hashes are included for KV event publishing
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlock {
Use(Vec<UniqueBlock>),
Use(Vec<UniqueBlock>, Vec<BlockHash>),
Destroy(Vec<UniqueBlock>),
Deref(Vec<UniqueBlock>),
Promote(Uuid, SequenceHash, Option<u64>),
Promote(Uuid, SequenceHash, Option<u64>, BlockHash),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
......@@ -50,7 +46,7 @@ pub struct PrefillCost {
impl PrefillCost {
pub fn predict_prefill_compute(&self, new_tokens: Option<usize>) -> f64 {
let tokens = new_tokens.unwrap_or(self.new_tokens);
1.25e-6 * (tokens as f64).powi(2) + 7.41e-2 * (tokens as f64) + 2.62e1
4.209989e-07 * (tokens as f64).powi(2) + 1.518344e-02 * (tokens as f64) + 1.650142e+01
}
}
......@@ -260,49 +256,6 @@ impl MockEngineArgs {
}
}
/// Converts a MoveBlockResponse from the mocker backend into a KvCacheEventData.
///
/// This function assumes that the stored sequence hashes in the response always
/// correspond to the tail part of the local hashes array. This is the expected
/// behavior of KV block storage, where blocks are stored sequentially and the
/// response contains the most recent blocks that were stored.
///
/// # Panics
/// Panics if the number of blocks in the Store response exceeds the length
/// of local_hashes.
pub fn block_response_to_kv_event(
response: MoveBlockResponse,
local_hashes: &[BlockHash],
) -> KvCacheEventData {
match response {
MoveBlockResponse::Store(full_blocks, parent_hash) => {
let num_blocks = full_blocks.len();
let local_hashes_slice = &local_hashes[local_hashes
.len()
.checked_sub(num_blocks)
.expect("local hashes fewer than block response signal")..];
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_hash.map(ExternalSequenceBlockHash),
blocks: full_blocks
.into_iter()
.zip(local_hashes_slice.iter())
.map(|(global_hash, local_hash)| KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(global_hash),
tokens_hash: LocalBlockHash(*local_hash),
})
.collect(),
})
}
MoveBlockResponse::Remove(full_blocks) => KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: full_blocks
.into_iter()
.map(ExternalSequenceBlockHash)
.collect(),
}),
}
}
#[cfg(test)]
mod tests {
use super::*;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::VecDeque;
use std::ops::{Add, Div, Sub};
/// A generic running mean calculator with a fixed-size sliding window.
/// Maintains a running sum and count to compute the mean in O(1) time.
#[derive(Debug, Clone)]
pub struct RunningMean<T>
where
T: Copy + Add<Output = T> + Sub<Output = T> + Div<Output = T> + Default + From<u16>,
{
max_size: u16,
sum: T,
values: VecDeque<T>,
}
impl<T> RunningMean<T>
where
T: Copy + Add<Output = T> + Sub<Output = T> + Div<Output = T> + Default + From<u16>,
{
pub fn new(max_size: u16) -> Self {
Self {
max_size,
sum: T::default(),
values: VecDeque::with_capacity(max_size as usize),
}
}
pub fn push(&mut self, value: T) {
// If at capacity, remove the oldest value from sum
if self.values.len() >= self.max_size as usize
&& let Some(old_value) = self.values.pop_front()
{
self.sum = self.sum - old_value;
}
// Add new value
self.sum = self.sum + value;
self.values.push_back(value);
}
pub fn mean(&self) -> T {
if self.values.is_empty() {
T::default()
} else {
self.sum / T::from(self.values.len() as u16)
}
}
pub fn len(&self) -> usize {
self.values.len()
}
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
/// Clear all values from the window.
pub fn clear(&mut self) {
self.sum = T::default();
self.values.clear();
}
}
......@@ -28,16 +28,16 @@
//! ## NOTE
//! The current prefill and decoding time simulations are not scientific at all and are WIP
use crate::kv_router::protocols::{ForwardPassMetrics, KvCacheEventData, KvStats, WorkerStats};
use crate::kv_router::protocols::{ForwardPassMetrics, KvStats, WorkerStats};
use crate::mocker::evictor::LRUEvictor;
use crate::mocker::kv_manager::KvManager;
use crate::mocker::protocols::{DirectRequest, MockEngineArgs, MoveBlockResponse};
use crate::mocker::protocols::{MoveBlock, OutputSignal, PrefillCost, block_response_to_kv_event};
use crate::mocker::protocols::{
DirectRequest, MockEngineArgs, MoveBlock, OutputSignal, PrefillCost, WorkerType,
};
use crate::mocker::running_mean::RunningMean;
use crate::mocker::sequence::ActiveSequence;
use crate::tokens::BlockHash;
use crate::tokens::blocks::UniqueBlock;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::collections::{HashMap, VecDeque};
use tokio::sync::mpsc;
use tokio::time::Duration;
use tokio_util::sync::CancellationToken;
......@@ -111,9 +111,8 @@ impl SchedulerState {
/// Returns `Some((prefill_compute, creation_signal, is_full_prefill))` where:
/// - `prefill_compute`: The compute time in milliseconds for this prefill operation
/// - `creation_signal`: Optional MoveBlock signal for KV cache block creation
/// - `block_hashes`: Block hashes of the sequence beign prefilled
/// - `is_full_prefill`: true if the entire sequence was prefilled, false if chunked
fn try_prefill(&mut self) -> Option<(f64, Option<MoveBlock>, Vec<BlockHash>, bool)> {
fn try_prefill(&mut self) -> Option<(f64, Option<MoveBlock>, bool)> {
let uuid = self.prefill.pop_front()?;
// Remove and extract prefill_compute from prefill_costs
......@@ -168,7 +167,6 @@ impl SchedulerState {
Some((
prefill_compute,
sequence.take_creation_signal(),
sequence.block_hashes(),
is_full_prefill,
))
}
......@@ -247,17 +245,9 @@ impl Scheduler {
args: MockEngineArgs,
dp_rank: u32,
output_tx: Option<mpsc::UnboundedSender<OutputSignal>>,
kv_events_tx: Option<mpsc::UnboundedSender<KvCacheEventData>>,
component: Option<dynamo_runtime::component::Component>,
cancellation_token: Option<CancellationToken>,
) -> Self {
// Create internal channel for KV events only if needed
let (block_resp_tx, mut block_resp_rx) = if kv_events_tx.is_some() {
let (tx, rx) = mpsc::unbounded_channel::<MoveBlockResponse>();
(Some(tx), Some(rx))
} else {
(None, None)
};
// Assert speedup_ratio is greater than 0
assert!(
args.speedup_ratio > 0.0,
......@@ -278,121 +268,64 @@ impl Scheduler {
tokio::spawn(async move {
// Create state and kv_manager as local variables owned by this task
let mut state = SchedulerState::new(args.max_num_batched_tokens);
let mut kv_manager =
KvManager::new_with_sender(args.num_gpu_blocks, args.block_size, block_resp_tx);
let mut hit_rates = VecDeque::with_capacity(1000);
let mut should_schedule = true;
let mut kv_manager = KvManager::new_with_publisher(
args.num_gpu_blocks,
args.block_size,
component,
dp_rank,
);
let mut hit_rates = RunningMean::new(1000);
loop {
{
// Enqueue new request, blocks until at least one is received, so no redundant work is done
if state.is_empty() {
let Some(request) = request_rx.recv().await else {
tracing::warn!("request sender is dropped");
// 1. Receive requests
if state.is_empty() {
// Fully idle - block until new request arrives
tokio::select! {
biased;
Some(request) = request_rx.recv() => {
state.receive(request);
}
_ = cancel_token_clone.cancelled() => {
break;
};
state.receive(request);
}
}
}
tokio::select! {
biased;
// Enqueue new request
Some(request) = request_rx.recv() => {
} else {
// Has active/waiting work - collect any pending requests without blocking
while let Ok(request) = request_rx.try_recv() {
state.receive(request);
}
// Try Scheduling Requests - runs on normal interval or after simulation
_ = tokio::task::yield_now() => {
// Skip if we just ran scheduling after simulation to prevent consecutive runs
if !should_schedule {
continue;
}
// Process DirectRequests, converting them to ActiveSequence and scheduling them until we can't
// schedule anymore.
let mut current_blocks = kv_manager.num_active_blocks();
let mut current_tokens = state.active_tokens + state.waiting_tokens;
let mut current_seqs = state.num_active_requests();
while let Some((uuid, request)) = state.next() {
let active_sequence = get_active_sequence(request, args.block_size, args.enable_prefix_caching);
// Update predictive budgets
let prefill_cost = kv_manager.get_prefill_cost(&active_sequence);
let total_tokens = active_sequence.len();
// this is conservative, assumes no cache hit so never over-schedules
let new_blocks = (total_tokens as u32).div_ceil(args.block_size as u32) as usize;
let new_tokens = prefill_cost.new_tokens;
current_blocks += new_blocks;
current_tokens += new_tokens;
current_seqs += 1;
// Check various budgets to see if possible to schedule
let under_block_budget = current_blocks as f64 <= (1. - args.watermark) * kv_manager.max_capacity() as f64;
// If chunked prefill is enabled, we can be under token budget when scheduling
let comparison_tokens = if args.enable_chunked_prefill {current_tokens - new_tokens} else {current_tokens};
let under_token_budget = args.max_num_batched_tokens.is_none_or(|limit| comparison_tokens <= limit);
let under_seq_budget = args.max_num_seqs.is_none_or(|limit| current_seqs <= limit);
// Cannot schedule, put first in line instead
if !(under_block_budget && under_token_budget && under_seq_budget) {
state.first_in_line(uuid, Request::Active(active_sequence));
break;
}
// Compute and store hit rate
let hit_rate = if !active_sequence.is_empty() { 1.0 - (new_tokens as f32 / active_sequence.len() as f32) } else { 0.0 };
hit_rates.push_back(hit_rate);
if hit_rates.len() > 1000 {
hit_rates.pop_front();
}
state.move_to_prefill(uuid, active_sequence, prefill_cost);
should_schedule = false;
}
}
// Check for cancellation
_ = cancel_token_clone.cancelled() => {
if cancel_token_clone.is_cancelled() {
break;
}
}
// Simulates prefill + decode
// Base time needed for decoding using active percentage and quadratic formula
let active_perc = kv_manager.get_active_perc();
let decoding_time = -5.47 * active_perc.powi(2) + 43.88 * active_perc + 19.44;
let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0);
// Start timing for this forward pass (schedule + simulate)
let iteration_start = std::time::Instant::now();
// 2. Schedule waiting requests (once per iteration)
try_schedule(&mut state, &kv_manager, &mut hit_rates, &args);
// 3. Simulate prefill + decode
let mut total_time = Duration::ZERO;
// Process prefilling
while let Some((
prefill_compute,
maybe_creation_signal,
block_hashes,
is_full_prefill,
)) = state.try_prefill()
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) =
state.try_prefill()
{
// NOTE: Prefill cost/time is always incremented for new blocks, even if they
// could be cached by other requests in the same batch. This matches vLLM behavior.
total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
if let Some(creation_signal) = maybe_creation_signal {
if !process_signals(&mut kv_manager, std::slice::from_ref(&creation_signal))
{
panic!("Block allocation for prefilling cannot fail.");
}
// For decode workers, skip adding prefill compute time
if args.worker_type != WorkerType::Decode {
total_time += Duration::from_secs_f64(prefill_compute / 1000.0);
}
// Drain KV events and forward to relay after prefill signal processing
if let (Some(relay_tx), Some(rx)) = (&kv_events_tx, &mut block_resp_rx) {
while let Ok(event) = rx.try_recv() {
let _ =
relay_tx.send(block_response_to_kv_event(event, &block_hashes));
}
}
};
if let Some(creation_signal) = maybe_creation_signal
&& !process_signals(&mut kv_manager, std::slice::from_ref(&creation_signal))
{
panic!("Block allocation for prefilling cannot fail.");
}
// Impossible to schedule more prefills if we encounter one incomplete (chunked) prefill
if !is_full_prefill {
......@@ -400,13 +333,15 @@ impl Scheduler {
}
}
let active_perc = kv_manager.get_active_perc();
// TODO: share the same logic with Planner
let decoding_time = -25.74 * active_perc.powi(2) + 54.01 * active_perc + 5.74;
total_time += Duration::from_secs_f64(decoding_time / 1000.0);
state.reset_active_tokens();
// Process decoding
let uuids: Vec<Uuid> = state.decode.keys().cloned().collect();
if !uuids.is_empty() {
should_schedule = true
};
for uuid in uuids {
let Some(sequence) = state.run(uuid) else {
continue;
......@@ -423,14 +358,6 @@ impl Scheduler {
continue;
}
// Drain KV events and forward to relay after decode signal processing
if let (Some(relay_tx), Some(rx)) = (&kv_events_tx, &mut block_resp_rx) {
while let Ok(event) = rx.try_recv() {
let _ = relay_tx
.send(block_response_to_kv_event(event, &sequence.block_hashes()));
}
}
// Check completion and send notification
let is_complete = sequence.generated_tokens() >= sequence.max_output_tokens();
let should_output =
......@@ -465,11 +392,13 @@ impl Scheduler {
let _ = metrics_tx.send(metrics);
}
// Sleep once for the adjusted duration
let adjusted_time =
// 4. Sleep to maintain target iteration timing
let target_duration =
Duration::from_secs_f64(total_time.as_secs_f64() / args.speedup_ratio);
if adjusted_time.as_millis() > 0 {
tokio::time::sleep(adjusted_time).await;
let elapsed = iteration_start.elapsed();
if elapsed < target_duration {
tokio::time::sleep(target_duration - elapsed).await;
}
}
});
......@@ -499,7 +428,7 @@ impl Scheduler {
fn get_fwd_pass_metrics(
state: &SchedulerState,
kv_manager: &KvManager,
hit_rates: &VecDeque<f32>,
hit_rates: &RunningMean<f32>,
dp_rank: u32,
) -> ForwardPassMetrics {
// Get state metrics
......@@ -507,7 +436,7 @@ fn get_fwd_pass_metrics(
let num_requests_waiting = state.waiting.len() as u64;
// Get KV manager metrics
let active_blocks_count = kv_manager.active_blocks().len() as u64;
let active_blocks_count = kv_manager.num_active_blocks() as u64;
let total_capacity = kv_manager.max_capacity() as u64;
let gpu_cache_usage_perc = if total_capacity > 0 {
active_blocks_count as f32 / total_capacity as f32
......@@ -515,13 +444,8 @@ fn get_fwd_pass_metrics(
0.0
};
// Get hit rate metrics
let gpu_prefix_cache_hit_rate = if hit_rates.is_empty() {
0.0
} else {
let sum: f32 = hit_rates.iter().sum();
sum / hit_rates.len() as f32
};
// Get hit rate metrics - O(1) access
let gpu_prefix_cache_hit_rate = hit_rates.mean();
let worker_stats = WorkerStats {
data_parallel_rank: Some(dp_rank),
......@@ -546,26 +470,75 @@ fn get_fwd_pass_metrics(
}
}
/// Convert a Request to an ActiveSequence
fn get_active_sequence(
request: Request,
block_size: usize,
enable_prefix_caching: bool,
) -> ActiveSequence {
if let Request::Active(active_seq) = request {
return active_seq;
}
/// Attempts to schedule waiting requests from the state queue.
/// Returns the number of requests successfully scheduled.
fn try_schedule(
state: &mut SchedulerState,
kv_manager: &KvManager,
hit_rates: &mut RunningMean<f32>,
args: &MockEngineArgs,
) -> usize {
let mut scheduled_count = 0;
let mut current_blocks = kv_manager.num_active_blocks();
let mut current_tokens = state.active_tokens + state.waiting_tokens;
let mut current_seqs = state.num_active_requests();
while let Some((uuid, request)) = state.next() {
// Convert Request to ActiveSequence
let active_sequence = match request {
Request::Active(active_seq) => active_seq,
Request::Direct(direct_request) => ActiveSequence::new(
direct_request.tokens,
direct_request.max_output_tokens,
Some(args.block_size),
args.enable_prefix_caching,
),
};
let Request::Direct(direct_request) = request else {
unreachable!("Request must be either Direct or Active");
};
// Update predictive budgets
let prefill_cost = kv_manager.get_prefill_cost(&active_sequence);
let total_tokens = active_sequence.len();
// this is conservative, assumes no cache hit so never over-schedules
let new_blocks = (total_tokens as u32).div_ceil(args.block_size as u32) as usize;
let new_tokens = prefill_cost.new_tokens;
current_blocks += new_blocks;
current_tokens += new_tokens;
current_seqs += 1;
// Check various budgets to see if possible to schedule
let under_block_budget =
current_blocks as f64 <= (1. - args.watermark) * kv_manager.max_capacity() as f64;
// If chunked prefill is enabled, we can be under token budget when scheduling
let comparison_tokens = if args.enable_chunked_prefill {
current_tokens - new_tokens
} else {
current_tokens
};
let under_token_budget = args
.max_num_batched_tokens
.is_none_or(|limit| comparison_tokens <= limit);
let under_seq_budget = args.max_num_seqs.is_none_or(|limit| current_seqs <= limit);
// Cannot schedule, put first in line instead
if !(under_block_budget && under_token_budget && under_seq_budget) {
state.first_in_line(uuid, Request::Active(active_sequence));
break;
}
// Compute and store hit rate
let hit_rate = if !active_sequence.is_empty() {
1.0 - (new_tokens as f32 / active_sequence.len() as f32)
} else {
0.0
};
hit_rates.push(hit_rate);
state.move_to_prefill(uuid, active_sequence, prefill_cost);
scheduled_count += 1;
}
ActiveSequence::new(
direct_request.tokens,
direct_request.max_output_tokens,
Some(block_size),
enable_prefix_caching,
)
scheduled_count
}
/// Processes MoveBlock signals with the KvManager.
......@@ -582,7 +555,7 @@ fn process_signals(kv_manager: &mut KvManager, signals: &[MoveBlock]) -> bool {
}
// Check we have a Use signal with blocks
let MoveBlock::Use(blocks) = signal else {
let MoveBlock::Use(blocks, _hashes) = signal else {
panic!(
"Failed signal is Invalid. Has to fail on generation signal, but failed on {signal:?}"
);
......
......@@ -6,12 +6,10 @@ use crate::tokens::blocks::UniqueBlock;
use crate::tokens::{TokenBlockSequence, Tokens};
use derive_getters::Getters;
use rand::random;
use uuid;
/// Create unique blocks from a TokenBlockSequence
fn create_unique_blocks_from_sequence(
tokens: &TokenBlockSequence,
uuid: Option<uuid::Uuid>,
block_size: usize,
enable_prefix_caching: bool,
) -> Vec<UniqueBlock> {
......@@ -29,10 +27,7 @@ fn create_unique_blocks_from_sequence(
// Only push the partial block if tokens count isn't a multiple of block_size
if !tokens.total_tokens().is_multiple_of(block_size) {
unique_blocks.push(match uuid {
Some(uuid) => UniqueBlock::PartialBlock(uuid),
None => UniqueBlock::default(),
});
unique_blocks.push(UniqueBlock::default());
}
unique_blocks
}
......@@ -80,8 +75,9 @@ impl ActiveSequence {
let tokens = Tokens::from(tokens).into_sequence(block_size as u32, Some(1337));
let unique_blocks =
create_unique_blocks_from_sequence(&tokens, None, block_size, enable_prefix_caching);
let creation_signal = Some(MoveBlock::Use(unique_blocks.clone()));
create_unique_blocks_from_sequence(&tokens, block_size, enable_prefix_caching);
let block_hashes = tokens.blocks().iter().map(|b| b.block_hash()).collect();
let creation_signal = Some(MoveBlock::Use(unique_blocks.clone(), block_hashes));
Self {
unique_blocks,
......@@ -132,17 +128,6 @@ impl ActiveSequence {
(sequence, signal)
}
/// Get the parent hash from the second-to-last block if it exists and is a FullBlock
fn get_parent_hash(&self) -> Option<u64> {
if self.unique_blocks.len() < 2 {
return None;
}
match &self.unique_blocks[self.unique_blocks.len() - 2] {
UniqueBlock::FullBlock(hash) => Some(*hash),
_ => panic!("Cannot have a partial block as parent"),
}
}
/// Push a token to the sequence
pub fn push(&mut self, token: u32) -> Option<Vec<MoveBlock>> {
self.tokens.append(token).expect("Token push failed.");
......@@ -158,24 +143,33 @@ impl ActiveSequence {
// Replace last partial block with full block if it exists
if let Some(UniqueBlock::PartialBlock(uuid)) = self.unique_blocks.last().cloned() {
let last_block_hash = if self.enable_prefix_caching {
let last_seq_hash = if self.enable_prefix_caching {
self.tokens.last_complete_block().unwrap().sequence_hash()
} else {
random::<u64>()
};
let last_block_hash = self.tokens.last_complete_block().unwrap().block_hash();
self.unique_blocks.pop();
// After pop, the last element is the parent block
let second_to_last_hash = self.unique_blocks.last().map(|block| match block {
UniqueBlock::FullBlock(hash) => *hash,
UniqueBlock::PartialBlock(_) => panic!("Cannot have a partial block as parent"),
});
self.unique_blocks
.push(UniqueBlock::FullBlock(last_block_hash));
.push(UniqueBlock::FullBlock(last_seq_hash));
signals.push(MoveBlock::Promote(
uuid,
last_seq_hash,
second_to_last_hash,
last_block_hash,
self.get_parent_hash(),
));
}
let new_partial_block = UniqueBlock::default();
self.unique_blocks.push(new_partial_block.clone());
signals.push(MoveBlock::Use(vec![new_partial_block]));
signals.push(MoveBlock::Use(vec![new_partial_block], vec![]));
Some(signals)
}
......@@ -241,13 +235,15 @@ impl ActiveSequence {
self.tokens.truncate(self.num_input_tokens).unwrap();
self.unique_blocks = create_unique_blocks_from_sequence(
&self.tokens,
None,
self.block_size,
self.enable_prefix_caching,
);
self.already_generated_tokens = self.generated_tokens.max(self.already_generated_tokens);
self.generated_tokens = 0;
self.creation_signal = Some(MoveBlock::Use(self.unique_blocks.clone()));
self.creation_signal = Some(MoveBlock::Use(
self.unique_blocks.clone(),
self.block_hashes(),
));
free_signal
}
......@@ -280,7 +276,7 @@ mod tests {
// Check that we got a Use signal
assert!(signal1.is_some());
match &signal1 {
Some(MoveBlock::Use(blocks)) => {
Some(MoveBlock::Use(blocks, _hashes)) => {
assert_eq!(blocks.len(), 1);
}
_ => panic!("Expected Use signal"),
......@@ -301,7 +297,7 @@ mod tests {
// First signal should be Promote for the previous block
match &signal_16[0] {
MoveBlock::Promote(_, _, parent_hash) => {
MoveBlock::Promote(_, _, parent_hash, _hash) => {
assert_eq!(*parent_hash, None);
}
_ => panic!("Expected Promote signal as second signal"),
......@@ -309,7 +305,7 @@ mod tests {
// Second signal should be Use for new partial block
match &signal_16[1] {
MoveBlock::Use(blocks) => {
MoveBlock::Use(blocks, _hashes) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
......@@ -396,7 +392,7 @@ mod tests {
// Check that signal[0] is promote
match &signal[0] {
MoveBlock::Promote(_, _, parent_hash) => {
MoveBlock::Promote(_, _, parent_hash, _hash) => {
// Check that the parent_hash matches unique_blocks[1], which should be a full block
if let UniqueBlock::FullBlock(expected_hash) = seq1.unique_blocks()[1] {
assert_eq!(
......@@ -430,7 +426,7 @@ mod tests {
// Initial signal - should have received a Use signal for the partial block
assert!(signal.is_some());
match signal {
Some(MoveBlock::Use(blocks)) => {
Some(MoveBlock::Use(blocks, _hashes)) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
......@@ -448,7 +444,7 @@ mod tests {
// First signal should be Promote
match &signals_second[0] {
MoveBlock::Promote(_, _, parent_hash) => {
MoveBlock::Promote(_, _, parent_hash, _hash) => {
assert_eq!(*parent_hash, None);
}
_ => panic!("Expected Promote signal as first signal after second token"),
......@@ -456,7 +452,7 @@ mod tests {
// Second signal should be Use for new partial block
match &signals_second[1] {
MoveBlock::Use(blocks) => {
MoveBlock::Use(blocks, _hashes) => {
assert_eq!(blocks.len(), 1);
assert!(matches!(blocks[0], UniqueBlock::PartialBlock(_)));
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment