Unverified Commit 488c8709 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: Router warm restarts via durable KV event consumers and radix snapshotting (#2756)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 3f09c395
......@@ -149,6 +149,19 @@ def parse_args():
default=False,
help="KV Router: Enable replica synchronization across multiple router instances. When true, routers will publish and subscribe to events to maintain consistent state.",
)
parser.add_argument(
"--router-snapshot-threshold",
type=int,
default=10000,
help="KV Router: Number of messages in stream before triggering a snapshot. Defaults to 10000.",
)
parser.add_argument(
"--router-persist-states",
action="store_false",
dest="router_reset_states",
default=True,
help="KV Router: Persist router state on startup. Keep existing state from stream and object store (default: reset states).",
)
parser.add_argument(
"--busy-threshold",
type=float,
......@@ -212,6 +225,8 @@ async def async_main():
router_temperature=flags.router_temperature,
use_kv_events=flags.use_kv_events,
router_replica_sync=flags.router_replica_sync,
router_snapshot_threshold=flags.router_snapshot_threshold,
router_reset_states=flags.router_reset_states,
)
elif flags.router_mode == "random":
router_mode = RouterMode.Random
......
......@@ -15,7 +15,7 @@ When KV blocks are created or removed, the engine notifies the Dynamo router, wh
To evaluate the benefits of KV-aware routing, compare your workload's performance using `--router-mode random|round-robin` against KV-aware routing.
The KV-aware routing arguments:
The main KV-aware routing arguments:
- `--kv-overlap-score-weight`: Controls the importance of prefix cache overlaps in prefill cost calculations. Higher values improve Time To First Token (TTFT) at the cost of Inter-Token Latency (ITL). When set to 0, the router ignores prefix caches and uses pure load balancing. Defaults to 1.
......@@ -23,7 +23,11 @@ The KV-aware routing arguments:
- `--use-kv-events`/`--no-kv-events`: Determines how the router tracks cached blocks. When enabled (default), uses `KvIndexer` to monitor block creation and deletion events. When disabled, uses `ApproxKvIndexer`, which estimates cache hits based on a fixed time window (120s). Disable this if your backend doesn't support KV events.
- `--router-replica-sync`: Enables NATS-based state synchronization between router replicas. When enabled, routers share their KV cache distribution and active sequence information, ensuring optimal routing decisions across multiple router instances. This improves fault tolerance and routing accuracy in distributed deployments. Disabled by default.
- `--router-replica-sync`: Enables NATS-based synchronization of local routing decisions between router replicas. When enabled, routers share their active sequence information and local predictions of block usage, improving routing consistency across instances. Note that this does not sync the radix tree or cached KV block states themselves - those are synchronized through JetStream events. Disabled by default.
- `--router-reset-states`/`--router-persist-states`: Controls whether the router state is reset on startup. When `--router-reset-states` is used (default), the router clears both the JetStream event stream and NATs object store, starting with a fresh state. When `--router-persist-states` is used, the router retains existing state from previous runs, downloading any available snapshot from NATs object store and continuing to consume events from where it left off. This enables routers to maintain KV cache awareness across restarts. **Note**: State persistence is only available when `--use-kv-events` is enabled (default). When using `--no-kv-events` with `ApproxKvIndexer`, state persistence is not supported.
- `--router-snapshot-threshold`: Sets the number of messages in the JetStream before triggering a snapshot. When the message count exceeds this threshold, a router will attempt to purge acknowledged messages from the stream and create a snapshot of the current radix tree state in NATs object store. Defaults to 10000. This helps manage stream size and provides faster initialization for routers that restart.
## Architecture
......@@ -50,17 +54,26 @@ KV Cache routing uses direct routing with a special worker selection algorithm.
For improved fault tolerance, you can launch two frontend + router replicas. Since the frontend and router are currently tied together, you'll need to use two different HTTP ports for each instance.
To enable state sharing between the router replicas (which provides more accurate routing decisions), use the `--router-replica-sync` flag when starting the frontend:
To enable state sharing between the router replicas (which provides more accurate routing decisions), use the `--router-replica-sync` flag when starting the frontend. Router replicas are currently tied to a component, and state syncing and sharing can only happen within the component group. Here's an example of running multiple router replicas:
```bash
# Router replica 1
python -m dynamo.frontend --router-mode kv --port 8000 --router-replica-sync
# Router replica 2
python -m dynamo.frontend --router-mode kv --port 8001 --router-replica-sync
# Router replica 2 (can be started later, note the extra --router-persist-states arg)
python -m dynamo.frontend --router-mode kv --port 8001 --router-replica-sync --router-persist-states
```
When `--router-replica-sync` is enabled, the router replicas will communicate with each other via NATS to maintain consistent state across instances. This allows both routers to have a complete view of the KV cache distribution and make optimal routing decisions, even when requests are distributed across multiple router instances.
After these two replicas are launched, they will share the same JetStream and snapshot state. The second replica can be started after the first has already been handling requests. As long as `--router-persist-states` is set, the new replica will sync its KV block indexer by consuming the JetStream events and/or downloading the latest snapshot, ensuring both replicas have the same view of cached blocks across workers. It's okay for one router to go down, or even both to go down - the state persistence ensures continuity (up to the message retention of an hour we set for the stream). When a third router starts (with `--router-persist-states`), the states will still persist:
```bash
# Router replica 3 (can be started even after replicas 1 and 2 have gone down)
python -m dynamo.frontend --router-mode kv --port 8002 --router-replica-sync --router-persist-states
```
> **Note:** If a router replica is launched without the `--router-persist-states` flag, the entire stream and radix snapshot will be purged. If you want to serve a separate router (targeting a different set of workers) independently without affecting the current state, consider using a new namespace/component (see [Distributed Runtime](distributed_runtime.md)) which will start a new stream and NATS object store path.
When `--router-replica-sync` is enabled, the router replicas will additionally share their local routing decisions and active sequence predictions via NATS. Active blocks information is communicated between routers in a fire-and-forget manner, but the routers will quickly become consistent as this information is tied to the request cycle. This helps maintain consistent load estimates across instances even when requests are distributed between routers.
## Understanding KV Cache
The leading Large Language Models (LLMs) today are auto-regressive and based off of the [transformer architecture](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf). One key inference optimization technique is to cache the already computed keys and values and to reuse them for the future tokens. This is called the [KV Cache](https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/#key-value_caching).
......@@ -182,3 +195,8 @@ In distributed deployments with multiple routers, each router maintains visibili
Each event carries a unique router ID to prevent self-event processing. This asynchronous communication system ensures optimal routing decisions by maintaining consistent KV cache state across all routers, even as they handle different request streams.
### Event Persistence and Recovery
KV cache events are persisted in NATS JetStream, allowing router replicas to maintain their global view of KV blocks across restarts. When a router starts with `--router-persist-states`, it downloads any available snapshot from NATs object store and continues consuming events from its last acknowledged position in the stream.
To manage stream growth, when the message count exceeds `--router-snapshot-threshold`, a router acquires an etcd-based distributed lock, purges acknowledged messages from the stream, and uploads the current radix tree state to NATs object store. This snapshot serves as a checkpoint for faster initialization of future router instances.
......@@ -248,6 +248,9 @@ impl Flags {
self.use_kv_events,
self.router_replica_sync,
self.max_num_batched_tokens,
// defaulting below args (no longer maintaining new flags for dynamo-run)
None,
None,
),
)
}
......
......@@ -42,12 +42,14 @@ impl KvRouterConfig {
#[pymethods]
impl KvRouterConfig {
#[new]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false))]
#[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_snapshot_threshold=10000, router_reset_states=true))]
fn new(
overlap_score_weight: f64,
router_temperature: f64,
use_kv_events: bool,
router_replica_sync: bool,
router_snapshot_threshold: Option<u32>,
router_reset_states: bool,
) -> Self {
KvRouterConfig {
inner: RsKvRouterConfig {
......@@ -55,6 +57,8 @@ impl KvRouterConfig {
router_temperature,
use_kv_events,
router_replica_sync,
router_snapshot_threshold,
router_reset_states,
..Default::default()
},
}
......
......@@ -19,6 +19,7 @@ use std::sync::atomic::AtomicU32;
use tokio_stream::StreamExt;
use super::*;
use crate::Component;
use llm_rs::kv_router::indexer::compute_block_hash_for_seq;
use llm_rs::kv_router::indexer::KvIndexerInterface;
use llm_rs::kv_router::protocols::ForwardPassMetrics as RsForwardPassMetrics;
......@@ -405,39 +406,36 @@ pub(crate) struct KvIndexer {
#[pymethods]
impl KvIndexer {
#[new]
fn new(component: Component, kv_block_size: usize) -> PyResult<Self> {
#[pyo3(signature = (component, kv_block_size, consumer_uuid=None))]
fn new(
component: Component,
kv_block_size: usize,
consumer_uuid: Option<String>,
) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let cancellation_token = component.inner.drt().runtime().child_token();
let inner: Arc<llm_rs::kv_router::indexer::KvIndexer> =
llm_rs::kv_router::indexer::KvIndexer::new(
component.inner.drt().runtime().child_token(),
cancellation_token.clone(),
kv_block_size as u32,
)
.into();
// [gluo TODO] try subscribe_with_type::<RouterEvent>,
// error checking below will be different.
let mut kv_events_rx = component
.inner
.subscribe(llm_rs::kv_router::KV_EVENT_SUBJECT)
// Use the shared start_kv_router_background function for event consumption
// Pass None for snapshot_tx to skip snapshot handling in Python bindings
llm_rs::kv_router::subscriber::start_kv_router_background(
component.inner.clone(),
consumer_uuid.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
inner.event_sender(),
None,
cancellation_token,
None,
true,
)
.await
.map_err(to_pyerr)?;
let kv_events_tx = inner.event_sender();
// [FIXME] this is the added functionality to the indexer to subscribe to kv events,
// should have been made to a trait and implemented here? i.e. AsyncEngine style
tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await {
let event: llm_rs::kv_router::indexer::RouterEvent =
serde_json::from_slice(&event.payload).unwrap();
tracing::debug!("received kv event: {:?}", event);
if let Err(e) = kv_events_tx.send(event).await {
tracing::trace!(
"failed to send kv event to indexer; shutting down: {:?}",
e
);
}
}
});
Ok(Self { inner })
})
}
......@@ -845,6 +843,7 @@ impl SpecDecodeStats {
#[pyclass]
pub(crate) struct KvPushRouter {
inner: Arc<llm_rs::kv_router::KvPushRouter>,
primary_token: tokio_util::sync::CancellationToken,
}
#[pymethods]
......@@ -875,12 +874,25 @@ impl KvPushRouter {
// Get component from endpoint
let component = endpoint.inner.component();
// Create KvRouter
// Get the primary token from the component's primary lease
let primary_token = component
.drt()
.primary_lease()
.ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Failed to get primary lease: Cannot KV route static workers",
)
})?
.primary_token();
// Create KvRouter with a unique consumer UUID
let consumer_uuid = uuid::Uuid::new_v4().to_string();
let kv_router = llm_rs::kv_router::KvRouter::new(
component.clone(),
block_size as u32,
None, // default selector
Some(kv_router_config.inner()),
consumer_uuid,
)
.await
.map_err(to_pyerr)?;
......@@ -891,6 +903,7 @@ impl KvPushRouter {
Ok(Self {
inner: Arc::new(kv_push_router),
primary_token,
})
})
}
......@@ -996,6 +1009,25 @@ impl KvPushRouter {
})
})
}
/// Dump all events from the KV router's indexer as a JSON string
fn dump_events<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let events = inner.dump_events().await.map_err(to_pyerr)?;
// Serialize to JSON string
let json_str = serde_json::to_string(&events).map_err(to_pyerr)?;
Ok(json_str)
})
}
}
impl Drop for KvPushRouter {
fn drop(&mut self) {
// Cancel the primary token to shut down background tasks
self.primary_token.cancel();
}
}
// Python async generator wrapper for the stream
......
......@@ -138,17 +138,38 @@ async def test_event_handler(distributed_runtime):
event_publisher.store_event(test_token, lora_id)
# wait for the event to be processed as it is sent asynchronously
await asyncio.sleep(1)
# Retry loop for CI environments where processing may take longer
for retry in range(10): # Try up to 10 times
await asyncio.sleep(0.5) # Wait 500ms between retries
scores = await indexer.find_matches_for_request(test_token, lora_id)
assert scores.scores
assert worker_id in scores.scores
assert scores.scores[worker_id] == 1
if (
scores.scores
and worker_id in scores.scores
and scores.scores[worker_id] == 1
):
break
if retry == 9: # Last iteration
# Provide detailed error message for debugging
assert scores.scores, f"No scores found after {(retry+1)*0.5}s"
assert (
worker_id in scores.scores
), f"Worker {worker_id} not in scores after {(retry+1)*0.5}s"
assert (
scores.scores[worker_id] == 1
), f"Expected score 1, got {scores.scores.get(worker_id)} after {(retry+1)*0.5}s"
# remove event
event_publisher.remove_event()
await asyncio.sleep(1)
# Retry loop for event removal verification
for retry in range(10): # Try up to 10 times
await asyncio.sleep(0.5) # Wait 500ms between retries
scores = await indexer.find_matches_for_request(test_token, lora_id)
assert not scores.scores
if not scores.scores:
break
if retry == 9: # Last iteration
assert (
not scores.scores
), f"Scores still present after {(retry+1)*0.5}s: {scores.scores}"
async def test_approx_kv_indexer(distributed_runtime):
......@@ -235,12 +256,13 @@ async def test_metrics_aggregator(distributed_runtime):
asyncio.create_task(metrics_publisher_task(kv_listener, expected_metrics))
# needs time for publisher to spawn up
for i in range(10):
await asyncio.sleep(1)
# Using shorter intervals for faster detection in normal cases
for i in range(20): # Try up to 20 times (10 seconds total)
await asyncio.sleep(0.5) # Wait 500ms between retries
metrics = await metrics_aggregator.get_metrics()
if metrics.endpoints:
break
assert metrics.endpoints
assert metrics.endpoints, f"No metrics endpoints found after {(i+1)*0.5}s"
for endpoint in metrics.endpoints:
# [TODO] not really checking id for now, can't get it as create_endpoint()
# create and serve the endpoint internally
......
......@@ -12,3 +12,6 @@ pub use watcher::{ModelUpdate, ModelWatcher};
/// The root etcd path for ModelEntry
pub const MODEL_ROOT_PATH: &str = "models";
/// The root etcd path for KV Router registrations
pub const KV_ROUTERS_ROOT_PATH: &str = "kv_routers";
......@@ -10,9 +10,8 @@ use parking_lot::Mutex;
use dynamo_runtime::component::Component;
use dynamo_runtime::prelude::DistributedRuntimeProvider;
use dynamo_runtime::slug::Slug;
use crate::discovery::ModelEntry;
use crate::discovery::{KV_ROUTERS_ROOT_PATH, ModelEntry};
use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector};
use crate::{
kv_router::KvRouter,
......@@ -218,10 +217,12 @@ impl ModelManager {
.drt()
.etcd_client()
.ok_or_else(|| anyhow::anyhow!("KV routing requires etcd (dynamic mode)"))?;
let router_uuid = uuid::Uuid::new_v4();
let router_key = format!(
"kv_routers/{}/{}",
Slug::from_string(model_name),
uuid::Uuid::new_v4()
"{}/{}/{}",
KV_ROUTERS_ROOT_PATH,
component.path(),
router_uuid
);
etcd_client
.kv_create(
......@@ -237,6 +238,7 @@ impl ModelManager {
kv_cache_block_size,
Some(selector),
kv_router_config,
router_uuid.to_string(),
)
.await?;
let new_kv_chooser = Arc::new(chooser);
......
......@@ -15,6 +15,7 @@ use dynamo_runtime::{
},
prelude::*,
protocols::annotated::Annotated,
utils::typed_prefix_watcher::{key_extractors, watch_prefix_with_extraction},
};
use futures::stream::{self, StreamExt};
use serde::{Deserialize, Serialize};
......@@ -29,6 +30,7 @@ pub mod recorder;
pub mod scheduler;
pub mod scoring;
pub mod sequence;
pub mod subscriber;
use crate::{
discovery::{MODEL_ROOT_PATH, ModelEntry},
......@@ -41,14 +43,13 @@ use crate::{
protocols::{LocalBlockHash, RouterRequest, RouterResponse, WorkerSelectionResult},
scheduler::{KvScheduler, KvSchedulerError, SchedulingRequest},
scoring::ProcessedEndpoints,
subscriber::start_kv_router_background,
},
local_model::runtime_config::ModelRuntimeConfig,
preprocessor::PreprocessedRequest,
protocols::common::llm_backend::LLMEngineOutput,
};
use dynamo_runtime::traits::events::EventSubscriber;
// [gluo TODO] shouldn't need to be public
// this should be discovered from the component
......@@ -64,6 +65,12 @@ pub const KV_METRICS_SUBJECT: &str = "kv_metrics";
pub const PREFILL_SUBJECT: &str = "prefill_events";
pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
// for radix tree snapshot storage
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";
pub const ROUTER_SNAPSHOT_LOCK: &str = "router-snapshot-lock";
pub const ROUTER_CLEANUP_LOCK: &str = "router-cleanup-lock";
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
fn select_worker(
......@@ -98,6 +105,12 @@ pub struct KvRouterConfig {
// TODO: this is not actually used for now
// Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
pub max_num_batched_tokens: u32,
/// Threshold for triggering snapshots. If None, no snapshots will be performed.
pub router_snapshot_threshold: Option<u32>,
/// Whether to reset the router state on startup (default: true)
pub router_reset_states: bool,
}
impl Default for KvRouterConfig {
......@@ -108,6 +121,8 @@ impl Default for KvRouterConfig {
use_kv_events: true,
router_replica_sync: false,
max_num_batched_tokens: 8192,
router_snapshot_threshold: Some(10000),
router_reset_states: true,
}
}
}
......@@ -121,6 +136,8 @@ impl KvRouterConfig {
use_kv_events: Option<bool>,
replica_sync: Option<bool>,
max_num_batched_tokens: Option<u32>,
router_snapshot_threshold: Option<Option<u32>>,
router_reset_states: Option<bool>,
) -> Self {
let default = Self::default();
Self {
......@@ -130,6 +147,9 @@ impl KvRouterConfig {
router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
max_num_batched_tokens: max_num_batched_tokens
.unwrap_or(default.max_num_batched_tokens),
router_snapshot_threshold: router_snapshot_threshold
.unwrap_or(default.router_snapshot_threshold),
router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
}
}
}
......@@ -151,6 +171,13 @@ impl Indexer {
Indexer::ApproxKvIndexer(indexer) => indexer.find_matches(sequence).await,
}
}
async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
match self {
Indexer::KvIndexer(indexer) => indexer.dump_events().await,
Indexer::ApproxKvIndexer(indexer) => indexer.dump_events().await,
}
}
}
/// A KvRouter only decides which worker you should use. It doesn't send you there.
......@@ -170,6 +197,7 @@ impl KvRouter {
block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
kv_router_config: Option<KvRouterConfig>,
consumer_uuid: String,
) -> Result<Self> {
let kv_router_config = kv_router_config.unwrap_or_default();
......@@ -196,9 +224,6 @@ impl KvRouter {
.etcd_client()
.expect("Cannot KV route without etcd client");
use dynamo_runtime::utils::typed_prefix_watcher::{
key_extractors, watch_prefix_with_extraction,
};
let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client,
MODEL_ROOT_PATH,
......@@ -230,31 +255,20 @@ impl KvRouter {
)
.await?;
// [gluo TODO] try subscribe_with_type::<RouterEvent>,
// error checking below will be different.
// Start unified background process if using KvIndexer
if let Indexer::KvIndexer(ref kv_indexer) = indexer {
let mut kv_events_rx = component.subscribe(KV_EVENT_SUBJECT).await?;
let kv_events_tx = kv_indexer.event_sender();
tokio::spawn(async move {
while let Some(event) = kv_events_rx.next().await {
let event: RouterEvent = match serde_json::from_slice(&event.payload) {
Ok(event) => event,
Err(e) => {
tracing::warn!("Failed to deserialize RouterEvent: {:?}", e);
// Choosing warn and continue to process other events from other workers
// A bad event likely signals a problem with a worker, but potentially other workers are still healthy
continue;
}
};
if let Err(e) = kv_events_tx.send(event).await {
tracing::warn!(
"failed to send kv event to indexer; shutting down: {:?}",
e
);
}
}
});
start_kv_router_background(
component.clone(),
consumer_uuid,
kv_indexer.event_sender(),
kv_router_config
.router_snapshot_threshold
.map(|_| kv_indexer.snapshot_event_sender()),
cancellation_token.clone(),
kv_router_config.router_snapshot_threshold,
kv_router_config.router_reset_states,
)
.await?;
}
tracing::info!("KV Routing initialized");
......@@ -318,6 +332,11 @@ impl KvRouter {
pub fn block_size(&self) -> u32 {
self.block_size
}
/// Dump all events from the indexer
pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.indexer.dump_events().await
}
}
// NOTE: this would not be usable for now, should deprecate
......@@ -351,6 +370,11 @@ impl KvPushRouter {
) -> Self {
KvPushRouter { inner, chooser }
}
/// Dump all events from the KV router's indexer
pub async fn dump_events(&self) -> Result<Vec<RouterEvent>, KvRouterError> {
self.chooser.dump_events().await
}
}
#[async_trait]
......
......@@ -25,10 +25,9 @@ use tokio_util::sync::CancellationToken;
use crate::tokens::{SequenceHash, TokenBlockSequence};
use crate::kv_router::RouterEvent;
use crate::kv_router::indexer::{
DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, WorkerId,
compute_block_hash_for_seq,
DumpRequest, KvIndexerInterface, KvRouterError, OverlapScores, RadixTree, RouterEvent,
WorkerId, compute_block_hash_for_seq,
};
use crate::kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
......
......@@ -760,6 +760,15 @@ impl KvIndexer {
pub fn event_sender(&self) -> mpsc::Sender<RouterEvent> {
self.event_tx.clone()
}
/// Get a sender for dump requests (snapshot events).
///
/// ### Returns
///
/// A `mpsc::Sender` for `DumpRequest`s.
pub fn snapshot_event_sender(&self) -> mpsc::Sender<DumpRequest> {
self.dump_tx.clone()
}
}
#[async_trait]
......
......@@ -30,6 +30,7 @@ use dynamo_runtime::{
network::Ingress,
},
protocols::annotated::Annotated,
transports::nats::{NatsQueue, QUEUE_NAME, Slug},
};
use futures::stream;
use std::sync::{Arc, OnceLock};
......@@ -133,16 +134,27 @@ impl KvEventPublisher {
)?);
}
component
.drt()
.runtime()
.secondary()
.spawn(start_event_processor(
component,
worker_id,
cancellation_token.clone(),
rx,
));
let stream_name = Slug::slugify(&format!("{}.{}", component.subject(), KV_EVENT_SUBJECT))
.to_string()
.replace("_", "-");
let nats_server =
std::env::var("NATS_SERVER").unwrap_or_else(|_| "nats://localhost:4222".to_string());
// Create NatsQueue without consumer since we're only publishing
let mut nats_queue = NatsQueue::new_without_consumer(
stream_name,
nats_server,
std::time::Duration::from_secs(60), // 1 minute timeout
);
// Connect the NatsQueue before passing it to the event processor
let cancellation_token_clone = cancellation_token.clone();
component.drt().runtime().secondary().spawn(async move {
if let Err(e) = nats_queue.connect().await {
tracing::error!("Failed to connect NatsQueue: {}", e);
return;
}
start_event_processor(nats_queue, worker_id, cancellation_token_clone, rx).await
});
Ok(Self {
kv_block_size,
......@@ -198,7 +210,7 @@ async fn start_event_processor<P: EventPublisher + Send + Sync + 'static>(
// Encapsulate in a router event and publish.
tracing::trace!("Event processor for worker_id {} processing event: {:?}", worker_id, event.data);
let router_event = RouterEvent::new(worker_id, event);
if let Err(e) = publisher.publish(KV_EVENT_SUBJECT, &router_event).await {
if let Err(e) = publisher.publish(QUEUE_NAME, &router_event).await {
tracing::error!("Failed to publish event: {}", e);
}
}
......@@ -929,7 +941,7 @@ mod tests_startup_helpers {
let published = published.lock().unwrap();
assert_eq!(published.len(), 1);
let (subject, _) = &published[0];
assert_eq!(subject, &KV_EVENT_SUBJECT.to_string());
assert_eq!(subject, QUEUE_NAME);
}
//--------------------------------------------------------------------
......
......@@ -271,7 +271,8 @@ impl ActiveSequencesMultiWorker {
let component_clone = component.clone();
let router_id_clone = router_id;
tokio::spawn(async move {
component.drt().runtime().secondary().spawn(async move {
// NATS subscription loop
if let Err(e) = Self::subscribe_to_events(
senders_clone,
request_to_worker_clone,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Background processes for the KV Router including event consumption and snapshot uploads.
use std::time::Duration;
use anyhow::Result;
use dynamo_runtime::{
component::Component,
prelude::*,
traits::events::EventPublisher,
transports::{
etcd::WatchEvent,
nats::{NatsQueue, Slug},
},
};
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use crate::{
discovery::KV_ROUTERS_ROOT_PATH,
kv_router::{
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, ROUTER_CLEANUP_LOCK,
ROUTER_SNAPSHOT_LOCK,
indexer::{DumpRequest, RouterEvent},
},
};
/// Resources required for snapshot operations
#[derive(Clone)]
struct SnapshotResources {
nats_client: dynamo_runtime::transports::nats::Client,
bucket_name: String,
etcd_client: dynamo_runtime::transports::etcd::Client,
lock_name: String,
}
impl SnapshotResources {
/// Try to acquire distributed lock for snapshot operations
/// Returns Some(lock_response) if lock acquired, None if another instance holds it
async fn lock(&self) -> Option<etcd_client::LockResponse> {
match self
.etcd_client
.lock(self.lock_name.clone(), Some(self.etcd_client.lease_id()))
.await
{
Ok(response) => {
tracing::debug!(
"Successfully acquired snapshot lock with key: {:?}",
response.key()
);
Some(response)
}
Err(e) => {
tracing::debug!("Another instance already holds the snapshot lock: {e:?}");
None
}
}
}
/// Release the distributed lock
async fn unlock(&self, lock_response: etcd_client::LockResponse) {
if let Err(e) = self.etcd_client.unlock(lock_response.key()).await {
tracing::warn!("Failed to release snapshot lock: {e:?}");
}
}
}
/// Start a unified background task for event consumption and optional snapshot management
pub async fn start_kv_router_background(
component: Component,
consumer_uuid: String,
kv_events_tx: mpsc::Sender<RouterEvent>,
snapshot_tx: Option<mpsc::Sender<DumpRequest>>,
cancellation_token: CancellationToken,
router_snapshot_threshold: Option<u32>,
router_reset_states: bool,
) -> Result<()> {
// Set up NATS connections
let stream_name = Slug::slugify(&format!("{}.{}", component.subject(), KV_EVENT_SUBJECT))
.to_string()
.replace("_", "-");
let nats_server =
std::env::var("NATS_SERVER").unwrap_or_else(|_| "nats://localhost:4222".to_string());
// Create NatsQueue for event consumption
let mut nats_queue = NatsQueue::new_with_consumer(
stream_name.clone(),
nats_server.clone(),
std::time::Duration::from_secs(60), // 1 minute timeout
consumer_uuid,
);
nats_queue.connect_with_reset(router_reset_states).await?;
// Always create NATS client (needed for both reset and snapshots)
let client_options = dynamo_runtime::transports::nats::Client::builder()
.server(&nats_server)
.build()?;
let nats_client = client_options.connect().await?;
// Create bucket name for snapshots/state
let bucket_name = Slug::slugify(&format!("{}-{RADIX_STATE_BUCKET}", component.subject()))
.to_string()
.replace("_", "-");
// Handle initial state based on router_reset_states flag
if router_reset_states {
// Delete the bucket to reset state
tracing::info!("Resetting router state, deleting bucket: {bucket_name}");
if let Err(e) = nats_client.object_store_delete_bucket(&bucket_name).await {
tracing::warn!("Failed to delete bucket (may not exist): {e:?}");
}
} else {
// Try to download initial state from object store
let url = url::Url::parse(&format!(
"nats://{}/{bucket_name}/{RADIX_STATE_FILE}",
nats_client.addr()
))?;
match nats_client
.object_store_download_data::<Vec<RouterEvent>>(url)
.await
{
Ok(events) => {
tracing::info!(
"Successfully downloaded {} events from object store",
events.len()
);
// Send all events to the indexer
for event in events {
if let Err(e) = kv_events_tx.send(event).await {
tracing::warn!("Failed to send initial event to indexer: {e:?}");
}
}
tracing::info!("Successfully sent all initial events to indexer");
}
Err(e) => {
tracing::info!(
"Did not initialize radix state from NATs object store (likely no snapshots yet): {e:?}"
);
}
}
}
// Get etcd client (needed for both snapshots and router watching)
let etcd_client = component
.drt()
.etcd_client()
.ok_or_else(|| anyhow::anyhow!("etcd client not available"))?;
// Watch for router deletions to clean up orphaned consumers
let (_prefix_str, _watcher, mut router_replicas_rx) = etcd_client
.kv_get_and_watch_prefix(&format!("{}/", KV_ROUTERS_ROOT_PATH))
.await?
.dissolve();
let cleanup_lock_name = format!("{}/{}", ROUTER_CLEANUP_LOCK, component.subject());
// Only set up snapshot-related resources if snapshot_tx is provided and threshold is set
let snapshot_resources = if snapshot_tx.is_some() && router_snapshot_threshold.is_some() {
let lock_name = format!("{}/{}", ROUTER_SNAPSHOT_LOCK, component.subject());
Some(SnapshotResources {
nats_client,
bucket_name,
etcd_client: etcd_client.clone(),
lock_name,
})
} else {
None
};
component.drt().runtime().secondary().spawn(async move {
let mut check_interval = tokio::time::interval(Duration::from_secs(1));
check_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
biased;
_ = cancellation_token.cancelled() => {
tracing::debug!("KV Router background task received cancellation signal");
// Clean up the queue and remove the durable consumer
if let Err(e) = nats_queue.shutdown(None).await {
tracing::warn!("Failed to shutdown NatsQueue: {e}");
}
break;
}
// Handle event consumption
result = nats_queue.dequeue_task(None) => {
match result {
Ok(Some(bytes)) => {
let event: RouterEvent = match serde_json::from_slice(&bytes) {
Ok(event) => event,
Err(e) => {
tracing::warn!("Failed to deserialize RouterEvent: {e:?}");
continue;
}
};
// Forward the RouterEvent to the indexer
if let Err(e) = kv_events_tx.send(event).await {
tracing::warn!(
"failed to send kv event to indexer; shutting down: {e:?}"
);
break;
}
},
Ok(None) => {
tracing::trace!("Dequeue timeout, continuing");
},
Err(e) => {
tracing::error!("Failed to dequeue task: {e:?}");
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
}
}
// Handle periodic stream checking and purging (only if snapshot_tx is provided)
_ = check_interval.tick() => {
let Some((snapshot_tx, resources)) = snapshot_tx.as_ref().zip(snapshot_resources.as_ref()) else {
continue;
};
// Check total messages in the stream
let Ok(message_count) = nats_queue.get_stream_messages().await else {
tracing::warn!("Failed to get stream message count");
continue;
};
// Guard clause: skip if message count is too low
let threshold = router_snapshot_threshold.unwrap_or(u32::MAX) as u64;
if message_count <= threshold {
continue;
}
tracing::info!("Stream has {message_count} messages, attempting to acquire lock for purge and snapshot");
// Try to acquire distributed lock
let Some(lock_response) = resources.lock().await else {
continue;
};
// Perform snapshot upload and purge
match perform_snapshot_and_purge(
&mut nats_queue,
snapshot_tx,
resources
).await {
Ok(_) => tracing::info!("Successfully performed purge and snapshot"),
Err(e) => tracing::error!("Failed to perform purge and snapshot: {e:?}"),
}
// Release the lock
resources.unlock(lock_response).await;
}
// Handle router deletion events
Some(event) = router_replicas_rx.recv() => {
let WatchEvent::Delete(kv) = event else {
// We only care about deletions for cleaning up consumers
continue;
};
let key = String::from_utf8_lossy(kv.key());
tracing::info!("Router deleted: {}", key);
// Extract the router UUID from the key (format: kv_routers/<model>/<uuid>)
let Some(router_uuid) = key.split('/').next_back() else {
tracing::warn!("Could not extract UUID from router key: {}", key);
continue;
};
// The consumer UUID is the router UUID
let consumer_to_delete = router_uuid.to_string();
tracing::info!("Attempting to delete orphaned consumer: {}", consumer_to_delete);
// Try to acquire cleanup lock before deleting consumer
match etcd_client
.lock(cleanup_lock_name.clone(), Some(etcd_client.lease_id()))
.await
{
Ok(lock_response) => {
tracing::debug!(
"Acquired cleanup lock for deleting consumer: {}",
consumer_to_delete
);
// Delete the consumer
if let Err(e) = nats_queue.shutdown(Some(consumer_to_delete.clone())).await {
tracing::warn!("Failed to delete consumer {}: {}", consumer_to_delete, e);
} else {
tracing::info!("Successfully deleted orphaned consumer: {}", consumer_to_delete);
}
// Release the lock
if let Err(e) = etcd_client.unlock(lock_response.key()).await {
tracing::warn!("Failed to release cleanup lock: {e:?}");
}
}
Err(e) => {
tracing::debug!(
"Could not acquire cleanup lock for consumer {}: {e:?}",
consumer_to_delete
);
}
}
}
}
}
// Clean up the queue and remove the durable consumer
if let Err(e) = nats_queue.shutdown(None).await {
tracing::warn!("Failed to shutdown NatsQueue: {e}");
}
});
Ok(())
}
/// Perform snapshot upload and purge operations
async fn perform_snapshot_and_purge(
nats_queue: &mut NatsQueue,
snapshot_tx: &mpsc::Sender<DumpRequest>,
resources: &SnapshotResources,
) -> anyhow::Result<()> {
// Snapshot before purge ensures we capture the current state before removing any messages.
// This guarantees the snapshot matches what has been acknowledged up to this point.
// First, request a snapshot from the indexer
let (resp_tx, resp_rx) = oneshot::channel();
let dump_req = DumpRequest { resp: resp_tx };
snapshot_tx
.send(dump_req)
.await
.map_err(|e| anyhow::anyhow!("Failed to send dump request: {e:?}"))?;
// Wait for the dump response
let events = resp_rx
.await
.map_err(|e| anyhow::anyhow!("Failed to receive dump response: {e:?}"))?;
// Upload the snapshot to NATS object store
let url = url::Url::parse(&format!(
"nats://{}/{}/{RADIX_STATE_FILE}",
resources.nats_client.addr(),
resources.bucket_name
))?;
resources
.nats_client
.object_store_upload_data(&events, url)
.await
.map_err(|e| anyhow::anyhow!("Failed to upload snapshot: {e:?}"))?;
tracing::info!(
"Successfully uploaded radix tree snapshot with {} events to bucket {}",
events.len(),
resources.bucket_name
);
// Now purge acknowledged messages from the stream
nats_queue.purge_acknowledged().await?;
Ok(())
}
......@@ -180,12 +180,12 @@ impl MockVllmEngine {
component: Option<Component>,
cancel_token: CancellationToken,
) -> Result<()> {
tracing::info!("Creating metrics publisher");
tracing::debug!("Creating metrics publisher");
let metrics_publisher = Arc::new(WorkerMetricsPublisher::new()?);
tracing::info!("Metrics publisher created");
tracing::debug!("Metrics publisher created");
if let Some(comp) = component {
tracing::info!("Creating metrics endpoint");
tracing::debug!("Creating metrics endpoint");
tokio::spawn({
let publisher = metrics_publisher.clone();
async move {
......@@ -197,10 +197,10 @@ impl MockVllmEngine {
// Give it a moment to start
tokio::time::sleep(Duration::from_millis(100)).await;
tracing::info!("Metrics endpoint started (background)");
tracing::debug!("Metrics endpoint started (background)");
}
tracing::info!("Starting metrics background tasks");
tracing::debug!("Starting metrics background tasks");
for (dp_rank, scheduler) in schedulers.iter().enumerate() {
let mut metrics_rx = scheduler.metrics_receiver();
let publisher = metrics_publisher.clone();
......@@ -223,7 +223,7 @@ impl MockVllmEngine {
}
}
_ = cancel_token.cancelled() => {
tracing::info!("Metrics publishing cancelled for DP rank {dp_rank}");
tracing::debug!("Metrics publishing cancelled for DP rank {dp_rank}");
break;
}
}
......@@ -241,14 +241,14 @@ impl MockVllmEngine {
block_size: usize,
cancel_token: CancellationToken,
) -> Result<()> {
tracing::info!("Starting KV events publishing");
tracing::debug!("Starting KV events publishing");
// Only start KV events publishing if we have a component
let Some(comp) = component else {
tracing::warn!("No component provided, skipping KV events publishing");
return Ok(());
};
tracing::info!("Component found for KV events publishing");
tracing::debug!("Component found for KV events publishing");
tracing::debug!("Getting worker_id");
let worker_id = comp
......@@ -259,16 +259,16 @@ impl MockVllmEngine {
// let worker_id = 0;
tracing::debug!("Worker_id set to: {worker_id}");
tracing::info!("Creating KV event publisher");
tracing::debug!("Creating KV event publisher");
let kv_event_publisher = Arc::new(KvEventPublisher::new(
comp.clone(),
worker_id,
block_size as u32,
None,
)?);
tracing::info!("KV event publisher created");
tracing::debug!("KV event publisher created");
tracing::info!(
tracing::debug!(
"Starting KV event background tasks for {} receivers",
kv_event_receivers.len()
);
......@@ -298,7 +298,7 @@ impl MockVllmEngine {
}
}
_ = cancel_token.cancelled() => {
tracing::info!("KV events publishing cancelled for DP rank {dp_rank}");
tracing::debug!("KV events publishing cancelled for DP rank {dp_rank}");
break;
}
}
......@@ -476,7 +476,7 @@ impl AnnotatedMockEngine {
continue;
}
tracing::info!("Component service is now available, starting mocker engine");
tracing::debug!("Component service is now available, starting mocker engine");
// Start the engine with the component
if let Err(e) = inner_clone.start(component).await {
......@@ -515,7 +515,7 @@ pub async fn make_mocker_engine(
args: MockEngineArgs,
) -> Result<crate::backend::ExecutionContext, Error> {
// Create the mocker engine
tracing::info!("Creating mocker engine with config: {args:?}");
tracing::debug!("Creating mocker engine with config: {args:?}");
let annotated_engine =
AnnotatedMockEngine::new(MockVllmEngine::new(args), distributed_runtime, endpoint_id);
......
......@@ -25,8 +25,9 @@ use tokio::sync::{RwLock, mpsc};
use validator::Validate;
use etcd_client::{
Certificate, Compare, CompareOp, DeleteOptions, GetOptions, Identity, PutOptions, PutResponse,
TlsOptions, Txn, TxnOp, TxnOpResponse, WatchOptions, Watcher,
Certificate, Compare, CompareOp, DeleteOptions, GetOptions, Identity, LockClient, LockOptions,
LockResponse, PutOptions, PutResponse, TlsOptions, Txn, TxnOp, TxnOpResponse, WatchOptions,
Watcher,
};
pub use etcd_client::{ConnectOptions, KeyValue, LeaseClient};
use tokio::time::{Duration, interval};
......@@ -306,6 +307,32 @@ impl Client {
Ok(get_response.take_kvs())
}
/// Acquire a distributed lock using etcd's native lock mechanism
/// Returns a LockResponse that can be used to unlock later
pub async fn lock(
&self,
key: impl Into<Vec<u8>>,
lease_id: Option<i64>,
) -> Result<LockResponse> {
let mut lock_client = self.client.lock_client();
let id = lease_id.unwrap_or(self.lease_id());
let options = LockOptions::new().with_lease(id);
lock_client
.lock(key, Some(options))
.await
.map_err(|err| err.into())
}
/// Release a distributed lock using the key from the LockResponse
pub async fn unlock(&self, lock_key: impl Into<Vec<u8>>) -> Result<()> {
let mut lock_client = self.client.lock_client();
lock_client
.unlock(lock_key)
.await
.map_err(|err: etcd_client::Error| anyhow::anyhow!(err))?;
Ok(())
}
pub async fn kv_get_and_watch_prefix(
&self,
prefix: impl AsRef<str> + std::fmt::Display,
......
......@@ -28,10 +28,12 @@
//! - `NATS_AUTH_CREDENTIALS_FILE`: the path to the credentials file
//!
//! Note: `NATS_AUTH_USERNAME` and `NATS_AUTH_PASSWORD` must be used together.
use crate::traits::events::EventPublisher;
use crate::{Result, metrics::MetricsRegistry};
use async_nats::connection::State;
use async_nats::{Subscriber, client, jetstream};
use async_trait::async_trait;
use bytes::Bytes;
use derive_builder::Builder;
use futures::{StreamExt, TryStreamExt};
......@@ -429,6 +431,9 @@ pub fn url_to_bucket_and_key(url: &Url) -> anyhow::Result<(String, String)> {
Ok((bucket.to_string(), key.to_string()))
}
/// Default queue name for publishing events
pub const QUEUE_NAME: &str = "queue";
/// A queue implementation using NATS JetStream
pub struct NatsQueue {
/// The name of the stream to use for the queue
......@@ -448,12 +453,32 @@ pub struct NatsQueue {
}
impl NatsQueue {
/// Create a new NatsQueue with the given configuration
/// Create a new NatsQueue with the default "worker-group" consumer
pub fn new(stream_name: String, nats_server: String, dequeue_timeout: time::Duration) -> Self {
// Sanitize stream name to remove path separators (like in Python version)
let sanitized_stream_name = stream_name.replace(['/', '\\'], "_");
// rupei: are we sure NATs stream name accepts '_'?
let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
let subject = format!("{sanitized_stream_name}.*");
let subject = format!("{}.*", sanitized_stream_name);
Self {
stream_name: sanitized_stream_name,
nats_server,
dequeue_timeout,
client: None,
subject,
subscriber: None,
consumer_name: Some("worker-group".to_string()),
}
}
/// Create a new NatsQueue without a consumer (publisher-only mode)
pub fn new_without_consumer(
stream_name: String,
nats_server: String,
dequeue_timeout: time::Duration,
) -> Self {
let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
let subject = format!("{sanitized_stream_name}.*");
Self {
stream_name: sanitized_stream_name,
......@@ -474,8 +499,8 @@ impl NatsQueue {
dequeue_timeout: time::Duration,
consumer_name: String,
) -> Self {
let sanitized_stream_name = stream_name.replace(['/', '\\'], "_");
let subject = format!("{}.*", sanitized_stream_name);
let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
let subject = format!("{sanitized_stream_name}.*");
Self {
stream_name: sanitized_stream_name,
......@@ -490,39 +515,71 @@ impl NatsQueue {
/// Connect to the NATS server and set up the stream and consumer
pub async fn connect(&mut self) -> Result<()> {
self.connect_with_reset(false).await
}
/// Connect to the NATS server and set up the stream and consumer, optionally resetting the stream
pub async fn connect_with_reset(&mut self, reset_stream: bool) -> Result<()> {
if self.client.is_none() {
// Create a new client
let client_options = Client::builder().server(self.nats_server.clone()).build()?;
let client = client_options.connect().await?;
// Check if stream exists, if not create it
let streams = client.list_streams().await?;
if !streams.contains(&self.stream_name) {
log::debug!("Creating NATS stream {}", self.stream_name);
// If reset_stream is true, delete the stream first
if reset_stream {
match client.jetstream().delete_stream(&self.stream_name).await {
Ok(_) => {
log::debug!(
"Successfully deleted NATS stream {} for reset",
self.stream_name
);
}
Err(e) => {
log::debug!(
"Failed to delete NATS stream '{}' (may not exist): {}",
self.stream_name,
e
);
}
}
}
// Always try to create the stream (removes the race condition)
let stream_config = jetstream::stream::Config {
name: self.stream_name.clone(),
subjects: vec![self.subject.clone()],
max_age: time::Duration::from_secs(60 * 10), // 10 min
// messages older than a hour in the stream will be automatically purged
max_age: time::Duration::from_secs(60 * 60),
..Default::default()
};
client.jetstream().create_stream(stream_config).await?;
match client.jetstream().create_stream(stream_config).await {
Ok(_) => {
log::debug!("Successfully created NATS stream {}", self.stream_name);
}
Err(e) => {
// Log warning but continue - stream likely already exists
log::warn!(
"Failed to create NATS stream '{}': {}. Stream likely already exists, continuing...",
self.stream_name,
e
);
}
}
// Create persistent subscriber
// Create persistent subscriber only if consumer_name is set
if let Some(ref consumer_name) = self.consumer_name {
let consumer_config = jetstream::consumer::pull::Config {
durable_name: Some(
self.consumer_name
.clone()
.unwrap_or_else(|| "worker-group".to_string()),
),
durable_name: Some(consumer_name.clone()),
..Default::default()
};
let stream = client.jetstream().get_stream(&self.stream_name).await?;
let subscriber = stream.create_consumer(consumer_config).await?;
self.subscriber = Some(subscriber);
}
self.client = Some(client);
}
......@@ -546,28 +603,52 @@ impl NatsQueue {
/// Shutdown the consumer by deleting it from the stream and closing the connection
/// This permanently removes the consumer from the server
pub async fn shutdown(&mut self) -> Result<()> {
if let (Some(client), Some(consumer_name)) = (&self.client, &self.consumer_name) {
///
/// If `consumer_name` is provided, that specific consumer will be deleted instead of the
/// current consumer. This allows deletion of other consumers on the same stream.
pub async fn shutdown(&mut self, consumer_name: Option<String>) -> Result<()> {
// Determine which consumer to delete
let target_consumer = consumer_name.as_ref().or(self.consumer_name.as_ref());
// Warn if deleting our own consumer via explicit parameter
if let Some(ref passed_name) = consumer_name
&& self.consumer_name.as_ref() == Some(passed_name)
{
log::warn!(
"Deleting our own consumer '{}' via explicit consumer_name parameter. \
Consider calling shutdown without arguments instead.",
passed_name
);
}
if let (Some(client), Some(consumer_to_delete)) = (&self.client, target_consumer) {
// Get the stream and delete the consumer
let stream = client.jetstream().get_stream(&self.stream_name).await?;
stream.delete_consumer(consumer_name).await.map_err(|e| {
anyhow::anyhow!("Failed to delete consumer {}: {}", consumer_name, e)
stream
.delete_consumer(consumer_to_delete)
.await
.map_err(|e| {
anyhow::anyhow!("Failed to delete consumer {}: {}", consumer_to_delete, e)
})?;
log::debug!(
"Deleted consumer {} from stream {}",
consumer_name,
consumer_to_delete,
self.stream_name
);
} else {
log::warn!(
"Cannot shutdown consumer: client or consumer_name is None (client: {:?}, consumer_name: {:?})",
log::debug!(
"Cannot shutdown consumer: client or target consumer is None (client: {:?}, target_consumer: {:?})",
self.client.is_some(),
self.consumer_name.is_some()
target_consumer.is_some()
);
}
// Then close the connection
// Only close the connection if we deleted our own consumer
if consumer_name.is_none() {
self.close().await
} else {
Ok(())
}
}
/// Count the number of consumers for the stream
......@@ -648,6 +729,19 @@ impl NatsQueue {
}
}
/// Get the total number of messages currently in the stream
pub async fn get_stream_messages(&mut self) -> Result<u64> {
self.ensure_connection().await?;
if let Some(client) = &self.client {
let mut stream = client.jetstream().get_stream(&self.stream_name).await?;
let info = stream.info().await?;
Ok(info.state.messages)
} else {
Err(anyhow::anyhow!("Client not connected"))
}
}
/// Purge messages from the stream up to (but not including) the specified sequence number
/// This permanently removes messages and affects all consumers of the stream
pub async fn purge_up_to_sequence(&self, sequence: u64) -> Result<()> {
......@@ -727,7 +821,7 @@ impl NatsQueue {
self.purge_up_to_sequence(purge_sequence).await?;
log::info!(
log::debug!(
"Purged stream {} up to acknowledged sequence {} (purged up to sequence {})",
self.stream_name,
min_ack_sequence,
......@@ -745,6 +839,49 @@ impl NatsQueue {
}
}
#[async_trait]
impl EventPublisher for NatsQueue {
fn subject(&self) -> String {
self.stream_name.clone()
}
async fn publish(
&self,
event_name: impl AsRef<str> + Send + Sync,
event: &(impl Serialize + Send + Sync),
) -> Result<()> {
let bytes = serde_json::to_vec(event)?;
self.publish_bytes(event_name, bytes).await
}
async fn publish_bytes(
&self,
event_name: impl AsRef<str> + Send + Sync,
bytes: Vec<u8>,
) -> Result<()> {
// We expect the stream to be always suffixed with "queue"
// This suffix itself is nothing special, just a repo standard
if event_name.as_ref() != QUEUE_NAME {
tracing::warn!(
"Expected event_name to be '{}', but got '{}'",
QUEUE_NAME,
event_name.as_ref()
);
}
let subject = format!("{}.{}", self.subject(), event_name.as_ref());
// Note: enqueue_task requires &mut self, but EventPublisher requires &self
// We need to ensure the client is connected and use it directly
if let Some(client) = &self.client {
client.jetstream().publish(subject, bytes.into()).await?;
Ok(())
} else {
Err(anyhow::anyhow!("Client not connected"))
}
}
}
/// Prometheus metrics that mirror the NATS client statistics (in primitive types)
/// to be used for the System Status Server.
///
......@@ -966,6 +1103,20 @@ mod tests {
let nats_server = "nats://localhost:4222".to_string();
let timeout = time::Duration::from_secs(0);
// Connect to NATS client first to delete stream if it exists
let client_options = Client::builder()
.server(nats_server.clone())
.build()
.expect("Failed to build client options");
let client = client_options
.connect()
.await
.expect("Failed to connect to NATS");
// Delete the stream if it exists (to ensure clean start)
let _ = client.jetstream().delete_stream(&stream_name).await;
// Create two consumers with different names for the same stream
let consumer1_name = format!("consumer-{}", Uuid::new_v4());
let consumer2_name = format!("consumer-{}", Uuid::new_v4());
......@@ -977,46 +1128,49 @@ mod tests {
consumer1_name,
);
let mut queue2 = NatsQueue::new_with_consumer(
stream_name.clone(),
nats_server.clone(),
timeout,
consumer2_name,
);
// Connect both queues (first one creates the stream, second one reuses it)
// Connect queue1 first (it will create the stream)
queue1.connect().await.expect("Failed to connect queue1");
queue2.connect().await.expect("Failed to connect queue2");
// Send 4 messages
let messages = vec![
Bytes::from("message1"),
Bytes::from("message2"),
Bytes::from("message3"),
Bytes::from("message4"),
// Send 4 messages using the EventPublisher trait
let message_strings = [
"message1".to_string(),
"message2".to_string(),
"message3".to_string(),
"message4".to_string(),
];
for msg in &messages {
// Using the EventPublisher trait to publish messages
for (idx, msg) in message_strings.iter().enumerate() {
queue1
.enqueue_task(msg.clone())
.publish("queue", msg)
.await
.expect("Failed to enqueue message");
.unwrap_or_else(|_| panic!("Failed to publish message {}", idx + 1));
}
// Convert messages to JSON-serialized Bytes for comparison
let messages: Vec<Bytes> = message_strings
.iter()
.map(|s| Bytes::from(serde_json::to_vec(s).unwrap()))
.collect();
// Give JetStream a moment to persist the messages
tokio::time::sleep(time::Duration::from_millis(100)).await;
// Get stream info to find the sequence numbers
// We need to know the sequence of message 2 to purge up to it
let client_options = Client::builder()
.server(nats_server.clone())
.build()
.expect("Failed to build client options");
// Now create and connect queue2 and queue3 AFTER messages are published (to test persistence)
let mut queue2 = NatsQueue::new_with_consumer(
stream_name.clone(),
nats_server.clone(),
timeout,
consumer2_name,
);
let client = client_options
.connect()
.await
.expect("Failed to connect to NATS");
// Create a third queue without consumer (publisher-only)
let mut queue3 =
NatsQueue::new_without_consumer(stream_name.clone(), nats_server.clone(), timeout);
// Connect queue2 and queue3 after messages are already published
queue2.connect().await.expect("Failed to connect queue2");
queue3.connect().await.expect("Failed to connect queue3");
// Purge the first two messages (sequence 1 and 2)
// Note: JetStream sequences start at 1, and purge is exclusive of the sequence number
......@@ -1127,7 +1281,10 @@ mod tests {
queue1.connect().await.expect("Failed to reconnect queue1");
// Shutdown consumer 1 and verify via consumer 2 that there is only one consumer left
queue1.shutdown().await.expect("Failed to shutdown queue1");
queue1
.shutdown(None)
.await
.expect("Failed to shutdown queue1");
let consumer_count = queue2
.count_consumers()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment