Unverified Commit 72ec5f5c authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Allow an endpoint to serve multiple models (#2418)

parent ebc84d6c
...@@ -127,7 +127,7 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -127,7 +127,7 @@ async fn app(runtime: Runtime) -> Result<()> {
tracing::debug!("Creating unique instance of Count at {key}"); tracing::debug!("Creating unique instance of Count at {key}");
drt.etcd_client() drt.etcd_client()
.expect("Unreachable because of DistributedRuntime::from_settings above") .expect("Unreachable because of DistributedRuntime::from_settings above")
.kv_create(key, serde_json::to_vec_pretty(&config)?, None) .kv_create(&key, serde_json::to_vec_pretty(&config)?, None)
.await .await
.context("Unable to create unique instance of Count; possibly one already exists")?; .context("Unable to create unique instance of Count; possibly one already exists")?;
......
...@@ -557,7 +557,7 @@ impl EtcdClient { ...@@ -557,7 +557,7 @@ impl EtcdClient {
let client = self.inner.clone(); let client = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
client client
.kv_create(key, value, lease_id) .kv_create(&key, value, lease_id)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
Ok(()) Ok(())
......
...@@ -10,7 +10,7 @@ use dynamo_runtime::protocols::Endpoint as EndpointId; ...@@ -10,7 +10,7 @@ use dynamo_runtime::protocols::Endpoint as EndpointId;
use dynamo_runtime::slug::Slug; use dynamo_runtime::slug::Slug;
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::{ use dynamo_runtime::{
component::{Component, Endpoint}, component::Endpoint,
storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager}, storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
}; };
...@@ -302,8 +302,6 @@ impl LocalModel { ...@@ -302,8 +302,6 @@ impl LocalModel {
let Some(etcd_client) = endpoint.drt().etcd_client() else { let Some(etcd_client) = endpoint.drt().etcd_client() else {
anyhow::bail!("Cannot attach to static endpoint"); anyhow::bail!("Cannot attach to static endpoint");
}; };
self.ensure_unique(endpoint.component(), self.display_name())
.await?;
// Store model config files in NATS object store // Store model config files in NATS object store
let nats_client = endpoint.drt().nats_client(); let nats_client = endpoint.drt().nats_client();
...@@ -319,7 +317,7 @@ impl LocalModel { ...@@ -319,7 +317,7 @@ impl LocalModel {
// Publish our ModelEntry to etcd. This allows ingress to find the model card. // Publish our ModelEntry to etcd. This allows ingress to find the model card.
// (Why don't we put the model card directly under this key?) // (Why don't we put the model card directly under this key?)
let network_name = ModelNetworkName::from_local(endpoint, etcd_client.lease_id()); let network_name = ModelNetworkName::new();
tracing::debug!("Registering with etcd as {network_name}"); tracing::debug!("Registering with etcd as {network_name}");
let model_registration = ModelEntry { let model_registration = ModelEntry {
name: self.display_name().to_string(), name: self.display_name().to_string(),
...@@ -328,35 +326,12 @@ impl LocalModel { ...@@ -328,35 +326,12 @@ impl LocalModel {
}; };
etcd_client etcd_client
.kv_create( .kv_create(
network_name.to_string(), &network_name,
serde_json::to_vec_pretty(&model_registration)?, serde_json::to_vec_pretty(&model_registration)?,
None, // use primary lease None, // use primary lease
) )
.await .await
} }
/// Ensure that each component serves only one model.
/// We can have multiple instances of the same model running using the same component name
/// (they get load balanced, and are differentiated in etcd by their lease_id).
/// We cannot have multiple models with the same component name.
///
/// Returns an error if there is already a component by this name serving a different model.
async fn ensure_unique(&self, component: &Component, model_name: &str) -> anyhow::Result<()> {
let Some(etcd_client) = component.drt().etcd_client() else {
// A static component is necessarily unique, it cannot register
return Ok(());
};
for endpoint_info in component.list_instances().await? {
let network_name: ModelNetworkName = (&endpoint_info).into();
if let Ok(entry) = network_name.load_entry(&etcd_client).await {
if entry.name != model_name {
anyhow::bail!("Duplicate component. Attempt to register model {model_name} at {component}, which is already used by {network_name} running model {}.", entry.name);
}
}
}
Ok(())
}
} }
/// A random endpoint to use for internal communication /// A random endpoint to use for internal communication
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use anyhow::Context as _; use crate::discovery::MODEL_ROOT_PATH;
use crate::discovery::{ModelEntry, MODEL_ROOT_PATH};
use dynamo_runtime::component::{self, Instance};
use dynamo_runtime::slug::Slug;
use dynamo_runtime::transports::etcd;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ModelNetworkName(String); pub struct ModelNetworkName(String);
impl ModelNetworkName { impl ModelNetworkName {
/// Key to store this model entry in networked key-value store (etcd). pub fn new() -> Self {
/// ModelNetworkName(format!("{MODEL_ROOT_PATH}/{}", uuid::Uuid::new_v4()))
/// It looks like this:
/// ns.cp.ep-694d967ca5efd804
fn from_parts(namespace: &str, component: &str, endpoint: &str, lease_id: i64) -> Self {
let model_root = MODEL_ROOT_PATH;
let slug = Slug::slugify(&format!("{namespace}.{component}.{endpoint}-{lease_id:x}"));
ModelNetworkName(format!("{model_root}/{slug}"))
}
// We can't do From<&component::Endpoint> here because we also need the lease_id
pub fn from_local(endpoint: &component::Endpoint, lease_id: i64) -> Self {
Self::from_parts(
&endpoint.component().namespace().to_string(),
&endpoint.component().name(),
endpoint.name(),
lease_id,
)
} }
}
pub fn from_entry(entry: &ModelEntry, lease_id: i64) -> Self { impl Default for ModelNetworkName {
Self::from_parts( fn default() -> Self {
&entry.endpoint.namespace, Self::new()
&entry.endpoint.component,
&entry.endpoint.name,
lease_id,
)
} }
}
/// Fetch the ModelEntry from etcd. impl std::fmt::Display for ModelNetworkName {
pub async fn load_entry(&self, etcd_client: &etcd::Client) -> anyhow::Result<ModelEntry> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut model_entries = etcd_client.kv_get(self.to_string(), None).await?; write!(f, "{}", self.0)
if model_entries.is_empty() {
anyhow::bail!("No ModelEntry in etcd for key {self}");
}
let model_entry = model_entries.remove(0);
serde_json::from_slice(model_entry.value()).with_context(|| {
format!(
"Error deserializing JSON. Key={self}. JSON={}",
model_entry.value_str().unwrap_or("INVALID UTF-8")
)
})
} }
} }
impl From<&Instance> for ModelNetworkName { impl AsRef<str> for ModelNetworkName {
fn from(cei: &Instance) -> Self { fn as_ref(&self) -> &str {
Self::from_parts( &self.0
&cei.namespace,
&cei.component,
&cei.endpoint,
cei.instance_id,
)
} }
} }
impl std::fmt::Display for ModelNetworkName { impl std::ops::Deref for ModelNetworkName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { type Target = str;
write!(f, "{}", self.0) fn deref(&self) -> &Self::Target {
&self.0
} }
} }
...@@ -182,6 +182,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" ...@@ -182,6 +182,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5"
dependencies = [ dependencies = [
"axum-core", "axum-core",
"axum-macros",
"bytes", "bytes",
"form_urlencoded", "form_urlencoded",
"futures-util", "futures-util",
...@@ -229,6 +230,17 @@ dependencies = [ ...@@ -229,6 +230,17 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "axum-macros"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]] [[package]]
name = "backtrace" name = "backtrace"
version = "0.3.74" version = "0.3.74"
...@@ -673,6 +685,7 @@ dependencies = [ ...@@ -673,6 +685,7 @@ dependencies = [
"tokio", "tokio",
"tokio-stream", "tokio-stream",
"tokio-util", "tokio-util",
"tower-http",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"url", "url",
...@@ -3013,6 +3026,7 @@ dependencies = [ ...@@ -3013,6 +3026,7 @@ dependencies = [
"tower", "tower",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
"tracing",
] ]
[[package]] [[package]]
......
...@@ -144,7 +144,7 @@ impl EndpointConfigBuilder { ...@@ -144,7 +144,7 @@ impl EndpointConfigBuilder {
if let Some(etcd_client) = &endpoint.component.drt.etcd_client { if let Some(etcd_client) = &endpoint.component.drt.etcd_client {
if let Err(e) = etcd_client if let Err(e) = etcd_client
.kv_create( .kv_create(
endpoint.etcd_path_with_lease_id(lease_id), &endpoint.etcd_path_with_lease_id(lease_id),
info, info,
Some(lease_id), Some(lease_id),
) )
......
...@@ -170,20 +170,15 @@ impl Client { ...@@ -170,20 +170,15 @@ impl Client {
.await? .await?
} }
pub async fn kv_create( pub async fn kv_create(&self, key: &str, value: Vec<u8>, lease_id: Option<i64>) -> Result<()> {
&self,
key: String,
value: Vec<u8>,
lease_id: Option<i64>,
) -> Result<()> {
let id = lease_id.unwrap_or(self.lease_id()); let id = lease_id.unwrap_or(self.lease_id());
let put_options = PutOptions::new().with_lease(id); let put_options = PutOptions::new().with_lease(id);
// Build the transaction // Build the transaction
let txn = Txn::new() let txn = Txn::new()
.when(vec![Compare::version(key.as_str(), CompareOp::Equal, 0)]) // Ensure the lock does not exist .when(vec![Compare::version(key, CompareOp::Equal, 0)]) // Ensure the lock does not exist
.and_then(vec![ .and_then(vec![
TxnOp::put(key.as_str(), value, Some(put_options)), // Create the object TxnOp::put(key, value, Some(put_options)), // Create the object
]); ]);
// Execute the transaction // Execute the transaction
......
...@@ -95,7 +95,7 @@ fn handle_watch_event<T: DeserializeOwned>( ...@@ -95,7 +95,7 @@ fn handle_watch_event<T: DeserializeOwned>(
/// Creates a key-value pair in etcd, returning a specific error if the key already exists /// Creates a key-value pair in etcd, returning a specific error if the key already exists
async fn create_barrier_key<T: Serialize>( async fn create_barrier_key<T: Serialize>(
client: &Client, client: &Client,
key: String, key: &str,
data: T, data: T,
lease_id: Option<i64>, lease_id: Option<i64>,
) -> Result<(), LeaderWorkerBarrierError> { ) -> Result<(), LeaderWorkerBarrierError> {
...@@ -193,7 +193,7 @@ impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + Deseriali ...@@ -193,7 +193,7 @@ impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + Deseriali
lease_id: i64, lease_id: i64,
) -> Result<(), LeaderWorkerBarrierError> { ) -> Result<(), LeaderWorkerBarrierError> {
let key = barrier_key(&self.barrier_id, BARRIER_DATA); let key = barrier_key(&self.barrier_id, BARRIER_DATA);
create_barrier_key(client, key, data, Some(lease_id)).await create_barrier_key(client, &key, data, Some(lease_id)).await
} }
async fn wait_for_workers( async fn wait_for_workers(
...@@ -216,10 +216,10 @@ impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + Deseriali ...@@ -216,10 +216,10 @@ impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + Deseriali
let workers = worker_result.keys().collect::<HashSet<_>>(); let workers = worker_result.keys().collect::<HashSet<_>>();
create_barrier_key(client, key, workers, Some(lease_id)).await?; create_barrier_key(client, &key, workers, Some(lease_id)).await?;
} else { } else {
let key = barrier_key(&self.barrier_id, BARRIER_ABORT); let key = barrier_key(&self.barrier_id, BARRIER_ABORT);
create_barrier_key(client, key, (), Some(lease_id)).await?; create_barrier_key(client, &key, (), Some(lease_id)).await?;
} }
Ok(()) Ok(())
...@@ -302,7 +302,7 @@ impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + Deseriali ...@@ -302,7 +302,7 @@ impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + Deseriali
&self.barrier_id, &self.barrier_id,
&format!("{}/{}", BARRIER_WORKER, self.worker_id), &format!("{}/{}", BARRIER_WORKER, self.worker_id),
); );
create_barrier_key(client, key.clone(), data, Some(lease_id)).await?; create_barrier_key(client, &key, data, Some(lease_id)).await?;
Ok(key) Ok(key)
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment