"vscode:/vscode.git/clone" did not exist on "85df8afdae72866f5142d674a916ec6a0879b9e8"
Unverified Commit cdeda221 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix: block on notification of at least one runtime config (#5191)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent e7918716
......@@ -2,7 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
mod model_manager;
pub use model_manager::{ModelManager, ModelManagerError};
pub use model_manager::{ModelManager, ModelManagerError, RuntimeConfigsWithNotify};
mod watcher;
pub use watcher::{ModelUpdate, ModelWatcher};
......
......@@ -8,7 +8,7 @@ use std::{
use dashmap::{DashMap, mapref::entry::Entry};
use parking_lot::{Mutex, RwLock};
use tokio::sync::oneshot;
use tokio::sync::{Notify, oneshot};
use crate::discovery::KvWorkerMonitor;
......@@ -81,8 +81,14 @@ pub struct ModelManager {
/// Runtime configs per endpoint using DashMap for lock-free access.
/// Outer DashMap: keyed by EndpointId
/// Inner Arc<DashMap>: keyed by WorkerId, shared with KvScheduler
runtime_configs: DashMap<EndpointId, Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>>,
/// Inner RuntimeConfigsWithNotify: shared with KvScheduler
runtime_configs: DashMap<EndpointId, Arc<RuntimeConfigsWithNotify>>,
}
/// Runtime configs for an endpoint with a notify for change notifications.
pub struct RuntimeConfigsWithNotify {
pub configs: DashMap<WorkerId, Option<ModelRuntimeConfig>>,
pub notify: Notify,
}
impl Default for ModelManager {
......@@ -619,11 +625,11 @@ impl ModelManager {
/// Get or create a runtime config watcher for an endpoint.
/// Spawns a background task to watch DiscoveryQuery::EndpointModels.
/// Returns a shared Arc<DashMap> that KvScheduler can use directly.
/// Returns a shared RuntimeConfigsWithNotify that KvScheduler can use directly.
pub async fn get_or_create_runtime_config_watcher(
&self,
endpoint: &Endpoint,
) -> anyhow::Result<Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>> {
) -> anyhow::Result<Arc<RuntimeConfigsWithNotify>> {
let endpoint_id = endpoint.id();
// Fast path: return existing if present
......@@ -632,22 +638,25 @@ impl ModelManager {
}
// Atomic get-or-insert to avoid TOCTOU race
let inner_map = Arc::new(DashMap::new());
let (map, is_new) = match self.runtime_configs.entry(endpoint_id) {
let inner = Arc::new(RuntimeConfigsWithNotify {
configs: DashMap::new(),
notify: Notify::new(),
});
let (result, is_new) = match self.runtime_configs.entry(endpoint_id) {
Entry::Occupied(e) => (e.get().clone(), false),
Entry::Vacant(e) => {
e.insert(inner_map.clone());
(inner_map, true)
e.insert(inner.clone());
(inner, true)
}
};
// Only spawn watcher if we were the one who inserted
if is_new {
self.spawn_runtime_config_watcher(endpoint, map.clone())
self.spawn_runtime_config_watcher(endpoint, result.clone())
.await?;
}
Ok(map)
Ok(result)
}
/// Get disaggregated endpoint for a specific worker.
......@@ -657,16 +666,17 @@ impl ModelManager {
endpoint_id: &EndpointId,
worker_id: WorkerId,
) -> Option<DisaggregatedEndpoint> {
let inner_map = self.runtime_configs.get(endpoint_id)?;
let config_ref = inner_map.get(&worker_id)?;
let inner = self.runtime_configs.get(endpoint_id)?;
let config_ref = inner.configs.get(&worker_id)?;
config_ref.as_ref()?.disaggregated_endpoint.clone()
}
/// Spawn background task to watch runtime configs via discovery.
/// Blocks until at least one worker with a runtime config is available.
async fn spawn_runtime_config_watcher(
&self,
endpoint: &Endpoint,
inner_map: Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>,
inner: Arc<RuntimeConfigsWithNotify>,
) -> anyhow::Result<()> {
let component = endpoint.component();
let cancellation_token = component.drt().primary_token();
......@@ -693,7 +703,29 @@ impl ModelManager {
let client = endpoint.client().await?;
let mut instance_ids_rx = client.instance_avail_watcher();
// Spawn background task to update inner_map
// Wait for at least one worker with runtime config before proceeding.
// This ensures the DashMap is populated before KvScheduler starts.
tracing::info!("ModelManager: Waiting for at least one worker with runtime config...");
runtime_configs_rx
.changed()
.await
.map_err(|_| anyhow::anyhow!("runtime configs watch sender shutdown while waiting"))?;
// Populate initial state
{
let instance_ids = instance_ids_rx.borrow();
let configs = runtime_configs_rx.borrow();
for worker_id in instance_ids.iter() {
let config = configs.get(worker_id).cloned();
inner.configs.insert(*worker_id, config);
}
tracing::info!(
"ModelManager: Found {} workers, proceeding",
inner.configs.len()
);
}
// Spawn background task to update configs for future changes
let cancel_token = cancellation_token.clone();
tokio::spawn(async move {
tracing::trace!("ModelManager runtime config watcher started");
......@@ -725,30 +757,32 @@ impl ModelManager {
// Update the DashMap
// First, remove workers that no longer exist
let current_workers: HashSet<WorkerId> =
inner_map.iter().map(|r| *r.key()).collect();
inner.configs.iter().map(|r| *r.key()).collect();
let new_workers: HashSet<WorkerId> = new_instance_ids.iter().copied().collect();
for removed_worker in current_workers.difference(&new_workers) {
inner_map.remove(removed_worker);
inner.configs.remove(removed_worker);
}
// Then, add/update workers
for worker_id in &new_instance_ids {
let config = new_configs.get(worker_id).cloned();
if config.is_some() {
let prev_config = inner_map.get(worker_id);
let prev_config = inner.configs.get(worker_id);
if prev_config.as_ref().map(|r| r.value()) != Some(&config) {
tracing::info!(
"ModelManager: Runtime config found for worker_id: {}",
worker_id
"ModelManager: Runtime config found for worker_id: {worker_id}"
);
}
}
inner_map.insert(*worker_id, config);
inner.configs.insert(*worker_id, config);
}
// Notify waiters that configs have changed
inner.notify.notify_waiters();
tracing::trace!(
"ModelManager: Updated runtime_configs with {} workers",
inner_map.len()
inner.configs.len()
);
}
tracing::trace!("ModelManager runtime config watcher shutting down");
......
......@@ -6,7 +6,6 @@ use std::sync::Arc;
use std::time::Duration;
use anyhow::Result;
use dashmap::DashMap;
use derive_builder::Builder;
use dynamo_runtime::{
component::{Client, Endpoint},
......@@ -39,6 +38,7 @@ pub use prefill_router::PrefillRouter;
use worker_query::WorkerQueryClient;
use crate::{
discovery::RuntimeConfigsWithNotify,
kv_router::{
approx::PruneConfig,
indexer::{KvIndexer, KvIndexerInterface, KvRouterError, OverlapScores, RouterEvent},
......@@ -281,7 +281,7 @@ impl KvRouter {
pub async fn new(
endpoint: Endpoint,
client: Client,
workers_with_configs: Arc<DashMap<protocols::WorkerId, Option<ModelRuntimeConfig>>>,
workers_with_configs: Arc<RuntimeConfigsWithNotify>,
block_size: u32,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
kv_router_config: Option<KvRouterConfig>,
......@@ -291,8 +291,6 @@ impl KvRouter {
let component = endpoint.component();
let cancellation_token = component.drt().primary_token();
let instance_ids_rx = client.instance_avail_watcher();
// Watch for runtime config updates via discovery interface
// (still needed for WorkerQueryClient and background tasks)
let discovery = component.drt().discovery();
......@@ -339,8 +337,7 @@ impl KvRouter {
let scheduler = KvScheduler::start(
component.clone(),
block_size,
instance_ids_rx,
workers_with_configs,
workers_with_configs.clone(),
selector,
kv_router_config.router_replica_sync,
consumer_id.clone(),
......@@ -354,39 +351,25 @@ impl KvRouter {
tracing::info!("Worker query client initialized");
// Start KV event subscriber background process (only when use_kv_events is enabled)
// We block here until at least one worker runtime config is registered,
// then spawn the subscriber. This ensures the router is ready before accepting requests.
// model_manager.get_or_create_runtime_config_watcher() guarantees at least one worker exists.
if kv_router_config.use_kv_events
&& let Indexer::KvIndexer(ref kv_indexer) = indexer
{
let mut runtime_configs_rx_clone = runtime_configs_rx.clone();
// Wait for at least one worker runtime config to be registered
tracing::info!("Waiting for at least one worker runtime config to be registered...");
let (all_local_indexer, count) = loop {
{
let configs = runtime_configs_rx_clone.borrow();
if !configs.is_empty() {
let all_local_indexer = configs.values().all(|c| c.enable_local_indexer);
break (all_local_indexer, configs.len());
}
// model_manager guarantees workers_with_configs is populated
// Wait for at least one worker before starting the subscriber
while workers_with_configs.configs.is_empty() {
tracing::info!("KV router waiting for at least one worker...");
workers_with_configs.notify.notified().await;
}
// Wait for changes to runtime_configs
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::debug!("KvRouter startup cancelled while waiting for workers");
anyhow::bail!("KvRouter startup cancelled");
}
result = runtime_configs_rx_clone.changed() => {
if result.is_err() {
tracing::debug!("Runtime configs channel closed");
anyhow::bail!("Runtime configs channel closed before any workers registered");
}
}
}
};
tracing::info!("Found {count} worker runtime config(s), starting KV event subscriber");
let count = workers_with_configs.configs.len();
let all_local_indexer = workers_with_configs
.configs
.iter()
.filter_map(|r| r.value().as_ref().map(|c| c.enable_local_indexer))
.all(|b| b);
tracing::info!("Found {count} worker(s), starting KV event subscriber");
// Start subscriber - setup runs synchronously, then spawns background loop internally
if all_local_indexer {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::discovery::RuntimeConfigsWithNotify;
use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::Result;
use dashmap::DashMap;
use dynamo_runtime::component::Component;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::traits::events::EventPublisher;
......@@ -12,7 +12,6 @@ use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use super::KV_HIT_RATE_SUBJECT;
use super::KvRouterConfig;
......@@ -97,16 +96,17 @@ impl KvScheduler {
pub async fn start(
component: Component,
block_size: u32,
instance_ids_rx: watch::Receiver<Vec<u64>>,
workers_with_configs: Arc<DashMap<WorkerId, Option<ModelRuntimeConfig>>>,
workers_with_configs: Arc<RuntimeConfigsWithNotify>,
selector: Option<Box<dyn WorkerSelector + Send + Sync>>,
replica_sync: bool,
router_uuid: String,
) -> Result<Self, KvSchedulerError> {
let selector = selector.unwrap_or(Box::new(DefaultWorkerSelector::default()));
// Get initial workers from DashMap for slot initialization
// Get initial workers from DashMap for slot initialization.
// ModelManager guarantees at least one worker is present before KvRouter::new() is called.
let initial_workers: HashMap<WorkerId, Option<ModelRuntimeConfig>> = workers_with_configs
.configs
.iter()
.map(|r| (*r.key(), r.value().clone()))
.collect();
......@@ -119,33 +119,29 @@ impl KvScheduler {
router_uuid,
));
// Spawn background task to monitor workers_with_configs changes and update slots
// Spawn background task to sync slots with DashMap when notified of changes.
// ModelManager's watcher updates the DashMap and notifies; we wait on notify here.
let slots_monitor = slots.clone();
let workers_monitor = workers_with_configs.clone();
let mut instance_ids_monitor_rx = instance_ids_rx.clone();
let monitor_cancel_token = component.drt().child_token();
tokio::spawn(async move {
tracing::trace!("KvScheduler workers monitoring task started");
let mut last_workers: HashSet<WorkerId> = HashSet::new();
loop {
// Wait for instance changes (ModelManager handles config updates to the DashMap)
// Wait for notification or cancellation
tokio::select! {
_ = monitor_cancel_token.cancelled() => {
tracing::trace!("KvScheduler workers monitoring task shutting down");
break;
}
result = instance_ids_monitor_rx.changed() => {
if result.is_err() {
tracing::warn!("instance IDs watch sender shutdown in KvScheduler monitor");
break;
}
}
_ = workers_monitor.notify.notified() => {}
}
// Get current workers from DashMap
let current_workers: HashMap<WorkerId, Option<ModelRuntimeConfig>> =
workers_monitor
.configs
.iter()
.map(|r| (*r.key(), r.value().clone()))
.collect();
......@@ -156,13 +152,8 @@ impl KvScheduler {
if current_worker_ids != last_workers {
slots_monitor.update_workers(current_workers);
last_workers = current_worker_ids;
tracing::trace!(
"KvScheduler: Updated slots with {} workers",
last_workers.len()
);
}
}
tracing::trace!("KvScheduler workers monitoring task shutting down");
});
let slots_clone = slots.clone();
......@@ -202,6 +193,7 @@ impl KvScheduler {
// Read the current workers configuration from DashMap
let workers: HashMap<WorkerId, Option<ModelRuntimeConfig>> = workers_scheduler
.configs
.iter()
.map(|r| (*r.key(), r.value().clone()))
.collect();
......
......@@ -36,6 +36,34 @@ const CHECK_INTERVAL_JITTER_MS: i64 = 100;
const WORKER_QUERY_MAX_RETRIES: u32 = 8;
const WORKER_QUERY_INITIAL_BACKOFF_MS: u64 = 200;
// ============================================================================
// Discovery Helpers
// ============================================================================
/// Wait for at least one worker instance to be discovered.
/// Returns a peekable stream of discovery events for the generate endpoint.
async fn wait_for_worker_instance(
component: &Component,
cancellation_token: &CancellationToken,
) -> Result<std::pin::Pin<Box<dyn futures::Stream<Item = Result<DiscoveryEvent>> + Send>>> {
let discovery_client = component.drt().discovery();
let generate_discovery_key = DiscoveryQuery::Endpoint {
namespace: component.namespace().name().to_string(),
component: component.name().to_string(),
endpoint: "generate".to_string(),
};
let mut stream = discovery_client
.list_and_watch(generate_discovery_key, Some(cancellation_token.clone()))
.await?
.peekable();
tracing::info!("KV subscriber waiting for at least one worker instance...");
std::pin::Pin::new(&mut stream).peek().await;
Ok(Box::pin(stream))
}
// ============================================================================
// Local KvIndexer-based Recovery
// ============================================================================
......@@ -473,19 +501,13 @@ pub async fn start_kv_router_background(
// Cleanup orphaned consumers on startup
cleanup_orphaned_consumers(&mut nats_queue, &component, &consumer_id).await;
// Get the generate endpoint and watch for instance deletions
let generate_endpoint = component.endpoint("generate");
let discovery_client = component.drt().discovery();
let generate_discovery_key = DiscoveryQuery::Endpoint {
namespace: component.namespace().name().to_string(),
component: component.name().to_string(),
endpoint: "generate".to_string(),
};
let mut instance_event_stream = discovery_client
.list_and_watch(generate_discovery_key, Some(cancellation_token.clone()))
.await?;
// Wait for at least one worker instance before proceeding
let mut instance_event_stream =
wait_for_worker_instance(&component, &cancellation_token).await?;
// Watch for router deletions to clean up orphaned consumers via discovery
let generate_endpoint = component.endpoint("generate");
let discovery_client = component.drt().discovery();
let router_discovery_key = router_discovery_query(component.namespace().name());
let mut router_event_stream = discovery_client
.list_and_watch(router_discovery_key, Some(cancellation_token.clone()))
......@@ -725,16 +747,9 @@ pub async fn start_kv_router_background_nats_core(
"KV Router using NATS Core subscription (local_indexer mode)"
);
// Get the generate endpoint and watch for instance events (add/remove)
let discovery_client = component.drt().discovery();
let generate_discovery_key = DiscoveryQuery::Endpoint {
namespace: component.namespace().name().to_string(),
component: component.name().to_string(),
endpoint: "generate".to_string(),
};
let mut instance_event_stream = discovery_client
.list_and_watch(generate_discovery_key, Some(cancellation_token.clone()))
.await?;
// Wait for at least one worker instance before proceeding
let mut instance_event_stream =
wait_for_worker_instance(&component, &cancellation_token).await?;
// Drain and process all existing workers before spawning the background loop.
// list_and_watch returns existing instances first, so we poll with a short timeout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment