"lib/vscode:/vscode.git/clone" did not exist on "3577b5c138465110a44dd1919d731d7fc29c1c26"
Unverified Commit 6783bdca authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: enable local indexers by default, and use normal event plane by default...


chore: enable local indexers by default, and use normal event plane by default (not jetstream) (#5941)
Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 3d7182b8
......@@ -126,7 +126,7 @@ impl ZmqKvEventPublisherConfig {
kv_block_size,
zmq_endpoint = "tcp://127.0.0.1:5557".to_string(),
zmq_topic = "".to_string(),
enable_local_indexer = false,
enable_local_indexer = true,
dp_rank = 0
))]
pub fn new(
......
......@@ -419,6 +419,15 @@ class ModelDeploymentCard:
A model deployment card is a collection of model information
"""
def to_json_str(self) -> str:
"""Serialize the model deployment card to a JSON string."""
...
@staticmethod
def from_json_str(json: str) -> "ModelDeploymentCard":
"""Deserialize a model deployment card from a JSON string."""
...
...
class ModelRuntimeConfig:
......@@ -737,7 +746,7 @@ class ZmqKvEventPublisherConfig:
kv_block_size: int,
zmq_endpoint: str = "tcp://127.0.0.1:5557",
zmq_topic: str = "",
enable_local_indexer: bool = False,
enable_local_indexer: bool = True,
dp_rank: int = 0
) -> None:
"""
......@@ -747,7 +756,7 @@ class ZmqKvEventPublisherConfig:
:param kv_block_size: The block size for the key-value store.
:param zmq_endpoint: The ZeroMQ endpoint. Defaults to "tcp://127.0.0.1:5557".
:param zmq_topic: The ZeroMQ topic to subscribe to. Defaults to an empty string.
:param enable_local_indexer: Whether to enable the worker-local KV indexer. Defaults to False.
:param enable_local_indexer: Whether to enable the worker-local KV indexer. Defaults to True.
:param dp_rank: The data parallel rank for this publisher. Defaults to 0.
"""
...
......@@ -924,10 +933,34 @@ class ModelType:
class RouterMode:
"""Router mode for load balancing requests across workers"""
RoundRobin: "RouterMode"
Random: "RouterMode"
KV: "RouterMode"
...
class RouterConfig:
"""How to route the request"""
def __init__(
self,
mode: RouterMode,
config: Optional[KvRouterConfig] = None,
active_decode_blocks_threshold: Optional[float] = None,
active_prefill_tokens_threshold: Optional[int] = None,
active_prefill_tokens_threshold_frac: Optional[float] = None,
enforce_disagg: bool = False,
) -> None:
"""
Create a RouterConfig.
Args:
mode: The router mode (RoundRobin, Random, or KV)
config: Optional KV router configuration (used when mode is KV)
active_decode_blocks_threshold: Threshold percentage (0.0-1.0) for decode blocks busy detection
active_prefill_tokens_threshold: Literal token count threshold for prefill busy detection
active_prefill_tokens_threshold_frac: Fraction of max_num_batched_tokens for busy detection
enforce_disagg: Enforce disaggregated prefill-decode mode
"""
...
class KvRouterConfig:
......@@ -938,6 +971,7 @@ class KvRouterConfig:
overlap_score_weight: float = 1.0,
router_temperature: float = 0.0,
use_kv_events: bool = True,
durable_kv_events: bool = False,
router_replica_sync: bool = False,
router_track_active_blocks: bool = True,
router_track_output_blocks: bool = False,
......@@ -955,6 +989,9 @@ class KvRouterConfig:
overlap_score_weight: Weight for overlap score in worker selection (default: 1.0)
router_temperature: Temperature for worker sampling via softmax (default: 0.0)
use_kv_events: Whether to use KV events from workers (default: True)
durable_kv_events: Enable durable KV events using NATS JetStream (default: False).
When False, uses NATS Core / generic event plane with local_indexer mode.
When True, uses JetStream for durability and multi-replica consistency.
router_replica_sync: Enable replica synchronization (default: False)
router_track_active_blocks: Track active blocks for load balancing (default: True)
router_track_output_blocks: Track output blocks during generation (default: False).
......@@ -1026,7 +1063,7 @@ class EngineConfig:
"""Holds internal configuration for a Dynamo engine."""
...
async def make_engine(args: EntrypointArgs) -> EngineConfig:
async def make_engine(distributed_runtime: DistributedRuntime, args: EntrypointArgs) -> EngineConfig:
"""Make an engine matching the args"""
...
......@@ -1435,12 +1472,63 @@ class KvPushRouter:
"""
...
class EngineType:
"""Engine type for Dynamo workers"""
Echo: "EngineType"
Dynamic: "EngineType"
Mocker: "EngineType"
...
class EntrypointArgs:
"""
Settings to connect an input to a worker and run them.
Use by `dynamo run`.
"""
def __init__(
self,
engine_type: "EngineType",
model_path: Optional[str] = None,
model_name: Optional[str] = None,
endpoint_id: Optional[str] = None,
context_length: Optional[int] = None,
template_file: Optional[str] = None,
router_config: Optional[RouterConfig] = None,
kv_cache_block_size: Optional[int] = None,
http_host: Optional[str] = None,
http_port: Optional[int] = None,
http_metrics_port: Optional[int] = None,
tls_cert_path: Optional[str] = None,
tls_key_path: Optional[str] = None,
extra_engine_args: Optional[str] = None,
namespace: Optional[str] = None,
is_prefill: bool = False,
migration_limit: int = 0,
engine_factory: Optional[Callable] = None,
) -> None:
"""
Create EntrypointArgs.
Args:
engine_type: The type of engine to use
model_path: Path to the model directory on disk
model_name: Model name or dynamo endpoint (e.g. 'dyn://namespace.component.endpoint')
endpoint_id: Optional endpoint ID
context_length: Optional context length override
template_file: Optional path to a prompt template file
router_config: Optional router configuration
kv_cache_block_size: Optional KV cache block size
http_host: HTTP host to bind to
http_port: HTTP port to bind to
http_metrics_port: HTTP metrics port (for gRPC service)
tls_cert_path: TLS certificate path (PEM format)
tls_key_path: TLS key path (PEM format)
extra_engine_args: Path to extra engine arguments file
namespace: Dynamo namespace for model discovery scoping
is_prefill: Whether this is a prefill worker
migration_limit: Maximum number of request migrations (0=disabled)
engine_factory: Optional Python engine factory callback
"""
...
class PlannerDecision:
......
......@@ -5,7 +5,7 @@ mod model_manager;
pub use model_manager::{ModelManager, ModelManagerError};
pub(crate) mod runtime_configs;
pub use runtime_configs::{RuntimeConfigs, RuntimeConfigsSubscriber};
pub use runtime_configs::{RuntimeConfigWatch, runtime_config_watch};
mod watcher;
pub use watcher::{ModelUpdate, ModelWatcher};
......
......@@ -11,7 +11,7 @@ use parking_lot::RwLock;
use tokio::sync::oneshot;
use super::worker_monitor::LoadThresholdConfig;
use super::{KvWorkerMonitor, RuntimeConfigs};
use super::{KvWorkerMonitor, RuntimeConfigWatch, runtime_config_watch};
use dynamo_runtime::{
component::{Client, Endpoint, build_transport_type},
......@@ -77,7 +77,7 @@ pub struct ModelManager {
// Per-model monitoring: worker_monitors for load-based rejection, runtime_configs for KvScheduler
worker_monitors: DashMap<String, KvWorkerMonitor>,
runtime_configs: DashMap<EndpointId, Arc<RuntimeConfigs>>,
runtime_configs: DashMap<EndpointId, RuntimeConfigWatch>,
}
impl Default for ModelManager {
......@@ -563,12 +563,12 @@ impl ModelManager {
}
/// Get or create a runtime config watcher for an endpoint.
/// Spawns a background task to watch for worker config changes.
/// Returns a shared RuntimeConfigs that KvScheduler can use directly.
/// Spawns a background task that joins instance availability and config discovery.
/// Returns a `watch::Receiver` with the latest `HashMap<WorkerId, ModelRuntimeConfig>`.
pub async fn get_or_create_runtime_config_watcher(
&self,
endpoint: &Endpoint,
) -> anyhow::Result<Arc<RuntimeConfigs>> {
) -> anyhow::Result<RuntimeConfigWatch> {
let endpoint_id = endpoint.id();
// Fast path: return existing if present
......@@ -576,21 +576,18 @@ impl ModelManager {
return Ok(existing.clone());
}
// Atomic get-or-insert to avoid TOCTOU race
let inner = Arc::new(RuntimeConfigs::new());
let (result, is_new) = match self.runtime_configs.entry(endpoint_id) {
Entry::Occupied(e) => (e.get().clone(), false),
// Slow path: create the watch (spawns a background task).
// If another caller raced us, the entry() below picks up the winner;
// the loser's background task stops once its receivers are dropped.
let rx = runtime_config_watch(endpoint).await?;
let result = match self.runtime_configs.entry(endpoint_id) {
Entry::Occupied(e) => e.get().clone(),
Entry::Vacant(e) => {
e.insert(inner.clone());
(inner, true)
e.insert(rx.clone());
rx
}
};
// Only spawn watcher if we were the one who inserted
if is_new {
result.start_watcher(endpoint).await?;
}
Ok(result)
}
......@@ -601,9 +598,9 @@ impl ModelManager {
endpoint_id: &EndpointId,
worker_id: WorkerId,
) -> Option<DisaggregatedEndpoint> {
let inner = self.runtime_configs.get(endpoint_id)?;
let config_ref = inner.configs.get(&worker_id)?;
config_ref.as_ref()?.disaggregated_endpoint.clone()
let rx = self.runtime_configs.get(endpoint_id)?;
let configs = rx.borrow();
configs.get(&worker_id)?.disaggregated_endpoint.clone()
}
/// Lists all models with worker monitors configured.
......
......@@ -2,9 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use dashmap::DashMap;
use tokio::sync::watch;
use dynamo_runtime::component::Endpoint;
......@@ -15,186 +13,72 @@ use crate::kv_router::protocols::WorkerId;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::model_card::ModelDeploymentCard;
/// Runtime configs for an endpoint with watch-based change notifications.
/// Call `subscribe()` to get a subscriber with its own watch receiver.
pub struct RuntimeConfigs {
pub configs: Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>,
change_tx: watch::Sender<u64>,
}
impl RuntimeConfigs {
pub(crate) fn new() -> Self {
let (change_tx, _) = watch::channel(0u64);
Self {
configs: Arc::new(DashMap::new()),
change_tx,
}
}
/// Create a subscriber that can wait for config changes.
/// Each subscriber has its own watch receiver, so notifications are not lost.
pub fn subscribe(&self) -> RuntimeConfigsSubscriber {
RuntimeConfigsSubscriber {
configs: self.configs.clone(),
change_rx: self.change_tx.subscribe(),
}
}
/// Notify all subscribers of a change (internal use only).
fn notify_change(&self) {
// Increment counter to notify subscribers
self.change_tx.send_modify(|v| *v = v.wrapping_add(1));
}
/// Returns the number of workers in the configs.
pub fn num_workers(&self) -> usize {
self.configs.len()
}
/// Update configs with new worker instances and their configs.
/// Notifies subscribers if a config with Some value is added or a worker is removed.
pub(crate) fn update(
&self,
new_instance_ids: &[WorkerId],
new_configs: &HashMap<WorkerId, ModelRuntimeConfig>,
) {
// First, remove workers that no longer exist
let current_workers: HashSet<WorkerId> = self.configs.iter().map(|r| *r.key()).collect();
let new_workers: HashSet<WorkerId> = new_instance_ids.iter().copied().collect();
let mut worker_removed = false;
for removed_worker in current_workers.difference(&new_workers) {
self.configs.remove(removed_worker);
worker_removed = true;
}
/// Type alias for the runtime config watch receiver.
pub type RuntimeConfigWatch = watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>;
// Then, add/update workers
// Track if any config became Some (for notify)
let mut config_added = false;
for worker_id in new_instance_ids {
let config = new_configs.get(worker_id).cloned();
if config.is_some() {
let prev_config = self.configs.get(worker_id);
let was_none = prev_config
.as_ref()
.map(|r| r.value().is_none())
.unwrap_or(true);
if was_none {
tracing::info!("RuntimeConfigs: config found for worker_id: {worker_id}");
config_added = true;
}
}
self.configs.insert(*worker_id, config);
}
// Notify when a config with Some value is added OR a worker is removed
if config_added || worker_removed {
self.notify_change();
}
}
/// Spawn background task to watch runtime configs via discovery.
/// Does not block - consumers should use `subscribe().wait_for_some()` if they need workers.
pub(crate) async fn start_watcher(self: &Arc<Self>, endpoint: &Endpoint) -> anyhow::Result<()> {
/// Join instance availability and config discovery into a single watch.
///
/// Only includes workers that have BOTH an instance registration AND a runtime config.
/// Spawns a background task that recomputes the joined state whenever either source changes.
/// The returned `watch::Receiver` always contains the latest joined snapshot.
pub async fn runtime_config_watch(endpoint: &Endpoint) -> anyhow::Result<RuntimeConfigWatch> {
let component = endpoint.component();
let cancellation_token = component.drt().primary_token();
let cancel_token = component.drt().primary_token();
// Set up discovery watch for EndpointModels
// Source 1: instance availability (watches DiscoveryQuery::Endpoint)
let client = endpoint.client().await?;
let mut instance_ids_rx = client.instance_avail_watcher();
// Source 2: runtime configs from discovery (watches DiscoveryQuery::EndpointModels)
let discovery = component.drt().discovery();
let endpoint_id = endpoint.id();
let discovery_key = DiscoveryQuery::EndpointModels {
namespace: endpoint_id.namespace.clone(),
component: endpoint_id.component.clone(),
endpoint: endpoint_id.name.clone(),
};
let discovery_stream = discovery
.list_and_watch(discovery_key.clone(), Some(cancellation_token.clone()))
let eid = endpoint.id();
let stream = discovery
.list_and_watch(
DiscoveryQuery::EndpointModels {
namespace: eid.namespace.clone(),
component: eid.component.clone(),
endpoint: eid.name.clone(),
},
Some(cancel_token.clone()),
)
.await?;
let mut configs_rx =
watch_and_extract_field(stream, |card: ModelDeploymentCard| card.runtime_config);
// Extract runtime_config from ModelDeploymentCard
let mut runtime_configs_rx =
watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
card.runtime_config
});
// Also watch instance IDs
let client = endpoint.client().await?;
let mut instance_ids_rx = client.instance_avail_watcher();
let (tx, rx) = watch::channel(HashMap::new());
// Spawn background task to watch for config changes
// Note: We don't block here - consumers should wait on notify for configs they need
let inner = self.clone();
let cancel_token = cancellation_token.clone();
tokio::spawn(async move {
tracing::trace!("RuntimeConfigs watcher started");
loop {
// Wait for either instances or configs to change
tokio::select! {
_ = cancel_token.cancelled() => {
tracing::trace!("RuntimeConfigs watcher shutting down");
break;
}
result = instance_ids_rx.changed() => {
if result.is_err() {
tracing::warn!("instance IDs watch sender shutdown");
break;
}
}
result = runtime_configs_rx.changed() => {
if result.is_err() {
tracing::warn!("runtime configs watch sender shutdown");
break;
}
}
}
// Get the latest values from both channels
let new_instance_ids = instance_ids_rx.borrow_and_update().clone();
let new_configs = runtime_configs_rx.borrow_and_update().clone();
inner.update(&new_instance_ids, &new_configs);
tracing::trace!(
"RuntimeConfigs: Updated with {} workers",
inner.configs.len()
);
_ = cancel_token.cancelled() => break,
result = instance_ids_rx.changed() => { if result.is_err() { break; } }
result = configs_rx.changed() => { if result.is_err() { break; } }
}
tracing::trace!("RuntimeConfigs watcher stopped");
});
Ok(())
}
}
/// A subscriber to runtime config changes.
/// Each subscriber has its own watch receiver, ensuring no notifications are lost.
pub struct RuntimeConfigsSubscriber {
pub configs: Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>,
pub change_rx: watch::Receiver<u64>,
}
impl RuntimeConfigsSubscriber {
/// Wait until at least one worker has a Some config.
/// Returns the list of worker IDs that have configs.
/// This is race-safe: checks the DashMap first, only waits if empty.
/// Returns empty vec if the sender is dropped (shutdown).
pub async fn wait_for_some(&mut self) -> Vec<WorkerId> {
loop {
let ready: Vec<WorkerId> = self
.configs
let instances: HashSet<WorkerId> = instance_ids_rx
.borrow_and_update()
.iter()
.filter(|r| r.value().is_some())
.map(|r| *r.key())
.copied()
.collect();
let configs = configs_rx.borrow_and_update().clone();
if !ready.is_empty() {
return ready;
}
let ready: HashMap<WorkerId, ModelRuntimeConfig> = instances
.into_iter()
.filter_map(|id| configs.get(&id).map(|cfg| (id, cfg.clone())))
.collect();
// If sender dropped (shutdown), return empty rather than loop forever
if self.change_rx.changed().await.is_err() {
tracing::warn!("RuntimeConfigsSubscriber: sender dropped during wait_for_some");
return vec![];
// Only send if the joined result actually changed, to avoid waking
// downstream consumers (wait_for, changed) on no-op recomputations.
if *tx.borrow() == ready {
continue;
}
// Break if all receivers dropped (e.g., TOCTOU in model_manager discards a duplicate).
if tx.send(ready).is_err() {
break;
}
}
});
Ok(rx)
}
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use derive_builder::Builder;
use rand::Rng;
use serde::{Deserialize, Serialize};
use validator::Validate;
use crate::kv_router::protocols::{compute_block_hash_for_seq, compute_seq_hash_for_block};
/// Override configuration for router settings that can be specified per-request
#[derive(Debug, Clone, Default, Builder, Serialize, Deserialize, Validate)]
pub struct RouterConfigOverride {
#[builder(default)]
pub overlap_score_weight: Option<f64>,
#[builder(default)]
#[validate(range(min = 0.0))]
pub router_temperature: Option<f64>,
}
/// KV Router configuration parameters
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Validate)]
pub struct KvRouterConfig {
#[validate(range(min = 0.0))]
pub overlap_score_weight: f64,
#[validate(range(min = 0.0))]
pub router_temperature: f64,
pub use_kv_events: bool,
/// Enable durable KV events using NATS JetStream instead of the default event plane.
/// When false (default), the router uses the event-plane subscriber and requires
/// workers to have local_indexer enabled for gap recovery.
/// When true, uses JetStream for durability and multi-replica consistency.
pub durable_kv_events: bool,
pub router_replica_sync: bool,
/// Whether to track active blocks in the router (default: true)
pub router_track_active_blocks: bool,
/// Whether to track output blocks during generation (default: false)
/// When enabled, the router adds placeholder blocks as tokens are generated
/// and applies fractional decay based on progress toward expected_output_tokens.
pub router_track_output_blocks: bool,
/// Whether to assume KV cache reuse when tracking active blocks (default: true).
/// When true, computes actual block hashes for sequence tracking.
/// When false, generates random hashes (assuming no KV cache reuse).
pub router_assume_kv_reuse: bool,
/// Threshold for triggering snapshots. If None, no snapshots will be performed.
#[validate(range(min = 1))]
pub router_snapshot_threshold: Option<u32>,
/// Whether to reset the router state on startup (default: false)
pub router_reset_states: bool,
/// TTL for blocks in seconds (only used when use_kv_events is false, default: 120.0)
#[validate(range(min = 0.0))]
pub router_ttl_secs: f64,
/// Maximum tree size before pruning (only used when use_kv_events is false, default: 2^20 = 1048576)
#[validate(range(min = 1))]
pub router_max_tree_size: usize,
/// Target size ratio after pruning (only used when use_kv_events is false, default: 0.8)
#[validate(range(min = 0.0, max = 1.0))]
pub router_prune_target_ratio: f64,
}
impl Default for KvRouterConfig {
fn default() -> Self {
Self {
overlap_score_weight: 1.0,
router_temperature: 0.0,
use_kv_events: true,
durable_kv_events: false, // default to NATS Core (local indexer mode)
router_replica_sync: false,
router_track_active_blocks: true,
router_track_output_blocks: false,
router_assume_kv_reuse: true,
router_snapshot_threshold: Some(1000000),
router_reset_states: false,
router_ttl_secs: 120.0,
router_max_tree_size: 2usize.pow(20), // 2^20 = 1048576, matches PruneConfig::default()
router_prune_target_ratio: 0.8,
}
}
}
impl KvRouterConfig {
/// Compute sequence hashes for active block tracking based on configuration.
///
/// Returns:
/// - `None` if `router_track_active_blocks` is false
/// - Random hashes if `router_track_active_blocks` is true but `router_assume_kv_reuse` is false
/// - Actual sequence hashes if both are true
pub fn compute_seq_hashes_for_tracking(
&self,
tokens: &[u32],
block_size: u32,
) -> Option<Vec<u64>> {
if !self.router_track_active_blocks {
return None;
}
let num_blocks = tokens.len() / block_size as usize;
if num_blocks == 0 {
return Some(Vec::new());
}
if self.router_assume_kv_reuse {
// Compute actual block hashes and sequence hashes
let block_hashes = compute_block_hash_for_seq(tokens, block_size, None);
Some(compute_seq_hash_for_block(&block_hashes))
} else {
// Generate random hashes (no KV reuse assumed)
let mut rng = rand::rng();
Some((0..num_blocks).map(|_| rng.random::<u64>()).collect())
}
}
/// Check if KV event subscription should be started.
///
/// Returns false if:
/// - KV events are disabled (`use_kv_events=false`)
/// - Overlap scoring is disabled (`overlap_score_weight=0`)
///
/// When false, the router skips starting the KV event subscription entirely,
/// avoiding the need to query workers for their local indexer state.
pub fn should_subscribe_to_kv_events(&self) -> bool {
self.use_kv_events && self.overlap_score_weight > 0.0
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use anyhow::Result;
use dynamo_runtime::{
pipeline::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
SingleIn, async_trait,
},
protocols::annotated::Annotated,
};
use futures::stream::{self, StreamExt};
use serde_json::json;
use crate::{
kv_router::{
KvRouter,
protocols::{TokensWithHashes, WorkerWithDpRank},
},
preprocessor::PreprocessedRequest,
protocols::common::{llm_backend::LLMEngineOutput, timing::RequestPhase},
};
pub struct KvPushRouter {
inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
pub chooser: Arc<KvRouter>,
}
/// Result of worker selection containing instance ID, dp_rank, and overlap amount.
struct WorkerSelection {
instance_id: u64,
dp_rank: u32,
overlap_amount: u32,
}
impl KvPushRouter {
pub fn new(
inner: PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>,
chooser: Arc<KvRouter>,
) -> Self {
KvPushRouter { inner, chooser }
}
/// Select a worker for the request, either using a preselected worker or finding the best match.
///
/// When `is_query_only` is false and `handle_local_updates` is true, this also registers
/// the request with the scheduler via `add_request`.
async fn select_worker(
&self,
context_id: &str,
request: &PreprocessedRequest,
phase: RequestPhase,
is_query_only: bool,
handle_local_updates: bool,
) -> Result<WorkerSelection, Error> {
let routing = request.routing.as_ref();
let lora_name = routing.and_then(|r| r.lora_name.clone());
let dp_rank = routing.and_then(|r| r.dp_rank).unwrap_or(0);
let expected_output_tokens = routing.and_then(|r| r.expected_output_tokens);
// Get pre-selected worker based on phase, with backend_instance_id as fallback
let preselected_id = match phase {
RequestPhase::Prefill => {
routing.and_then(|r| r.prefill_worker_id.or(r.backend_instance_id))
}
RequestPhase::Decode => {
routing.and_then(|r| r.decode_worker_id.or(r.backend_instance_id))
}
RequestPhase::Aggregated => routing.and_then(|r| r.backend_instance_id),
};
let Some(id) = preselected_id else {
let (best_worker, overlap_amount) = self
.chooser
.find_best_match(
Some(context_id),
&request.token_ids,
request.router_config_override.as_ref(),
!is_query_only,
lora_name,
)
.await?;
return Ok(WorkerSelection {
instance_id: best_worker.worker_id,
dp_rank: best_worker.dp_rank,
overlap_amount,
});
};
tracing::debug!(
worker_id = id,
dp_rank = dp_rank,
?phase,
"Routing to specified worker"
);
let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = self
.chooser
.get_overlap_blocks(&request.token_ids, worker)
.await?;
if !is_query_only && handle_local_updates {
self.chooser
.add_request(
context_id.to_string(),
&request.token_ids,
overlap_blocks,
expected_output_tokens,
worker,
lora_name,
)
.await;
} else {
tracing::debug!(
request_id = %context_id,
worker_id = id,
dp_rank = dp_rank,
"Skipping add_request - query or handled externally"
);
}
Ok(WorkerSelection {
instance_id: id,
dp_rank,
overlap_amount: overlap_blocks,
})
}
}
#[async_trait]
impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutput>>, Error>
for KvPushRouter
{
/// Generate method that handles KV-aware routing with three distinct behaviors:
///
/// 1. **If `query_instance_id` annotation is set**:
/// - Returns the best matching worker ID without routing the request
/// - Does NOT update any router local states
/// - Response includes worker_instance_id and token_data annotations
///
/// 2. **If `backend_instance_id` is set in the request**:
/// - Routes directly to the specified backend instance
/// - DOES update router states to track this request (unless query_instance_id is also set)
/// - Bypasses the normal KV matching logic
///
/// 3. **If neither are set (default behavior)**:
/// - Finds the best worker based on KV cache overlap
/// - Updates router states to track the request
/// - Routes to the selected worker
///
/// The router state updates include tracking active sequences and managing
/// prefill/completion lifecycle for proper KV cache management.
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
// Extract context ID for request tracking
let context_id = request.context().id().to_string();
// Simple query-only detection: presence of query_instance_id annotation means query-only mode
let is_query_only = request.get_annotation_value("query_instance_id").is_some();
// Determine if this router should handle local state updates (add_request, free, etc.)
// Default is true (router handles bookkeeping). Set to false for GAIE Stage 2 where
// an external orchestrator (e.g., EPP sidecar) handles bookkeeping via C FFI.
let handle_local_updates = request
.routing
.as_ref()
.and_then(|r| r.enable_local_updates)
.unwrap_or(true);
// Get phase from tracker (defaults to Aggregated if no tracker or phase not set)
let phase = request
.tracker
.as_ref()
.map(|t| t.phase())
.unwrap_or(RequestPhase::Aggregated);
let block_size = self.chooser.block_size() as usize;
let selection = self
.select_worker(
&context_id,
&request,
phase,
is_query_only,
handle_local_updates,
)
.await?;
let WorkerSelection {
instance_id,
dp_rank,
overlap_amount,
} = selection;
// In approximate mode (use_kv_events=false), record the routing decision
// so the indexer can track cache state based on routing decisions.
// This covers both pre-selected workers and find_best_match selections.
if !is_query_only && !self.chooser.kv_router_config().use_kv_events {
let worker = WorkerWithDpRank::new(instance_id, dp_rank);
let mut tokens_with_hashes =
TokensWithHashes::new(request.token_ids.clone(), self.chooser.block_size());
if let Err(e) = self
.chooser
.indexer()
.process_routing_decision_for_request(&mut tokens_with_hashes, worker)
.await
{
tracing::warn!(
request_id = %context_id,
worker_id = instance_id,
dp_rank = dp_rank,
error = %e,
"Failed to record routing decision in approximate mode"
);
}
}
// Record metrics in tracker: KV hit rate, worker ID, and worker type based on phase.
// Worker type is stored at routing time to avoid expensive MDC lookups when
// updating Prometheus metrics (TTFT/ITL) later in the response stream.
if let Some(ref tracker) = request.tracker {
let isl_blocks = request.token_ids.len().div_ceil(block_size);
tracker.record_kv_hit(overlap_amount, isl_blocks);
tracker.record_worker_full(instance_id, dp_rank, self.chooser.worker_type());
}
// Handle query-only requests: early return with worker info
if is_query_only {
let stream_context = request.context().clone();
// Tracker is always created for query-only requests (delta generator enables tracking
// when query_instance_id annotation is present)
let worker_id_info = request.tracker.as_ref().and_then(|t| t.get_worker_info());
tracing::trace!(
?phase,
worker_id = instance_id,
?worker_id_info,
"Returning worker selection (query-only mode)"
);
let output = LLMEngineOutput {
disaggregated_params: Some(json!({
"worker_id": worker_id_info,
"token_ids": request.token_ids
})),
..Default::default()
};
let response = Annotated::from_data(output);
let stream = stream::iter(vec![response]);
return Ok(ResponseStream::new(Box::pin(stream), stream_context));
}
// Route to worker
let isl_tokens = request.token_ids.len();
let expected_output_tokens = request
.routing
.as_ref()
.and_then(|r| r.expected_output_tokens);
let track_output_blocks =
self.chooser.kv_router_config().router_track_output_blocks && handle_local_updates;
let (mut backend_input, context) = request.into_parts();
backend_input.routing_mut().dp_rank = Some(dp_rank);
let updated_request = context.map(|_| backend_input);
let chooser = self.chooser.clone();
let mut response_stream = self.inner.direct(updated_request, instance_id).await?;
let stream_context = response_stream.context();
let context_for_monitoring = stream_context.clone();
// Wrap stream with lifecycle management (mark_prefill_completed, free)
// Only perform these operations if handle_local_updates is true.
// When false, an external caller (e.g., GAIE sidecar) handles bookkeeping via C FFI.
let wrapped_stream = Box::pin(async_stream::stream! {
let mut prefill_marked = false;
// Output block tracking state
let mut cumulative_osl: usize = 0;
let mut current_total_blocks = isl_tokens.div_ceil(block_size);
loop {
tokio::select! {
biased;
_ = context_for_monitoring.stopped() => {
tracing::debug!("Request {context_id} cancelled, ending stream");
break;
}
item = response_stream.next() => {
let Some(item) = item else {
break;
};
if handle_local_updates && !prefill_marked {
// Only mark prefill completed when we receive actual tokens,
// not empty bootstrap info (token_ids: []) from disaggregated prefill
let has_tokens = item.data.as_ref()
.map(|d| !d.token_ids.is_empty())
.unwrap_or(false);
if has_tokens {
if let Err(e) = chooser.mark_prefill_completed(&context_id).await {
tracing::warn!("Failed to mark prefill completed for request {context_id}: {e}");
}
prefill_marked = true;
}
}
// Track output blocks if enabled
if track_output_blocks {
let new_tokens = item.data.as_ref()
.map(|d| d.token_ids.len())
.unwrap_or(0);
cumulative_osl += new_tokens;
let new_total_blocks = (isl_tokens + cumulative_osl).div_ceil(block_size);
if new_total_blocks > current_total_blocks {
// New block boundary crossed - add output block with decay
// Clamp eot to min 1 to avoid division by zero, and result to min 0.0
let decay_fraction = expected_output_tokens.map(|eot| {
(1.0 - (cumulative_osl as f64 / eot.max(1) as f64)).max(0.0)
});
if let Err(e) = chooser.add_output_block(&context_id, decay_fraction).await {
tracing::warn!(
"Failed to add output block for request {context_id}: {e}"
);
}
current_total_blocks = new_total_blocks;
}
}
yield item;
}
}
}
// Only call free() if we handle local updates.
// When handle_local_updates=false, external caller handles cleanup via C FFI.
if handle_local_updates
&& let Err(e) = chooser.free(&context_id).await
{
tracing::warn!("Failed to free request {context_id}: {e}");
}
});
Ok(ResponseStream::new(wrapped_stream, stream_context))
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::discovery::RuntimeConfigs;
use crate::discovery::RuntimeConfigWatch;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result;
use dynamo_runtime::component::Component;
......@@ -99,7 +99,7 @@ impl KvScheduler {
pub async fn start(
component: Component,
block_size: u32,
workers_with_configs: Arc<RuntimeConfigs>,
workers_with_configs: RuntimeConfigWatch,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool,
router_id: u64,
......@@ -107,13 +107,10 @@ impl KvScheduler {
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
// Get initial workers from DashMap for slot initialization.
// Caller must ensure at least one worker is present (via wait_for_some).
let initial_workers: HashMap<WorkerId, Option<ModelRuntimeConfig>> = workers_with_configs
.configs
.iter()
.map(|r| (*r.key(), r.value().clone()))
.collect();
// Get initial workers from watch receiver.
// Caller must ensure at least one worker is present (via wait_for).
let initial_workers: HashMap<WorkerId, ModelRuntimeConfig> =
workers_with_configs.borrow().clone();
let slots = Arc::new(
ActiveSequencesMultiWorker::new(
......@@ -128,25 +125,21 @@ impl KvScheduler {
.map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?,
);
// Spawn background task to sync slots with DashMap when notified of changes.
// ModelManager's watcher updates the DashMap and notifies; we wait on watch receiver here.
// Spawn background task to sync slots when the watch value changes.
let slots_monitor = slots.clone();
let subscriber = workers_with_configs.subscribe();
let configs_monitor = subscriber.configs;
let mut change_rx = subscriber.change_rx;
let mut monitor_rx = workers_with_configs.clone();
let monitor_cancel_token = component.drt().child_token();
tokio::spawn(async move {
tracing::trace!("KvScheduler workers monitoring task started");
let mut last_workers: HashSet<WorkerId> = HashSet::new();
let mut last_workers: HashMap<WorkerId, ModelRuntimeConfig> = HashMap::new();
loop {
// Wait for notification or cancellation
tokio::select! {
_ = monitor_cancel_token.cancelled() => {
tracing::trace!("KvScheduler workers monitoring task shutting down");
break;
}
result = change_rx.changed() => {
result = monitor_rx.changed() => {
if result.is_err() {
tracing::warn!("KvScheduler: config watch sender dropped, shutting down");
break;
......@@ -154,25 +147,17 @@ impl KvScheduler {
}
}
// Get current workers from DashMap
let current_workers: HashMap<WorkerId, Option<ModelRuntimeConfig>> =
configs_monitor
.iter()
.map(|r| (*r.key(), r.value().clone()))
.collect();
let current_worker_ids: HashSet<WorkerId> =
current_workers.keys().copied().collect();
let current_workers = monitor_rx.borrow_and_update().clone();
// Only update slots if workers have changed
if current_worker_ids != last_workers {
slots_monitor.update_workers(current_workers);
last_workers = current_worker_ids;
if current_workers != last_workers {
slots_monitor.update_workers(current_workers.clone());
last_workers = current_workers;
}
}
});
let slots_clone = slots.clone();
let workers_scheduler = workers_with_configs.clone();
let scheduler_rx = workers_with_configs.clone();
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
let scheduler_cancel_token = component.drt().primary_token();
let hit_rate_publisher =
......@@ -209,12 +194,8 @@ impl KvScheduler {
request.decode_blocks = decode_blocks;
request.prefill_tokens = prefill_tokens;
// Read the current workers configuration from DashMap
let workers: HashMap<WorkerId, Option<ModelRuntimeConfig>> = workers_scheduler
.configs
.iter()
.map(|r| (*r.key(), r.value().clone()))
.collect();
// Read the current workers configuration from watch receiver
let workers: HashMap<WorkerId, ModelRuntimeConfig> = scheduler_rx.borrow().clone();
match selector.select_worker(&workers, &request, block_size) {
Ok(selection) => {
......@@ -511,7 +492,7 @@ impl DefaultWorkerSelector {
impl WorkerSelector for DefaultWorkerSelector {
fn select_worker(
&self,
workers: &HashMap<WorkerId, Option<ModelRuntimeConfig>>,
workers: &HashMap<WorkerId, ModelRuntimeConfig>,
request: &SchedulingRequest,
block_size: u32,
) -> Result<WorkerSelectionResult, KvSchedulerError> {
......@@ -541,11 +522,8 @@ impl WorkerSelector for DefaultWorkerSelector {
// Outer loop: iterate over all workers from runtime config
// Inner loop: iterate over all dp_ranks for each worker
for (worker_id, config) in workers.iter() {
// Get data_parallel_size from runtime config
// data_parallel_size defaults to 1 in ModelRuntimeConfig
let data_parallel_size = config.as_ref().map(|c| c.data_parallel_size).unwrap_or(1); // Fallback if config is None
let data_parallel_size = config.data_parallel_size;
// Iterate over all dp_ranks for this worker
for dp_rank in 0..data_parallel_size {
let worker = WorkerWithDpRank::new(*worker_id, dp_rank);
......@@ -612,7 +590,6 @@ impl WorkerSelector for DefaultWorkerSelector {
// this is a runtime config set on a per worker basis, not per dp-rank
let total_blocks_info = workers
.get(&best_worker.worker_id)
.and_then(|cfg| cfg.as_ref())
.and_then(|cfg| cfg.total_kv_blocks)
.map(|blocks| format!(", total blocks: {}", blocks))
.unwrap_or_default();
......
......@@ -424,7 +424,7 @@ impl ActiveSequencesMultiWorker {
pub async fn new(
component: Component,
block_size: usize,
workers_with_configs: HashMap<u64, Option<ModelRuntimeConfig>>,
workers_with_configs: HashMap<u64, ModelRuntimeConfig>,
replica_sync: bool,
router_id: u64,
worker_type: &'static str,
......@@ -438,7 +438,7 @@ impl ActiveSequencesMultiWorker {
// Expand workers by their dp_rank
for (worker_id, config) in workers_with_configs {
let dp_size = config.as_ref().map(|c| c.data_parallel_size).unwrap_or(1);
let dp_size = config.data_parallel_size;
for dp_rank in 0..dp_size {
let worker = WorkerWithDpRank::new(worker_id, dp_rank);
......@@ -710,17 +710,14 @@ impl ActiveSequencesMultiWorker {
}
/// Update the set of workers, adding and removing as needed
pub fn update_workers(
&self,
new_workers_with_configs: HashMap<u64, Option<ModelRuntimeConfig>>,
) {
pub fn update_workers(&self, new_workers_with_configs: HashMap<u64, ModelRuntimeConfig>) {
let current_workers: HashSet<WorkerWithDpRank> =
self.senders.iter().map(|entry| *entry.key()).collect();
// Expand new workers by their dp_rank
let mut new_workers: HashSet<WorkerWithDpRank> = HashSet::new();
for (worker_id, config) in &new_workers_with_configs {
let dp_size = config.as_ref().map(|c| c.data_parallel_size).unwrap_or(1);
let dp_size = config.data_parallel_size;
for dp_rank in 0..dp_size {
new_workers.insert(WorkerWithDpRank::new(*worker_id, dp_rank));
......@@ -784,10 +781,15 @@ impl ActiveSequencesMultiWorker {
worker: WorkerWithDpRank,
lora_name: Option<String>,
) -> Result<(), SequenceError> {
// Check for worker existence
if !self.senders.contains_key(&worker) {
return Err(SequenceError::WorkerNotFound { worker });
}
// Clone the sender upfront so we don't hold the DashMap Ref across
// the .await points below. Also eliminates the TOCTOU between
// contains_key and a later get().unwrap().
let sender = self
.senders
.get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?
.value()
.clone();
// Check for duplicate request
if let Some(existing_worker) = self.request_to_worker.get(&request_id) {
......@@ -825,9 +827,7 @@ impl ActiveSequencesMultiWorker {
self.request_to_lora.insert(request_id.clone(), lora);
}
self.senders
.get(&worker)
.unwrap()
sender
.send(UpdateSequences::AddRequest {
request_id,
token_sequence,
......@@ -855,25 +855,31 @@ impl ActiveSequencesMultiWorker {
Ok(())
}
/// Free all blocks associated with a request
///
/// Note: This operation is idempotent. Calling it multiple times for the same request
/// will log a warning but not return an error (double free is allowed).
pub async fn free(&self, request_id: &RequestId) -> Result<(), SequenceError> {
// Check if request exists - if not, it's already been freed (idempotent)
let Some(worker) = self.request_to_worker.get(request_id).map(|entry| *entry) else {
tracing::debug!("Request {request_id} not found, already freed (idempotent)");
return Ok(());
};
/// Send a command to the worker assigned to a request, optionally publishing
/// a replica-sync event and cleaning up request mappings afterward.
async fn send_to_request_worker(
&self,
request_id: &RequestId,
event_data: ActiveSequenceEventData,
command_fn: impl FnOnce(RequestId) -> UpdateSequences,
remove_mapping: bool,
) -> Result<(), SequenceError> {
let worker = self
.request_to_worker
.get(request_id)
.map(|entry| *entry)
.ok_or_else(|| SequenceError::RequestNotFound {
request_id: request_id.clone(),
})?;
// Verify worker still exists
if !self.senders.contains_key(&worker) {
return Err(SequenceError::WorkerNotFound { worker });
}
let sender = self
.senders
.get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?
.value()
.clone();
// Publish event only if replica_sync is enabled
if self.replica_sync {
// Look up lora_name from mapping
let lora_name = self
.request_to_lora
.get(request_id)
......@@ -882,31 +888,46 @@ impl ActiveSequencesMultiWorker {
let event = ActiveSequenceEvent {
request_id: request_id.clone(),
worker,
data: ActiveSequenceEventData::Free,
data: event_data,
router_id: self.router_id,
lora_name,
};
self.event_publisher.publish(&event).await?;
}
// Update local state
self.senders
.get(&worker)
.unwrap()
.send(UpdateSequences::Free {
request_id: request_id.clone(),
})
sender
.send(command_fn(request_id.clone()))
.map_err(|_| SequenceError::WorkerChannelClosed)?;
if remove_mapping {
self.request_to_worker.remove(request_id);
self.request_to_lora.remove(request_id);
}
// Publish ActiveLoad metrics for this worker
self.publish_active_load_for_worker(worker).await;
Ok(())
}
/// Free all blocks associated with a request
///
/// Note: This operation is idempotent. Calling it multiple times for the same request
/// will log a warning but not return an error (double free is allowed).
pub async fn free(&self, request_id: &RequestId) -> Result<(), SequenceError> {
if !self.request_to_worker.contains_key(request_id) {
tracing::debug!("Request {request_id} not found, already freed (idempotent)");
return Ok(());
}
self.send_to_request_worker(
request_id,
ActiveSequenceEventData::Free,
|rid| UpdateSequences::Free { request_id: rid },
true,
)
.await
}
/// Mark prefill as completed for a request
///
/// Note: Calling this multiple times for the same request is allowed and will be a no-op
......@@ -915,50 +936,13 @@ impl ActiveSequencesMultiWorker {
&self,
request_id: &RequestId,
) -> Result<(), SequenceError> {
let worker = self
.request_to_worker
.get(request_id)
.map(|entry| *entry)
.ok_or_else(|| SequenceError::RequestNotFound {
request_id: request_id.clone(),
})?;
// Verify worker still exists
if !self.senders.contains_key(&worker) {
return Err(SequenceError::WorkerNotFound { worker });
}
// Publish event only if replica_sync is enabled
if self.replica_sync {
// Look up lora_name from mapping
let lora_name = self
.request_to_lora
.get(request_id)
.map(|entry| entry.value().clone());
let event = ActiveSequenceEvent {
request_id: request_id.clone(),
worker,
data: ActiveSequenceEventData::MarkPrefillCompleted,
router_id: self.router_id,
lora_name,
};
self.event_publisher.publish(&event).await?;
}
// Update local state
self.senders
.get(&worker)
.unwrap()
.send(UpdateSequences::MarkPrefillCompleted {
request_id: request_id.clone(),
})
.map_err(|_| SequenceError::WorkerChannelClosed)?;
// Publish ActiveLoad metrics for this worker
self.publish_active_load_for_worker(worker).await;
Ok(())
self.send_to_request_worker(
request_id,
ActiveSequenceEventData::MarkPrefillCompleted,
|rid| UpdateSequences::MarkPrefillCompleted { request_id: rid },
false,
)
.await
}
/// Add an output block with optional fractional decay weight
......@@ -978,18 +962,19 @@ impl ActiveSequencesMultiWorker {
request_id: request_id.clone(),
})?;
// Verify worker still exists
if !self.senders.contains_key(&worker) {
return Err(SequenceError::WorkerNotFound { worker });
}
// Clone sender upfront to avoid TOCTOU between contains_key and get().unwrap()
let sender = self
.senders
.get(&worker)
.ok_or(SequenceError::WorkerNotFound { worker })?
.value()
.clone();
// Create response channel
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
// Send command to worker
self.senders
.get(&worker)
.unwrap()
sender
.send(UpdateSequences::AddOutputBlock {
request_id: request_id.clone(),
decay_fraction,
......@@ -1016,10 +1001,17 @@ impl ActiveSequencesMultiWorker {
/// Helper method to query a single worker for active blocks/tokens and publish ActiveLoad
async fn publish_active_load_for_worker(&self, worker: WorkerWithDpRank) {
let Some(sender) = self.senders.get(&worker) else {
// Clone the sender and drop the DashMap Ref immediately.
// Holding a Ref across .await points can deadlock: if the task yields
// and update_workers() needs a write lock on the same shard, the
// runtime thread blocks forever.
let sender = {
let Some(entry) = self.senders.get(&worker) else {
tracing::warn!("Worker {worker:?} not found when publishing ActiveLoad");
return;
};
entry.value().clone()
};
// Query active blocks
let (blocks_tx, blocks_rx) = tokio::sync::oneshot::channel();
......@@ -1337,11 +1329,11 @@ mod tests {
// Create runtime config for worker 0 with dp_size=2
let mut config_worker_0 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
config_worker_0.data_parallel_size = 2;
workers_with_configs.insert(0, Some(config_worker_0));
workers_with_configs.insert(0, config_worker_0);
// Create runtime config for worker 1 with dp_size=1 (default)
let config_worker_1 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
workers_with_configs.insert(1, Some(config_worker_1));
workers_with_configs.insert(1, config_worker_1);
let seq_manager_1 = Arc::new(
ActiveSequencesMultiWorker::new(
......@@ -1509,9 +1501,18 @@ mod tests {
// Create multi-worker sequence managers with ALL workers [0, 1, 2]
// Both use the same component to ensure event synchronization works
let mut workers_with_configs = HashMap::new();
workers_with_configs.insert(0, None);
workers_with_configs.insert(1, None);
workers_with_configs.insert(2, None);
workers_with_configs.insert(
0,
crate::local_model::runtime_config::ModelRuntimeConfig::new(),
);
workers_with_configs.insert(
1,
crate::local_model::runtime_config::ModelRuntimeConfig::new(),
);
workers_with_configs.insert(
2,
crate::local_model::runtime_config::ModelRuntimeConfig::new(),
);
let seq_manager_1 = Arc::new(
ActiveSequencesMultiWorker::new(
......
......@@ -18,8 +18,8 @@ use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use crate::kv_router::{
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE,
indexer::{DumpRequest, GetWorkersRequest},
KV_EVENT_SUBJECT, KvRouterConfig, RADIX_STATE_BUCKET, RADIX_STATE_FILE,
indexer::{DumpRequest, GetWorkersRequest, KvIndexer},
protocols::{DpRank, RouterEvent, WorkerId},
router_discovery_query,
worker_query::WorkerQueryClient,
......@@ -511,10 +511,15 @@ pub async fn start_kv_router_background(
pub async fn start_kv_router_background_event_plane(
component: Component,
kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<WorkerId>,
cancellation_token: CancellationToken,
mut worker_query_client: WorkerQueryClient,
transport_kind: EventTransportKind,
) -> Result<()> {
// WorkerQueryClient handles its own discovery loop for lifecycle + initial recovery.
// No blocking wait — recovery happens asynchronously as endpoints are discovered.
let worker_query_client =
WorkerQueryClient::spawn(component.clone(), remove_worker_tx, kv_events_tx.clone()).await?;
// Subscribe to KV events using the selected event plane transport
let mut subscriber =
EventSubscriber::for_component_with_transport(&component, KV_EVENT_SUBJECT, transport_kind)
......@@ -542,20 +547,6 @@ pub async fn start_kv_router_background_event_plane(
}
}
// Wait for at least one worker with a known runtime config before proceeding.
// This ensures we have actual config data (including enable_local_indexer) available.
tracing::info!("KV subscriber waiting for at least one worker with runtime config...");
let ready_workers = worker_query_client.wait_for_ready().await;
tracing::info!(
"KV subscriber found {} worker(s) with runtime config, proceeding",
ready_workers.len()
);
// Recover initial state from all workers with local indexer enabled
worker_query_client
.process_and_recover_workers(&kv_events_tx, "Initial recovery")
.await;
tokio::spawn(async move {
// Track last received event ID per (worker, dp_rank) for gap detection
// Each dp_rank has its own monotonic event ID sequence
......@@ -570,18 +561,6 @@ pub async fn start_kv_router_background_event_plane(
break;
}
// Handle runtime config changes (worker add/remove, recovery for new workers)
result = worker_query_client.wait_for_config_change() => {
if result.is_err() {
tracing::warn!("Runtime config watch sender dropped");
continue;
}
worker_query_client
.process_and_recover_workers(&kv_events_tx, "DISCOVERY")
.await;
}
// Handle event consumption from event plane subscription
Some(result) = subscriber.next() => {
let (envelope, event) = match result {
......@@ -597,7 +576,6 @@ pub async fn start_kv_router_background_event_plane(
let event_id = event.event.event_id;
let event_key = (worker_id, dp_rank);
// Use envelope metadata for additional debugging
tracing::trace!(
"Received event from publisher {} (seq {})",
envelope.publisher_id,
......@@ -609,7 +587,6 @@ pub async fn start_kv_router_background_event_plane(
if let Some(&last_id) = last_event_ids.get(&event_key)
&& event_id > last_id + 1
{
// Gap detected - recover missing events before processing current
let gap_start = last_id + 1;
let gap_end = event_id - 1;
let gap_size = gap_end - gap_start + 1;
......@@ -617,22 +594,15 @@ pub async fn start_kv_router_background_event_plane(
"Event ID gap detected for worker {worker_id} dp_rank {dp_rank}, recovering events [{gap_start}, {gap_end}], gap_size: {gap_size}"
);
// Note: While recovering, new events may queue in the subscriber's
// internal buffer. We don't explicitly buffer them here for simplicity.
// The subscriber will process them in order after recovery completes.
if let Err(e) = worker_query_client
.recover_from_worker(worker_id, dp_rank, Some(gap_start), Some(gap_end), &kv_events_tx)
.recover_from_worker(worker_id, dp_rank, Some(gap_start), Some(gap_end))
.await
{
tracing::error!(
"Failed to recover gap events for worker {worker_id} dp_rank {dp_rank} (gap_start: {gap_start}, gap_end: {gap_end}); proceeding with current event anyway: {e}"
);
// Note: If recovery fails, we still apply the current event.
// The tree will have a gap, but it's better than dropping the event.
}
}
// First event from this (worker, dp_rank) is always valid - we accept whatever ID it has.
// This handles initial startup and worker restarts without requiring event 0.
// Update last seen event ID (use max to handle out-of-order)
last_event_ids
......@@ -657,23 +627,6 @@ pub async fn start_kv_router_background_event_plane(
Ok(())
}
/// Backwards-compatible wrapper for NATS Core local-indexer mode.
pub async fn start_kv_router_background_nats_core(
component: Component,
kv_events_tx: mpsc::Sender<RouterEvent>,
cancellation_token: CancellationToken,
worker_query_client: WorkerQueryClient,
) -> Result<()> {
start_kv_router_background_event_plane(
component,
kv_events_tx,
cancellation_token,
worker_query_client,
EventTransportKind::Nats,
)
.await
}
/// Cleanup orphaned NATS consumers that no longer have corresponding router entries
async fn cleanup_orphaned_consumers(
nats_queue: &mut NatsQueue,
......@@ -711,3 +664,66 @@ async fn cleanup_orphaned_consumers(
}
}
}
/// Helper to decide which subscriber (JetStream or Event Plane) to start based on config
pub async fn start_subscriber(
component: Component,
kv_router_config: &KvRouterConfig,
router_id: u64,
kv_indexer: &KvIndexer,
cancellation_token: CancellationToken,
) -> Result<()> {
let transport_kind = EventTransportKind::from_env_or_default();
// Start subscriber - durable_kv_events flag determines the mode:
// - durable_kv_events=false (default): Use NATS Core / generic event plane (requires workers to have local_indexer enabled)
// - durable_kv_events=true: Use JetStream for durability and multi-replica consistency
if kv_router_config.durable_kv_events {
if transport_kind == EventTransportKind::Zmq {
tracing::warn!(
"--durable-kv-events requires NATS, but ZMQ event plane is configured; falling back to JetStream anyway"
);
}
tracing::info!("Using JetStream subscription (--durable-kv-events enabled)");
let consumer_id = router_id.to_string();
start_kv_router_background(
component,
consumer_id,
kv_indexer.event_sender(),
kv_indexer.remove_worker_sender(),
kv_router_config
.router_snapshot_threshold
.map(|_| kv_indexer.get_workers_sender()),
kv_router_config
.router_snapshot_threshold
.map(|_| kv_indexer.snapshot_event_sender()),
cancellation_token,
kv_router_config.router_snapshot_threshold,
kv_router_config.router_reset_states,
)
.await
} else {
if transport_kind == EventTransportKind::Zmq {
if kv_router_config.router_snapshot_threshold.is_some()
|| kv_router_config.router_reset_states
{
tracing::warn!(
"ZMQ event plane does not support KV snapshots or state reset; ignoring snapshot/reset settings"
);
}
tracing::info!("Using ZMQ event plane subscription (local_indexer mode)");
} else {
tracing::info!("Using NATS Core subscription (local_indexer mode)");
}
start_kv_router_background_event_plane(
component.clone(),
kv_indexer.event_sender(),
kv_indexer.remove_worker_sender(),
cancellation_token,
transport_kind,
)
.await
}
}
This diff is collapsed.
......@@ -239,7 +239,11 @@ impl LocalModelBuilder {
self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
self.runtime_config.enable_local_indexer = mocker_engine_args.enable_local_indexer;
// Decode workers don't create the WorkerKvQuery endpoint (scheduler_component is None),
// so they must not advertise enable_local_indexer=true or the router will hang
// trying to query them during initial recovery.
self.runtime_config.enable_local_indexer = mocker_engine_args.enable_local_indexer
&& mocker_engine_args.worker_type != WorkerType::Decode;
self.runtime_config.data_parallel_size = mocker_engine_args.dp_size;
self.media_decoder = Some(MediaDecoder {
image: Some(ImageDecoder::default()),
......
......@@ -32,8 +32,8 @@ pub struct ModelRuntimeConfig {
#[serde(default = "default_data_parallel_size")]
pub data_parallel_size: u32,
/// Enable worker-local KV indexer for tracking this worker's own KV cache state
#[serde(default)]
/// Enable worker-local KV indexer for tracking this worker's own KV cache state (default: true)
#[serde(default = "default_local_indexer")]
pub enable_local_indexer: bool,
/// Mapping of engine-specific runtime configs
......@@ -59,6 +59,10 @@ const fn default_data_parallel_size() -> u32 {
1
}
const fn default_local_indexer() -> bool {
true
}
impl Default for ModelRuntimeConfig {
fn default() -> Self {
Self {
......@@ -68,7 +72,7 @@ impl Default for ModelRuntimeConfig {
tool_call_parser: None,
reasoning_parser: None,
data_parallel_size: default_data_parallel_size(),
enable_local_indexer: false,
enable_local_indexer: true,
runtime_data: HashMap::new(),
tensor_model_config: None,
disaggregated_endpoint: None,
......
......@@ -385,7 +385,7 @@ impl Discovery for KVStoreDiscovery {
// Get bucket - if it doesn't exist, return empty list
let Some(bucket) = self.store.get_bucket(bucket_name).await? else {
tracing::info!(
tracing::debug!(
"KVStoreDiscovery::list: bucket missing for query={:?}, prefix={}, bucket={}",
query,
prefix,
......@@ -396,7 +396,7 @@ impl Discovery for KVStoreDiscovery {
// Get all entries from the bucket
let entries = bucket.entries().await?;
tracing::info!(
tracing::debug!(
"KVStoreDiscovery::list: query={:?}, prefix={}, bucket={}, entries={}",
query,
prefix,
......
......@@ -66,6 +66,12 @@ where
/// If None, busy detection is disabled
busy_threshold: Option<f64>,
/// When false, `generate_with_fault_detection` skips fault detection logic:
/// it won't call `report_instance_down` on errors, and it uses the raw discovery
/// instance list instead of the filtered avail list. Use for recovery/query paths
/// where transient failures are expected.
fault_detection_enabled: bool,
/// An internal Rust type. This says that PushRouter is generic over the T and U types,
/// which are the input and output types of it's `generate` function. It allows the
/// compiler to specialize us at compile time.
......@@ -112,6 +118,28 @@ where
Self::from_client_with_threshold(client, router_mode, None, None).await
}
/// Create a new PushRouter with fault detection disabled.
///
/// Unlike `from_client`, this router will not call `report_instance_down` on
/// transient errors, and `direct()` uses the raw discovery instance list instead
/// of the filtered avail list. Use for recovery/query paths.
pub async fn from_client_no_fault_detection(
client: Client,
router_mode: RouterMode,
) -> anyhow::Result<Self> {
let addressed = addressed_router(&client.endpoint).await?;
Ok(PushRouter {
client: client.clone(),
addressed,
router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)),
busy_threshold: None,
fault_detection_enabled: false,
_phantom: PhantomData,
})
}
/// Create a new PushRouter with optional busy threshold and worker load monitor
pub async fn from_client_with_threshold(
client: Client,
......@@ -132,6 +160,7 @@ where
router_mode,
round_robin_counter: Arc::new(AtomicU64::new(0)),
busy_threshold,
fault_detection_enabled: true,
_phantom: PhantomData,
};
......@@ -185,7 +214,14 @@ where
request: SingleIn<T>,
instance_id: u64,
) -> anyhow::Result<ManyOut<U>> {
let found = self.client.instance_ids_avail().contains(&instance_id);
// When fault detection is disabled, check the raw discovery list
// (not filtered by report_instance_down) so transient failures
// don't poison the instance for subsequent retries.
let found = if self.fault_detection_enabled {
self.client.instance_ids_avail().contains(&instance_id)
} else {
self.client.instance_ids().contains(&instance_id)
};
if !found {
return Err(anyhow::anyhow!(
......@@ -271,8 +307,8 @@ where
instance_id: u64,
request: SingleIn<T>,
) -> anyhow::Result<ManyOut<U>> {
// Check if all workers are busy (only if busy threshold is set)
if self.busy_threshold.is_some() {
// Check if all workers are busy (only if busy threshold is set and fault detection enabled)
if self.fault_detection_enabled && self.busy_threshold.is_some() {
let free_instances = self.client.instance_ids_free();
if free_instances.is_empty() {
// Check if we actually have any instances at all
......@@ -332,6 +368,9 @@ where
let stream: anyhow::Result<ManyOut<U>> = self.addressed.generate(request).await;
match stream {
Ok(stream) => {
if !self.fault_detection_enabled {
return Ok(stream);
}
let engine_ctx = stream.context();
let client = self.client.clone();
let stream = stream.map(move |res| {
......@@ -349,7 +388,8 @@ where
Ok(ResponseStream::new(Box::pin(stream), engine_ctx))
}
Err(err) => {
if let Some(req_err) = err.downcast_ref::<NatsRequestError>()
if self.fault_detection_enabled
&& let Some(req_err) = err.downcast_ref::<NatsRequestError>()
&& matches!(req_err.kind(), NatsNoResponders)
{
tracing::debug!(
......
......@@ -202,7 +202,10 @@ def predownload_models(pytestconfig):
else:
# Fallback to original behavior if extraction failed
download_models()
os.environ["HF_HUB_OFFLINE"] = "1"
yield
os.environ.pop("HF_HUB_OFFLINE", None)
@pytest.fixture(scope="session")
......@@ -218,7 +221,13 @@ def predownload_tokenizers(pytestconfig):
else:
# Fallback to original behavior if extraction failed
download_models(ignore_weights=True)
# Skip redundant HuggingFace API calls in worker subprocesses since
# tokenizers are already cached. This avoids flaky timeouts from slow
# HF API responses (the RepoInfo fetch still happens even for cached models).
os.environ["HF_HUB_OFFLINE"] = "1"
yield
os.environ.pop("HF_HUB_OFFLINE", None)
@pytest.fixture(autouse=True)
......@@ -610,20 +619,20 @@ def request_plane(request):
@pytest.fixture
def use_nats_core(request):
def durable_kv_events(request):
"""
Whether to use NATS Core mode (local indexer) instead of JetStream. Defaults to False.
Whether to use durable KV events via JetStream. Defaults to False (NATS Core mode).
When True:
When False (default):
- NATS server starts without JetStream (-js flag omitted) for faster startup
- Tests should use enable_local_indexer=True in mocker_args
- Workers use local indexer mode (NATS Core / fire-and-forget events)
When False (default):
- NATS server starts with JetStream for KV event distribution
- Tests use JetStream-based indexer synchronization
When True:
- NATS server starts with JetStream for durable KV event distribution
- Workers use --durable-kv-events flag to publish to JetStream
To use NATS Core mode:
@pytest.mark.parametrize("use_nats_core", [True], indirect=True)
To use JetStream mode:
@pytest.mark.parametrize("durable_kv_events", [True], indirect=True)
def test_example(runtime_services_dynamic_ports):
...
"""
......@@ -656,7 +665,7 @@ def runtime_services(request, store_kv, request_plane):
@pytest.fixture()
def runtime_services_dynamic_ports(request, store_kv, request_plane, use_nats_core):
def runtime_services_dynamic_ports(request, store_kv, request_plane, durable_kv_events):
"""Provide NATS and Etcd servers with truly dynamic ports per test.
This fixture actually allocates dynamic ports by passing port=0 to the servers.
......@@ -671,7 +680,7 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane, use_nats_co
- If store_kv != "etcd", etcd is not started (returns None)
- NATS is always started when etcd is used, because KV events require NATS
regardless of the request_plane (tcp/nats only affects request transport)
- JetStream is enabled by default; disabled when use_nats_core=True for faster startup
- NATS Core mode (no JetStream) is the default; JetStream is enabled when durable_kv_events=True
Returns a tuple of (nats_process, etcd_process) where each has a .port attribute.
"""
......@@ -679,10 +688,10 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane, use_nats_co
# Port cleanup is now handled in NatsServer and EtcdServer __exit__ methods
# Always start NATS when etcd is used - KV events require NATS regardless of request_plane
# When use_nats_core=True, disable JetStream for faster startup
# When durable_kv_events=False (default), disable JetStream for faster startup
if store_kv == "etcd":
with NatsServer(
request, port=0, disable_jetstream=use_nats_core
request, port=0, disable_jetstream=not durable_kv_events
) as nats_process:
with EtcdServer(request, port=0) as etcd_process:
# Save original env vars (may be set by session-scoped fixture)
......@@ -706,7 +715,7 @@ def runtime_services_dynamic_ports(request, store_kv, request_plane, use_nats_co
os.environ.pop("ETCD_ENDPOINTS", None)
elif request_plane == "nats":
with NatsServer(
request, port=0, disable_jetstream=use_nats_core
request, port=0, disable_jetstream=not durable_kv_events
) as nats_process:
orig_nats = os.environ.get("NATS_SERVER")
os.environ["NATS_SERVER"] = f"nats://localhost:{nats_process.port}"
......
......@@ -50,6 +50,7 @@ class KVRouterProcess(ManagedProcess):
tokens_threshold: float | None = None,
tokens_threshold_frac: float | None = None,
request_plane: str = "nats",
durable_kv_events: bool = False,
):
command = [
"python3",
......@@ -81,6 +82,9 @@ class KVRouterProcess(ManagedProcess):
["--active-prefill-tokens-threshold-frac", str(tokens_threshold_frac)]
)
if durable_kv_events:
command.append("--durable-kv-events")
env = os.environ.copy()
env["DYN_REQUEST_PLANE"] = request_plane
......@@ -1335,6 +1339,7 @@ def _test_router_indexers_sync(
request_plane: str = "nats",
test_nats_interruption: bool = False,
nats_server: Optional["NatsServer"] = None,
durable_kv_events: bool = False,
):
"""Test that two KV routers have synchronized indexer states after processing requests.
......@@ -1365,6 +1370,7 @@ def _test_router_indexers_sync(
request_plane: Request plane to use ("nats" or "tcp"). Defaults to "nats".
test_nats_interruption: If True, test NATS interruption recovery. Defaults to False.
nats_server: NatsServer instance for stop/start (required if test_nats_interruption=True).
durable_kv_events: If True, use durable KV events (JetStream). Defaults to False.
Raises:
AssertionError: If router states don't synchronize correctly or snapshot is missing
......@@ -1375,7 +1381,10 @@ def _test_router_indexers_sync(
# Use async to manage the test flow
async def test_sync():
# Create KvRouterConfig with lower snapshot threshold for testing
kv_router_config = KvRouterConfig(router_snapshot_threshold=20)
kv_router_config = KvRouterConfig(
router_snapshot_threshold=20,
durable_kv_events=durable_kv_events,
)
async def send_requests_to_router(router, num_requests, router_name, endpoint):
# Now send the actual requests
......@@ -1690,6 +1699,7 @@ def _test_router_decisions_disagg(
test_payload: dict,
store_backend: str = "etcd",
request_plane: str = "nats",
durable_kv_events: bool = False,
):
"""Validate KV cache prefix reuse in disaggregated prefill-decode setup via HTTP frontend.
......@@ -1711,6 +1721,7 @@ def _test_router_decisions_disagg(
frontend_port: Port for the frontend HTTP server
test_payload: Base test payload to send to /v1/chat/completions
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
durable_kv_events: If True, use durable KV events (JetStream). Defaults to False.
Raises:
AssertionError: If prefill_worker_ids differ across requests (prefix reuse failure)
......@@ -1730,6 +1741,7 @@ def _test_router_decisions_disagg(
store_backend,
enforce_disagg=True,
request_plane=request_plane,
durable_kv_events=durable_kv_events,
)
kv_router.__enter__()
......@@ -1909,6 +1921,7 @@ def _test_router_decisions(
test_dp_rank: bool = False,
block_size: int = BLOCK_SIZE,
use_kv_events: bool = True,
durable_kv_events: bool = False,
):
"""Validate KV cache prefix reuse and worker routing by sending requests diverging prefixes.
......@@ -1929,6 +1942,7 @@ def _test_router_decisions(
test_dp_rank: If True, also forces and validates dp_rank routing (for data parallel setups)
use_kv_events: If True (default), uses KV events from workers. If False, uses
approximate routing with TTL-based expiration (--no-kv-events mode).
durable_kv_events: If True, use durable KV events (JetStream). Defaults to False.
Raises:
AssertionError: If routing decisions don't follow KV cache prefix reuse as expected
......@@ -1937,6 +1951,7 @@ def _test_router_decisions(
kv_router_config = KvRouterConfig(
router_snapshot_threshold=20,
use_kv_events=use_kv_events,
durable_kv_events=durable_kv_events,
)
kv_push_router = KvPushRouter(
endpoint=endpoint,
......
......@@ -158,8 +158,9 @@ def _build_mocker_command(
command.extend(["--watermark", str(mocker_args["watermark"])])
if "dp_size" in mocker_args:
command.extend(["--data-parallel-size", str(mocker_args["dp_size"])])
if mocker_args.get("enable_local_indexer"):
command.append("--enable-local-indexer")
# Use --durable-kv-events to enable JetStream mode (local indexer disabled)
if mocker_args.get("durable_kv_events") is True:
command.append("--durable-kv-events")
if "bootstrap_ports" in mocker_args:
command.extend(["--bootstrap-ports", mocker_args["bootstrap_ports"]])
......@@ -325,14 +326,14 @@ class DisaggMockerProcess:
@pytest.mark.timeout(42) # ~3x average (~13.80s), rounded up
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize(
"use_nats_core", [True], indirect=True
"durable_kv_events", [False], indirect=True
) # Use NATS Core (local indexer)
def test_mocker_kv_router(
request,
runtime_services_dynamic_ports,
predownload_tokenizers,
request_plane,
use_nats_core,
durable_kv_events,
):
"""
Test KV router with multiple mocker engine instances.
......@@ -347,7 +348,7 @@ def test_mocker_kv_router(
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"enable_local_indexer": use_nats_core,
"durable_kv_events": durable_kv_events,
}
try:
......@@ -385,7 +386,7 @@ def test_mocker_kv_router(
@pytest.mark.parametrize("store_backend", ["etcd", "file"])
@pytest.mark.parametrize(
"use_nats_core", [True], indirect=True
"durable_kv_events", [False], indirect=True
) # Use NATS Core (local indexer)
@pytest.mark.timeout(60) # ~3x average (~19.86s), rounded up
def test_mocker_two_kv_router(
......@@ -394,7 +395,7 @@ def test_mocker_two_kv_router(
predownload_tokenizers,
file_storage_backend,
store_backend,
use_nats_core,
durable_kv_events,
):
"""
Test with two KV routers and multiple mocker engine instances.
......@@ -411,7 +412,7 @@ def test_mocker_two_kv_router(
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"enable_local_indexer": use_nats_core,
"durable_kv_events": durable_kv_events,
}
try:
......@@ -440,7 +441,7 @@ def test_mocker_two_kv_router(
test_payload=TEST_PAYLOAD,
num_requests=NUM_REQUESTS,
store_backend=store_backend,
skip_consumer_verification=use_nats_core, # Skip JetStream checks in NATS Core mode
skip_consumer_verification=not durable_kv_events, # Skip JetStream checks in NATS Core mode
)
finally:
......@@ -450,11 +451,11 @@ def test_mocker_two_kv_router(
@pytest.mark.skip(reason="Flaky, temporarily disabled")
@pytest.mark.parametrize(
"use_nats_core", [True], indirect=True
"durable_kv_events", [False], indirect=True
) # Use NATS Core (local indexer)
@pytest.mark.timeout(60) # ~3x average (~19.86s), rounded up (when enabled)
def test_mocker_kv_router_overload_503(
request, runtime_services_dynamic_ports, predownload_tokenizers, use_nats_core
request, runtime_services_dynamic_ports, predownload_tokenizers, durable_kv_events
):
"""Test that KV router returns 503 when mocker workers are overloaded."""
logger.info("Starting mocker KV router overload test for 503 status")
......@@ -463,7 +464,7 @@ def test_mocker_kv_router_overload_503(
"speedup_ratio": 10,
"block_size": 4, # Smaller block size
"num_gpu_blocks": 64, # Limited GPU blocks to exhaust quickly
"enable_local_indexer": use_nats_core,
"durable_kv_events": durable_kv_events,
}
try:
......@@ -494,14 +495,14 @@ def test_mocker_kv_router_overload_503(
@pytest.mark.timeout(22) # ~3x average (~7.10s), rounded up
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize(
"use_nats_core", [True], indirect=True
"durable_kv_events", [False], indirect=True
) # Use NATS Core (local indexer)
def test_kv_push_router_bindings(
request,
runtime_services_dynamic_ports,
predownload_tokenizers,
request_plane,
use_nats_core,
durable_kv_events,
):
"""Test KvPushRouter Python bindings with mocker engines."""
logger.info("Starting KvPushRouter bindings test")
......@@ -509,7 +510,7 @@ def test_kv_push_router_bindings(
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"enable_local_indexer": use_nats_core,
"durable_kv_events": durable_kv_events,
}
try:
......@@ -545,18 +546,18 @@ def test_kv_push_router_bindings(
@pytest.mark.parametrize(
"store_backend,use_nats_core,request_plane",
"store_backend,durable_kv_events,request_plane",
[
("etcd", False, "nats"), # JetStream mode - uses JetStream (default)
("etcd", True, "tcp"), # NATS core mode (with gap detection) - no JetStream
("file", False, "nats"), # File backend - uses JetStream (default)
("etcd", True, "nats"), # JetStream mode - uses JetStream
("etcd", False, "tcp"), # NATS core mode (with gap detection) - no JetStream
("file", True, "nats"), # File backend - uses JetStream
],
ids=[
"jetstream",
"nats_core",
"file",
],
indirect=["request_plane", "use_nats_core"],
indirect=["request_plane", "durable_kv_events"],
)
@pytest.mark.timeout(90) # TODO: figure out a timeout
def test_indexers_sync(
......@@ -565,7 +566,7 @@ def test_indexers_sync(
predownload_tokenizers,
file_storage_backend,
store_backend,
use_nats_core,
durable_kv_events,
request_plane,
):
"""
......@@ -580,7 +581,7 @@ def test_indexers_sync(
"""
logger.info(
f"Starting indexers sync test: store_backend={store_backend}, "
f"use_nats_core={use_nats_core}, request_plane={request_plane}"
f"durable_kv_events={durable_kv_events}, request_plane={request_plane}"
)
# Use the dynamic-port fixture to avoid hardcoded localhost:4222/2379 in parallel runs.
......@@ -591,7 +592,7 @@ def test_indexers_sync(
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"enable_local_indexer": use_nats_core,
"durable_kv_events": durable_kv_events,
"dp_size": 2,
}
......@@ -610,6 +611,7 @@ def test_indexers_sync(
# Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
# When using durable_kv_events=True, use JetStream mode for the router
_test_router_indexers_sync(
engine_workers=mockers,
block_size=BLOCK_SIZE,
......@@ -617,8 +619,9 @@ def test_indexers_sync(
num_workers=NUM_MOCKERS,
store_backend=store_backend,
request_plane=request_plane,
test_nats_interruption=use_nats_core,
nats_server=nats_process if use_nats_core else None,
test_nats_interruption=not durable_kv_events,
nats_server=nats_process if not durable_kv_events else None,
durable_kv_events=durable_kv_events,
)
logger.info("Indexers sync test completed successfully")
......@@ -630,10 +633,10 @@ def test_indexers_sync(
@pytest.mark.timeout(42) # ~3x average (~13.80s), rounded up
@pytest.mark.parametrize(
"use_nats_core", [True], indirect=True
"durable_kv_events", [False], indirect=True
) # Use NATS Core (local indexer)
def test_query_instance_id_returns_worker_and_tokens(
request, runtime_services_dynamic_ports, predownload_tokenizers, use_nats_core
request, runtime_services_dynamic_ports, predownload_tokenizers, durable_kv_events
):
"""Test query_instance_id annotation with mocker engines."""
logger.info("Starting KV router query_instance_id annotation test")
......@@ -641,7 +644,7 @@ def test_query_instance_id_returns_worker_and_tokens(
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"enable_local_indexer": use_nats_core,
"durable_kv_events": durable_kv_events,
}
os.makedirs(request.node.name, exist_ok=True)
......@@ -674,55 +677,46 @@ def test_query_instance_id_returns_worker_and_tokens(
@pytest.mark.timeout(29) # ~3x average (~9.55s), rounded up
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize(
"use_nats_core,use_kv_events",
"durable_kv_events,use_kv_events",
[
(False, True), # JetStream mode (default) - uses JetStream
(True, True), # NATS Core + local indexer mode - no JetStream
(False, False), # Approximate mode (--no-kv-events) - uses JetStream
(True, True), # JetStream mode with KV events
(False, True), # NATS Core mode with local indexer (default)
(False, False), # Approximate mode (--no-kv-events) - no KV events
],
ids=["jetstream", "nats_core", "no_kv_events"],
indirect=["use_nats_core"],
indirect=["durable_kv_events"],
)
def test_router_decisions(
request,
runtime_services_dynamic_ports,
predownload_tokenizers,
use_nats_core,
durable_kv_events,
use_kv_events,
request_plane,
):
"""Validate KV cache prefix reuse and dp_rank routing by sending progressive requests with overlapping prefixes.
Parameterized to test:
- JetStream mode (default): KV events via JetStream
- NATS Core mode: KV events via NATS Core with local indexer on workers
- JetStream mode: KV events via NATS JetStream (durable)
- NATS Core mode (default): KV events via NATS Core with local indexer on workers
- Approximate mode (--no-kv-events): No KV events, router predicts cache state
based on routing decisions with TTL-based expiration and pruning
"""
# runtime_services_dynamic_ports handles NATS and etcd startup
if not use_kv_events:
mode = "Approximate (no-kv-events)"
elif use_nats_core:
mode = "NATS Core (local indexer)"
else:
mode = "JetStream"
logger.info(
f"Starting test router prefix reuse and KV events synchronization ({mode})"
f"Starting test router decisions: durable_kv_events={durable_kv_events}, use_kv_events={use_kv_events}"
)
# Create mocker args dictionary with dp_size=4
# Note: enable_local_indexer only applies when use_kv_events=True and use_nats_core=True
# durable_kv_events=True enables JetStream mode; False (default) uses NATS Core with local indexer
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"dp_size": 4,
"enable_local_indexer": use_nats_core and use_kv_events,
"durable_kv_events": durable_kv_events and use_kv_events,
}
try:
logger.info(
f"Starting 2 mocker instances with dp_size=4 each (8 total dp ranks), {mode}"
)
mockers = MockerProcess(
request,
mocker_args=mocker_args,
......@@ -748,6 +742,7 @@ def test_router_decisions(
request,
test_dp_rank=True,
use_kv_events=use_kv_events,
durable_kv_events=durable_kv_events,
)
finally:
......@@ -786,11 +781,11 @@ def test_router_decisions_disagg(
namespace_suffix = generate_random_suffix()
shared_namespace = f"test-namespace-{namespace_suffix}"
# Create mocker args - use JetStream for KV events (more reliable than NATS Core)
# Create mocker args - use NATS Core with local indexer (default mode)
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"enable_local_indexer": False,
# durable_kv_events defaults to False (NATS Core mode)
}
prefill_workers = None
......@@ -877,7 +872,7 @@ def test_router_decisions_disagg(
@pytest.mark.parametrize("request_plane", ["nats", "tcp"], indirect=True)
@pytest.mark.parametrize(
"use_nats_core", [True], indirect=True
"durable_kv_events", [False], indirect=True
) # Use NATS Core (local indexer)
@pytest.mark.timeout(39) # ~3x average (~12.84s), rounded up
def test_busy_threshold_endpoint(
......@@ -885,7 +880,7 @@ def test_busy_threshold_endpoint(
runtime_services_dynamic_ports,
predownload_tokenizers,
request_plane,
use_nats_core,
durable_kv_events,
):
"""Test that the /busy_threshold endpoint can be hit and responds correctly.
......@@ -905,7 +900,7 @@ def test_busy_threshold_endpoint(
mocker_args = {
"speedup_ratio": SPEEDUP_RATIO,
"block_size": BLOCK_SIZE,
"enable_local_indexer": use_nats_core,
"durable_kv_events": durable_kv_events,
}
try:
......
......@@ -87,7 +87,7 @@ class SGLangProcess:
data_parallel_size: Optional[int] = None,
request_plane: str = "tcp",
store_backend: str = "etcd",
enable_local_indexer: bool = False,
durable_kv_events: bool = False,
):
"""Initialize SGLang workers with dynamo integration.
......@@ -104,7 +104,7 @@ class SGLangProcess:
data_parallel_size: If set, enables data parallelism with this many ranks (num_workers must equal data_parallel_size)
request_plane: Request plane to use ("nats", "tcp", or "http"). Defaults to "tcp".
store_backend: Storage backend to use ("etcd" or "file"). Defaults to "etcd".
enable_local_indexer: If True, enable worker-local KV indexer for NATS Core mode. Defaults to False.
durable_kv_events: If True, use JetStream for durable KV events. Defaults to False (NATS Core mode).
"""
# Generate unique namespace for isolation
namespace_suffix = generate_random_suffix()
......@@ -185,6 +185,10 @@ class SGLangProcess:
kv_events_config = f'{{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:{kv_events_port}"}}'
command.extend(["--kv-events-config", kv_events_config])
# Use --durable-kv-events to enable JetStream mode (local indexer disabled)
if durable_kv_events:
command.append("--durable-kv-events")
env = os.environ.copy() # Copy parent environment
env_vars = {
"CUDA_VISIBLE_DEVICES": gpu_device,
......@@ -197,10 +201,6 @@ class SGLangProcess:
if self.store_backend == "file" and "DYN_FILE_KV" in os.environ:
env_vars["DYN_FILE_KV"] = os.environ["DYN_FILE_KV"]
# Enable local indexer for NATS Core mode
if enable_local_indexer:
env_vars["DYN_LOCAL_INDEXER"] = "true"
env.update(env_vars)
# Create managed process for the worker
......@@ -475,13 +475,12 @@ def test_router_decisions_sglang_dp(
@pytest.mark.pre_merge
@pytest.mark.gpu_1
@pytest.mark.parametrize(
"store_backend,use_nats_core,request_plane",
"store_backend,durable_kv_events,request_plane",
[
("etcd", False, "nats"), # JetStream mode
# ("etcd", True, "tcp"), # nats_core mode - disabled for now
# ("file", False, "nats"), # File backend - TODO: investigate file backend support for SGLang
("etcd", False, "tcp"),
],
ids=["jetstream"],
ids=["nats_core"],
indirect=["durable_kv_events", "request_plane"],
)
@pytest.mark.timeout(150) # ~3x average (~46s/test), rounded up
def test_sglang_indexers_sync(
......@@ -491,7 +490,7 @@ def test_sglang_indexers_sync(
file_storage_backend,
set_ucx_tls_no_mm,
store_backend,
use_nats_core,
durable_kv_events,
request_plane,
):
"""
......@@ -499,15 +498,15 @@ def test_sglang_indexers_sync(
with SGLang workers. This test verifies that both routers converge to the same internal state.
Tests with configuration:
- jetstream: etcd backend, JetStream for KV events, NATS request plane
- tcp_nats_core: etcd backend, local indexer with NATS Core, TCP request plane
- nats_core: etcd backend, local indexer with NATS Core, TCP request plane
(includes NATS interruption/recovery testing)
"""
# runtime_services_dynamic_ports handles NATS and etcd startup
nats_process, _etcd_process = runtime_services_dynamic_ports
logger.info(
f"Starting SGLang indexers sync test: store_backend={store_backend}, "
f"use_nats_core={use_nats_core}, request_plane={request_plane}"
f"durable_kv_events={durable_kv_events}, request_plane={request_plane}"
)
N_SGLANG_WORKERS = 2
......@@ -522,13 +521,14 @@ def test_sglang_indexers_sync(
single_gpu=True, # fit workers into one GPU
request_plane=request_plane,
store_backend=store_backend,
enable_local_indexer=use_nats_core,
durable_kv_events=durable_kv_events,
)
logger.info(f"All SGLang workers using namespace: {sglang_workers.namespace}")
sglang_workers.__enter__()
# Use the common test implementation (creates its own runtimes for each router)
# Note: Consumer verification is done inside _test_router_indexers_sync while routers are alive
# When using durable_kv_events=True, use JetStream mode for the router
_test_router_indexers_sync(
engine_workers=sglang_workers,
block_size=PAGE_SIZE,
......@@ -536,8 +536,9 @@ def test_sglang_indexers_sync(
num_workers=N_SGLANG_WORKERS,
store_backend=store_backend,
request_plane=request_plane,
test_nats_interruption=use_nats_core,
nats_server=nats_process if use_nats_core else None,
test_nats_interruption=not durable_kv_events,
nats_server=nats_process if not durable_kv_events else None,
durable_kv_events=durable_kv_events,
)
logger.info("SGLang indexers sync test completed successfully")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment