Unverified Commit 8fc70f20 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: add ZMQ transport and its discovery (#5625)

parent 2dbfed89
......@@ -131,7 +131,7 @@ version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
dependencies = [
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
......@@ -142,7 +142,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
dependencies = [
"anstyle",
"once_cell_polyfill",
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
......@@ -703,7 +703,7 @@ dependencies = [
"bitflags 2.10.0",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"itertools 0.11.0",
"proc-macro2",
"quote",
"regex",
......@@ -721,7 +721,7 @@ dependencies = [
"bitflags 2.10.0",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"itertools 0.11.0",
"log",
"prettyplease",
"proc-macro2",
......@@ -2132,7 +2132,7 @@ dependencies = [
"libc",
"option-ext",
"redox_users",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -2502,6 +2502,7 @@ dependencies = [
"libc",
"local-ip-address",
"log",
"lru",
"nid",
"nix 0.29.0",
"notify",
......@@ -2526,6 +2527,7 @@ dependencies = [
"temp-env",
"tempfile",
"thiserror 2.0.17",
"tmq",
"tokio",
"tokio-rayon",
"tokio-stream",
......@@ -2539,6 +2541,7 @@ dependencies = [
"uuid 1.18.1",
"validator",
"xxhash-rust",
"zmq",
]
[[package]]
......@@ -2723,7 +2726,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -4317,7 +4320,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46"
dependencies = [
"hermit-abi 0.5.2",
"libc",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -4380,7 +4383,7 @@ dependencies = [
"portable-atomic",
"portable-atomic-util",
"serde_core",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -4842,6 +4845,15 @@ dependencies = [
"vob",
]
[[package]]
name = "lru"
version = "0.12.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38"
dependencies = [
"hashbrown 0.15.5",
]
[[package]]
name = "lru-slab"
version = "0.1.2"
......@@ -5782,7 +5794,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -7881,7 +7893,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -9150,7 +9162,7 @@ dependencies = [
"getrandom 0.3.4",
"once_cell",
"rustix",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -10774,7 +10786,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys 0.48.0",
"windows-sys 0.61.2",
]
[[package]]
......
......@@ -102,6 +102,8 @@ serde_json = { version = "1" }
strum = { version = "0.27", features = ["derive"] }
tempfile = "3"
thiserror = { version = "2.0.17" }
tmq = { version = "0.5.0" }
zmq = { version = "0.10" }
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1" }
tokio-util = { version = "0.7", features = ["codec", "net", "rt"] }
......
......@@ -69,7 +69,7 @@ serde_json = { workspace = true }
strum = { workspace = true }
tempfile = { workspace = true }
thiserror = { workspace = true }
tmq = "0.5.0"
tmq = { workspace = true }
tokio = { workspace = true }
tokio-stream = { workspace = true }
tokio-util = { workspace = true }
......
......@@ -8,11 +8,10 @@ use dynamo_runtime::component::Client;
use dynamo_runtime::discovery::{DiscoveryQuery, watch_and_extract_field};
use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventSubscriber;
use dynamo_runtime::transports::event_plane::EventSubscriber;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use tokio_stream::StreamExt;
/// Scale factor for storing f64 thresholds as u32 (10000 = 4 decimal places)
const THRESHOLD_SCALE: u32 = 10000;
......@@ -194,8 +193,11 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
card.runtime_config
});
// Subscribe to KV metrics events
let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?;
// Subscribe to KV metrics events using EventSubscriber (Msgpack payloads)
let mut kv_metrics_rx =
EventSubscriber::for_namespace(component.namespace(), KV_METRICS_SUBJECT)
.await?
.typed::<ActiveLoad>();
let worker_load_states = self.worker_load_states.clone();
let client = self.client.clone();
......@@ -235,12 +237,13 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
// Handle KV metrics updates (ActiveLoad)
kv_event = kv_metrics_rx.next() => {
let Some(event) = kv_event else {
let Some(event_result) = kv_event else {
tracing::debug!("KV metrics stream closed");
break;
};
let Ok(active_load) = serde_json::from_slice::<ActiveLoad>(&event.payload) else {
let Ok((_envelope, active_load)) = event_result else {
tracing::error!("Error receiving KV metrics event: {event_result:?}");
continue;
};
......
......@@ -9,7 +9,7 @@ use anyhow::Result;
use derive_builder::Builder;
use dynamo_runtime::{
component::{Client, Endpoint},
discovery::{DiscoveryQuery, watch_and_extract_field},
discovery::{DiscoveryQuery, EventTransportKind, watch_and_extract_field},
pipeline::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
SingleIn, async_trait,
......@@ -50,7 +50,7 @@ use crate::{
},
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
sequence::SequenceError,
subscriber::{start_kv_router_background, start_kv_router_background_nats_core},
subscriber::{start_kv_router_background, start_kv_router_background_event_plane},
},
local_model::runtime_config::ModelRuntimeConfig,
model_card::ModelDeploymentCard,
......@@ -420,13 +420,25 @@ impl KvRouter {
tracing::info!("Found {count} worker(s), starting KV event subscriber");
let transport_kind = EventTransportKind::from_env_or_default();
// Start subscriber - setup runs synchronously, then spawns background loop internally
if all_local_indexer {
if transport_kind == EventTransportKind::Zmq {
if kv_router_config.router_snapshot_threshold.is_some()
|| kv_router_config.router_reset_states
{
tracing::warn!(
"ZMQ event plane does not support KV snapshots or state reset; ignoring snapshot/reset settings"
);
}
} else {
tracing::info!(
"All {count} workers have local_indexer enabled, using NATS Core subscription"
);
}
start_kv_router_background_nats_core(
start_kv_router_background_event_plane(
component.clone(),
kv_indexer.event_sender(),
kv_indexer.remove_worker_sender(),
......@@ -435,9 +447,15 @@ impl KvRouter {
component.clone(),
runtime_configs_rx.clone(),
),
transport_kind,
)
.await?;
} else {
if transport_kind == EventTransportKind::Zmq {
tracing::warn!(
"Not all workers have local_indexer enabled; falling back to JetStream for durability"
);
}
tracing::info!(
"Not all workers have local_indexer enabled, using JetStream subscription"
);
......@@ -596,7 +614,7 @@ impl KvRouter {
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
let maybe_seq_hashes = self
.kv_router_config
......
......@@ -7,6 +7,7 @@ use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
use rmp_serde as rmps;
use serde::Deserialize;
use serde::Serialize;
......@@ -15,12 +16,30 @@ use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use zeromq::{Socket, SocketRecv, SubSocket};
use dynamo_runtime::traits::{DistributedRuntimeProvider, events::EventPublisher};
use dynamo_runtime::traits::{
DistributedRuntimeProvider, events::EventPublisher as EventPublisherTrait,
};
use dynamo_runtime::transports::event_plane::EventPublisher;
use dynamo_runtime::{
component::{Component, Namespace},
transports::nats::{NatsQueue, Slug},
};
/// Helper function to create a KV stream name from a component and subject.
///
/// Generates a slugified stream name in the format:
/// `namespace-{namespace}-component-{component}-{subject}`
fn create_kv_stream_name(component: &Component, subject: &str) -> String {
Slug::slugify(&format!(
"namespace.{}.component.{}.{}",
component.namespace().name(),
component.name(),
subject
))
.to_string()
.replace("_", "-")
}
use crate::kv_router::{
KV_EVENT_SUBJECT, KV_METRICS_SUBJECT, WORKER_KV_INDEXER_BUFFER_SIZE,
indexer::{KvIndexerMetrics, LocalKvIndexer, RouterEvent},
......@@ -176,19 +195,27 @@ impl KvEventPublisher {
))
});
// Connect the NatsQueue before passing it to the event processor
let cancellation_token_clone = cancellation_token.clone();
let local_indexer_clone = local_indexer.clone();
if enable_local_indexer {
// When local indexer is enabled, use NATS Core (Component) for publishing.
// This is simpler and doesn't require JetStream durability since recovery
// is handled via the local indexer's event buffer.
tracing::info!("Using NATS Core for KV event publishing (local_indexer mode)");
// When local indexer is enabled, use the event plane directly.
// EventPublisher handles transport selection (ZMQ or NATS) based on environment.
// Durability is provided by the local indexer's event buffer.
tracing::info!("Using event plane for KV event publishing (local_indexer mode)");
let component_clone = component.clone();
component.drt().runtime().secondary().spawn(async move {
let event_publisher =
match EventPublisher::for_component(&component_clone, KV_EVENT_SUBJECT).await {
Ok(publisher) => publisher,
Err(e) => {
tracing::error!("Failed to create event publisher: {}", e);
return;
}
};
start_event_processor(
component_clone,
event_publisher,
worker_id,
cancellation_token_clone,
rx,
......@@ -198,10 +225,7 @@ impl KvEventPublisher {
});
} else {
// When local indexer is disabled, use JetStream (NatsQueue) for durability.
let stream_name =
Slug::slugify(&format!("{}.{}", component.subject(), KV_EVENT_SUBJECT))
.to_string()
.replace("_", "-");
let stream_name = create_kv_stream_name(&component, KV_EVENT_SUBJECT);
let nats_server = std::env::var(env_nats::NATS_SERVER)
.unwrap_or_else(|_| "nats://localhost:4222".to_string());
let mut nats_queue = NatsQueue::new_without_consumer(
......@@ -215,7 +239,7 @@ impl KvEventPublisher {
tracing::error!("Failed to connect NatsQueue: {e}");
return;
}
start_event_processor(
start_event_processor_jetstream(
nats_queue,
worker_id,
cancellation_token_clone,
......@@ -259,7 +283,27 @@ impl Drop for KvEventPublisher {
}
}
async fn start_event_processor<P: EventPublisher + Send + Sync + 'static>(
#[async_trait]
trait EventSink: Send + Sync {
async fn publish_event(&self, event: &RouterEvent) -> Result<()>;
}
#[async_trait]
impl EventSink for EventPublisher {
async fn publish_event(&self, event: &RouterEvent) -> Result<()> {
self.publish(event).await
}
}
#[async_trait]
impl EventSink for NatsQueue {
async fn publish_event(&self, event: &RouterEvent) -> Result<()> {
self.publish(KV_EVENT_SUBJECT, event).await
}
}
/// Event processor for ephemeral transports (NATS Core / ZMQ).
async fn start_event_processor<P: EventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
......@@ -294,11 +338,55 @@ async fn start_event_processor<P: EventPublisher + Send + Sync + 'static>(
}
}
// Then publish to NATS for global distribution
// Use KV_EVENT_SUBJECT so both JetStream and NATS Core subscribers
// can receive events on the expected subject.
if let Err(e) = publisher.publish(KV_EVENT_SUBJECT, &router_event).await {
tracing::error!("Failed to publish event to NATS: {}", e);
// Then publish to event plane for global distribution.
if let Err(e) = publisher.publish_event(&router_event).await {
tracing::error!("Failed to publish event: {}", e);
}
}
}
}
}
/// Event processor using JetStream (durable).
async fn start_event_processor_jetstream(
publisher: NatsQueue,
worker_id: u64,
cancellation_token: CancellationToken,
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>,
) {
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::info!("KV Event source received cancellation signal");
break;
}
event = rx.recv() => {
let Some(event) = event else {
tracing::debug!("Event processor channel closed.");
break;
};
// Encapsulate in a router event.
tracing::trace!("Event processor for worker_id {} processing event: {:?}", worker_id, event.data);
let router_event = RouterEvent::new(worker_id, event);
// Apply to local indexer first (if present)
if let Some(indexer) = &local_indexer {
// Adds event into local indexer, and logs it into internal buffer
if let Err(e) = indexer.apply_event_with_buffer(router_event.clone()).await {
tracing::warn!(
"Failed to send event to local indexer for worker {}: {}",
worker_id,
e
);
}
}
// Then publish to event plane for global distribution
if let Err(e) = publisher.publish_event(&router_event).await {
tracing::error!("Failed to publish event to event plane: {}", e);
}
}
......@@ -886,6 +974,15 @@ impl WorkerMetricsPublisher {
let nats_rx = self.rx.clone();
tokio::spawn(async move {
let event_publisher =
match EventPublisher::for_namespace(&namespace, KV_METRICS_SUBJECT).await {
Ok(publisher) => publisher,
Err(e) => {
tracing::error!("Failed to create metrics publisher: {}", e);
return;
}
};
let mut rx = nats_rx;
let mut last_active_decode_blocks: Option<u64> = Some(0);
let mut pending_publish: Option<WorkerMetrics> = None;
......@@ -933,10 +1030,8 @@ impl WorkerMetricsPublisher {
active_prefill_tokens: None,
};
if let Err(e) =
namespace.publish(KV_METRICS_SUBJECT, &active_load).await
{
tracing::warn!("Failed to publish metrics over NATS: {}", e);
if let Err(e) = event_publisher.publish(&active_load).await {
tracing::warn!("Failed to publish metrics: {}", e);
}
}
......@@ -1076,7 +1171,6 @@ mod tests_startup_helpers {
use crate::kv_router::KvIndexer;
use crate::kv_router::indexer::KvIndexerInterface;
use crate::kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
use async_trait;
use bytes::Bytes;
use std::sync::{Arc, Mutex};
use zeromq::{PubSocket, Socket, SocketSend, ZmqMessage};
......@@ -1105,35 +1199,15 @@ mod tests_startup_helpers {
}
#[async_trait::async_trait]
impl EventPublisher for MockComponent {
async fn publish(
&self,
event_name: impl AsRef<str> + Send + Sync,
event: &(impl serde::Serialize + Send + Sync),
) -> anyhow::Result<()> {
impl EventSink for MockComponent {
async fn publish_event(&self, event: &RouterEvent) -> anyhow::Result<()> {
let bytes = rmp_serde::to_vec(event).unwrap();
self.published
.lock()
.unwrap()
.push((event_name.as_ref().to_string(), bytes));
Ok(())
}
async fn publish_bytes(
&self,
event_name: impl AsRef<str> + Send + Sync,
bytes: Vec<u8>,
) -> anyhow::Result<()> {
self.published
.lock()
.unwrap()
.push((event_name.as_ref().to_string(), bytes));
.push((KV_EVENT_SUBJECT.to_string(), bytes));
Ok(())
}
fn subject(&self) -> String {
"mock.subject".into()
}
}
//--------------------------------------------------------------------
......@@ -1330,7 +1404,7 @@ mod tests_startup_helpers {
}
assert!(no_blocks, "worker should have no blocks after removal");
// Global kvindexer should have received two events (create/remove)
// Global kvindexer should have recieved two events (create/remove)
let published = published.lock().unwrap();
assert_eq!(
published.len(),
......@@ -1409,7 +1483,7 @@ mod tests_startup_helpers {
}
assert!(no_blocks, "worker should have no blocks after clearing");
// Global kvindexer should have received two events (create/remove)
// Global kvindexer should have recieved two events (create/remove)
let published = published.lock().unwrap();
assert_eq!(
published.len(),
......@@ -1815,7 +1889,7 @@ mod test_integration_publisher {
use super::*;
use crate::kv_router::protocols::ActiveLoad;
use dynamo_runtime::distributed_test_utils::create_test_drt_async;
use dynamo_runtime::traits::events::EventSubscriber;
use dynamo_runtime::transports::event_plane::EventSubscriber;
use futures::StreamExt;
#[tokio::test]
......@@ -1825,11 +1899,11 @@ mod test_integration_publisher {
let drt = create_test_drt_async().await;
let namespace = drt.namespace("ns2001".to_string())?;
// Create a subscriber for the metrics events using subscribe_with_type
let mut subscriber = namespace
.subscribe_with_type::<ActiveLoad>(KV_METRICS_SUBJECT)
// Create a subscriber for the metrics events
let mut subscriber = EventSubscriber::for_namespace(&namespace, KV_METRICS_SUBJECT)
.await
.unwrap();
.unwrap()
.typed::<ActiveLoad>();
// Create WorkerMetricsPublisher
let publisher = WorkerMetricsPublisher::new().unwrap();
......@@ -1857,7 +1931,7 @@ mod test_integration_publisher {
.await
.unwrap();
let event = result.unwrap().unwrap(); // Unwrap the Option and the Result
let (_envelope, event) = result.unwrap().unwrap(); // Unwrap the Option and the Result
assert_eq!(event.worker_id, worker_id);
assert_eq!(event.active_decode_blocks, Some(900)); // Last value: 9 * 100
assert_eq!(event.active_prefill_tokens, None); // Worker doesn't publish prefill tokens
......
......@@ -6,7 +6,7 @@ use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventPublisher;
use dynamo_runtime::transports::event_plane::EventPublisher;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
......@@ -50,6 +50,9 @@ pub enum KvSchedulerError {
#[error("endpoint subscriber shutdown")]
SubscriberShutdown,
#[error("failed to initialize event publisher: {0}")]
InitFailed(String),
}
#[derive(Debug)]
......@@ -111,13 +114,17 @@ impl KvScheduler {
.map(|r| (*r.key(), r.value().clone()))
.collect();
let slots = Arc::new(ActiveSequencesMultiWorker::new(
let slots = Arc::new(
ActiveSequencesMultiWorker::new(
component.clone(),
block_size as usize,
initial_workers,
replica_sync,
router_id,
));
)
.await
.map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?,
);
// Spawn background task to sync slots with DashMap when notified of changes.
// ModelManager's watcher updates the DashMap and notifies; we wait on notify here.
......@@ -160,7 +167,10 @@ impl KvScheduler {
let workers_scheduler = workers_with_configs.clone();
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
let scheduler_cancel_token = component.drt().primary_token();
let ns_clone = component.namespace().clone();
let hit_rate_publisher =
EventPublisher::for_namespace(component.namespace(), KV_HIT_RATE_SUBJECT)
.await
.map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?;
// Background task to handle scheduling requests
tokio::spawn(async move {
......@@ -206,7 +216,7 @@ impl KvScheduler {
isl_blocks: selection.required_blocks as usize,
overlap_blocks: selection.overlap_blocks,
};
if let Err(e) = ns_clone.publish(KV_HIT_RATE_SUBJECT, &event).await {
if let Err(e) = hit_rate_publisher.publish(&event).await {
tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
}
......
......@@ -29,8 +29,7 @@ use dashmap::DashMap;
use derive_getters::Getters;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
use futures::StreamExt;
use dynamo_runtime::transports::event_plane::{EventPublisher, EventSubscriber};
use std::collections::{HashMap, HashSet};
use std::rc::{Rc, Weak};
use std::sync::Arc;
......@@ -410,17 +409,21 @@ pub struct ActiveSequencesMultiWorker {
block_size: usize,
component: Component,
router_id: u64,
/// Publisher for sequence events
event_publisher: EventPublisher,
/// Publisher for metrics (namespace-scoped)
metrics_publisher: EventPublisher,
replica_sync: bool,
}
impl ActiveSequencesMultiWorker {
pub fn new(
pub async fn new(
component: Component,
block_size: usize,
workers_with_configs: HashMap<u64, Option<ModelRuntimeConfig>>,
replica_sync: bool,
router_id: u64,
) -> Self {
) -> Result<Self> {
assert!(block_size > 1, "block_size must be greater than 1");
let senders = Arc::new(DashMap::new());
......@@ -441,12 +444,19 @@ impl ActiveSequencesMultiWorker {
}
}
let event_publisher =
EventPublisher::for_component(&component, ACTIVE_SEQUENCES_SUBJECT).await?;
let metrics_publisher =
EventPublisher::for_namespace(component.namespace(), KV_METRICS_SUBJECT).await?;
let multi_worker = Self {
senders: senders.clone(),
request_to_worker: request_to_worker.clone(),
handles,
block_size,
component: component.clone(),
event_publisher,
metrics_publisher,
router_id,
replica_sync,
};
......@@ -475,7 +485,7 @@ impl ActiveSequencesMultiWorker {
});
}
multi_worker
Ok(multi_worker)
}
/// Helper method to start a worker task
......@@ -597,9 +607,9 @@ impl ActiveSequencesMultiWorker {
router_id: u64,
cancel_token: CancellationToken,
) -> Result<()> {
let mut subscriber = component
.subscribe_with_type::<ActiveSequenceEvent>(ACTIVE_SEQUENCES_SUBJECT)
.await?;
let mut subscriber = EventSubscriber::for_component(&component, ACTIVE_SEQUENCES_SUBJECT)
.await?
.typed::<ActiveSequenceEvent>();
loop {
tokio::select! {
......@@ -610,7 +620,7 @@ impl ActiveSequencesMultiWorker {
break;
};
let Ok(event) = result else {
let Ok((_envelope, event)) = result else {
tracing::error!(
"Error receiving active sequence event: {}",
result.unwrap_err()
......@@ -770,9 +780,7 @@ impl ActiveSequencesMultiWorker {
},
router_id: self.router_id,
};
self.component
.publish(ACTIVE_SEQUENCES_SUBJECT, &event)
.await?;
self.event_publisher.publish(&event).await?;
}
// Update local state with full WorkerWithDpRank
......@@ -831,9 +839,7 @@ impl ActiveSequencesMultiWorker {
data: ActiveSequenceEventData::Free,
router_id: self.router_id,
};
self.component
.publish(ACTIVE_SEQUENCES_SUBJECT, &event)
.await?;
self.event_publisher.publish(&event).await?;
}
// Update local state
......@@ -882,9 +888,7 @@ impl ActiveSequencesMultiWorker {
data: ActiveSequenceEventData::MarkPrefillCompleted,
router_id: self.router_id,
};
self.component
.publish(ACTIVE_SEQUENCES_SUBJECT, &event)
.await?;
self.event_publisher.publish(&event).await?;
}
// Update local state
......@@ -999,12 +1003,7 @@ impl ActiveSequencesMultiWorker {
active_prefill_tokens: Some(active_tokens as u64),
};
if let Err(e) = self
.component
.namespace()
.publish(KV_METRICS_SUBJECT, &active_load)
.await
{
if let Err(e) = self.metrics_publisher.publish(&active_load).await {
tracing::warn!("Failed to publish ActiveLoad for worker {worker:?}: {e:?}");
}
}
......@@ -1236,20 +1235,20 @@ mod tests {
let config_worker_1 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
workers_with_configs.insert(1, Some(config_worker_1));
let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new(
let seq_manager_1 = Arc::new(
ActiveSequencesMultiWorker::new(
component.clone(),
block_size,
workers_with_configs.clone(),
true,
1,
));
let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
component,
block_size,
workers_with_configs,
true,
2,
));
)
.await?,
);
let seq_manager_2 = Arc::new(
ActiveSequencesMultiWorker::new(component, block_size, workers_with_configs, true, 2)
.await?,
);
// Give some time for the subscription loops to start
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
......@@ -1395,20 +1394,20 @@ mod tests {
workers_with_configs.insert(1, None);
workers_with_configs.insert(2, None);
let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new(
let seq_manager_1 = Arc::new(
ActiveSequencesMultiWorker::new(
component.clone(),
block_size,
workers_with_configs.clone(),
true,
1,
));
let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
component,
block_size,
workers_with_configs,
true,
2,
));
)
.await?,
);
let seq_manager_2 = Arc::new(
ActiveSequencesMultiWorker::new(component, block_size, workers_with_configs, true, 2)
.await?,
);
// Give some time for the subscription loops to start
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
......
......@@ -7,9 +7,9 @@ use anyhow::Result;
use dynamo_runtime::{
component::Component,
config::environment_names::nats as env_nats,
discovery::{DiscoveryEvent, DiscoveryQuery},
discovery::{DiscoveryEvent, DiscoveryQuery, EventTransportKind},
prelude::*,
traits::events::{EventPublisher, EventSubscriber},
transports::event_plane::EventSubscriber,
transports::nats::{NatsQueue, Slug},
};
use futures::StreamExt;
......@@ -25,6 +25,21 @@ use crate::kv_router::{
worker_query::WorkerQueryClient,
};
/// Helper function to create a KV stream name from a component and subject.
///
/// Generates a slugified stream name in the format:
/// `namespace-{namespace}-component-{component}-{subject}`
fn create_kv_stream_name(component: &Component, subject: &str) -> String {
Slug::slugify(&format!(
"namespace.{}.component.{}.{}",
component.namespace().name(),
component.name(),
subject
))
.to_string()
.replace("_", "-")
}
/// Delay between snapshot reads to verify stability
const SNAPSHOT_STABILITY_DELAY: Duration = Duration::from_millis(100);
const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10;
......@@ -463,9 +478,7 @@ pub async fn start_kv_router_background(
router_reset_states: bool,
) -> Result<()> {
// Set up NATS connections
let stream_name = Slug::slugify(&format!("{}.{}", component.subject(), KV_EVENT_SUBJECT))
.to_string()
.replace("_", "-");
let stream_name = create_kv_stream_name(&component, KV_EVENT_SUBJECT);
let nats_server = std::env::var(env_nats::NATS_SERVER)
.unwrap_or_else(|_| "nats://localhost:4222".to_string());
......@@ -485,7 +498,12 @@ pub async fn start_kv_router_background(
let nats_client = client_options.connect().await?;
// Create bucket name for snapshots/state
let bucket_name = Slug::slugify(&format!("{}-{RADIX_STATE_BUCKET}", component.subject()))
let event_plane_subject = format!(
"namespace.{}.component.{}",
component.namespace().name(),
component.name()
);
let bucket_name = Slug::slugify(&format!("{}-{RADIX_STATE_BUCKET}", event_plane_subject))
.to_string()
.replace("_", "-");
......@@ -726,11 +744,11 @@ async fn handle_worker_discovery(
}
}
/// Start a simplified background task for event consumption using NATS Core.
/// Start a simplified background task for event consumption using the event plane.
///
/// This is used when local indexer mode is enabled. Unlike `start_kv_router_background`,
/// this function:
/// - Uses NATS Core pub/sub instead of JetStream
/// - Uses the event plane (NATS Core or ZMQ) instead of JetStream
/// - Does not support snapshots, purging, or durable consumers
/// - On worker Added: dumps worker's local indexer into router
/// - On worker Removed: removes worker from router indexer
......@@ -739,21 +757,40 @@ async fn handle_worker_discovery(
/// spawning the background task, ensuring the router is ready before returning.
///
/// This is appropriate when workers have local indexers enabled.
pub async fn start_kv_router_background_nats_core(
pub async fn start_kv_router_background_event_plane(
component: Component,
kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<WorkerId>,
cancellation_token: CancellationToken,
worker_query_client: WorkerQueryClient,
transport_kind: EventTransportKind,
) -> Result<()> {
// Subscribe to KV events using NATS Core
let mut subscriber = component.subscribe(KV_EVENT_SUBJECT).await?;
let kv_event_subject = format!("{}.{}", component.subject(), KV_EVENT_SUBJECT);
// Subscribe to KV events using the selected event plane transport
let mut subscriber =
EventSubscriber::for_component_with_transport(&component, KV_EVENT_SUBJECT, transport_kind)
.await?
.typed::<RouterEvent>();
let kv_event_subject = format!(
"namespace.{}.component.{}.{}",
component.namespace().name(),
component.name(),
KV_EVENT_SUBJECT
);
match transport_kind {
EventTransportKind::Nats => {
tracing::info!(
subject = %kv_event_subject,
"KV Router using NATS Core subscription (local_indexer mode)"
);
}
EventTransportKind::Zmq => {
tracing::info!(
subject = %kv_event_subject,
"KV Router using ZMQ event plane subscription (local_indexer mode)"
);
}
}
// Wait for at least one worker instance before proceeding
let mut instance_event_stream =
......@@ -802,7 +839,7 @@ pub async fn start_kv_router_background_nats_core(
biased;
_ = cancellation_token.cancelled() => {
tracing::debug!("KV Router NATS Core background task received cancellation signal");
tracing::debug!("KV Router event plane background task received cancellation signal");
break;
}
......@@ -821,12 +858,12 @@ pub async fn start_kv_router_background_nats_core(
.await;
}
// Handle event consumption from NATS Core subscription
Some(msg) = subscriber.next() => {
let event: RouterEvent = match serde_json::from_slice(&msg.payload) {
Ok(event) => event,
// Handle event consumption from event plane subscription
Some(result) = subscriber.next() => {
let (envelope, event) = match result {
Ok((envelope, event)) => (envelope, event),
Err(e) => {
tracing::warn!("Failed to deserialize RouterEvent from NATS Core: {e:?}");
tracing::warn!("Failed to receive RouterEvent from event plane: {e:?}");
continue;
}
};
......@@ -834,6 +871,13 @@ pub async fn start_kv_router_background_nats_core(
let worker_id = event.worker_id;
let event_id = event.event.event_id;
// Use envelope metadata for additional debugging
tracing::trace!(
"Received event from publisher {} (seq {})",
envelope.publisher_id,
envelope.sequence
);
// Gap detection: check if event ID is monotonically increasing per worker
// Note: event_id <= last_id is duplicate/out-of-order, apply anyway (idempotent)
if let Some(&last_id) = last_event_ids.get(&worker_id)
......@@ -847,7 +891,7 @@ pub async fn start_kv_router_background_nats_core(
"Event ID gap detected for worker {worker_id}, recovering events [{gap_start}, {gap_end}], gap_size: {gap_size}"
);
// Note: While recovering, new events may queue in the NATS subscriber's
// Note: While recovering, new events may queue in the subscriber's
// internal buffer. We don't explicitly buffer them here for simplicity.
// The subscriber will process them in order after recovery completes.
if let Err(e) = recover_from_worker(
......@@ -884,12 +928,31 @@ pub async fn start_kv_router_background_nats_core(
}
}
tracing::debug!("KV Router NATS Core background task exiting");
tracing::debug!("KV Router event plane background task exiting");
});
Ok(())
}
/// Backwards-compatible wrapper for NATS Core local-indexer mode.
pub async fn start_kv_router_background_nats_core(
component: Component,
kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<WorkerId>,
cancellation_token: CancellationToken,
worker_query_client: WorkerQueryClient,
) -> Result<()> {
start_kv_router_background_event_plane(
component,
kv_events_tx,
remove_worker_tx,
cancellation_token,
worker_query_client,
EventTransportKind::Nats,
)
.await
}
/// Cleanup orphaned NATS consumers that no longer have corresponding router entries
async fn cleanup_orphaned_consumers(
nats_queue: &mut NatsQueue,
......
......@@ -27,6 +27,9 @@ async-nats = { workspace = true }
async-stream = { workspace = true }
async-trait = { workspace = true }
async_zmq = { workspace = true }
tmq = { workspace = true }
zmq = { workspace = true }
lru = { version = "0.12" }
axum = { workspace = true }
blake3 = { workspace = true }
bytes = { workspace = true }
......
......@@ -313,6 +313,26 @@ pub mod event_plane {
pub const DYN_EVENT_PLANE_CODEC: &str = "DYN_EVENT_PLANE_CODEC";
}
/// ZMQ Broker environment variables
pub mod zmq_broker {
/// Explicit ZMQ broker URL (takes precedence over discovery)
/// Format: "xsub=<url1>[;<url2>...] , xpub=<url1>[;<url2>...]"
/// Example: "xsub=tcp://broker:5555 , xpub=tcp://broker:5556"
pub const DYN_ZMQ_BROKER_URL: &str = "DYN_ZMQ_BROKER_URL";
/// Enable ZMQ broker discovery mode
pub const DYN_ZMQ_BROKER_ENABLED: &str = "DYN_ZMQ_BROKER_ENABLED";
/// XSUB bind address (broker binary only)
pub const ZMQ_BROKER_XSUB_BIND: &str = "ZMQ_BROKER_XSUB_BIND";
/// XPUB bind address (broker binary only)
pub const ZMQ_BROKER_XPUB_BIND: &str = "ZMQ_BROKER_XPUB_BIND";
/// Namespace for broker discovery registration
pub const ZMQ_BROKER_NAMESPACE: &str = "ZMQ_BROKER_NAMESPACE";
}
/// CUDA and GPU environment variables
pub mod cuda {
/// Path to custom CUDA fatbin file
......@@ -419,6 +439,12 @@ mod tests {
// Event Plane
event_plane::DYN_EVENT_PLANE,
event_plane::DYN_EVENT_PLANE_CODEC,
// ZMQ Broker
zmq_broker::DYN_ZMQ_BROKER_URL,
zmq_broker::DYN_ZMQ_BROKER_ENABLED,
zmq_broker::ZMQ_BROKER_XSUB_BIND,
zmq_broker::ZMQ_BROKER_XPUB_BIND,
zmq_broker::ZMQ_BROKER_NAMESPACE,
// CUDA
cuda::DYNAMO_FATBIN_PATH,
// Build
......
......@@ -133,11 +133,18 @@ pub enum EventTransport {
/// Subject prefix (e.g., "namespace.dynamo.component.backend")
subject_prefix: String,
},
/// ZMQ pub/sub - endpoint address
/// ZMQ pub/sub - endpoint address (direct mode)
Zmq {
/// ZMQ endpoint (e.g., "tcp://host:port")
endpoint: String,
},
/// ZMQ broker endpoints (broker mode) - for discovery of brokers
ZmqBroker {
/// XSUB endpoints (publishers connect here)
xsub_endpoints: Vec<String>,
/// XPUB endpoints (subscribers connect here)
xpub_endpoints: Vec<String>,
},
}
impl EventTransport {
......@@ -145,7 +152,7 @@ impl EventTransport {
pub fn kind(&self) -> EventTransportKind {
match self {
Self::Nats { .. } => EventTransportKind::Nats,
Self::Zmq { .. } => EventTransportKind::Zmq,
Self::Zmq { .. } | Self::ZmqBroker { .. } => EventTransportKind::Zmq,
}
}
......@@ -164,10 +171,14 @@ impl EventTransport {
}
/// Get the subject prefix (NATS) or endpoint (ZMQ)
/// For ZmqBroker, returns the first XSUB endpoint
pub fn address(&self) -> &str {
match self {
Self::Nats { subject_prefix } => subject_prefix,
Self::Zmq { endpoint } => endpoint,
Self::ZmqBroker { xsub_endpoints, .. } => {
xsub_endpoints.first().map(|s| s.as_str()).unwrap_or("")
}
}
}
}
......
......@@ -9,6 +9,57 @@ use serde::{Serialize, de::DeserializeOwned};
use super::EventEnvelope;
/// Codec for serializing and deserializing event envelopes and payloads.
///
/// Currently only supports MessagePack for all transports.
#[derive(Debug, Clone, Copy)]
pub enum Codec {
Msgpack(MsgpackCodec),
}
impl Default for Codec {
fn default() -> Self {
Codec::Msgpack(MsgpackCodec)
}
}
impl Codec {
/// Encode an EventEnvelope to wire bytes
pub fn encode_envelope(&self, envelope: &EventEnvelope) -> Result<Bytes> {
match self {
Codec::Msgpack(c) => c.encode_envelope(envelope),
}
}
/// Decode wire bytes to an EventEnvelope
pub fn decode_envelope(&self, bytes: &Bytes) -> Result<EventEnvelope> {
match self {
Codec::Msgpack(c) => c.decode_envelope(bytes),
}
}
/// Encode a typed payload to bytes (for embedding in envelope)
pub fn encode_payload<T: Serialize>(&self, payload: &T) -> Result<Bytes> {
match self {
Codec::Msgpack(c) => c.encode_payload(payload),
}
}
/// Decode payload bytes to a typed value
pub fn decode_payload<T: DeserializeOwned>(&self, bytes: &Bytes) -> Result<T> {
match self {
Codec::Msgpack(c) => c.decode_payload(bytes),
}
}
/// Codec name for debugging
pub fn name(&self) -> &'static str {
match self {
Codec::Msgpack(c) => c.name(),
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct MsgpackCodec;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Dynamic subscriber that watches discovery and manages connections to multiple publishers.
//!
//! This module enables automatic discovery and connection to new publishers as they come online,
//! and cleanup of disconnected publishers.
use anyhow::Result;
use bytes::Bytes;
use futures::stream::StreamExt;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
use tokio_util::sync::CancellationToken;
use super::transport::{EventTransportRx, WireStream};
use super::zmq_transport::ZmqSubTransport;
use crate::discovery::{
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
EventTransport,
};
/// Manages dynamic subscriptions to multiple publishers.
pub struct DynamicSubscriber {
discovery: Arc<dyn Discovery>,
query: DiscoveryQuery,
topic: String,
cancel_token: CancellationToken,
}
impl DynamicSubscriber {
pub fn new(discovery: Arc<dyn Discovery>, query: DiscoveryQuery, topic: String) -> Self {
Self {
discovery,
query,
topic,
cancel_token: CancellationToken::new(),
}
}
/// Start watching discovery and create a merged stream of events.
pub async fn start_zmq(self: Arc<Self>) -> Result<WireStream> {
let (event_tx, event_rx) = mpsc::unbounded_channel::<Bytes>();
// Track active endpoint connections with instance ID to endpoint mapping
let active_endpoints: Arc<RwLock<HashMap<String, (String, CancellationToken)>>> =
Arc::new(RwLock::new(HashMap::new()));
// Clone self for the spawned task
let subscriber_clone = Arc::clone(&self);
// Spawn background task to watch discovery
let discovery = Arc::clone(&self.discovery);
let query = self.query.clone();
// Use the actual topic for ZMQ native filtering (avoids decoding irrelevant messages)
let zmq_topic = self.topic.clone();
let cancel_token = self.cancel_token.clone();
let endpoints = Arc::clone(&active_endpoints);
tokio::spawn(async move {
tracing::debug!(
?query,
cancel_token_cancelled = cancel_token.is_cancelled(),
"Attempting to start discovery watch"
);
// Don't pass the cancel token to list_and_watch - we'll handle cancellation ourselves
let mut watch_stream = match discovery.list_and_watch(query.clone(), None).await {
Ok(stream) => {
tracing::debug!("Successfully obtained discovery watch stream");
stream
}
Err(e) => {
tracing::error!(error = %e, "Failed to start discovery watch");
return;
}
};
tracing::info!(?query, "Started dynamic discovery watch for ZMQ publishers");
while let Some(event_result) = watch_stream.next().await {
tracing::debug!("Received discovery event: {:?}", event_result);
if cancel_token.is_cancelled() {
tracing::info!("Dynamic subscriber cancelled, stopping watch");
break;
}
match event_result {
Ok(DiscoveryEvent::Added(instance)) => {
tracing::info!(instance = ?instance, "Discovery Added event received");
let instance_id = instance.instance_id().to_string();
// Extract ZMQ endpoint from the instance
if let Some(endpoint) = Self::extract_zmq_endpoint(&instance) {
let mut endpoints_guard = endpoints.write().await;
// Skip if instance already tracked
if endpoints_guard.contains_key(&instance_id) {
tracing::debug!(endpoint = %endpoint, instance_id = %instance_id, "Already connected to ZMQ publisher");
continue;
}
tracing::info!(endpoint = %endpoint, instance_id = %instance_id, "Connecting to new ZMQ publisher");
// Create cancellation token for this endpoint's stream
let endpoint_cancel = CancellationToken::new();
endpoints_guard.insert(
instance_id.clone(),
(endpoint.clone(), endpoint_cancel.clone()),
);
drop(endpoints_guard);
// Spawn task to handle this endpoint's stream
let event_tx_clone = event_tx.clone();
let zmq_topic_clone = zmq_topic.clone();
let endpoint_clone = endpoint.clone();
let endpoints_clone = Arc::clone(&endpoints);
let instance_id_clone = instance_id.clone();
tokio::spawn(async move {
if let Err(e) = Self::consume_endpoint_stream(
&endpoint_clone,
&zmq_topic_clone,
event_tx_clone,
endpoint_cancel,
)
.await
{
tracing::warn!(
endpoint = %endpoint_clone,
error = %e,
"Error consuming ZMQ endpoint stream"
);
}
// Clean up on stream termination
endpoints_clone.write().await.remove(&instance_id_clone);
});
} else {
tracing::warn!(
instance = ?instance,
"Discovery Added event did not contain a ZMQ endpoint"
);
}
}
Ok(DiscoveryEvent::Removed(instance_id)) => {
let id_str = instance_id.instance_id().to_string();
tracing::info!(
instance_id = %id_str,
"ZMQ publisher removed from discovery, cancelling endpoint stream"
);
// Cancel the endpoint's stream via its CancellationToken
if let Some((_endpoint, cancel)) = endpoints.write().await.remove(&id_str) {
cancel.cancel();
tracing::info!(instance_id = %id_str, "Cancelled endpoint stream");
} else {
tracing::warn!(instance_id = %id_str, "No active endpoint found for removed stream instance");
}
}
Err(e) => {
tracing::error!(error = %e, "Discovery watch error");
break;
}
}
}
// Cancel all active endpoints on shutdown
let endpoints_guard = endpoints.write().await;
for (_id, (_endpoint, cancel)) in endpoints_guard.iter() {
cancel.cancel();
}
tracing::info!("Discovery watch stream ended");
});
// Return a stream that reads from the merged channel
let stream = async_stream::stream! {
// Keep subscriber_clone alive by capturing it in the stream
let _subscriber = subscriber_clone;
let mut rx = event_rx;
while let Some(bytes) = rx.recv().await {
yield Ok(bytes);
}
};
Ok(Box::pin(stream))
}
/// Extract ZMQ endpoint from a discovery instance.
fn extract_zmq_endpoint(instance: &DiscoveryInstance) -> Option<String> {
if let DiscoveryInstance::EventChannel { transport, .. } = instance
&& let EventTransport::Zmq { endpoint } = transport
{
return Some(endpoint.clone());
}
None
}
/// Consume events from a single endpoint and forward to the merged channel.
async fn consume_endpoint_stream(
endpoint: &str,
zmq_topic: &str,
event_tx: mpsc::UnboundedSender<Bytes>,
cancel_token: CancellationToken,
) -> Result<()> {
// Connect to the endpoint
let sub_transport = ZmqSubTransport::connect(endpoint, zmq_topic).await?;
let mut stream = sub_transport.subscribe(zmq_topic).await?;
tracing::info!(endpoint = %endpoint, topic = %zmq_topic, "Started consuming ZMQ endpoint stream");
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
tracing::info!(endpoint = %endpoint, "Endpoint stream cancelled");
break;
}
event = stream.next() => {
match event {
Some(Ok(bytes)) => {
if event_tx.send(bytes).is_err() {
tracing::warn!(endpoint = %endpoint, "Event channel closed, stopping endpoint stream");
break;
}
}
Some(Err(e)) => {
tracing::error!(
endpoint = %endpoint,
error = %e,
"Error receiving from ZMQ endpoint"
);
break;
}
None => {
tracing::info!(endpoint = %endpoint, "ZMQ endpoint stream ended");
break;
}
}
}
}
}
Ok(())
}
/// Stop watching and disconnect from all endpoints.
pub fn cancel(&self) {
self.cancel_token.cancel();
}
}
impl Drop for DynamicSubscriber {
fn drop(&mut self) {
self.cancel_token.cancel();
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Event Plane: Transport-agnostic pub/sub communication layer.
//! Generic Event Plane for transport-agnostic pub/sub communication.
mod codec;
mod dynamic_subscriber;
mod frame;
mod nats_transport;
mod traits;
mod transport;
pub mod zmq_transport;
pub use codec::MsgpackCodec;
pub use frame::{Frame, FrameError, FrameHeader};
pub use nats_transport::NatsTransport;
pub use codec::{Codec, MsgpackCodec};
pub use dynamic_subscriber::DynamicSubscriber;
pub use frame::{FRAME_HEADER_SIZE, FRAME_VERSION, Frame, FrameError, FrameHeader};
pub use traits::{EventEnvelope, EventStream, TypedEventStream};
pub use transport::{EventTransportRx, EventTransportTx, WireStream};
pub use zmq_transport::{ZmqPubTransport, ZmqSubTransport};
// Re-export transport kind from discovery for convenience
pub use crate::discovery::EventTransportKind;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use anyhow::Result;
use bytes::Bytes;
use futures::{Stream, StreamExt};
use lru::LruCache;
use serde::Serialize;
use serde::de::DeserializeOwned;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::DistributedRuntime;
use crate::component::{Component, Namespace};
use crate::discovery::{
Discovery, DiscoveryInstance, DiscoveryQuery, DiscoverySpec, EventChannelQuery, EventTransport,
};
use crate::traits::DistributedRuntimeProvider;
use crate::utils::ip_resolver::get_local_ip_for_advertise;
/// Scope of the event plane - determines the subject prefix for pub/sub.
#[derive(Debug, Clone)]
pub enum EventScope {
/// Namespace-level scope: `namespace.{name}`
Namespace { name: String },
/// Component-level scope: `namespace.{namespace}.component.{component}`
Component {
namespace: String,
component: String,
},
}
impl EventScope {
/// Returns the subject prefix for this scope.
pub fn subject_prefix(&self) -> String {
match self {
EventScope::Namespace { name } => format!("namespace.{}", name),
EventScope::Component {
namespace,
component,
} => {
format!("namespace.{}.component.{}", namespace, component)
}
}
}
/// Get the namespace name
pub fn namespace(&self) -> &str {
match self {
EventScope::Namespace { name } => name,
EventScope::Component { namespace, .. } => namespace,
}
}
/// Get the component name (if component-scoped)
pub fn component(&self) -> Option<&str> {
match self {
EventScope::Namespace { .. } => None,
EventScope::Component { component, .. } => Some(component),
}
}
}
// ============================================================================
// Broker Resolution Logic
// ============================================================================
/// Broker endpoints for ZMQ broker mode
#[derive(Debug, Clone)]
struct BrokerEndpoints {
xsub_endpoints: Vec<String>,
xpub_endpoints: Vec<String>,
}
/// Resolve ZMQ broker endpoints from environment or discovery
/// Returns None if broker mode is not configured (direct mode)
async fn resolve_zmq_broker(
drt: &DistributedRuntime,
scope: &EventScope,
) -> Result<Option<BrokerEndpoints>> {
// Priority 1: Explicit URL from DYN_ZMQ_BROKER_URL
if let Ok(broker_url) =
std::env::var(crate::config::environment_names::zmq_broker::DYN_ZMQ_BROKER_URL)
{
let (xsub_endpoints, xpub_endpoints) = parse_broker_url(&broker_url)?;
tracing::info!(
num_xsub = xsub_endpoints.len(),
num_xpub = xpub_endpoints.len(),
"Using explicit ZMQ broker URL"
);
return Ok(Some(BrokerEndpoints {
xsub_endpoints,
xpub_endpoints,
}));
}
// Priority 2: Discovery-based lookup if DYN_ZMQ_BROKER_ENABLED=true
if std::env::var(crate::config::environment_names::zmq_broker::DYN_ZMQ_BROKER_ENABLED)
.unwrap_or_default()
== "true"
{
let query = DiscoveryQuery::EventChannels(EventChannelQuery::component(
scope.namespace().to_string(),
"zmq_broker".to_string(),
));
let instances = drt.discovery().list(query).await?;
// Collect all broker instances (multiple brokers for HA)
let mut xsub_endpoints = Vec::new();
let mut xpub_endpoints = Vec::new();
for instance in instances {
if let DiscoveryInstance::EventChannel { transport, .. } = instance
&& let EventTransport::ZmqBroker {
xsub_endpoints: xsubs,
xpub_endpoints: xpubs,
} = transport
{
xsub_endpoints.extend(xsubs);
xpub_endpoints.extend(xpubs);
}
}
if xsub_endpoints.is_empty() {
anyhow::bail!(
"DYN_ZMQ_BROKER_ENABLED=true but no broker found in discovery for namespace '{}'",
scope.namespace()
);
}
tracing::info!(
num_brokers = xsub_endpoints.len(),
"Discovered ZMQ brokers from discovery plane"
);
return Ok(Some(BrokerEndpoints {
xsub_endpoints,
xpub_endpoints,
}));
}
// No broker configured - use direct mode
Ok(None)
}
/// Parse broker URL format: "xsub=tcp://host1:5555;tcp://host2:5555 , xpub=tcp://host1:5556;tcp://host2:5556"
fn parse_broker_url(url: &str) -> Result<(Vec<String>, Vec<String>)> {
let parts: Vec<&str> = url.split(',').map(|s| s.trim()).collect();
if parts.len() != 2 {
anyhow::bail!(
"Invalid broker URL format. Expected 'xsub=<urls> , xpub=<urls>', got: {}",
url
);
}
let mut xsub_endpoints = Vec::new();
let mut xpub_endpoints = Vec::new();
for part in parts {
if let Some(urls_str) = part.strip_prefix("xsub=") {
xsub_endpoints = urls_str
.split(';')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
} else if let Some(urls_str) = part.strip_prefix("xpub=") {
xpub_endpoints = urls_str
.split(';')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
} else {
anyhow::bail!(
"Invalid broker URL part. Expected 'xsub=' or 'xpub=' prefix, got: {}",
part
);
}
}
if xsub_endpoints.is_empty() || xpub_endpoints.is_empty() {
anyhow::bail!(
"Broker URL must contain at least one xsub and one xpub endpoint. Got xsub={:?}, xpub={:?}",
xsub_endpoints,
xpub_endpoints
);
}
Ok((xsub_endpoints, xpub_endpoints))
}
/// Deduplicates events based on (publisher_id, sequence) tuple
/// Required when connecting to multiple brokers in HA mode
struct DeduplicatingStream {
inner: WireStream,
codec: Arc<Codec>,
seen_events: LruCache<(u64, u64), ()>, // (publisher_id, sequence) -> ()
}
impl DeduplicatingStream {
fn new(inner: WireStream, codec: Arc<Codec>, cache_size: usize) -> Self {
Self {
inner,
codec,
seen_events: LruCache::new(
NonZeroUsize::new(cache_size).expect("cache_size must be non-zero"),
),
}
}
}
impl Stream for DeduplicatingStream {
type Item = Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
// Decode envelope to extract publisher_id and sequence
match self.codec.decode_envelope(&bytes) {
Ok(envelope) => {
let key = (envelope.publisher_id, envelope.sequence);
// Check if we've seen this event before
if self.seen_events.contains(&key) {
// Duplicate - skip and continue loop
tracing::debug!(
publisher_id = envelope.publisher_id,
sequence = envelope.sequence,
"Filtered duplicate event from multi-broker setup"
);
continue;
}
// New event - record and return
self.seen_events.put(key, ());
return Poll::Ready(Some(Ok(bytes)));
}
Err(e) => {
tracing::warn!(error = %e, "Failed to decode envelope for deduplication");
return Poll::Ready(Some(Err(e)));
}
}
}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => return Poll::Pending,
}
}
}
}
/// Event publisher for a specific topic.
pub struct EventPublisher {
transport_kind: EventTransportKind,
scope: EventScope,
topic: String,
publisher_id: u64,
sequence: AtomicU64,
tx: Arc<dyn EventTransportTx>,
codec: Arc<Codec>,
/// Discovery client and registered instance for unregistration on drop
discovery_client: Option<Arc<dyn Discovery>>,
discovery_instance: Option<crate::discovery::DiscoveryInstance>,
}
impl EventPublisher {
/// Create a publisher for a component-scoped topic.
pub async fn for_component(comp: &Component, topic: impl Into<String>) -> Result<Self> {
Self::for_component_with_transport(comp, topic, EventTransportKind::from_env_or_default())
.await
}
/// Create a publisher with explicit transport.
pub async fn for_component_with_transport(
comp: &Component,
topic: impl Into<String>,
transport_kind: EventTransportKind,
) -> Result<Self> {
let drt = comp.drt();
let scope = EventScope::Component {
namespace: comp.namespace().name(),
component: comp.name().to_string(),
};
Self::new_internal(drt, scope, topic.into(), transport_kind).await
}
/// Create a publisher for a namespace-scoped topic.
pub async fn for_namespace(ns: &Namespace, topic: impl Into<String>) -> Result<Self> {
Self::for_namespace_with_transport(ns, topic, EventTransportKind::from_env_or_default())
.await
}
/// Create a namespace publisher with explicit transport.
pub async fn for_namespace_with_transport(
ns: &Namespace,
topic: impl Into<String>,
transport_kind: EventTransportKind,
) -> Result<Self> {
let drt = ns.drt();
let scope = EventScope::Namespace { name: ns.name() };
Self::new_internal(drt, scope, topic.into(), transport_kind).await
}
async fn new_internal(
drt: &DistributedRuntime,
scope: EventScope,
topic: String,
transport_kind: EventTransportKind,
) -> Result<Self> {
let publisher_id = drt.discovery().instance_id();
let discovery = Some(drt.discovery());
// Use Msgpack codec for all transports
enum TransportSetup {
Nats(Arc<dyn EventTransportTx>, Arc<Codec>),
ZmqDirect(Arc<dyn EventTransportTx>, Arc<Codec>, String), // includes public endpoint
ZmqBroker(Arc<dyn EventTransportTx>, Arc<Codec>),
}
let transport_setup = match transport_kind {
EventTransportKind::Nats => {
let transport = Arc::new(nats_transport::NatsTransport::new(drt.clone()));
let codec = Arc::new(Codec::Msgpack(MsgpackCodec));
TransportSetup::Nats(transport as Arc<dyn EventTransportTx>, codec)
}
EventTransportKind::Zmq => {
// Check for broker mode
if let Some(broker) = resolve_zmq_broker(drt, &scope).await? {
// BROKER MODE: Connect to broker (single or multiple endpoints)
let pub_transport = if broker.xsub_endpoints.len() == 1 {
zmq_transport::ZmqPubTransport::connect(&broker.xsub_endpoints[0], &topic)
.await?
} else {
zmq_transport::ZmqPubTransport::connect_multiple(
&broker.xsub_endpoints,
&topic,
)
.await?
};
let codec = Arc::new(Codec::Msgpack(MsgpackCodec));
TransportSetup::ZmqBroker(
Arc::new(pub_transport) as Arc<dyn EventTransportTx>,
codec,
)
} else {
// DIRECT MODE: Bind PUB socket
let (pub_transport, actual_bind_endpoint) = std::thread::spawn({
let topic = topic.clone();
move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("Failed to create Tokio runtime for ZMQ");
rt.block_on(async move {
zmq_transport::ZmqPubTransport::bind("tcp://0.0.0.0:0", &topic)
.await
.expect("Failed to bind ZMQ publisher")
})
}
})
.join()
.expect("Failed to join ZMQ initialization thread");
// Get local IP for public endpoint
let actual_port: u16 = actual_bind_endpoint
.rsplit(':')
.next()
.and_then(|s| s.parse().ok())
.expect("Failed to parse port from bind endpoint");
let local_ip = get_local_ip_for_advertise();
let public_endpoint = format!("tcp://{}:{}", local_ip, actual_port);
let codec = Arc::new(Codec::Msgpack(MsgpackCodec));
TransportSetup::ZmqDirect(
Arc::new(pub_transport) as Arc<dyn EventTransportTx>,
codec,
public_endpoint,
)
}
}
};
// Extract transport and codec, and register if needed
let (tx, codec, discovery_instance) = match transport_setup {
TransportSetup::Nats(tx, codec) => {
let transport_config = EventTransport::nats(scope.subject_prefix());
let spec = DiscoverySpec::EventChannel {
namespace: scope.namespace().to_string(),
component: scope.component().unwrap_or("").to_string(),
topic: topic.clone(),
transport: transport_config,
};
let registered_instance = drt.discovery().register(spec).await?;
tracing::info!(
topic = %topic,
transport = ?transport_kind,
instance_id = %registered_instance.instance_id(),
"EventPublisher registered with discovery"
);
(tx, codec, Some(registered_instance))
}
TransportSetup::ZmqDirect(tx, codec, public_endpoint) => {
let transport_config = EventTransport::zmq(public_endpoint);
let spec = DiscoverySpec::EventChannel {
namespace: scope.namespace().to_string(),
component: scope.component().unwrap_or("").to_string(),
topic: topic.clone(),
transport: transport_config,
};
let registered_instance = drt.discovery().register(spec).await?;
tracing::info!(
topic = %topic,
transport = ?transport_kind,
instance_id = %registered_instance.instance_id(),
"EventPublisher registered with discovery (direct mode)"
);
(tx, codec, Some(registered_instance))
}
TransportSetup::ZmqBroker(tx, codec) => {
tracing::info!(
topic = %topic,
transport = ?transport_kind,
"EventPublisher in broker mode - skipping discovery registration"
);
(tx, codec, None)
}
};
Ok(Self {
transport_kind,
scope,
topic,
publisher_id,
sequence: AtomicU64::new(0),
tx,
codec,
discovery_client: discovery,
discovery_instance,
})
}
/// Publish a serializable event.
pub async fn publish<T: Serialize + Send + Sync>(&self, event: &T) -> Result<()> {
let payload = self.codec.encode_payload(event)?;
self.publish_bytes(payload.to_vec()).await
}
/// Publish raw bytes.
pub async fn publish_bytes(&self, bytes: Vec<u8>) -> Result<()> {
let envelope = EventEnvelope {
publisher_id: self.publisher_id,
sequence: self.sequence.fetch_add(1, Ordering::SeqCst),
published_at: current_timestamp_ms(),
topic: self.topic.clone(),
payload: Bytes::from(bytes),
};
let envelope_bytes = self.codec.encode_envelope(&envelope)?;
let subject = format!("{}.{}", self.scope.subject_prefix(), self.topic);
self.tx.publish(&subject, envelope_bytes).await
}
/// Get the publisher ID.
pub fn publisher_id(&self) -> u64 {
self.publisher_id
}
/// Get the topic.
pub fn topic(&self) -> &str {
&self.topic
}
/// Get the transport kind.
pub fn transport_kind(&self) -> EventTransportKind {
self.transport_kind
}
}
impl Drop for EventPublisher {
fn drop(&mut self) {
// Unregister from discovery on drop
if let (Some(discovery), Some(instance)) =
(self.discovery_client.take(), self.discovery_instance.take())
{
let topic = self.topic.clone();
let instance_id = instance.instance_id();
// Spawn background task for async unregister since Drop is sync
tokio::spawn(async move {
match discovery.unregister(instance).await {
Ok(()) => {
tracing::info!(
topic = %topic,
instance_id = %instance_id,
"EventPublisher unregistered from discovery"
);
}
Err(e) => {
tracing::warn!(
topic = %topic,
instance_id = %instance_id,
error = %e,
"Failed to unregister EventPublisher from discovery"
);
}
}
});
}
}
}
/// Event subscriber for a specific topic.
pub struct EventSubscriber {
stream: EventStream,
#[allow(dead_code)]
scope: EventScope,
#[allow(dead_code)]
topic: String,
codec: Arc<Codec>,
}
impl EventSubscriber {
/// Create a subscriber for a component-scoped topic.
pub async fn for_component(comp: &Component, topic: impl Into<String>) -> Result<Self> {
Self::for_component_with_transport(comp, topic, EventTransportKind::from_env_or_default())
.await
}
/// Create a subscriber with explicit transport.
pub async fn for_component_with_transport(
comp: &Component,
topic: impl Into<String>,
transport_kind: EventTransportKind,
) -> Result<Self> {
let drt = comp.drt();
let scope = EventScope::Component {
namespace: comp.namespace().name(),
component: comp.name().to_string(),
};
Self::new_internal(drt, scope, topic.into(), transport_kind).await
}
/// Create a subscriber for a namespace-scoped topic.
pub async fn for_namespace(ns: &Namespace, topic: impl Into<String>) -> Result<Self> {
Self::for_namespace_with_transport(ns, topic, EventTransportKind::from_env_or_default())
.await
}
/// Create a namespace subscriber with explicit transport.
pub async fn for_namespace_with_transport(
ns: &Namespace,
topic: impl Into<String>,
transport_kind: EventTransportKind,
) -> Result<Self> {
let drt = ns.drt();
let scope = EventScope::Namespace { name: ns.name() };
Self::new_internal(drt, scope, topic.into(), transport_kind).await
}
async fn new_internal(
drt: &DistributedRuntime,
scope: EventScope,
topic: String,
transport_kind: EventTransportKind,
) -> Result<Self> {
let discovery = drt.discovery();
// Use Msgpack codec for all transports
let (wire_stream, codec): (WireStream, Arc<Codec>) = match transport_kind {
EventTransportKind::Nats => {
let transport = nats_transport::NatsTransport::new(drt.clone());
let subject = format!("{}.{}", scope.subject_prefix(), topic);
let stream = transport.subscribe(&subject).await?;
let codec = Arc::new(Codec::Msgpack(MsgpackCodec));
(stream, codec)
}
EventTransportKind::Zmq => {
// Check for broker mode
if let Some(broker) = resolve_zmq_broker(drt, &scope).await? {
// BROKER MODE: Connect to broker's XPUB (single or multiple endpoints)
let codec = Arc::new(Codec::Msgpack(MsgpackCodec));
let stream: WireStream = if broker.xpub_endpoints.len() == 1 {
// Single broker - no deduplication needed
let sub_transport = zmq_transport::ZmqSubTransport::connect_broker(
&broker.xpub_endpoints[0],
&topic,
)
.await?;
sub_transport.subscribe(&topic).await?
} else {
// Multiple brokers - need deduplication
let sub_transport =
zmq_transport::ZmqSubTransport::connect_broker_multiple(
&broker.xpub_endpoints,
&topic,
)
.await?;
let inner_stream = sub_transport.subscribe(&topic).await?;
// Wrap with deduplication (default cache size: 100,000 entries)
Box::pin(DeduplicatingStream::new(
inner_stream,
codec.clone(),
100_000,
))
};
(stream, codec)
} else {
// DIRECT MODE: Use dynamic subscriber to discover and connect to publishers
let query = match &scope {
EventScope::Namespace { name } => {
crate::discovery::DiscoveryQuery::EventChannels(
crate::discovery::EventChannelQuery::namespace(name.clone()),
)
}
EventScope::Component {
namespace,
component,
} => crate::discovery::DiscoveryQuery::EventChannels(
crate::discovery::EventChannelQuery::topic(
namespace.clone(),
component.clone(),
topic.clone(),
),
),
};
let subscriber =
Arc::new(DynamicSubscriber::new(discovery, query, topic.clone()));
let stream = subscriber.start_zmq().await?;
let codec = Arc::new(Codec::Msgpack(MsgpackCodec));
(stream, codec)
}
}
};
// Filter by topic and decode envelopes
let topic_filter = topic.clone();
let codec_for_stream = codec.clone();
let stream = wire_stream.filter_map(move |result| {
let codec = codec_for_stream.clone();
let topic_filter = topic_filter.clone();
async move {
match result {
Ok(bytes) => match codec.decode_envelope(&bytes) {
Ok(envelope) => {
// Filter by topic for transports that don't support native filtering
if envelope.topic == topic_filter {
Some(Ok(envelope))
} else {
None
}
}
Err(e) => Some(Err(e)),
},
Err(e) => Some(Err(e)),
}
}
});
tracing::info!(
topic = %topic,
transport = ?transport_kind,
"EventSubscriber created"
);
Ok(Self {
stream: Box::pin(stream),
scope,
topic,
codec,
})
}
/// Get the next event envelope.
pub async fn next(&mut self) -> Option<Result<EventEnvelope>> {
self.stream.next().await
}
/// Subscribe with automatic deserialization.
pub fn typed<T: DeserializeOwned + Send + 'static>(self) -> TypedEventSubscriber<T> {
TypedEventSubscriber {
stream: self.stream,
codec: self.codec,
_marker: std::marker::PhantomData,
}
}
}
/// Typed event subscriber that deserializes payloads.
pub struct TypedEventSubscriber<T> {
stream: EventStream,
codec: Arc<Codec>,
_marker: std::marker::PhantomData<T>,
}
impl<T: DeserializeOwned + Send + 'static> TypedEventSubscriber<T> {
/// Get the next typed event with its envelope.
pub async fn next(&mut self) -> Option<Result<(EventEnvelope, T)>> {
let envelope = self.stream.next().await?;
match envelope {
Ok(env) => match self.codec.decode_payload(&env.payload) {
Ok(typed) => Some(Ok((env, typed))),
Err(e) => Some(Err(e)),
},
Err(e) => Some(Err(e)),
}
}
}
/// Get current timestamp in milliseconds since Unix epoch.
fn current_timestamp_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::environment_names::event_plane as env;
#[test]
fn test_event_scope_subject_prefix() {
let ns_scope = EventScope::Namespace {
name: "test-ns".to_string(),
};
assert_eq!(ns_scope.subject_prefix(), "namespace.test-ns");
let comp_scope = EventScope::Component {
namespace: "test-ns".to_string(),
component: "test-comp".to_string(),
};
assert_eq!(
comp_scope.subject_prefix(),
"namespace.test-ns.component.test-comp"
);
}
#[test]
fn test_event_scope_accessors() {
let ns_scope = EventScope::Namespace {
name: "my-ns".to_string(),
};
assert_eq!(ns_scope.namespace(), "my-ns");
assert_eq!(ns_scope.component(), None);
let comp_scope = EventScope::Component {
namespace: "my-ns".to_string(),
component: "my-comp".to_string(),
};
assert_eq!(comp_scope.namespace(), "my-ns");
assert_eq!(comp_scope.component(), Some("my-comp"));
}
#[test]
fn test_timestamp_generation() {
let ts = current_timestamp_ms();
// Should be after Jan 1, 2020 (1577836800000) and before Jan 1, 2100 (4102444800000)
assert!(ts > 1577836800000, "Timestamp should be after 2020");
assert!(ts < 4102444800000, "Timestamp should be before 2100");
}
#[test]
fn test_event_envelope_serde() {
let envelope = EventEnvelope {
publisher_id: 42,
sequence: 10,
published_at: 1700000000000,
topic: "test-topic".to_string(),
payload: Bytes::from("test data"),
};
let json = serde_json::to_string(&envelope).expect("serialize");
let deserialized: EventEnvelope = serde_json::from_str(&json).expect("deserialize");
assert_eq!(deserialized.publisher_id, 42);
assert_eq!(deserialized.sequence, 10);
assert_eq!(deserialized.published_at, 1700000000000);
assert_eq!(deserialized.topic, "test-topic");
assert_eq!(deserialized.payload, Bytes::from("test data"));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Uses ZMQ PUB/SUB pattern for one-way event broadcasting:
//! - Publishers bind to endpoints and broadcast events
//! - Subscribers connect to endpoints and receive events
//! - Topic-based filtering at socket level for efficiency
//!
//! ## Message Format
//!
//! ZMQ multipart message:
//! - Frame 0: Topic (string) - for ZMQ subscription filtering
//! - Frame 1: publisher_id (8 bytes, u64 big-endian) - for fast deduplication
//! - Frame 2: sequence (8 bytes, u64 big-endian) - for fast deduplication
//! - Frame 3: Binary frame (5-byte header + EventEnvelope payload)
use anyhow::Result;
use async_stream::stream;
use async_trait::async_trait;
use bytes::Bytes;
use std::sync::{Arc, Mutex};
/// High Water Mark (HWM) for ZMQ sockets.
/// This controls the maximum number of messages that can be queued.
/// Default ZMQ HWM is 1000, which limits scalability.
const ZMQ_SNDHWM: i32 = 100_000; // Send buffer: 100K messages
const ZMQ_RCVHWM: i32 = 100_000; // Receive buffer: 100K messages
use super::codec::MsgpackCodec;
use super::frame::Frame;
use super::transport::{EventTransportRx, EventTransportTx, WireStream};
use crate::discovery::EventTransportKind;
/// ZMQ PUB transport for publishing events.
///
/// Uses raw zmq::Socket with configured HWM for better scalability.
pub struct ZmqPubTransport {
socket: Arc<Mutex<zmq::Socket>>,
topic: String,
}
impl ZmqPubTransport {
/// Create a new ZMQ publisher by binding to an endpoint.
///
/// If port is 0, finds an available port using TcpListener first,
/// then binds ZMQ to that port.
///
/// Returns the transport and the actual bound endpoint.
pub async fn bind(endpoint: &str, topic: &str) -> Result<(Self, String)> {
// Parse the endpoint to check if we need to find an available port
let actual_endpoint = if endpoint.ends_with(":0") {
// Find an available port using TcpListener
let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await?;
let actual_addr = listener.local_addr()?;
let port = actual_addr.port();
drop(listener); // Close listener so ZMQ can bind to the port
format!("tcp://0.0.0.0:{}", port)
} else {
endpoint.to_string()
};
// Create raw ZMQ socket with HWM configuration
let endpoint_for_closure = actual_endpoint.clone();
let socket = tokio::task::spawn_blocking(move || -> Result<zmq::Socket> {
let ctx = zmq::Context::new();
let socket = ctx.socket(zmq::PUB)?;
// Configure High Water Mark for better scalability
socket.set_sndhwm(ZMQ_SNDHWM)?;
// Set send timeout to 0 (non-blocking)
socket.set_sndtimeo(0)?;
// Bind to endpoint
socket.bind(&endpoint_for_closure)?;
Ok(socket)
})
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))??;
tracing::info!(
endpoint = %actual_endpoint,
topic = %topic,
sndhwm = ZMQ_SNDHWM,
"ZMQ PUB transport bound with configured HWM"
);
Ok((
Self {
socket: Arc::new(Mutex::new(socket)),
topic: topic.to_string(),
},
actual_endpoint,
))
}
pub fn topic(&self) -> &str {
&self.topic
}
/// Connect to single broker XSUB endpoint (broker mode)
pub async fn connect(xsub_endpoint: &str, topic: &str) -> Result<Self> {
let endpoint_owned = xsub_endpoint.to_string();
let topic_owned = topic.to_string();
let socket = tokio::task::spawn_blocking(move || -> Result<zmq::Socket> {
let ctx = zmq::Context::new();
let socket = ctx.socket(zmq::PUB)?;
// Configure High Water Mark for better scalability
socket.set_sndhwm(ZMQ_SNDHWM)?;
// Set send timeout to 0 (non-blocking)
socket.set_sndtimeo(0)?;
// Connect (not bind) to broker's XSUB
socket.connect(&endpoint_owned)?;
Ok(socket)
})
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))??;
tracing::info!(
endpoint = %xsub_endpoint,
topic = %topic,
sndhwm = ZMQ_SNDHWM,
"ZMQ PUB transport connected to broker XSUB"
);
Ok(Self {
socket: Arc::new(Mutex::new(socket)),
topic: topic_owned,
})
}
/// Connect to multiple broker XSUB endpoints (HA mode)
pub async fn connect_multiple(xsub_endpoints: &[String], topic: &str) -> Result<Self> {
if xsub_endpoints.is_empty() {
anyhow::bail!("Cannot connect to zero endpoints");
}
let endpoints_owned = xsub_endpoints.to_vec();
let topic_owned = topic.to_string();
let socket = tokio::task::spawn_blocking(move || -> Result<zmq::Socket> {
let ctx = zmq::Context::new();
let socket = ctx.socket(zmq::PUB)?;
// Configure High Water Mark for better scalability
socket.set_sndhwm(ZMQ_SNDHWM)?;
// Set send timeout to 0 (non-blocking)
socket.set_sndtimeo(0)?;
// Connect to all XSUB endpoints (ZMQ handles load balancing)
for endpoint in &endpoints_owned {
socket.connect(endpoint)?;
tracing::debug!(endpoint = %endpoint, "ZMQ PUB connected to broker XSUB");
}
Ok(socket)
})
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))??;
tracing::info!(
num_endpoints = xsub_endpoints.len(),
topic = %topic,
sndhwm = ZMQ_SNDHWM,
"ZMQ PUB transport connected to multiple broker XSUBs with configured HWM"
);
Ok(Self {
socket: Arc::new(Mutex::new(socket)),
topic: topic_owned,
})
}
}
#[async_trait]
impl EventTransportTx for ZmqPubTransport {
async fn publish(&self, _subject: &str, envelope_bytes: Bytes) -> Result<()> {
// Decode envelope to extract publisher_id and sequence for fast deduplication
let codec = MsgpackCodec;
let envelope = codec.decode_envelope(&envelope_bytes)?;
// Create binary frame
let frame = Frame::new(envelope_bytes);
let frame_bytes = frame.encode();
// Prepare multipart message: [topic, publisher_id, sequence, frame_bytes]
let topic_bytes = self.topic.as_bytes().to_vec();
let publisher_id_bytes = envelope.publisher_id.to_be_bytes().to_vec();
let sequence_bytes = envelope.sequence.to_be_bytes().to_vec();
let frame_vec = frame_bytes.to_vec();
let socket = Arc::clone(&self.socket);
tokio::task::spawn_blocking(move || -> Result<()> {
let socket = socket.lock().unwrap();
// Send topic frame (for ZMQ subscription filtering)
socket.send(&topic_bytes, zmq::SNDMORE)?;
// Send publisher_id (for fast deduplication)
socket.send(&publisher_id_bytes, zmq::SNDMORE)?;
// Send sequence (for fast deduplication)
socket.send(&sequence_bytes, zmq::SNDMORE)?;
// Send data frame (complete envelope)
socket.send(&frame_vec, 0)?;
Ok(())
})
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))??;
Ok(())
}
fn kind(&self) -> EventTransportKind {
EventTransportKind::Zmq
}
}
/// ZMQ SUB transport for subscribing to events.
///
/// Uses a background socket pump to avoid holding the socket lock across stream lifetimes.
/// Multiple subscribers can receive events concurrently via broadcast channel.
pub struct ZmqSubTransport {
socket: Arc<Mutex<zmq::Socket>>,
broadcast_tx: tokio::sync::broadcast::Sender<Bytes>,
_socket_pump_handle: tokio::task::JoinHandle<()>,
}
impl ZmqSubTransport {
/// Create a new ZMQ subscriber by connecting to a single endpoint.
pub async fn connect(endpoint: &str, topic: &str) -> Result<Self> {
let endpoint_owned = endpoint.to_string();
let topic_owned = topic.to_string();
let socket = tokio::task::spawn_blocking(move || -> Result<zmq::Socket> {
let ctx = zmq::Context::new();
let socket = ctx.socket(zmq::SUB)?;
// Configure High Water Mark for better scalability
socket.set_rcvhwm(ZMQ_RCVHWM)?;
// Set receive timeout to -1 (blocking)
socket.set_rcvtimeo(-1)?;
// Connect to endpoint
socket.connect(&endpoint_owned)?;
// Subscribe to topic
socket.set_subscribe(topic_owned.as_bytes())?;
Ok(socket)
})
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))??;
tracing::info!(
endpoint = %endpoint,
topic = %topic,
rcvhwm = ZMQ_RCVHWM,
"ZMQ SUB transport connected with configured HWM"
);
let socket = Arc::new(Mutex::new(socket));
// Create broadcast channel for multiple subscribers
let (broadcast_tx, _) = tokio::sync::broadcast::channel(1024);
// Start background socket pump
let pump_handle = Self::start_socket_pump(Arc::clone(&socket), broadcast_tx.clone());
Ok(Self {
socket,
broadcast_tx,
_socket_pump_handle: pump_handle,
})
}
/// Connect to broker's XPUB endpoint (broker mode)
pub async fn connect_broker(xpub_endpoint: &str, topic: &str) -> Result<Self> {
Self::connect(xpub_endpoint, topic).await
}
/// Connect to multiple broker XPUB endpoints (HA mode)
/// Reuses existing connect_multiple implementation
pub async fn connect_broker_multiple(xpub_endpoints: &[String], topic: &str) -> Result<Self> {
Self::connect_multiple(xpub_endpoints, topic).await
}
/// Create a new ZMQ subscriber by connecting to multiple endpoints (fan-in).
pub async fn connect_multiple(endpoints: &[String], topic: &str) -> Result<Self> {
if endpoints.is_empty() {
anyhow::bail!("Cannot connect to zero endpoints");
}
let endpoints_owned = endpoints.to_vec();
let topic_owned = topic.to_string();
let socket = tokio::task::spawn_blocking(move || -> Result<zmq::Socket> {
let ctx = zmq::Context::new();
let socket = ctx.socket(zmq::SUB)?;
// Configure High Water Mark for better scalability
socket.set_rcvhwm(ZMQ_RCVHWM)?;
// Set receive timeout to -1 (blocking)
socket.set_rcvtimeo(-1)?;
// Connect to all endpoints
for endpoint in &endpoints_owned {
socket.connect(endpoint)?;
tracing::debug!(endpoint = %endpoint, "ZMQ SUB connected to endpoint");
}
// Subscribe to topic
socket.set_subscribe(topic_owned.as_bytes())?;
Ok(socket)
})
.await
.map_err(|e| anyhow::anyhow!("Task join error: {}", e))??;
tracing::info!(
num_endpoints = endpoints.len(),
topic = %topic,
rcvhwm = ZMQ_RCVHWM,
"ZMQ SUB transport connected to multiple endpoints with configured HWM"
);
let socket = Arc::new(Mutex::new(socket));
// Create broadcast channel for multiple subscribers
let (broadcast_tx, _) = tokio::sync::broadcast::channel(1024);
// Start background socket pump
let pump_handle = Self::start_socket_pump(Arc::clone(&socket), broadcast_tx.clone());
Ok(Self {
socket,
broadcast_tx,
_socket_pump_handle: pump_handle,
})
}
/// Background task that reads from socket and broadcasts to all subscribers.
///
/// This task holds the socket lock only briefly during each recv operation,
/// allowing multiple subscribers to receive concurrently via broadcast channel.
fn start_socket_pump(
socket: Arc<Mutex<zmq::Socket>>,
broadcast_tx: tokio::sync::broadcast::Sender<Bytes>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
loop {
// Receive multipart message in blocking task: [topic, publisher_id, sequence, frame_bytes]
let socket_clone = Arc::clone(&socket);
let result =
tokio::task::spawn_blocking(move || -> Result<(Vec<u8>, u64, u64, Vec<u8>)> {
let socket = socket_clone.lock().unwrap();
// Receive topic frame
let topic = socket.recv_bytes(0)?;
// Receive publisher_id frame (8 bytes, u64 big-endian)
let publisher_id_bytes = socket.recv_bytes(0)?;
if publisher_id_bytes.len() != 8 {
anyhow::bail!(
"Invalid publisher_id frame: expected 8 bytes, got {}",
publisher_id_bytes.len()
);
}
let publisher_id =
u64::from_be_bytes(publisher_id_bytes.try_into().unwrap());
// Receive sequence frame (8 bytes, u64 big-endian)
let sequence_bytes = socket.recv_bytes(0)?;
if sequence_bytes.len() != 8 {
anyhow::bail!(
"Invalid sequence frame: expected 8 bytes, got {}",
sequence_bytes.len()
);
}
let sequence = u64::from_be_bytes(sequence_bytes.try_into().unwrap());
// Receive data frame
let data = socket.recv_bytes(0)?;
Ok((topic, publisher_id, sequence, data))
})
.await;
match result {
Ok(Ok((_topic, publisher_id, sequence, frame_bytes))) => {
// Log dedup metadata for debugging
tracing::trace!(
publisher_id = publisher_id,
sequence = sequence,
"Socket pump received ZMQ message"
);
// Parse binary frame
let frame_bytes = Bytes::from(frame_bytes);
match Frame::decode(frame_bytes) {
Ok(frame) => {
// Broadcast payload to all subscribers
// Ignore send errors (no receivers or lagging receivers)
let _ = broadcast_tx.send(frame.payload);
}
Err(e) => {
tracing::warn!(error = %e, "Failed to decode ZMQ frame in socket pump");
continue;
}
}
}
Ok(Err(e)) => {
tracing::error!(error = %e, "ZMQ receive error in socket pump");
break;
}
Err(e) => {
tracing::error!(error = %e, "Task join error in socket pump");
break;
}
}
}
tracing::info!("ZMQ socket pump task terminated");
})
}
}
#[async_trait]
impl EventTransportRx for ZmqSubTransport {
async fn subscribe(&self, _subject: &str) -> Result<WireStream> {
// Subscribe to broadcast channel (does not hold socket lock)
let mut receiver = self.broadcast_tx.subscribe();
let stream = stream! {
loop {
match receiver.recv().await {
Ok(payload) => {
yield Ok(payload);
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => {
tracing::warn!(
skipped = skipped,
"Subscriber lagged behind, skipped messages"
);
// Continue receiving, don't break the stream
continue;
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
tracing::info!("Broadcast channel closed");
break;
}
}
}
};
Ok(Box::pin(stream))
}
fn kind(&self) -> EventTransportKind {
EventTransportKind::Zmq
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transports::event_plane::{EventEnvelope, MsgpackCodec};
use tokio::time::{Duration, timeout};
#[tokio::test]
async fn test_zmq_pubsub_basic() {
let port = 25555;
let endpoint = format!("tcp://127.0.0.1:{}", port);
let topic = "test-topic";
let (publisher, _actual_endpoint) = ZmqPubTransport::bind(&endpoint, topic)
.await
.expect("Failed to create publisher");
tokio::time::sleep(Duration::from_millis(100)).await;
let subscriber = ZmqSubTransport::connect(&endpoint, topic)
.await
.expect("Failed to create subscriber");
use futures::StreamExt;
let mut stream = subscriber
.subscribe(topic)
.await
.expect("Failed to create subscription");
tokio::time::sleep(Duration::from_millis(100)).await;
let codec = MsgpackCodec;
let envelope = EventEnvelope {
publisher_id: 12345,
sequence: 1,
published_at: 1700000000000,
topic: topic.to_string(),
payload: Bytes::from("test payload"),
};
let envelope_bytes = codec.encode_envelope(&envelope).unwrap();
publisher.publish(topic, envelope_bytes).await.unwrap();
let result = timeout(Duration::from_secs(2), stream.next()).await;
assert!(result.is_ok(), "Timeout waiting for message");
let received_bytes = result.unwrap().unwrap().unwrap();
let decoded = codec.decode_envelope(&received_bytes).unwrap();
assert_eq!(decoded.publisher_id, 12345);
assert_eq!(decoded.sequence, 1);
assert_eq!(decoded.topic, topic);
}
#[tokio::test]
async fn test_zmq_multiple_messages() {
let port = 25556;
let endpoint = format!("tcp://127.0.0.1:{}", port);
let topic = "multi-test";
let (publisher, _) = ZmqPubTransport::bind(&endpoint, topic).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let subscriber = ZmqSubTransport::connect(&endpoint, topic).await.unwrap();
use futures::StreamExt;
let mut stream = subscriber.subscribe(topic).await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let codec = MsgpackCodec;
for i in 0..5 {
let envelope = EventEnvelope {
publisher_id: 99999,
sequence: i,
published_at: 1700000000000 + i,
topic: topic.to_string(),
payload: Bytes::from(format!("message {}", i)),
};
let bytes = codec.encode_envelope(&envelope).unwrap();
publisher.publish(topic, bytes).await.unwrap();
}
for i in 0..5 {
let result = timeout(Duration::from_secs(2), stream.next()).await;
assert!(result.is_ok(), "Timeout on message {}", i);
let received = result.unwrap().unwrap().unwrap();
let decoded = codec.decode_envelope(&received).unwrap();
assert_eq!(decoded.sequence, i);
assert_eq!(decoded.topic, topic);
}
}
}
......@@ -7,30 +7,20 @@ use crate::pipeline::network::tcp::server::{DefaultIpResolver, IpResolver};
use local_ip_address::Error;
use std::net::IpAddr;
/// Get the local IP address for HTTP RPC host binding, using IpResolver with fallback to 127.0.0.1
///
/// This function attempts to resolve the local IP address using the provided resolver.
/// If resolution fails, it falls back to 127.0.0.1 (localhost).
///
/// IPv6 addresses are wrapped with brackets for safe URL construction (e.g., `[::1]`).
///
/// # Arguments
/// * `resolver` - An implementation of IpResolver trait for getting local IP addresses
///
/// # Returns
/// A string representation of the resolved IP address (IPv6 addresses are bracketed)
pub fn get_http_rpc_host_with_resolver<R: IpResolver>(resolver: R) -> String {
fn resolve_local_ip_with_resolver<R: IpResolver>(resolver: R) -> IpAddr {
let resolved_ip = resolver.local_ip().or_else(|err| match err {
Error::LocalIpAddressNotFound => resolver.local_ipv6(),
_ => Err(err),
});
let addr = match resolved_ip {
match resolved_ip {
Ok(addr) => addr,
Err(Error::LocalIpAddressNotFound) => IpAddr::from([127, 0, 0, 1]),
Err(_) => IpAddr::from([127, 0, 0, 1]), // Fallback for any other error
};
}
}
fn format_ip_for_url(addr: IpAddr) -> String {
// Wrap IPv6 addresses with brackets for safe URL construction
// e.g., "2001:db8::1" becomes "[2001:db8::1]" so that "{host}:{port}" is valid
match addr {
......@@ -39,6 +29,32 @@ pub fn get_http_rpc_host_with_resolver<R: IpResolver>(resolver: R) -> String {
}
}
/// Get the local IP address for advertising endpoints, using IpResolver with fallback to 127.0.0.1
///
/// This function attempts to resolve the local IP address using the provided resolver.
/// If resolution fails, it falls back to 127.0.0.1 (localhost).
///
/// IPv6 addresses are wrapped with brackets for safe URL construction (e.g., `[::1]`).
///
/// # Arguments
/// * `resolver` - An implementation of IpResolver trait for getting local IP addresses
///
/// # Returns
/// A string representation of the resolved IP address (IPv6 addresses are bracketed)
pub fn get_local_ip_for_advertise_with_resolver<R: IpResolver>(resolver: R) -> String {
format_ip_for_url(resolve_local_ip_with_resolver(resolver))
}
/// Get the local IP address for advertising endpoints using the default resolver.
pub fn get_local_ip_for_advertise() -> String {
get_local_ip_for_advertise_with_resolver(DefaultIpResolver)
}
/// Get the local IP address for HTTP RPC host binding, using IpResolver with fallback to 127.0.0.1
pub fn get_http_rpc_host_with_resolver<R: IpResolver>(resolver: R) -> String {
get_local_ip_for_advertise_with_resolver(resolver)
}
/// Get the local IP address for HTTP RPC host binding using the default resolver
///
/// This is a convenience function that uses the DefaultIpResolver.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment