Unverified Commit 33249945 authored by Karen Chung's avatar Karen Chung Committed by GitHub
Browse files

feat: worker-local KvIndexer in KvEventPublisher (#4519)


Co-authored-by: default avatarYan Ru Pei <yanrpei@gmail.com>
parent 10b01b45
......@@ -2663,7 +2663,7 @@ dependencies = [
"bytes",
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"chrono",
"clap 4.5.52",
"clap 4.5.53",
"criterion 0.3.6",
"cudarc",
"dashmap 5.5.3",
......@@ -4065,8 +4065,8 @@ checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97"
dependencies = [
"dirs",
"futures",
"indicatif 0.17.11",
"http 1.4.0",
"indicatif 0.17.11",
"libc",
"log",
"num_cpus",
......
......@@ -113,6 +113,7 @@ def create_temp_engine_args_file(args) -> Path:
else None,
"is_prefill": getattr(args, "is_prefill_worker", None),
"is_decode": getattr(args, "is_decode_worker", None),
"enable_local_indexer": getattr(args, "enable_local_indexer", None),
}
# Remove None values to only include explicitly set arguments
......@@ -284,6 +285,12 @@ def parse_args():
default=False,
help="Mark this as a decode worker which does not publish KV events and skips prefill cost estimation (default: False)",
)
parser.add_argument(
"--enable-local-indexer",
action="store_true",
default=False,
help="Enable worker-local KV indexer for tracking this worker's own KV cache state (default: False)",
)
parser.add_argument(
"--store-kv",
type=str,
......
......@@ -40,6 +40,7 @@ class Config:
custom_jinja_template: Optional[str] = None
store_kv: str
request_plane: str
enable_local_indexer: bool = False
# mirror vLLM
model: str
......@@ -204,6 +205,13 @@ def parse_args() -> Config:
default=os.environ.get("DYN_REQUEST_PLANE", "nats"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
)
parser.add_argument(
"--enable-local-indexer",
type=str,
choices=["true", "false"],
default=os.environ.get("DYN_LOCAL_INDEXER", "false"),
help="Enable worker-local KV indexer for tracking this worker's own KV cache state (can also be toggled with env var DYN_LOCAL_INDEXER).",
)
parser.add_argument(
"--use-vllm-tokenizer",
action="store_true",
......@@ -214,6 +222,7 @@ def parse_args() -> Config:
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
args.enable_local_indexer = str(args.enable_local_indexer).lower() == "true"
engine_args = AsyncEngineArgs.from_cli_args(args)
# Workaround for vLLM GIL contention bug with NIXL connector when using UniProcExecutor.
......@@ -312,6 +321,7 @@ def parse_args() -> Config:
config.mm_prompt_template = args.mm_prompt_template
config.store_kv = args.store_kv
config.request_plane = args.request_plane
config.enable_local_indexer = args.enable_local_indexer
config.use_vllm_tokenizer = args.use_vllm_tokenizer
# Validate custom Jinja template file exists if provided
......
......@@ -224,6 +224,7 @@ def setup_kv_event_publisher(
worker_id=generate_endpoint.connection_id(),
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
enable_local_indexer=config.enable_local_indexer,
)
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
kv_publishers.append(kv_publisher)
......@@ -336,6 +337,7 @@ async def register_vllm_model(
runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
runtime_config.enable_local_indexer = config.enable_local_indexer
# Add tool/reasoning parsers for decode models
if model_type != ModelType.Prefill:
......
......@@ -21,7 +21,7 @@ use rs::traits::events::EventSubscriber;
use tracing;
use llm_rs::kv_router::protocols::*;
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks};
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks, start_zmq_listener};
use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
#[pyfunction]
......@@ -106,6 +106,9 @@ pub struct ZmqKvEventPublisherConfig {
pub zmq_endpoint: String,
#[pyo3(get, set)]
pub zmq_topic: String,
#[pyo3(get, set)]
pub enable_local_indexer: bool, // whether the underlying KvEventPublisher publishes to
// both global and worker-local KvIndexers
}
#[pymethods]
......@@ -115,19 +118,22 @@ impl ZmqKvEventPublisherConfig {
worker_id,
kv_block_size,
zmq_endpoint = "tcp://127.0.0.1:5557".to_string(),
zmq_topic = "".to_string()
zmq_topic = "".to_string(),
enable_local_indexer = false
))]
pub fn new(
worker_id: WorkerId,
kv_block_size: usize,
zmq_endpoint: String,
zmq_topic: String,
enable_local_indexer: bool,
) -> Self {
Self {
worker_id,
kv_block_size,
zmq_endpoint,
zmq_topic,
enable_local_indexer,
}
}
}
......@@ -141,13 +147,14 @@ pub(crate) struct ZmqKvEventPublisher {
impl ZmqKvEventPublisher {
#[new]
fn new(component: Component, config: ZmqKvEventPublisherConfig) -> PyResult<Self> {
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new_with_local_indexer(
component.inner,
config.kv_block_size as u32,
Some(KvEventSourceConfig::Zmq {
endpoint: config.zmq_endpoint,
topic: config.zmq_topic,
}),
config.enable_local_indexer,
)
.map_err(to_pyerr)?;
Ok(Self { inner })
......@@ -179,7 +186,7 @@ impl ZmqKvEventListener {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<KvCacheEvent>();
let shutdown_token = tokio_util::sync::CancellationToken::new();
tokio::spawn(llm_rs::kv_router::publisher::start_zmq_listener(
tokio::spawn(start_zmq_listener(
zmq_endpoint,
zmq_topic,
tx,
......
......@@ -49,6 +49,11 @@ impl ModelRuntimeConfig {
self.inner.data_parallel_size = data_parallel_size;
}
#[setter]
fn set_enable_local_indexer(&mut self, enable_local_indexer: bool) {
self.inner.enable_local_indexer = enable_local_indexer;
}
fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> {
let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?;
self.inner
......@@ -103,6 +108,11 @@ impl ModelRuntimeConfig {
self.inner.reasoning_parser.clone()
}
#[getter]
fn enable_local_indexer(&self) -> bool {
self.inner.enable_local_indexer
}
#[getter]
fn runtime_data(&self, py: Python<'_>) -> PyResult<PyObject> {
let dict = PyDict::new(py);
......
......@@ -460,6 +460,7 @@ class ModelRuntimeConfig:
max_num_batched_tokens: int | None
tool_call_parser: str | None
reasoning_parser: str | None
enable_local_indexer: bool
runtime_data: dict[str, Any]
tensor_model_config: Any | None
......@@ -843,7 +844,8 @@ class ZmqKvEventPublisherConfig:
worker_id: int,
kv_block_size: int,
zmq_endpoint: str = "tcp://127.0.0.1:5557",
zmq_topic: str = ""
zmq_topic: str = "",
enable_local_indexer: bool = False
) -> None:
"""
Configuration for the ZmqKvEventPublisher.
......@@ -852,6 +854,7 @@ class ZmqKvEventPublisherConfig:
:param kv_block_size: The block size for the key-value store.
:param zmq_endpoint: The ZeroMQ endpoint. Defaults to "tcp://127.0.0.1:5557".
:param zmq_topic: The ZeroMQ topic to subscribe to. Defaults to an empty string.
:param enable_local_indexer: Whether to enable the worker-local KV indexer. Defaults to False.
"""
...
......
......@@ -34,8 +34,11 @@ pub mod scheduler;
pub mod scoring;
pub mod sequence;
pub mod subscriber;
pub mod worker_query;
use indexer::WorkerKvQueryResponse;
pub use prefill_router::PrefillRouter;
use worker_query::WorkerQueryClient;
use crate::{
kv_router::{
......@@ -45,11 +48,12 @@ use crate::{
compute_block_hash_for_seq, compute_seq_hash_for_block,
},
protocols::{
LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult, WorkerWithDpRank,
LocalBlockHash, RouterRequest, RouterResponse, WorkerId, WorkerSelectionResult,
WorkerWithDpRank,
},
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
sequence::SequenceError,
subscriber::start_kv_router_background,
subscriber::{recover_from_all_workers, start_kv_router_background},
},
local_model::runtime_config::ModelRuntimeConfig,
model_card::ModelDeploymentCard,
......@@ -77,6 +81,10 @@ pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_QUERY_SUBJECT: &str = "worker_kv_indexer_query";
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";
......@@ -270,6 +278,8 @@ pub struct KvRouter {
cancellation_token: tokio_util::sync::CancellationToken,
client: Client,
worker_query_client: Option<WorkerQueryClient>,
}
impl KvRouter {
......@@ -296,7 +306,7 @@ impl KvRouter {
endpoint: endpoint_id.name.clone(),
};
let discovery_stream = discovery
.list_and_watch(discovery_key, Some(cancellation_token.clone()))
.list_and_watch(discovery_key.clone(), Some(cancellation_token.clone()))
.await?;
let runtime_configs_rx =
watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
......@@ -333,13 +343,19 @@ impl KvRouter {
component.clone(),
block_size,
instance_ids_rx,
runtime_configs_rx,
runtime_configs_rx.clone(),
selector,
kv_router_config.router_replica_sync,
consumer_id.clone(),
)
.await?;
// Initialize worker query client using namespace abstraction
// (created before background task so we can use it for startup recovery)
let worker_query_client =
worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());
tracing::info!("Worker query client initialized");
// Start KV event subscriber background process (only when use_kv_events is enabled)
if kv_router_config.use_kv_events
&& let Indexer::KvIndexer(ref kv_indexer) = indexer
......@@ -360,6 +376,47 @@ impl KvRouter {
kv_router_config.router_reset_states,
)
.await?;
// Perform startup recovery from workers with local indexers
// This catches up on any events missed while the router was offline
let last_event_ids = kv_indexer
.get_last_received_event_ids()
.await
.unwrap_or_default();
let instances = client.instance_source.as_ref().borrow().clone();
let worker_ids: Vec<WorkerId> = instances.iter().map(|i| i.instance_id).collect();
if !worker_ids.is_empty() {
tracing::info!(
worker_count = worker_ids.len(),
"Starting recovery from workers with local indexers"
);
// NOTE: recover_from_all_workers() is a no-op if
// Worker with worker_id is not associated with a
// local indexer instance.
let recovered = recover_from_all_workers(
&worker_query_client,
&last_event_ids,
&worker_ids,
&kv_indexer.event_sender(),
)
.await;
if recovered > 0 {
tracing::info!(
recovered_events = recovered,
"KV Router startup: Recovered {} KV events from workers {:?}",
recovered,
worker_ids
);
} else {
tracing::info!(
"KV Router startup: No KV events recovered from workers {:?}",
worker_ids
);
}
}
}
tracing::info!("KV Routing initialized");
......@@ -370,6 +427,7 @@ impl KvRouter {
kv_router_config,
cancellation_token,
client,
worker_query_client: Some(worker_query_client),
})
}
......@@ -502,6 +560,62 @@ impl KvRouter {
pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.indexer.dump_events().await
}
/// Query a specific worker's local KV indexer for its events
/// (See docstring for `WorkerQueryClient.query_worker()`)
pub async fn query_worker_local_kv(
&self,
worker_id: WorkerId,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> {
let query_client = self
.worker_query_client
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?;
query_client
.query_worker(worker_id, start_event_id, end_event_id)
.await
}
/// Recover missed KV events from a specific worker.
///
/// Queries the worker's local KV indexer for events starting from
/// `start_event_id` and applies them to the router's indexer.
///
/// # Arguments
///
/// * `worker_id` - The worker to recover from
/// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
/// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
pub async fn recover_from_worker(
&self,
worker_id: WorkerId,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<usize> {
let query_client = self
.worker_query_client
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Worker query client not available"))?;
let event_tx = match &self.indexer {
Indexer::KvIndexer(kv_indexer) => kv_indexer.event_sender(),
Indexer::None => {
anyhow::bail!("Cannot recover: indexer is disabled (--overlap_score_weight is 0)")
}
};
subscriber::recover_from_worker(
query_client,
worker_id,
start_event_id,
end_event_id,
&event_tx,
)
.await
}
}
// NOTE: KVRouter works like a PushRouter,
......
......@@ -44,7 +44,7 @@ use std::{
collections::{HashMap, VecDeque},
iter,
rc::Rc,
sync::{Arc, OnceLock},
sync::{Arc, Mutex, OnceLock},
thread::JoinHandle,
time::{Duration, Instant},
};
......@@ -199,6 +199,31 @@ impl RouterEvent {
}
}
// -------
// Distributed router - Worker KV Query types
// -------
/// Request to query a worker's local KV indexer.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct WorkerKvQueryRequest {
/// The worker ID of the worker to query.
pub worker_id: WorkerId,
/// The query can specify the [start, end) range of event id's to return.
/// If neither is specified, the worker dumps all events.
/// If only one is specified, `start` is assumed to be the oldest logged event,
/// and `end` is assumed to be the newest logged event.
pub start_event_id: Option<u64>,
pub end_event_id: Option<u64>,
}
/// Response from a worker's local KV indexer.
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct WorkerKvQueryResponse {
/// The events from the worker local KvIndexer.
pub events: Vec<RouterEvent>,
}
/// A block in the Radix Tree.
#[derive(Debug)]
struct RadixBlock {
......@@ -781,6 +806,13 @@ pub struct GetWorkersRequest {
pub resp: oneshot::Sender<Vec<WorkerId>>,
}
/// A request to get the last received event ID per worker.
/// Used for fault tolerance recovery to determine which events to request from workers.
pub struct GetLastReceivedEventIdsRequest {
/// Channel to send the last received event IDs per worker
pub resp: oneshot::Sender<HashMap<WorkerId, u64>>,
}
#[async_trait]
pub trait KvIndexerInterface {
/// Find matches for a given sequence of `LocalBlockHash`es.
......@@ -885,6 +917,8 @@ pub struct KvIndexer {
dump_tx: mpsc::Sender<DumpRequest>,
/// A sender for routing decision requests.
routing_tx: mpsc::Sender<RoutingDecisionRequest>,
/// A sender for getting last received event IDs (for fault tolerance recovery).
last_event_ids_tx: mpsc::Sender<GetLastReceivedEventIdsRequest>,
/// A handle to the background task managing the KV store.
task: OnceLock<std::thread::JoinHandle<()>>,
/// The size of the KV block this indexer can handle.
......@@ -918,6 +952,9 @@ impl KvIndexer {
let (dump_tx, dump_rx) = mpsc::channel::<DumpRequest>(16);
let (routing_tx, mut routing_rx) = mpsc::channel::<RoutingDecisionRequest>(2048);
let (prune_tx, mut prune_rx) = mpsc::channel::<()>(1);
let (last_event_ids_tx, mut last_event_ids_rx) =
mpsc::channel::<GetLastReceivedEventIdsRequest>(16);
let cancel_clone = token.clone();
let task = std::thread::spawn(move || {
......@@ -942,6 +979,10 @@ impl KvIndexer {
});
let mut event_id_counter = 0u64;
// Track last received event ID per worker (for fault tolerance recovery)
// Only used when enable_event_tracking is true
let mut last_received_event_id: HashMap<WorkerId, u64> = HashMap::new();
loop {
// Create a future that sleeps until the next expiration time
let expiry_fut = if let Some(ref pm) = prune_manager
......@@ -968,6 +1009,10 @@ impl KvIndexer {
let _ = get_workers_req.resp.send(workers);
}
Some(req) = last_event_ids_rx.recv() => {
let _ = req.resp.send(last_received_event_id.clone());
}
Some(_) = prune_rx.recv() => {
// Tree size-based pruning triggered
let Some(ref mut pm) = prune_manager else { continue };
......@@ -990,6 +1035,33 @@ impl KvIndexer {
}
Some(event) = event_rx.recv() => {
// Track last received event ID per worker
// Check for gaps before updating the last received ID
// TODO should this trigger a recovery event?
let last_id = *last_received_event_id.get(&event.worker_id).unwrap_or(&0);
let incoming_id = event.event.event_id;
// Detect gap: if incoming ID is more than 1 greater than last received
if incoming_id > last_id + 1 && last_id > 0 {
let gap_start = last_id + 1;
let gap_end = incoming_id - 1;
tracing::warn!(
worker_id = event.worker_id,
gap_start,
gap_end,
gap_size = gap_end - gap_start + 1,
"Event ID gap detected! Missed events [{}, {}]. \
If this is a global KvIndexer, within a KvRouter context,
consider calling KvRouter::query_worker_local_kv() to potentially recover worker-stored events.",
gap_start,
gap_end,
);
}
// Update last received event ID (use max to handle out-of-order events)
let entry = last_received_event_id.entry(event.worker_id).or_insert(0);
*entry = (*entry).max(event.event.event_id);
let event_type = KvIndexerMetrics::get_event_type(&event.event.data);
let result = trie.apply_event(event.clone());
let result_is_ok = result.is_ok();
......@@ -1121,6 +1193,7 @@ impl KvIndexer {
get_workers_tx,
dump_tx,
routing_tx,
last_event_ids_tx,
task: once,
kv_block_size,
}
......@@ -1173,6 +1246,48 @@ impl KvIndexer {
pub fn get_workers_sender(&self) -> mpsc::Sender<GetWorkersRequest> {
self.get_workers_tx.clone()
}
/// Get a sender for last received event IDs requests.
///
/// ### Returns
///
/// A `mpsc::Sender` for `GetLastReceivedEventIdsRequest`s.
pub fn last_event_ids_sender(&self) -> mpsc::Sender<GetLastReceivedEventIdsRequest> {
self.last_event_ids_tx.clone()
}
/// Get the last received event ID for each worker.
///
/// This method is used for **fault tolerance recovery** when the router needs to
/// catch up on missed events after a disconnect. By tracking the last event ID
/// received from each worker, the router can query workers for events starting
/// from `last_id + 1` to recover missed state.
///
/// **Note**: This method is intdned for the global `KvIndexer` used by routers,
/// not on `LocalKvIndexer` (worker-side) or `KvIndexerSharded`.
///
/// ### Returns
///
/// A `HashMap` mapping worker IDs to their last received event ID.
///
pub async fn get_last_received_event_ids(
&self,
) -> Result<HashMap<WorkerId, u64>, KvRouterError> {
let (resp_tx, resp_rx) = oneshot::channel();
let req = GetLastReceivedEventIdsRequest { resp: resp_tx };
if let Err(e) = self.last_event_ids_tx.send(req).await {
tracing::error!(
"Failed to send last event IDs request: {:?}; the indexer maybe offline",
e
);
return Err(KvRouterError::IndexerOffline);
}
resp_rx
.await
.map_err(|_| KvRouterError::IndexerDroppedRequest)
}
}
#[async_trait]
......@@ -1285,6 +1400,574 @@ impl Drop for KvIndexer {
}
}
// -------------------------------------------------
// Decentralized router: LocalKvIndexer for workers
// -------------------------------------------------
/// A thin wrapper around KvIndexer that buffers recent events
/// (e.g. which may be queued by router upon startup)
///
pub struct LocalKvIndexer {
/// The underlying indexer
indexer: KvIndexer,
/// Circular buffer of recent events
event_buffer: Mutex<VecDeque<RouterEvent>>,
/// Maximum number of events to keep in buffer
max_buffer_size: usize, // Router sets this to WORKER_KV_INDEXER_BUFFER_SIZE
}
impl LocalKvIndexer {
/// create a new LocalKvIndexer pointing to a KvIndexer.
pub fn new(
token: CancellationToken,
kv_block_size: u32,
metrics: Arc<KvIndexerMetrics>,
max_buffer_size: usize,
) -> Self {
Self {
indexer: KvIndexer::new(token, kv_block_size, metrics),
event_buffer: Mutex::new(VecDeque::with_capacity(max_buffer_size)),
max_buffer_size,
}
}
/// Get all buffered events (oldest first).
pub fn get_all_events_in_buffer(&self) -> Vec<RouterEvent> {
let buffer = self.event_buffer.lock().unwrap();
buffer.iter().cloned().collect()
}
/// Query events by ID range, returning events in `[start_id, end_id]` (both inclusive).
///
/// This method attempts to serve the request from the in-memory event buffer when possible.
/// If the requested range extends beyond what's available in the buffer, a full tree dump
/// is performed instead.
///
/// ### Arguments
///
/// * `start_id` - Starting event ID (inclusive). If `None`, returns from oldest available.
/// * `end_id` - Ending event ID (inclusive). If `None`, returns up to newest available.
///
/// ### Behavior
///
/// - **Buffer path**: If `start_id >= first_buffered_id`, events are retrieved directly
/// from the buffer with their original event IDs.
///
/// - **Tree dump path**: If the range extends before the buffer or no range is specified,
/// a full tree dump is performed. **Note**: Tree dumps generate synthetic 0-indexed
/// event IDs that do NOT correspond to the original event IDs. The entire tree state
/// is returned regardless of the requested range.
///
/// ### Returns
///
/// A vector of `RouterEvent`s. When served from buffer, events have their original IDs.
/// When served from tree dump, events have synthetic sequential IDs starting from 0.
pub async fn get_events_in_id_range(
&self,
start_id: Option<u64>,
end_id: Option<u64>,
) -> Vec<RouterEvent> {
// Validate range if both specified
if let (Some(s), Some(e)) = (start_id, end_id)
&& s > e
{
tracing::warn!(
start_id = s,
end_id = e,
"Requested start_id > end_id; returning empty result."
);
return Vec::new();
}
// Check if we can serve from buffer
let buffer_range = {
let buffer = self.event_buffer.lock().unwrap();
if buffer.is_empty() {
None
} else {
Some((
buffer.front().unwrap().event.event_id,
buffer.back().unwrap().event.event_id,
))
}
};
// Determine if request can be served from buffer
let can_use_buffer = match (start_id, buffer_range) {
// No start specified means we need everything from the beginning -> tree dump
(None, _) => false,
// Buffer is empty -> tree dump
(_, None) => false,
// start_id is within or after buffer range -> can use buffer
(Some(s), Some((first_buffered, _))) => s >= first_buffered,
};
if can_use_buffer {
// Serve from buffer - these have real event IDs
self.get_buffer_events_in_id_range(start_id, end_id)
} else {
// Must dump entire tree
if let (Some(s), Some(e)) = (start_id, end_id) {
tracing::warn!(
requested_start_id = s,
requested_end_id = e,
buffer_range = ?buffer_range,
"Requested event ID range extends before buffer; dumping entire tree. \
Note: Tree dump returns synthetic 0-indexed event IDs, not original IDs."
);
} else if start_id.is_some() || end_id.is_some() {
tracing::warn!(
requested_start_id = ?start_id,
requested_end_id = ?end_id,
buffer_range = ?buffer_range,
"Partial range specified but cannot serve from buffer; dumping entire tree. \
Note: Tree dump returns synthetic 0-indexed event IDs, not original IDs."
);
}
// Return full tree dump - no filtering since IDs are synthetic
self.dump_events().await.unwrap_or_default()
}
}
/// Get events from the buffer in the range `[start_id, end_id]` (both inclusive).
pub fn get_buffer_events_in_id_range(
&self,
start_id: Option<u64>,
end_id: Option<u64>,
) -> Vec<RouterEvent> {
let buffer = self.event_buffer.lock().unwrap();
if buffer.is_empty() {
tracing::warn!("No events in buffer yet; returning empty result.");
return Vec::new();
}
let first_id = buffer.front().map(|e| e.event.event_id).unwrap();
let last_id = buffer.back().map(|e| e.event.event_id).unwrap();
let start_id = start_id.unwrap_or(first_id);
let end_id = end_id.unwrap_or(last_id);
if start_id > end_id {
tracing::warn!(
start_id,
end_id,
"Requested start_id > end_id; returning empty result."
);
return Vec::new();
}
let start_idx = match buffer.binary_search_by_key(&start_id, |e| e.event.event_id) {
Ok(idx) => idx,
Err(_) if start_id < first_id => {
tracing::warn!(
start_id,
first_id,
"Requested start_id precedes buffer; clamping to oldest."
);
0
}
Err(_) if start_id > last_id => {
tracing::error!(
start_id,
last_id,
"Requested start_id is newer than buffer; returning empty."
);
return Vec::new();
}
Err(insertion_point) => insertion_point,
};
// For inclusive end, we need idx + 1 when we find an exact match
let end_idx = match buffer.binary_search_by_key(&end_id, |e| e.event.event_id) {
Ok(idx) => idx + 1, // Include the matched element
Err(_) if end_id < first_id => {
return Vec::new();
}
Err(_) if end_id > last_id => {
tracing::warn!(
end_id,
last_id,
"Requested end_id exceeds buffer; clamping to newest."
);
buffer.len()
}
Err(insertion_point) => insertion_point,
};
buffer
.iter()
.skip(start_idx)
.take(end_idx.saturating_sub(start_idx))
.cloned()
.collect()
}
/// Record an event in the buffer
fn record_event(&self, event: RouterEvent) {
let mut buffer = self.event_buffer.lock().unwrap();
// Check that event id is consecutive to last one
if let Some(last_event) = buffer.back()
&& event.event.event_id != last_event.event.event_id + 1
{
let expected = last_event.event.event_id + 1;
tracing::error!(
worker_id = event.worker_id,
expected,
got = event.event.event_id,
"Non-consecutive KV event id; buffer may have gaps"
);
}
tracing::info!(
"Recorded event {:?} in buffer, now size is {}",
event,
buffer.len()
);
// Add to back
buffer.push_back(event);
// Remove from front if over capacity (circular buffer behavior)
while buffer.len() > self.max_buffer_size {
buffer.pop_front();
}
}
/// Apply event with buffering.
///
/// This records the event in the buffer and forwards it to the underlying indexer.
pub async fn apply_event_with_buffer(&self, event: RouterEvent) -> Result<(), KvRouterError> {
// Record in buffer
self.record_event(event.clone());
// Forward to underlying indexer
self.indexer
.event_sender()
.send(event)
.await
.map_err(|_| KvRouterError::IndexerOffline)
}
/// Clear the event buffer.
pub fn clear_buffer(&self) {
let mut buffer = self.event_buffer.lock().unwrap();
buffer.clear();
}
/// Get the current buffer size.
pub fn buffer_len(&self) -> usize {
let buffer = self.event_buffer.lock().unwrap();
buffer.len()
}
// Delegation methods to underlying KvIndexer
/// Get a sender for `RouterEvent`s.
pub fn event_sender(&self) -> mpsc::Sender<RouterEvent> {
self.indexer.event_sender()
}
/// Get a sender for dump requests (snapshot events).
pub fn snapshot_event_sender(&self) -> mpsc::Sender<DumpRequest> {
self.indexer.snapshot_event_sender()
}
/// Get a sender for worker removal requests.
pub fn remove_worker_sender(&self) -> mpsc::Sender<WorkerId> {
self.indexer.remove_worker_sender()
}
/// Get a sender for get workers requests.
pub fn get_workers_sender(&self) -> mpsc::Sender<GetWorkersRequest> {
self.indexer.get_workers_sender()
}
/// Get the KV block size.
pub fn block_size(&self) -> u32 {
self.indexer.block_size()
}
}
#[cfg(test)]
mod local_kv_indexer_tests {
use super::*;
fn make_indexer_with_events(ids: &[u64]) -> LocalKvIndexer {
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4,
Arc::new(KvIndexerMetrics::new_unregistered()),
32,
);
{
let mut buffer = indexer.event_buffer.lock().unwrap();
for &id in ids {
buffer.push_back(RouterEvent::new(
0,
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Cleared,
dp_rank: 0,
},
));
}
}
indexer
}
#[test]
fn returns_slice_within_range() {
let indexer = make_indexer_with_events(&[1, 2, 3, 4, 5]);
// Test get_buffer_events_in_id_range (buffer-only queries)
// Range is [start, end] inclusive
let mut result = indexer.get_buffer_events_in_id_range(Some(2), Some(4));
let mut ids: Vec<u64> = result
.iter()
.map(|router_event| router_event.event.event_id)
.collect();
assert_eq!(ids, vec![2, 3, 4]); // inclusive range [2, 4]
result = indexer.get_buffer_events_in_id_range(Some(2), Some(6));
ids = result
.iter()
.map(|router_event| router_event.event.event_id)
.collect();
assert_eq!(ids, vec![2, 3, 4, 5]); // clamp end to buffer max
result = indexer.get_buffer_events_in_id_range(Some(0), Some(4));
ids = result
.iter()
.map(|router_event| router_event.event.event_id)
.collect();
assert_eq!(ids, vec![1, 2, 3, 4]); // clamp start to buffer min, inclusive end
result = indexer.get_buffer_events_in_id_range(Some(3), Some(3));
ids = result
.iter()
.map(|router_event| router_event.event.event_id)
.collect();
assert_eq!(ids, vec![3]); // single element when start == end
result = indexer.get_buffer_events_in_id_range(Some(5), Some(2));
ids = result
.iter()
.map(|router_event| router_event.event.event_id)
.collect();
assert!(ids.is_empty()); // return empty when start > end
}
#[tokio::test]
async fn test_get_events_in_id_range_all_cases() {
use crate::kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
// Create indexer with small buffer (5 events max)
// This way older events will only be in the tree, not the buffer
let indexer = LocalKvIndexer::new(
CancellationToken::new(),
4, // block_size
Arc::new(KvIndexerMetrics::new_unregistered()),
5, // max_buffer_size - only keeps 5 most recent events
);
// Helper to create a test event
let make_event = |id: u64| {
RouterEvent::new(
0, // worker_id
KvCacheEvent {
event_id: id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(id * 100),
tokens_hash: LocalBlockHash(id * 200),
}],
}),
dp_rank: 0,
},
)
};
// Add 10 events (IDs 5-14)
// Buffer will only keep the last 5: events 10-14
// Tree will have all blocks
for id in 5..15 {
indexer
.apply_event_with_buffer(make_event(id))
.await
.unwrap();
}
// Wait for events to be processed by the tree
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Helper to extract event IDs from result
let get_ids = |events: Vec<RouterEvent>| -> Vec<u64> {
events.iter().map(|e| e.event.event_id).collect()
};
// Verify buffer state: should have events 10-14 (last 5)
let buffer_events = indexer.get_all_events_in_buffer();
assert_eq!(
get_ids(buffer_events),
vec![10, 11, 12, 13, 14],
"Buffer should have events 10-14"
);
// ========== BUFFER PATH TESTS (start_id >= first_buffered) ==========
// Range is [start, end] inclusive
// Test: start_id within buffer, no end
let result = indexer.get_events_in_id_range(Some(11), None).await;
assert_eq!(
get_ids(result),
vec![11, 12, 13, 14],
"start_id=11 (in buffer) should return [11, 14]"
);
// Test: start_id at buffer boundary
let result = indexer.get_events_in_id_range(Some(10), None).await;
assert_eq!(
get_ids(result),
vec![10, 11, 12, 13, 14],
"start_id=10 (buffer start) should return [10, 14]"
);
// Test: both start and end within buffer (inclusive)
let result = indexer.get_events_in_id_range(Some(11), Some(13)).await;
assert_eq!(
get_ids(result),
vec![11, 12, 13],
"range [11, 13] inclusive should return 3 events"
);
let result = indexer.get_events_in_id_range(Some(10), Some(14)).await;
assert_eq!(
get_ids(result),
vec![10, 11, 12, 13, 14],
"range [10, 14] should return all buffer events"
);
// ========== TREE DUMP PATH TESTS (range extends before buffer) ==========
// Note: Tree dumps return synthetic 0-indexed event IDs, so we just check
// that we get events back (the IDs won't match original IDs)
// Test: (None, None) dumps entire tree
let result = indexer.get_events_in_id_range(None, None).await;
assert_eq!(
result.len(),
10,
"(None, None) should dump entire tree (10 events)"
);
// Test: (None, Some(_)) dumps entire tree
let result = indexer.get_events_in_id_range(None, Some(8)).await;
assert_eq!(
result.len(),
10,
"(None, Some(_)) dumps entire tree - end_id is ignored for tree dumps"
);
// Test: start_id before buffer triggers tree dump
let result = indexer.get_events_in_id_range(Some(7), None).await;
assert_eq!(
result.len(),
10,
"start_id=7 (before buffer) should dump entire tree"
);
let result = indexer.get_events_in_id_range(Some(5), Some(12)).await;
assert_eq!(
result.len(),
10,
"range [5, 12] extending before buffer should dump entire tree"
);
// ========== EDGE CASES ==========
// Single element when start == end (inclusive range)
let result = indexer.get_events_in_id_range(Some(12), Some(12)).await;
assert_eq!(
get_ids(result),
vec![12],
"start == end should return single event"
);
// Empty when start > end
let result = indexer.get_events_in_id_range(Some(15), Some(10)).await;
assert!(result.is_empty(), "start > end should return empty");
// Request beyond buffer but valid range -> buffer returns what it has
let result = indexer.get_events_in_id_range(Some(12), Some(100)).await;
assert_eq!(
get_ids(result),
vec![12, 13, 14],
"range with end beyond buffer should return available buffer events"
);
}
}
// Implement KvIndexerInterface by delegating to the underlying indexer
#[async_trait]
impl KvIndexerInterface for LocalKvIndexer {
async fn find_matches(
&self,
sequence: Vec<LocalBlockHash>,
) -> Result<OverlapScores, KvRouterError> {
self.indexer.find_matches(sequence).await
}
async fn find_matches_for_request(
&self,
tokens: &[u32],
) -> Result<OverlapScores, KvRouterError> {
self.indexer.find_matches_for_request(tokens).await
}
async fn apply_event(&mut self, event: RouterEvent) {
// Use the buffering version
let _ = self.apply_event_with_buffer(event).await;
}
async fn remove_worker(&mut self, worker: WorkerId) {
let _ = self.indexer.remove_worker_sender().send(worker).await;
}
fn shutdown(&mut self) {
// Note: Since indexer is Arc<KvIndexer>, we can't call mutable methods directly.
// The indexer will be shut down when the CancellationToken is cancelled
// or when the last Arc reference is dropped.
}
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.indexer.dump_events().await
}
async fn process_routing_decision(
&self,
worker: WorkerWithDpRank,
local_hashes: Vec<LocalBlockHash>,
sequence_hashes: Vec<SequenceHash>,
) -> Result<(), KvRouterError> {
// TODO I guess the local kvindexers have little use for this method?
// Keeping it here now to implement the trait fully
self.indexer
.process_routing_decision(worker, local_hashes, sequence_hashes)
.await
}
async fn process_routing_decision_for_request(
&self,
tokens: &[u32],
worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> {
// TODO I guess the local kvindexers have little use for this method?
// Keeping it here now to implement the trait fully
self.indexer
.process_routing_decision_for_request(tokens, worker)
.await
}
}
#[derive(Debug, Clone)]
pub struct ShardedMatchRequest {
sequence: Vec<LocalBlockHash>,
......@@ -2978,3 +3661,158 @@ mod tests {
assert!(result.contains_key(&WorkerWithDpRank::from_worker_id(worker_2)));
}
}
#[cfg(test)]
mod tests_local_indexer {
use super::*;
use crate::kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
use tokio::time;
use tokio_util::sync::CancellationToken;
fn setup() {
dynamo_runtime::logging::init();
}
fn make_blocks(hashes: Vec<u64>) -> Vec<KvCacheStoredBlockData> {
hashes
.iter()
.map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100),
})
.collect()
}
fn create_store_event(
worker_id: WorkerId,
event_id: u64,
hashes: Vec<u64>,
parent: Option<ExternalSequenceBlockHash>,
) -> RouterEvent {
RouterEvent {
worker_id,
event: KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent,
blocks: make_blocks(hashes),
}),
dp_rank: 0,
},
}
}
#[tokio::test]
async fn test_local_indexer_buffer_and_serialization() {
// Tests components of the LocalKvIndexer query without using nats
let worker_id = 42u64;
// Create a local indexer
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100));
// Add events to local indexer's buffer
let test_event_1 = RouterEvent::new(
worker_id,
KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
}],
}),
dp_rank: 0,
},
);
// Apply events with buffer
local_indexer
.apply_event_with_buffer(test_event_1)
.await
.unwrap();
// Wait for events to be processed
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
// Get buffered events (what the query service would return)
let buffered_events = local_indexer.get_all_events_in_buffer();
// Verify buffer contents
assert_eq!(buffered_events.len(), 1, "Buffer should have 1 event");
assert_eq!(buffered_events[0].worker_id, worker_id);
assert_eq!(buffered_events[0].event.event_id, 1);
// Build the response that would be sent
let response = WorkerKvQueryResponse {
events: buffered_events.clone(),
};
// Test serialization/deserialization (simulating NATS round-trip)
let serialized = serde_json::to_vec(&response).unwrap();
let deserialized: WorkerKvQueryResponse = serde_json::from_slice(&serialized).unwrap();
// Verify response correctness
assert_eq!(deserialized.events.len(), 1);
assert_eq!(deserialized.events[0].worker_id, worker_id);
assert_eq!(deserialized.events[0].event.event_id, 1);
// Verify event data
match &deserialized.events[0].event.data {
KvCacheEventData::Stored(store_data) => {
assert_eq!(store_data.blocks.len(), 1);
assert_eq!(store_data.blocks[0].block_hash.0, 100);
assert_eq!(store_data.blocks[0].tokens_hash.0, 200);
}
_ => panic!("Expected Stored event"),
}
}
#[tokio::test]
async fn test_gap_detection_per_worker() {
setup();
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let indexer = KvIndexer::new(token.clone(), 4, metrics);
let worker_a: WorkerId = 100;
let worker_b: WorkerId = 200;
let event_tx = indexer.event_sender();
// Worker A: events 1, 2, 3 (no gap)
for id in 1..=3 {
let event = create_store_event(worker_a, id, vec![id], None);
event_tx.send(event).await.unwrap();
}
// Worker B: events 1, then 5 (gap of 2, 3, 4)
let event_b1 = create_store_event(worker_b, 1, vec![10], None);
event_tx.send(event_b1).await.unwrap();
let event_b5 = create_store_event(worker_b, 5, vec![50], None);
event_tx.send(event_b5).await.unwrap();
// Give time for events to be processed
time::sleep(Duration::from_millis(20)).await;
// Verify each worker has correct last_received_event_id
let last_ids = indexer.get_last_received_event_ids().await.unwrap();
assert_eq!(
last_ids.get(&worker_a),
Some(&3),
"Worker A should have last_id = 3 (no gap)"
);
assert_eq!(
last_ids.get(&worker_b),
Some(&5),
"Worker B should have last_id = 5 (despite gap)"
);
// Cleanup
token.cancel();
}
}
......@@ -330,6 +330,9 @@ impl<'de> Deserialize<'de> for ExternalSequenceBlockHash {
}
}
// ------
// Tests
// ------
#[cfg(test)]
mod tests {
use super::*;
......
......@@ -16,15 +16,22 @@ use tokio_util::sync::CancellationToken;
use zeromq::{Socket, SocketRecv, SubSocket};
use dynamo_runtime::metrics::{MetricsHierarchy, prometheus_names::kvstats};
use dynamo_runtime::traits::{DistributedRuntimeProvider, events::EventPublisher};
use dynamo_runtime::traits::{
DistributedRuntimeProvider, events::EventPublisher, events::EventSubscriber,
};
use dynamo_runtime::{
component::{Component, Namespace},
transports::nats::{NatsQueue, QUEUE_NAME, Slug},
};
use futures::StreamExt;
use crate::kv_router::{
KV_EVENT_SUBJECT, KV_METRICS_SUBJECT,
indexer::{RouterEvent, compute_block_hash_for_seq},
KV_EVENT_SUBJECT, KV_METRICS_SUBJECT, WORKER_KV_INDEXER_BUFFER_SIZE,
WORKER_KV_INDEXER_QUERY_SUBJECT,
indexer::{
KvIndexerInterface, KvIndexerMetrics, LocalKvIndexer, RouterEvent, WorkerKvQueryRequest,
WorkerKvQueryResponse, compute_block_hash_for_seq,
},
protocols::*,
};
use dynamo_runtime::config::environment_names::nats as env_nats;
......@@ -101,6 +108,15 @@ impl KvEventPublisher {
component: Component,
kv_block_size: u32,
source_config: Option<KvEventSourceConfig>,
) -> Result<Self> {
Self::new_with_local_indexer(component, kv_block_size, source_config, false)
}
pub fn new_with_local_indexer(
component: Component,
kv_block_size: u32,
source_config: Option<KvEventSourceConfig>,
enable_local_indexer: bool,
) -> Result<Self> {
let cancellation_token = CancellationToken::new();
......@@ -109,6 +125,18 @@ impl KvEventPublisher {
// Infer worker_id from component's connection
let worker_id = component.drt().connection_id();
tracing::info!(
worker_id,
component = component.name(),
"Initializing KvEventPublisher for worker {worker_id} in component {component}"
);
if enable_local_indexer {
tracing::info!(
"LocalKvIndexer enabled for worker {worker_id} in component {component}"
);
}
// Create our event source (if any)
let mut source = None;
if let Some(config) = source_config {
......@@ -121,6 +149,36 @@ impl KvEventPublisher {
)?);
}
// Create local indexer if requested
let local_indexer = if enable_local_indexer {
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
Some(Arc::new(LocalKvIndexer::new(
cancellation_token.clone(),
kv_block_size,
metrics,
WORKER_KV_INDEXER_BUFFER_SIZE,
)))
} else {
None
};
// Spawn runtime for router->local indexer comm if requested
let _local_indexer_query_handle = local_indexer.as_ref().map(|local_indexer_ref| {
let component = component.clone();
let local_indexer = local_indexer_ref.clone();
component
.drt()
.runtime()
.secondary()
.spawn(start_worker_kv_query_service(
component,
worker_id,
local_indexer,
cancellation_token.clone(),
))
});
let stream_name = Slug::slugify(&format!("{}.{}", component.subject(), KV_EVENT_SUBJECT))
.to_string()
.replace("_", "-");
......@@ -135,12 +193,20 @@ impl KvEventPublisher {
// Connect the NatsQueue before passing it to the event processor
let cancellation_token_clone = cancellation_token.clone();
let local_indexer_clone = local_indexer.clone();
component.drt().runtime().secondary().spawn(async move {
if let Err(e) = nats_queue.connect().await {
tracing::error!("Failed to connect NatsQueue: {}", e);
return;
}
start_event_processor(nats_queue, worker_id, cancellation_token_clone, rx).await
start_event_processor(
nats_queue,
worker_id,
cancellation_token_clone,
rx,
local_indexer_clone,
)
.await
});
Ok(Self {
......@@ -181,6 +247,7 @@ async fn start_event_processor<P: EventPublisher + Send + Sync + 'static>(
worker_id: u64,
cancellation_token: CancellationToken,
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>,
) {
loop {
tokio::select! {
......@@ -194,17 +261,129 @@ async fn start_event_processor<P: EventPublisher + Send + Sync + 'static>(
break;
};
// Encapsulate in a router event and publish.
// Encapsulate in a router event.
tracing::trace!("Event processor for worker_id {} processing event: {:?}", worker_id, event.data);
let router_event = RouterEvent::new(worker_id, event);
// Apply to local indexer first (if present)
if let Some(indexer) = &local_indexer {
// Adds event into local indexer, and logs it into internal buffer
if let Err(e) = indexer.apply_event_with_buffer(router_event.clone()).await {
tracing::warn!(
"Failed to send event to local indexer for worker {}: {}",
worker_id,
e
);
}
}
// Then publish to NATS for global distribution
if let Err(e) = publisher.publish(QUEUE_NAME, &router_event).await {
tracing::error!("Failed to publish event: {}", e);
tracing::error!("Failed to publish event to NATS: {}", e);
}
}
}
}
}
// Processor for Router -> LocalKvIndexer query service
async fn start_worker_kv_query_service(
component: Component,
worker_id: u64,
local_indexer: Arc<LocalKvIndexer>,
cancellation_token: CancellationToken,
) {
// Create NATS subscriber on a subject specific to worker's id
let subject = format!("{}.{}", WORKER_KV_INDEXER_QUERY_SUBJECT, worker_id);
let mut subscriber = match component.subscribe(&subject).await {
Ok(sub) => sub,
Err(e) => {
tracing::error!("Failed to subscribe to {}: {}", subject, e);
return; // No ? because function doesn't return Result
}
};
tracing::debug!(
"Query service on worker {} listening on NATS subject: {}",
worker_id,
subject
);
// Receive query request from router, retrieve event(s) from LocalKvIndexer, return response
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::info!("Router-Worker communication channel received cancellation signal");
break;
}
msg = subscriber.next() => {
let Some(msg) = msg else {
tracing::debug!("Router-Worker stream ended.");
break;
};
// deserialize from msg (async_nats::Message)
let request: WorkerKvQueryRequest = match serde_json::from_slice(&msg.payload) {
Ok(request) => request,
Err(e) => {
tracing::error!("Failed to deserialize WorkerKvQueryRequest: {}", e);
continue;
}
};
// TODO extract request event id range. For now, just debug print
tracing::debug!("Received WorkerKvQueryRequest: {:?}", request);
// Resolve which events to return based on optional start/end ids
let events = match (request.start_event_id, request.end_event_id) {
(None, None) => {
match local_indexer.dump_events().await {
Ok(events) => events,
Err(err) => {
tracing::error!(
error = %err,
worker_id,
"Failed to dump events for WorkerKvQueryRequest; returning buffered events instead"
);
local_indexer.get_all_events_in_buffer()
}
}
}
_ => {
local_indexer.get_events_in_id_range(request.start_event_id, request.end_event_id).await
}
};
// Build WorkerKvQueryResponse
let response = WorkerKvQueryResponse { events };
// Send reply back (if reply subject exists)
if let Some(reply_subject) = msg.reply {
let payload = match serde_json::to_vec(&response) {
Ok(p) => p,
Err(e) => {
tracing::error!("Failed to serialize response: {}", e);
continue;
}
};
// Publish through DRT/NATS directly instead of namespace (adds a prefix)
if let Err(e) = component
.drt()
.kv_router_nats_publish(reply_subject.to_string(), payload.into())
.await
{
tracing::error!("Failed to send reply: {}", e);
}
}
}
}
}
}
// Error handling configuration for ZMQ operations
const INITIAL_BACKOFF_MS: u64 = 10;
const MAX_BACKOFF_MS: u64 = 5000;
......@@ -1009,7 +1188,9 @@ mod test_event_processing {
#[cfg(test)]
mod tests_startup_helpers {
use super::*;
use crate::kv_router::protocols::ExternalSequenceBlockHash;
use crate::kv_router::KvIndexer;
use crate::kv_router::indexer::KvIndexerInterface;
use crate::kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
use async_trait;
use bytes::Bytes;
use std::sync::{Arc, Mutex};
......@@ -1090,7 +1271,7 @@ mod tests_startup_helpers {
tx.send(event).unwrap();
drop(tx);
let handle = tokio::spawn(start_event_processor(component, 1, token, rx));
let handle = tokio::spawn(start_event_processor(component, 1, token, rx, None));
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
......@@ -1103,6 +1284,300 @@ mod tests_startup_helpers {
assert_eq!(subject, QUEUE_NAME);
}
//--------------------------------------------------------------------
// Test start_event_processor with local indexer
//--------------------------------------------------------------------
#[tokio::test]
async fn test_start_event_processor_with_local_indexer() {
let (component, published) = MockComponent::new();
// Create a local indexer
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100));
// Create BlockStored event
let event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
},
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(101),
tokens_hash: LocalBlockHash(201),
},
],
}),
dp_rank: 0,
};
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tx.send(event).unwrap();
drop(tx);
// Start event processor with local indexer
let handle = tokio::spawn(start_event_processor(
component,
1,
token.clone(),
rx,
Some(local_indexer.clone()), // arc::clone just increments atomic counters
));
// Wait for processing
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
// Verify event was published to NATS (same as test_start_event_processor)
{
let published_events = published.lock().unwrap();
assert_eq!(published_events.len(), 1);
let (subject, _) = &published_events[0];
assert_eq!(subject, QUEUE_NAME);
} // drop lock
// Verify event was applied to local indexer
// We can check by querying the workers that have blocks
let get_workers_tx = local_indexer.get_workers_sender();
let mut found = false;
for _ in 0..20 {
// Try up to 20 times (200ms total)
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
get_workers_tx
.send(crate::kv_router::indexer::GetWorkersRequest { resp: resp_tx })
.await
.unwrap();
let workers: Vec<u64> = resp_rx.await.unwrap();
if workers.contains(&1) {
found = true;
break;
}
// Wait before retrying
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
// Worker 1 should be in the set (we used worker_id=1)
assert!(
found,
"Worker 1 was not found in the indexer after processing"
);
// Cleanup
token.cancel();
}
//--------------------------------------------------------------------
// Test BlockRemoved event with local indexer
//--------------------------------------------------------------------
#[tokio::test]
async fn test_event_processor_block_removed_with_local_indexer() {
let (component, published) = MockComponent::new();
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100));
// First, store a block
let store_event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
}],
}),
dp_rank: 0,
};
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tx.send(store_event).unwrap();
// Start event processor with local indexer
let handle = tokio::spawn(start_event_processor(
component,
1,
token.clone(),
rx,
Some(local_indexer.clone()),
));
// Then remove same event
let remove_event = KvCacheEvent {
event_id: 2,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(100)],
}),
dp_rank: 0,
};
tx.send(remove_event).unwrap();
drop(tx);
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
// Local indexer should have no block
let mut no_blocks = false;
for _ in 0..20 {
// Try up to 20 times (200ms total)
let scores = local_indexer
.find_matches(vec![LocalBlockHash(200)])
.await
.unwrap();
if scores.scores.is_empty() {
no_blocks = true;
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
assert!(no_blocks, "worker should have no blocks after removal");
// Global kvindexer should have recieved two events (create/remove)
let published = published.lock().unwrap();
assert_eq!(
published.len(),
2,
"expected 2 published events, found {}",
published.len()
);
token.cancel();
}
//--------------------------------------------------------------------
// Test AllBlocksCleared event with local indexer
//--------------------------------------------------------------------
#[tokio::test]
async fn test_event_processor_all_blocks_cleared_with_local_indexer() {
let (component, published) = MockComponent::new();
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100));
// Store a block
let store_event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
}],
}),
dp_rank: 0,
};
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tx.send(store_event).unwrap();
// Clear all blocks
let clear_event = KvCacheEvent {
event_id: 2,
data: KvCacheEventData::Cleared,
dp_rank: 0,
};
tx.send(clear_event).unwrap();
drop(tx);
// Create event processor and wait
let handle = tokio::spawn(start_event_processor(
component,
1,
token.clone(),
rx,
Some(local_indexer.clone()),
));
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
// Local indexer should have no block
let mut no_blocks = false;
for _ in 0..20 {
// Try up to 20 times (200ms total)
let scores = local_indexer
.find_matches(vec![LocalBlockHash(200)])
.await
.unwrap();
if scores.scores.is_empty() {
no_blocks = true;
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
assert!(no_blocks, "worker should have no blocks after clearing");
// Global kvindexer should have recieved two events (create/remove)
let published = published.lock().unwrap();
assert_eq!(
published.len(),
2,
"expected 2 published events, found {}",
published.len()
);
token.cancel();
}
//--------------------------------------------------------------------
// Test that local indexer failure doesn't break NATS publishing
//--------------------------------------------------------------------
#[tokio::test]
async fn test_event_processor_local_indexer_failure_continues() {
let (component, published) = MockComponent::new();
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100));
// cancel indexer immediately to simulate failure
token.cancel();
let event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(1)],
}),
dp_rank: 0,
};
let new_token = CancellationToken::new();
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tx.send(event).unwrap();
drop(tx);
// Despite local indexer being cancelled, event processor should continue
let handle = tokio::spawn(start_event_processor(
component,
1,
new_token,
rx,
Some(local_indexer),
));
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
// Verify event was still published to NATS despite local indexer failure
let published_events = published.lock().unwrap();
assert_eq!(published_events.len(), 1);
}
//--------------------------------------------------------------------
// Test start_zmq_listener without a real socket
// (feed it frames through a ZMQ PAIR tcp socket)
......@@ -1186,6 +1661,215 @@ mod tests_startup_helpers {
token.cancel();
let _ = listener_handle.await;
}
//--------------------------------------------------------------------
// Test distributed recovery: Router queries worker's LocalKvIndexer after outage
//--------------------------------------------------------------------
#[tokio::test]
async fn test_distributed_kvindexer_recovery_from_outage() {
let worker_1_id = 1u64;
let block_size = 4u32;
let token = CancellationToken::new();
// === SETUP: Worker Components ===
let (worker_component, worker_published) = MockComponent::new();
let local_indexer_1 = Arc::new(LocalKvIndexer::new(
token.clone(),
block_size,
Arc::new(KvIndexerMetrics::new_unregistered()),
100, // buffer size
));
let (worker_tx, worker_rx) = mpsc::unbounded_channel::<KvCacheEvent>();
// Start worker's event processor
tokio::spawn(start_event_processor(
worker_component,
worker_1_id,
token.clone(),
worker_rx,
Some(local_indexer_1.clone()),
));
// === SETUP: Router Components ===
let router_indexer = Arc::new(KvIndexer::new(
token.clone(),
block_size,
Arc::new(KvIndexerMetrics::new_unregistered()),
));
// === STEP 1: Normal Operation ===
let event_1 = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
},
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(101),
tokens_hash: LocalBlockHash(201),
},
],
}),
dp_rank: 0,
};
worker_tx.send(event_1.clone()).unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Simulate JetStream: forward worker's published event to router
let (subject, bytes) = {
let published = worker_published.lock().unwrap();
assert_eq!(published.len(), 1, "Worker should have published 1 event");
(published[0].0.clone(), published[0].1.clone())
}; // drop worker_published before await
assert_eq!(subject, QUEUE_NAME);
let router_event: RouterEvent = rmp_serde::from_slice(&bytes).unwrap();
router_indexer
.event_sender()
.send(router_event)
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// assert: Router's indexer has event
let get_workers_tx = router_indexer.get_workers_sender();
let mut router_has_worker = false;
for _ in 0..20 {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
get_workers_tx
.send(crate::kv_router::indexer::GetWorkersRequest { resp: resp_tx })
.await
.unwrap();
let workers: Vec<u64> = resp_rx.await.unwrap();
if workers.contains(&worker_1_id) {
router_has_worker = true;
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
assert!(
router_has_worker,
"Router should see worker 1 after normal operation"
);
// assert: Worker's local indexer buffered event
let buffered = local_indexer_1.get_all_events_in_buffer();
assert_eq!(buffered.len(), 1, "Local indexer should buffer 1 event");
// === STEP 2 & 3: Simulate Outage - Stop forwarding to router ===
let event_2 = KvCacheEvent {
event_id: 2,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100), // Shared prefix
tokens_hash: LocalBlockHash(200),
},
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(102), // New block
tokens_hash: LocalBlockHash(202),
},
],
}),
dp_rank: 0,
};
worker_tx.send(event_2.clone()).unwrap(); // send to worker but not to router
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// assert: Worker published event_2 to "NATS" (MockComponent)
{
let published = worker_published.lock().unwrap();
assert_eq!(
published.len(),
2,
"Worker should have published 2 events total"
);
}
// assert: Worker's local indexer has both events
let buffered = local_indexer_1.get_all_events_in_buffer();
assert_eq!(
buffered.len(),
2,
"Local indexer should have both events during outage"
);
// assert: Router DOESN'T have event_2
let block_hashes_2 = vec![LocalBlockHash(200), LocalBlockHash(202)];
let overlap = router_indexer
.find_matches(block_hashes_2.clone())
.await
.unwrap();
let router_overlap = overlap
.scores
.get(&crate::kv_router::protocols::WorkerWithDpRank::from_worker_id(worker_1_id))
.copied()
.unwrap_or(0);
assert_eq!(
router_overlap, 1,
"Router should only see 1 shared block (not the new block from event_2)"
);
// === STEP 4 & 5: Recovery - Query last received event IDs and fetch missed events ===
// Step 4a: Router queries its last received event ID per worker
let last_ids = router_indexer.get_last_received_event_ids().await.unwrap();
let last_known_id = last_ids.get(&worker_1_id).copied().unwrap_or(0);
assert_eq!(
last_known_id, 1,
"Router should have last_received_event_id = 1 for worker (only event_1 was forwarded)"
);
// Step 4b: Query worker's local indexer for events after last_known_id
let missed_events = local_indexer_1
.get_events_in_id_range(Some(last_known_id + 1), None)
.await;
assert_eq!(
missed_events.len(),
1,
"Should get 1 missed event (event_2 with id=2)"
);
// Step 5: Apply missed events to router
for router_event in missed_events {
router_indexer
.event_sender()
.send(router_event)
.await
.unwrap();
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// assert: Router now has complete state
let overlap = router_indexer.find_matches(block_hashes_2).await.unwrap();
let router_overlap_after = overlap
.scores
.get(&crate::kv_router::protocols::WorkerWithDpRank::from_worker_id(worker_1_id))
.copied()
.unwrap_or(0);
assert_eq!(
router_overlap_after, 2,
"Router should now see both blocks after recovery"
);
// assert: Router's last_received_event_id is updated after recovery
let last_ids_after = router_indexer.get_last_received_event_ids().await.unwrap();
assert_eq!(
last_ids_after.get(&worker_1_id),
Some(&2),
"Router should have last_received_event_id = 2 after recovery"
);
token.cancel();
}
}
#[cfg(test)]
......@@ -1431,3 +2115,402 @@ mod test_integration_publisher {
);
}
}
#[cfg(all(test, feature = "integration"))]
mod test_integration_publisher_with_kvindexer {
use super::*;
use crate::kv_router::scheduler::DefaultWorkerSelector;
use crate::kv_router::{KvPushRouter, KvRouter, KvRouterConfig};
use crate::local_model::LocalModelBuilder;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::mocker::engine::{MOCKER_COMPONENT, MockVllmEngine};
use crate::mocker::protocols::MockEngineArgs;
use crate::protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest};
use crate::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use dynamo_runtime::distributed_test_utils::create_test_shared_drt_async;
use dynamo_runtime::engine::AsyncEngine;
use dynamo_runtime::pipeline::{Context, PushRouter, RouterMode, network::Ingress};
use dynamo_runtime::protocols::annotated::Annotated;
/// Integration test: KvPushRouter end-to-end routing with mock engines.
#[tokio::test(flavor = "multi_thread")]
#[ignore] // Requires NATS/etcd. Run with: cargo test --package dynamo-llm --lib --features integration test_distributed_kvindexer_e2e -- --ignored --nocapture
async fn test_distributed_kvindexer_e2e() -> anyhow::Result<()> {
const BLOCK_SIZE: u32 = 4;
const NUM_REQUESTS: usize = 4;
dynamo_runtime::logging::init();
// === SETUP: Distributed runtimes and namespaces ===
let shared_store_dir = tempfile::tempdir()?;
let shared_store_path = shared_store_dir.path().to_path_buf();
// Make both runtimes point at the same file-backed storage backend so worker
// registrations and heartbeats remain visible to every DRT instance.
let distributed1 = create_test_shared_drt_async(&shared_store_path).await;
let distributed2 = create_test_shared_drt_async(&shared_store_path).await;
let component1 = distributed1
.namespace("test_e2e_router")?
.component(MOCKER_COMPONENT)?;
let component2 = distributed2
.namespace("test_e2e_router")?
.component(MOCKER_COMPONENT)?;
// === SETUP: Start mocker workers ===
let mocker_args = MockEngineArgs::builder()
.block_size(BLOCK_SIZE as usize)
.dp_size(1) // single worker per runtime
.enable_prefix_caching(true)
.enable_local_indexer(true) // affects scheduler/publisher args
.build()?;
let worker_components = vec![component1.clone(), component2.clone()];
let mut server_handles = Vec::new();
let mut worker_ids = Vec::new();
for comp in worker_components {
let engine = Arc::new(MockVllmEngine::new(mocker_args.clone()));
engine.start(comp.clone()).await?;
tracing::info!("MockVllmEngine started for {:?}", comp);
// Register MDC with runtime_config so router can discover enable_local_indexer.
// (Without this step, the MDC-based assert in query_worker() in worker_query.rs will fail.)
// This inlines code which in the Python path would be performed by:
// - local_model.rs: LocalModelBuilder::build() sets runtime_config from MockEngineArgs
// - entrypoint/input/endpoint.rs: LocalModel::attach() registers MDC via discovery
let endpoint = comp.endpoint("generate");
let runtime_config = ModelRuntimeConfig {
enable_local_indexer: true,
..Default::default()
};
let mut builder = LocalModelBuilder::default();
builder
.model_name(Some("mock".to_string()))
.kv_cache_block_size(Some(BLOCK_SIZE))
.runtime_config(runtime_config);
let mut local_model = builder.build().await?;
local_model
.attach(
&endpoint,
crate::model_type::ModelType::Chat,
crate::model_type::ModelInput::Tokens,
None,
)
.await?;
let ingress = Ingress::for_engine(engine.clone())?;
let endpoint_component = comp.clone();
let handle = tokio::spawn(async move {
if let Err(e) = endpoint_component
.endpoint("generate")
.endpoint_builder()
.handler(ingress)
.start()
.await
{
tracing::error!("Generate endpoint failed: {e}");
}
});
server_handles.push(handle);
worker_ids.push(comp.drt().connection_id());
}
tracing::info!("Generate endpoint servers launched");
tokio::time::sleep(Duration::from_millis(500)).await;
// === SETUP: Build KvPushRouter ===
let router_distributed = create_test_shared_drt_async(&shared_store_path).await;
let router_namespace = router_distributed.namespace("test_e2e_router")?;
let backend_component = router_namespace.component(MOCKER_COMPONENT)?;
let backend_endpoint = backend_component.endpoint("generate");
let client = backend_endpoint.client().await?;
let kv_router_config = KvRouterConfig::default();
let selector = Box::new(DefaultWorkerSelector::new(Some(kv_router_config)));
let consumer_id = format!("test-router-{}", router_distributed.connection_id());
let kv_router: Arc<KvRouter> = Arc::new(
KvRouter::new(
backend_endpoint.clone(),
client.clone(),
BLOCK_SIZE,
Some(selector),
Some(kv_router_config),
consumer_id,
)
.await?,
);
let push_router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
RouterMode::KV,
None,
None,
)
.await?;
let kv_push_router = KvPushRouter::new(push_router, kv_router.clone());
// ===== TEST PART 1: ROUTE & SEND REQUESTS TO WORKERS (ROUTER -> WORKER) =====
let create_request = |tokens: Vec<u32>| {
PreprocessedRequest::builder()
.model("mock".to_string())
.token_ids(tokens)
.stop_conditions(StopConditions {
max_tokens: Some(10),
..Default::default()
})
.sampling_options(SamplingOptions::default())
.output_options(OutputOptions::default())
.build()
.unwrap()
}; // from mocker/engine.rs
for i in 0..NUM_REQUESTS {
tracing::info!("Sending routed request {}", i + 1);
let tokens = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, i as u32];
let request = create_request(tokens.clone());
let response_stream = kv_push_router.generate(Context::new(request)).await?;
let responses: Vec<Annotated<LLMEngineOutput>> = response_stream.collect().await;
assert!(
!responses.is_empty(),
"Request {} should produce at least one response",
i + 1
);
}
tracing::info!("KvPushRouter generate() succeeded for {NUM_REQUESTS} requests");
// ===== TEST PART 2: QUERY WORKER-LOCAL KVINDEXERS DIRECTLY =====
// TODO: This could be refactored as router function (e.g. router.refresh_from_worker(worker_id))
// (which should also update the global kvIndexer with the buffer from the local kvIndexer)
let mut best_worker_info: Option<(u64, usize)> = None;
// Exactly one worker should have been routed requests. Find that worker
for &worker_id in &worker_ids {
let response = kv_router
.query_worker_local_kv(worker_id, None, None)
.await?;
if response.events.is_empty() {
continue;
}
let event_count = response.events.len();
tracing::info!(
worker_id,
events = event_count,
"Worker query on worker {worker_id} returned buffered KV events"
);
best_worker_info = Some((worker_id, event_count));
break;
}
// Verify that only one worker has KV events in buffer
let (best_worker_id, best_worker_event_count) =
best_worker_info.expect("At least one worker should have buffered KV events");
tracing::info!(
"Best worker is {best_worker_id} with {best_worker_event_count} buffered KV events"
);
for &worker_id in &worker_ids {
if worker_id == best_worker_id {
continue;
}
let response = kv_router
.query_worker_local_kv(worker_id, None, None)
.await?;
assert!(
response.events.is_empty(),
"Worker {worker_id} should not report buffered KV events; best worker {best_worker_id} reported {best_worker_event_count}"
);
}
// === Cleanup ===
for handle in server_handles {
handle.abort();
}
distributed1.shutdown();
distributed2.shutdown();
router_distributed.shutdown();
Ok(())
}
#[tokio::test(flavor = "multi_thread")]
#[ignore]
async fn test_distributed_kvindexer_e2e_startup() -> anyhow::Result<()> {
const BLOCK_SIZE: u32 = 4;
dynamo_runtime::logging::init();
// === SETUP: Distributed runtimes and namespaces ===
let shared_store_dir = tempfile::tempdir()?;
let shared_store_path = shared_store_dir.path().to_path_buf();
// Use a unique namespace per test run for full isolation
let test_namespace = format!("test_e2e_{}", uuid::Uuid::new_v4().simple());
// Make both runtimes point at the same file-backed storage backend so worker
// registrations and heartbeats remain visible to every DRT instance.
let distributed1 = create_test_shared_drt_async(&shared_store_path).await;
let distributed2 = create_test_shared_drt_async(&shared_store_path).await;
let component1 = distributed1
.namespace(&test_namespace)?
.component(MOCKER_COMPONENT)?;
let component2 = distributed2
.namespace(&test_namespace)?
.component(MOCKER_COMPONENT)?;
// === SETUP: Start mocker workers ===
let mocker_args = MockEngineArgs::builder()
.block_size(BLOCK_SIZE as usize)
.dp_size(1) // single worker per runtime
.enable_prefix_caching(true)
.enable_local_indexer(true) // affects scheduler/publisher args
.build()?;
let worker_components = vec![component1.clone(), component2.clone()];
let mut server_handles = Vec::new();
let mut worker_ids = Vec::new();
for comp in worker_components {
let engine: Arc<MockVllmEngine> = Arc::new(MockVllmEngine::new(mocker_args.clone()));
engine.start(comp.clone()).await?;
tracing::info!("MockVllmEngine started for {:?}", comp);
// Register MDC with runtime_config so router can discover enable_local_indexer.
// (Without this step, the MDC-based assert in query_worker() in worker_query.rs will fail.)
// This inlines code which in the Python path would be performed by:
// - local_model.rs: LocalModelBuilder::build() sets runtime_config from MockEngineArgs
// - entrypoint/input/endpoint.rs: LocalModel::attach() registers MDC via discovery
let endpoint = comp.endpoint("generate");
let runtime_config = ModelRuntimeConfig {
enable_local_indexer: true,
..Default::default()
};
let mut builder = LocalModelBuilder::default();
builder
.model_name(Some("mock".to_string()))
.kv_cache_block_size(Some(BLOCK_SIZE))
.runtime_config(runtime_config);
let mut local_model = builder.build().await?;
local_model
.attach(
&endpoint,
crate::model_type::ModelType::Chat,
crate::model_type::ModelInput::Tokens,
None,
)
.await?;
let ingress = Ingress::for_engine(engine.clone())?;
let endpoint_component = comp.clone();
let handle = tokio::spawn(async move {
if let Err(e) = endpoint_component
.endpoint("generate")
.endpoint_builder()
.handler(ingress)
.start()
.await
{
tracing::error!("Generate endpoint failed: {e}");
}
});
server_handles.push(handle);
worker_ids.push(comp.drt().connection_id());
}
tracing::info!("Generate endpoint servers launched");
tokio::time::sleep(Duration::from_millis(500)).await;
// === STEP 1: Send request to worker_ids[0] to populate its local indexer ===
// This simulates a situation where KvPushRouter is initialized
// to route to workers which already have KV events
let pre_router_distributed = create_test_shared_drt_async(&shared_store_path).await;
let pre_backend_endpoint = pre_router_distributed
.namespace(&test_namespace)?
.component(MOCKER_COMPONENT)?
.endpoint("generate");
let pre_client = pre_backend_endpoint.client().await?;
// Create a PushRouter to send requests directly to a specific worker
let pre_push_router =
PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
pre_client,
RouterMode::Random, // We'll use direct() so mode doesn't matter
None,
None,
)
.await?;
// Force sending one requests each to the two workers
for &worker_id in &worker_ids {
let tokens: Vec<u32> = vec![0, 1, 2, 3];
let request = PreprocessedRequest::builder()
.model("mock".to_string())
.token_ids(tokens.clone())
.sampling_options(SamplingOptions::default())
.output_options(OutputOptions::default())
.stop_conditions(StopConditions {
max_tokens: Some(5),
..Default::default()
})
.build()?;
let response_stream = pre_push_router
.direct(Context::new(request), worker_id)
.await?;
// Consume the stream to complete the request
let _responses: Vec<_> = response_stream.collect().await;
tracing::debug!(
"Sent request {:?} directly to worker {} to populate its local indexer",
tokens,
worker_id
);
}
tokio::time::sleep(Duration::from_millis(1000)).await;
// === SETUP: Build KvPushRouter ===
let router_distributed = create_test_shared_drt_async(&shared_store_path).await;
let router_namespace = router_distributed.namespace(&test_namespace)?;
let backend_component = router_namespace.component(MOCKER_COMPONENT)?;
let backend_endpoint = backend_component.endpoint("generate");
let client = backend_endpoint.client().await?;
let kv_router_config = KvRouterConfig::default();
let selector = Box::new(DefaultWorkerSelector::new(Some(kv_router_config)));
let consumer_id = format!("test-router-{}", router_distributed.connection_id());
let kv_router: Arc<KvRouter> = Arc::new(
KvRouter::new(
backend_endpoint.clone(),
client.clone(),
BLOCK_SIZE,
Some(selector),
Some(kv_router_config),
consumer_id,
)
.await?,
);
// At this point kvrouter's indexer should already have the
// events stored in the workers, due to the catch-up built into KvRouter::new.
// Each request generates 2 events: input block (parent_hash: None) + output block (parent_hash: Some)
// With 2 workers, that's 4 events total.
let global_kv_events = kv_router.indexer.dump_events().await?;
tracing::debug!("Global KV events: {:?}", global_kv_events);
assert_eq!(global_kv_events.len(), 4); // 2 workers × 2 events per request (input + output)
// === Cleanup ===
for handle in server_handles {
handle.abort();
}
distributed1.shutdown();
distributed2.shutdown();
router_distributed.shutdown();
Ok(())
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Background processes for the KV Router including event consumption and snapshot uploads.
use std::{collections::HashSet, time::Duration};
use std::{collections::HashMap, collections::HashSet, time::Duration};
use anyhow::Result;
use dynamo_runtime::{
......@@ -24,6 +22,7 @@ use crate::kv_router::{
indexer::{DumpRequest, GetWorkersRequest, RouterEvent},
protocols::WorkerId,
router_discovery_query,
worker_query::WorkerQueryClient,
};
/// Delay between snapshot reads to verify stability
......@@ -33,6 +32,163 @@ const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10;
const CHECK_INTERVAL_BASE: Duration = Duration::from_secs(1);
const CHECK_INTERVAL_JITTER_MS: i64 = 100;
// ============================================================================
// Local KvIndexer-based Recovery
// ============================================================================
/// Recover missed events from all workers with local indexers.
///
/// This function should be called on router startup to catch up on any events
/// that were missed while the router was offline.
///
/// # Arguments
///
/// * `worker_query_client` - Client for querying worker local indexers
/// * `last_received_event_ids` - Map of worker ID to last received event ID
/// * `worker_ids` - List of worker IDs to recover from
/// * `event_tx` - Channel to send recovered events to the indexer
///
/// # Returns
///
/// Total number of events recovered across all workers
pub async fn recover_from_all_workers(
worker_query_client: &WorkerQueryClient,
last_received_event_ids: &HashMap<WorkerId, u64>,
worker_ids: &Vec<WorkerId>,
event_tx: &mpsc::Sender<RouterEvent>,
) -> usize {
let mut total_recovered = 0;
let mut successful_workers = 0;
let mut failed_workers = 0;
for &worker_id in worker_ids {
// Skip workers without local indexer
if !worker_query_client.has_local_indexer(worker_id) {
tracing::debug!(
worker_id,
"Skipping recovery - worker does not have local indexer enabled"
);
continue;
}
// If we haven't seen any events from this worker, start from beginning (None)
// If we've seen events, start from last_known_id + 1
let start_event_id = last_received_event_ids
.get(&worker_id)
.map(|&last_id| last_id + 1);
match recover_from_worker(
worker_query_client,
worker_id,
start_event_id,
None, // Get all events after start_event_id
event_tx,
)
.await
{
Ok(count) => {
total_recovered += count;
if count > 0 {
successful_workers += 1;
}
}
Err(_) => {
failed_workers += 1;
}
}
}
// Log summary
if total_recovered > 0 || failed_workers > 0 {
tracing::info!(
total_recovered,
successful_workers,
failed_workers,
"Startup recovery completed"
);
}
total_recovered
}
/// Recover missed KV events from a specific worker.
///
/// # Arguments
///
/// * `worker_query_client` - Client for querying worker local indexers
/// * `worker_id` - The worker to recover from
/// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
/// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
/// * `event_tx` - Channel to send recovered events to the indexer
///
/// # Returns
///
/// Number of events recovered, or error if recovery failed
pub async fn recover_from_worker(
worker_query_client: &WorkerQueryClient,
worker_id: WorkerId,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
event_tx: &mpsc::Sender<RouterEvent>,
) -> Result<usize> {
if worker_query_client.has_local_indexer(worker_id) {
tracing::debug!(
worker_id,
start_event_id = ?start_event_id,
end_event_id = ?end_event_id,
"Attempting recovery from worker"
);
} else {
tracing::warn!(
"Worker {} does not have local indexer enabled, skipping recovery",
worker_id
);
return Ok(0);
}
// Query worker for events in range
let response = worker_query_client
.query_worker(worker_id, start_event_id, end_event_id)
.await?;
let events_count = response.events.len();
if events_count == 0 {
tracing::debug!(
worker_id,
start_event_id = ?start_event_id,
"No missed events to recover from worker"
);
return Ok(0);
}
tracing::info!(
worker_id,
start_event_id = ?start_event_id,
events_count,
"Recovered {} missed events from worker",
events_count
);
// Apply recovered events to the indexer
for event in response.events {
if let Err(e) = event_tx.send(event).await {
tracing::error!(
worker_id,
error = %e,
"Failed to send recovered event to indexer"
);
anyhow::bail!("Failed to send recovered event: {}", e);
}
}
Ok(events_count)
}
// ============================================================================
// Snapshot Management
// ============================================================================
/// Download a stable snapshot from object store and send events to the indexer.
/// Retries until two consecutive reads match or max attempts is reached.
async fn download_stable_snapshot(
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use anyhow::{Context, Result};
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventPublisher;
use tokio::sync::watch;
use crate::kv_router::WORKER_KV_INDEXER_QUERY_SUBJECT;
use crate::kv_router::indexer::{WorkerKvQueryRequest, WorkerKvQueryResponse};
use crate::kv_router::protocols::WorkerId;
use crate::local_model::runtime_config::ModelRuntimeConfig;
/// Router-side client for querying worker local KV indexers
///
/// Performs request/reply communication with workers via NATS.
/// (Only queries workers that have `enable_local_indexer=true` in their MDC user_data)
/// The client is spawned by KvRouter; it watches same discovery stream as the router.
pub struct WorkerQueryClient {
component: Component,
/// Watch receiver for enable_local_indexer state per worker
model_runtime_config_rx: watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>,
}
impl WorkerQueryClient {
/// Create a new WorkerQueryClient with a watch receiver for local indexer states
pub fn new(
component: Component,
model_runtime_config_rx: watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>,
) -> Self {
Self {
component,
model_runtime_config_rx,
}
}
/// Check if a worker has local indexer enabled
pub fn has_local_indexer(&self, worker_id: WorkerId) -> bool {
self.model_runtime_config_rx
.borrow()
.get(&worker_id)
.map(|config| config.enable_local_indexer)
.unwrap_or(false)
}
/// Query a specific worker's local KV indexer and return its buffered events.
/// Returns an error if the worker does not have enable_local_indexer=true.
pub async fn query_worker(
&self,
worker_id: WorkerId,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> {
// Check if worker has local indexer enabled
if !self.has_local_indexer(worker_id) {
anyhow::bail!(
"Worker {} does not have local indexer enabled (enable_local_indexer=false or not set in MDC user_data)",
worker_id
);
}
// Match worker's subscribe format
let subject_str = format!("{}.{}", WORKER_KV_INDEXER_QUERY_SUBJECT, worker_id); // see publisher.rs/start_worker_kv_query_service()
let subject = format!("{}.{}", self.component.subject(), subject_str);
tracing::debug!(
"Router sending query request to worker {} on NATS subject: {}",
worker_id,
subject
);
// Create and serialize request
let request = WorkerKvQueryRequest {
worker_id,
start_event_id,
end_event_id,
};
let request_bytes =
serde_json::to_vec(&request).context("Failed to serialize WorkerKvQueryRequest")?;
// Send NATS request with timeout using DRT helper
let timeout = tokio::time::Duration::from_secs(1);
let response_msg = self
.component
.drt()
.kv_router_nats_request(subject.clone(), request_bytes.into(), timeout)
.await
.with_context(|| {
format!(
"Failed to send request to worker {} on subject {}",
worker_id, subject
)
})?;
// Deserialize response
let response: WorkerKvQueryResponse = serde_json::from_slice(&response_msg.payload)
.context("Failed to deserialize WorkerKvQueryResponse")?;
Ok(response)
}
}
......@@ -234,6 +234,7 @@ impl LocalModelBuilder {
self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
self.runtime_config.enable_local_indexer = mocker_engine_args.enable_local_indexer;
self.runtime_config.data_parallel_size = mocker_engine_args.dp_size;
self.media_decoder = Some(MediaDecoder::default());
self.media_fetcher = Some(MediaFetcher::default());
......
......@@ -23,6 +23,10 @@ pub struct ModelRuntimeConfig {
#[serde(default = "default_data_parallel_size")]
pub data_parallel_size: u32,
/// Enable worker-local KV indexer for tracking this worker's own KV cache state
#[serde(default)]
pub enable_local_indexer: bool,
/// Mapping of engine-specific runtime configs
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub runtime_data: HashMap<String, serde_json::Value>,
......@@ -51,6 +55,7 @@ impl Default for ModelRuntimeConfig {
tool_call_parser: None,
reasoning_parser: None,
data_parallel_size: default_data_parallel_size(),
enable_local_indexer: false,
runtime_data: HashMap::new(),
tensor_model_config: None,
}
......
......@@ -72,7 +72,7 @@ pub struct KvManager {
impl KvManager {
pub fn new(max_capacity: usize, block_size: usize) -> Self {
Self::new_with_publisher(max_capacity, block_size, None, 0)
Self::new_with_publisher(max_capacity, block_size, None, 0, false)
}
pub fn new_with_publisher(
......@@ -80,6 +80,7 @@ impl KvManager {
block_size: usize,
component: Option<Component>,
dp_rank: u32,
enable_local_indexer: bool,
) -> Self {
let active_blocks = HashMap::new();
let inactive_blocks = LRUEvictor::default();
......@@ -87,10 +88,10 @@ impl KvManager {
let kv_event_publisher = component.map(|comp| {
tracing::info!(
"Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}"
"Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}, enable_local_indexer={enable_local_indexer}"
);
Arc::new(
KvEventPublisher::new(comp, block_size as u32, None)
KvEventPublisher::new_with_local_indexer(comp, block_size as u32, None, enable_local_indexer)
.expect("Failed to create KV event publisher"),
)
});
......
......@@ -120,6 +120,10 @@ pub struct MockEngineArgs {
#[serde(skip)]
#[builder(default = "Arc::new(PerfModel::default())")]
pub perf_model: Arc<PerfModel>,
/// Enable worker-local KV indexer for tracking this worker's own KV cache state
#[builder(default = "false")]
pub enable_local_indexer: bool,
}
impl Default for MockEngineArgs {
......@@ -158,6 +162,7 @@ impl MockEngineArgs {
"is_prefill",
"is_decode",
"planner_profile_data",
"enable_local_indexer",
]
.iter()
.cloned()
......@@ -239,6 +244,12 @@ impl MockEngineArgs {
builder = builder.startup_time(Some(num));
}
if let Some(value) = extra_args.get("enable_local_indexer")
&& let Some(enabled) = value.as_bool()
{
builder = builder.enable_local_indexer(enabled);
}
// Parse worker type from is_prefill and is_decode flags
let is_prefill = extra_args
.get("is_prefill")
......
......@@ -275,6 +275,7 @@ impl Scheduler {
args.block_size,
component,
dp_rank,
args.enable_local_indexer,
);
let mut hit_rates = RunningMean::new(1000);
......
......@@ -397,7 +397,7 @@ impl DistributedRuntime {
/// TODO: This is a temporary KV router measure for component/component.rs EventPublisher impl for
/// Component, to allow it to publish to NATS. KV Router is the only user.
pub(crate) async fn kv_router_nats_publish(
pub async fn kv_router_nats_publish(
&self,
subject: String,
payload: bytes::Bytes,
......@@ -420,6 +420,25 @@ impl DistributedRuntime {
Ok(nats_client.client().subscribe(subject).await?)
}
/// TODO (karenc): This is a temporary KV router measure for worker query requests.
/// Allows KV Router to perform request/reply with workers. (versus the pub/sub pattern above)
/// KV Router is the only user, made public for use in dynamo-llm crate
pub async fn kv_router_nats_request(
&self,
subject: String,
payload: bytes::Bytes,
timeout: std::time::Duration,
) -> anyhow::Result<async_nats::Message> {
let Some(nats_client) = self.nats_client.as_ref() else {
anyhow::bail!("KV router's request requires NATS");
};
let response =
tokio::time::timeout(timeout, nats_client.client().request(subject, payload))
.await
.map_err(|_| anyhow::anyhow!("Request timed out after {:?}", timeout))??;
Ok(response)
}
/// DEPRECATED: This method exists only for NATS request plane support.
/// Once everything uses the TCP request plane, this can be removed along with
/// the NATS service registration infrastructure.
......@@ -633,6 +652,26 @@ pub mod distributed_test_utils {
};
super::DistributedRuntime::new(rt, config).await.unwrap()
}
/// Helper function to create a DRT instance which points at
/// a (shared) file-backed KV store and ephemeral NATS transport so that
/// multiple DRT instances may observe the same registration state.
/// NOTE: This gets around the fact that create_test_drt_async() is
/// hardcoded to spin up a memory-backed discovery store
/// which means we can't share discovery state across runtimes.
pub async fn create_test_shared_drt_async(
store_path: &std::path::Path,
) -> super::DistributedRuntime {
use crate::{storage::kv, transports::nats};
let rt = crate::Runtime::from_current().unwrap();
let config = super::DistributedConfig {
store_backend: kv::Selector::File(store_path.to_path_buf()),
nats_config: Some(nats::ClientOptions::default()),
request_plane: crate::distributed::RequestPlaneMode::default(),
};
super::DistributedRuntime::new(rt, config).await.unwrap()
}
}
#[cfg(all(test, feature = "integration"))]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment