Unverified Commit d346782c authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Publish Model Deployment Card to NATS (#799)

This will allow an ingress-side pre-processor to see it without needing a model checkout.

Currently pre-processing is done in the worker, which has access to the model deployment card ("MDC") files (`config.json`, `tokenizer.json` and `tokenizer_config.json`) locally. We want to move the pre-processor to the ingress side to support KV routing. That requires ingress side (i.e the HTTP server), on a different machine than the worker to be able to see those three files.

To support that this PR makes the worker upload the contents of those files to the NATS object store, and publishes the MDC with those NATS urls to the key-value store. 

The key-value store has an interface so any store (nats, etcd, redis, etc) can be supported. Implementations for memory and NATS are provided.

Fetching the MDC from the store, doing pre-processing ingress side, and publishing a card backed by a GGUF, are all for a later commit.

Part of #743 
parent 16310b26
......@@ -36,8 +36,6 @@ impl Slug {
Slug::slugify_unique(s.as_ref())
}
/* Not currently used but leave it for now
*
/// Turn the string into a valid slug, replacing any not-web-or-nats-safe characters with '-'
pub fn slugify(s: &str) -> Slug {
let out = s
......@@ -54,7 +52,6 @@ impl Slug {
.collect::<String>();
Slug::new(out)
}
*/
/// Like slugify but also add a four byte hash on the end, in case two different strings slug
/// to the same thing.
......
......@@ -34,13 +34,17 @@ use async_nats::{client, jetstream, Subscriber};
use bytes::Bytes;
use derive_builder::Builder;
use futures::{StreamExt, TryStreamExt};
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use tokio::fs::File as TokioFile;
use tokio::time;
use url::Url;
use validator::{Validate, ValidationError};
pub use crate::slug::Slug;
use tracing as log;
pub const URL_PREFIX: &str = "nats://";
#[derive(Clone)]
pub struct Client {
client: client::Client,
......@@ -63,6 +67,12 @@ impl Client {
&self.js_ctx
}
/// host:port of NATS
pub fn addr(&self) -> String {
let info = self.client.server_info();
format!("{}:{}", info.host, info.port)
}
/// fetch the list of streams
pub async fn list_streams(&self) -> Result<Vec<String>> {
let names = self.js_ctx.stream_names();
......@@ -108,148 +118,57 @@ impl Client {
Ok(subscription)
}
// /// create a new stream
// async fn get_or_create_work_queue_stream(
// &self,
// name: &super::Namespace,
// ) -> Result<jetstream::stream::Stream> {
// let stream = self
// .js_ctx
// .get_or_create_stream(async_nats::jetstream::stream::Config {
// name: name.to_string(),
// retention: async_nats::jetstream::stream::RetentionPolicy::WorkQueue,
// subjects: vec![format!("{name}.>")],
// ..Default::default()
// })
// .await?;
// Ok(stream)
// }
// // get work queue
// pub async fn get_or_create_work_queue(
// &self,
// namespace: &super::Namespace,
// queue_name: &Slug,
// ) -> Result<WorkQueue> {
// let stream = self.get_or_create_work_queue_stream(namespace).await?;
// let consumer_name = single_name(namespace, queue_name);
// let subject_name = subject_name(namespace, queue_name);
// let subject_name = format!("{}.*", subject_name);
// tracing::trace!(
// durable_name = consumer_name,
// filter_subject = subject_name,
// "get_or_create_work_queue"
// );
// let consumer = stream
// .get_or_create_consumer(
// &consumer_name,
// jetstream::consumer::pull::Config {
// durable_name: Some(consumer_name.clone()),
// filter_subject: subject_name,
// ack_policy: jetstream::consumer::AckPolicy::Explicit,
// ..Default::default()
// },
// )
// .await?;
// Ok(WorkQueue::new(consumer))
// }
// pub async fn get_or_create_work_queue_publisher(
// &self,
// namespace: &super::Namespace,
// queue_name: &Slug,
// ) -> Result<WorkQueuePublisher> {
// let _stream = self.get_or_create_work_queue_stream(namespace).await?;
// let _subject = subject_name(namespace, queue_name);
// Ok(WorkQueuePublisher {
// client: self.clone(),
// namespace: namespace.clone(),
// queue_name: queue_name.clone(),
// })
// }
// pub async fn list_work_queues(
// &self,
// namespace: &super::Namespace,
// ) -> Result<Vec<String>> {
// let stream = self.get_stream(namespace.as_ref()).await?;
// let consumers: Vec<String> = stream.consumer_names().try_collect().await?;
// Ok(consumers)
// }
// /// remove a work queue
// pub async fn remove_work_queue(
// &self,
// namespace: &super::Namespace,
// queue_name: &Slug,
// ) -> Result<()> {
// let stream = self.get_stream(namespace.as_ref()).await?;
// let consumer_name = single_name(namespace, queue_name);
// let consumers = self.list_consumers(namespace.as_ref()).await?;
// if consumers.contains(&consumer_name) {
// stream.delete_consumer(&consumer_name).await?;
// }
// Ok(())
// }
// /// publish a message to a subject
// pub async fn publish(&self, subject: String, msg: Vec<u8>) -> Result<()> {
// self.client.publish(subject, msg.into()).await?;
// Ok(())
// }
// /// subscribe to a subject
// pub async fn subscribe(
// &self,
// subject: String,
// ) -> Result<async_nats::Subscriber> {
// let sub = self.client.subscribe(subject).await?;
// Ok(sub)
// }
// pub async fn enqueue(
// &self,
// namespace: &super::Namespace,
// queue_name: &Slug,
// payload: Bytes,
// ) -> Result<String> {
// // let mut headers = HeaderMap::new();
// let subject = subject_name(namespace, queue_name);
// let request_id = uuid::Uuid::new_v4().to_string();
// let subject = format!("{}.{}", subject, request_id);
// self.client.publish(subject, payload).await?;
// // self.client
// // .publish_with_headers(subject, headers, payload.into())
// // .await?;
// Ok(request_id)
// }
// pub async fn enqueue_with_id(
// &self,
// namespace: &super::Namespace,
// queue_name: &Slug,
// request_id: &str,
// payload: Vec<u8>,
// ) -> Result<()> {
// let subject = subject_name(namespace, queue_name);
// let subject = format!("{}.{}", subject, request_id);
// self.client.publish(subject, payload.into()).await?;
// Ok(())
// }
// pub fn frontend_client(&self, request_id: String) -> SpecializedClient {
// SpecializedClient::new(self.client.clone(), ClientKind::Frontend, request_id)
// }
// pub fn backend_client(&self, request_id: String) -> SpecializedClient {
// SpecializedClient::new(self.client.clone(), ClientKind::Backend, request_id)
// }
/// Upload file to NATS at this URL
pub async fn object_store_upload(&self, filepath: &Path, nats_url: Url) -> anyhow::Result<()> {
let mut disk_file = TokioFile::open(filepath).await?;
let (bucket_name, key) = url_to_bucket_and_key(&nats_url)?;
let context = self.jetstream();
let bucket = match context.get_object_store(&bucket_name).await {
Ok(bucket) => bucket,
Err(err) if err.to_string().contains("stream not found") => {
// err.source() is GetStreamError, which has a kind() which
// is GetStreamErrorKind::JetStream which wraps a jetstream::Error
// which has code 404. Phew. So yeah check the string for now.
tracing::debug!("Creating NATS bucket {bucket_name}");
context
.create_object_store(jetstream::object_store::Config {
bucket: bucket_name.to_string(),
..Default::default()
})
.await
.map_err(|e| anyhow::anyhow!("Failed creating bucket / object store: {e}"))?
}
Err(err) => {
anyhow::bail!("NATS get_object_store error: {err}");
}
};
let key_meta = async_nats::jetstream::object_store::ObjectMetadata {
name: key.to_string(),
..Default::default()
};
bucket.put(key_meta, &mut disk_file).await.map_err(|e| {
anyhow::anyhow!("Failed uploading to bucket / object store {bucket_name}/{key}: {e}")
})?;
Ok(())
}
/// Delete a bucket and all it's contents from the NATS object store
pub async fn object_store_delete_bucket(&self, bucket_name: &str) -> anyhow::Result<()> {
let context = self.jetstream();
match context.delete_object_store(&bucket_name).await {
Ok(_) => Ok(()),
Err(err) if err.to_string().contains("stream not found") => {
tracing::trace!(bucket_name, "NATS bucket already gone");
Ok(())
}
Err(err) => Err(anyhow::anyhow!("NATS get_object_store error: {err}")),
}
}
}
/// NATS client options
......@@ -366,6 +285,27 @@ impl Default for NatsAuth {
}
}
/// Is this file name / url in the NATS object store?
/// Checks the name only, does not go to the store.
pub fn is_nats_url(s: &str) -> bool {
s.starts_with(URL_PREFIX)
}
/// Extract NATS bucket and key from a nats URL of the form:
/// nats://host[:port]/bucket/key
pub fn url_to_bucket_and_key(url: &Url) -> anyhow::Result<(String, String)> {
let Some(mut path_segments) = url.path_segments() else {
anyhow::bail!("No path in NATS URL: {url}");
};
let Some(bucket) = path_segments.next() else {
anyhow::bail!("No bucket in NATS URL: {url}");
};
let Some(key) = path_segments.next() else {
anyhow::bail!("No key in NATS URL: {url}");
};
Ok((bucket.to_string(), key.to_string()))
}
#[cfg(test)]
mod tests {
......@@ -416,416 +356,4 @@ mod tests {
Ok(())
});
}
// const TEST_STREAM: &str = "test_async_nats_stream";
// #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
// struct Request {
// id: String,
// }
// async fn nats_client() -> Result<Client> {
// Client::builder()
// .server("nats://localhost:4222")
// .username("user")
// .password("user")
// .build()
// .await
// }
// #[tokio::test]
// async fn test_list_streams() {
// let client = match nats_client().await.ok() {
// Some(client) => client,
// None => {
// println!("Failed to create client; skipping nats tests");
// return;
// }
// };
// let streams = client.list_streams().await.expect("failed to list streams");
// for stream in streams {
// let info = client
// .stream_info(&stream)
// .await
// .expect("failed to get stream info");
// assert_eq!(info.messages, 0, "stream {} not empty", stream);
// }
// }
// #[tokio::test]
// async fn test_workq_pull_and_response_stream() {
// let ns: Namespace = TEST_STREAM.try_into().unwrap();
// let _client = match nats_client().await.ok() {
// Some(client) => client,
// None => {
// println!("Failed to create client; skipping nats tests");
// return;
// }
// };
// let client = Client::builder()
// .server("nats://localhost:4222")
// .username("user")
// .password("user")
// .build()
// .await
// .expect("failed to create client");
// let _streams = client.list_streams().await.expect("failed to list streams");
// // assert!(!streams.contains(&TEST_STREAM.to_string()));
// let _stream = client
// .get_or_create_work_queue_stream(&ns)
// .await
// .expect("failed to create stream");
// let model_name: Slug = "foo".try_into().unwrap();
// let request_id = "bar";
// let request = Request {
// id: request_id.to_string(),
// };
// let request_payload = serde_json::to_vec(&request).expect("failed to serialize request");
// // let request = CompletionRequest {
// // prompt: CompletionContext::from_prompt("deep learning is".to_string()).into(),
// // stop_conditions: None,
// // sampling_options: None,
// // };
// // remove work queue if it exists
// client
// .remove_work_queue(&ns, &model_name)
// .await
// .expect("remove work queue does not fail if queue does not exist");
// // get the count of the work queues
// let initial_work_queue_count = client
// .list_work_queues(&ns)
// .await
// .expect("failed to list work queues")
// .len();
// // create work queue
// let workq = client
// .get_or_create_work_queue(&ns, &model_name)
// .await
// .expect("failed to get work queue");
// // new work queue count
// let work_queue_count = client
// .list_work_queues(&ns)
// .await
// .expect("failed to list work queues")
// .len();
// assert_eq!(initial_work_queue_count, work_queue_count - 1);
// client
// .enqueue(&ns, &model_name, request_payload.into())
// .await
// .expect("failed to enqueue completion request");
// let mut messages = workq
// .pull(1, std::time::Duration::from_secs(1))
// .await
// .expect("failed to pull messages from work queue");
// assert_eq!(1, messages.len());
// let msg = messages.pop().expect("no message received");
// msg.ack().await.expect("failed to ack");
// let request: Request =
// serde_json::from_slice(&msg.payload).expect("failed to deserialize message");
// assert_eq!(request.id, request_id);
// // clean up and delete nats work queue and stream
// client
// .remove_work_queue(&ns, &model_name)
// .await
// .expect("failed to remove work queue");
// // client
// // .delete_stream(TEST_STREAM)
// // .await
// // .expect("failed to delete stream");
// }
}
// let frontend_client = client.frontend_client("test".to_string());
// // the represents the frontend response subscription
// let mut frontend_sub = frontend_client
// .subscribe()
// .await
// .expect("failed to subscribe");
// let backend_client = client.backend_client("test".to_string());
// let mut backend_sub = backend_client
// .subscribe()
// .await
// .expect("failed to subscribe");
// let msg = messages[0].clone();
// let req = serde_json::from_slice::<CompletionRequest>(&msg.payload)
// .expect("failed to deserialize message");
// msg.ack().await.expect("failed to ack");
// assert_eq!(req.prompt, request.prompt);
// // ping pong message between backend and frontend
// // backend publishes to frontend
// backend_client
// .publish(&MessageKind::Initialize(Prologue {
// formatted_prompt: None,
// input_token_ids: None,
// }))
// .await
// .expect("failed to publish");
// // frontend receives initialize message
// let msg = frontend_sub.next().await.expect("msg not received");
// let msg = serde_json::from_slice::<MessageKind>(&msg.payload)
// .expect("failed to deserialize message");
// match msg {
// MessageKind::Initialize(_) => {}
// _ => panic!("unexpected message"),
// }
// // frontend publishes to backend
// frontend_client
// .publish(&MessageKind::Finalize(Epilogue {}))
// .await
// .expect("failed to publish");
// // backend receives finalize message
// let msg = backend_sub.next().await.expect("msg not received");
// let msg = serde_json::from_slice::<MessageKind>(&msg.payload)
// .expect("failed to deserialize message");
// match &msg {
// MessageKind::Finalize(_) => {}
// _ => panic!("unexpected message"),
// }
// // delete the work queue
// client
// .remove_work_queue(model_name, TEST_STREAM)
// .await
// .expect("failed to remove work queue");
// // new work queue count
// let work_queue_count = client
// .list_work_queues(TEST_STREAM)
// .await
// .expect("failed to list work queues")
// .len();
// // compare against the initial work queue count
// assert_eq!(initial_work_queue_count, work_queue_count);
// }
// async fn connect(config: Arc<Config>) -> Result<NatsClient> {
// let client = ClientOptions::builder()
// .server(config.nats_address.clone())
// .build()
// .await
// .context("Creating NATS Client")?;
// Ok(client)
// }
// async fn create_service(
// nats: NatsClient,
// config: Arc<Config>,
// observer: ServiceObserver,
// ) -> Result<NatsService> {
// let service = nats
// .client()
// .service_builder()
// .description(config.service_description.as_str())
// .stats_handler(move |_name, _stats| {
// let stats = InstanceStats {
// stage: observer.stage(),
// };
// serde_json::to_value(&stats).unwrap()
// })
// .start(
// config.service_name.as_str(),
// config.service_version.as_str(),
// )
// .await
// .map_err(|e| anyhow::anyhow!("Failed to start service: {e}"))?;
// Ok(service)
// }
// async fn create_endpoint(
// endpoint_name: impl Into<String>,
// service: &NatsService,
// ) -> Result<Endpoint> {
// let info = service.info().await;
// let group_name = format!("{}-{}", info.name, info.id);
// let group = service.group(group_name);
// let endpoint = group
// .endpoint(endpoint_name.into())
// .await
// .map_err(|e| anyhow::anyhow!("Failed to start endpoint: {e}"))?;
// Ok(endpoint)
// }
// async fn shutdown_endpoint_handler(
// controller: ServiceController,
// endpoint: Endpoint,
// ) -> Result<()> {
// let mut endpoint = endpoint;
// // note: this is a child cancellation token, canceling it will not cancel the parent
// // but the parent will cancel the child -- we only use this to observe if another
// // controller has cancelled the service
// let cancellation_token = controller.cancel_token();
// loop {
// let req = tokio::select! {
// _ = cancellation_token.cancelled() => {
// // tracing::trace!(worker_id, "Shutting down service {}", self.endpoint.name);
// return Ok(());
// }
// // await on service request
// req = endpoint.next() => {
// req
// }
// };
// if let Some(req) = req {
// let response = "DONE".to_string();
// if let Err(e) = req.respond(Ok(response.into())).await {
// tracing::warn!("Failed to respond to the shutdown request: {:?}", e);
// }
// controller.set_stage(ServiceStage::ShuttingDown);
// }
// }
// }
// #[derive(Debug, Clone, Builder)]
// pub struct Config {
// /// The NATS server address
// #[builder(default = "String::from(\"nats://localhost:4222\")")]
// pub nats_address: String,
// #[builder(setter(into), default = "String::from(SERVICE_NAME)")]
// pub service_name: String,
// #[builder(setter(into), default = "String::from(SERVICE_VERSION)")]
// pub service_version: String,
// #[builder(setter(into), default = "String::from(SERVICE_DESCRIPTION)")]
// pub service_description: String,
// }
// impl Config {
// pub fn new() -> Result<Config> {
// Ok(ConfigBuilder::default().build()?)
// }
// /// Create a new [`ConfigBuilder`]
// pub fn builder() -> ConfigBuilder {
// ConfigBuilder::default()
// }
// }
// // todo: move to icp - transports
// #[derive(Clone, Debug)]
// pub struct NatsClient {
// client: Client,
// js_ctx: jetstream::Context,
// }
// impl NatsClient {
// pub fn client(&self) -> &Client {
// &self.client
// }
// pub fn jetstream(&self) -> &jetstream::Context {
// &self.js_ctx
// }
// pub fn service_builder(&self) -> NatsServiceBuilder {
// self.client.service_builder()
// }
// }
// #[derive(Debug, Clone, Serialize, Deserialize)]
// pub struct ServiceInfo {
// pub name: String,
// pub id: String,
// pub version: String,
// pub started: String,
// pub endpoints: Vec<EndpointInfo>,
// }
// #[derive(Debug, Clone, Serialize, Deserialize)]
// pub struct EndpointInfo {
// pub name: String,
// pub subject: String,
// pub data: serde_json::Value,
// }
// impl EndpointInfo {
// pub fn get<T: serde::de::DeserializeOwned>(&self) -> Result<T> {
// serde_json::from_value(self.data.clone()).map_err(Into::into)
// }
// }
// #[derive(Clone, Debug, Builder)]
// #[builder(build_fn(private, name = "build_internal"))]
// pub struct ClientOptions {
// #[builder(setter(into))]
// server: String,
// #[builder(setter(into, strip_option), default)]
// username: Option<String>,
// #[builder(setter(into, strip_option), default)]
// password: Option<String>,
// }
// #[allow(dead_code)]
// impl ClientOptions {
// pub fn builder() -> ClientOptionsBuilder {
// ClientOptionsBuilder::default()
// }
// }
// impl ClientOptionsBuilder {
// pub async fn build(&self) -> Result<NatsClient> {
// let opts = self.build_internal()?;
// // Create an unauthenticated connection to NATS.
// let client = async_nats::ConnectOptions::new();
// let client = if let (Some(username), Some(password)) = (opts.username, opts.password) {
// client.user_and_password(username, password)
// } else {
// client
// };
// let client = client.connect(&opts.server).await?;
// let js_ctx = jetstream::new(client.clone());
// Ok(NatsClient { client, js_ctx })
// }
// }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment