Unverified Commit 3e8e38a9 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix(llmctl): Use ModelWatcher instead of direct etcd operations (#1150)

parent dcad8ac7
...@@ -3522,6 +3522,7 @@ dependencies = [ ...@@ -3522,6 +3522,7 @@ dependencies = [
name = "llmctl" name = "llmctl"
version = "0.2.1" version = "0.2.1"
dependencies = [ dependencies = [
"anyhow",
"clap", "clap",
"dynamo-llm", "dynamo-llm",
"dynamo-runtime", "dynamo-runtime",
......
...@@ -56,7 +56,7 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -56,7 +56,7 @@ async fn app(runtime: Runtime) -> Result<()> {
// the cli when operating on an `http` component will validate the namespace.component is // the cli when operating on an `http` component will validate the namespace.component is
// registered with HttpServiceComponentDefinition // registered with HttpServiceComponentDefinition
let watch_obj = ModelWatcher::new(distributed.clone(), manager, RouterMode::Random).await?; let watch_obj = ModelWatcher::new(distributed.clone(), manager, RouterMode::Random);
if let Some(etcd_client) = distributed.etcd_client() { if let Some(etcd_client) = distributed.etcd_client() {
let models_watcher: PrefixWatcher = let models_watcher: PrefixWatcher =
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
* [Vllm](#vllm) * [Vllm](#vllm)
* [TensorRT-LLM](#tensorrt-llm-engine) * [TensorRT-LLM](#tensorrt-llm-engine)
* [Echo Engines](#echo-engines) * [Echo Engines](#echo-engines)
* [Write your own engine in Python](#write-your-own-engine-in-python) * [Writing your own engine in Python](#writing-your-own-engine-in-python)
* [Batch mode](#batch-mode) * [Batch mode](#batch-mode)
* [Defaults](#defaults) * [Defaults](#defaults)
* [Extra engine arguments](#extra-engine-arguments) * [Extra engine arguments](#extra-engine-arguments)
......
...@@ -46,14 +46,11 @@ pub async fn prepare_engine( ...@@ -46,14 +46,11 @@ pub async fn prepare_engine(
anyhow::bail!("Cannot be both static mode and run with dynamic discovery."); anyhow::bail!("Cannot be both static mode and run with dynamic discovery.");
}; };
let model_manager = Arc::new(ModelManager::new()); let model_manager = Arc::new(ModelManager::new());
let watch_obj = Arc::new( let watch_obj = Arc::new(ModelWatcher::new(
ModelWatcher::new( distributed_runtime,
distributed_runtime, model_manager.clone(),
model_manager.clone(), dynamo_runtime::pipeline::RouterMode::RoundRobin,
dynamo_runtime::pipeline::RouterMode::RoundRobin, ));
)
.await?,
);
let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?; let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
......
...@@ -103,7 +103,7 @@ async fn run_watcher( ...@@ -103,7 +103,7 @@ async fn run_watcher(
network_prefix: &str, network_prefix: &str,
router_mode: RouterMode, router_mode: RouterMode,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let watch_obj = ModelWatcher::new(runtime, model_manager, router_mode).await?; let watch_obj = ModelWatcher::new(runtime, model_manager, router_mode);
tracing::info!("Watching for remote model at {network_prefix}"); tracing::info!("Watching for remote model at {network_prefix}");
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?; let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
......
...@@ -5,7 +5,7 @@ use std::{future::Future, pin::Pin}; ...@@ -5,7 +5,7 @@ use std::{future::Future, pin::Pin};
use std::{io::Read, sync::Arc, time::Duration}; use std::{io::Read, sync::Arc, time::Duration};
use anyhow::Context; use anyhow::Context;
use dynamo_llm::{backend::ExecutionContext, engines::StreamingEngine, LocalModel}; use dynamo_llm::{backend::ExecutionContext, engines::StreamingEngine, local_model::LocalModel};
use dynamo_runtime::{CancellationToken, DistributedRuntime}; use dynamo_runtime::{CancellationToken, DistributedRuntime};
mod flags; mod flags;
......
...@@ -12,7 +12,7 @@ use regex::Regex; ...@@ -12,7 +12,7 @@ use regex::Regex;
use tokio::io::AsyncBufReadExt; use tokio::io::AsyncBufReadExt;
use dynamo_llm::engines::MultiNodeConfig; use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::LocalModel; use dynamo_llm::local_model::LocalModel;
use dynamo_runtime::protocols::Endpoint as EndpointId; use dynamo_runtime::protocols::Endpoint as EndpointId;
pub mod sglang; pub mod sglang;
......
...@@ -23,6 +23,7 @@ homepage.workspace = true ...@@ -23,6 +23,7 @@ homepage.workspace = true
repository.workspace = true repository.workspace = true
[dependencies] [dependencies]
anyhow = { workspace = true }
dynamo-runtime = { workspace = true } dynamo-runtime = { workspace = true }
dynamo-llm = { workspace = true } dynamo-llm = { workspace = true }
...@@ -32,4 +33,4 @@ tracing = { workspace = true } ...@@ -32,4 +33,4 @@ tracing = { workspace = true }
tokio = { workspace = true } tokio = { workspace = true }
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
tabled = { version = "0.18" } tabled = { version = "0.18" }
\ No newline at end of file
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use tracing as log;
use dynamo_llm::discovery::ModelEntry; use dynamo_llm::discovery::{ModelManager, ModelWatcher};
use dynamo_llm::local_model::{LocalModel, ModelNetworkName};
use dynamo_llm::model_type::ModelType; use dynamo_llm::model_type::ModelType;
use dynamo_runtime::component::Endpoint;
use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::{ use dynamo_runtime::{
distributed::DistributedConfig, logging, protocols::Endpoint, raise, DistributedRuntime, distributed::DistributedConfig, logging, DistributedRuntime, Result, Runtime, Worker,
Result, Runtime, Worker,
}; };
// Macro to define model types and associated commands // Macro to define model types and associated commands
...@@ -93,12 +96,13 @@ define_type_subcommands!( ...@@ -93,12 +96,13 @@ define_type_subcommands!(
#[command( #[command(
author="NVIDIA", author="NVIDIA",
version="0.2.1", version="0.2.1",
about="LLMCTL - Control and manage Dynamo Components", about="LLMCTL - Deprecated. Do not use.",
long_about = None, long_about = None,
disable_help_subcommand = true, disable_help_subcommand = true,
)] )]
struct Cli { struct Cli {
/// Public Namespace to operate in /// Public Namespace to operate in
/// Do not use this. In fact don't use anything about this file.
#[arg(short = 'n', long)] #[arg(short = 'n', long)]
public_namespace: Option<String>, public_namespace: Option<String>,
...@@ -158,8 +162,8 @@ fn main() -> Result<()> { ...@@ -158,8 +162,8 @@ fn main() -> Result<()> {
logging::init(); logging::init();
let cli = Cli::parse(); let cli = Cli::parse();
// Default namespace to "public" if not specified // Default namespace to "dynamo" if not specified
let namespace = cli.public_namespace.unwrap_or_else(|| "public".to_string()); let namespace = cli.public_namespace.unwrap_or_else(|| "dynamo".to_string());
let worker = Worker::from_settings()?; let worker = Worker::from_settings()?;
worker.execute(|runtime| async move { handle_command(runtime, namespace, cli.command).await }) worker.execute(|runtime| async move { handle_command(runtime, namespace, cli.command).await })
...@@ -200,8 +204,8 @@ async fn handle_command(runtime: Runtime, namespace: String, command: Commands) ...@@ -200,8 +204,8 @@ async fn handle_command(runtime: Runtime, namespace: String, command: Commands)
} }
} }
HttpCommands::Remove { model_type } => { HttpCommands::Remove { model_type } => {
let (model_type, name) = model_type.into_parts(); let (_, name) = model_type.into_parts();
remove_model(&distributed, namespace.to_string(), model_type, &name).await?; remove_model(&distributed, &name).await?;
} }
} }
} }
...@@ -209,7 +213,6 @@ async fn handle_command(runtime: Runtime, namespace: String, command: Commands) ...@@ -209,7 +213,6 @@ async fn handle_command(runtime: Runtime, namespace: String, command: Commands)
Ok(()) Ok(())
} }
// Helper functions to handle the actual operations
async fn add_model( async fn add_model(
distributed: &DistributedRuntime, distributed: &DistributedRuntime,
namespace: String, namespace: String,
...@@ -217,74 +220,15 @@ async fn add_model( ...@@ -217,74 +220,15 @@ async fn add_model(
model_name: String, model_name: String,
endpoint_name: &str, endpoint_name: &str,
) -> Result<()> { ) -> Result<()> {
log::debug!( tracing::debug!("Adding model {model_name} with endpoint {endpoint_name}");
"Adding model {} with endpoint {}",
model_name,
endpoint_name
);
if model_name.starts_with('/') { if model_name.starts_with('/') {
raise!("Model name '{}' cannot start with a slash", model_name); anyhow::bail!("Model name '{model_name}' cannot start with a slash");
}
let parts: Vec<&str> = endpoint_name.split('.').collect();
if parts.len() < 2 {
raise!("Endpoint name '{}' is too short. Format should be 'component.endpoint' or 'namespace.component.endpoint'", endpoint_name);
} else if parts.len() > 3 {
raise!("Endpoint name '{}' is too long. Format should be 'component.endpoint' or 'namespace.component.endpoint'", endpoint_name);
} }
// create model entry let endpoint = endpoint_from_name(distributed, &namespace, endpoint_name)?;
let endpoint = Endpoint {
namespace: if parts.len() == 3 {
parts[0].to_string()
} else {
println!(
"Using the public namespace: {} for model: {}",
namespace, model_name
);
namespace.clone()
},
component: parts[parts.len() - 2].to_string(),
name: parts[parts.len() - 1].to_string(),
};
let model = ModelEntry { let mut model = LocalModel::with_name_only(&model_name);
name: model_name.to_string(), model.attach(&endpoint, model_type).await?;
endpoint,
model_type,
};
// add model to etcd
let component = distributed.namespace(&namespace)?.component("http")?;
let path = format!(
"{}/models/{}/{}",
component.etcd_root(),
model_type.as_str(),
model_name
);
let etcd_client = distributed
.etcd_client()
.expect("unreachable: llmctl is only useful with dynamic workers");
// check if model already exists
let kvs = etcd_client.kv_get_prefix(&path).await?;
if !kvs.is_empty() {
println!(
"{} model {} already exists, please remove it before changing the endpoint.",
model_type.as_str(),
model_name,
);
list_single_model(distributed, namespace, model_type, model_name).await?;
} else {
etcd_client
.kv_create(path, serde_json::to_vec_pretty(&model)?, None)
.await?;
println!("Added new {} model {}", model_type.as_str(), model_name,);
list_single_model(distributed, namespace, model_type, model_name).await?;
}
Ok(()) Ok(())
} }
...@@ -303,147 +247,104 @@ struct ModelRow { ...@@ -303,147 +247,104 @@ struct ModelRow {
endpoint: String, endpoint: String,
} }
async fn list_single_model(
distributed: &DistributedRuntime,
namespace: String,
model_type: ModelType,
model_name: String,
) -> Result<()> {
let component = distributed.namespace(&namespace)?.component("http")?;
let path = format!(
"{}/models/{}/{}",
component.etcd_root(),
model_type.as_str(),
model_name
);
let mut models = Vec::new();
let etcd_client = distributed
.etcd_client()
.expect("llmctl is only useful for dynamic workers");
let kvs = etcd_client.kv_get_prefix(&path).await?;
for kv in kvs {
if let (Ok(_key), Ok(model)) = (
kv.key_str(),
serde_json::from_slice::<ModelEntry>(kv.value()),
) {
models.push(ModelRow {
model_type: model_type.as_str().to_string(),
name: model_name.clone(),
namespace: model.endpoint.namespace,
component: model.endpoint.component,
endpoint: model.endpoint.name,
});
}
}
if models.is_empty() {
println!("Something went wrong, no model was found.");
} else {
let table = tabled::Table::new(models);
println!("{}", table);
}
Ok(())
}
async fn list_models( async fn list_models(
distributed: &DistributedRuntime, distributed: &DistributedRuntime,
namespace: String, namespace: String,
model_type: Option<ModelType>, model_type: Option<ModelType>,
) -> Result<()> { ) -> Result<()> {
let component = distributed.namespace(&namespace)?.component("http")?; // We only need a ModelWatcher to call it's all_entries. llmctl is going away so no need to
// refactor for this.
let watcher = ModelWatcher::new(
distributed.clone(),
Arc::new(ModelManager::new()),
RouterMode::Random,
);
let mut models = Vec::new(); let mut models = Vec::new();
let model_types = match model_type { for entry in watcher.all_entries().await? {
Some(mt) => vec![mt], match (model_type, entry.model_type) {
None => vec![ModelType::Chat, ModelType::Completion], (None, _) => {
}; // list all
}
// TODO: Do we need the model_type in etcd key? (Some(want), got) if want == got => {
// match
for mt in model_types { }
let prefix = format!("{}/models/{}/", component.etcd_root(), mt.as_str(),); _ => {
// no match
let etcd_client = distributed continue;
.etcd_client()
.expect("llmctl is only useful with dynamic workers");
let kvs = etcd_client.kv_get_prefix(&prefix).await?;
for kv in kvs {
if let (Ok(key), Ok(model)) = (
kv.key_str(),
serde_json::from_slice::<ModelEntry>(kv.value()),
) {
models.push(ModelRow {
model_type: mt.as_str().to_string(),
name: key.trim_start_matches(&prefix).to_string(),
namespace: model.endpoint.namespace,
component: model.endpoint.component,
endpoint: model.endpoint.name,
});
} }
} }
models.push(ModelRow {
model_type: entry.model_type.as_str().to_string(),
name: entry.name,
namespace: entry.endpoint.namespace,
component: entry.endpoint.component,
endpoint: entry.endpoint.name,
});
} }
if models.is_empty() { if models.is_empty() {
match &model_type { match &model_type {
Some(mt) => println!( Some(mt) => println!(
"No {} models found in the public namespace: {}", "No {} models found in namespace: {}",
mt.as_str(), mt.as_str(),
namespace namespace
), ),
None => println!("No models found in the public namespace: {}", namespace), None => println!("No models found in namespace: {}", namespace),
} }
} else { } else {
let table = tabled::Table::new(models); let table = tabled::Table::new(models);
match &model_type { match &model_type {
Some(mt) => println!( Some(mt) => println!("Listing {} models in namespace: {}", mt.as_str(), namespace),
"Listing {} models in the public namespace: {}", None => println!("Listing all models in namespace: {}", namespace),
mt.as_str(),
namespace
),
None => println!("Listing all models in the public namespace: {}", namespace),
} }
println!("{}", table); println!("{}", table);
} }
Ok(()) Ok(())
} }
async fn remove_model( async fn remove_model(distributed: &DistributedRuntime, model_name: &str) -> Result<()> {
distributed: &DistributedRuntime, // We have to do this manually because normally the etcd lease system does it for us
namespace: String, let watcher = ModelWatcher::new(
model_type: ModelType, distributed.clone(),
name: &str, Arc::new(ModelManager::new()),
) -> Result<()> { RouterMode::Random,
let component = distributed.namespace(&namespace)?.component("http")?;
let prefix = format!(
"{}/models/{}/{}",
component.etcd_root(),
model_type.as_str(),
name
); );
let Some(etcd_client) = distributed.etcd_client() else {
log::debug!("deleting key: {}", prefix); anyhow::bail!("llmctl is only useful with dynamic workers");
};
// get the kvs from etcd let active_instances = watcher.entries_for_model(model_name).await?;
let mut kv_client = distributed for entry in active_instances {
.etcd_client() let network_name = ModelNetworkName::from_entry(&entry, 0);
.expect("llmctl is only useful with dynamic workers") tracing::debug!("deleting key: {network_name}");
.etcd_client() etcd_client
.kv_client(); .kv_delete(network_name.to_string(), None)
match kv_client.delete(prefix.as_bytes(), None).await { .await?;
Ok(_response) => {
println!(
"{} model {} removed from the public namespace: {}",
model_type.as_str(),
name,
namespace
);
}
Err(e) => {
log::error!("Error removing model {}: {}", name, e);
}
} }
Ok(()) Ok(())
} }
fn endpoint_from_name(
distributed: &DistributedRuntime,
namespace: &str,
endpoint_name: &str,
) -> anyhow::Result<Endpoint> {
let parts: Vec<&str> = endpoint_name.split('.').collect();
if parts.len() < 2 {
anyhow::bail!("Endpoint name '{}' is too short. Format should be 'component.endpoint' or 'namespace.component.endpoint'", endpoint_name);
} else if parts.len() > 3 {
anyhow::bail!("Endpoint name '{}' is too long. Format should be 'component.endpoint' or 'namespace.component.endpoint'", endpoint_name);
}
// TODO previous version sometime hardcoded this to "http", so maybe adjust
let component_name = parts[parts.len() - 2].to_string();
let endpoint_name = parts[parts.len() - 1].to_string();
let component = distributed
.namespace(namespace)?
.component(component_name)?;
Ok(component.endpoint(endpoint_name))
}
...@@ -111,9 +111,10 @@ fn register_llm<'p>( ...@@ -111,9 +111,10 @@ fn register_llm<'p>(
let model_name = model_name.map(|n| n.to_string()); let model_name = model_name.map(|n| n.to_string());
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Download from HF, load the ModelDeploymentCard // Download from HF, load the ModelDeploymentCard
let mut local_model = llm_rs::LocalModel::prepare(&inner_path, None, model_name) let mut local_model =
.await llm_rs::local_model::LocalModel::prepare(&inner_path, None, model_name)
.map_err(to_pyerr)?; .await
.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us // Advertise ourself on etcd so ingress can find us
local_model local_model
......
...@@ -42,7 +42,7 @@ use dynamo_llm::protocols::openai::{ ...@@ -42,7 +42,7 @@ use dynamo_llm::protocols::openai::{
}; };
use dynamo_llm::engines::{EngineDispatcher, StreamingEngine}; use dynamo_llm::engines::{EngineDispatcher, StreamingEngine};
use dynamo_llm::LocalModel; use dynamo_llm::local_model::LocalModel;
/// How many requests mistral will run at once in the paged attention scheduler. /// How many requests mistral will run at once in the paged attention scheduler.
/// It actually runs 1 fewer than this. /// It actually runs 1 fewer than this.
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
use std::sync::Arc; use std::sync::Arc;
use dynamo_runtime::protocols;
use dynamo_runtime::transports::etcd; use dynamo_runtime::transports::etcd;
use dynamo_runtime::{protocols, slug::Slug};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{ use crate::{
...@@ -28,6 +28,11 @@ pub struct ModelEntry { ...@@ -28,6 +28,11 @@ pub struct ModelEntry {
} }
impl ModelEntry { impl ModelEntry {
/// Slugified display name for use in etcd and NATS
pub fn slug(&self) -> Slug {
Slug::from_string(&self.name)
}
pub fn requires_preprocessing(&self) -> bool { pub fn requires_preprocessing(&self) -> bool {
matches!(self.model_type, ModelType::Backend) matches!(self.model_type, ModelType::Backend)
} }
...@@ -40,7 +45,7 @@ impl ModelEntry { ...@@ -40,7 +45,7 @@ impl ModelEntry {
) -> anyhow::Result<ModelDeploymentCard> { ) -> anyhow::Result<ModelDeploymentCard> {
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone())); let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore)); let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let card_key = ModelDeploymentCard::service_name_slug(&self.name); let card_key = self.slug();
match card_store match card_store
.load::<ModelDeploymentCard>(model_card::ROOT_PATH, &card_key) .load::<ModelDeploymentCard>(model_card::ROOT_PATH, &card_key)
.await .await
......
...@@ -13,6 +13,7 @@ use crate::{ ...@@ -13,6 +13,7 @@ use crate::{
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine, completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
}, },
}; };
use std::collections::HashSet;
use std::sync::RwLock; use std::sync::RwLock;
use std::{ use std::{
collections::HashMap, collections::HashMap,
...@@ -62,6 +63,14 @@ impl ModelManager { ...@@ -62,6 +63,14 @@ impl ModelManager {
|| self.completion_engines.read().unwrap().contains(model) || self.completion_engines.read().unwrap().contains(model)
} }
pub fn model_display_names(&self) -> HashSet<String> {
self.list_chat_completions_models()
.into_iter()
.chain(self.list_completions_models())
.chain(self.list_embeddings_models())
.collect()
}
pub fn list_chat_completions_models(&self) -> Vec<String> { pub fn list_chat_completions_models(&self) -> Vec<String> {
self.chat_completion_engines.read().unwrap().list() self.chat_completion_engines.read().unwrap().list()
} }
......
...@@ -39,17 +39,17 @@ pub struct ModelWatcher { ...@@ -39,17 +39,17 @@ pub struct ModelWatcher {
} }
impl ModelWatcher { impl ModelWatcher {
pub async fn new( pub fn new(
runtime: DistributedRuntime, runtime: DistributedRuntime,
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
router_mode: RouterMode, router_mode: RouterMode,
) -> anyhow::Result<ModelWatcher> { ) -> ModelWatcher {
Ok(Self { Self {
manager: model_manager, manager: model_manager,
drt: runtime, drt: runtime,
router_mode, router_mode,
notify_on_model: Notify::new(), notify_on_model: Notify::new(),
}) }
} }
/// Wait until we have at least one chat completions model and return it's name. /// Wait until we have at least one chat completions model and return it's name.
...@@ -93,10 +93,7 @@ impl ModelWatcher { ...@@ -93,10 +93,7 @@ impl ModelWatcher {
self.manager.save_model_entry(key, model_entry.clone()); self.manager.save_model_entry(key, model_entry.clone());
if self.manager.has_model_any(&model_entry.name) { if self.manager.has_model_any(&model_entry.name) {
tracing::trace!( tracing::trace!(name = model_entry.name, "New endpoint for existing model");
service_name = model_entry.name,
"New endpoint for existing model"
);
self.notify_on_model.notify_waiters(); self.notify_on_model.notify_waiters();
continue; continue;
} }
...@@ -300,7 +297,7 @@ impl ModelWatcher { ...@@ -300,7 +297,7 @@ impl ModelWatcher {
} }
/// All the registered ModelEntry, one per instance /// All the registered ModelEntry, one per instance
async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> { pub async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> {
let Some(etcd_client) = self.drt.etcd_client() else { let Some(etcd_client) = self.drt.etcd_client() else {
anyhow::bail!("all_entries: Missing etcd client"); anyhow::bail!("all_entries: Missing etcd client");
}; };
...@@ -326,7 +323,7 @@ impl ModelWatcher { ...@@ -326,7 +323,7 @@ impl ModelWatcher {
Ok(entries) Ok(entries)
} }
async fn entries_for_model(&self, model_name: &str) -> anyhow::Result<Vec<ModelEntry>> { pub async fn entries_for_model(&self, model_name: &str) -> anyhow::Result<Vec<ModelEntry>> {
let mut all = self.all_entries().await?; let mut all = self.all_entries().await?;
all.retain(|entry| entry.name == model_name); all.retain(|entry| entry.name == model_name);
Ok(all) Ok(all)
......
...@@ -357,17 +357,10 @@ async fn list_models_openai( ...@@ -357,17 +357,10 @@ async fn list_models_openai(
.as_secs(); .as_secs();
let mut data = Vec::new(); let mut data = Vec::new();
let models: HashSet<String> = state let models: HashSet<String> = state.manager().model_display_names();
.manager() for model_name in models {
.list_chat_completions_models()
.into_iter()
.chain(state.manager().list_completions_models())
.chain(state.manager().list_embeddings_models())
.collect();
for model_id in models {
data.push(ModelListing { data.push(ModelListing {
id: model_id.clone(), id: model_name.clone(),
object: "object", object: "object",
created, // Where would this come from? The GGUF? created, // Where would this come from? The GGUF?
owned_by: "nvidia".to_string(), // Get organization from GGUF owned_by: "nvidia".to_string(), // Get organization from GGUF
......
...@@ -29,6 +29,7 @@ pub mod hub; ...@@ -29,6 +29,7 @@ pub mod hub;
pub mod key_value_store; pub mod key_value_store;
pub mod kv_router; pub mod kv_router;
pub use kv_router::DEFAULT_KV_BLOCK_SIZE; pub use kv_router::DEFAULT_KV_BLOCK_SIZE;
pub mod local_model;
pub mod mocker; pub mod mocker;
pub mod model_card; pub mod model_card;
pub mod model_type; pub mod model_type;
...@@ -40,8 +41,5 @@ pub mod tokenizers; ...@@ -40,8 +41,5 @@ pub mod tokenizers;
pub mod tokens; pub mod tokens;
pub mod types; pub mod types;
mod local_model;
pub use local_model::LocalModel;
#[cfg(feature = "block-manager")] #[cfg(feature = "block-manager")]
pub mod block_manager; pub mod block_manager;
...@@ -149,7 +149,7 @@ impl LocalModel { ...@@ -149,7 +149,7 @@ impl LocalModel {
let network_name = ModelNetworkName::from_local(endpoint, etcd_client.lease_id()); let network_name = ModelNetworkName::from_local(endpoint, etcd_client.lease_id());
tracing::debug!("Registering with etcd as {network_name}"); tracing::debug!("Registering with etcd as {network_name}");
let model_registration = ModelEntry { let model_registration = ModelEntry {
name: self.service_name().to_string(), name: self.display_name().to_string(),
endpoint: endpoint.id(), endpoint: endpoint.id(),
model_type, model_type,
}; };
......
...@@ -32,6 +32,15 @@ impl ModelNetworkName { ...@@ -32,6 +32,15 @@ impl ModelNetworkName {
) )
} }
pub fn from_entry(entry: &ModelEntry, lease_id: i64) -> Self {
Self::from_parts(
&entry.endpoint.namespace,
&entry.endpoint.component,
&entry.endpoint.name,
lease_id,
)
}
/// Fetch the ModelEntry from etcd. /// Fetch the ModelEntry from etcd.
pub async fn load_entry(&self, etcd_client: &etcd::Client) -> anyhow::Result<ModelEntry> { pub async fn load_entry(&self, etcd_client: &etcd::Client) -> anyhow::Result<ModelEntry> {
let mut model_entries = etcd_client.kv_get(self.to_string(), None).await?; let mut model_entries = etcd_client.kv_get(self.to_string(), None).await?;
......
...@@ -143,13 +143,6 @@ impl ModelDeploymentCard { ...@@ -143,13 +143,6 @@ impl ModelDeploymentCard {
} }
} }
/// A URL and NATS friendly and very likely unique ID for this model.
/// Mostly human readable. a-z, 0-9, _ and - only.
/// Pass the service_name.
pub fn service_name_slug(s: &str) -> Slug {
Slug::from_string(s)
}
/// How often we should check if a model deployment card expired because it's workers are gone /// How often we should check if a model deployment card expired because it's workers are gone
pub fn expiry_check_period() -> Duration { pub fn expiry_check_period() -> Duration {
match CARD_MAX_AGE.to_std() { match CARD_MAX_AGE.to_std() {
...@@ -186,7 +179,7 @@ impl ModelDeploymentCard { ...@@ -186,7 +179,7 @@ impl ModelDeploymentCard {
} }
pub fn slug(&self) -> Slug { pub fn slug(&self) -> Slug {
ModelDeploymentCard::service_name_slug(&self.service_name) Slug::from_string(&self.display_name)
} }
/// Serialize the model deployment card to a JSON string /// Serialize the model deployment card to a JSON string
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment