Unverified Commit 7f40761b authored by MatejKosec's avatar MatejKosec Committed by GitHub
Browse files

fix(discovery): address concurrency edge cases in handle_put lifecycle (#8237)


Signed-off-by: default avatarMatej Kosec <mkosec@nvidia.com>
parent f05cabac
...@@ -283,7 +283,6 @@ impl ModelWatcher { ...@@ -283,7 +283,6 @@ impl ModelWatcher {
// await the in-flight put before attempting cleanup. // await the in-flight put before attempting cleanup.
let instance_key = mcid.to_path(); let instance_key = mcid.to_path();
let watcher = Arc::clone(&self); let watcher = Arc::clone(&self);
let key_for_cleanup = instance_key.clone();
let handle = tokio::spawn(async move { let handle = tokio::spawn(async move {
match watcher.handle_put(&mcid, &mut card).await { match watcher.handle_put(&mcid, &mut card).await {
Ok(()) => { Ok(()) => {
...@@ -303,12 +302,20 @@ impl ModelWatcher { ...@@ -303,12 +302,20 @@ impl ModelWatcher {
); );
} }
} }
// Remove ourselves from pending_puts once complete // Note: we intentionally do NOT remove from pending_puts here.
watcher.pending_puts.remove(&key_for_cleanup); // Only the watch loop (on duplicate events) and handle_delete
// manage pending_puts, avoiding a race where a completed task's
// cleanup could remove a newer task's entry.
}); });
// If a duplicate Added event arrives while the first task is still // If a duplicate Added event arrives while the first task is still
// in-flight, abort the old task to prevent its cleanup from // in-flight, abort the old task to cancel redundant work.
// removing the new task's handle from pending_puts. //
// `instance_key` is `mcid.to_path()` = "{ns}/{component}/{endpoint}/{instance_id:x}",
// so this is keyed per-worker-instance, NOT per-model. Two different workers
// registering the same model produce two different keys and run independently.
// The only case that hits this branch is the etcd watch replaying the same
// worker's Added event (reconnect or re-sync) — where cancelling the earlier
// redundant task is exactly what we want.
if let Some((_, old_handle)) = self.pending_puts.remove(&instance_key) { if let Some((_, old_handle)) = self.pending_puts.remove(&instance_key) {
old_handle.abort(); old_handle.abort();
} }
...@@ -358,11 +365,23 @@ impl ModelWatcher { ...@@ -358,11 +365,23 @@ impl ModelWatcher {
// to complete before we attempt cleanup. Without this, a Removed event // to complete before we attempt cleanup. Without this, a Removed event
// arriving while handle_put is still downloading HF config would fail // arriving while handle_put is still downloading HF config would fail
// to find the model card, leaving a stale registration. // to find the model card, leaving a stale registration.
if let Some((_, handle)) = self.pending_puts.remove(&key) { if let Some((_, mut handle)) = self.pending_puts.remove(&key) {
tracing::debug!(key = %key, "awaiting in-flight handle_put before delete"); tracing::debug!(key = %key, "awaiting in-flight handle_put before delete");
// Ignore join errors (panic in the spawned task) — we still proceed // Ignore join errors (panic in the spawned task) — we still proceed
// with cleanup since the put may have partially registered the model. // with cleanup since the put may have partially registered the model.
let _ = handle.await; match tokio::time::timeout(Duration::from_secs(60), &mut handle).await {
Ok(_) => {}
Err(_) => {
// Abort the timed-out task so it cannot register the model
// after we proceed with deletion.
handle.abort();
let _ = handle.await;
tracing::warn!(
key = %key,
"Timed out waiting for in-flight handle_put, aborted and proceeding with delete"
);
}
}
} }
let card = match self.manager.remove_model_card(&key) { let card = match self.manager.remove_model_card(&key) {
Some(card) => card, Some(card) => card,
...@@ -454,6 +473,18 @@ impl ModelWatcher { ...@@ -454,6 +473,18 @@ impl ModelWatcher {
if let Some(model) = self.manager.get_model(&model_name) if let Some(model) = self.manager.get_model(&model_name)
&& model.has_worker_set(&ws_key) && model.has_worker_set(&ws_key)
{ {
if !model.is_checksum_compatible(&ws_key, card.mdcsum()) {
tracing::error!(
model_name = card.name(),
namespace = namespace,
new_checksum = card.mdcsum(),
"Checksum for new worker does not match existing WorkerSet's checksum. \
Drain all old workers in this namespace before deploying a new version."
);
return Err(anyhow::anyhow!(
"Checksum mismatch for worker in namespace {namespace}"
));
}
self.manager self.manager
.save_model_card(&mcid.to_path(), card.clone())?; .save_model_card(&mcid.to_path(), card.clone())?;
tracing::debug!( tracing::debug!(
...@@ -540,6 +571,20 @@ impl ModelWatcher { ...@@ -540,6 +571,20 @@ impl ModelWatcher {
} }
} }
// If we timed out and the other task is still running, bail out rather
// than proceeding with concurrent pipeline construction.
if self.registering_worker_sets.contains(registration_key) {
// Save the model card so handle_delete can find it for cleanup.
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
tracing::warn!(
model_name = card.name(),
namespace = namespace,
"Timed out waiting for concurrent registration to complete, skipping"
);
return Ok(false);
}
// Validate checksum against the registered model // Validate checksum against the registered model
if let Some(model) = self.manager.get_model(model_name) if let Some(model) = self.manager.get_model(model_name)
&& !model.is_checksum_compatible(ws_key, card.mdcsum()) && !model.is_checksum_compatible(ws_key, card.mdcsum())
...@@ -562,13 +607,27 @@ impl ModelWatcher { ...@@ -562,13 +607,27 @@ impl ModelWatcher {
.get_model(model_name) .get_model(model_name)
.is_none_or(|m| !m.has_worker_set(ws_key)) .is_none_or(|m| !m.has_worker_set(ws_key))
{ {
// Only the first waiter to re-insert the key should proceed with
// registration. Other waiters return false to avoid concurrent builds.
if !self
.registering_worker_sets
.insert(registration_key.to_string())
{
// Save the model card so handle_delete can find it for cleanup.
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
tracing::debug!(
model_name = card.name(),
namespace = namespace,
"Another waiter won the re-registration race, skipping"
);
return Ok(false);
}
tracing::warn!( tracing::warn!(
model_name = card.name(), model_name = card.name(),
namespace = namespace, namespace = namespace,
"Concurrent registration produced no WorkerSet, retrying" "Concurrent registration produced no WorkerSet, retrying"
); );
self.registering_worker_sets
.insert(registration_key.to_string());
return Ok(true); return Ok(true);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment