Unverified Commit 0b33c1df authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: sglang disagg routing fixes and optimizations [DYN-1692] (#5106)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Co-authored-by: default avatarIshan Dhanani <ishandhanani@gmail.com>
Co-authored-by: default avatarSean SH Choi <sechoi@nvidia.com>
Co-authored-by: default avatarishandhanani <82981111+ishandhanani@users.noreply.github.com>
parent e49834c9
......@@ -49,13 +49,10 @@ use std::{
};
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use xxhash_rust::xxh3;
pub const XXH3_SEED: u64 = 1337;
use crate::kv_router::approx::{BlockEntry, PruneConfig, PruneManager};
use crate::kv_router::protocols::{BlockExtraInfo, *};
use crate::tokens::{SequenceHash, TokenBlockSequence};
use crate::kv_router::protocols::*;
use crate::tokens::SequenceHash;
/// Errors that can occur in the KV Router.
#[derive(Debug, thiserror::Error)]
......@@ -89,119 +86,6 @@ pub enum KvCacheEventError {
/// A shared reference to a [`RadixBlock`].
type SharedRadixBlock = Rc<RefCell<RadixBlock>>;
pub fn compute_hash(data: &[u8]) -> u64 {
xxh3::xxh3_64_with_seed(data, XXH3_SEED)
}
/// Compute the hash of a local block.
///
/// ### Arguments
///
/// * `data` - A byte slice representing the data to hash.
///
/// ### Returns
///
/// A `LocalBlockHash` representing the computed hash.
pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
LocalBlockHash(compute_hash(data))
}
// /// Updated version of the `compute_block_hash` function that included the lora_id
// pub fn compute_block_hash_v2(token_id: &[u32], lora_id: u64) {
// let mut bytes = Vec::new();
// for token in token_id {
// bytes.extend_from_slice(&token.to_le_bytes());
// }
// bytes.extend_from_slice(&lora_id.to_le_bytes());
// let hash = xxh3::xxh3_64_with_seed(&bytes, XXH3_SEED);
// }
/// Compute the hash for a sequence of tokens, optionally including multimodal metadata.
///
/// When multimodal extra info is provided, the mm_hashes are included in the hash computation
/// to ensure that blocks with identical tokens but different multimodal objects produce
/// different hashes.
///
/// ### Arguments
///
/// * `tokens` - A vector of `u32` tokens.
/// * `kv_block_size` - The size of each block in tokens.
/// * `block_mm_infos` - Optional per-block multimodal metadata.
///
/// ### Returns
///
/// A vector of `LocalBlockHash` representing the computed hashes for each chunk of tokens.
pub fn compute_block_hash_for_seq(
tokens: &[u32],
kv_block_size: u32,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
) -> Vec<LocalBlockHash> {
tokens
.chunks_exact(kv_block_size as usize)
.enumerate()
.map(|(block_idx, chunk)| {
let mut bytes: Vec<u8> = chunk.iter().flat_map(|&num| num.to_le_bytes()).collect();
// Include MM hashes in the block hash computation if present
if let Some(mm_infos) = block_mm_infos
&& let Some(Some(block_mm_info)) = mm_infos.get(block_idx)
{
// The order of different multimodal hashes does not matter.
// Only which multimodal infos are present in a block is important.
// The order may differ in different code paths, so the hashes are sorted
// to keep the block hash stable.
let mut mm_hashes: Vec<u64> = block_mm_info
.mm_objects
.iter()
.map(|obj| obj.mm_hash)
.collect();
mm_hashes.sort_unstable();
// Append sorted mm_hashes to the byte array
for mm_hash in mm_hashes {
bytes.extend_from_slice(&mm_hash.to_le_bytes());
}
}
compute_block_hash(&bytes)
})
.collect()
}
/// Compute rolling sequence hashes for a vector of block hashes.
///
/// This mirrors the behavior in tokens.rs where:
/// - The first block's sequence hash equals its block hash
/// - Subsequent blocks' sequence hash = hash([parent_sequence_hash, current_block_hash], seed)
///
/// ### Arguments
///
/// * `block_hashes` - A vector of `LocalBlockHash` values representing the block hashes.
///
/// ### Returns
///
/// A vector of u64 values representing the sequence hashes for each block.
pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<SequenceHash> {
if block_hashes.is_empty() {
return Vec::new();
}
let mut sequence_hashes = Vec::with_capacity(block_hashes.len());
sequence_hashes.push(block_hashes[0].0);
for i in 1..block_hashes.len() {
let parent_seq_hash = sequence_hashes[i - 1];
let current_block_hash = block_hashes[i].0;
let combined = [parent_seq_hash, current_block_hash];
let bytes: Vec<u8> = combined.iter().flat_map(|&num| num.to_le_bytes()).collect();
let seq_hash = compute_hash(&bytes);
sequence_hashes.push(seq_hash);
}
sequence_hashes
}
/// A [`KvCacheEvent`] on a specific LLM worker denoted by [`WorkerId`].
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RouterEvent {
......@@ -897,29 +781,18 @@ pub trait KvIndexerInterface {
/// A vector of RouterEvents representing the current state of the tree.
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError>;
/// Process a routing decision with pre-computed hashes.
///
/// ### Arguments
///
/// * `worker` - The worker (with dp_rank) that was selected.
/// * `local_hashes` - The local hashes of the tokens sent to the worker.
/// * `sequence_hashes` - The sequence hashes of the tokens sent to the worker.
async fn process_routing_decision(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError>;
/// Process a routing decision for a request with tokens.
///
/// Uses TokensWithHashes for lazy hash computation - if hashes were already
/// computed (e.g., by find_best_match), they will be reused.
///
/// ### Arguments
///
/// * `tokens` - A vector of `u32` tokens.
/// * `tokens_with_hashes` - Tokens with lazily computed hashes.
/// * `worker` - The worker (with dp_rank) that was selected.
async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
tokens_with_hashes: &mut TokensWithHashes,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError>;
}
......@@ -1304,7 +1177,22 @@ impl KvIndexerInterface for KvIndexer {
.map_err(|_| KvRouterError::IndexerDroppedRequest)
}
async fn process_routing_decision(
async fn process_routing_decision_for_request(
&self,
tokens_with_hashes: &mut TokensWithHashes,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let local_hashes = tokens_with_hashes.get_or_compute_block_hashes().to_vec();
let sequence_hashes = tokens_with_hashes.get_or_compute_seq_hashes().to_vec();
self.process_routing_decision_internal(worker, local_hashes, sequence_hashes)
.await
}
}
impl KvIndexer {
/// Internal method to process a routing decision with pre-computed hashes.
async fn process_routing_decision_internal(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
......@@ -1320,23 +1208,6 @@ impl KvIndexerInterface for KvIndexer {
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(())
}
async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size, None);
let sequence = TokenBlockSequence::new(tokens.into(), self.kv_block_size, None);
let sequence_hashes = sequence
.blocks()
.iter()
.map(|b| b.sequence_hash())
.collect::<Vec<_>>();
self.process_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
}
impl Drop for KvIndexer {
......@@ -1617,28 +1488,15 @@ impl KvIndexerInterface for LocalKvIndexer {
self.indexer.dump_events().await
}
async fn process_routing_decision(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
// TODO I guess the local kvindexers have little use for this method?
// Keeping it here now to implement the trait fully
self.indexer
.process_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
tokens_with_hashes: &mut TokensWithHashes,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
// TODO I guess the local kvindexers have little use for this method?
// Keeping it here now to implement the trait fully
self.indexer
.process_routing_decision_for_request(tokens, worker)
.process_routing_decision_for_request(tokens_with_hashes, worker)
.await
}
}
......@@ -2075,7 +1933,22 @@ impl KvIndexerInterface for KvIndexerSharded {
Ok(all_events)
}
async fn process_routing_decision(
async fn process_routing_decision_for_request(
&self,
tokens_with_hashes: &mut TokensWithHashes,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let local_hashes = tokens_with_hashes.get_or_compute_block_hashes().to_vec();
let sequence_hashes = tokens_with_hashes.get_or_compute_seq_hashes().to_vec();
self.process_routing_decision_internal(worker, local_hashes, sequence_hashes)
.await
}
}
impl KvIndexerSharded {
/// Internal method to process a routing decision with pre-computed hashes.
async fn process_routing_decision_internal(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
......@@ -2098,23 +1971,6 @@ impl KvIndexerInterface for KvIndexerSharded {
.map_err(|_| KvRouterError::IndexerDroppedRequest)?;
Ok(())
}
async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size, None);
let sequence = TokenBlockSequence::new(tokens.into(), self.kv_block_size, None);
let sequence_hashes = sequence
.blocks()
.iter()
.map(|b| b.sequence_hash())
.collect::<Vec<_>>();
self.process_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
}
impl Drop for KvIndexerSharded {
......
......@@ -15,7 +15,7 @@ use dynamo_runtime::{
AsyncEngine, AsyncEngineContextProvider, Context, ManyOut, Operator, PushRouter,
RouterMode, ServerStreamingEngine, SingleIn, async_trait,
},
protocols::{annotated::Annotated, maybe_error::MaybeError},
protocols::{EndpointId, annotated::Annotated, maybe_error::MaybeError},
};
use crate::{
......@@ -23,7 +23,7 @@ use crate::{
kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride},
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
protocols::common::preprocessor::{BootstrapInfo, PrefillResult},
protocols::common::timing::RequestPhase,
protocols::common::timing::{RequestPhase, RequestTracker},
protocols::openai::nvext::WorkerIdInfo,
};
......@@ -54,14 +54,29 @@ enum InnerPrefillRouter {
}
impl InnerPrefillRouter {
/// Execute prefill generation through the underlying router
async fn generate(
/// Generate with optional direct routing to specific worker.
/// For KvRouter, target_worker is ignored since prefill_worker_id is already set on the request.
/// For SimpleRouter, target_worker triggers direct routing via router.direct().
async fn generate_to_worker(
&self,
request: SingleIn<PreprocessedRequest>,
target_worker: Option<u64>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
match (self, target_worker) {
// KvRouter: prefill_worker_id already set on request, KvPushRouter::select_worker uses it
(InnerPrefillRouter::KvRouter(router), _) => router.generate(request).await,
(InnerPrefillRouter::SimpleRouter(router), Some(worker_id)) => {
router.direct(request, worker_id).await
}
(InnerPrefillRouter::SimpleRouter(router), None) => router.generate(request).await,
}
}
/// Select next worker (for non-KV modes only)
fn select_next_worker(&self) -> Option<u64> {
match self {
InnerPrefillRouter::KvRouter(router) => router.generate(request).await,
InnerPrefillRouter::SimpleRouter(router) => router.generate(request).await,
InnerPrefillRouter::SimpleRouter(router) => router.select_next_worker(),
InnerPrefillRouter::KvRouter(_) => None,
}
}
}
......@@ -76,6 +91,8 @@ impl InnerPrefillRouter {
/// - Non-GAIE: like GAIE Stage 2 but the worker ids have to be determined.
pub struct PrefillRouter {
prefill_router: OnceLock<InnerPrefillRouter>,
model_manager: Arc<ModelManager>,
endpoint_id: OnceLock<EndpointId>,
cancel_token: CancellationToken,
router_mode: RouterMode,
enforce_disagg: bool,
......@@ -83,9 +100,15 @@ pub struct PrefillRouter {
impl PrefillRouter {
/// Create a disabled prefill router that will never activate (passthrough only)
pub fn disabled(router_mode: RouterMode, enforce_disagg: bool) -> Arc<Self> {
pub fn disabled(
model_manager: Arc<ModelManager>,
router_mode: RouterMode,
enforce_disagg: bool,
) -> Arc<Self> {
Arc::new(Self {
prefill_router: OnceLock::new(),
model_manager,
endpoint_id: OnceLock::new(),
cancel_token: CancellationToken::new(),
router_mode,
enforce_disagg,
......@@ -105,6 +128,8 @@ impl PrefillRouter {
let router = Arc::new(Self {
prefill_router,
model_manager: model_manager.clone(),
endpoint_id: OnceLock::new(),
cancel_token: cancel_token.clone(),
router_mode,
enforce_disagg,
......@@ -151,6 +176,15 @@ impl PrefillRouter {
"Activating prefill router"
);
// Store endpoint_id for later use in build_bootstrap_info
let _ = self.endpoint_id.set(endpoint.id());
// Start runtime config watcher for this endpoint (needed for get_disaggregated_endpoint)
// This must be done before creating the router so bootstrap info is available
model_manager
.get_or_create_runtime_config_watcher(&endpoint)
.await?;
let inner_router = if self.router_mode.is_kv_routing() {
// Create KV chooser using the endpoint
let kv_chooser = model_manager
......@@ -205,22 +239,18 @@ impl PrefillRouter {
/// Build bootstrap_info for disaggregated serving
/// If preselected_worker is provided (GAIE Stage 2), use it directly.
/// Otherwise, query for the best worker.
/// Otherwise, query for the best worker (KV mode) or select next worker (non-KV modes).
async fn build_bootstrap_info(
&self,
req: &PreprocessedRequest,
preselected_worker: Option<u64>,
) -> Option<(u64, u32, BootstrapInfo)> {
let endpoint_id = self.endpoint_id.get()?;
let prefill_router = self.prefill_router.get()?;
// Only works with KvRouter
let kv_router = match prefill_router {
InnerPrefillRouter::KvRouter(r) => r,
InnerPrefillRouter::SimpleRouter(_) => return None,
};
// Use pre-selected worker (GAIE Stage 2) or query for best worker
// Worker selection
let (worker_id, dp_rank) = if let Some(id) = preselected_worker {
// GAIE Stage 2: use pre-selected worker
let dp_rank = req.routing.as_ref().and_then(|r| r.dp_rank).unwrap_or(0);
tracing::debug!(
worker_id = id,
......@@ -228,7 +258,12 @@ impl PrefillRouter {
"Using pre-selected prefill worker for bootstrap"
);
(id, dp_rank)
} else {
} else if self.router_mode.is_kv_routing() {
// KV mode: use find_best_match
let kv_router = match prefill_router {
InnerPrefillRouter::KvRouter(r) => r,
_ => return None,
};
match kv_router
.chooser
.find_best_match(None, &req.token_ids, None, false)
......@@ -237,13 +272,16 @@ impl PrefillRouter {
Ok((worker, _overlap)) => (worker.worker_id, worker.dp_rank),
Err(_) => return None,
}
} else {
// Non-KV mode: use PushRouter's stateful selection
let worker_id = prefill_router.select_next_worker()?;
(worker_id, 0)
};
// Look up bootstrap endpoint from discovery
let endpoint = kv_router
.chooser
.get_disaggregated_endpoint(worker_id)
.await?;
// Get bootstrap info from ModelManager (works for ANY mode)
let endpoint = self
.model_manager
.get_disaggregated_endpoint(endpoint_id, worker_id)?;
let host = endpoint.bootstrap_host?;
let port = endpoint.bootstrap_port?;
......@@ -255,6 +293,7 @@ impl PrefillRouter {
bootstrap_host = %host,
bootstrap_port = port,
bootstrap_room = bootstrap_room,
router_mode = ?self.router_mode,
"Built bootstrap_info upfront before prefill"
);
......@@ -270,13 +309,15 @@ impl PrefillRouter {
}
/// Execute prefill with the given router and extract structured result
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization)
async fn execute_prefill(
router: Option<InnerPrefillRouter>,
request: SingleIn<PreprocessedRequest>,
target_worker: Option<u64>,
) -> Result<(PrefillResult, Option<u64>), PrefillError> {
let router = router.ok_or(PrefillError::NotActivated)?;
let mut prefill_response = router
.generate(request)
.generate_to_worker(request, target_worker)
.await
.map_err(|e| PrefillError::PrefillError(e.to_string()))?;
......@@ -339,11 +380,16 @@ impl PrefillRouter {
}
/// Spawn prefill as a background task
fn spawn_prefill_task(&self, prefill_request: SingleIn<PreprocessedRequest>) {
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization)
fn spawn_prefill_task(
&self,
prefill_request: SingleIn<PreprocessedRequest>,
target_worker: Option<u64>,
) {
let router = self.prefill_router.get().cloned();
tokio::spawn(async move {
match Self::execute_prefill(router, prefill_request).await {
match Self::execute_prefill(router, prefill_request, target_worker).await {
Ok(_) => {
tracing::debug!("Prefill background task completed");
}
......@@ -359,7 +405,8 @@ impl PrefillRouter {
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<(PrefillResult, Option<u64>), PrefillError> {
Self::execute_prefill(self.prefill_router.get().cloned(), request).await
// For call_prefill path, routing is handled by the router itself (no direct routing needed)
Self::execute_prefill(self.prefill_router.get().cloned(), request, None).await
}
}
......@@ -439,7 +486,7 @@ impl
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
// Extract request data while preserving context
let (req, context) = request.into_parts();
let (mut req, context) = request.into_parts();
let request_id = context.id().to_string();
let engine_ctx = context.context();
......@@ -452,7 +499,26 @@ impl
// Save original max_tokens for decode
let original_max_tokens = req.stop_conditions.max_tokens;
// Prepare prefill request with max_tokens = 1
// GAIE Stage 1: Check if prefill router is activated - if not, skip to decode
if is_gaie_stage1 && self.prefill_router.get().is_none() {
tracing::debug!("GAIE Stage 1: Prefill router not activated, skipping to decode");
if self.enforce_disagg {
return Err(anyhow::anyhow!(PrefillError::NotActivated));
}
// Fall back to decode-only
return next.generate(context.map(|_| req)).await;
}
// Ensure tracker exists for routing decisions in disaggregated mode.
// Create one if not provided by the upstream DeltaGenerator.
if req.tracker.is_none() {
req.tracker = Some(Arc::new(RequestTracker::new()));
}
let tracker = req.tracker.as_ref().unwrap();
tracker.set_phase(RequestPhase::Prefill);
tracker.record_prefill_start();
// Prepare prefill request with max_tokens = 1 (clone after tracker is set)
let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1);
......@@ -465,78 +531,38 @@ impl
.routing
.as_ref()
.and_then(|r| r.prefill_worker_id);
let prefill_result = if !is_gaie_stage1 {
if let Some((worker_id, dp_rank, bootstrap_info)) = self
let prefill_result = if !is_gaie_stage1
&& let Some((worker_id, dp_rank, bootstrap_info)) = self
.build_bootstrap_info(&prefill_req, preselected_worker)
.await
{
let bootstrap_room = bootstrap_info.bootstrap_room;
// Prepare request with bootstrap_room and force routing to specific worker
let routing = prefill_req.routing_mut();
routing.backend_instance_id = Some(worker_id);
routing.dp_rank = Some(dp_rank);
let extra_args = prefill_req
.extra_args
.get_or_insert_with(|| serde_json::json!({}));
if let Some(obj) = extra_args.as_object_mut() {
obj.insert(
"bootstrap_room".to_string(),
serde_json::json!(bootstrap_room),
);
}
// Set phase to Prefill and record prefill start time if tracking is enabled
if let Some(ref tracker) = req.tracker {
tracker.set_phase(RequestPhase::Prefill);
tracker.record_prefill_start();
}
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
self.spawn_prefill_task(prefill_context);
Ok((None, Some(worker_id), Some(bootstrap_info)))
} else {
// Fallback to original: Wait for prefill to complete
tracing::debug!("Using original prefill path");
{
// Bootstrap optimization path: spawn prefill in background
let routing = prefill_req.routing_mut();
routing.prefill_worker_id = Some(worker_id);
routing.backend_instance_id = Some(worker_id); // Route prefill to the SAME worker we got bootstrap_info from
routing.dp_rank = Some(dp_rank);
prefill_req.bootstrap_info = Some(bootstrap_info.clone());
// Set phase to Prefill and record prefill start time if tracking is enabled
if let Some(ref tracker) = req.tracker {
tracker.set_phase(RequestPhase::Prefill);
tracker.record_prefill_start();
}
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
self.spawn_prefill_task(prefill_context, Some(worker_id));
self.call_prefill(prefill_context)
.await
.map(|(result, worker_id)| (Some(result), worker_id, None))
}
Ok((None, Some(worker_id), Some(bootstrap_info)))
} else {
// GAIE Stage 1: Use original path (no bootstrap optimization)
// But first check if prefill router is activated - if not, skip to avoid setting phase
if self.prefill_router.get().is_none() {
tracing::debug!("GAIE Stage 1: Prefill router not activated, skipping to decode");
Err(PrefillError::NotActivated)
} else {
tracing::debug!("Using original prefill path (GAIE Stage 1)");
// Set phase to Prefill and record prefill start time if tracking is enabled
if let Some(ref tracker) = req.tracker {
tracker.set_phase(RequestPhase::Prefill);
tracker.record_prefill_start();
}
// Original prefill path: wait for prefill to complete
tracing::debug!(
is_gaie_stage1 = is_gaie_stage1,
"Using original prefill path"
);
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
self.call_prefill(prefill_context)
.await
.map(|(result, worker_id)| (Some(result), worker_id, None))
}
self.call_prefill(prefill_context)
.await
.map(|(result, worker_id)| (Some(result), worker_id, None))
};
// Abort if cancelled during prefill
......
......@@ -4,6 +4,82 @@
use crate::tokens::{SequenceHash, Token};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use xxhash_rust::xxh3;
/// Seed for XXH3 hashing, consistent with indexer.rs
pub const XXH3_SEED: u64 = 1337;
/// Compute hash of data using XXH3 with the standard seed.
pub fn compute_hash(data: &[u8]) -> u64 {
xxh3::xxh3_64_with_seed(data, XXH3_SEED)
}
/// Compute the hash of a local block.
pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
LocalBlockHash(compute_hash(data))
}
/// Compute the hash for a sequence of tokens, optionally including multimodal metadata.
///
/// When multimodal extra info is provided, the mm_hashes are included in the hash computation
/// to ensure that blocks with identical tokens but different multimodal objects produce
/// different hashes.
pub fn compute_block_hash_for_seq(
tokens: &[u32],
kv_block_size: u32,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
) -> Vec<LocalBlockHash> {
tokens
.chunks_exact(kv_block_size as usize)
.enumerate()
.map(|(block_idx, chunk)| {
let mut bytes: Vec<u8> = chunk.iter().flat_map(|&num| num.to_le_bytes()).collect();
// Include MM hashes in the block hash computation if present
if let Some(mm_infos) = block_mm_infos
&& let Some(Some(block_mm_info)) = mm_infos.get(block_idx)
{
let mut mm_hashes: Vec<u64> = block_mm_info
.mm_objects
.iter()
.map(|obj| obj.mm_hash)
.collect();
mm_hashes.sort_unstable();
for mm_hash in mm_hashes {
bytes.extend_from_slice(&mm_hash.to_le_bytes());
}
}
compute_block_hash(&bytes)
})
.collect()
}
/// Compute rolling sequence hashes for a vector of block hashes.
///
/// - The first block's sequence hash equals its block hash
/// - Subsequent blocks' sequence hash = hash([parent_sequence_hash, current_block_hash], seed)
pub fn compute_seq_hash_for_block(block_hashes: &[LocalBlockHash]) -> Vec<SequenceHash> {
if block_hashes.is_empty() {
return Vec::new();
}
let mut sequence_hashes = Vec::with_capacity(block_hashes.len());
sequence_hashes.push(block_hashes[0].0);
for i in 1..block_hashes.len() {
let parent_seq_hash = sequence_hashes[i - 1];
let current_block_hash = block_hashes[i].0;
let combined = [parent_seq_hash, current_block_hash];
let bytes: Vec<u8> = combined.iter().flat_map(|&num| num.to_le_bytes()).collect();
let seq_hash = compute_hash(&bytes);
sequence_hashes.push(seq_hash);
}
sequence_hashes
}
/// A worker identifier.
pub type WorkerId = u64;
......@@ -439,6 +515,103 @@ impl<'de> Deserialize<'de> for ExternalSequenceBlockHash {
}
}
// ------
// TokensWithHashes
// ------
/// A container for tokens with lazily computed block and sequence hashes.
///
/// This struct avoids redundant hash computations by caching results:
/// - `get_or_compute_block_hashes()` computes block hashes if not cached
/// - `get_or_compute_seq_hashes()` computes seq hashes if not cached,
/// and will also compute block hashes first if needed (since seq hashes depend on them)
#[derive(Debug, Clone)]
pub struct TokensWithHashes {
tokens: Vec<u32>,
block_size: u32,
block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>,
block_hashes: Option<Vec<LocalBlockHash>>,
seq_hashes: Option<Vec<SequenceHash>>,
}
impl TokensWithHashes {
/// Creates a new TokensWithHashes from tokens and block size.
pub fn new(tokens: Vec<u32>, block_size: u32) -> Self {
Self {
tokens,
block_size,
block_mm_infos: None,
block_hashes: None,
seq_hashes: None,
}
}
/// Adds multimodal extra info for blocks.
pub fn with_mm_infos(mut self, infos: Vec<Option<BlockExtraInfo>>) -> Self {
self.block_mm_infos = Some(infos);
self
}
/// Returns a reference to the tokens.
pub fn tokens(&self) -> &[u32] {
&self.tokens
}
/// Returns the number of tokens.
pub fn len(&self) -> usize {
self.tokens.len()
}
/// Returns true if there are no tokens.
pub fn is_empty(&self) -> bool {
self.tokens.is_empty()
}
/// Returns the block size.
pub fn block_size(&self) -> u32 {
self.block_size
}
/// Returns the multimodal extra info, if set.
pub fn block_mm_infos(&self) -> Option<&[Option<BlockExtraInfo>]> {
self.block_mm_infos.as_deref()
}
/// Returns block hashes, computing them if not already cached.
pub fn get_or_compute_block_hashes(&mut self) -> &[LocalBlockHash] {
if self.block_hashes.is_none() {
self.block_hashes = Some(compute_block_hash_for_seq(
&self.tokens,
self.block_size,
self.block_mm_infos.as_deref(),
));
}
self.block_hashes.as_ref().unwrap()
}
/// Returns sequence hashes, computing them if not already cached.
/// This will also compute block hashes if they haven't been computed yet,
/// since sequence hashes depend on block hashes.
pub fn get_or_compute_seq_hashes(&mut self) -> &[SequenceHash] {
if self.seq_hashes.is_none() {
// Ensure block hashes are computed first
let block_hashes = self.get_or_compute_block_hashes();
self.seq_hashes = Some(compute_seq_hash_for_block(block_hashes));
}
self.seq_hashes.as_ref().unwrap()
}
/// Returns cached block hashes without computing. Returns None if not yet computed.
pub fn block_hashes(&self) -> Option<&[LocalBlockHash]> {
self.block_hashes.as_deref()
}
/// Returns cached seq hashes without computing. Returns None if not yet computed.
pub fn seq_hashes(&self) -> Option<&[SequenceHash]> {
self.seq_hashes.as_deref()
}
}
// ------
// Tests
// ------
......
......@@ -28,10 +28,7 @@ use futures::StreamExt;
use crate::kv_router::{
KV_EVENT_SUBJECT, KV_METRICS_SUBJECT, WORKER_KV_INDEXER_BUFFER_SIZE,
WORKER_KV_INDEXER_QUERY_SUBJECT,
indexer::{
KvIndexerMetrics, LocalKvIndexer, RouterEvent, WorkerKvQueryRequest,
compute_block_hash_for_seq,
},
indexer::{KvIndexerMetrics, LocalKvIndexer, RouterEvent, WorkerKvQueryRequest},
protocols::*,
};
use dynamo_runtime::config::environment_names::nats as env_nats;
......@@ -1105,7 +1102,7 @@ impl WorkerMetricsPublisher {
#[cfg(test)]
mod test_event_processing {
use super::*;
use crate::kv_router::indexer::compute_block_hash_for_seq;
use crate::kv_router::protocols::compute_block_hash_for_seq;
// ---------------------------------------------------------------------
// create_stored_block_from_parts --------------------------------------
......@@ -2143,449 +2140,3 @@ mod test_integration_publisher {
);
}
}
#[cfg(all(test, feature = "integration"))]
mod test_integration_publisher_with_kvindexer {
use super::*;
use crate::kv_router::scheduler::DefaultWorkerSelector;
use crate::kv_router::{KvPushRouter, KvRouter, KvRouterConfig};
use crate::local_model::LocalModelBuilder;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::mocker::engine::{MOCKER_COMPONENT, MockVllmEngine};
use crate::mocker::protocols::MockEngineArgs;
use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use dynamo_runtime::distributed_test_utils::create_test_shared_drt_async;
use dynamo_runtime::engine::AsyncEngine;
use dynamo_runtime::pipeline::{Context, PushRouter, RouterMode, network::Ingress};
use dynamo_runtime::protocols::annotated::Annotated;
/// Integration test: KvPushRouter end-to-end routing with mock engines.
#[tokio::test(flavor = "multi_thread")]
#[ignore] // Requires NATS/etcd. Run with: cargo test --package dynamo-llm --lib --features integration test_distributed_kvindexer_e2e -- --ignored --nocapture
async fn test_distributed_kvindexer_e2e() -> anyhow::Result<()> {
const BLOCK_SIZE: u32 = 4;
const NUM_REQUESTS: usize = 4;
dynamo_runtime::logging::init();
// === SETUP: Distributed runtimes and namespaces ===
let shared_store_dir = tempfile::tempdir()?;
let shared_store_path = shared_store_dir.path().to_path_buf();
// Make both runtimes point at the same file-backed storage backend so worker
// registrations and heartbeats remain visible to every DRT instance.
let distributed1 = create_test_shared_drt_async(&shared_store_path).await;
let distributed2 = create_test_shared_drt_async(&shared_store_path).await;
let component1 = distributed1
.namespace("test_e2e_router")?
.component(MOCKER_COMPONENT)?;
let component2 = distributed2
.namespace("test_e2e_router")?
.component(MOCKER_COMPONENT)?;
// === SETUP: Start mocker workers ===
let mocker_args = MockEngineArgs::builder()
.block_size(BLOCK_SIZE as usize)
.dp_size(1) // single worker per runtime
.enable_prefix_caching(true)
.enable_local_indexer(true) // affects scheduler/publisher args
.build()?;
let worker_components = vec![component1.clone(), component2.clone()];
let mut server_handles = Vec::new();
let mut worker_ids = Vec::new();
for comp in worker_components {
let engine = Arc::new(MockVllmEngine::new(mocker_args.clone()));
engine.start(comp.clone()).await?;
tracing::info!("MockVllmEngine started for {:?}", comp);
// Register MDC with runtime_config so router can discover enable_local_indexer.
// (Without this step, the MDC-based assert in query_worker() in worker_query.rs will fail.)
// This inlines code which in the Python path would be performed by:
// - local_model.rs: LocalModelBuilder::build() sets runtime_config from MockEngineArgs
// - entrypoint/input/endpoint.rs: LocalModel::attach() registers MDC via discovery
let endpoint = comp.endpoint("generate");
let runtime_config = ModelRuntimeConfig {
enable_local_indexer: true,
..Default::default()
};
let mut builder = LocalModelBuilder::default();
builder
.model_name(Some("mock".to_string()))
.kv_cache_block_size(Some(BLOCK_SIZE))
.runtime_config(runtime_config);
let mut local_model = builder.build().await?;
local_model
.attach(
&endpoint,
crate::model_type::ModelType::Chat,
crate::model_type::ModelInput::Tokens,
None,
)
.await?;
let ingress = Ingress::for_engine(engine.clone())?;
let endpoint_component = comp.clone();
let handle = tokio::spawn(async move {
if let Err(e) = endpoint_component
.endpoint("generate")
.endpoint_builder()
.handler(ingress)
.start()
.await
{
tracing::error!("Generate endpoint failed: {e}");
}
});
server_handles.push(handle);
worker_ids.push(comp.drt().connection_id());
}
tracing::info!("Generate endpoint servers launched");
tokio::time::sleep(Duration::from_millis(500)).await;
// === SETUP: Build KvPushRouter ===
let router_distributed = create_test_shared_drt_async(&shared_store_path).await;
let router_namespace = router_distributed.namespace("test_e2e_router")?;
let backend_component = router_namespace.component(MOCKER_COMPONENT)?;
let backend_endpoint = backend_component.endpoint("generate");
let client = backend_endpoint.client().await?;
let kv_router_config = KvRouterConfig::default();
let selector = Box::new(DefaultWorkerSelector::new(Some(kv_router_config)));
let consumer_id = format!("test-router-{}", router_distributed.connection_id());
let kv_router: Arc<KvRouter> = Arc::new(
KvRouter::new(
backend_endpoint.clone(),
client.clone(),
BLOCK_SIZE,
Some(selector),
Some(kv_router_config),
consumer_id,
)
.await?,
);
let push_router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
RouterMode::KV,
None,
None,
)
.await?;
let kv_push_router = KvPushRouter::new(push_router, kv_router.clone());
// ===== TEST PART 1: ROUTE & SEND REQUESTS TO WORKERS (ROUTER -> WORKER) =====
let create_request = |tokens: Vec<u32>| {
PreprocessedRequest::builder()
.model("mock".to_string())
.token_ids(tokens)
.stop_conditions(StopConditions {
max_tokens: Some(10),
..Default::default()
})
.sampling_options(SamplingOptions::default())
.output_options(OutputOptions::default())
.build()
.unwrap()
}; // from mocker/engine.rs
for i in 0..NUM_REQUESTS {
tracing::info!("Sending routed request {}", i + 1);
let tokens = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, i as u32];
let request = create_request(tokens.clone());
let response_stream = kv_push_router.generate(Context::new(request)).await?;
let responses: Vec<Annotated<LLMEngineOutput>> = response_stream.collect().await;
assert!(
!responses.is_empty(),
"Request {} should produce at least one response",
i + 1
);
}
tracing::info!("KvPushRouter generate() succeeded for {NUM_REQUESTS} requests");
// ===== TEST PART 2: QUERY WORKER-LOCAL KVINDEXERS DIRECTLY =====
// TODO: This could be refactored as router function (e.g. router.refresh_from_worker(worker_id))
// (which should also update the global kvIndexer with the buffer from the local kvIndexer)
let mut best_worker_info: Option<(u64, usize)> = None;
// Exactly one worker should have been routed requests. Find that worker
for &worker_id in &worker_ids {
let response = kv_router
.query_worker_local_kv(worker_id, None, None)
.await?;
let events = match response {
crate::kv_router::indexer::WorkerKvQueryResponse::Events(e) => e,
crate::kv_router::indexer::WorkerKvQueryResponse::TreeDump(e) => e,
_ => vec![],
};
if events.is_empty() {
continue;
}
let event_count = events.len();
tracing::info!(
worker_id,
events = event_count,
"Worker query on worker {worker_id} returned buffered KV events"
);
best_worker_info = Some((worker_id, event_count));
break;
}
// Verify that only one worker has KV events in buffer
let (best_worker_id, best_worker_event_count) =
best_worker_info.expect("At least one worker should have buffered KV events");
tracing::info!(
"Best worker is {best_worker_id} with {best_worker_event_count} buffered KV events"
);
for &worker_id in &worker_ids {
if worker_id == best_worker_id {
continue;
}
let response = kv_router
.query_worker_local_kv(worker_id, None, None)
.await?;
let events = match response {
crate::kv_router::indexer::WorkerKvQueryResponse::Events(e) => e,
crate::kv_router::indexer::WorkerKvQueryResponse::TreeDump(e) => e,
_ => vec![],
};
assert!(
events.is_empty(),
"Worker {worker_id} should not report buffered KV events; best worker {best_worker_id} reported {best_worker_event_count}"
);
}
// === Cleanup ===
for handle in server_handles {
handle.abort();
}
distributed1.shutdown();
distributed2.shutdown();
router_distributed.shutdown();
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
#[ignore]
async fn test_distributed_kvindexer_e2e_startup() -> anyhow::Result<()> {
const BLOCK_SIZE: u32 = 4;
dynamo_runtime::logging::init();
// === SETUP: Distributed runtimes and namespaces ===
let shared_store_dir = tempfile::tempdir()?;
let shared_store_path = shared_store_dir.path().to_path_buf();
// Use a unique namespace per test run for full isolation
let test_namespace = format!("test_e2e_{}", uuid::Uuid::new_v4().simple());
// Make both runtimes point at the same file-backed storage backend so worker
// registrations and heartbeats remain visible to every DRT instance.
let distributed1 = create_test_shared_drt_async(&shared_store_path).await;
let distributed2 = create_test_shared_drt_async(&shared_store_path).await;
let component1 = distributed1
.namespace(&test_namespace)?
.component(MOCKER_COMPONENT)?;
let component2 = distributed2
.namespace(&test_namespace)?
.component(MOCKER_COMPONENT)?;
// === SETUP: Start mocker workers ===
let mocker_args = MockEngineArgs::builder()
.block_size(BLOCK_SIZE as usize)
.dp_size(1) // single worker per runtime
.enable_prefix_caching(true)
.enable_local_indexer(true) // affects scheduler/publisher args
.build()?;
let worker_components = vec![component1.clone(), component2.clone()];
let mut server_handles = Vec::new();
let mut worker_ids = Vec::new();
for comp in worker_components {
let engine: Arc<MockVllmEngine> = Arc::new(MockVllmEngine::new(mocker_args.clone()));
engine.start(comp.clone()).await?;
tracing::info!("MockVllmEngine started for {:?}", comp);
// Register MDC with runtime_config so router can discover enable_local_indexer.
// (Without this step, the MDC-based assert in query_worker() in worker_query.rs will fail.)
// This inlines code which in the Python path would be performed by:
// - local_model.rs: LocalModelBuilder::build() sets runtime_config from MockEngineArgs
// - entrypoint/input/endpoint.rs: LocalModel::attach() registers MDC via discovery
let endpoint = comp.endpoint("generate");
let runtime_config = ModelRuntimeConfig {
enable_local_indexer: true,
..Default::default()
};
let mut builder = LocalModelBuilder::default();
builder
.model_name(Some("mock".to_string()))
.kv_cache_block_size(Some(BLOCK_SIZE))
.runtime_config(runtime_config);
let mut local_model = builder.build().await?;
local_model
.attach(
&endpoint,
crate::model_type::ModelType::Chat,
crate::model_type::ModelInput::Tokens,
None,
)
.await?;
let ingress = Ingress::for_engine(engine.clone())?;
let endpoint_component = comp.clone();
let handle = tokio::spawn(async move {
if let Err(e) = endpoint_component
.endpoint("generate")
.endpoint_builder()
.handler(ingress)
.start()
.await
{
tracing::error!("Generate endpoint failed: {e}");
}
});
server_handles.push(handle);
worker_ids.push(comp.drt().connection_id());
}
tracing::info!("Generate endpoint servers launched");
tokio::time::sleep(Duration::from_millis(500)).await;
// === STEP 1: Send request to worker_ids[0] to populate its local indexer ===
// This simulates a situation where KvPushRouter is initialized
// to route to workers which already have KV events
let pre_router_distributed = create_test_shared_drt_async(&shared_store_path).await;
let pre_backend_endpoint = pre_router_distributed
.namespace(&test_namespace)?
.component(MOCKER_COMPONENT)?
.endpoint("generate");
let pre_client = pre_backend_endpoint.client().await?;
// Wait for the client to discover both workers
let discovery_timeout = Duration::from_secs(5);
let discovery_start = std::time::Instant::now();
loop {
let instances = pre_client.instance_source.as_ref().borrow().clone();
if instances.len() >= 2 {
tracing::info!("Discovered {} workers", instances.len());
break;
}
if discovery_start.elapsed() > discovery_timeout {
anyhow::bail!(
"Timed out waiting for worker discovery: expected 2, found {}",
instances.len()
);
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
// Create a PushRouter to send requests directly to a specific worker
let pre_push_router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
pre_client,
RouterMode::Random, // We'll use direct() so mode doesn't matter
None,
None,
)
.await?;
// Force sending one requests each to the two workers
for &worker_id in &worker_ids {
let tokens: Vec<u32> = vec![0, 1, 2, 3];
let request = PreprocessedRequest::builder()
.model("mock".to_string())
.token_ids(tokens.clone())
.sampling_options(SamplingOptions::default())
.output_options(OutputOptions::default())
.stop_conditions(StopConditions {
max_tokens: Some(5),
..Default::default()
})
.build()?;
let response_stream = pre_push_router
.direct(Context::new(request), worker_id)
.await?;
// Consume the stream to complete the request
let _responses: Vec<_> = response_stream.collect().await;
tracing::debug!(
"Sent request {:?} directly to worker {} to populate its local indexer",
tokens,
worker_id
);
}
tokio::time::sleep(Duration::from_millis(1000)).await;
// === SETUP: Build KvPushRouter ===
let router_distributed = create_test_shared_drt_async(&shared_store_path).await;
let router_namespace = router_distributed.namespace(&test_namespace)?;
let backend_component = router_namespace.component(MOCKER_COMPONENT)?;
let backend_endpoint = backend_component.endpoint("generate");
let client = backend_endpoint.client().await?;
let kv_router_config = KvRouterConfig::default();
let selector = Box::new(DefaultWorkerSelector::new(Some(kv_router_config)));
let consumer_id = format!("test-router-{}", router_distributed.connection_id());
let kv_router: Arc<KvRouter> = Arc::new(
KvRouter::new(
backend_endpoint.clone(),
client.clone(),
BLOCK_SIZE,
Some(selector),
Some(kv_router_config),
consumer_id,
)
.await?,
);
// The KvRouter now starts its subscriber asynchronously in a background task
// that waits for runtime_configs. Poll until events appear or timeout.
// Each request generates 2 events: input block (parent_hash: None) + output block (parent_hash: Some)
// With 2 workers, that's 4 events total.
let expected_events = 4;
let max_wait = Duration::from_secs(10);
let poll_interval = Duration::from_millis(100);
let start = std::time::Instant::now();
let global_kv_events = loop {
let events = kv_router.indexer.dump_events().await?;
tracing::debug!("Global KV events ({}): {:?}", events.len(), events);
if events.len() >= expected_events {
break events;
}
if start.elapsed() > max_wait {
anyhow::bail!(
"Timed out waiting for KV events: expected {}, got {}",
expected_events,
events.len()
);
}
tokio::time::sleep(poll_interval).await;
};
assert_eq!(global_kv_events.len(), expected_events); // 2 workers × 2 events per request (input + output)
// === Cleanup ===
for handle in server_handles {
handle.abort();
}
distributed1.shutdown();
distributed2.shutdown();
router_distributed.shutdown();
Ok(())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::local_model::runtime_config::{DisaggregatedEndpoint, ModelRuntimeConfig};
use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result;
use dashmap::DashMap;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventPublisher;
......@@ -11,7 +12,7 @@ use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{RwLock, watch};
use tokio::sync::watch;
use super::KV_HIT_RATE_SUBJECT;
use super::KvRouterConfig;
......@@ -90,8 +91,6 @@ impl SchedulingRequest {
pub struct KvScheduler {
request_tx: tokio::sync::mpsc::Sender<SchedulingRequest>,
slots: Arc<ActiveSequencesMultiWorker>,
/// Worker runtime configs for looking up disaggregated endpoints
workers_with_configs: Arc<RwLock<HashMap<WorkerId, Option<ModelRuntimeConfig>>>>,
}
impl KvScheduler {
......@@ -99,92 +98,71 @@ impl KvScheduler {
component: Component,
block_size: u32,
instance_ids_rx: watch::Receiver<Vec<u64>>,
runtime_configs_rx: watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>,
workers_with_configs: Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool,
router_uuid: String,
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
let instance_ids: Vec<u64> = instance_ids_rx.borrow().clone();
let runtime_configs: HashMap<WorkerId, ModelRuntimeConfig> =
runtime_configs_rx.borrow().clone();
// Create shared workers_with_configs wrapped in Arc<RwLock>
let workers_with_configs: Arc<RwLock<HashMap<WorkerId, Option<ModelRuntimeConfig>>>> = {
let mut initial_map = HashMap::new();
for worker_id in &instance_ids {
let config = runtime_configs.get(worker_id).cloned();
if config.is_some() {
tracing::info!("Runtime config found for worker_id: {}", worker_id);
}
initial_map.insert(*worker_id, config);
}
Arc::new(RwLock::new(initial_map))
};
// Get initial workers from DashMap for slot initialization
let initial_workers: HashMap<WorkerId, Option<ModelRuntimeConfig>> = workers_with_configs
.iter()
.map(|r| (*r.key(), r.value().clone()))
.collect();
let slots = Arc::new(ActiveSequencesMultiWorker::new(
component.clone(),
block_size as usize,
workers_with_configs.read().await.clone(), // this includes dp_size info
initial_workers,
replica_sync,
router_uuid,
));
// Spawn background task to monitor and update workers_with_configs
let workers_monitor = workers_with_configs.clone();
// Spawn background task to monitor workers_with_configs changes and update slots
let slots_monitor = slots.clone();
let workers_monitor = workers_with_configs.clone();
let mut instance_ids_monitor_rx = instance_ids_rx.clone();
let mut configs_monitor_rx = runtime_configs_rx.clone();
let monitor_cancel_token = component.drt().child_token();
tokio::spawn(async move {
tracing::trace!("workers monitoring task started");
tracing::trace!("KvScheduler workers monitoring task started");
let mut last_workers: HashSet<WorkerId> = HashSet::new();
loop {
// Wait for either instances or configs to change
// Wait for instance changes (ModelManager handles config updates to the DashMap)
tokio::select! {
_ = monitor_cancel_token.cancelled() => {
tracing::trace!("workers monitoring task shutting down");
tracing::trace!("KvScheduler workers monitoring task shutting down");
break;
}
result = instance_ids_monitor_rx.changed() => {
if result.is_err() {
tracing::warn!("instance IDs watch sender shutdown in monitor");
break;
}
}
result = configs_monitor_rx.changed() => {
if result.is_err() {
tracing::warn!("runtime configs watch sender shutdown in monitor");
tracing::warn!("instance IDs watch sender shutdown in KvScheduler monitor");
break;
}
}
}
// Get the latest values from both channels
let new_instance_ids = instance_ids_monitor_rx.borrow_and_update().clone();
let new_configs = configs_monitor_rx.borrow_and_update().clone();
// Build the new workers_with_configs map
let mut new_workers_with_configs = HashMap::new();
for worker_id in &new_instance_ids {
let config = new_configs.get(worker_id).cloned();
if config.is_some() {
tracing::info!("Runtime config found for worker_id: {}", worker_id);
}
new_workers_with_configs.insert(*worker_id, config);
// Get current workers from DashMap
let current_workers: HashMap<WorkerId, Option<ModelRuntimeConfig>> =
workers_monitor
.iter()
.map(|r| (*r.key(), r.value().clone()))
.collect();
let current_worker_ids: HashSet<WorkerId> =
current_workers.keys().copied().collect();
// Only update slots if workers have changed
if current_worker_ids != last_workers {
slots_monitor.update_workers(current_workers);
last_workers = current_worker_ids;
tracing::trace!(
"KvScheduler: Updated slots with {} workers",
last_workers.len()
);
}
// Update workers when instances change
slots_monitor.update_workers(new_workers_with_configs.clone());
// Update the shared workers_with_configs
let mut workers_map = workers_monitor.write().await;
*workers_map = new_workers_with_configs;
tracing::trace!(
"Updated workers_with_configs with {} workers",
workers_map.len()
);
}
tracing::trace!("workers monitoring task shutting down");
tracing::trace!("KvScheduler workers monitoring task shutting down");
});
let slots_clone = slots.clone();
......@@ -222,8 +200,11 @@ impl KvScheduler {
request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens;
// Read the current workers configuration
let workers = workers_scheduler.read().await.clone();
// Read the current workers configuration from DashMap
let workers: HashMap<WorkerId, Option<ModelRuntimeConfig>> = workers_scheduler
.iter()
.map(|r| (*r.key(), r.value().clone()))
.collect();
match selector.select_worker(&workers, &request, block_size) {
Ok(selection) => {
......@@ -289,11 +270,7 @@ impl KvScheduler {
tracing::trace!("background endpoint subscriber shutting down");
});
Ok(KvScheduler {
request_tx,
slots,
workers_with_configs,
})
Ok(KvScheduler { request_tx, slots })
}
pub async fn schedule(
......@@ -352,17 +329,6 @@ impl KvScheduler {
self.slots.free(&request_id.to_string()).await
}
pub async fn get_disaggregated_endpoint(
&self,
worker_id: WorkerId,
) -> Option<DisaggregatedEndpoint> {
let workers = self.workers_with_configs.read().await;
workers
.get(&worker_id)
.and_then(|config| config.as_ref())
.and_then(|config| config.disaggregated_endpoint.clone())
}
pub async fn get_potential_loads(
&self,
token_seq: Option<Vec<SequenceHash>>,
......
......@@ -198,6 +198,34 @@ where
.await
}
/// Select the next worker according to the routing mode.
/// Increments round-robin counter if applicable.
/// Panics if called on Direct or KV mode - those have their own selection mechanisms.
pub fn select_next_worker(&self) -> Option<u64> {
let instance_ids = self.client.instance_ids_avail();
let count = instance_ids.len();
if count == 0 {
return None;
}
match self.router_mode {
RouterMode::RoundRobin => {
let counter = self.round_robin_counter.fetch_add(1, Ordering::Relaxed) as usize;
Some(instance_ids[counter % count])
}
RouterMode::Random => {
let counter = rand::rng().random::<u64>() as usize;
Some(instance_ids[counter % count])
}
_ => {
panic!(
"select_next_worker should not be called for {:?} routing mode",
self.router_mode
)
}
}
}
/*
pub async fn r#static(&self, request: SingleIn<T>) -> anyhow::Result<ManyOut<U>> {
let subject = self.client.endpoint.subject();
......
......@@ -1894,6 +1894,7 @@ def _test_router_decisions(
request,
test_dp_rank: bool = False,
block_size: int = BLOCK_SIZE,
use_kv_events: bool = True,
):
"""Validate KV cache prefix reuse and worker routing by sending requests diverging prefixes.
......@@ -1912,12 +1913,17 @@ def _test_router_decisions(
model_name: Name of the model
request: Pytest request fixture
test_dp_rank: If True, also forces and validates dp_rank routing (for data parallel setups)
use_kv_events: If True (default), uses KV events from workers. If False, uses
approximate routing with TTL-based expiration (--no-kv-events mode).
Raises:
AssertionError: If routing decisions don't follow KV cache prefix reuse as expected
"""
# Create KvRouterConfig with lower snapshot threshold for testing
kv_router_config = KvRouterConfig(router_snapshot_threshold=20)
kv_router_config = KvRouterConfig(
router_snapshot_threshold=20,
use_kv_events=use_kv_events,
)
kv_push_router = KvPushRouter(
endpoint=endpoint,
block_size=block_size,
......
......@@ -596,30 +596,49 @@ def test_query_instance_id_returns_worker_and_tokens(
@pytest.mark.timeout(29) # ~3x average (~9.55s), rounded up
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize("use_nats_core", [False, True], ids=["jetstream", "nats_core"])
@pytest.mark.parametrize(
"use_nats_core,use_kv_events",
[
(False, True), # JetStream mode (default)
(True, True), # NATS Core + local indexer mode
(False, False), # Approximate mode (--no-kv-events)
],
ids=["jetstream", "nats_core", "no_kv_events"],
)
def test_router_decisions(
request,
runtime_services_dynamic_ports,
predownload_tokenizers,
use_nats_core,
use_kv_events,
request_plane,
):
"""Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes.
Parameterized to test both JetStream (default) and NATS Core (local indexer) modes.
Parameterized to test:
- JetStream mode (default): KV events via JetStream
- NATS Core mode: KV events via NATS Core with local indexer on workers
- Approximate mode (--no-kv-events): No KV events, router predicts cache state
based on routing decisions with TTL-based expiration and pruning
"""
# runtime_services_dynamic_ports handles NATS and etcd startup
mode = "NATS Core (local indexer)" if use_nats_core else "JetStream"
if not use_kv_events:
mode = "Approximate (no-kv-events)"
elif use_nats_core:
mode = "NATS Core (local indexer)"
else:
mode = "JetStream"
logger.info(
f"Starting test router prefix reuse and KV events synchronization ({mode})"
)
# Create mocker args dictionary with dp_size=4
# Note: enable_local_indexer only applies when use_kv_events=True and use_nats_core=True
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"dp_size": 4,
"enable_local_indexer": use_nats_core,
"enable_local_indexer": use_nats_core and use_kv_events,
}
try:
......@@ -645,7 +664,12 @@ def test_router_decisions(
endpoint = component.endpoint("generate")
_test_router_decisions(
mockers, endpoint, MODEL_NAME, request, test_dp_rank=True
mockers,
endpoint,
MODEL_NAME,
request,
test_dp_rank=True,
use_kv_events=use_kv_events,
)
finally:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment