Unverified Commit a1a10365 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Split PushRouter from Client (#817)

In a distributed system we don't know if the remote workers need pre-processing done ingress-side or not. Previously Client required us to decide this before discovering the remote endpoints, which was fine because pre-processing was worker-side.

As part of moving pre-processing back to ingress-side we need to split this into two steps:
- Client discovers the endpoints, and (later PR) will fetch their Model Deployment Card.
- PushRouter will use the Model Deployment Card to decide if they need pre-processing or not, which affects the types of the generic parameters.

Part of #743
parent 97bf8184
...@@ -1619,6 +1619,7 @@ dependencies = [ ...@@ -1619,6 +1619,7 @@ dependencies = [
"dynamo-runtime", "dynamo-runtime",
"either", "either",
"erased-serde", "erased-serde",
"etcd-client",
"futures", "futures",
"galil-seiferas", "galil-seiferas",
"ggus", "ggus",
......
...@@ -55,6 +55,7 @@ chrono = { version = "0.4", default-features = false, features = ["alloc", "std" ...@@ -55,6 +55,7 @@ chrono = { version = "0.4", default-features = false, features = ["alloc", "std"
derive_builder = { version = "0.20" } derive_builder = { version = "0.20" }
derive-getters = { version = "0.5" } derive-getters = { version = "0.5" }
either = { version = "1.13", features = ["serde"] } either = { version = "1.13", features = ["serde"] }
etcd-client = { version = "0.14" }
futures = { version = "0.3" } futures = { version = "0.3" }
hf-hub = { version = "0.4.2", default-features = false, features = ["tokio", "rustls-tls"] } hf-hub = { version = "0.4.2", default-features = false, features = ["tokio", "rustls-tls"] }
humantime = { version = "2.2.0" } humantime = { version = "2.2.0" }
......
...@@ -127,11 +127,7 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -127,11 +127,7 @@ async fn app(runtime: Runtime) -> Result<()> {
tracing::debug!("Creating unique instance of Count at {key}"); tracing::debug!("Creating unique instance of Count at {key}");
drt.etcd_client() drt.etcd_client()
.expect("Unreachable because of DistributedRuntime::from_settings above") .expect("Unreachable because of DistributedRuntime::from_settings above")
.kv_create( .kv_create(key, serde_json::to_vec_pretty(&config)?, None)
key,
serde_json::to_vec_pretty(&config)?,
Some(drt.primary_lease().unwrap().id()),
)
.await .await
.context("Unable to create unique instance of Count; possibly one already exists")?; .context("Unable to create unique instance of Count; possibly one already exists")?;
......
...@@ -18,7 +18,7 @@ use std::path::PathBuf; ...@@ -18,7 +18,7 @@ use std::path::PathBuf;
use std::str::FromStr; use std::str::FromStr;
use clap::ValueEnum; use clap::ValueEnum;
use dynamo_runtime::component::RouterMode as RuntimeRouterMode; use dynamo_runtime::pipeline::RouterMode as RuntimeRouterMode;
/// Required options depend on the in and out choices /// Required options depend on the in and out choices
#[derive(clap::Parser, Debug, Clone)] #[derive(clap::Parser, Debug, Clone)]
......
...@@ -15,12 +15,11 @@ ...@@ -15,12 +15,11 @@
use std::pin::Pin; use std::pin::Pin;
use crate::{flags::RouterMode, EngineConfig, Flags};
use dynamo_llm::{ use dynamo_llm::{
backend::Backend, backend::{Backend, ExecutionContext},
backend::ExecutionContext,
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
model_card::model::ModelDeploymentCard, http::service::discovery::ModelNetworkName,
model_card::ModelDeploymentCard,
preprocessor::OpenAIPreprocessor, preprocessor::OpenAIPreprocessor,
protocols::common::llm_backend::{BackendInput, BackendOutput}, protocols::common::llm_backend::{BackendInput, BackendOutput},
types::{ types::{
...@@ -33,11 +32,15 @@ use dynamo_llm::{ ...@@ -33,11 +32,15 @@ use dynamo_llm::{
}; };
use dynamo_runtime::{ use dynamo_runtime::{
engine::{AsyncEngineStream, Data}, engine::{AsyncEngineStream, Data},
pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source}, pipeline::{
Context, ManyOut, Operator, PushRouter, ServiceBackend, ServiceFrontend, SingleIn, Source,
},
DistributedRuntime, Runtime, DistributedRuntime, Runtime,
}; };
use std::sync::Arc; use std::sync::Arc;
use crate::{flags::RouterMode, EngineConfig, Flags};
/// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine. /// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine.
pub async fn prepare_engine( pub async fn prepare_engine(
runtime: Runtime, runtime: Runtime,
...@@ -53,22 +56,40 @@ pub async fn prepare_engine( ...@@ -53,22 +56,40 @@ pub async fn prepare_engine(
.component(endpoint_id.component.clone())? .component(endpoint_id.component.clone())?
.endpoint(endpoint_id.name.clone()); .endpoint(endpoint_id.name.clone());
let mut client = endpoint.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>().await?; let client = endpoint.client().await?;
let router = match &flags.router_mode {
match &flags.router_mode {
RouterMode::Random | RouterMode::RoundRobin => { RouterMode::Random | RouterMode::RoundRobin => {
client.set_router_mode(flags.router_mode.into());
tracing::info!("Waiting for remote model.."); tracing::info!("Waiting for remote model..");
client.wait_for_endpoints().await?;
tracing::info!("Model discovered"); // We then use the ModelDeploymentCard's `requires_preprocessing`
// field to decide what kind of PushRouter to make.
let remote_endpoints = client.wait_for_endpoints().await?;
debug_assert!(!remote_endpoints.is_empty());
tracing::info!(count = remote_endpoints.len(), "Model(s) discovered");
let network_name: ModelNetworkName = (&remote_endpoints[0]).into();
let Some(etcd_client) = distributed_runtime.etcd_client() else {
anyhow::bail!("Cannot run distributed components without etcd");
};
let mdc = network_name.load_mdc(endpoint_id, etcd_client).await?;
if mdc.requires_preprocessing {
// Note requires_preprocessing is never true in our code right now
todo!("Ingress-side pre-processing not supported yet");
} else {
PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client(client, flags.router_mode.into())
.await?
}
} }
RouterMode::KV => todo!(), RouterMode::KV => todo!(),
} };
// The service_name isn't used for text chat outside of logs, // The service_name isn't used for text chat outside of logs,
// so use the path. That avoids having to listen on etcd for model registration. // so use the path. That avoids having to listen on etcd for model registration.
let service_name = endpoint.subject(); let service_name = endpoint.subject();
Ok((service_name, Arc::new(client), false)) Ok((service_name, Arc::new(router), false))
} }
EngineConfig::StaticFull { EngineConfig::StaticFull {
service_name, service_name,
......
...@@ -18,9 +18,9 @@ use std::sync::Arc; ...@@ -18,9 +18,9 @@ use std::sync::Arc;
use dynamo_llm::{ use dynamo_llm::{
backend::Backend, backend::Backend,
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
http::service::discovery::ModelEntry, http::service::discovery::{ModelEntry, ModelNetworkName},
key_value_store::{KeyValueStore, KeyValueStoreManager, NATSStorage}, key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
model_card::{BUCKET_NAME, BUCKET_TTL}, model_card,
model_type::ModelType, model_type::ModelType,
preprocessor::OpenAIPreprocessor, preprocessor::OpenAIPreprocessor,
types::{ types::{
...@@ -49,14 +49,14 @@ pub async fn run( ...@@ -49,14 +49,14 @@ pub async fn run(
let etcd_client = distributed_runtime.etcd_client(); let etcd_client = distributed_runtime.etcd_client();
let (ingress, service_name, mut card) = match engine_config { let (ingress, service_name, mut card, requires_preprocessing) = match engine_config {
EngineConfig::StaticFull { EngineConfig::StaticFull {
service_name, service_name,
engine, engine,
card, card,
} => { } => {
let engine = Arc::new(StreamingEngineAdapter::new(engine)); let engine = Arc::new(StreamingEngineAdapter::new(engine));
(Ingress::for_engine(engine)?, service_name, card) (Ingress::for_engine(engine)?, service_name, card, false)
} }
EngineConfig::StaticCore { EngineConfig::StaticCore {
service_name, service_name,
...@@ -81,7 +81,8 @@ pub async fn run( ...@@ -81,7 +81,8 @@ pub async fn run(
.link(preprocessor.backward_edge())? .link(preprocessor.backward_edge())?
.link(frontend)?; .link(frontend)?;
(Ingress::for_pipeline(pipeline)?, service_name, card) // TODO: switch last 'false' to 'true' once we have ingress-side pre-processing
(Ingress::for_pipeline(pipeline)?, service_name, card, false)
} }
EngineConfig::Dynamic(_) => { EngineConfig::Dynamic(_) => {
anyhow::bail!("Cannot use endpoint for both in and out"); anyhow::bail!("Cannot use endpoint for both in and out");
...@@ -104,30 +105,30 @@ pub async fn run( ...@@ -104,30 +105,30 @@ pub async fn run(
.await? .await?
.endpoint(&endpoint_id.name); .endpoint(&endpoint_id.name);
let nats_client = distributed_runtime.nats_client();
card.move_to_nats(nats_client.clone()).await?;
let kvstore: Box<dyn KeyValueStore> =
Box::new(NATSStorage::new(nats_client.clone(), endpoint_id));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
card.requires_preprocessing = false;
card_store.publish_until_cancelled(
cancel_token.clone(),
BUCKET_NAME.to_string(),
Some(BUCKET_TTL),
BUCKET_TTL / 2,
card.slug().to_string(),
*card.clone(),
);
if let Some(etcd_client) = etcd_client { if let Some(etcd_client) = etcd_client {
let network_name = endpoint.subject_to(etcd_client.lease_id()); // Store model config files in NATS object store
let nats_client = distributed_runtime.nats_client();
card.move_to_nats(nats_client.clone()).await?;
// Publish the Model Deployment Card to etcd
let kvstore: Box<dyn KeyValueStore> =
Box::new(EtcdStorage::new(etcd_client.clone(), endpoint_id));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
card.requires_preprocessing = requires_preprocessing; // Not used yet. Soon.
let key = card.slug().to_string();
card_store
.publish(model_card::BUCKET_NAME, None, &key, &mut *card.clone())
.await?;
// Publish our ModelEntry to etcd. This allows ingress to find the model card.
// (Why don't we put the model card directly under this key?)
let network_name = ModelNetworkName::from_local(&endpoint, etcd_client.lease_id());
tracing::debug!("Registering with etcd as {network_name}"); tracing::debug!("Registering with etcd as {network_name}");
etcd_client etcd_client
.kv_create( .kv_create(
network_name.clone(), network_name.to_string(),
serde_json::to_vec_pretty(&model_registration)?, serde_json::to_vec_pretty(&model_registration)?,
Some(etcd_client.lease_id()), None, // use primary lease
) )
.await?; .await?;
} }
...@@ -140,8 +141,12 @@ pub async fn run( ...@@ -140,8 +141,12 @@ pub async fn run(
_ = cancel_token.cancelled() => { _ = cancel_token.cancelled() => {
} }
} }
// Cleanup on shutdown // Cleanup on shutdown
if let Err(err) = card.delete_from_nats(nats_client).await { if let Err(err) = card
.delete_from_nats(distributed_runtime.nats_client())
.await
{
tracing::error!(%err, "delete_from_nats error on shutdown"); tracing::error!(%err, "delete_from_nats error on shutdown");
} }
Ok(()) Ok(())
......
...@@ -180,9 +180,19 @@ pub async fn run( ...@@ -180,9 +180,19 @@ pub async fn run(
} }
}; };
// If we are in a distributed system, we need to know our component upfront
let dyn_input = match &in_opt { let dyn_input = match &in_opt {
Input::Endpoint(endpoint_path) => { Input::Endpoint(endpoint_path) => {
if model_path.as_ref().map(|mp| mp.is_file()).unwrap_or(false)
&& flags.model_config.is_none()
{
// TODO We need to convert tokenizer extract from GGUF file into something we can
// publish to NATS. Ideally `tokenizer.json` directly, but otherwise an
// intermediate format.
tracing::error!("Serving GGUF files in a distributed system requires `--model-config <hf-repo-dir>` so that we can find the tokenzier config");
return Ok(());
}
// If we are in a distributed system, we need to know our component upfront
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let endpoint_id: Endpoint = endpoint_path.parse()?; let endpoint_id: Endpoint = endpoint_path.parse()?;
Some(DynInput { Some(DynInput {
...@@ -216,7 +226,10 @@ pub async fn run( ...@@ -216,7 +226,10 @@ pub async fn run(
"out=echo_core need to find the tokenizer. Pass flag --model-path <path>" "out=echo_core need to find the tokenizer. Pass flag --model-path <path>"
); );
}; };
card.requires_preprocessing = true;
// TODO: Switch to `true` once pre-processing moves ingress side
card.requires_preprocessing = false;
EngineConfig::StaticCore { EngineConfig::StaticCore {
service_name: card.service_name.clone(), service_name: card.service_name.clone(),
engine: dynamo_llm::engines::make_engine_core(), engine: dynamo_llm::engines::make_engine_core(),
......
...@@ -1049,6 +1049,7 @@ dependencies = [ ...@@ -1049,6 +1049,7 @@ dependencies = [
"dynamo-runtime", "dynamo-runtime",
"either", "either",
"erased-serde", "erased-serde",
"etcd-client",
"futures", "futures",
"galil-seiferas", "galil-seiferas",
"ggus", "ggus",
......
...@@ -146,7 +146,7 @@ struct Endpoint { ...@@ -146,7 +146,7 @@ struct Endpoint {
#[pyclass] #[pyclass]
#[derive(Clone)] #[derive(Clone)]
struct Client { struct Client {
inner: rs::component::Client<serde_json::Value, serde_json::Value>, router: rs::pipeline::PushRouter<serde_json::Value, serde_json::Value>,
} }
#[pyclass] #[pyclass]
...@@ -445,11 +445,17 @@ impl Endpoint { ...@@ -445,11 +445,17 @@ impl Endpoint {
fn client<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> { fn client<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone(); let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let client = inner let client = inner.client().await.map_err(to_pyerr)?;
.client::<serde_json::Value, serde_json::Value>() let push_router =
rs::pipeline::PushRouter::<serde_json::Value, serde_json::Value>::from_client(
client,
Default::default(),
)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
Ok(Client { inner: client }) Ok(Client {
router: push_router,
})
}) })
} }
...@@ -552,13 +558,17 @@ impl EtcdClient { ...@@ -552,13 +558,17 @@ impl EtcdClient {
impl Client { impl Client {
/// Get list of current endpoints /// Get list of current endpoints
fn endpoint_ids(&self) -> Vec<i64> { fn endpoint_ids(&self) -> Vec<i64> {
self.inner.endpoint_ids() self.router.client.endpoint_ids()
} }
fn wait_for_endpoints<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> { fn wait_for_endpoints<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone(); let inner = self.router.client.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner.wait_for_endpoints().await.map_err(to_pyerr) inner
.wait_for_endpoints()
.await
.map(|v| v.into_iter().map(|cei| cei.id()).collect::<Vec<i64>>())
.map_err(to_pyerr)
}) })
} }
...@@ -570,7 +580,7 @@ impl Client { ...@@ -570,7 +580,7 @@ impl Client {
request: PyObject, request: PyObject,
annotated: Option<bool>, annotated: Option<bool>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
if self.inner.is_static() { if self.router.client.is_static() {
self.r#static(py, request, annotated) self.r#static(py, request, annotated)
} else { } else {
self.random(py, request, annotated) self.random(py, request, annotated)
...@@ -589,7 +599,7 @@ impl Client { ...@@ -589,7 +599,7 @@ impl Client {
let annotated = annotated.unwrap_or(false); let annotated = annotated.unwrap_or(false);
let (tx, rx) = tokio::sync::mpsc::channel(32); let (tx, rx) = tokio::sync::mpsc::channel(32);
let client = self.inner.clone(); let client = self.router.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.round_robin(request.into()).await.map_err(to_pyerr)?; let stream = client.round_robin(request.into()).await.map_err(to_pyerr)?;
...@@ -613,7 +623,7 @@ impl Client { ...@@ -613,7 +623,7 @@ impl Client {
let annotated = annotated.unwrap_or(false); let annotated = annotated.unwrap_or(false);
let (tx, rx) = tokio::sync::mpsc::channel(32); let (tx, rx) = tokio::sync::mpsc::channel(32);
let client = self.inner.clone(); let client = self.router.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.random(request.into()).await.map_err(to_pyerr)?; let stream = client.random(request.into()).await.map_err(to_pyerr)?;
...@@ -638,7 +648,7 @@ impl Client { ...@@ -638,7 +648,7 @@ impl Client {
let annotated = annotated.unwrap_or(false); let annotated = annotated.unwrap_or(false);
let (tx, rx) = tokio::sync::mpsc::channel(32); let (tx, rx) = tokio::sync::mpsc::channel(32);
let client = self.inner.clone(); let client = self.router.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client let stream = client
...@@ -667,7 +677,7 @@ impl Client { ...@@ -667,7 +677,7 @@ impl Client {
let annotated = annotated.unwrap_or(false); let annotated = annotated.unwrap_or(false);
let (tx, rx) = tokio::sync::mpsc::channel(32); let (tx, rx) = tokio::sync::mpsc::channel(32);
let client = self.inner.clone(); let client = self.router.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.r#static(request.into()).await.map_err(to_pyerr)?; let stream = client.r#static(request.into()).await.map_err(to_pyerr)?;
......
...@@ -27,9 +27,9 @@ use llm_rs::{ ...@@ -27,9 +27,9 @@ use llm_rs::{
}, },
}; };
use dynamo_runtime::pipeline::{Operator, ServiceFrontend, Source}; use dynamo_runtime::pipeline::{
ManyOut, Operator, PushRouter, SegmentSink, ServiceFrontend, SingleIn, Source,
use dynamo_runtime::pipeline::{ManyOut, SegmentSink, SingleIn}; };
#[pyclass] #[pyclass]
pub(crate) struct OAIChatPreprocessor { pub(crate) struct OAIChatPreprocessor {
...@@ -76,13 +76,14 @@ impl OAIChatPreprocessor { ...@@ -76,13 +76,14 @@ impl OAIChatPreprocessor {
let builder = self.current.inner.endpoint_builder().handler(ingress); let builder = self.current.inner.endpoint_builder().handler(ingress);
let endpoint = Arc::new(self.next.inner.clone()); let endpoint = Arc::new(self.next.inner.clone());
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let client = Arc::new( let client = endpoint.client().await.map_err(to_pyerr)?;
endpoint let router = PushRouter::<BackendInput, Annotated<BackendOutput>>::from_client(
.client::<BackendInput, Annotated<BackendOutput>>() client,
.await Default::default(),
.map_err(to_pyerr)?, )
); .await
network.attach(client).map_err(to_pyerr)?; .map_err(to_pyerr)?;
network.attach(Arc::new(router)).map_err(to_pyerr)?;
builder.start().await.map_err(to_pyerr)?; builder.start().await.map_err(to_pyerr)?;
Ok(()) Ok(())
}) })
......
...@@ -44,6 +44,7 @@ bytes = { workspace = true } ...@@ -44,6 +44,7 @@ bytes = { workspace = true }
chrono = { workspace = true } chrono = { workspace = true }
derive_builder = {workspace = true } derive_builder = {workspace = true }
either = { workspace = true } either = { workspace = true }
etcd-client = { workspace = true }
futures = { workspace = true } futures = { workspace = true }
rand = { workspace = true } rand = { workspace = true }
prometheus = { workspace = true } prometheus = { workspace = true }
......
...@@ -15,13 +15,17 @@ ...@@ -15,13 +15,17 @@
use std::sync::Arc; use std::sync::Arc;
use anyhow::Context as _;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::Receiver; use tokio::sync::mpsc::Receiver;
use dynamo_runtime::{ use dynamo_runtime::{
component::{self, ComponentEndpointInfo},
pipeline::network::egress::push_router::PushRouter,
protocols::{self, annotated::Annotated}, protocols::{self, annotated::Annotated},
raise, raise,
transports::etcd::{KeyValue, WatchEvent}, slug::Slug,
transports::etcd::{self, KeyValue, WatchEvent},
DistributedRuntime, DistributedRuntime,
}; };
...@@ -31,7 +35,12 @@ use crate::protocols::openai::chat_completions::{ ...@@ -31,7 +35,12 @@ use crate::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
}; };
use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse}; use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse};
use crate::{
key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
model_card::{self, ModelDeploymentCard},
};
use tracing; use tracing;
/// [ModelEntry] is a struct that contains the information for the HTTP service to discover models /// [ModelEntry] is a struct that contains the information for the HTTP service to discover models
/// from the etcd cluster. /// from the etcd cluster.
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
...@@ -48,6 +57,90 @@ pub struct ModelEntry { ...@@ -48,6 +57,90 @@ pub struct ModelEntry {
pub model_type: ModelType, pub model_type: ModelType,
} }
impl ModelEntry {
pub async fn load_mdc(
&self,
endpoint_id: protocols::Endpoint,
etcd_client: etcd::Client,
) -> anyhow::Result<ModelDeploymentCard> {
let kvstore: Box<dyn KeyValueStore> =
Box::new(EtcdStorage::new(etcd_client.clone(), endpoint_id));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let card_key = ModelDeploymentCard::service_name_slug(&self.name);
match card_store
.load::<ModelDeploymentCard>(model_card::BUCKET_NAME, &card_key)
.await
{
Ok(Some(mdc)) => Ok(mdc),
Ok(None) => {
anyhow::bail!("Missing ModelDeploymentCard in etcd under key {card_key}");
}
Err(err) => {
anyhow::bail!(
"Error fetching ModelDeploymentCard from etcd under key {card_key}. {err}"
);
}
}
}
}
#[derive(Debug, Clone)]
pub struct ModelNetworkName(String);
impl ModelNetworkName {
/// Key to store this model entry in networked key-value store (etcd).
///
/// It looks like this:
/// ns.cp.ep-694d967ca5efd804
fn from_parts(namespace: &str, component: &str, endpoint: &str, lease_id: i64) -> Self {
ModelNetworkName(
Slug::slugify(&format!("{namespace}.{component}.{endpoint}-{lease_id:x}")).to_string(),
)
}
// We can't do From<&component::Endpoint> here because we also need the lease_id
pub fn from_local(endpoint: &component::Endpoint, lease_id: i64) -> Self {
Self::from_parts(
&endpoint.component().namespace().to_string(),
&endpoint.component().name(),
endpoint.name(),
lease_id,
)
}
pub async fn load_mdc(
&self,
endpoint_id: protocols::Endpoint,
etcd_client: etcd::Client,
) -> anyhow::Result<ModelDeploymentCard> {
let network_name = self;
let model_entries = etcd_client.kv_get(network_name.to_string(), None).await?;
if model_entries.is_empty() {
anyhow::bail!("No ModelEntry in etcd for key {network_name}");
}
let entry: ModelEntry =
serde_json::from_slice(model_entries[0].value()).with_context(|| {
format!(
"Error deserializing JSON. Key={network_name}. JSON={}",
model_entries[0].value_str().unwrap_or("INVALID UTF-8")
)
})?;
entry.load_mdc(endpoint_id, etcd_client).await
}
}
impl From<&ComponentEndpointInfo> for ModelNetworkName {
fn from(cei: &ComponentEndpointInfo) -> Self {
Self::from_parts(&cei.namespace, &cei.component, &cei.endpoint, cei.lease_id)
}
}
impl std::fmt::Display for ModelNetworkName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
pub struct ModelWatchState { pub struct ModelWatchState {
pub prefix: String, pub prefix: String,
pub model_type: ModelType, pub model_type: ModelType,
...@@ -142,16 +235,33 @@ async fn handle_put( ...@@ -142,16 +235,33 @@ async fn handle_put(
match state.model_type { match state.model_type {
ModelType::Chat => { ModelType::Chat => {
let endpoint_id = model_entry.endpoint.clone();
let client = state let client = state
.drt .drt
.namespace(model_entry.endpoint.namespace)? .namespace(&endpoint_id.namespace)?
.component(model_entry.endpoint.component)? .component(&endpoint_id.component)?
.endpoint(model_entry.endpoint.name) .endpoint(&endpoint_id.name)
.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>() .client()
.await?; .await?;
state
.manager let Some(etcd_client) = state.drt.etcd_client() else {
.add_chat_completions_model(&model_entry.name, Arc::new(client))?; // Should be impossible because we only get here on an etcd event
anyhow::bail!("Missing etcd_client");
};
let mdc = model_entry.load_mdc(endpoint_id, etcd_client).await?;
if mdc.requires_preprocessing {
// Note requires_preprocessing is never true in our code right now
todo!("Ingress-side pre-processing not supported yet");
} else {
let push_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client(client, Default::default())
.await?;
state
.manager
.add_chat_completions_model(&model_entry.name, Arc::new(push_router))?;
}
} }
ModelType::Completion => { ModelType::Completion => {
let client = state let client = state
...@@ -159,11 +269,20 @@ async fn handle_put( ...@@ -159,11 +269,20 @@ async fn handle_put(
.namespace(model_entry.endpoint.namespace)? .namespace(model_entry.endpoint.namespace)?
.component(model_entry.endpoint.component)? .component(model_entry.endpoint.component)?
.endpoint(model_entry.endpoint.name) .endpoint(model_entry.endpoint.name)
.client::<CompletionRequest, Annotated<CompletionResponse>>() .client()
.await?;
// TODO: Handle pre-processing once it moves ingress-side
let push_router =
PushRouter::<CompletionRequest, Annotated<CompletionResponse>>::from_client(
client,
Default::default(),
)
.await?; .await?;
state state
.manager .manager
.add_completions_model(&model_entry.name, Arc::new(client))?; .add_completions_model(&model_entry.name, Arc::new(push_router))?;
} }
} }
......
...@@ -32,6 +32,8 @@ mod mem; ...@@ -32,6 +32,8 @@ mod mem;
pub use mem::MemoryStorage; pub use mem::MemoryStorage;
mod nats; mod nats;
pub use nats::NATSStorage; pub use nats::NATSStorage;
mod etcd;
pub use etcd::EtcdStorage;
#[async_trait] #[async_trait]
pub trait KeyValueStore: Send + Sync { pub trait KeyValueStore: Send + Sync {
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::pin::Pin;
use std::time::Duration;
use async_stream::stream;
use async_trait::async_trait;
use dynamo_runtime::{protocols::Endpoint, slug::Slug, transports::etcd::Client};
use etcd_client::{EventType, PutOptions, WatchOptions};
use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
#[derive(Clone)]
pub struct EtcdStorage {
client: Client,
endpoint: Endpoint,
}
impl EtcdStorage {
pub fn new(client: Client, endpoint: Endpoint) -> Self {
Self { client, endpoint }
}
}
#[async_trait]
impl KeyValueStore for EtcdStorage {
/// A "bucket" in etcd is a path prefix
async fn get_or_create_bucket(
&self,
bucket_name: &str,
_ttl: Option<Duration>, // TODO ttl not used yet
) -> Result<Box<dyn KeyValueBucket>, StorageError> {
Ok(self.get_bucket(bucket_name).await?.unwrap())
}
/// A "bucket" in etcd is a path prefix. This creates an EtcdBucket object without doing
/// any network calls.
async fn get_bucket(
&self,
bucket_name: &str,
) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> {
Ok(Some(Box::new(EtcdBucket {
client: self.client.clone(),
endpoint: self.endpoint.clone(),
bucket_name: bucket_name.to_string(),
})))
}
}
pub struct EtcdBucket {
client: Client,
endpoint: Endpoint,
bucket_name: String,
}
#[async_trait]
impl KeyValueBucket for EtcdBucket {
async fn insert(
&self,
key: String,
value: String,
// "version" in etcd speak. revision is a global cluster-wide value
revision: u64,
) -> Result<StorageOutcome, StorageError> {
let version = revision;
if version == 0 {
self.create(&key, &value).await
} else {
self.update(&key, &value, version).await
}
}
async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> {
let k = make_key(&self.endpoint, &self.bucket_name, key);
tracing::trace!("etcd get: {k}");
let mut kvs = self
.client
.kv_get(k, None)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
if kvs.is_empty() {
return Ok(None);
}
let (_, val) = kvs.swap_remove(0).into_key_value();
Ok(Some(val.into()))
}
async fn delete(&self, key: &str) -> Result<(), StorageError> {
let _ = self
.client
.kv_delete(key, None)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
Ok(())
}
async fn watch(
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>
{
let k = make_key(&self.endpoint, &self.bucket_name, "");
tracing::trace!("etcd watch: {k}");
let (_watcher, mut watch_stream) = self
.client
.etcd_client()
.clone()
.watch(k.as_bytes(), Some(WatchOptions::new().with_prefix()))
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
let output = stream! {
while let Ok(Some(resp)) = watch_stream.message().await {
for e in resp.events() {
if matches!(e.event_type(), EventType::Put) && e.kv().is_some() {
let b: bytes::Bytes = e.kv().unwrap().value().to_vec().into();
yield b;
}
}
}
};
Ok(Box::pin(output))
}
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> {
let k = make_key(&self.endpoint, &self.bucket_name, "");
tracing::trace!("etcd entries: {k}");
let resp = self
.client
.kv_get_prefix(k)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
let out: HashMap<String, bytes::Bytes> = resp
.into_iter()
.map(|kv| {
let (k, v) = kv.into_key_value();
(String::from_utf8_lossy(&k).to_string(), v.into())
})
.collect();
Ok(out)
}
}
impl EtcdBucket {
async fn create(&self, key: &str, value: &str) -> Result<StorageOutcome, StorageError> {
let k = make_key(&self.endpoint, &self.bucket_name, key);
tracing::trace!("etcd create: {k}");
// Does it already exists? For 'create' it shouldn't.
let kvs = self
.client
.kv_get(k.clone(), None)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
if !kvs.is_empty() {
let version = kvs.first().unwrap().version();
return Ok(StorageOutcome::Exists(version as u64));
}
// Write it
let mut put_resp = self
.client
.kv_put_with_options(k, value, Some(PutOptions::new().with_prev_key()))
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
// Check if we overwrite something
if put_resp.take_prev_key().is_some() {
// Key created between our get and put
return Err(StorageError::Retry);
}
// version of a new key is always 1
Ok(StorageOutcome::Created(1))
}
async fn update(
&self,
key: &str,
value: &str,
revision: u64,
) -> Result<StorageOutcome, StorageError> {
let version = revision;
let k = make_key(&self.endpoint, &self.bucket_name, key);
tracing::trace!("etcd update: {k}");
let kvs = self
.client
.kv_get(k.clone(), None)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
if kvs.is_empty() {
return Err(StorageError::MissingKey(key.to_string()));
}
let current_version = kvs.first().unwrap().version() as u64;
if current_version != version + 1 {
tracing::warn!(
current_version,
attempted_next_version = version,
key,
"update: Wrong revision"
);
// NATS does a resync_update, overwriting the key anyway and getting the new revision.
// So we do too in etcd.
}
let mut put_resp = self
.client
.kv_put_with_options(k, value, Some(PutOptions::new().with_prev_key()))
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
Ok(match put_resp.take_prev_key() {
// Should this be an error?
// The key was deleted between our get and put. We re-created it.
// Version of new key is always 1.
// <https://etcd.io/docs/v3.5/learning/data_model/>
None => StorageOutcome::Created(1),
// Expected case, success
Some(kv) if kv.version() as u64 == version + 1 => StorageOutcome::Created(version),
// Should this be an error? Something updated the version between our get and put
Some(kv) => StorageOutcome::Created(kv.version() as u64 + 1),
})
}
}
fn make_key(endpoint: &Endpoint, bucket_name: &str, key: &str) -> String {
[
endpoint.namespace.to_string(),
Slug::slugify(bucket_name).to_string(),
Slug::slugify(key).to_string(),
]
.join("/")
}
...@@ -17,6 +17,7 @@ use std::time::Duration; ...@@ -17,6 +17,7 @@ use std::time::Duration;
pub mod create; pub mod create;
pub mod model; pub mod model;
pub use model::ModelDeploymentCard;
// TODO: Do these network/publish related model deployment card values belong here or in a // TODO: Do these network/publish related model deployment card values belong here or in a
// network module? // network module?
......
...@@ -81,7 +81,7 @@ impl ModelDeploymentCard { ...@@ -81,7 +81,7 @@ impl ModelDeploymentCard {
prompt_context: None, // TODO - auto-detect prompt context prompt_context: None, // TODO - auto-detect prompt context
revision: 0, revision: 0,
last_published: None, last_published: None,
requires_preprocessing: true, requires_preprocessing: false,
}) })
} }
...@@ -103,7 +103,7 @@ impl ModelDeploymentCard { ...@@ -103,7 +103,7 @@ impl ModelDeploymentCard {
prompt_context: None, // TODO - auto-detect prompt context prompt_context: None, // TODO - auto-detect prompt context
revision: 0, revision: 0,
last_published: None, last_published: None,
requires_preprocessing: true, requires_preprocessing: false,
}) })
} }
} }
......
...@@ -146,7 +146,7 @@ impl ModelDeploymentCard { ...@@ -146,7 +146,7 @@ impl ModelDeploymentCard {
pub fn with_name_only(name: &str) -> ModelDeploymentCard { pub fn with_name_only(name: &str) -> ModelDeploymentCard {
ModelDeploymentCard { ModelDeploymentCard {
display_name: name.to_string(), display_name: name.to_string(),
service_name: Slug::from_string(name).to_string(), service_name: Slug::slugify(name).to_string(),
..Default::default() ..Default::default()
} }
} }
...@@ -238,7 +238,7 @@ impl ModelDeploymentCard { ...@@ -238,7 +238,7 @@ impl ModelDeploymentCard {
tracing::debug!( tracing::debug!(
nats_addr, nats_addr,
%bucket_name, %bucket_name,
"Uploading model deployment card to NATS" "Uploading model deployment card fields to NATS"
); );
if let Some(ModelInfoType::HfConfigJson(ref src_file)) = self.model_info { if let Some(ModelInfoType::HfConfigJson(ref src_file)) = self.model_info {
......
...@@ -41,6 +41,7 @@ chrono = { workspace = true } ...@@ -41,6 +41,7 @@ chrono = { workspace = true }
derive_builder = { workspace = true } derive_builder = { workspace = true }
derive-getters = { workspace = true } derive-getters = { workspace = true }
either = { workspace = true } either = { workspace = true }
etcd-client = { workspace = true }
futures = { workspace = true } futures = { workspace = true }
humantime = { workspace = true } humantime = { workspace = true }
prometheus = { workspace = true } prometheus = { workspace = true }
...@@ -60,7 +61,6 @@ xxhash-rust = { workspace = true } ...@@ -60,7 +61,6 @@ xxhash-rust = { workspace = true }
async-once-cell = { version = "0.5.4" } async-once-cell = { version = "0.5.4" }
educe = { version = "0.6.0" } educe = { version = "0.6.0" }
etcd-client = { version = "0.14" }
figment = { version = "0.10.19", features = ["env", "json", "toml", "test"] } figment = { version = "0.10.19", features = ["env", "json", "toml", "test"] }
local-ip-address = { version = "0.6.3" } local-ip-address = { version = "0.6.3" }
log = { version = "0.4" } log = { version = "0.4" }
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
// limitations under the License. // limitations under the License.
use dynamo_runtime::{ use dynamo_runtime::{
logging, protocols::annotated::Annotated, stream::StreamExt, DistributedRuntime, Result, logging, pipeline::PushRouter, protocols::annotated::Annotated, stream::StreamExt,
Runtime, Worker, DistributedRuntime, Result, Runtime, Worker,
}; };
use hello_world::DEFAULT_NAMESPACE; use hello_world::DEFAULT_NAMESPACE;
...@@ -32,12 +32,13 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -32,12 +32,13 @@ async fn app(runtime: Runtime) -> Result<()> {
.namespace(DEFAULT_NAMESPACE)? .namespace(DEFAULT_NAMESPACE)?
.component("backend")? .component("backend")?
.endpoint("generate") .endpoint("generate")
.client::<String, Annotated<String>>() .client()
.await?; .await?;
client.wait_for_endpoints().await?; client.wait_for_endpoints().await?;
let router =
PushRouter::<String, Annotated<String>>::from_client(client, Default::default()).await?;
let mut stream = client.random("hello world".to_string().into()).await?; let mut stream = router.random("hello world".to_string().into()).await?;
while let Some(resp) = stream.next().await { while let Some(resp) = stream.next().await {
println!("{:?}", resp); println!("{:?}", resp);
......
...@@ -17,8 +17,8 @@ use futures::StreamExt; ...@@ -17,8 +17,8 @@ use futures::StreamExt;
use service_metrics::DEFAULT_NAMESPACE; use service_metrics::DEFAULT_NAMESPACE;
use dynamo_runtime::{ use dynamo_runtime::{
logging, protocols::annotated::Annotated, utils::Duration, DistributedRuntime, Result, Runtime, logging, pipeline::PushRouter, protocols::annotated::Annotated, utils::Duration,
Worker, DistributedRuntime, Result, Runtime, Worker,
}; };
fn main() -> Result<()> { fn main() -> Result<()> {
...@@ -33,14 +33,13 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -33,14 +33,13 @@ async fn app(runtime: Runtime) -> Result<()> {
let namespace = distributed.namespace(DEFAULT_NAMESPACE)?; let namespace = distributed.namespace(DEFAULT_NAMESPACE)?;
let component = namespace.component("backend")?; let component = namespace.component("backend")?;
let client = component let client = component.endpoint("generate").client().await?;
.endpoint("generate")
.client::<String, Annotated<String>>()
.await?;
client.wait_for_endpoints().await?; client.wait_for_endpoints().await?;
let router =
PushRouter::<String, Annotated<String>>::from_client(client, Default::default()).await?;
let mut stream = client.random("hello world".to_string().into()).await?; let mut stream = router.random("hello world".to_string().into()).await?;
while let Some(resp) = stream.next().await { while let Some(resp) = stream.next().await {
println!("{:?}", resp); println!("{:?}", resp);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment