Unverified Commit 470bb48f authored by MatejKosec's avatar MatejKosec Committed by GitHub
Browse files

fix(discovery): spawn handle_put tasks concurrently to prevent replay wedging (#7931)


Signed-off-by: default avatarMatej Kosec <mkosec@nvidia.com>
parent e232bec0
...@@ -2,11 +2,13 @@ ...@@ -2,11 +2,13 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Notify; use tokio::sync::Notify;
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use tokio::task::JoinHandle;
use anyhow::Context as _; use anyhow::Context as _;
use dashmap::DashSet; use dashmap::{DashMap, DashSet};
use dynamo_kv_router::PrefillLoadEstimator; use dynamo_kv_router::PrefillLoadEstimator;
use futures::StreamExt; use futures::StreamExt;
...@@ -80,6 +82,9 @@ pub struct ModelWatcher { ...@@ -80,6 +82,9 @@ pub struct ModelWatcher {
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
/// Guards against concurrent pipeline construction for the same (model, namespace). /// Guards against concurrent pipeline construction for the same (model, namespace).
registering_worker_sets: DashSet<String>, registering_worker_sets: DashSet<String>,
/// Tracks in-flight `handle_put` tasks by instance path so that `handle_delete`
/// can await a racing put before proceeding with cleanup.
pending_puts: DashMap<String, JoinHandle<()>>,
} }
const ALL_MODEL_TYPES: &[ModelType] = &[ const ALL_MODEL_TYPES: &[ModelType] = &[
...@@ -114,6 +119,20 @@ fn is_model_type_list_empty(manager: &ModelManager, model_type: ModelType) -> bo ...@@ -114,6 +119,20 @@ fn is_model_type_list_empty(manager: &ModelManager, model_type: ModelType) -> bo
} }
} }
/// RAII guard that removes a key from a `DashSet` on drop.
/// Ensures `registering_worker_sets` is cleaned up even if the registration
/// task panics, preventing permanent poisoning of the registration key.
struct RegistrationGuard<'a> {
set: &'a DashSet<String>,
key: String,
}
impl Drop for RegistrationGuard<'_> {
fn drop(&mut self) {
self.set.remove(&self.key);
}
}
impl ModelWatcher { impl ModelWatcher {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
...@@ -138,6 +157,7 @@ impl ModelWatcher { ...@@ -138,6 +157,7 @@ impl ModelWatcher {
prefill_load_estimator, prefill_load_estimator,
metrics, metrics,
registering_worker_sets: DashSet::new(), registering_worker_sets: DashSet::new(),
pending_puts: DashMap::new(),
} }
} }
...@@ -156,9 +176,13 @@ impl ModelWatcher { ...@@ -156,9 +176,13 @@ impl ModelWatcher {
} }
} }
/// Common watch logic with optional namespace filtering /// Common watch logic with optional namespace filtering.
///
/// Takes `Arc<Self>` so that each `handle_put` call can be spawned into its own
/// tokio task, preventing a slow HuggingFace config download for one model from
/// blocking discovery events for all subsequent models.
pub async fn watch( pub async fn watch(
&self, self: Arc<Self>,
mut discovery_stream: DiscoveryStream, mut discovery_stream: DiscoveryStream,
namespace_filter: NamespaceFilter, namespace_filter: NamespaceFilter,
) { ) {
...@@ -241,14 +265,25 @@ impl ModelWatcher { ...@@ -241,14 +265,25 @@ impl ModelWatcher {
continue; continue;
} }
match self.handle_put(&mcid, &mut card).await { // Spawn each handle_put into its own task so that a slow
// HuggingFace config download for one model cannot block
// discovery events for all subsequent models.
//
// The JoinHandle is stored in `pending_puts` so that a
// subsequent `handle_delete` for the same instance can
// await the in-flight put before attempting cleanup.
let instance_key = mcid.to_path();
let watcher = Arc::clone(&self);
let key_for_cleanup = instance_key.clone();
let handle = tokio::spawn(async move {
match watcher.handle_put(&mcid, &mut card).await {
Ok(()) => { Ok(()) => {
tracing::info!( tracing::info!(
model_name = card.name(), model_name = card.name(),
namespace = mcid.namespace, namespace = mcid.namespace,
"added model" "added model"
); );
self.notify_on_model.notify_waiters(); watcher.notify_on_model.notify_waiters();
} }
Err(err) => { Err(err) => {
tracing::error!( tracing::error!(
...@@ -259,6 +294,16 @@ impl ModelWatcher { ...@@ -259,6 +294,16 @@ impl ModelWatcher {
); );
} }
} }
// Remove ourselves from pending_puts once complete
watcher.pending_puts.remove(&key_for_cleanup);
});
// If a duplicate Added event arrives while the first task is still
// in-flight, abort the old task to prevent its cleanup from
// removing the new task's handle from pending_puts.
if let Some((_, old_handle)) = self.pending_puts.remove(&instance_key) {
old_handle.abort();
}
self.pending_puts.insert(instance_key, handle);
} }
DiscoveryEvent::Removed(id) => { DiscoveryEvent::Removed(id) => {
// Extract ModelCardInstanceId from the removal event // Extract ModelCardInstanceId from the removal event
...@@ -299,6 +344,17 @@ impl ModelWatcher { ...@@ -299,6 +344,17 @@ impl ModelWatcher {
namespace_filter: &NamespaceFilter, namespace_filter: &NamespaceFilter,
) -> anyhow::Result<Option<String>> { ) -> anyhow::Result<Option<String>> {
let key = mcid.to_path(); let key = mcid.to_path();
// If there is an in-flight handle_put for this instance, wait for it
// to complete before we attempt cleanup. Without this, a Removed event
// arriving while handle_put is still downloading HF config would fail
// to find the model card, leaving a stale registration.
if let Some((_, handle)) = self.pending_puts.remove(&key) {
tracing::debug!(key = %key, "awaiting in-flight handle_put before delete");
// Ignore join errors (panic in the spawned task) — we still proceed
// with cleanup since the put may have partially registered the model.
let _ = handle.await;
}
let card = match self.manager.remove_model_card(&key) { let card = match self.manager.remove_model_card(&key) {
Some(card) => card, Some(card) => card,
None => { None => {
...@@ -404,23 +460,94 @@ impl ModelWatcher { ...@@ -404,23 +460,94 @@ impl ModelWatcher {
if !self if !self
.registering_worker_sets .registering_worker_sets
.insert(registration_key.clone()) .insert(registration_key.clone())
&& !self
.recover_concurrent_registration(
mcid,
card,
&model_name,
&namespace,
&ws_key,
&registration_key,
)
.await?
{ {
self.manager return Ok(());
.save_model_card(&mcid.to_path(), card.clone())?; }
tracing::debug!(
// RAII guard ensures the registration key is removed even if
// do_worker_set_registration panics, preventing permanent poisoning.
let _guard = RegistrationGuard {
set: &self.registering_worker_sets,
key: registration_key,
};
self.do_worker_set_registration(mcid, card).await
}
/// Handle the case where another task is already building the pipeline for this
/// (model, namespace, type). This is a recovery path — it waits for the in-flight
/// registration to finish, then either joins the resulting WorkerSet or retries.
///
/// Returns `true` if the caller should proceed with its own registration
/// (i.e. the other task failed), `false` if the worker was handled (joined or rejected).
async fn recover_concurrent_registration(
&self,
mcid: &ModelCardInstanceId,
card: &mut ModelDeploymentCard,
model_name: &str,
namespace: &str,
ws_key: &str,
registration_key: &str,
) -> anyhow::Result<bool> {
// Wait for the in-flight registration to complete so we can validate
// the new worker's checksum. Without this, a concurrent worker with a
// mismatched checksum could sneak past the early check in `watch`.
let mut attempts = 0;
while self.registering_worker_sets.contains(registration_key) && attempts < 300 {
tokio::time::sleep(Duration::from_millis(100)).await;
attempts += 1;
}
// Validate checksum against the registered model
if let Some(model) = self.manager.get_model(model_name)
&& !model.is_checksum_compatible(ws_key, card.mdcsum())
{
tracing::error!(
model_name = card.name(), model_name = card.name(),
namespace = namespace, namespace = namespace,
"WorkerSet registration in progress, skipping" new_checksum = card.mdcsum(),
"Checksum for new worker does not match existing WorkerSet's checksum. \
Drain all old workers in this namespace before deploying a new version."
); );
return Ok(()); return Ok(false);
} }
let result = self.do_worker_set_registration(mcid, card).await; // If the first registration failed or timed out, no WorkerSet exists.
// Fall through to do_worker_set_registration instead of becoming a ghost
// Always remove from registering set // worker (registered in cards but with no serving pipeline).
self.registering_worker_sets.remove(&registration_key); if self
.manager
.get_model(model_name)
.is_none_or(|m| !m.has_worker_set(ws_key))
{
tracing::warn!(
model_name = card.name(),
namespace = namespace,
"Concurrent registration produced no WorkerSet, retrying"
);
self.registering_worker_sets
.insert(registration_key.to_string());
return Ok(true);
}
result self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
tracing::debug!(
model_name = card.name(),
namespace = namespace,
"Worker joined existing WorkerSet, skipping pipeline build"
);
Ok(false)
} }
/// Build a complete WorkerSet with all engines for this (model, namespace) /// Build a complete WorkerSet with all engines for this (model, namespace)
......
...@@ -147,6 +147,7 @@ async fn run_watcher( ...@@ -147,6 +147,7 @@ async fn run_watcher(
// only has one kind of inference endpoint. // only has one kind of inference endpoint.
// Pass the discovery stream to the watcher // Pass the discovery stream to the watcher
let watch_obj = Arc::new(watch_obj);
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
watch_obj.watch(discovery_stream, namespace_filter).await; watch_obj.watch(discovery_stream, namespace_filter).await;
}); });
......
...@@ -196,6 +196,7 @@ async fn run_watcher( ...@@ -196,6 +196,7 @@ async fn run_watcher(
// Create a channel to receive model type updates // Create a channel to receive model type updates
let (tx, mut rx) = tokio::sync::mpsc::channel(32); let (tx, mut rx) = tokio::sync::mpsc::channel(32);
watch_obj.set_notify_on_model_update(tx); watch_obj.set_notify_on_model_update(tx);
let watch_obj = Arc::new(watch_obj);
// Spawn a task to watch for model type changes and update HTTP service endpoints and metrics // Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
let _endpoint_enabler_task = tokio::spawn(async move { let _endpoint_enabler_task = tokio::spawn(async move {
......
...@@ -383,7 +383,7 @@ mod integration_tests { ...@@ -383,7 +383,7 @@ mod integration_tests {
// Spawn watcher task to discover models // Spawn watcher task to discover models
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
model_watcher Arc::new(model_watcher)
.watch(discovery_stream, NamespaceFilter::Global) .watch(discovery_stream, NamespaceFilter::Global)
.await; .await;
}); });
......
...@@ -324,7 +324,10 @@ impl Manager { ...@@ -324,7 +324,10 @@ impl Manager {
tokio::sync::mpsc::Receiver<WatchEvent>, tokio::sync::mpsc::Receiver<WatchEvent>,
) { ) {
let bucket_name = bucket_name.to_string(); let bucket_name = bucket_name.to_string();
let (tx, rx) = tokio::sync::mpsc::channel(1024); // Use a larger channel capacity to reduce the likelihood that a slow consumer
// during the initial KV-store replay phase triggers send timeouts. Events may
// still be dropped if the consumer cannot keep up within `WATCH_SEND_TIMEOUT`.
let (tx, rx) = tokio::sync::mpsc::channel(16384);
let watch_task = tokio::spawn(async move { let watch_task = tokio::spawn(async move {
// Start listening for changes but don't poll this yet // Start listening for changes but don't poll this yet
let bucket = self let bucket = self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment