Unverified Commit 803bfa81 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: proper local hashes for mockers + router watches endpoints (#2132)

parent e82bc4ec
......@@ -27,9 +27,9 @@ use clap::Parser;
use dynamo_llm::kv_router::{
protocols::WorkerSelectionResult,
scheduler::{DefaultWorkerSelector, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints,
KvRouter, WorkerSelector,
};
use dynamo_runtime::component::Instance;
use dynamo_runtime::{
logging, pipeline::network::Ingress, DistributedRuntime, Result, Runtime, Worker,
};
......@@ -86,7 +86,7 @@ pub struct CustomWorkerSelector(DefaultWorkerSelector);
impl WorkerSelector for CustomWorkerSelector {
fn select_worker(
&self,
workers: &ProcessedEndpoints,
workers: &[Instance],
request: &SchedulingRequest,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> {
......
......@@ -34,7 +34,7 @@ use crate::{
compute_block_hash_for_seq, KvIndexer, KvIndexerInterface, KvRouterError,
OverlapScores, RouterEvent,
},
metrics_aggregator::EndpointCollector,
// metrics_aggregator::EndpointCollector,
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints,
......@@ -43,6 +43,7 @@ use crate::{
protocols::common::llm_backend::LLMEngineOutput,
};
use dynamo_runtime::component::Instance;
use dynamo_runtime::traits::events::EventSubscriber;
// [gluo TODO] shouldn't need to be public
......@@ -55,7 +56,7 @@ pub const KV_METRICS_ENDPOINT: &str = "load_metrics";
pub trait WorkerSelector {
fn select_worker(
&self,
workers: &ProcessedEndpoints,
workers: &[Instance],
request: &SchedulingRequest,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError>;
......@@ -151,8 +152,16 @@ impl KvRouter {
.primary_lease()
.expect("Cannot KV route static workers")
.primary_token();
let metrics_aggregator =
EndpointCollector::new(component.clone(), cancellation_token.clone()).await;
let generate_endpoint = component.endpoint("generate");
let client = generate_endpoint.client().await?;
let instances_rx = match client.instance_source.as_ref() {
InstanceSource::Dynamic(rx) => rx.clone(),
InstanceSource::Static => {
panic!("Expected dynamic instance source for KV routing");
}
};
let indexer = if use_kv_events {
Indexer::KvIndexer(KvIndexer::new(cancellation_token.clone(), block_size))
......@@ -168,7 +177,7 @@ impl KvRouter {
let scheduler = KvScheduler::start(
component.namespace().clone(),
block_size,
metrics_aggregator.endpoints_watcher(),
instances_rx,
selector,
)
.await?;
......@@ -325,6 +334,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let isl = backend_input.token_ids.len();
backend_input.estimated_prefix_hit_num_blocks = Some(overlap_amount);
let updated_request = context.map(|_| backend_input);
// if request has the annotation "query_instance_id", for example
// curl -d '{... ,"nvext": { "annotations": ["query_instance_id"]}}'
// request will not be routed to worker immediately
......@@ -333,11 +343,10 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
let response =
Annotated::from_annotation("worker_instance_id", &instance_id_str)?;
let stream = stream::iter(vec![response]);
Ok(ResponseStream::new(Box::pin(stream), stream_context))
} else {
return Ok(ResponseStream::new(Box::pin(stream), stream_context));
}
// Get the response stream from the worker
let mut response_stream =
self.inner.direct(updated_request, instance_id).await?;
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
// Wrap the stream to track tokens
let stream_context = response_stream.context();
......@@ -390,5 +399,4 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
}
}
}
}
}
......@@ -26,11 +26,11 @@ use super::protocols::WorkerSelectionResult;
use super::WorkerSelector;
use crate::kv_router::indexer::OverlapScores;
use crate::kv_router::protocols::LoadMetrics;
use crate::kv_router::scoring::ProcessedEndpoints;
use crate::kv_router::sequence::ActiveSequencesMultiWorker;
use crate::kv_router::KvRouterConfig;
use crate::kv_router::KV_HIT_RATE_SUBJECT;
use crate::tokens::TokenBlockSequence;
use dynamo_runtime::component::Instance;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KVHitRateEvent {
......@@ -107,12 +107,14 @@ impl KvScheduler {
pub async fn start(
ns: Namespace,
block_size: u32,
endpoints_rx: tokio::sync::watch::Receiver<ProcessedEndpoints>,
mut instances_rx: tokio::sync::watch::Receiver<Vec<Instance>>, // Changed from ProcessedEndpoints
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let mut endpoints_rx = endpoints_rx;
let mut endpoints: ProcessedEndpoints = endpoints_rx.borrow_and_update().clone();
let mut instances: Vec<Instance> = instances_rx.borrow_and_update().clone();
// Get worker IDs from instances
let worker_ids: Vec<i64> = instances.iter().map(|i| i.instance_id).collect();
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<KVHitRateEvent>();
tokio::spawn(async move {
......@@ -126,7 +128,7 @@ impl KvScheduler {
let sequences = Arc::new(Mutex::new(ActiveSequencesMultiWorker::new(
block_size as usize,
endpoints.worker_ids(),
worker_ids,
)));
// Channel to accept new scheduling requests
......@@ -142,9 +144,10 @@ impl KvScheduler {
request = tokio::select! {
biased;
_ = endpoints_rx.changed() => {
endpoints = endpoints_rx.borrow_and_update().clone();
pending_endpoint_update = Some(endpoints.worker_ids());
_ = instances_rx.changed() => {
instances = instances_rx.borrow_and_update().clone();
let worker_ids: Vec<i64> = instances.iter().map(|i| i.instance_id).collect();
pending_endpoint_update = Some(worker_ids);
continue 'outer;
}
......@@ -159,7 +162,8 @@ impl KvScheduler {
};
loop {
match selector.select_worker(&endpoints, &request, block_size) {
// When calling selector.select_worker, we need to adapt
match selector.select_worker(&instances, &request, block_size) {
Ok(selection) => {
if let Err(e) = event_tx.send(KVHitRateEvent {
worker_id: selection.worker_id,
......@@ -179,9 +183,11 @@ impl KvScheduler {
}
Err(KvSchedulerError::NoEndpoints) => {
tracing::trace!("no endpoints available; waiting for endpoints update");
endpoints_rx.changed().await.ok();
endpoints = endpoints_rx.borrow_and_update().clone();
pending_endpoint_update = Some(endpoints.worker_ids());
instances_rx.changed().await.ok();
instances = instances_rx.borrow_and_update().clone();
let worker_ids: Vec<i64> =
instances.iter().map(|i| i.instance_id).collect();
pending_endpoint_update = Some(worker_ids);
continue;
}
// TODO: this is not actually hooked up
......@@ -353,13 +359,13 @@ impl DefaultWorkerSelector {
impl WorkerSelector for DefaultWorkerSelector {
fn select_worker(
&self,
workers: &ProcessedEndpoints,
workers: &[Instance],
request: &SchedulingRequest,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> {
assert!(request.isl_tokens > 0);
if workers.endpoints.is_empty() {
if workers.is_empty() {
return Err(KvSchedulerError::NoEndpoints);
}
......@@ -376,9 +382,10 @@ impl WorkerSelector for DefaultWorkerSelector {
let mut max_logit = f64::NEG_INFINITY;
// Calculate logits for each worker
for (worker_id, _) in workers.endpoints.iter() {
for instance in workers.iter() {
let worker_id = instance.instance_id;
// this is the number of tokens each worker would have if the request were scheduled there
let potential_tokens = *potential_active_tokens.get(worker_id).unwrap_or_else(|| {
let potential_tokens = *potential_active_tokens.get(&worker_id).unwrap_or_else(|| {
tracing::warn!(
"assuming {isl} tokens for {worker_id}, as the endpoint does not exist yet"
);
......@@ -386,7 +393,7 @@ impl WorkerSelector for DefaultWorkerSelector {
}) as f64;
// this is the number of blocks each worker would have if the request were scheduled there
let potential_blocks = *potential_active_blocks.get(worker_id).unwrap_or_else(||
let potential_blocks = *potential_active_blocks.get(&worker_id).unwrap_or_else(||
{tracing::warn!("assuming {request_blocks} decoding blocks for {worker_id}, as the endpoint does not exist yet");
&request_blocks
}) as f64;
......@@ -398,12 +405,12 @@ impl WorkerSelector for DefaultWorkerSelector {
+ potential_blocks;
max_logit = max_logit.max(logit);
worker_logits.insert(*worker_id, logit);
worker_logits.insert(worker_id, logit);
tracing::info!(
"Formula for {worker_id}: {logit:.3} = {:.1} * {potential_prefill_blocks:.3} + {potential_blocks:.3} (cached_blocks: {})",
self.kv_router_config.overlap_score_weight,
overlaps.get(worker_id).unwrap_or(&0),
overlaps.get(&worker_id).unwrap_or(&0),
);
}
......
......@@ -24,9 +24,8 @@ use crate::kv_router::protocols::{
KvCacheStoredBlockData, LocalBlockHash,
};
use crate::tokens::blocks::UniqueBlock;
use crate::tokens::{BlockHash, SequenceHash, Token};
pub type Token = u32;
pub type GlobalHash = u64;
pub type NumBlocks = usize;
/// Represents different block movement operations in the cache
......@@ -36,13 +35,13 @@ pub enum MoveBlock {
Use(Vec<UniqueBlock>),
Destroy(Vec<UniqueBlock>),
Deref(Vec<UniqueBlock>),
Promote(Uuid, GlobalHash, Option<u64>),
Promote(Uuid, SequenceHash, Option<u64>),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MoveBlockResponse {
Store(Vec<GlobalHash>, Option<u64>),
Remove(Vec<GlobalHash>),
Store(Vec<SequenceHash>, Option<u64>),
Remove(Vec<SequenceHash>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
......@@ -222,18 +221,36 @@ impl MockEngineArgs {
}
}
/// Note: This assumes block_hash and tokens_hash are the same, which is not correct in rare cases
/// where the sequence-aware hash differs from the token content hash.
pub fn block_response_to_kv_event(response: MoveBlockResponse) -> KvCacheEventData {
/// Converts a MoveBlockResponse from the mocker backend into a KvCacheEventData.
///
/// This function assumes that the stored sequence hashes in the response always
/// correspond to the tail part of the local hashes array. This is the expected
/// behavior of KV block storage, where blocks are stored sequentially and the
/// response contains the most recent blocks that were stored.
///
/// # Panics
/// Panics if the number of blocks in the Store response exceeds the length
/// of local_hashes.
pub fn block_response_to_kv_event(
response: MoveBlockResponse,
local_hashes: &[BlockHash],
) -> KvCacheEventData {
match response {
MoveBlockResponse::Store(full_blocks, parent_hash) => {
let num_blocks = full_blocks.len();
let local_hashes_slice = &local_hashes[local_hashes
.len()
.checked_sub(num_blocks)
.expect("local hashes fewer than block response signal")..];
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_hash.map(ExternalSequenceBlockHash),
blocks: full_blocks
.into_iter()
.map(|block| KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block),
tokens_hash: LocalBlockHash(block),
.zip(local_hashes_slice.iter())
.map(|(global_hash, local_hash)| KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(global_hash),
tokens_hash: LocalBlockHash(*local_hash),
})
.collect(),
})
......
......@@ -47,6 +47,7 @@ use crate::mocker::protocols::{block_response_to_kv_event, MoveBlock, OutputSign
use crate::mocker::protocols::{DirectRequest, MockEngineArgs, MoveBlockResponse};
use crate::mocker::sequence::ActiveSequence;
use crate::tokens::blocks::UniqueBlock;
use crate::tokens::BlockHash;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::sync::Arc;
......@@ -123,8 +124,9 @@ impl SchedulerState {
/// Returns `Some((prefill_compute, creation_signal, is_full_prefill))` where:
/// - `prefill_compute`: The compute time in milliseconds for this prefill operation
/// - `creation_signal`: Optional MoveBlock signal for KV cache block creation
/// - `block_hashes`: Block hashes of the sequence beign prefilled
/// - `is_full_prefill`: true if the entire sequence was prefilled, false if chunked
fn try_prefill(&mut self) -> Option<(f64, Option<MoveBlock>, bool)> {
fn try_prefill(&mut self) -> Option<(f64, Option<MoveBlock>, Vec<BlockHash>, bool)> {
let uuid = self.prefill.pop_front()?;
// Remove and extract prefill_compute from prefill_costs
......@@ -179,6 +181,7 @@ impl SchedulerState {
Some((
prefill_compute,
sequence.take_creation_signal(),
sequence.block_hashes(),
is_full_prefill,
))
}
......@@ -401,8 +404,12 @@ impl Scheduler {
let mut total_time = Duration::from_secs_f64(decoding_time / 1000.0);
// Process prefilling
while let Some((prefill_compute, maybe_creation_signal, is_full_prefill)) =
state_guard.try_prefill()
while let Some((
prefill_compute,
maybe_creation_signal,
block_hashes,
is_full_prefill,
)) = state_guard.try_prefill()
{
// NOTE: Prefill cost/time is always incremented for new blocks, even if they
// could be cached by other requests in the same batch. This matches vLLM behavior.
......@@ -421,7 +428,8 @@ impl Scheduler {
(&kv_events_tx, &mut block_resp_rx)
{
while let Ok(event) = rx.try_recv() {
let _ = relay_tx.send(block_response_to_kv_event(event));
let _ =
relay_tx.send(block_response_to_kv_event(event, &block_hashes));
}
}
};
......@@ -460,7 +468,8 @@ impl Scheduler {
(&kv_events_tx, &mut block_resp_rx)
{
while let Ok(event) = rx.try_recv() {
let _ = relay_tx.send(block_response_to_kv_event(event));
let _ = relay_tx
.send(block_response_to_kv_event(event, &sequence.block_hashes()));
}
}
......
......@@ -90,7 +90,7 @@ impl ActiveSequence {
assert!(block_size > 1, "block_size must be greater than 1");
let num_input_tokens = tokens.len();
let tokens = Tokens::from(tokens).into_sequence(block_size as u32, None);
let tokens = Tokens::from(tokens).into_sequence(block_size as u32, Some(1337));
let unique_blocks =
create_unique_blocks_from_sequence(&tokens, None, block_size, enable_prefix_caching);
let creation_signal = Some(MoveBlock::Use(unique_blocks.clone()));
......@@ -124,6 +124,14 @@ impl ActiveSequence {
self.creation_signal.take()
}
pub fn block_hashes(&self) -> Vec<u64> {
self.tokens
.blocks()
.iter()
.map(|block| block.block_hash())
.collect()
}
/// Create a new ActiveSequence instance and return the creation signal
pub fn new_with_signal(
tokens: Vec<u32>,
......
......@@ -18,6 +18,7 @@ logger = logging.getLogger(__name__)
MODEL_NAME = "Qwen/Qwen3-0.6B"
NUM_MOCKERS = 2
BLOCK_SIZE = 16
SPEEDUP_RATIO = 10.0
NUM_REQUESTS = 100
PORT = 8090 # Starting port for mocker instances
......@@ -59,6 +60,8 @@ class KVRouterProcess(ManagedProcess):
"python",
"-m",
"dynamo.frontend",
"--kv-cache-block-size",
str(BLOCK_SIZE),
"--router-mode",
"kv",
"--http-port",
......@@ -100,7 +103,7 @@ def test_mocker_kv_router(request, runtime_services):
logger.info("Starting mocker KV router test")
# Create mocker args file
mocker_args = {"speedup_ratio": SPEEDUP_RATIO}
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
mocker_args_file = os.path.join(request.node.name, "mocker_args.json")
with open(mocker_args_file, "w") as f:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment