Unverified Commit 9c5018e5 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: make router use the discovery pattern instead of manual kv watch (#4597)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 445b4bd9
...@@ -9,6 +9,3 @@ pub use watcher::{ModelUpdate, ModelWatcher}; ...@@ -9,6 +9,3 @@ pub use watcher::{ModelUpdate, ModelWatcher};
mod worker_monitor; mod worker_monitor;
pub use worker_monitor::{KvWorkerMonitor, WorkerLoadState}; pub use worker_monitor::{KvWorkerMonitor, WorkerLoadState};
/// The root etcd path for KV Router registrations
pub const KV_ROUTERS_ROOT_PATH: &str = "v1/kv-routers";
...@@ -9,12 +9,16 @@ use std::{ ...@@ -9,12 +9,16 @@ use std::{
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use tokio::sync::oneshot; use tokio::sync::oneshot;
use dynamo_runtime::{component::Endpoint, storage::key_value_store::Key}; use dynamo_runtime::{
use dynamo_runtime::{prelude::DistributedRuntimeProvider, protocols::EndpointId}; component::{Endpoint, TransportType},
discovery::DiscoverySpec,
prelude::DistributedRuntimeProvider,
protocols::EndpointId,
transports::nats,
};
use crate::{ use crate::{
discovery::KV_ROUTERS_ROOT_PATH, kv_router::{KvRouter, KvRouterConfig, router_endpoint_id, scheduler::DefaultWorkerSelector},
kv_router::{KvRouter, KvRouterConfig, scheduler::DefaultWorkerSelector},
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
model_type::ModelType, model_type::ModelType,
types::{ types::{
...@@ -310,23 +314,28 @@ impl ModelManager { ...@@ -310,23 +314,28 @@ impl ModelManager {
} }
let client = endpoint.client().await?; let client = endpoint.client().await?;
let store = endpoint.component().drt().store();
let router_bucket = store // Register router via discovery mechanism
.get_or_create_bucket(KV_ROUTERS_ROOT_PATH, None) let discovery = endpoint.component().drt().discovery();
.await?; let instance_id = discovery.instance_id();
let router_uuid = uuid::Uuid::new_v4();
// In lib/llm/src/kv_router/subscriber.rs we filter on component.service_name() so this // Build NATS transport subject for the router endpoint
// must have that prefix. // Use KV_ROUTER_COMPONENT as the component name to distinguish from the generate endpoint's component
let router_key = Key::new(format!( let router_endpoint_id = router_endpoint_id(endpoint.id().namespace);
"{}/{}/{}", // Placeholder subject - router is not callable, only registered for lifecycle coordination
endpoint.component().service_name(), let nats_subject = nats::instance_subject(&router_endpoint_id, instance_id);
endpoint.name(),
router_uuid, let discovery_spec = DiscoverySpec::Endpoint {
)); namespace: router_endpoint_id.namespace.clone(),
let json_router_config = serde_json::to_vec_pretty(&kv_router_config.unwrap_or_default())?; component: router_endpoint_id.component.clone(),
router_bucket endpoint: router_endpoint_id.name.clone(),
.insert(&router_key, json_router_config.into(), 0) transport: TransportType::Nats(nats_subject),
.await?; };
discovery.register(discovery_spec).await?;
// Use instance_id (hex) as the consumer ID for NATS consumer coordination
let consumer_id = instance_id.to_string();
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config)); let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
let chooser = KvRouter::new( let chooser = KvRouter::new(
...@@ -335,7 +344,7 @@ impl ModelManager { ...@@ -335,7 +344,7 @@ impl ModelManager {
kv_cache_block_size, kv_cache_block_size,
Some(selector), Some(selector),
kv_router_config, kv_router_config,
router_uuid.to_string(), consumer_id,
) )
.await?; .await?;
let new_kv_chooser = Arc::new(chooser); let new_kv_chooser = Arc::new(chooser);
......
...@@ -14,6 +14,7 @@ use dynamo_runtime::{ ...@@ -14,6 +14,7 @@ use dynamo_runtime::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
SingleIn, async_trait, SingleIn, async_trait,
}, },
protocols::EndpointId,
protocols::annotated::Annotated, protocols::annotated::Annotated,
traits::DistributedRuntimeProvider, traits::DistributedRuntimeProvider,
}; };
...@@ -74,6 +75,28 @@ pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events"; ...@@ -74,6 +75,28 @@ pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
pub const RADIX_STATE_BUCKET: &str = "radix-bucket"; pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state"; pub const RADIX_STATE_FILE: &str = "radix-state";
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";
/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
EndpointId {
namespace,
component: KV_ROUTER_COMPONENT.to_string(),
name: KV_ROUTER_ENDPOINT.to_string(),
}
}
/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
DiscoveryQuery::Endpoint {
namespace,
component: KV_ROUTER_COMPONENT.to_string(),
endpoint: KV_ROUTER_ENDPOINT.to_string(),
}
}
/// A trait that users can implement to define custom selection logic /// A trait that users can implement to define custom selection logic
pub trait WorkerSelector { pub trait WorkerSelector {
fn select_worker( fn select_worker(
...@@ -254,7 +277,7 @@ impl KvRouter { ...@@ -254,7 +277,7 @@ impl KvRouter {
block_size: u32, block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>, selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
consumer_uuid: String, consumer_id: String,
) -> Result<Self> { ) -> Result<Self> {
let kv_router_config = kv_router_config.unwrap_or_default(); let kv_router_config = kv_router_config.unwrap_or_default();
let component = endpoint.component(); let component = endpoint.component();
...@@ -311,7 +334,7 @@ impl KvRouter { ...@@ -311,7 +334,7 @@ impl KvRouter {
runtime_configs_rx, runtime_configs_rx,
selector, selector,
kv_router_config.router_replica_sync, kv_router_config.router_replica_sync,
consumer_uuid.clone(), consumer_id.clone(),
) )
.await?; .await?;
...@@ -321,7 +344,7 @@ impl KvRouter { ...@@ -321,7 +344,7 @@ impl KvRouter {
{ {
start_kv_router_background( start_kv_router_background(
component.clone(), component.clone(),
consumer_uuid, consumer_id,
kv_indexer.event_sender(), kv_indexer.event_sender(),
kv_indexer.remove_worker_sender(), kv_indexer.remove_worker_sender(),
kv_router_config kv_router_config
......
...@@ -3,15 +3,14 @@ ...@@ -3,15 +3,14 @@
//! Background processes for the KV Router including event consumption and snapshot uploads. //! Background processes for the KV Router including event consumption and snapshot uploads.
use std::{collections::HashSet, sync::Arc, time::Duration}; use std::{collections::HashSet, time::Duration};
use anyhow::Result; use anyhow::Result;
use dynamo_runtime::{ use dynamo_runtime::{
component::Component, component::Component,
config::environment_names::nats as env_nats, config::environment_names::nats as env_nats,
discovery::DiscoveryQuery, discovery::{DiscoveryEvent, DiscoveryQuery},
prelude::*, prelude::*,
storage::key_value_store::WatchEvent,
traits::events::EventPublisher, traits::events::EventPublisher,
transports::nats::{NatsQueue, Slug}, transports::nats::{NatsQueue, Slug},
}; };
...@@ -20,13 +19,11 @@ use rand::Rng; ...@@ -20,13 +19,11 @@ use rand::Rng;
use tokio::sync::{mpsc, oneshot}; use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::{ use crate::kv_router::{
discovery::KV_ROUTERS_ROOT_PATH, KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE,
kv_router::{ indexer::{DumpRequest, GetWorkersRequest, RouterEvent},
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE, protocols::WorkerId,
indexer::{DumpRequest, GetWorkersRequest, RouterEvent}, router_discovery_query,
protocols::WorkerId,
},
}; };
/// Delay between snapshot reads to verify stability /// Delay between snapshot reads to verify stability
...@@ -217,7 +214,7 @@ impl SnapshotResources { ...@@ -217,7 +214,7 @@ impl SnapshotResources {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub async fn start_kv_router_background( pub async fn start_kv_router_background(
component: Component, component: Component,
consumer_uuid: String, consumer_id: String,
kv_events_tx: mpsc::Sender<RouterEvent>, kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<WorkerId>, remove_worker_tx: mpsc::Sender<WorkerId>,
maybe_get_workers_tx: Option<mpsc::Sender<GetWorkersRequest>>, maybe_get_workers_tx: Option<mpsc::Sender<GetWorkersRequest>>,
...@@ -238,7 +235,7 @@ pub async fn start_kv_router_background( ...@@ -238,7 +235,7 @@ pub async fn start_kv_router_background(
stream_name.clone(), stream_name.clone(),
nats_server.clone(), nats_server.clone(),
std::time::Duration::from_secs(60), // 1 minute timeout std::time::Duration::from_secs(60), // 1 minute timeout
consumer_uuid.clone(), consumer_id.clone(),
); );
nats_queue.connect_with_reset(router_reset_states).await?; nats_queue.connect_with_reset(router_reset_states).await?;
...@@ -266,23 +263,24 @@ pub async fn start_kv_router_background( ...@@ -266,23 +263,24 @@ pub async fn start_kv_router_background(
} }
// Cleanup orphaned consumers on startup // Cleanup orphaned consumers on startup
cleanup_orphaned_consumers(&mut nats_queue, &component, &consumer_uuid).await; cleanup_orphaned_consumers(&mut nats_queue, &component, &consumer_id).await;
// Watch for router deletions to clean up orphaned consumers
let store = component.drt().store();
let (_watch_handle, mut router_replicas_rx) =
Arc::new(store.clone()).watch(KV_ROUTERS_ROOT_PATH, None, cancellation_token.clone());
// Get the generate endpoint and watch for instance deletions // Get the generate endpoint and watch for instance deletions
let generate_endpoint = component.endpoint("generate"); let generate_endpoint = component.endpoint("generate");
let discovery_client = component.drt().discovery(); let discovery_client = component.drt().discovery();
let discovery_key = DiscoveryQuery::Endpoint { let generate_discovery_key = DiscoveryQuery::Endpoint {
namespace: component.namespace().name().to_string(), namespace: component.namespace().name().to_string(),
component: component.name().to_string(), component: component.name().to_string(),
endpoint: "generate".to_string(), endpoint: "generate".to_string(),
}; };
let mut instance_event_stream = discovery_client let mut instance_event_stream = discovery_client
.list_and_watch(discovery_key, Some(cancellation_token.clone())) .list_and_watch(generate_discovery_key, Some(cancellation_token.clone()))
.await?;
// Watch for router deletions to clean up orphaned consumers via discovery
let router_discovery_key = router_discovery_query(component.namespace().name());
let mut router_event_stream = discovery_client
.list_and_watch(router_discovery_key, Some(cancellation_token.clone()))
.await?; .await?;
// Get instances_rx for tracking current workers // Get instances_rx for tracking current workers
...@@ -336,7 +334,7 @@ pub async fn start_kv_router_background( ...@@ -336,7 +334,7 @@ pub async fn start_kv_router_background(
continue; continue;
}; };
let dynamo_runtime::discovery::DiscoveryEvent::Removed(worker_id) = discovery_event else { let DiscoveryEvent::Removed(worker_id) = discovery_event else {
continue; continue;
}; };
...@@ -409,36 +407,24 @@ pub async fn start_kv_router_background( ...@@ -409,36 +407,24 @@ pub async fn start_kv_router_background(
} }
} }
// Handle router deletion events // Handle router deletion events via discovery
Some(event) = router_replicas_rx.recv() => { Some(router_event_result) = router_event_stream.next() => {
let WatchEvent::Delete(kv) = event else { let Ok(router_event) = router_event_result else {
// We only care about deletions for cleaning up consumers
continue; continue;
}; };
let key = kv.as_ref(); let DiscoveryEvent::Removed(router_instance_id) = router_event else {
tracing::info!("Detected router replica deletion: {key}"); // We only care about removals for cleaning up consumers
// Only process deletions for routers on the same component
// Must match model_manager.rs kv_chooser_for
if !key.contains(&component.service_name()) {
tracing::trace!(
"Skipping router deletion from different component (key: {key}, subscriber component: {})",
component.service_name()
);
continue;
}
// Extract the router UUID from the key
let Some(router_uuid) = key.split('/').next_back() else {
tracing::warn!("Could not extract UUID from router key: {key}");
continue; continue;
}; };
// The consumer UUID is the router UUID // The consumer UUID is the instance_id in hex format
let consumer_to_delete = router_uuid.to_string(); let consumer_to_delete = router_instance_id.to_string();
tracing::info!("Attempting to delete orphaned consumer: {consumer_to_delete}"); tracing::info!(
router_instance_id = router_instance_id,
"DISCOVERY: Router instance removed, attempting to delete orphaned consumer: {consumer_to_delete}"
);
// Delete the consumer (allow race condition if multiple routers try to delete) // Delete the consumer (allow race condition if multiple routers try to delete)
if let Err(e) = nats_queue.shutdown(Some(consumer_to_delete.clone())).await { if let Err(e) = nats_queue.shutdown(Some(consumer_to_delete.clone())).await {
...@@ -463,43 +449,34 @@ pub async fn start_kv_router_background( ...@@ -463,43 +449,34 @@ pub async fn start_kv_router_background(
async fn cleanup_orphaned_consumers( async fn cleanup_orphaned_consumers(
nats_queue: &mut NatsQueue, nats_queue: &mut NatsQueue,
component: &Component, component: &Component,
consumer_uuid: &str, consumer_id: &str,
) { ) {
let Ok(consumers) = nats_queue.list_consumers().await else { let Ok(consumers) = nats_queue.list_consumers().await else {
return; return;
}; };
// Get active routers from store // Get active routers from discovery
let store = component.drt().store(); let discovery = component.drt().discovery();
let Ok(Some(router_bucket)) = store.get_bucket(KV_ROUTERS_ROOT_PATH).await else { let Ok(router_instances) = discovery
tracing::debug!("No router bucket found, skipping cleanup"); .list(router_discovery_query(component.namespace().name()))
return; .await
}; else {
tracing::debug!("Failed to list router instances from discovery, skipping cleanup");
let Ok(entries) = router_bucket.entries().await else {
return; return;
}; };
// Filter to only routers for this component // Build set of active router instance IDs
let component_path = component.service_name(); let active_instance_ids: HashSet<String> = router_instances
let active_uuids: HashSet<String> = entries
.iter() .iter()
.filter_map(|(key, _)| { .map(|instance| instance.instance_id().to_string())
// Check if key contains this component's path
if !key.as_ref().contains(&component_path) {
return None;
}
// Extract the last part (should be the UUID)
key.as_ref().split('/').next_back().map(str::to_string)
})
.collect(); .collect();
for consumer in consumers { for consumer in consumers {
if consumer == consumer_uuid { if consumer == consumer_id {
// Never delete myself (extra/redundant safeguard) // Never delete myself (extra/redundant safeguard)
continue; continue;
} }
if !active_uuids.contains(&consumer) { if !active_instance_ids.contains(&consumer) {
tracing::info!("Cleaning up orphaned consumer: {consumer}"); tracing::info!("Cleaning up orphaned consumer: {consumer}");
let _ = nats_queue.shutdown(Some(consumer)).await; let _ = nats_queue.shutdown(Some(consumer)).await;
} }
......
...@@ -993,7 +993,7 @@ impl DRTNatsClientPrometheusMetrics { ...@@ -993,7 +993,7 @@ impl DRTNatsClientPrometheusMetrics {
/// The NATS subject / inbox to talk to an instance on. /// The NATS subject / inbox to talk to an instance on.
/// TODO: Do we need to sanitize the names? /// TODO: Do we need to sanitize the names?
pub(crate) fn instance_subject(endpoint_id: &EndpointId, instance_id: u64) -> String { pub fn instance_subject(endpoint_id: &EndpointId, instance_id: u64) -> String {
format!( format!(
"{}_{}.{}-{:x}", "{}_{}.{}-{:x}",
endpoint_id.namespace, endpoint_id.component, endpoint_id.name, instance_id, endpoint_id.namespace, endpoint_id.component, endpoint_id.name, instance_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment