"docs/vscode:/vscode.git/clone" did not exist on "5103efdb64c778feee206defcf9930235c157dd6"
Unverified Commit 8fc70f20 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: add ZMQ transport and its discovery (#5625)

parent 2dbfed89
......@@ -131,7 +131,7 @@ version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc"
dependencies = [
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
......@@ -142,7 +142,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d"
dependencies = [
"anstyle",
"once_cell_polyfill",
"windows-sys 0.60.2",
"windows-sys 0.61.2",
]
[[package]]
......@@ -703,7 +703,7 @@ dependencies = [
"bitflags 2.10.0",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"itertools 0.11.0",
"proc-macro2",
"quote",
"regex",
......@@ -721,7 +721,7 @@ dependencies = [
"bitflags 2.10.0",
"cexpr",
"clang-sys",
"itertools 0.10.5",
"itertools 0.11.0",
"log",
"prettyplease",
"proc-macro2",
......@@ -2132,7 +2132,7 @@ dependencies = [
"libc",
"option-ext",
"redox_users",
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -2502,6 +2502,7 @@ dependencies = [
"libc",
"local-ip-address",
"log",
"lru",
"nid",
"nix 0.29.0",
"notify",
......@@ -2526,6 +2527,7 @@ dependencies = [
"temp-env",
"tempfile",
"thiserror 2.0.17",
"tmq",
"tokio",
"tokio-rayon",
"tokio-stream",
......@@ -2539,6 +2541,7 @@ dependencies = [
"uuid 1.18.1",
"validator",
"xxhash-rust",
"zmq",
]
[[package]]
......@@ -2723,7 +2726,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -4317,7 +4320,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46"
dependencies = [
"hermit-abi 0.5.2",
"libc",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -4380,7 +4383,7 @@ dependencies = [
"portable-atomic",
"portable-atomic-util",
"serde_core",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -4842,6 +4845,15 @@ dependencies = [
"vob",
]
[[package]]
name = "lru"
version = "0.12.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38"
dependencies = [
"hashbrown 0.15.5",
]
[[package]]
name = "lru-slab"
version = "0.1.2"
......@@ -5782,7 +5794,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.59.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -7881,7 +7893,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -9150,7 +9162,7 @@ dependencies = [
"getrandom 0.3.4",
"once_cell",
"rustix",
"windows-sys 0.52.0",
"windows-sys 0.61.2",
]
[[package]]
......@@ -10774,7 +10786,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [
"windows-sys 0.48.0",
"windows-sys 0.61.2",
]
[[package]]
......
......@@ -102,6 +102,8 @@ serde_json = { version = "1" }
strum = { version = "0.27", features = ["derive"] }
tempfile = "3"
thiserror = { version = "2.0.17" }
tmq = { version = "0.5.0" }
zmq = { version = "0.10" }
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1" }
tokio-util = { version = "0.7", features = ["codec", "net", "rt"] }
......
......@@ -69,7 +69,7 @@ serde_json = { workspace = true }
strum = { workspace = true }
tempfile = { workspace = true }
thiserror = { workspace = true }
tmq = "0.5.0"
tmq = { workspace = true }
tokio = { workspace = true }
tokio-stream = { workspace = true }
tokio-util = { workspace = true }
......
......@@ -8,11 +8,10 @@ use dynamo_runtime::component::Client;
use dynamo_runtime::discovery::{DiscoveryQuery, watch_and_extract_field};
use dynamo_runtime::pipeline::{WorkerLoadMonitor, async_trait};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventSubscriber;
use dynamo_runtime::transports::event_plane::EventSubscriber;
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use tokio_stream::StreamExt;
/// Scale factor for storing f64 thresholds as u32 (10000 = 4 decimal places)
const THRESHOLD_SCALE: u32 = 10000;
......@@ -194,8 +193,11 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
card.runtime_config
});
// Subscribe to KV metrics events
let mut kv_metrics_rx = component.namespace().subscribe(KV_METRICS_SUBJECT).await?;
// Subscribe to KV metrics events using EventSubscriber (Msgpack payloads)
let mut kv_metrics_rx =
EventSubscriber::for_namespace(component.namespace(), KV_METRICS_SUBJECT)
.await?
.typed::<ActiveLoad>();
let worker_load_states = self.worker_load_states.clone();
let client = self.client.clone();
......@@ -235,12 +237,13 @@ impl WorkerLoadMonitor for KvWorkerMonitor {
// Handle KV metrics updates (ActiveLoad)
kv_event = kv_metrics_rx.next() => {
let Some(event) = kv_event else {
let Some(event_result) = kv_event else {
tracing::debug!("KV metrics stream closed");
break;
};
let Ok(active_load) = serde_json::from_slice::<ActiveLoad>(&event.payload) else {
let Ok((_envelope, active_load)) = event_result else {
tracing::error!("Error receiving KV metrics event: {event_result:?}");
continue;
};
......
......@@ -9,7 +9,7 @@ use anyhow::Result;
use derive_builder::Builder;
use dynamo_runtime::{
component::{Client, Endpoint},
discovery::{DiscoveryQuery, watch_and_extract_field},
discovery::{DiscoveryQuery, EventTransportKind, watch_and_extract_field},
pipeline::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
SingleIn, async_trait,
......@@ -50,7 +50,7 @@ use crate::{
},
scheduler::{KvScheduler, KvSchedulerError, PotentialLoad, SchedulingRequest},
sequence::SequenceError,
subscriber::{start_kv_router_background, start_kv_router_background_nats_core},
subscriber::{start_kv_router_background, start_kv_router_background_event_plane},
},
local_model::runtime_config::ModelRuntimeConfig,
model_card::ModelDeploymentCard,
......@@ -420,13 +420,25 @@ impl KvRouter {
tracing::info!("Found {count} worker(s), starting KV event subscriber");
let transport_kind = EventTransportKind::from_env_or_default();
// Start subscriber - setup runs synchronously, then spawns background loop internally
if all_local_indexer {
if transport_kind == EventTransportKind::Zmq {
if kv_router_config.router_snapshot_threshold.is_some()
|| kv_router_config.router_reset_states
{
tracing::warn!(
"ZMQ event plane does not support KV snapshots or state reset; ignoring snapshot/reset settings"
);
}
} else {
tracing::info!(
"All {count} workers have local_indexer enabled, using NATS Core subscription"
);
}
start_kv_router_background_nats_core(
start_kv_router_background_event_plane(
component.clone(),
kv_indexer.event_sender(),
kv_indexer.remove_worker_sender(),
......@@ -435,9 +447,15 @@ impl KvRouter {
component.clone(),
runtime_configs_rx.clone(),
),
transport_kind,
)
.await?;
} else {
if transport_kind == EventTransportKind::Zmq {
tracing::warn!(
"Not all workers have local_indexer enabled; falling back to JetStream for durability"
);
}
tracing::info!(
"Not all workers have local_indexer enabled, using JetStream subscription"
);
......@@ -596,7 +614,7 @@ impl KvRouter {
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
let overlap_scores = self.indexer.find_matches(block_hashes).await?;
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
let maybe_seq_hashes = self
.kv_router_config
......
......@@ -7,6 +7,7 @@ use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
use rmp_serde as rmps;
use serde::Deserialize;
use serde::Serialize;
......@@ -15,12 +16,30 @@ use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use zeromq::{Socket, SocketRecv, SubSocket};
use dynamo_runtime::traits::{DistributedRuntimeProvider, events::EventPublisher};
use dynamo_runtime::traits::{
DistributedRuntimeProvider, events::EventPublisher as EventPublisherTrait,
};
use dynamo_runtime::transports::event_plane::EventPublisher;
use dynamo_runtime::{
component::{Component, Namespace},
transports::nats::{NatsQueue, Slug},
};
/// Helper function to create a KV stream name from a component and subject.
///
/// Generates a slugified stream name in the format:
/// `namespace-{namespace}-component-{component}-{subject}`
fn create_kv_stream_name(component: &Component, subject: &str) -> String {
Slug::slugify(&format!(
"namespace.{}.component.{}.{}",
component.namespace().name(),
component.name(),
subject
))
.to_string()
.replace("_", "-")
}
use crate::kv_router::{
KV_EVENT_SUBJECT, KV_METRICS_SUBJECT, WORKER_KV_INDEXER_BUFFER_SIZE,
indexer::{KvIndexerMetrics, LocalKvIndexer, RouterEvent},
......@@ -176,19 +195,27 @@ impl KvEventPublisher {
))
});
// Connect the NatsQueue before passing it to the event processor
let cancellation_token_clone = cancellation_token.clone();
let local_indexer_clone = local_indexer.clone();
if enable_local_indexer {
// When local indexer is enabled, use NATS Core (Component) for publishing.
// This is simpler and doesn't require JetStream durability since recovery
// is handled via the local indexer's event buffer.
tracing::info!("Using NATS Core for KV event publishing (local_indexer mode)");
// When local indexer is enabled, use the event plane directly.
// EventPublisher handles transport selection (ZMQ or NATS) based on environment.
// Durability is provided by the local indexer's event buffer.
tracing::info!("Using event plane for KV event publishing (local_indexer mode)");
let component_clone = component.clone();
component.drt().runtime().secondary().spawn(async move {
let event_publisher =
match EventPublisher::for_component(&component_clone, KV_EVENT_SUBJECT).await {
Ok(publisher) => publisher,
Err(e) => {
tracing::error!("Failed to create event publisher: {}", e);
return;
}
};
start_event_processor(
component_clone,
event_publisher,
worker_id,
cancellation_token_clone,
rx,
......@@ -198,10 +225,7 @@ impl KvEventPublisher {
});
} else {
// When local indexer is disabled, use JetStream (NatsQueue) for durability.
let stream_name =
Slug::slugify(&format!("{}.{}", component.subject(), KV_EVENT_SUBJECT))
.to_string()
.replace("_", "-");
let stream_name = create_kv_stream_name(&component, KV_EVENT_SUBJECT);
let nats_server = std::env::var(env_nats::NATS_SERVER)
.unwrap_or_else(|_| "nats://localhost:4222".to_string());
let mut nats_queue = NatsQueue::new_without_consumer(
......@@ -215,7 +239,7 @@ impl KvEventPublisher {
tracing::error!("Failed to connect NatsQueue: {e}");
return;
}
start_event_processor(
start_event_processor_jetstream(
nats_queue,
worker_id,
cancellation_token_clone,
......@@ -259,7 +283,27 @@ impl Drop for KvEventPublisher {
}
}
async fn start_event_processor<P: EventPublisher + Send + Sync + 'static>(
#[async_trait]
trait EventSink: Send + Sync {
async fn publish_event(&self, event: &RouterEvent) -> Result<()>;
}
#[async_trait]
impl EventSink for EventPublisher {
async fn publish_event(&self, event: &RouterEvent) -> Result<()> {
self.publish(event).await
}
}
#[async_trait]
impl EventSink for NatsQueue {
async fn publish_event(&self, event: &RouterEvent) -> Result<()> {
self.publish(KV_EVENT_SUBJECT, event).await
}
}
/// Event processor for ephemeral transports (NATS Core / ZMQ).
async fn start_event_processor<P: EventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
......@@ -294,11 +338,55 @@ async fn start_event_processor<P: EventPublisher + Send + Sync + 'static>(
}
}
// Then publish to NATS for global distribution
// Use KV_EVENT_SUBJECT so both JetStream and NATS Core subscribers
// can receive events on the expected subject.
if let Err(e) = publisher.publish(KV_EVENT_SUBJECT, &router_event).await {
tracing::error!("Failed to publish event to NATS: {}", e);
// Then publish to event plane for global distribution.
if let Err(e) = publisher.publish_event(&router_event).await {
tracing::error!("Failed to publish event: {}", e);
}
}
}
}
}
/// Event processor using JetStream (durable).
async fn start_event_processor_jetstream(
publisher: NatsQueue,
worker_id: u64,
cancellation_token: CancellationToken,
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>,
) {
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::info!("KV Event source received cancellation signal");
break;
}
event = rx.recv() => {
let Some(event) = event else {
tracing::debug!("Event processor channel closed.");
break;
};
// Encapsulate in a router event.
tracing::trace!("Event processor for worker_id {} processing event: {:?}", worker_id, event.data);
let router_event = RouterEvent::new(worker_id, event);
// Apply to local indexer first (if present)
if let Some(indexer) = &local_indexer {
// Adds event into local indexer, and logs it into internal buffer
if let Err(e) = indexer.apply_event_with_buffer(router_event.clone()).await {
tracing::warn!(
"Failed to send event to local indexer for worker {}: {}",
worker_id,
e
);
}
}
// Then publish to event plane for global distribution
if let Err(e) = publisher.publish_event(&router_event).await {
tracing::error!("Failed to publish event to event plane: {}", e);
}
}
......@@ -886,6 +974,15 @@ impl WorkerMetricsPublisher {
let nats_rx = self.rx.clone();
tokio::spawn(async move {
let event_publisher =
match EventPublisher::for_namespace(&namespace, KV_METRICS_SUBJECT).await {
Ok(publisher) => publisher,
Err(e) => {
tracing::error!("Failed to create metrics publisher: {}", e);
return;
}
};
let mut rx = nats_rx;
let mut last_active_decode_blocks: Option<u64> = Some(0);
let mut pending_publish: Option<WorkerMetrics> = None;
......@@ -933,10 +1030,8 @@ impl WorkerMetricsPublisher {
active_prefill_tokens: None,
};
if let Err(e) =
namespace.publish(KV_METRICS_SUBJECT, &active_load).await
{
tracing::warn!("Failed to publish metrics over NATS: {}", e);
if let Err(e) = event_publisher.publish(&active_load).await {
tracing::warn!("Failed to publish metrics: {}", e);
}
}
......@@ -1076,7 +1171,6 @@ mod tests_startup_helpers {
use crate::kv_router::KvIndexer;
use crate::kv_router::indexer::KvIndexerInterface;
use crate::kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
use async_trait;
use bytes::Bytes;
use std::sync::{Arc, Mutex};
use zeromq::{PubSocket, Socket, SocketSend, ZmqMessage};
......@@ -1105,35 +1199,15 @@ mod tests_startup_helpers {
}
#[async_trait::async_trait]
impl EventPublisher for MockComponent {
async fn publish(
&self,
event_name: impl AsRef<str> + Send + Sync,
event: &(impl serde::Serialize + Send + Sync),
) -> anyhow::Result<()> {
impl EventSink for MockComponent {
async fn publish_event(&self, event: &RouterEvent) -> anyhow::Result<()> {
let bytes = rmp_serde::to_vec(event).unwrap();
self.published
.lock()
.unwrap()
.push((event_name.as_ref().to_string(), bytes));
Ok(())
}
async fn publish_bytes(
&self,
event_name: impl AsRef<str> + Send + Sync,
bytes: Vec<u8>,
) -> anyhow::Result<()> {
self.published
.lock()
.unwrap()
.push((event_name.as_ref().to_string(), bytes));
.push((KV_EVENT_SUBJECT.to_string(), bytes));
Ok(())
}
fn subject(&self) -> String {
"mock.subject".into()
}
}
//--------------------------------------------------------------------
......@@ -1330,7 +1404,7 @@ mod tests_startup_helpers {
}
assert!(no_blocks, "worker should have no blocks after removal");
// Global kvindexer should have received two events (create/remove)
// Global kvindexer should have recieved two events (create/remove)
let published = published.lock().unwrap();
assert_eq!(
published.len(),
......@@ -1409,7 +1483,7 @@ mod tests_startup_helpers {
}
assert!(no_blocks, "worker should have no blocks after clearing");
// Global kvindexer should have received two events (create/remove)
// Global kvindexer should have recieved two events (create/remove)
let published = published.lock().unwrap();
assert_eq!(
published.len(),
......@@ -1815,7 +1889,7 @@ mod test_integration_publisher {
use super::*;
use crate::kv_router::protocols::ActiveLoad;
use dynamo_runtime::distributed_test_utils::create_test_drt_async;
use dynamo_runtime::traits::events::EventSubscriber;
use dynamo_runtime::transports::event_plane::EventSubscriber;
use futures::StreamExt;
#[tokio::test]
......@@ -1825,11 +1899,11 @@ mod test_integration_publisher {
let drt = create_test_drt_async().await;
let namespace = drt.namespace("ns2001".to_string())?;
// Create a subscriber for the metrics events using subscribe_with_type
let mut subscriber = namespace
.subscribe_with_type::<ActiveLoad>(KV_METRICS_SUBJECT)
// Create a subscriber for the metrics events
let mut subscriber = EventSubscriber::for_namespace(&namespace, KV_METRICS_SUBJECT)
.await
.unwrap();
.unwrap()
.typed::<ActiveLoad>();
// Create WorkerMetricsPublisher
let publisher = WorkerMetricsPublisher::new().unwrap();
......@@ -1857,7 +1931,7 @@ mod test_integration_publisher {
.await
.unwrap();
let event = result.unwrap().unwrap(); // Unwrap the Option and the Result
let (_envelope, event) = result.unwrap().unwrap(); // Unwrap the Option and the Result
assert_eq!(event.worker_id, worker_id);
assert_eq!(event.active_decode_blocks, Some(900)); // Last value: 9 * 100
assert_eq!(event.active_prefill_tokens, None); // Worker doesn't publish prefill tokens
......
......@@ -6,7 +6,7 @@ use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventPublisher;
use dynamo_runtime::transports::event_plane::EventPublisher;
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
......@@ -50,6 +50,9 @@ pub enum KvSchedulerError {
#[error("endpoint subscriber shutdown")]
SubscriberShutdown,
#[error("failed to initialize event publisher: {0}")]
InitFailed(String),
}
#[derive(Debug)]
......@@ -111,13 +114,17 @@ impl KvScheduler {
.map(|r| (*r.key(), r.value().clone()))
.collect();
let slots = Arc::new(ActiveSequencesMultiWorker::new(
let slots = Arc::new(
ActiveSequencesMultiWorker::new(
component.clone(),
block_size as usize,
initial_workers,
replica_sync,
router_id,
));
)
.await
.map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?,
);
// Spawn background task to sync slots with DashMap when notified of changes.
// ModelManager's watcher updates the DashMap and notifies; we wait on notify here.
......@@ -160,7 +167,10 @@ impl KvScheduler {
let workers_scheduler = workers_with_configs.clone();
let (request_tx, request_rx) = tokio::sync::mpsc::channel::<SchedulingRequest>(1024);
let scheduler_cancel_token = component.drt().primary_token();
let ns_clone = component.namespace().clone();
let hit_rate_publisher =
EventPublisher::for_namespace(component.namespace(), KV_HIT_RATE_SUBJECT)
.await
.map_err(|e| KvSchedulerError::InitFailed(e.to_string()))?;
// Background task to handle scheduling requests
tokio::spawn(async move {
......@@ -206,7 +216,7 @@ impl KvScheduler {
isl_blocks: selection.required_blocks as usize,
overlap_blocks: selection.overlap_blocks,
};
if let Err(e) = ns_clone.publish(KV_HIT_RATE_SUBJECT, &event).await {
if let Err(e) = hit_rate_publisher.publish(&event).await {
tracing::warn!("Failed to publish KV hit rate event: {:?}", e);
}
......
......@@ -29,8 +29,7 @@ use dashmap::DashMap;
use derive_getters::Getters;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::{EventPublisher, EventSubscriber};
use futures::StreamExt;
use dynamo_runtime::transports::event_plane::{EventPublisher, EventSubscriber};
use std::collections::{HashMap, HashSet};
use std::rc::{Rc, Weak};
use std::sync::Arc;
......@@ -410,17 +409,21 @@ pub struct ActiveSequencesMultiWorker {
block_size: usize,
component: Component,
router_id: u64,
/// Publisher for sequence events
event_publisher: EventPublisher,
/// Publisher for metrics (namespace-scoped)
metrics_publisher: EventPublisher,
replica_sync: bool,
}
impl ActiveSequencesMultiWorker {
pub fn new(
pub async fn new(
component: Component,
block_size: usize,
workers_with_configs: HashMap<u64, Option<ModelRuntimeConfig>>,
replica_sync: bool,
router_id: u64,
) -> Self {
) -> Result<Self> {
assert!(block_size > 1, "block_size must be greater than 1");
let senders = Arc::new(DashMap::new());
......@@ -441,12 +444,19 @@ impl ActiveSequencesMultiWorker {
}
}
let event_publisher =
EventPublisher::for_component(&component, ACTIVE_SEQUENCES_SUBJECT).await?;
let metrics_publisher =
EventPublisher::for_namespace(component.namespace(), KV_METRICS_SUBJECT).await?;
let multi_worker = Self {
senders: senders.clone(),
request_to_worker: request_to_worker.clone(),
handles,
block_size,
component: component.clone(),
event_publisher,
metrics_publisher,
router_id,
replica_sync,
};
......@@ -475,7 +485,7 @@ impl ActiveSequencesMultiWorker {
});
}
multi_worker
Ok(multi_worker)
}
/// Helper method to start a worker task
......@@ -597,9 +607,9 @@ impl ActiveSequencesMultiWorker {
router_id: u64,
cancel_token: CancellationToken,
) -> Result<()> {
let mut subscriber = component
.subscribe_with_type::<ActiveSequenceEvent>(ACTIVE_SEQUENCES_SUBJECT)
.await?;
let mut subscriber = EventSubscriber::for_component(&component, ACTIVE_SEQUENCES_SUBJECT)
.await?
.typed::<ActiveSequenceEvent>();
loop {
tokio::select! {
......@@ -610,7 +620,7 @@ impl ActiveSequencesMultiWorker {
break;
};
let Ok(event) = result else {
let Ok((_envelope, event)) = result else {
tracing::error!(
"Error receiving active sequence event: {}",
result.unwrap_err()
......@@ -770,9 +780,7 @@ impl ActiveSequencesMultiWorker {
},
router_id: self.router_id,
};
self.component
.publish(ACTIVE_SEQUENCES_SUBJECT, &event)
.await?;
self.event_publisher.publish(&event).await?;
}
// Update local state with full WorkerWithDpRank
......@@ -831,9 +839,7 @@ impl ActiveSequencesMultiWorker {
data: ActiveSequenceEventData::Free,
router_id: self.router_id,
};
self.component
.publish(ACTIVE_SEQUENCES_SUBJECT, &event)
.await?;
self.event_publisher.publish(&event).await?;
}
// Update local state
......@@ -882,9 +888,7 @@ impl ActiveSequencesMultiWorker {
data: ActiveSequenceEventData::MarkPrefillCompleted,
router_id: self.router_id,
};
self.component
.publish(ACTIVE_SEQUENCES_SUBJECT, &event)
.await?;
self.event_publisher.publish(&event).await?;
}
// Update local state
......@@ -999,12 +1003,7 @@ impl ActiveSequencesMultiWorker {
active_prefill_tokens: Some(active_tokens as u64),
};
if let Err(e) = self
.component
.namespace()
.publish(KV_METRICS_SUBJECT, &active_load)
.await
{
if let Err(e) = self.metrics_publisher.publish(&active_load).await {
tracing::warn!("Failed to publish ActiveLoad for worker {worker:?}: {e:?}");
}
}
......@@ -1236,20 +1235,20 @@ mod tests {
let config_worker_1 = crate::local_model::runtime_config::ModelRuntimeConfig::new();
workers_with_configs.insert(1, Some(config_worker_1));
let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new(
let seq_manager_1 = Arc::new(
ActiveSequencesMultiWorker::new(
component.clone(),
block_size,
workers_with_configs.clone(),
true,
1,
));
let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
component,
block_size,
workers_with_configs,
true,
2,
));
)
.await?,
);
let seq_manager_2 = Arc::new(
ActiveSequencesMultiWorker::new(component, block_size, workers_with_configs, true, 2)
.await?,
);
// Give some time for the subscription loops to start
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
......@@ -1395,20 +1394,20 @@ mod tests {
workers_with_configs.insert(1, None);
workers_with_configs.insert(2, None);
let seq_manager_1 = Arc::new(ActiveSequencesMultiWorker::new(
let seq_manager_1 = Arc::new(
ActiveSequencesMultiWorker::new(
component.clone(),
block_size,
workers_with_configs.clone(),
true,
1,
));
let seq_manager_2 = Arc::new(ActiveSequencesMultiWorker::new(
component,
block_size,
workers_with_configs,
true,
2,
));
)
.await?,
);
let seq_manager_2 = Arc::new(
ActiveSequencesMultiWorker::new(component, block_size, workers_with_configs, true, 2)
.await?,
);
// Give some time for the subscription loops to start
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
......
......@@ -7,9 +7,9 @@ use anyhow::Result;
use dynamo_runtime::{
component::Component,
config::environment_names::nats as env_nats,
discovery::{DiscoveryEvent, DiscoveryQuery},
discovery::{DiscoveryEvent, DiscoveryQuery, EventTransportKind},
prelude::*,
traits::events::{EventPublisher, EventSubscriber},
transports::event_plane::EventSubscriber,
transports::nats::{NatsQueue, Slug},
};
use futures::StreamExt;
......@@ -25,6 +25,21 @@ use crate::kv_router::{
worker_query::WorkerQueryClient,
};
/// Helper function to create a KV stream name from a component and subject.
///
/// Generates a slugified stream name in the format:
/// `namespace-{namespace}-component-{component}-{subject}`
fn create_kv_stream_name(component: &Component, subject: &str) -> String {
Slug::slugify(&format!(
"namespace.{}.component.{}.{}",
component.namespace().name(),
component.name(),
subject
))
.to_string()
.replace("_", "-")
}
/// Delay between snapshot reads to verify stability
const SNAPSHOT_STABILITY_DELAY: Duration = Duration::from_millis(100);
const MAX_SNAPSHOT_STABILITY_ATTEMPTS: usize = 10;
......@@ -463,9 +478,7 @@ pub async fn start_kv_router_background(
router_reset_states: bool,
) -> Result<()> {
// Set up NATS connections
let stream_name = Slug::slugify(&format!("{}.{}", component.subject(), KV_EVENT_SUBJECT))
.to_string()
.replace("_", "-");
let stream_name = create_kv_stream_name(&component, KV_EVENT_SUBJECT);
let nats_server = std::env::var(env_nats::NATS_SERVER)
.unwrap_or_else(|_| "nats://localhost:4222".to_string());
......@@ -485,7 +498,12 @@ pub async fn start_kv_router_background(
let nats_client = client_options.connect().await?;
// Create bucket name for snapshots/state
let bucket_name = Slug::slugify(&format!("{}-{RADIX_STATE_BUCKET}", component.subject()))
let event_plane_subject = format!(
"namespace.{}.component.{}",
component.namespace().name(),
component.name()
);
let bucket_name = Slug::slugify(&format!("{}-{RADIX_STATE_BUCKET}", event_plane_subject))
.to_string()
.replace("_", "-");
......@@ -726,11 +744,11 @@ async fn handle_worker_discovery(
}
}
/// Start a simplified background task for event consumption using NATS Core.
/// Start a simplified background task for event consumption using the event plane.
///
/// This is used when local indexer mode is enabled. Unlike `start_kv_router_background`,
/// this function:
/// - Uses NATS Core pub/sub instead of JetStream
/// - Uses the event plane (NATS Core or ZMQ) instead of JetStream
/// - Does not support snapshots, purging, or durable consumers
/// - On worker Added: dumps worker's local indexer into router
/// - On worker Removed: removes worker from router indexer
......@@ -739,21 +757,40 @@ async fn handle_worker_discovery(
/// spawning the background task, ensuring the router is ready before returning.
///
/// This is appropriate when workers have local indexers enabled.
pub async fn start_kv_router_background_nats_core(
pub async fn start_kv_router_background_event_plane(
component: Component,
kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<WorkerId>,
cancellation_token: CancellationToken,
worker_query_client: WorkerQueryClient,
transport_kind: EventTransportKind,
) -> Result<()> {
// Subscribe to KV events using NATS Core
let mut subscriber = component.subscribe(KV_EVENT_SUBJECT).await?;
let kv_event_subject = format!("{}.{}", component.subject(), KV_EVENT_SUBJECT);
// Subscribe to KV events using the selected event plane transport
let mut subscriber =
EventSubscriber::for_component_with_transport(&component, KV_EVENT_SUBJECT, transport_kind)
.await?
.typed::<RouterEvent>();
let kv_event_subject = format!(
"namespace.{}.component.{}.{}",
component.namespace().name(),
component.name(),
KV_EVENT_SUBJECT
);
match transport_kind {
EventTransportKind::Nats => {
tracing::info!(
subject = %kv_event_subject,
"KV Router using NATS Core subscription (local_indexer mode)"
);
}
EventTransportKind::Zmq => {
tracing::info!(
subject = %kv_event_subject,
"KV Router using ZMQ event plane subscription (local_indexer mode)"
);
}
}
// Wait for at least one worker instance before proceeding
let mut instance_event_stream =
......@@ -802,7 +839,7 @@ pub async fn start_kv_router_background_nats_core(
biased;
_ = cancellation_token.cancelled() => {
tracing::debug!("KV Router NATS Core background task received cancellation signal");
tracing::debug!("KV Router event plane background task received cancellation signal");
break;
}
......@@ -821,12 +858,12 @@ pub async fn start_kv_router_background_nats_core(
.await;
}
// Handle event consumption from NATS Core subscription
Some(msg) = subscriber.next() => {
let event: RouterEvent = match serde_json::from_slice(&msg.payload) {
Ok(event) => event,
// Handle event consumption from event plane subscription
Some(result) = subscriber.next() => {
let (envelope, event) = match result {
Ok((envelope, event)) => (envelope, event),
Err(e) => {
tracing::warn!("Failed to deserialize RouterEvent from NATS Core: {e:?}");
tracing::warn!("Failed to receive RouterEvent from event plane: {e:?}");
continue;
}
};
......@@ -834,6 +871,13 @@ pub async fn start_kv_router_background_nats_core(
let worker_id = event.worker_id;
let event_id = event.event.event_id;
// Use envelope metadata for additional debugging
tracing::trace!(
"Received event from publisher {} (seq {})",
envelope.publisher_id,
envelope.sequence
);
// Gap detection: check if event ID is monotonically increasing per worker
// Note: event_id <= last_id is duplicate/out-of-order, apply anyway (idempotent)
if let Some(&last_id) = last_event_ids.get(&worker_id)
......@@ -847,7 +891,7 @@ pub async fn start_kv_router_background_nats_core(
"Event ID gap detected for worker {worker_id}, recovering events [{gap_start}, {gap_end}], gap_size: {gap_size}"
);
// Note: While recovering, new events may queue in the NATS subscriber's
// Note: While recovering, new events may queue in the subscriber's
// internal buffer. We don't explicitly buffer them here for simplicity.
// The subscriber will process them in order after recovery completes.
if let Err(e) = recover_from_worker(
......@@ -884,12 +928,31 @@ pub async fn start_kv_router_background_nats_core(
}
}
tracing::debug!("KV Router NATS Core background task exiting");
tracing::debug!("KV Router event plane background task exiting");
});
Ok(())
}
/// Backwards-compatible wrapper for NATS Core local-indexer mode.
pub async fn start_kv_router_background_nats_core(
component: Component,
kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<WorkerId>,
cancellation_token: CancellationToken,
worker_query_client: WorkerQueryClient,
) -> Result<()> {
start_kv_router_background_event_plane(
component,
kv_events_tx,
remove_worker_tx,
cancellation_token,
worker_query_client,
EventTransportKind::Nats,
)
.await
}
/// Cleanup orphaned NATS consumers that no longer have corresponding router entries
async fn cleanup_orphaned_consumers(
nats_queue: &mut NatsQueue,
......
......@@ -27,6 +27,9 @@ async-nats = { workspace = true }
async-stream = { workspace = true }
async-trait = { workspace = true }
async_zmq = { workspace = true }
tmq = { workspace = true }
zmq = { workspace = true }
lru = { version = "0.12" }
axum = { workspace = true }
blake3 = { workspace = true }
bytes = { workspace = true }
......
......@@ -313,6 +313,26 @@ pub mod event_plane {
pub const DYN_EVENT_PLANE_CODEC: &str = "DYN_EVENT_PLANE_CODEC";
}
/// ZMQ Broker environment variables
pub mod zmq_broker {
/// Explicit ZMQ broker URL (takes precedence over discovery)
/// Format: "xsub=<url1>[;<url2>...] , xpub=<url1>[;<url2>...]"
/// Example: "xsub=tcp://broker:5555 , xpub=tcp://broker:5556"
pub const DYN_ZMQ_BROKER_URL: &str = "DYN_ZMQ_BROKER_URL";
/// Enable ZMQ broker discovery mode
pub const DYN_ZMQ_BROKER_ENABLED: &str = "DYN_ZMQ_BROKER_ENABLED";
/// XSUB bind address (broker binary only)
pub const ZMQ_BROKER_XSUB_BIND: &str = "ZMQ_BROKER_XSUB_BIND";
/// XPUB bind address (broker binary only)
pub const ZMQ_BROKER_XPUB_BIND: &str = "ZMQ_BROKER_XPUB_BIND";
/// Namespace for broker discovery registration
pub const ZMQ_BROKER_NAMESPACE: &str = "ZMQ_BROKER_NAMESPACE";
}
/// CUDA and GPU environment variables
pub mod cuda {
/// Path to custom CUDA fatbin file
......@@ -419,6 +439,12 @@ mod tests {
// Event Plane
event_plane::DYN_EVENT_PLANE,
event_plane::DYN_EVENT_PLANE_CODEC,
// ZMQ Broker
zmq_broker::DYN_ZMQ_BROKER_URL,
zmq_broker::DYN_ZMQ_BROKER_ENABLED,
zmq_broker::ZMQ_BROKER_XSUB_BIND,
zmq_broker::ZMQ_BROKER_XPUB_BIND,
zmq_broker::ZMQ_BROKER_NAMESPACE,
// CUDA
cuda::DYNAMO_FATBIN_PATH,
// Build
......
......@@ -133,11 +133,18 @@ pub enum EventTransport {
/// Subject prefix (e.g., "namespace.dynamo.component.backend")
subject_prefix: String,
},
/// ZMQ pub/sub - endpoint address
/// ZMQ pub/sub - endpoint address (direct mode)
Zmq {
/// ZMQ endpoint (e.g., "tcp://host:port")
endpoint: String,
},
/// ZMQ broker endpoints (broker mode) - for discovery of brokers
ZmqBroker {
/// XSUB endpoints (publishers connect here)
xsub_endpoints: Vec<String>,
/// XPUB endpoints (subscribers connect here)
xpub_endpoints: Vec<String>,
},
}
impl EventTransport {
......@@ -145,7 +152,7 @@ impl EventTransport {
pub fn kind(&self) -> EventTransportKind {
match self {
Self::Nats { .. } => EventTransportKind::Nats,
Self::Zmq { .. } => EventTransportKind::Zmq,
Self::Zmq { .. } | Self::ZmqBroker { .. } => EventTransportKind::Zmq,
}
}
......@@ -164,10 +171,14 @@ impl EventTransport {
}
/// Get the subject prefix (NATS) or endpoint (ZMQ)
/// For ZmqBroker, returns the first XSUB endpoint
pub fn address(&self) -> &str {
match self {
Self::Nats { subject_prefix } => subject_prefix,
Self::Zmq { endpoint } => endpoint,
Self::ZmqBroker { xsub_endpoints, .. } => {
xsub_endpoints.first().map(|s| s.as_str()).unwrap_or("")
}
}
}
}
......
......@@ -9,6 +9,57 @@ use serde::{Serialize, de::DeserializeOwned};
use super::EventEnvelope;
/// Codec for serializing and deserializing event envelopes and payloads.
///
/// Currently only supports MessagePack for all transports.
#[derive(Debug, Clone, Copy)]
pub enum Codec {
Msgpack(MsgpackCodec),
}
impl Default for Codec {
fn default() -> Self {
Codec::Msgpack(MsgpackCodec)
}
}
impl Codec {
/// Encode an EventEnvelope to wire bytes
pub fn encode_envelope(&self, envelope: &EventEnvelope) -> Result<Bytes> {
match self {
Codec::Msgpack(c) => c.encode_envelope(envelope),
}
}
/// Decode wire bytes to an EventEnvelope
pub fn decode_envelope(&self, bytes: &Bytes) -> Result<EventEnvelope> {
match self {
Codec::Msgpack(c) => c.decode_envelope(bytes),
}
}
/// Encode a typed payload to bytes (for embedding in envelope)
pub fn encode_payload<T: Serialize>(&self, payload: &T) -> Result<Bytes> {
match self {
Codec::Msgpack(c) => c.encode_payload(payload),
}
}
/// Decode payload bytes to a typed value
pub fn decode_payload<T: DeserializeOwned>(&self, bytes: &Bytes) -> Result<T> {
match self {
Codec::Msgpack(c) => c.decode_payload(bytes),
}
}
/// Codec name for debugging
pub fn name(&self) -> &'static str {
match self {
Codec::Msgpack(c) => c.name(),
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct MsgpackCodec;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Dynamic subscriber that watches discovery and manages connections to multiple publishers.
//!
//! This module enables automatic discovery and connection to new publishers as they come online,
//! and cleanup of disconnected publishers.
use anyhow::Result;
use bytes::Bytes;
use futures::stream::StreamExt;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, mpsc};
use tokio_util::sync::CancellationToken;
use super::transport::{EventTransportRx, WireStream};
use super::zmq_transport::ZmqSubTransport;
use crate::discovery::{
Discovery, DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId, DiscoveryQuery,
EventTransport,
};
/// Manages dynamic subscriptions to multiple publishers.
pub struct DynamicSubscriber {
discovery: Arc<dyn Discovery>,
query: DiscoveryQuery,
topic: String,
cancel_token: CancellationToken,
}
impl DynamicSubscriber {
pub fn new(discovery: Arc<dyn Discovery>, query: DiscoveryQuery, topic: String) -> Self {
Self {
discovery,
query,
topic,
cancel_token: CancellationToken::new(),
}
}
/// Start watching discovery and create a merged stream of events.
pub async fn start_zmq(self: Arc<Self>) -> Result<WireStream> {
let (event_tx, event_rx) = mpsc::unbounded_channel::<Bytes>();
// Track active endpoint connections with instance ID to endpoint mapping
let active_endpoints: Arc<RwLock<HashMap<String, (String, CancellationToken)>>> =
Arc::new(RwLock::new(HashMap::new()));
// Clone self for the spawned task
let subscriber_clone = Arc::clone(&self);
// Spawn background task to watch discovery
let discovery = Arc::clone(&self.discovery);
let query = self.query.clone();
// Use the actual topic for ZMQ native filtering (avoids decoding irrelevant messages)
let zmq_topic = self.topic.clone();
let cancel_token = self.cancel_token.clone();
let endpoints = Arc::clone(&active_endpoints);
tokio::spawn(async move {
tracing::debug!(
?query,
cancel_token_cancelled = cancel_token.is_cancelled(),
"Attempting to start discovery watch"
);
// Don't pass the cancel token to list_and_watch - we'll handle cancellation ourselves
let mut watch_stream = match discovery.list_and_watch(query.clone(), None).await {
Ok(stream) => {
tracing::debug!("Successfully obtained discovery watch stream");
stream
}
Err(e) => {
tracing::error!(error = %e, "Failed to start discovery watch");
return;
}
};
tracing::info!(?query, "Started dynamic discovery watch for ZMQ publishers");
while let Some(event_result) = watch_stream.next().await {
tracing::debug!("Received discovery event: {:?}", event_result);
if cancel_token.is_cancelled() {
tracing::info!("Dynamic subscriber cancelled, stopping watch");
break;
}
match event_result {
Ok(DiscoveryEvent::Added(instance)) => {
tracing::info!(instance = ?instance, "Discovery Added event received");
let instance_id = instance.instance_id().to_string();
// Extract ZMQ endpoint from the instance
if let Some(endpoint) = Self::extract_zmq_endpoint(&instance) {
let mut endpoints_guard = endpoints.write().await;
// Skip if instance already tracked
if endpoints_guard.contains_key(&instance_id) {
tracing::debug!(endpoint = %endpoint, instance_id = %instance_id, "Already connected to ZMQ publisher");
continue;
}
tracing::info!(endpoint = %endpoint, instance_id = %instance_id, "Connecting to new ZMQ publisher");
// Create cancellation token for this endpoint's stream
let endpoint_cancel = CancellationToken::new();
endpoints_guard.insert(
instance_id.clone(),
(endpoint.clone(), endpoint_cancel.clone()),
);
drop(endpoints_guard);
// Spawn task to handle this endpoint's stream
let event_tx_clone = event_tx.clone();
let zmq_topic_clone = zmq_topic.clone();
let endpoint_clone = endpoint.clone();
let endpoints_clone = Arc::clone(&endpoints);
let instance_id_clone = instance_id.clone();
tokio::spawn(async move {
if let Err(e) = Self::consume_endpoint_stream(
&endpoint_clone,
&zmq_topic_clone,
event_tx_clone,
endpoint_cancel,
)
.await
{
tracing::warn!(
endpoint = %endpoint_clone,
error = %e,
"Error consuming ZMQ endpoint stream"
);
}
// Clean up on stream termination
endpoints_clone.write().await.remove(&instance_id_clone);
});
} else {
tracing::warn!(
instance = ?instance,
"Discovery Added event did not contain a ZMQ endpoint"
);
}
}
Ok(DiscoveryEvent::Removed(instance_id)) => {
let id_str = instance_id.instance_id().to_string();
tracing::info!(
instance_id = %id_str,
"ZMQ publisher removed from discovery, cancelling endpoint stream"
);
// Cancel the endpoint's stream via its CancellationToken
if let Some((_endpoint, cancel)) = endpoints.write().await.remove(&id_str) {
cancel.cancel();
tracing::info!(instance_id = %id_str, "Cancelled endpoint stream");
} else {
tracing::warn!(instance_id = %id_str, "No active endpoint found for removed stream instance");
}
}
Err(e) => {
tracing::error!(error = %e, "Discovery watch error");
break;
}
}
}
// Cancel all active endpoints on shutdown
let endpoints_guard = endpoints.write().await;
for (_id, (_endpoint, cancel)) in endpoints_guard.iter() {
cancel.cancel();
}
tracing::info!("Discovery watch stream ended");
});
// Return a stream that reads from the merged channel
let stream = async_stream::stream! {
// Keep subscriber_clone alive by capturing it in the stream
let _subscriber = subscriber_clone;
let mut rx = event_rx;
while let Some(bytes) = rx.recv().await {
yield Ok(bytes);
}
};
Ok(Box::pin(stream))
}
/// Extract ZMQ endpoint from a discovery instance.
fn extract_zmq_endpoint(instance: &DiscoveryInstance) -> Option<String> {
if let DiscoveryInstance::EventChannel { transport, .. } = instance
&& let EventTransport::Zmq { endpoint } = transport
{
return Some(endpoint.clone());
}
None
}
/// Consume events from a single endpoint and forward to the merged channel.
async fn consume_endpoint_stream(
endpoint: &str,
zmq_topic: &str,
event_tx: mpsc::UnboundedSender<Bytes>,
cancel_token: CancellationToken,
) -> Result<()> {
// Connect to the endpoint
let sub_transport = ZmqSubTransport::connect(endpoint, zmq_topic).await?;
let mut stream = sub_transport.subscribe(zmq_topic).await?;
tracing::info!(endpoint = %endpoint, topic = %zmq_topic, "Started consuming ZMQ endpoint stream");
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
tracing::info!(endpoint = %endpoint, "Endpoint stream cancelled");
break;
}
event = stream.next() => {
match event {
Some(Ok(bytes)) => {
if event_tx.send(bytes).is_err() {
tracing::warn!(endpoint = %endpoint, "Event channel closed, stopping endpoint stream");
break;
}
}
Some(Err(e)) => {
tracing::error!(
endpoint = %endpoint,
error = %e,
"Error receiving from ZMQ endpoint"
);
break;
}
None => {
tracing::info!(endpoint = %endpoint, "ZMQ endpoint stream ended");
break;
}
}
}
}
}
Ok(())
}
/// Stop watching and disconnect from all endpoints.
pub fn cancel(&self) {
self.cancel_token.cancel();
}
}
impl Drop for DynamicSubscriber {
fn drop(&mut self) {
self.cancel_token.cancel();
}
}
This diff is collapsed.
......@@ -7,30 +7,20 @@ use crate::pipeline::network::tcp::server::{DefaultIpResolver, IpResolver};
use local_ip_address::Error;
use std::net::IpAddr;
/// Get the local IP address for HTTP RPC host binding, using IpResolver with fallback to 127.0.0.1
///
/// This function attempts to resolve the local IP address using the provided resolver.
/// If resolution fails, it falls back to 127.0.0.1 (localhost).
///
/// IPv6 addresses are wrapped with brackets for safe URL construction (e.g., `[::1]`).
///
/// # Arguments
/// * `resolver` - An implementation of IpResolver trait for getting local IP addresses
///
/// # Returns
/// A string representation of the resolved IP address (IPv6 addresses are bracketed)
pub fn get_http_rpc_host_with_resolver<R: IpResolver>(resolver: R) -> String {
fn resolve_local_ip_with_resolver<R: IpResolver>(resolver: R) -> IpAddr {
let resolved_ip = resolver.local_ip().or_else(|err| match err {
Error::LocalIpAddressNotFound => resolver.local_ipv6(),
_ => Err(err),
});
let addr = match resolved_ip {
match resolved_ip {
Ok(addr) => addr,
Err(Error::LocalIpAddressNotFound) => IpAddr::from([127, 0, 0, 1]),
Err(_) => IpAddr::from([127, 0, 0, 1]), // Fallback for any other error
};
}
}
fn format_ip_for_url(addr: IpAddr) -> String {
// Wrap IPv6 addresses with brackets for safe URL construction
// e.g., "2001:db8::1" becomes "[2001:db8::1]" so that "{host}:{port}" is valid
match addr {
......@@ -39,6 +29,32 @@ pub fn get_http_rpc_host_with_resolver<R: IpResolver>(resolver: R) -> String {
}
}
/// Get the local IP address for advertising endpoints, using IpResolver with fallback to 127.0.0.1
///
/// This function attempts to resolve the local IP address using the provided resolver.
/// If resolution fails, it falls back to 127.0.0.1 (localhost).
///
/// IPv6 addresses are wrapped with brackets for safe URL construction (e.g., `[::1]`).
///
/// # Arguments
/// * `resolver` - An implementation of IpResolver trait for getting local IP addresses
///
/// # Returns
/// A string representation of the resolved IP address (IPv6 addresses are bracketed)
pub fn get_local_ip_for_advertise_with_resolver<R: IpResolver>(resolver: R) -> String {
format_ip_for_url(resolve_local_ip_with_resolver(resolver))
}
/// Get the local IP address for advertising endpoints using the default resolver.
pub fn get_local_ip_for_advertise() -> String {
get_local_ip_for_advertise_with_resolver(DefaultIpResolver)
}
/// Get the local IP address for HTTP RPC host binding, using IpResolver with fallback to 127.0.0.1
pub fn get_http_rpc_host_with_resolver<R: IpResolver>(resolver: R) -> String {
get_local_ip_for_advertise_with_resolver(resolver)
}
/// Get the local IP address for HTTP RPC host binding using the default resolver
///
/// This is a convenience function that uses the DefaultIpResolver.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment