Commit 9f53922a authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix: dynemo-run model discovery working again (#52)

There are two etcd keys:
- The service
- The model

The second one is the interesting one for us. Previously we confused the two.
parent aacc5d76
......@@ -41,7 +41,7 @@ pub async fn run(
let distributed = DistributedRuntime::from_settings(runtime.clone()).await?;
let cancel_token = runtime.primary_token().clone();
let endpoint: Endpoint = path.parse()?;
let endpoint_id: Endpoint = path.parse()?;
let etcd_client = distributed.etcd_client();
......@@ -83,28 +83,29 @@ pub async fn run(
let model_registration = ModelEntry {
name: service_name.to_string(),
endpoint: endpoint.clone(),
endpoint: endpoint_id.clone(),
model_type: ModelType::Chat,
};
let component = distributed
.namespace(endpoint_id.namespace)?
.component(endpoint_id.component)?;
let endpoint = component
.service_builder()
.create()
.await?
.endpoint(endpoint_id.name);
let network_name = endpoint.subject();
tracing::debug!("Registering with etcd as {network_name}");
etcd_client
.kv_create(
path.clone(),
network_name.clone(),
serde_json::to_vec_pretty(&model_registration)?,
None,
)
.await?;
let rt_fut = distributed
.namespace(endpoint.namespace)?
.component(endpoint.component)?
.service_builder()
.create()
.await?
.endpoint(endpoint.name)
.endpoint_builder()
.handler(ingress)
.start();
let rt_fut = endpoint.endpoint_builder().handler(ingress).start();
tokio::select! {
_ = rt_fut => {
tracing::debug!("Endpoint ingress ended");
......
......@@ -46,18 +46,25 @@ pub async fn run(
.enable_cmpl_endpoints(true)
.build()?;
match engine_config {
EngineConfig::Dynamic(client) => {
let service_name = client.path();
EngineConfig::Dynamic(endpoint) => {
// This will attempt to connect to NATS and etcd
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let component = distributed_runtime
.namespace(endpoint.namespace)?
.component(endpoint.component)?;
let network_prefix = component.service_name();
// Listen for models registering themselves in etcd, add them to HTTP service
let state = Arc::new(discovery::ModelWatchState {
prefix: service_name.clone(),
model_type: ModelType::Chat, // Tio currently supports only chat models
prefix: network_prefix.clone(),
model_type: ModelType::Chat,
manager: http_service.model_manager().clone(),
drt: distributed_runtime.clone(),
});
tracing::info!("Waiting for remote model at {network_prefix}");
let etcd_client = distributed_runtime.etcd_client();
let models_watcher = etcd_client.kv_get_and_watch_prefix(service_name).await?;
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let _watcher_task = tokio::spawn(discovery::model_watcher(state, receiver));
}
......
......@@ -27,6 +27,7 @@ use dynemo_llm::{
use dynemo_runtime::{
pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source},
runtime::CancellationToken,
DistributedRuntime, Runtime,
};
use futures::StreamExt;
use std::{
......@@ -43,6 +44,7 @@ const MAX_TOKENS: u32 = 8192;
const IS_A_TTY: i32 = 1;
pub async fn run(
runtime: Runtime,
cancel_token: CancellationToken,
engine_config: EngineConfig,
) -> anyhow::Result<()> {
......@@ -51,11 +53,22 @@ pub async fn run(
OpenAIChatCompletionsStreamingEngine,
bool,
) = match engine_config {
EngineConfig::Dynamic(client) => {
EngineConfig::Dynamic(endpoint_id) => {
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let endpoint = distributed_runtime
.namespace(endpoint_id.namespace)?
.component(endpoint_id.component)?
.endpoint(endpoint_id.name);
let client = endpoint.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>().await?;
tracing::info!("Waiting for remote model..");
client.wait_for_endpoints().await?;
tracing::info!("Model discovered");
// The service_name isn't used for text chat outside of logs,
// so use the path. That avoids having to listen on etcd for model registration.
let service_name = client.path();
tracing::info!("Model: {service_name}");
let service_name = endpoint.subject();
(service_name, Arc::new(client), false)
}
EngineConfig::StaticFull {
......
......@@ -17,17 +17,10 @@
use std::{future::Future, pin::Pin};
use dynemo_llm::{
backend::ExecutionContext,
model_card::model::ModelDeploymentCard,
types::{
openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
OpenAIChatCompletionsStreamingEngine,
},
Annotated,
},
backend::ExecutionContext, model_card::model::ModelDeploymentCard,
types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine,
};
use dynemo_runtime::{component::Client, protocols::Endpoint, DistributedRuntime};
use dynemo_runtime::protocols::Endpoint;
mod flags;
pub use flags::Flags;
......@@ -53,8 +46,7 @@ const PYTHON_TOK_SCHEME: &str = "pytok:";
pub enum EngineConfig {
/// An remote networked engine we don't know about yet
/// We don't have the pre-processor yet so this is only text requests. Type will change later.
Dynamic(Client<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>),
Dynamic(Endpoint),
/// A Full service engine does it's own tokenization and prompt formatting.
StaticFull {
......@@ -148,28 +140,7 @@ pub async fn run(
}
Output::Endpoint(path) => {
let endpoint: Endpoint = path.parse()?;
// This will attempt to connect to NATS and etcd
let distributed_runtime = DistributedRuntime::from_settings(runtime.clone()).await?;
let client = distributed_runtime
.namespace(endpoint.namespace)?
.component(endpoint.component)?
.endpoint(endpoint.name)
.client::<NvCreateChatCompletionRequest, Annotated<NvCreateChatCompletionStreamResponse>>()
.await?;
tracing::info!("Waiting for remote {}...", client.path());
tokio::select! {
_ = cancel_token.cancelled() => {
return Ok(());
}
r = client.wait_for_endpoints() => {
r?;
}
}
EngineConfig::Dynamic(client)
EngineConfig::Dynamic(endpoint)
}
#[cfg(feature = "mistralrs")]
Output::MistralRs => {
......@@ -379,7 +350,7 @@ pub async fn run(
crate::input::http::run(runtime.clone(), flags.http_port, engine_config).await?;
}
Input::Text => {
crate::input::text::run(cancel_token.clone(), engine_config).await?;
crate::input::text::run(runtime.clone(), cancel_token.clone(), engine_config).await?;
}
Input::Endpoint(path) => {
crate::input::endpoint::run(runtime.clone(), path, engine_config).await?;
......
......@@ -55,11 +55,9 @@ pub struct ModelWatchState {
pub drt: DistributedRuntime,
}
pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<WatchEvent>) {
pub async fn model_watcher(state: Arc<ModelWatchState>, mut events_rx: Receiver<WatchEvent>) {
tracing::debug!("model watcher started");
let mut events_rx = events_rx;
while let Some(event) = events_rx.recv().await {
match event {
WatchEvent::Put(kv) => match handle_put(&kv, state.clone()).await {
......@@ -80,15 +78,11 @@ pub async fn model_watcher(state: Arc<ModelWatchState>, events_rx: Receiver<Watc
},
}
}
tracing::debug!("model watcher stopped");
}
async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ModelType)> {
tracing::debug!("removing model");
let key = kv.key_str()?;
tracing::debug!("key: {}", key);
tracing::debug!(key, "removing model");
let model_name = key.trim_start_matches(&state.prefix);
......@@ -104,22 +98,14 @@ async fn handle_delete(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&s
// models.
//
// If this method errors, for the near term, we will delete the offending key.
async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str, ModelType)> {
tracing::debug!("adding model");
async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(String, ModelType)> {
let key = kv.key_str()?;
tracing::debug!("key: {}", key);
tracing::debug!(key, "adding model");
let model_name = key.trim_start_matches(&state.prefix);
// model_entry.name is the service name (e.g. "Llama-3.2-3B-Instruct")
let model_entry = serde_json::from_slice::<ModelEntry>(kv.value())?;
let service_name = model_entry.name.clone();
if model_entry.name != model_name {
raise!(
"model name mismatch: {} != {}",
model_entry.name,
model_name
);
}
if model_entry.model_type != state.model_type {
raise!(
"model type mismatch: {} != {}",
......@@ -139,7 +125,7 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str,
.await?;
state
.manager
.add_chat_completions_model(model_name, Arc::new(client))?;
.add_chat_completions_model(&service_name, Arc::new(client))?;
}
ModelType::Completion => {
let client = state
......@@ -151,9 +137,9 @@ async fn handle_put(kv: &KeyValue, state: Arc<ModelWatchState>) -> Result<(&str,
.await?;
state
.manager
.add_completions_model(model_name, Arc::new(client))?;
.add_completions_model(&service_name, Arc::new(client))?;
}
}
Ok((model_name, state.model_type))
Ok((service_name, state.model_type))
}
......@@ -43,6 +43,8 @@ pub struct Component {
/// - **name**
///
/// Example format: `"namespace/component/endpoint"`
///
/// TODO: There is also an Endpoint in runtime/src/component.rs
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct Endpoint {
pub namespace: String,
......
......@@ -20,7 +20,6 @@ use derive_builder::Builder;
use derive_getters::Dissolve;
use futures::StreamExt;
use tokio::sync::mpsc;
use tracing as log;
use validator::Validate;
use etcd_client::{
......@@ -167,9 +166,13 @@ impl Client {
// Execute the transaction
let result = self.client.kv_client().txn(txn).await?;
match result.succeeded() {
true => Ok(()),
false => Err(error!("failed to create key")),
if result.succeeded() {
Ok(())
} else {
for resp in result.op_responses() {
tracing::warn!("kv_create etcd op response: {resp:?}");
}
Err(error!("failed to create key"))
}
}
......@@ -247,7 +250,10 @@ impl Client {
Ok(get_response.take_kvs())
}
pub async fn kv_get_and_watch_prefix(&self, prefix: impl AsRef<str>) -> Result<PrefixWatcher> {
pub async fn kv_get_and_watch_prefix(
&self,
prefix: impl AsRef<str> + std::fmt::Display,
) -> Result<PrefixWatcher> {
let mut kv_client = self.client.kv_client();
let mut watch_client = self.client.watch_client();
......@@ -260,7 +266,7 @@ impl Client {
.ok_or(error!("missing header; unable to get revision"))?
.revision();
log::trace!("start_revision: {}", start_revision);
tracing::trace!("{prefix}: start_revision: {start_revision}");
let start_revision = start_revision + 1;
let (watcher, mut watch_stream) = watch_client
......@@ -276,7 +282,7 @@ impl Client {
.await?;
let kvs = get_response.take_kvs();
log::trace!("initial kv count: {:?}", kvs.len());
tracing::trace!("initial kv count: {:?}", kvs.len());
let (tx, rx) = mpsc::channel(32);
......@@ -293,7 +299,10 @@ impl Client {
match event.event_type() {
etcd_client::EventType::Put => {
if let Some(kv) = event.kv() {
if tx.send(WatchEvent::Put(kv.clone())).await.is_err() {
if let Err(err) = tx.send(WatchEvent::Put(kv.clone())).await {
tracing::error!(
"kv watcher error forwarding WatchEvent::Put: {err}"
);
// receiver is closed
break;
}
......@@ -311,7 +320,6 @@ impl Client {
}
}
});
Ok(PrefixWatcher {
prefix: prefix.as_ref().to_string(),
watcher,
......@@ -327,6 +335,7 @@ pub struct PrefixWatcher {
rx: mpsc::Receiver<WatchEvent>,
}
#[derive(Debug)]
pub enum WatchEvent {
Put(KeyValue),
Delete(KeyValue),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment