// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::sync::Arc; use std::time::Instant; use anyhow::Result; use dynamo_kv_router::{ PrefillLoadEstimator, config::{KvRouterConfig, RouterConfigOverride, min_initial_workers_from_env}, indexer::KvRouterError, protocols::KV_EVENT_SUBJECT, protocols::{ BlockExtraInfo, BlockHashOptions, DpRank, LocalBlockHash, PrefillLoadHint, RouterEvent, RouterRequest, RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank, compute_block_hash_for_seq, }, }; use dynamo_runtime::{ component::{Client, Endpoint}, discovery::DiscoveryQuery, pipeline::{ AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, ResponseStream, SingleIn, async_trait, }, protocols::EndpointId, protocols::annotated::Annotated, traits::DistributedRuntimeProvider, }; use futures::stream; use tracing::Instrument; use validator::Validate; // Re-export from dynamo-kv-router crate pub use dynamo_kv_router::approx; pub use dynamo_kv_router::protocols; pub use dynamo_kv_router::scheduling; pub use dynamo_kv_router::selector; pub mod agent_controller; pub mod indexer; pub mod metrics; pub mod prefill_router; pub mod publisher; pub mod push_router; pub mod scheduler; pub mod sequence; pub mod sticky_sessions; pub use agent_controller::AgentController; pub use indexer::{Indexer, ServedIndexerHandle, ServedIndexerMode, ensure_served_indexer_service}; pub use prefill_router::PrefillRouter; pub use push_router::{DirectRoutingRouter, KvPushRouter}; pub use sticky_sessions::StickySessionRouter; use crate::{ discovery::RuntimeConfigWatch, kv_router::{ scheduler::{DefaultWorkerSelector, KvScheduler, PotentialLoad}, sequence::{SequenceError, SequenceRequest}, }, local_model::runtime_config::ModelRuntimeConfig, }; use std::collections::HashSet; // [gluo TODO] shouldn't need to be public // this should be discovered from the component // for metric scraping (pull-based) pub const KV_METRICS_ENDPOINT: &str = "load_metrics"; // for metric publishing (push-based) pub const KV_METRICS_SUBJECT: &str = "kv_metrics"; // for inter-router comms pub const PREFILL_SUBJECT: &str = "prefill_events"; pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events"; // for radix tree snapshot storage pub const RADIX_STATE_BUCKET: &str = "radix-bucket"; pub const RADIX_STATE_FILE: &str = "radix-state"; // for worker-local kvindexer query pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer /// Generates a dp_rank-specific endpoint name for the worker KV indexer query service. /// Each dp_rank has its own LocalKvIndexer and query endpoint to ensure per-dp_rank monotonicity. pub fn worker_kv_indexer_query_endpoint(dp_rank: DpRank) -> String { format!("worker_kv_indexer_query_dp{dp_rank}") } fn log_routing_input_hashes( request_id: Option<&str>, block_size: u32, tokens: &[u32], local_hashes: &[LocalBlockHash], ) { if !tracing::enabled!(tracing::Level::DEBUG) { return; } let local_hash_ids: Vec = local_hashes.iter().map(|hash| hash.0).collect(); tracing::debug!( request_id = request_id.unwrap_or(""), isl_tokens = tokens.len(), block_size, num_blocks = local_hashes.len(), local_hashes = ?local_hash_ids, "[ROUTING_INPUT] request local hashes" ); } // for router discovery registration pub const KV_ROUTER_ENDPOINT: &str = "router-discovery"; /// Creates an EndpointId for the KV router in the given namespace. pub fn router_endpoint_id(namespace: String, component: String) -> EndpointId { EndpointId { namespace, component, name: KV_ROUTER_ENDPOINT.to_string(), } } /// Creates a DiscoveryQuery for the KV router in the given namespace. pub fn router_discovery_query(namespace: String, component: String) -> DiscoveryQuery { DiscoveryQuery::Endpoint { namespace, component, endpoint: KV_ROUTER_ENDPOINT.to_string(), } } /// A KvRouter only decides which worker you should use. It doesn't send you there. /// TODO: Rename this to indicate it only selects a worker, it does not route. pub struct KvRouter where Sel: dynamo_kv_router::selector::WorkerSelector, { indexer: Indexer, scheduler: KvScheduler, workers_with_configs: RuntimeConfigWatch, block_size: u32, kv_router_config: KvRouterConfig, prefill_load_estimator: Option>, cancellation_token: tokio_util::sync::CancellationToken, client: Client, is_eagle: bool, _served_indexer_handle: Option, } impl KvRouter where Sel: dynamo_kv_router::selector::WorkerSelector + Send + Sync + 'static, { #[allow(clippy::too_many_arguments)] pub async fn new( endpoint: Endpoint, client: Client, workers_with_configs: RuntimeConfigWatch, block_size: u32, selector: Sel, kv_router_config: Option, prefill_load_estimator: Option>, worker_type: &'static str, model_name: Option, is_eagle: bool, ) -> Result { let kv_router_config = kv_router_config.unwrap_or_default(); kv_router_config.validate()?; let component = endpoint.component(); let cancellation_token = component.drt().primary_token(); let min_initial_workers = min_initial_workers_from_env()?; let indexer = Indexer::new( component, &kv_router_config, block_size, model_name.as_deref(), ) .await?; if min_initial_workers > 0 && !kv_router_config.skip_initial_worker_wait { let mut startup_watch = workers_with_configs.clone(); let _ = startup_watch .wait_for(|m| m.len() >= min_initial_workers) .await .map_err(|_| { anyhow::anyhow!( "runtime config watch closed before {} workers appeared", min_initial_workers ) })?; } let scheduler = KvScheduler::start( component.clone(), block_size, workers_with_configs.clone(), selector, &kv_router_config, prefill_load_estimator.clone(), worker_type, ) .await?; // Start KV event subscription if needed — skip when using a remote indexer. if kv_router_config.use_remote_indexer { tracing::info!("Skipping KV event subscription (using remote indexer)"); } else if kv_router_config.should_subscribe_to_kv_events() { indexer::start_subscriber(component.clone(), &kv_router_config, indexer.clone()) .await?; } else { tracing::info!( "Skipping KV event subscription (use_kv_events={}, overlap_score_weight={})", kv_router_config.use_kv_events, kv_router_config.overlap_score_weight, ); } let served_indexer_handle = if kv_router_config.serve_indexer { let model_name = model_name.clone().ok_or_else(|| { anyhow::anyhow!("model_name is required when serve_indexer is configured") })?; Some( ensure_served_indexer_service( component.clone(), ServedIndexerMode::from_use_kv_events(kv_router_config.use_kv_events), model_name, indexer.clone(), ) .await?, ) } else { None }; tracing::info!("KV Routing initialized"); Ok(Self { indexer, scheduler, workers_with_configs, block_size, kv_router_config, prefill_load_estimator, cancellation_token, client, is_eagle, _served_indexer_handle: served_indexer_handle, }) } /// Get a reference to the client used by this KvRouter pub fn client(&self) -> &Client { &self.client } pub fn indexer(&self) -> &Indexer { &self.indexer } pub fn kv_router_config(&self) -> &KvRouterConfig { &self.kv_router_config } pub fn is_eagle(&self) -> bool { self.is_eagle } pub async fn record_routing_decision( &self, mut tokens_with_hashes: TokensWithHashes, worker: WorkerWithDpRank, ) -> Result<(), KvRouterError> { self.indexer .process_routing_decision_for_request(&mut tokens_with_hashes, worker) .await } /// Give these tokens, find the worker with the best match in it's KV cache. /// Returns the best worker (with dp_rank) and overlap amount in number of blocks. /// Now also takes optional context_id for request tracking. /// /// When `pinned_worker` is Some, scheduling and queueing are constrained to /// that exact worker/rank. /// /// When `allowed_worker_ids` is Some, only workers in that set are considered for selection. #[allow(clippy::too_many_arguments)] pub async fn find_best_match( &self, context_id: Option<&str>, tokens: &[u32], block_mm_infos: Option<&[Option]>, router_config_override: Option<&RouterConfigOverride>, update_states: bool, lora_name: Option, priority_jump: f64, expected_output_tokens: Option, pinned_worker: Option, allowed_worker_ids: Option>, ) -> anyhow::Result<(WorkerWithDpRank, u32)> { let start = Instant::now(); if update_states && context_id.is_none() { anyhow::bail!("context_id must be provided when update_states is true"); } let isl_tokens = tokens.len(); let hash_options = BlockHashOptions { block_mm_infos, lora_name: lora_name.as_deref(), is_eagle: Some(self.is_eagle), }; let block_hashes = tracing::info_span!("kv_router.compute_block_hashes") .in_scope(|| compute_block_hash_for_seq(tokens, self.block_size, hash_options)); log_routing_input_hashes(context_id, self.block_size, tokens, &block_hashes); let hash_elapsed = start.elapsed(); // Compute seq_hashes only if scheduler needs it for active blocks tracking let maybe_seq_hashes = tracing::info_span!("kv_router.compute_seq_hashes").in_scope(|| { self.kv_router_config.compute_seq_hashes_for_tracking( tokens, self.block_size, router_config_override, hash_options, Some(&block_hashes), ) }); let seq_hash_elapsed = start.elapsed(); let overlap_scores = self .indexer .find_matches(block_hashes) .instrument(tracing::info_span!("kv_router.find_matches")) .await?; let find_matches_elapsed = start.elapsed(); let response = self .scheduler .schedule( context_id.map(|s| s.to_string()), isl_tokens, maybe_seq_hashes, overlap_scores, router_config_override, update_states, lora_name, priority_jump, expected_output_tokens, pinned_worker, allowed_worker_ids, ) .instrument(tracing::info_span!("kv_router.schedule")) .await?; let total_elapsed = start.elapsed(); if let Some(m) = metrics::RoutingOverheadMetrics::get() { m.observe( hash_elapsed, seq_hash_elapsed, find_matches_elapsed, total_elapsed, ); } #[cfg(feature = "bench")] tracing::info!( isl_tokens, hash_us = hash_elapsed.as_micros() as u64, seq_hash_us = (seq_hash_elapsed - hash_elapsed).as_micros() as u64, find_matches_us = (find_matches_elapsed - seq_hash_elapsed).as_micros() as u64, schedule_us = (total_elapsed - find_matches_elapsed).as_micros() as u64, total_us = total_elapsed.as_micros() as u64, "find_best_match completed" ); Ok((response.best_worker, response.overlap_blocks)) } /// Register externally-provided workers in the slot tracker. pub fn register_workers(&self, worker_ids: &HashSet) { self.scheduler.register_workers(worker_ids); } #[allow(clippy::too_many_arguments)] pub async fn add_request( &self, request_id: String, tokens: &[u32], block_mm_infos: Option<&[Option]>, overlap_blocks: u32, expected_output_tokens: Option, worker: WorkerWithDpRank, lora_name: Option, router_config_override: Option<&RouterConfigOverride>, ) { let isl_tokens = tokens.len(); let hash_options = BlockHashOptions { block_mm_infos, lora_name: lora_name.as_deref(), is_eagle: Some(self.is_eagle), }; let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking( tokens, self.block_size, router_config_override, hash_options, None, ); let track_prefill_tokens = self .kv_router_config .track_prefill_tokens(router_config_override); let prefill_load_hint = self.prefill_load_hint_for(isl_tokens, overlap_blocks, track_prefill_tokens); if let Err(e) = self .scheduler .add_request(SequenceRequest { request_id: request_id.clone(), token_sequence: maybe_seq_hashes, track_prefill_tokens, expected_output_tokens, prefill_load_hint, worker, lora_name, }) .await { tracing::warn!("Failed to add request {request_id}: {e}"); } } pub async fn mark_prefill_completed(&self, request_id: &str) -> Result<(), SequenceError> { self.scheduler.mark_prefill_completed(request_id).await } pub async fn free(&self, request_id: &str) -> Result<(), SequenceError> { self.scheduler.free(request_id).await } /// Number of requests currently parked in the scheduler queue. pub fn pending_count(&self) -> usize { self.scheduler.pending_count() } fn prefill_load_hint_for( &self, isl_tokens: usize, overlap_blocks: u32, track_prefill_tokens: bool, ) -> Option { if !track_prefill_tokens { return None; } let prefix = (overlap_blocks as usize) * (self.block_size as usize); let effective_isl = isl_tokens.saturating_sub(prefix); if effective_isl == 0 { return None; } let expected_prefill_duration = match &self.prefill_load_estimator { Some(estimator) => match estimator.predict_prefill_duration(1, effective_isl, prefix) { Ok(expected_prefill_duration) => Some(expected_prefill_duration), Err(error) => { tracing::warn!( effective_isl, prefix, "failed to predict prefill duration for direct add_request path: {error}" ); None } }, None => None, }; Some(PrefillLoadHint { initial_effective_prefill_tokens: effective_isl, expected_prefill_duration, }) } /// Get the worker type for this router ("prefill" or "decode"). /// Used for Prometheus metric labeling. pub fn worker_type(&self) -> &'static str { self.scheduler.worker_type() } /// Return the worker's unique global DP rank when it owns exactly one rank. pub fn unique_dp_rank_for_worker(&self, worker_id: WorkerId) -> Option { let configs = self.workers_with_configs.borrow(); let config = configs.get(&worker_id)?; (config.data_parallel_size == 1).then_some(config.data_parallel_start_rank) } pub fn add_output_block( &self, request_id: &str, decay_fraction: Option, ) -> Result<(), SequenceError> { self.scheduler.add_output_block(request_id, decay_fraction) } pub fn block_size(&self) -> u32 { self.block_size } /// Compute the overlap blocks for a given token sequence and worker. /// This queries the indexer to find how many blocks are already cached. pub async fn get_overlap_blocks( &self, tokens: &[u32], block_mm_infos: Option<&[Option]>, worker: WorkerWithDpRank, lora_name: Option<&str>, ) -> Result { let block_hashes = compute_block_hash_for_seq( tokens, self.block_size, BlockHashOptions { block_mm_infos, lora_name, is_eagle: Some(self.is_eagle), }, ); log_routing_input_hashes(None, self.block_size, tokens, &block_hashes); let overlap_scores = self.indexer.find_matches(block_hashes).await?; Ok(overlap_scores.scores.get(&worker).copied().unwrap_or(0)) } /// Get potential prefill and decode loads for all workers pub async fn get_potential_loads( &self, tokens: &[u32], router_config_override: Option<&RouterConfigOverride>, block_mm_infos: Option<&[Option]>, lora_name: Option<&str>, ) -> Result> { let isl_tokens = tokens.len(); let hash_options = BlockHashOptions { block_mm_infos, lora_name, is_eagle: Some(self.is_eagle), }; let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, hash_options); let maybe_seq_hashes = self.kv_router_config.compute_seq_hashes_for_tracking( tokens, self.block_size, router_config_override, hash_options, Some(&block_hashes), ); let track_prefill_tokens = self .kv_router_config .track_prefill_tokens(router_config_override); let overlap_scores = self.indexer.find_matches(block_hashes).await?; Ok(self.scheduler.get_potential_loads( maybe_seq_hashes, isl_tokens, overlap_scores, track_prefill_tokens, )) } /// Dump all events from the indexer pub async fn dump_events(&self) -> Result, KvRouterError> { self.indexer.dump_events().await } } // NOTE: KVRouter works like a PushRouter, // but without the reverse proxy functionality, but based on contract of 3 request types #[async_trait] impl AsyncEngine, ManyOut>, Error> for KvRouter where Sel: dynamo_kv_router::selector::WorkerSelector + Send + Sync + 'static, { async fn generate( &self, request: SingleIn, ) -> Result>> { let (request, ctx) = request.into_parts(); let context_id = ctx.context().id().to_string(); // Handle different request types let response = match request { RouterRequest::New { tokens, block_mm_infos, } => { let (best_worker, overlap_blocks) = self .find_best_match( Some(&context_id), &tokens, block_mm_infos.as_deref(), None, true, None, 0.0, None, None, None, ) .await?; RouterResponse::New { worker_id: best_worker.worker_id, dp_rank: best_worker.dp_rank, overlap_blocks, } } RouterRequest::MarkPrefill => RouterResponse::PrefillMarked { success: self.mark_prefill_completed(&context_id).await.is_ok(), }, RouterRequest::MarkFree { request_id } => { let request_id = match request_id.as_deref() { Some(request_id) if !request_id.trim().is_empty() => request_id, _ => &context_id, }; RouterResponse::FreeMarked { success: self.free(request_id).await.is_ok(), } } }; let response = Annotated::from_data(response); let stream = stream::iter(vec![response]); Ok(ResponseStream::new(Box::pin(stream), ctx.context())) } } impl Drop for KvRouter where Sel: dynamo_kv_router::selector::WorkerSelector, { fn drop(&mut self) { tracing::info!("Dropping KvRouter - cancelling background tasks"); self.cancellation_token.cancel(); } }