Unverified Commit 72ec5f5c authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Allow an endpoint to serve multiple models (#2418)

parent ebc84d6c
......@@ -127,7 +127,7 @@ async fn app(runtime: Runtime) -> Result<()> {
tracing::debug!("Creating unique instance of Count at {key}");
drt.etcd_client()
.expect("Unreachable because of DistributedRuntime::from_settings above")
.kv_create(key, serde_json::to_vec_pretty(&config)?, None)
.kv_create(&key, serde_json::to_vec_pretty(&config)?, None)
.await
.context("Unable to create unique instance of Count; possibly one already exists")?;
......
......@@ -557,7 +557,7 @@ impl EtcdClient {
let client = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
client
.kv_create(key, value, lease_id)
.kv_create(&key, value, lease_id)
.await
.map_err(to_pyerr)?;
Ok(())
......
......@@ -10,7 +10,7 @@ use dynamo_runtime::protocols::Endpoint as EndpointId;
use dynamo_runtime::slug::Slug;
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::{
component::{Component, Endpoint},
component::Endpoint,
storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
};
......@@ -302,8 +302,6 @@ impl LocalModel {
let Some(etcd_client) = endpoint.drt().etcd_client() else {
anyhow::bail!("Cannot attach to static endpoint");
};
self.ensure_unique(endpoint.component(), self.display_name())
.await?;
// Store model config files in NATS object store
let nats_client = endpoint.drt().nats_client();
......@@ -319,7 +317,7 @@ impl LocalModel {
// Publish our ModelEntry to etcd. This allows ingress to find the model card.
// (Why don't we put the model card directly under this key?)
let network_name = ModelNetworkName::from_local(endpoint, etcd_client.lease_id());
let network_name = ModelNetworkName::new();
tracing::debug!("Registering with etcd as {network_name}");
let model_registration = ModelEntry {
name: self.display_name().to_string(),
......@@ -328,35 +326,12 @@ impl LocalModel {
};
etcd_client
.kv_create(
network_name.to_string(),
&network_name,
serde_json::to_vec_pretty(&model_registration)?,
None, // use primary lease
)
.await
}
/// Ensure that each component serves only one model.
/// We can have multiple instances of the same model running using the same component name
/// (they get load balanced, and are differentiated in etcd by their lease_id).
/// We cannot have multiple models with the same component name.
///
/// Returns an error if there is already a component by this name serving a different model.
async fn ensure_unique(&self, component: &Component, model_name: &str) -> anyhow::Result<()> {
let Some(etcd_client) = component.drt().etcd_client() else {
// A static component is necessarily unique, it cannot register
return Ok(());
};
for endpoint_info in component.list_instances().await? {
let network_name: ModelNetworkName = (&endpoint_info).into();
if let Ok(entry) = network_name.load_entry(&etcd_client).await {
if entry.name != model_name {
anyhow::bail!("Duplicate component. Attempt to register model {model_name} at {component}, which is already used by {network_name} running model {}.", entry.name);
}
}
}
Ok(())
}
}
/// A random endpoint to use for internal communication
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Context as _;
use crate::discovery::{ModelEntry, MODEL_ROOT_PATH};
use dynamo_runtime::component::{self, Instance};
use dynamo_runtime::slug::Slug;
use dynamo_runtime::transports::etcd;
use crate::discovery::MODEL_ROOT_PATH;
#[derive(Debug, Clone)]
pub struct ModelNetworkName(String);
impl ModelNetworkName {
/// Key to store this model entry in networked key-value store (etcd).
///
/// It looks like this:
/// ns.cp.ep-694d967ca5efd804
fn from_parts(namespace: &str, component: &str, endpoint: &str, lease_id: i64) -> Self {
let model_root = MODEL_ROOT_PATH;
let slug = Slug::slugify(&format!("{namespace}.{component}.{endpoint}-{lease_id:x}"));
ModelNetworkName(format!("{model_root}/{slug}"))
}
// We can't do From<&component::Endpoint> here because we also need the lease_id
pub fn from_local(endpoint: &component::Endpoint, lease_id: i64) -> Self {
Self::from_parts(
&endpoint.component().namespace().to_string(),
&endpoint.component().name(),
endpoint.name(),
lease_id,
)
pub fn new() -> Self {
ModelNetworkName(format!("{MODEL_ROOT_PATH}/{}", uuid::Uuid::new_v4()))
}
}
pub fn from_entry(entry: &ModelEntry, lease_id: i64) -> Self {
Self::from_parts(
&entry.endpoint.namespace,
&entry.endpoint.component,
&entry.endpoint.name,
lease_id,
)
impl Default for ModelNetworkName {
fn default() -> Self {
Self::new()
}
}
/// Fetch the ModelEntry from etcd.
pub async fn load_entry(&self, etcd_client: &etcd::Client) -> anyhow::Result<ModelEntry> {
let mut model_entries = etcd_client.kv_get(self.to_string(), None).await?;
if model_entries.is_empty() {
anyhow::bail!("No ModelEntry in etcd for key {self}");
}
let model_entry = model_entries.remove(0);
serde_json::from_slice(model_entry.value()).with_context(|| {
format!(
"Error deserializing JSON. Key={self}. JSON={}",
model_entry.value_str().unwrap_or("INVALID UTF-8")
)
})
impl std::fmt::Display for ModelNetworkName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<&Instance> for ModelNetworkName {
fn from(cei: &Instance) -> Self {
Self::from_parts(
&cei.namespace,
&cei.component,
&cei.endpoint,
cei.instance_id,
)
impl AsRef<str> for ModelNetworkName {
fn as_ref(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for ModelNetworkName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
impl std::ops::Deref for ModelNetworkName {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.0
}
}
......@@ -182,6 +182,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5"
dependencies = [
"axum-core",
"axum-macros",
"bytes",
"form_urlencoded",
"futures-util",
......@@ -229,6 +230,17 @@ dependencies = [
"tracing",
]
[[package]]
name = "axum-macros"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "backtrace"
version = "0.3.74"
......@@ -673,6 +685,7 @@ dependencies = [
"tokio",
"tokio-stream",
"tokio-util",
"tower-http",
"tracing",
"tracing-subscriber",
"url",
......@@ -3013,6 +3026,7 @@ dependencies = [
"tower",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
......
......@@ -144,7 +144,7 @@ impl EndpointConfigBuilder {
if let Some(etcd_client) = &endpoint.component.drt.etcd_client {
if let Err(e) = etcd_client
.kv_create(
endpoint.etcd_path_with_lease_id(lease_id),
&endpoint.etcd_path_with_lease_id(lease_id),
info,
Some(lease_id),
)
......
......@@ -170,20 +170,15 @@ impl Client {
.await?
}
pub async fn kv_create(
&self,
key: String,
value: Vec<u8>,
lease_id: Option<i64>,
) -> Result<()> {
pub async fn kv_create(&self, key: &str, value: Vec<u8>, lease_id: Option<i64>) -> Result<()> {
let id = lease_id.unwrap_or(self.lease_id());
let put_options = PutOptions::new().with_lease(id);
// Build the transaction
let txn = Txn::new()
.when(vec![Compare::version(key.as_str(), CompareOp::Equal, 0)]) // Ensure the lock does not exist
.when(vec![Compare::version(key, CompareOp::Equal, 0)]) // Ensure the lock does not exist
.and_then(vec![
TxnOp::put(key.as_str(), value, Some(put_options)), // Create the object
TxnOp::put(key, value, Some(put_options)), // Create the object
]);
// Execute the transaction
......
......@@ -95,7 +95,7 @@ fn handle_watch_event<T: DeserializeOwned>(
/// Creates a key-value pair in etcd, returning a specific error if the key already exists
async fn create_barrier_key<T: Serialize>(
client: &Client,
key: String,
key: &str,
data: T,
lease_id: Option<i64>,
) -> Result<(), LeaderWorkerBarrierError> {
......@@ -193,7 +193,7 @@ impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + Deseriali
lease_id: i64,
) -> Result<(), LeaderWorkerBarrierError> {
let key = barrier_key(&self.barrier_id, BARRIER_DATA);
create_barrier_key(client, key, data, Some(lease_id)).await
create_barrier_key(client, &key, data, Some(lease_id)).await
}
async fn wait_for_workers(
......@@ -216,10 +216,10 @@ impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + Deseriali
let workers = worker_result.keys().collect::<HashSet<_>>();
create_barrier_key(client, key, workers, Some(lease_id)).await?;
create_barrier_key(client, &key, workers, Some(lease_id)).await?;
} else {
let key = barrier_key(&self.barrier_id, BARRIER_ABORT);
create_barrier_key(client, key, (), Some(lease_id)).await?;
create_barrier_key(client, &key, (), Some(lease_id)).await?;
}
Ok(())
......@@ -302,7 +302,7 @@ impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + Deseriali
&self.barrier_id,
&format!("{}/{}", BARRIER_WORKER, self.worker_id),
);
create_barrier_key(client, key.clone(), data, Some(lease_id)).await?;
create_barrier_key(client, &key, data, Some(lease_id)).await?;
Ok(key)
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment