"tests/global_planner/vscode:/vscode.git/clone" did not exist on "5d5fd243da84beea19ecb66cbce0fdfa449f5a33"
Unverified Commit 21ba0c4b authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix(backend): Reject conflicting model registrations on the same endpoint (#7686)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 8886d184
......@@ -89,7 +89,7 @@ impl Discovery for KubeDiscoveryClient {
self.instance_id
}
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
async fn register_internal(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
let instance_id = self.instance_id();
let instance = spec.with_instance_id(instance_id);
......
......@@ -156,7 +156,7 @@ impl Discovery for KVStoreDiscovery {
self.store.connection_id()
}
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
async fn register_internal(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
let instance_id = self.instance_id();
let instance = spec.with_instance_id(instance_id);
......
......@@ -157,7 +157,7 @@ impl Discovery for MockDiscovery {
self.instance_id
}
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
async fn register_internal(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
let instance = spec.with_instance_id(self.instance_id);
self.registry
......@@ -241,6 +241,46 @@ mod tests {
use super::*;
use futures::StreamExt;
fn model_spec(
namespace: &str,
component: &str,
endpoint: &str,
model_name: &str,
) -> DiscoverySpec {
DiscoverySpec::Model {
namespace: namespace.to_string(),
component: component.to_string(),
endpoint: endpoint.to_string(),
card_json: serde_json::json!({
"display_name": model_name,
}),
model_suffix: None,
}
}
fn lora_model_spec(
namespace: &str,
component: &str,
endpoint: &str,
model_name: &str,
source_path: &str,
lora_name: &str,
) -> DiscoverySpec {
DiscoverySpec::Model {
namespace: namespace.to_string(),
component: component.to_string(),
endpoint: endpoint.to_string(),
card_json: serde_json::json!({
"display_name": model_name,
"source_path": source_path,
"lora": {
"name": lora_name,
},
}),
model_suffix: Some(lora_name.to_string()),
}
}
#[tokio::test]
async fn test_mock_discovery_add_and_remove() {
let registry = SharedMockRegistry::new();
......@@ -301,4 +341,142 @@ mod tests {
_ => panic!("Expected Removed event for instance-1"),
}
}
#[tokio::test]
async fn register_allows_same_model_name_on_same_endpoint() {
let registry = SharedMockRegistry::new();
let discovery1 = MockDiscovery::new(Some(1), registry.clone());
let discovery2 = MockDiscovery::new(Some(2), registry);
let spec = model_spec("ns", "comp", "generate", "model-a");
discovery1.register(spec.clone()).await.unwrap();
discovery2.register(spec).await.unwrap();
let instances = discovery1
.list(DiscoveryQuery::EndpointModels {
namespace: "ns".to_string(),
component: "comp".to_string(),
endpoint: "generate".to_string(),
})
.await
.unwrap();
assert_eq!(instances.len(), 2);
}
#[tokio::test]
async fn register_rejects_different_model_name_on_same_endpoint() {
let registry = SharedMockRegistry::new();
let discovery1 = MockDiscovery::new(Some(1), registry.clone());
let discovery2 = MockDiscovery::new(Some(2), registry);
discovery1
.register(model_spec("ns", "comp", "generate", "model-a"))
.await
.unwrap();
let err = discovery2
.register(model_spec("ns", "comp", "generate", "model-b"))
.await
.unwrap_err();
assert!(err.to_string().contains(
"Cannot register model 'model-b' on endpoint 'ns/comp/generate': a different model 'model-a' is already registered there"
));
let instances = discovery1
.list(DiscoveryQuery::EndpointModels {
namespace: "ns".to_string(),
component: "comp".to_string(),
endpoint: "generate".to_string(),
})
.await
.unwrap();
assert_eq!(instances.len(), 1);
}
#[tokio::test]
async fn register_allows_different_model_names_on_different_endpoints() {
let registry = SharedMockRegistry::new();
let discovery1 = MockDiscovery::new(Some(1), registry.clone());
let discovery2 = MockDiscovery::new(Some(2), registry);
discovery1
.register(model_spec("ns", "comp", "generate-a", "model-a"))
.await
.unwrap();
discovery2
.register(model_spec("ns", "comp", "generate-b", "model-b"))
.await
.unwrap();
}
#[tokio::test]
async fn register_allows_lora_adapter_on_same_endpoint() {
let registry = SharedMockRegistry::new();
let discovery1 = MockDiscovery::new(Some(1), registry.clone());
let discovery2 = MockDiscovery::new(Some(2), registry);
discovery1
.register(DiscoverySpec::Model {
namespace: "ns".to_string(),
component: "comp".to_string(),
endpoint: "generate".to_string(),
card_json: serde_json::json!({
"display_name": "base-model",
"source_path": "base-repo",
}),
model_suffix: None,
})
.await
.unwrap();
discovery2
.register(lora_model_spec(
"ns",
"comp",
"generate",
"adapter-a",
"base-repo",
"adapter-a",
))
.await
.unwrap();
}
#[tokio::test]
async fn register_rejects_lora_adapter_for_different_base_model() {
let registry = SharedMockRegistry::new();
let discovery1 = MockDiscovery::new(Some(1), registry.clone());
let discovery2 = MockDiscovery::new(Some(2), registry);
discovery1
.register(DiscoverySpec::Model {
namespace: "ns".to_string(),
component: "comp".to_string(),
endpoint: "generate".to_string(),
card_json: serde_json::json!({
"display_name": "base-model",
"source_path": "base-repo",
}),
model_suffix: None,
})
.await
.unwrap();
let err = discovery2
.register(lora_model_spec(
"ns",
"comp",
"generate",
"adapter-a",
"other-base-repo",
"adapter-a",
))
.await
.unwrap_err();
assert!(err.to_string().contains(
"Cannot register model 'adapter-a' on endpoint 'ns/comp/generate': a different model 'base-model' is already registered there"
));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Result;
use anyhow::{Context, Result};
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
......@@ -683,6 +683,74 @@ pub enum DiscoveryEvent {
/// Stream type for discovery events
pub type DiscoveryStream = Pin<Box<dyn Stream<Item = Result<DiscoveryEvent>> + Send>>;
#[derive(Clone, Debug, PartialEq, Eq)]
struct ModelRegistrationIdentity {
display_name: String,
source_path: Option<String>,
is_lora: bool,
}
impl ModelRegistrationIdentity {
fn base_identity(&self) -> &str {
self.source_path.as_deref().unwrap_or(&self.display_name)
}
fn is_compatible_with(&self, other: &Self) -> bool {
if self.is_lora || other.is_lora {
self.base_identity() == other.base_identity()
} else {
self.display_name == other.display_name
}
}
}
fn extract_model_registration_identity(
card_json: &serde_json::Value,
model_suffix: Option<&str>,
) -> Result<ModelRegistrationIdentity> {
let display_name = card_json
.get("display_name")
.and_then(serde_json::Value::as_str)
.map(str::to_owned)
.ok_or_else(|| {
anyhow::anyhow!("failed to deserialize model display_name from card_json")
})?;
let source_path = card_json
.get("source_path")
.and_then(serde_json::Value::as_str)
.map(str::to_owned);
let is_lora =
model_suffix.is_some() || card_json.get("lora").is_some_and(|value| !value.is_null());
Ok(ModelRegistrationIdentity {
display_name,
source_path,
is_lora,
})
}
fn find_conflicting_model_name(
instances: &[DiscoveryInstance],
requested_identity: &ModelRegistrationIdentity,
) -> Result<Option<String>> {
for instance in instances {
if let DiscoveryInstance::Model {
card_json,
model_suffix,
..
} = instance
{
let existing_identity =
extract_model_registration_identity(card_json, model_suffix.as_deref())?;
if !requested_identity.is_compatible_with(&existing_identity) {
return Ok(Some(existing_identity.display_name));
}
}
}
Ok(None)
}
/// Discovery trait for service discovery across different backends
#[async_trait]
pub trait Discovery: Send + Sync {
......@@ -691,7 +759,65 @@ pub trait Discovery: Send + Sync {
fn instance_id(&self) -> u64;
/// Registers an object in the discovery plane with the instance id
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance>;
async fn register(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance> {
let (namespace, component, endpoint, requested_identity) = match &spec {
DiscoverySpec::Model {
namespace,
component,
endpoint,
card_json,
model_suffix,
..
} => (
namespace.clone(),
component.clone(),
endpoint.clone(),
extract_model_registration_identity(card_json, model_suffix.as_deref())?,
),
_ => return self.register_internal(spec).await,
};
let query = DiscoveryQuery::EndpointModels {
namespace: namespace.clone(),
component: component.clone(),
endpoint: endpoint.clone(),
};
if let Some(conflicting_name) =
find_conflicting_model_name(&self.list(query.clone()).await?, &requested_identity)?
{
let requested_name = &requested_identity.display_name;
anyhow::bail!(
"Cannot register model '{requested_name}' on endpoint '{namespace}/{component}/{endpoint}': a different model '{conflicting_name}' is already registered there"
);
}
let instance = self.register_internal(spec).await?;
if let Some(conflicting_name) =
find_conflicting_model_name(&self.list(query).await?, &requested_identity)?
{
let requested_name = &requested_identity.display_name;
if let Err(unregister_err) = self.unregister(instance.clone()).await {
return Err(anyhow::anyhow!(
"Cannot register model '{requested_name}' on endpoint '{namespace}/{component}/{endpoint}': a different model '{conflicting_name}' is already registered there"
))
.context(format!(
"failed to roll back conflicting model registration for instance {instance_id}: {unregister_err}",
instance_id = instance.instance_id()
));
}
anyhow::bail!(
"Cannot register model '{requested_name}' on endpoint '{namespace}/{component}/{endpoint}': a different model '{conflicting_name}' is already registered there"
);
}
Ok(instance)
}
/// Backend-specific raw registration implementation.
async fn register_internal(&self, spec: DiscoverySpec) -> Result<DiscoveryInstance>;
/// Unregisters an instance from the discovery plane
async fn unregister(&self, instance: DiscoveryInstance) -> Result<()>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment