Unverified Commit 02b1c58a authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat(mocker): add offline disagg replay (#7617)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 4b8826b3
...@@ -135,6 +135,7 @@ mod tests { ...@@ -135,6 +135,7 @@ mod tests {
overlaps, overlaps,
decode_blocks: HashMap::new(), decode_blocks: HashMap::new(),
prefill_tokens: HashMap::new(), prefill_tokens: HashMap::new(),
track_prefill_tokens: true,
router_config_override: None, router_config_override: None,
update_states: false, update_states: false,
lora_name: None, lora_name: None,
......
...@@ -191,11 +191,14 @@ impl< ...@@ -191,11 +191,14 @@ impl<
/// Run the full scheduling pipeline for a single request: /// Run the full scheduling pipeline for a single request:
/// compute potential load -> select worker -> respond -> book via add_request. /// compute potential load -> select worker -> respond -> book via add_request.
async fn schedule(&self, mut request: SchedulingRequest) { async fn schedule(&self, mut request: SchedulingRequest) {
let (decode_blocks, prefill_tokens) = self.slots.potential_blocks_and_tokens( let (decode_blocks, prefill_tokens) = self
request.token_seq.as_deref(), .slots
request.isl_tokens, .potential_blocks_and_tokens_with_prefill_tracking(
request.overlaps.clone(), request.token_seq.as_deref(),
); request.isl_tokens,
request.overlaps.clone(),
request.track_prefill_tokens,
);
request.decode_blocks = decode_blocks; request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens; request.prefill_tokens = prefill_tokens;
...@@ -235,6 +238,7 @@ impl< ...@@ -235,6 +238,7 @@ impl<
token_sequence: request.token_seq, token_sequence: request.token_seq,
isl: request.isl_tokens, isl: request.isl_tokens,
overlap: selection.overlap_blocks, overlap: selection.overlap_blocks,
track_prefill_tokens: request.track_prefill_tokens,
expected_output_tokens: request.expected_output_tokens, expected_output_tokens: request.expected_output_tokens,
worker: selection.worker, worker: selection.worker,
lora_name: request.lora_name.clone(), lora_name: request.lora_name.clone(),
...@@ -376,6 +380,7 @@ mod tests { ...@@ -376,6 +380,7 @@ mod tests {
overlaps: OverlapScores::default(), overlaps: OverlapScores::default(),
decode_blocks: HashMap::new(), decode_blocks: HashMap::new(),
prefill_tokens: HashMap::new(), prefill_tokens: HashMap::new(),
track_prefill_tokens: true,
router_config_override: None, router_config_override: None,
update_states: true, update_states: true,
lora_name: None, lora_name: None,
...@@ -695,6 +700,7 @@ mod tests { ...@@ -695,6 +700,7 @@ mod tests {
overlaps: OverlapScores::default(), overlaps: OverlapScores::default(),
decode_blocks: HashMap::new(), decode_blocks: HashMap::new(),
prefill_tokens: HashMap::new(), prefill_tokens: HashMap::new(),
track_prefill_tokens: true,
router_config_override: None, router_config_override: None,
update_states: true, update_states: true,
lora_name: None, lora_name: None,
...@@ -719,4 +725,31 @@ mod tests { ...@@ -719,4 +725,31 @@ mod tests {
.unwrap(); .unwrap();
slots.free(&"filter-0".to_string()).await.unwrap(); slots.free(&"filter-0".to_string()).await.unwrap();
} }
#[tokio::test(flavor = "multi_thread")]
async fn test_queue_busy_check_ignores_untracked_prefill_tokens() {
let (queue, slots) = make_queue(1, 16, 256, Some(0.0));
let (mut req1, rx1) = make_request("req-1", 256);
req1.track_prefill_tokens = false;
queue.enqueue(req1).await;
let _resp1 = rx1.await.unwrap().unwrap();
assert_eq!(
slots
.active_tokens()
.get(&WorkerWithDpRank::new(0, 0))
.copied(),
Some(0)
);
let (req2, rx2) = make_request("req-2", 256);
queue.enqueue(req2).await;
let _resp2 = rx2.await.unwrap().unwrap();
assert_eq!(queue.pending_count(), 0);
let _ = slots.mark_prefill_completed(&"req-1".to_string()).await;
let _ = slots.free(&"req-1".to_string()).await;
let _ = slots.mark_prefill_completed(&"req-2".to_string()).await;
let _ = slots.free(&"req-2".to_string()).await;
}
} }
...@@ -42,6 +42,7 @@ pub struct SchedulingRequest { ...@@ -42,6 +42,7 @@ pub struct SchedulingRequest {
pub overlaps: OverlapScores, pub overlaps: OverlapScores,
pub decode_blocks: HashMap<WorkerWithDpRank, usize>, pub decode_blocks: HashMap<WorkerWithDpRank, usize>,
pub prefill_tokens: HashMap<WorkerWithDpRank, usize>, pub prefill_tokens: HashMap<WorkerWithDpRank, usize>,
pub track_prefill_tokens: bool,
pub router_config_override: Option<RouterConfigOverride>, pub router_config_override: Option<RouterConfigOverride>,
pub update_states: bool, pub update_states: bool,
pub lora_name: Option<String>, pub lora_name: Option<String>,
......
...@@ -97,6 +97,7 @@ pub struct SequenceRequest { ...@@ -97,6 +97,7 @@ pub struct SequenceRequest {
pub token_sequence: Option<Vec<SequenceHash>>, pub token_sequence: Option<Vec<SequenceHash>>,
pub isl: usize, pub isl: usize,
pub overlap: u32, pub overlap: u32,
pub track_prefill_tokens: bool,
pub expected_output_tokens: Option<u32>, pub expected_output_tokens: Option<u32>,
pub worker: WorkerWithDpRank, pub worker: WorkerWithDpRank,
pub lora_name: Option<String>, pub lora_name: Option<String>,
...@@ -221,6 +222,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -221,6 +222,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
token_sequence, token_sequence,
isl, isl,
overlap, overlap,
track_prefill_tokens,
expected_output_tokens, expected_output_tokens,
} => { } => {
self.request_to_worker self.request_to_worker
...@@ -233,12 +235,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -233,12 +235,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
let table = self.workers.read(); let table = self.workers.read();
if let Some(&idx) = table.index.get(&event.worker) { if let Some(&idx) = table.index.get(&event.worker) {
table.slots[idx].1.write().add_request( table.slots[idx].1.write().add_request_with_prefill_tracking(
event.request_id.clone(), event.request_id.clone(),
token_sequence.clone(), token_sequence.clone(),
*isl, *isl,
*overlap, *overlap,
*expected_output_tokens, *expected_output_tokens,
*track_prefill_tokens,
); );
} else { } else {
tracing::warn!( tracing::warn!(
...@@ -380,6 +383,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -380,6 +383,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
token_sequence, token_sequence,
isl, isl,
overlap, overlap,
track_prefill_tokens,
expected_output_tokens, expected_output_tokens,
worker, worker,
lora_name, lora_name,
...@@ -409,12 +413,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -409,12 +413,13 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
.get(&worker) .get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?; .ok_or(SequenceError::WorkerNotFound { worker })?;
let mut seq = table.slots[idx].1.write(); let mut seq = table.slots[idx].1.write();
seq.add_request( seq.add_request_with_prefill_tracking(
request_id, request_id,
token_sequence, token_sequence,
isl, isl,
overlap, overlap,
expected_output_tokens, expected_output_tokens,
track_prefill_tokens,
) )
}; };
...@@ -437,6 +442,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -437,6 +442,7 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
token_sequence: req.token_sequence.clone(), token_sequence: req.token_sequence.clone(),
isl: req.isl, isl: req.isl,
overlap: req.overlap, overlap: req.overlap,
track_prefill_tokens: req.track_prefill_tokens,
expected_output_tokens: req.expected_output_tokens, expected_output_tokens: req.expected_output_tokens,
}, },
router_id: self.router_id, router_id: self.router_id,
...@@ -527,6 +533,10 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -527,6 +533,10 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
/// ///
/// Note: This operation is idempotent. Calling it multiple times for the same request /// Note: This operation is idempotent. Calling it multiple times for the same request
/// will log a warning but not return an error (double free is allowed). /// will log a warning but not return an error (double free is allowed).
///
/// This also performs the underlying prefill-complete cleanup via
/// [`ActiveSequences::free`], so callers do not need to call
/// [`Self::mark_prefill_completed`] before freeing a completed request.
pub async fn free(&self, request_id: &RequestId) -> Result<(), SequenceError> { pub async fn free(&self, request_id: &RequestId) -> Result<(), SequenceError> {
if !self.request_to_worker.contains_key(request_id) { if !self.request_to_worker.contains_key(request_id) {
tracing::debug!("Request {request_id} not found, already freed (idempotent)"); tracing::debug!("Request {request_id} not found, already freed (idempotent)");
...@@ -696,6 +706,19 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -696,6 +706,19 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
) -> ( ) -> (
HashMap<WorkerWithDpRank, usize>, HashMap<WorkerWithDpRank, usize>,
HashMap<WorkerWithDpRank, usize>, HashMap<WorkerWithDpRank, usize>,
) {
self.potential_blocks_and_tokens_with_prefill_tracking(token_sequence, isl, overlaps, true)
}
pub fn potential_blocks_and_tokens_with_prefill_tracking(
&self,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlaps: OverlapScores,
track_prefill_tokens: bool,
) -> (
HashMap<WorkerWithDpRank, usize>,
HashMap<WorkerWithDpRank, usize>,
) { ) {
#[cfg(feature = "bench")] #[cfg(feature = "bench")]
let start = tokio::time::Instant::now(); let start = tokio::time::Instant::now();
...@@ -711,9 +734,14 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -711,9 +734,14 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
for (worker, lock) in &table.slots { for (worker, lock) in &table.slots {
let overlap = *overlaps.scores.get(worker).unwrap_or(&0); let overlap = *overlaps.scores.get(worker).unwrap_or(&0);
let (blocks, tokens) = let (blocks, tokens) = lock
lock.read() .read()
.potential_blocks_and_tokens(token_sequence, isl, overlap); .potential_blocks_and_tokens_with_prefill_tracking(
token_sequence,
isl,
overlap,
track_prefill_tokens,
);
potential_blocks.insert(*worker, blocks); potential_blocks.insert(*worker, blocks);
potential_tokens.insert(*worker, tokens); potential_tokens.insert(*worker, tokens);
} }
...@@ -832,3 +860,44 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> { ...@@ -832,3 +860,44 @@ impl<P: SequencePublisher + 'static> ActiveSequencesMultiWorker<P> {
}); });
} }
} }
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
use crate::test_utils::NoopSequencePublisher;
fn make_sequences() -> ActiveSequencesMultiWorker<NoopSequencePublisher> {
ActiveSequencesMultiWorker::new(
NoopSequencePublisher,
4,
HashMap::from([(1_u64, (0_u32, 1_u32))]),
false,
0,
"test",
)
}
#[tokio::test]
async fn add_request_can_skip_prefill_token_tracking() {
let sequences = make_sequences();
let worker = WorkerWithDpRank::new(1, 0);
sequences
.add_request(SequenceRequest {
request_id: "req-1".to_string(),
token_sequence: Some(vec![1, 2, 3]),
isl: 12,
overlap: 0,
track_prefill_tokens: false,
expected_output_tokens: None,
worker,
lora_name: None,
})
.await
.unwrap();
assert_eq!(sequences.active_tokens().get(&worker).copied(), Some(0));
}
}
...@@ -143,6 +143,27 @@ impl ActiveSequences { ...@@ -143,6 +143,27 @@ impl ActiveSequences {
isl: usize, isl: usize,
overlap: u32, overlap: u32,
expected_output_tokens: Option<u32>, expected_output_tokens: Option<u32>,
) -> HashSet<RequestId> {
self.add_request_with_prefill_tracking(
request_id,
token_sequence,
isl,
overlap,
expected_output_tokens,
true,
)
}
/// Add a new request with optional prompt-token load accounting.
/// Returns the set of expired request IDs that were removed during cleanup.
pub fn add_request_with_prefill_tracking(
&mut self,
request_id: RequestId,
token_sequence: Option<Vec<SequenceHash>>,
isl: usize,
overlap: u32,
expected_output_tokens: Option<u32>,
track_prefill_tokens: bool,
) -> HashSet<RequestId> { ) -> HashSet<RequestId> {
// Check for double-add and log error, returning early // Check for double-add and log error, returning early
if self.active_seqs.contains_key(&request_id) { if self.active_seqs.contains_key(&request_id) {
...@@ -153,7 +174,11 @@ impl ActiveSequences { ...@@ -153,7 +174,11 @@ impl ActiveSequences {
// Lazily check and clean up expired requests, capturing removed IDs // Lazily check and clean up expired requests, capturing removed IDs
let removed_requests = self.force_expiry(); let removed_requests = self.force_expiry();
let prefill_tokens = self.new_tokens(isl, overlap); let prefill_tokens = if track_prefill_tokens {
self.new_tokens(isl, overlap)
} else {
0
};
self.prefill_tokens self.prefill_tokens
.insert(request_id.clone(), prefill_tokens); .insert(request_id.clone(), prefill_tokens);
self.active_tokens += prefill_tokens; self.active_tokens += prefill_tokens;
...@@ -208,13 +233,27 @@ impl ActiveSequences { ...@@ -208,13 +233,27 @@ impl ActiveSequences {
token_sequence: Option<&[SequenceHash]>, token_sequence: Option<&[SequenceHash]>,
isl: usize, isl: usize,
overlap: u32, overlap: u32,
) -> (usize, usize) {
self.potential_blocks_and_tokens_with_prefill_tracking(token_sequence, isl, overlap, true)
}
pub fn potential_blocks_and_tokens_with_prefill_tracking(
&self,
token_sequence: Option<&[SequenceHash]>,
isl: usize,
overlap: u32,
track_prefill_tokens: bool,
) -> (usize, usize) { ) -> (usize, usize) {
let potential_blocks = if let Some(token_seq) = token_sequence { let potential_blocks = if let Some(token_seq) = token_sequence {
self.new_blocks(token_seq) + self.active_blocks() self.new_blocks(token_seq) + self.active_blocks()
} else { } else {
self.active_blocks() self.active_blocks()
}; };
let potential_tokens = self.new_tokens(isl, overlap) + self.active_tokens; let potential_tokens = if track_prefill_tokens {
self.new_tokens(isl, overlap) + self.active_tokens
} else {
self.active_tokens
};
(potential_blocks, potential_tokens) (potential_blocks, potential_tokens)
} }
...@@ -232,7 +271,10 @@ impl ActiveSequences { ...@@ -232,7 +271,10 @@ impl ActiveSequences {
self.new_blocks(token_sequence) + self.active_blocks() self.new_blocks(token_sequence) + self.active_blocks()
} }
/// Free all blocks associated with a request /// Free all blocks associated with a request.
///
/// This implicitly calls [`Self::mark_prefill_completed`] first, so callers do not need
/// to invoke both when the request is finishing.
pub fn free(&mut self, request_id: &RequestId) -> usize { pub fn free(&mut self, request_id: &RequestId) -> usize {
self.mark_prefill_completed(request_id); self.mark_prefill_completed(request_id);
...@@ -424,6 +466,48 @@ mod tests { ...@@ -424,6 +466,48 @@ mod tests {
assert_eq!(seq_manager.active_tokens(), 0); assert_eq!(seq_manager.active_tokens(), 0);
} }
#[test]
fn test_add_request_without_prefill_tracking_keeps_active_tokens_zero() {
let mut seq_manager = ActiveSequences::new(4);
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1, 2, 3]),
12,
0,
None,
false,
);
assert_eq!(seq_manager.active_tokens(), 0);
seq_manager.mark_prefill_completed(&"r1".to_string());
assert_eq!(seq_manager.active_tokens(), 0);
seq_manager.free(&"r1".to_string());
assert_eq!(seq_manager.active_blocks(), 0);
}
#[test]
fn test_potential_blocks_and_tokens_without_prefill_tracking_ignores_prompt_load() {
let mut seq_manager = ActiveSequences::new(4);
seq_manager.add_request_with_prefill_tracking(
"r1".to_string(),
Some(vec![1, 2, 3]),
12,
0,
None,
false,
);
let (blocks, tokens) = seq_manager.potential_blocks_and_tokens_with_prefill_tracking(
Some(&[1, 2, 3, 4]),
16,
0,
false,
);
assert_eq!(blocks, 4);
assert_eq!(tokens, 0);
}
#[tokio::test(start_paused = true)] #[tokio::test(start_paused = true)]
async fn test_force_expiry() { async fn test_force_expiry() {
let block_size = 4; let block_size = 4;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc; use std::time::Instant;
use std::time::{Duration, Instant};
use anyhow::Result; use anyhow::Result;
use dynamo_kv_router::{ use dynamo_kv_router::{
ConcurrentRadixTree, ThreadPoolIndexer,
approx::PruneConfig,
config::{KvRouterConfig, RouterConfigOverride}, config::{KvRouterConfig, RouterConfigOverride},
indexer::{GetWorkersRequest, KvIndexer, KvIndexerInterface, KvIndexerMetrics, KvRouterError}, indexer::KvRouterError,
protocols::KV_EVENT_SUBJECT, protocols::KV_EVENT_SUBJECT,
protocols::{ protocols::{
BlockExtraInfo, BlockHashOptions, DpRank, LocalBlockHash, OverlapScores, RouterEvent, BlockExtraInfo, BlockHashOptions, DpRank, RouterEvent, RouterRequest, RouterResponse,
RouterRequest, RouterResponse, TokensWithHashes, WorkerId, WorkerWithDpRank, TokensWithHashes, WorkerId, WorkerWithDpRank, compute_block_hash_for_seq,
compute_block_hash_for_seq,
}, },
}; };
use dynamo_runtime::{ use dynamo_runtime::{
...@@ -29,30 +25,29 @@ use dynamo_runtime::{ ...@@ -29,30 +25,29 @@ use dynamo_runtime::{
traits::DistributedRuntimeProvider, traits::DistributedRuntimeProvider,
}; };
use futures::stream; use futures::stream;
use tokio::sync::oneshot;
use tracing::Instrument; use tracing::Instrument;
use validator::Validate; use validator::Validate;
pub mod cache_control; pub mod cache_control;
pub mod indexer;
mod jetstream; mod jetstream;
pub mod metrics; pub mod metrics;
pub mod prefill_router; pub mod prefill_router;
pub mod publisher; pub mod publisher;
pub mod push_router; pub mod push_router;
pub mod remote_indexer;
pub mod scheduler; pub mod scheduler;
pub mod sequence; pub mod sequence;
pub mod subscriber; pub mod subscriber;
pub mod worker_query; pub mod worker_query;
pub use cache_control::{CacheControlClient, spawn_pin_prefix}; pub use cache_control::{CacheControlClient, spawn_pin_prefix};
pub use indexer::Indexer;
pub use prefill_router::PrefillRouter; pub use prefill_router::PrefillRouter;
pub use push_router::{DirectRoutingRouter, KvPushRouter}; pub use push_router::{DirectRoutingRouter, KvPushRouter};
use crate::{ use crate::{
discovery::RuntimeConfigWatch, discovery::RuntimeConfigWatch,
kv_router::{ kv_router::{
remote_indexer::RemoteIndexer,
scheduler::{DefaultWorkerSelector, KvScheduler, PotentialLoad}, scheduler::{DefaultWorkerSelector, KvScheduler, PotentialLoad},
sequence::{SequenceError, SequenceRequest}, sequence::{SequenceError, SequenceRequest},
}, },
...@@ -108,188 +103,6 @@ pub fn router_discovery_query(namespace: String, component: String) -> Discovery ...@@ -108,188 +103,6 @@ pub fn router_discovery_query(namespace: String, component: String) -> Discovery
} }
} }
#[derive(Clone)]
pub enum Indexer {
/// Single-threaded radix tree with channel-based event processing.
/// Supports TTL-based expiration and size-based pruning.
/// Has the ability to persist and snapshot states.
KvIndexer(KvIndexer),
/// Concurrent radix tree with a thread pool for event processing.
/// Uses sticky worker routing for per-worker event serialization.
/// Does not support TTL/pruning.
Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTree>>),
/// Forwards queries to a standalone KV indexer service via the request plane.
/// The standalone indexer manages its own radix tree and event subscription.
Remote(Arc<RemoteIndexer>),
/// Used when we do not wish to use the indexer at all (e.g., when overlap_score_weight is 0).
/// Note: This will cause KV events to accumulate in JetStream as we do not regularly purge them.
None,
}
impl Indexer {
pub async fn new(
component: &dynamo_runtime::component::Component,
kv_router_config: &KvRouterConfig,
block_size: u32,
model_name: Option<String>,
) -> Result<Self> {
if kv_router_config.overlap_score_weight == 0.0 {
return Ok(Indexer::None);
}
// Remote indexer: forward queries to a standalone KV indexer service.
if let Some(ref indexer_component_name) = kv_router_config.remote_indexer_component {
let model_name = model_name.ok_or_else(|| {
anyhow::anyhow!(
"model_name is required when remote_indexer_component is configured"
)
})?;
tracing::info!(
remote_indexer_component = %indexer_component_name,
model_name,
"Using remote KV indexer"
);
let remote = RemoteIndexer::new(component, indexer_component_name, model_name).await?;
return Ok(Indexer::Remote(Arc::new(remote)));
}
// Approximate mode (--no-kv-events): always use single-threaded KvIndexer
// with TTL/pruning regardless of event_threads, since updates come from
// routing decisions only, not live KV events from workers.
if !kv_router_config.use_kv_events {
let kv_indexer_metrics = KvIndexerMetrics::from_component(component);
let cancellation_token = component.drt().primary_token();
let prune_config = Some(PruneConfig {
ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
max_tree_size: kv_router_config.router_max_tree_size,
prune_target_ratio: kv_router_config.router_prune_target_ratio,
});
return Ok(Indexer::KvIndexer(KvIndexer::new_with_frequency(
cancellation_token,
None,
block_size,
kv_indexer_metrics,
prune_config,
)));
}
if kv_router_config.router_event_threads > 1 {
return Ok(Indexer::Concurrent(Arc::new(ThreadPoolIndexer::new(
ConcurrentRadixTree::new(),
kv_router_config.router_event_threads as usize,
block_size,
))));
}
let kv_indexer_metrics = KvIndexerMetrics::from_component(component);
let cancellation_token = component.drt().primary_token();
Ok(Indexer::KvIndexer(KvIndexer::new_with_frequency(
cancellation_token,
None, // expiration_duration for frequency tracking
block_size,
kv_indexer_metrics,
None,
)))
}
pub(crate) async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
match self {
Indexer::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Indexer::Concurrent(tpi) => tpi.find_matches(sequence).await,
Indexer::Remote(remote) => remote.find_matches(sequence).await.map_err(|e| {
tracing::warn!(error = %e, "Remote indexer query failed");
KvRouterError::IndexerOffline
}),
Indexer::None => Ok(OverlapScores::new()),
}
}
pub(crate) async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
match self {
Indexer::KvIndexer(indexer) => indexer.dump_events().await,
Indexer::Concurrent(tpi) => tpi.dump_events().await,
Indexer::Remote(_) => Ok(Vec::new()),
Indexer::None => {
panic!(
"Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
);
}
}
}
pub(crate) async fn process_routing_decision_for_request(
&self,
tokens_with_hashes: &mut TokensWithHashes,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
match self {
Indexer::KvIndexer(indexer) => {
indexer
.process_routing_decision_for_request(tokens_with_hashes, worker)
.await
}
Indexer::Concurrent(tpi) => {
tpi.process_routing_decision_for_request(tokens_with_hashes, worker)
.await
}
Indexer::Remote(_) => Ok(()),
Indexer::None => Ok(()),
}
}
pub(crate) async fn apply_event(&self, event: RouterEvent) {
match self {
Indexer::KvIndexer(indexer) => {
if let Err(e) = indexer.event_sender().send(event).await {
tracing::warn!("Failed to send event to indexer: {e}");
}
}
Indexer::Concurrent(tpi) => tpi.apply_event(event).await,
Indexer::Remote(_) => {} // standalone indexer gets events directly
Indexer::None => {}
}
}
pub(crate) async fn remove_worker(&self, worker_id: WorkerId) {
match self {
Indexer::KvIndexer(indexer) => {
if let Err(e) = indexer.remove_worker_sender().send(worker_id).await {
tracing::warn!("Failed to send worker removal for {worker_id}: {e}");
}
}
Indexer::Concurrent(tpi) => {
KvIndexerInterface::remove_worker(tpi.as_ref(), worker_id).await;
}
Indexer::Remote(_) => {} // standalone indexer manages its own workers
Indexer::None => {}
}
}
pub(crate) async fn get_workers(&self) -> Vec<WorkerId> {
match self {
Indexer::KvIndexer(indexer) => {
let (resp_tx, resp_rx) = oneshot::channel();
let req = GetWorkersRequest { resp: resp_tx };
if let Err(e) = indexer.get_workers_sender().send(req).await {
tracing::warn!("Failed to send get_workers request: {e}");
return Vec::new();
}
resp_rx.await.unwrap_or_default()
}
Indexer::Concurrent(tpi) => tpi.backend().get_workers(),
Indexer::Remote(_) => Vec::new(),
Indexer::None => Vec::new(),
}
}
}
/// A KvRouter only decides which worker you should use. It doesn't send you there. /// A KvRouter only decides which worker you should use. It doesn't send you there.
/// TODO: Rename this to indicate it only selects a worker, it does not route. /// TODO: Rename this to indicate it only selects a worker, it does not route.
pub struct KvRouter<Sel = DefaultWorkerSelector> pub struct KvRouter<Sel = DefaultWorkerSelector>
...@@ -529,6 +342,9 @@ where ...@@ -529,6 +342,9 @@ where
hash_options, hash_options,
None, None,
); );
let track_prefill_tokens = self
.kv_router_config
.track_prefill_tokens(router_config_override);
if let Err(e) = self if let Err(e) = self
.scheduler .scheduler
...@@ -537,6 +353,7 @@ where ...@@ -537,6 +353,7 @@ where
token_sequence: maybe_seq_hashes, token_sequence: maybe_seq_hashes,
isl: isl_tokens, isl: isl_tokens,
overlap: overlap_blocks, overlap: overlap_blocks,
track_prefill_tokens,
expected_output_tokens, expected_output_tokens,
worker, worker,
lora_name, lora_name,
...@@ -623,12 +440,17 @@ where ...@@ -623,12 +440,17 @@ where
hash_options, hash_options,
Some(&block_hashes), Some(&block_hashes),
); );
let track_prefill_tokens = self
.kv_router_config
.track_prefill_tokens(router_config_override);
let overlap_scores = self.indexer.find_matches(block_hashes).await?; let overlap_scores = self.indexer.find_matches(block_hashes).await?;
Ok(self Ok(self.scheduler.get_potential_loads(
.scheduler maybe_seq_hashes,
.get_potential_loads(maybe_seq_hashes, isl_tokens, overlap_scores)) isl_tokens,
overlap_scores,
track_prefill_tokens,
))
} }
/// Dump all events from the indexer /// Dump all events from the indexer
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use futures::StreamExt;
use dynamo_kv_router::{
ConcurrentRadixTree, ThreadPoolIndexer,
approx::PruneConfig,
config::KvRouterConfig,
indexer::{
IndexerQueryRequest, IndexerQueryResponse, KV_INDEXER_QUERY_ENDPOINT, KvIndexer,
KvIndexerInterface, KvIndexerMetrics, KvRouterError,
},
protocols::{
LocalBlockHash, OverlapScores, RouterEvent, TokensWithHashes, WorkerId, WorkerWithDpRank,
},
};
use dynamo_runtime::{
component::Component,
pipeline::{ManyOut, RouterMode, SingleIn, network::egress::push_router::PushRouter},
traits::DistributedRuntimeProvider,
};
use tokio::sync::oneshot;
pub struct RemoteIndexer {
router: PushRouter<IndexerQueryRequest, IndexerQueryResponse>,
model_name: String,
namespace: String,
}
impl RemoteIndexer {
async fn new(
component: &Component,
indexer_component_name: &str,
model_name: String,
) -> Result<Self> {
let namespace = component.namespace().name();
let indexer_ns = component.namespace();
let indexer_component = indexer_ns.component(indexer_component_name)?;
let endpoint = indexer_component.endpoint(KV_INDEXER_QUERY_ENDPOINT);
let client = endpoint.client().await?;
let router =
PushRouter::from_client_no_fault_detection(client, RouterMode::RoundRobin).await?;
Ok(Self {
router,
model_name,
namespace,
})
}
async fn find_matches(&self, block_hashes: Vec<LocalBlockHash>) -> Result<OverlapScores> {
let request = IndexerQueryRequest {
model_name: self.model_name.clone(),
namespace: self.namespace.clone(),
block_hashes,
};
let mut stream: ManyOut<IndexerQueryResponse> =
self.router.round_robin(SingleIn::new(request)).await?;
match stream.next().await {
Some(IndexerQueryResponse::Scores(scores)) => Ok(scores.into()),
Some(IndexerQueryResponse::Error(msg)) => {
Err(anyhow::anyhow!("Remote indexer error: {}", msg))
}
None => Err(anyhow::anyhow!("Remote indexer returned empty response")),
}
}
}
#[derive(Clone)]
pub enum Indexer {
KvIndexer(KvIndexer),
Concurrent(Arc<ThreadPoolIndexer<ConcurrentRadixTree>>),
Remote(Arc<RemoteIndexer>),
None,
}
impl Indexer {
pub async fn new(
component: &Component,
kv_router_config: &KvRouterConfig,
block_size: u32,
model_name: Option<String>,
) -> Result<Self> {
if kv_router_config.overlap_score_weight == 0.0 {
return Ok(Self::None);
}
if let Some(ref indexer_component_name) = kv_router_config.remote_indexer_component {
let model_name = model_name.ok_or_else(|| {
anyhow::anyhow!(
"model_name is required when remote_indexer_component is configured"
)
})?;
tracing::info!(
remote_indexer_component = %indexer_component_name,
model_name,
"Using remote KV indexer"
);
let remote = RemoteIndexer::new(component, indexer_component_name, model_name).await?;
return Ok(Self::Remote(Arc::new(remote)));
}
if !kv_router_config.use_kv_events {
let kv_indexer_metrics = KvIndexerMetrics::from_component(component);
let cancellation_token = component.drt().primary_token();
let prune_config = Some(PruneConfig {
ttl: Duration::from_secs_f64(kv_router_config.router_ttl_secs),
max_tree_size: kv_router_config.router_max_tree_size,
prune_target_ratio: kv_router_config.router_prune_target_ratio,
});
return Ok(Self::KvIndexer(KvIndexer::new_with_frequency(
cancellation_token,
None,
block_size,
kv_indexer_metrics,
prune_config,
)));
}
if kv_router_config.router_event_threads > 1 {
return Ok(Self::Concurrent(Arc::new(ThreadPoolIndexer::new(
ConcurrentRadixTree::new(),
kv_router_config.router_event_threads as usize,
block_size,
))));
}
let kv_indexer_metrics = KvIndexerMetrics::from_component(component);
let cancellation_token = component.drt().primary_token();
Ok(Self::KvIndexer(KvIndexer::new_with_frequency(
cancellation_token,
None,
block_size,
kv_indexer_metrics,
None,
)))
}
pub(crate) async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
match self {
Self::KvIndexer(indexer) => indexer.find_matches(sequence).await,
Self::Concurrent(tpi) => tpi.find_matches(sequence).await,
Self::Remote(remote) => remote.find_matches(sequence).await.map_err(|e| {
tracing::warn!(error = %e, "Remote indexer query failed");
KvRouterError::IndexerOffline
}),
Self::None => Ok(OverlapScores::new()),
}
}
pub(crate) async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
match self {
Self::KvIndexer(indexer) => indexer.dump_events().await,
Self::Concurrent(tpi) => tpi.dump_events().await,
Self::Remote(_) => Ok(Vec::new()),
Self::None => {
panic!(
"Cannot dump events: indexer does not exist (is overlap_score_weight set to 0?)"
);
}
}
}
pub(crate) async fn process_routing_decision_for_request(
&self,
tokens_with_hashes: &mut TokensWithHashes,
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
match self {
Self::KvIndexer(indexer) => {
indexer
.process_routing_decision_for_request(tokens_with_hashes, worker)
.await
}
Self::Concurrent(tpi) => {
tpi.process_routing_decision_for_request(tokens_with_hashes, worker)
.await
}
Self::Remote(_) | Self::None => Ok(()),
}
}
pub(crate) async fn apply_event(&self, event: RouterEvent) {
match self {
Self::KvIndexer(indexer) => {
if let Err(e) = indexer.event_sender().send(event).await {
tracing::warn!("Failed to send event to indexer: {e}");
}
}
Self::Concurrent(tpi) => tpi.apply_event(event).await,
Self::Remote(_) | Self::None => {}
}
}
pub(crate) async fn remove_worker(&self, worker_id: WorkerId) {
match self {
Self::KvIndexer(indexer) => {
if let Err(e) = indexer.remove_worker_sender().send(worker_id).await {
tracing::warn!("Failed to send worker removal for {worker_id}: {e}");
}
}
Self::Concurrent(tpi) => {
KvIndexerInterface::remove_worker(tpi.as_ref(), worker_id).await;
}
Self::Remote(_) | Self::None => {}
}
}
pub(crate) async fn get_workers(&self) -> Vec<WorkerId> {
match self {
Self::KvIndexer(indexer) => {
let (resp_tx, resp_rx) = oneshot::channel();
let req = dynamo_kv_router::indexer::GetWorkersRequest { resp: resp_tx };
if let Err(e) = indexer.get_workers_sender().send(req).await {
tracing::warn!("Failed to send get_workers request: {e}");
return Vec::new();
}
resp_rx.await.unwrap_or_default()
}
Self::Concurrent(tpi) => tpi.backend().get_workers(),
Self::Remote(_) | Self::None => Vec::new(),
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use anyhow::Result;
use tokio::sync::oneshot;
use dynamo_kv_router::config::KvRouterConfig;
use dynamo_runtime::{
component::{Client, Endpoint},
pipeline::{PushRouter, RouterMode},
protocols::annotated::Annotated,
};
use super::{InnerPrefillRouter, PrefillRouter};
use crate::{
discovery::ModelManager,
kv_router::KvPushRouter,
protocols::common::{
llm_backend::{LLMEngineOutput, PreprocessedRequest},
timing::WORKER_TYPE_PREFILL,
},
};
impl PrefillRouter {
/// Create a disabled prefill router that will never activate (passthrough only)
pub fn disabled(
model_manager: Arc<ModelManager>,
router_mode: RouterMode,
enforce_disagg: bool,
) -> Arc<Self> {
Arc::new(Self {
prefill_router: std::sync::OnceLock::new(),
model_manager,
endpoint_id: std::sync::OnceLock::new(),
cancel_token: tokio_util::sync::CancellationToken::new(),
router_mode,
enforce_disagg,
model_name: String::new(), // Not used for disabled router
namespace: String::new(), // Not used for disabled router
is_eagle: false,
})
}
#[expect(clippy::too_many_arguments)]
pub fn new(
activation_rx: oneshot::Receiver<Endpoint>,
model_manager: Arc<ModelManager>,
router_mode: RouterMode,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
enforce_disagg: bool,
model_name: String,
namespace: String,
is_eagle: bool,
) -> Arc<Self> {
let prefill_router = std::sync::OnceLock::new();
let cancel_token = tokio_util::sync::CancellationToken::new();
let router = Arc::new(Self {
prefill_router,
model_manager: model_manager.clone(),
endpoint_id: std::sync::OnceLock::new(),
cancel_token: cancel_token.clone(),
router_mode,
enforce_disagg,
model_name,
namespace,
is_eagle,
});
// Spawn background task to wait for activation
let router_clone = router.clone();
tokio::spawn(async move {
tokio::select! {
result = activation_rx => {
let Ok(endpoint) = result else {
tracing::debug!("Prefill router activation channel closed without receiving endpoint");
return;
};
if let Err(e) = router_clone.activate(
endpoint,
model_manager,
kv_cache_block_size,
kv_router_config,
).await {
tracing::error!(error = %e, "Failed to activate prefill router");
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("Prefill router activation cancelled");
}
}
});
router
}
/// Activate the prefill router with the provided endpoint
async fn activate(
&self,
endpoint: Endpoint,
model_manager: Arc<ModelManager>,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
) -> Result<()> {
tracing::info!(
router_mode = ?self.router_mode,
"Activating prefill router"
);
// Store endpoint_id for later use in resolve_prefill_worker
let _ = self.endpoint_id.set(endpoint.id());
// Start runtime config watcher for this endpoint (needed for get_disaggregated_endpoint)
// This must be done before creating the router so bootstrap info is available
model_manager
.get_or_create_runtime_config_watcher(&endpoint)
.await?;
let inner_router = if self.router_mode.is_kv_routing() {
// Create KV chooser using the endpoint (this is a prefill router)
let kv_chooser = model_manager
.kv_chooser_for(
&endpoint,
kv_cache_block_size,
kv_router_config,
WORKER_TYPE_PREFILL,
Some(self.model_name.clone()),
self.is_eagle,
)
.await?;
// Extract client from kv_chooser to ensure shared state
let client = kv_chooser.client().clone();
self.register_prefill_client(model_manager.as_ref(), &client);
// Build the PushRouter for prefill with KV mode using the shared client
let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
RouterMode::KV,
None, // busy_threshold
None, // worker_monitor
)
.await?;
// Wrap it in KvPushRouter
InnerPrefillRouter::KvRouter(Arc::new(KvPushRouter::new(push_router, kv_chooser)))
} else {
// Create client for simple router
let client = endpoint.client().await?;
self.register_prefill_client(model_manager.as_ref(), &client);
// Create simple push router with the frontend's router mode
// Note: Per-worker metrics (active_prefill_tokens, active_decode_blocks) are only
// available in KV routing mode where the router has actual bookkeeping.
let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
self.router_mode,
None, // busy_threshold
None, // worker_monitor
)
.await?;
InnerPrefillRouter::SimpleRouter(Arc::new(push_router))
};
// Set the router (ignore error if already set)
let _ = self.prefill_router.set(inner_router);
tracing::info!(
router_mode = ?self.router_mode,
"Prefill router activated successfully"
);
Ok(())
}
fn register_prefill_client(&self, model_manager: &ModelManager, client: &Client) {
if let Some(monitor) =
model_manager.get_worker_monitor_for_namespace(&self.model_name, &self.namespace)
{
monitor.set_prefill_client(client.clone());
}
}
}
...@@ -2,301 +2,41 @@ ...@@ -2,301 +2,41 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::HashSet; use std::collections::HashSet;
use std::sync::{Arc, OnceLock}; use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use futures::StreamExt; use futures::StreamExt;
use tokio::sync::{OwnedSemaphorePermit, oneshot}; use tokio::sync::OwnedSemaphorePermit;
use tokio_util::sync::CancellationToken;
use tracing::Instrument; use tracing::Instrument;
use dynamo_kv_router::{ use dynamo_kv_router::protocols::{BlockExtraInfo, WorkerId};
config::{KvRouterConfig, RouterConfigOverride},
protocols::{BlockExtraInfo, WorkerId},
};
use dynamo_runtime::{ use dynamo_runtime::{
component::Endpoint, engine::AsyncEngineContext,
pipeline::{ pipeline::{AsyncEngineContextProvider, Context, SingleIn},
AsyncEngine, AsyncEngineContextProvider, Context, ManyOut, Operator, PushRouter, protocols::maybe_error::MaybeError,
RouterMode, ServerStreamingEngine, SingleIn, async_trait,
},
protocols::{EndpointId, annotated::Annotated, maybe_error::MaybeError},
}; };
use crate::{ use super::{InnerPrefillRouter, PrefillError, PrefillResolveDecision, PrefillRouter};
discovery::ModelManager, use crate::protocols::common::{
kv_router::KvPushRouter, llm_backend::PreprocessedRequest,
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest}, preprocessor::{BootstrapInfo, PrefillResult},
protocols::common::preprocessor::{BootstrapInfo, PrefillResult},
protocols::common::timing::{RequestPhase, RequestTracker, WORKER_TYPE_PREFILL},
}; };
/// Errors that can occur during prefill routing
#[derive(Debug, thiserror::Error)]
pub enum PrefillError {
/// Prefill router has not been activated yet
#[error("Prefill router not yet activated")]
NotActivated,
/// TODO: Separate prefill worker error from prefill router error
/// Error during prefill execution
#[error("Prefill execution failed: {0}")]
PrefillError(
String,
#[source] Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
),
/// Disaggregated params not found in prefill response
#[error("No disaggregated params in prefill response: {0}")]
NoDisaggregatedParams(String),
}
/// Result of the prefill phase in `generate()`.
enum PrefillOutcome {
/// Bootstrap optimization: prefill spawned in background, bootstrap info ready
Bootstrap(BootstrapInfo),
/// Synchronous prefill completed with result
Completed(PrefillResult),
}
/// The inner router used by PrefillRouter
#[derive(Clone)]
enum InnerPrefillRouter {
/// KV-aware routing using KvPushRouter
KvRouter(Arc<KvPushRouter>),
/// Simple routing (RoundRobin, Random, Direct)
/// Note: Per-worker metrics (active_prefill_tokens, active_decode_blocks) are only
/// available in KV routing mode where the router has actual bookkeeping.
SimpleRouter(Arc<PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>>),
}
impl InnerPrefillRouter {
/// Generate with optional direct routing to specific worker.
/// For KvRouter, target_worker is ignored since prefill_worker_id is already set on the request.
/// For SimpleRouter, target_worker triggers direct routing via router.direct().
async fn generate_to_worker(
&self,
request: SingleIn<PreprocessedRequest>,
target_worker: Option<u64>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
match (self, target_worker) {
// KvRouter: prefill_worker_id already set on request, KvPushRouter::select_worker uses it
(InnerPrefillRouter::KvRouter(router), _) => router.generate(request).await,
(InnerPrefillRouter::SimpleRouter(router), Some(worker_id)) => {
router.direct(request, worker_id).await
}
(InnerPrefillRouter::SimpleRouter(router), None) => router.generate(request).await,
}
}
/// Select next worker (for non-KV modes only)
fn select_next_worker(&self) -> Option<u64> {
match self {
InnerPrefillRouter::SimpleRouter(router) => router.select_next_worker(),
InnerPrefillRouter::KvRouter(_) => None,
}
}
}
/// PrefillRouter is a forward-only operator that sits between Migration and the decode router.
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request.
///
/// Modes:
/// - Query-only: `query_instance_id` annotation present → returns worker IDs without execution
/// - Pre-routed: `prefill_worker_id`/`decode_worker_id` set → routes to specified workers
/// - Normal: Worker IDs determined by router based on KV cache state
pub struct PrefillRouter {
prefill_router: OnceLock<InnerPrefillRouter>,
model_manager: Arc<ModelManager>,
endpoint_id: OnceLock<EndpointId>,
cancel_token: CancellationToken,
router_mode: RouterMode,
enforce_disagg: bool,
/// Model name used to look up the worker monitor for prefill client registration
model_name: String,
/// Namespace used to look up the correct WorkerSet's worker monitor
namespace: String,
is_eagle: bool,
}
impl PrefillRouter { impl PrefillRouter {
/// Create a disabled prefill router that will never activate (passthrough only)
pub fn disabled(
model_manager: Arc<ModelManager>,
router_mode: RouterMode,
enforce_disagg: bool,
) -> Arc<Self> {
Arc::new(Self {
prefill_router: OnceLock::new(),
model_manager,
endpoint_id: OnceLock::new(),
cancel_token: CancellationToken::new(),
router_mode,
enforce_disagg,
model_name: String::new(), // Not used for disabled router
namespace: String::new(), // Not used for disabled router
is_eagle: false,
})
}
#[expect(clippy::too_many_arguments)]
pub fn new(
activation_rx: oneshot::Receiver<Endpoint>,
model_manager: Arc<ModelManager>,
router_mode: RouterMode,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
enforce_disagg: bool,
model_name: String,
namespace: String,
is_eagle: bool,
) -> Arc<Self> {
let prefill_router = OnceLock::new();
let cancel_token = CancellationToken::new();
let router = Arc::new(Self {
prefill_router,
model_manager: model_manager.clone(),
endpoint_id: OnceLock::new(),
cancel_token: cancel_token.clone(),
router_mode,
enforce_disagg,
model_name,
namespace,
is_eagle,
});
// Spawn background task to wait for activation
let router_clone = router.clone();
tokio::spawn(async move {
tokio::select! {
result = activation_rx => {
let Ok(endpoint) = result else {
tracing::debug!("Prefill router activation channel closed without receiving endpoint");
return;
};
if let Err(e) = router_clone.activate(
endpoint,
model_manager,
kv_cache_block_size,
kv_router_config,
).await {
tracing::error!(error = %e, "Failed to activate prefill router");
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("Prefill router activation cancelled");
}
}
});
router
}
/// Activate the prefill router with the provided endpoint
async fn activate(
&self,
endpoint: Endpoint,
model_manager: Arc<ModelManager>,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
) -> Result<()> {
tracing::info!(
router_mode = ?self.router_mode,
"Activating prefill router"
);
// Store endpoint_id for later use in resolve_prefill_worker
let _ = self.endpoint_id.set(endpoint.id());
// Start runtime config watcher for this endpoint (needed for get_disaggregated_endpoint)
// This must be done before creating the router so bootstrap info is available
model_manager
.get_or_create_runtime_config_watcher(&endpoint)
.await?;
let inner_router = if self.router_mode.is_kv_routing() {
// Create KV chooser using the endpoint (this is a prefill router)
let kv_chooser = model_manager
.kv_chooser_for(
&endpoint,
kv_cache_block_size,
kv_router_config,
WORKER_TYPE_PREFILL,
Some(self.model_name.clone()),
self.is_eagle,
)
.await?;
// Extract client from kv_chooser to ensure shared state
let client = kv_chooser.client().clone();
// Register prefill client with worker monitor for TTFT metric cleanup in disaggregated mode
if let Some(monitor) =
model_manager.get_worker_monitor_for_namespace(&self.model_name, &self.namespace)
{
monitor.set_prefill_client(client.clone());
}
// Build the PushRouter for prefill with KV mode using the shared client
let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
RouterMode::KV,
None, // busy_threshold
None, // worker_monitor
)
.await?;
// Wrap it in KvPushRouter
InnerPrefillRouter::KvRouter(Arc::new(KvPushRouter::new(push_router, kv_chooser)))
} else {
// Create client for simple router
let client = endpoint.client().await?;
// Register prefill client with worker monitor for TTFT metric cleanup in disaggregated mode
if let Some(monitor) =
model_manager.get_worker_monitor_for_namespace(&self.model_name, &self.namespace)
{
monitor.set_prefill_client(client.clone());
}
// Create simple push router with the frontend's router mode
// Note: Per-worker metrics (active_prefill_tokens, active_decode_blocks) are only
// available in KV routing mode where the router has actual bookkeeping.
let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
self.router_mode,
None, // busy_threshold
None, // worker_monitor
)
.await?;
InnerPrefillRouter::SimpleRouter(Arc::new(push_router))
};
// Set the router (ignore error if already set)
let _ = self.prefill_router.set(inner_router);
tracing::info!(
router_mode = ?self.router_mode,
"Prefill router activated successfully"
);
Ok(())
}
/// Select a prefill worker and resolve its bootstrap connection info. /// Select a prefill worker and resolve its bootstrap connection info.
/// If preselected_worker is provided (GAIE Stage 2), use it directly. /// If preselected_worker is provided (GAIE Stage 2), use it directly.
/// Otherwise, query for the best worker (KV mode) or select next worker (non-KV modes). /// Otherwise, query for the best worker (KV mode) or select next worker (non-KV modes).
async fn resolve_prefill_worker( pub(super) async fn resolve_prefill_worker(
&self, &self,
req: &PreprocessedRequest, req: &PreprocessedRequest,
preselected_worker: Option<u64>, preselected_worker: Option<u64>,
) -> Option<(u64, u32, BootstrapInfo)> { ) -> PrefillResolveDecision {
let endpoint_id = self.endpoint_id.get()?; let Some(endpoint_id) = self.endpoint_id.get() else {
self.prefill_router.get()?; return PrefillResolveDecision::NotActivated;
};
if self.prefill_router.get().is_none() {
return PrefillResolveDecision::NotActivated;
}
// Worker selection // Worker selection
let (worker_id, dp_rank) = if let Some(id) = preselected_worker { let (worker_id, dp_rank) = if let Some(id) = preselected_worker {
...@@ -333,16 +73,23 @@ impl PrefillRouter { ...@@ -333,16 +73,23 @@ impl PrefillRouter {
.await .await
{ {
Ok((worker_id, dp_rank)) => (worker_id, dp_rank), Ok((worker_id, dp_rank)) => (worker_id, dp_rank),
Err(_) => return None, Err(_) => return PrefillResolveDecision::Unavailable,
} }
}; };
// Get bootstrap info from ModelManager (works for ANY mode) // Get bootstrap info from ModelManager (works for ANY mode)
let endpoint = self let Some(endpoint) = self
.model_manager .model_manager
.get_disaggregated_endpoint(endpoint_id, worker_id)?; .get_disaggregated_endpoint(endpoint_id, worker_id)
let host = endpoint.bootstrap_host?; else {
let port = endpoint.bootstrap_port?; return PrefillResolveDecision::NoBootstrapEndpoint;
};
let Some(host) = endpoint.bootstrap_host else {
return PrefillResolveDecision::NoBootstrapEndpoint;
};
let Some(port) = endpoint.bootstrap_port else {
return PrefillResolveDecision::NoBootstrapEndpoint;
};
let bootstrap_room: u64 = rand::random_range(0..=i64::MAX.cast_unsigned()); let bootstrap_room: u64 = rand::random_range(0..=i64::MAX.cast_unsigned());
...@@ -356,31 +103,31 @@ impl PrefillRouter { ...@@ -356,31 +103,31 @@ impl PrefillRouter {
"Built bootstrap_info upfront before prefill" "Built bootstrap_info upfront before prefill"
); );
Some(( PrefillResolveDecision::Resolved {
worker_id, worker_id,
dp_rank, dp_rank,
BootstrapInfo { bootstrap_info: BootstrapInfo {
bootstrap_host: host, bootstrap_host: host,
bootstrap_port: port, bootstrap_port: port,
bootstrap_room, bootstrap_room,
}, },
)) }
} }
/// Execute prefill with the given router and extract structured result. /// Execute prefill with the given router and extract structured result.
/// ///
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization). /// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization).
/// ///
/// If `phase_permit` is provided, it is dropped after the first output is received, /// If `phase_transition_permit` is provided, it is dropped immediately after routing completes,
/// allowing subsequent `set_phase` calls to proceed. This is used in the bootstrap /// allowing subsequent `set_phase` calls to proceed. This preserves the current synchronization:
/// optimization path to ensure `record_worker_full` completes before the phase changes. /// the prefill route must finish `record_worker_full` before the phase can change to Decode.
/// ///
/// Returns (PrefillResult, Option<(worker_id, dp_rank)>). /// Returns (PrefillResult, Option<(worker_id, dp_rank)>).
async fn execute_prefill( pub(super) async fn execute_prefill(
router: Option<InnerPrefillRouter>, router: Option<InnerPrefillRouter>,
request: SingleIn<PreprocessedRequest>, request: SingleIn<PreprocessedRequest>,
target_worker: Option<u64>, target_worker: Option<u64>,
phase_permit: Option<OwnedSemaphorePermit>, phase_transition_permit: Option<OwnedSemaphorePermit>,
) -> Result<(PrefillResult, Option<(u64, u32)>), PrefillError> { ) -> Result<(PrefillResult, Option<(u64, u32)>), PrefillError> {
let router = router.ok_or(PrefillError::NotActivated)?; let router = router.ok_or(PrefillError::NotActivated)?;
let mut prefill_response = router let mut prefill_response = router
...@@ -393,9 +140,9 @@ impl PrefillRouter { ...@@ -393,9 +140,9 @@ impl PrefillRouter {
) )
})?; })?;
// Drop phase permit now - routing is complete, record_worker_full was called in select_worker. // Release the phase barrier now that routing completed and record_worker_full already ran.
// This unblocks set_phase(Decode) in the main task without waiting for prefill output. // Decode may proceed without waiting for prefill output streaming to finish.
drop(phase_permit); drop(phase_transition_permit);
let Some(first_output) = prefill_response.next().await else { let Some(first_output) = prefill_response.next().await else {
return Err(PrefillError::PrefillError( return Err(PrefillError::PrefillError(
...@@ -468,13 +215,13 @@ impl PrefillRouter { ...@@ -468,13 +215,13 @@ impl PrefillRouter {
/// ///
/// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization). /// Uses direct routing to target_worker when specified (for non-KV modes with bootstrap optimization).
/// ///
/// The `phase_permit` is passed to the spawned task and dropped after the first output, /// The `phase_transition_permit` is passed to the spawned task and released after routing
/// allowing the main task's `set_phase(Decode)` to proceed. /// completes, allowing the main task's `set_phase(Decode)` to proceed.
fn spawn_prefill_task( pub(super) fn spawn_prefill_task(
&self, &self,
prefill_request: SingleIn<PreprocessedRequest>, prefill_request: SingleIn<PreprocessedRequest>,
target_worker: Option<u64>, target_worker: Option<u64>,
phase_permit: OwnedSemaphorePermit, phase_transition_permit: OwnedSemaphorePermit,
) { ) {
let router = self.prefill_router.get().cloned(); let router = self.prefill_router.get().cloned();
// Capture current span to propagate trace context to the spawned task // Capture current span to propagate trace context to the spawned task
...@@ -486,7 +233,7 @@ impl PrefillRouter { ...@@ -486,7 +233,7 @@ impl PrefillRouter {
router, router,
prefill_request, prefill_request,
target_worker, target_worker,
Some(phase_permit), Some(phase_transition_permit),
) )
.await .await
{ {
...@@ -507,13 +254,6 @@ impl PrefillRouter { ...@@ -507,13 +254,6 @@ impl PrefillRouter {
/// ///
/// This is the shared worker selection logic used by both `resolve_prefill_worker` /// This is the shared worker selection logic used by both `resolve_prefill_worker`
/// and `query_route`. /// and `query_route`.
/// Register externally-provided workers in the prefill router's slot tracker.
pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
if let Some(InnerPrefillRouter::KvRouter(r)) = self.prefill_router.get() {
r.chooser.register_workers(worker_ids);
}
}
pub async fn query_prefill_worker( pub async fn query_prefill_worker(
&self, &self,
token_ids: &[u32], token_ids: &[u32],
...@@ -553,194 +293,30 @@ impl PrefillRouter { ...@@ -553,194 +293,30 @@ impl PrefillRouter {
r.peek_next_worker() r.peek_next_worker()
} }
.ok_or_else(|| anyhow::anyhow!("No workers available for prefill"))?; .ok_or_else(|| anyhow::anyhow!("No workers available for prefill"))?;
Ok((worker_id, u32::MAX)) Ok((worker_id, 0))
} }
} }
} }
/// Register externally-provided workers in the prefill router's slot tracker.
pub fn register_workers(&self, worker_ids: &HashSet<WorkerId>) {
if let Some(InnerPrefillRouter::KvRouter(r)) = self.prefill_router.get() {
r.chooser.register_workers(worker_ids);
}
}
/// Check if disaggregated mode is currently active (prefill router activated) /// Check if disaggregated mode is currently active (prefill router activated)
pub fn is_activated(&self) -> bool { pub fn is_activated(&self) -> bool {
self.prefill_router.get().is_some() self.prefill_router.get().is_some()
} }
} }
impl Drop for PrefillRouter { pub(super) fn link_child_context<T: Send + Sync + 'static>(
fn drop(&mut self) { engine_ctx: &Arc<dyn AsyncEngineContext>,
tracing::debug!("Dropping PrefillRouter, cancelling background activation task"); request: T,
self.cancel_token.cancel(); request_id: &str,
} ) -> Context<T> {
} let child_context = Context::with_id(request, request_id.to_string());
engine_ctx.link_child(child_context.context());
#[async_trait] child_context
impl
Operator<
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>,
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>,
> for PrefillRouter
{
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
// Extract request data while preserving context
let (mut req, context) = request.into_parts();
let request_id = context.id().to_string();
let engine_ctx = context.context();
// Save original max_tokens for decode
let original_max_tokens = req.stop_conditions.max_tokens;
// If prefill router is not activated (no prefill workers discovered),
// this is aggregated mode — route directly to decode.
// With --enforce-disagg, fail instead of falling back.
if self.prefill_router.get().is_none() {
if self.enforce_disagg {
return Err(anyhow::anyhow!(PrefillError::NotActivated));
}
return next.generate(context.map(|_| req)).await;
}
// Ensure tracker exists for routing decisions in disaggregated mode.
// Create one if not provided by the upstream DeltaGenerator.
if req.tracker.is_none() {
req.tracker = Some(Arc::new(RequestTracker::new()));
}
let tracker = req.tracker.as_ref().unwrap();
let prefill_phase_permit = tracker.set_phase(RequestPhase::Prefill).await;
// Prepare prefill request with max_tokens = 1 (clone after tracker is set)
let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1);
// Try to resolve prefill worker upfront: if we can get bootstrap info early,
// spawn prefill in background and proceed to decode immediately.
let preselected_worker = prefill_req
.routing
.as_ref()
.and_then(|r| r.prefill_worker_id);
if self.router_mode.is_direct_routing() && preselected_worker.is_none() {
return Err(anyhow::anyhow!(
"Prefill worker ID required in Direct routing mode but none found in request. \
Expected prefill_worker_id to be set via x-prefill-instance-id header by external router (e.g., EPP)."
));
}
let prefill_result = async {
if let Some((worker_id, dp_rank, bootstrap_info)) = self
.resolve_prefill_worker(&prefill_req, preselected_worker)
.await
{
// Bootstrap optimization path: spawn prefill in background
// We successfully used the peeked worker, so we must now advance the router state
// to ensure the next request gets a different worker.
if !self.router_mode.is_kv_routing()
&& let Some(router) = self.prefill_router.get()
{
router.select_next_worker();
}
let routing = prefill_req.routing_mut();
routing.prefill_worker_id = Some(worker_id);
routing.dp_rank = Some(dp_rank);
prefill_req.bootstrap_info = Some(bootstrap_info.clone());
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
// Pass phase permit to spawned task - it drops after first output (record_worker_full complete)
// This allows set_phase(Decode) below to proceed only after prefill routing is done
self.spawn_prefill_task(prefill_context, Some(worker_id), prefill_phase_permit);
Ok(PrefillOutcome::Bootstrap(bootstrap_info))
} else {
// Original prefill path: wait for prefill to complete
tracing::debug!("Using original prefill path");
// Drop the phase permit - we wait for completion
// so there's no race with set_phase(Decode) below
drop(prefill_phase_permit);
let prefill_context = Context::with_id(prefill_req, request_id.clone());
engine_ctx.link_child(prefill_context.context());
// In Direct mode, pass preselected_worker so execute_prefill uses
// router.direct() instead of router.generate() (which bails in Direct mode).
let (result, _worker_info) = Self::execute_prefill(
self.prefill_router.get().cloned(),
prefill_context,
preselected_worker,
None,
)
.await?;
Ok(PrefillOutcome::Completed(result))
}
}
.await;
// Abort if cancelled during prefill
if engine_ctx.is_stopped() || engine_ctx.is_killed() {
tracing::debug!("Abort entering decode after context is stopped or killed");
return Err(anyhow::anyhow!(
"Context id {} is stopped or killed",
engine_ctx.id()
));
}
// Handle prefill result
match prefill_result {
Ok(outcome) => {
tracing::debug!("Prefill completed, proceeding to decode");
// Set phase to Decode for the decode request.
// In bootstrap path, this blocks until the spawned prefill task drops its permit
// (after first output / record_worker_full completes), ensuring correct phase for routing.
if let Some(ref tracker) = req.tracker {
let _decode_permit = tracker.set_phase(RequestPhase::Decode).await;
// Permit is dropped immediately - decode proceeds, no need to hold it
}
let mut decode_req = req;
match outcome {
PrefillOutcome::Bootstrap(info) => {
decode_req.bootstrap_info = Some(info);
}
PrefillOutcome::Completed(result) => {
decode_req.prefill_result = Some(result);
}
}
// Restore original max_tokens for decode
decode_req.stop_conditions.max_tokens = original_max_tokens;
// Set router_config_override for decode:
// - overlap_score_weight = 0 (no KV cache overlap scoring for decode)
// - assume_kv_reuse = false (generate random hashes since decode workers
// may already have blocks cached from prefill transfer)
let existing_override = decode_req.router_config_override.take();
decode_req.router_config_override = Some(RouterConfigOverride {
overlap_score_weight: Some(0.0),
assume_kv_reuse: Some(false),
..existing_override.unwrap_or_default()
});
// Map the modified request through with preserved context
let decode_request = context.map(|_| decode_req);
next.generate(decode_request).await
}
Err(PrefillError::NotActivated) => {
tracing::error!("Prefill router not activated, failing request");
Err(anyhow::anyhow!(PrefillError::NotActivated))
}
Err(e) => {
tracing::error!(error = %e, "Remote prefill failed, failing request");
Err(anyhow::anyhow!(e))
}
}
}
} }
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use anyhow::Result;
use dynamo_runtime::{
pipeline::{AsyncEngine, ManyOut, PushRouter, SingleIn},
protocols::annotated::Annotated,
};
use crate::{
kv_router::KvPushRouter,
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
};
/// The inner router used by PrefillRouter
#[derive(Clone)]
pub(super) enum InnerPrefillRouter {
/// KV-aware routing using KvPushRouter
KvRouter(Arc<KvPushRouter>),
/// Simple routing (RoundRobin, Random, Direct)
/// Note: Per-worker metrics (active_prefill_tokens, active_decode_blocks) are only
/// available in KV routing mode where the router has actual bookkeeping.
SimpleRouter(Arc<PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>>),
}
impl InnerPrefillRouter {
/// Generate with optional direct routing to specific worker.
/// For KvRouter, target_worker is ignored since prefill_worker_id is already set on the request.
/// For SimpleRouter, target_worker triggers direct routing via router.direct().
pub(super) async fn generate_to_worker(
&self,
request: SingleIn<PreprocessedRequest>,
target_worker: Option<u64>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
match (self, target_worker) {
// KvRouter: prefill_worker_id already set on request, KvPushRouter::select_worker uses it
(InnerPrefillRouter::KvRouter(router), _) => router.generate(request).await,
(InnerPrefillRouter::SimpleRouter(router), Some(worker_id)) => {
router.direct(request, worker_id).await
}
(InnerPrefillRouter::SimpleRouter(router), None) => router.generate(request).await,
}
}
/// Select next worker (for non-KV modes only)
pub(super) fn select_next_worker(&self) -> Option<u64> {
match self {
InnerPrefillRouter::SimpleRouter(router) => router.select_next_worker(),
InnerPrefillRouter::KvRouter(_) => None,
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::{Arc, OnceLock};
use anyhow::Result;
use tokio_util::sync::CancellationToken;
use dynamo_runtime::{
pipeline::{
AsyncEngineContextProvider, ManyOut, Operator, RouterMode, ServerStreamingEngine, SingleIn,
async_trait,
},
protocols::{EndpointId, annotated::Annotated},
};
use crate::{
discovery::ModelManager,
protocols::common::{
llm_backend::{LLMEngineOutput, PreprocessedRequest},
timing::{RequestPhase, RequestTracker},
},
};
mod activation;
mod execution;
mod inner;
mod types;
use execution::link_child_context;
use inner::InnerPrefillRouter;
pub use types::PrefillError;
use types::{PrefillOutcome, PrefillResolveDecision, build_decode_router_override};
/// PrefillRouter is a forward-only operator that sits between Migration and the decode router.
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request.
///
/// Modes:
/// - Query-only: `query_instance_id` annotation present → returns worker IDs without execution
/// - Pre-routed: `prefill_worker_id`/`decode_worker_id` set → routes to specified workers
/// - Normal: Worker IDs determined by router based on KV cache state
pub struct PrefillRouter {
prefill_router: OnceLock<InnerPrefillRouter>,
model_manager: Arc<ModelManager>,
endpoint_id: OnceLock<EndpointId>,
cancel_token: CancellationToken,
router_mode: RouterMode,
enforce_disagg: bool,
/// Model name used to look up the worker monitor for prefill client registration
model_name: String,
/// Namespace used to look up the correct WorkerSet's worker monitor
namespace: String,
is_eagle: bool,
}
impl Drop for PrefillRouter {
fn drop(&mut self) {
tracing::debug!("Dropping PrefillRouter, cancelling background activation task");
self.cancel_token.cancel();
}
}
#[async_trait]
impl
Operator<
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>,
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>,
> for PrefillRouter
{
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
// Extract request data while preserving context
let (mut req, context) = request.into_parts();
let request_id = context.id().to_string();
let engine_ctx = context.context();
// Save original max_tokens for decode
let original_max_tokens = req.stop_conditions.max_tokens;
// If prefill router is not activated (no prefill workers discovered),
// this is aggregated mode — route directly to decode.
// With --enforce-disagg, fail instead of falling back.
if self.prefill_router.get().is_none() {
if self.enforce_disagg {
return Err(anyhow::anyhow!(PrefillError::NotActivated));
}
return next.generate(context.map(|_| req)).await;
}
// Ensure tracker exists for routing decisions in disaggregated mode.
// Create one if not provided by the upstream DeltaGenerator.
if req.tracker.is_none() {
req.tracker = Some(Arc::new(RequestTracker::new()));
}
let tracker = req.tracker.as_ref().unwrap();
let prefill_phase_barrier = tracker.set_phase(RequestPhase::Prefill).await;
// Prepare prefill request with max_tokens = 1 (clone after tracker is set)
let mut prefill_req = req.clone();
prefill_req.stop_conditions.max_tokens = Some(1);
// Try to resolve prefill worker upfront: if we can get bootstrap info early,
// spawn prefill in background and proceed to decode immediately.
let preselected_worker = prefill_req
.routing
.as_ref()
.and_then(|r| r.prefill_worker_id);
if self.router_mode.is_direct_routing() && preselected_worker.is_none() {
return Err(anyhow::anyhow!(
"Prefill worker ID required in Direct routing mode but none found in request. \
Expected prefill_worker_id to be set via x-prefill-instance-id header by external router (e.g., EPP)."
));
}
let prefill_result = match self
.resolve_prefill_worker(&prefill_req, preselected_worker)
.await
{
PrefillResolveDecision::Resolved {
worker_id,
dp_rank,
bootstrap_info,
} => {
// Bootstrap optimization path: spawn prefill in background
// We successfully used the peeked worker, so we must now advance the router state
// to ensure the next request gets a different worker.
if !self.router_mode.is_kv_routing()
&& let Some(router) = self.prefill_router.get()
{
router.select_next_worker();
}
let routing = prefill_req.routing_mut();
routing.prefill_worker_id = Some(worker_id);
routing.dp_rank = Some(dp_rank);
prefill_req.bootstrap_info = Some(bootstrap_info.clone());
let prefill_context =
link_child_context(&engine_ctx, prefill_req, request_id.as_str());
// Pass the phase barrier to the spawned task. It is released after routing
// completes so `record_worker_full` finishes before phase changes to Decode.
self.spawn_prefill_task(prefill_context, Some(worker_id), prefill_phase_barrier);
Ok(PrefillOutcome::Bootstrap(bootstrap_info))
}
PrefillResolveDecision::Unavailable
| PrefillResolveDecision::NotActivated
| PrefillResolveDecision::NoBootstrapEndpoint => {
// Original prefill path: wait for prefill to complete
tracing::debug!("Using original prefill path");
// Drop the phase barrier because we wait for prefill completion in this task,
// so there is no race with set_phase(Decode) below.
drop(prefill_phase_barrier);
let prefill_context =
link_child_context(&engine_ctx, prefill_req, request_id.as_str());
// In Direct mode, pass preselected_worker so execute_prefill uses
// router.direct() instead of router.generate() (which bails in Direct mode).
let (result, _worker_info) = Self::execute_prefill(
self.prefill_router.get().cloned(),
prefill_context,
preselected_worker,
None,
)
.await?;
Ok(PrefillOutcome::Completed(result))
}
};
// Abort if cancelled during prefill
if engine_ctx.is_stopped() || engine_ctx.is_killed() {
tracing::debug!("Abort entering decode after context is stopped or killed");
return Err(anyhow::anyhow!(
"Context id {} is stopped or killed",
engine_ctx.id()
));
}
// Handle prefill result
match prefill_result {
Ok(outcome) => {
tracing::debug!("Prefill completed, proceeding to decode");
// Set phase to Decode for the decode request.
// In bootstrap path, this blocks until the spawned prefill task releases its
// phase barrier after routing completes, ensuring correct worker attribution.
if let Some(ref tracker) = req.tracker {
let _decode_permit = tracker.set_phase(RequestPhase::Decode).await;
// Permit is dropped immediately - decode proceeds, no need to hold it
}
let mut decode_req = req;
match outcome {
PrefillOutcome::Bootstrap(info) => {
decode_req.bootstrap_info = Some(info);
}
PrefillOutcome::Completed(result) => {
decode_req.prefill_result = Some(result);
}
}
// Restore original max_tokens for decode
decode_req.stop_conditions.max_tokens = original_max_tokens;
// Set router_config_override for decode:
// - overlap_score_weight = 0 (no KV cache overlap scoring for decode)
// - assume_kv_reuse = false (generate random hashes since decode workers
// may already have blocks cached from prefill transfer)
// - track_prefill_tokens = false (decode router should ignore prompt-side load)
let existing_override = decode_req.router_config_override.take();
decode_req.router_config_override =
Some(build_decode_router_override(existing_override));
// Map the modified request through with preserved context
let decode_request = context.map(|_| decode_req);
next.generate(decode_request).await
}
Err(PrefillError::NotActivated) => {
tracing::error!("Prefill router not activated, failing request");
Err(anyhow::anyhow!(PrefillError::NotActivated))
}
Err(e) => {
tracing::error!(error = %e, "Remote prefill failed, failing request");
Err(anyhow::anyhow!(e))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use dynamo_kv_router::config::RouterConfigOverride;
#[test]
fn decode_router_override_disables_overlap_and_prefill_tracking() {
let override_config = build_decode_router_override(Some(RouterConfigOverride {
router_temperature: Some(0.7),
..Default::default()
}));
assert_eq!(override_config.overlap_score_weight, Some(0.0));
assert_eq!(override_config.assume_kv_reuse, Some(false));
assert_eq!(override_config.track_prefill_tokens, Some(false));
assert_eq!(override_config.router_temperature, Some(0.7));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_kv_router::config::RouterConfigOverride;
use crate::protocols::common::preprocessor::{BootstrapInfo, PrefillResult};
/// Errors that can occur during prefill routing
#[derive(Debug, thiserror::Error)]
pub enum PrefillError {
/// Prefill router has not been activated yet
#[error("Prefill router not yet activated")]
NotActivated,
/// TODO: Separate prefill worker error from prefill router error
/// Error during prefill execution
#[error("Prefill execution failed: {0}")]
PrefillError(
String,
#[source] Option<Box<dyn std::error::Error + Send + Sync + 'static>>,
),
/// Disaggregated params not found in prefill response
#[error("No disaggregated params in prefill response: {0}")]
NoDisaggregatedParams(String),
}
/// Result of the prefill phase in `generate()`.
pub(super) enum PrefillOutcome {
/// Bootstrap optimization: prefill spawned in background, bootstrap info ready
Bootstrap(BootstrapInfo),
/// Synchronous prefill completed with result
Completed(PrefillResult),
}
pub(super) enum PrefillResolveDecision {
Resolved {
worker_id: u64,
dp_rank: u32,
bootstrap_info: BootstrapInfo,
},
Unavailable,
NotActivated,
NoBootstrapEndpoint,
}
pub(super) fn build_decode_router_override(
existing_override: Option<RouterConfigOverride>,
) -> RouterConfigOverride {
RouterConfigOverride {
overlap_score_weight: Some(0.0),
assume_kv_reuse: Some(false),
track_prefill_tokens: Some(false),
..existing_override.unwrap_or_default()
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::future::Future;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::Result;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use dynamo_kv_router::RouterEventSink;
use dynamo_kv_router::indexer::LocalKvIndexer;
use dynamo_kv_router::protocols::*;
use dynamo_runtime::transports::event_plane::EventPublisher;
use dynamo_runtime::transports::nats::NatsQueue;
use crate::kv_router::KV_EVENT_SUBJECT;
use super::{DEFAULT_MAX_BATCH_BLOCKS, kv_publisher_metrics};
/// Accumulator for in-flight KV cache events that will be merged into a single
/// [`RouterEvent`] before being forwarded to the event sink.
#[derive(Debug)]
pub(super) struct BatchingState {
pub(super) pending_removed: Option<KvCacheRemoveData>,
pub(super) pending_stored: Option<KvCacheStoreData>,
pub(super) next_publish_id: u64,
pub(super) last_dp_rank: u32,
pub(super) last_flush_time: Instant,
}
impl BatchingState {
pub(super) fn new() -> Self {
Self {
pending_removed: None,
pending_stored: None,
next_publish_id: 1,
last_dp_rank: 0,
last_flush_time: Instant::now(),
}
}
pub(super) fn has_pending(&self) -> bool {
self.pending_removed.is_some() || self.pending_stored.is_some()
}
pub(super) fn pending_block_count(&self) -> usize {
self.pending_removed
.as_ref()
.map(|r| r.block_hashes.len())
.unwrap_or(0)
+ self
.pending_stored
.as_ref()
.map(|s| s.blocks.len())
.unwrap_or(0)
}
pub(super) fn record_flush_time(&mut self) {
self.last_flush_time = Instant::now();
}
pub(super) fn remaining_timeout(&self, timeout_ms: u64) -> Duration {
let timeout = Duration::from_millis(timeout_ms);
let elapsed = self.last_flush_time.elapsed();
if elapsed >= timeout {
Duration::ZERO
} else {
timeout - elapsed
}
}
pub(super) fn is_timeout_elapsed(&self, timeout_ms: u64) -> bool {
self.remaining_timeout(timeout_ms) == Duration::ZERO
}
async fn flush<P: RouterEventSink + Send + Sync + 'static>(
&mut self,
publisher: &P,
local_indexer: &Option<Arc<LocalKvIndexer>>,
worker_id: u64,
) {
if !self.has_pending() {
return;
}
let id = self.next_publish_id;
let dp_rank = self.last_dp_rank;
if let Some(data) = self.pending_removed.take() {
emit(
publisher,
local_indexer,
worker_id,
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Removed(data),
dp_rank,
},
)
.await;
}
if let Some(data) = self.pending_stored.take() {
emit(
publisher,
local_indexer,
worker_id,
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Stored(data),
dp_rank,
},
)
.await;
}
self.next_publish_id += 1;
self.record_flush_time();
}
}
pub(super) struct EventPlanePublisher(pub(super) EventPublisher);
impl RouterEventSink for EventPlanePublisher {
fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
self.0.publish(event)
}
}
pub(super) struct JetStreamPublisher(pub(super) NatsQueue);
impl RouterEventSink for JetStreamPublisher {
fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
NatsQueue::publish_event(&self.0, KV_EVENT_SUBJECT, event)
}
}
async fn emit<P: RouterEventSink>(
publisher: &P,
local_indexer: &Option<Arc<LocalKvIndexer>>,
worker_id: u64,
event: KvCacheEvent,
) {
let router_event = RouterEvent::new(worker_id, event);
if let Some(indexer) = local_indexer
&& let Err(e) = indexer.apply_event_with_buffer(router_event.clone()).await
{
tracing::warn!(worker_id, error = %e, "Failed to apply event to local indexer");
}
if let Err(e) = publisher.publish_event(&router_event).await {
tracing::error!(worker_id, error = %e, "Failed to publish event");
}
}
pub(super) async fn run_event_processor_loop<P: RouterEventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
mut rx: mpsc::UnboundedReceiver<PlacementEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>,
timeout_ms: Option<u64>,
max_batch_blocks: usize,
) {
let mut batching_state = BatchingState::new();
let mut last_raw_input_id: Option<u64> = None;
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::info!("KV Event source received cancellation signal");
batching_state.flush(&publisher, &local_indexer, worker_id).await;
break;
}
event = rx.recv() => {
let Some(placement_event) = event else {
tracing::debug!("Event processor channel closed.");
batching_state.flush(&publisher, &local_indexer, worker_id).await;
break;
};
let raw_event_id = placement_event.event.event_id;
if let Some(last_id) = last_raw_input_id
&& raw_event_id > last_id + 1
{
let gap = raw_event_id - last_id - 1;
tracing::warn!(
worker_id,
last_raw_input_id = last_id,
raw_event_id,
gap,
"Input event gap detected: raw events dropped before batching"
);
if let Some(metrics) = kv_publisher_metrics() {
metrics.increment_engines_dropped_events(worker_id, gap);
} else {
tracing::warn!(
worker_id,
gap,
"Failed to record dropped events metric: metrics not initialized"
);
}
}
last_raw_input_id = Some(raw_event_id);
if !placement_event.placement.is_local_gpu() {
tracing::trace!(
worker_id,
?placement_event.placement,
event_id = placement_event.event.event_id,
"Skipping non-local-GPU placement event"
);
continue;
}
let event = placement_event.event;
tracing::trace!(
"Event processor for worker_id {} processing event: {:?}",
worker_id,
event.data
);
let dp_rank_changed =
batching_state.has_pending() && event.dp_rank != batching_state.last_dp_rank;
match event.data {
KvCacheEventData::Removed(data) => {
if batching_state.pending_stored.is_some() || dp_rank_changed {
batching_state.flush(&publisher, &local_indexer, worker_id).await;
}
match &mut batching_state.pending_removed {
Some(pending) => pending.block_hashes.extend(data.block_hashes),
None => {
batching_state.pending_removed = Some(data);
}
}
}
KvCacheEventData::Stored(data) => {
let should_flush = dp_rank_changed
|| batching_state.pending_removed.is_some()
|| batching_state.pending_stored.as_ref().is_some_and(|p| {
data.parent_hash != p.blocks.last().map(|b| b.block_hash)
});
if should_flush {
batching_state.flush(&publisher, &local_indexer, worker_id).await;
}
match &mut batching_state.pending_stored {
Some(pending) => pending.blocks.extend(data.blocks),
None => {
batching_state.pending_stored = Some(data);
}
}
}
KvCacheEventData::Cleared => {
batching_state.flush(&publisher, &local_indexer, worker_id).await;
emit(
&publisher,
&local_indexer,
worker_id,
KvCacheEvent {
event_id: batching_state.next_publish_id,
data: KvCacheEventData::Cleared,
dp_rank: event.dp_rank,
},
)
.await;
batching_state.next_publish_id += 1;
}
}
batching_state.last_dp_rank = event.dp_rank;
if batching_state.has_pending()
&& (timeout_ms.is_none_or(|ms| batching_state.is_timeout_elapsed(ms))
|| batching_state.pending_block_count() > max_batch_blocks)
{
batching_state.flush(&publisher, &local_indexer, worker_id).await;
}
}
_ = tokio::time::sleep(
timeout_ms
.map(|ms| batching_state.remaining_timeout(ms))
.unwrap_or(Duration::from_secs(3600))
), if timeout_ms.is_some() && batching_state.has_pending() => {
batching_state.flush(&publisher, &local_indexer, worker_id).await;
}
}
}
}
pub(super) async fn start_event_processor<P: RouterEventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
rx: mpsc::UnboundedReceiver<PlacementEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>,
batching_timeout_ms: Option<u64>,
) {
run_event_processor_loop(
publisher,
worker_id,
cancellation_token,
rx,
local_indexer,
batching_timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
}
pub(super) async fn start_event_processor_jetstream(
publisher: NatsQueue,
worker_id: u64,
cancellation_token: CancellationToken,
rx: mpsc::UnboundedReceiver<PlacementEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>,
batching_timeout_ms: Option<u64>,
) {
run_event_processor_loop(
JetStreamPublisher(publisher),
worker_id,
cancellation_token,
rx,
local_indexer,
batching_timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicU64, Ordering};
use anyhow::Result;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use dynamo_kv_router::indexer::{KvIndexerMetrics, LocalKvIndexer};
use dynamo_kv_router::protocols::*;
pub use dynamo_kv_router::zmq_wire::create_stored_blocks;
#[cfg(test)]
use dynamo_kv_router::zmq_wire::*;
use dynamo_runtime::config::environment_names::nats as env_nats;
use dynamo_runtime::metrics::MetricsHierarchy;
use dynamo_runtime::metrics::prometheus_names::kv_publisher;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::{
component::Component,
transports::nats::{NatsQueue, Slug},
};
use crate::kv_router::{
KV_EVENT_SUBJECT, WORKER_KV_INDEXER_BUFFER_SIZE, worker_query::start_worker_kv_query_endpoint,
};
mod event_processor;
#[cfg(test)]
mod tests;
mod worker_metrics;
mod zmq_listener;
#[cfg(test)]
use event_processor::{BatchingState, run_event_processor_loop};
use event_processor::{
EventPlanePublisher, start_event_processor, start_event_processor_jetstream,
};
pub use worker_metrics::WorkerMetricsPublisher;
use zmq_listener::start_zmq_listener;
#[cfg(test)]
use zmq_listener::{
INITIAL_BACKOFF_MS, MAX_BACKOFF_EXPONENT, MAX_BACKOFF_MS, MAX_CONSECUTIVE_ERRORS,
calculate_backoff_ms,
};
const MAX_BATCHING_TIMEOUT_MS: u64 = 15_000;
pub const DEFAULT_BATCHING_TIMEOUT_MS: Option<u64> = None;
const DEFAULT_MAX_BATCH_BLOCKS: usize = 128;
/// Helper function to create a KV stream name from a component and subject.
///
/// Generates a slugified stream name in the format:
/// `namespace-{namespace}-component-{component}-{subject}`
fn create_kv_stream_name(component: &Component, subject: &str) -> String {
Slug::slugify(&format!(
"namespace.{}.component.{}.{}",
component.namespace().name(),
component.name(),
subject
))
.to_string()
.replace("_", "-")
}
/// Metrics for the KV publisher, created via the MetricsHierarchy API.
/// This provides automatic `dynamo_namespace`, `dynamo_component`, and other
/// hierarchy labels for free.
pub(super) struct KvPublisherMetrics {
/// Total number of raw events dropped by engines before reaching publisher
pub engines_dropped_events_total: prometheus::IntCounterVec,
}
static KV_PUBLISHER_METRICS: OnceLock<Arc<KvPublisherMetrics>> = OnceLock::new();
impl KvPublisherMetrics {
/// Create from a Component, memoized in a static OnceLock.
/// Uses the MetricsHierarchy API which auto-prepends `dynamo_component_`,
/// injects hierarchy labels, and registers with the DRT `MetricsRegistry`.
pub fn from_component(component: &Component) -> Arc<Self> {
KV_PUBLISHER_METRICS
.get_or_init(|| {
let metrics = component.metrics();
match metrics.create_intcountervec(
kv_publisher::ENGINES_DROPPED_EVENTS_TOTAL,
"Total number of raw events dropped by engines before reaching publisher (detected via event_id gaps)",
&["worker_id"],
&[],
) {
Ok(engines_dropped_events_total) => {
Arc::new(Self { engines_dropped_events_total })
}
Err(e) => {
tracing::warn!("Failed to create kv_publisher metrics from component: {}. Using unregistered metrics as fallback.", e);
Arc::new(Self::new_unregistered())
}
}
})
.clone()
}
/// Creates unregistered metrics for use when the MetricsRegistry is not available.
/// This is used as a fallback when metric creation fails.
pub fn new_unregistered() -> Self {
Self {
engines_dropped_events_total: prometheus::IntCounterVec::new(
prometheus::Opts::new(
kv_publisher::ENGINES_DROPPED_EVENTS_TOTAL,
"Total number of raw events dropped by engines before reaching publisher (detected via event_id gaps)",
),
&["worker_id"],
)
.expect("failed to create engines_dropped_events_total counter"),
}
}
/// Increment the engines dropped events counter by the given amount.
pub fn increment_engines_dropped_events(&self, worker_id: u64, count: u64) {
self.engines_dropped_events_total
.with_label_values(&[&worker_id.to_string()])
.inc_by(count);
}
}
fn kv_publisher_metrics() -> Option<Arc<KvPublisherMetrics>> {
KV_PUBLISHER_METRICS.get().cloned()
}
/// Configure the source of KV events.
/// Currently, only ZMQ is supported.
pub enum KvEventSourceConfig {
Zmq { endpoint: String, topic: String },
}
enum KvEventSource {
Zmq {
zmq_handle: tokio::task::JoinHandle<()>,
},
}
impl KvEventSource {
fn start(
component: Component,
worker_id: WorkerId,
kv_block_size: u32,
source_config: KvEventSourceConfig,
cancellation_token: CancellationToken,
tx: mpsc::UnboundedSender<PlacementEvent>,
next_event_id: Arc<AtomicU64>,
) -> Result<Self> {
match source_config {
KvEventSourceConfig::Zmq { endpoint, topic } => {
let zmq_handle = component
.drt()
.runtime()
.secondary()
.spawn(start_zmq_listener(
endpoint,
topic,
worker_id,
tx,
cancellation_token.clone(),
kv_block_size,
next_event_id,
));
Ok(KvEventSource::Zmq { zmq_handle })
}
}
}
fn shutdown(&self) {
match self {
KvEventSource::Zmq { zmq_handle } => {
zmq_handle.abort();
}
}
}
}
/// A publisher of KV events.
pub struct KvEventPublisher {
/// The size of the KV block.
kv_block_size: u32,
/// The source of KV events.
/// Can be `None` if all events provided through [`KvEventPublisher::publish`].
source: Option<KvEventSource>,
/// The cancellation token.
cancellation_token: CancellationToken,
/// The ID of the local worker emitting placement events.
worker_id: WorkerId,
/// The channel to send events to.
tx: mpsc::UnboundedSender<PlacementEvent>,
/// Internal monotonic event ID counter. Shared with the ZMQ listener if present.
next_event_id: Arc<AtomicU64>,
}
impl KvEventPublisher {
pub fn new(
component: Component,
kv_block_size: u32,
source_config: Option<KvEventSourceConfig>,
) -> Result<Self> {
Self::new_with_local_indexer(
component,
kv_block_size,
source_config,
false,
0,
DEFAULT_BATCHING_TIMEOUT_MS,
)
}
pub fn new_with_local_indexer(
component: Component,
kv_block_size: u32,
source_config: Option<KvEventSourceConfig>,
enable_local_indexer: bool,
dp_rank: DpRank,
batching_timeout_ms: Option<u64>,
) -> Result<Self> {
let cancellation_token = CancellationToken::new();
let batching_timeout_ms = batching_timeout_ms
.filter(|&ms| {
if ms > MAX_BATCHING_TIMEOUT_MS {
tracing::warn!(
requested_ms = ms,
max_ms = MAX_BATCHING_TIMEOUT_MS,
"batching_timeout_ms too high, capping to 15s"
);
}
ms > 0
})
.map(|ms| ms.min(MAX_BATCHING_TIMEOUT_MS));
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let worker_id = component.drt().connection_id();
KvPublisherMetrics::from_component(&component);
let component_name = component.name();
tracing::info!(
"Initializing KvEventPublisher for worker {worker_id} in component {component_name}"
);
if enable_local_indexer {
tracing::info!(
"LocalKvIndexer enabled for worker {worker_id} in component {component_name}"
);
}
let next_event_id = Arc::new(AtomicU64::new(0));
let mut source = None;
if let Some(config) = source_config {
source = Some(KvEventSource::start(
component.clone(),
worker_id,
kv_block_size,
config,
cancellation_token.clone(),
tx.clone(),
next_event_id.clone(),
)?);
}
let local_indexer = if enable_local_indexer {
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
Some(Arc::new(LocalKvIndexer::new(
cancellation_token.clone(),
kv_block_size,
metrics,
WORKER_KV_INDEXER_BUFFER_SIZE,
)))
} else {
None
};
let _local_indexer_query_handle = local_indexer.as_ref().map(|local_indexer_ref| {
let component = component.clone();
let local_indexer = local_indexer_ref.clone();
component
.drt()
.runtime()
.secondary()
.spawn(start_worker_kv_query_endpoint(
component,
worker_id,
dp_rank,
local_indexer,
))
});
let cancellation_token_clone = cancellation_token.clone();
let local_indexer_clone = local_indexer.clone();
if enable_local_indexer {
tracing::info!("Using event plane for KV event publishing (local_indexer mode)");
let component_clone = component.clone();
component.drt().runtime().secondary().spawn(async move {
let event_publisher =
match dynamo_runtime::transports::event_plane::EventPublisher::for_component(
&component_clone,
KV_EVENT_SUBJECT,
)
.await
{
Ok(publisher) => publisher,
Err(e) => {
tracing::error!("Failed to create event publisher: {}", e);
return;
}
};
start_event_processor(
EventPlanePublisher(event_publisher),
worker_id,
cancellation_token_clone,
rx,
local_indexer_clone,
batching_timeout_ms,
)
.await
});
} else {
let stream_name = create_kv_stream_name(&component, KV_EVENT_SUBJECT);
let nats_server = std::env::var(env_nats::NATS_SERVER)
.unwrap_or_else(|_| "nats://localhost:4222".to_string());
let mut nats_queue = NatsQueue::new_without_consumer(
stream_name,
nats_server,
std::time::Duration::from_secs(60),
);
component.drt().runtime().secondary().spawn(async move {
if let Err(e) = nats_queue.connect().await {
tracing::error!("Failed to connect NatsQueue: {e}");
return;
}
start_event_processor_jetstream(
nats_queue,
worker_id,
cancellation_token_clone,
rx,
local_indexer_clone,
batching_timeout_ms,
)
.await
});
}
Ok(Self {
kv_block_size,
source,
cancellation_token,
worker_id,
tx,
next_event_id,
})
}
pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError<KvCacheEvent>> {
let placement_event = PlacementEvent::local_gpu(self.worker_id, event);
match self.tx.send(placement_event) {
Ok(()) => Ok(()),
Err(err) => Err(mpsc::error::SendError(err.0.event)),
}
}
pub fn next_event_id(&self) -> u64 {
self.next_event_id.fetch_add(1, Ordering::SeqCst)
}
pub fn kv_block_size(&self) -> u32 {
self.kv_block_size
}
pub fn shutdown(&mut self) {
if !self.cancellation_token.is_cancelled() {
self.cancellation_token.cancel();
}
if let Some(source) = self.source.take() {
source.shutdown();
}
}
}
impl Drop for KvEventPublisher {
fn drop(&mut self) {
self.shutdown();
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::*;
#[allow(unused_imports)]
use bytes::Bytes;
#[allow(unused_imports)]
use dynamo_kv_router::RouterEventSink;
#[allow(unused_imports)]
use rmp_serde as rmps;
#[allow(unused_imports)]
use std::future::Future; use std::future::Future;
use std::sync::Arc; #[allow(unused_imports)]
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::{Duration, Instant}; use std::time::Duration;
#[allow(unused_imports)]
use anyhow::Result; use zeromq::{PubSocket, Socket, SocketSend, ZmqMessage};
use rmp_serde as rmps;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use zeromq::{Socket, SocketRecv, SubSocket};
use dynamo_runtime::metrics::MetricsHierarchy;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::transports::event_plane::EventPublisher;
use dynamo_runtime::{
component::{Component, Namespace},
transports::nats::{NatsQueue, Slug},
};
/// Helper function to create a KV stream name from a component and subject.
///
/// Generates a slugified stream name in the format:
/// `namespace-{namespace}-component-{component}-{subject}`
fn create_kv_stream_name(component: &Component, subject: &str) -> String {
Slug::slugify(&format!(
"namespace.{}.component.{}.{}",
component.namespace().name(),
component.name(),
subject
))
.to_string()
.replace("_", "-")
}
use dynamo_kv_router::indexer::{KvIndexerMetrics, LocalKvIndexer};
use dynamo_kv_router::protocols::*;
pub use dynamo_kv_router::zmq_wire::create_stored_blocks;
use dynamo_kv_router::zmq_wire::*;
use crate::kv_router::{
KV_EVENT_SUBJECT, KV_METRICS_SUBJECT, WORKER_KV_INDEXER_BUFFER_SIZE,
worker_query::start_worker_kv_query_endpoint,
};
use dynamo_runtime::config::environment_names::nats as env_nats;
// Error handling configuration for ZMQ operations
const INITIAL_BACKOFF_MS: u64 = 10;
const MAX_BACKOFF_MS: u64 = 5000;
const MAX_CONSECUTIVE_ERRORS: u32 = 10;
const MAX_BACKOFF_EXPONENT: u32 = 8; // Cap at 2^8 = 256x multiplier to prevent overflow
// Batching configuration
const MAX_BATCHING_TIMEOUT_MS: u64 = 15_000; // 15 seconds, prevents misconfiguration
pub const DEFAULT_BATCHING_TIMEOUT_MS: Option<u64> = None; // disabled by default
const DEFAULT_MAX_BATCH_BLOCKS: usize = 128; // Max blocks to batch before flushing
// ---------------------------------------------------------------------------
// Engines dropped events metric
// ---------------------------------------------------------------------------
use std::sync::OnceLock;
use dynamo_runtime::metrics::prometheus_names::kv_publisher;
/// Metrics for the KV publisher, created via the MetricsHierarchy API.
/// This provides automatic `dynamo_namespace`, `dynamo_component`, and other
/// hierarchy labels for free.
pub struct KvPublisherMetrics {
/// Total number of raw events dropped by engines before reaching publisher
pub engines_dropped_events_total: prometheus::IntCounterVec,
}
static KV_PUBLISHER_METRICS: OnceLock<Arc<KvPublisherMetrics>> = OnceLock::new();
impl KvPublisherMetrics {
/// Create from a Component, memoized in a static OnceLock.
/// Uses the MetricsHierarchy API which auto-prepends `dynamo_component_`,
/// injects hierarchy labels, and registers with the DRT `MetricsRegistry`.
pub fn from_component(component: &Component) -> Arc<Self> {
KV_PUBLISHER_METRICS
.get_or_init(|| {
let metrics = component.metrics();
match metrics.create_intcountervec(
kv_publisher::ENGINES_DROPPED_EVENTS_TOTAL,
"Total number of raw events dropped by engines before reaching publisher (detected via event_id gaps)",
&["worker_id"],
&[],
) {
Ok(engines_dropped_events_total) => {
Arc::new(Self { engines_dropped_events_total })
}
Err(e) => {
tracing::warn!("Failed to create kv_publisher metrics from component: {}. Using unregistered metrics as fallback.", e);
Arc::new(Self::new_unregistered())
}
}
})
.clone()
}
/// Creates unregistered metrics for use when the MetricsRegistry is not available.
/// This is used as a fallback when metric creation fails.
pub fn new_unregistered() -> Self {
Self {
engines_dropped_events_total: prometheus::IntCounterVec::new(
prometheus::Opts::new(
kv_publisher::ENGINES_DROPPED_EVENTS_TOTAL,
"Total number of raw events dropped by engines before reaching publisher (detected via event_id gaps)",
),
&["worker_id"],
)
.expect("failed to create engines_dropped_events_total counter"),
}
}
/// Increment the engines dropped events counter by the given amount.
pub fn increment_engines_dropped_events(&self, worker_id: u64, count: u64) {
self.engines_dropped_events_total
.with_label_values(&[&worker_id.to_string()])
.inc_by(count);
}
}
/// Get the KV publisher metrics if initialized.
fn kv_publisher_metrics() -> Option<Arc<KvPublisherMetrics>> {
KV_PUBLISHER_METRICS.get().cloned()
}
// -------------------------------------------------------------------------
// Batching State -----------------------------------------------------------
// -------------------------------------------------------------------------
/// Accumulator for in-flight KV cache events that will be merged into a single
/// [`RouterEvent`] before being forwarded to the event sink.
#[derive(Debug)]
struct BatchingState {
/// Block hashes accumulating for the next Removed event.
pending_removed: Option<KvCacheRemoveData>,
/// Blocks accumulating for the next Stored event.
pending_stored: Option<KvCacheStoreData>,
/// Monotonic published-batch counter. Increments by 1 per flush so downstream
/// consumers always see consecutive event IDs, regardless of how many raw source
/// events were merged into the batch.
next_publish_id: u64,
/// dp_rank of the events in the current pending batch.
/// A change signals that the batch must be flushed before accumulating further.
last_dp_rank: u32,
/// When we last flushed (or initialized). Used to detect stale pending data:
/// if a new event arrives after a long idle period (exceeding timeout),
/// we flush immediately for lower latency on sparse important events.
last_flush_time: Instant,
}
impl BatchingState {
fn new() -> Self {
Self {
pending_removed: None,
pending_stored: None,
next_publish_id: 1,
last_dp_rank: 0,
last_flush_time: Instant::now(),
}
}
fn has_pending(&self) -> bool {
self.pending_removed.is_some() || self.pending_stored.is_some()
}
fn pending_block_count(&self) -> usize {
self.pending_removed
.as_ref()
.map(|r| r.block_hashes.len())
.unwrap_or(0)
+ self
.pending_stored
.as_ref()
.map(|s| s.blocks.len())
.unwrap_or(0)
}
/// Records that a flush just happened. Called after every flush to track
/// idle periods for stale-data detection.
fn record_flush_time(&mut self) {
self.last_flush_time = Instant::now();
}
/// Returns the time remaining in the current batch window (zero if already elapsed).
fn remaining_timeout(&self, timeout_ms: u64) -> Duration {
let timeout = Duration::from_millis(timeout_ms);
let elapsed = self.last_flush_time.elapsed();
if elapsed >= timeout {
Duration::ZERO
} else {
timeout - elapsed
}
}
/// Returns `true` when the batch window has elapsed (or `timeout_ms` is zero).
fn is_timeout_elapsed(&self, timeout_ms: u64) -> bool {
self.remaining_timeout(timeout_ms) == Duration::ZERO
}
}
// -------------------------------------------------------------------------
// KV Event Publishers -----------------------------------------------------
// -------------------------------------------------------------------------
/// Configure the source of KV events.
/// Currently, only ZMQ is supported.
pub enum KvEventSourceConfig {
Zmq { endpoint: String, topic: String },
}
/// The source of KV events.
enum KvEventSource {
Zmq {
zmq_handle: tokio::task::JoinHandle<()>,
},
}
impl KvEventSource {
/// Start the event source from a [`KvEventSourceConfig`].
fn start(
component: Component,
worker_id: WorkerId,
kv_block_size: u32,
source_config: KvEventSourceConfig,
cancellation_token: CancellationToken,
tx: mpsc::UnboundedSender<PlacementEvent>,
next_event_id: Arc<AtomicU64>,
) -> Result<Self> {
match source_config {
KvEventSourceConfig::Zmq { endpoint, topic } => {
let zmq_handle = component
.drt()
.runtime()
.secondary()
.spawn(start_zmq_listener(
endpoint,
topic,
worker_id,
tx,
cancellation_token.clone(),
kv_block_size,
next_event_id,
));
Ok(KvEventSource::Zmq { zmq_handle })
}
}
}
fn shutdown(&self) {
match self {
KvEventSource::Zmq { zmq_handle } => {
zmq_handle.abort();
}
}
}
}
/// A publisher of KV events.
pub struct KvEventPublisher {
/// The size of the KV block.
kv_block_size: u32,
/// The source of KV events.
/// Can be `None` if all events provided through [`KvEventPublisher::publish`].
source: Option<KvEventSource>,
/// The cancellation token.
cancellation_token: CancellationToken,
/// The ID of the local worker emitting placement events.
worker_id: WorkerId,
/// The channel to send events to.
tx: mpsc::UnboundedSender<PlacementEvent>,
/// Internal monotonic event ID counter - ensures each event gets a unique, incrementing ID.
/// Shared with the ZMQ listener (if any) to maintain consistency.
next_event_id: Arc<AtomicU64>,
}
impl KvEventPublisher {
pub fn new(
component: Component,
kv_block_size: u32,
source_config: Option<KvEventSourceConfig>,
) -> Result<Self> {
Self::new_with_local_indexer(
component,
kv_block_size,
source_config,
false,
0,
DEFAULT_BATCHING_TIMEOUT_MS,
)
}
pub fn new_with_local_indexer(
component: Component,
kv_block_size: u32,
source_config: Option<KvEventSourceConfig>,
enable_local_indexer: bool,
dp_rank: DpRank,
batching_timeout_ms: Option<u64>,
) -> Result<Self> {
let cancellation_token = CancellationToken::new();
// None = disabled (flush every event); Some(0) normalised to None; Some(ms) = opt-in.
// Cap at MAX_BATCHING_TIMEOUT_MS to prevent misconfiguration.
let batching_timeout_ms = batching_timeout_ms
.filter(|&ms| {
if ms > MAX_BATCHING_TIMEOUT_MS {
tracing::warn!(
requested_ms = ms,
max_ms = MAX_BATCHING_TIMEOUT_MS,
"batching_timeout_ms too high, capping to 15s"
);
}
// if ms is 0, treat as disabled (None)
ms > 0
})
.map(|ms| ms.min(MAX_BATCHING_TIMEOUT_MS));
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
// Infer worker_id from component's connection
let worker_id = component.drt().connection_id();
// Initialize the KV publisher metrics via MetricsHierarchy API
// This provides automatic hierarchy labels (dynamo_namespace, dynamo_component, etc.)
KvPublisherMetrics::from_component(&component);
let component_name = component.name();
tracing::info!(
"Initializing KvEventPublisher for worker {worker_id} in component {component_name}"
);
if enable_local_indexer {
tracing::info!(
"LocalKvIndexer enabled for worker {worker_id} in component {component_name}"
);
}
// Internal monotonic event ID counter - shared with ZMQ listener if any
let next_event_id = Arc::new(AtomicU64::new(0));
// Create our event source (if any)
let mut source = None;
if let Some(config) = source_config {
source = Some(KvEventSource::start(
component.clone(),
worker_id,
kv_block_size,
config,
cancellation_token.clone(),
tx.clone(),
next_event_id.clone(),
)?);
}
// Create local indexer if requested
let local_indexer = if enable_local_indexer {
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
Some(Arc::new(LocalKvIndexer::new(
cancellation_token.clone(),
kv_block_size,
metrics,
WORKER_KV_INDEXER_BUFFER_SIZE,
)))
} else {
None
};
// Spawn runtime for router->local indexer comm if requested
let _local_indexer_query_handle = local_indexer.as_ref().map(|local_indexer_ref| {
let component = component.clone();
let local_indexer = local_indexer_ref.clone();
component
.drt()
.runtime()
.secondary()
.spawn(start_worker_kv_query_endpoint(
component,
worker_id,
dp_rank,
local_indexer,
))
});
let cancellation_token_clone = cancellation_token.clone();
let local_indexer_clone = local_indexer.clone();
if enable_local_indexer {
// When local indexer is enabled, use the event plane directly.
// EventPublisher handles transport selection (ZMQ or NATS) based on environment.
// Durability is provided by the local indexer's event buffer.
tracing::info!("Using event plane for KV event publishing (local_indexer mode)");
let component_clone = component.clone();
component.drt().runtime().secondary().spawn(async move {
let event_publisher =
match EventPublisher::for_component(&component_clone, KV_EVENT_SUBJECT).await {
Ok(publisher) => publisher,
Err(e) => {
tracing::error!("Failed to create event publisher: {}", e);
return;
}
};
start_event_processor(
EventPlanePublisher(event_publisher),
worker_id,
cancellation_token_clone,
rx,
local_indexer_clone,
batching_timeout_ms,
)
.await
});
} else {
// When local indexer is disabled, use JetStream (NatsQueue) for durability.
let stream_name = create_kv_stream_name(&component, KV_EVENT_SUBJECT);
let nats_server = std::env::var(env_nats::NATS_SERVER)
.unwrap_or_else(|_| "nats://localhost:4222".to_string());
let mut nats_queue = NatsQueue::new_without_consumer(
stream_name,
nats_server,
std::time::Duration::from_secs(60), // 1 minute timeout
);
component.drt().runtime().secondary().spawn(async move {
if let Err(e) = nats_queue.connect().await {
tracing::error!("Failed to connect NatsQueue: {e}");
return;
}
start_event_processor_jetstream(
JetStreamPublisher(nats_queue),
worker_id,
cancellation_token_clone,
rx,
local_indexer_clone,
batching_timeout_ms,
)
.await
});
}
Ok(Self {
kv_block_size,
source,
cancellation_token,
worker_id,
tx,
next_event_id,
})
}
pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError<KvCacheEvent>> {
let placement_event = PlacementEvent::local_gpu(self.worker_id, event);
match self.tx.send(placement_event) {
Ok(()) => Ok(()),
Err(err) => Err(mpsc::error::SendError(err.0.event)),
}
}
/// Get and increment the next event ID atomically.
/// Use this to assign monotonically increasing event IDs to events before publishing.
pub fn next_event_id(&self) -> u64 {
self.next_event_id.fetch_add(1, Ordering::SeqCst)
}
pub fn kv_block_size(&self) -> u32 {
self.kv_block_size
}
pub fn shutdown(&mut self) {
if !self.cancellation_token.is_cancelled() {
self.cancellation_token.cancel();
}
if let Some(source) = self.source.take() {
source.shutdown();
}
}
}
impl Drop for KvEventPublisher {
fn drop(&mut self) {
self.shutdown();
}
}
use dynamo_kv_router::RouterEventSink;
struct EventPlanePublisher(EventPublisher);
impl RouterEventSink for EventPlanePublisher {
fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
self.0.publish(event)
}
}
struct JetStreamPublisher(NatsQueue);
impl RouterEventSink for JetStreamPublisher {
fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
NatsQueue::publish_event(&self.0, KV_EVENT_SUBJECT, event)
}
}
/// Publishes a single [`KvCacheEvent`] to the event sink and, when present, the local indexer.
/// Errors are logged and swallowed so the caller loop can continue uninterrupted.
async fn emit<P: RouterEventSink>(
publisher: &P,
local_indexer: &Option<Arc<LocalKvIndexer>>,
worker_id: u64,
event: KvCacheEvent,
) {
let router_event = RouterEvent::new(worker_id, event);
if let Some(indexer) = local_indexer
&& let Err(e) = indexer.apply_event_with_buffer(router_event.clone()).await
{
tracing::warn!(worker_id, error = %e, "Failed to apply event to local indexer");
}
if let Err(e) = publisher.publish_event(&router_event).await {
tracing::error!(worker_id, error = %e, "Failed to publish event");
}
}
impl BatchingState {
/// Publishes any pending batch as a single [`RouterEvent`] and advances the monotonic
/// batch ID. No-ops when nothing is pending, so callers may call unconditionally.
async fn flush<P: RouterEventSink + Send + Sync + 'static>(
&mut self,
publisher: &P,
local_indexer: &Option<Arc<LocalKvIndexer>>,
worker_id: u64,
) {
if !self.has_pending() {
return;
}
let id = self.next_publish_id;
let dp_rank = self.last_dp_rank;
if let Some(data) = self.pending_removed.take() {
emit(
publisher,
local_indexer,
worker_id,
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Removed(data),
dp_rank,
},
)
.await;
}
if let Some(data) = self.pending_stored.take() {
emit(
publisher,
local_indexer,
worker_id,
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Stored(data),
dp_rank,
},
)
.await;
}
// Consecutive batch IDs (1, 2, 3, …) keep downstream gap-detection happy.
self.next_publish_id += 1;
// Record when we flushed for stale-data detection on next event.
self.record_flush_time();
}
}
/// Batching loop: accumulates Removed/Stored events and flushes them as a single
/// [`RouterEvent`] when any of the following conditions are met:
/// - Event type switches (Removed ↔ Stored)
/// - `dp_rank` changes between consecutive events
/// - A `Stored` event's `parent_hash` breaks the sequential chain
/// - The batch window expires (`Some(timeout_ms)`; `None` = disabled, flush every event)
/// - Channel is closed or a cancellation signal is received
async fn run_event_processor_loop<P: RouterEventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
mut rx: mpsc::UnboundedReceiver<PlacementEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>,
timeout_ms: Option<u64>,
max_batch_blocks: usize,
) {
let mut batching_state = BatchingState::new();
// Track last raw input event_id for gap detection (dropped events before batching).
// The raw event_id is a globally monotonic counter assigned by the ZMQ listener,
// so any gap here means events were silently dropped (e.g. send error on the channel).
let mut last_raw_input_id: Option<u64> = None;
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::info!("KV Event source received cancellation signal");
batching_state.flush(&publisher, &local_indexer, worker_id).await;
break;
}
event = rx.recv() => {
let Some(placement_event) = event else {
tracing::debug!("Event processor channel closed.");
batching_state.flush(&publisher, &local_indexer, worker_id).await;
break;
};
// Warn if the raw input event_id is not consecutive — events were dropped
// (e.g. channel send error) before they reached the batching layer.
let raw_event_id = placement_event.event.event_id;
if let Some(last_id) = last_raw_input_id
&& raw_event_id > last_id + 1
{
let gap = raw_event_id - last_id - 1;
tracing::warn!(
worker_id,
last_raw_input_id = last_id,
raw_event_id,
gap,
"Input event gap detected: raw events dropped before batching"
);
// Increment Prometheus counter for dropped events (if initialized)
if let Some(metrics) = kv_publisher_metrics() {
metrics.increment_engines_dropped_events(worker_id, gap);
} else {
tracing::warn!(
worker_id,
gap,
"Failed to record dropped events metric: metrics not initialized"
);
}
}
last_raw_input_id = Some(raw_event_id);
if !placement_event.placement.is_local_gpu() {
tracing::trace!(
worker_id,
?placement_event.placement,
event_id = placement_event.event.event_id,
"Skipping non-local-GPU placement event"
);
continue;
}
let event = placement_event.event;
tracing::trace!("Event processor for worker_id {} processing event: {:?}", worker_id, event.data);
let dp_rank_changed = batching_state.has_pending()
&& event.dp_rank != batching_state.last_dp_rank;
match event.data {
KvCacheEventData::Removed(data) => {
if batching_state.pending_stored.is_some() || dp_rank_changed {
batching_state.flush(&publisher, &local_indexer, worker_id).await;
}
match &mut batching_state.pending_removed {
Some(pending) => pending.block_hashes.extend(data.block_hashes),
None => {
batching_state.pending_removed = Some(data);
}
}
}
KvCacheEventData::Stored(data) => {
// Flush if: type switch, dp_rank change, or the chain is broken
// (new event's parent_hash doesn't continue from the last stored block).
let should_flush = dp_rank_changed
|| batching_state.pending_removed.is_some()
|| batching_state.pending_stored.as_ref().is_some_and(|p| {
data.parent_hash != p.blocks.last().map(|b| b.block_hash)
});
if should_flush {
batching_state.flush(&publisher, &local_indexer, worker_id).await;
}
match &mut batching_state.pending_stored {
// Only extend blocks; parent_hash stays fixed from the first event.
Some(pending) => pending.blocks.extend(data.blocks),
None => {
batching_state.pending_stored = Some(data);
}
}
}
KvCacheEventData::Cleared => {
batching_state.flush(&publisher, &local_indexer, worker_id).await;
emit(&publisher, &local_indexer, worker_id, KvCacheEvent {
event_id: batching_state.next_publish_id,
data: KvCacheEventData::Cleared,
dp_rank: event.dp_rank,
}).await;
batching_state.next_publish_id += 1;
}
}
// Track dp_rank after the match so in-flight flushes use the old value.
batching_state.last_dp_rank = event.dp_rank;
// Flush after every event when disabled (None), or when the window has elapsed,
// or when the batch exceeds the max block count.
// The sleep arm only arms when batching is enabled; this covers the disabled path.
if batching_state.has_pending()
&& (timeout_ms.is_none_or(|ms| batching_state.is_timeout_elapsed(ms))
|| batching_state.pending_block_count() > max_batch_blocks)
{
batching_state.flush(&publisher, &local_indexer, worker_id).await;
}
}
// if has some pending and has timeout, and no new events come in, then flush when timeout elapsed to prevent stale events
_ = tokio::time::sleep(
timeout_ms.map(|ms| batching_state.remaining_timeout(ms)).unwrap_or(Duration::from_secs(3600))
), if timeout_ms.is_some() && batching_state.has_pending() => {
batching_state.flush(&publisher, &local_indexer, worker_id).await;
}
}
}
}
/// Batched event processor for ephemeral transports (NATS Core / ZMQ).
async fn start_event_processor<P: RouterEventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
rx: mpsc::UnboundedReceiver<PlacementEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>,
batching_timeout_ms: Option<u64>,
) {
run_event_processor_loop(
publisher,
worker_id,
cancellation_token,
rx,
local_indexer,
batching_timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
}
/// Batched event processor using JetStream (durable).
async fn start_event_processor_jetstream<P: RouterEventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
rx: mpsc::UnboundedReceiver<PlacementEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>,
batching_timeout_ms: Option<u64>,
) {
run_event_processor_loop(
publisher,
worker_id,
cancellation_token,
rx,
local_indexer,
batching_timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
}
/// Calculate exponential backoff duration based on consecutive error count
fn calculate_backoff_ms(consecutive_errors: u32) -> u64 {
std::cmp::min(
INITIAL_BACKOFF_MS * 2_u64.pow(consecutive_errors.min(MAX_BACKOFF_EXPONENT)),
MAX_BACKOFF_MS,
)
}
pub async fn start_zmq_listener(
zmq_endpoint: String,
zmq_topic: String,
worker_id: WorkerId,
tx: mpsc::UnboundedSender<PlacementEvent>,
cancellation_token: CancellationToken,
kv_block_size: u32,
next_event_id: Arc<AtomicU64>,
) {
tracing::debug!(
"KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')",
zmq_endpoint,
zmq_topic
);
let warning_count = Arc::new(AtomicU32::new(0));
let mut socket = SubSocket::new();
// Subscribe to the requested topic (empty string == all topics)
if let Err(e) = socket.subscribe(&zmq_topic).await {
tracing::error!("Failed to subscribe on ZMQ socket: {}", e);
return;
}
// Connect to the ZMQ endpoint. SGLang binds locally, Dynamo connects.
// In multi-node setups, each node runs dynamo.sglang alongside local SGLang ranks,
// so ZMQ connections are always local. NATS handles cross-node event distribution.
if let Err(e) = socket.connect(&zmq_endpoint).await {
tracing::error!("Failed to connect ZMQ SUB socket to {zmq_endpoint}: {e}");
return;
}
let mut consecutive_errors = 0u32;
#[expect(unused_assignments)]
let mut exit_reason = "unknown";
let mut messages_processed = 0u64;
'main: loop {
tokio::select! {
biased;
// Check for cancellation
_ = cancellation_token.cancelled() => {
tracing::debug!("ZMQ listener received cancellation signal");
exit_reason = "cancellation token cancelled";
break 'main;
}
// Receive message
msg_result = socket.recv() => {
let Ok(msg) = msg_result else {
let e = msg_result.unwrap_err();
consecutive_errors += 1;
if consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
tracing::error!(
error=%e,
consecutive_errors=%consecutive_errors,
"Too many consecutive ZMQ errors, terminating listener"
);
exit_reason = "too many consecutive errors";
break 'main;
}
// Simple exponential backoff with max exponent to prevent overflow
let backoff_ms = calculate_backoff_ms(consecutive_errors);
tracing::warn!(
error=%e,
consecutive_errors=%consecutive_errors,
backoff_ms=%backoff_ms,
"Error reading from ZMQ socket, applying exponential backoff"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
continue;
};
// Reset error count on successful message
consecutive_errors = 0;
// We expect multipart frames: [topic, seq, payload]
let mut frames: Vec<Vec<u8>> = msg.into_vec().into_iter().map(|frame| frame.to_vec()).collect();
if frames.len() != 3 {
tracing::warn!("Received unexpected ZMQ frame count: expected 3, actual {}", frames.len());
continue;
}
// Extract the payload and sequence number.
let payload = frames.pop().unwrap();
let seq_bytes = frames.pop().unwrap();
if seq_bytes.len() != 8 {
tracing::warn!("Invalid sequence number byte length: expected 8, actual {}", seq_bytes.len());
continue;
}
// Note: We extract the engine's sequence number for logging but use our own
// internal monotonic counter for event_id to ensure per-dp_rank monotonicity
let engine_seq = u64::from_be_bytes(seq_bytes.try_into().unwrap());
// Decode our batch of events.
let batch_result = rmps::from_slice::<KvEventBatch>(&payload);
let Ok(batch) = batch_result else {
let e = batch_result.unwrap_err();
tracing::warn!("Failed to decode KVEventBatch msgpack: {e}");
continue;
};
tracing::trace!(
"ZMQ listener on {} received batch with {} events (engine_seq={}, dp_rank={})",
zmq_endpoint,
batch.events.len(),
engine_seq,
batch.data_parallel_rank.unwrap_or(0)
);
let dp_rank = batch.data_parallel_rank.unwrap_or(0).cast_unsigned();
for raw_event in batch.events.into_iter() {
// Use shared monotonic event_id counter instead of engine's sequence number
let event_id = next_event_id.fetch_add(1, Ordering::SeqCst);
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
let event = convert_event(
raw_event,
event_id,
kv_block_size,
worker,
&warning_count,
);
if tx.send(event).is_err() {
tracing::warn!("Failed to send message to channel - receiver dropped");
exit_reason = "channel receiver dropped";
break 'main;
}
messages_processed += 1;
}
}
}
}
tracing::debug!(
"ZMQ listener exiting, reason: {}, messages processed: {}",
exit_reason,
messages_processed
);
}
// -------------------------------------------------------------------------
// Metrics Publishers ------------------------------------------------------
// -------------------------------------------------------------------------
/// Metrics data passed through the channel for NATS publishing
#[derive(Debug, Clone, Default, PartialEq)]
struct WorkerMetrics {
dp_rank: DpRank,
active_decode_blocks: u64,
}
pub struct WorkerMetricsPublisher {
tx: tokio::sync::watch::Sender<WorkerMetrics>,
rx: tokio::sync::watch::Receiver<WorkerMetrics>,
}
impl WorkerMetricsPublisher {
pub fn new() -> Result<Self> {
let (tx, rx) = tokio::sync::watch::channel(WorkerMetrics::default());
Ok(WorkerMetricsPublisher { tx, rx })
}
/// Publish worker metrics for load monitoring.
///
/// # Arguments
/// * `dp_rank` - Data parallel rank of the worker (None defaults to 0)
/// * `active_decode_blocks` - Number of active KV cache blocks
pub fn publish(&self, dp_rank: Option<DpRank>, active_decode_blocks: u64) -> Result<()> {
let metrics = WorkerMetrics {
dp_rank: dp_rank.unwrap_or(0),
active_decode_blocks,
};
tracing::trace!(
"Publish metrics: dp_rank={}, active_decode_blocks={}",
metrics.dp_rank,
metrics.active_decode_blocks
);
self.tx
.send(metrics)
.map_err(|_| anyhow::anyhow!("metrics channel closed"))
}
pub async fn create_endpoint(&self, component: Component) -> Result<()> {
let worker_id = component.drt().connection_id();
self.start_nats_metrics_publishing(component.namespace().clone(), worker_id);
Ok(())
}
/// Starts a background task to publish metrics over NATS
///
/// This task monitors metric changes (specifically active_decode_blocks)
/// and publishes stable metrics to NATS after they've been unchanged for 1ms.
fn start_nats_metrics_publishing(&self, namespace: Namespace, worker_id: u64) {
let nats_rx = self.rx.clone();
tokio::spawn(async move {
let event_publisher =
match EventPublisher::for_namespace(&namespace, KV_METRICS_SUBJECT).await {
Ok(publisher) => publisher,
Err(e) => {
tracing::error!("Failed to create metrics publisher: {}", e);
return;
}
};
let mut rx = nats_rx;
let mut last_metrics: Option<WorkerMetrics> = None;
let mut pending_publish: Option<WorkerMetrics> = None;
let mut publish_timer =
Box::pin(tokio::time::sleep(tokio::time::Duration::from_secs(0)));
publish_timer.as_mut().reset(tokio::time::Instant::now()); // Complete immediately
loop {
tokio::select! {
// Handle metrics changes
result = rx.changed() => {
if result.is_err() {
tracing::debug!(
"Metrics publisher sender dropped, stopping NATS background task"
);
break;
}
let metrics = rx.borrow_and_update().clone();
// Check if metrics have changed
let has_changed = last_metrics.as_ref() != Some(&metrics);
// If metrics changed, schedule a publish
if has_changed {
pending_publish = Some(metrics.clone());
last_metrics = Some(metrics);
// Start the 1ms timer
publish_timer.as_mut().reset(
tokio::time::Instant::now() + tokio::time::Duration::from_millis(1)
);
}
}
// Timer expired - publish if we have pending metrics
_ = &mut publish_timer => {
if let Some(metrics) = pending_publish.take() {
let active_load = ActiveLoad {
worker_id,
dp_rank: metrics.dp_rank,
active_decode_blocks: Some(metrics.active_decode_blocks),
active_prefill_tokens: None,
};
if let Err(e) = event_publisher.publish(&active_load).await {
tracing::warn!("Failed to publish metrics: {}", e);
}
}
// Reset timer to pending state to avoid tight loop
// It will be reset to 1ms when metrics actually change
publish_timer.as_mut().reset(
tokio::time::Instant::now() + tokio::time::Duration::from_secs(3600)
);
}
}
}
});
}
}
// -------------------------------------------------------------------------
// Testing -----------------------------------------------------------------
// -------------------------------------------------------------------------
#[cfg(test)] #[cfg(test)]
mod test_event_processing { mod test_event_processing {
...@@ -1459,9 +430,8 @@ mod test_event_processing { ...@@ -1459,9 +430,8 @@ mod test_event_processing {
#[cfg(test)] #[cfg(test)]
mod tests_startup_helpers { mod tests_startup_helpers {
use super::*; use super::*;
use crate::kv_router::KvIndexer;
use bytes::Bytes; use bytes::Bytes;
use dynamo_kv_router::indexer::{GetWorkersRequest, KvIndexerInterface}; use dynamo_kv_router::indexer::{GetWorkersRequest, KvIndexer, KvIndexerInterface};
use dynamo_kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash}; use dynamo_kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use zeromq::{PubSocket, Socket, SocketSend, ZmqMessage}; use zeromq::{PubSocket, Socket, SocketSend, ZmqMessage};
...@@ -2202,6 +1172,7 @@ mod test_exponential_backoff { ...@@ -2202,6 +1172,7 @@ mod test_exponential_backoff {
#[cfg(all(test, feature = "integration"))] #[cfg(all(test, feature = "integration"))]
mod test_integration_publisher { mod test_integration_publisher {
use super::*; use super::*;
use crate::kv_router::KV_METRICS_SUBJECT;
use dynamo_kv_router::protocols::ActiveLoad; use dynamo_kv_router::protocols::ActiveLoad;
use dynamo_runtime::distributed_test_utils::create_test_drt_async; use dynamo_runtime::distributed_test_utils::create_test_drt_async;
use dynamo_runtime::transports::event_plane::EventSubscriber; use dynamo_runtime::transports::event_plane::EventSubscriber;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use dynamo_kv_router::protocols::{ActiveLoad, DpRank};
use dynamo_runtime::component::{Component, Namespace};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::transports::event_plane::EventPublisher;
use crate::kv_router::KV_METRICS_SUBJECT;
#[derive(Debug, Clone, Default, PartialEq)]
struct WorkerMetrics {
dp_rank: DpRank,
active_decode_blocks: u64,
}
pub struct WorkerMetricsPublisher {
tx: tokio::sync::watch::Sender<WorkerMetrics>,
rx: tokio::sync::watch::Receiver<WorkerMetrics>,
}
impl WorkerMetricsPublisher {
pub fn new() -> Result<Self> {
let (tx, rx) = tokio::sync::watch::channel(WorkerMetrics::default());
Ok(Self { tx, rx })
}
pub fn publish(&self, dp_rank: Option<DpRank>, active_decode_blocks: u64) -> Result<()> {
let metrics = WorkerMetrics {
dp_rank: dp_rank.unwrap_or(0),
active_decode_blocks,
};
tracing::trace!(
"Publish metrics: dp_rank={}, active_decode_blocks={}",
metrics.dp_rank,
metrics.active_decode_blocks
);
self.tx
.send(metrics)
.map_err(|_| anyhow::anyhow!("metrics channel closed"))
}
pub async fn create_endpoint(&self, component: Component) -> Result<()> {
let worker_id = component.drt().connection_id();
self.start_nats_metrics_publishing(component.namespace().clone(), worker_id);
Ok(())
}
pub(super) fn start_nats_metrics_publishing(&self, namespace: Namespace, worker_id: u64) {
let nats_rx = self.rx.clone();
tokio::spawn(async move {
let event_publisher =
match EventPublisher::for_namespace(&namespace, KV_METRICS_SUBJECT).await {
Ok(publisher) => publisher,
Err(e) => {
tracing::error!("Failed to create metrics publisher: {}", e);
return;
}
};
let mut rx = nats_rx;
let mut last_metrics: Option<WorkerMetrics> = None;
let mut pending_publish: Option<WorkerMetrics> = None;
let mut publish_timer =
Box::pin(tokio::time::sleep(tokio::time::Duration::from_secs(0)));
publish_timer.as_mut().reset(tokio::time::Instant::now());
loop {
tokio::select! {
result = rx.changed() => {
if result.is_err() {
tracing::debug!(
"Metrics publisher sender dropped, stopping NATS background task"
);
break;
}
let metrics = rx.borrow_and_update().clone();
let has_changed = last_metrics.as_ref() != Some(&metrics);
if has_changed {
pending_publish = Some(metrics.clone());
last_metrics = Some(metrics);
publish_timer.as_mut().reset(
tokio::time::Instant::now()
+ tokio::time::Duration::from_millis(1)
);
}
}
_ = &mut publish_timer => {
if let Some(metrics) = pending_publish.take() {
let active_load = ActiveLoad {
worker_id,
dp_rank: metrics.dp_rank,
active_decode_blocks: Some(metrics.active_decode_blocks),
active_prefill_tokens: None,
};
if let Err(e) = event_publisher.publish(&active_load).await {
tracing::warn!("Failed to publish metrics: {}", e);
}
}
publish_timer.as_mut().reset(
tokio::time::Instant::now()
+ tokio::time::Duration::from_secs(3600)
);
}
}
}
});
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::Duration;
use rmp_serde as rmps;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use zeromq::{Socket, SocketRecv, SubSocket};
use dynamo_kv_router::protocols::*;
use dynamo_kv_router::zmq_wire::*;
pub(super) const INITIAL_BACKOFF_MS: u64 = 10;
pub(super) const MAX_BACKOFF_MS: u64 = 5000;
pub(super) const MAX_CONSECUTIVE_ERRORS: u32 = 10;
pub(super) const MAX_BACKOFF_EXPONENT: u32 = 8;
pub(super) fn calculate_backoff_ms(consecutive_errors: u32) -> u64 {
std::cmp::min(
INITIAL_BACKOFF_MS * 2_u64.pow(consecutive_errors.min(MAX_BACKOFF_EXPONENT)),
MAX_BACKOFF_MS,
)
}
pub(super) async fn start_zmq_listener(
zmq_endpoint: String,
zmq_topic: String,
worker_id: WorkerId,
tx: mpsc::UnboundedSender<PlacementEvent>,
cancellation_token: CancellationToken,
kv_block_size: u32,
next_event_id: Arc<AtomicU64>,
) {
tracing::debug!(
"KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')",
zmq_endpoint,
zmq_topic
);
let warning_count = Arc::new(AtomicU32::new(0));
let mut socket = SubSocket::new();
if let Err(e) = socket.subscribe(&zmq_topic).await {
tracing::error!("Failed to subscribe on ZMQ socket: {}", e);
return;
}
if let Err(e) = socket.connect(&zmq_endpoint).await {
tracing::error!("Failed to connect ZMQ SUB socket to {zmq_endpoint}: {e}");
return;
}
let mut consecutive_errors = 0u32;
#[expect(unused_assignments)]
let mut exit_reason = "unknown";
let mut messages_processed = 0u64;
'main: loop {
tokio::select! {
biased;
_ = cancellation_token.cancelled() => {
tracing::debug!("ZMQ listener received cancellation signal");
exit_reason = "cancellation token cancelled";
break 'main;
}
msg_result = socket.recv() => {
let Ok(msg) = msg_result else {
let e = msg_result.unwrap_err();
consecutive_errors += 1;
if consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
tracing::error!(
error=%e,
consecutive_errors=%consecutive_errors,
"Too many consecutive ZMQ errors, terminating listener"
);
exit_reason = "too many consecutive errors";
break 'main;
}
let backoff_ms = calculate_backoff_ms(consecutive_errors);
tracing::warn!(
error=%e,
consecutive_errors=%consecutive_errors,
backoff_ms=%backoff_ms,
"Error reading from ZMQ socket, applying exponential backoff"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
continue;
};
consecutive_errors = 0;
let mut frames: Vec<Vec<u8>> =
msg.into_vec().into_iter().map(|frame| frame.to_vec()).collect();
if frames.len() != 3 {
tracing::warn!(
"Received unexpected ZMQ frame count: expected 3, actual {}",
frames.len()
);
continue;
}
let payload = frames.pop().unwrap();
let seq_bytes = frames.pop().unwrap();
if seq_bytes.len() != 8 {
tracing::warn!(
"Invalid sequence number byte length: expected 8, actual {}",
seq_bytes.len()
);
continue;
}
let engine_seq = u64::from_be_bytes(seq_bytes.try_into().unwrap());
let batch_result = rmps::from_slice::<KvEventBatch>(&payload);
let Ok(batch) = batch_result else {
let e = batch_result.unwrap_err();
tracing::warn!("Failed to decode KVEventBatch msgpack: {e}");
continue;
};
tracing::trace!(
"ZMQ listener on {} received batch with {} events (engine_seq={}, dp_rank={})",
zmq_endpoint,
batch.events.len(),
engine_seq,
batch.data_parallel_rank.unwrap_or(0)
);
let dp_rank = batch.data_parallel_rank.unwrap_or(0).cast_unsigned();
for raw_event in batch.events {
let event_id = next_event_id.fetch_add(1, Ordering::SeqCst);
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
let event =
convert_event(raw_event, event_id, kv_block_size, worker, &warning_count);
if tx.send(event).is_err() {
tracing::warn!("Failed to send message to channel - receiver dropped");
exit_reason = "channel receiver dropped";
break 'main;
}
messages_processed += 1;
}
}
}
}
tracing::debug!(
"ZMQ listener exiting, reason: {}, messages processed: {}",
exit_reason,
messages_processed
);
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use futures::StreamExt;
use dynamo_runtime::{
component::Component,
pipeline::{ManyOut, RouterMode, SingleIn, network::egress::push_router::PushRouter},
};
use dynamo_kv_router::{
indexer::{IndexerQueryRequest, IndexerQueryResponse, KV_INDEXER_QUERY_ENDPOINT},
protocols::{LocalBlockHash, OverlapScores},
};
/// A remote indexer that queries a standalone KV indexer via the request plane.
///
/// Used by the frontend when `remote_indexer_component` is configured. Instead of
/// maintaining a local radix tree, this forwards `find_matches` queries to the
/// standalone indexer service over the Dynamo request plane.
pub struct RemoteIndexer {
router: PushRouter<IndexerQueryRequest, IndexerQueryResponse>,
model_name: String,
namespace: String,
}
impl RemoteIndexer {
pub async fn new(
component: &Component,
indexer_component_name: &str,
model_name: String,
) -> Result<Self> {
let namespace = component.namespace().name();
let indexer_ns = component.namespace();
let indexer_component = indexer_ns.component(indexer_component_name)?;
let endpoint = indexer_component.endpoint(KV_INDEXER_QUERY_ENDPOINT);
let client = endpoint.client().await?;
let router =
PushRouter::from_client_no_fault_detection(client, RouterMode::RoundRobin).await?;
Ok(Self {
router,
model_name,
namespace,
})
}
pub async fn find_matches(&self, block_hashes: Vec<LocalBlockHash>) -> Result<OverlapScores> {
let request = IndexerQueryRequest {
model_name: self.model_name.clone(),
namespace: self.namespace.clone(),
block_hashes,
};
let mut stream: ManyOut<IndexerQueryResponse> =
self.router.round_robin(SingleIn::new(request)).await?;
match stream.next().await {
Some(IndexerQueryResponse::Scores(scores)) => Ok(scores.into()),
Some(IndexerQueryResponse::Error(msg)) => {
Err(anyhow::anyhow!("Remote indexer error: {}", msg))
}
None => Err(anyhow::anyhow!("Remote indexer returned empty response")),
}
}
}
...@@ -81,6 +81,7 @@ where ...@@ -81,6 +81,7 @@ where
block_size, block_size,
selector, selector,
policy, policy,
kv_router_config.router_track_prefill_tokens,
component.drt().child_token(), component.drt().child_token(),
worker_type, worker_type,
watch_worker_configs, watch_worker_configs,
...@@ -180,9 +181,10 @@ where ...@@ -180,9 +181,10 @@ where
token_seq: Option<Vec<SequenceHash>>, token_seq: Option<Vec<SequenceHash>>,
isl_tokens: usize, isl_tokens: usize,
overlaps: OverlapScores, overlaps: OverlapScores,
track_prefill_tokens: bool,
) -> Vec<PotentialLoad> { ) -> Vec<PotentialLoad> {
self.inner self.inner
.get_potential_loads(token_seq, isl_tokens, overlaps) .get_potential_loads(token_seq, isl_tokens, overlaps, track_prefill_tokens)
} }
pub fn get_active_lora_counts(&self) -> HashMap<String, usize> { pub fn get_active_lora_counts(&self) -> HashMap<String, usize> {
......
...@@ -223,6 +223,7 @@ mod tests { ...@@ -223,6 +223,7 @@ mod tests {
token_sequence: Some(vec![0, 1, 2]), token_sequence: Some(vec![0, 1, 2]),
isl: 12, isl: 12,
overlap: 0, overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
worker: WorkerWithDpRank::new(0, 0), worker: WorkerWithDpRank::new(0, 0),
lora_name: None, lora_name: None,
...@@ -235,6 +236,7 @@ mod tests { ...@@ -235,6 +236,7 @@ mod tests {
token_sequence: Some(vec![3, 4]), token_sequence: Some(vec![3, 4]),
isl: 8, isl: 8,
overlap: 0, overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
worker: WorkerWithDpRank::new(0, 1), worker: WorkerWithDpRank::new(0, 1),
lora_name: None, lora_name: None,
...@@ -247,6 +249,7 @@ mod tests { ...@@ -247,6 +249,7 @@ mod tests {
token_sequence: Some(vec![0, 1, 2, 3]), token_sequence: Some(vec![0, 1, 2, 3]),
isl: 16, isl: 16,
overlap: 0, overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
worker: WorkerWithDpRank::new(1, 0), worker: WorkerWithDpRank::new(1, 0),
lora_name: None, lora_name: None,
...@@ -373,6 +376,7 @@ mod tests { ...@@ -373,6 +376,7 @@ mod tests {
token_sequence: None, token_sequence: None,
isl: 12, isl: 12,
overlap: 0, overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
worker: WorkerWithDpRank::from_worker_id(0), worker: WorkerWithDpRank::from_worker_id(0),
lora_name: None, lora_name: None,
...@@ -385,6 +389,7 @@ mod tests { ...@@ -385,6 +389,7 @@ mod tests {
token_sequence: None, token_sequence: None,
isl: 8, isl: 8,
overlap: 0, overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
worker: WorkerWithDpRank::from_worker_id(1), worker: WorkerWithDpRank::from_worker_id(1),
lora_name: None, lora_name: None,
...@@ -397,6 +402,7 @@ mod tests { ...@@ -397,6 +402,7 @@ mod tests {
token_sequence: None, token_sequence: None,
isl: 16, isl: 16,
overlap: 0, overlap: 0,
track_prefill_tokens: true,
expected_output_tokens: None, expected_output_tokens: None,
worker: WorkerWithDpRank::from_worker_id(2), worker: WorkerWithDpRank::from_worker_id(2),
lora_name: None, lora_name: None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment