Unverified Commit d6136f4a authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

feat(runtime): scope MDC checksum validation to per-WorkerSet (#7368)


Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent a565f105
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
//! //!
//! Requests are routed to a WorkerSet selected by weighted random (proportional to worker count). //! Requests are routed to a WorkerSet selected by weighted random (proportional to worker count).
use std::sync::{Arc, OnceLock}; use std::sync::Arc;
use dashmap::DashMap; use dashmap::DashMap;
use rand::Rng; use rand::Rng;
...@@ -29,10 +29,6 @@ use crate::types::{ ...@@ -29,10 +29,6 @@ use crate::types::{
pub struct Model { pub struct Model {
name: String, name: String,
worker_sets: DashMap<String, Arc<WorkerSet>>, worker_sets: DashMap<String, Arc<WorkerSet>>,
/// The canonical MDC checksum for this model. Set by the first WorkerSet registered;
/// all subsequent WorkerSets must match. Naturally cleared when the Model is dropped
/// (last WorkerSet removed), allowing a new version to register.
canonical_checksum: OnceLock<String>,
} }
impl Model { impl Model {
...@@ -40,7 +36,6 @@ impl Model { ...@@ -40,7 +36,6 @@ impl Model {
Self { Self {
name, name,
worker_sets: DashMap::new(), worker_sets: DashMap::new(),
canonical_checksum: OnceLock::new(),
} }
} }
...@@ -48,21 +43,23 @@ impl Model { ...@@ -48,21 +43,23 @@ impl Model {
&self.name &self.name
} }
/// Add a WorkerSet to this model. Returns `Err` if the WorkerSet's checksum /// Add a WorkerSet to this model.
/// doesn't match the model's canonical checksum (set by the first WorkerSet). pub fn add_worker_set(&self, namespace: String, worker_set: Arc<WorkerSet>) {
pub fn add_worker_set(
&self,
namespace: String,
worker_set: Arc<WorkerSet>,
) -> Result<(), ModelManagerError> {
self.set_canonical_checksum(worker_set.mdcsum())?;
tracing::info!( tracing::info!(
model = %self.name, model = %self.name,
namespace = %namespace, namespace = %namespace,
"Adding worker set to model" "Adding worker set to model"
); );
self.worker_sets.insert(namespace, worker_set); self.worker_sets.insert(namespace, worker_set);
Ok(()) }
/// Check whether a candidate checksum is compatible with an existing WorkerSet
/// identified by `ws_key`.
pub fn is_checksum_compatible(&self, ws_key: &str, candidate_checksum: &str) -> bool {
match self.worker_sets.get(ws_key) {
Some(existing_ws) => existing_ws.mdcsum() == candidate_checksum,
None => true,
}
} }
pub fn remove_worker_set(&self, namespace: &str) -> Option<Arc<WorkerSet>> { pub fn remove_worker_set(&self, namespace: &str) -> Option<Arc<WorkerSet>> {
...@@ -177,36 +174,6 @@ impl Model { ...@@ -177,36 +174,6 @@ impl Model {
}) })
} }
/// Check if a candidate checksum is valid for this model.
/// Returns `Some(true)` if it matches the canonical checksum, `Some(false)` if it
/// doesn't match, or `None` if no canonical checksum has been set yet (no WorkerSets).
pub fn is_valid_checksum(&self, candidate: &str) -> Option<bool> {
let canonical = self.canonical_checksum.get()?;
Some(canonical == candidate)
}
/// Set the canonical checksum for this model. The first caller wins (OnceLock).
/// Returns `Err` if a different checksum was already set.
fn set_canonical_checksum(&self, checksum: &str) -> Result<(), ModelManagerError> {
// Try to set; if already set, verify it matches.
match self.canonical_checksum.set(checksum.to_string()) {
Ok(()) => Ok(()),
Err(_) => {
// OnceLock was already set — check if the value matches
let canonical = self.canonical_checksum.get().unwrap();
if canonical == checksum {
Ok(())
} else {
Err(ModelManagerError::ChecksumMismatch {
model: self.name.clone(),
expected: canonical.clone(),
got: checksum.to_string(),
})
}
}
}
}
// -- Engine accessors: select a WorkerSet, return its engine -- // -- Engine accessors: select a WorkerSet, return its engine --
pub fn get_chat_engine( pub fn get_chat_engine(
...@@ -410,7 +377,7 @@ mod tests { ...@@ -410,7 +377,7 @@ mod tests {
let model = Model::new("llama".to_string()); let model = Model::new("llama".to_string());
let ws = make_worker_set("ns1", "abc"); let ws = make_worker_set("ns1", "abc");
model.add_worker_set("ns1".to_string(), ws).unwrap(); model.add_worker_set("ns1".to_string(), ws);
assert!(!model.is_empty()); assert!(!model.is_empty());
assert_eq!(model.worker_set_count(), 1); assert_eq!(model.worker_set_count(), 1);
assert!(model.has_worker_set("ns1")); assert!(model.has_worker_set("ns1"));
...@@ -428,7 +395,7 @@ mod tests { ...@@ -428,7 +395,7 @@ mod tests {
fn test_get_worker_set() { fn test_get_worker_set() {
let model = Model::new("llama".to_string()); let model = Model::new("llama".to_string());
let ws = make_worker_set("ns1", "abc"); let ws = make_worker_set("ns1", "abc");
model.add_worker_set("ns1".to_string(), ws).unwrap(); model.add_worker_set("ns1".to_string(), ws);
let retrieved = model.get_worker_set("ns1"); let retrieved = model.get_worker_set("ns1");
assert!(retrieved.is_some()); assert!(retrieved.is_some());
...@@ -440,12 +407,8 @@ mod tests { ...@@ -440,12 +407,8 @@ mod tests {
#[test] #[test]
fn test_multiple_worker_sets_same_checksum() { fn test_multiple_worker_sets_same_checksum() {
let model = Model::new("llama".to_string()); let model = Model::new("llama".to_string());
model model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc")) model.add_worker_set("ns2".to_string(), make_worker_set("ns2", "abc"));
.unwrap();
model
.add_worker_set("ns2".to_string(), make_worker_set("ns2", "abc"))
.unwrap();
assert_eq!(model.worker_set_count(), 2); assert_eq!(model.worker_set_count(), 2);
assert!(model.has_worker_set("ns1")); assert!(model.has_worker_set("ns1"));
...@@ -458,41 +421,57 @@ mod tests { ...@@ -458,41 +421,57 @@ mod tests {
} }
#[test] #[test]
fn test_add_worker_set_rejects_checksum_mismatch() { fn test_multiple_worker_sets_different_checksums() {
// Different namespaces are allowed to have different checksums
let model = Model::new("llama".to_string());
model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
model.add_worker_set("ns2".to_string(), make_worker_set("ns2", "def"));
assert_eq!(model.worker_set_count(), 2);
assert!(model.has_worker_set("ns1"));
assert!(model.has_worker_set("ns2"));
}
#[test]
fn test_is_checksum_compatible_no_existing_worker_set() {
let model = Model::new("llama".to_string());
// No WorkerSet exists yet — any checksum is compatible
assert!(model.is_checksum_compatible("ns1", "abc"));
assert!(model.is_checksum_compatible("ns1", "xyz"));
}
#[test]
fn test_is_checksum_compatible_matching_checksum() {
let model = Model::new("llama".to_string()); let model = Model::new("llama".to_string());
model model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
// Different checksum from a different namespace should be rejected // Same ws_key, same checksum → compatible
let result = model.add_worker_set("ns2".to_string(), make_worker_set("ns2", "def")); assert!(model.is_checksum_compatible("ns1", "abc"));
assert!(result.is_err());
assert_eq!(model.worker_set_count(), 1); // ns2 was not added
} }
#[test] #[test]
fn test_is_valid_checksum() { fn test_is_checksum_compatible_mismatched_checksum() {
let model = Model::new("llama".to_string()); let model = Model::new("llama".to_string());
model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
// No canonical set yet // Same ws_key, different checksum → incompatible
assert_eq!(model.is_valid_checksum("abc123"), None); assert!(!model.is_checksum_compatible("ns1", "def"));
}
model #[test]
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc123")) fn test_is_checksum_compatible_different_ws_key() {
.unwrap(); let model = Model::new("llama".to_string());
model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
// Matches canonical // Different ws_key — no existing WorkerSet for "ns2", so any checksum is fine
assert_eq!(model.is_valid_checksum("abc123"), Some(true)); assert!(model.is_checksum_compatible("ns2", "def"));
// Does not match canonical assert!(model.is_checksum_compatible("ns2", "abc"));
assert_eq!(model.is_valid_checksum("wrong"), Some(false));
} }
#[test] #[test]
fn test_no_engines_means_prefill() { fn test_no_engines_means_prefill() {
let model = Model::new("llama".to_string()); let model = Model::new("llama".to_string());
model model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
// WorkerSets with no engines are treated as prefill sets // WorkerSets with no engines are treated as prefill sets
assert!(model.has_prefill()); assert!(model.has_prefill());
...@@ -507,9 +486,7 @@ mod tests { ...@@ -507,9 +486,7 @@ mod tests {
#[test] #[test]
fn test_get_engine_returns_error_without_engines() { fn test_get_engine_returns_error_without_engines() {
let model = Model::new("llama".to_string()); let model = Model::new("llama".to_string());
model model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
assert!(model.get_chat_engine().is_err()); assert!(model.get_chat_engine().is_err());
assert!(model.get_completions_engine().is_err()); assert!(model.get_completions_engine().is_err());
...@@ -530,15 +507,11 @@ mod tests { ...@@ -530,15 +507,11 @@ mod tests {
assert!(model.get_chat_engine().is_err()); assert!(model.get_chat_engine().is_err());
// Single set (fast path) // Single set (fast path)
model model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
assert!(model.get_chat_engine().is_err()); // No engine → filtered out assert!(model.get_chat_engine().is_err()); // No engine → filtered out
// Multiple sets (weighted path) // Multiple sets (weighted path)
model model.add_worker_set("ns2".to_string(), make_worker_set("ns2", "abc"));
.add_worker_set("ns2".to_string(), make_worker_set("ns2", "abc"))
.unwrap();
assert!(model.get_chat_engine().is_err()); // Still no engines → all filtered out assert!(model.get_chat_engine().is_err()); // Still no engines → all filtered out
} }
...@@ -548,14 +521,10 @@ mod tests { ...@@ -548,14 +521,10 @@ mod tests {
let model = Model::new("llama".to_string()); let model = Model::new("llama".to_string());
assert_eq!(model.total_workers(), 0); // empty model assert_eq!(model.total_workers(), 0); // empty model
model model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
assert_eq!(model.total_workers(), 1); assert_eq!(model.total_workers(), 1);
model model.add_worker_set("ns2".to_string(), make_worker_set("ns2", "abc"));
.add_worker_set("ns2".to_string(), make_worker_set("ns2", "abc"))
.unwrap();
assert_eq!(model.total_workers(), 2); assert_eq!(model.total_workers(), 2);
} }
...@@ -565,8 +534,8 @@ mod tests { ...@@ -565,8 +534,8 @@ mod tests {
let (ws1, _tx1) = make_worker_set_with_count("ns1", "abc", vec![1, 2, 3]); let (ws1, _tx1) = make_worker_set_with_count("ns1", "abc", vec![1, 2, 3]);
let (ws2, _tx2) = make_worker_set_with_count("ns2", "abc", vec![10, 20]); let (ws2, _tx2) = make_worker_set_with_count("ns2", "abc", vec![10, 20]);
model.add_worker_set("ns1".to_string(), ws1).unwrap(); model.add_worker_set("ns1".to_string(), ws1);
model.add_worker_set("ns2".to_string(), ws2).unwrap(); model.add_worker_set("ns2".to_string(), ws2);
assert_eq!(model.total_workers(), 5); // 3 + 2 assert_eq!(model.total_workers(), 5); // 3 + 2
} }
...@@ -576,7 +545,7 @@ mod tests { ...@@ -576,7 +545,7 @@ mod tests {
let model = Model::new("llama".to_string()); let model = Model::new("llama".to_string());
let (ws1, tx1) = make_worker_set_with_count("ns1", "abc", vec![1, 2]); let (ws1, tx1) = make_worker_set_with_count("ns1", "abc", vec![1, 2]);
model.add_worker_set("ns1".to_string(), ws1).unwrap(); model.add_worker_set("ns1".to_string(), ws1);
assert_eq!(model.total_workers(), 2); assert_eq!(model.total_workers(), 2);
// Workers leave // Workers leave
...@@ -597,7 +566,7 @@ mod tests { ...@@ -597,7 +566,7 @@ mod tests {
let model = Model::new("llama".to_string()); let model = Model::new("llama".to_string());
let (ws, _tx) = make_worker_set_with_count("ns1", "abc", vec![]); let (ws, _tx) = make_worker_set_with_count("ns1", "abc", vec![]);
model.add_worker_set("ns1".to_string(), ws).unwrap(); model.add_worker_set("ns1".to_string(), ws);
// WorkerSet exists but has 0 workers → selection filtered out → Err // WorkerSet exists but has 0 workers → selection filtered out → Err
assert!(model.get_chat_engine().is_err()); assert!(model.get_chat_engine().is_err());
...@@ -611,8 +580,8 @@ mod tests { ...@@ -611,8 +580,8 @@ mod tests {
let (ws1, _tx1) = make_worker_set_with_count("ns1", "abc", vec![]); let (ws1, _tx1) = make_worker_set_with_count("ns1", "abc", vec![]);
let (ws2, _tx2) = make_worker_set_with_count("ns2", "abc", vec![]); let (ws2, _tx2) = make_worker_set_with_count("ns2", "abc", vec![]);
model.add_worker_set("ns1".to_string(), ws1).unwrap(); model.add_worker_set("ns1".to_string(), ws1);
model.add_worker_set("ns2".to_string(), ws2).unwrap(); model.add_worker_set("ns2".to_string(), ws2);
// Both have 0 workers → all filtered → Err // Both have 0 workers → all filtered → Err
assert!(model.get_chat_engine().is_err()); assert!(model.get_chat_engine().is_err());
......
...@@ -47,15 +47,6 @@ pub enum ModelManagerError { ...@@ -47,15 +47,6 @@ pub enum ModelManagerError {
#[error("Model already exists: {0}")] #[error("Model already exists: {0}")]
ModelAlreadyExists(String), ModelAlreadyExists(String),
#[error(
"Checksum mismatch for model {model}: expected {expected}, got {got}. All WorkerSets of a model must share the same checksum. Drain all old workers before deploying a new version."
)]
ChecksumMismatch {
model: String,
expected: String,
got: String,
},
} }
/// Central manager for model engines, routing, and configuration. /// Central manager for model engines, routing, and configuration.
...@@ -124,15 +115,9 @@ impl ModelManager { ...@@ -124,15 +115,9 @@ impl ModelManager {
} }
/// Add a WorkerSet to a Model. Creates the Model if it doesn't exist. /// Add a WorkerSet to a Model. Creates the Model if it doesn't exist.
/// Returns `Err` if the WorkerSet's checksum doesn't match the model's canonical checksum. pub fn add_worker_set(&self, model_name: &str, namespace: &str, worker_set: WorkerSet) {
pub fn add_worker_set(
&self,
model_name: &str,
namespace: &str,
worker_set: WorkerSet,
) -> Result<(), ModelManagerError> {
let model = self.get_or_create_model(model_name); let model = self.get_or_create_model(model_name);
model.add_worker_set(namespace.to_string(), Arc::new(worker_set)) model.add_worker_set(namespace.to_string(), Arc::new(worker_set));
} }
/// Remove a WorkerSet from a Model. Removes the Model if it becomes empty. /// Remove a WorkerSet from a Model. Removes the Model if it becomes empty.
...@@ -144,16 +129,6 @@ impl ModelManager { ...@@ -144,16 +129,6 @@ impl ModelManager {
removed removed
} }
// -- Checksum validation --
/// Check if a candidate checksum is valid for a model.
/// Returns `Some(true)` if it matches the model's canonical checksum, `Some(false)` if it
/// doesn't match, or `None` if the model doesn't exist or has no canonical checksum yet.
pub fn is_valid_checksum(&self, model_name: &str, candidate_checksum: &str) -> Option<bool> {
let model = self.models.get(model_name)?;
model.is_valid_checksum(candidate_checksum)
}
// -- Model cards -- // -- Model cards --
pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> { pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
...@@ -372,7 +347,7 @@ impl ModelManager { ...@@ -372,7 +347,7 @@ impl ModelManager {
ModelDeploymentCard::default(), ModelDeploymentCard::default(),
); );
ws.chat_engine = Some(engine); ws.chat_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?; model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(()) Ok(())
} }
...@@ -393,7 +368,7 @@ impl ModelManager { ...@@ -393,7 +368,7 @@ impl ModelManager {
ModelDeploymentCard::default(), ModelDeploymentCard::default(),
); );
ws.completions_engine = Some(engine); ws.completions_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?; model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(()) Ok(())
} }
...@@ -414,7 +389,7 @@ impl ModelManager { ...@@ -414,7 +389,7 @@ impl ModelManager {
ModelDeploymentCard::default(), ModelDeploymentCard::default(),
); );
ws.embeddings_engine = Some(engine); ws.embeddings_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?; model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(()) Ok(())
} }
...@@ -435,7 +410,7 @@ impl ModelManager { ...@@ -435,7 +410,7 @@ impl ModelManager {
ModelDeploymentCard::default(), ModelDeploymentCard::default(),
); );
ws.tensor_engine = Some(engine); ws.tensor_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?; model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(()) Ok(())
} }
...@@ -456,7 +431,7 @@ impl ModelManager { ...@@ -456,7 +431,7 @@ impl ModelManager {
ModelDeploymentCard::default(), ModelDeploymentCard::default(),
); );
ws.images_engine = Some(engine); ws.images_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?; model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(()) Ok(())
} }
...@@ -477,7 +452,7 @@ impl ModelManager { ...@@ -477,7 +452,7 @@ impl ModelManager {
ModelDeploymentCard::default(), ModelDeploymentCard::default(),
); );
ws.videos_engine = Some(engine); ws.videos_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?; model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(()) Ok(())
} }
...@@ -496,7 +471,7 @@ impl ModelManager { ...@@ -496,7 +471,7 @@ impl ModelManager {
card_checksum.to_string(), card_checksum.to_string(),
ModelDeploymentCard::default(), ModelDeploymentCard::default(),
); );
model_entry.add_worker_set(namespace, Arc::new(ws))?; model_entry.add_worker_set(namespace, Arc::new(ws));
Ok(()) Ok(())
} }
...@@ -821,7 +796,7 @@ mod tests { ...@@ -821,7 +796,7 @@ mod tests {
fn test_add_and_get_worker_set() { fn test_add_and_get_worker_set() {
let mm = ModelManager::new(); let mm = ModelManager::new();
let ws = make_worker_set("ns1", "abc"); let ws = make_worker_set("ns1", "abc");
mm.add_worker_set("llama", "ns1", ws).unwrap(); mm.add_worker_set("llama", "ns1", ws);
let model = mm.get_model("llama"); let model = mm.get_model("llama");
assert!(model.is_some()); assert!(model.is_some());
...@@ -835,16 +810,14 @@ mod tests { ...@@ -835,16 +810,14 @@ mod tests {
let mm = ModelManager::new(); let mm = ModelManager::new();
assert!(mm.get_model("llama").is_none()); assert!(mm.get_model("llama").is_none());
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")) mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
.unwrap();
assert!(mm.get_model("llama").is_some()); assert!(mm.get_model("llama").is_some());
} }
#[test] #[test]
fn test_remove_worker_set_removes_empty_model() { fn test_remove_worker_set_removes_empty_model() {
let mm = ModelManager::new(); let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")) mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
.unwrap();
assert!(mm.get_model("llama").is_some()); assert!(mm.get_model("llama").is_some());
let removed = mm.remove_worker_set("llama", "ns1"); let removed = mm.remove_worker_set("llama", "ns1");
...@@ -858,10 +831,8 @@ mod tests { ...@@ -858,10 +831,8 @@ mod tests {
#[test] #[test]
fn test_remove_worker_set_keeps_model_with_remaining() { fn test_remove_worker_set_keeps_model_with_remaining() {
let mm = ModelManager::new(); let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")) mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
.unwrap(); mm.add_worker_set("llama", "ns2", make_worker_set("ns2", "abc"));
mm.add_worker_set("llama", "ns2", make_worker_set("ns2", "abc"))
.unwrap();
mm.remove_worker_set("llama", "ns1"); mm.remove_worker_set("llama", "ns1");
...@@ -881,8 +852,7 @@ mod tests { ...@@ -881,8 +852,7 @@ mod tests {
#[test] #[test]
fn test_remove_worker_set_nonexistent_namespace() { fn test_remove_worker_set_nonexistent_namespace() {
let mm = ModelManager::new(); let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")) mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
.unwrap();
assert!(mm.remove_worker_set("llama", "ns2").is_none()); assert!(mm.remove_worker_set("llama", "ns2").is_none());
// Model should still exist (ns1 still there) // Model should still exist (ns1 still there)
...@@ -892,8 +862,7 @@ mod tests { ...@@ -892,8 +862,7 @@ mod tests {
#[test] #[test]
fn test_remove_model_if_empty_noop_when_not_empty() { fn test_remove_model_if_empty_noop_when_not_empty() {
let mm = ModelManager::new(); let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")) mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
.unwrap();
mm.remove_model_if_empty("llama"); mm.remove_model_if_empty("llama");
assert!(mm.get_model("llama").is_some()); // Still has ns1 assert!(mm.get_model("llama").is_some()); // Still has ns1
...@@ -908,10 +877,8 @@ mod tests { ...@@ -908,10 +877,8 @@ mod tests {
#[test] #[test]
fn test_remove_model() { fn test_remove_model() {
let mm = ModelManager::new(); let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")) mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
.unwrap(); mm.add_worker_set("llama", "ns2", make_worker_set("ns2", "abc"));
mm.add_worker_set("llama", "ns2", make_worker_set("ns2", "abc"))
.unwrap();
let removed = mm.remove_model("llama"); let removed = mm.remove_model("llama");
assert!(removed.is_some()); assert!(removed.is_some());
...@@ -927,56 +894,6 @@ mod tests { ...@@ -927,56 +894,6 @@ mod tests {
assert!(Arc::ptr_eq(&m1, &m2)); assert!(Arc::ptr_eq(&m1, &m2));
} }
// -- Checksum validation tests --
#[test]
fn test_is_valid_checksum_match() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc123"))
.unwrap();
assert_eq!(mm.is_valid_checksum("llama", "abc123"), Some(true));
}
#[test]
fn test_is_valid_checksum_mismatch() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc123"))
.unwrap();
assert_eq!(mm.is_valid_checksum("llama", "wrong"), Some(false));
}
#[test]
fn test_is_valid_checksum_no_canonical_yet() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc123"))
.unwrap();
// Canonical is set, so even for a "new namespace" scenario the checksum is checked
assert_eq!(mm.is_valid_checksum("llama", "abc123"), Some(true));
assert_eq!(mm.is_valid_checksum("llama", "xyz"), Some(false));
}
#[test]
fn test_is_valid_checksum_missing_model() {
let mm = ModelManager::new();
assert_eq!(mm.is_valid_checksum("nonexistent", "abc"), None);
}
#[test]
fn test_is_valid_checksum_cross_namespace_enforcement() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "checksum_a"))
.unwrap();
// A different namespace with a different checksum should be rejected at the model level
assert_eq!(mm.is_valid_checksum("llama", "checksum_b"), Some(false));
// Same checksum is accepted
assert_eq!(mm.is_valid_checksum("llama", "checksum_a"), Some(true));
}
// -- Model listing and filtering tests -- // -- Model listing and filtering tests --
#[test] #[test]
...@@ -987,8 +904,7 @@ mod tests { ...@@ -987,8 +904,7 @@ mod tests {
assert!(!mm.has_decode_model("llama")); assert!(!mm.has_decode_model("llama"));
// Prefill-only set (no engines) → false // Prefill-only set (no engines) → false
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")) mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
.unwrap();
assert!(!mm.has_decode_model("llama")); assert!(!mm.has_decode_model("llama"));
} }
...@@ -997,8 +913,7 @@ mod tests { ...@@ -997,8 +913,7 @@ mod tests {
let mm = ModelManager::new(); let mm = ModelManager::new();
// Prefill set = no engines // Prefill set = no engines
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")) mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
.unwrap();
assert!(mm.has_prefill_model("llama")); assert!(mm.has_prefill_model("llama"));
} }
...@@ -1007,16 +922,14 @@ mod tests { ...@@ -1007,16 +922,14 @@ mod tests {
let mm = ModelManager::new(); let mm = ModelManager::new();
assert!(!mm.has_model_any("llama")); assert!(!mm.has_model_any("llama"));
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")) mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
.unwrap();
assert!(mm.has_model_any("llama")); // has prefill assert!(mm.has_model_any("llama")); // has prefill
} }
#[test] #[test]
fn test_model_display_names_includes_prefill() { fn test_model_display_names_includes_prefill() {
let mm = ModelManager::new(); let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")) mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
.unwrap();
let names = mm.model_display_names(); let names = mm.model_display_names();
assert!(names.contains("llama")); assert!(names.contains("llama"));
...@@ -1031,10 +944,8 @@ mod tests { ...@@ -1031,10 +944,8 @@ mod tests {
#[test] #[test]
fn test_list_prefill_models() { fn test_list_prefill_models() {
let mm = ModelManager::new(); let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")) mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
.unwrap(); mm.add_worker_set("gpt", "ns1", make_worker_set("ns1", "def"));
mm.add_worker_set("gpt", "ns1", make_worker_set("ns1", "def"))
.unwrap();
let prefill = mm.list_prefill_models(); let prefill = mm.list_prefill_models();
assert_eq!(prefill.len(), 2); assert_eq!(prefill.len(), 2);
......
...@@ -208,19 +208,20 @@ impl ModelWatcher { ...@@ -208,19 +208,20 @@ impl ModelWatcher {
continue; continue;
} }
// If we already have a WorkerSet for this model and the checksums // If a WorkerSet already exists for this (model, namespace, type),
// don't match, reject the new worker. All WorkerSets of a model // validate that the new worker's checksum matches. Different
// must share the same checksum. // WorkerSets (different namespaces) are allowed to have different checksums to support rolling updates.
let can_add = self.manager.is_valid_checksum(card.name(), card.mdcsum()); let ws_key = worker_set_key(&mcid.namespace, card.model_type);
if can_add.is_some_and(|is_valid| !is_valid) { if let Some(model) = self.manager.get_model(card.name())
&& !model.is_checksum_compatible(&ws_key, card.mdcsum())
{
tracing::error!( tracing::error!(
model_name = card.name(), model_name = card.name(),
namespace = mcid.namespace, namespace = mcid.namespace,
"Checksum for new worker does not match model's canonical checksum. \ new_checksum = card.mdcsum(),
All WorkerSets must share the same checksum. \ "Checksum for new worker does not match existing WorkerSet's checksum. \
Drain all old workers before deploying a new version." Drain all old workers in this namespace before deploying a new version."
); );
// TODO: mark that instance down in clients // TODO: mark that instance down in clients
// Not obvious how to do that given the current design // Not obvious how to do that given the current design
// Instances come from an `InstanceSource` in a `Client` in a `PushRouter`. // Instances come from an `InstanceSource` in a `Client` in a `PushRouter`.
...@@ -728,7 +729,7 @@ impl ModelWatcher { ...@@ -728,7 +729,7 @@ impl ModelWatcher {
// Prefill sets have no engines — we add the WorkerSet first for tracking, // Prefill sets have no engines — we add the WorkerSet first for tracking,
// then activate the prefill router. // then activate the prefill router.
self.manager self.manager
.add_worker_set(card.name(), &ws_key, worker_set)?; .add_worker_set(card.name(), &ws_key, worker_set);
// Note: activate_prefill_router is keyed by deployment namespace (not ws_key) // Note: activate_prefill_router is keyed by deployment namespace (not ws_key)
// because it coordinates between decode and prefill WorkerSets that share // because it coordinates between decode and prefill WorkerSets that share
...@@ -762,7 +763,7 @@ impl ModelWatcher { ...@@ -762,7 +763,7 @@ impl ModelWatcher {
// Add the completed WorkerSet to the Model // Add the completed WorkerSet to the Model
self.manager self.manager
.add_worker_set(card.name(), &ws_key, worker_set)?; .add_worker_set(card.name(), &ws_key, worker_set);
Ok(()) Ok(())
} }
...@@ -865,8 +866,7 @@ mod tests { ...@@ -865,8 +866,7 @@ mod tests {
fn test_is_model_type_list_empty_prefill_present() { fn test_is_model_type_list_empty_prefill_present() {
let mm = ModelManager::new(); let mm = ModelManager::new();
// A WorkerSet with no engines is treated as a prefill set // A WorkerSet with no engines is treated as a prefill set
mm.add_worker_set("model-a", "ns1", make_worker_set("ns1")) mm.add_worker_set("model-a", "ns1", make_worker_set("ns1"));
.unwrap();
assert!(!is_model_type_list_empty(&mm, ModelType::Prefill)); assert!(!is_model_type_list_empty(&mm, ModelType::Prefill));
// Other types should still be empty since the WorkerSet has no engines // Other types should still be empty since the WorkerSet has no engines
...@@ -881,8 +881,7 @@ mod tests { ...@@ -881,8 +881,7 @@ mod tests {
#[test] #[test]
fn test_is_model_type_list_empty_after_removal() { fn test_is_model_type_list_empty_after_removal() {
let mm = ModelManager::new(); let mm = ModelManager::new();
mm.add_worker_set("model-a", "ns1", make_worker_set("ns1")) mm.add_worker_set("model-a", "ns1", make_worker_set("ns1"));
.unwrap();
assert!(!is_model_type_list_empty(&mm, ModelType::Prefill)); assert!(!is_model_type_list_empty(&mm, ModelType::Prefill));
mm.remove_model("model-a"); mm.remove_model("model-a");
...@@ -892,10 +891,8 @@ mod tests { ...@@ -892,10 +891,8 @@ mod tests {
#[test] #[test]
fn test_is_model_type_list_not_empty_when_other_model_remains() { fn test_is_model_type_list_not_empty_when_other_model_remains() {
let mm = ModelManager::new(); let mm = ModelManager::new();
mm.add_worker_set("model-a", "ns1", make_worker_set("ns1")) mm.add_worker_set("model-a", "ns1", make_worker_set("ns1"));
.unwrap(); mm.add_worker_set("model-b", "ns1", make_worker_set("ns1"));
mm.add_worker_set("model-b", "ns1", make_worker_set("ns1"))
.unwrap();
// Remove one model — other still provides prefill // Remove one model — other still provides prefill
mm.remove_model("model-a"); mm.remove_model("model-a");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment