"git@developer.sourcefind.cn:OpenDAS/openfold.git" did not exist on "f33faf84a7d24343079626935b7219ae96e2c3e5"
Commit 9f53922a authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix: dynemo-run model discovery working again (#52)

There are two etcd keys:
- The service
- The model

The second one is the interesting one for us. Previously we confused the two.
parent aacc5d76
...@@ -41,7 +41,7 @@ pub async fn run( ...@@ -41,7 +41,7 @@ pub async fn run(
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?; let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
let cancel_token = runtime.primary_token().clone(); let cancel_token = runtime.primary_token().clone();
let endpoint: Endpoint = path.parse()?; let endpoint_id: Endpoint = path.parse()?;
let etcd_client = distributed.etcd_client(); let etcd_client = distributed.etcd_client();
...@@ -83,28 +83,29 @@ pub async fn run( ...@@ -83,28 +83,29 @@ pub async fn run(
let model_registration = ModelEntry { let model_registration = ModelEntry {
name: service_name.to_string(), name: service_name.to_string(),
endpoint: endpoint.clone(), endpoint: endpoint_id.clone(),
model_type: ModelType::Chat, model_type: ModelType::Chat,
}; };
let component = distributed
.namespace(endpoint_id.namespace)?
.component(endpoint_id.component)?;
let endpoint = component
.service_builder()
.create()
.await?
.endpoint(endpoint_id.name);
let network_name = endpoint.subject();
tracing::debug!("Registering with etcd as {network_name}");
etcd_client etcd_client
.kv_create( .kv_create(
path.clone(), network_name.clone(),
serde_json::to_vec_pretty(&model_registration)?, serde_json::to_vec_pretty(&model_registration)?,
None, None,
) )
.await?; .await?;
let rt_fut = distributed let rt_fut = endpoint.endpoint_builder().handler(ingress).start();
.namespace(endpoint.namespace)?
.component(endpoint.component)?
.service_builder()
.create()
.await?
.endpoint(endpoint.name)
.endpoint_builder()
.handler(ingress)
.start();
tokio::select! { tokio::select! {
_ = rt_fut => { _ = rt_fut => {
tracing::debug!("Endpoint ingress ended"); tracing::debug!("Endpoint ingress ended");
......
...@@ -46,18 +46,25 @@ pub async fn run( ...@@ -46,18 +46,25 @@ pub async fn run(
.enable_cmpl_endpoints(true) .enable_cmpl_endpoints(true)
.build()?; .build()?;
match engine_config { match engine_config {
EngineConfig::Dynamic(client) => { EngineConfig::Dynamic(endpoint) => {
let service_name = client.path(); // This will attempt to connect to NATS and etcd
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?; let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let component = distributed_runtime
.namespace(endpoint.namespace)?
.component(endpoint.component)?;
let network_prefix = component.service_name();
// Listen for models registering themselves in etcd, add them to HTTP service // Listen for models registering themselves in etcd, add them to HTTP service
let state = Arc::new(discovery::ModelWatchState { let state = Arc::new(discovery::ModelWatchState {
prefix: service_name.clone(), prefix: network_prefix.clone(),
model_type: ModelType::Chat, // Tio currently supports only chat models model_type: ModelType::Chat,
manager: http_service.model_manager().clone(), manager: http_service.model_manager().clone(),
drt: distributed_runtime.clone(), drt: distributed_runtime.clone(),
}); });
tracing::info!("Waiting for remote model at {network_prefix}");
let etcd_client = distributed_runtime.etcd_client(); let etcd_client = distributed_runtime.etcd_client();
let models_watcher = etcd_client.kv_get_and_watch_prefix(service_name).await?; let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let _watcher_task = tokio::spawn(discovery::model_watcher(state, receiver)); let _watcher_task = tokio::spawn(discovery::model_watcher(state, receiver));
} }
......
...@@ -27,6 +27,7 @@ use dynemo_llm::{ ...@@ -27,6 +27,7 @@ use dynemo_llm::{
use dynemo_runtime::{ use dynemo_runtime::{
pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source}, pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source},
runtime::CancellationToken, runtime::CancellationToken,
DistributedRuntime, Runtime,
}; };
use futures::StreamExt; use futures::StreamExt;
use std::{ use std::{
...@@ -43,6 +44,7 @@ const MAX_TOKENS: u32 = 8192; ...@@ -43,6 +44,7 @@ const MAX_TOKENS: u32 = 8192;
const IS_A_TTY: i32 = 1; const IS_A_TTY: i32 = 1;
pub async fn run( pub async fn run(
runtime: Runtime,
cancel_token: CancellationToken, cancel_token: CancellationToken,
engine_config: EngineConfig, engine_config: EngineConfig,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
...@@ -51,11 +53,22 @@ pub async fn run( ...@@ -51,11 +53,22 @@ pub async fn run(
OpenAIChatCompletionsStreamingEngine, OpenAIChatCompletionsStreamingEngine,
bool, bool,
) = match engine_config { ) = match engine_config {
EngineConfig::Dynamic(client) => { EngineConfig::Dynamic(endpoint_id) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let endpoint = distributed_runtime
.namespace(endpoint_id.namespace)?
.component(endpoint_id.component)?
.endpoint(endpoint_id.name);
let client = endpoint.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>().await?;
tracing::info!("Waiting for remote model..");
client.wait_for_endpoints().await?;
tracing::info!("Model discovered");
// The service_name isn't used for text chat outside of logs, // The service_name isn't used for text chat outside of logs,
// so use the path. That avoids having to listen on etcd for model registration. // so use the path. That avoids having to listen on etcd for model registration.
let service_name = client.path(); let service_name = endpoint.subject();
tracing::info!("Model: {service_name}");
(service_name, Arc::new(client), false) (service_name, Arc::new(client), false)
} }
EngineConfig::StaticFull { EngineConfig::StaticFull {
......
...@@ -17,17 +17,10 @@ ...@@ -17,17 +17,10 @@
use std::{future::Future, pin::Pin}; use std::{future::Future, pin::Pin};
use dynemo_llm::{ use dynemo_llm::{
backend::ExecutionContext, backend::ExecutionContext, model_card::model::ModelDeploymentCard,
model_card::model::ModelDeploymentCard, types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine,
types::{
openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
OpenAIChatCompletionsStreamingEngine,
},
Annotated,
},
}; };
use dynemo_runtime::{component::Client, protocols::Endpoint, DistributedRuntime}; use dynemo_runtime::protocols::Endpoint;
mod flags; mod flags;
pub use flags::Flags; pub use flags::Flags;
...@@ -53,8 +46,7 @@ const PYTHON_TOK_SCHEME: &str = "pytok:"; ...@@ -53,8 +46,7 @@ const PYTHON_TOK_SCHEME: &str = "pytok:";
pub enum EngineConfig { pub enum EngineConfig {
/// An remote networked engine we don't know about yet /// An remote networked engine we don't know about yet
/// We don't have the pre-processor yet so this is only text requests. Type will change later. Dynamic(Endpoint),
Dynamic(Client<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>),
/// A Full service engine does it's own tokenization and prompt formatting. /// A Full service engine does it's own tokenization and prompt formatting.
StaticFull { StaticFull {
...@@ -148,28 +140,7 @@ pub async fn run( ...@@ -148,28 +140,7 @@ pub async fn run(
} }
Output::Endpoint(path) => { Output::Endpoint(path) => {
let endpoint: Endpoint = path.parse()?; let endpoint: Endpoint = path.parse()?;
EngineConfig::Dynamic(endpoint)
// This will attempt to connect to NATS and etcd
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let client = distributed_runtime
.namespace(endpoint.namespace)?
.component(endpoint.component)?
.endpoint(endpoint.name)
.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>()
.await?;
tracing::info!("Waiting for remote {}...", client.path());
tokio::select! {
_ = cancel_token.cancelled() => {
return Ok(());
}
r = client.wait_for_endpoints() => {
r?;
}
}
EngineConfig::Dynamic(client)
} }
#[cfg(feature = "mistralrs")] #[cfg(feature = "mistralrs")]
Output::MistralRs => { Output::MistralRs => {
...@@ -379,7 +350,7 @@ pub async fn run( ...@@ -379,7 +350,7 @@ pub async fn run(
crate::input::http::run(runtime.clone(), flags.http_port, engine_config).await?; crate::input::http::run(runtime.clone(), flags.http_port, engine_config).await?;
} }
Input::Text => { Input::Text => {
crate::input::text::run(cancel_token.clone(), engine_config).await?; crate::input::text::run(runtime.clone(), cancel_token.clone(), engine_config).await?;
} }
Input::Endpoint(path) => { Input::Endpoint(path) => {
crate::input::endpoint::run(runtime.clone(), path, engine_config).await?; crate::input::endpoint::run(runtime.clone(), path, engine_config).await?;
......
...@@ -55,11 +55,9 @@ pub struct ModelWatchState { ...@@ -55,11 +55,9 @@ pub struct ModelWatchState {
pub drt: DistributedRuntime, pub drt: DistributedRuntime,
} }
pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<WatchEvent>) { pub async fn model_watcher(state: Arc<ModelWatchState>, mut events_rx: Receiver<WatchEvent>) {
tracing::debug!("model watcher started"); tracing::debug!("model watcher started");
let mut events_rx = events_rx;
while let Some(event) = events_rx.recv().await { while let Some(event) = events_rx.recv().await {
match event { match event {
WatchEvent::Put(kv) => match handle_put(&kv, state.clone()).await { WatchEvent::Put(kv) => match handle_put(&kv, state.clone()).await {
...@@ -80,15 +78,11 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<Watc ...@@ -80,15 +78,11 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<Watc
}, },
} }
} }
tracing::debug!("model watcher stopped");
} }
async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ModelType)> { async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ModelType)> {
tracing::debug!("removing model");
let key = kv.key_str()?; let key = kv.key_str()?;
tracing::debug!("key: {}", key); tracing::debug!(key, "removing model");
let model_name = key.trim_start_matches(&state.prefix); let model_name = key.trim_start_matches(&state.prefix);
...@@ -104,22 +98,14 @@ async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&s ...@@ -104,22 +98,14 @@ async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&s
// models. // models.
// //
// If this method errors, for the near term, we will delete the offending key. // If this method errors, for the near term, we will delete the offending key.
async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ModelType)> { async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(String, ModelType)> {
tracing::debug!("adding model");
let key = kv.key_str()?; let key = kv.key_str()?;
tracing::debug!("key: {}", key); tracing::debug!(key, "adding model");
let model_name = key.trim_start_matches(&state.prefix); // model_entry.name is the service name (e.g. "Llama-3.2-3B-Instruct")
let model_entry = serde_json::from_slice::<ModelEntry>(kv.value())?; let model_entry = serde_json::from_slice::<ModelEntry>(kv.value())?;
let service_name = model_entry.name.clone();
if model_entry.name != model_name {
raise!(
"model name mismatch: {} != {}",
model_entry.name,
model_name
);
}
if model_entry.model_type != state.model_type { if model_entry.model_type != state.model_type {
raise!( raise!(
"model type mismatch: {} != {}", "model type mismatch: {} != {}",
...@@ -139,7 +125,7 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ...@@ -139,7 +125,7 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str,
.await?; .await?;
state state
.manager .manager
.add_chat_completions_model(model_name, Arc::new(client))?; .add_chat_completions_model(&service_name, Arc::new(client))?;
} }
ModelType::Completion => { ModelType::Completion => {
let client = state let client = state
...@@ -151,9 +137,9 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ...@@ -151,9 +137,9 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str,
.await?; .await?;
state state
.manager .manager
.add_completions_model(model_name, Arc::new(client))?; .add_completions_model(&service_name, Arc::new(client))?;
} }
} }
Ok((model_name, state.model_type)) Ok((service_name, state.model_type))
} }
...@@ -43,6 +43,8 @@ pub struct Component { ...@@ -43,6 +43,8 @@ pub struct Component {
/// - **name** /// - **name**
/// ///
/// Example format: `"namespace/component/endpoint"` /// Example format: `"namespace/component/endpoint"`
///
/// TODO: There is also an Endpoint in runtime/src/component.rs
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct Endpoint { pub struct Endpoint {
pub namespace: String, pub namespace: String,
......
...@@ -20,7 +20,6 @@ use derive_builder::Builder; ...@@ -20,7 +20,6 @@ use derive_builder::Builder;
use derive_getters::Dissolve; use derive_getters::Dissolve;
use futures::StreamExt; use futures::StreamExt;
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tracing as log;
use validator::Validate; use validator::Validate;
use etcd_client::{ use etcd_client::{
...@@ -167,9 +166,13 @@ impl Client { ...@@ -167,9 +166,13 @@ impl Client {
// Execute the transaction // Execute the transaction
let result = self.client.kv_client().txn(txn).await?; let result = self.client.kv_client().txn(txn).await?;
match result.succeeded() { if result.succeeded() {
true => Ok(()), Ok(())
false => Err(error!("failed to create key")), } else {
for resp in result.op_responses() {
tracing::warn!("kv_create etcd op response: {resp:?}");
}
Err(error!("failed to create key"))
} }
} }
...@@ -247,7 +250,10 @@ impl Client { ...@@ -247,7 +250,10 @@ impl Client {
Ok(get_response.take_kvs()) Ok(get_response.take_kvs())
} }
pub async fn kv_get_and_watch_prefix(&self, prefix: impl AsRef<str>) -> Result<PrefixWatcher> { pub async fn kv_get_and_watch_prefix(
&self,
prefix: impl AsRef<str> + std::fmt::Display,
) -> Result<PrefixWatcher> {
let mut kv_client = self.client.kv_client(); let mut kv_client = self.client.kv_client();
let mut watch_client = self.client.watch_client(); let mut watch_client = self.client.watch_client();
...@@ -260,7 +266,7 @@ impl Client { ...@@ -260,7 +266,7 @@ impl Client {
.ok_or(error!("missing header; unable to get revision"))? .ok_or(error!("missing header; unable to get revision"))?
.revision(); .revision();
log::trace!("start_revision: {}", start_revision); tracing::trace!("{prefix}: start_revision: {start_revision}");
let start_revision = start_revision + 1; let start_revision = start_revision + 1;
let (watcher, mut watch_stream) = watch_client let (watcher, mut watch_stream) = watch_client
...@@ -276,7 +282,7 @@ impl Client { ...@@ -276,7 +282,7 @@ impl Client {
.await?; .await?;
let kvs = get_response.take_kvs(); let kvs = get_response.take_kvs();
log::trace!("initial kv count: {:?}", kvs.len()); tracing::trace!("initial kv count: {:?}", kvs.len());
let (tx, rx) = mpsc::channel(32); let (tx, rx) = mpsc::channel(32);
...@@ -293,7 +299,10 @@ impl Client { ...@@ -293,7 +299,10 @@ impl Client {
match event.event_type() { match event.event_type() {
etcd_client::EventType::Put => { etcd_client::EventType::Put => {
if let Some(kv) = event.kv() { if let Some(kv) = event.kv() {
if tx.send(WatchEvent::Put(kv.clone())).await.is_err() { if let Err(err) = tx.send(WatchEvent::Put(kv.clone())).await {
tracing::error!(
"kv watcher error forwarding WatchEvent::Put: {err}"
);
// receiver is closed // receiver is closed
break; break;
} }
...@@ -311,7 +320,6 @@ impl Client { ...@@ -311,7 +320,6 @@ impl Client {
} }
} }
}); });
Ok(PrefixWatcher { Ok(PrefixWatcher {
prefix: prefix.as_ref().to_string(), prefix: prefix.as_ref().to_string(),
watcher, watcher,
...@@ -327,6 +335,7 @@ pub struct PrefixWatcher { ...@@ -327,6 +335,7 @@ pub struct PrefixWatcher {
rx: mpsc::Receiver<WatchEvent>, rx: mpsc::Receiver<WatchEvent>,
} }
#[derive(Debug)]
pub enum WatchEvent { pub enum WatchEvent {
Put(KeyValue), Put(KeyValue),
Delete(KeyValue), Delete(KeyValue),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment