"lib/bindings/python/vscode:/vscode.git/clone" did not exist on "cf433e6825d83f41905da47d69ca5ee30d4eb1ba"
Unverified Commit b88cb59b authored by Janelle Cai's avatar Janelle Cai Committed by GitHub
Browse files

chore: move HF config download and prevent duplicate concurrent model registration (#5767)

parent 95383fd6
......@@ -6,6 +6,7 @@ use tokio::sync::Notify;
use tokio::sync::mpsc::Sender;
use anyhow::Context as _;
use dashmap::DashSet;
use futures::StreamExt;
use dynamo_runtime::{
......@@ -59,6 +60,7 @@ pub struct ModelWatcher {
model_update_tx: Option<Sender<ModelUpdate>>,
engine_factory: Option<EngineFactoryCallback>,
metrics: Arc<Metrics>,
registering_models: DashSet<String>,
}
const ALL_MODEL_TYPES: &[ModelType] = &[
......@@ -85,6 +87,7 @@ impl ModelWatcher {
model_update_tx: None,
engine_factory,
metrics,
registering_models: DashSet::new(),
}
}
......@@ -340,19 +343,9 @@ impl ModelWatcher {
mcid: &ModelCardInstanceId,
card: &mut ModelDeploymentCard,
) -> anyhow::Result<()> {
card.download_config().await?;
let component = self
.drt
.namespace(&mcid.namespace)?
.component(&mcid.component)?;
let endpoint = component.endpoint(&mcid.endpoint);
let client = endpoint.client().await?;
tracing::debug!(model_name = card.name(), "adding model");
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
// Skip duplicate registrations based on model type.
// Check if model is already registered before downloading config.
// This prevents duplicate HuggingFace API calls when multiple workers register
// the same model.
// Prefill and decode models are tracked separately, so registering one
// doesn't block the other (they can arrive in any order).
let already_registered = if card.model_type.supports_prefill() {
......@@ -362,15 +355,58 @@ impl ModelWatcher {
};
if already_registered {
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
tracing::debug!(
model_name = card.name(),
namespace = mcid.namespace,
model_type = %card.model_type,
"Model already registered, skipping"
"Model already registered, skipping config download"
);
return Ok(());
}
// Use registering_models set to prevent concurrent registrations.
let model_key = card.name().to_string();
if !self.registering_models.insert(model_key.clone()) {
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
tracing::debug!(
model_name = card.name(),
namespace = mcid.namespace,
"Model registration in progress by another worker, skipping"
);
return Ok(());
}
// We acquired the registration lock. Use a helper to ensure cleanup on all exit paths.
let result = self.do_model_registration(mcid, card).await;
// Always remove from registering set, whether success or failure
self.registering_models.remove(&model_key);
result
}
/// Inner function that performs the actual model registration.
/// Called by handle_put after acquiring the registration lock.
async fn do_model_registration(
&self,
mcid: &ModelCardInstanceId,
card: &mut ModelDeploymentCard,
) -> anyhow::Result<()> {
card.download_config().await?;
let component = self
.drt
.namespace(&mcid.namespace)?
.component(&mcid.component)?;
let endpoint = component.endpoint(&mcid.endpoint);
let client = endpoint.client().await?;
tracing::debug!(model_name = card.name(), "adding model");
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Added(card.clone())).await.ok();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment