Unverified Commit 7f40761b authored by MatejKosec's avatar MatejKosec Committed by GitHub
Browse files

fix(discovery): address concurrency edge cases in handle_put lifecycle (#8237)


Signed-off-by: default avatarMatej Kosec <mkosec@nvidia.com>
parent f05cabac
......@@ -283,7 +283,6 @@ impl ModelWatcher {
// await the in-flight put before attempting cleanup.
let instance_key = mcid.to_path();
let watcher = Arc::clone(&self);
let key_for_cleanup = instance_key.clone();
let handle = tokio::spawn(async move {
match watcher.handle_put(&mcid, &mut card).await {
Ok(()) => {
......@@ -303,12 +302,20 @@ impl ModelWatcher {
);
}
}
// Remove ourselves from pending_puts once complete
watcher.pending_puts.remove(&key_for_cleanup);
// Note: we intentionally do NOT remove from pending_puts here.
// Only the watch loop (on duplicate events) and handle_delete
// manage pending_puts, avoiding a race where a completed task's
// cleanup could remove a newer task's entry.
});
// If a duplicate Added event arrives while the first task is still
// in-flight, abort the old task to prevent its cleanup from
// removing the new task's handle from pending_puts.
// in-flight, abort the old task to cancel redundant work.
//
// `instance_key` is `mcid.to_path()` = "{ns}/{component}/{endpoint}/{instance_id:x}",
// so this is keyed per-worker-instance, NOT per-model. Two different workers
// registering the same model produce two different keys and run independently.
// The only case that hits this branch is the etcd watch replaying the same
// worker's Added event (reconnect or re-sync) — where cancelling the earlier
// redundant task is exactly what we want.
if let Some((_, old_handle)) = self.pending_puts.remove(&instance_key) {
old_handle.abort();
}
......@@ -358,11 +365,23 @@ impl ModelWatcher {
// to complete before we attempt cleanup. Without this, a Removed event
// arriving while handle_put is still downloading HF config would fail
// to find the model card, leaving a stale registration.
if let Some((_, handle)) = self.pending_puts.remove(&key) {
if let Some((_, mut handle)) = self.pending_puts.remove(&key) {
tracing::debug!(key = %key, "awaiting in-flight handle_put before delete");
// Ignore join errors (panic in the spawned task) — we still proceed
// with cleanup since the put may have partially registered the model.
match tokio::time::timeout(Duration::from_secs(60), &mut handle).await {
Ok(_) => {}
Err(_) => {
// Abort the timed-out task so it cannot register the model
// after we proceed with deletion.
handle.abort();
let _ = handle.await;
tracing::warn!(
key = %key,
"Timed out waiting for in-flight handle_put, aborted and proceeding with delete"
);
}
}
}
let card = match self.manager.remove_model_card(&key) {
Some(card) => card,
......@@ -454,6 +473,18 @@ impl ModelWatcher {
if let Some(model) = self.manager.get_model(&model_name)
&& model.has_worker_set(&ws_key)
{
if !model.is_checksum_compatible(&ws_key, card.mdcsum()) {
tracing::error!(
model_name = card.name(),
namespace = namespace,
new_checksum = card.mdcsum(),
"Checksum for new worker does not match existing WorkerSet's checksum. \
Drain all old workers in this namespace before deploying a new version."
);
return Err(anyhow::anyhow!(
"Checksum mismatch for worker in namespace {namespace}"
));
}
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
tracing::debug!(
......@@ -540,6 +571,20 @@ impl ModelWatcher {
}
}
// If we timed out and the other task is still running, bail out rather
// than proceeding with concurrent pipeline construction.
if self.registering_worker_sets.contains(registration_key) {
// Save the model card so handle_delete can find it for cleanup.
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
tracing::warn!(
model_name = card.name(),
namespace = namespace,
"Timed out waiting for concurrent registration to complete, skipping"
);
return Ok(false);
}
// Validate checksum against the registered model
if let Some(model) = self.manager.get_model(model_name)
&& !model.is_checksum_compatible(ws_key, card.mdcsum())
......@@ -562,13 +607,27 @@ impl ModelWatcher {
.get_model(model_name)
.is_none_or(|m| !m.has_worker_set(ws_key))
{
// Only the first waiter to re-insert the key should proceed with
// registration. Other waiters return false to avoid concurrent builds.
if !self
.registering_worker_sets
.insert(registration_key.to_string())
{
// Save the model card so handle_delete can find it for cleanup.
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
tracing::debug!(
model_name = card.name(),
namespace = namespace,
"Another waiter won the re-registration race, skipping"
);
return Ok(false);
}
tracing::warn!(
model_name = card.name(),
namespace = namespace,
"Concurrent registration produced no WorkerSet, retrying"
);
self.registering_worker_sets
.insert(registration_key.to_string());
return Ok(true);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment