Unverified Commit 803bfa81 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: proper local hashes for mockers + router watches endpoints (#2132)

parent e82bc4ec
...@@ -27,9 +27,9 @@ use clap::Parser; ...@@ -27,9 +27,9 @@ use clap::Parser;
use dynamo_llm::kv_router::{ use dynamo_llm::kv_router::{
protocols::WorkerSelectionResult, protocols::WorkerSelectionResult,
scheduler::{DefaultWorkerSelector, KvSchedulerError, SchedulingRequest}, scheduler::{DefaultWorkerSelector, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints,
KvRouter, WorkerSelector, KvRouter, WorkerSelector,
}; };
use dynamo_runtime::component::Instance;
use dynamo_runtime::{ use dynamo_runtime::{
logging, pipeline::network::Ingress, DistributedRuntime, Result, Runtime, Worker, logging, pipeline::network::Ingress, DistributedRuntime, Result, Runtime, Worker,
}; };
...@@ -86,7 +86,7 @@ pub struct CustomWorkerSelector(DefaultWorkerSelector); ...@@ -86,7 +86,7 @@ pub struct CustomWorkerSelector(DefaultWorkerSelector);
impl WorkerSelector for CustomWorkerSelector { impl WorkerSelector for CustomWorkerSelector {
fn select_worker( fn select_worker(
&self, &self,
workers: &ProcessedEndpoints, workers: &[Instance],
request: &SchedulingRequest, request: &SchedulingRequest,
block_size: u32, block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> { ) -> Result<WorkerSelectionResult, KvSchedulerError> {
......
...@@ -34,7 +34,7 @@ use crate::{ ...@@ -34,7 +34,7 @@ use crate::{
compute_block_hash_for_seq, KvIndexer, KvIndexerInterface, KvRouterError, compute_block_hash_for_seq, KvIndexer, KvIndexerInterface, KvRouterError,
OverlapScores, RouterEvent, OverlapScores, RouterEvent,
}, },
metrics_aggregator::EndpointCollector, // metrics_aggregator::EndpointCollector,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest}, scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints, scoring::ProcessedEndpoints,
...@@ -43,6 +43,7 @@ use crate::{ ...@@ -43,6 +43,7 @@ use crate::{
protocols::common::llm_backend::LLMEngineOutput, protocols::common::llm_backend::LLMEngineOutput,
}; };
use dynamo_runtime::component::Instance;
use dynamo_runtime::traits::events::EventSubscriber; use dynamo_runtime::traits::events::EventSubscriber;
// [gluo TODO] shouldn't need to be public // [gluo TODO] shouldn't need to be public
...@@ -55,7 +56,7 @@ pub const KV_METRICS_ENDPOINT: &str = "load_metrics"; ...@@ -55,7 +56,7 @@ pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
pub trait WorkerSelector { pub trait WorkerSelector {
fn select_worker( fn select_worker(
&self, &self,
workers: &ProcessedEndpoints, workers: &[Instance],
request: &SchedulingRequest, request: &SchedulingRequest,
block_size: u32, block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError>; ) -> Result<WorkerSelectionResult, KvSchedulerError>;
...@@ -151,8 +152,16 @@ impl KvRouter { ...@@ -151,8 +152,16 @@ impl KvRouter {
.primary_lease() .primary_lease()
.expect("Cannot KV route static workers") .expect("Cannot KV route static workers")
.primary_token(); .primary_token();
let metrics_aggregator =
EndpointCollector::new(component.clone(), cancellation_token.clone()).await; let generate_endpoint = component.endpoint("generate");
let client = generate_endpoint.client().await?;
let instances_rx = match client.instance_source.as_ref() {
InstanceSource::Dynamic(rx) => rx.clone(),
InstanceSource::Static => {
panic!("Expected dynamic instance source for KV routing");
}
};
let indexer = if use_kv_events { let indexer = if use_kv_events {
Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size)) Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size))
...@@ -168,7 +177,7 @@ impl KvRouter { ...@@ -168,7 +177,7 @@ impl KvRouter {
let scheduler = KvScheduler::start( let scheduler = KvScheduler::start(
component.namespace().clone(), component.namespace().clone(),
block_size, block_size,
metrics_aggregator.endpoints_watcher(), instances_rx,
selector, selector,
) )
.await?; .await?;
...@@ -325,6 +334,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -325,6 +334,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let isl = backend_input.token_ids.len(); let isl = backend_input.token_ids.len();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount); backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input); let updated_request = context.map(|_| backend_input);
// if request has the annotation "query_instance_id", for example // if request has the annotation "query_instance_id", for example
// curl -d '{... ,"nvext": { "annotations": ["query_instance_id"]}}' // curl -d '{... ,"nvext": { "annotations": ["query_instance_id"]}}'
// request will not be routed to worker immediately // request will not be routed to worker immediately
...@@ -333,61 +343,59 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -333,61 +343,59 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let response = let response =
Annotated::from_annotation("worker_instance_id", &instance_id_str)?; Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
let stream = stream::iter(vec![response]); let stream = stream::iter(vec![response]);
Ok(ResponseStream::new(Box::pin(stream), stream_context)) return Ok(ResponseStream::new(Box::pin(stream), stream_context));
} else { }
// Get the response stream from the worker // Get the response stream from the worker
let mut response_stream = let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
self.inner.direct(updated_request, instance_id).await?;
// Wrap the stream to track tokens
// Wrap the stream to track tokens let stream_context = response_stream.context();
let stream_context = response_stream.context(); let chooser = self.chooser.clone();
let chooser = self.chooser.clone(); let request_id = context_id.clone();
let request_id = context_id.clone(); let block_size = chooser.block_size() as usize;
let block_size = chooser.block_size() as usize;
let wrapped_stream = Box::pin(async_stream::stream! {
let wrapped_stream = Box::pin(async_stream::stream! { let mut accumulated_tokens = Vec::new();
let mut accumulated_tokens = Vec::new(); let mut total_output_length = 0usize;
let mut total_output_length = 0usize; let mut last_block_index = (isl.saturating_sub(1)) / block_size;
let mut last_block_index = (isl.saturating_sub(1)) / block_size; let mut first_push_done = false;
let mut first_push_done = false;
while let Some(item) = response_stream.next().await {
while let Some(item) = response_stream.next().await { // Track tokens if they exist in the response
// Track tokens if they exist in the response let Some(ref output) = item.data else {
let Some(ref output) = item.data else { yield item;
yield item; continue;
continue; };
}; if output.token_ids.is_empty() {
if output.token_ids.is_empty() { yield item;
yield item; continue;
continue; }
}
// Add tokens to accumulator // Add tokens to accumulator
accumulated_tokens.extend_from_slice(&output.token_ids); accumulated_tokens.extend_from_slice(&output.token_ids);
total_output_length += output.token_ids.len(); total_output_length += output.token_ids.len();
// Always push for the first generated token (to mark prefill done) // Always push for the first generated token (to mark prefill done)
// or when we've moved to a new block // or when we've moved to a new block
let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size; let current_block_index = (isl + total_output_length).saturating_sub(1) / block_size;
let should_push = (!first_push_done && total_output_length >= 1) || let should_push = (!first_push_done && total_output_length >= 1) ||
(first_push_done && current_block_index > last_block_index); (first_push_done && current_block_index > last_block_index);
if should_push { if should_push {
chooser.push(&request_id, &accumulated_tokens).await; chooser.push(&request_id, &accumulated_tokens).await;
accumulated_tokens.clear(); accumulated_tokens.clear();
last_block_index = current_block_index; last_block_index = current_block_index;
if !first_push_done { if !first_push_done {
first_push_done = true; first_push_done = true;
}
} }
yield item;
} }
chooser.free(&request_id).await; yield item;
}); }
Ok(ResponseStream::new(wrapped_stream, stream_context))
} chooser.free(&request_id).await;
});
Ok(ResponseStream::new(wrapped_stream, stream_context))
} }
} }
} }
......
...@@ -26,11 +26,11 @@ use super::protocols::WorkerSelectionResult; ...@@ -26,11 +26,11 @@ use super::protocols::WorkerSelectionResult;
use super::WorkerSelector; use super::WorkerSelector;
use crate::kv_router::indexer::OverlapScores; use crate::kv_router::indexer::OverlapScores;
use crate::kv_router::protocols::LoadMetrics; use crate::kv_router::protocols::LoadMetrics;
use crate::kv_router::scoring::ProcessedEndpoints;
use crate::kv_router::sequence::ActiveSequencesMultiWorker; use crate::kv_router::sequence::ActiveSequencesMultiWorker;
use crate::kv_router::KvRouterConfig; use crate::kv_router::KvRouterConfig;
use crate::kv_router::KV_HIT_RATE_SUBJECT; use crate::kv_router::KV_HIT_RATE_SUBJECT;
use crate::tokens::TokenBlockSequence; use crate::tokens::TokenBlockSequence;
use dynamo_runtime::component::Instance;
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent { pub struct KVHitRateEvent {
...@@ -107,12 +107,14 @@ impl KvScheduler { ...@@ -107,12 +107,14 @@ impl KvScheduler {
pub async fn start( pub async fn start(
ns: Namespace, ns: Namespace,
block_size: u32, block_size: u32,
endpoints_rx: tokio::sync::watch::Receiver<ProcessedEndpoints>, mut instances_rx: tokio::sync::watch::Receiver<Vec<Instance>>, // Changed from ProcessedEndpoints
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default())); let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let mut endpoints_rx = endpoints_rx; let mut instances: Vec<Instance> = instances_rx.borrow_and_update().clone();
let mut endpoints: ProcessedEndpoints = endpoints_rx.borrow_and_update().clone();
// Get worker IDs from instances
let worker_ids: Vec<i64> = instances.iter().map(|i| i.instance_id).collect();
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>(); let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();
tokio::spawn(async move { tokio::spawn(async move {
...@@ -126,7 +128,7 @@ impl KvScheduler { ...@@ -126,7 +128,7 @@ impl KvScheduler {
let sequences = Arc::new(Mutex::new(ActiveSequencesMultiWorker::new( let sequences = Arc::new(Mutex::new(ActiveSequencesMultiWorker::new(
block_size as usize, block_size as usize,
endpoints.worker_ids(), worker_ids,
))); )));
// Channel to accept new scheduling requests // Channel to accept new scheduling requests
...@@ -142,9 +144,10 @@ impl KvScheduler { ...@@ -142,9 +144,10 @@ impl KvScheduler {
request = tokio::select! { request = tokio::select! {
biased; biased;
_ = endpoints_rx.changed() => { _ = instances_rx.changed() => {
endpoints = endpoints_rx.borrow_and_update().clone(); instances = instances_rx.borrow_and_update().clone();
pending_endpoint_update = Some(endpoints.worker_ids()); let worker_ids: Vec<i64> = instances.iter().map(|i| i.instance_id).collect();
pending_endpoint_update = Some(worker_ids);
continue 'outer; continue 'outer;
} }
...@@ -159,7 +162,8 @@ impl KvScheduler { ...@@ -159,7 +162,8 @@ impl KvScheduler {
}; };
loop { loop {
match selector.select_worker(&endpoints, &request, block_size) { // When calling selector.select_worker, we need to adapt
match selector.select_worker(&instances, &request, block_size) {
Ok(selection) => { Ok(selection) => {
if let Err(e) = event_tx.send(KVHitRateEvent { if let Err(e) = event_tx.send(KVHitRateEvent {
worker_id: selection.worker_id, worker_id: selection.worker_id,
...@@ -179,9 +183,11 @@ impl KvScheduler { ...@@ -179,9 +183,11 @@ impl KvScheduler {
} }
Err(KvSchedulerError::NoEndpoints) => { Err(KvSchedulerError::NoEndpoints) => {
tracing::trace!("no endpoints available; waiting for endpoints update"); tracing::trace!("no endpoints available; waiting for endpoints update");
endpoints_rx.changed().await.ok(); instances_rx.changed().await.ok();
endpoints = endpoints_rx.borrow_and_update().clone(); instances = instances_rx.borrow_and_update().clone();
pending_endpoint_update = Some(endpoints.worker_ids()); let worker_ids: Vec<i64> =
instances.iter().map(|i| i.instance_id).collect();
pending_endpoint_update = Some(worker_ids);
continue; continue;
} }
// TODO: this is not actually hooked up // TODO: this is not actually hooked up
...@@ -353,13 +359,13 @@ impl DefaultWorkerSelector { ...@@ -353,13 +359,13 @@ impl DefaultWorkerSelector {
impl WorkerSelector for DefaultWorkerSelector { impl WorkerSelector for DefaultWorkerSelector {
fn select_worker( fn select_worker(
&self, &self,
workers: &ProcessedEndpoints, workers: &[Instance],
request: &SchedulingRequest, request: &SchedulingRequest,
block_size: u32, block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> { ) -> Result<WorkerSelectionResult, KvSchedulerError> {
assert!(request.isl_tokens > 0); assert!(request.isl_tokens > 0);
if workers.endpoints.is_empty() { if workers.is_empty() {
return Err(KvSchedulerError::NoEndpoints); return Err(KvSchedulerError::NoEndpoints);
} }
...@@ -376,9 +382,10 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -376,9 +382,10 @@ impl WorkerSelector for DefaultWorkerSelector {
let mut max_logit = f64::NEG_INFINITY; let mut max_logit = f64::NEG_INFINITY;
// Calculate logits for each worker // Calculate logits for each worker
for (worker_id, _) in workers.endpoints.iter() { for instance in workers.iter() {
let worker_id = instance.instance_id;
// this is the number of tokens each worker would have if the request were scheduled there // this is the number of tokens each worker would have if the request were scheduled there
let potential_tokens = *potential_active_tokens.get(worker_id).unwrap_or_else(|| { let potential_tokens = *potential_active_tokens.get(&worker_id).unwrap_or_else(|| {
tracing::warn!( tracing::warn!(
"assuming {isl} tokens for {worker_id}, as the endpoint does not exist yet" "assuming {isl} tokens for {worker_id}, as the endpoint does not exist yet"
); );
...@@ -386,7 +393,7 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -386,7 +393,7 @@ impl WorkerSelector for DefaultWorkerSelector {
}) as f64; }) as f64;
// this is the number of blocks each worker would have if the request were scheduled there // this is the number of blocks each worker would have if the request were scheduled there
let potential_blocks = *potential_active_blocks.get(worker_id).unwrap_or_else(|| let potential_blocks = *potential_active_blocks.get(&worker_id).unwrap_or_else(||
{tracing::warn!("assuming {request_blocks} decoding blocks for {worker_id}, as the endpoint does not exist yet"); {tracing::warn!("assuming {request_blocks} decoding blocks for {worker_id}, as the endpoint does not exist yet");
&request_blocks &request_blocks
}) as f64; }) as f64;
...@@ -398,12 +405,12 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -398,12 +405,12 @@ impl WorkerSelector for DefaultWorkerSelector {
+ potential_blocks; + potential_blocks;
max_logit = max_logit.max(logit); max_logit = max_logit.max(logit);
worker_logits.insert(*worker_id, logit); worker_logits.insert(worker_id, logit);
tracing::info!( tracing::info!(
"Formula for {worker_id}: {logit:.3} = {:.1} * {potential_prefill_blocks:.3} + {potential_blocks:.3} (cached_blocks: {})", "Formula for {worker_id}: {logit:.3} = {:.1} * {potential_prefill_blocks:.3} + {potential_blocks:.3} (cached_blocks: {})",
self.kv_router_config.overlap_score_weight, self.kv_router_config.overlap_score_weight,
overlaps.get(worker_id).unwrap_or(&0), overlaps.get(&worker_id).unwrap_or(&0),
); );
} }
......
...@@ -24,9 +24,8 @@ use crate::kv_router::protocols::{ ...@@ -24,9 +24,8 @@ use crate::kv_router::protocols::{
KvCacheStoredBlockData, LocalBlockHash, KvCacheStoredBlockData, LocalBlockHash,
}; };
use crate::tokens::blocks::UniqueBlock; use crate::tokens::blocks::UniqueBlock;
use crate::tokens::{BlockHash, SequenceHash, Token};
pub type Token = u32;
pub type GlobalHash = u64;
pub type NumBlocks = usize; pub type NumBlocks = usize;
/// Represents different block movement operations in the cache /// Represents different block movement operations in the cache
...@@ -36,13 +35,13 @@ pub enum MoveBlock { ...@@ -36,13 +35,13 @@ pub enum MoveBlock {
Use(Vec<UniqueBlock>), Use(Vec<UniqueBlock>),
Destroy(Vec<UniqueBlock>), Destroy(Vec<UniqueBlock>),
Deref(Vec<UniqueBlock>), Deref(Vec<UniqueBlock>),
Promote(Uuid, GlobalHash, Option<u64>), Promote(Uuid, SequenceHash, Option<u64>),
} }
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlockResponse { pub enum MoveBlockResponse {
Store(Vec<GlobalHash>, Option<u64>), Store(Vec<SequenceHash>, Option<u64>),
Remove(Vec<GlobalHash>), Remove(Vec<SequenceHash>),
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
...@@ -222,18 +221,36 @@ impl MockEngineArgs { ...@@ -222,18 +221,36 @@ impl MockEngineArgs {
} }
} }
/// Note: This assumes block_hash and tokens_hash are the same, which is not correct in rare cases /// Converts a MoveBlockResponse from the mocker backend into a KvCacheEventData.
/// where the sequence-aware hash differs from the token content hash. ///
pub fn block_response_to_kv_event(response: MoveBlockResponse) -> KvCacheEventData { /// This function assumes that the stored sequence hashes in the response always
/// correspond to the tail part of the local hashes array. This is the expected
/// behavior of KV block storage, where blocks are stored sequentially and the
/// response contains the most recent blocks that were stored.
///
/// # Panics
/// Panics if the number of blocks in the Store response exceeds the length
/// of local_hashes.
pub fn block_response_to_kv_event(
response: MoveBlockResponse,
local_hashes: &[BlockHash],
) -> KvCacheEventData {
match response { match response {
MoveBlockResponse::Store(full_blocks, parent_hash) => { MoveBlockResponse::Store(full_blocks, parent_hash) => {
let num_blocks = full_blocks.len();
let local_hashes_slice = &local_hashes[local_hashes
.len()
.checked_sub(num_blocks)
.expect("local hashes fewer than block response signal")..];
KvCacheEventData::Stored(KvCacheStoreData { KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_hash.map(ExternalSequenceBlockHash), parent_hash: parent_hash.map(ExternalSequenceBlockHash),
blocks: full_blocks blocks: full_blocks
.into_iter() .into_iter()
.map(|block| KvCacheStoredBlockData { .zip(local_hashes_slice.iter())
block_hash: ExternalSequenceBlockHash(block), .map(|(global_hash, local_hash)| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(block), block_hash: ExternalSequenceBlockHash(global_hash),
tokens_hash: LocalBlockHash(*local_hash),
}) })
.collect(), .collect(),
}) })
......
...@@ -47,6 +47,7 @@ use crate::mocker::protocols::{block_response_to_kv_event, MoveBlock, OutputSign ...@@ -47,6 +47,7 @@ use crate::mocker::protocols::{block_response_to_kv_event, MoveBlock, OutputSign
use crate::mocker::protocols::{DirectRequest, MockEngineArgs, MoveBlockResponse}; use crate::mocker::protocols::{DirectRequest, MockEngineArgs, MoveBlockResponse};
use crate::mocker::sequence::ActiveSequence; use crate::mocker::sequence::ActiveSequence;
use crate::tokens::blocks::UniqueBlock; use crate::tokens::blocks::UniqueBlock;
use crate::tokens::BlockHash;
use std::collections::HashMap; use std::collections::HashMap;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::sync::Arc; use std::sync::Arc;
...@@ -123,8 +124,9 @@ impl SchedulerState { ...@@ -123,8 +124,9 @@ impl SchedulerState {
/// Returns `Some((prefill_compute, creation_signal, is_full_prefill))` where: /// Returns `Some((prefill_compute, creation_signal, is_full_prefill))` where:
/// - `prefill_compute`: The compute time in milliseconds for this prefill operation /// - `prefill_compute`: The compute time in milliseconds for this prefill operation
/// - `creation_signal`: Optional MoveBlock signal for KV cache block creation /// - `creation_signal`: Optional MoveBlock signal for KV cache block creation
/// - `block_hashes`: Block hashes of the sequence beign prefilled
/// - `is_full_prefill`: true if the entire sequence was prefilled, false if chunked /// - `is_full_prefill`: true if the entire sequence was prefilled, false if chunked
fn try_prefill(&mut self) -> Option<(f64, Option<MoveBlock>, bool)> { fn try_prefill(&mut self) -> Option<(f64, Option<MoveBlock>, Vec<BlockHash>, bool)> {
let uuid = self.prefill.pop_front()?; let uuid = self.prefill.pop_front()?;
// Remove and extract prefill_compute from prefill_costs // Remove and extract prefill_compute from prefill_costs
...@@ -179,6 +181,7 @@ impl SchedulerState { ...@@ -179,6 +181,7 @@ impl SchedulerState {
Some(( Some((
prefill_compute, prefill_compute,
sequence.take_creation_signal(), sequence.take_creation_signal(),
sequence.block_hashes(),
is_full_prefill, is_full_prefill,
)) ))
} }
...@@ -401,8 +404,12 @@ impl Scheduler { ...@@ -401,8 +404,12 @@ impl Scheduler {
let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0); let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0);
// Process prefilling // Process prefilling
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) = while let Some((
state_guard.try_prefill() prefill_compute,
maybe_creation_signal,
block_hashes,
is_full_prefill,
)) = state_guard.try_prefill()
{ {
// NOTE: Prefill cost/time is always incremented for new blocks, even if they // NOTE: Prefill cost/time is always incremented for new blocks, even if they
// could be cached by other requests in the same batch. This matches vLLM behavior. // could be cached by other requests in the same batch. This matches vLLM behavior.
...@@ -421,7 +428,8 @@ impl Scheduler { ...@@ -421,7 +428,8 @@ impl Scheduler {
(&kv_events_tx, &mut block_resp_rx) (&kv_events_tx, &mut block_resp_rx)
{ {
while let Ok(event) = rx.try_recv() { while let Ok(event) = rx.try_recv() {
let _ = relay_tx.send(block_response_to_kv_event(event)); let _ =
relay_tx.send(block_response_to_kv_event(event, &block_hashes));
} }
} }
}; };
...@@ -460,7 +468,8 @@ impl Scheduler { ...@@ -460,7 +468,8 @@ impl Scheduler {
(&kv_events_tx, &mut block_resp_rx) (&kv_events_tx, &mut block_resp_rx)
{ {
while let Ok(event) = rx.try_recv() { while let Ok(event) = rx.try_recv() {
let _ = relay_tx.send(block_response_to_kv_event(event)); let _ = relay_tx
.send(block_response_to_kv_event(event, &sequence.block_hashes()));
} }
} }
......
...@@ -90,7 +90,7 @@ impl ActiveSequence { ...@@ -90,7 +90,7 @@ impl ActiveSequence {
assert!(block_size > 1, "block_size must be greater than 1"); assert!(block_size > 1, "block_size must be greater than 1");
let num_input_tokens = tokens.len(); let num_input_tokens = tokens.len();
let tokens = Tokens::from(tokens).into_sequence(block_size as u32, None); let tokens = Tokens::from(tokens).into_sequence(block_size as u32, Some(1337));
let unique_blocks = let unique_blocks =
create_unique_blocks_from_sequence(&tokens, None, block_size, enable_prefix_caching); create_unique_blocks_from_sequence(&tokens, None, block_size, enable_prefix_caching);
let creation_signal = Some(MoveBlock::Use(unique_blocks.clone())); let creation_signal = Some(MoveBlock::Use(unique_blocks.clone()));
...@@ -124,6 +124,14 @@ impl ActiveSequence { ...@@ -124,6 +124,14 @@ impl ActiveSequence {
self.creation_signal.take() self.creation_signal.take()
} }
pub fn block_hashes(&self) -> Vec<u64> {
self.tokens
.blocks()
.iter()
.map(|block| block.block_hash())
.collect()
}
/// Create a new ActiveSequence instance and return the creation signal /// Create a new ActiveSequence instance and return the creation signal
pub fn new_with_signal( pub fn new_with_signal(
tokens: Vec<u32>, tokens: Vec<u32>,
......
...@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__) ...@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
MODEL_NAME = "Qwen/Qwen3-0.6B" MODEL_NAME = "Qwen/Qwen3-0.6B"
NUM_MOCKERS = 2 NUM_MOCKERS = 2
BLOCK_SIZE = 16
SPEEDUP_RATIO = 10.0 SPEEDUP_RATIO = 10.0
NUM_REQUESTS = 100 NUM_REQUESTS = 100
PORT = 8090 # Starting port for mocker instances PORT = 8090 # Starting port for mocker instances
...@@ -59,6 +60,8 @@ class KVRouterProcess(ManagedProcess): ...@@ -59,6 +60,8 @@ class KVRouterProcess(ManagedProcess):
"python", "python",
"-m", "-m",
"dynamo.frontend", "dynamo.frontend",
"--kv-cache-block-size",
str(BLOCK_SIZE),
"--router-mode", "--router-mode",
"kv", "kv",
"--http-port", "--http-port",
...@@ -100,7 +103,7 @@ def test_mocker_kv_router(request, runtime_services): ...@@ -100,7 +103,7 @@ def test_mocker_kv_router(request, runtime_services):
logger.info("Starting mocker KV router test") logger.info("Starting mocker KV router test")
# Create mocker args file # Create mocker args file
mocker_args = {"speedup_ratio": SPEEDUP_RATIO} mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
mocker_args_file = os.path.join(request.node.name, "mocker_args.json") mocker_args_file = os.path.join(request.node.name, "mocker_args.json")
with open(mocker_args_file, "w") as f: with open(mocker_args_file, "w") as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment