Unverified Commit 9c5018e5 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: make router use the discovery pattern instead of manual kv watch (#4597)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 445b4bd9
......@@ -9,6 +9,3 @@ pub use watcher::{ModelUpdate, ModelWatcher};
mod worker_monitor;
pub use worker_monitor::{KvWorkerMonitor, WorkerLoadState};
/// The root etcd path for KV Router registrations
pub const KV_ROUTERS_ROOT_PATH: &str = "v1/kv-routers";
......@@ -9,12 +9,16 @@ use std::{
use parking_lot::{Mutex, RwLock};
use tokio::sync::oneshot;
use dynamo_runtime::{component::Endpoint, storage::key_value_store::Key};
use dynamo_runtime::{prelude::DistributedRuntimeProvider, protocols::EndpointId};
use dynamo_runtime::{
component::{Endpoint, TransportType},
discovery::DiscoverySpec,
prelude::DistributedRuntimeProvider,
protocols::EndpointId,
transports::nats,
};
use crate::{
discovery::KV_ROUTERS_ROOT_PATH,
kv_router::{KvRouter, KvRouterConfig, scheduler::DefaultWorkerSelector},
kv_router::{KvRouter, KvRouterConfig, router_endpoint_id, scheduler::DefaultWorkerSelector},
model_card::ModelDeploymentCard,
model_type::ModelType,
types::{
......@@ -310,23 +314,28 @@ impl ModelManager {
}
let client = endpoint.client().await?;
let store = endpoint.component().drt().store();
let router_bucket = store
.get_or_create_bucket(KV_ROUTERS_ROOT_PATH, None)
.await?;
let router_uuid = uuid::Uuid::new_v4();
// In lib/llm/src/kv_router/subscriber.rs we filter on component.service_name() so this
// must have that prefix.
let router_key = Key::new(format!(
"{}/{}/{}",
endpoint.component().service_name(),
endpoint.name(),
router_uuid,
));
let json_router_config = serde_json::to_vec_pretty(&kv_router_config.unwrap_or_default())?;
router_bucket
.insert(&router_key, json_router_config.into(), 0)
.await?;
// Register router via discovery mechanism
let discovery = endpoint.component().drt().discovery();
let instance_id = discovery.instance_id();
// Build NATS transport subject for the router endpoint
// Use KV_ROUTER_COMPONENT as the component name to distinguish from the generate endpoint's component
let router_endpoint_id = router_endpoint_id(endpoint.id().namespace);
// Placeholder subject - router is not callable, only registered for lifecycle coordination
let nats_subject = nats::instance_subject(&router_endpoint_id, instance_id);
let discovery_spec = DiscoverySpec::Endpoint {
namespace: router_endpoint_id.namespace.clone(),
component: router_endpoint_id.component.clone(),
endpoint: router_endpoint_id.name.clone(),
transport: TransportType::Nats(nats_subject),
};
discovery.register(discovery_spec).await?;
// Use instance_id (hex) as the consumer ID for NATS consumer coordination
let consumer_id = instance_id.to_string();
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
let chooser = KvRouter::new(
......@@ -335,7 +344,7 @@ impl ModelManager {
kv_cache_block_size,
Some(selector),
kv_router_config,
router_uuid.to_string(),
consumer_id,
)
.await?;
let new_kv_chooser = Arc::new(chooser);
......
......@@ -14,6 +14,7 @@ use dynamo_runtime::{
AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, ResponseStream,
SingleIn, async_trait,
},
protocols::EndpointId,
protocols::annotated::Annotated,
traits::DistributedRuntimeProvider,
};
......@@ -74,6 +75,28 @@ pub const ACTIVE_SEQUENCES_SUBJECT: &str = "active_sequences_events";
pub const RADIX_STATE_BUCKET: &str = "radix-bucket";
pub const RADIX_STATE_FILE: &str = "radix-state";
// for router discovery registration
pub const KV_ROUTER_COMPONENT: &str = "kv-router";
pub const KV_ROUTER_ENDPOINT: &str = "generate";
/// Creates an EndpointId for the KV router in the given namespace.
pub fn router_endpoint_id(namespace: String) -> EndpointId {
EndpointId {
namespace,
component: KV_ROUTER_COMPONENT.to_string(),
name: KV_ROUTER_ENDPOINT.to_string(),
}
}
/// Creates a DiscoveryQuery for the KV router in the given namespace.
pub fn router_discovery_query(namespace: String) -> DiscoveryQuery {
DiscoveryQuery::Endpoint {
namespace,
component: KV_ROUTER_COMPONENT.to_string(),
endpoint: KV_ROUTER_ENDPOINT.to_string(),
}
}
/// A trait that users can implement to define custom selection logic
pub trait WorkerSelector {
fn select_worker(
......@@ -254,7 +277,7 @@ impl KvRouter {
block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
kv_router_config: Option<KvRouterConfig>,
consumer_uuid: String,
consumer_id: String,
) -> Result<Self> {
let kv_router_config = kv_router_config.unwrap_or_default();
let component = endpoint.component();
......@@ -311,7 +334,7 @@ impl KvRouter {
runtime_configs_rx,
selector,
kv_router_config.router_replica_sync,
consumer_uuid.clone(),
consumer_id.clone(),
)
.await?;
......@@ -321,7 +344,7 @@ impl KvRouter {
{
start_kv_router_background(
component.clone(),
consumer_uuid,
consumer_id,
kv_indexer.event_sender(),
kv_indexer.remove_worker_sender(),
kv_router_config
......
......@@ -3,15 +3,14 @@
//! Background processes for the KV Router including event consumption and snapshot uploads.
use std::{collections::HashSet, sync::Arc, time::Duration};
use std::{collections::HashSet, time::Duration};
use anyhow::Result;
use dynamo_runtime::{
component::Component,
config::environment_names::nats as env_nats,
discovery::DiscoveryQuery,
discovery::{DiscoveryEvent, DiscoveryQuery},
prelude::*,
storage::key_value_store::WatchEvent,
traits::events::EventPublisher,
transports::nats::{NatsQueue, Slug},
};
......@@ -20,13 +19,11 @@ use rand::Rng;
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use crate::{
discovery::KV_ROUTERS_ROOT_PATH,
kv_router::{
use crate::kv_router::{
KV_EVENT_SUBJECT, RADIX_STATE_BUCKET, RADIX_STATE_FILE,
indexer::{DumpRequest, GetWorkersRequest, RouterEvent},
protocols::WorkerId,
},
router_discovery_query,
};
/// Delay between snapshot reads to verify stability
......@@ -217,7 +214,7 @@ impl SnapshotResources {
#[allow(clippy::too_many_arguments)]
pub async fn start_kv_router_background(
component: Component,
consumer_uuid: String,
consumer_id: String,
kv_events_tx: mpsc::Sender<RouterEvent>,
remove_worker_tx: mpsc::Sender<WorkerId>,
maybe_get_workers_tx: Option<mpsc::Sender<GetWorkersRequest>>,
......@@ -238,7 +235,7 @@ pub async fn start_kv_router_background(
stream_name.clone(),
nats_server.clone(),
std::time::Duration::from_secs(60), // 1 minute timeout
consumer_uuid.clone(),
consumer_id.clone(),
);
nats_queue.connect_with_reset(router_reset_states).await?;
......@@ -266,23 +263,24 @@ pub async fn start_kv_router_background(
}
// Cleanup orphaned consumers on startup
cleanup_orphaned_consumers(&mut nats_queue, &component, &consumer_uuid).await;
// Watch for router deletions to clean up orphaned consumers
let store = component.drt().store();
let (_watch_handle, mut router_replicas_rx) =
Arc::new(store.clone()).watch(KV_ROUTERS_ROOT_PATH, None, cancellation_token.clone());
cleanup_orphaned_consumers(&mut nats_queue, &component, &consumer_id).await;
// Get the generate endpoint and watch for instance deletions
let generate_endpoint = component.endpoint("generate");
let discovery_client = component.drt().discovery();
let discovery_key = DiscoveryQuery::Endpoint {
let generate_discovery_key = DiscoveryQuery::Endpoint {
namespace: component.namespace().name().to_string(),
component: component.name().to_string(),
endpoint: "generate".to_string(),
};
let mut instance_event_stream = discovery_client
.list_and_watch(discovery_key, Some(cancellation_token.clone()))
.list_and_watch(generate_discovery_key, Some(cancellation_token.clone()))
.await?;
// Watch for router deletions to clean up orphaned consumers via discovery
let router_discovery_key = router_discovery_query(component.namespace().name());
let mut router_event_stream = discovery_client
.list_and_watch(router_discovery_key, Some(cancellation_token.clone()))
.await?;
// Get instances_rx for tracking current workers
......@@ -336,7 +334,7 @@ pub async fn start_kv_router_background(
continue;
};
let dynamo_runtime::discovery::DiscoveryEvent::Removed(worker_id) = discovery_event else {
let DiscoveryEvent::Removed(worker_id) = discovery_event else {
continue;
};
......@@ -409,36 +407,24 @@ pub async fn start_kv_router_background(
}
}
// Handle router deletion events
Some(event) = router_replicas_rx.recv() => {
let WatchEvent::Delete(kv) = event else {
// We only care about deletions for cleaning up consumers
// Handle router deletion events via discovery
Some(router_event_result) = router_event_stream.next() => {
let Ok(router_event) = router_event_result else {
continue;
};
let key = kv.as_ref();
tracing::info!("Detected router replica deletion: {key}");
// Only process deletions for routers on the same component
// Must match model_manager.rs kv_chooser_for
if !key.contains(&component.service_name()) {
tracing::trace!(
"Skipping router deletion from different component (key: {key}, subscriber component: {})",
component.service_name()
);
continue;
}
// Extract the router UUID from the key
let Some(router_uuid) = key.split('/').next_back() else {
tracing::warn!("Could not extract UUID from router key: {key}");
let DiscoveryEvent::Removed(router_instance_id) = router_event else {
// We only care about removals for cleaning up consumers
continue;
};
// The consumer UUID is the router UUID
let consumer_to_delete = router_uuid.to_string();
// The consumer UUID is the instance_id in hex format
let consumer_to_delete = router_instance_id.to_string();
tracing::info!("Attempting to delete orphaned consumer: {consumer_to_delete}");
tracing::info!(
router_instance_id = router_instance_id,
"DISCOVERY: Router instance removed, attempting to delete orphaned consumer: {consumer_to_delete}"
);
// Delete the consumer (allow race condition if multiple routers try to delete)
if let Err(e) = nats_queue.shutdown(Some(consumer_to_delete.clone())).await {
......@@ -463,43 +449,34 @@ pub async fn start_kv_router_background(
async fn cleanup_orphaned_consumers(
nats_queue: &mut NatsQueue,
component: &Component,
consumer_uuid: &str,
consumer_id: &str,
) {
let Ok(consumers) = nats_queue.list_consumers().await else {
return;
};
// Get active routers from store
let store = component.drt().store();
let Ok(Some(router_bucket)) = store.get_bucket(KV_ROUTERS_ROOT_PATH).await else {
tracing::debug!("No router bucket found, skipping cleanup");
return;
};
let Ok(entries) = router_bucket.entries().await else {
// Get active routers from discovery
let discovery = component.drt().discovery();
let Ok(router_instances) = discovery
.list(router_discovery_query(component.namespace().name()))
.await
else {
tracing::debug!("Failed to list router instances from discovery, skipping cleanup");
return;
};
// Filter to only routers for this component
let component_path = component.service_name();
let active_uuids: HashSet<String> = entries
// Build set of active router instance IDs
let active_instance_ids: HashSet<String> = router_instances
.iter()
.filter_map(|(key, _)| {
// Check if key contains this component's path
if !key.as_ref().contains(&component_path) {
return None;
}
// Extract the last part (should be the UUID)
key.as_ref().split('/').next_back().map(str::to_string)
})
.map(|instance| instance.instance_id().to_string())
.collect();
for consumer in consumers {
if consumer == consumer_uuid {
if consumer == consumer_id {
// Never delete myself (extra/redundant safeguard)
continue;
}
if !active_uuids.contains(&consumer) {
if !active_instance_ids.contains(&consumer) {
tracing::info!("Cleaning up orphaned consumer: {consumer}");
let _ = nats_queue.shutdown(Some(consumer)).await;
}
......
......@@ -993,7 +993,7 @@ impl DRTNatsClientPrometheusMetrics {
/// The NATS subject / inbox to talk to an instance on.
/// TODO: Do we need to sanitize the names?
pub(crate) fn instance_subject(endpoint_id: &EndpointId, instance_id: u64) -> String {
pub fn instance_subject(endpoint_id: &EndpointId, instance_id: u64) -> String {
format!(
"{}_{}.{}-{:x}",
endpoint_id.namespace, endpoint_id.component, endpoint_id.name, instance_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment