Unverified Commit a1a10365 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Split PushRouter from Client (#817)

In a distributed system we don't know if the remote workers need pre-processing done ingress-side or not. Previously Client required us to decide this before discovering the remote endpoints, which was fine because pre-processing was worker-side.

As part of moving pre-processing back to ingress-side we need to split this into two steps:
- Client discovers the endpoints, and (later PR) will fetch their Model Deployment Card.
- PushRouter will use the Model Deployment Card to decide if they need pre-processing or not, which affects the types of the generic parameters.

Part of #743
parent 97bf8184
......@@ -1619,6 +1619,7 @@ dependencies = [
"dynamo-runtime",
"either",
"erased-serde",
"etcd-client",
"futures",
"galil-seiferas",
"ggus",
......
......@@ -55,6 +55,7 @@ chrono = { version = "0.4", default-features = false, features = ["alloc", "std"
derive_builder = { version = "0.20" }
derive-getters = { version = "0.5" }
either = { version = "1.13", features = ["serde"] }
etcd-client = { version = "0.14" }
futures = { version = "0.3" }
hf-hub = { version = "0.4.2", default-features = false, features = ["tokio", "rustls-tls"] }
humantime = { version = "2.2.0" }
......
......@@ -127,11 +127,7 @@ async fn app(runtime: Runtime) -> Result<()> {
tracing::debug!("Creating unique instance of Count at {key}");
drt.etcd_client()
.expect("Unreachable because of DistributedRuntime::from_settings above")
.kv_create(
key,
serde_json::to_vec_pretty(&config)?,
Some(drt.primary_lease().unwrap().id()),
)
.kv_create(key, serde_json::to_vec_pretty(&config)?, None)
.await
.context("Unable to create unique instance of Count; possibly one already exists")?;
......
......@@ -18,7 +18,7 @@ use std::path::PathBuf;
use std::str::FromStr;
use clap::ValueEnum;
use dynamo_runtime::component::RouterMode as RuntimeRouterMode;
use dynamo_runtime::pipeline::RouterMode as RuntimeRouterMode;
/// Required options depend on the in and out choices
#[derive(clap::Parser, Debug, Clone)]
......
......@@ -15,12 +15,11 @@
use std::pin::Pin;
use crate::{flags::RouterMode, EngineConfig, Flags};
use dynamo_llm::{
backend::Backend,
backend::ExecutionContext,
backend::{Backend, ExecutionContext},
engines::StreamingEngineAdapter,
model_card::model::ModelDeploymentCard,
http::service::discovery::ModelNetworkName,
model_card::ModelDeploymentCard,
preprocessor::OpenAIPreprocessor,
protocols::common::llm_backend::{BackendInput, BackendOutput},
types::{
......@@ -33,11 +32,15 @@ use dynamo_llm::{
};
use dynamo_runtime::{
engine::{AsyncEngineStream, Data},
pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source},
pipeline::{
Context, ManyOut, Operator, PushRouter, ServiceBackend, ServiceFrontend, SingleIn, Source,
},
DistributedRuntime, Runtime,
};
use std::sync::Arc;
use crate::{flags::RouterMode, EngineConfig, Flags};
/// Turns an EngineConfig into an OpenAI chat-completions and completions supported StreamingEngine.
pub async fn prepare_engine(
runtime: Runtime,
......@@ -53,22 +56,40 @@ pub async fn prepare_engine(
.component(endpoint_id.component.clone())?
.endpoint(endpoint_id.name.clone());
let mut client = endpoint.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>().await?;
match &flags.router_mode {
let client = endpoint.client().await?;
let router = match &flags.router_mode {
RouterMode::Random | RouterMode::RoundRobin => {
client.set_router_mode(flags.router_mode.into());
tracing::info!("Waiting for remote model..");
client.wait_for_endpoints().await?;
tracing::info!("Model discovered");
// We then use the ModelDeploymentCard's `requires_preprocessing`
// field to decide what kind of PushRouter to make.
let remote_endpoints = client.wait_for_endpoints().await?;
debug_assert!(!remote_endpoints.is_empty());
tracing::info!(count = remote_endpoints.len(), "Model(s) discovered");
let network_name: ModelNetworkName = (&remote_endpoints[0]).into();
let Some(etcd_client) = distributed_runtime.etcd_client() else {
anyhow::bail!("Cannot run distributed components without etcd");
};
let mdc = network_name.load_mdc(endpoint_id, etcd_client).await?;
if mdc.requires_preprocessing {
// Note requires_preprocessing is never true in our code right now
todo!("Ingress-side pre-processing not supported yet");
} else {
PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client(client, flags.router_mode.into())
.await?
}
}
RouterMode::KV => todo!(),
}
};
// The service_name isn't used for text chat outside of logs,
// so use the path. That avoids having to listen on etcd for model registration.
let service_name = endpoint.subject();
Ok((service_name, Arc::new(client), false))
Ok((service_name, Arc::new(router), false))
}
EngineConfig::StaticFull {
service_name,
......
......@@ -18,9 +18,9 @@ use std::sync::Arc;
use dynamo_llm::{
backend::Backend,
engines::StreamingEngineAdapter,
http::service::discovery::ModelEntry,
key_value_store::{KeyValueStore, KeyValueStoreManager, NATSStorage},
model_card::{BUCKET_NAME, BUCKET_TTL},
http::service::discovery::{ModelEntry, ModelNetworkName},
key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
model_card,
model_type::ModelType,
preprocessor::OpenAIPreprocessor,
types::{
......@@ -49,14 +49,14 @@ pub async fn run(
let etcd_client = distributed_runtime.etcd_client();
let (ingress, service_name, mut card) = match engine_config {
let (ingress, service_name, mut card, requires_preprocessing) = match engine_config {
EngineConfig::StaticFull {
service_name,
engine,
card,
} => {
let engine = Arc::new(StreamingEngineAdapter::new(engine));
(Ingress::for_engine(engine)?, service_name, card)
(Ingress::for_engine(engine)?, service_name, card, false)
}
EngineConfig::StaticCore {
service_name,
......@@ -81,7 +81,8 @@ pub async fn run(
.link(preprocessor.backward_edge())?
.link(frontend)?;
(Ingress::for_pipeline(pipeline)?, service_name, card)
// TODO: switch last 'false' to 'true' once we have ingress-side pre-processing
(Ingress::for_pipeline(pipeline)?, service_name, card, false)
}
EngineConfig::Dynamic(_) => {
anyhow::bail!("Cannot use endpoint for both in and out");
......@@ -104,30 +105,30 @@ pub async fn run(
.await?
.endpoint(&endpoint_id.name);
let nats_client = distributed_runtime.nats_client();
card.move_to_nats(nats_client.clone()).await?;
let kvstore: Box<dyn KeyValueStore> =
Box::new(NATSStorage::new(nats_client.clone(), endpoint_id));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
card.requires_preprocessing = false;
card_store.publish_until_cancelled(
cancel_token.clone(),
BUCKET_NAME.to_string(),
Some(BUCKET_TTL),
BUCKET_TTL / 2,
card.slug().to_string(),
*card.clone(),
);
if let Some(etcd_client) = etcd_client {
let network_name = endpoint.subject_to(etcd_client.lease_id());
// Store model config files in NATS object store
let nats_client = distributed_runtime.nats_client();
card.move_to_nats(nats_client.clone()).await?;
// Publish the Model Deployment Card to etcd
let kvstore: Box<dyn KeyValueStore> =
Box::new(EtcdStorage::new(etcd_client.clone(), endpoint_id));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
card.requires_preprocessing = requires_preprocessing; // Not used yet. Soon.
let key = card.slug().to_string();
card_store
.publish(model_card::BUCKET_NAME, None, &key, &mut *card.clone())
.await?;
// Publish our ModelEntry to etcd. This allows ingress to find the model card.
// (Why don't we put the model card directly under this key?)
let network_name = ModelNetworkName::from_local(&endpoint, etcd_client.lease_id());
tracing::debug!("Registering with etcd as {network_name}");
etcd_client
.kv_create(
network_name.clone(),
network_name.to_string(),
serde_json::to_vec_pretty(&model_registration)?,
Some(etcd_client.lease_id()),
None, // use primary lease
)
.await?;
}
......@@ -140,8 +141,12 @@ pub async fn run(
_ = cancel_token.cancelled() => {
}
}
// Cleanup on shutdown
if let Err(err) = card.delete_from_nats(nats_client).await {
if let Err(err) = card
.delete_from_nats(distributed_runtime.nats_client())
.await
{
tracing::error!(%err, "delete_from_nats error on shutdown");
}
Ok(())
......
......@@ -180,9 +180,19 @@ pub async fn run(
}
};
// If we are in a distributed system, we need to know our component upfront
let dyn_input = match &in_opt {
Input::Endpoint(endpoint_path) => {
if model_path.as_ref().map(|mp| mp.is_file()).unwrap_or(false)
&& flags.model_config.is_none()
{
// TODO We need to convert tokenizer extract from GGUF file into something we can
// publish to NATS. Ideally `tokenizer.json` directly, but otherwise an
// intermediate format.
tracing::error!("Serving GGUF files in a distributed system requires `--model-config <hf-repo-dir>` so that we can find the tokenzier config");
return Ok(());
}
// If we are in a distributed system, we need to know our component upfront
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let endpoint_id: Endpoint = endpoint_path.parse()?;
Some(DynInput {
......@@ -216,7 +226,10 @@ pub async fn run(
"out=echo_core need to find the tokenizer. Pass flag --model-path <path>"
);
};
card.requires_preprocessing = true;
// TODO: Switch to `true` once pre-processing moves ingress side
card.requires_preprocessing = false;
EngineConfig::StaticCore {
service_name: card.service_name.clone(),
engine: dynamo_llm::engines::make_engine_core(),
......
......@@ -1049,6 +1049,7 @@ dependencies = [
"dynamo-runtime",
"either",
"erased-serde",
"etcd-client",
"futures",
"galil-seiferas",
"ggus",
......
......@@ -146,7 +146,7 @@ struct Endpoint {
#[pyclass]
#[derive(Clone)]
struct Client {
inner: rs::component::Client<serde_json::Value, serde_json::Value>,
router: rs::pipeline::PushRouter<serde_json::Value, serde_json::Value>,
}
#[pyclass]
......@@ -445,11 +445,17 @@ impl Endpoint {
fn client<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let client = inner
.client::<serde_json::Value, serde_json::Value>()
let client = inner.client().await.map_err(to_pyerr)?;
let push_router =
rs::pipeline::PushRouter::<serde_json::Value, serde_json::Value>::from_client(
client,
Default::default(),
)
.await
.map_err(to_pyerr)?;
Ok(Client { inner: client })
Ok(Client {
router: push_router,
})
})
}
......@@ -552,13 +558,17 @@ impl EtcdClient {
impl Client {
/// Get list of current endpoints
fn endpoint_ids(&self) -> Vec<i64> {
self.inner.endpoint_ids()
self.router.client.endpoint_ids()
}
fn wait_for_endpoints<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
let inner = self.router.client.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner.wait_for_endpoints().await.map_err(to_pyerr)
inner
.wait_for_endpoints()
.await
.map(|v| v.into_iter().map(|cei| cei.id()).collect::<Vec<i64>>())
.map_err(to_pyerr)
})
}
......@@ -570,7 +580,7 @@ impl Client {
request: PyObject,
annotated: Option<bool>,
) -> PyResult<Bound<'p, PyAny>> {
if self.inner.is_static() {
if self.router.client.is_static() {
self.r#static(py, request, annotated)
} else {
self.random(py, request, annotated)
......@@ -589,7 +599,7 @@ impl Client {
let annotated = annotated.unwrap_or(false);
let (tx, rx) = tokio::sync::mpsc::channel(32);
let client = self.inner.clone();
let client = self.router.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.round_robin(request.into()).await.map_err(to_pyerr)?;
......@@ -613,7 +623,7 @@ impl Client {
let annotated = annotated.unwrap_or(false);
let (tx, rx) = tokio::sync::mpsc::channel(32);
let client = self.inner.clone();
let client = self.router.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.random(request.into()).await.map_err(to_pyerr)?;
......@@ -638,7 +648,7 @@ impl Client {
let annotated = annotated.unwrap_or(false);
let (tx, rx) = tokio::sync::mpsc::channel(32);
let client = self.inner.clone();
let client = self.router.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client
......@@ -667,7 +677,7 @@ impl Client {
let annotated = annotated.unwrap_or(false);
let (tx, rx) = tokio::sync::mpsc::channel(32);
let client = self.inner.clone();
let client = self.router.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let stream = client.r#static(request.into()).await.map_err(to_pyerr)?;
......
......@@ -27,9 +27,9 @@ use llm_rs::{
},
};
use dynamo_runtime::pipeline::{Operator, ServiceFrontend, Source};
use dynamo_runtime::pipeline::{ManyOut, SegmentSink, SingleIn};
use dynamo_runtime::pipeline::{
ManyOut, Operator, PushRouter, SegmentSink, ServiceFrontend, SingleIn, Source,
};
#[pyclass]
pub(crate) struct OAIChatPreprocessor {
......@@ -76,13 +76,14 @@ impl OAIChatPreprocessor {
let builder = self.current.inner.endpoint_builder().handler(ingress);
let endpoint = Arc::new(self.next.inner.clone());
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let client = Arc::new(
endpoint
.client::<BackendInput, Annotated<BackendOutput>>()
.await
.map_err(to_pyerr)?,
);
network.attach(client).map_err(to_pyerr)?;
let client = endpoint.client().await.map_err(to_pyerr)?;
let router = PushRouter::<BackendInput, Annotated<BackendOutput>>::from_client(
client,
Default::default(),
)
.await
.map_err(to_pyerr)?;
network.attach(Arc::new(router)).map_err(to_pyerr)?;
builder.start().await.map_err(to_pyerr)?;
Ok(())
})
......
......@@ -44,6 +44,7 @@ bytes = { workspace = true }
chrono = { workspace = true }
derive_builder = {workspace = true }
either = { workspace = true }
etcd-client = { workspace = true }
futures = { workspace = true }
rand = { workspace = true }
prometheus = { workspace = true }
......
......@@ -15,13 +15,17 @@
use std::sync::Arc;
use anyhow::Context as _;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::Receiver;
use dynamo_runtime::{
component::{self, ComponentEndpointInfo},
pipeline::network::egress::push_router::PushRouter,
protocols::{self, annotated::Annotated},
raise,
transports::etcd::{KeyValue, WatchEvent},
slug::Slug,
transports::etcd::{self, KeyValue, WatchEvent},
DistributedRuntime,
};
......@@ -31,7 +35,12 @@ use crate::protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
};
use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse};
use crate::{
key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
model_card::{self, ModelDeploymentCard},
};
use tracing;
/// [ModelEntry] is a struct that contains the information for the HTTP service to discover models
/// from the etcd cluster.
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
......@@ -48,6 +57,90 @@ pub struct ModelEntry {
pub model_type: ModelType,
}
impl ModelEntry {
pub async fn load_mdc(
&self,
endpoint_id: protocols::Endpoint,
etcd_client: etcd::Client,
) -> anyhow::Result<ModelDeploymentCard> {
let kvstore: Box<dyn KeyValueStore> =
Box::new(EtcdStorage::new(etcd_client.clone(), endpoint_id));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let card_key = ModelDeploymentCard::service_name_slug(&self.name);
match card_store
.load::<ModelDeploymentCard>(model_card::BUCKET_NAME, &card_key)
.await
{
Ok(Some(mdc)) => Ok(mdc),
Ok(None) => {
anyhow::bail!("Missing ModelDeploymentCard in etcd under key {card_key}");
}
Err(err) => {
anyhow::bail!(
"Error fetching ModelDeploymentCard from etcd under key {card_key}. {err}"
);
}
}
}
}
#[derive(Debug, Clone)]
pub struct ModelNetworkName(String);
impl ModelNetworkName {
/// Key to store this model entry in networked key-value store (etcd).
///
/// It looks like this:
/// ns.cp.ep-694d967ca5efd804
fn from_parts(namespace: &str, component: &str, endpoint: &str, lease_id: i64) -> Self {
ModelNetworkName(
Slug::slugify(&format!("{namespace}.{component}.{endpoint}-{lease_id:x}")).to_string(),
)
}
// We can't do From<&component::Endpoint> here because we also need the lease_id
pub fn from_local(endpoint: &component::Endpoint, lease_id: i64) -> Self {
Self::from_parts(
&endpoint.component().namespace().to_string(),
&endpoint.component().name(),
endpoint.name(),
lease_id,
)
}
pub async fn load_mdc(
&self,
endpoint_id: protocols::Endpoint,
etcd_client: etcd::Client,
) -> anyhow::Result<ModelDeploymentCard> {
let network_name = self;
let model_entries = etcd_client.kv_get(network_name.to_string(), None).await?;
if model_entries.is_empty() {
anyhow::bail!("No ModelEntry in etcd for key {network_name}");
}
let entry: ModelEntry =
serde_json::from_slice(model_entries[0].value()).with_context(|| {
format!(
"Error deserializing JSON. Key={network_name}. JSON={}",
model_entries[0].value_str().unwrap_or("INVALID UTF-8")
)
})?;
entry.load_mdc(endpoint_id, etcd_client).await
}
}
impl From<&ComponentEndpointInfo> for ModelNetworkName {
fn from(cei: &ComponentEndpointInfo) -> Self {
Self::from_parts(&cei.namespace, &cei.component, &cei.endpoint, cei.lease_id)
}
}
impl std::fmt::Display for ModelNetworkName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
pub struct ModelWatchState {
pub prefix: String,
pub model_type: ModelType,
......@@ -142,16 +235,33 @@ async fn handle_put(
match state.model_type {
ModelType::Chat => {
let endpoint_id = model_entry.endpoint.clone();
let client = state
.drt
.namespace(model_entry.endpoint.namespace)?
.component(model_entry.endpoint.component)?
.endpoint(model_entry.endpoint.name)
.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>()
.namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?
.endpoint(&endpoint_id.name)
.client()
.await?;
state
.manager
.add_chat_completions_model(&model_entry.name, Arc::new(client))?;
let Some(etcd_client) = state.drt.etcd_client() else {
// Should be impossible because we only get here on an etcd event
anyhow::bail!("Missing etcd_client");
};
let mdc = model_entry.load_mdc(endpoint_id, etcd_client).await?;
if mdc.requires_preprocessing {
// Note requires_preprocessing is never true in our code right now
todo!("Ingress-side pre-processing not supported yet");
} else {
let push_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client(client, Default::default())
.await?;
state
.manager
.add_chat_completions_model(&model_entry.name, Arc::new(push_router))?;
}
}
ModelType::Completion => {
let client = state
......@@ -159,11 +269,20 @@ async fn handle_put(
.namespace(model_entry.endpoint.namespace)?
.component(model_entry.endpoint.component)?
.endpoint(model_entry.endpoint.name)
.client::<CompletionRequest, Annotated<CompletionResponse>>()
.client()
.await?;
// TODO: Handle pre-processing once it moves ingress-side
let push_router =
PushRouter::<CompletionRequest, Annotated<CompletionResponse>>::from_client(
client,
Default::default(),
)
.await?;
state
.manager
.add_completions_model(&model_entry.name, Arc::new(client))?;
.add_completions_model(&model_entry.name, Arc::new(push_router))?;
}
}
......
......@@ -32,6 +32,8 @@ mod mem;
pub use mem::MemoryStorage;
mod nats;
pub use nats::NATSStorage;
mod etcd;
pub use etcd::EtcdStorage;
#[async_trait]
pub trait KeyValueStore: Send + Sync {
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use std::pin::Pin;
use std::time::Duration;
use async_stream::stream;
use async_trait::async_trait;
use dynamo_runtime::{protocols::Endpoint, slug::Slug, transports::etcd::Client};
use etcd_client::{EventType, PutOptions, WatchOptions};
use super::{KeyValueBucket, KeyValueStore, StorageError, StorageOutcome};
#[derive(Clone)]
pub struct EtcdStorage {
client: Client,
endpoint: Endpoint,
}
impl EtcdStorage {
pub fn new(client: Client, endpoint: Endpoint) -> Self {
Self { client, endpoint }
}
}
#[async_trait]
impl KeyValueStore for EtcdStorage {
/// A "bucket" in etcd is a path prefix
async fn get_or_create_bucket(
&self,
bucket_name: &str,
_ttl: Option<Duration>, // TODO ttl not used yet
) -> Result<Box<dyn KeyValueBucket>, StorageError> {
Ok(self.get_bucket(bucket_name).await?.unwrap())
}
/// A "bucket" in etcd is a path prefix. This creates an EtcdBucket object without doing
/// any network calls.
async fn get_bucket(
&self,
bucket_name: &str,
) -> Result<Option<Box<dyn KeyValueBucket>>, StorageError> {
Ok(Some(Box::new(EtcdBucket {
client: self.client.clone(),
endpoint: self.endpoint.clone(),
bucket_name: bucket_name.to_string(),
})))
}
}
pub struct EtcdBucket {
client: Client,
endpoint: Endpoint,
bucket_name: String,
}
#[async_trait]
impl KeyValueBucket for EtcdBucket {
async fn insert(
&self,
key: String,
value: String,
// "version" in etcd speak. revision is a global cluster-wide value
revision: u64,
) -> Result<StorageOutcome, StorageError> {
let version = revision;
if version == 0 {
self.create(&key, &value).await
} else {
self.update(&key, &value, version).await
}
}
async fn get(&self, key: &str) -> Result<Option<bytes::Bytes>, StorageError> {
let k = make_key(&self.endpoint, &self.bucket_name, key);
tracing::trace!("etcd get: {k}");
let mut kvs = self
.client
.kv_get(k, None)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
if kvs.is_empty() {
return Ok(None);
}
let (_, val) = kvs.swap_remove(0).into_key_value();
Ok(Some(val.into()))
}
async fn delete(&self, key: &str) -> Result<(), StorageError> {
let _ = self
.client
.kv_delete(key, None)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
Ok(())
}
async fn watch(
&self,
) -> Result<Pin<Box<dyn futures::Stream<Item = bytes::Bytes> + Send + 'life0>>, StorageError>
{
let k = make_key(&self.endpoint, &self.bucket_name, "");
tracing::trace!("etcd watch: {k}");
let (_watcher, mut watch_stream) = self
.client
.etcd_client()
.clone()
.watch(k.as_bytes(), Some(WatchOptions::new().with_prefix()))
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
let output = stream! {
while let Ok(Some(resp)) = watch_stream.message().await {
for e in resp.events() {
if matches!(e.event_type(), EventType::Put) && e.kv().is_some() {
let b: bytes::Bytes = e.kv().unwrap().value().to_vec().into();
yield b;
}
}
}
};
Ok(Box::pin(output))
}
async fn entries(&self) -> Result<HashMap<String, bytes::Bytes>, StorageError> {
let k = make_key(&self.endpoint, &self.bucket_name, "");
tracing::trace!("etcd entries: {k}");
let resp = self
.client
.kv_get_prefix(k)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
let out: HashMap<String, bytes::Bytes> = resp
.into_iter()
.map(|kv| {
let (k, v) = kv.into_key_value();
(String::from_utf8_lossy(&k).to_string(), v.into())
})
.collect();
Ok(out)
}
}
impl EtcdBucket {
async fn create(&self, key: &str, value: &str) -> Result<StorageOutcome, StorageError> {
let k = make_key(&self.endpoint, &self.bucket_name, key);
tracing::trace!("etcd create: {k}");
// Does it already exists? For 'create' it shouldn't.
let kvs = self
.client
.kv_get(k.clone(), None)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
if !kvs.is_empty() {
let version = kvs.first().unwrap().version();
return Ok(StorageOutcome::Exists(version as u64));
}
// Write it
let mut put_resp = self
.client
.kv_put_with_options(k, value, Some(PutOptions::new().with_prev_key()))
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
// Check if we overwrite something
if put_resp.take_prev_key().is_some() {
// Key created between our get and put
return Err(StorageError::Retry);
}
// version of a new key is always 1
Ok(StorageOutcome::Created(1))
}
async fn update(
&self,
key: &str,
value: &str,
revision: u64,
) -> Result<StorageOutcome, StorageError> {
let version = revision;
let k = make_key(&self.endpoint, &self.bucket_name, key);
tracing::trace!("etcd update: {k}");
let kvs = self
.client
.kv_get(k.clone(), None)
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
if kvs.is_empty() {
return Err(StorageError::MissingKey(key.to_string()));
}
let current_version = kvs.first().unwrap().version() as u64;
if current_version != version + 1 {
tracing::warn!(
current_version,
attempted_next_version = version,
key,
"update: Wrong revision"
);
// NATS does a resync_update, overwriting the key anyway and getting the new revision.
// So we do too in etcd.
}
let mut put_resp = self
.client
.kv_put_with_options(k, value, Some(PutOptions::new().with_prev_key()))
.await
.map_err(|e| StorageError::EtcdError(e.to_string()))?;
Ok(match put_resp.take_prev_key() {
// Should this be an error?
// The key was deleted between our get and put. We re-created it.
// Version of new key is always 1.
// <https://etcd.io/docs/v3.5/learning/data_model/>
None => StorageOutcome::Created(1),
// Expected case, success
Some(kv) if kv.version() as u64 == version + 1 => StorageOutcome::Created(version),
// Should this be an error? Something updated the version between our get and put
Some(kv) => StorageOutcome::Created(kv.version() as u64 + 1),
})
}
}
fn make_key(endpoint: &Endpoint, bucket_name: &str, key: &str) -> String {
[
endpoint.namespace.to_string(),
Slug::slugify(bucket_name).to_string(),
Slug::slugify(key).to_string(),
]
.join("/")
}
......@@ -17,6 +17,7 @@ use std::time::Duration;
pub mod create;
pub mod model;
pub use model::ModelDeploymentCard;
// TODO: Do these network/publish related model deployment card values belong here or in a
// network module?
......
......@@ -81,7 +81,7 @@ impl ModelDeploymentCard {
prompt_context: None, // TODO - auto-detect prompt context
revision: 0,
last_published: None,
requires_preprocessing: true,
requires_preprocessing: false,
})
}
......@@ -103,7 +103,7 @@ impl ModelDeploymentCard {
prompt_context: None, // TODO - auto-detect prompt context
revision: 0,
last_published: None,
requires_preprocessing: true,
requires_preprocessing: false,
})
}
}
......
......@@ -146,7 +146,7 @@ impl ModelDeploymentCard {
pub fn with_name_only(name: &str) -> ModelDeploymentCard {
ModelDeploymentCard {
display_name: name.to_string(),
service_name: Slug::from_string(name).to_string(),
service_name: Slug::slugify(name).to_string(),
..Default::default()
}
}
......@@ -238,7 +238,7 @@ impl ModelDeploymentCard {
tracing::debug!(
nats_addr,
%bucket_name,
"Uploading model deployment card to NATS"
"Uploading model deployment card fields to NATS"
);
if let Some(ModelInfoType::HfConfigJson(ref src_file)) = self.model_info {
......
......@@ -41,6 +41,7 @@ chrono = { workspace = true }
derive_builder = { workspace = true }
derive-getters = { workspace = true }
either = { workspace = true }
etcd-client = { workspace = true }
futures = { workspace = true }
humantime = { workspace = true }
prometheus = { workspace = true }
......@@ -60,7 +61,6 @@ xxhash-rust = { workspace = true }
async-once-cell = { version = "0.5.4" }
educe = { version = "0.6.0" }
etcd-client = { version = "0.14" }
figment = { version = "0.10.19", features = ["env", "json", "toml", "test"] }
local-ip-address = { version = "0.6.3" }
log = { version = "0.4" }
......
......@@ -14,8 +14,8 @@
// limitations under the License.
use dynamo_runtime::{
logging, protocols::annotated::Annotated, stream::StreamExt, DistributedRuntime, Result,
Runtime, Worker,
logging, pipeline::PushRouter, protocols::annotated::Annotated, stream::StreamExt,
DistributedRuntime, Result, Runtime, Worker,
};
use hello_world::DEFAULT_NAMESPACE;
......@@ -32,12 +32,13 @@ async fn app(runtime: Runtime) -> Result<()> {
.namespace(DEFAULT_NAMESPACE)?
.component("backend")?
.endpoint("generate")
.client::<String, Annotated<String>>()
.client()
.await?;
client.wait_for_endpoints().await?;
let router =
PushRouter::<String, Annotated<String>>::from_client(client, Default::default()).await?;
let mut stream = client.random("hello world".to_string().into()).await?;
let mut stream = router.random("hello world".to_string().into()).await?;
while let Some(resp) = stream.next().await {
println!("{:?}", resp);
......
......@@ -17,8 +17,8 @@ use futures::StreamExt;
use service_metrics::DEFAULT_NAMESPACE;
use dynamo_runtime::{
logging, protocols::annotated::Annotated, utils::Duration, DistributedRuntime, Result, Runtime,
Worker,
logging, pipeline::PushRouter, protocols::annotated::Annotated, utils::Duration,
DistributedRuntime, Result, Runtime, Worker,
};
fn main() -> Result<()> {
......@@ -33,14 +33,13 @@ async fn app(runtime: Runtime) -> Result<()> {
let namespace = distributed.namespace(DEFAULT_NAMESPACE)?;
let component = namespace.component("backend")?;
let client = component
.endpoint("generate")
.client::<String, Annotated<String>>()
.await?;
let client = component.endpoint("generate").client().await?;
client.wait_for_endpoints().await?;
let router =
PushRouter::<String, Annotated<String>>::from_client(client, Default::default()).await?;
let mut stream = client.random("hello world".to_string().into()).await?;
let mut stream = router.random("hello world".to_string().into()).await?;
while let Some(resp) = stream.next().await {
println!("{:?}", resp);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment