Unverified Commit 21ba0c4b authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix(backend): Reject conflicting model registrations on the same endpoint (#7686)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 8886d184
...@@ -89,7 +89,7 @@ impl Discovery for KubeDiscoveryClient { ...@@ -89,7 +89,7 @@ impl Discovery for KubeDiscoveryClient {
self.instance_id self.instance_id
} }
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> { async fn register_internal(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
let instance_id = self.instance_id(); let instance_id = self.instance_id();
let instance = spec.with_instance_id(instance_id); let instance = spec.with_instance_id(instance_id);
......
...@@ -156,7 +156,7 @@ impl Discovery for KVStoreDiscovery { ...@@ -156,7 +156,7 @@ impl Discovery for KVStoreDiscovery {
self.store.connection_id() self.store.connection_id()
} }
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> { async fn register_internal(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
let instance_id = self.instance_id(); let instance_id = self.instance_id();
let instance = spec.with_instance_id(instance_id); let instance = spec.with_instance_id(instance_id);
......
...@@ -157,7 +157,7 @@ impl Discovery for MockDiscovery { ...@@ -157,7 +157,7 @@ impl Discovery for MockDiscovery {
self.instance_id self.instance_id
} }
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> { async fn register_internal(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
let instance = spec.with_instance_id(self.instance_id); let instance = spec.with_instance_id(self.instance_id);
self.registry self.registry
...@@ -241,6 +241,46 @@ mod tests { ...@@ -241,6 +241,46 @@ mod tests {
use super::*; use super::*;
use futures::StreamExt; use futures::StreamExt;
fn model_spec(
namespace: &str,
component: &str,
endpoint: &str,
model_name: &str,
) -> DiscoverySpec {
DiscoverySpec::Model {
namespace: namespace.to_string(),
component: component.to_string(),
endpoint: endpoint.to_string(),
card_json: serde_json::json!({
"display_name": model_name,
}),
model_suffix: None,
}
}
fn lora_model_spec(
namespace: &str,
component: &str,
endpoint: &str,
model_name: &str,
source_path: &str,
lora_name: &str,
) -> DiscoverySpec {
DiscoverySpec::Model {
namespace: namespace.to_string(),
component: component.to_string(),
endpoint: endpoint.to_string(),
card_json: serde_json::json!({
"display_name": model_name,
"source_path": source_path,
"lora": {
"name": lora_name,
},
}),
model_suffix: Some(lora_name.to_string()),
}
}
#[tokio::test] #[tokio::test]
async fn test_mock_discovery_add_and_remove() { async fn test_mock_discovery_add_and_remove() {
let registry = SharedMockRegistry::new(); let registry = SharedMockRegistry::new();
...@@ -301,4 +341,142 @@ mod tests { ...@@ -301,4 +341,142 @@ mod tests {
_ => panic!("Expected Removed event for instance-1"), _ => panic!("Expected Removed event for instance-1"),
} }
} }
#[tokio::test]
async fn register_allows_same_model_name_on_same_endpoint() {
let registry = SharedMockRegistry::new();
let discovery1 = MockDiscovery::new(Some(1), registry.clone());
let discovery2 = MockDiscovery::new(Some(2), registry);
let spec = model_spec("ns", "comp", "generate", "model-a");
discovery1.register(spec.clone()).await.unwrap();
discovery2.register(spec).await.unwrap();
let instances = discovery1
.list(DiscoveryQuery::EndpointModels {
namespace: "ns".to_string(),
component: "comp".to_string(),
endpoint: "generate".to_string(),
})
.await
.unwrap();
assert_eq!(instances.len(), 2);
}
#[tokio::test]
async fn register_rejects_different_model_name_on_same_endpoint() {
let registry = SharedMockRegistry::new();
let discovery1 = MockDiscovery::new(Some(1), registry.clone());
let discovery2 = MockDiscovery::new(Some(2), registry);
discovery1
.register(model_spec("ns", "comp", "generate", "model-a"))
.await
.unwrap();
let err = discovery2
.register(model_spec("ns", "comp", "generate", "model-b"))
.await
.unwrap_err();
assert!(err.to_string().contains(
"Cannot register model 'model-b' on endpoint 'ns/comp/generate': a different model 'model-a' is already registered there"
));
let instances = discovery1
.list(DiscoveryQuery::EndpointModels {
namespace: "ns".to_string(),
component: "comp".to_string(),
endpoint: "generate".to_string(),
})
.await
.unwrap();
assert_eq!(instances.len(), 1);
}
#[tokio::test]
async fn register_allows_different_model_names_on_different_endpoints() {
let registry = SharedMockRegistry::new();
let discovery1 = MockDiscovery::new(Some(1), registry.clone());
let discovery2 = MockDiscovery::new(Some(2), registry);
discovery1
.register(model_spec("ns", "comp", "generate-a", "model-a"))
.await
.unwrap();
discovery2
.register(model_spec("ns", "comp", "generate-b", "model-b"))
.await
.unwrap();
}
#[tokio::test]
async fn register_allows_lora_adapter_on_same_endpoint() {
let registry = SharedMockRegistry::new();
let discovery1 = MockDiscovery::new(Some(1), registry.clone());
let discovery2 = MockDiscovery::new(Some(2), registry);
discovery1
.register(DiscoverySpec::Model {
namespace: "ns".to_string(),
component: "comp".to_string(),
endpoint: "generate".to_string(),
card_json: serde_json::json!({
"display_name": "base-model",
"source_path": "base-repo",
}),
model_suffix: None,
})
.await
.unwrap();
discovery2
.register(lora_model_spec(
"ns",
"comp",
"generate",
"adapter-a",
"base-repo",
"adapter-a",
))
.await
.unwrap();
}
#[tokio::test]
async fn register_rejects_lora_adapter_for_different_base_model() {
let registry = SharedMockRegistry::new();
let discovery1 = MockDiscovery::new(Some(1), registry.clone());
let discovery2 = MockDiscovery::new(Some(2), registry);
discovery1
.register(DiscoverySpec::Model {
namespace: "ns".to_string(),
component: "comp".to_string(),
endpoint: "generate".to_string(),
card_json: serde_json::json!({
"display_name": "base-model",
"source_path": "base-repo",
}),
model_suffix: None,
})
.await
.unwrap();
let err = discovery2
.register(lora_model_spec(
"ns",
"comp",
"generate",
"adapter-a",
"other-base-repo",
"adapter-a",
))
.await
.unwrap_err();
assert!(err.to_string().contains(
"Cannot register model 'adapter-a' on endpoint 'ns/comp/generate': a different model 'base-model' is already registered there"
));
}
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use anyhow::Result; use anyhow::{Context, Result};
use async_trait::async_trait; use async_trait::async_trait;
use futures::Stream; use futures::Stream;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
...@@ -683,6 +683,74 @@ pub enum DiscoveryEvent { ...@@ -683,6 +683,74 @@ pub enum DiscoveryEvent {
/// Stream type for discovery events /// Stream type for discovery events
pub type DiscoveryStream = Pin<Box<dyn Stream<Item = Result<DiscoveryEvent>> + Send>>; pub type DiscoveryStream = Pin<Box<dyn Stream<Item = Result<DiscoveryEvent>> + Send>>;
#[derive(Clone, Debug, PartialEq, Eq)]
struct ModelRegistrationIdentity {
display_name: String,
source_path: Option<String>,
is_lora: bool,
}
impl ModelRegistrationIdentity {
fn base_identity(&self) -> &str {
self.source_path.as_deref().unwrap_or(&self.display_name)
}
fn is_compatible_with(&self, other: &Self) -> bool {
if self.is_lora || other.is_lora {
self.base_identity() == other.base_identity()
} else {
self.display_name == other.display_name
}
}
}
fn extract_model_registration_identity(
card_json: &serde_json::Value,
model_suffix: Option<&str>,
) -> Result<ModelRegistrationIdentity> {
let display_name = card_json
.get("display_name")
.and_then(serde_json::Value::as_str)
.map(str::to_owned)
.ok_or_else(|| {
anyhow::anyhow!("failed to deserialize model display_name from card_json")
})?;
let source_path = card_json
.get("source_path")
.and_then(serde_json::Value::as_str)
.map(str::to_owned);
let is_lora =
model_suffix.is_some() || card_json.get("lora").is_some_and(|value| !value.is_null());
Ok(ModelRegistrationIdentity {
display_name,
source_path,
is_lora,
})
}
fn find_conflicting_model_name(
instances: &[DiscoveryInstance],
requested_identity: &ModelRegistrationIdentity,
) -> Result<Option<String>> {
for instance in instances {
if let DiscoveryInstance::Model {
card_json,
model_suffix,
..
} = instance
{
let existing_identity =
extract_model_registration_identity(card_json, model_suffix.as_deref())?;
if !requested_identity.is_compatible_with(&existing_identity) {
return Ok(Some(existing_identity.display_name));
}
}
}
Ok(None)
}
/// Discovery trait for service discovery across different backends /// Discovery trait for service discovery across different backends
#[async_trait] #[async_trait]
pub trait Discovery: Send + Sync { pub trait Discovery: Send + Sync {
...@@ -691,7 +759,65 @@ pub trait Discovery: Send + Sync { ...@@ -691,7 +759,65 @@ pub trait Discovery: Send + Sync {
fn instance_id(&self) -> u64; fn instance_id(&self) -> u64;
/// Registers an object in the discovery plane with the instance id /// Registers an object in the discovery plane with the instance id
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance>; async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
let (namespace, component, endpoint, requested_identity) = match &spec {
DiscoverySpec::Model {
namespace,
component,
endpoint,
card_json,
model_suffix,
..
} => (
namespace.clone(),
component.clone(),
endpoint.clone(),
extract_model_registration_identity(card_json, model_suffix.as_deref())?,
),
_ => return self.register_internal(spec).await,
};
let query = DiscoveryQuery::EndpointModels {
namespace: namespace.clone(),
component: component.clone(),
endpoint: endpoint.clone(),
};
if let Some(conflicting_name) =
find_conflicting_model_name(&self.list(query.clone()).await?, &requested_identity)?
{
let requested_name = &requested_identity.display_name;
anyhow::bail!(
"Cannot register model '{requested_name}' on endpoint '{namespace}/{component}/{endpoint}': a different model '{conflicting_name}' is already registered there"
);
}
let instance = self.register_internal(spec).await?;
if let Some(conflicting_name) =
find_conflicting_model_name(&self.list(query).await?, &requested_identity)?
{
let requested_name = &requested_identity.display_name;
if let Err(unregister_err) = self.unregister(instance.clone()).await {
return Err(anyhow::anyhow!(
"Cannot register model '{requested_name}' on endpoint '{namespace}/{component}/{endpoint}': a different model '{conflicting_name}' is already registered there"
))
.context(format!(
"failed to roll back conflicting model registration for instance {instance_id}: {unregister_err}",
instance_id = instance.instance_id()
));
}
anyhow::bail!(
"Cannot register model '{requested_name}' on endpoint '{namespace}/{component}/{endpoint}': a different model '{conflicting_name}' is already registered there"
);
}
Ok(instance)
}
/// Backend-specific raw registration implementation.
async fn register_internal(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance>;
/// Unregisters an instance from the discovery plane /// Unregisters an instance from the discovery plane
async fn unregister(&self, instance: DiscoveryInstance) -> Result<()>; async fn unregister(&self, instance: DiscoveryInstance) -> Result<()>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment