Unverified Commit 6783bdca authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: enable local indexers by default, and use normal event plane by default...


chore: enable local indexers by default, and use normal event plane by default (not jetstream) (#5941)
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 3d7182b8
...@@ -126,7 +126,7 @@ impl ZmqKvEventPublisherConfig { ...@@ -126,7 +126,7 @@ impl ZmqKvEventPublisherConfig {
kv_block_size, kv_block_size,
zmq_endpoint = "tcp://127.0.0.1:5557".to_string(), zmq_endpoint = "tcp://127.0.0.1:5557".to_string(),
zmq_topic = "".to_string(), zmq_topic = "".to_string(),
enable_local_indexer = false, enable_local_indexer = true,
dp_rank = 0 dp_rank = 0
))] ))]
pub fn new( pub fn new(
......
...@@ -419,6 +419,15 @@ class ModelDeploymentCard: ...@@ -419,6 +419,15 @@ class ModelDeploymentCard:
A model deployment card is a collection of model information A model deployment card is a collection of model information
""" """
def to_json_str(self) -> str:
"""Serialize the model deployment card to a JSON string."""
...
@staticmethod
def from_json_str(json: str) -> "ModelDeploymentCard":
"""Deserialize a model deployment card from a JSON string."""
...
... ...
class ModelRuntimeConfig: class ModelRuntimeConfig:
...@@ -737,7 +746,7 @@ class ZmqKvEventPublisherConfig: ...@@ -737,7 +746,7 @@ class ZmqKvEventPublisherConfig:
kv_block_size: int, kv_block_size: int,
zmq_endpoint: str = "tcp://127.0.0.1:5557", zmq_endpoint: str = "tcp://127.0.0.1:5557",
zmq_topic: str = "", zmq_topic: str = "",
enable_local_indexer: bool = False, enable_local_indexer: bool = True,
dp_rank: int = 0 dp_rank: int = 0
) -> None: ) -> None:
""" """
...@@ -747,7 +756,7 @@ class ZmqKvEventPublisherConfig: ...@@ -747,7 +756,7 @@ class ZmqKvEventPublisherConfig:
:param kv_block_size: The block size for the key-value store. :param kv_block_size: The block size for the key-value store.
:param zmq_endpoint: The ZeroMQ endpoint. Defaults to "tcp://127.0.0.1:5557". :param zmq_endpoint: The ZeroMQ endpoint. Defaults to "tcp://127.0.0.1:5557".
:param zmq_topic: The ZeroMQ topic to subscribe to. Defaults to an empty string. :param zmq_topic: The ZeroMQ topic to subscribe to. Defaults to an empty string.
:param enable_local_indexer: Whether to enable the worker-local KV indexer. Defaults to False. :param enable_local_indexer: Whether to enable the worker-local KV indexer. Defaults to True.
:param dp_rank: The data parallel rank for this publisher. Defaults to 0. :param dp_rank: The data parallel rank for this publisher. Defaults to 0.
""" """
... ...
...@@ -924,10 +933,34 @@ class ModelType: ...@@ -924,10 +933,34 @@ class ModelType:
class RouterMode: class RouterMode:
"""Router mode for load balancing requests across workers""" """Router mode for load balancing requests across workers"""
RoundRobin: "RouterMode"
Random: "RouterMode"
KV: "RouterMode"
... ...
class RouterConfig: class RouterConfig:
"""How to route the request""" """How to route the request"""
def __init__(
self,
mode: RouterMode,
config: Optional[KvRouterConfig] = None,
active_decode_blocks_threshold: Optional[float] = None,
active_prefill_tokens_threshold: Optional[int] = None,
active_prefill_tokens_threshold_frac: Optional[float] = None,
enforce_disagg: bool = False,
) -> None:
"""
Create a RouterConfig.
Args:
mode: The router mode (RoundRobin, Random, or KV)
config: Optional KV router configuration (used when mode is KV)
active_decode_blocks_threshold: Threshold percentage (0.0-1.0) for decode blocks busy detection
active_prefill_tokens_threshold: Literal token count threshold for prefill busy detection
active_prefill_tokens_threshold_frac: Fraction of max_num_batched_tokens for busy detection
enforce_disagg: Enforce disaggregated prefill-decode mode
"""
... ...
class KvRouterConfig: class KvRouterConfig:
...@@ -938,6 +971,7 @@ class KvRouterConfig: ...@@ -938,6 +971,7 @@ class KvRouterConfig:
overlap_score_weight: float = 1.0, overlap_score_weight: float = 1.0,
router_temperature: float = 0.0, router_temperature: float = 0.0,
use_kv_events: bool = True, use_kv_events: bool = True,
durable_kv_events: bool = False,
router_replica_sync: bool = False, router_replica_sync: bool = False,
router_track_active_blocks: bool = True, router_track_active_blocks: bool = True,
router_track_output_blocks: bool = False, router_track_output_blocks: bool = False,
...@@ -955,6 +989,9 @@ class KvRouterConfig: ...@@ -955,6 +989,9 @@ class KvRouterConfig:
overlap_score_weight: Weight for overlap score in worker selection (default: 1.0) overlap_score_weight: Weight for overlap score in worker selection (default: 1.0)
router_temperature: Temperature for worker sampling via softmax (default: 0.0) router_temperature: Temperature for worker sampling via softmax (default: 0.0)
use_kv_events: Whether to use KV events from workers (default: True) use_kv_events: Whether to use KV events from workers (default: True)
durable_kv_events: Enable durable KV events using NATS JetStream (default: False).
When False, uses NATS Core / generic event plane with local_indexer mode.
When True, uses JetStream for durability and multi-replica consistency.
router_replica_sync: Enable replica synchronization (default: False) router_replica_sync: Enable replica synchronization (default: False)
router_track_active_blocks: Track active blocks for load balancing (default: True) router_track_active_blocks: Track active blocks for load balancing (default: True)
router_track_output_blocks: Track output blocks during generation (default: False). router_track_output_blocks: Track output blocks during generation (default: False).
...@@ -1026,7 +1063,7 @@ class EngineConfig: ...@@ -1026,7 +1063,7 @@ class EngineConfig:
"""Holds internal configuration for a Dynamo engine.""" """Holds internal configuration for a Dynamo engine."""
... ...
async def make_engine(args: EntrypointArgs) -> EngineConfig: async def make_engine(distributed_runtime: DistributedRuntime, args: EntrypointArgs) -> EngineConfig:
"""Make an engine matching the args""" """Make an engine matching the args"""
... ...
...@@ -1435,12 +1472,63 @@ class KvPushRouter: ...@@ -1435,12 +1472,63 @@ class KvPushRouter:
""" """
... ...
class EngineType:
"""Engine type for Dynamo workers"""
Echo: "EngineType"
Dynamic: "EngineType"
Mocker: "EngineType"
...
class EntrypointArgs: class EntrypointArgs:
""" """
Settings to connect an input to a worker and run them. Settings to connect an input to a worker and run them.
Use by `dynamo run`. Use by `dynamo run`.
""" """
def __init__(
self,
engine_type: "EngineType",
model_path: Optional[str] = None,
model_name: Optional[str] = None,
endpoint_id: Optional[str] = None,
context_length: Optional[int] = None,
template_file: Optional[str] = None,
router_config: Optional[RouterConfig] = None,
kv_cache_block_size: Optional[int] = None,
http_host: Optional[str] = None,
http_port: Optional[int] = None,
http_metrics_port: Optional[int] = None,
tls_cert_path: Optional[str] = None,
tls_key_path: Optional[str] = None,
extra_engine_args: Optional[str] = None,
namespace: Optional[str] = None,
is_prefill: bool = False,
migration_limit: int = 0,
engine_factory: Optional[Callable] = None,
) -> None:
"""
Create EntrypointArgs.
Args:
engine_type: The type of engine to use
model_path: Path to the model directory on disk
model_name: Model name or dynamo endpoint (e.g. 'dyn://namespace.component.endpoint')
endpoint_id: Optional endpoint ID
context_length: Optional context length override
template_file: Optional path to a prompt template file
router_config: Optional router configuration
kv_cache_block_size: Optional KV cache block size
http_host: HTTP host to bind to
http_port: HTTP port to bind to
http_metrics_port: HTTP metrics port (for gRPC service)
tls_cert_path: TLS certificate path (PEM format)
tls_key_path: TLS key path (PEM format)
extra_engine_args: Path to extra engine arguments file
namespace: Dynamo namespace for model discovery scoping
is_prefill: Whether this is a prefill worker
migration_limit: Maximum number of request migrations (0=disabled)
engine_factory: Optional Python engine factory callback
"""
... ...
class PlannerDecision: class PlannerDecision:
......
...@@ -5,7 +5,7 @@ mod model_manager; ...@@ -5,7 +5,7 @@ mod model_manager;
pub use model_manager::{ModelManager, ModelManagerError}; pub use model_manager::{ModelManager, ModelManagerError};
pub(crate) mod runtime_configs; pub(crate) mod runtime_configs;
pub use runtime_configs::{RuntimeConfigs, RuntimeConfigsSubscriber}; pub use runtime_configs::{RuntimeConfigWatch, runtime_config_watch};
mod watcher; mod watcher;
pub use watcher::{ModelUpdate, ModelWatcher}; pub use watcher::{ModelUpdate, ModelWatcher};
......
...@@ -11,7 +11,7 @@ use parking_lot::RwLock; ...@@ -11,7 +11,7 @@ use parking_lot::RwLock;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use super::worker_monitor::LoadThresholdConfig; use super::worker_monitor::LoadThresholdConfig;
use super::{KvWorkerMonitor, RuntimeConfigs}; use super::{KvWorkerMonitor, RuntimeConfigWatch, runtime_config_watch};
use dynamo_runtime::{ use dynamo_runtime::{
component::{Client, Endpoint, build_transport_type}, component::{Client, Endpoint, build_transport_type},
...@@ -77,7 +77,7 @@ pub struct ModelManager { ...@@ -77,7 +77,7 @@ pub struct ModelManager {
// Per-model monitoring: worker_monitors for load-based rejection, runtime_configs for KvScheduler // Per-model monitoring: worker_monitors for load-based rejection, runtime_configs for KvScheduler
worker_monitors: DashMap<String, KvWorkerMonitor>, worker_monitors: DashMap<String, KvWorkerMonitor>,
runtime_configs: DashMap<EndpointId, Arc<RuntimeConfigs>>, runtime_configs: DashMap<EndpointId, RuntimeConfigWatch>,
} }
impl Default for ModelManager { impl Default for ModelManager {
...@@ -563,12 +563,12 @@ impl ModelManager { ...@@ -563,12 +563,12 @@ impl ModelManager {
} }
/// Get or create a runtime config watcher for an endpoint. /// Get or create a runtime config watcher for an endpoint.
/// Spawns a background task to watch for worker config changes. /// Spawns a background task that joins instance availability and config discovery.
/// Returns a shared RuntimeConfigs that KvScheduler can use directly. /// Returns a `watch::Receiver` with the latest `HashMap<WorkerId, ModelRuntimeConfig>`.
pub async fn get_or_create_runtime_config_watcher( pub async fn get_or_create_runtime_config_watcher(
&self, &self,
endpoint: &Endpoint, endpoint: &Endpoint,
) -> anyhow::Result<Arc<RuntimeConfigs>> { ) -> anyhow::Result<RuntimeConfigWatch> {
let endpoint_id = endpoint.id(); let endpoint_id = endpoint.id();
// Fast path: return existing if present // Fast path: return existing if present
...@@ -576,21 +576,18 @@ impl ModelManager { ...@@ -576,21 +576,18 @@ impl ModelManager {
return Ok(existing.clone()); return Ok(existing.clone());
} }
// Atomic get-or-insert to avoid TOCTOU race // Slow path: create the watch (spawns a background task).
let inner = Arc::new(RuntimeConfigs::new()); // If another caller raced us, the entry() below picks up the winner;
let (result, is_new) = match self.runtime_configs.entry(endpoint_id) { // the loser's background task stops once its receivers are dropped.
Entry::Occupied(e) => (e.get().clone(), false), let rx = runtime_config_watch(endpoint).await?;
let result = match self.runtime_configs.entry(endpoint_id) {
Entry::Occupied(e) => e.get().clone(),
Entry::Vacant(e) => { Entry::Vacant(e) => {
e.insert(inner.clone()); e.insert(rx.clone());
(inner, true) rx
} }
}; };
// Only spawn watcher if we were the one who inserted
if is_new {
result.start_watcher(endpoint).await?;
}
Ok(result) Ok(result)
} }
...@@ -601,9 +598,9 @@ impl ModelManager { ...@@ -601,9 +598,9 @@ impl ModelManager {
endpoint_id: &EndpointId, endpoint_id: &EndpointId,
worker_id: WorkerId, worker_id: WorkerId,
) -> Option<DisaggregatedEndpoint> { ) -> Option<DisaggregatedEndpoint> {
let inner = self.runtime_configs.get(endpoint_id)?; let rx = self.runtime_configs.get(endpoint_id)?;
let config_ref = inner.configs.get(&worker_id)?; let configs = rx.borrow();
config_ref.as_ref()?.disaggregated_endpoint.clone() configs.get(&worker_id)?.disaggregated_endpoint.clone()
} }
/// Lists all models with worker monitors configured. /// Lists all models with worker monitors configured.
......
...@@ -2,9 +2,7 @@ ...@@ -2,9 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use dashmap::DashMap;
use tokio::sync::watch; use tokio::sync::watch;
use dynamo_runtime::component::Endpoint; use dynamo_runtime::component::Endpoint;
...@@ -15,186 +13,72 @@ use crate::kv_router::protocols::WorkerId; ...@@ -15,186 +13,72 @@ use crate::kv_router::protocols::WorkerId;
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::model_card::ModelDeploymentCard; use crate::model_card::ModelDeploymentCard;
/// Runtime configs for an endpoint with watch-based change notifications. /// Type alias for the runtime config watch receiver.
/// Call `subscribe()` to get a subscriber with its own watch receiver. pub type RuntimeConfigWatch = watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>;
pub struct RuntimeConfigs {
pub configs: Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>,
change_tx: watch::Sender<u64>,
}
impl RuntimeConfigs {
pub(crate) fn new() -> Self {
let (change_tx, _) = watch::channel(0u64);
Self {
configs: Arc::new(DashMap::new()),
change_tx,
}
}
/// Create a subscriber that can wait for config changes.
/// Each subscriber has its own watch receiver, so notifications are not lost.
pub fn subscribe(&self) -> RuntimeConfigsSubscriber {
RuntimeConfigsSubscriber {
configs: self.configs.clone(),
change_rx: self.change_tx.subscribe(),
}
}
/// Notify all subscribers of a change (internal use only).
fn notify_change(&self) {
// Increment counter to notify subscribers
self.change_tx.send_modify(|v| *v = v.wrapping_add(1));
}
/// Returns the number of workers in the configs.
pub fn num_workers(&self) -> usize {
self.configs.len()
}
/// Update configs with new worker instances and their configs.
/// Notifies subscribers if a config with Some value is added or a worker is removed.
pub(crate) fn update(
&self,
new_instance_ids: &[WorkerId],
new_configs: &HashMap<WorkerId, ModelRuntimeConfig>,
) {
// First, remove workers that no longer exist
let current_workers: HashSet<WorkerId> = self.configs.iter().map(|r| *r.key()).collect();
let new_workers: HashSet<WorkerId> = new_instance_ids.iter().copied().collect();
let mut worker_removed = false;
for removed_worker in current_workers.difference(&new_workers) {
self.configs.remove(removed_worker);
worker_removed = true;
}
// Then, add/update workers /// Join instance availability and config discovery into a single watch.
// Track if any config became Some (for notify) ///
let mut config_added = false; /// Only includes workers that have BOTH an instance registration AND a runtime config.
for worker_id in new_instance_ids { /// Spawns a background task that recomputes the joined state whenever either source changes.
let config = new_configs.get(worker_id).cloned(); /// The returned `watch::Receiver` always contains the latest joined snapshot.
if config.is_some() { pub async fn runtime_config_watch(endpoint: &Endpoint) -> anyhow::Result<RuntimeConfigWatch> {
let prev_config = self.configs.get(worker_id);
let was_none = prev_config
.as_ref()
.map(|r| r.value().is_none())
.unwrap_or(true);
if was_none {
tracing::info!("RuntimeConfigs: config found for worker_id: {worker_id}");
config_added = true;
}
}
self.configs.insert(*worker_id, config);
}
// Notify when a config with Some value is added OR a worker is removed
if config_added || worker_removed {
self.notify_change();
}
}
/// Spawn background task to watch runtime configs via discovery.
/// Does not block - consumers should use `subscribe().wait_for_some()` if they need workers.
pub(crate) async fn start_watcher(self: &Arc<Self>, endpoint: &Endpoint) -> anyhow::Result<()> {
let component = endpoint.component(); let component = endpoint.component();
let cancellation_token = component.drt().primary_token(); let cancel_token = component.drt().primary_token();
// Set up discovery watch for EndpointModels // Source 1: instance availability (watches DiscoveryQuery::Endpoint)
let client = endpoint.client().await?;
let mut instance_ids_rx = client.instance_avail_watcher();
// Source 2: runtime configs from discovery (watches DiscoveryQuery::EndpointModels)
let discovery = component.drt().discovery(); let discovery = component.drt().discovery();
let endpoint_id = endpoint.id(); let eid = endpoint.id();
let discovery_key = DiscoveryQuery::EndpointModels { let stream = discovery
namespace: endpoint_id.namespace.clone(), .list_and_watch(
component: endpoint_id.component.clone(), DiscoveryQuery::EndpointModels {
endpoint: endpoint_id.name.clone(), namespace: eid.namespace.clone(),
}; component: eid.component.clone(),
let discovery_stream = discovery endpoint: eid.name.clone(),
.list_and_watch(discovery_key.clone(), Some(cancellation_token.clone())) },
Some(cancel_token.clone()),
)
.await?; .await?;
let mut configs_rx =
watch_and_extract_field(stream, |card: ModelDeploymentCard| card.runtime_config);
// Extract runtime_config from ModelDeploymentCard let (tx, rx) = watch::channel(HashMap::new());
let mut runtime_configs_rx =
watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
card.runtime_config
});
// Also watch instance IDs
let client = endpoint.client().await?;
let mut instance_ids_rx = client.instance_avail_watcher();
// Spawn background task to watch for config changes
// Note: We don't block here - consumers should wait on notify for configs they need
let inner = self.clone();
let cancel_token = cancellation_token.clone();
tokio::spawn(async move { tokio::spawn(async move {
tracing::trace!("RuntimeConfigs watcher started");
loop { loop {
// Wait for either instances or configs to change
tokio::select! { tokio::select! {
_ = cancel_token.cancelled() => { _ = cancel_token.cancelled() => break,
tracing::trace!("RuntimeConfigs watcher shutting down"); result = instance_ids_rx.changed() => { if result.is_err() { break; } }
break; result = configs_rx.changed() => { if result.is_err() { break; } }
}
result = instance_ids_rx.changed() => {
if result.is_err() {
tracing::warn!("instance IDs watch sender shutdown");
break;
}
}
result = runtime_configs_rx.changed() => {
if result.is_err() {
tracing::warn!("runtime configs watch sender shutdown");
break;
}
}
}
// Get the latest values from both channels
let new_instance_ids = instance_ids_rx.borrow_and_update().clone();
let new_configs = runtime_configs_rx.borrow_and_update().clone();
inner.update(&new_instance_ids, &new_configs);
tracing::trace!(
"RuntimeConfigs: Updated with {} workers",
inner.configs.len()
);
} }
tracing::trace!("RuntimeConfigs watcher stopped");
});
Ok(())
}
}
/// A subscriber to runtime config changes. let instances: HashSet<WorkerId> = instance_ids_rx
/// Each subscriber has its own watch receiver, ensuring no notifications are lost. .borrow_and_update()
pub struct RuntimeConfigsSubscriber {
pub configs: Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>,
pub change_rx: watch::Receiver<u64>,
}
impl RuntimeConfigsSubscriber {
/// Wait until at least one worker has a Some config.
/// Returns the list of worker IDs that have configs.
/// This is race-safe: checks the DashMap first, only waits if empty.
/// Returns empty vec if the sender is dropped (shutdown).
pub async fn wait_for_some(&mut self) -> Vec<WorkerId> {
loop {
let ready: Vec<WorkerId> = self
.configs
.iter() .iter()
.filter(|r| r.value().is_some()) .copied()
.map(|r| *r.key())
.collect(); .collect();
let configs = configs_rx.borrow_and_update().clone();
if !ready.is_empty() { let ready: HashMap<WorkerId, ModelRuntimeConfig> = instances
return ready; .into_iter()
} .filter_map(|id| configs.get(&id).map(|cfg| (id, cfg.clone())))
.collect();
// If sender dropped (shutdown), return empty rather than loop forever // Only send if the joined result actually changed, to avoid waking
if self.change_rx.changed().await.is_err() { // downstream consumers (wait_for, changed) on no-op recomputations.
tracing::warn!("RuntimeConfigsSubscriber: sender dropped during wait_for_some"); if *tx.borrow() == ready {
return vec![]; continue;
} }
// Break if all receivers dropped (e.g., TOCTOU in model_manager discards a duplicate).
if tx.send(ready).is_err() {
break;
} }
} }
});
Ok(rx)
} }
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use derive_builder::Builder;
use rand::Rng;
use serde::{Deserialize, Serialize};
use validator::Validate;
use crate::kv_router::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block};
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize, Validate)]
pub struct RouterConfigOverride {
#[builder(default)]
pub overlap_score_weight: Option<f64>,
#[builder(default)]
#[validate(range(min = 0.0))]
pub router_temperature: Option<f64>,
}
/// KV Router configuration parameters
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Validate)]
pub struct KvRouterConfig {
#[validate(range(min = 0.0))]
pub overlap_score_weight: f64,
#[validate(range(min = 0.0))]
pub router_temperature: f64,
pub use_kv_events: bool,
/// Enable durable KV events using NATS JetStream instead of the default event plane.
/// When false (default), the router uses the event-plane subscriber and requires
/// workers to have local_indexer enabled for gap recovery.
/// When true, uses JetStream for durability and multi-replica consistency.
pub durable_kv_events: bool,
pub router_replica_sync: bool,
/// Whether to track active blocks in the router (default: true)
pub router_track_active_blocks: bool,
/// Whether to track output blocks during generation (default: false)
/// When enabled, the router adds placeholder blocks as tokens are generated
/// and applies fractional decay based on progress toward expected_output_tokens.
pub router_track_output_blocks: bool,
/// Whether to assume KV cache reuse when tracking active blocks (default: true).
/// When true, computes actual block hashes for sequence tracking.
/// When false, generates random hashes (assuming no KV cache reuse).
pub router_assume_kv_reuse: bool,
/// Threshold for triggering snapshots. If None, no snapshots will be performed.
#[validate(range(min = 1))]
pub router_snapshot_threshold: Option<u32>,
/// Whether to reset the router state on startup (default: false)
pub router_reset_states: bool,
/// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
#[validate(range(min = 0.0))]
pub router_ttl_secs: f64,
/// Maximum tree size before pruning (only used when use_kv_events is false, default: 2^20 = 1048576)
#[validate(range(min = 1))]
pub router_max_tree_size: usize,
/// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
#[validate(range(min = 0.0, max = 1.0))]
pub router_prune_target_ratio: f64,
}
impl Default for KvRouterConfig {
fn default() -> Self {
Self {
overlap_score_weight: 1.0,
router_temperature: 0.0,
use_kv_events: true,
durable_kv_events: false, // default to NATS Core (local indexer mode)
router_replica_sync: false,
router_track_active_blocks: true,
router_track_output_blocks: false,
router_assume_kv_reuse: true,
router_snapshot_threshold: Some(1000000),
router_reset_states: false,
router_ttl_secs: 120.0,
router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
router_prune_target_ratio: 0.8,
}
}
}
impl KvRouterConfig {
/// Compute sequence hashes for active block tracking based on configuration.
///
/// Returns:
/// - `None` if `router_track_active_blocks` is false
/// - Random hashes if `router_track_active_blocks` is true but `router_assume_kv_reuse` is false
/// - Actual sequence hashes if both are true
pub fn compute_seq_hashes_for_tracking(
&self,
tokens: &[u32],
block_size: u32,
) -> Option<Vec<u64>> {
if !self.router_track_active_blocks {
return None;
}
let num_blocks = tokens.len() / block_size as usize;
if num_blocks == 0 {
return Some(Vec::new());
}
if self.router_assume_kv_reuse {
// Compute actual block hashes and sequence hashes
let block_hashes = compute_block_hash_for_seq(tokens, block_size, None);
Some(compute_seq_hash_for_block(&block_hashes))
} else {
// Generate random hashes (no KV reuse assumed)
let mut rng = rand::rng();
Some((0..num_blocks).map(|_| rng.random::<u64>()).collect())
}
}
/// Check if KV event subscription should be started.
///
/// Returns false if:
/// - KV events are disabled (`use_kv_events=false`)
/// - Overlap scoring is disabled (`overlap_score_weight=0`)
///
/// When false, the router skips starting the KV event subscription entirely,
/// avoiding the need to query workers for their local indexer state.
pub fn should_subscribe_to_kv_events(&self) -> bool {
self.use_kv_events && self.overlap_score_weight > 0.0
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use anyhow::Result;
use dynamo_runtime::{
pipeline::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
SingleIn, async_trait,
},
protocols::annotated::Annotated,
};
use futures::stream::{self, StreamExt};
use serde_json::json;
use crate::{
kv_router::{
KvRouter,
protocols::{TokensWithHashes, WorkerWithDpRank},
},
preprocessor::PreprocessedRequest,
protocols::common::{llm_backend::LLMEngineOutput, timing::RequestPhase},
};
pub struct KvPushRouter {
inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
pub chooser: Arc<KvRouter>,
}
/// Result of worker selection containing instance ID, dp_rank, and overlap amount.
struct WorkerSelection {
instance_id: u64,
dp_rank: u32,
overlap_amount: u32,
}
impl KvPushRouter {
pub fn new(
inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
chooser: Arc<KvRouter>,
) -> Self {
KvPushRouter { inner, chooser }
}
/// Select a worker for the request, either using a preselected worker or finding the best match.
///
/// When `is_query_only` is false and `handle_local_updates` is true, this also registers
/// the request with the scheduler via `add_request`.
async fn select_worker(
&self,
context_id: &str,
request: &PreprocessedRequest,
phase: RequestPhase,
is_query_only: bool,
handle_local_updates: bool,
) -> Result<WorkerSelection, Error> {
let routing = request.routing.as_ref();
let lora_name = routing.and_then(|r| r.lora_name.clone());
let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens);
// Get pre-selected worker based on phase, with backend_instance_id as fallback
let preselected_id = match phase {
RequestPhase::Prefill => {
routing.and_then(|r| r.prefill_worker_id.or(r.backend_instance_id))
}
RequestPhase::Decode => {
routing.and_then(|r| r.decode_worker_id.or(r.backend_instance_id))
}
RequestPhase::Aggregated => routing.and_then(|r| r.backend_instance_id),
};
let Some(id) = preselected_id else {
let (best_worker, overlap_amount) = self
.chooser
.find_best_match(
Some(context_id),
&request.token_ids,
request.router_config_override.as_ref(),
!is_query_only,
lora_name,
)
.await?;
return Ok(WorkerSelection {
instance_id: best_worker.worker_id,
dp_rank: best_worker.dp_rank,
overlap_amount,
});
};
tracing::debug!(
worker_id = id,
dp_rank = dp_rank,
?phase,
"Routing to specified worker"
);
let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = self
.chooser
.get_overlap_blocks(&request.token_ids, worker)
.await?;
if !is_query_only && handle_local_updates {
self.chooser
.add_request(
context_id.to_string(),
&request.token_ids,
overlap_blocks,
expected_output_tokens,
worker,
lora_name,
)
.await;
} else {
tracing::debug!(
request_id = %context_id,
worker_id = id,
dp_rank = dp_rank,
"Skipping add_request - query or handled externally"
);
}
Ok(WorkerSelection {
instance_id: id,
dp_rank,
overlap_amount: overlap_blocks,
})
}
}
#[async_trait]
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for KvPushRouter
{
/// Generate method that handles KV-aware routing with three distinct behaviors:
///
/// 1. **If `query_instance_id` annotation is set**:
/// - Returns the best matching worker ID without routing the request
/// - Does NOT update any router local states
/// - Response includes worker_instance_id and token_data annotations
///
/// 2. **If `backend_instance_id` is set in the request**:
/// - Routes directly to the specified backend instance
/// - DOES update router states to track this request (unless query_instance_id is also set)
/// - Bypasses the normal KV matching logic
///
/// 3. **If neither are set (default behavior)**:
/// - Finds the best worker based on KV cache overlap
/// - Updates router states to track the request
/// - Routes to the selected worker
///
/// The router state updates include tracking active sequences and managing
/// prefill/completion lifecycle for proper KV cache management.
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
// Extract context ID for request tracking
let context_id = request.context().id().to_string();
// Simple query-only detection: presence of query_instance_id annotation means query-only mode
let is_query_only = request.get_annotation_value("query_instance_id").is_some();
// Determine if this router should handle local state updates (add_request, free, etc.)
// Default is true (router handles bookkeeping). Set to false for GAIE Stage 2 where
// an external orchestrator (e.g., EPP sidecar) handles bookkeeping via C FFI.
let handle_local_updates = request
.routing
.as_ref()
.and_then(|r| r.enable_local_updates)
.unwrap_or(true);
// Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
let phase = request
.tracker
.as_ref()
.map(|t| t.phase())
.unwrap_or(RequestPhase::Aggregated);
let block_size = self.chooser.block_size() as usize;
let selection = self
.select_worker(
&context_id,
&request,
phase,
is_query_only,
handle_local_updates,
)
.await?;
let WorkerSelection {
instance_id,
dp_rank,
overlap_amount,
} = selection;
// In approximate mode (use_kv_events=false), record the routing decision
// so the indexer can track cache state based on routing decisions.
// This covers both pre-selected workers and find_best_match selections.
if !is_query_only && !self.chooser.kv_router_config().use_kv_events {
let worker = WorkerWithDpRank::new(instance_id, dp_rank);
let mut tokens_with_hashes =
TokensWithHashes::new(request.token_ids.clone(), self.chooser.block_size());
if let Err(e) = self
.chooser
.indexer()
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await
{
tracing::warn!(
request_id = %context_id,
worker_id = instance_id,
dp_rank = dp_rank,
error = %e,
"Failed to record routing decision in approximate mode"
);
}
}
// Record metrics in tracker: KV hit rate, worker ID, and worker type based on phase.
// Worker type is stored at routing time to avoid expensive MDC lookups when
// updating Prometheus metrics (TTFT/ITL) later in the response stream.
if let Some(ref tracker) = request.tracker {
let isl_blocks = request.token_ids.len().div_ceil(block_size);
tracker.record_kv_hit(overlap_amount, isl_blocks);
tracker.record_worker_full(instance_id, dp_rank, self.chooser.worker_type());
}
// Handle query-only requests: early return with worker info
if is_query_only {
let stream_context = request.context().clone();
// Tracker is always created for query-only requests (delta generator enables tracking
// when query_instance_id annotation is present)
let worker_id_info = request.tracker.as_ref().and_then(|t| t.get_worker_info());
tracing::trace!(
?phase,
worker_id = instance_id,
?worker_id_info,
"Returning worker selection (query-only mode)"
);
let output = LLMEngineOutput {
disaggregated_params: Some(json!({
"worker_id": worker_id_info,
"token_ids": request.token_ids
})),
..Default::default()
};
let response = Annotated::from_data(output);
let stream = stream::iter(vec![response]);
return Ok(ResponseStream::new(Box::pin(stream), stream_context));
}
// Route to worker
let isl_tokens = request.token_ids.len();
let expected_output_tokens = request
.routing
.as_ref()
.and_then(|r| r.expected_output_tokens);
let track_output_blocks =
self.chooser.kv_router_config().router_track_output_blocks && handle_local_updates;
let (mut backend_input, context) = request.into_parts();
backend_input.routing_mut().dp_rank = Some(dp_rank);
let updated_request = context.map(|_| backend_input);
let chooser = self.chooser.clone();
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
let stream_context = response_stream.context();
let context_for_monitoring = stream_context.clone();
// Wrap stream with lifecycle management (mark_prefill_completed, free)
// Only perform these operations if handle_local_updates is true.
// When false, an external caller (e.g., GAIE sidecar) handles bookkeeping via C FFI.
let wrapped_stream = Box::pin(async_stream::stream! {
let mut prefill_marked = false;
// Output block tracking state
let mut cumulative_osl: usize = 0;
let mut current_total_blocks = isl_tokens.div_ceil(block_size);
loop {
tokio::select! {
biased;
_ = context_for_monitoring.stopped() => {
tracing::debug!("Request {context_id} cancelled, ending stream");
break;
}
item = response_stream.next() => {
let Some(item) = item else {
break;
};
if handle_local_updates && !prefill_marked {
// Only mark prefill completed when we receive actual tokens,
// not empty bootstrap info (token_ids: []) from disaggregated prefill
let has_tokens = item.data.as_ref()
.map(|d| !d.token_ids.is_empty())
.unwrap_or(false);
if has_tokens {
if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
}
prefill_marked = true;
}
}
// Track output blocks if enabled
if track_output_blocks {
let new_tokens = item.data.as_ref()
.map(|d| d.token_ids.len())
.unwrap_or(0);
cumulative_osl += new_tokens;
let new_total_blocks = (isl_tokens + cumulative_osl).div_ceil(block_size);
if new_total_blocks > current_total_blocks {
// New block boundary crossed - add output block with decay
// Clamp eot to min 1 to avoid division by zero, and result to min 0.0
let decay_fraction = expected_output_tokens.map(|eot| {
(1.0 - (cumulative_osl as f64 / eot.max(1) as f64)).max(0.0)
});
if let Err(e) = chooser.add_output_block(&context_id, decay_fraction).await {
tracing::warn!(
"Failed to add output block for request {context_id}: {e}"
);
}
current_total_blocks = new_total_blocks;
}
}
yield item;
}
}
}
// Only call free() if we handle local updates.
// When handle_local_updates=false, external caller handles cleanup via C FFI.
if handle_local_updates
&& let Err(e) = chooser.free(&context_id).await
{
tracing::warn!("Failed to free request {context_id}: {e}");
}
});
Ok(ResponseStream::new(wrapped_stream, stream_context))
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use crate::discovery::RuntimeConfigs; use crate::discovery::RuntimeConfigWatch;
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result; use anyhow::Result;
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
...@@ -99,7 +99,7 @@ impl KvScheduler { ...@@ -99,7 +99,7 @@ impl KvScheduler {
pub async fn start( pub async fn start(
component: Component, component: Component,
block_size: u32, block_size: u32,
workers_with_configs: Arc<RuntimeConfigs>, workers_with_configs: RuntimeConfigWatch,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool, replica_sync: bool,
router_id: u64, router_id: u64,
...@@ -107,13 +107,10 @@ impl KvScheduler { ...@@ -107,13 +107,10 @@ impl KvScheduler {
) -> Result<Self, KvSchedulerError> { ) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default())); let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
// Get initial workers from DashMap for slot initialization. // Get initial workers from watch receiver.
// Caller must ensure at least one worker is present (via wait_for_some). // Caller must ensure at least one worker is present (via wait_for).
let initial_workers: HashMap<WorkerId, Option<ModelRuntimeConfig>> = workers_with_configs let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
.configs workers_with_configs.borrow().clone();
.iter()
.map(|r| (*r.key(), r.value().clone()))
.collect();
let slots = Arc::new( let slots = Arc::new(
ActiveSequencesMultiWorker::new( ActiveSequencesMultiWorker::new(
...@@ -128,25 +125,21 @@ impl KvScheduler { ...@@ -128,25 +125,21 @@ impl KvScheduler {
.map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?, .map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?,
); );
// Spawn background task to sync slots with DashMap when notified of changes. // Spawn background task to sync slots when the watch value changes.
// ModelManager's watcher updates the DashMap and notifies; we wait on watch receiver here.
let slots_monitor = slots.clone(); let slots_monitor = slots.clone();
let subscriber = workers_with_configs.subscribe(); let mut monitor_rx = workers_with_configs.clone();
let configs_monitor = subscriber.configs;
let mut change_rx = subscriber.change_rx;
let monitor_cancel_token = component.drt().child_token(); let monitor_cancel_token = component.drt().child_token();
tokio::spawn(async move { tokio::spawn(async move {
tracing::trace!("KvScheduler workers monitoring task started"); tracing::trace!("KvScheduler workers monitoring task started");
let mut last_workers: HashSet<WorkerId> = HashSet::new(); let mut last_workers: HashMap<WorkerId, ModelRuntimeConfig> = HashMap::new();
loop { loop {
// Wait for notification or cancellation
tokio::select! { tokio::select! {
_ = monitor_cancel_token.cancelled() => { _ = monitor_cancel_token.cancelled() => {
tracing::trace!("KvScheduler workers monitoring task shutting down"); tracing::trace!("KvScheduler workers monitoring task shutting down");
break; break;
} }
result = change_rx.changed() => { result = monitor_rx.changed() => {
if result.is_err() { if result.is_err() {
tracing::warn!("KvScheduler: config watch sender dropped, shutting down"); tracing::warn!("KvScheduler: config watch sender dropped, shutting down");
break; break;
...@@ -154,25 +147,17 @@ impl KvScheduler { ...@@ -154,25 +147,17 @@ impl KvScheduler {
} }
} }
// Get current workers from DashMap let current_workers = monitor_rx.borrow_and_update().clone();
let current_workers: HashMap<WorkerId, Option<ModelRuntimeConfig>> =
configs_monitor
.iter()
.map(|r| (*r.key(), r.value().clone()))
.collect();
let current_worker_ids: HashSet<WorkerId> =
current_workers.keys().copied().collect();
// Only update slots if workers have changed if current_workers != last_workers {
if current_worker_ids != last_workers { slots_monitor.update_workers(current_workers.clone());
slots_monitor.update_workers(current_workers); last_workers = current_workers;
last_workers = current_worker_ids;
} }
} }
}); });
let slots_clone = slots.clone(); let slots_clone = slots.clone();
let workers_scheduler = workers_with_configs.clone(); let scheduler_rx = workers_with_configs.clone();
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024); let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
let scheduler_cancel_token = component.drt().primary_token(); let scheduler_cancel_token = component.drt().primary_token();
let hit_rate_publisher = let hit_rate_publisher =
...@@ -209,12 +194,8 @@ impl KvScheduler { ...@@ -209,12 +194,8 @@ impl KvScheduler {
request.decode_blocks = decode_blocks; request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens; request.prefill_tokens = prefill_tokens;
// Read the current workers configuration from DashMap // Read the current workers configuration from watch receiver
let workers: HashMap<WorkerId, Option<ModelRuntimeConfig>> = workers_scheduler let workers: HashMap<WorkerId, ModelRuntimeConfig> = scheduler_rx.borrow().clone();
.configs
.iter()
.map(|r| (*r.key(), r.value().clone()))
.collect();
match selector.select_worker(&workers, &request, block_size) { match selector.select_worker(&workers, &request, block_size) {
Ok(selection) => { Ok(selection) => {
...@@ -511,7 +492,7 @@ impl DefaultWorkerSelector { ...@@ -511,7 +492,7 @@ impl DefaultWorkerSelector {
impl WorkerSelector for DefaultWorkerSelector { impl WorkerSelector for DefaultWorkerSelector {
fn select_worker( fn select_worker(
&self, &self,
workers: &HashMap<WorkerId, Option<ModelRuntimeConfig>>, workers: &HashMap<WorkerId, ModelRuntimeConfig>,
request: &SchedulingRequest, request: &SchedulingRequest,
block_size: u32, block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> { ) -> Result<WorkerSelectionResult, KvSchedulerError> {
...@@ -541,11 +522,8 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -541,11 +522,8 @@ impl WorkerSelector for DefaultWorkerSelector {
// Outer loop: iterate over all workers from runtime config // Outer loop: iterate over all workers from runtime config
// Inner loop: iterate over all dp_ranks for each worker // Inner loop: iterate over all dp_ranks for each worker
for (worker_id, config) in workers.iter() { for (worker_id, config) in workers.iter() {
// Get data_parallel_size from runtime config let data_parallel_size = config.data_parallel_size;
// data_parallel_size defaults to 1 in ModelRuntimeConfig
let data_parallel_size = config.as_ref().map(|c| c.data_parallel_size).unwrap_or(1); // Fallback if config is None
// Iterate over all dp_ranks for this worker
for dp_rank in 0..data_parallel_size { for dp_rank in 0..data_parallel_size {
let worker = WorkerWithDpRank::new(*worker_id, dp_rank); let worker = WorkerWithDpRank::new(*worker_id, dp_rank);
...@@ -612,7 +590,6 @@ impl WorkerSelector for DefaultWorkerSelector { ...@@ -612,7 +590,6 @@ impl WorkerSelector for DefaultWorkerSelector {
// this is a runtime config set on a per worker basis, not per dp-rank // this is a runtime config set on a per worker basis, not per dp-rank
let total_blocks_info = workers let total_blocks_info = workers
.get(&best_worker.worker_id) .get(&best_worker.worker_id)
.and_then(|cfg| cfg.as_ref())
.and_then(|cfg| cfg.total_kv_blocks) .and_then(|cfg| cfg.total_kv_blocks)
.map(|blocks| format!(", total blocks: {}", blocks)) .map(|blocks| format!(", total blocks: {}", blocks))
.unwrap_or_default(); .unwrap_or_default();
......
...@@ -424,7 +424,7 @@ impl ActiveSequencesMultiWorker { ...@@ -424,7 +424,7 @@ impl ActiveSequencesMultiWorker {
pub async fn new( pub async fn new(
component: Component, component: Component,
block_size: usize, block_size: usize,
workers_with_configs: HashMap<u64, Option<ModelRuntimeConfig>>, workers_with_configs: HashMap<u64, ModelRuntimeConfig>,
replica_sync: bool, replica_sync: bool,
router_id: u64, router_id: u64,
worker_type: &'static str, worker_type: &'static str,
...@@ -438,7 +438,7 @@ impl ActiveSequencesMultiWorker { ...@@ -438,7 +438,7 @@ impl ActiveSequencesMultiWorker {
// Expand workers by their dp_rank // Expand workers by their dp_rank
for (worker_id, config) in workers_with_configs { for (worker_id, config) in workers_with_configs {
let dp_size = config.as_ref().map(|c| c.data_parallel_size).unwrap_or(1); let dp_size = config.data_parallel_size;
for dp_rank in 0..dp_size { for dp_rank in 0..dp_size {
let worker = WorkerWithDpRank::new(worker_id, dp_rank); let worker = WorkerWithDpRank::new(worker_id, dp_rank);
...@@ -710,17 +710,14 @@ impl ActiveSequencesMultiWorker { ...@@ -710,17 +710,14 @@ impl ActiveSequencesMultiWorker {
} }
/// Update the set of workers, adding and removing as needed /// Update the set of workers, adding and removing as needed
pub fn update_workers( pub fn update_workers(&self, new_workers_with_configs: HashMap<u64, ModelRuntimeConfig>) {
&self,
new_workers_with_configs: HashMap<u64, Option<ModelRuntimeConfig>>,
) {
let current_workers: HashSet<WorkerWithDpRank> = let current_workers: HashSet<WorkerWithDpRank> =
self.senders.iter().map(|entry| *entry.key()).collect(); self.senders.iter().map(|entry| *entry.key()).collect();
// Expand new workers by their dp_rank // Expand new workers by their dp_rank
let mut new_workers: HashSet<WorkerWithDpRank> = HashSet::new(); let mut new_workers: HashSet<WorkerWithDpRank> = HashSet::new();
for (worker_id, config) in &new_workers_with_configs { for (worker_id, config) in &new_workers_with_configs {
let dp_size = config.as_ref().map(|c| c.data_parallel_size).unwrap_or(1); let dp_size = config.data_parallel_size;
for dp_rank in 0..dp_size { for dp_rank in 0..dp_size {
new_workers.insert(WorkerWithDpRank::new(*worker_id, dp_rank)); new_workers.insert(WorkerWithDpRank::new(*worker_id, dp_rank));
...@@ -784,10 +781,15 @@ impl ActiveSequencesMultiWorker { ...@@ -784,10 +781,15 @@ impl ActiveSequencesMultiWorker {
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
lora_name: Option<String>, lora_name: Option<String>,
) -> Result<(), SequenceError> { ) -> Result<(), SequenceError> {
// Check for worker existence // Clone the sender upfront so we don't hold the DashMap Ref across
if !self.senders.contains_key(&worker) { // the .await points below. Also eliminates the TOCTOU between
return Err(SequenceError::WorkerNotFound { worker }); // contains_key and a later get().unwrap().
} let sender = self
.senders
.get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?
.value()
.clone();
// Check for duplicate request // Check for duplicate request
if let Some(existing_worker) = self.request_to_worker.get(&request_id) { if let Some(existing_worker) = self.request_to_worker.get(&request_id) {
...@@ -825,9 +827,7 @@ impl ActiveSequencesMultiWorker { ...@@ -825,9 +827,7 @@ impl ActiveSequencesMultiWorker {
self.request_to_lora.insert(request_id.clone(), lora); self.request_to_lora.insert(request_id.clone(), lora);
} }
self.senders sender
.get(&worker)
.unwrap()
.send(UpdateSequences::AddRequest { .send(UpdateSequences::AddRequest {
request_id, request_id,
token_sequence, token_sequence,
...@@ -855,25 +855,31 @@ impl ActiveSequencesMultiWorker { ...@@ -855,25 +855,31 @@ impl ActiveSequencesMultiWorker {
Ok(()) Ok(())
} }
/// Free all blocks associated with a request /// Send a command to the worker assigned to a request, optionally publishing
/// /// a replica-sync event and cleaning up request mappings afterward.
/// Note: This operation is idempotent. Calling it multiple times for the same request async fn send_to_request_worker(
/// will log a warning but not return an error (double free is allowed). &self,
pub async fn free(&self, request_id: &RequestId) -> Result<(), SequenceError> { request_id: &RequestId,
// Check if request exists - if not, it's already been freed (idempotent) event_data: ActiveSequenceEventData,
let Some(worker) = self.request_to_worker.get(request_id).map(|entry| *entry) else { command_fn: impl FnOnce(RequestId) -> UpdateSequences,
tracing::debug!("Request {request_id} not found, already freed (idempotent)"); remove_mapping: bool,
return Ok(()); ) -> Result<(), SequenceError> {
}; let worker = self
.request_to_worker
.get(request_id)
.map(|entry| *entry)
.ok_or_else(|| SequenceError::RequestNotFound {
request_id: request_id.clone(),
})?;
// Verify worker still exists let sender = self
if !self.senders.contains_key(&worker) { .senders
return Err(SequenceError::WorkerNotFound { worker }); .get(&worker)
} .ok_or(SequenceError::WorkerNotFound { worker })?
.value()
.clone();
// Publish event only if replica_sync is enabled
if self.replica_sync { if self.replica_sync {
// Look up lora_name from mapping
let lora_name = self let lora_name = self
.request_to_lora .request_to_lora
.get(request_id) .get(request_id)
...@@ -882,31 +888,46 @@ impl ActiveSequencesMultiWorker { ...@@ -882,31 +888,46 @@ impl ActiveSequencesMultiWorker {
let event = ActiveSequenceEvent { let event = ActiveSequenceEvent {
request_id: request_id.clone(), request_id: request_id.clone(),
worker, worker,
data: ActiveSequenceEventData::Free, data: event_data,
router_id: self.router_id, router_id: self.router_id,
lora_name, lora_name,
}; };
self.event_publisher.publish(&event).await?; self.event_publisher.publish(&event).await?;
} }
// Update local state sender
self.senders .send(command_fn(request_id.clone()))
.get(&worker)
.unwrap()
.send(UpdateSequences::Free {
request_id: request_id.clone(),
})
.map_err(|_| SequenceError::WorkerChannelClosed)?; .map_err(|_| SequenceError::WorkerChannelClosed)?;
if remove_mapping {
self.request_to_worker.remove(request_id); self.request_to_worker.remove(request_id);
self.request_to_lora.remove(request_id); self.request_to_lora.remove(request_id);
}
// Publish ActiveLoad metrics for this worker
self.publish_active_load_for_worker(worker).await; self.publish_active_load_for_worker(worker).await;
Ok(()) Ok(())
} }
/// Free all blocks associated with a request
///
/// Note: This operation is idempotent. Calling it multiple times for the same request
/// will log a warning but not return an error (double free is allowed).
pub async fn free(&self, request_id: &RequestId) -> Result<(), SequenceError> {
if !self.request_to_worker.contains_key(request_id) {
tracing::debug!("Request {request_id} not found, already freed (idempotent)");
return Ok(());
}
self.send_to_request_worker(
request_id,
ActiveSequenceEventData::Free,
|rid| UpdateSequences::Free { request_id: rid },
true,
)
.await
}
/// Mark prefill as completed for a request /// Mark prefill as completed for a request
/// ///
/// Note: Calling this multiple times for the same request is allowed and will be a no-op /// Note: Calling this multiple times for the same request is allowed and will be a no-op
...@@ -915,50 +936,13 @@ impl ActiveSequencesMultiWorker { ...@@ -915,50 +936,13 @@ impl ActiveSequencesMultiWorker {
&self, &self,
request_id: &RequestId, request_id: &RequestId,
) -> Result<(), SequenceError> { ) -> Result<(), SequenceError> {
let worker = self self.send_to_request_worker(
.request_to_worker request_id,
.get(request_id) ActiveSequenceEventData::MarkPrefillCompleted,
.map(|entry| *entry) |rid| UpdateSequences::MarkPrefillCompleted { request_id: rid },
.ok_or_else(|| SequenceError::RequestNotFound { false,
request_id: request_id.clone(), )
})?; .await
// Verify worker still exists
if !self.senders.contains_key(&worker) {
return Err(SequenceError::WorkerNotFound { worker });
}
// Publish event only if replica_sync is enabled
if self.replica_sync {
// Look up lora_name from mapping
let lora_name = self
.request_to_lora
.get(request_id)
.map(|entry| entry.value().clone());
let event = ActiveSequenceEvent {
request_id: request_id.clone(),
worker,
data: ActiveSequenceEventData::MarkPrefillCompleted,
router_id: self.router_id,
lora_name,
};
self.event_publisher.publish(&event).await?;
}
// Update local state
self.senders
.get(&worker)
.unwrap()
.send(UpdateSequences::MarkPrefillCompleted {
request_id: request_id.clone(),
})
.map_err(|_| SequenceError::WorkerChannelClosed)?;
// Publish ActiveLoad metrics for this worker
self.publish_active_load_for_worker(worker).await;
Ok(())
} }
/// Add an output block with optional fractional decay weight /// Add an output block with optional fractional decay weight
...@@ -978,18 +962,19 @@ impl ActiveSequencesMultiWorker { ...@@ -978,18 +962,19 @@ impl ActiveSequencesMultiWorker {
request_id: request_id.clone(), request_id: request_id.clone(),
})?; })?;
// Verify worker still exists // Clone sender upfront to avoid TOCTOU between contains_key and get().unwrap()
if !self.senders.contains_key(&worker) { let sender = self
return Err(SequenceError::WorkerNotFound { worker }); .senders
} .get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?
.value()
.clone();
// Create response channel // Create response channel
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
// Send command to worker // Send command to worker
self.senders sender
.get(&worker)
.unwrap()
.send(UpdateSequences::AddOutputBlock { .send(UpdateSequences::AddOutputBlock {
request_id: request_id.clone(), request_id: request_id.clone(),
decay_fraction, decay_fraction,
...@@ -1016,10 +1001,17 @@ impl ActiveSequencesMultiWorker { ...@@ -1016,10 +1001,17 @@ impl ActiveSequencesMultiWorker {
/// Helper method to query a single worker for active blocks/tokens and publish ActiveLoad /// Helper method to query a single worker for active blocks/tokens and publish ActiveLoad
async fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) { async fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) {
let Some(sender) = self.senders.get(&worker) else { // Clone the sender and drop the DashMap Ref immediately.
// Holding a Ref across .await points can deadlock: if the task yields
// and update_workers() needs a write lock on the same shard, the
// runtime thread blocks forever.
let sender = {
let Some(entry) = self.senders.get(&worker) else {
tracing::warn!("Worker {worker:?} not found when publishing ActiveLoad"); tracing::warn!("Worker {worker:?} not found when publishing ActiveLoad");
return; return;
}; };
entry.value().clone()
};
// Query active blocks // Query active blocks
let (blocks_tx, blocks_rx) = tokio::sync::oneshot::channel(); let (blocks_tx, blocks_rx) = tokio::sync::oneshot::channel();
...@@ -1337,11 +1329,11 @@ mod tests { ...@@ -1337,11 +1329,11 @@ mod tests {
// Create runtime config for worker 0 with dp_size=2 // Create runtime config for worker 0 with dp_size=2
let mut config_worker_0 = crate::local_model::runtime_config::ModelRuntimeConfig::new(); let mut config_worker_0 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
config_worker_0.data_parallel_size = 2; config_worker_0.data_parallel_size = 2;
workers_with_configs.insert(0, Some(config_worker_0)); workers_with_configs.insert(0, config_worker_0);
// Create runtime config for worker 1 with dp_size=1 (default) // Create runtime config for worker 1 with dp_size=1 (default)
let config_worker_1 = crate::local_model::runtime_config::ModelRuntimeConfig::new(); let config_worker_1 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
workers_with_configs.insert(1, Some(config_worker_1)); workers_with_configs.insert(1, config_worker_1);
let seq_manager_1 = Arc::new( let seq_manager_1 = Arc::new(
ActiveSequencesMultiWorker::new( ActiveSequencesMultiWorker::new(
...@@ -1509,9 +1501,18 @@ mod tests { ...@@ -1509,9 +1501,18 @@ mod tests {
// Create multi-worker sequence managers with ALL workers [0, 1, 2] // Create multi-worker sequence managers with ALL workers [0, 1, 2]
// Both use the same component to ensure event synchronization works // Both use the same component to ensure event synchronization works
let mut workers_with_configs = HashMap::new(); let mut workers_with_configs = HashMap::new();
workers_with_configs.insert(0, None); workers_with_configs.insert(
workers_with_configs.insert(1, None); 0,
workers_with_configs.insert(2, None); crate::local_model::runtime_config::ModelRuntimeConfig::new(),
);
workers_with_configs.insert(
1,
crate::local_model::runtime_config::ModelRuntimeConfig::new(),
);
workers_with_configs.insert(
2,
crate::local_model::runtime_config::ModelRuntimeConfig::new(),
);
let seq_manager_1 = Arc::new( let seq_manager_1 = Arc::new(
ActiveSequencesMultiWorker::new( ActiveSequencesMultiWorker::new(
......
...@@ -18,8 +18,8 @@ use tokio::sync::{mpsc, oneshot}; ...@@ -18,8 +18,8 @@ use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::kv_router::{ use crate::kv_router::{
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, KV_EVENT_SUBJECT, KvRouterConfig, RADIX_STATE_BUCKET, RADIX_STATE_FILE,
indexer::{DumpRequest, GetWorkersRequest}, indexer::{DumpRequest, GetWorkersRequest, KvIndexer},
protocols::{DpRank, RouterEvent, WorkerId}, protocols::{DpRank, RouterEvent, WorkerId},
router_discovery_query, router_discovery_query,
worker_query::WorkerQueryClient, worker_query::WorkerQueryClient,
...@@ -511,10 +511,15 @@ pub async fn start_kv_router_background( ...@@ -511,10 +511,15 @@ pub async fn start_kv_router_background(
pub async fn start_kv_router_background_event_plane( pub async fn start_kv_router_background_event_plane(
component: Component, component: Component,
kv_events_tx: mpsc::Sender<RouterEvent>, kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<WorkerId>,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
mut worker_query_client: WorkerQueryClient,
transport_kind: EventTransportKind, transport_kind: EventTransportKind,
) -> Result<()> { ) -> Result<()> {
// WorkerQueryClient handles its own discovery loop for lifecycle + initial recovery.
// No blocking wait — recovery happens asynchronously as endpoints are discovered.
let worker_query_client =
WorkerQueryClient::spawn(component.clone(), remove_worker_tx, kv_events_tx.clone()).await?;
// Subscribe to KV events using the selected event plane transport // Subscribe to KV events using the selected event plane transport
let mut subscriber = let mut subscriber =
EventSubscriber::for_component_with_transport(&component, KV_EVENT_SUBJECT, transport_kind) EventSubscriber::for_component_with_transport(&component, KV_EVENT_SUBJECT, transport_kind)
...@@ -542,20 +547,6 @@ pub async fn start_kv_router_background_event_plane( ...@@ -542,20 +547,6 @@ pub async fn start_kv_router_background_event_plane(
} }
} }
// Wait for at least one worker with a known runtime config before proceeding.
// This ensures we have actual config data (including enable_local_indexer) available.
tracing::info!("KV subscriber waiting for at least one worker with runtime config...");
let ready_workers = worker_query_client.wait_for_ready().await;
tracing::info!(
"KV subscriber found {} worker(s) with runtime config, proceeding",
ready_workers.len()
);
// Recover initial state from all workers with local indexer enabled
worker_query_client
.process_and_recover_workers(&kv_events_tx, "Initial recovery")
.await;
tokio::spawn(async move { tokio::spawn(async move {
// Track last received event ID per (worker, dp_rank) for gap detection // Track last received event ID per (worker, dp_rank) for gap detection
// Each dp_rank has its own monotonic event ID sequence // Each dp_rank has its own monotonic event ID sequence
...@@ -570,18 +561,6 @@ pub async fn start_kv_router_background_event_plane( ...@@ -570,18 +561,6 @@ pub async fn start_kv_router_background_event_plane(
break; break;
} }
// Handle runtime config changes (worker add/remove, recovery for new workers)
result = worker_query_client.wait_for_config_change() => {
if result.is_err() {
tracing::warn!("Runtime config watch sender dropped");
continue;
}
worker_query_client
.process_and_recover_workers(&kv_events_tx, "DISCOVERY")
.await;
}
// Handle event consumption from event plane subscription // Handle event consumption from event plane subscription
Some(result) = subscriber.next() => { Some(result) = subscriber.next() => {
let (envelope, event) = match result { let (envelope, event) = match result {
...@@ -597,7 +576,6 @@ pub async fn start_kv_router_background_event_plane( ...@@ -597,7 +576,6 @@ pub async fn start_kv_router_background_event_plane(
let event_id = event.event.event_id; let event_id = event.event.event_id;
let event_key = (worker_id, dp_rank); let event_key = (worker_id, dp_rank);
// Use envelope metadata for additional debugging
tracing::trace!( tracing::trace!(
"Received event from publisher {} (seq {})", "Received event from publisher {} (seq {})",
envelope.publisher_id, envelope.publisher_id,
...@@ -609,7 +587,6 @@ pub async fn start_kv_router_background_event_plane( ...@@ -609,7 +587,6 @@ pub async fn start_kv_router_background_event_plane(
if let Some(&last_id) = last_event_ids.get(&event_key) if let Some(&last_id) = last_event_ids.get(&event_key)
&& event_id > last_id + 1 && event_id > last_id + 1
{ {
// Gap detected - recover missing events before processing current
let gap_start = last_id + 1; let gap_start = last_id + 1;
let gap_end = event_id - 1; let gap_end = event_id - 1;
let gap_size = gap_end - gap_start + 1; let gap_size = gap_end - gap_start + 1;
...@@ -617,22 +594,15 @@ pub async fn start_kv_router_background_event_plane( ...@@ -617,22 +594,15 @@ pub async fn start_kv_router_background_event_plane(
"Event ID gap detected for worker {worker_id} dp_rank {dp_rank}, recovering events [{gap_start}, {gap_end}], gap_size: {gap_size}" "Event ID gap detected for worker {worker_id} dp_rank {dp_rank}, recovering events [{gap_start}, {gap_end}], gap_size: {gap_size}"
); );
// Note: While recovering, new events may queue in the subscriber's
// internal buffer. We don't explicitly buffer them here for simplicity.
// The subscriber will process them in order after recovery completes.
if let Err(e) = worker_query_client if let Err(e) = worker_query_client
.recover_from_worker(worker_id, dp_rank, Some(gap_start), Some(gap_end), &kv_events_tx) .recover_from_worker(worker_id, dp_rank, Some(gap_start), Some(gap_end))
.await .await
{ {
tracing::error!( tracing::error!(
"Failed to recover gap events for worker {worker_id} dp_rank {dp_rank} (gap_start: {gap_start}, gap_end: {gap_end}); proceeding with current event anyway: {e}" "Failed to recover gap events for worker {worker_id} dp_rank {dp_rank} (gap_start: {gap_start}, gap_end: {gap_end}); proceeding with current event anyway: {e}"
); );
// Note: If recovery fails, we still apply the current event.
// The tree will have a gap, but it's better than dropping the event.
} }
} }
// First event from this (worker, dp_rank) is always valid - we accept whatever ID it has.
// This handles initial startup and worker restarts without requiring event 0.
// Update last seen event ID (use max to handle out-of-order) // Update last seen event ID (use max to handle out-of-order)
last_event_ids last_event_ids
...@@ -657,23 +627,6 @@ pub async fn start_kv_router_background_event_plane( ...@@ -657,23 +627,6 @@ pub async fn start_kv_router_background_event_plane(
Ok(()) Ok(())
} }
/// Backwards-compatible wrapper for NATS Core local-indexer mode.
pub async fn start_kv_router_background_nats_core(
component: Component,
kv_events_tx: mpsc::Sender<RouterEvent>,
cancellation_token: CancellationToken,
worker_query_client: WorkerQueryClient,
) -> Result<()> {
start_kv_router_background_event_plane(
component,
kv_events_tx,
cancellation_token,
worker_query_client,
EventTransportKind::Nats,
)
.await
}
/// Cleanup orphaned NATS consumers that no longer have corresponding router entries /// Cleanup orphaned NATS consumers that no longer have corresponding router entries
async fn cleanup_orphaned_consumers( async fn cleanup_orphaned_consumers(
nats_queue: &mut NatsQueue, nats_queue: &mut NatsQueue,
...@@ -711,3 +664,66 @@ async fn cleanup_orphaned_consumers( ...@@ -711,3 +664,66 @@ async fn cleanup_orphaned_consumers(
} }
} }
} }
/// Helper to decide which subscriber (JetStream or Event Plane) to start based on config
pub async fn start_subscriber(
component: Component,
kv_router_config: &KvRouterConfig,
router_id: u64,
kv_indexer: &KvIndexer,
cancellation_token: CancellationToken,
) -> Result<()> {
let transport_kind = EventTransportKind::from_env_or_default();
// Start subscriber - durable_kv_events flag determines the mode:
// - durable_kv_events=false (default): Use NATS Core / generic event plane (requires workers to have local_indexer enabled)
// - durable_kv_events=true: Use JetStream for durability and multi-replica consistency
if kv_router_config.durable_kv_events {
if transport_kind == EventTransportKind::Zmq {
tracing::warn!(
"--durable-kv-events requires NATS, but ZMQ event plane is configured; falling back to JetStream anyway"
);
}
tracing::info!("Using JetStream subscription (--durable-kv-events enabled)");
let consumer_id = router_id.to_string();
start_kv_router_background(
component,
consumer_id,
kv_indexer.event_sender(),
kv_indexer.remove_worker_sender(),
kv_router_config
.router_snapshot_threshold
.map(|_| kv_indexer.get_workers_sender()),
kv_router_config
.router_snapshot_threshold
.map(|_| kv_indexer.snapshot_event_sender()),
cancellation_token,
kv_router_config.router_snapshot_threshold,
kv_router_config.router_reset_states,
)
.await
} else {
if transport_kind == EventTransportKind::Zmq {
if kv_router_config.router_snapshot_threshold.is_some()
|| kv_router_config.router_reset_states
{
tracing::warn!(
"ZMQ event plane does not support KV snapshots or state reset; ignoring snapshot/reset settings"
);
}
tracing::info!("Using ZMQ event plane subscription (local_indexer mode)");
} else {
tracing::info!("Using NATS Core subscription (local_indexer mode)");
}
start_kv_router_background_event_plane(
component.clone(),
kv_indexer.event_sender(),
kv_indexer.remove_worker_sender(),
cancellation_token,
transport_kind,
)
.await
}
}
This diff is collapsed.
...@@ -239,7 +239,11 @@ impl LocalModelBuilder { ...@@ -239,7 +239,11 @@ impl LocalModelBuilder {
self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64); self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
self.runtime_config.max_num_batched_tokens = self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64); mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
self.runtime_config.enable_local_indexer = mocker_engine_args.enable_local_indexer; // Decode workers don't create the WorkerKvQuery endpoint (scheduler_component is None),
// so they must not advertise enable_local_indexer=true or the router will hang
// trying to query them during initial recovery.
self.runtime_config.enable_local_indexer = mocker_engine_args.enable_local_indexer
&& mocker_engine_args.worker_type != WorkerType::Decode;
self.runtime_config.data_parallel_size = mocker_engine_args.dp_size; self.runtime_config.data_parallel_size = mocker_engine_args.dp_size;
self.media_decoder = Some(MediaDecoder { self.media_decoder = Some(MediaDecoder {
image: Some(ImageDecoder::default()), image: Some(ImageDecoder::default()),
......
...@@ -32,8 +32,8 @@ pub struct ModelRuntimeConfig { ...@@ -32,8 +32,8 @@ pub struct ModelRuntimeConfig {
#[serde(default = "default_data_parallel_size")] #[serde(default = "default_data_parallel_size")]
pub data_parallel_size: u32, pub data_parallel_size: u32,
/// Enable worker-local KV indexer for tracking this worker's own KV cache state /// Enable worker-local KV indexer for tracking this worker's own KV cache state (default: true)
#[serde(default)] #[serde(default = "default_local_indexer")]
pub enable_local_indexer: bool, pub enable_local_indexer: bool,
/// Mapping of engine-specific runtime configs /// Mapping of engine-specific runtime configs
...@@ -59,6 +59,10 @@ const fn default_data_parallel_size() -> u32 { ...@@ -59,6 +59,10 @@ const fn default_data_parallel_size() -> u32 {
1 1
} }
const fn default_local_indexer() -> bool {
true
}
impl Default for ModelRuntimeConfig { impl Default for ModelRuntimeConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
...@@ -68,7 +72,7 @@ impl Default for ModelRuntimeConfig { ...@@ -68,7 +72,7 @@ impl Default for ModelRuntimeConfig {
tool_call_parser: None, tool_call_parser: None,
reasoning_parser: None, reasoning_parser: None,
data_parallel_size: default_data_parallel_size(), data_parallel_size: default_data_parallel_size(),
enable_local_indexer: false, enable_local_indexer: true,
runtime_data: HashMap::new(), runtime_data: HashMap::new(),
tensor_model_config: None, tensor_model_config: None,
disaggregated_endpoint: None, disaggregated_endpoint: None,
......
...@@ -385,7 +385,7 @@ impl Discovery for KVStoreDiscovery { ...@@ -385,7 +385,7 @@ impl Discovery for KVStoreDiscovery {
// Get bucket - if it doesn't exist, return empty list // Get bucket - if it doesn't exist, return empty list
let Some(bucket) = self.store.get_bucket(bucket_name).await? else { let Some(bucket) = self.store.get_bucket(bucket_name).await? else {
tracing::info!( tracing::debug!(
"KVStoreDiscovery::list: bucket missing for query={:?}, prefix={}, bucket={}", "KVStoreDiscovery::list: bucket missing for query={:?}, prefix={}, bucket={}",
query, query,
prefix, prefix,
...@@ -396,7 +396,7 @@ impl Discovery for KVStoreDiscovery { ...@@ -396,7 +396,7 @@ impl Discovery for KVStoreDiscovery {
// Get all entries from the bucket // Get all entries from the bucket
let entries = bucket.entries().await?; let entries = bucket.entries().await?;
tracing::info!( tracing::debug!(
"KVStoreDiscovery::list: query={:?}, prefix={}, bucket={}, entries={}", "KVStoreDiscovery::list: query={:?}, prefix={}, bucket={}, entries={}",
query, query,
prefix, prefix,
......
...@@ -66,6 +66,12 @@ where ...@@ -66,6 +66,12 @@ where
/// If None, busy detection is disabled /// If None, busy detection is disabled
busy_threshold: Option<f64>, busy_threshold: Option<f64>,
/// When false, `generate_with_fault_detection` skips fault detection logic:
/// it won't call `report_instance_down` on errors, and it uses the raw discovery
/// instance list instead of the filtered avail list. Use for recovery/query paths
/// where transient failures are expected.
fault_detection_enabled: bool,
/// An internal Rust type. This says that PushRouter is generic over the T and U types, /// An internal Rust type. This says that PushRouter is generic over the T and U types,
/// which are the input and output types of it's `generate` function. It allows the /// which are the input and output types of it's `generate` function. It allows the
/// compiler to specialize us at compile time. /// compiler to specialize us at compile time.
...@@ -112,6 +118,28 @@ where ...@@ -112,6 +118,28 @@ where
Self::from_client_with_threshold(client, router_mode, None, None).await Self::from_client_with_threshold(client, router_mode, None, None).await
} }
/// Create a new PushRouter with fault detection disabled.
///
/// Unlike `from_client`, this router will not call `report_instance_down` on
/// transient errors, and `direct()` uses the raw discovery instance list instead
/// of the filtered avail list. Use for recovery/query paths.
pub async fn from_client_no_fault_detection(
client: Client,
router_mode: RouterMode,
) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?;
Ok(PushRouter {
client: client.clone(),
addressed,
router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)),
busy_threshold: None,
fault_detection_enabled: false,
_phantom: PhantomData,
})
}
/// Create a new PushRouter with optional busy threshold and worker load monitor /// Create a new PushRouter with optional busy threshold and worker load monitor
pub async fn from_client_with_threshold( pub async fn from_client_with_threshold(
client: Client, client: Client,
...@@ -132,6 +160,7 @@ where ...@@ -132,6 +160,7 @@ where
router_mode, router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)), round_robin_counter: Arc::new(AtomicU64::new(0)),
busy_threshold, busy_threshold,
fault_detection_enabled: true,
_phantom: PhantomData, _phantom: PhantomData,
}; };
...@@ -185,7 +214,14 @@ where ...@@ -185,7 +214,14 @@ where
request: SingleIn<T>, request: SingleIn<T>,
instance_id: u64, instance_id: u64,
) -> anyhow::Result<ManyOut<U>> { ) -> anyhow::Result<ManyOut<U>> {
let found = self.client.instance_ids_avail().contains(&instance_id); // When fault detection is disabled, check the raw discovery list
// (not filtered by report_instance_down) so transient failures
// don't poison the instance for subsequent retries.
let found = if self.fault_detection_enabled {
self.client.instance_ids_avail().contains(&instance_id)
} else {
self.client.instance_ids().contains(&instance_id)
};
if !found { if !found {
return Err(anyhow::anyhow!( return Err(anyhow::anyhow!(
...@@ -271,8 +307,8 @@ where ...@@ -271,8 +307,8 @@ where
instance_id: u64, instance_id: u64,
request: SingleIn<T>, request: SingleIn<T>,
) -> anyhow::Result<ManyOut<U>> { ) -> anyhow::Result<ManyOut<U>> {
// Check if all workers are busy (only if busy threshold is set) // Check if all workers are busy (only if busy threshold is set and fault detection enabled)
if self.busy_threshold.is_some() { if self.fault_detection_enabled && self.busy_threshold.is_some() {
let free_instances = self.client.instance_ids_free(); let free_instances = self.client.instance_ids_free();
if free_instances.is_empty() { if free_instances.is_empty() {
// Check if we actually have any instances at all // Check if we actually have any instances at all
...@@ -332,6 +368,9 @@ where ...@@ -332,6 +368,9 @@ where
let stream: anyhow::Result<ManyOut<U>> = self.addressed.generate(request).await; let stream: anyhow::Result<ManyOut<U>> = self.addressed.generate(request).await;
match stream { match stream {
Ok(stream) => { Ok(stream) => {
if !self.fault_detection_enabled {
return Ok(stream);
}
let engine_ctx = stream.context(); let engine_ctx = stream.context();
let client = self.client.clone(); let client = self.client.clone();
let stream = stream.map(move |res| { let stream = stream.map(move |res| {
...@@ -349,7 +388,8 @@ where ...@@ -349,7 +388,8 @@ where
Ok(ResponseStream::new(Box::pin(stream), engine_ctx)) Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
} }
Err(err) => { Err(err) => {
if let Some(req_err) = err.downcast_ref::<NatsRequestError>() if self.fault_detection_enabled
&& let Some(req_err) = err.downcast_ref::<NatsRequestError>()
&& matches!(req_err.kind(), NatsNoResponders) && matches!(req_err.kind(), NatsNoResponders)
{ {
tracing::debug!( tracing::debug!(
......
...@@ -202,7 +202,10 @@ def predownload_models(pytestconfig): ...@@ -202,7 +202,10 @@ def predownload_models(pytestconfig):
else: else:
# Fallback to original behavior if extraction failed # Fallback to original behavior if extraction failed
download_models() download_models()
os.environ["HF_HUB_OFFLINE"] = "1"
yield yield
os.environ.pop("HF_HUB_OFFLINE", None)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
...@@ -218,7 +221,13 @@ def predownload_tokenizers(pytestconfig): ...@@ -218,7 +221,13 @@ def predownload_tokenizers(pytestconfig):
else: else:
# Fallback to original behavior if extraction failed # Fallback to original behavior if extraction failed
download_models(ignore_weights=True) download_models(ignore_weights=True)
# Skip redundant HuggingFace API calls in worker subprocesses since
# tokenizers are already cached. This avoids flaky timeouts from slow
# HF API responses (the RepoInfo fetch still happens even for cached models).
os.environ["HF_HUB_OFFLINE"] = "1"
yield yield
os.environ.pop("HF_HUB_OFFLINE", None)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
...@@ -610,20 +619,20 @@ def request_plane(request): ...@@ -610,20 +619,20 @@ def request_plane(request):
@pytest.fixture @pytest.fixture
def use_nats_core(request): def durable_kv_events(request):
""" """
Whether to use NATS Core mode (local indexer) instead of JetStream. Defaults to False. Whether to use durable KV events via JetStream. Defaults to False (NATS Core mode).
When True: When False (default):
- NATS server starts without JetStream (-js flag omitted) for faster startup - NATS server starts without JetStream (-js flag omitted) for faster startup
- Tests should use enable_local_indexer=True in mocker_args - Workers use local indexer mode (NATS Core / fire-and-forget events)
When False (default): When True:
- NATS server starts with JetStream for KV event distribution - NATS server starts with JetStream for durable KV event distribution
- Tests use JetStream-based indexer synchronization - Workers use --durable-kv-events flag to publish to JetStream
To use NATS Core mode: To use JetStream mode:
@pytest.mark.parametrize("use_nats_core", [True], indirect=True) @pytest.mark.parametrize("durable_kv_events", [True], indirect=True)
def test_example(runtime_services_dynamic_ports): def test_example(runtime_services_dynamic_ports):
... ...
""" """
...@@ -656,7 +665,7 @@ def runtime_services(request, store_kv, request_plane): ...@@ -656,7 +665,7 @@ def runtime_services(request, store_kv, request_plane):
@pytest.fixture() @pytest.fixture()
def runtime_services_dynamic_ports(request, store_kv, request_plane, use_nats_core): def runtime_services_dynamic_ports(request, store_kv, request_plane, durable_kv_events):
"""Provide NATS and Etcd servers with truly dynamic ports per test. """Provide NATS and Etcd servers with truly dynamic ports per test.
This fixture actually allocates dynamic ports by passing port=0 to the servers. This fixture actually allocates dynamic ports by passing port=0 to the servers.
...@@ -671,7 +680,7 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane, use_nats_co ...@@ -671,7 +680,7 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane, use_nats_co
- If store_kv != "etcd", etcd is not started (returns None) - If store_kv != "etcd", etcd is not started (returns None)
- NATS is always started when etcd is used, because KV events require NATS - NATS is always started when etcd is used, because KV events require NATS
regardless of the request_plane (tcp/nats only affects request transport) regardless of the request_plane (tcp/nats only affects request transport)
- JetStream is enabled by default; disabled when use_nats_core=True for faster startup - NATS Core mode (no JetStream) is the default; JetStream is enabled when durable_kv_events=True
Returns a tuple of (nats_process, etcd_process) where each has a .port attribute. Returns a tuple of (nats_process, etcd_process) where each has a .port attribute.
""" """
...@@ -679,10 +688,10 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane, use_nats_co ...@@ -679,10 +688,10 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane, use_nats_co
# Port cleanup is now handled in NatsServer and EtcdServer __exit__ methods # Port cleanup is now handled in NatsServer and EtcdServer __exit__ methods
# Always start NATS when etcd is used - KV events require NATS regardless of request_plane # Always start NATS when etcd is used - KV events require NATS regardless of request_plane
# When use_nats_core=True, disable JetStream for faster startup # When durable_kv_events=False (default), disable JetStream for faster startup
if store_kv == "etcd": if store_kv == "etcd":
with NatsServer( with NatsServer(
request, port=0, disable_jetstream=use_nats_core request, port=0, disable_jetstream=not durable_kv_events
) as nats_process: ) as nats_process:
with EtcdServer(request, port=0) as etcd_process: with EtcdServer(request, port=0) as etcd_process:
# Save original env vars (may be set by session-scoped fixture) # Save original env vars (may be set by session-scoped fixture)
...@@ -706,7 +715,7 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane, use_nats_co ...@@ -706,7 +715,7 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane, use_nats_co
os.environ.pop("ETCD_ENDPOINTS", None) os.environ.pop("ETCD_ENDPOINTS", None)
elif request_plane == "nats": elif request_plane == "nats":
with NatsServer( with NatsServer(
request, port=0, disable_jetstream=use_nats_core request, port=0, disable_jetstream=not durable_kv_events
) as nats_process: ) as nats_process:
orig_nats = os.environ.get("NATS_SERVER") orig_nats = os.environ.get("NATS_SERVER")
os.environ["NATS_SERVER"] = f"nats://localhost:{nats_process.port}" os.environ["NATS_SERVER"] = f"nats://localhost:{nats_process.port}"
......
...@@ -50,6 +50,7 @@ class KVRouterProcess(ManagedProcess): ...@@ -50,6 +50,7 @@ class KVRouterProcess(ManagedProcess):
tokens_threshold: float | None = None, tokens_threshold: float | None = None,
tokens_threshold_frac: float | None = None, tokens_threshold_frac: float | None = None,
request_plane: str = "nats", request_plane: str = "nats",
durable_kv_events: bool = False,
): ):
command = [ command = [
"python3", "python3",
...@@ -81,6 +82,9 @@ class KVRouterProcess(ManagedProcess): ...@@ -81,6 +82,9 @@ class KVRouterProcess(ManagedProcess):
["--active-prefill-tokens-threshold-frac", str(tokens_threshold_frac)] ["--active-prefill-tokens-threshold-frac", str(tokens_threshold_frac)]
) )
if durable_kv_events:
command.append("--durable-kv-events")
env = os.environ.copy() env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request_plane env["DYN_REQUEST_PLANE"] = request_plane
...@@ -1335,6 +1339,7 @@ def _test_router_indexers_sync( ...@@ -1335,6 +1339,7 @@ def _test_router_indexers_sync(
request_plane: str = "nats", request_plane: str = "nats",
test_nats_interruption: bool = False, test_nats_interruption: bool = False,
nats_server: Optional["NatsServer"] = None, nats_server: Optional["NatsServer"] = None,
durable_kv_events: bool = False,
): ):
"""Test that two KV routers have synchronized indexer states after processing requests. """Test that two KV routers have synchronized indexer states after processing requests.
...@@ -1365,6 +1370,7 @@ def _test_router_indexers_sync( ...@@ -1365,6 +1370,7 @@ def _test_router_indexers_sync(
request_plane: Request plane to use ("nats" or "tcp"). Defaults to "nats". request_plane: Request plane to use ("nats" or "tcp"). Defaults to "nats".
test_nats_interruption: If True, test NATS interruption recovery. Defaults to False. test_nats_interruption: If True, test NATS interruption recovery. Defaults to False.
nats_server: NatsServer instance for stop/start (required if test_nats_interruption=True). nats_server: NatsServer instance for stop/start (required if test_nats_interruption=True).
durable_kv_events: If True, use durable KV events (JetStream). Defaults to False.
Raises: Raises:
AssertionError: If router states don't synchronize correctly or snapshot is missing AssertionError: If router states don't synchronize correctly or snapshot is missing
...@@ -1375,7 +1381,10 @@ def _test_router_indexers_sync( ...@@ -1375,7 +1381,10 @@ def _test_router_indexers_sync(
# Use async to manage the test flow # Use async to manage the test flow
async def test_sync(): async def test_sync():
# Create KvRouterConfig with lower snapshot threshold for testing # Create KvRouterConfig with lower snapshot threshold for testing
kv_router_config = KvRouterConfig(router_snapshot_threshold=20) kv_router_config = KvRouterConfig(
router_snapshot_threshold=20,
durable_kv_events=durable_kv_events,
)
async def send_requests_to_router(router, num_requests, router_name, endpoint): async def send_requests_to_router(router, num_requests, router_name, endpoint):
# Now send the actual requests # Now send the actual requests
...@@ -1690,6 +1699,7 @@ def _test_router_decisions_disagg( ...@@ -1690,6 +1699,7 @@ def _test_router_decisions_disagg(
test_payload: dict, test_payload: dict,
store_backend: str = "etcd", store_backend: str = "etcd",
request_plane: str = "nats", request_plane: str = "nats",
durable_kv_events: bool = False,
): ):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup via HTTP frontend. """Validate KV cache prefix reuse in disaggregated prefill-decode setup via HTTP frontend.
...@@ -1711,6 +1721,7 @@ def _test_router_decisions_disagg( ...@@ -1711,6 +1721,7 @@ def _test_router_decisions_disagg(
frontend_port: Port for the frontend HTTP server frontend_port: Port for the frontend HTTP server
test_payload: Base test payload to send to /v1/chat/completions test_payload: Base test payload to send to /v1/chat/completions
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd". store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
durable_kv_events: If True, use durable KV events (JetStream). Defaults to False.
Raises: Raises:
AssertionError: If prefill_worker_ids differ across requests (prefix reuse failure) AssertionError: If prefill_worker_ids differ across requests (prefix reuse failure)
...@@ -1730,6 +1741,7 @@ def _test_router_decisions_disagg( ...@@ -1730,6 +1741,7 @@ def _test_router_decisions_disagg(
store_backend, store_backend,
enforce_disagg=True, enforce_disagg=True,
request_plane=request_plane, request_plane=request_plane,
durable_kv_events=durable_kv_events,
) )
kv_router.__enter__() kv_router.__enter__()
...@@ -1909,6 +1921,7 @@ def _test_router_decisions( ...@@ -1909,6 +1921,7 @@ def _test_router_decisions(
test_dp_rank: bool = False, test_dp_rank: bool = False,
block_size: int = BLOCK_SIZE, block_size: int = BLOCK_SIZE,
use_kv_events: bool = True, use_kv_events: bool = True,
durable_kv_events: bool = False,
): ):
"""Validate KV cache prefix reuse and worker routing by sending requests diverging prefixes. """Validate KV cache prefix reuse and worker routing by sending requests diverging prefixes.
...@@ -1929,6 +1942,7 @@ def _test_router_decisions( ...@@ -1929,6 +1942,7 @@ def _test_router_decisions(
test_dp_rank: If True, also forces and validates dp_rank routing (for data parallel setups) test_dp_rank: If True, also forces and validates dp_rank routing (for data parallel setups)
use_kv_events: If True (default), uses KV events from workers. If False, uses use_kv_events: If True (default), uses KV events from workers. If False, uses
approximate routing with TTL-based expiration (--no-kv-events mode). approximate routing with TTL-based expiration (--no-kv-events mode).
durable_kv_events: If True, use durable KV events (JetStream). Defaults to False.
Raises: Raises:
AssertionError: If routing decisions don't follow KV cache prefix reuse as expected AssertionError: If routing decisions don't follow KV cache prefix reuse as expected
...@@ -1937,6 +1951,7 @@ def _test_router_decisions( ...@@ -1937,6 +1951,7 @@ def _test_router_decisions(
kv_router_config = KvRouterConfig( kv_router_config = KvRouterConfig(
router_snapshot_threshold=20, router_snapshot_threshold=20,
use_kv_events=use_kv_events, use_kv_events=use_kv_events,
durable_kv_events=durable_kv_events,
) )
kv_push_router = KvPushRouter( kv_push_router = KvPushRouter(
endpoint=endpoint, endpoint=endpoint,
......
...@@ -158,8 +158,9 @@ def _build_mocker_command( ...@@ -158,8 +158,9 @@ def _build_mocker_command(
command.extend(["--watermark", str(mocker_args["watermark"])]) command.extend(["--watermark", str(mocker_args["watermark"])])
if "dp_size" in mocker_args: if "dp_size" in mocker_args:
command.extend(["--data-parallel-size", str(mocker_args["dp_size"])]) command.extend(["--data-parallel-size", str(mocker_args["dp_size"])])
if mocker_args.get("enable_local_indexer"): # Use --durable-kv-events to enable JetStream mode (local indexer disabled)
command.append("--enable-local-indexer") if mocker_args.get("durable_kv_events") is True:
command.append("--durable-kv-events")
if "bootstrap_ports" in mocker_args: if "bootstrap_ports" in mocker_args:
command.extend(["--bootstrap-ports", mocker_args["bootstrap_ports"]]) command.extend(["--bootstrap-ports", mocker_args["bootstrap_ports"]])
...@@ -325,14 +326,14 @@ class DisaggMockerProcess: ...@@ -325,14 +326,14 @@ class DisaggMockerProcess:
@pytest.mark.timeout(42) # ~3x average (~13.80s), rounded up @pytest.mark.timeout(42) # ~3x average (~13.80s), rounded up
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) @pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_nats_core", [True], indirect=True "durable_kv_events", [False], indirect=True
) # Use NATS Core (local indexer) ) # Use NATS Core (local indexer)
def test_mocker_kv_router( def test_mocker_kv_router(
request, request,
runtime_services_dynamic_ports, runtime_services_dynamic_ports,
predownload_tokenizers, predownload_tokenizers,
request_plane, request_plane,
use_nats_core, durable_kv_events,
): ):
""" """
Test KV router with multiple mocker engine instances. Test KV router with multiple mocker engine instances.
...@@ -347,7 +348,7 @@ def test_mocker_kv_router( ...@@ -347,7 +348,7 @@ def test_mocker_kv_router(
mocker_args = { mocker_args = {
"speedup_ratio": SPEEDUP_RATIO, "speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE, "block_size": BLOCK_SIZE,
"enable_local_indexer": use_nats_core, "durable_kv_events": durable_kv_events,
} }
try: try:
...@@ -385,7 +386,7 @@ def test_mocker_kv_router( ...@@ -385,7 +386,7 @@ def test_mocker_kv_router(
@pytest.mark.parametrize("store_backend", ["etcd", "file"]) @pytest.mark.parametrize("store_backend", ["etcd", "file"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_nats_core", [True], indirect=True "durable_kv_events", [False], indirect=True
) # Use NATS Core (local indexer) ) # Use NATS Core (local indexer)
@pytest.mark.timeout(60) # ~3x average (~19.86s), rounded up @pytest.mark.timeout(60) # ~3x average (~19.86s), rounded up
def test_mocker_two_kv_router( def test_mocker_two_kv_router(
...@@ -394,7 +395,7 @@ def test_mocker_two_kv_router( ...@@ -394,7 +395,7 @@ def test_mocker_two_kv_router(
predownload_tokenizers, predownload_tokenizers,
file_storage_backend, file_storage_backend,
store_backend, store_backend,
use_nats_core, durable_kv_events,
): ):
""" """
Test with two KV routers and multiple mocker engine instances. Test with two KV routers and multiple mocker engine instances.
...@@ -411,7 +412,7 @@ def test_mocker_two_kv_router( ...@@ -411,7 +412,7 @@ def test_mocker_two_kv_router(
mocker_args = { mocker_args = {
"speedup_ratio": SPEEDUP_RATIO, "speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE, "block_size": BLOCK_SIZE,
"enable_local_indexer": use_nats_core, "durable_kv_events": durable_kv_events,
} }
try: try:
...@@ -440,7 +441,7 @@ def test_mocker_two_kv_router( ...@@ -440,7 +441,7 @@ def test_mocker_two_kv_router(
test_payload=TEST_PAYLOAD, test_payload=TEST_PAYLOAD,
num_requests=NUM_REQUESTS, num_requests=NUM_REQUESTS,
store_backend=store_backend, store_backend=store_backend,
skip_consumer_verification=use_nats_core, # Skip JetStream checks in NATS Core mode skip_consumer_verification=not durable_kv_events, # Skip JetStream checks in NATS Core mode
) )
finally: finally:
...@@ -450,11 +451,11 @@ def test_mocker_two_kv_router( ...@@ -450,11 +451,11 @@ def test_mocker_two_kv_router(
@pytest.mark.skip(reason="Flaky, temporarily disabled") @pytest.mark.skip(reason="Flaky, temporarily disabled")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_nats_core", [True], indirect=True "durable_kv_events", [False], indirect=True
) # Use NATS Core (local indexer) ) # Use NATS Core (local indexer)
@pytest.mark.timeout(60) # ~3x average (~19.86s), rounded up (when enabled) @pytest.mark.timeout(60) # ~3x average (~19.86s), rounded up (when enabled)
def test_mocker_kv_router_overload_503( def test_mocker_kv_router_overload_503(
request, runtime_services_dynamic_ports, predownload_tokenizers, use_nats_core request, runtime_services_dynamic_ports, predownload_tokenizers, durable_kv_events
): ):
"""Test that KV router returns 503 when mocker workers are overloaded.""" """Test that KV router returns 503 when mocker workers are overloaded."""
logger.info("Starting mocker KV router overload test for 503 status") logger.info("Starting mocker KV router overload test for 503 status")
...@@ -463,7 +464,7 @@ def test_mocker_kv_router_overload_503( ...@@ -463,7 +464,7 @@ def test_mocker_kv_router_overload_503(
"speedup_ratio": 10, "speedup_ratio": 10,
"block_size": 4, # Smaller block size "block_size": 4, # Smaller block size
"num_gpu_blocks": 64, # Limited GPU blocks to exhaust quickly "num_gpu_blocks": 64, # Limited GPU blocks to exhaust quickly
"enable_local_indexer": use_nats_core, "durable_kv_events": durable_kv_events,
} }
try: try:
...@@ -494,14 +495,14 @@ def test_mocker_kv_router_overload_503( ...@@ -494,14 +495,14 @@ def test_mocker_kv_router_overload_503(
@pytest.mark.timeout(22) # ~3x average (~7.10s), rounded up @pytest.mark.timeout(22) # ~3x average (~7.10s), rounded up
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) @pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_nats_core", [True], indirect=True "durable_kv_events", [False], indirect=True
) # Use NATS Core (local indexer) ) # Use NATS Core (local indexer)
def test_kv_push_router_bindings( def test_kv_push_router_bindings(
request, request,
runtime_services_dynamic_ports, runtime_services_dynamic_ports,
predownload_tokenizers, predownload_tokenizers,
request_plane, request_plane,
use_nats_core, durable_kv_events,
): ):
"""Test KvPushRouter Python bindings with mocker engines.""" """Test KvPushRouter Python bindings with mocker engines."""
logger.info("Starting KvPushRouter bindings test") logger.info("Starting KvPushRouter bindings test")
...@@ -509,7 +510,7 @@ def test_kv_push_router_bindings( ...@@ -509,7 +510,7 @@ def test_kv_push_router_bindings(
mocker_args = { mocker_args = {
"speedup_ratio": SPEEDUP_RATIO, "speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE, "block_size": BLOCK_SIZE,
"enable_local_indexer": use_nats_core, "durable_kv_events": durable_kv_events,
} }
try: try:
...@@ -545,18 +546,18 @@ def test_kv_push_router_bindings( ...@@ -545,18 +546,18 @@ def test_kv_push_router_bindings(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"store_backend,use_nats_core,request_plane", "store_backend,durable_kv_events,request_plane",
[ [
("etcd", False, "nats"), # JetStream mode - uses JetStream (default) ("etcd", True, "nats"), # JetStream mode - uses JetStream
("etcd", True, "tcp"), # NATS core mode (with gap detection) - no JetStream ("etcd", False, "tcp"), # NATS core mode (with gap detection) - no JetStream
("file", False, "nats"), # File backend - uses JetStream (default) ("file", True, "nats"), # File backend - uses JetStream
], ],
ids=[ ids=[
"jetstream", "jetstream",
"nats_core", "nats_core",
"file", "file",
], ],
indirect=["request_plane", "use_nats_core"], indirect=["request_plane", "durable_kv_events"],
) )
@pytest.mark.timeout(90) # TODO: figure out a timeout @pytest.mark.timeout(90) # TODO: figure out a timeout
def test_indexers_sync( def test_indexers_sync(
...@@ -565,7 +566,7 @@ def test_indexers_sync( ...@@ -565,7 +566,7 @@ def test_indexers_sync(
predownload_tokenizers, predownload_tokenizers,
file_storage_backend, file_storage_backend,
store_backend, store_backend,
use_nats_core, durable_kv_events,
request_plane, request_plane,
): ):
""" """
...@@ -580,7 +581,7 @@ def test_indexers_sync( ...@@ -580,7 +581,7 @@ def test_indexers_sync(
""" """
logger.info( logger.info(
f"Starting indexers sync test: store_backend={store_backend}, " f"Starting indexers sync test: store_backend={store_backend}, "
f"use_nats_core={use_nats_core}, request_plane={request_plane}" f"durable_kv_events={durable_kv_events}, request_plane={request_plane}"
) )
# Use the dynamic-port fixture to avoid hardcoded localhost:4222/2379 in parallel runs. # Use the dynamic-port fixture to avoid hardcoded localhost:4222/2379 in parallel runs.
...@@ -591,7 +592,7 @@ def test_indexers_sync( ...@@ -591,7 +592,7 @@ def test_indexers_sync(
mocker_args = { mocker_args = {
"speedup_ratio": SPEEDUP_RATIO, "speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE, "block_size": BLOCK_SIZE,
"enable_local_indexer": use_nats_core, "durable_kv_events": durable_kv_events,
"dp_size": 2, "dp_size": 2,
} }
...@@ -610,6 +611,7 @@ def test_indexers_sync( ...@@ -610,6 +611,7 @@ def test_indexers_sync(
# Use the common test implementation (creates its own runtimes for each router) # Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive # Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
# When using durable_kv_events=True, use JetStream mode for the router
_test_router_indexers_sync( _test_router_indexers_sync(
engine_workers=mockers, engine_workers=mockers,
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
...@@ -617,8 +619,9 @@ def test_indexers_sync( ...@@ -617,8 +619,9 @@ def test_indexers_sync(
num_workers=NUM_MOCKERS, num_workers=NUM_MOCKERS,
store_backend=store_backend, store_backend=store_backend,
request_plane=request_plane, request_plane=request_plane,
test_nats_interruption=use_nats_core, test_nats_interruption=not durable_kv_events,
nats_server=nats_process if use_nats_core else None, nats_server=nats_process if not durable_kv_events else None,
durable_kv_events=durable_kv_events,
) )
logger.info("Indexers sync test completed successfully") logger.info("Indexers sync test completed successfully")
...@@ -630,10 +633,10 @@ def test_indexers_sync( ...@@ -630,10 +633,10 @@ def test_indexers_sync(
@pytest.mark.timeout(42) # ~3x average (~13.80s), rounded up @pytest.mark.timeout(42) # ~3x average (~13.80s), rounded up
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_nats_core", [True], indirect=True "durable_kv_events", [False], indirect=True
) # Use NATS Core (local indexer) ) # Use NATS Core (local indexer)
def test_query_instance_id_returns_worker_and_tokens( def test_query_instance_id_returns_worker_and_tokens(
request, runtime_services_dynamic_ports, predownload_tokenizers, use_nats_core request, runtime_services_dynamic_ports, predownload_tokenizers, durable_kv_events
): ):
"""Test query_instance_id annotation with mocker engines.""" """Test query_instance_id annotation with mocker engines."""
logger.info("Starting KV router query_instance_id annotation test") logger.info("Starting KV router query_instance_id annotation test")
...@@ -641,7 +644,7 @@ def test_query_instance_id_returns_worker_and_tokens( ...@@ -641,7 +644,7 @@ def test_query_instance_id_returns_worker_and_tokens(
mocker_args = { mocker_args = {
"speedup_ratio": SPEEDUP_RATIO, "speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE, "block_size": BLOCK_SIZE,
"enable_local_indexer": use_nats_core, "durable_kv_events": durable_kv_events,
} }
os.makedirs(request.node.name, exist_ok=True) os.makedirs(request.node.name, exist_ok=True)
...@@ -674,55 +677,46 @@ def test_query_instance_id_returns_worker_and_tokens( ...@@ -674,55 +677,46 @@ def test_query_instance_id_returns_worker_and_tokens(
@pytest.mark.timeout(29) # ~3x average (~9.55s), rounded up @pytest.mark.timeout(29) # ~3x average (~9.55s), rounded up
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) @pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_nats_core,use_kv_events", "durable_kv_events,use_kv_events",
[ [
(False, True), # JetStream mode (default) - uses JetStream (True, True), # JetStream mode with KV events
(True, True), # NATS Core + local indexer mode - no JetStream (False, True), # NATS Core mode with local indexer (default)
(False, False), # Approximate mode (--no-kv-events) - uses JetStream (False, False), # Approximate mode (--no-kv-events) - no KV events
], ],
ids=["jetstream", "nats_core", "no_kv_events"], ids=["jetstream", "nats_core", "no_kv_events"],
indirect=["use_nats_core"], indirect=["durable_kv_events"],
) )
def test_router_decisions( def test_router_decisions(
request, request,
runtime_services_dynamic_ports, runtime_services_dynamic_ports,
predownload_tokenizers, predownload_tokenizers,
use_nats_core, durable_kv_events,
use_kv_events, use_kv_events,
request_plane, request_plane,
): ):
"""Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes. """Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes.
Parameterized to test: Parameterized to test:
- JetStream mode (default): KV events via JetStream - JetStream mode: KV events via NATS JetStream (durable)
- NATS Core mode: KV events via NATS Core with local indexer on workers - NATS Core mode (default): KV events via NATS Core with local indexer on workers
- Approximate mode (--no-kv-events): No KV events, router predicts cache state - Approximate mode (--no-kv-events): No KV events, router predicts cache state
based on routing decisions with TTL-based expiration and pruning based on routing decisions with TTL-based expiration and pruning
""" """
# runtime_services_dynamic_ports handles NATS and etcd startup # runtime_services_dynamic_ports handles NATS and etcd startup
if not use_kv_events:
mode = "Approximate (no-kv-events)"
elif use_nats_core:
mode = "NATS Core (local indexer)"
else:
mode = "JetStream"
logger.info( logger.info(
f"Starting test router prefix reuse and KV events synchronization ({mode})" f"Starting test router decisions: durable_kv_events={durable_kv_events}, use_kv_events={use_kv_events}"
) )
# Create mocker args dictionary with dp_size=4 # Create mocker args dictionary with dp_size=4
# Note: enable_local_indexer only applies when use_kv_events=True and use_nats_core=True # durable_kv_events=True enables JetStream mode; False (default) uses NATS Core with local indexer
mocker_args = { mocker_args = {
"speedup_ratio": SPEEDUP_RATIO, "speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE, "block_size": BLOCK_SIZE,
"dp_size": 4, "dp_size": 4,
"enable_local_indexer": use_nats_core and use_kv_events, "durable_kv_events": durable_kv_events and use_kv_events,
} }
try: try:
logger.info(
f"Starting 2 mocker instances with dp_size=4 each (8 total dp ranks), {mode}"
)
mockers = MockerProcess( mockers = MockerProcess(
request, request,
mocker_args=mocker_args, mocker_args=mocker_args,
...@@ -748,6 +742,7 @@ def test_router_decisions( ...@@ -748,6 +742,7 @@ def test_router_decisions(
request, request,
test_dp_rank=True, test_dp_rank=True,
use_kv_events=use_kv_events, use_kv_events=use_kv_events,
durable_kv_events=durable_kv_events,
) )
finally: finally:
...@@ -786,11 +781,11 @@ def test_router_decisions_disagg( ...@@ -786,11 +781,11 @@ def test_router_decisions_disagg(
namespace_suffix = generate_random_suffix() namespace_suffix = generate_random_suffix()
shared_namespace = f"test-namespace-{namespace_suffix}" shared_namespace = f"test-namespace-{namespace_suffix}"
# Create mocker args - use JetStream for KV events (more reliable than NATS Core) # Create mocker args - use NATS Core with local indexer (default mode)
mocker_args = { mocker_args = {
"speedup_ratio": SPEEDUP_RATIO, "speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE, "block_size": BLOCK_SIZE,
"enable_local_indexer": False, # durable_kv_events defaults to False (NATS Core mode)
} }
prefill_workers = None prefill_workers = None
...@@ -877,7 +872,7 @@ def test_router_decisions_disagg( ...@@ -877,7 +872,7 @@ def test_router_decisions_disagg(
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True) @pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"use_nats_core", [True], indirect=True "durable_kv_events", [False], indirect=True
) # Use NATS Core (local indexer) ) # Use NATS Core (local indexer)
@pytest.mark.timeout(39) # ~3x average (~12.84s), rounded up @pytest.mark.timeout(39) # ~3x average (~12.84s), rounded up
def test_busy_threshold_endpoint( def test_busy_threshold_endpoint(
...@@ -885,7 +880,7 @@ def test_busy_threshold_endpoint( ...@@ -885,7 +880,7 @@ def test_busy_threshold_endpoint(
runtime_services_dynamic_ports, runtime_services_dynamic_ports,
predownload_tokenizers, predownload_tokenizers,
request_plane, request_plane,
use_nats_core, durable_kv_events,
): ):
"""Test that the /busy_threshold endpoint can be hit and responds correctly. """Test that the /busy_threshold endpoint can be hit and responds correctly.
...@@ -905,7 +900,7 @@ def test_busy_threshold_endpoint( ...@@ -905,7 +900,7 @@ def test_busy_threshold_endpoint(
mocker_args = { mocker_args = {
"speedup_ratio": SPEEDUP_RATIO, "speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE, "block_size": BLOCK_SIZE,
"enable_local_indexer": use_nats_core, "durable_kv_events": durable_kv_events,
} }
try: try:
......
...@@ -87,7 +87,7 @@ class SGLangProcess: ...@@ -87,7 +87,7 @@ class SGLangProcess:
data_parallel_size: Optional[int] = None, data_parallel_size: Optional[int] = None,
request_plane: str = "tcp", request_plane: str = "tcp",
store_backend: str = "etcd", store_backend: str = "etcd",
enable_local_indexer: bool = False, durable_kv_events: bool = False,
): ):
"""Initialize SGLang workers with dynamo integration. """Initialize SGLang workers with dynamo integration.
...@@ -104,7 +104,7 @@ class SGLangProcess: ...@@ -104,7 +104,7 @@ class SGLangProcess:
data_parallel_size: If set, enables data parallelism with this many ranks (num_workers must equal data_parallel_size) data_parallel_size: If set, enables data parallelism with this many ranks (num_workers must equal data_parallel_size)
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp". request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd". store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
enable_local_indexer: If True, enable worker-local KV indexer for NATS Core mode. Defaults to False. durable_kv_events: If True, use JetStream for durable KV events. Defaults to False (NATS Core mode).
""" """
# Generate unique namespace for isolation # Generate unique namespace for isolation
namespace_suffix = generate_random_suffix() namespace_suffix = generate_random_suffix()
...@@ -185,6 +185,10 @@ class SGLangProcess: ...@@ -185,6 +185,10 @@ class SGLangProcess:
kv_events_config = f'{{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:{kv_events_port}"}}' kv_events_config = f'{{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:{kv_events_port}"}}'
command.extend(["--kv-events-config", kv_events_config]) command.extend(["--kv-events-config", kv_events_config])
# Use --durable-kv-events to enable JetStream mode (local indexer disabled)
if durable_kv_events:
command.append("--durable-kv-events")
env = os.environ.copy() # Copy parent environment env = os.environ.copy() # Copy parent environment
env_vars = { env_vars = {
"CUDA_VISIBLE_DEVICES": gpu_device, "CUDA_VISIBLE_DEVICES": gpu_device,
...@@ -197,10 +201,6 @@ class SGLangProcess: ...@@ -197,10 +201,6 @@ class SGLangProcess:
if self.store_backend == "file" and "DYN_FILE_KV" in os.environ: if self.store_backend == "file" and "DYN_FILE_KV" in os.environ:
env_vars["DYN_FILE_KV"] = os.environ["DYN_FILE_KV"] env_vars["DYN_FILE_KV"] = os.environ["DYN_FILE_KV"]
# Enable local indexer for NATS Core mode
if enable_local_indexer:
env_vars["DYN_LOCAL_INDEXER"] = "true"
env.update(env_vars) env.update(env_vars)
# Create managed process for the worker # Create managed process for the worker
...@@ -475,13 +475,12 @@ def test_router_decisions_sglang_dp( ...@@ -475,13 +475,12 @@ def test_router_decisions_sglang_dp(
@pytest.mark.pre_merge @pytest.mark.pre_merge
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.parametrize( @pytest.mark.parametrize(
"store_backend,use_nats_core,request_plane", "store_backend,durable_kv_events,request_plane",
[ [
("etcd", False, "nats"), # JetStream mode ("etcd", False, "tcp"),
# ("etcd", True, "tcp"), # nats_core mode - disabled for now
# ("file", False, "nats"), # File backend - TODO: investigate file backend support for SGLang
], ],
ids=["jetstream"], ids=["nats_core"],
indirect=["durable_kv_events", "request_plane"],
) )
@pytest.mark.timeout(150) # ~3x average (~46s/test), rounded up @pytest.mark.timeout(150) # ~3x average (~46s/test), rounded up
def test_sglang_indexers_sync( def test_sglang_indexers_sync(
...@@ -491,7 +490,7 @@ def test_sglang_indexers_sync( ...@@ -491,7 +490,7 @@ def test_sglang_indexers_sync(
file_storage_backend, file_storage_backend,
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
store_backend, store_backend,
use_nats_core, durable_kv_events,
request_plane, request_plane,
): ):
""" """
...@@ -499,15 +498,15 @@ def test_sglang_indexers_sync( ...@@ -499,15 +498,15 @@ def test_sglang_indexers_sync(
with SGLang workers. This test verifies that both routers converge to the same internal state. with SGLang workers. This test verifies that both routers converge to the same internal state.
Tests with configuration: Tests with configuration:
- jetstream: etcd backend, JetStream for KV events, NATS request plane - nats_core: etcd backend, local indexer with NATS Core, TCP request plane
- tcp_nats_core: etcd backend, local indexer with NATS Core, TCP request plane (includes NATS interruption/recovery testing)
""" """
# runtime_services_dynamic_ports handles NATS and etcd startup # runtime_services_dynamic_ports handles NATS and etcd startup
nats_process, _etcd_process = runtime_services_dynamic_ports nats_process, _etcd_process = runtime_services_dynamic_ports
logger.info( logger.info(
f"Starting SGLang indexers sync test: store_backend={store_backend}, " f"Starting SGLang indexers sync test: store_backend={store_backend}, "
f"use_nats_core={use_nats_core}, request_plane={request_plane}" f"durable_kv_events={durable_kv_events}, request_plane={request_plane}"
) )
N_SGLANG_WORKERS = 2 N_SGLANG_WORKERS = 2
...@@ -522,13 +521,14 @@ def test_sglang_indexers_sync( ...@@ -522,13 +521,14 @@ def test_sglang_indexers_sync(
single_gpu=True, # fit workers into one GPU single_gpu=True, # fit workers into one GPU
request_plane=request_plane, request_plane=request_plane,
store_backend=store_backend, store_backend=store_backend,
enable_local_indexer=use_nats_core, durable_kv_events=durable_kv_events,
) )
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}") logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
sglang_workers.__enter__() sglang_workers.__enter__()
# Use the common test implementation (creates its own runtimes for each router) # Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive # Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
# When using durable_kv_events=True, use JetStream mode for the router
_test_router_indexers_sync( _test_router_indexers_sync(
engine_workers=sglang_workers, engine_workers=sglang_workers,
block_size=PAGE_SIZE, block_size=PAGE_SIZE,
...@@ -536,8 +536,9 @@ def test_sglang_indexers_sync( ...@@ -536,8 +536,9 @@ def test_sglang_indexers_sync(
num_workers=N_SGLANG_WORKERS, num_workers=N_SGLANG_WORKERS,
store_backend=store_backend, store_backend=store_backend,
request_plane=request_plane, request_plane=request_plane,
test_nats_interruption=use_nats_core, test_nats_interruption=not durable_kv_events,
nats_server=nats_process if use_nats_core else None, nats_server=nats_process if not durable_kv_events else None,
durable_kv_events=durable_kv_events,
) )
logger.info("SGLang indexers sync test completed successfully") logger.info("SGLang indexers sync test completed successfully")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment