Unverified Commit 3e8e38a9 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

fix(llmctl): Use ModelWatcher instead of direct etcd operations (#1150)

parent dcad8ac7
......@@ -3522,6 +3522,7 @@ dependencies = [
name = "llmctl"
version = "0.2.1"
dependencies = [
"anyhow",
"clap",
"dynamo-llm",
"dynamo-runtime",
......
......@@ -56,7 +56,7 @@ async fn app(runtime: Runtime) -> Result<()> {
// the cli when operating on an `http` component will validate the namespace.component is
// registered with HttpServiceComponentDefinition
let watch_obj = ModelWatcher::new(distributed.clone(), manager, RouterMode::Random).await?;
let watch_obj = ModelWatcher::new(distributed.clone(), manager, RouterMode::Random);
if let Some(etcd_client) = distributed.etcd_client() {
let models_watcher: PrefixWatcher =
......
......@@ -14,7 +14,7 @@
* [Vllm](#vllm)
* [TensorRT-LLM](#tensorrt-llm-engine)
* [Echo Engines](#echo-engines)
* [Write your own engine in Python](#write-your-own-engine-in-python)
* [Writing your own engine in Python](#writing-your-own-engine-in-python)
* [Batch mode](#batch-mode)
* [Defaults](#defaults)
* [Extra engine arguments](#extra-engine-arguments)
......
......@@ -46,14 +46,11 @@ pub async fn prepare_engine(
anyhow::bail!("Cannot be both static mode and run with dynamic discovery.");
};
let model_manager = Arc::new(ModelManager::new());
let watch_obj = Arc::new(
ModelWatcher::new(
distributed_runtime,
model_manager.clone(),
dynamo_runtime::pipeline::RouterMode::RoundRobin,
)
.await?,
);
let watch_obj = Arc::new(ModelWatcher::new(
distributed_runtime,
model_manager.clone(),
dynamo_runtime::pipeline::RouterMode::RoundRobin,
));
let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
......
......@@ -103,7 +103,7 @@ async fn run_watcher(
network_prefix: &str,
router_mode: RouterMode,
) -> anyhow::Result<()> {
let watch_obj = ModelWatcher::new(runtime, model_manager, router_mode).await?;
let watch_obj = ModelWatcher::new(runtime, model_manager, router_mode);
tracing::info!("Watching for remote model at {network_prefix}");
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
......
......@@ -5,7 +5,7 @@ use std::{future::Future, pin::Pin};
use std::{io::Read, sync::Arc, time::Duration};
use anyhow::Context;
use dynamo_llm::{backend::ExecutionContext, engines::StreamingEngine, LocalModel};
use dynamo_llm::{backend::ExecutionContext, engines::StreamingEngine, local_model::LocalModel};
use dynamo_runtime::{CancellationToken, DistributedRuntime};
mod flags;
......
......@@ -12,7 +12,7 @@ use regex::Regex;
use tokio::io::AsyncBufReadExt;
use dynamo_llm::engines::MultiNodeConfig;
use dynamo_llm::LocalModel;
use dynamo_llm::local_model::LocalModel;
use dynamo_runtime::protocols::Endpoint as EndpointId;
pub mod sglang;
......
......@@ -23,6 +23,7 @@ homepage.workspace = true
repository.workspace = true
[dependencies]
anyhow = { workspace = true }
dynamo-runtime = { workspace = true }
dynamo-llm = { workspace = true }
......@@ -32,4 +33,4 @@ tracing = { workspace = true }
tokio = { workspace = true }
clap = { version = "4.5", features = ["derive"] }
tabled = { version = "0.18" }
\ No newline at end of file
tabled = { version = "0.18" }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use clap::{Parser, Subcommand};
use tracing as log;
use dynamo_llm::discovery::ModelEntry;
use dynamo_llm::discovery::{ModelManager, ModelWatcher};
use dynamo_llm::local_model::{LocalModel, ModelNetworkName};
use dynamo_llm::model_type::ModelType;
use dynamo_runtime::component::Endpoint;
use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::{
distributed::DistributedConfig, logging, protocols::Endpoint, raise, DistributedRuntime,
Result, Runtime, Worker,
distributed::DistributedConfig, logging, DistributedRuntime, Result, Runtime, Worker,
};
// Macro to define model types and associated commands
......@@ -93,12 +96,13 @@ define_type_subcommands!(
#[command(
author="NVIDIA",
version="0.2.1",
about="LLMCTL - Control and manage Dynamo Components",
about="LLMCTL - Deprecated. Do not use.",
long_about = None,
disable_help_subcommand = true,
)]
struct Cli {
/// Public Namespace to operate in
/// Do not use this. In fact don't use anything about this file.
#[arg(short = 'n', long)]
public_namespace: Option<String>,
......@@ -158,8 +162,8 @@ fn main() -> Result<()> {
logging::init();
let cli = Cli::parse();
// Default namespace to "public" if not specified
let namespace = cli.public_namespace.unwrap_or_else(|| "public".to_string());
// Default namespace to "dynamo" if not specified
let namespace = cli.public_namespace.unwrap_or_else(|| "dynamo".to_string());
let worker = Worker::from_settings()?;
worker.execute(|runtime| async move { handle_command(runtime, namespace, cli.command).await })
......@@ -200,8 +204,8 @@ async fn handle_command(runtime: Runtime, namespace: String, command: Commands)
}
}
HttpCommands::Remove { model_type } => {
let (model_type, name) = model_type.into_parts();
remove_model(&distributed, namespace.to_string(), model_type, &name).await?;
let (_, name) = model_type.into_parts();
remove_model(&distributed, &name).await?;
}
}
}
......@@ -209,7 +213,6 @@ async fn handle_command(runtime: Runtime, namespace: String, command: Commands)
Ok(())
}
// Helper functions to handle the actual operations
async fn add_model(
distributed: &DistributedRuntime,
namespace: String,
......@@ -217,74 +220,15 @@ async fn add_model(
model_name: String,
endpoint_name: &str,
) -> Result<()> {
log::debug!(
"Adding model {} with endpoint {}",
model_name,
endpoint_name
);
tracing::debug!("Adding model {model_name} with endpoint {endpoint_name}");
if model_name.starts_with('/') {
raise!("Model name '{}' cannot start with a slash", model_name);
}
let parts: Vec<&str> = endpoint_name.split('.').collect();
if parts.len() < 2 {
raise!("Endpoint name '{}' is too short. Format should be 'component.endpoint' or 'namespace.component.endpoint'", endpoint_name);
} else if parts.len() > 3 {
raise!("Endpoint name '{}' is too long. Format should be 'component.endpoint' or 'namespace.component.endpoint'", endpoint_name);
anyhow::bail!("Model name '{model_name}' cannot start with a slash");
}
// create model entry
let endpoint = Endpoint {
namespace: if parts.len() == 3 {
parts[0].to_string()
} else {
println!(
"Using the public namespace: {} for model: {}",
namespace, model_name
);
namespace.clone()
},
component: parts[parts.len() - 2].to_string(),
name: parts[parts.len() - 1].to_string(),
};
let endpoint = endpoint_from_name(distributed, &namespace, endpoint_name)?;
let model = ModelEntry {
name: model_name.to_string(),
endpoint,
model_type,
};
// add model to etcd
let component = distributed.namespace(&namespace)?.component("http")?;
let path = format!(
"{}/models/{}/{}",
component.etcd_root(),
model_type.as_str(),
model_name
);
let etcd_client = distributed
.etcd_client()
.expect("unreachable: llmctl is only useful with dynamic workers");
// check if model already exists
let kvs = etcd_client.kv_get_prefix(&path).await?;
if !kvs.is_empty() {
println!(
"{} model {} already exists, please remove it before changing the endpoint.",
model_type.as_str(),
model_name,
);
list_single_model(distributed, namespace, model_type, model_name).await?;
} else {
etcd_client
.kv_create(path, serde_json::to_vec_pretty(&model)?, None)
.await?;
println!("Added new {} model {}", model_type.as_str(), model_name,);
list_single_model(distributed, namespace, model_type, model_name).await?;
}
let mut model = LocalModel::with_name_only(&model_name);
model.attach(&endpoint, model_type).await?;
Ok(())
}
......@@ -303,147 +247,104 @@ struct ModelRow {
endpoint: String,
}
async fn list_single_model(
distributed: &DistributedRuntime,
namespace: String,
model_type: ModelType,
model_name: String,
) -> Result<()> {
let component = distributed.namespace(&namespace)?.component("http")?;
let path = format!(
"{}/models/{}/{}",
component.etcd_root(),
model_type.as_str(),
model_name
);
let mut models = Vec::new();
let etcd_client = distributed
.etcd_client()
.expect("llmctl is only useful for dynamic workers");
let kvs = etcd_client.kv_get_prefix(&path).await?;
for kv in kvs {
if let (Ok(_key), Ok(model)) = (
kv.key_str(),
serde_json::from_slice::<ModelEntry>(kv.value()),
) {
models.push(ModelRow {
model_type: model_type.as_str().to_string(),
name: model_name.clone(),
namespace: model.endpoint.namespace,
component: model.endpoint.component,
endpoint: model.endpoint.name,
});
}
}
if models.is_empty() {
println!("Something went wrong, no model was found.");
} else {
let table = tabled::Table::new(models);
println!("{}", table);
}
Ok(())
}
async fn list_models(
distributed: &DistributedRuntime,
namespace: String,
model_type: Option<ModelType>,
) -> Result<()> {
let component = distributed.namespace(&namespace)?.component("http")?;
// We only need a ModelWatcher to call it's all_entries. llmctl is going away so no need to
// refactor for this.
let watcher = ModelWatcher::new(
distributed.clone(),
Arc::new(ModelManager::new()),
RouterMode::Random,
);
let mut models = Vec::new();
let model_types = match model_type {
Some(mt) => vec![mt],
None => vec![ModelType::Chat, ModelType::Completion],
};
// TODO: Do we need the model_type in etcd key?
for mt in model_types {
let prefix = format!("{}/models/{}/", component.etcd_root(), mt.as_str(),);
let etcd_client = distributed
.etcd_client()
.expect("llmctl is only useful with dynamic workers");
let kvs = etcd_client.kv_get_prefix(&prefix).await?;
for kv in kvs {
if let (Ok(key), Ok(model)) = (
kv.key_str(),
serde_json::from_slice::<ModelEntry>(kv.value()),
) {
models.push(ModelRow {
model_type: mt.as_str().to_string(),
name: key.trim_start_matches(&prefix).to_string(),
namespace: model.endpoint.namespace,
component: model.endpoint.component,
endpoint: model.endpoint.name,
});
for entry in watcher.all_entries().await? {
match (model_type, entry.model_type) {
(None, _) => {
// list all
}
(Some(want), got) if want == got => {
// match
}
_ => {
// no match
continue;
}
}
models.push(ModelRow {
model_type: entry.model_type.as_str().to_string(),
name: entry.name,
namespace: entry.endpoint.namespace,
component: entry.endpoint.component,
endpoint: entry.endpoint.name,
});
}
if models.is_empty() {
match &model_type {
Some(mt) => println!(
"No {} models found in the public namespace: {}",
"No {} models found in namespace: {}",
mt.as_str(),
namespace
),
None => println!("No models found in the public namespace: {}", namespace),
None => println!("No models found in namespace: {}", namespace),
}
} else {
let table = tabled::Table::new(models);
match &model_type {
Some(mt) => println!(
"Listing {} models in the public namespace: {}",
mt.as_str(),
namespace
),
None => println!("Listing all models in the public namespace: {}", namespace),
Some(mt) => println!("Listing {} models in namespace: {}", mt.as_str(), namespace),
None => println!("Listing all models in namespace: {}", namespace),
}
println!("{}", table);
}
Ok(())
}
async fn remove_model(
distributed: &DistributedRuntime,
namespace: String,
model_type: ModelType,
name: &str,
) -> Result<()> {
let component = distributed.namespace(&namespace)?.component("http")?;
let prefix = format!(
"{}/models/{}/{}",
component.etcd_root(),
model_type.as_str(),
name
async fn remove_model(distributed: &DistributedRuntime, model_name: &str) -> Result<()> {
// We have to do this manually because normally the etcd lease system does it for us
let watcher = ModelWatcher::new(
distributed.clone(),
Arc::new(ModelManager::new()),
RouterMode::Random,
);
log::debug!("deleting key: {}", prefix);
// get the kvs from etcd
let mut kv_client = distributed
.etcd_client()
.expect("llmctl is only useful with dynamic workers")
.etcd_client()
.kv_client();
match kv_client.delete(prefix.as_bytes(), None).await {
Ok(_response) => {
println!(
"{} model {} removed from the public namespace: {}",
model_type.as_str(),
name,
namespace
);
}
Err(e) => {
log::error!("Error removing model {}: {}", name, e);
}
let Some(etcd_client) = distributed.etcd_client() else {
anyhow::bail!("llmctl is only useful with dynamic workers");
};
let active_instances = watcher.entries_for_model(model_name).await?;
for entry in active_instances {
let network_name = ModelNetworkName::from_entry(&entry, 0);
tracing::debug!("deleting key: {network_name}");
etcd_client
.kv_delete(network_name.to_string(), None)
.await?;
}
Ok(())
}
fn endpoint_from_name(
distributed: &DistributedRuntime,
namespace: &str,
endpoint_name: &str,
) -> anyhow::Result<Endpoint> {
let parts: Vec<&str> = endpoint_name.split('.').collect();
if parts.len() < 2 {
anyhow::bail!("Endpoint name '{}' is too short. Format should be 'component.endpoint' or 'namespace.component.endpoint'", endpoint_name);
} else if parts.len() > 3 {
anyhow::bail!("Endpoint name '{}' is too long. Format should be 'component.endpoint' or 'namespace.component.endpoint'", endpoint_name);
}
// TODO previous version sometime hardcoded this to "http", so maybe adjust
let component_name = parts[parts.len() - 2].to_string();
let endpoint_name = parts[parts.len() - 1].to_string();
let component = distributed
.namespace(namespace)?
.component(component_name)?;
Ok(component.endpoint(endpoint_name))
}
......@@ -111,9 +111,10 @@ fn register_llm<'p>(
let model_name = model_name.map(|n| n.to_string());
pyo3_async_runtimes::tokio::future_into_py(py, async move {
// Download from HF, load the ModelDeploymentCard
let mut local_model = llm_rs::LocalModel::prepare(&inner_path, None, model_name)
.await
.map_err(to_pyerr)?;
let mut local_model =
llm_rs::local_model::LocalModel::prepare(&inner_path, None, model_name)
.await
.map_err(to_pyerr)?;
// Advertise ourself on etcd so ingress can find us
local_model
......
......@@ -42,7 +42,7 @@ use dynamo_llm::protocols::openai::{
};
use dynamo_llm::engines::{EngineDispatcher, StreamingEngine};
use dynamo_llm::LocalModel;
use dynamo_llm::local_model::LocalModel;
/// How many requests mistral will run at once in the paged attention scheduler.
/// It actually runs 1 fewer than this.
......
......@@ -3,8 +3,8 @@
use std::sync::Arc;
use dynamo_runtime::protocols;
use dynamo_runtime::transports::etcd;
use dynamo_runtime::{protocols, slug::Slug};
use serde::{Deserialize, Serialize};
use crate::{
......@@ -28,6 +28,11 @@ pub struct ModelEntry {
}
impl ModelEntry {
/// Slugified display name for use in etcd and NATS
pub fn slug(&self) -> Slug {
Slug::from_string(&self.name)
}
pub fn requires_preprocessing(&self) -> bool {
matches!(self.model_type, ModelType::Backend)
}
......@@ -40,7 +45,7 @@ impl ModelEntry {
) -> anyhow::Result<ModelDeploymentCard> {
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let card_key = ModelDeploymentCard::service_name_slug(&self.name);
let card_key = self.slug();
match card_store
.load::<ModelDeploymentCard>(model_card::ROOT_PATH, &card_key)
.await
......
......@@ -13,6 +13,7 @@ use crate::{
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
},
};
use std::collections::HashSet;
use std::sync::RwLock;
use std::{
collections::HashMap,
......@@ -62,6 +63,14 @@ impl ModelManager {
|| self.completion_engines.read().unwrap().contains(model)
}
pub fn model_display_names(&self) -> HashSet<String> {
self.list_chat_completions_models()
.into_iter()
.chain(self.list_completions_models())
.chain(self.list_embeddings_models())
.collect()
}
pub fn list_chat_completions_models(&self) -> Vec<String> {
self.chat_completion_engines.read().unwrap().list()
}
......
......@@ -39,17 +39,17 @@ pub struct ModelWatcher {
}
impl ModelWatcher {
pub async fn new(
pub fn new(
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
router_mode: RouterMode,
) -> anyhow::Result<ModelWatcher> {
Ok(Self {
) -> ModelWatcher {
Self {
manager: model_manager,
drt: runtime,
router_mode,
notify_on_model: Notify::new(),
})
}
}
/// Wait until we have at least one chat completions model and return it's name.
......@@ -93,10 +93,7 @@ impl ModelWatcher {
self.manager.save_model_entry(key, model_entry.clone());
if self.manager.has_model_any(&model_entry.name) {
tracing::trace!(
service_name = model_entry.name,
"New endpoint for existing model"
);
tracing::trace!(name = model_entry.name, "New endpoint for existing model");
self.notify_on_model.notify_waiters();
continue;
}
......@@ -300,7 +297,7 @@ impl ModelWatcher {
}
/// All the registered ModelEntry, one per instance
async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> {
pub async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> {
let Some(etcd_client) = self.drt.etcd_client() else {
anyhow::bail!("all_entries: Missing etcd client");
};
......@@ -326,7 +323,7 @@ impl ModelWatcher {
Ok(entries)
}
async fn entries_for_model(&self, model_name: &str) -> anyhow::Result<Vec<ModelEntry>> {
pub async fn entries_for_model(&self, model_name: &str) -> anyhow::Result<Vec<ModelEntry>> {
let mut all = self.all_entries().await?;
all.retain(|entry| entry.name == model_name);
Ok(all)
......
......@@ -357,17 +357,10 @@ async fn list_models_openai(
.as_secs();
let mut data = Vec::new();
let models: HashSet<String> = state
.manager()
.list_chat_completions_models()
.into_iter()
.chain(state.manager().list_completions_models())
.chain(state.manager().list_embeddings_models())
.collect();
for model_id in models {
let models: HashSet<String> = state.manager().model_display_names();
for model_name in models {
data.push(ModelListing {
id: model_id.clone(),
id: model_name.clone(),
object: "object",
created, // Where would this come from? The GGUF?
owned_by: "nvidia".to_string(), // Get organization from GGUF
......
......@@ -29,6 +29,7 @@ pub mod hub;
pub mod key_value_store;
pub mod kv_router;
pub use kv_router::DEFAULT_KV_BLOCK_SIZE;
pub mod local_model;
pub mod mocker;
pub mod model_card;
pub mod model_type;
......@@ -40,8 +41,5 @@ pub mod tokenizers;
pub mod tokens;
pub mod types;
mod local_model;
pub use local_model::LocalModel;
#[cfg(feature = "block-manager")]
pub mod block_manager;
......@@ -149,7 +149,7 @@ impl LocalModel {
let network_name = ModelNetworkName::from_local(endpoint, etcd_client.lease_id());
tracing::debug!("Registering with etcd as {network_name}");
let model_registration = ModelEntry {
name: self.service_name().to_string(),
name: self.display_name().to_string(),
endpoint: endpoint.id(),
model_type,
};
......
......@@ -32,6 +32,15 @@ impl ModelNetworkName {
)
}
pub fn from_entry(entry: &ModelEntry, lease_id: i64) -> Self {
Self::from_parts(
&entry.endpoint.namespace,
&entry.endpoint.component,
&entry.endpoint.name,
lease_id,
)
}
/// Fetch the ModelEntry from etcd.
pub async fn load_entry(&self, etcd_client: &etcd::Client) -> anyhow::Result<ModelEntry> {
let mut model_entries = etcd_client.kv_get(self.to_string(), None).await?;
......
......@@ -143,13 +143,6 @@ impl ModelDeploymentCard {
}
}
/// A URL and NATS friendly and very likely unique ID for this model.
/// Mostly human readable. a-z, 0-9, _ and - only.
/// Pass the service_name.
pub fn service_name_slug(s: &str) -> Slug {
Slug::from_string(s)
}
/// How often we should check if a model deployment card expired because it's workers are gone
pub fn expiry_check_period() -> Duration {
match CARD_MAX_AGE.to_std() {
......@@ -186,7 +179,7 @@ impl ModelDeploymentCard {
}
pub fn slug(&self) -> Slug {
ModelDeploymentCard::service_name_slug(&self.service_name)
Slug::from_string(&self.display_name)
}
/// Serialize the model deployment card to a JSON string
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment