"examples/vscode:/vscode.git/clone" did not exist on "0024f39a3224326a9f871919cf16a06c58edfdad"
Unverified Commit 488c8709 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: Router warm restarts via durable KV event consumers and radix snapshotting (#2756)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 3f09c395
...@@ -149,6 +149,19 @@ def parse_args(): ...@@ -149,6 +149,19 @@ def parse_args():
default=False, default=False,
help="KV Router: Enable replica synchronization across multiple router instances. When true, routers will publish and subscribe to events to maintain consistent state.", help="KV Router: Enable replica synchronization across multiple router instances. When true, routers will publish and subscribe to events to maintain consistent state.",
) )
parser.add_argument(
"--router-snapshot-threshold",
type=int,
default=10000,
help="KV Router: Number of messages in stream before triggering a snapshot. Defaults to 10000.",
)
parser.add_argument(
"--router-persist-states",
action="store_false",
dest="router_reset_states",
default=True,
help="KV Router: Persist router state on startup. Keep existing state from stream and object store (default: reset states).",
)
parser.add_argument( parser.add_argument(
"--busy-threshold", "--busy-threshold",
type=float, type=float,
...@@ -212,6 +225,8 @@ async def async_main(): ...@@ -212,6 +225,8 @@ async def async_main():
router_temperature=flags.router_temperature, router_temperature=flags.router_temperature,
use_kv_events=flags.use_kv_events, use_kv_events=flags.use_kv_events,
router_replica_sync=flags.router_replica_sync, router_replica_sync=flags.router_replica_sync,
router_snapshot_threshold=flags.router_snapshot_threshold,
router_reset_states=flags.router_reset_states,
) )
elif flags.router_mode == "random": elif flags.router_mode == "random":
router_mode = RouterMode.Random router_mode = RouterMode.Random
......
...@@ -15,7 +15,7 @@ When KV blocks are created or removed, the engine notifies the Dynamo router, wh ...@@ -15,7 +15,7 @@ When KV blocks are created or removed, the engine notifies the Dynamo router, wh
To evaluate the benefits of KV-aware routing, compare your workload's performance using `--router-mode random|round-robin` against KV-aware routing. To evaluate the benefits of KV-aware routing, compare your workload's performance using `--router-mode random|round-robin` against KV-aware routing.
The KV-aware routing arguments: The main KV-aware routing arguments:
- `--kv-overlap-score-weight`: Controls the importance of prefix cache overlaps in prefill cost calculations. Higher values improve Time To First Token (TTFT) at the cost of Inter-Token Latency (ITL). When set to 0, the router ignores prefix caches and uses pure load balancing. Defaults to 1. - `--kv-overlap-score-weight`: Controls the importance of prefix cache overlaps in prefill cost calculations. Higher values improve Time To First Token (TTFT) at the cost of Inter-Token Latency (ITL). When set to 0, the router ignores prefix caches and uses pure load balancing. Defaults to 1.
...@@ -23,7 +23,11 @@ The KV-aware routing arguments: ...@@ -23,7 +23,11 @@ The KV-aware routing arguments:
- `--use-kv-events`/`--no-kv-events`: Determines how the router tracks cached blocks. When enabled (default), uses `KvIndexer` to monitor block creation and deletion events. When disabled, uses `ApproxKvIndexer`, which estimates cache hits based on a fixed time window (120s). Disable this if your backend doesn't support KV events. - `--use-kv-events`/`--no-kv-events`: Determines how the router tracks cached blocks. When enabled (default), uses `KvIndexer` to monitor block creation and deletion events. When disabled, uses `ApproxKvIndexer`, which estimates cache hits based on a fixed time window (120s). Disable this if your backend doesn't support KV events.
- `--router-replica-sync`: Enables NATS-based state synchronization between router replicas. When enabled, routers share their KV cache distribution and active sequence information, ensuring optimal routing decisions across multiple router instances. This improves fault tolerance and routing accuracy in distributed deployments. Disabled by default. - `--router-replica-sync`: Enables NATS-based synchronization of local routing decisions between router replicas. When enabled, routers share their active sequence information and local predictions of block usage, improving routing consistency across instances. Note that this does not sync the radix tree or cached KV block states themselves - those are synchronized through JetStream events. Disabled by default.
- `--router-reset-states`/`--router-persist-states`: Controls whether the router state is reset on startup. When `--router-reset-states` is used (default), the router clears both the JetStream event stream and NATs object store, starting with a fresh state. When `--router-persist-states` is used, the router retains existing state from previous runs, downloading any available snapshot from NATs object store and continuing to consume events from where it left off. This enables routers to maintain KV cache awareness across restarts. **Note**: State persistence is only available when `--use-kv-events` is enabled (default). When using `--no-kv-events` with `ApproxKvIndexer`, state persistence is not supported.
- `--router-snapshot-threshold`: Sets the number of messages in the JetStream before triggering a snapshot. When the message count exceeds this threshold, a router will attempt to purge acknowledged messages from the stream and create a snapshot of the current radix tree state in NATs object store. Defaults to 10000. This helps manage stream size and provides faster initialization for routers that restart.
## Architecture ## Architecture
...@@ -50,17 +54,26 @@ KV Cache routing uses direct routing with a special worker selection algorithm. ...@@ -50,17 +54,26 @@ KV Cache routing uses direct routing with a special worker selection algorithm.
For improved fault tolerance, you can launch two frontend + router replicas. Since the frontend and router are currently tied together, you'll need to use two different HTTP ports for each instance. For improved fault tolerance, you can launch two frontend + router replicas. Since the frontend and router are currently tied together, you'll need to use two different HTTP ports for each instance.
To enable state sharing between the router replicas (which provides more accurate routing decisions), use the `--router-replica-sync` flag when starting the frontend: To enable state sharing between the router replicas (which provides more accurate routing decisions), use the `--router-replica-sync` flag when starting the frontend. Router replicas are currently tied to a component, and state syncing and sharing can only happen within the component group. Here's an example of running multiple router replicas:
```bash ```bash
# Router replica 1 # Router replica 1
python -m dynamo.frontend --router-mode kv --port 8000 --router-replica-sync python -m dynamo.frontend --router-mode kv --port 8000 --router-replica-sync
# Router replica 2 # Router replica 2 (can be started later, note the extra --router-persist-states arg)
python -m dynamo.frontend --router-mode kv --port 8001 --router-replica-sync python -m dynamo.frontend --router-mode kv --port 8001 --router-replica-sync --router-persist-states
``` ```
When `--router-replica-sync` is enabled, the router replicas will communicate with each other via NATS to maintain consistent state across instances. This allows both routers to have a complete view of the KV cache distribution and make optimal routing decisions, even when requests are distributed across multiple router instances. After these two replicas are launched, they will share the same JetStream and snapshot state. The second replica can be started after the first has already been handling requests. As long as `--router-persist-states` is set, the new replica will sync its KV block indexer by consuming the JetStream events and/or downloading the latest snapshot, ensuring both replicas have the same view of cached blocks across workers. It's okay for one router to go down, or even both to go down - the state persistence ensures continuity (up to the message retention of an hour we set for the stream). When a third router starts (with `--router-persist-states`), the states will still persist:
```bash
# Router replica 3 (can be started even after replicas 1 and 2 have gone down)
python -m dynamo.frontend --router-mode kv --port 8002 --router-replica-sync --router-persist-states
```
> **Note:** If a router replica is launched without the `--router-persist-states` flag, the entire stream and radix snapshot will be purged. If you want to serve a separate router (targeting a different set of workers) independently without affecting the current state, consider using a new namespace/component (see [Distributed Runtime](distributed_runtime.md)) which will start a new stream and NATS object store path.
When `--router-replica-sync` is enabled, the router replicas will additionally share their local routing decisions and active sequence predictions via NATS. Active blocks information is communicated between routers in a fire-and-forget manner, but the routers will quickly become consistent as this information is tied to the request cycle. This helps maintain consistent load estimates across instances even when requests are distributed between routers.
## Understanding KV Cache ## Understanding KV Cache
The leading Large Language Models (LLMs) today are auto-regressive and based off of the [transformer architecture](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf). One key inference optimization technique is to cache the already computed keys and values and to reuse them for the future tokens. This is called the [KV Cache](https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/#key-value_caching). The leading Large Language Models (LLMs) today are auto-regressive and based off of the [transformer architecture](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf). One key inference optimization technique is to cache the already computed keys and values and to reuse them for the future tokens. This is called the [KV Cache](https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/#key-value_caching).
...@@ -182,3 +195,8 @@ In distributed deployments with multiple routers, each router maintains visibili ...@@ -182,3 +195,8 @@ In distributed deployments with multiple routers, each router maintains visibili
Each event carries a unique router ID to prevent self-event processing. This asynchronous communication system ensures optimal routing decisions by maintaining consistent KV cache state across all routers, even as they handle different request streams. Each event carries a unique router ID to prevent self-event processing. This asynchronous communication system ensures optimal routing decisions by maintaining consistent KV cache state across all routers, even as they handle different request streams.
### Event Persistence and Recovery
KV cache events are persisted in NATS JetStream, allowing router replicas to maintain their global view of KV blocks across restarts. When a router starts with `--router-persist-states`, it downloads any available snapshot from NATs object store and continues consuming events from its last acknowledged position in the stream.
To manage stream growth, when the message count exceeds `--router-snapshot-threshold`, a router acquires an etcd-based distributed lock, purges acknowledged messages from the stream, and uploads the current radix tree state to NATs object store. This snapshot serves as a checkpoint for faster initialization of future router instances.
...@@ -248,6 +248,9 @@ impl Flags { ...@@ -248,6 +248,9 @@ impl Flags {
self.use_kv_events, self.use_kv_events,
self.router_replica_sync, self.router_replica_sync,
self.max_num_batched_tokens, self.max_num_batched_tokens,
// defaulting below args (no longer maintaining new flags for dynamo-run)
None,
None,
), ),
) )
} }
......
...@@ -42,12 +42,14 @@ impl KvRouterConfig { ...@@ -42,12 +42,14 @@ impl KvRouterConfig {
#[pymethods] #[pymethods]
impl KvRouterConfig { impl KvRouterConfig {
#[new] #[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false))] #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_snapshot_threshold=10000, router_reset_states=true))]
fn new( fn new(
overlap_score_weight: f64, overlap_score_weight: f64,
router_temperature: f64, router_temperature: f64,
use_kv_events: bool, use_kv_events: bool,
router_replica_sync: bool, router_replica_sync: bool,
router_snapshot_threshold: Option<u32>,
router_reset_states: bool,
) -> Self { ) -> Self {
KvRouterConfig { KvRouterConfig {
inner: RsKvRouterConfig { inner: RsKvRouterConfig {
...@@ -55,6 +57,8 @@ impl KvRouterConfig { ...@@ -55,6 +57,8 @@ impl KvRouterConfig {
router_temperature, router_temperature,
use_kv_events, use_kv_events,
router_replica_sync, router_replica_sync,
router_snapshot_threshold,
router_reset_states,
..Default::default() ..Default::default()
}, },
} }
......
...@@ -19,6 +19,7 @@ use std::sync::atomic::AtomicU32; ...@@ -19,6 +19,7 @@ use std::sync::atomic::AtomicU32;
use tokio_stream::StreamExt; use tokio_stream::StreamExt;
use super::*; use super::*;
use crate::Component;
use llm_rs::kv_router::indexer::compute_block_hash_for_seq; use llm_rs::kv_router::indexer::compute_block_hash_for_seq;
use llm_rs::kv_router::indexer::KvIndexerInterface; use llm_rs::kv_router::indexer::KvIndexerInterface;
use llm_rs::kv_router::protocols::ForwardPassMetrics as RsForwardPassMetrics; use llm_rs::kv_router::protocols::ForwardPassMetrics as RsForwardPassMetrics;
...@@ -405,39 +406,36 @@ pub(crate) struct KvIndexer { ...@@ -405,39 +406,36 @@ pub(crate) struct KvIndexer {
#[pymethods] #[pymethods]
impl KvIndexer { impl KvIndexer {
#[new] #[new]
fn new(component: Component, kv_block_size: usize) -> PyResult<Self> { #[pyo3(signature = (component, kv_block_size, consumer_uuid=None))]
fn new(
component: Component,
kv_block_size: usize,
consumer_uuid: Option<String>,
) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime(); let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async { runtime.block_on(async {
let cancellation_token = component.inner.drt().runtime().child_token();
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> = let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> =
llm_rs::kv_router::indexer::KvIndexer::new( llm_rs::kv_router::indexer::KvIndexer::new(
component.inner.drt().runtime().child_token(), cancellation_token.clone(),
kv_block_size as u32, kv_block_size as u32,
) )
.into(); .into();
// [gluo TODO] try subscribe_with_type::<RouterEvent>,
// error checking below will be different. // Use the shared start_kv_router_background function for event consumption
let mut kv_events_rx = component // Pass None for snapshot_tx to skip snapshot handling in Python bindings
.inner llm_rs::kv_router::subscriber::start_kv_router_background(
.subscribe(llm_rs::kv_router::KV_EVENT_SUBJECT) component.inner.clone(),
consumer_uuid.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
inner.event_sender(),
None,
cancellation_token,
None,
true,
)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
let kv_events_tx = inner.event_sender();
// [FIXME] this is the added functionality to the indexer to subscribe to kv events,
// should have been made to a trait and implemented here? i.e. AsyncEngine style
tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await {
let event: llm_rs::kv_router::indexer::RouterEvent =
serde_json::from_slice(&event.payload).unwrap();
tracing::debug!("received kv event: {:?}", event);
if let Err(e) = kv_events_tx.send(event).await {
tracing::trace!(
"failed to send kv event to indexer; shutting down: {:?}",
e
);
}
}
});
Ok(Self { inner }) Ok(Self { inner })
}) })
} }
...@@ -845,6 +843,7 @@ impl SpecDecodeStats { ...@@ -845,6 +843,7 @@ impl SpecDecodeStats {
#[pyclass] #[pyclass]
pub(crate) struct KvPushRouter { pub(crate) struct KvPushRouter {
inner: Arc<llm_rs::kv_router::KvPushRouter>, inner: Arc<llm_rs::kv_router::KvPushRouter>,
primary_token: tokio_util::sync::CancellationToken,
} }
#[pymethods] #[pymethods]
...@@ -875,12 +874,25 @@ impl KvPushRouter { ...@@ -875,12 +874,25 @@ impl KvPushRouter {
// Get component from endpoint // Get component from endpoint
let component = endpoint.inner.component(); let component = endpoint.inner.component();
// Create KvRouter // Get the primary token from the component's primary lease
let primary_token = component
.drt()
.primary_lease()
.ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Failed to get primary lease: Cannot KV route static workers",
)
})?
.primary_token();
// Create KvRouter with a unique consumer UUID
let consumer_uuid = uuid::Uuid::new_v4().to_string();
let kv_router = llm_rs::kv_router::KvRouter::new( let kv_router = llm_rs::kv_router::KvRouter::new(
component.clone(), component.clone(),
block_size as u32, block_size as u32,
None, // default selector None, // default selector
Some(kv_router_config.inner()), Some(kv_router_config.inner()),
consumer_uuid,
) )
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
...@@ -891,6 +903,7 @@ impl KvPushRouter { ...@@ -891,6 +903,7 @@ impl KvPushRouter {
Ok(Self { Ok(Self {
inner: Arc::new(kv_push_router), inner: Arc::new(kv_push_router),
primary_token,
}) })
}) })
} }
...@@ -996,6 +1009,25 @@ impl KvPushRouter { ...@@ -996,6 +1009,25 @@ impl KvPushRouter {
}) })
}) })
} }
/// Dump all events from the KV router's indexer as a JSON string
fn dump_events<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let events = inner.dump_events().await.map_err(to_pyerr)?;
// Serialize to JSON string
let json_str = serde_json::to_string(&events).map_err(to_pyerr)?;
Ok(json_str)
})
}
}
impl Drop for KvPushRouter {
fn drop(&mut self) {
// Cancel the primary token to shut down background tasks
self.primary_token.cancel();
}
} }
// Python async generator wrapper for the stream // Python async generator wrapper for the stream
......
...@@ -138,17 +138,38 @@ async def test_event_handler(distributed_runtime): ...@@ -138,17 +138,38 @@ async def test_event_handler(distributed_runtime):
event_publisher.store_event(test_token, lora_id) event_publisher.store_event(test_token, lora_id)
# wait for the event to be processed as it is sent asynchronously # wait for the event to be processed as it is sent asynchronously
await asyncio.sleep(1) # Retry loop for CI environments where processing may take longer
for retry in range(10): # Try up to 10 times
await asyncio.sleep(0.5) # Wait 500ms between retries
scores = await indexer.find_matches_for_request(test_token, lora_id) scores = await indexer.find_matches_for_request(test_token, lora_id)
assert scores.scores if (
assert worker_id in scores.scores scores.scores
assert scores.scores[worker_id] == 1 and worker_id in scores.scores
and scores.scores[worker_id] == 1
):
break
if retry == 9: # Last iteration
# Provide detailed error message for debugging
assert scores.scores, f"No scores found after {(retry+1)*0.5}s"
assert (
worker_id in scores.scores
), f"Worker {worker_id} not in scores after {(retry+1)*0.5}s"
assert (
scores.scores[worker_id] == 1
), f"Expected score 1, got {scores.scores.get(worker_id)} after {(retry+1)*0.5}s"
# remove event # remove event
event_publisher.remove_event() event_publisher.remove_event()
await asyncio.sleep(1) # Retry loop for event removal verification
for retry in range(10): # Try up to 10 times
await asyncio.sleep(0.5) # Wait 500ms between retries
scores = await indexer.find_matches_for_request(test_token, lora_id) scores = await indexer.find_matches_for_request(test_token, lora_id)
assert not scores.scores if not scores.scores:
break
if retry == 9: # Last iteration
assert (
not scores.scores
), f"Scores still present after {(retry+1)*0.5}s: {scores.scores}"
async def test_approx_kv_indexer(distributed_runtime): async def test_approx_kv_indexer(distributed_runtime):
...@@ -235,12 +256,13 @@ async def test_metrics_aggregator(distributed_runtime): ...@@ -235,12 +256,13 @@ async def test_metrics_aggregator(distributed_runtime):
asyncio.create_task(metrics_publisher_task(kv_listener, expected_metrics)) asyncio.create_task(metrics_publisher_task(kv_listener, expected_metrics))
# needs time for publisher to spawn up # needs time for publisher to spawn up
for i in range(10): # Using shorter intervals for faster detection in normal cases
await asyncio.sleep(1) for i in range(20): # Try up to 20 times (10 seconds total)
await asyncio.sleep(0.5) # Wait 500ms between retries
metrics = await metrics_aggregator.get_metrics() metrics = await metrics_aggregator.get_metrics()
if metrics.endpoints: if metrics.endpoints:
break break
assert metrics.endpoints assert metrics.endpoints, f"No metrics endpoints found after {(i+1)*0.5}s"
for endpoint in metrics.endpoints: for endpoint in metrics.endpoints:
# [TODO] not really checking id for now, can't get it as create_endpoint() # [TODO] not really checking id for now, can't get it as create_endpoint()
# create and serve the endpoint internally # create and serve the endpoint internally
......
...@@ -12,3 +12,6 @@ pub use watcher::{ModelUpdate, ModelWatcher}; ...@@ -12,3 +12,6 @@ pub use watcher::{ModelUpdate, ModelWatcher};
/// The root etcd path for ModelEntry /// The root etcd path for ModelEntry
pub const MODEL_ROOT_PATH: &str = "models"; pub const MODEL_ROOT_PATH: &str = "models";
/// The root etcd path for KV Router registrations
pub const KV_ROUTERS_ROOT_PATH: &str = "kv_routers";
...@@ -10,9 +10,8 @@ use parking_lot::Mutex; ...@@ -10,9 +10,8 @@ use parking_lot::Mutex;
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::prelude::DistributedRuntimeProvider; use dynamo_runtime::prelude::DistributedRuntimeProvider;
use dynamo_runtime::slug::Slug;
use crate::discovery::ModelEntry; use crate::discovery::{KV_ROUTERS_ROOT_PATH, ModelEntry};
use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector}; use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector};
use crate::{ use crate::{
kv_router::KvRouter, kv_router::KvRouter,
...@@ -218,10 +217,12 @@ impl ModelManager { ...@@ -218,10 +217,12 @@ impl ModelManager {
.drt() .drt()
.etcd_client() .etcd_client()
.ok_or_else(|| anyhow::anyhow!("KV routing requires etcd (dynamic mode)"))?; .ok_or_else(|| anyhow::anyhow!("KV routing requires etcd (dynamic mode)"))?;
let router_uuid = uuid::Uuid::new_v4();
let router_key = format!( let router_key = format!(
"kv_routers/{}/{}", "{}/{}/{}",
Slug::from_string(model_name), KV_ROUTERS_ROOT_PATH,
uuid::Uuid::new_v4() component.path(),
router_uuid
); );
etcd_client etcd_client
.kv_create( .kv_create(
...@@ -237,6 +238,7 @@ impl ModelManager { ...@@ -237,6 +238,7 @@ impl ModelManager {
kv_cache_block_size, kv_cache_block_size,
Some(selector), Some(selector),
kv_router_config, kv_router_config,
router_uuid.to_string(),
) )
.await?; .await?;
let new_kv_chooser = Arc::new(chooser); let new_kv_chooser = Arc::new(chooser);
......
...@@ -15,6 +15,7 @@ use dynamo_runtime::{ ...@@ -15,6 +15,7 @@ use dynamo_runtime::{
}, },
prelude::*, prelude::*,
protocols::annotated::Annotated, protocols::annotated::Annotated,
utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction},
}; };
use futures::stream::{self, StreamExt}; use futures::stream::{self, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -29,6 +30,7 @@ pub mod recorder; ...@@ -29,6 +30,7 @@ pub mod recorder;
pub mod scheduler; pub mod scheduler;
pub mod scoring; pub mod scoring;
pub mod sequence; pub mod sequence;
pub mod subscriber;
use crate::{ use crate::{
discovery::{MODEL_ROOT_PATH, ModelEntry}, discovery::{MODEL_ROOT_PATH, ModelEntry},
...@@ -41,14 +43,13 @@ use crate::{ ...@@ -41,14 +43,13 @@ use crate::{
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult}, protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest}, scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints, scoring::ProcessedEndpoints,
subscriber::start_kv_router_background,
}, },
local_model::runtime_config::ModelRuntimeConfig, local_model::runtime_config::ModelRuntimeConfig,
preprocessor::PreprocessedRequest, preprocessor::PreprocessedRequest,
protocols::common::llm_backend::LLMEngineOutput, protocols::common::llm_backend::LLMEngineOutput,
}; };
use dynamo_runtime::traits::events::EventSubscriber;
// [gluo TODO] shouldn't need to be public // [gluo TODO] shouldn't need to be public
// this should be discovered from the component // this should be discovered from the component
...@@ -64,6 +65,12 @@ pub const KV_METRICS_SUBJECT: &str = "kv_metrics"; ...@@ -64,6 +65,12 @@ pub const KV_METRICS_SUBJECT: &str = "kv_metrics";
pub const PREFILL_SUBJECT: &str = "prefill_events"; pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events"; pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";
pub const ROUTER_SNAPSHOT_LOCK: &str = "router-snapshot-lock";
pub const ROUTER_CLEANUP_LOCK: &str = "router-cleanup-lock";
/// A trait that users can implement to define custom selection logic /// A trait that users can implement to define custom selection logic
pub trait WorkerSelector { pub trait WorkerSelector {
fn select_worker( fn select_worker(
...@@ -98,6 +105,12 @@ pub struct KvRouterConfig { ...@@ -98,6 +105,12 @@ pub struct KvRouterConfig {
// TODO: this is not actually used for now // TODO: this is not actually used for now
// Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting // Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
pub max_num_batched_tokens: u32, pub max_num_batched_tokens: u32,
/// Threshold for triggering snapshots. If None, no snapshots will be performed.
pub router_snapshot_threshold: Option<u32>,
/// Whether to reset the router state on startup (default: true)
pub router_reset_states: bool,
} }
impl Default for KvRouterConfig { impl Default for KvRouterConfig {
...@@ -108,6 +121,8 @@ impl Default for KvRouterConfig { ...@@ -108,6 +121,8 @@ impl Default for KvRouterConfig {
use_kv_events: true, use_kv_events: true,
router_replica_sync: false, router_replica_sync: false,
max_num_batched_tokens: 8192, max_num_batched_tokens: 8192,
router_snapshot_threshold: Some(10000),
router_reset_states: true,
} }
} }
} }
...@@ -121,6 +136,8 @@ impl KvRouterConfig { ...@@ -121,6 +136,8 @@ impl KvRouterConfig {
use_kv_events: Option<bool>, use_kv_events: Option<bool>,
replica_sync: Option<bool>, replica_sync: Option<bool>,
max_num_batched_tokens: Option<u32>, max_num_batched_tokens: Option<u32>,
router_snapshot_threshold: Option<Option<u32>>,
router_reset_states: Option<bool>,
) -> Self { ) -> Self {
let default = Self::default(); let default = Self::default();
Self { Self {
...@@ -130,6 +147,9 @@ impl KvRouterConfig { ...@@ -130,6 +147,9 @@ impl KvRouterConfig {
router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync), router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
max_num_batched_tokens: max_num_batched_tokens max_num_batched_tokens: max_num_batched_tokens
.unwrap_or(default.max_num_batched_tokens), .unwrap_or(default.max_num_batched_tokens),
router_snapshot_threshold: router_snapshot_threshold
.unwrap_or(default.router_snapshot_threshold),
router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
} }
} }
} }
...@@ -151,6 +171,13 @@ impl Indexer { ...@@ -151,6 +171,13 @@ impl Indexer {
Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await, Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
} }
} }
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
match self {
Indexer::KvIndexer(indexer) => indexer.dump_events().await,
Indexer::ApproxKvIndexer(indexer) => indexer.dump_events().await,
}
}
} }
/// A KvRouter only decides which worker you should use. It doesn't send you there. /// A KvRouter only decides which worker you should use. It doesn't send you there.
...@@ -170,6 +197,7 @@ impl KvRouter { ...@@ -170,6 +197,7 @@ impl KvRouter {
block_size: u32, block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
consumer_uuid: String,
) -> Result<Self> { ) -> Result<Self> {
let kv_router_config = kv_router_config.unwrap_or_default(); let kv_router_config = kv_router_config.unwrap_or_default();
...@@ -196,9 +224,6 @@ impl KvRouter { ...@@ -196,9 +224,6 @@ impl KvRouter {
.etcd_client() .etcd_client()
.expect("Cannot KV route without etcd client"); .expect("Cannot KV route without etcd client");
use dynamo_runtime::utils::typed_prefix_watcher::{
key_extractors, watch_prefix_with_extraction,
};
let runtime_configs_watcher = watch_prefix_with_extraction( let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client, etcd_client,
MODEL_ROOT_PATH, MODEL_ROOT_PATH,
...@@ -230,31 +255,20 @@ impl KvRouter { ...@@ -230,31 +255,20 @@ impl KvRouter {
) )
.await?; .await?;
// [gluo TODO] try subscribe_with_type::<RouterEvent>, // Start unified background process if using KvIndexer
// error checking below will be different.
if let Indexer::KvIndexer(ref kv_indexer) = indexer { if let Indexer::KvIndexer(ref kv_indexer) = indexer {
let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?; start_kv_router_background(
let kv_events_tx = kv_indexer.event_sender(); component.clone(),
consumer_uuid,
tokio::spawn(async move { kv_indexer.event_sender(),
while let Some(event) = kv_events_rx.next().await { kv_router_config
let event: RouterEvent = match serde_json::from_slice(&event.payload) { .router_snapshot_threshold
Ok(event) => event, .map(|_| kv_indexer.snapshot_event_sender()),
Err(e) => { cancellation_token.clone(),
tracing::warn!("Failed to deserialize RouterEvent: {:?}", e); kv_router_config.router_snapshot_threshold,
// Choosing warn and continue to process other events from other workers kv_router_config.router_reset_states,
// A bad event likely signals a problem with a worker, but potentially other workers are still healthy )
continue; .await?;
}
};
if let Err(e) = kv_events_tx.send(event).await {
tracing::warn!(
"failed to send kv event to indexer; shutting down: {:?}",
e
);
}
}
});
} }
tracing::info!("KV Routing initialized"); tracing::info!("KV Routing initialized");
...@@ -318,6 +332,11 @@ impl KvRouter { ...@@ -318,6 +332,11 @@ impl KvRouter {
pub fn block_size(&self) -> u32 { pub fn block_size(&self) -> u32 {
self.block_size self.block_size
} }
/// Dump all events from the indexer
pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.indexer.dump_events().await
}
} }
// NOTE: this would not be usable for now, should deprecate // NOTE: this would not be usable for now, should deprecate
...@@ -351,6 +370,11 @@ impl KvPushRouter { ...@@ -351,6 +370,11 @@ impl KvPushRouter {
) -> Self { ) -> Self {
KvPushRouter { inner, chooser } KvPushRouter { inner, chooser }
} }
/// Dump all events from the KV router's indexer
pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.chooser.dump_events().await
}
} }
#[async_trait] #[async_trait]
......
...@@ -25,10 +25,9 @@ use tokio_util::sync::CancellationToken; ...@@ -25,10 +25,9 @@ use tokio_util::sync::CancellationToken;
use crate::tokens::{SequenceHash, TokenBlockSequence}; use crate::tokens::{SequenceHash, TokenBlockSequence};
use crate::kv_router::RouterEvent;
use crate::kv_router::indexer::{ use crate::kv_router::indexer::{
DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, WorkerId, DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, RouterEvent,
compute_block_hash_for_seq, WorkerId, compute_block_hash_for_seq,
}; };
use crate::kv_router::protocols::{ use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData, ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
......
...@@ -760,6 +760,15 @@ impl KvIndexer { ...@@ -760,6 +760,15 @@ impl KvIndexer {
pub fn event_sender(&self) -> mpsc::Sender<RouterEvent> { pub fn event_sender(&self) -> mpsc::Sender<RouterEvent> {
self.event_tx.clone() self.event_tx.clone()
} }
/// Get a sender for dump requests (snapshot events).
///
/// ### Returns
///
/// A `mpsc::Sender` for `DumpRequest`s.
pub fn snapshot_event_sender(&self) -> mpsc::Sender<DumpRequest> {
self.dump_tx.clone()
}
} }
#[async_trait] #[async_trait]
......
...@@ -30,6 +30,7 @@ use dynamo_runtime::{ ...@@ -30,6 +30,7 @@ use dynamo_runtime::{
network::Ingress, network::Ingress,
}, },
protocols::annotated::Annotated, protocols::annotated::Annotated,
transports::nats::{NatsQueue, QUEUE_NAME, Slug},
}; };
use futures::stream; use futures::stream;
use std::sync::{Arc, OnceLock}; use std::sync::{Arc, OnceLock};
...@@ -133,16 +134,27 @@ impl KvEventPublisher { ...@@ -133,16 +134,27 @@ impl KvEventPublisher {
)?); )?);
} }
component let stream_name = Slug::slugify(&format!("{}.{}", component.subject(), KV_EVENT_SUBJECT))
.drt() .to_string()
.runtime() .replace("_", "-");
.secondary() let nats_server =
.spawn(start_event_processor( std::env::var("NATS_SERVER").unwrap_or_else(|_| "nats://localhost:4222".to_string());
component, // Create NatsQueue without consumer since we're only publishing
worker_id, let mut nats_queue = NatsQueue::new_without_consumer(
cancellation_token.clone(), stream_name,
rx, nats_server,
)); std::time::Duration::from_secs(60), // 1 minute timeout
);
// Connect the NatsQueue before passing it to the event processor
let cancellation_token_clone = cancellation_token.clone();
component.drt().runtime().secondary().spawn(async move {
if let Err(e) = nats_queue.connect().await {
tracing::error!("Failed to connect NatsQueue: {}", e);
return;
}
start_event_processor(nats_queue, worker_id, cancellation_token_clone, rx).await
});
Ok(Self { Ok(Self {
kv_block_size, kv_block_size,
...@@ -198,7 +210,7 @@ async fn start_event_processor<P: EventPublisher + Send + Sync + 'static>( ...@@ -198,7 +210,7 @@ async fn start_event_processor<P: EventPublisher + Send + Sync + 'static>(
// Encapsulate in a router event and publish. // Encapsulate in a router event and publish.
tracing::trace!("Event processor for worker_id {} processing event: {:?}", worker_id, event.data); tracing::trace!("Event processor for worker_id {} processing event: {:?}", worker_id, event.data);
let router_event = RouterEvent::new(worker_id, event); let router_event = RouterEvent::new(worker_id, event);
if let Err(e) = publisher.publish(KV_EVENT_SUBJECT, &router_event).await { if let Err(e) = publisher.publish(QUEUE_NAME, &router_event).await {
tracing::error!("Failed to publish event: {}", e); tracing::error!("Failed to publish event: {}", e);
} }
} }
...@@ -929,7 +941,7 @@ mod tests_startup_helpers { ...@@ -929,7 +941,7 @@ mod tests_startup_helpers {
let published = published.lock().unwrap(); let published = published.lock().unwrap();
assert_eq!(published.len(), 1); assert_eq!(published.len(), 1);
let (subject, _) = &published[0]; let (subject, _) = &published[0];
assert_eq!(subject, &KV_EVENT_SUBJECT.to_string()); assert_eq!(subject, QUEUE_NAME);
} }
//-------------------------------------------------------------------- //--------------------------------------------------------------------
......
...@@ -271,7 +271,8 @@ impl ActiveSequencesMultiWorker { ...@@ -271,7 +271,8 @@ impl ActiveSequencesMultiWorker {
let component_clone = component.clone(); let component_clone = component.clone();
let router_id_clone = router_id; let router_id_clone = router_id;
tokio::spawn(async move { component.drt().runtime().secondary().spawn(async move {
// NATS subscription loop
if let Err(e) = Self::subscribe_to_events( if let Err(e) = Self::subscribe_to_events(
senders_clone, senders_clone,
request_to_worker_clone, request_to_worker_clone,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Background processes for the KV Router including event consumption and snapshot uploads.
use std::time::Duration;
use anyhow::Result;
use dynamo_runtime::{
component::Component,
prelude::*,
traits::events::EventPublisher,
transports::{
etcd::WatchEvent,
nats::{NatsQueue, Slug},
},
};
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use crate::{
discovery::KV_ROUTERS_ROOT_PATH,
kv_router::{
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, ROUTER_CLEANUP_LOCK,
ROUTER_SNAPSHOT_LOCK,
indexer::{DumpRequest, RouterEvent},
},
};
/// Resources required for snapshot operations
#[derive(Clone)]
struct SnapshotResources {
nats_client: dynamo_runtime::transports::nats::Client,
bucket_name: String,
etcd_client: dynamo_runtime::transports::etcd::Client,
lock_name: String,
}
impl SnapshotResources {
/// Try to acquire distributed lock for snapshot operations
/// Returns Some(lock_response) if lock acquired, None if another instance holds it
async fn lock(&self) -> Option<etcd_client::LockResponse> {
match self
.etcd_client
.lock(self.lock_name.clone(), Some(self.etcd_client.lease_id()))
.await
{
Ok(response) => {
tracing::debug!(
"Successfully acquired snapshot lock with key: {:?}",
response.key()
);
Some(response)
}
Err(e) => {
tracing::debug!("Another instance already holds the snapshot lock: {e:?}");
None
}
}
}
/// Release the distributed lock
async fn unlock(&self, lock_response: etcd_client::LockResponse) {
if let Err(e) = self.etcd_client.unlock(lock_response.key()).await {
tracing::warn!("Failed to release snapshot lock: {e:?}");
}
}
}
/// Start a unified background task for event consumption and optional snapshot management
pub async fn start_kv_router_background(
component: Component,
consumer_uuid: String,
kv_events_tx: mpsc::Sender<RouterEvent>,
snapshot_tx: Option<mpsc::Sender<DumpRequest>>,
cancellation_token: CancellationToken,
router_snapshot_threshold: Option<u32>,
router_reset_states: bool,
) -> Result<()> {
// Set up NATS connections
let stream_name = Slug::slugify(&format!("{}.{}", component.subject(), KV_EVENT_SUBJECT))
.to_string()
.replace("_", "-");
let nats_server =
std::env::var("NATS_SERVER").unwrap_or_else(|_| "nats://localhost:4222".to_string());
// Create NatsQueue for event consumption
let mut nats_queue = NatsQueue::new_with_consumer(
stream_name.clone(),
nats_server.clone(),
std::time::Duration::from_secs(60), // 1 minute timeout
consumer_uuid,
);
nats_queue.connect_with_reset(router_reset_states).await?;
// Always create NATS client (needed for both reset and snapshots)
let client_options = dynamo_runtime::transports::nats::Client::builder()
.server(&nats_server)
.build()?;
let nats_client = client_options.connect().await?;
// Create bucket name for snapshots/state
let bucket_name = Slug::slugify(&format!("{}-{RADIX_STATE_BUCKET}", component.subject()))
.to_string()
.replace("_", "-");
// Handle initial state based on router_reset_states flag
if router_reset_states {
// Delete the bucket to reset state
tracing::info!("Resetting router state, deleting bucket: {bucket_name}");
if let Err(e) = nats_client.object_store_delete_bucket(&bucket_name).await {
tracing::warn!("Failed to delete bucket (may not exist): {e:?}");
}
} else {
// Try to download initial state from object store
let url = url::Url::parse(&format!(
"nats://{}/{bucket_name}/{RADIX_STATE_FILE}",
nats_client.addr()
))?;
match nats_client
.object_store_download_data::<Vec<RouterEvent>>(url)
.await
{
Ok(events) => {
tracing::info!(
"Successfully downloaded {} events from object store",
events.len()
);
// Send all events to the indexer
for event in events {
if let Err(e) = kv_events_tx.send(event).await {
tracing::warn!("Failed to send initial event to indexer: {e:?}");
}
}
tracing::info!("Successfully sent all initial events to indexer");
}
Err(e) => {
tracing::info!(
"Did not initialize radix state from NATs object store (likely no snapshots yet): {e:?}"
);
}
}
}
// Get etcd client (needed for both snapshots and router watching)
let etcd_client = component
.drt()
.etcd_client()
.ok_or_else(|| anyhow::anyhow!("etcd client not available"))?;
// Watch for router deletions to clean up orphaned consumers
let (_prefix_str, _watcher, mut router_replicas_rx) = etcd_client
.kv_get_and_watch_prefix(&format!("{}/", KV_ROUTERS_ROOT_PATH))
.await?
.dissolve();
let cleanup_lock_name = format!("{}/{}", ROUTER_CLEANUP_LOCK, component.subject());
// Only set up snapshot-related resources if snapshot_tx is provided and threshold is set
let snapshot_resources = if snapshot_tx.is_some() && router_snapshot_threshold.is_some() {
let lock_name = format!("{}/{}", ROUTER_SNAPSHOT_LOCK, component.subject());
Some(SnapshotResources {
nats_client,
bucket_name,
etcd_client: etcd_client.clone(),
lock_name,
})
} else {
None
};
component.drt().runtime().secondary().spawn(async move {
let mut check_interval = tokio::time::interval(Duration::from_secs(1));
check_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
biased;
_ = cancellation_token.cancelled() => {
tracing::debug!("KV Router background task received cancellation signal");
// Clean up the queue and remove the durable consumer
if let Err(e) = nats_queue.shutdown(None).await {
tracing::warn!("Failed to shutdown NatsQueue: {e}");
}
break;
}
// Handle event consumption
result = nats_queue.dequeue_task(None) => {
match result {
Ok(Some(bytes)) => {
let event: RouterEvent = match serde_json::from_slice(&bytes) {
Ok(event) => event,
Err(e) => {
tracing::warn!("Failed to deserialize RouterEvent: {e:?}");
continue;
}
};
// Forward the RouterEvent to the indexer
if let Err(e) = kv_events_tx.send(event).await {
tracing::warn!(
"failed to send kv event to indexer; shutting down: {e:?}"
);
break;
}
},
Ok(None) => {
tracing::trace!("Dequeue timeout, continuing");
},
Err(e) => {
tracing::error!("Failed to dequeue task: {e:?}");
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
}
// Handle periodic stream checking and purging (only if snapshot_tx is provided)
_ = check_interval.tick() => {
let Some((snapshot_tx, resources)) = snapshot_tx.as_ref().zip(snapshot_resources.as_ref()) else {
continue;
};
// Check total messages in the stream
let Ok(message_count) = nats_queue.get_stream_messages().await else {
tracing::warn!("Failed to get stream message count");
continue;
};
// Guard clause: skip if message count is too low
let threshold = router_snapshot_threshold.unwrap_or(u32::MAX) as u64;
if message_count <= threshold {
continue;
}
tracing::info!("Stream has {message_count} messages, attempting to acquire lock for purge and snapshot");
// Try to acquire distributed lock
let Some(lock_response) = resources.lock().await else {
continue;
};
// Perform snapshot upload and purge
match perform_snapshot_and_purge(
&mut nats_queue,
snapshot_tx,
resources
).await {
Ok(_) => tracing::info!("Successfully performed purge and snapshot"),
Err(e) => tracing::error!("Failed to perform purge and snapshot: {e:?}"),
}
// Release the lock
resources.unlock(lock_response).await;
}
// Handle router deletion events
Some(event) = router_replicas_rx.recv() => {
let WatchEvent::Delete(kv) = event else {
// We only care about deletions for cleaning up consumers
continue;
};
let key = String::from_utf8_lossy(kv.key());
tracing::info!("Router deleted: {}", key);
// Extract the router UUID from the key (format: kv_routers/<model>/<uuid>)
let Some(router_uuid) = key.split('/').next_back() else {
tracing::warn!("Could not extract UUID from router key: {}", key);
continue;
};
// The consumer UUID is the router UUID
let consumer_to_delete = router_uuid.to_string();
tracing::info!("Attempting to delete orphaned consumer: {}", consumer_to_delete);
// Try to acquire cleanup lock before deleting consumer
match etcd_client
.lock(cleanup_lock_name.clone(), Some(etcd_client.lease_id()))
.await
{
Ok(lock_response) => {
tracing::debug!(
"Acquired cleanup lock for deleting consumer: {}",
consumer_to_delete
);
// Delete the consumer
if let Err(e) = nats_queue.shutdown(Some(consumer_to_delete.clone())).await {
tracing::warn!("Failed to delete consumer {}: {}", consumer_to_delete, e);
} else {
tracing::info!("Successfully deleted orphaned consumer: {}", consumer_to_delete);
}
// Release the lock
if let Err(e) = etcd_client.unlock(lock_response.key()).await {
tracing::warn!("Failed to release cleanup lock: {e:?}");
}
}
Err(e) => {
tracing::debug!(
"Could not acquire cleanup lock for consumer {}: {e:?}",
consumer_to_delete
);
}
}
}
}
}
// Clean up the queue and remove the durable consumer
if let Err(e) = nats_queue.shutdown(None).await {
tracing::warn!("Failed to shutdown NatsQueue: {e}");
}
});
Ok(())
}
/// Perform snapshot upload and purge operations
async fn perform_snapshot_and_purge(
nats_queue: &mut NatsQueue,
snapshot_tx: &mpsc::Sender<DumpRequest>,
resources: &SnapshotResources,
) -> anyhow::Result<()> {
// Snapshot before purge ensures we capture the current state before removing any messages.
// This guarantees the snapshot matches what has been acknowledged up to this point.
// First, request a snapshot from the indexer
let (resp_tx, resp_rx) = oneshot::channel();
let dump_req = DumpRequest { resp: resp_tx };
snapshot_tx
.send(dump_req)
.await
.map_err(|e| anyhow::anyhow!("Failed to send dump request: {e:?}"))?;
// Wait for the dump response
let events = resp_rx
.await
.map_err(|e| anyhow::anyhow!("Failed to receive dump response: {e:?}"))?;
// Upload the snapshot to NATS object store
let url = url::Url::parse(&format!(
"nats://{}/{}/{RADIX_STATE_FILE}",
resources.nats_client.addr(),
resources.bucket_name
))?;
resources
.nats_client
.object_store_upload_data(&events, url)
.await
.map_err(|e| anyhow::anyhow!("Failed to upload snapshot: {e:?}"))?;
tracing::info!(
"Successfully uploaded radix tree snapshot with {} events to bucket {}",
events.len(),
resources.bucket_name
);
// Now purge acknowledged messages from the stream
nats_queue.purge_acknowledged().await?;
Ok(())
}
...@@ -180,12 +180,12 @@ impl MockVllmEngine { ...@@ -180,12 +180,12 @@ impl MockVllmEngine {
component: Option<Component>, component: Option<Component>,
cancel_token: CancellationToken, cancel_token: CancellationToken,
) -> Result<()> { ) -> Result<()> {
tracing::info!("Creating metrics publisher"); tracing::debug!("Creating metrics publisher");
let metrics_publisher = Arc::new(WorkerMetricsPublisher::new()?); let metrics_publisher = Arc::new(WorkerMetricsPublisher::new()?);
tracing::info!("Metrics publisher created"); tracing::debug!("Metrics publisher created");
if let Some(comp) = component { if let Some(comp) = component {
tracing::info!("Creating metrics endpoint"); tracing::debug!("Creating metrics endpoint");
tokio::spawn({ tokio::spawn({
let publisher = metrics_publisher.clone(); let publisher = metrics_publisher.clone();
async move { async move {
...@@ -197,10 +197,10 @@ impl MockVllmEngine { ...@@ -197,10 +197,10 @@ impl MockVllmEngine {
// Give it a moment to start // Give it a moment to start
tokio::time::sleep(Duration::from_millis(100)).await; tokio::time::sleep(Duration::from_millis(100)).await;
tracing::info!("Metrics endpoint started (background)"); tracing::debug!("Metrics endpoint started (background)");
} }
tracing::info!("Starting metrics background tasks"); tracing::debug!("Starting metrics background tasks");
for (dp_rank, scheduler) in schedulers.iter().enumerate() { for (dp_rank, scheduler) in schedulers.iter().enumerate() {
let mut metrics_rx = scheduler.metrics_receiver(); let mut metrics_rx = scheduler.metrics_receiver();
let publisher = metrics_publisher.clone(); let publisher = metrics_publisher.clone();
...@@ -223,7 +223,7 @@ impl MockVllmEngine { ...@@ -223,7 +223,7 @@ impl MockVllmEngine {
} }
} }
_ = cancel_token.cancelled() => { _ = cancel_token.cancelled() => {
tracing::info!("Metrics publishing cancelled for DP rank {dp_rank}"); tracing::debug!("Metrics publishing cancelled for DP rank {dp_rank}");
break; break;
} }
} }
...@@ -241,14 +241,14 @@ impl MockVllmEngine { ...@@ -241,14 +241,14 @@ impl MockVllmEngine {
block_size: usize, block_size: usize,
cancel_token: CancellationToken, cancel_token: CancellationToken,
) -> Result<()> { ) -> Result<()> {
tracing::info!("Starting KV events publishing"); tracing::debug!("Starting KV events publishing");
// Only start KV events publishing if we have a component // Only start KV events publishing if we have a component
let Some(comp) = component else { let Some(comp) = component else {
tracing::warn!("No component provided, skipping KV events publishing"); tracing::warn!("No component provided, skipping KV events publishing");
return Ok(()); return Ok(());
}; };
tracing::info!("Component found for KV events publishing"); tracing::debug!("Component found for KV events publishing");
tracing::debug!("Getting worker_id"); tracing::debug!("Getting worker_id");
let worker_id = comp let worker_id = comp
...@@ -259,16 +259,16 @@ impl MockVllmEngine { ...@@ -259,16 +259,16 @@ impl MockVllmEngine {
// let worker_id = 0; // let worker_id = 0;
tracing::debug!("Worker_id set to: {worker_id}"); tracing::debug!("Worker_id set to: {worker_id}");
tracing::info!("Creating KV event publisher"); tracing::debug!("Creating KV event publisher");
let kv_event_publisher = Arc::new(KvEventPublisher::new( let kv_event_publisher = Arc::new(KvEventPublisher::new(
comp.clone(), comp.clone(),
worker_id, worker_id,
block_size as u32, block_size as u32,
None, None,
)?); )?);
tracing::info!("KV event publisher created"); tracing::debug!("KV event publisher created");
tracing::info!( tracing::debug!(
"Starting KV event background tasks for {} receivers", "Starting KV event background tasks for {} receivers",
kv_event_receivers.len() kv_event_receivers.len()
); );
...@@ -298,7 +298,7 @@ impl MockVllmEngine { ...@@ -298,7 +298,7 @@ impl MockVllmEngine {
} }
} }
_ = cancel_token.cancelled() => { _ = cancel_token.cancelled() => {
tracing::info!("KV events publishing cancelled for DP rank {dp_rank}"); tracing::debug!("KV events publishing cancelled for DP rank {dp_rank}");
break; break;
} }
} }
...@@ -476,7 +476,7 @@ impl AnnotatedMockEngine { ...@@ -476,7 +476,7 @@ impl AnnotatedMockEngine {
continue; continue;
} }
tracing::info!("Component service is now available, starting mocker engine"); tracing::debug!("Component service is now available, starting mocker engine");
// Start the engine with the component // Start the engine with the component
if let Err(e) = inner_clone.start(component).await { if let Err(e) = inner_clone.start(component).await {
...@@ -515,7 +515,7 @@ pub async fn make_mocker_engine( ...@@ -515,7 +515,7 @@ pub async fn make_mocker_engine(
args: MockEngineArgs, args: MockEngineArgs,
) -> Result<crate::backend::ExecutionContext, Error> { ) -> Result<crate::backend::ExecutionContext, Error> {
// Create the mocker engine // Create the mocker engine
tracing::info!("Creating mocker engine with config: {args:?}"); tracing::debug!("Creating mocker engine with config: {args:?}");
let annotated_engine = let annotated_engine =
AnnotatedMockEngine::new(MockVllmEngine::new(args), distributed_runtime, endpoint_id); AnnotatedMockEngine::new(MockVllmEngine::new(args), distributed_runtime, endpoint_id);
......
...@@ -25,8 +25,9 @@ use tokio::sync::{RwLock, mpsc}; ...@@ -25,8 +25,9 @@ use tokio::sync::{RwLock, mpsc};
use validator::Validate; use validator::Validate;
use etcd_client::{ use etcd_client::{
Certificate, Compare, CompareOp, DeleteOptions, GetOptions, Identity, PutOptions, PutResponse, Certificate, Compare, CompareOp, DeleteOptions, GetOptions, Identity, LockClient, LockOptions,
TlsOptions, Txn, TxnOp, TxnOpResponse, WatchOptions, Watcher, LockResponse, PutOptions, PutResponse, TlsOptions, Txn, TxnOp, TxnOpResponse, WatchOptions,
Watcher,
}; };
pub use etcd_client::{ConnectOptions, KeyValue, LeaseClient}; pub use etcd_client::{ConnectOptions, KeyValue, LeaseClient};
use tokio::time::{Duration, interval}; use tokio::time::{Duration, interval};
...@@ -306,6 +307,32 @@ impl Client { ...@@ -306,6 +307,32 @@ impl Client {
Ok(get_response.take_kvs()) Ok(get_response.take_kvs())
} }
/// Acquire a distributed lock using etcd's native lock mechanism
/// Returns a LockResponse that can be used to unlock later
pub async fn lock(
&self,
key: impl Into<Vec<u8>>,
lease_id: Option<i64>,
) -> Result<LockResponse> {
let mut lock_client = self.client.lock_client();
let id = lease_id.unwrap_or(self.lease_id());
let options = LockOptions::new().with_lease(id);
lock_client
.lock(key, Some(options))
.await
.map_err(|err| err.into())
}
/// Release a distributed lock using the key from the LockResponse
pub async fn unlock(&self, lock_key: impl Into<Vec<u8>>) -> Result<()> {
let mut lock_client = self.client.lock_client();
lock_client
.unlock(lock_key)
.await
.map_err(|err: etcd_client::Error| anyhow::anyhow!(err))?;
Ok(())
}
pub async fn kv_get_and_watch_prefix( pub async fn kv_get_and_watch_prefix(
&self, &self,
prefix: impl AsRef<str> + std::fmt::Display, prefix: impl AsRef<str> + std::fmt::Display,
......
...@@ -28,10 +28,12 @@ ...@@ -28,10 +28,12 @@
//! - `NATS_AUTH_CREDENTIALS_FILE`: the path to the credentials file //! - `NATS_AUTH_CREDENTIALS_FILE`: the path to the credentials file
//! //!
//! Note: `NATS_AUTH_USERNAME` and `NATS_AUTH_PASSWORD` must be used together. //! Note: `NATS_AUTH_USERNAME` and `NATS_AUTH_PASSWORD` must be used together.
use crate::traits::events::EventPublisher;
use crate::{Result, metrics::MetricsRegistry}; use crate::{Result, metrics::MetricsRegistry};
use async_nats::connection::State; use async_nats::connection::State;
use async_nats::{Subscriber, client, jetstream}; use async_nats::{Subscriber, client, jetstream};
use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use derive_builder::Builder; use derive_builder::Builder;
use futures::{StreamExt, TryStreamExt}; use futures::{StreamExt, TryStreamExt};
...@@ -429,6 +431,9 @@ pub fn url_to_bucket_and_key(url: &Url) -> anyhow::Result<(String, String)> { ...@@ -429,6 +431,9 @@ pub fn url_to_bucket_and_key(url: &Url) -> anyhow::Result<(String, String)> {
Ok((bucket.to_string(), key.to_string())) Ok((bucket.to_string(), key.to_string()))
} }
/// Default queue name for publishing events
pub const QUEUE_NAME: &str = "queue";
/// A queue implementation using NATS JetStream /// A queue implementation using NATS JetStream
pub struct NatsQueue { pub struct NatsQueue {
/// The name of the stream to use for the queue /// The name of the stream to use for the queue
...@@ -448,12 +453,32 @@ pub struct NatsQueue { ...@@ -448,12 +453,32 @@ pub struct NatsQueue {
} }
impl NatsQueue { impl NatsQueue {
/// Create a new NatsQueue with the given configuration /// Create a new NatsQueue with the default "worker-group" consumer
pub fn new(stream_name: String, nats_server: String, dequeue_timeout: time::Duration) -> Self { pub fn new(stream_name: String, nats_server: String, dequeue_timeout: time::Duration) -> Self {
// Sanitize stream name to remove path separators (like in Python version) // Sanitize stream name to remove path separators (like in Python version)
let sanitized_stream_name = stream_name.replace(['/', '\\'], "_"); // rupei: are we sure NATs stream name accepts '_'?
let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
let subject = format!("{sanitized_stream_name}.*");
let subject = format!("{}.*", sanitized_stream_name); Self {
stream_name: sanitized_stream_name,
nats_server,
dequeue_timeout,
client: None,
subject,
subscriber: None,
consumer_name: Some("worker-group".to_string()),
}
}
/// Create a new NatsQueue without a consumer (publisher-only mode)
pub fn new_without_consumer(
stream_name: String,
nats_server: String,
dequeue_timeout: time::Duration,
) -> Self {
let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
let subject = format!("{sanitized_stream_name}.*");
Self { Self {
stream_name: sanitized_stream_name, stream_name: sanitized_stream_name,
...@@ -474,8 +499,8 @@ impl NatsQueue { ...@@ -474,8 +499,8 @@ impl NatsQueue {
dequeue_timeout: time::Duration, dequeue_timeout: time::Duration,
consumer_name: String, consumer_name: String,
) -> Self { ) -> Self {
let sanitized_stream_name = stream_name.replace(['/', '\\'], "_"); let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
let subject = format!("{}.*", sanitized_stream_name); let subject = format!("{sanitized_stream_name}.*");
Self { Self {
stream_name: sanitized_stream_name, stream_name: sanitized_stream_name,
...@@ -490,39 +515,71 @@ impl NatsQueue { ...@@ -490,39 +515,71 @@ impl NatsQueue {
/// Connect to the NATS server and set up the stream and consumer /// Connect to the NATS server and set up the stream and consumer
pub async fn connect(&mut self) -> Result<()> { pub async fn connect(&mut self) -> Result<()> {
self.connect_with_reset(false).await
}
/// Connect to the NATS server and set up the stream and consumer, optionally resetting the stream
pub async fn connect_with_reset(&mut self, reset_stream: bool) -> Result<()> {
if self.client.is_none() { if self.client.is_none() {
// Create a new client // Create a new client
let client_options = Client::builder().server(self.nats_server.clone()).build()?; let client_options = Client::builder().server(self.nats_server.clone()).build()?;
let client = client_options.connect().await?; let client = client_options.connect().await?;
// Check if stream exists, if not create it // If reset_stream is true, delete the stream first
let streams = client.list_streams().await?; if reset_stream {
if !streams.contains(&self.stream_name) { match client.jetstream().delete_stream(&self.stream_name).await {
log::debug!("Creating NATS stream {}", self.stream_name); Ok(_) => {
log::debug!(
"Successfully deleted NATS stream {} for reset",
self.stream_name
);
}
Err(e) => {
log::debug!(
"Failed to delete NATS stream '{}' (may not exist): {}",
self.stream_name,
e
);
}
}
}
// Always try to create the stream (removes the race condition)
let stream_config = jetstream::stream::Config { let stream_config = jetstream::stream::Config {
name: self.stream_name.clone(), name: self.stream_name.clone(),
subjects: vec![self.subject.clone()], subjects: vec![self.subject.clone()],
max_age: time::Duration::from_secs(60 * 10), // 10 min // messages older than a hour in the stream will be automatically purged
max_age: time::Duration::from_secs(60 * 60),
..Default::default() ..Default::default()
}; };
client.jetstream().create_stream(stream_config).await?;
match client.jetstream().create_stream(stream_config).await {
Ok(_) => {
log::debug!("Successfully created NATS stream {}", self.stream_name);
}
Err(e) => {
// Log warning but continue - stream likely already exists
log::warn!(
"Failed to create NATS stream '{}': {}. Stream likely already exists, continuing...",
self.stream_name,
e
);
}
} }
// Create persistent subscriber // Create persistent subscriber only if consumer_name is set
if let Some(ref consumer_name) = self.consumer_name {
let consumer_config = jetstream::consumer::pull::Config { let consumer_config = jetstream::consumer::pull::Config {
durable_name: Some( durable_name: Some(consumer_name.clone()),
self.consumer_name
.clone()
.unwrap_or_else(|| "worker-group".to_string()),
),
..Default::default() ..Default::default()
}; };
let stream = client.jetstream().get_stream(&self.stream_name).await?; let stream = client.jetstream().get_stream(&self.stream_name).await?;
let subscriber = stream.create_consumer(consumer_config).await?; let subscriber = stream.create_consumer(consumer_config).await?;
self.subscriber = Some(subscriber); self.subscriber = Some(subscriber);
}
self.client = Some(client); self.client = Some(client);
} }
...@@ -546,28 +603,52 @@ impl NatsQueue { ...@@ -546,28 +603,52 @@ impl NatsQueue {
/// Shutdown the consumer by deleting it from the stream and closing the connection /// Shutdown the consumer by deleting it from the stream and closing the connection
/// This permanently removes the consumer from the server /// This permanently removes the consumer from the server
pub async fn shutdown(&mut self) -> Result<()> { ///
if let (Some(client), Some(consumer_name)) = (&self.client, &self.consumer_name) { /// If `consumer_name` is provided, that specific consumer will be deleted instead of the
/// current consumer. This allows deletion of other consumers on the same stream.
pub async fn shutdown(&mut self, consumer_name: Option<String>) -> Result<()> {
// Determine which consumer to delete
let target_consumer = consumer_name.as_ref().or(self.consumer_name.as_ref());
// Warn if deleting our own consumer via explicit parameter
if let Some(ref passed_name) = consumer_name
&& self.consumer_name.as_ref() == Some(passed_name)
{
log::warn!(
"Deleting our own consumer '{}' via explicit consumer_name parameter. \
Consider calling shutdown without arguments instead.",
passed_name
);
}
if let (Some(client), Some(consumer_to_delete)) = (&self.client, target_consumer) {
// Get the stream and delete the consumer // Get the stream and delete the consumer
let stream = client.jetstream().get_stream(&self.stream_name).await?; let stream = client.jetstream().get_stream(&self.stream_name).await?;
stream.delete_consumer(consumer_name).await.map_err(|e| { stream
anyhow::anyhow!("Failed to delete consumer {}: {}", consumer_name, e) .delete_consumer(consumer_to_delete)
.await
.map_err(|e| {
anyhow::anyhow!("Failed to delete consumer {}: {}", consumer_to_delete, e)
})?; })?;
log::debug!( log::debug!(
"Deleted consumer {} from stream {}", "Deleted consumer {} from stream {}",
consumer_name, consumer_to_delete,
self.stream_name self.stream_name
); );
} else { } else {
log::warn!( log::debug!(
"Cannot shutdown consumer: client or consumer_name is None (client: {:?}, consumer_name: {:?})", "Cannot shutdown consumer: client or target consumer is None (client: {:?}, target_consumer: {:?})",
self.client.is_some(), self.client.is_some(),
self.consumer_name.is_some() target_consumer.is_some()
); );
} }
// Then close the connection // Only close the connection if we deleted our own consumer
if consumer_name.is_none() {
self.close().await self.close().await
} else {
Ok(())
}
} }
/// Count the number of consumers for the stream /// Count the number of consumers for the stream
...@@ -648,6 +729,19 @@ impl NatsQueue { ...@@ -648,6 +729,19 @@ impl NatsQueue {
} }
} }
/// Get the total number of messages currently in the stream
pub async fn get_stream_messages(&mut self) -> Result<u64> {
self.ensure_connection().await?;
if let Some(client) = &self.client {
let mut stream = client.jetstream().get_stream(&self.stream_name).await?;
let info = stream.info().await?;
Ok(info.state.messages)
} else {
Err(anyhow::anyhow!("Client not connected"))
}
}
/// Purge messages from the stream up to (but not including) the specified sequence number /// Purge messages from the stream up to (but not including) the specified sequence number
/// This permanently removes messages and affects all consumers of the stream /// This permanently removes messages and affects all consumers of the stream
pub async fn purge_up_to_sequence(&self, sequence: u64) -> Result<()> { pub async fn purge_up_to_sequence(&self, sequence: u64) -> Result<()> {
...@@ -727,7 +821,7 @@ impl NatsQueue { ...@@ -727,7 +821,7 @@ impl NatsQueue {
self.purge_up_to_sequence(purge_sequence).await?; self.purge_up_to_sequence(purge_sequence).await?;
log::info!( log::debug!(
"Purged stream {} up to acknowledged sequence {} (purged up to sequence {})", "Purged stream {} up to acknowledged sequence {} (purged up to sequence {})",
self.stream_name, self.stream_name,
min_ack_sequence, min_ack_sequence,
...@@ -745,6 +839,49 @@ impl NatsQueue { ...@@ -745,6 +839,49 @@ impl NatsQueue {
} }
} }
#[async_trait]
impl EventPublisher for NatsQueue {
fn subject(&self) -> String {
self.stream_name.clone()
}
async fn publish(
&self,
event_name: impl AsRef<str> + Send + Sync,
event: &(impl Serialize + Send + Sync),
) -> Result<()> {
let bytes = serde_json::to_vec(event)?;
self.publish_bytes(event_name, bytes).await
}
async fn publish_bytes(
&self,
event_name: impl AsRef<str> + Send + Sync,
bytes: Vec<u8>,
) -> Result<()> {
// We expect the stream to be always suffixed with "queue"
// This suffix itself is nothing special, just a repo standard
if event_name.as_ref() != QUEUE_NAME {
tracing::warn!(
"Expected event_name to be '{}', but got '{}'",
QUEUE_NAME,
event_name.as_ref()
);
}
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
// Note: enqueue_task requires &mut self, but EventPublisher requires &self
// We need to ensure the client is connected and use it directly
if let Some(client) = &self.client {
client.jetstream().publish(subject, bytes.into()).await?;
Ok(())
} else {
Err(anyhow::anyhow!("Client not connected"))
}
}
}
/// Prometheus metrics that mirror the NATS client statistics (in primitive types) /// Prometheus metrics that mirror the NATS client statistics (in primitive types)
/// to be used for the System Status Server. /// to be used for the System Status Server.
/// ///
...@@ -966,6 +1103,20 @@ mod tests { ...@@ -966,6 +1103,20 @@ mod tests {
let nats_server = "nats://localhost:4222".to_string(); let nats_server = "nats://localhost:4222".to_string();
let timeout = time::Duration::from_secs(0); let timeout = time::Duration::from_secs(0);
// Connect to NATS client first to delete stream if it exists
let client_options = Client::builder()
.server(nats_server.clone())
.build()
.expect("Failed to build client options");
let client = client_options
.connect()
.await
.expect("Failed to connect to NATS");
// Delete the stream if it exists (to ensure clean start)
let _ = client.jetstream().delete_stream(&stream_name).await;
// Create two consumers with different names for the same stream // Create two consumers with different names for the same stream
let consumer1_name = format!("consumer-{}", Uuid::new_v4()); let consumer1_name = format!("consumer-{}", Uuid::new_v4());
let consumer2_name = format!("consumer-{}", Uuid::new_v4()); let consumer2_name = format!("consumer-{}", Uuid::new_v4());
...@@ -977,46 +1128,49 @@ mod tests { ...@@ -977,46 +1128,49 @@ mod tests {
consumer1_name, consumer1_name,
); );
let mut queue2 = NatsQueue::new_with_consumer( // Connect queue1 first (it will create the stream)
stream_name.clone(),
nats_server.clone(),
timeout,
consumer2_name,
);
// Connect both queues (first one creates the stream, second one reuses it)
queue1.connect().await.expect("Failed to connect queue1"); queue1.connect().await.expect("Failed to connect queue1");
queue2.connect().await.expect("Failed to connect queue2");
// Send 4 messages // Send 4 messages using the EventPublisher trait
let messages = vec![ let message_strings = [
Bytes::from("message1"), "message1".to_string(),
Bytes::from("message2"), "message2".to_string(),
Bytes::from("message3"), "message3".to_string(),
Bytes::from("message4"), "message4".to_string(),
]; ];
for msg in &messages { // Using the EventPublisher trait to publish messages
for (idx, msg) in message_strings.iter().enumerate() {
queue1 queue1
.enqueue_task(msg.clone()) .publish("queue", msg)
.await .await
.expect("Failed to enqueue message"); .unwrap_or_else(|_| panic!("Failed to publish message {}", idx + 1));
} }
// Convert messages to JSON-serialized Bytes for comparison
let messages: Vec<Bytes> = message_strings
.iter()
.map(|s| Bytes::from(serde_json::to_vec(s).unwrap()))
.collect();
// Give JetStream a moment to persist the messages // Give JetStream a moment to persist the messages
tokio::time::sleep(time::Duration::from_millis(100)).await; tokio::time::sleep(time::Duration::from_millis(100)).await;
// Get stream info to find the sequence numbers // Now create and connect queue2 and queue3 AFTER messages are published (to test persistence)
// We need to know the sequence of message 2 to purge up to it let mut queue2 = NatsQueue::new_with_consumer(
let client_options = Client::builder() stream_name.clone(),
.server(nats_server.clone()) nats_server.clone(),
.build() timeout,
.expect("Failed to build client options"); consumer2_name,
);
let client = client_options // Create a third queue without consumer (publisher-only)
.connect() let mut queue3 =
.await NatsQueue::new_without_consumer(stream_name.clone(), nats_server.clone(), timeout);
.expect("Failed to connect to NATS");
// Connect queue2 and queue3 after messages are already published
queue2.connect().await.expect("Failed to connect queue2");
queue3.connect().await.expect("Failed to connect queue3");
// Purge the first two messages (sequence 1 and 2) // Purge the first two messages (sequence 1 and 2)
// Note: JetStream sequences start at 1, and purge is exclusive of the sequence number // Note: JetStream sequences start at 1, and purge is exclusive of the sequence number
...@@ -1127,7 +1281,10 @@ mod tests { ...@@ -1127,7 +1281,10 @@ mod tests {
queue1.connect().await.expect("Failed to reconnect queue1"); queue1.connect().await.expect("Failed to reconnect queue1");
// Shutdown consumer 1 and verify via consumer 2 that there is only one consumer left // Shutdown consumer 1 and verify via consumer 2 that there is only one consumer left
queue1.shutdown().await.expect("Failed to shutdown queue1"); queue1
.shutdown(None)
.await
.expect("Failed to shutdown queue1");
let consumer_count = queue2 let consumer_count = queue2
.count_consumers() .count_consumers()
......
...@@ -6,7 +6,8 @@ import json ...@@ -6,7 +6,8 @@ import json
import logging import logging
import os import os
import random import random
from typing import Any, Dict import string
from typing import Any, Dict, Optional
import aiohttp import aiohttp
import pytest import pytest
...@@ -25,6 +26,12 @@ SPEEDUP_RATIO = 10.0 ...@@ -25,6 +26,12 @@ SPEEDUP_RATIO = 10.0
NUM_REQUESTS = 100 NUM_REQUESTS = 100
PORT = 8090 # Starting port for mocker instances PORT = 8090 # Starting port for mocker instances
def generate_random_suffix() -> str:
"""Generate a 10-character random alphabetic suffix for namespace isolation."""
return "".join(random.choices(string.ascii_lowercase, k=10))
# Shared test payload for all tests # Shared test payload for all tests
TEST_PAYLOAD: Dict[str, Any] = { TEST_PAYLOAD: Dict[str, Any] = {
"model": MODEL_NAME, "model": MODEL_NAME,
...@@ -39,10 +46,19 @@ TEST_PAYLOAD: Dict[str, Any] = { ...@@ -39,10 +46,19 @@ TEST_PAYLOAD: Dict[str, Any] = {
} }
class MockerProcess(ManagedProcess): class MockerProcess:
"""Manages a single mocker engine instance""" """Manages multiple mocker engine instances with the same namespace"""
def __init__(self, request, endpoint: str, mocker_args_file: str): def __init__(self, request, mocker_args_file: str, num_mockers: int = 1):
# Generate a unique namespace suffix shared by all mockers
namespace_suffix = generate_random_suffix()
self.namespace = f"test-namespace-{namespace_suffix}"
self.endpoint = f"dyn://{self.namespace}.mocker.generate"
self.num_mockers = num_mockers
self.mocker_processes = []
# Create multiple mocker processes with the same namespace
for i in range(num_mockers):
command = [ command = [
"python", "python",
"-m", "-m",
...@@ -52,10 +68,10 @@ class MockerProcess(ManagedProcess): ...@@ -52,10 +68,10 @@ class MockerProcess(ManagedProcess):
"--extra-engine-args", "--extra-engine-args",
mocker_args_file, mocker_args_file,
"--endpoint", "--endpoint",
endpoint, self.endpoint,
] ]
super().__init__( process = ManagedProcess(
command=command, command=command,
timeout=60, timeout=60,
display_output=True, display_output=True,
...@@ -64,7 +80,21 @@ class MockerProcess(ManagedProcess): ...@@ -64,7 +80,21 @@ class MockerProcess(ManagedProcess):
log_dir=request.node.name, log_dir=request.node.name,
terminate_existing=False, terminate_existing=False,
) )
self.endpoint = endpoint self.mocker_processes.append(process)
logger.info(f"Created mocker instance {i} with endpoint: {self.endpoint}")
def __enter__(self):
"""Start all mocker processes"""
for i, process in enumerate(self.mocker_processes):
logger.info(f"Starting mocker instance {i}")
process.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop all mocker processes"""
for i, process in enumerate(self.mocker_processes):
logger.info(f"Stopping mocker instance {i}")
process.__exit__(exc_type, exc_val, exc_tb)
class KVRouterProcess(ManagedProcess): class KVRouterProcess(ManagedProcess):
...@@ -151,11 +181,14 @@ def get_runtime(): ...@@ -151,11 +181,14 @@ def get_runtime():
return _runtime_instance return _runtime_instance
async def check_registration_in_etcd(expected_count: int): async def check_registration_in_etcd(
expected_count: int, endpoint: Optional[str] = None
):
"""Check that the expected number of KV routers are registered in etcd. """Check that the expected number of KV routers are registered in etcd.
Args: Args:
expected_count: The number of KV routers expected to be registered expected_count: The number of KV routers expected to be registered
endpoint: The endpoint string to extract component path from (e.g., "dyn://namespace.component.generate")
Returns: Returns:
List of registered KV router entries from etcd List of registered KV router entries from etcd
...@@ -163,10 +196,27 @@ async def check_registration_in_etcd(expected_count: int): ...@@ -163,10 +196,27 @@ async def check_registration_in_etcd(expected_count: int):
runtime = get_runtime() runtime = get_runtime()
etcd = runtime.etcd_client() etcd = runtime.etcd_client()
# Extract component path from endpoint if provided
prefix = "kv_routers/"
if endpoint:
# Parse endpoint format: dyn://namespace.component.endpoint_suffix
# Extract namespace and component, ignoring the endpoint suffix (e.g., "generate")
endpoint_parts = endpoint.replace("dyn://", "").split(".")
if len(endpoint_parts) >= 2:
namespace = endpoint_parts[0]
component = endpoint_parts[1]
component_path = f"{namespace}/{component}"
prefix = f"kv_routers/{component_path}/"
logger.info(
f"Checking for KV routers with component path: {component_path}"
)
# Check for kv_routers in etcd # Check for kv_routers in etcd
# The KV router registers itself with key format: kv_routers/{model_name}/{uuid} # The KV router registers itself with key format: kv_routers/{component_path}/{uuid}
kv_routers = await etcd.kv_get_prefix("kv_routers/") kv_routers = await etcd.kv_get_prefix(prefix)
logger.info(f"Found {len(kv_routers)} KV router(s) registered in etcd") logger.info(
f"Found {len(kv_routers)} KV router(s) registered in etcd under prefix: {prefix}"
)
# Assert we have the expected number of KV routers registered # Assert we have the expected number of KV routers registered
assert ( assert (
...@@ -248,9 +298,6 @@ def test_mocker_kv_router(request, runtime_services): ...@@ -248,9 +298,6 @@ def test_mocker_kv_router(request, runtime_services):
with open(mocker_args_file, "w") as f: with open(mocker_args_file, "w") as f:
json.dump(mocker_args, f) json.dump(mocker_args, f)
# Start mocker instances
mocker_processes = []
try: try:
# Start KV router (frontend) # Start KV router (frontend)
frontend_port = PORT frontend_port = PORT
...@@ -259,17 +306,11 @@ def test_mocker_kv_router(request, runtime_services): ...@@ -259,17 +306,11 @@ def test_mocker_kv_router(request, runtime_services):
kv_router = KVRouterProcess(request, frontend_port) kv_router = KVRouterProcess(request, frontend_port)
kv_router.__enter__() kv_router.__enter__()
for i in range(NUM_MOCKERS): # Start mocker instances
# Use unique endpoints for each mocker logger.info(f"Starting {NUM_MOCKERS} mocker instances")
endpoint = "dyn://test-namespace.mocker.generate" mockers = MockerProcess(request, mocker_args_file, num_mockers=NUM_MOCKERS)
logger.info(f"Starting mocker instance {i} on endpoint {endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
mocker = MockerProcess(request, endpoint, mocker_args_file)
mocker_processes.append(mocker)
# Start all mockers
for mocker in mocker_processes:
mocker.__enter__()
# Use async to send requests concurrently for better performance # Use async to send requests concurrently for better performance
asyncio.run( asyncio.run(
...@@ -285,15 +326,18 @@ def test_mocker_kv_router(request, runtime_services): ...@@ -285,15 +326,18 @@ def test_mocker_kv_router(request, runtime_services):
logger.info(f"Successfully completed {NUM_REQUESTS} requests") logger.info(f"Successfully completed {NUM_REQUESTS} requests")
# Check etcd registration - expect 1 KV router # Check etcd registration - expect 1 KV router
asyncio.run(check_registration_in_etcd(expected_count=1)) # Use the mockers' endpoint since all mockers share the same component path
asyncio.run(
check_registration_in_etcd(expected_count=1, endpoint=mockers.endpoint)
)
finally: finally:
# Clean up # Clean up
if "kv_router" in locals(): if "kv_router" in locals():
kv_router.__exit__(None, None, None) kv_router.__exit__(None, None, None)
for mocker in mocker_processes: if "mockers" in locals():
mocker.__exit__(None, None, None) mockers.__exit__(None, None, None)
if os.path.exists(mocker_args_file): if os.path.exists(mocker_args_file):
os.unlink(mocker_args_file) os.unlink(mocker_args_file)
...@@ -316,8 +360,6 @@ def test_mocker_two_kv_router(request, runtime_services): ...@@ -316,8 +360,6 @@ def test_mocker_two_kv_router(request, runtime_services):
with open(mocker_args_file, "w") as f: with open(mocker_args_file, "w") as f:
json.dump(mocker_args, f) json.dump(mocker_args, f)
# Start mocker instances
mocker_processes = []
kv_routers = [] kv_routers = []
try: try:
...@@ -330,17 +372,11 @@ def test_mocker_two_kv_router(request, runtime_services): ...@@ -330,17 +372,11 @@ def test_mocker_two_kv_router(request, runtime_services):
kv_router.__enter__() kv_router.__enter__()
kv_routers.append(kv_router) kv_routers.append(kv_router)
for i in range(NUM_MOCKERS): # Start mocker instances
# Use unique endpoints for each mocker logger.info(f"Starting {NUM_MOCKERS} mocker instances")
endpoint = "dyn://test-namespace.mocker.generate" mockers = MockerProcess(request, mocker_args_file, num_mockers=NUM_MOCKERS)
logger.info(f"Starting mocker instance {i} on endpoint {endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
mocker = MockerProcess(request, endpoint, mocker_args_file)
mocker_processes.append(mocker)
# Start all mockers
for mocker in mocker_processes:
mocker.__enter__()
# Build URLs for both routers # Build URLs for both routers
router_urls = [ router_urls = [
...@@ -361,7 +397,10 @@ def test_mocker_two_kv_router(request, runtime_services): ...@@ -361,7 +397,10 @@ def test_mocker_two_kv_router(request, runtime_services):
) )
# Check etcd registration - expect 2 KV routers # Check etcd registration - expect 2 KV routers
asyncio.run(check_registration_in_etcd(expected_count=2)) # Use the mockers' endpoint since all mockers share the same component path
asyncio.run(
check_registration_in_etcd(expected_count=2, endpoint=mockers.endpoint)
)
finally: finally:
# Clean up routers # Clean up routers
...@@ -369,8 +408,8 @@ def test_mocker_two_kv_router(request, runtime_services): ...@@ -369,8 +408,8 @@ def test_mocker_two_kv_router(request, runtime_services):
kv_router.__exit__(None, None, None) kv_router.__exit__(None, None, None)
# Clean up mockers # Clean up mockers
for mocker in mocker_processes: if "mockers" in locals():
mocker.__exit__(None, None, None) mockers.__exit__(None, None, None)
if os.path.exists(mocker_args_file): if os.path.exists(mocker_args_file):
os.unlink(mocker_args_file) os.unlink(mocker_args_file)
...@@ -437,13 +476,10 @@ def test_mocker_kv_router_overload_503(request, runtime_services): ...@@ -437,13 +476,10 @@ def test_mocker_kv_router_overload_503(request, runtime_services):
kv_router.__enter__() kv_router.__enter__()
# Start single mocker instance with limited resources # Start single mocker instance with limited resources
endpoint = "dyn://test-namespace.mocker.generate" logger.info("Starting single mocker instance with limited resources")
logger.info( mockers = MockerProcess(request, mocker_args_file, num_mockers=1)
f"Starting single mocker instance with limited resources on endpoint {endpoint}" logger.info(f"Mocker using endpoint: {mockers.endpoint}")
) mockers.__enter__()
mocker = MockerProcess(request, endpoint, mocker_args_file)
mocker.__enter__()
url = f"http://localhost:{frontend_port}/v1/chat/completions" url = f"http://localhost:{frontend_port}/v1/chat/completions"
...@@ -545,8 +581,8 @@ def test_mocker_kv_router_overload_503(request, runtime_services): ...@@ -545,8 +581,8 @@ def test_mocker_kv_router_overload_503(request, runtime_services):
if "kv_router" in locals(): if "kv_router" in locals():
kv_router.__exit__(None, None, None) kv_router.__exit__(None, None, None)
if "mocker" in locals(): if "mockers" in locals():
mocker.__exit__(None, None, None) mockers.__exit__(None, None, None)
if os.path.exists(mocker_args_file): if os.path.exists(mocker_args_file):
os.unlink(mocker_args_file) os.unlink(mocker_args_file)
...@@ -570,28 +606,19 @@ def test_kv_push_router_bindings(request, runtime_services): ...@@ -570,28 +606,19 @@ def test_kv_push_router_bindings(request, runtime_services):
with open(mocker_args_file, "w") as f: with open(mocker_args_file, "w") as f:
json.dump(mocker_args, f) json.dump(mocker_args, f)
# Start mocker instances
mocker_processes = []
try: try:
# Start mockers # Start mocker instances
for i in range(NUM_MOCKERS): logger.info(f"Starting {NUM_MOCKERS} mocker instances")
# Use unique endpoints for each mocker mockers = MockerProcess(request, mocker_args_file, num_mockers=NUM_MOCKERS)
endpoint = "dyn://test-namespace.mocker.generate" logger.info(f"All mockers using endpoint: {mockers.endpoint}")
logger.info(f"Starting mocker instance {i} on endpoint {endpoint}") mockers.__enter__()
mocker = MockerProcess(request, endpoint, mocker_args_file)
mocker_processes.append(mocker)
# Start all mockers
for mocker in mocker_processes:
mocker.__enter__()
# Wait for mockers to be ready by sending a dummy request with retry # Wait for mockers to be ready by sending a dummy request with retry
async def wait_for_mockers_ready(): async def wait_for_mockers_ready():
"""Send a dummy request to ensure mockers are ready""" """Send a dummy request to ensure mockers are ready"""
runtime = get_runtime() runtime = get_runtime()
namespace = runtime.namespace("test-namespace") # Use the namespace from the mockers
namespace = runtime.namespace(mockers.namespace)
component = namespace.component("mocker") component = namespace.component("mocker")
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
...@@ -653,7 +680,8 @@ def test_kv_push_router_bindings(request, runtime_services): ...@@ -653,7 +680,8 @@ def test_kv_push_router_bindings(request, runtime_services):
async def test_kv_push_router(): async def test_kv_push_router():
# Get runtime and create endpoint # Get runtime and create endpoint
runtime = get_runtime() runtime = get_runtime()
namespace = runtime.namespace("test-namespace") # Use the namespace from the mockers
namespace = runtime.namespace(mockers.namespace)
component = namespace.component("mocker") component = namespace.component("mocker")
endpoint = component.endpoint("generate") endpoint = component.endpoint("generate")
...@@ -785,8 +813,245 @@ def test_kv_push_router_bindings(request, runtime_services): ...@@ -785,8 +813,245 @@ def test_kv_push_router_bindings(request, runtime_services):
finally: finally:
# Clean up mockers # Clean up mockers
for mocker in mocker_processes: if "mockers" in locals():
mocker.__exit__(None, None, None) mockers.__exit__(None, None, None)
if os.path.exists(mocker_args_file):
os.unlink(mocker_args_file)
@pytest.mark.pre_merge
def test_indexers_sync(request, runtime_services):
"""
Test that two KV routers have synchronized indexer states after processing requests.
This test verifies that both routers converge to the same internal state.
"""
# runtime_services starts etcd and nats
logger.info("Starting indexers sync test")
# Create mocker args file
mocker_args = {"speedup_ratio": SPEEDUP_RATIO, "block_size": BLOCK_SIZE}
mocker_args_file = os.path.join(request.node.name, "mocker_args.json")
with open(mocker_args_file, "w") as f:
json.dump(mocker_args, f)
try:
# Start mocker instances
logger.info(f"Starting {NUM_MOCKERS} mocker instances")
mockers = MockerProcess(request, mocker_args_file, num_mockers=NUM_MOCKERS)
logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mockers.__enter__()
# Run the async test
async def test_sync():
# Get runtime and create endpoint
runtime = get_runtime()
# Use the namespace from the mockers
namespace = runtime.namespace(mockers.namespace)
component = namespace.component("mocker")
endpoint = component.endpoint("generate")
# Create first KV router
from dynamo._core import KvPushRouter, KvRouterConfig
# First router with default router_reset_states=True
kv_router_config = KvRouterConfig(
router_snapshot_threshold=20, router_reset_states=True
)
async def send_requests_to_router(router, num_requests, router_name):
# First, send a test request with retry to ensure router is ready
max_retries = 8
wait_time = 1
for attempt in range(max_retries + 1):
try:
logger.info(
f"Testing {router_name} readiness (attempt {attempt + 1})"
)
# Generate small test token IDs
test_token_ids = [random.randint(1, 10000) for _ in range(10)]
stream = await router.generate(
token_ids=test_token_ids, # Small test
model=MODEL_NAME,
stop_conditions={"max_tokens": 1},
)
# Just consume the stream to verify it works
async for _ in stream:
pass
logger.info(f"{router_name} is ready!")
break
except Exception as e:
logger.warning(
f"{router_name} attempt {attempt + 1} failed: {e}"
)
if attempt < max_retries:
await asyncio.sleep(wait_time)
wait_time *= 2
else:
raise RuntimeError(
f"Failed to connect to {router_name} after retries"
)
# Now send the actual requests
tasks = []
for i in range(num_requests):
# Generate random token IDs for each request
request_tokens = [random.randint(1, 10000) for _ in range(30)]
async def single_request(req_id, tokens):
try:
stream = await router.generate(
token_ids=tokens,
model=MODEL_NAME,
stop_conditions={"max_tokens": 10},
)
# Consume the stream
async for _ in stream:
pass
return True
except Exception as e:
logger.error(
f"Request {req_id} to {router_name} failed: {e}"
)
return False
tasks.append(asyncio.create_task(single_request(i, request_tokens)))
results = await asyncio.gather(*tasks)
successful = sum(1 for r in results if r)
logger.info(
f"Completed {successful}/{num_requests} requests for {router_name}"
)
return successful
logger.info("Creating first KV router")
kv_push_router1 = KvPushRouter(
endpoint=endpoint,
block_size=BLOCK_SIZE,
kv_router_config=kv_router_config,
)
# Send 25 requests to first router with initial retry loop
logger.info("Sending 25 requests to first router")
# Send requests to first router
successful1 = await send_requests_to_router(kv_push_router1, 25, "Router 1")
assert (
successful1 == 25
), f"Expected 25 successful requests to router 1, got {successful1}"
# Wait for a second before creating the second router
logger.info("Waiting for 1 second before creating second router")
await asyncio.sleep(1)
# Launch second router with router_reset_states=False
logger.info("Creating second KV router with router_reset_states=False")
kv_router_config2 = KvRouterConfig(
router_snapshot_threshold=20, router_reset_states=False
)
kv_push_router2 = KvPushRouter(
endpoint=endpoint,
block_size=BLOCK_SIZE,
kv_router_config=kv_router_config2,
)
# Send 25 requests to second router with initial retry loop
logger.info("Sending 25 requests to second router")
successful2 = await send_requests_to_router(kv_push_router2, 25, "Router 2")
assert (
successful2 == 25
), f"Expected 25 successful requests to router 2, got {successful2}"
# Wait for all requests to complete (they should already be complete from gather)
# Wait another 1 second for internal synchronization
logger.info("Waiting for final synchronization")
await asyncio.sleep(1)
# Dump states from both routers
logger.info("Dumping states from both routers")
state1_json = await kv_push_router1.dump_events()
state2_json = await kv_push_router2.dump_events()
# Parse JSON strings for comparison
state1 = json.loads(state1_json)
state2 = json.loads(state2_json)
# Sort both states for comparison (order might differ due to HashMap iteration and sharding)
def sort_key(event):
data = event["event"]["data"]["stored"]
blocks = data["blocks"]
first_block = blocks[0]
return (
event["worker_id"],
first_block["tokens_hash"],
data["parent_hash"],
)
sorted_state1 = sorted(state1, key=sort_key)
sorted_state2 = sorted(state2, key=sort_key)
# Verify they are equal
logger.info(f"Router 1 has {len(sorted_state1)} events")
logger.info(f"Router 2 has {len(sorted_state2)} events")
# Compare states one by one and only show differences
if len(sorted_state1) != len(sorted_state2):
logger.error(
f"Router 1 has {len(sorted_state1)} events, Router 2 has {len(sorted_state2)} events"
)
assert False, "Router states have different numbers of events"
differences = []
for i, (state1_item, state2_item) in enumerate(
zip(sorted_state1, sorted_state2)
):
# Create copies without event_id for comparison
item1_compare = state1_item.copy()
item2_compare = state2_item.copy()
# Remove event_id from the nested event structure
if "event" in item1_compare and "event_id" in item1_compare["event"]:
del item1_compare["event"]["event_id"]
if "event" in item2_compare and "event_id" in item2_compare["event"]:
del item2_compare["event"]["event_id"]
if item1_compare != item2_compare:
differences.append(
{
"index": i,
"router1_state": state1_item,
"router2_state": state2_item,
}
)
if differences:
error_msg = f"Router states are not equal. Found {len(differences)} differences:\n"
for diff in differences:
error_msg += f"\nDifference at index {diff['index']}:\n"
error_msg += (
f"Router 1: {json.dumps(diff['router1_state'], indent=2)}\n"
)
error_msg += (
f"Router 2: {json.dumps(diff['router2_state'], indent=2)}\n"
)
error_msg += "-" * 80 + "\n"
assert False, error_msg
logger.info("Successfully verified that both router states are equal")
# Run the async test
asyncio.run(test_sync())
logger.info("Indexers sync test completed successfully")
finally:
# Clean up mockers
if "mockers" in locals():
mockers.__exit__(None, None, None)
if os.path.exists(mocker_args_file): if os.path.exists(mocker_args_file):
os.unlink(mocker_args_file) os.unlink(mocker_args_file)
...@@ -821,8 +1086,6 @@ def test_query_instance_id_returns_worker_and_tokens(request, runtime_services): ...@@ -821,8 +1086,6 @@ def test_query_instance_id_returns_worker_and_tokens(request, runtime_services):
with open(mocker_args_file, "w") as f: with open(mocker_args_file, "w") as f:
json.dump(mocker_args, f) json.dump(mocker_args, f)
mocker_processes = []
try: try:
# Start KV router (frontend) # Start KV router (frontend)
frontend_port = PORT + 30 # Use unique port to avoid conflicts frontend_port = PORT + 30 # Use unique port to avoid conflicts
...@@ -831,14 +1094,10 @@ def test_query_instance_id_returns_worker_and_tokens(request, runtime_services): ...@@ -831,14 +1094,10 @@ def test_query_instance_id_returns_worker_and_tokens(request, runtime_services):
kv_router.__enter__() kv_router.__enter__()
# Start multiple mocker engines to ensure worker selection logic # Start multiple mocker engines to ensure worker selection logic
endpoint = "dyn://test-namespace.mocker.generate" logger.info(f"Starting {NUM_MOCKERS} mocker instances")
for i in range(NUM_MOCKERS): mockers = MockerProcess(request, mocker_args_file, num_mockers=NUM_MOCKERS)
logger.info(f"Starting mocker instance {i} on endpoint {endpoint}") logger.info(f"All mockers using endpoint: {mockers.endpoint}")
mocker = MockerProcess(request, endpoint, mocker_args_file) mockers.__enter__()
mocker_processes.append(mocker)
for mocker in mocker_processes:
mocker.__enter__()
url = f"http://localhost:{frontend_port}/v1/chat/completions" url = f"http://localhost:{frontend_port}/v1/chat/completions"
...@@ -981,7 +1240,7 @@ def test_query_instance_id_returns_worker_and_tokens(request, runtime_services): ...@@ -981,7 +1240,7 @@ def test_query_instance_id_returns_worker_and_tokens(request, runtime_services):
finally: finally:
if "kv_router" in locals(): if "kv_router" in locals():
kv_router.__exit__(None, None, None) kv_router.__exit__(None, None, None)
for mocker in mocker_processes: if "mockers" in locals():
mocker.__exit__(None, None, None) mockers.__exit__(None, None, None)
if os.path.exists(mocker_args_file): if os.path.exists(mocker_args_file):
os.unlink(mocker_args_file) os.unlink(mocker_args_file)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment