Unverified Commit 33249945 authored by Karen Chung's avatar Karen Chung Committed by GitHub
Browse files

feat: worker-local KvIndexer in KvEventPublisher (#4519)


Co-authored-by: default avatarYan Ru Pei <yanrpei@gmail.com>
parent 10b01b45
......@@ -2663,7 +2663,7 @@ dependencies = [
"bytes",
"candle-core 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
"chrono",
"clap 4.5.52",
"clap 4.5.53",
"criterion 0.3.6",
"cudarc",
"dashmap 5.5.3",
......@@ -4065,8 +4065,8 @@ checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97"
dependencies = [
"dirs",
"futures",
"indicatif 0.17.11",
"http 1.4.0",
"indicatif 0.17.11",
"libc",
"log",
"num_cpus",
......
......@@ -113,6 +113,7 @@ def create_temp_engine_args_file(args) -> Path:
else None,
"is_prefill": getattr(args, "is_prefill_worker", None),
"is_decode": getattr(args, "is_decode_worker", None),
"enable_local_indexer": getattr(args, "enable_local_indexer", None),
}
# Remove None values to only include explicitly set arguments
......@@ -284,6 +285,12 @@ def parse_args():
default=False,
help="Mark this as a decode worker which does not publish KV events and skips prefill cost estimation (default: False)",
)
parser.add_argument(
"--enable-local-indexer",
action="store_true",
default=False,
help="Enable worker-local KV indexer for tracking this worker's own KV cache state (default: False)",
)
parser.add_argument(
"--store-kv",
type=str,
......
......@@ -40,6 +40,7 @@ class Config:
custom_jinja_template: Optional[str] = None
store_kv: str
request_plane: str
enable_local_indexer: bool = False
# mirror vLLM
model: str
......@@ -204,6 +205,13 @@ def parse_args() -> Config:
default=os.environ.get("DYN_REQUEST_PLANE", "nats"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
)
parser.add_argument(
"--enable-local-indexer",
type=str,
choices=["true", "false"],
default=os.environ.get("DYN_LOCAL_INDEXER", "false"),
help="Enable worker-local KV indexer for tracking this worker's own KV cache state (can also be toggled with env var DYN_LOCAL_INDEXER).",
)
parser.add_argument(
"--use-vllm-tokenizer",
action="store_true",
......@@ -214,6 +222,7 @@ def parse_args() -> Config:
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
args.enable_local_indexer = str(args.enable_local_indexer).lower() == "true"
engine_args = AsyncEngineArgs.from_cli_args(args)
# Workaround for vLLM GIL contention bug with NIXL connector when using UniProcExecutor.
......@@ -312,6 +321,7 @@ def parse_args() -> Config:
config.mm_prompt_template = args.mm_prompt_template
config.store_kv = args.store_kv
config.request_plane = args.request_plane
config.enable_local_indexer = args.enable_local_indexer
config.use_vllm_tokenizer = args.use_vllm_tokenizer
# Validate custom Jinja template file exists if provided
......
......@@ -224,6 +224,7 @@ def setup_kv_event_publisher(
worker_id=generate_endpoint.connection_id(),
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
enable_local_indexer=config.enable_local_indexer,
)
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
kv_publishers.append(kv_publisher)
......@@ -336,6 +337,7 @@ async def register_vllm_model(
runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
runtime_config.enable_local_indexer = config.enable_local_indexer
# Add tool/reasoning parsers for decode models
if model_type != ModelType.Prefill:
......
......@@ -21,7 +21,7 @@ use rs::traits::events::EventSubscriber;
use tracing;
use llm_rs::kv_router::protocols::*;
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks};
use llm_rs::kv_router::publisher::{KvEventSourceConfig, create_stored_blocks, start_zmq_listener};
use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
#[pyfunction]
......@@ -106,6 +106,9 @@ pub struct ZmqKvEventPublisherConfig {
pub zmq_endpoint: String,
#[pyo3(get, set)]
pub zmq_topic: String,
#[pyo3(get, set)]
pub enable_local_indexer: bool, // whether the underlying KvEventPublisher publishes to
// both global and worker-local KvIndexers
}
#[pymethods]
......@@ -115,19 +118,22 @@ impl ZmqKvEventPublisherConfig {
worker_id,
kv_block_size,
zmq_endpoint = "tcp://127.0.0.1:5557".to_string(),
zmq_topic = "".to_string()
zmq_topic = "".to_string(),
enable_local_indexer = false
))]
pub fn new(
worker_id: WorkerId,
kv_block_size: usize,
zmq_endpoint: String,
zmq_topic: String,
enable_local_indexer: bool,
) -> Self {
Self {
worker_id,
kv_block_size,
zmq_endpoint,
zmq_topic,
enable_local_indexer,
}
}
}
......@@ -141,13 +147,14 @@ pub(crate) struct ZmqKvEventPublisher {
impl ZmqKvEventPublisher {
#[new]
fn new(component: Component, config: ZmqKvEventPublisherConfig) -> PyResult<Self> {
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new_with_local_indexer(
component.inner,
config.kv_block_size as u32,
Some(KvEventSourceConfig::Zmq {
endpoint: config.zmq_endpoint,
topic: config.zmq_topic,
}),
config.enable_local_indexer,
)
.map_err(to_pyerr)?;
Ok(Self { inner })
......@@ -179,7 +186,7 @@ impl ZmqKvEventListener {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<KvCacheEvent>();
let shutdown_token = tokio_util::sync::CancellationToken::new();
tokio::spawn(llm_rs::kv_router::publisher::start_zmq_listener(
tokio::spawn(start_zmq_listener(
zmq_endpoint,
zmq_topic,
tx,
......
......@@ -49,6 +49,11 @@ impl ModelRuntimeConfig {
self.inner.data_parallel_size = data_parallel_size;
}
#[setter]
fn set_enable_local_indexer(&mut self, enable_local_indexer: bool) {
self.inner.enable_local_indexer = enable_local_indexer;
}
fn set_engine_specific(&mut self, key: &str, value: String) -> PyResult<()> {
let value: serde_json::Value = serde_json::from_str(&value).map_err(to_pyerr)?;
self.inner
......@@ -103,6 +108,11 @@ impl ModelRuntimeConfig {
self.inner.reasoning_parser.clone()
}
#[getter]
fn enable_local_indexer(&self) -> bool {
self.inner.enable_local_indexer
}
#[getter]
fn runtime_data(&self, py: Python<'_>) -> PyResult<PyObject> {
let dict = PyDict::new(py);
......
......@@ -460,6 +460,7 @@ class ModelRuntimeConfig:
max_num_batched_tokens: int | None
tool_call_parser: str | None
reasoning_parser: str | None
enable_local_indexer: bool
runtime_data: dict[str, Any]
tensor_model_config: Any | None
......@@ -843,7 +844,8 @@ class ZmqKvEventPublisherConfig:
worker_id: int,
kv_block_size: int,
zmq_endpoint: str = "tcp://127.0.0.1:5557",
zmq_topic: str = ""
zmq_topic: str = "",
enable_local_indexer: bool = False
) -> None:
"""
Configuration for the ZmqKvEventPublisher.
......@@ -852,6 +854,7 @@ class ZmqKvEventPublisherConfig:
:param kv_block_size: The block size for the key-value store.
:param zmq_endpoint: The ZeroMQ endpoint. Defaults to "tcp://127.0.0.1:5557".
:param zmq_topic: The ZeroMQ topic to subscribe to. Defaults to an empty string.
:param enable_local_indexer: Whether to enable the worker-local KV indexer. Defaults to False.
"""
...
......
......@@ -34,8 +34,11 @@ pub mod scheduler;
pub mod scoring;
pub mod sequence;
pub mod subscriber;
pub mod worker_query;
use indexer::WorkerKvQueryResponse;
pub use prefill_router::PrefillRouter;
use worker_query::WorkerQueryClient;
use crate::{
kv_router::{
......@@ -45,11 +48,12 @@ use crate::{
compute_block_hash_for_seq, compute_seq_hash_for_block,
},
protocols::{
LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult, WorkerWithDpRank,
LocalBlockHash, RouterRequest, RouterResponse, WorkerId, WorkerSelectionResult,
WorkerWithDpRank,
},
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
sequence::SequenceError,
subscriber::start_kv_router_background,
subscriber::{recover_from_all_workers, start_kv_router_background},
},
local_model::runtime_config::ModelRuntimeConfig,
model_card::ModelDeploymentCard,
......@@ -77,6 +81,10 @@ pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";
// for worker-local kvindexer query
pub const WORKER_KV_INDEXER_QUERY_SUBJECT: &str = "worker_kv_indexer_query";
pub const WORKER_KV_INDEXER_BUFFER_SIZE: usize = 1024; // store 1024 most recent events in worker buffer
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";
......@@ -270,6 +278,8 @@ pub struct KvRouter {
cancellation_token: tokio_util::sync::CancellationToken,
client: Client,
worker_query_client: Option<WorkerQueryClient>,
}
impl KvRouter {
......@@ -296,7 +306,7 @@ impl KvRouter {
endpoint: endpoint_id.name.clone(),
};
let discovery_stream = discovery
.list_and_watch(discovery_key, Some(cancellation_token.clone()))
.list_and_watch(discovery_key.clone(), Some(cancellation_token.clone()))
.await?;
let runtime_configs_rx =
watch_and_extract_field(discovery_stream, |card: ModelDeploymentCard| {
......@@ -333,13 +343,19 @@ impl KvRouter {
component.clone(),
block_size,
instance_ids_rx,
runtime_configs_rx,
runtime_configs_rx.clone(),
selector,
kv_router_config.router_replica_sync,
consumer_id.clone(),
)
.await?;
// Initialize worker query client using namespace abstraction
// (created before background task so we can use it for startup recovery)
let worker_query_client =
worker_query::WorkerQueryClient::new(component.clone(), runtime_configs_rx.clone());
tracing::info!("Worker query client initialized");
// Start KV event subscriber background process (only when use_kv_events is enabled)
if kv_router_config.use_kv_events
&& let Indexer::KvIndexer(ref kv_indexer) = indexer
......@@ -360,6 +376,47 @@ impl KvRouter {
kv_router_config.router_reset_states,
)
.await?;
// Perform startup recovery from workers with local indexers
// This catches up on any events missed while the router was offline
let last_event_ids = kv_indexer
.get_last_received_event_ids()
.await
.unwrap_or_default();
let instances = client.instance_source.as_ref().borrow().clone();
let worker_ids: Vec<WorkerId> = instances.iter().map(|i| i.instance_id).collect();
if !worker_ids.is_empty() {
tracing::info!(
worker_count = worker_ids.len(),
"Starting recovery from workers with local indexers"
);
// NOTE: recover_from_all_workers() is a no-op if
// Worker with worker_id is not associated with a
// local indexer instance.
let recovered = recover_from_all_workers(
&worker_query_client,
&last_event_ids,
&worker_ids,
&kv_indexer.event_sender(),
)
.await;
if recovered > 0 {
tracing::info!(
recovered_events = recovered,
"KV Router startup: Recovered {} KV events from workers {:?}",
recovered,
worker_ids
);
} else {
tracing::info!(
"KV Router startup: No KV events recovered from workers {:?}",
worker_ids
);
}
}
}
tracing::info!("KV Routing initialized");
......@@ -370,6 +427,7 @@ impl KvRouter {
kv_router_config,
cancellation_token,
client,
worker_query_client: Some(worker_query_client),
})
}
......@@ -502,6 +560,62 @@ impl KvRouter {
pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.indexer.dump_events().await
}
/// Query a specific worker's local KV indexer for its events
/// (See docstring for `WorkerQueryClient.query_worker()`)
pub async fn query_worker_local_kv(
&self,
worker_id: WorkerId,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> {
let query_client = self
.worker_query_client
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Worker query client not available (NATS required)"))?;
query_client
.query_worker(worker_id, start_event_id, end_event_id)
.await
}
/// Recover missed KV events from a specific worker.
///
/// Queries the worker's local KV indexer for events starting from
/// `start_event_id` and applies them to the router's indexer.
///
/// # Arguments
///
/// * `worker_id` - The worker to recover from
/// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
/// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
pub async fn recover_from_worker(
&self,
worker_id: WorkerId,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<usize> {
let query_client = self
.worker_query_client
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Worker query client not available"))?;
let event_tx = match &self.indexer {
Indexer::KvIndexer(kv_indexer) => kv_indexer.event_sender(),
Indexer::None => {
anyhow::bail!("Cannot recover: indexer is disabled (--overlap_score_weight is 0)")
}
};
subscriber::recover_from_worker(
query_client,
worker_id,
start_event_id,
end_event_id,
&event_tx,
)
.await
}
}
// NOTE: KVRouter works like a PushRouter,
......
This diff is collapsed.
......@@ -330,6 +330,9 @@ impl<'de> Deserialize<'de> for ExternalSequenceBlockHash {
}
}
// ------
// Tests
// ------
#[cfg(test)]
mod tests {
use super::*;
......
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Background processes for the KV Router including event consumption and snapshot uploads.
use std::{collections::HashSet, time::Duration};
use std::{collections::HashMap, collections::HashSet, time::Duration};
use anyhow::Result;
use dynamo_runtime::{
......@@ -24,6 +22,7 @@ use crate::kv_router::{
indexer::{DumpRequest, GetWorkersRequest, RouterEvent},
protocols::WorkerId,
router_discovery_query,
worker_query::WorkerQueryClient,
};
/// Delay between snapshot reads to verify stability
......@@ -33,6 +32,163 @@ const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10;
const CHECK_INTERVAL_BASE: Duration = Duration::from_secs(1);
const CHECK_INTERVAL_JITTER_MS: i64 = 100;
// ============================================================================
// Local KvIndexer-based Recovery
// ============================================================================
/// Recover missed events from all workers with local indexers.
///
/// This function should be called on router startup to catch up on any events
/// that were missed while the router was offline.
///
/// # Arguments
///
/// * `worker_query_client` - Client for querying worker local indexers
/// * `last_received_event_ids` - Map of worker ID to last received event ID
/// * `worker_ids` - List of worker IDs to recover from
/// * `event_tx` - Channel to send recovered events to the indexer
///
/// # Returns
///
/// Total number of events recovered across all workers
pub async fn recover_from_all_workers(
worker_query_client: &WorkerQueryClient,
last_received_event_ids: &HashMap<WorkerId, u64>,
worker_ids: &Vec<WorkerId>,
event_tx: &mpsc::Sender<RouterEvent>,
) -> usize {
let mut total_recovered = 0;
let mut successful_workers = 0;
let mut failed_workers = 0;
for &worker_id in worker_ids {
// Skip workers without local indexer
if !worker_query_client.has_local_indexer(worker_id) {
tracing::debug!(
worker_id,
"Skipping recovery - worker does not have local indexer enabled"
);
continue;
}
// If we haven't seen any events from this worker, start from beginning (None)
// If we've seen events, start from last_known_id + 1
let start_event_id = last_received_event_ids
.get(&worker_id)
.map(|&last_id| last_id + 1);
match recover_from_worker(
worker_query_client,
worker_id,
start_event_id,
None, // Get all events after start_event_id
event_tx,
)
.await
{
Ok(count) => {
total_recovered += count;
if count > 0 {
successful_workers += 1;
}
}
Err(_) => {
failed_workers += 1;
}
}
}
// Log summary
if total_recovered > 0 || failed_workers > 0 {
tracing::info!(
total_recovered,
successful_workers,
failed_workers,
"Startup recovery completed"
);
}
total_recovered
}
/// Recover missed KV events from a specific worker.
///
/// # Arguments
///
/// * `worker_query_client` - Client for querying worker local indexers
/// * `worker_id` - The worker to recover from
/// * `start_event_id` - First event ID to fetch (inclusive), or None to start from beginning
/// * `end_event_id` - Last event ID to fetch (inclusive), or None for all
/// * `event_tx` - Channel to send recovered events to the indexer
///
/// # Returns
///
/// Number of events recovered, or error if recovery failed
pub async fn recover_from_worker(
worker_query_client: &WorkerQueryClient,
worker_id: WorkerId,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
event_tx: &mpsc::Sender<RouterEvent>,
) -> Result<usize> {
if worker_query_client.has_local_indexer(worker_id) {
tracing::debug!(
worker_id,
start_event_id = ?start_event_id,
end_event_id = ?end_event_id,
"Attempting recovery from worker"
);
} else {
tracing::warn!(
"Worker {} does not have local indexer enabled, skipping recovery",
worker_id
);
return Ok(0);
}
// Query worker for events in range
let response = worker_query_client
.query_worker(worker_id, start_event_id, end_event_id)
.await?;
let events_count = response.events.len();
if events_count == 0 {
tracing::debug!(
worker_id,
start_event_id = ?start_event_id,
"No missed events to recover from worker"
);
return Ok(0);
}
tracing::info!(
worker_id,
start_event_id = ?start_event_id,
events_count,
"Recovered {} missed events from worker",
events_count
);
// Apply recovered events to the indexer
for event in response.events {
if let Err(e) = event_tx.send(event).await {
tracing::error!(
worker_id,
error = %e,
"Failed to send recovered event to indexer"
);
anyhow::bail!("Failed to send recovered event: {}", e);
}
}
Ok(events_count)
}
// ============================================================================
// Snapshot Management
// ============================================================================
/// Download a stable snapshot from object store and send events to the indexer.
/// Retries until two consecutive reads match or max attempts is reached.
async fn download_stable_snapshot(
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use anyhow::{Context, Result};
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventPublisher;
use tokio::sync::watch;
use crate::kv_router::WORKER_KV_INDEXER_QUERY_SUBJECT;
use crate::kv_router::indexer::{WorkerKvQueryRequest, WorkerKvQueryResponse};
use crate::kv_router::protocols::WorkerId;
use crate::local_model::runtime_config::ModelRuntimeConfig;
/// Router-side client for querying worker local KV indexers
///
/// Performs request/reply communication with workers via NATS.
/// (Only queries workers that have `enable_local_indexer=true` in their MDC user_data)
/// The client is spawned by KvRouter; it watches same discovery stream as the router.
pub struct WorkerQueryClient {
component: Component,
/// Watch receiver for enable_local_indexer state per worker
model_runtime_config_rx: watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>,
}
impl WorkerQueryClient {
/// Create a new WorkerQueryClient with a watch receiver for local indexer states
pub fn new(
component: Component,
model_runtime_config_rx: watch::Receiver<HashMap<WorkerId, ModelRuntimeConfig>>,
) -> Self {
Self {
component,
model_runtime_config_rx,
}
}
/// Check if a worker has local indexer enabled
pub fn has_local_indexer(&self, worker_id: WorkerId) -> bool {
self.model_runtime_config_rx
.borrow()
.get(&worker_id)
.map(|config| config.enable_local_indexer)
.unwrap_or(false)
}
/// Query a specific worker's local KV indexer and return its buffered events.
/// Returns an error if the worker does not have enable_local_indexer=true.
pub async fn query_worker(
&self,
worker_id: WorkerId,
start_event_id: Option<u64>,
end_event_id: Option<u64>,
) -> Result<WorkerKvQueryResponse> {
// Check if worker has local indexer enabled
if !self.has_local_indexer(worker_id) {
anyhow::bail!(
"Worker {} does not have local indexer enabled (enable_local_indexer=false or not set in MDC user_data)",
worker_id
);
}
// Match worker's subscribe format
let subject_str = format!("{}.{}", WORKER_KV_INDEXER_QUERY_SUBJECT, worker_id); // see publisher.rs/start_worker_kv_query_service()
let subject = format!("{}.{}", self.component.subject(), subject_str);
tracing::debug!(
"Router sending query request to worker {} on NATS subject: {}",
worker_id,
subject
);
// Create and serialize request
let request = WorkerKvQueryRequest {
worker_id,
start_event_id,
end_event_id,
};
let request_bytes =
serde_json::to_vec(&request).context("Failed to serialize WorkerKvQueryRequest")?;
// Send NATS request with timeout using DRT helper
let timeout = tokio::time::Duration::from_secs(1);
let response_msg = self
.component
.drt()
.kv_router_nats_request(subject.clone(), request_bytes.into(), timeout)
.await
.with_context(|| {
format!(
"Failed to send request to worker {} on subject {}",
worker_id, subject
)
})?;
// Deserialize response
let response: WorkerKvQueryResponse = serde_json::from_slice(&response_msg.payload)
.context("Failed to deserialize WorkerKvQueryResponse")?;
Ok(response)
}
}
......@@ -234,6 +234,7 @@ impl LocalModelBuilder {
self.runtime_config.max_num_seqs = mocker_engine_args.max_num_seqs.map(|v| v as u64);
self.runtime_config.max_num_batched_tokens =
mocker_engine_args.max_num_batched_tokens.map(|v| v as u64);
self.runtime_config.enable_local_indexer = mocker_engine_args.enable_local_indexer;
self.runtime_config.data_parallel_size = mocker_engine_args.dp_size;
self.media_decoder = Some(MediaDecoder::default());
self.media_fetcher = Some(MediaFetcher::default());
......
......@@ -23,6 +23,10 @@ pub struct ModelRuntimeConfig {
#[serde(default = "default_data_parallel_size")]
pub data_parallel_size: u32,
/// Enable worker-local KV indexer for tracking this worker's own KV cache state
#[serde(default)]
pub enable_local_indexer: bool,
/// Mapping of engine-specific runtime configs
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub runtime_data: HashMap<String, serde_json::Value>,
......@@ -51,6 +55,7 @@ impl Default for ModelRuntimeConfig {
tool_call_parser: None,
reasoning_parser: None,
data_parallel_size: default_data_parallel_size(),
enable_local_indexer: false,
runtime_data: HashMap::new(),
tensor_model_config: None,
}
......
......@@ -72,7 +72,7 @@ pub struct KvManager {
impl KvManager {
pub fn new(max_capacity: usize, block_size: usize) -> Self {
Self::new_with_publisher(max_capacity, block_size, None, 0)
Self::new_with_publisher(max_capacity, block_size, None, 0, false)
}
pub fn new_with_publisher(
......@@ -80,6 +80,7 @@ impl KvManager {
block_size: usize,
component: Option<Component>,
dp_rank: u32,
enable_local_indexer: bool,
) -> Self {
let active_blocks = HashMap::new();
let inactive_blocks = LRUEvictor::default();
......@@ -87,10 +88,10 @@ impl KvManager {
let kv_event_publisher = component.map(|comp| {
tracing::info!(
"Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}"
"Initializing KV event publisher for DP rank {dp_rank} with block_size {block_size}, enable_local_indexer={enable_local_indexer}"
);
Arc::new(
KvEventPublisher::new(comp, block_size as u32, None)
KvEventPublisher::new_with_local_indexer(comp, block_size as u32, None, enable_local_indexer)
.expect("Failed to create KV event publisher"),
)
});
......
......@@ -120,6 +120,10 @@ pub struct MockEngineArgs {
#[serde(skip)]
#[builder(default = "Arc::new(PerfModel::default())")]
pub perf_model: Arc<PerfModel>,
/// Enable worker-local KV indexer for tracking this worker's own KV cache state
#[builder(default = "false")]
pub enable_local_indexer: bool,
}
impl Default for MockEngineArgs {
......@@ -158,6 +162,7 @@ impl MockEngineArgs {
"is_prefill",
"is_decode",
"planner_profile_data",
"enable_local_indexer",
]
.iter()
.cloned()
......@@ -239,6 +244,12 @@ impl MockEngineArgs {
builder = builder.startup_time(Some(num));
}
if let Some(value) = extra_args.get("enable_local_indexer")
&& let Some(enabled) = value.as_bool()
{
builder = builder.enable_local_indexer(enabled);
}
// Parse worker type from is_prefill and is_decode flags
let is_prefill = extra_args
.get("is_prefill")
......
......@@ -275,6 +275,7 @@ impl Scheduler {
args.block_size,
component,
dp_rank,
args.enable_local_indexer,
);
let mut hit_rates = RunningMean::new(1000);
......
......@@ -397,7 +397,7 @@ impl DistributedRuntime {
/// TODO: This is a temporary KV router measure for component/component.rs EventPublisher impl for
/// Component, to allow it to publish to NATS. KV Router is the only user.
pub(crate) async fn kv_router_nats_publish(
pub async fn kv_router_nats_publish(
&self,
subject: String,
payload: bytes::Bytes,
......@@ -420,6 +420,25 @@ impl DistributedRuntime {
Ok(nats_client.client().subscribe(subject).await?)
}
/// TODO (karenc): This is a temporary KV router measure for worker query requests.
/// Allows KV Router to perform request/reply with workers. (versus the pub/sub pattern above)
/// KV Router is the only user, made public for use in dynamo-llm crate
pub async fn kv_router_nats_request(
&self,
subject: String,
payload: bytes::Bytes,
timeout: std::time::Duration,
) -> anyhow::Result<async_nats::Message> {
let Some(nats_client) = self.nats_client.as_ref() else {
anyhow::bail!("KV router's request requires NATS");
};
let response =
tokio::time::timeout(timeout, nats_client.client().request(subject, payload))
.await
.map_err(|_| anyhow::anyhow!("Request timed out after {:?}", timeout))??;
Ok(response)
}
/// DEPRECATED: This method exists only for NATS request plane support.
/// Once everything uses the TCP request plane, this can be removed along with
/// the NATS service registration infrastructure.
......@@ -633,6 +652,26 @@ pub mod distributed_test_utils {
};
super::DistributedRuntime::new(rt, config).await.unwrap()
}
/// Helper function to create a DRT instance which points at
/// a (shared) file-backed KV store and ephemeral NATS transport so that
/// multiple DRT instances may observe the same registration state.
/// NOTE: This gets around the fact that create_test_drt_async() is
/// hardcoded to spin up a memory-backed discovery store
/// which means we can't share discovery state across runtimes.
pub async fn create_test_shared_drt_async(
store_path: &std::path::Path,
) -> super::DistributedRuntime {
use crate::{storage::kv, transports::nats};
let rt = crate::Runtime::from_current().unwrap();
let config = super::DistributedConfig {
store_backend: kv::Selector::File(store_path.to_path_buf()),
nats_config: Some(nats::ClientOptions::default()),
request_plane: crate::distributed::RequestPlaneMode::default(),
};
super::DistributedRuntime::new(rt, config).await.unwrap()
}
}
#[cfg(all(test, feature = "integration"))]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment