Unverified Commit aeb79e62 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Support multiple models on single ingress node (#1127)

We can now do this:

- Node 1:

```
dynamo-run in=http out=dyn
```

- Node 2 and 3, two instances of component 'backend' in the nemotron_ultra pipeline:

```
dynamo-run in=dyn://nemotron_ultra.backend.generate out=vllm /data/models/NemotronUltra
```

- Node 4 and 5, two instances of the 'backend' component in nemotron_super pipeline:

```
dynamo-run in=dyn://nemotron_super.backend.generate out=vllm /data/models/NemotronSuper
```

The ingress node will discover all four instances and route correctly. We have been planning for this for a long time now.

As part of this auto-discovery is now always `out=dyn`, with no extra URL parts. Previously it could only route to a single pipeline.

Also:
- Refactor endpoint / instance naming now that I understand them
- Fix removing models when their instance stops.
parent 74221fd7
...@@ -87,8 +87,8 @@ pub enum Output { ...@@ -87,8 +87,8 @@ pub enum Output {
/// Accept preprocessed requests, echo the tokens back as the response /// Accept preprocessed requests, echo the tokens back as the response
EchoCore, EchoCore,
/// Publish requests to a namespace/component/endpoint path. /// Listen for models on nats/etcd, add/remove dynamically
Endpoint(String), Dynamic,
#[cfg(feature = "mistralrs")] #[cfg(feature = "mistralrs")]
/// Run inference on a model in a GGUF file using mistralrs w/ candle /// Run inference on a model in a GGUF file using mistralrs w/ candle
...@@ -130,9 +130,15 @@ impl TryFrom<&str> for Output { ...@@ -130,9 +130,15 @@ impl TryFrom<&str> for Output {
"echo_full" => Ok(Output::EchoFull), "echo_full" => Ok(Output::EchoFull),
"echo_core" => Ok(Output::EchoCore), "echo_core" => Ok(Output::EchoCore),
"dyn" => Ok(Output::Dynamic),
// Deprecated, should only use `out=dyn`
endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => { endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => {
let path = endpoint_path.strip_prefix(ENDPOINT_SCHEME).unwrap(); tracing::warn!(
Ok(Output::Endpoint(path.to_string())) "out=dyn://<path> is deprecated, the path is not used. Please use 'out=dyn'"
);
//let path = endpoint_path.strip_prefix(ENDPOINT_SCHEME).unwrap();
Ok(Output::Dynamic)
} }
#[cfg(feature = "python")] #[cfg(feature = "python")]
...@@ -163,7 +169,7 @@ impl fmt::Display for Output { ...@@ -163,7 +169,7 @@ impl fmt::Display for Output {
Output::EchoFull => "echo_full", Output::EchoFull => "echo_full",
Output::EchoCore => "echo_core", Output::EchoCore => "echo_core",
Output::Endpoint(path) => path, Output::Dynamic => "dyn",
#[cfg(feature = "python")] #[cfg(feature = "python")]
Output::PythonStr(_) => "pystr", Output::PythonStr(_) => "pystr",
......
...@@ -271,7 +271,7 @@ async fn add_model( ...@@ -271,7 +271,7 @@ async fn add_model(
let component = distributed.namespace(&namespace)?.component("http")?; let component = distributed.namespace(&namespace)?.component("http")?;
let path = format!( let path = format!(
"{}/models/{}/{}", "{}/models/{}/{}",
component.etcd_path(), component.etcd_root(),
model_type.as_str(), model_type.as_str(),
model_name model_name
); );
...@@ -323,7 +323,7 @@ async fn list_single_model( ...@@ -323,7 +323,7 @@ async fn list_single_model(
let component = distributed.namespace(&namespace)?.component("http")?; let component = distributed.namespace(&namespace)?.component("http")?;
let path = format!( let path = format!(
"{}/models/{}/{}", "{}/models/{}/{}",
component.etcd_path(), component.etcd_root(),
model_type.as_str(), model_type.as_str(),
model_name model_name
); );
...@@ -374,7 +374,7 @@ async fn list_models( ...@@ -374,7 +374,7 @@ async fn list_models(
// TODO: Do we need the model_type in etcd key? // TODO: Do we need the model_type in etcd key?
for mt in model_types { for mt in model_types {
let prefix = format!("{}/models/{}/", component.etcd_path(), mt.as_str(),); let prefix = format!("{}/models/{}/", component.etcd_root(), mt.as_str(),);
let etcd_client = distributed let etcd_client = distributed
.etcd_client() .etcd_client()
...@@ -430,7 +430,7 @@ async fn remove_model( ...@@ -430,7 +430,7 @@ async fn remove_model(
let component = distributed.namespace(&namespace)?.component("http")?; let component = distributed.namespace(&namespace)?.component("http")?;
let prefix = format!( let prefix = format!(
"{}/models/{}/{}", "{}/models/{}/{}",
component.etcd_path(), component.etcd_root(),
model_type.as_str(), model_type.as_str(),
name name
); );
......
...@@ -36,7 +36,7 @@ async def init(runtime: DistributedRuntime, ns: str): ...@@ -36,7 +36,7 @@ async def init(runtime: DistributedRuntime, ns: str):
client = await endpoint.client() client = await endpoint.client()
# wait for an endpoint to be ready # wait for an endpoint to be ready
await client.wait_for_endpoints() await client.wait_for_instances()
# issue request # issue request
stream = await client.generate("hello world") stream = await client.generate("hello world")
......
...@@ -36,7 +36,7 @@ async def init(runtime: DistributedRuntime, ns: str): ...@@ -36,7 +36,7 @@ async def init(runtime: DistributedRuntime, ns: str):
client = await endpoint.client() client = await endpoint.client()
# wait for an endpoint to be ready # wait for an endpoint to be ready
await client.wait_for_endpoints() await client.wait_for_instances()
# issue request # issue request
stream = await client.generate("hello world") stream = await client.generate("hello world")
......
...@@ -36,7 +36,7 @@ async def init(runtime: DistributedRuntime, ns: str): ...@@ -36,7 +36,7 @@ async def init(runtime: DistributedRuntime, ns: str):
client = await endpoint.client() client = await endpoint.client()
# wait for an endpoint to be ready # wait for an endpoint to be ready
await client.wait_for_endpoints() await client.wait_for_instances()
# issue request # issue request
stream = await client.generate("hello world") stream = await client.generate("hello world")
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# - nats-server -js # - nats-server -js
# #
# Window 1: `python server_sglang.py`. Wait for log "Starting endpoint". # Window 1: `python server_sglang.py`. Wait for log "Starting endpoint".
# Window 2: `dynamo-run out=dyn://dynamo.backend.generate` # Window 2: `dynamo-run out=dyn
import argparse import argparse
import asyncio import asyncio
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
# - nats-server -js # - nats-server -js
# #
# Window 1: `python server_sglang.py`. Wait for log "Starting endpoint". # Window 1: `python server_sglang.py`. Wait for log "Starting endpoint".
# Window 2: `dynamo-run out=dyn://dynamo.backend.generate` # Window 2: `dynamo-run out=dyn
import argparse import argparse
import asyncio import asyncio
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
# - nats-server -js # - nats-server -js
# #
# Window 1: `python server_vllm.py`. Wait for log "Starting endpoint". # Window 1: `python server_vllm.py`. Wait for log "Starting endpoint".
# Window 2: `dynamo-run out=dyn://dynamo.backend.generate` # Window 2: `dynamo-run out=dyn
import argparse import argparse
import asyncio import asyncio
......
...@@ -31,8 +31,8 @@ async def worker(runtime: DistributedRuntime): ...@@ -31,8 +31,8 @@ async def worker(runtime: DistributedRuntime):
# create client # create client
client = await endpoint.client() client = await endpoint.client()
# list the endpoints # list the endpoint instances
print(client.endpoint_ids()) print(client.instance_ids())
# issue request # issue request
stream = await client.generate(Request(data="hello world").model_dump_json()) stream = await client.generate(Request(data="hello world").model_dump_json())
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use futures::StreamExt; use futures::StreamExt;
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
...@@ -589,16 +577,19 @@ impl EtcdClient { ...@@ -589,16 +577,19 @@ impl EtcdClient {
#[pymethods] #[pymethods]
impl Client { impl Client {
/// Get list of current endpoints /// Get list of current instances.
fn endpoint_ids(&self) -> Vec<i64> { /// Replaces endpoint_ids.
self.router.client.endpoint_ids() fn instance_ids(&self) -> Vec<i64> {
self.router.client.instance_ids()
} }
fn wait_for_endpoints<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> { /// Wait for an instance to be available for work.
/// Replaces wait_for_endpoints.
fn wait_for_instances<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.router.client.clone(); let inner = self.router.client.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner inner
.wait_for_endpoints() .wait_for_instances()
.await .await
.map(|v| v.into_iter().map(|cei| cei.id()).collect::<Vec<i64>>()) .map(|v| v.into_iter().map(|cei| cei.id()).collect::<Vec<i64>>())
.map_err(to_pyerr) .map_err(to_pyerr)
...@@ -669,12 +660,12 @@ impl Client { ...@@ -669,12 +660,12 @@ impl Client {
} }
/// Directly send a request to a specific endpoint. /// Directly send a request to a specific endpoint.
#[pyo3(signature = (request, endpoint_id, annotated=DEFAULT_ANNOTATED_SETTING))] #[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING))]
fn direct<'p>( fn direct<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
request: PyObject, request: PyObject,
endpoint_id: i64, instance_id: i64,
annotated: Option<bool>, annotated: Option<bool>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?; let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
...@@ -685,7 +676,7 @@ impl Client { ...@@ -685,7 +676,7 @@ impl Client {
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client let stream = client
.direct(request.into(), endpoint_id) .direct(request.into(), instance_id)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
......
...@@ -42,12 +42,16 @@ pub mod service_v2; ...@@ -42,12 +42,16 @@ pub mod service_v2;
pub use async_trait::async_trait; pub use async_trait::async_trait;
pub use axum; pub use axum;
use discovery::ModelEntry;
pub use error::ServiceHttpError; pub use error::ServiceHttpError;
pub use metrics::Metrics; pub use metrics::Metrics;
use crate::types::openai::{ use crate::{
chat_completions::OpenAIChatCompletionsStreamingEngine, kv_router::KvRouter,
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine, types::openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
},
}; };
use std::{ use std::{
collections::HashMap, collections::HashMap,
...@@ -208,6 +212,8 @@ pub struct DeploymentState { ...@@ -208,6 +212,8 @@ pub struct DeploymentState {
embeddings_engines: Arc<Mutex<ModelEngines<OpenAIEmbeddingsStreamingEngine>>>, embeddings_engines: Arc<Mutex<ModelEngines<OpenAIEmbeddingsStreamingEngine>>>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
sse_keep_alive: Option<Duration>, sse_keep_alive: Option<Duration>,
entries: Arc<Mutex<HashMap<String, ModelEntry>>>,
kv_choosers: Arc<Mutex<HashMap<String, Arc<KvRouter>>>>,
} }
impl DeploymentState { impl DeploymentState {
...@@ -218,6 +224,8 @@ impl DeploymentState { ...@@ -218,6 +224,8 @@ impl DeploymentState {
embeddings_engines: Arc::new(Mutex::new(ModelEngines::default())), embeddings_engines: Arc::new(Mutex::new(ModelEngines::default())),
metrics: Arc::new(Metrics::default()), metrics: Arc::new(Metrics::default()),
sse_keep_alive: None, sse_keep_alive: None,
entries: Arc::new(Mutex::new(HashMap::new())),
kv_choosers: Arc::new(Mutex::new(HashMap::new())),
} }
} }
...@@ -235,7 +243,7 @@ impl DeploymentState { ...@@ -235,7 +243,7 @@ impl DeploymentState {
.ok_or(ServiceHttpError::ModelNotFound(model.to_string())) .ok_or(ServiceHttpError::ModelNotFound(model.to_string()))
} }
fn get_completions_engine( pub fn get_completions_engine(
&self, &self,
model: &str, model: &str,
) -> Result<OpenAICompletionsStreamingEngine, ServiceHttpError> { ) -> Result<OpenAICompletionsStreamingEngine, ServiceHttpError> {
...@@ -247,7 +255,7 @@ impl DeploymentState { ...@@ -247,7 +255,7 @@ impl DeploymentState {
.ok_or(ServiceHttpError::ModelNotFound(model.to_string())) .ok_or(ServiceHttpError::ModelNotFound(model.to_string()))
} }
fn get_chat_completions_engine( pub fn get_chat_completions_engine(
&self, &self,
model: &str, model: &str,
) -> Result<OpenAIChatCompletionsStreamingEngine, ServiceHttpError> { ) -> Result<OpenAIChatCompletionsStreamingEngine, ServiceHttpError> {
......
...@@ -5,17 +5,16 @@ use std::sync::Arc; ...@@ -5,17 +5,16 @@ use std::sync::Arc;
use anyhow::Context as _; use anyhow::Context as _;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::Receiver; use tokio::sync::{mpsc::Receiver, Notify};
use dynamo_runtime::{ use dynamo_runtime::{
component::{self, Component, ComponentEndpointInfo}, component::{self, Component, Instance},
pipeline::{ pipeline::{
network::egress::push_router::PushRouter, ManyOut, Operator, RouterMode, SegmentSource, network::egress::push_router::PushRouter, ManyOut, Operator, RouterMode, SegmentSource,
ServiceBackend, SingleIn, Source, ServiceBackend, SingleIn, Source,
}, },
protocols::{self, annotated::Annotated}, protocols::{self, annotated::Annotated},
slug::Slug, slug::Slug,
traits::DistributedRuntimeProvider as _,
transports::etcd::{self, KeyValue, WatchEvent}, transports::etcd::{self, KeyValue, WatchEvent},
DistributedRuntime, DistributedRuntime,
}; };
...@@ -66,15 +65,13 @@ impl ModelEntry { ...@@ -66,15 +65,13 @@ impl ModelEntry {
/// This does not touch it's fields so you may need to call move_from_nats on it. /// This does not touch it's fields so you may need to call move_from_nats on it.
pub async fn load_mdc( pub async fn load_mdc(
&self, &self,
endpoint_id: protocols::Endpoint,
etcd_client: &etcd::Client, etcd_client: &etcd::Client,
) -> anyhow::Result<ModelDeploymentCard> { ) -> anyhow::Result<ModelDeploymentCard> {
let kvstore: Box<dyn KeyValueStore> = let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
Box::new(EtcdStorage::new(etcd_client.clone(), endpoint_id));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore)); let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let card_key = ModelDeploymentCard::service_name_slug(&self.name); let card_key = ModelDeploymentCard::service_name_slug(&self.name);
match card_store match card_store
.load::<ModelDeploymentCard>(model_card::BUCKET_NAME, &card_key) .load::<ModelDeploymentCard>(model_card::ROOT_PATH, &card_key)
.await .await
{ {
Ok(Some(mdc)) => Ok(mdc), Ok(Some(mdc)) => Ok(mdc),
...@@ -99,9 +96,9 @@ impl ModelNetworkName { ...@@ -99,9 +96,9 @@ impl ModelNetworkName {
/// It looks like this: /// It looks like this:
/// ns.cp.ep-694d967ca5efd804 /// ns.cp.ep-694d967ca5efd804
fn from_parts(namespace: &str, component: &str, endpoint: &str, lease_id: i64) -> Self { fn from_parts(namespace: &str, component: &str, endpoint: &str, lease_id: i64) -> Self {
ModelNetworkName( let model_root = component::MODEL_ROOT_PATH;
Slug::slugify(&format!("{namespace}.{component}.{endpoint}-{lease_id:x}")).to_string(), let slug = Slug::slugify(&format!("{namespace}.{component}.{endpoint}-{lease_id:x}"));
) ModelNetworkName(format!("{model_root}/{slug}"))
} }
// We can't do From<&component::Endpoint> here because we also need the lease_id // We can't do From<&component::Endpoint> here because we also need the lease_id
...@@ -134,17 +131,21 @@ impl ModelNetworkName { ...@@ -134,17 +131,21 @@ impl ModelNetworkName {
/// TODO We have potentially two for each endpoint, one Chat and one Completion. /// TODO We have potentially two for each endpoint, one Chat and one Completion.
pub async fn load_mdc( pub async fn load_mdc(
&self, &self,
endpoint_id: protocols::Endpoint,
etcd_client: &etcd::Client, etcd_client: &etcd::Client,
) -> anyhow::Result<ModelDeploymentCard> { ) -> anyhow::Result<ModelDeploymentCard> {
let entry = self.load_entry(etcd_client).await?; let entry = self.load_entry(etcd_client).await?;
entry.load_mdc(endpoint_id, etcd_client).await entry.load_mdc(etcd_client).await
} }
} }
impl From<&ComponentEndpointInfo> for ModelNetworkName { impl From<&Instance> for ModelNetworkName {
fn from(cei: &ComponentEndpointInfo) -> Self { fn from(cei: &Instance) -> Self {
Self::from_parts(&cei.namespace, &cei.component, &cei.endpoint, cei.lease_id) Self::from_parts(
&cei.namespace,
&cei.component,
&cei.endpoint,
cei.instance_id,
)
} }
} }
...@@ -155,41 +156,37 @@ impl std::fmt::Display for ModelNetworkName { ...@@ -155,41 +156,37 @@ impl std::fmt::Display for ModelNetworkName {
} }
pub struct ModelWatcher { pub struct ModelWatcher {
prefix: String,
manager: ModelManager, manager: ModelManager,
drt: DistributedRuntime, drt: DistributedRuntime,
router_mode: RouterMode, router_mode: RouterMode,
kv_chooser: Option<Arc<KvRouter>>, notify_on_model: Notify,
} }
impl ModelWatcher { impl ModelWatcher {
pub async fn new( pub async fn new(
component: Component, runtime: DistributedRuntime,
model_manager: ModelManager, model_manager: ModelManager,
network_prefix: &str,
router_mode: RouterMode, router_mode: RouterMode,
) -> anyhow::Result<ModelWatcher> { ) -> anyhow::Result<ModelWatcher> {
let kv_chooser = if router_mode.is_kv_routing() {
let selector = Box::new(DefaultWorkerSelector {});
let chooser = KvRouter::new(
component.clone(),
crate::DEFAULT_KV_BLOCK_SIZE,
Some(selector),
)
.await?;
Some(Arc::new(chooser))
} else {
None
};
Ok(Self { Ok(Self {
prefix: network_prefix.to_string(),
manager: model_manager, manager: model_manager,
drt: component.drt().clone(), drt: runtime,
router_mode, router_mode,
kv_chooser, notify_on_model: Notify::new(),
}) })
} }
/// Wait until we have at least one chat completions model and return it's name.
pub async fn wait_for_chat_model(&self) -> String {
// Loop in case it gets added and immediately deleted
loop {
if let Some(model_name) = self.manager.list_chat_completions_models().first() {
return model_name.to_owned();
}
self.notify_on_model.notified().await
}
}
pub async fn watch(self: Arc<Self>, mut events_rx: Receiver<WatchEvent>) { pub async fn watch(self: Arc<Self>, mut events_rx: Receiver<WatchEvent>) {
tracing::debug!("model watcher started"); tracing::debug!("model watcher started");
...@@ -199,7 +196,14 @@ impl ModelWatcher { ...@@ -199,7 +196,14 @@ impl ModelWatcher {
let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) { let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
Ok(model_entry) => model_entry, Ok(model_entry) => model_entry,
Err(err) => { Err(err) => {
tracing::error!(%err, ?kv, "Invalid JSON in model entry"); match kv.value_str() {
Ok(value) => {
tracing::error!(%err, value, "Invalid JSON in model entry")
}
Err(value_str_err) => {
tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
}
}
continue; continue;
} }
}; };
...@@ -208,12 +212,14 @@ impl ModelWatcher { ...@@ -208,12 +212,14 @@ impl ModelWatcher {
service_name = model_entry.name, service_name = model_entry.name,
"New endpoint for existing model" "New endpoint for existing model"
); );
self.notify_on_model.notify_waiters();
continue; continue;
} }
match self.clone().handle_put(&model_entry).await { match self.clone().handle_put(&kv, &model_entry).await {
Ok(()) => { Ok(()) => {
tracing::info!(model_name = model_entry.name, "added model"); tracing::info!(model_name = model_entry.name, "added model");
self.notify_on_model.notify_waiters();
} }
Err(e) => { Err(e) => {
tracing::error!(%e, "error adding model {}", model_entry.name); tracing::error!(%e, "error adding model {}", model_entry.name);
...@@ -232,39 +238,49 @@ impl ModelWatcher { ...@@ -232,39 +238,49 @@ impl ModelWatcher {
} }
} }
async fn handle_delete(self: Arc<Self>, kv: &KeyValue) -> anyhow::Result<&str> { /// Returns the name of the model we just deleted
async fn handle_delete(self: Arc<ModelWatcher>, kv: &KeyValue) -> anyhow::Result<String> {
let key = kv.key_str()?; let key = kv.key_str()?;
tracing::debug!(key, "removing model"); let model_entry = match self.manager.state.entries.lock().unwrap().remove(key) {
Some(entry) => entry,
let model_name = key.trim_start_matches(&self.prefix); None => {
anyhow::bail!("Missing ModelEntry for {key}");
}
};
let model_name = &model_entry.name;
tracing::debug!(model_name, "removing model");
// Ignore the errors because model could be either type // Ignore the errors because model could be either type
let _ = self.manager.remove_chat_completions_model(model_name); let _ = self.manager.remove_chat_completions_model(model_name);
let _ = self.manager.remove_completions_model(model_name); let _ = self.manager.remove_completions_model(model_name);
let _ = self.manager.remove_embeddings_model(model_name); let _ = self.manager.remove_embeddings_model(model_name);
Ok(model_name) // We own model_entry now so take ownership of the name
Ok(model_entry.name)
} }
// Handles a PUT event from etcd, this usually means adding a new model to the list of served // Handles a PUT event from etcd, this usually means adding a new model to the list of served
// models. // models.
// //
// If this method errors, for the near term, we will delete the offending key. // If this method errors, for the near term, we will delete the offending key.
async fn handle_put(self: Arc<ModelWatcher>, model_entry: &ModelEntry) -> anyhow::Result<()> { async fn handle_put(
self: Arc<ModelWatcher>,
kv: &KeyValue,
model_entry: &ModelEntry,
) -> anyhow::Result<()> {
let key = kv.key_str()?;
let endpoint_id = model_entry.endpoint.clone(); let endpoint_id = model_entry.endpoint.clone();
let client = self let component = self
.drt .drt
.namespace(&endpoint_id.namespace)? .namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)? .component(&endpoint_id.component)?;
.endpoint(&endpoint_id.name) let client = component.endpoint(&endpoint_id.name).client().await?;
.client()
.await?;
let Some(etcd_client) = self.drt.etcd_client() else { let Some(etcd_client) = self.drt.etcd_client() else {
// Should be impossible because we only get here on an etcd event // Should be impossible because we only get here on an etcd event
anyhow::bail!("Missing etcd_client"); anyhow::bail!("Missing etcd_client");
}; };
let card = match model_entry.load_mdc(endpoint_id, &etcd_client).await { let card = match model_entry.load_mdc(&etcd_client).await {
Ok(card) => { Ok(card) => {
tracing::debug!(card.display_name, "adding model"); tracing::debug!(card.display_name, "adding model");
Some(card) Some(card)
...@@ -275,6 +291,14 @@ impl ModelWatcher { ...@@ -275,6 +291,14 @@ impl ModelWatcher {
None None
} }
}; };
// We need to save the entry to know what the model is called when we delete it
self.manager
.state
.entries
.lock()
.unwrap()
.insert(key.to_string(), model_entry.clone());
match model_entry.model_type { match model_entry.model_type {
ModelType::Backend => { ModelType::Backend => {
// A Backend model expects pre-processed requests meaning it's up to us whether we // A Backend model expects pre-processed requests meaning it's up to us whether we
...@@ -305,10 +329,8 @@ impl ModelWatcher { ...@@ -305,10 +329,8 @@ impl ModelWatcher {
ServiceBackend::from_engine(Arc::new(router)) ServiceBackend::from_engine(Arc::new(router))
} }
RouterMode::KV => { RouterMode::KV => {
let Some(kv_chooser) = self.kv_chooser.clone() else { let chooser = self.kv_chooser_for(&model_entry.name, &component).await?;
anyhow::bail!("KV routing mode with no chooser, should be unreachable"); let kv_push_router = KvPushRouter::new(router, chooser);
};
let kv_push_router = KvPushRouter::new(router, kv_chooser);
ServiceBackend::from_engine(Arc::new(kv_push_router)) ServiceBackend::from_engine(Arc::new(kv_push_router))
} }
}; };
...@@ -339,10 +361,8 @@ impl ModelWatcher { ...@@ -339,10 +361,8 @@ impl ModelWatcher {
ServiceBackend::from_engine(Arc::new(router)) ServiceBackend::from_engine(Arc::new(router))
} }
RouterMode::KV => { RouterMode::KV => {
let Some(kv_chooser) = self.kv_chooser.clone() else { let chooser = self.kv_chooser_for(&model_entry.name, &component).await?;
anyhow::bail!("KV routing mode with no chooser, should be unreachable"); let kv_push_router = KvPushRouter::new(router, chooser);
};
let kv_push_router = KvPushRouter::new(router, kv_chooser);
ServiceBackend::from_engine(Arc::new(kv_push_router)) ServiceBackend::from_engine(Arc::new(kv_push_router))
} }
}; };
...@@ -392,4 +412,37 @@ impl ModelWatcher { ...@@ -392,4 +412,37 @@ impl ModelWatcher {
Ok(()) Ok(())
} }
async fn kv_chooser_for(
&self,
model_name: &str,
component: &Component,
) -> anyhow::Result<Arc<KvRouter>> {
if let Some(kv_chooser) = self
.manager
.state
.kv_choosers
.lock()
.unwrap()
.get(model_name)
{
// Return early to avoid holding the lock during the await later
return Ok(Arc::clone(kv_chooser));
}
let selector = Box::new(DefaultWorkerSelector {});
let chooser = KvRouter::new(
component.clone(),
crate::DEFAULT_KV_BLOCK_SIZE,
Some(selector),
)
.await?;
let new_kv_chooser = Arc::new(chooser);
self.manager
.state
.kv_choosers
.lock()
.unwrap()
.insert(model_name.to_string(), new_kv_chooser.clone());
Ok(new_kv_chooser)
}
} }
...@@ -19,7 +19,7 @@ use std::time::Duration; ...@@ -19,7 +19,7 @@ use std::time::Duration;
use async_stream::stream; use async_stream::stream;
use async_trait::async_trait; use async_trait::async_trait;
use dynamo_runtime::{protocols::Endpoint, slug::Slug, transports::etcd::Client}; use dynamo_runtime::{slug::Slug, transports::etcd::Client};
use etcd_client::{EventType, PutOptions, WatchOptions}; use etcd_client::{EventType, PutOptions, WatchOptions};
use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome}; use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
...@@ -27,12 +27,11 @@ use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome}; ...@@ -27,12 +27,11 @@ use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
#[derive(Clone)] #[derive(Clone)]
pub struct EtcdStorage { pub struct EtcdStorage {
client: Client, client: Client,
endpoint: Endpoint,
} }
impl EtcdStorage { impl EtcdStorage {
pub fn new(client: Client, endpoint: Endpoint) -> Self { pub fn new(client: Client) -> Self {
Self { client, endpoint } Self { client }
} }
} }
...@@ -55,7 +54,6 @@ impl KeyValueStore for EtcdStorage { ...@@ -55,7 +54,6 @@ impl KeyValueStore for EtcdStorage {
) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> { ) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> {
Ok(Some(Box::new(EtcdBucket { Ok(Some(Box::new(EtcdBucket {
client: self.client.clone(), client: self.client.clone(),
endpoint: self.endpoint.clone(),
bucket_name: bucket_name.to_string(), bucket_name: bucket_name.to_string(),
}))) })))
} }
...@@ -63,7 +61,6 @@ impl KeyValueStore for EtcdStorage { ...@@ -63,7 +61,6 @@ impl KeyValueStore for EtcdStorage {
pub struct EtcdBucket { pub struct EtcdBucket {
client: Client, client: Client,
endpoint: Endpoint,
bucket_name: String, bucket_name: String,
} }
...@@ -85,7 +82,7 @@ impl KeyValueBucket for EtcdBucket { ...@@ -85,7 +82,7 @@ impl KeyValueBucket for EtcdBucket {
} }
async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> { async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> {
let k = make_key(&self.endpoint, &self.bucket_name, key); let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd get: {k}"); tracing::trace!("etcd get: {k}");
let mut kvs = self let mut kvs = self
...@@ -113,7 +110,7 @@ impl KeyValueBucket for EtcdBucket { ...@@ -113,7 +110,7 @@ impl KeyValueBucket for EtcdBucket {
&self, &self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError> ) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>
{ {
let k = make_key(&self.endpoint, &self.bucket_name, ""); let k = make_key(&self.bucket_name, "");
tracing::trace!("etcd watch: {k}"); tracing::trace!("etcd watch: {k}");
let (_watcher, mut watch_stream) = self let (_watcher, mut watch_stream) = self
.client .client
...@@ -136,7 +133,7 @@ impl KeyValueBucket for EtcdBucket { ...@@ -136,7 +133,7 @@ impl KeyValueBucket for EtcdBucket {
} }
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> { async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> {
let k = make_key(&self.endpoint, &self.bucket_name, ""); let k = make_key(&self.bucket_name, "");
tracing::trace!("etcd entries: {k}"); tracing::trace!("etcd entries: {k}");
let resp = self let resp = self
...@@ -158,7 +155,7 @@ impl KeyValueBucket for EtcdBucket { ...@@ -158,7 +155,7 @@ impl KeyValueBucket for EtcdBucket {
impl EtcdBucket { impl EtcdBucket {
async fn create(&self, key: &str, value: &str) -> Result<StorageOutcome, StorageError> { async fn create(&self, key: &str, value: &str) -> Result<StorageOutcome, StorageError> {
let k = make_key(&self.endpoint, &self.bucket_name, key); let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd create: {k}"); tracing::trace!("etcd create: {k}");
// Does it already exists? For 'create' it shouldn't. // Does it already exists? For 'create' it shouldn't.
...@@ -195,7 +192,7 @@ impl EtcdBucket { ...@@ -195,7 +192,7 @@ impl EtcdBucket {
revision: u64, revision: u64,
) -> Result<StorageOutcome, StorageError> { ) -> Result<StorageOutcome, StorageError> {
let version = revision; let version = revision;
let k = make_key(&self.endpoint, &self.bucket_name, key); let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd update: {k}"); tracing::trace!("etcd update: {k}");
let kvs = self let kvs = self
...@@ -237,9 +234,8 @@ impl EtcdBucket { ...@@ -237,9 +234,8 @@ impl EtcdBucket {
} }
} }
fn make_key(endpoint: &Endpoint, bucket_name: &str, key: &str) -> String { fn make_key(bucket_name: &str, key: &str) -> String {
[ [
endpoint.namespace.to_string(),
Slug::slugify(bucket_name).to_string(), Slug::slugify(bucket_name).to_string(),
Slug::slugify(key).to_string(), Slug::slugify(key).to_string(),
] ]
......
...@@ -17,7 +17,7 @@ use std::sync::Arc; ...@@ -17,7 +17,7 @@ use std::sync::Arc;
use anyhow::Result; use anyhow::Result;
use dynamo_runtime::{ use dynamo_runtime::{
component::{Component, EndpointSource}, component::{Component, InstanceSource},
pipeline::{ pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter, async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter,
ResponseStream, SingleIn, ResponseStream, SingleIn,
...@@ -199,11 +199,11 @@ impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Er ...@@ -199,11 +199,11 @@ impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Er
&self, &self,
request: SingleIn<BackendInput>, request: SingleIn<BackendInput>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> { ) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
match &self.inner.client.endpoints { match &self.inner.client.instances {
EndpointSource::Static => self.inner.r#static(request).await, InstanceSource::Static => self.inner.r#static(request).await,
EndpointSource::Dynamic(_) => { InstanceSource::Dynamic(_) => {
let worker_id = self.chooser.find_best_match(&request.token_ids).await?; let instance_id = self.chooser.find_best_match(&request.token_ids).await?;
self.inner.direct(request, worker_id).await self.inner.direct(request, instance_id).await
} }
} }
} }
......
...@@ -134,13 +134,11 @@ impl LocalModel { ...@@ -134,13 +134,11 @@ impl LocalModel {
self.card.move_to_nats(nats_client.clone()).await?; self.card.move_to_nats(nats_client.clone()).await?;
// Publish the Model Deployment Card to etcd // Publish the Model Deployment Card to etcd
let endpoint_id = endpoint.id(); let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let kvstore: Box<dyn KeyValueStore> =
Box::new(EtcdStorage::new(etcd_client.clone(), endpoint_id.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore)); let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let key = self.card.slug().to_string(); let key = self.card.slug().to_string();
card_store card_store
.publish(model_card::BUCKET_NAME, None, &key, &mut self.card) .publish(model_card::ROOT_PATH, None, &key, &mut self.card)
.await?; .await?;
// Publish our ModelEntry to etcd. This allows ingress to find the model card. // Publish our ModelEntry to etcd. This allows ingress to find the model card.
...@@ -149,7 +147,7 @@ impl LocalModel { ...@@ -149,7 +147,7 @@ impl LocalModel {
tracing::debug!("Registering with etcd as {network_name}"); tracing::debug!("Registering with etcd as {network_name}");
let model_registration = ModelEntry { let model_registration = ModelEntry {
name: self.service_name().to_string(), name: self.service_name().to_string(),
endpoint: endpoint_id.clone(), endpoint: endpoint.id(),
model_type, model_type,
}; };
etcd_client etcd_client
...@@ -172,7 +170,7 @@ impl LocalModel { ...@@ -172,7 +170,7 @@ impl LocalModel {
// A static component is necessarily unique, it cannot register // A static component is necessarily unique, it cannot register
return Ok(()); return Ok(());
}; };
for endpoint_info in component.list_endpoints().await? { for endpoint_info in component.list_instances().await? {
let network_name: ModelNetworkName = (&endpoint_info).into(); let network_name: ModelNetworkName = (&endpoint_info).into();
let entry = network_name.load_entry(&etcd_client).await?; let entry = network_name.load_entry(&etcd_client).await?;
if entry.name != model_name { if entry.name != model_name {
......
...@@ -23,7 +23,7 @@ pub use model::ModelDeploymentCard; ...@@ -23,7 +23,7 @@ pub use model::ModelDeploymentCard;
// network module? // network module?
/// Identify model deployment cards in the key-value store /// Identify model deployment cards in the key-value store
pub const BUCKET_NAME: &str = "mdc"; pub const ROOT_PATH: &str = "mdc";
/// Delete model deployment cards that haven't been re-published after this long. /// Delete model deployment cards that haven't been re-published after this long.
/// Cleans up if the worker stopped. /// Cleans up if the worker stopped.
......
...@@ -44,8 +44,6 @@ use crate::gguf::{Content, ContentConfig, ModelConfigLike}; ...@@ -44,8 +44,6 @@ use crate::gguf::{Content, ContentConfig, ModelConfigLike};
use crate::key_value_store::Versioned; use crate::key_value_store::Versioned;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
pub const BUCKET_NAME: &str = "mdc";
/// Delete model deployment cards that haven't been re-published after this long. /// Delete model deployment cards that haven't been re-published after this long.
/// Cleans up if the worker stopped. /// Cleans up if the worker stopped.
pub const BUCKET_TTL: Duration = Duration::from_secs(5 * 60); pub const BUCKET_TTL: Duration = Duration::from_secs(5 * 60);
......
...@@ -34,7 +34,7 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -34,7 +34,7 @@ async fn app(runtime: Runtime) -> Result<()> {
.endpoint("generate") .endpoint("generate")
.client() .client()
.await?; .await?;
client.wait_for_endpoints().await?; client.wait_for_instances().await?;
let router = let router =
PushRouter::<String, Annotated<String>>::from_client(client, Default::default()).await?; PushRouter::<String, Annotated<String>>::from_client(client, Default::default()).await?;
......
...@@ -35,7 +35,7 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -35,7 +35,7 @@ async fn app(runtime: Runtime) -> Result<()> {
let client = component.endpoint("generate").client().await?; let client = component.endpoint("generate").client().await?;
client.wait_for_endpoints().await?; client.wait_for_instances().await?;
let router = let router =
PushRouter::<String, Annotated<String>>::from_client(client, Default::default()).await?; PushRouter::<String, Annotated<String>>::from_client(client, Default::default()).await?;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! The [Component] module defines the top-level API for building distributed applications. //! The [Component] module defines the top-level API for building distributed applications.
//! //!
...@@ -69,7 +57,14 @@ mod namespace; ...@@ -69,7 +57,14 @@ mod namespace;
mod registry; mod registry;
pub mod service; pub mod service;
pub use client::{Client, EndpointSource}; pub use client::{Client, InstanceSource};
/// The root etcd path where each instance registers itself in etcd.
/// An instance is namespace+component+endpoint+lease_id and must be unique.
pub const INSTANCE_ROOT_PATH: &str = "instances";
/// The root etcd path for ModelEntry
pub const MODEL_ROOT_PATH: &str = "models";
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
...@@ -89,17 +84,17 @@ pub struct Registry { ...@@ -89,17 +84,17 @@ pub struct Registry {
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComponentEndpointInfo { pub struct Instance {
pub component: String, pub component: String,
pub endpoint: String, pub endpoint: String,
pub namespace: String, pub namespace: String,
pub lease_id: i64, pub instance_id: i64,
pub transport: TransportType, pub transport: TransportType,
} }
impl ComponentEndpointInfo { impl Instance {
pub fn id(&self) -> i64 { pub fn id(&self) -> i64 {
self.lease_id self.instance_id
} }
} }
...@@ -150,8 +145,11 @@ impl RuntimeProvider for Component { ...@@ -150,8 +145,11 @@ impl RuntimeProvider for Component {
} }
impl Component { impl Component {
pub fn etcd_path(&self) -> String { /// The component part of an instance path in etcd.
format!("{}/components/{}", self.namespace.name(), self.name) pub fn etcd_root(&self) -> String {
let ns = self.namespace.name();
let cp = &self.name;
format!("{INSTANCE_ROOT_PATH}/{ns}/{cp}")
} }
pub fn service_name(&self) -> String { pub fn service_name(&self) -> String {
...@@ -179,21 +177,21 @@ impl Component { ...@@ -179,21 +177,21 @@ impl Component {
} }
} }
pub async fn list_endpoints(&self) -> anyhow::Result<Vec<ComponentEndpointInfo>> { pub async fn list_instances(&self) -> anyhow::Result<Vec<Instance>> {
let Some(etcd_client) = self.drt.etcd_client() else { let Some(etcd_client) = self.drt.etcd_client() else {
return Ok(vec![]); return Ok(vec![]);
}; };
let mut out = vec![]; let mut out = vec![];
// The extra slash is important to only list exact component matches, not substrings. // The extra slash is important to only list exact component matches, not substrings.
for kv in etcd_client for kv in etcd_client
.kv_get_prefix(format!("{}/", self.etcd_path())) .kv_get_prefix(format!("{}/", self.etcd_root()))
.await? .await?
{ {
let val = match serde_json::from_slice::<ComponentEndpointInfo>(kv.value()) { let val = match serde_json::from_slice::<Instance>(kv.value()) {
Ok(val) => val, Ok(val) => val,
Err(err) => { Err(err) => {
anyhow::bail!( anyhow::bail!(
"Error converting etcd response to ComponentEndpointInfo: {err}. {}", "Error converting etcd response to Instance: {err}. {}",
kv.value_str()? kv.value_str()?
); );
} }
...@@ -276,15 +274,20 @@ impl Endpoint { ...@@ -276,15 +274,20 @@ impl Endpoint {
format!("{}/{}", self.component.path(), self.name) format!("{}/{}", self.component.path(), self.name)
} }
pub fn etcd_path(&self) -> String { /// The endpoint part of an instance path in etcd
format!("{}/{}", self.component.etcd_path(), self.name) pub fn etcd_root(&self) -> String {
let component_path = self.component.etcd_root();
let endpoint_name = &self.name;
format!("{component_path}/{endpoint_name}")
} }
pub fn etcd_path_with_id(&self, lease_id: i64) -> String { /// The fully path of an instance in etcd
pub fn etcd_path(&self, lease_id: i64) -> String {
let endpoint_root = self.etcd_root();
if self.is_static { if self.is_static {
self.etcd_path() endpoint_root
} else { } else {
format!("{}:{:x}", self.etcd_path(), lease_id) format!("{endpoint_root}:{lease_id:x}")
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment