Unverified Commit aeb79e62 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Support multiple models on single ingress node (#1127)

We can now do this:

- Node 1:

```
dynamo-run in=http out=dyn
```

- Node 2 and 3, two instances of component 'backend' in the nemotron_ultra pipeline:

```
dynamo-run in=dyn://nemotron_ultra.backend.generate out=vllm /data/models/NemotronUltra
```

- Node 4 and 5, two instances of the 'backend' component in nemotron_super pipeline:

```
dynamo-run in=dyn://nemotron_super.backend.generate out=vllm /data/models/NemotronSuper
```

The ingress node will discover all four instances and route correctly. We have been planning for this for a long time now.

As part of this auto-discovery is now always `out=dyn`, with no extra URL parts. Previously it could only route to a single pipeline.

Also:
- Refactor endpoint / instance naming now that I understand them
- Fix removing models when their instance stops.
parent 74221fd7
......@@ -87,8 +87,8 @@ pub enum Output {
/// Accept preprocessed requests, echo the tokens back as the response
EchoCore,
/// Publish requests to a namespace/component/endpoint path.
Endpoint(String),
/// Listen for models on nats/etcd, add/remove dynamically
Dynamic,
#[cfg(feature = "mistralrs")]
/// Run inference on a model in a GGUF file using mistralrs w/ candle
......@@ -130,9 +130,15 @@ impl TryFrom<&str> for Output {
"echo_full" => Ok(Output::EchoFull),
"echo_core" => Ok(Output::EchoCore),
"dyn" => Ok(Output::Dynamic),
// Deprecated, should only use `out=dyn`
endpoint_path if endpoint_path.starts_with(ENDPOINT_SCHEME) => {
let path = endpoint_path.strip_prefix(ENDPOINT_SCHEME).unwrap();
Ok(Output::Endpoint(path.to_string()))
tracing::warn!(
"out=dyn://<path> is deprecated, the path is not used. Please use 'out=dyn'"
);
//let path = endpoint_path.strip_prefix(ENDPOINT_SCHEME).unwrap();
Ok(Output::Dynamic)
}
#[cfg(feature = "python")]
......@@ -163,7 +169,7 @@ impl fmt::Display for Output {
Output::EchoFull => "echo_full",
Output::EchoCore => "echo_core",
Output::Endpoint(path) => path,
Output::Dynamic => "dyn",
#[cfg(feature = "python")]
Output::PythonStr(_) => "pystr",
......
......@@ -271,7 +271,7 @@ async fn add_model(
let component = distributed.namespace(&namespace)?.component("http")?;
let path = format!(
"{}/models/{}/{}",
component.etcd_path(),
component.etcd_root(),
model_type.as_str(),
model_name
);
......@@ -323,7 +323,7 @@ async fn list_single_model(
let component = distributed.namespace(&namespace)?.component("http")?;
let path = format!(
"{}/models/{}/{}",
component.etcd_path(),
component.etcd_root(),
model_type.as_str(),
model_name
);
......@@ -374,7 +374,7 @@ async fn list_models(
// TODO: Do we need the model_type in etcd key?
for mt in model_types {
let prefix = format!("{}/models/{}/", component.etcd_path(), mt.as_str(),);
let prefix = format!("{}/models/{}/", component.etcd_root(), mt.as_str(),);
let etcd_client = distributed
.etcd_client()
......@@ -430,7 +430,7 @@ async fn remove_model(
let component = distributed.namespace(&namespace)?.component("http")?;
let prefix = format!(
"{}/models/{}/{}",
component.etcd_path(),
component.etcd_root(),
model_type.as_str(),
name
);
......
......@@ -36,7 +36,7 @@ async def init(runtime: DistributedRuntime, ns: str):
client = await endpoint.client()
# wait for an endpoint to be ready
await client.wait_for_endpoints()
await client.wait_for_instances()
# issue request
stream = await client.generate("hello world")
......
......@@ -36,7 +36,7 @@ async def init(runtime: DistributedRuntime, ns: str):
client = await endpoint.client()
# wait for an endpoint to be ready
await client.wait_for_endpoints()
await client.wait_for_instances()
# issue request
stream = await client.generate("hello world")
......
......@@ -36,7 +36,7 @@ async def init(runtime: DistributedRuntime, ns: str):
client = await endpoint.client()
# wait for an endpoint to be ready
await client.wait_for_endpoints()
await client.wait_for_instances()
# issue request
stream = await client.generate("hello world")
......
......@@ -17,7 +17,7 @@
# - nats-server -js
#
# Window 1: `python server_sglang.py`. Wait for log "Starting endpoint".
# Window 2: `dynamo-run out=dyn://dynamo.backend.generate`
# Window 2: `dynamo-run out=dyn
import argparse
import asyncio
......
......@@ -18,7 +18,7 @@
# - nats-server -js
#
# Window 1: `python server_sglang.py`. Wait for log "Starting endpoint".
# Window 2: `dynamo-run out=dyn://dynamo.backend.generate`
# Window 2: `dynamo-run out=dyn
import argparse
import asyncio
......
......@@ -25,7 +25,7 @@
# - nats-server -js
#
# Window 1: `python server_vllm.py`. Wait for log "Starting endpoint".
# Window 2: `dynamo-run out=dyn://dynamo.backend.generate`
# Window 2: `dynamo-run out=dyn
import argparse
import asyncio
......
......@@ -31,8 +31,8 @@ async def worker(runtime: DistributedRuntime):
# create client
client = await endpoint.client()
# list the endpoints
print(client.endpoint_ids())
# list the endpoint instances
print(client.instance_ids())
# issue request
stream = await client.generate(Request(data="hello world").model_dump_json())
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use futures::StreamExt;
use once_cell::sync::OnceCell;
......@@ -589,16 +577,19 @@ impl EtcdClient {
#[pymethods]
impl Client {
/// Get list of current endpoints
fn endpoint_ids(&self) -> Vec<i64> {
self.router.client.endpoint_ids()
/// Get list of current instances.
/// Replaces endpoint_ids.
fn instance_ids(&self) -> Vec<i64> {
self.router.client.instance_ids()
}
fn wait_for_endpoints<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
/// Wait for an instance to be available for work.
/// Replaces wait_for_endpoints.
fn wait_for_instances<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.router.client.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner
.wait_for_endpoints()
.wait_for_instances()
.await
.map(|v| v.into_iter().map(|cei| cei.id()).collect::<Vec<i64>>())
.map_err(to_pyerr)
......@@ -669,12 +660,12 @@ impl Client {
}
/// Directly send a request to a specific endpoint.
#[pyo3(signature = (request, endpoint_id, annotated=DEFAULT_ANNOTATED_SETTING))]
#[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING))]
fn direct<'p>(
&self,
py: Python<'p>,
request: PyObject,
endpoint_id: i64,
instance_id: i64,
annotated: Option<bool>,
) -> PyResult<Bound<'p, PyAny>> {
let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?;
......@@ -685,7 +676,7 @@ impl Client {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client
.direct(request.into(), endpoint_id)
.direct(request.into(), instance_id)
.await
.map_err(to_pyerr)?;
......
......@@ -42,12 +42,16 @@ pub mod service_v2;
pub use async_trait::async_trait;
pub use axum;
use discovery::ModelEntry;
pub use error::ServiceHttpError;
pub use metrics::Metrics;
use crate::types::openai::{
use crate::{
kv_router::KvRouter,
types::openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
},
};
use std::{
collections::HashMap,
......@@ -208,6 +212,8 @@ pub struct DeploymentState {
embeddings_engines: Arc<Mutex<ModelEngines<OpenAIEmbeddingsStreamingEngine>>>,
metrics: Arc<Metrics>,
sse_keep_alive: Option<Duration>,
entries: Arc<Mutex<HashMap<String, ModelEntry>>>,
kv_choosers: Arc<Mutex<HashMap<String, Arc<KvRouter>>>>,
}
impl DeploymentState {
......@@ -218,6 +224,8 @@ impl DeploymentState {
embeddings_engines: Arc::new(Mutex::new(ModelEngines::default())),
metrics: Arc::new(Metrics::default()),
sse_keep_alive: None,
entries: Arc::new(Mutex::new(HashMap::new())),
kv_choosers: Arc::new(Mutex::new(HashMap::new())),
}
}
......@@ -235,7 +243,7 @@ impl DeploymentState {
.ok_or(ServiceHttpError::ModelNotFound(model.to_string()))
}
fn get_completions_engine(
pub fn get_completions_engine(
&self,
model: &str,
) -> Result<OpenAICompletionsStreamingEngine, ServiceHttpError> {
......@@ -247,7 +255,7 @@ impl DeploymentState {
.ok_or(ServiceHttpError::ModelNotFound(model.to_string()))
}
fn get_chat_completions_engine(
pub fn get_chat_completions_engine(
&self,
model: &str,
) -> Result<OpenAIChatCompletionsStreamingEngine, ServiceHttpError> {
......
......@@ -5,17 +5,16 @@ use std::sync::Arc;
use anyhow::Context as _;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::Receiver;
use tokio::sync::{mpsc::Receiver, Notify};
use dynamo_runtime::{
component::{self, Component, ComponentEndpointInfo},
component::{self, Component, Instance},
pipeline::{
network::egress::push_router::PushRouter, ManyOut, Operator, RouterMode, SegmentSource,
ServiceBackend, SingleIn, Source,
},
protocols::{self, annotated::Annotated},
slug::Slug,
traits::DistributedRuntimeProvider as _,
transports::etcd::{self, KeyValue, WatchEvent},
DistributedRuntime,
};
......@@ -66,15 +65,13 @@ impl ModelEntry {
/// This does not touch it's fields so you may need to call move_from_nats on it.
pub async fn load_mdc(
&self,
endpoint_id: protocols::Endpoint,
etcd_client: &etcd::Client,
) -> anyhow::Result<ModelDeploymentCard> {
let kvstore: Box<dyn KeyValueStore> =
Box::new(EtcdStorage::new(etcd_client.clone(), endpoint_id));
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let card_key = ModelDeploymentCard::service_name_slug(&self.name);
match card_store
.load::<ModelDeploymentCard>(model_card::BUCKET_NAME, &card_key)
.load::<ModelDeploymentCard>(model_card::ROOT_PATH, &card_key)
.await
{
Ok(Some(mdc)) => Ok(mdc),
......@@ -99,9 +96,9 @@ impl ModelNetworkName {
/// It looks like this:
/// ns.cp.ep-694d967ca5efd804
fn from_parts(namespace: &str, component: &str, endpoint: &str, lease_id: i64) -> Self {
ModelNetworkName(
Slug::slugify(&format!("{namespace}.{component}.{endpoint}-{lease_id:x}")).to_string(),
)
let model_root = component::MODEL_ROOT_PATH;
let slug = Slug::slugify(&format!("{namespace}.{component}.{endpoint}-{lease_id:x}"));
ModelNetworkName(format!("{model_root}/{slug}"))
}
// We can't do From<&component::Endpoint> here because we also need the lease_id
......@@ -134,17 +131,21 @@ impl ModelNetworkName {
/// TODO We have potentially two for each endpoint, one Chat and one Completion.
pub async fn load_mdc(
&self,
endpoint_id: protocols::Endpoint,
etcd_client: &etcd::Client,
) -> anyhow::Result<ModelDeploymentCard> {
let entry = self.load_entry(etcd_client).await?;
entry.load_mdc(endpoint_id, etcd_client).await
entry.load_mdc(etcd_client).await
}
}
impl From<&ComponentEndpointInfo> for ModelNetworkName {
fn from(cei: &ComponentEndpointInfo) -> Self {
Self::from_parts(&cei.namespace, &cei.component, &cei.endpoint, cei.lease_id)
impl From<&Instance> for ModelNetworkName {
fn from(cei: &Instance) -> Self {
Self::from_parts(
&cei.namespace,
&cei.component,
&cei.endpoint,
cei.instance_id,
)
}
}
......@@ -155,41 +156,37 @@ impl std::fmt::Display for ModelNetworkName {
}
pub struct ModelWatcher {
prefix: String,
manager: ModelManager,
drt: DistributedRuntime,
router_mode: RouterMode,
kv_chooser: Option<Arc<KvRouter>>,
notify_on_model: Notify,
}
impl ModelWatcher {
pub async fn new(
component: Component,
runtime: DistributedRuntime,
model_manager: ModelManager,
network_prefix: &str,
router_mode: RouterMode,
) -> anyhow::Result<ModelWatcher> {
let kv_chooser = if router_mode.is_kv_routing() {
let selector = Box::new(DefaultWorkerSelector {});
let chooser = KvRouter::new(
component.clone(),
crate::DEFAULT_KV_BLOCK_SIZE,
Some(selector),
)
.await?;
Some(Arc::new(chooser))
} else {
None
};
Ok(Self {
prefix: network_prefix.to_string(),
manager: model_manager,
drt: component.drt().clone(),
drt: runtime,
router_mode,
kv_chooser,
notify_on_model: Notify::new(),
})
}
/// Wait until we have at least one chat completions model and return it's name.
pub async fn wait_for_chat_model(&self) -> String {
// Loop in case it gets added and immediately deleted
loop {
if let Some(model_name) = self.manager.list_chat_completions_models().first() {
return model_name.to_owned();
}
self.notify_on_model.notified().await
}
}
pub async fn watch(self: Arc<Self>, mut events_rx: Receiver<WatchEvent>) {
tracing::debug!("model watcher started");
......@@ -199,7 +196,14 @@ impl ModelWatcher {
let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
Ok(model_entry) => model_entry,
Err(err) => {
tracing::error!(%err, ?kv, "Invalid JSON in model entry");
match kv.value_str() {
Ok(value) => {
tracing::error!(%err, value, "Invalid JSON in model entry")
}
Err(value_str_err) => {
tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
}
}
continue;
}
};
......@@ -208,12 +212,14 @@ impl ModelWatcher {
service_name = model_entry.name,
"New endpoint for existing model"
);
self.notify_on_model.notify_waiters();
continue;
}
match self.clone().handle_put(&model_entry).await {
match self.clone().handle_put(&kv, &model_entry).await {
Ok(()) => {
tracing::info!(model_name = model_entry.name, "added model");
self.notify_on_model.notify_waiters();
}
Err(e) => {
tracing::error!(%e, "error adding model {}", model_entry.name);
......@@ -232,39 +238,49 @@ impl ModelWatcher {
}
}
async fn handle_delete(self: Arc<Self>, kv: &KeyValue) -> anyhow::Result<&str> {
/// Returns the name of the model we just deleted
async fn handle_delete(self: Arc<ModelWatcher>, kv: &KeyValue) -> anyhow::Result<String> {
let key = kv.key_str()?;
tracing::debug!(key, "removing model");
let model_name = key.trim_start_matches(&self.prefix);
let model_entry = match self.manager.state.entries.lock().unwrap().remove(key) {
Some(entry) => entry,
None => {
anyhow::bail!("Missing ModelEntry for {key}");
}
};
let model_name = &model_entry.name;
tracing::debug!(model_name, "removing model");
// Ignore the errors because model could be either type
let _ = self.manager.remove_chat_completions_model(model_name);
let _ = self.manager.remove_completions_model(model_name);
let _ = self.manager.remove_embeddings_model(model_name);
Ok(model_name)
// We own model_entry now so take ownership of the name
Ok(model_entry.name)
}
// Handles a PUT event from etcd, this usually means adding a new model to the list of served
// models.
//
// If this method errors, for the near term, we will delete the offending key.
async fn handle_put(self: Arc<ModelWatcher>, model_entry: &ModelEntry) -> anyhow::Result<()> {
async fn handle_put(
self: Arc<ModelWatcher>,
kv: &KeyValue,
model_entry: &ModelEntry,
) -> anyhow::Result<()> {
let key = kv.key_str()?;
let endpoint_id = model_entry.endpoint.clone();
let client = self
let component = self
.drt
.namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?
.endpoint(&endpoint_id.name)
.client()
.await?;
.component(&endpoint_id.component)?;
let client = component.endpoint(&endpoint_id.name).client().await?;
let Some(etcd_client) = self.drt.etcd_client() else {
// Should be impossible because we only get here on an etcd event
anyhow::bail!("Missing etcd_client");
};
let card = match model_entry.load_mdc(endpoint_id, &etcd_client).await {
let card = match model_entry.load_mdc(&etcd_client).await {
Ok(card) => {
tracing::debug!(card.display_name, "adding model");
Some(card)
......@@ -275,6 +291,14 @@ impl ModelWatcher {
None
}
};
// We need to save the entry to know what the model is called when we delete it
self.manager
.state
.entries
.lock()
.unwrap()
.insert(key.to_string(), model_entry.clone());
match model_entry.model_type {
ModelType::Backend => {
// A Backend model expects pre-processed requests meaning it's up to us whether we
......@@ -305,10 +329,8 @@ impl ModelWatcher {
ServiceBackend::from_engine(Arc::new(router))
}
RouterMode::KV => {
let Some(kv_chooser) = self.kv_chooser.clone() else {
anyhow::bail!("KV routing mode with no chooser, should be unreachable");
};
let kv_push_router = KvPushRouter::new(router, kv_chooser);
let chooser = self.kv_chooser_for(&model_entry.name, &component).await?;
let kv_push_router = KvPushRouter::new(router, chooser);
ServiceBackend::from_engine(Arc::new(kv_push_router))
}
};
......@@ -339,10 +361,8 @@ impl ModelWatcher {
ServiceBackend::from_engine(Arc::new(router))
}
RouterMode::KV => {
let Some(kv_chooser) = self.kv_chooser.clone() else {
anyhow::bail!("KV routing mode with no chooser, should be unreachable");
};
let kv_push_router = KvPushRouter::new(router, kv_chooser);
let chooser = self.kv_chooser_for(&model_entry.name, &component).await?;
let kv_push_router = KvPushRouter::new(router, chooser);
ServiceBackend::from_engine(Arc::new(kv_push_router))
}
};
......@@ -392,4 +412,37 @@ impl ModelWatcher {
Ok(())
}
async fn kv_chooser_for(
&self,
model_name: &str,
component: &Component,
) -> anyhow::Result<Arc<KvRouter>> {
if let Some(kv_chooser) = self
.manager
.state
.kv_choosers
.lock()
.unwrap()
.get(model_name)
{
// Return early to avoid holding the lock during the await later
return Ok(Arc::clone(kv_chooser));
}
let selector = Box::new(DefaultWorkerSelector {});
let chooser = KvRouter::new(
component.clone(),
crate::DEFAULT_KV_BLOCK_SIZE,
Some(selector),
)
.await?;
let new_kv_chooser = Arc::new(chooser);
self.manager
.state
.kv_choosers
.lock()
.unwrap()
.insert(model_name.to_string(), new_kv_chooser.clone());
Ok(new_kv_chooser)
}
}
......@@ -19,7 +19,7 @@ use std::time::Duration;
use async_stream::stream;
use async_trait::async_trait;
use dynamo_runtime::{protocols::Endpoint, slug::Slug, transports::etcd::Client};
use dynamo_runtime::{slug::Slug, transports::etcd::Client};
use etcd_client::{EventType, PutOptions, WatchOptions};
use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
......@@ -27,12 +27,11 @@ use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
#[derive(Clone)]
pub struct EtcdStorage {
client: Client,
endpoint: Endpoint,
}
impl EtcdStorage {
pub fn new(client: Client, endpoint: Endpoint) -> Self {
Self { client, endpoint }
pub fn new(client: Client) -> Self {
Self { client }
}
}
......@@ -55,7 +54,6 @@ impl KeyValueStore for EtcdStorage {
) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> {
Ok(Some(Box::new(EtcdBucket {
client: self.client.clone(),
endpoint: self.endpoint.clone(),
bucket_name: bucket_name.to_string(),
})))
}
......@@ -63,7 +61,6 @@ impl KeyValueStore for EtcdStorage {
pub struct EtcdBucket {
client: Client,
endpoint: Endpoint,
bucket_name: String,
}
......@@ -85,7 +82,7 @@ impl KeyValueBucket for EtcdBucket {
}
async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> {
let k = make_key(&self.endpoint, &self.bucket_name, key);
let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd get: {k}");
let mut kvs = self
......@@ -113,7 +110,7 @@ impl KeyValueBucket for EtcdBucket {
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>
{
let k = make_key(&self.endpoint, &self.bucket_name, "");
let k = make_key(&self.bucket_name, "");
tracing::trace!("etcd watch: {k}");
let (_watcher, mut watch_stream) = self
.client
......@@ -136,7 +133,7 @@ impl KeyValueBucket for EtcdBucket {
}
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> {
let k = make_key(&self.endpoint, &self.bucket_name, "");
let k = make_key(&self.bucket_name, "");
tracing::trace!("etcd entries: {k}");
let resp = self
......@@ -158,7 +155,7 @@ impl KeyValueBucket for EtcdBucket {
impl EtcdBucket {
async fn create(&self, key: &str, value: &str) -> Result<StorageOutcome, StorageError> {
let k = make_key(&self.endpoint, &self.bucket_name, key);
let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd create: {k}");
// Does it already exists? For 'create' it shouldn't.
......@@ -195,7 +192,7 @@ impl EtcdBucket {
revision: u64,
) -> Result<StorageOutcome, StorageError> {
let version = revision;
let k = make_key(&self.endpoint, &self.bucket_name, key);
let k = make_key(&self.bucket_name, key);
tracing::trace!("etcd update: {k}");
let kvs = self
......@@ -237,9 +234,8 @@ impl EtcdBucket {
}
}
fn make_key(endpoint: &Endpoint, bucket_name: &str, key: &str) -> String {
fn make_key(bucket_name: &str, key: &str) -> String {
[
endpoint.namespace.to_string(),
Slug::slugify(bucket_name).to_string(),
Slug::slugify(key).to_string(),
]
......
......@@ -17,7 +17,7 @@ use std::sync::Arc;
use anyhow::Result;
use dynamo_runtime::{
component::{Component, EndpointSource},
component::{Component, InstanceSource},
pipeline::{
async_trait, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, PushRouter,
ResponseStream, SingleIn,
......@@ -199,11 +199,11 @@ impl AsyncEngine<SingleIn<BackendInput>, ManyOut<Annotated<LLMEngineOutput>>, Er
&self,
request: SingleIn<BackendInput>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>, Error> {
match &self.inner.client.endpoints {
EndpointSource::Static => self.inner.r#static(request).await,
EndpointSource::Dynamic(_) => {
let worker_id = self.chooser.find_best_match(&request.token_ids).await?;
self.inner.direct(request, worker_id).await
match &self.inner.client.instances {
InstanceSource::Static => self.inner.r#static(request).await,
InstanceSource::Dynamic(_) => {
let instance_id = self.chooser.find_best_match(&request.token_ids).await?;
self.inner.direct(request, instance_id).await
}
}
}
......
......@@ -134,13 +134,11 @@ impl LocalModel {
self.card.move_to_nats(nats_client.clone()).await?;
// Publish the Model Deployment Card to etcd
let endpoint_id = endpoint.id();
let kvstore: Box<dyn KeyValueStore> =
Box::new(EtcdStorage::new(etcd_client.clone(), endpoint_id.clone()));
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let key = self.card.slug().to_string();
card_store
.publish(model_card::BUCKET_NAME, None, &key, &mut self.card)
.publish(model_card::ROOT_PATH, None, &key, &mut self.card)
.await?;
// Publish our ModelEntry to etcd. This allows ingress to find the model card.
......@@ -149,7 +147,7 @@ impl LocalModel {
tracing::debug!("Registering with etcd as {network_name}");
let model_registration = ModelEntry {
name: self.service_name().to_string(),
endpoint: endpoint_id.clone(),
endpoint: endpoint.id(),
model_type,
};
etcd_client
......@@ -172,7 +170,7 @@ impl LocalModel {
// A static component is necessarily unique, it cannot register
return Ok(());
};
for endpoint_info in component.list_endpoints().await? {
for endpoint_info in component.list_instances().await? {
let network_name: ModelNetworkName = (&endpoint_info).into();
let entry = network_name.load_entry(&etcd_client).await?;
if entry.name != model_name {
......
......@@ -23,7 +23,7 @@ pub use model::ModelDeploymentCard;
// network module?
/// Identify model deployment cards in the key-value store
pub const BUCKET_NAME: &str = "mdc";
pub const ROOT_PATH: &str = "mdc";
/// Delete model deployment cards that haven't been re-published after this long.
/// Cleans up if the worker stopped.
......
......@@ -44,8 +44,6 @@ use crate::gguf::{Content, ContentConfig, ModelConfigLike};
use crate::key_value_store::Versioned;
use crate::protocols::TokenIdType;
pub const BUCKET_NAME: &str = "mdc";
/// Delete model deployment cards that haven't been re-published after this long.
/// Cleans up if the worker stopped.
pub const BUCKET_TTL: Duration = Duration::from_secs(5 * 60);
......
......@@ -34,7 +34,7 @@ async fn app(runtime: Runtime) -> Result<()> {
.endpoint("generate")
.client()
.await?;
client.wait_for_endpoints().await?;
client.wait_for_instances().await?;
let router =
PushRouter::<String, Annotated<String>>::from_client(client, Default::default()).await?;
......
......@@ -35,7 +35,7 @@ async fn app(runtime: Runtime) -> Result<()> {
let client = component.endpoint("generate").client().await?;
client.wait_for_endpoints().await?;
client.wait_for_instances().await?;
let router =
PushRouter::<String, Annotated<String>>::from_client(client, Default::default()).await?;
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! The [Component] module defines the top-level API for building distributed applications.
//!
......@@ -69,7 +57,14 @@ mod namespace;
mod registry;
pub mod service;
pub use client::{Client, EndpointSource};
pub use client::{Client, InstanceSource};
/// The root etcd path where each instance registers itself in etcd.
/// An instance is namespace+component+endpoint+lease_id and must be unique.
pub const INSTANCE_ROOT_PATH: &str = "instances";
/// The root etcd path for ModelEntry
pub const MODEL_ROOT_PATH: &str = "models";
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")]
......@@ -89,17 +84,17 @@ pub struct Registry {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComponentEndpointInfo {
pub struct Instance {
pub component: String,
pub endpoint: String,
pub namespace: String,
pub lease_id: i64,
pub instance_id: i64,
pub transport: TransportType,
}
impl ComponentEndpointInfo {
impl Instance {
pub fn id(&self) -> i64 {
self.lease_id
self.instance_id
}
}
......@@ -150,8 +145,11 @@ impl RuntimeProvider for Component {
}
impl Component {
pub fn etcd_path(&self) -> String {
format!("{}/components/{}", self.namespace.name(), self.name)
/// The component part of an instance path in etcd.
pub fn etcd_root(&self) -> String {
let ns = self.namespace.name();
let cp = &self.name;
format!("{INSTANCE_ROOT_PATH}/{ns}/{cp}")
}
pub fn service_name(&self) -> String {
......@@ -179,21 +177,21 @@ impl Component {
}
}
pub async fn list_endpoints(&self) -> anyhow::Result<Vec<ComponentEndpointInfo>> {
pub async fn list_instances(&self) -> anyhow::Result<Vec<Instance>> {
let Some(etcd_client) = self.drt.etcd_client() else {
return Ok(vec![]);
};
let mut out = vec![];
// The extra slash is important to only list exact component matches, not substrings.
for kv in etcd_client
.kv_get_prefix(format!("{}/", self.etcd_path()))
.kv_get_prefix(format!("{}/", self.etcd_root()))
.await?
{
let val = match serde_json::from_slice::<ComponentEndpointInfo>(kv.value()) {
let val = match serde_json::from_slice::<Instance>(kv.value()) {
Ok(val) => val,
Err(err) => {
anyhow::bail!(
"Error converting etcd response to ComponentEndpointInfo: {err}. {}",
"Error converting etcd response to Instance: {err}. {}",
kv.value_str()?
);
}
......@@ -276,15 +274,20 @@ impl Endpoint {
format!("{}/{}", self.component.path(), self.name)
}
pub fn etcd_path(&self) -> String {
format!("{}/{}", self.component.etcd_path(), self.name)
/// The endpoint part of an instance path in etcd
pub fn etcd_root(&self) -> String {
let component_path = self.component.etcd_root();
let endpoint_name = &self.name;
format!("{component_path}/{endpoint_name}")
}
pub fn etcd_path_with_id(&self, lease_id: i64) -> String {
/// The fully path of an instance in etcd
pub fn etcd_path(&self, lease_id: i64) -> String {
let endpoint_root = self.etcd_root();
if self.is_static {
self.etcd_path()
endpoint_root
} else {
format!("{}:{:x}", self.etcd_path(), lease_id)
format!("{endpoint_root}:{lease_id:x}")
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment