Unverified Commit 0b33c1df authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: sglang disagg routing fixes and optimizations [DYN-1692] (#5106)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarIshan Dhanani <ishandhanani@gmail.com>
Co-authored-by: default avatarSean SH Choi <sechoi@nvidia.com>
Co-authored-by: default avatarishandhanani <82981111+ishandhanani@users.noreply.github.com>
parent e49834c9
...@@ -49,13 +49,10 @@ use std::{ ...@@ -49,13 +49,10 @@ use std::{
}; };
use tokio::sync::{broadcast, mpsc, oneshot}; use tokio::sync::{broadcast, mpsc, oneshot};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use xxhash_rust::xxh3;
pub const XXH3_SEED: u64 = 1337;
use crate::kv_router::approx::{BlockEntry, PruneConfig, PruneManager}; use crate::kv_router::approx::{BlockEntry, PruneConfig, PruneManager};
use crate::kv_router::protocols::{BlockExtraInfo, *}; use crate::kv_router::protocols::*;
use crate::tokens::{SequenceHash, TokenBlockSequence}; use crate::tokens::SequenceHash;
/// Errors that can occur in the KV Router. /// Errors that can occur in the KV Router.
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
...@@ -89,119 +86,6 @@ pub enum KvCacheEventError { ...@@ -89,119 +86,6 @@ pub enum KvCacheEventError {
/// A shared reference to a [`RadixBlock`]. /// A shared reference to a [`RadixBlock`].
type SharedRadixBlock = Rc<RefCell<RadixBlock>>; type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
pub fn compute_hash(data: &[u8]) -> u64 {
xxh3::xxh3_64_with_seed(data, XXH3_SEED)
}
/// Compute the hash of a local block.
///
/// ### Arguments
///
/// * `data` - A byte slice representing the data to hash.
///
/// ### Returns
///
/// A `LocalBlockHash` representing the computed hash.
pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
LocalBlockHash(compute_hash(data))
}
// /// Updated version of the `compute_block_hash` function that included the lora_id
// pub fn compute_block_hash_v2(token_id: &[u32], lora_id: u64) {
// let mut bytes = Vec::new();
// for token in token_id {
// bytes.extend_from_slice(&token.to_le_bytes());
// }
// bytes.extend_from_slice(&lora_id.to_le_bytes());
// let hash = xxh3::xxh3_64_with_seed(&bytes, XXH3_SEED);
// }
/// Compute the hash for a sequence of tokens, optionally including multimodal metadata.
///
/// When multimodal extra info is provided, the mm_hashes are included in the hash computation
/// to ensure that blocks with identical tokens but different multimodal objects produce
/// different hashes.
///
/// ### Arguments
///
/// * `tokens` - A vector of `u32` tokens.
/// * `kv_block_size` - The size of each block in tokens.
/// * `block_mm_infos` - Optional per-block multimodal metadata.
///
/// ### Returns
///
/// A vector of `LocalBlockHash` representing the computed hashes for each chunk of tokens.
pub fn compute_block_hash_for_seq(
tokens: &[u32],
kv_block_size: u32,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
) -> Vec<LocalBlockHash> {
tokens
.chunks_exact(kv_block_size as usize)
.enumerate()
.map(|(block_idx, chunk)| {
let mut bytes: Vec<u8> = chunk.iter().flat_map(|&num| num.to_le_bytes()).collect();
// Include MM hashes in the block hash computation if present
if let Some(mm_infos) = block_mm_infos
&& let Some(Some(block_mm_info)) = mm_infos.get(block_idx)
{
// The order of different multimodal hashes does not matter.
// Only which multimodal infos are present in a block is important.
// The order may differ in different code paths, so the hashes are sorted
// to keep the block hash stable.
let mut mm_hashes: Vec<u64> = block_mm_info
.mm_objects
.iter()
.map(|obj| obj.mm_hash)
.collect();
mm_hashes.sort_unstable();
// Append sorted mm_hashes to the byte array
for mm_hash in mm_hashes {
bytes.extend_from_slice(&mm_hash.to_le_bytes());
}
}
compute_block_hash(&bytes)
})
.collect()
}
/// Compute rolling sequence hashes for a vector of block hashes.
///
/// This mirrors the behavior in tokens.rs where:
/// - The first block's sequence hash equals its block hash
/// - Subsequent blocks' sequence hash = hash([parent_sequence_hash, current_block_hash], seed)
///
/// ### Arguments
///
/// * `block_hashes` - A vector of `LocalBlockHash` values representing the block hashes.
///
/// ### Returns
///
/// A vector of u64 values representing the sequence hashes for each block.
pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<SequenceHash> {
if block_hashes.is_empty() {
return Vec::new();
}
let mut sequence_hashes = Vec::with_capacity(block_hashes.len());
sequence_hashes.push(block_hashes[0].0);
for i in 1..block_hashes.len() {
let parent_seq_hash = sequence_hashes[i - 1];
let current_block_hash = block_hashes[i].0;
let combined = [parent_seq_hash, current_block_hash];
let bytes: Vec<u8> = combined.iter().flat_map(|&num| num.to_le_bytes()).collect();
let seq_hash = compute_hash(&bytes);
sequence_hashes.push(seq_hash);
}
sequence_hashes
}
/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`]. /// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`].
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RouterEvent { pub struct RouterEvent {
...@@ -897,29 +781,18 @@ pub trait KvIndexerInterface { ...@@ -897,29 +781,18 @@ pub trait KvIndexerInterface {
/// A vector of RouterEvents representing the current state of the tree. /// A vector of RouterEvents representing the current state of the tree.
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError>; async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError>;
/// Process a routing decision with pre-computed hashes.
///
/// ### Arguments
///
/// * `worker` - The worker (with dp_rank) that was selected.
/// * `local_hashes` - The local hashes of the tokens sent to the worker.
/// * `sequence_hashes` - The sequence hashes of the tokens sent to the worker.
async fn process_routing_decision(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError>;
/// Process a routing decision for a request with tokens. /// Process a routing decision for a request with tokens.
/// ///
/// Uses TokensWithHashes for lazy hash computation - if hashes were already
/// computed (e.g., by find_best_match), they will be reused.
///
/// ### Arguments /// ### Arguments
/// ///
/// * `tokens` - A vector of `u32` tokens. /// * `tokens_with_hashes` - Tokens with lazily computed hashes.
/// * `worker` - The worker (with dp_rank) that was selected. /// * `worker` - The worker (with dp_rank) that was selected.
async fn process_routing_decision_for_request( async fn process_routing_decision_for_request(
&self, &self,
tokens: &[u32], tokens_with_hashes: &mut TokensWithHashes,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
) -> Result<(), KvRouterError>; ) -> Result<(), KvRouterError>;
} }
...@@ -1304,7 +1177,22 @@ impl KvIndexerInterface for KvIndexer { ...@@ -1304,7 +1177,22 @@ impl KvIndexerInterface for KvIndexer {
.map_err(|_| KvRouterError::IndexerDroppedRequest) .map_err(|_| KvRouterError::IndexerDroppedRequest)
} }
async fn process_routing_decision( async fn process_routing_decision_for_request(
&self,
tokens_with_hashes: &mut TokensWithHashes,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let local_hashes = tokens_with_hashes.get_or_compute_block_hashes().to_vec();
let sequence_hashes = tokens_with_hashes.get_or_compute_seq_hashes().to_vec();
self.process_routing_decision_internal(worker, local_hashes, sequence_hashes)
.await
}
}
impl KvIndexer {
/// Internal method to process a routing decision with pre-computed hashes.
async fn process_routing_decision_internal(
&self, &self,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>, local_hashes: Vec<LocalBlockHash>,
...@@ -1320,23 +1208,6 @@ impl KvIndexerInterface for KvIndexer { ...@@ -1320,23 +1208,6 @@ impl KvIndexerInterface for KvIndexer {
.map_err(|_| KvRouterError::IndexerDroppedRequest)?; .map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(()) Ok(())
} }
async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size, None);
let sequence = TokenBlockSequence::new(tokens.into(), self.kv_block_size, None);
let sequence_hashes = sequence
.blocks()
.iter()
.map(|b| b.sequence_hash())
.collect::<Vec<_>>();
self.process_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
} }
impl Drop for KvIndexer { impl Drop for KvIndexer {
...@@ -1617,28 +1488,15 @@ impl KvIndexerInterface for LocalKvIndexer { ...@@ -1617,28 +1488,15 @@ impl KvIndexerInterface for LocalKvIndexer {
self.indexer.dump_events().await self.indexer.dump_events().await
} }
async fn process_routing_decision(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
// TODO I guess the local kvindexers have little use for this method?
// Keeping it here now to implement the trait fully
self.indexer
.process_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
async fn process_routing_decision_for_request( async fn process_routing_decision_for_request(
&self, &self,
tokens: &[u32], tokens_with_hashes: &mut TokensWithHashes,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> { ) -> Result<(), KvRouterError> {
// TODO I guess the local kvindexers have little use for this method? // TODO I guess the local kvindexers have little use for this method?
// Keeping it here now to implement the trait fully // Keeping it here now to implement the trait fully
self.indexer self.indexer
.process_routing_decision_for_request(tokens, worker) .process_routing_decision_for_request(tokens_with_hashes, worker)
.await .await
} }
} }
...@@ -2075,7 +1933,22 @@ impl KvIndexerInterface for KvIndexerSharded { ...@@ -2075,7 +1933,22 @@ impl KvIndexerInterface for KvIndexerSharded {
Ok(all_events) Ok(all_events)
} }
async fn process_routing_decision( async fn process_routing_decision_for_request(
&self,
tokens_with_hashes: &mut TokensWithHashes,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let local_hashes = tokens_with_hashes.get_or_compute_block_hashes().to_vec();
let sequence_hashes = tokens_with_hashes.get_or_compute_seq_hashes().to_vec();
self.process_routing_decision_internal(worker, local_hashes, sequence_hashes)
.await
}
}
impl KvIndexerSharded {
/// Internal method to process a routing decision with pre-computed hashes.
async fn process_routing_decision_internal(
&self, &self,
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>, local_hashes: Vec<LocalBlockHash>,
...@@ -2098,23 +1971,6 @@ impl KvIndexerInterface for KvIndexerSharded { ...@@ -2098,23 +1971,6 @@ impl KvIndexerInterface for KvIndexerSharded {
.map_err(|_| KvRouterError::IndexerDroppedRequest)?; .map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(()) Ok(())
} }
async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size, None);
let sequence = TokenBlockSequence::new(tokens.into(), self.kv_block_size, None);
let sequence_hashes = sequence
.blocks()
.iter()
.map(|b| b.sequence_hash())
.collect::<Vec<_>>();
self.process_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
} }
impl Drop for KvIndexerSharded { impl Drop for KvIndexerSharded {
......
...@@ -15,7 +15,7 @@ use dynamo_runtime::{ ...@@ -15,7 +15,7 @@ use dynamo_runtime::{
AsyncEngine, AsyncEngineContextProvider, Context, ManyOut, Operator, PushRouter, AsyncEngine, AsyncEngineContextProvider, Context, ManyOut, Operator, PushRouter,
RouterMode, ServerStreamingEngine, SingleIn, async_trait, RouterMode, ServerStreamingEngine, SingleIn, async_trait,
}, },
protocols::{annotated::Annotated, maybe_error::MaybeError}, protocols::{EndpointId, annotated::Annotated, maybe_error::MaybeError},
}; };
use crate::{ use crate::{
...@@ -23,7 +23,7 @@ use crate::{ ...@@ -23,7 +23,7 @@ use crate::{
kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride}, kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride},
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}, protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
protocols::common::preprocessor::{BootstrapInfo, PrefillResult}, protocols::common::preprocessor::{BootstrapInfo, PrefillResult},
protocols::common::timing::RequestPhase, protocols::common::timing::{RequestPhase, RequestTracker},
protocols::openai::nvext::WorkerIdInfo, protocols::openai::nvext::WorkerIdInfo,
}; };
...@@ -54,14 +54,29 @@ enum InnerPrefillRouter { ...@@ -54,14 +54,29 @@ enum InnerPrefillRouter {
} }
impl InnerPrefillRouter { impl InnerPrefillRouter {
/// Execute prefill generation through the underlying router /// Generate with optional direct routing to specific worker.
async fn generate( /// For KvRouter, target_worker is ignored since prefill_worker_id is already set on the request.
/// For SimpleRouter, target_worker triggers direct routing via router.direct().
async fn generate_to_worker(
&self, &self,
request: SingleIn<PreprocessedRequest>, request: SingleIn<PreprocessedRequest>,
target_worker: Option<u64>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> { ) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
match (self, target_worker) {
// KvRouter: prefill_worker_id already set on request, KvPushRouter::select_worker uses it
(InnerPrefillRouter::KvRouter(router), _) => router.generate(request).await,
(InnerPrefillRouter::SimpleRouter(router), Some(worker_id)) => {
router.direct(request, worker_id).await
}
(InnerPrefillRouter::SimpleRouter(router), None) => router.generate(request).await,
}
}
/// Select next worker (for non-KV modes only)
fn select_next_worker(&self) -> Option<u64> {
match self { match self {
InnerPrefillRouter::KvRouter(router) => router.generate(request).await, InnerPrefillRouter::SimpleRouter(router) => router.select_next_worker(),
InnerPrefillRouter::SimpleRouter(router) => router.generate(request).await, InnerPrefillRouter::KvRouter(_) => None,
} }
} }
} }
...@@ -76,6 +91,8 @@ impl InnerPrefillRouter { ...@@ -76,6 +91,8 @@ impl InnerPrefillRouter {
/// - Non-GAIE: like GAIE Stage 2 but the worker ids have to be determined. /// - Non-GAIE: like GAIE Stage 2 but the worker ids have to be determined.
pub struct PrefillRouter { pub struct PrefillRouter {
prefill_router: OnceLock<InnerPrefillRouter>, prefill_router: OnceLock<InnerPrefillRouter>,
model_manager: Arc<ModelManager>,
endpoint_id: OnceLock<EndpointId>,
cancel_token: CancellationToken, cancel_token: CancellationToken,
router_mode: RouterMode, router_mode: RouterMode,
enforce_disagg: bool, enforce_disagg: bool,
...@@ -83,9 +100,15 @@ pub struct PrefillRouter { ...@@ -83,9 +100,15 @@ pub struct PrefillRouter {
impl PrefillRouter { impl PrefillRouter {
/// Create a disabled prefill router that will never activate (passthrough only) /// Create a disabled prefill router that will never activate (passthrough only)
pub fn disabled(router_mode: RouterMode, enforce_disagg: bool) -> Arc<Self> { pub fn disabled(
model_manager: Arc<ModelManager>,
router_mode: RouterMode,
enforce_disagg: bool,
) -> Arc<Self> {
Arc::new(Self { Arc::new(Self {
prefill_router: OnceLock::new(), prefill_router: OnceLock::new(),
model_manager,
endpoint_id: OnceLock::new(),
cancel_token: CancellationToken::new(), cancel_token: CancellationToken::new(),
router_mode, router_mode,
enforce_disagg, enforce_disagg,
...@@ -105,6 +128,8 @@ impl PrefillRouter { ...@@ -105,6 +128,8 @@ impl PrefillRouter {
let router = Arc::new(Self { let router = Arc::new(Self {
prefill_router, prefill_router,
model_manager: model_manager.clone(),
endpoint_id: OnceLock::new(),
cancel_token: cancel_token.clone(), cancel_token: cancel_token.clone(),
router_mode, router_mode,
enforce_disagg, enforce_disagg,
...@@ -151,6 +176,15 @@ impl PrefillRouter { ...@@ -151,6 +176,15 @@ impl PrefillRouter {
"Activating prefill router" "Activating prefill router"
); );
// Store endpoint_id for later use in build_bootstrap_info
let _ = self.endpoint_id.set(endpoint.id());
// Start runtime config watcher for this endpoint (needed for get_disaggregated_endpoint)
// This must be done before creating the router so bootstrap info is available
model_manager
.get_or_create_runtime_config_watcher(&endpoint)
.await?;
let inner_router = if self.router_mode.is_kv_routing() { let inner_router = if self.router_mode.is_kv_routing() {
// Create KV chooser using the endpoint // Create KV chooser using the endpoint
let kv_chooser = model_manager let kv_chooser = model_manager
...@@ -205,22 +239,18 @@ impl PrefillRouter { ...@@ -205,22 +239,18 @@ impl PrefillRouter {
/// Build bootstrap_info for disaggregated serving /// Build bootstrap_info for disaggregated serving
/// If preselected_worker is provided (GAIE Stage 2), use it directly. /// If preselected_worker is provided (GAIE Stage 2), use it directly.
/// Otherwise, query for the best worker. /// Otherwise, query for the best worker (KV mode) or select next worker (non-KV modes).
async fn build_bootstrap_info( async fn build_bootstrap_info(
&self, &self,
req: &PreprocessedRequest, req: &PreprocessedRequest,
preselected_worker: Option<u64>, preselected_worker: Option<u64>,
) -> Option<(u64, u32, BootstrapInfo)> { ) -> Option<(u64, u32, BootstrapInfo)> {
let endpoint_id = self.endpoint_id.get()?;
let prefill_router = self.prefill_router.get()?; let prefill_router = self.prefill_router.get()?;
// Only works with KvRouter // Worker selection
let kv_router = match prefill_router {
InnerPrefillRouter::KvRouter(r) => r,
InnerPrefillRouter::SimpleRouter(_) => return None,
};
// Use pre-selected worker (GAIE Stage 2) or query for best worker
let (worker_id, dp_rank) = if let Some(id) = preselected_worker { let (worker_id, dp_rank) = if let Some(id) = preselected_worker {
// GAIE Stage 2: use pre-selected worker
let dp_rank = req.routing.as_ref().and_then(|r| r.dp_rank).unwrap_or(0); let dp_rank = req.routing.as_ref().and_then(|r| r.dp_rank).unwrap_or(0);
tracing::debug!( tracing::debug!(
worker_id = id, worker_id = id,
...@@ -228,7 +258,12 @@ impl PrefillRouter { ...@@ -228,7 +258,12 @@ impl PrefillRouter {
"Using pre-selected prefill worker for bootstrap" "Using pre-selected prefill worker for bootstrap"
); );
(id, dp_rank) (id, dp_rank)
} else { } else if self.router_mode.is_kv_routing() {
// KV mode: use find_best_match
let kv_router = match prefill_router {
InnerPrefillRouter::KvRouter(r) => r,
_ => return None,
};
match kv_router match kv_router
.chooser .chooser
.find_best_match(None, &req.token_ids, None, false) .find_best_match(None, &req.token_ids, None, false)
...@@ -237,13 +272,16 @@ impl PrefillRouter { ...@@ -237,13 +272,16 @@ impl PrefillRouter {
Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank), Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank),
Err(_) => return None, Err(_) => return None,
} }
} else {
// Non-KV mode: use PushRouter's stateful selection
let worker_id = prefill_router.select_next_worker()?;
(worker_id, 0)
}; };
// Look up bootstrap endpoint from discovery // Get bootstrap info from ModelManager (works for ANY mode)
let endpoint = kv_router let endpoint = self
.chooser .model_manager
.get_disaggregated_endpoint(worker_id) .get_disaggregated_endpoint(endpoint_id, worker_id)?;
.await?;
let host = endpoint.bootstrap_host?; let host = endpoint.bootstrap_host?;
let port = endpoint.bootstrap_port?; let port = endpoint.bootstrap_port?;
...@@ -255,6 +293,7 @@ impl PrefillRouter { ...@@ -255,6 +293,7 @@ impl PrefillRouter {
bootstrap_host = %host, bootstrap_host = %host,
bootstrap_port = port, bootstrap_port = port,
bootstrap_room = bootstrap_room, bootstrap_room = bootstrap_room,
router_mode = ?self.router_mode,
"Built bootstrap_info upfront before prefill" "Built bootstrap_info upfront before prefill"
); );
...@@ -270,13 +309,15 @@ impl PrefillRouter { ...@@ -270,13 +309,15 @@ impl PrefillRouter {
} }
/// Execute prefill with the given router and extract structured result /// Execute prefill with the given router and extract structured result
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization)
async fn execute_prefill( async fn execute_prefill(
router: Option<InnerPrefillRouter>, router: Option<InnerPrefillRouter>,
request: SingleIn<PreprocessedRequest>, request: SingleIn<PreprocessedRequest>,
target_worker: Option<u64>,
) -> Result<(PrefillResult, Option<u64>), PrefillError> { ) -> Result<(PrefillResult, Option<u64>), PrefillError> {
let router = router.ok_or(PrefillError::NotActivated)?; let router = router.ok_or(PrefillError::NotActivated)?;
let mut prefill_response = router let mut prefill_response = router
.generate(request) .generate_to_worker(request, target_worker)
.await .await
.map_err(|e| PrefillError::PrefillError(e.to_string()))?; .map_err(|e| PrefillError::PrefillError(e.to_string()))?;
...@@ -339,11 +380,16 @@ impl PrefillRouter { ...@@ -339,11 +380,16 @@ impl PrefillRouter {
} }
/// Spawn prefill as a background task /// Spawn prefill as a background task
fn spawn_prefill_task(&self, prefill_request: SingleIn<PreprocessedRequest>) { /// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization)
fn spawn_prefill_task(
&self,
prefill_request: SingleIn<PreprocessedRequest>,
target_worker: Option<u64>,
) {
let router = self.prefill_router.get().cloned(); let router = self.prefill_router.get().cloned();
tokio::spawn(async move { tokio::spawn(async move {
match Self::execute_prefill(router, prefill_request).await { match Self::execute_prefill(router, prefill_request, target_worker).await {
Ok(_) => { Ok(_) => {
tracing::debug!("Prefill background task completed"); tracing::debug!("Prefill background task completed");
} }
...@@ -359,7 +405,8 @@ impl PrefillRouter { ...@@ -359,7 +405,8 @@ impl PrefillRouter {
&self, &self,
request: SingleIn<PreprocessedRequest>, request: SingleIn<PreprocessedRequest>,
) -> Result<(PrefillResult, Option<u64>), PrefillError> { ) -> Result<(PrefillResult, Option<u64>), PrefillError> {
Self::execute_prefill(self.prefill_router.get().cloned(), request).await // For call_prefill path, routing is handled by the router itself (no direct routing needed)
Self::execute_prefill(self.prefill_router.get().cloned(), request, None).await
} }
} }
...@@ -439,7 +486,7 @@ impl ...@@ -439,7 +486,7 @@ impl
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>, next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> { ) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
// Extract request data while preserving context // Extract request data while preserving context
let (req, context) = request.into_parts(); let (mut req, context) = request.into_parts();
let request_id = context.id().to_string(); let request_id = context.id().to_string();
let engine_ctx = context.context(); let engine_ctx = context.context();
...@@ -452,7 +499,26 @@ impl ...@@ -452,7 +499,26 @@ impl
// Save original max_tokens for decode // Save original max_tokens for decode
let original_max_tokens = req.stop_conditions.max_tokens; let original_max_tokens = req.stop_conditions.max_tokens;
// Prepare prefill request with max_tokens = 1 // GAIE Stage 1: Check if prefill router is activated - if not, skip to decode
if is_gaie_stage1 && self.prefill_router.get().is_none() {
tracing::debug!("GAIE Stage 1: Prefill router not activated, skipping to decode");
if self.enforce_disagg {
return Err(anyhow::anyhow!(PrefillError::NotActivated));
}
// Fall back to decode-only
return next.generate(context.map(|_| req)).await;
}
// Ensure tracker exists for routing decisions in disaggregated mode.
// Create one if not provided by the upstream DeltaGenerator.
if req.tracker.is_none() {
req.tracker = Some(Arc::new(RequestTracker::new()));
}
let tracker = req.tracker.as_ref().unwrap();
tracker.set_phase(RequestPhase::Prefill);
tracker.record_prefill_start();
// Prepare prefill request with max_tokens = 1 (clone after tracker is set)
let mut prefill_req = req.clone(); let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1); prefill_req.stop_conditions.max_tokens = Some(1);
...@@ -465,70 +531,31 @@ impl ...@@ -465,70 +531,31 @@ impl
.routing .routing
.as_ref() .as_ref()
.and_then(|r| r.prefill_worker_id); .and_then(|r| r.prefill_worker_id);
let prefill_result = if !is_gaie_stage1 {
if let Some((worker_id, dp_rank, bootstrap_info)) = self let prefill_result = if !is_gaie_stage1
&& let Some((worker_id, dp_rank, bootstrap_info)) = self
.build_bootstrap_info(&prefill_req, preselected_worker) .build_bootstrap_info(&prefill_req, preselected_worker)
.await .await
{ {
let bootstrap_room = bootstrap_info.bootstrap_room; // Bootstrap optimization path: spawn prefill in background
// Prepare request with bootstrap_room and force routing to specific worker
let routing = prefill_req.routing_mut(); let routing = prefill_req.routing_mut();
routing.backend_instance_id = Some(worker_id); routing.prefill_worker_id = Some(worker_id);
routing.backend_instance_id = Some(worker_id); // Route prefill to the SAME worker we got bootstrap_info from
routing.dp_rank = Some(dp_rank); routing.dp_rank = Some(dp_rank);
let extra_args = prefill_req prefill_req.bootstrap_info = Some(bootstrap_info.clone());
.extra_args
.get_or_insert_with(|| serde_json::json!({}));
if let Some(obj) = extra_args.as_object_mut() {
obj.insert(
"bootstrap_room".to_string(),
serde_json::json!(bootstrap_room),
);
}
// Set phase to Prefill and record prefill start time if tracking is enabled
if let Some(ref tracker) = req.tracker {
tracker.set_phase(RequestPhase::Prefill);
tracker.record_prefill_start();
}
let prefill_context = Context::with_id(prefill_req, request_id.clone()); let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context()); engine_ctx.link_child(prefill_context.context());
self.spawn_prefill_task(prefill_context); self.spawn_prefill_task(prefill_context, Some(worker_id));
Ok((None, Some(worker_id), Some(bootstrap_info))) Ok((None, Some(worker_id), Some(bootstrap_info)))
} else { } else {
// Fallback to original: Wait for prefill to complete // Original prefill path: wait for prefill to complete
tracing::debug!("Using original prefill path"); tracing::debug!(
is_gaie_stage1 = is_gaie_stage1,
// Set phase to Prefill and record prefill start time if tracking is enabled "Using original prefill path"
if let Some(ref tracker) = req.tracker { );
tracker.set_phase(RequestPhase::Prefill);
tracker.record_prefill_start();
}
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
self.call_prefill(prefill_context)
.await
.map(|(result, worker_id)| (Some(result), worker_id, None))
}
} else {
// GAIE Stage 1: Use original path (no bootstrap optimization)
// But first check if prefill router is activated - if not, skip to avoid setting phase
if self.prefill_router.get().is_none() {
tracing::debug!("GAIE Stage 1: Prefill router not activated, skipping to decode");
Err(PrefillError::NotActivated)
} else {
tracing::debug!("Using original prefill path (GAIE Stage 1)");
// Set phase to Prefill and record prefill start time if tracking is enabled
if let Some(ref tracker) = req.tracker {
tracker.set_phase(RequestPhase::Prefill);
tracker.record_prefill_start();
}
let prefill_context = Context::with_id(prefill_req, request_id.clone()); let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context()); engine_ctx.link_child(prefill_context.context());
...@@ -536,7 +563,6 @@ impl ...@@ -536,7 +563,6 @@ impl
self.call_prefill(prefill_context) self.call_prefill(prefill_context)
.await .await
.map(|(result, worker_id)| (Some(result), worker_id, None)) .map(|(result, worker_id)| (Some(result), worker_id, None))
}
}; };
// Abort if cancelled during prefill // Abort if cancelled during prefill
......
...@@ -4,6 +4,82 @@ ...@@ -4,6 +4,82 @@
use crate::tokens::{SequenceHash, Token}; use crate::tokens::{SequenceHash, Token};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid; use uuid::Uuid;
use xxhash_rust::xxh3;
/// Seed for XXH3 hashing, consistent with indexer.rs
pub const XXH3_SEED: u64 = 1337;
/// Compute hash of data using XXH3 with the standard seed.
pub fn compute_hash(data: &[u8]) -> u64 {
xxh3::xxh3_64_with_seed(data, XXH3_SEED)
}
/// Compute the hash of a local block.
pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
LocalBlockHash(compute_hash(data))
}
/// Compute the hash for a sequence of tokens, optionally including multimodal metadata.
///
/// When multimodal extra info is provided, the mm_hashes are included in the hash computation
/// to ensure that blocks with identical tokens but different multimodal objects produce
/// different hashes.
pub fn compute_block_hash_for_seq(
tokens: &[u32],
kv_block_size: u32,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
) -> Vec<LocalBlockHash> {
tokens
.chunks_exact(kv_block_size as usize)
.enumerate()
.map(|(block_idx, chunk)| {
let mut bytes: Vec<u8> = chunk.iter().flat_map(|&num| num.to_le_bytes()).collect();
// Include MM hashes in the block hash computation if present
if let Some(mm_infos) = block_mm_infos
&& let Some(Some(block_mm_info)) = mm_infos.get(block_idx)
{
let mut mm_hashes: Vec<u64> = block_mm_info
.mm_objects
.iter()
.map(|obj| obj.mm_hash)
.collect();
mm_hashes.sort_unstable();
for mm_hash in mm_hashes {
bytes.extend_from_slice(&mm_hash.to_le_bytes());
}
}
compute_block_hash(&bytes)
})
.collect()
}
/// Compute rolling sequence hashes for a vector of block hashes.
///
/// - The first block's sequence hash equals its block hash
/// - Subsequent blocks' sequence hash = hash([parent_sequence_hash, current_block_hash], seed)
pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<SequenceHash> {
if block_hashes.is_empty() {
return Vec::new();
}
let mut sequence_hashes = Vec::with_capacity(block_hashes.len());
sequence_hashes.push(block_hashes[0].0);
for i in 1..block_hashes.len() {
let parent_seq_hash = sequence_hashes[i - 1];
let current_block_hash = block_hashes[i].0;
let combined = [parent_seq_hash, current_block_hash];
let bytes: Vec<u8> = combined.iter().flat_map(|&num| num.to_le_bytes()).collect();
let seq_hash = compute_hash(&bytes);
sequence_hashes.push(seq_hash);
}
sequence_hashes
}
/// A worker identifier. /// A worker identifier.
pub type WorkerId = u64; pub type WorkerId = u64;
...@@ -439,6 +515,103 @@ impl<'de> Deserialize<'de> for ExternalSequenceBlockHash { ...@@ -439,6 +515,103 @@ impl<'de> Deserialize<'de> for ExternalSequenceBlockHash {
} }
} }
// ------
// TokensWithHashes
// ------
/// A container for tokens with lazily computed block and sequence hashes.
///
/// This struct avoids redundant hash computations by caching results:
/// - `get_or_compute_block_hashes()` computes block hashes if not cached
/// - `get_or_compute_seq_hashes()` computes seq hashes if not cached,
/// and will also compute block hashes first if needed (since seq hashes depend on them)
#[derive(Debug, Clone)]
pub struct TokensWithHashes {
tokens: Vec<u32>,
block_size: u32,
block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>,
block_hashes: Option<Vec<LocalBlockHash>>,
seq_hashes: Option<Vec<SequenceHash>>,
}
impl TokensWithHashes {
/// Creates a new TokensWithHashes from tokens and block size.
pub fn new(tokens: Vec<u32>, block_size: u32) -> Self {
Self {
tokens,
block_size,
block_mm_infos: None,
block_hashes: None,
seq_hashes: None,
}
}
/// Adds multimodal extra info for blocks.
pub fn with_mm_infos(mut self, infos: Vec<Option<BlockExtraInfo>>) -> Self {
self.block_mm_infos = Some(infos);
self
}
/// Returns a reference to the tokens.
pub fn tokens(&self) -> &[u32] {
&self.tokens
}
/// Returns the number of tokens.
pub fn len(&self) -> usize {
self.tokens.len()
}
/// Returns true if there are no tokens.
pub fn is_empty(&self) -> bool {
self.tokens.is_empty()
}
/// Returns the block size.
pub fn block_size(&self) -> u32 {
self.block_size
}
/// Returns the multimodal extra info, if set.
pub fn block_mm_infos(&self) -> Option<&[Option<BlockExtraInfo>]> {
self.block_mm_infos.as_deref()
}
/// Returns block hashes, computing them if not already cached.
pub fn get_or_compute_block_hashes(&mut self) -> &[LocalBlockHash] {
if self.block_hashes.is_none() {
self.block_hashes = Some(compute_block_hash_for_seq(
&self.tokens,
self.block_size,
self.block_mm_infos.as_deref(),
));
}
self.block_hashes.as_ref().unwrap()
}
/// Returns sequence hashes, computing them if not already cached.
/// This will also compute block hashes if they haven't been computed yet,
/// since sequence hashes depend on block hashes.
pub fn get_or_compute_seq_hashes(&mut self) -> &[SequenceHash] {
if self.seq_hashes.is_none() {
// Ensure block hashes are computed first
let block_hashes = self.get_or_compute_block_hashes();
self.seq_hashes = Some(compute_seq_hash_for_block(block_hashes));
}
self.seq_hashes.as_ref().unwrap()
}
/// Returns cached block hashes without computing. Returns None if not yet computed.
pub fn block_hashes(&self) -> Option<&[LocalBlockHash]> {
self.block_hashes.as_deref()
}
/// Returns cached seq hashes without computing. Returns None if not yet computed.
pub fn seq_hashes(&self) -> Option<&[SequenceHash]> {
self.seq_hashes.as_deref()
}
}
// ------ // ------
// Tests // Tests
// ------ // ------
......
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::local_model::runtime_config::{DisaggregatedEndpoint, ModelRuntimeConfig}; use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result; use anyhow::Result;
use dashmap::DashMap;
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventPublisher; use dynamo_runtime::traits::events::EventPublisher;
...@@ -11,7 +12,7 @@ use serde::{Deserialize, Serialize}; ...@@ -11,7 +12,7 @@ use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::sync::{RwLock, watch}; use tokio::sync::watch;
use super::KV_HIT_RATE_SUBJECT; use super::KV_HIT_RATE_SUBJECT;
use super::KvRouterConfig; use super::KvRouterConfig;
...@@ -90,8 +91,6 @@ impl SchedulingRequest { ...@@ -90,8 +91,6 @@ impl SchedulingRequest {
pub struct KvScheduler { pub struct KvScheduler {
request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>, request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMultiWorker>, slots: Arc<ActiveSequencesMultiWorker>,
/// Worker runtime configs for looking up disaggregated endpoints
workers_with_configs: Arc<RwLock<HashMap<WorkerId, Option<ModelRuntimeConfig>>>>,
} }
impl KvScheduler { impl KvScheduler {
...@@ -99,92 +98,71 @@ impl KvScheduler { ...@@ -99,92 +98,71 @@ impl KvScheduler {
component: Component, component: Component,
block_size: u32, block_size: u32,
instance_ids_rx: watch::Receiver<Vec<u64>>, instance_ids_rx: watch::Receiver<Vec<u64>>,
runtime_configs_rx: watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>, workers_with_configs: Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool, replica_sync: bool,
router_uuid: String, router_uuid: String,
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default())); let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let instance_ids: Vec<u64> = instance_ids_rx.borrow().clone();
let runtime_configs: HashMap<WorkerId, ModelRuntimeConfig> =
runtime_configs_rx.borrow().clone();
// Create shared workers_with_configs wrapped in Arc<RwLock> // Get initial workers from DashMap for slot initialization
let workers_with_configs: Arc<RwLock<HashMap<WorkerId, Option<ModelRuntimeConfig>>>> = { let initial_workers: HashMap<WorkerId, Option<ModelRuntimeConfig>> = workers_with_configs
let mut initial_map = HashMap::new(); .iter()
for worker_id in &instance_ids { .map(|r| (*r.key(), r.value().clone()))
let config = runtime_configs.get(worker_id).cloned(); .collect();
if config.is_some() {
tracing::info!("Runtime config found for worker_id: {}", worker_id);
}
initial_map.insert(*worker_id, config);
}
Arc::new(RwLock::new(initial_map))
};
let slots = Arc::new(ActiveSequencesMultiWorker::new( let slots = Arc::new(ActiveSequencesMultiWorker::new(
component.clone(), component.clone(),
block_size as usize, block_size as usize,
workers_with_configs.read().await.clone(), // this includes dp_size info initial_workers,
replica_sync, replica_sync,
router_uuid, router_uuid,
)); ));
// Spawn background task to monitor and update workers_with_configs // Spawn background task to monitor workers_with_configs changes and update slots
let workers_monitor = workers_with_configs.clone();
let slots_monitor = slots.clone(); let slots_monitor = slots.clone();
let workers_monitor = workers_with_configs.clone();
let mut instance_ids_monitor_rx = instance_ids_rx.clone(); let mut instance_ids_monitor_rx = instance_ids_rx.clone();
let mut configs_monitor_rx = runtime_configs_rx.clone();
let monitor_cancel_token = component.drt().child_token(); let monitor_cancel_token = component.drt().child_token();
tokio::spawn(async move { tokio::spawn(async move {
tracing::trace!("workers monitoring task started"); tracing::trace!("KvScheduler workers monitoring task started");
let mut last_workers: HashSet<WorkerId> = HashSet::new();
loop { loop {
// Wait for either instances or configs to change // Wait for instance changes (ModelManager handles config updates to the DashMap)
tokio::select! { tokio::select! {
_ = monitor_cancel_token.cancelled() => { _ = monitor_cancel_token.cancelled() => {
tracing::trace!("workers monitoring task shutting down"); tracing::trace!("KvScheduler workers monitoring task shutting down");
break; break;
} }
result = instance_ids_monitor_rx.changed() => { result = instance_ids_monitor_rx.changed() => {
if result.is_err() { if result.is_err() {
tracing::warn!("instance IDs watch sender shutdown in monitor"); tracing::warn!("instance IDs watch sender shutdown in KvScheduler monitor");
break;
}
}
result = configs_monitor_rx.changed() => {
if result.is_err() {
tracing::warn!("runtime configs watch sender shutdown in monitor");
break; break;
} }
} }
} }
// Get the latest values from both channels // Get current workers from DashMap
let new_instance_ids = instance_ids_monitor_rx.borrow_and_update().clone(); let current_workers: HashMap<WorkerId, Option<ModelRuntimeConfig>> =
let new_configs = configs_monitor_rx.borrow_and_update().clone(); workers_monitor
.iter()
// Build the new workers_with_configs map .map(|r| (*r.key(), r.value().clone()))
let mut new_workers_with_configs = HashMap::new(); .collect();
for worker_id in &new_instance_ids { let current_worker_ids: HashSet<WorkerId> =
let config = new_configs.get(worker_id).cloned(); current_workers.keys().copied().collect();
if config.is_some() {
tracing::info!("Runtime config found for worker_id: {}", worker_id);
}
new_workers_with_configs.insert(*worker_id, config);
}
// Update workers when instances change
slots_monitor.update_workers(new_workers_with_configs.clone());
// Update the shared workers_with_configs // Only update slots if workers have changed
let mut workers_map = workers_monitor.write().await; if current_worker_ids != last_workers {
*workers_map = new_workers_with_configs; slots_monitor.update_workers(current_workers);
last_workers = current_worker_ids;
tracing::trace!( tracing::trace!(
"Updated workers_with_configs with {} workers", "KvScheduler: Updated slots with {} workers",
workers_map.len() last_workers.len()
); );
} }
tracing::trace!("workers monitoring task shutting down"); }
tracing::trace!("KvScheduler workers monitoring task shutting down");
}); });
let slots_clone = slots.clone(); let slots_clone = slots.clone();
...@@ -222,8 +200,11 @@ impl KvScheduler { ...@@ -222,8 +200,11 @@ impl KvScheduler {
request.decode_blocks = decode_blocks; request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens; request.prefill_tokens = prefill_tokens;
// Read the current workers configuration // Read the current workers configuration from DashMap
let workers = workers_scheduler.read().await.clone(); let workers: HashMap<WorkerId, Option<ModelRuntimeConfig>> = workers_scheduler
.iter()
.map(|r| (*r.key(), r.value().clone()))
.collect();
match selector.select_worker(&workers, &request, block_size) { match selector.select_worker(&workers, &request, block_size) {
Ok(selection) => { Ok(selection) => {
...@@ -289,11 +270,7 @@ impl KvScheduler { ...@@ -289,11 +270,7 @@ impl KvScheduler {
tracing::trace!("background endpoint subscriber shutting down"); tracing::trace!("background endpoint subscriber shutting down");
}); });
Ok(KvScheduler { Ok(KvScheduler { request_tx, slots })
request_tx,
slots,
workers_with_configs,
})
} }
pub async fn schedule( pub async fn schedule(
...@@ -352,17 +329,6 @@ impl KvScheduler { ...@@ -352,17 +329,6 @@ impl KvScheduler {
self.slots.free(&request_id.to_string()).await self.slots.free(&request_id.to_string()).await
} }
pub async fn get_disaggregated_endpoint(
&self,
worker_id: WorkerId,
) -> Option<DisaggregatedEndpoint> {
let workers = self.workers_with_configs.read().await;
workers
.get(&worker_id)
.and_then(|config| config.as_ref())
.and_then(|config| config.disaggregated_endpoint.clone())
}
pub async fn get_potential_loads( pub async fn get_potential_loads(
&self, &self,
token_seq: Option<Vec<SequenceHash>>, token_seq: Option<Vec<SequenceHash>>,
......
...@@ -198,6 +198,34 @@ where ...@@ -198,6 +198,34 @@ where
.await .await
} }
/// Select the next worker according to the routing mode.
/// Increments round-robin counter if applicable.
/// Panics if called on Direct or KV mode - those have their own selection mechanisms.
pub fn select_next_worker(&self) -> Option<u64> {
let instance_ids = self.client.instance_ids_avail();
let count = instance_ids.len();
if count == 0 {
return None;
}
match self.router_mode {
RouterMode::RoundRobin => {
let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
Some(instance_ids[counter % count])
}
RouterMode::Random => {
let counter = rand::rng().random::<u64>() as usize;
Some(instance_ids[counter % count])
}
_ => {
panic!(
"select_next_worker should not be called for {:?} routing mode",
self.router_mode
)
}
}
}
/* /*
pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> { pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let subject = self.client.endpoint.subject(); let subject = self.client.endpoint.subject();
......
...@@ -1894,6 +1894,7 @@ def _test_router_decisions( ...@@ -1894,6 +1894,7 @@ def _test_router_decisions(
request, request,
test_dp_rank: bool = False, test_dp_rank: bool = False,
block_size: int = BLOCK_SIZE, block_size: int = BLOCK_SIZE,
use_kv_events: bool = True,
): ):
"""Validate KV cache prefix reuse and worker routing by sending requests diverging prefixes. """Validate KV cache prefix reuse and worker routing by sending requests diverging prefixes.
...@@ -1912,12 +1913,17 @@ def _test_router_decisions( ...@@ -1912,12 +1913,17 @@ def _test_router_decisions(
model_name: Name of the model model_name: Name of the model
request: Pytest request fixture request: Pytest request fixture
test_dp_rank: If True, also forces and validates dp_rank routing (for data parallel setups) test_dp_rank: If True, also forces and validates dp_rank routing (for data parallel setups)
use_kv_events: If True (default), uses KV events from workers. If False, uses
approximate routing with TTL-based expiration (--no-kv-events mode).
Raises: Raises:
AssertionError: If routing decisions don't follow KV cache prefix reuse as expected AssertionError: If routing decisions don't follow KV cache prefix reuse as expected
""" """
# Create KvRouterConfig with lower snapshot threshold for testing # Create KvRouterConfig with lower snapshot threshold for testing
kv_router_config = KvRouterConfig(router_snapshot_threshold=20) kv_router_config = KvRouterConfig(
router_snapshot_threshold=20,
use_kv_events=use_kv_events,
)
kv_push_router = KvPushRouter( kv_push_router = KvPushRouter(
endpoint=endpoint, endpoint=endpoint,
block_size=block_size, block_size=block_size,
......
...@@ -596,30 +596,49 @@ def test_query_instance_id_returns_worker_and_tokens( ...@@ -596,30 +596,49 @@ def test_query_instance_id_returns_worker_and_tokens(
@pytest.mark.timeout(29) # ~3x average (~9.55s), rounded up @pytest.mark.timeout(29) # ~3x average (~9.55s), rounded up
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) @pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize("use_nats_core", [False, True], ids=["jetstream", "nats_core"]) @pytest.mark.parametrize(
"use_nats_core,use_kv_events",
[
(False, True), # JetStream mode (default)
(True, True), # NATS Core + local indexer mode
(False, False), # Approximate mode (--no-kv-events)
],
ids=["jetstream", "nats_core", "no_kv_events"],
)
def test_router_decisions( def test_router_decisions(
request, request,
runtime_services_dynamic_ports, runtime_services_dynamic_ports,
predownload_tokenizers, predownload_tokenizers,
use_nats_core, use_nats_core,
use_kv_events,
request_plane, request_plane,
): ):
"""Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes. """Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes.
Parameterized to test both JetStream (default) and NATS Core (local indexer) modes. Parameterized to test:
- JetStream mode (default): KV events via JetStream
- NATS Core mode: KV events via NATS Core with local indexer on workers
- Approximate mode (--no-kv-events): No KV events, router predicts cache state
based on routing decisions with TTL-based expiration and pruning
""" """
# runtime_services_dynamic_ports handles NATS and etcd startup # runtime_services_dynamic_ports handles NATS and etcd startup
mode = "NATS Core (local indexer)" if use_nats_core else "JetStream" if not use_kv_events:
mode = "Approximate (no-kv-events)"
elif use_nats_core:
mode = "NATS Core (local indexer)"
else:
mode = "JetStream"
logger.info( logger.info(
f"Starting test router prefix reuse and KV events synchronization ({mode})" f"Starting test router prefix reuse and KV events synchronization ({mode})"
) )
# Create mocker args dictionary with dp_size=4 # Create mocker args dictionary with dp_size=4
# Note: enable_local_indexer only applies when use_kv_events=True and use_nats_core=True
mocker_args = { mocker_args = {
"speedup_ratio": SPEEDUP_RATIO, "speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE, "block_size": BLOCK_SIZE,
"dp_size": 4, "dp_size": 4,
"enable_local_indexer": use_nats_core, "enable_local_indexer": use_nats_core and use_kv_events,
} }
try: try:
...@@ -645,7 +664,12 @@ def test_router_decisions( ...@@ -645,7 +664,12 @@ def test_router_decisions(
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
_test_router_decisions( _test_router_decisions(
mockers, endpoint, MODEL_NAME, request, test_dp_rank=True mockers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=True,
use_kv_events=use_kv_events,
) )
finally: finally:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment