Unverified Commit b88cb59b authored by Janelle Cai's avatar Janelle Cai Committed by GitHub
Browse files

chore: move HF config download and prevent duplicate concurrent model registration (#5767)

parent 95383fd6
...@@ -6,6 +6,7 @@ use tokio::sync::Notify; ...@@ -6,6 +6,7 @@ use tokio::sync::Notify;
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use anyhow::Context as _; use anyhow::Context as _;
use dashmap::DashSet;
use futures::StreamExt; use futures::StreamExt;
use dynamo_runtime::{ use dynamo_runtime::{
...@@ -59,6 +60,7 @@ pub struct ModelWatcher { ...@@ -59,6 +60,7 @@ pub struct ModelWatcher {
model_update_tx: Option<Sender<ModelUpdate>>, model_update_tx: Option<Sender<ModelUpdate>>,
engine_factory: Option<EngineFactoryCallback>, engine_factory: Option<EngineFactoryCallback>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
registering_models: DashSet<String>,
} }
const ALL_MODEL_TYPES: &[ModelType] = &[ const ALL_MODEL_TYPES: &[ModelType] = &[
...@@ -85,6 +87,7 @@ impl ModelWatcher { ...@@ -85,6 +87,7 @@ impl ModelWatcher {
model_update_tx: None, model_update_tx: None,
engine_factory, engine_factory,
metrics, metrics,
registering_models: DashSet::new(),
} }
} }
...@@ -340,19 +343,9 @@ impl ModelWatcher { ...@@ -340,19 +343,9 @@ impl ModelWatcher {
mcid: &ModelCardInstanceId, mcid: &ModelCardInstanceId,
card: &mut ModelDeploymentCard, card: &mut ModelDeploymentCard,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
card.download_config().await?; // Check if model is already registered before downloading config.
// This prevents duplicate HuggingFace API calls when multiple workers register
let component = self // the same model.
.drt
.namespace(&mcid.namespace)?
.component(&mcid.component)?;
let endpoint = component.endpoint(&mcid.endpoint);
let client = endpoint.client().await?;
tracing::debug!(model_name = card.name(), "adding model");
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
// Skip duplicate registrations based on model type.
// Prefill and decode models are tracked separately, so registering one // Prefill and decode models are tracked separately, so registering one
// doesn't block the other (they can arrive in any order). // doesn't block the other (they can arrive in any order).
let already_registered = if card.model_type.supports_prefill() { let already_registered = if card.model_type.supports_prefill() {
...@@ -362,15 +355,58 @@ impl ModelWatcher { ...@@ -362,15 +355,58 @@ impl ModelWatcher {
}; };
if already_registered { if already_registered {
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
tracing::debug!( tracing::debug!(
model_name = card.name(), model_name = card.name(),
namespace = mcid.namespace, namespace = mcid.namespace,
model_type = %card.model_type, model_type = %card.model_type,
"Model already registered, skipping" "Model already registered, skipping config download"
);
return Ok(());
}
// Use registering_models set to prevent concurrent registrations.
let model_key = card.name().to_string();
if !self.registering_models.insert(model_key.clone()) {
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
tracing::debug!(
model_name = card.name(),
namespace = mcid.namespace,
"Model registration in progress by another worker, skipping"
); );
return Ok(()); return Ok(());
} }
// We acquired the registration lock. Use a helper to ensure cleanup on all exit paths.
let result = self.do_model_registration(mcid, card).await;
// Always remove from registering set, whether success or failure
self.registering_models.remove(&model_key);
result
}
/// Inner function that performs the actual model registration.
/// Called by handle_put after acquiring the registration lock.
async fn do_model_registration(
&self,
mcid: &ModelCardInstanceId,
card: &mut ModelDeploymentCard,
) -> anyhow::Result<()> {
card.download_config().await?;
let component = self
.drt
.namespace(&mcid.namespace)?
.component(&mcid.component)?;
let endpoint = component.endpoint(&mcid.endpoint);
let client = endpoint.client().await?;
tracing::debug!(model_name = card.name(), "adding model");
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
if let Some(tx) = &self.model_update_tx { if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Added(card.clone())).await.ok(); tx.send(ModelUpdate::Added(card.clone())).await.ok();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment