Unverified Commit b520bf44 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Fix model removal on instance stop, refactor discovery (#1142)

- Stop advertising a model when it's last instance stops. Previously was when any instance stops.
- Faster locks on model manager.
- Move discovery code out of http, as it is used by all inputs.
parent 03c160af
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use clap::Parser; use clap::Parser;
use std::sync::Arc;
use dynamo_llm::http::service::{discovery::ModelWatcher, service_v2::HttpService}; use dynamo_llm::discovery::{ModelWatcher, MODEL_ROOT_PATH};
use dynamo_llm::http::service::service_v2::HttpService;
use dynamo_runtime::{ use dynamo_runtime::{
component, logging, pipeline::RouterMode, transports::etcd::PrefixWatcher, DistributedRuntime, logging, pipeline::RouterMode, transports::etcd::PrefixWatcher, DistributedRuntime, Result,
Result, Runtime, Worker, Runtime, Worker,
}; };
#[derive(Parser)] #[derive(Parser)]
...@@ -45,7 +45,7 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -45,7 +45,7 @@ async fn app(runtime: Runtime) -> Result<()> {
.port(args.port) .port(args.port)
.host(args.host) .host(args.host)
.build()?; .build()?;
let manager = http_service.model_manager().clone(); let manager = http_service.state().manager_clone();
// todo - use the IntoComponent trait to register the component // todo - use the IntoComponent trait to register the component
// todo - start a service // todo - start a service
...@@ -56,17 +56,16 @@ async fn app(runtime: Runtime) -> Result<()> { ...@@ -56,17 +56,16 @@ async fn app(runtime: Runtime) -> Result<()> {
// the cli when operating on an `http` component will validate the namespace.component is // the cli when operating on an `http` component will validate the namespace.component is
// registered with HttpServiceComponentDefinition // registered with HttpServiceComponentDefinition
let watch_obj = Arc::new( let watch_obj = ModelWatcher::new(distributed.clone(), manager, RouterMode::Random).await?;
ModelWatcher::new(distributed.clone(), manager.clone(), RouterMode::Random).await?,
);
if let Some(etcd_client) = distributed.etcd_client() { if let Some(etcd_client) = distributed.etcd_client() {
let models_watcher: PrefixWatcher = etcd_client let models_watcher: PrefixWatcher =
.kv_get_and_watch_prefix(component::MODEL_ROOT_PATH) etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?;
.await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
tokio::spawn(watch_obj.watch(receiver)); tokio::spawn(async move {
watch_obj.watch(receiver).await;
});
} }
// Run the service // Run the service
......
...@@ -5,8 +5,8 @@ use std::pin::Pin; ...@@ -5,8 +5,8 @@ use std::pin::Pin;
use dynamo_llm::{ use dynamo_llm::{
backend::{Backend, ExecutionContext}, backend::{Backend, ExecutionContext},
discovery::{ModelManager, ModelWatcher, MODEL_ROOT_PATH},
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
http::service::{discovery::ModelWatcher, ModelManager},
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
preprocessor::OpenAIPreprocessor, preprocessor::OpenAIPreprocessor,
protocols::common::llm_backend::{BackendInput, BackendOutput}, protocols::common::llm_backend::{BackendInput, BackendOutput},
...@@ -19,7 +19,6 @@ use dynamo_llm::{ ...@@ -19,7 +19,6 @@ use dynamo_llm::{
}, },
}; };
use dynamo_runtime::{ use dynamo_runtime::{
component,
engine::{AsyncEngineStream, Data}, engine::{AsyncEngineStream, Data},
pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source}, pipeline::{Context, ManyOut, Operator, ServiceBackend, ServiceFrontend, SingleIn, Source},
DistributedRuntime, Runtime, DistributedRuntime, Runtime,
...@@ -46,7 +45,7 @@ pub async fn prepare_engine( ...@@ -46,7 +45,7 @@ pub async fn prepare_engine(
let Some(etcd_client) = distributed_runtime.etcd_client() else { let Some(etcd_client) = distributed_runtime.etcd_client() else {
anyhow::bail!("Cannot be both static mode and run with dynamic discovery."); anyhow::bail!("Cannot be both static mode and run with dynamic discovery.");
}; };
let model_manager = ModelManager::new(); let model_manager = Arc::new(ModelManager::new());
let watch_obj = Arc::new( let watch_obj = Arc::new(
ModelWatcher::new( ModelWatcher::new(
distributed_runtime, distributed_runtime,
...@@ -55,13 +54,13 @@ pub async fn prepare_engine( ...@@ -55,13 +54,13 @@ pub async fn prepare_engine(
) )
.await?, .await?,
); );
let models_watcher = etcd_client let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?;
.kv_get_and_watch_prefix(component::MODEL_ROOT_PATH)
.await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let inner_watch_obj = watch_obj.clone(); let inner_watch_obj = watch_obj.clone();
let _watcher_task = tokio::spawn(inner_watch_obj.watch(receiver)); let _watcher_task = tokio::spawn(async move {
inner_watch_obj.watch(receiver).await;
});
tracing::info!("Waiting for remote model.."); tracing::info!("Waiting for remote model..");
// TODO: We use the first model to appear, usually we have only one // TODO: We use the first model to appear, usually we have only one
...@@ -69,9 +68,7 @@ pub async fn prepare_engine( ...@@ -69,9 +68,7 @@ pub async fn prepare_engine(
// '/models` to list, and notifications when models are added / removed. // '/models` to list, and notifications when models are added / removed.
let model_service_name = watch_obj.wait_for_chat_model().await; let model_service_name = watch_obj.wait_for_chat_model().await;
let engine = model_manager let engine = model_manager.get_chat_completions_engine(&model_service_name)?;
.state()
.get_chat_completions_engine(&model_service_name)?;
Ok(PreparedEngine { Ok(PreparedEngine {
service_name: model_service_name, service_name: model_service_name,
engine, engine,
......
...@@ -5,10 +5,10 @@ use std::sync::Arc; ...@@ -5,10 +5,10 @@ use std::sync::Arc;
use crate::input::common; use crate::input::common;
use crate::{EngineConfig, Flags}; use crate::{EngineConfig, Flags};
use dynamo_llm::http::service::ModelManager;
use dynamo_llm::{ use dynamo_llm::{
discovery::{ModelManager, ModelWatcher, MODEL_ROOT_PATH},
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
http::service::{discovery, service_v2}, http::service::service_v2,
request_template::RequestTemplate, request_template::RequestTemplate,
types::{ types::{
openai::chat_completions::{ openai::chat_completions::{
...@@ -17,7 +17,6 @@ use dynamo_llm::{ ...@@ -17,7 +17,6 @@ use dynamo_llm::{
openai::completions::{CompletionRequest, CompletionResponse}, openai::completions::{CompletionRequest, CompletionResponse},
}, },
}; };
use dynamo_runtime::component;
use dynamo_runtime::pipeline::RouterMode; use dynamo_runtime::pipeline::RouterMode;
use dynamo_runtime::transports::etcd; use dynamo_runtime::transports::etcd;
use dynamo_runtime::{DistributedRuntime, Runtime}; use dynamo_runtime::{DistributedRuntime, Runtime};
...@@ -43,9 +42,9 @@ pub async fn run( ...@@ -43,9 +42,9 @@ pub async fn run(
// Listen for models registering themselves in etcd, add them to HTTP service // Listen for models registering themselves in etcd, add them to HTTP service
run_watcher( run_watcher(
distributed_runtime, distributed_runtime,
http_service.model_manager().clone(), http_service.state().manager_clone(),
etcd_client.clone(), etcd_client.clone(),
component::MODEL_ROOT_PATH, MODEL_ROOT_PATH,
flags.router_mode.into(), flags.router_mode.into(),
) )
.await?; .await?;
...@@ -99,16 +98,17 @@ pub async fn run( ...@@ -99,16 +98,17 @@ pub async fn run(
/// and registers them with the ModelManager so that the HTTP service can use them. /// and registers them with the ModelManager so that the HTTP service can use them.
async fn run_watcher( async fn run_watcher(
runtime: DistributedRuntime, runtime: DistributedRuntime,
model_manager: ModelManager, model_manager: Arc<ModelManager>,
etcd_client: etcd::Client, etcd_client: etcd::Client,
network_prefix: &str, network_prefix: &str,
router_mode: RouterMode, router_mode: RouterMode,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let watch_obj = let watch_obj = ModelWatcher::new(runtime, model_manager, router_mode).await?;
Arc::new(discovery::ModelWatcher::new(runtime, model_manager, router_mode).await?);
tracing::info!("Watching for remote model at {network_prefix}"); tracing::info!("Watching for remote model at {network_prefix}");
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?; let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let _watcher_task = tokio::spawn(watch_obj.watch(receiver)); let _watcher_task = tokio::spawn(async move {
watch_obj.watch(receiver).await;
});
Ok(()) Ok(())
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use tracing as log; use tracing as log;
use dynamo_llm::{http::service::discovery::ModelEntry, model_type::ModelType}; use dynamo_llm::discovery::ModelEntry;
use dynamo_llm::model_type::ModelType;
use dynamo_runtime::{ use dynamo_runtime::{
distributed::DistributedConfig, logging, protocols::Endpoint, raise, DistributedRuntime, distributed::DistributedConfig, logging, protocols::Endpoint, raise, DistributedRuntime,
Result, Runtime, Worker, Result, Runtime, Worker,
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod model_manager;
pub use model_manager::{ModelManager, ModelManagerError};
mod model_entry;
pub use model_entry::ModelEntry;
mod watcher;
pub use watcher::ModelWatcher;
/// The root etcd path for ModelEntry
pub const MODEL_ROOT_PATH: &str = "models";
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use dynamo_runtime::protocols;
use dynamo_runtime::transports::etcd;
use serde::{Deserialize, Serialize};
use crate::{
key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
model_card::{self, ModelDeploymentCard},
model_type::ModelType,
};
/// [ModelEntry] contains the information to discover models from the etcd cluster.
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelEntry {
/// Public name of the model
/// This will be used to identify the model in the HTTP service from the value used in an an OpenAI ChatRequest.
pub name: String,
/// How to address this on the network
pub endpoint: protocols::Endpoint,
/// Specifies whether the model is a chat, completions, etc model.
pub model_type: ModelType,
}
impl ModelEntry {
pub fn requires_preprocessing(&self) -> bool {
matches!(self.model_type, ModelType::Backend)
}
/// Fetch the ModelDeploymentCard from NATS.
/// This does not touch it's fields so you may need to call move_from_nats on it.
pub async fn load_mdc(
&self,
etcd_client: &etcd::Client,
) -> anyhow::Result<ModelDeploymentCard> {
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let card_key = ModelDeploymentCard::service_name_slug(&self.name);
match card_store
.load::<ModelDeploymentCard>(model_card::ROOT_PATH, &card_key)
.await
{
Ok(Some(mdc)) => Ok(mdc),
Ok(None) => {
anyhow::bail!("Missing ModelDeploymentCard in etcd under key {card_key}");
}
Err(err) => {
anyhow::bail!(
"Error fetching ModelDeploymentCard from etcd under key {card_key}. {err}"
);
}
}
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_runtime::component::Component;
use crate::discovery::ModelEntry;
use crate::kv_router::scheduler::DefaultWorkerSelector;
use crate::{
kv_router::KvRouter,
types::openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
},
};
use std::sync::RwLock;
use std::{
collections::HashMap,
sync::{Arc, Mutex},
};
#[derive(Debug, thiserror::Error)]
pub enum ModelManagerError {
#[error("Model not found: {0}")]
ModelNotFound(String),
#[error("Model already exists: {0}")]
ModelAlreadyExists(String),
}
// Don't implement Clone for this, put it in an Arc instead.
pub struct ModelManager {
// We read a lot and write rarely, so these three are RwLock
completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
// These two are Mutex because we read and write rarely and equally
entries: Mutex<HashMap<String, ModelEntry>>,
kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>,
}
impl Default for ModelManager {
fn default() -> Self {
Self::new()
}
}
impl ModelManager {
pub fn new() -> Self {
Self {
completion_engines: RwLock::new(ModelEngines::default()),
chat_completion_engines: RwLock::new(ModelEngines::default()),
embeddings_engines: RwLock::new(ModelEngines::default()),
entries: Mutex::new(HashMap::new()),
kv_choosers: Mutex::new(HashMap::new()),
}
}
pub fn has_model_any(&self, model: &str) -> bool {
self.chat_completion_engines.read().unwrap().contains(model)
|| self.completion_engines.read().unwrap().contains(model)
}
pub fn list_chat_completions_models(&self) -> Vec<String> {
self.chat_completion_engines.read().unwrap().list()
}
pub fn list_completions_models(&self) -> Vec<String> {
self.completion_engines.read().unwrap().list()
}
pub fn list_embeddings_models(&self) -> Vec<String> {
self.embeddings_engines.read().unwrap().list()
}
pub fn add_completions_model(
&self,
model: &str,
engine: OpenAICompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write().unwrap();
clients.add(model, engine)
}
pub fn add_chat_completions_model(
&self,
model: &str,
engine: OpenAIChatCompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write().unwrap();
clients.add(model, engine)
}
pub fn add_embeddings_model(
&self,
model: &str,
engine: OpenAIEmbeddingsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write().unwrap();
clients.add(model, engine)
}
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write().unwrap();
clients.remove(model)
}
pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write().unwrap();
clients.remove(model)
}
pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write().unwrap();
clients.remove(model)
}
// TODO: Remove this allow once `embeddings` is implemented in lib/llm/src/http/service/openai.rs
#[allow(dead_code)]
fn get_embeddings_engine(
&self,
model: &str,
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
self.embeddings_engines
.read()
.unwrap()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn get_completions_engine(
&self,
model: &str,
) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
self.completion_engines
.read()
.unwrap()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn get_chat_completions_engine(
&self,
model: &str,
) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
self.chat_completion_engines
.read()
.unwrap()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
/// Save a ModelEntry under an instance's etcd `models/` key so we can fetch it later when the key is
/// deleted from etcd.
pub fn save_model_entry(&self, key: &str, entry: ModelEntry) {
self.entries.lock().unwrap().insert(key.to_string(), entry);
}
/// Remove and return model entry for this instance's etcd key. We do this when the instance stops.
pub fn remove_model_entry(&self, key: &str) -> Option<ModelEntry> {
self.entries.lock().unwrap().remove(key)
}
pub async fn kv_chooser_for(
&self,
model_name: &str,
component: &Component,
) -> anyhow::Result<Arc<KvRouter>> {
if let Some(kv_chooser) = self.get_kv_chooser(model_name) {
return Ok(kv_chooser);
}
self.create_kv_chooser(model_name, component).await
}
fn get_kv_chooser(&self, model_name: &str) -> Option<Arc<KvRouter>> {
self.kv_choosers.lock().unwrap().get(model_name).cloned()
}
/// Create and return a KV chooser for this component and model
async fn create_kv_chooser(
&self,
model_name: &str,
component: &Component,
) -> anyhow::Result<Arc<KvRouter>> {
let selector = Box::new(DefaultWorkerSelector {});
let chooser = KvRouter::new(
component.clone(),
crate::DEFAULT_KV_BLOCK_SIZE,
Some(selector),
)
.await?;
let new_kv_chooser = Arc::new(chooser);
self.kv_choosers
.lock()
.unwrap()
.insert(model_name.to_string(), new_kv_chooser.clone());
Ok(new_kv_chooser)
}
}
pub struct ModelEngines<E> {
/// Optional default model name
default: Option<String>,
engines: HashMap<String, E>,
}
impl<E> Default for ModelEngines<E> {
fn default() -> Self {
Self {
default: None,
engines: HashMap::new(),
}
}
}
impl<E> ModelEngines<E> {
#[allow(dead_code)]
fn set_default(&mut self, model: &str) {
self.default = Some(model.to_string());
}
#[allow(dead_code)]
fn clear_default(&mut self) {
self.default = None;
}
fn add(&mut self, model: &str, engine: E) -> Result<(), ModelManagerError> {
if self.engines.contains_key(model) {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
self.engines.insert(model.to_string(), engine);
Ok(())
}
fn remove(&mut self, model: &str) -> Result<(), ModelManagerError> {
if self.engines.remove(model).is_none() {
return Err(ModelManagerError::ModelNotFound(model.to_string()));
}
Ok(())
}
fn get(&self, model: &str) -> Option<&E> {
self.engines.get(model)
}
fn contains(&self, model: &str) -> bool {
self.engines.contains_key(model)
}
pub fn list(&self) -> Vec<String> {
self.engines.keys().map(|k| k.to_owned()).collect()
}
}
...@@ -4,159 +4,35 @@ ...@@ -4,159 +4,35 @@
use std::sync::Arc; use std::sync::Arc;
use anyhow::Context as _; use anyhow::Context as _;
use serde::{Deserialize, Serialize};
use tokio::sync::{mpsc::Receiver, Notify}; use tokio::sync::{mpsc::Receiver, Notify};
use dynamo_runtime::{ use dynamo_runtime::{
component::{self, Component, Instance},
pipeline::{ pipeline::{
network::egress::push_router::PushRouter, ManyOut, Operator, RouterMode, SegmentSource, network::egress::push_router::PushRouter, ManyOut, Operator, RouterMode, SegmentSource,
ServiceBackend, SingleIn, Source, ServiceBackend, SingleIn, Source,
}, },
protocols::{self, annotated::Annotated}, protocols::annotated::Annotated,
slug::Slug, transports::etcd::{KeyValue, WatchEvent},
transports::etcd::{self, KeyValue, WatchEvent},
DistributedRuntime, DistributedRuntime,
}; };
use super::ModelManager;
use crate::protocols::openai::completions::{CompletionRequest, CompletionResponse};
use crate::{ use crate::{
backend::Backend, backend::Backend,
kv_router::KvPushRouter,
model_type::ModelType, model_type::ModelType,
preprocessor::{BackendInput, OpenAIPreprocessor}, preprocessor::{BackendInput, OpenAIPreprocessor},
protocols::common::llm_backend::LLMEngineOutput, protocols::common::llm_backend::LLMEngineOutput,
};
use crate::{
key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
model_card::{self, ModelDeploymentCard},
};
use crate::{
kv_router::{scheduler::DefaultWorkerSelector, KvPushRouter, KvRouter},
protocols::openai::chat_completions::{ protocols::openai::chat_completions::{
NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse, NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse,
}, },
protocols::openai::completions::{CompletionRequest, CompletionResponse},
protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse}, protocols::openai::embeddings::{NvCreateEmbeddingRequest, NvCreateEmbeddingResponse},
}; };
use tracing;
/// [ModelEntry] is a struct that contains the information for the HTTP service to discover models
/// from the etcd cluster.
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelEntry {
/// Public name of the model
/// This will be used to identify the model in the HTTP service and the value used in an
/// an [OAI ChatRequest][crate::protocols::openai::chat_completions::NvCreateChatCompletionRequest].
pub name: String,
/// Component of the endpoint.
pub endpoint: protocols::Endpoint,
/// Specifies whether the model is a chat or completion model.s
pub model_type: ModelType,
}
impl ModelEntry {
pub fn requires_preprocessing(&self) -> bool {
matches!(self.model_type, ModelType::Backend)
}
/// Fetch the ModelDeploymentCard from NATS.
/// This does not touch it's fields so you may need to call move_from_nats on it.
pub async fn load_mdc(
&self,
etcd_client: &etcd::Client,
) -> anyhow::Result<ModelDeploymentCard> {
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let card_key = ModelDeploymentCard::service_name_slug(&self.name);
match card_store
.load::<ModelDeploymentCard>(model_card::ROOT_PATH, &card_key)
.await
{
Ok(Some(mdc)) => Ok(mdc),
Ok(None) => {
anyhow::bail!("Missing ModelDeploymentCard in etcd under key {card_key}");
}
Err(err) => {
anyhow::bail!(
"Error fetching ModelDeploymentCard from etcd under key {card_key}. {err}"
);
}
}
}
}
#[derive(Debug, Clone)]
pub struct ModelNetworkName(String);
impl ModelNetworkName {
/// Key to store this model entry in networked key-value store (etcd).
///
/// It looks like this:
/// ns.cp.ep-694d967ca5efd804
fn from_parts(namespace: &str, component: &str, endpoint: &str, lease_id: i64) -> Self {
let model_root = component::MODEL_ROOT_PATH;
let slug = Slug::slugify(&format!("{namespace}.{component}.{endpoint}-{lease_id:x}"));
ModelNetworkName(format!("{model_root}/{slug}"))
}
// We can't do From<&component::Endpoint> here because we also need the lease_id
pub fn from_local(endpoint: &component::Endpoint, lease_id: i64) -> Self {
Self::from_parts(
&endpoint.component().namespace().to_string(),
&endpoint.component().name(),
endpoint.name(),
lease_id,
)
}
/// Fetch the ModelEntry from etcd.
pub async fn load_entry(&self, etcd_client: &etcd::Client) -> anyhow::Result<ModelEntry> {
let mut model_entries = etcd_client.kv_get(self.to_string(), None).await?;
if model_entries.is_empty() {
anyhow::bail!("No ModelEntry in etcd for key {self}");
}
let model_entry = model_entries.remove(0);
serde_json::from_slice(model_entry.value()).with_context(|| {
format!(
"Error deserializing JSON. Key={self}. JSON={}",
model_entry.value_str().unwrap_or("INVALID UTF-8")
)
})
}
/// Fetch the ModelDeploymentCard from NATS. use super::{ModelEntry, ModelManager, MODEL_ROOT_PATH};
/// This does not touch it's fields so you may need to call move_from_nats on it.
/// TODO We have potentially two for each endpoint, one Chat and one Completion.
pub async fn load_mdc(
&self,
etcd_client: &etcd::Client,
) -> anyhow::Result<ModelDeploymentCard> {
let entry = self.load_entry(etcd_client).await?;
entry.load_mdc(etcd_client).await
}
}
impl From<&Instance> for ModelNetworkName {
fn from(cei: &Instance) -> Self {
Self::from_parts(
&cei.namespace,
&cei.component,
&cei.endpoint,
cei.instance_id,
)
}
}
impl std::fmt::Display for ModelNetworkName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
pub struct ModelWatcher { pub struct ModelWatcher {
manager: ModelManager, manager: Arc<ModelManager>,
drt: DistributedRuntime, drt: DistributedRuntime,
router_mode: RouterMode, router_mode: RouterMode,
notify_on_model: Notify, notify_on_model: Notify,
...@@ -165,7 +41,7 @@ pub struct ModelWatcher { ...@@ -165,7 +41,7 @@ pub struct ModelWatcher {
impl ModelWatcher { impl ModelWatcher {
pub async fn new( pub async fn new(
runtime: DistributedRuntime, runtime: DistributedRuntime,
model_manager: ModelManager, model_manager: Arc<ModelManager>,
router_mode: RouterMode, router_mode: RouterMode,
) -> anyhow::Result<ModelWatcher> { ) -> anyhow::Result<ModelWatcher> {
Ok(Self { Ok(Self {
...@@ -187,7 +63,7 @@ impl ModelWatcher { ...@@ -187,7 +63,7 @@ impl ModelWatcher {
} }
} }
pub async fn watch(self: Arc<Self>, mut events_rx: Receiver<WatchEvent>) { pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>) {
tracing::debug!("model watcher started"); tracing::debug!("model watcher started");
while let Some(event) = events_rx.recv().await { while let Some(event) = events_rx.recv().await {
...@@ -207,6 +83,15 @@ impl ModelWatcher { ...@@ -207,6 +83,15 @@ impl ModelWatcher {
continue; continue;
} }
}; };
let key = match kv.key_str() {
Ok(k) => k,
Err(err) => {
tracing::error!(%err, ?kv, "Invalid UTF-8 string in model entry key, skipping");
continue;
}
};
self.manager.save_model_entry(key, model_entry.clone());
if self.manager.has_model_any(&model_entry.name) { if self.manager.has_model_any(&model_entry.name) {
tracing::trace!( tracing::trace!(
service_name = model_entry.name, service_name = model_entry.name,
...@@ -216,7 +101,7 @@ impl ModelWatcher { ...@@ -216,7 +101,7 @@ impl ModelWatcher {
continue; continue;
} }
match self.clone().handle_put(&kv, &model_entry).await { match self.handle_put(&model_entry).await {
Ok(()) => { Ok(()) => {
tracing::info!(model_name = model_entry.name, "added model"); tracing::info!(model_name = model_entry.name, "added model");
self.notify_on_model.notify_waiters(); self.notify_on_model.notify_waiters();
...@@ -226,10 +111,13 @@ impl ModelWatcher { ...@@ -226,10 +111,13 @@ impl ModelWatcher {
} }
} }
} }
WatchEvent::Delete(kv) => match self.clone().handle_delete(&kv).await { WatchEvent::Delete(kv) => match self.handle_delete(&kv).await {
Ok(model_name) => { Ok(Some(model_name)) => {
tracing::info!("removed model {}", model_name); tracing::info!("removed model {}", model_name);
} }
Ok(None) => {
// There are other instances running this model, nothing to do
}
Err(e) => { Err(e) => {
tracing::error!("error removing model: {}", e); tracing::error!("error removing model: {}", e);
} }
...@@ -238,37 +126,36 @@ impl ModelWatcher { ...@@ -238,37 +126,36 @@ impl ModelWatcher {
} }
} }
/// Returns the name of the model we just deleted /// If the last instance running this model has gone delete it.
async fn handle_delete(self: Arc<ModelWatcher>, kv: &KeyValue) -> anyhow::Result<String> { /// Returns the name of the model we just deleted, if any.
async fn handle_delete(&self, kv: &KeyValue) -> anyhow::Result<Option<String>> {
let key = kv.key_str()?; let key = kv.key_str()?;
let model_entry = match self.manager.state.entries.lock().unwrap().remove(key) { let model_entry = match self.manager.remove_model_entry(key) {
Some(entry) => entry, Some(entry) => entry,
None => { None => {
anyhow::bail!("Missing ModelEntry for {key}"); anyhow::bail!("Missing ModelEntry for {key}");
} }
}; };
let model_name = &model_entry.name; let model_name = model_entry.name;
tracing::debug!(model_name, "removing model"); let active_instances = self
.entries_for_model(&model_name)
.await
.with_context(|| model_name.clone())?;
if !active_instances.is_empty() {
return Ok(None);
}
// Ignore the errors because model could be either type // Ignore the errors because model could be either type
let _ = self.manager.remove_chat_completions_model(model_name); let _ = self.manager.remove_chat_completions_model(&model_name);
let _ = self.manager.remove_completions_model(model_name); let _ = self.manager.remove_completions_model(&model_name);
let _ = self.manager.remove_embeddings_model(model_name); let _ = self.manager.remove_embeddings_model(&model_name);
// We own model_entry now so take ownership of the name Ok(Some(model_name))
Ok(model_entry.name)
} }
// Handles a PUT event from etcd, this usually means adding a new model to the list of served // Handles a PUT event from etcd, this usually means adding a new model to the list of served
// models. // models.
// async fn handle_put(&self, model_entry: &ModelEntry) -> anyhow::Result<()> {
// If this method errors, for the near term, we will delete the offending key.
async fn handle_put(
self: Arc<ModelWatcher>,
kv: &KeyValue,
model_entry: &ModelEntry,
) -> anyhow::Result<()> {
let key = kv.key_str()?;
let endpoint_id = model_entry.endpoint.clone(); let endpoint_id = model_entry.endpoint.clone();
let component = self let component = self
.drt .drt
...@@ -291,13 +178,6 @@ impl ModelWatcher { ...@@ -291,13 +178,6 @@ impl ModelWatcher {
None None
} }
}; };
// We need to save the entry to know what the model is called when we delete it
self.manager
.state
.entries
.lock()
.unwrap()
.insert(key.to_string(), model_entry.clone());
match model_entry.model_type { match model_entry.model_type {
ModelType::Backend => { ModelType::Backend => {
...@@ -329,7 +209,10 @@ impl ModelWatcher { ...@@ -329,7 +209,10 @@ impl ModelWatcher {
ServiceBackend::from_engine(Arc::new(router)) ServiceBackend::from_engine(Arc::new(router))
} }
RouterMode::KV => { RouterMode::KV => {
let chooser = self.kv_chooser_for(&model_entry.name, &component).await?; let chooser = self
.manager
.kv_chooser_for(&model_entry.name, &component)
.await?;
let kv_push_router = KvPushRouter::new(router, chooser); let kv_push_router = KvPushRouter::new(router, chooser);
ServiceBackend::from_engine(Arc::new(kv_push_router)) ServiceBackend::from_engine(Arc::new(kv_push_router))
} }
...@@ -361,7 +244,10 @@ impl ModelWatcher { ...@@ -361,7 +244,10 @@ impl ModelWatcher {
ServiceBackend::from_engine(Arc::new(router)) ServiceBackend::from_engine(Arc::new(router))
} }
RouterMode::KV => { RouterMode::KV => {
let chooser = self.kv_chooser_for(&model_entry.name, &component).await?; let chooser = self
.manager
.kv_chooser_for(&model_entry.name, &component)
.await?;
let kv_push_router = KvPushRouter::new(router, chooser); let kv_push_router = KvPushRouter::new(router, chooser);
ServiceBackend::from_engine(Arc::new(kv_push_router)) ServiceBackend::from_engine(Arc::new(kv_push_router))
} }
...@@ -413,36 +299,36 @@ impl ModelWatcher { ...@@ -413,36 +299,36 @@ impl ModelWatcher {
Ok(()) Ok(())
} }
async fn kv_chooser_for( /// All the registered ModelEntry, one per instance
&self, async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> {
model_name: &str, let Some(etcd_client) = self.drt.etcd_client() else {
component: &Component, anyhow::bail!("all_entries: Missing etcd client");
) -> anyhow::Result<Arc<KvRouter>> { };
if let Some(kv_chooser) = self let kvs = etcd_client.kv_get_prefix(MODEL_ROOT_PATH).await?;
.manager let mut entries = Vec::with_capacity(kvs.len());
.state for kv in kvs {
.kv_choosers let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
.lock() Ok(model_entry) => model_entry,
.unwrap() Err(err) => {
.get(model_name) match kv.value_str() {
{ Ok(value) => {
// Return early to avoid holding the lock during the await later tracing::error!(%err, value, "Invalid JSON in model entry")
return Ok(Arc::clone(kv_chooser)); }
Err(value_str_err) => {
tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
}
}
continue;
}
};
entries.push(model_entry);
} }
let selector = Box::new(DefaultWorkerSelector {}); Ok(entries)
let chooser = KvRouter::new( }
component.clone(),
crate::DEFAULT_KV_BLOCK_SIZE, async fn entries_for_model(&self, model_name: &str) -> anyhow::Result<Vec<ModelEntry>> {
Some(selector), let mut all = self.all_entries().await?;
) all.retain(|entry| entry.name == model_name);
.await?; Ok(all)
let new_kv_chooser = Arc::new(chooser);
self.manager
.state
.kv_choosers
.lock()
.unwrap()
.insert(model_name.to_string(), new_kv_chooser.clone());
Ok(new_kv_chooser)
} }
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! HTTP Service for Nova LLM //! HTTP Service for Dynamo LLM
//! //!
//! The primary purpose of this crate is to service the nova-llm-protocols via OpenAI compatible HTTP endpoints. This component //! The primary purpose of this crate is to service the dynamo-llm protocols via OpenAI compatible HTTP endpoints. This component
//! is meant to be a gateway/ingress into the Nova LLM Distributed Runtime. //! is meant to be a gateway/ingress into the Dynamo LLM Distributed Runtime.
//! //!
//! In order to create a common pattern, the HttpService forwards the incoming OAI Chat Request or OAI Completion Request to the //! In order to create a common pattern, the HttpService forwards the incoming OAI Chat Request or OAI Completion Request to the
//! to a model-specific engines. The engines can be attached and detached dynamically using the [`ModelManager`]. //! to a model-specific engines. The engines can be attached and detached dynamically using the [`ModelManager`].
...@@ -32,242 +20,13 @@ ...@@ -32,242 +20,13 @@
mod openai; mod openai;
pub mod discovery;
pub mod error; pub mod error;
pub mod metrics; pub mod metrics;
pub mod service_v2; pub mod service_v2;
// #[cfg(feature = "py3")]
// pub mod py3;
pub use async_trait::async_trait;
pub use axum; pub use axum;
use discovery::ModelEntry;
pub use error::ServiceHttpError;
pub use metrics::Metrics; pub use metrics::Metrics;
use crate::{
kv_router::KvRouter,
types::openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
},
};
use std::{
collections::HashMap,
sync::{Arc, Mutex},
time::Duration,
};
#[derive(Clone)]
pub struct ModelManager {
state: Arc<DeploymentState>,
}
impl Default for ModelManager {
fn default() -> Self {
Self::new()
}
}
impl ModelManager {
pub fn new() -> Self {
let state = Arc::new(DeploymentState::new());
Self { state }
}
pub fn state(&self) -> Arc<DeploymentState> {
self.state.clone()
}
pub fn has_model_any(&self, model: &str) -> bool {
self.state
.chat_completion_engines
.lock()
.unwrap()
.contains(model)
|| self
.state
.completion_engines
.lock()
.unwrap()
.contains(model)
}
pub fn list_chat_completions_models(&self) -> Vec<String> {
self.state.chat_completion_engines.lock().unwrap().list()
}
pub fn list_completions_models(&self) -> Vec<String> {
self.state.completion_engines.lock().unwrap().list()
}
pub fn add_completions_model(
&self,
model: &str,
engine: OpenAICompletionsStreamingEngine,
) -> Result<(), ServiceHttpError> {
let mut clients = self.state.completion_engines.lock().unwrap();
clients.add(model, engine)
}
pub fn add_chat_completions_model(
&self,
model: &str,
engine: OpenAIChatCompletionsStreamingEngine,
) -> Result<(), ServiceHttpError> {
let mut clients = self.state.chat_completion_engines.lock().unwrap();
clients.add(model, engine)
}
pub fn add_embeddings_model(
&self,
model: &str,
engine: OpenAIEmbeddingsStreamingEngine,
) -> Result<(), ServiceHttpError> {
let mut clients = self.state.embeddings_engines.lock().unwrap();
clients.add(model, engine)
}
pub fn remove_completions_model(&self, model: &str) -> Result<(), ServiceHttpError> {
let mut clients = self.state.completion_engines.lock().unwrap();
clients.remove(model)
}
pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ServiceHttpError> {
let mut clients = self.state.chat_completion_engines.lock().unwrap();
clients.remove(model)
}
pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ServiceHttpError> {
let mut clients = self.state.embeddings_engines.lock().unwrap();
clients.remove(model)
}
/// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests
pub fn metrics(&self) -> Arc<Metrics> {
self.state.metrics.clone()
}
}
struct ModelEngines<E> {
/// Optional default model name
default: Option<String>,
engines: HashMap<String, E>,
}
impl<E> Default for ModelEngines<E> {
fn default() -> Self {
Self {
default: None,
engines: HashMap::new(),
}
}
}
impl<E> ModelEngines<E> {
#[allow(dead_code)]
fn set_default(&mut self, model: &str) {
self.default = Some(model.to_string());
}
#[allow(dead_code)]
fn clear_default(&mut self) {
self.default = None;
}
fn add(&mut self, model: &str, engine: E) -> Result<(), ServiceHttpError> {
if self.engines.contains_key(model) {
return Err(ServiceHttpError::ModelAlreadyExists(model.to_string()));
}
self.engines.insert(model.to_string(), engine);
Ok(())
}
fn remove(&mut self, model: &str) -> Result<(), ServiceHttpError> {
if self.engines.remove(model).is_none() {
return Err(ServiceHttpError::ModelNotFound(model.to_string()));
}
Ok(())
}
fn get(&self, model: &str) -> Option<&E> {
self.engines.get(model)
}
fn contains(&self, model: &str) -> bool {
self.engines.contains_key(model)
}
fn list(&self) -> Vec<String> {
self.engines.keys().map(|k| k.to_owned()).collect()
}
}
/// The DeploymentState is a global state that is shared across all the workers
/// this provides set of known clients to Engines
pub struct DeploymentState {
completion_engines: Arc<Mutex<ModelEngines<OpenAICompletionsStreamingEngine>>>,
chat_completion_engines: Arc<Mutex<ModelEngines<OpenAIChatCompletionsStreamingEngine>>>,
embeddings_engines: Arc<Mutex<ModelEngines<OpenAIEmbeddingsStreamingEngine>>>,
metrics: Arc<Metrics>,
sse_keep_alive: Option<Duration>,
entries: Arc<Mutex<HashMap<String, ModelEntry>>>,
kv_choosers: Arc<Mutex<HashMap<String, Arc<KvRouter>>>>,
}
impl DeploymentState {
fn new() -> Self {
Self {
completion_engines: Arc::new(Mutex::new(ModelEngines::default())),
chat_completion_engines: Arc::new(Mutex::new(ModelEngines::default())),
embeddings_engines: Arc::new(Mutex::new(ModelEngines::default())),
metrics: Arc::new(Metrics::default()),
sse_keep_alive: None,
entries: Arc::new(Mutex::new(HashMap::new())),
kv_choosers: Arc::new(Mutex::new(HashMap::new())),
}
}
// TODO: Remove this allow once `embeddings` is implemented in lib/llm/src/http/service/openai.rs
#[allow(dead_code)]
fn get_embeddings_engine(
&self,
model: &str,
) -> Result<OpenAIEmbeddingsStreamingEngine, ServiceHttpError> {
self.embeddings_engines
.lock()
.unwrap()
.get(model)
.cloned()
.ok_or(ServiceHttpError::ModelNotFound(model.to_string()))
}
pub fn get_completions_engine(
&self,
model: &str,
) -> Result<OpenAICompletionsStreamingEngine, ServiceHttpError> {
self.completion_engines
.lock()
.unwrap()
.get(model)
.cloned()
.ok_or(ServiceHttpError::ModelNotFound(model.to_string()))
}
pub fn get_chat_completions_engine(
&self,
model: &str,
) -> Result<OpenAIChatCompletionsStreamingEngine, ServiceHttpError> {
self.chat_completion_engines
.lock()
.unwrap()
.get(model)
.cloned()
.ok_or(ServiceHttpError::ModelNotFound(model.to_string()))
}
}
/// Documentation for a route /// Documentation for a route
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RouteDoc { pub struct RouteDoc {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use thiserror::Error; use thiserror::Error;
#[derive(Debug, Error)]
pub enum ServiceHttpError {
#[error("Model not found: {0}")]
ModelNotFound(String),
#[error("Model already exists: {0}")]
ModelAlreadyExists(String),
}
/// Implementation of the Completion Engines served by the HTTP service should /// Implementation of the Completion Engines served by the HTTP service should
/// map their custom errors to to this error type if they wish to return error /// map their custom errors to to this error type if they wish to return error
/// codes besides 500. /// codes besides 500.
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::get, Router}; use axum::{extract::State, http::StatusCode, response::IntoResponse, routing::get, Router};
use prometheus::{Encoder, HistogramOpts, HistogramVec, IntCounterVec, IntGaugeVec, Opts}; use prometheus::{Encoder, HistogramOpts, HistogramVec, IntCounterVec, IntGaugeVec, Opts};
...@@ -19,7 +7,7 @@ use std::{sync::Arc, time::Instant}; ...@@ -19,7 +7,7 @@ use std::{sync::Arc, time::Instant};
pub use prometheus::Registry; pub use prometheus::Registry;
use super::{DeploymentState, RouteDoc}; use super::RouteDoc;
/// Value for the `status` label in the request counter for successful requests /// Value for the `status` label in the request counter for successful requests
pub const REQUEST_STATUS_SUCCESS: &str = "success"; pub const REQUEST_STATUS_SUCCESS: &str = "success";
...@@ -193,16 +181,14 @@ impl Metrics { ...@@ -193,16 +181,14 @@ impl Metrics {
registry.register(Box::new(self.request_duration.clone()))?; registry.register(Box::new(self.request_duration.clone()))?;
Ok(()) Ok(())
} }
}
impl DeploymentState {
/// Create a new [`InflightGuard`] for the given model and annotate if its a streaming request, /// Create a new [`InflightGuard`] for the given model and annotate if its a streaming request,
/// and the kind of endpoint that was hit /// and the kind of endpoint that was hit
/// ///
/// The [`InflightGuard`] is an RAII object will handle incrementing the inflight gauge and /// The [`InflightGuard`] is an RAII object will handle incrementing the inflight gauge and
/// request counters. /// request counters.
pub fn create_inflight_guard( pub fn create_inflight_guard(
&self, self: Arc<Self>,
model: &str, model: &str,
endpoint: Endpoint, endpoint: Endpoint,
streaming: bool, streaming: bool,
...@@ -213,12 +199,7 @@ impl DeploymentState { ...@@ -213,12 +199,7 @@ impl DeploymentState {
RequestType::Unary RequestType::Unary
}; };
InflightGuard::new( InflightGuard::new(self.clone(), model.to_string(), endpoint, request_type)
self.metrics.clone(),
model.to_string(),
endpoint,
request_type,
)
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use axum::{ use axum::{
extract::State, extract::State,
...@@ -26,18 +14,17 @@ use axum::{ ...@@ -26,18 +14,17 @@ use axum::{
use futures::{Stream, StreamExt}; use futures::{Stream, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{
collections::{HashMap, HashSet}, collections::HashSet,
pin::Pin, pin::Pin,
sync::Arc, sync::Arc,
time::{SystemTime, UNIX_EPOCH}, time::{SystemTime, UNIX_EPOCH},
}; };
use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::ReceiverStream;
use super::DeploymentState;
use super::{ use super::{
error::HttpError, error::HttpError,
metrics::{Endpoint, InflightGuard}, metrics::{Endpoint, InflightGuard},
RouteDoc, service_v2, RouteDoc,
}; };
use crate::protocols::openai::embeddings::NvCreateEmbeddingRequest; use crate::protocols::openai::embeddings::NvCreateEmbeddingRequest;
...@@ -132,7 +119,7 @@ impl From<HttpError> for ErrorResponse { ...@@ -132,7 +119,7 @@ impl From<HttpError> for ErrorResponse {
/// non-streaming requests, we will fold the stream into a single response as part of this handler. /// non-streaming requests, we will fold the stream into a single response as part of this handler.
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
async fn completions( async fn completions(
State(state): State<Arc<DeploymentState>>, State(state): State<Arc<service_v2::State>>,
Json(request): Json<CompletionRequest>, Json(request): Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// return a 503 if the service is not ready // return a 503 if the service is not ready
...@@ -161,11 +148,15 @@ async fn completions( ...@@ -161,11 +148,15 @@ async fn completions(
// todo - error handling should be more robust // todo - error handling should be more robust
let engine = state let engine = state
.manager()
.get_completions_engine(model) .get_completions_engine(model)
.map_err(|_| ErrorResponse::model_not_found())?; .map_err(|_| ErrorResponse::model_not_found())?;
// this will increment the inflight gauge for the model // this will increment the inflight gauge for the model
let mut inflight = state.create_inflight_guard(model, Endpoint::Completions, streaming); let mut inflight =
state
.metrics_clone()
.create_inflight_guard(model, Endpoint::Completions, streaming);
// setup context // setup context
// todo - inherit request_id from distributed trace details // todo - inherit request_id from distributed trace details
...@@ -189,7 +180,7 @@ async fn completions( ...@@ -189,7 +180,7 @@ async fn completions(
let mut sse_stream = Sse::new(stream); let mut sse_stream = Sse::new(stream);
if let Some(keep_alive) = state.sse_keep_alive { if let Some(keep_alive) = state.sse_keep_alive() {
sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive)); sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive));
} }
...@@ -213,7 +204,7 @@ async fn completions( ...@@ -213,7 +204,7 @@ async fn completions(
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
async fn embeddings( async fn embeddings(
State(_state): State<Arc<DeploymentState>>, State(_state): State<Arc<service_v2::State>>,
Json(_request): Json<NvCreateEmbeddingRequest>, Json(_request): Json<NvCreateEmbeddingRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
unimplemented!("embeddings are not supported yet"); unimplemented!("embeddings are not supported yet");
...@@ -229,7 +220,7 @@ async fn embeddings( ...@@ -229,7 +220,7 @@ async fn embeddings(
/// non-streaming requests, we will fold the stream into a single response as part of this handler. /// non-streaming requests, we will fold the stream into a single response as part of this handler.
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
async fn chat_completions( async fn chat_completions(
State((state, template)): State<(Arc<DeploymentState>, Option<RequestTemplate>)>, State((state, template)): State<(Arc<service_v2::State>, Option<RequestTemplate>)>,
Json(mut request): Json<NvCreateChatCompletionRequest>, Json(mut request): Json<NvCreateChatCompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
// return a 503 if the service is not ready // return a 503 if the service is not ready
...@@ -274,11 +265,15 @@ async fn chat_completions( ...@@ -274,11 +265,15 @@ async fn chat_completions(
tracing::trace!("Getting chat completions engine for model: {}", model); tracing::trace!("Getting chat completions engine for model: {}", model);
let engine = state let engine = state
.manager()
.get_chat_completions_engine(model) .get_chat_completions_engine(model)
.map_err(|_| ErrorResponse::model_not_found())?; .map_err(|_| ErrorResponse::model_not_found())?;
// this will increment the inflight gauge for the model // this will increment the inflight gauge for the model
let mut inflight = state.create_inflight_guard(model, Endpoint::ChatCompletions, streaming); let mut inflight =
state
.metrics_clone()
.create_inflight_guard(model, Endpoint::ChatCompletions, streaming);
// setup context // setup context
// todo - inherit request_id from distributed trace details // todo - inherit request_id from distributed trace details
...@@ -304,7 +299,7 @@ async fn chat_completions( ...@@ -304,7 +299,7 @@ async fn chat_completions(
let mut sse_stream = Sse::new(stream); let mut sse_stream = Sse::new(stream);
if let Some(keep_alive) = state.sse_keep_alive { if let Some(keep_alive) = state.sse_keep_alive() {
sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive)); sse_stream = sse_stream.keep_alive(KeepAlive::default().interval(keep_alive));
} }
...@@ -331,44 +326,13 @@ async fn chat_completions( ...@@ -331,44 +326,13 @@ async fn chat_completions(
// todo - abstract this to the top level lib.rs to be reused // todo - abstract this to the top level lib.rs to be reused
// todo - move the service_observer to its own state/arc // todo - move the service_observer to its own state/arc
fn check_ready(_state: &Arc<DeploymentState>) -> Result<(), (StatusCode, Json<ErrorResponse>)> { fn check_ready(_state: &Arc<service_v2::State>) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
// if state.service_observer.stage() != ServiceStage::Ready { // if state.service_observer.stage() != ServiceStage::Ready {
// return Err(ErrorResponse::service_unavailable()); // return Err(ErrorResponse::service_unavailable());
// } // }
Ok(()) Ok(())
} }
/// list models handler, non-standard format
async fn list_models_custom(
State(state): State<Arc<DeploymentState>>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
check_ready(&state)?;
let mut models = HashMap::new();
let chat_models = state
.chat_completion_engines
.lock()
.unwrap()
.engines
.keys()
.cloned()
.collect::<Vec<String>>();
let completion_models = state
.completion_engines
.lock()
.unwrap()
.engines
.keys()
.cloned()
.collect::<Vec<String>>();
models.insert("chat_completion_models", chat_models);
models.insert("completion_models", completion_models);
Ok(Json(models).into_response())
}
/// openai compatible format /// openai compatible format
/// Example: /// Example:
/// { /// {
...@@ -383,7 +347,7 @@ async fn list_models_custom( ...@@ -383,7 +347,7 @@ async fn list_models_custom(
/// ] /// ]
/// } /// }
async fn list_models_openai( async fn list_models_openai(
State(state): State<Arc<DeploymentState>>, State(state): State<Arc<service_v2::State>>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> { ) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
check_ready(&state)?; check_ready(&state)?;
...@@ -394,14 +358,11 @@ async fn list_models_openai( ...@@ -394,14 +358,11 @@ async fn list_models_openai(
let mut data = Vec::new(); let mut data = Vec::new();
let models: HashSet<String> = state let models: HashSet<String> = state
.chat_completion_engines .manager()
.lock() .list_chat_completions_models()
.unwrap() .into_iter()
.engines .chain(state.manager().list_completions_models())
.keys() .chain(state.manager().list_embeddings_models())
.chain(state.completion_engines.lock().unwrap().engines.keys())
.chain(state.embeddings_engines.lock().unwrap().engines.keys())
.cloned()
.collect(); .collect();
for model_id in models { for model_id in models {
...@@ -522,7 +483,7 @@ impl<T: Serialize> TryFrom<EventConverter<T>> for Event { ...@@ -522,7 +483,7 @@ impl<T: Serialize> TryFrom<EventConverter<T>> for Event {
/// Create an Axum [`Router`] for the OpenAI API Completions endpoint /// Create an Axum [`Router`] for the OpenAI API Completions endpoint
/// If not path is provided, the default path is `/v1/completions` /// If not path is provided, the default path is `/v1/completions`
pub fn completions_router( pub fn completions_router(
state: Arc<DeploymentState>, state: Arc<service_v2::State>,
path: Option<String>, path: Option<String>,
) -> (Vec<RouteDoc>, Router) { ) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/completions".to_string()); let path = path.unwrap_or("/v1/completions".to_string());
...@@ -536,7 +497,7 @@ pub fn completions_router( ...@@ -536,7 +497,7 @@ pub fn completions_router(
/// Create an Axum [`Router`] for the OpenAI API Chat Completions endpoint /// Create an Axum [`Router`] for the OpenAI API Chat Completions endpoint
/// If not path is provided, the default path is `/v1/chat/completions` /// If not path is provided, the default path is `/v1/chat/completions`
pub fn chat_completions_router( pub fn chat_completions_router(
state: Arc<DeploymentState>, state: Arc<service_v2::State>,
template: Option<RequestTemplate>, template: Option<RequestTemplate>,
path: Option<String>, path: Option<String>,
) -> (Vec<RouteDoc>, Router) { ) -> (Vec<RouteDoc>, Router) {
...@@ -551,7 +512,7 @@ pub fn chat_completions_router( ...@@ -551,7 +512,7 @@ pub fn chat_completions_router(
/// Create an Axum [`Router`] for the OpenAI API Embeddings endpoint /// Create an Axum [`Router`] for the OpenAI API Embeddings endpoint
/// If not path is provided, the default path is `/v1/embeddings` /// If not path is provided, the default path is `/v1/embeddings`
pub fn embeddings_router( pub fn embeddings_router(
state: Arc<DeploymentState>, state: Arc<service_v2::State>,
path: Option<String>, path: Option<String>,
) -> (Vec<RouteDoc>, Router) { ) -> (Vec<RouteDoc>, Router) {
let path = path.unwrap_or("/v1/embeddings".to_string()); let path = path.unwrap_or("/v1/embeddings".to_string());
...@@ -564,28 +525,23 @@ pub fn embeddings_router( ...@@ -564,28 +525,23 @@ pub fn embeddings_router(
/// List Models /// List Models
pub fn list_models_router( pub fn list_models_router(
state: Arc<DeploymentState>, state: Arc<service_v2::State>,
path: Option<String>, path: Option<String>,
) -> (Vec<RouteDoc>, Router) { ) -> (Vec<RouteDoc>, Router) {
// TODO: Why do we have this endpoint?
let custom_path = path.unwrap_or("/dynamo/alpha/list-models".to_string());
let doc_for_custom = RouteDoc::new(axum::http::Method::GET, &custom_path);
// Standard OpenAI compatible list models endpoint // Standard OpenAI compatible list models endpoint
let openai_path = "/v1/models".to_string(); let openai_path = path.unwrap_or("/v1/models".to_string());
let doc_for_openai = RouteDoc::new(axum::http::Method::GET, &openai_path); let doc_for_openai = RouteDoc::new(axum::http::Method::GET, &openai_path);
let router = Router::new() let router = Router::new()
.route(&custom_path, get(list_models_custom))
.route(&openai_path, get(list_models_openai)) .route(&openai_path, get(list_models_openai))
.with_state(state); .with_state(state);
(vec![doc_for_custom, doc_for_openai], router) (vec![doc_for_openai], router)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::super::ServiceHttpError; use crate::discovery::ModelManagerError;
use super::*; use super::*;
...@@ -599,7 +555,7 @@ mod tests { ...@@ -599,7 +555,7 @@ mod tests {
} }
fn other_error_from_engine() -> Result<(), anyhow::Error> { fn other_error_from_engine() -> Result<(), anyhow::Error> {
Err(ServiceHttpError::ModelNotFound("foo".to_string()))? Err(ModelManagerError::ModelNotFound("foo".to_string()))?
} }
#[test] #[test]
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License"); use std::sync::Arc;
// you may not use this file except in compliance with the License. use std::time::Duration;
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use super::metrics; use super::metrics;
use super::ModelManager; use super::Metrics;
use super::RouteDoc; use super::RouteDoc;
use crate::discovery::ModelManager;
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use anyhow::Result; use anyhow::Result;
use derive_builder::Builder; use derive_builder::Builder;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
/// HTTP service shared state
pub struct State {
metrics: Arc<Metrics>,
manager: Arc<ModelManager>,
}
impl State {
pub fn new(manager: Arc<ModelManager>) -> Self {
Self {
manager,
metrics: Arc::new(Metrics::default()),
}
}
/// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests
pub fn metrics_clone(&self) -> Arc<Metrics> {
self.metrics.clone()
}
pub fn manager(&self) -> &ModelManager {
Arc::as_ref(&self.manager)
}
pub fn manager_clone(&self) -> Arc<ModelManager> {
self.manager.clone()
}
// TODO
pub fn sse_keep_alive(&self) -> Option<Duration> {
None
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct HttpService { pub struct HttpService {
models: ModelManager, // The state we share with every request handler
state: Arc<State>,
router: axum::Router, router: axum::Router,
port: u16, port: u16,
host: String, host: String,
...@@ -60,8 +87,16 @@ impl HttpService { ...@@ -60,8 +87,16 @@ impl HttpService {
HttpServiceConfigBuilder::default() HttpServiceConfigBuilder::default()
} }
pub fn state_clone(&self) -> Arc<State> {
self.state.clone()
}
pub fn state(&self) -> &State {
Arc::as_ref(&self.state)
}
pub fn model_manager(&self) -> &ModelManager { pub fn model_manager(&self) -> &ModelManager {
&self.models self.state().manager()
} }
pub async fn spawn(&self, cancel_token: CancellationToken) -> JoinHandle<Result<()>> { pub async fn spawn(&self, cancel_token: CancellationToken) -> JoinHandle<Result<()>> {
...@@ -98,11 +133,12 @@ impl HttpServiceConfigBuilder { ...@@ -98,11 +133,12 @@ impl HttpServiceConfigBuilder {
pub fn build(self) -> Result<HttpService, anyhow::Error> { pub fn build(self) -> Result<HttpService, anyhow::Error> {
let config: HttpServiceConfig = self.build_internal()?; let config: HttpServiceConfig = self.build_internal()?;
let model_manager = ModelManager::new(); let model_manager = Arc::new(ModelManager::new());
let state = Arc::new(State::new(model_manager));
// enable prometheus metrics // enable prometheus metrics
let registry = metrics::Registry::new(); let registry = metrics::Registry::new();
model_manager.metrics().register(&registry)?; state.metrics_clone().register(&registry)?;
let mut router = axum::Router::new(); let mut router = axum::Router::new();
...@@ -110,29 +146,23 @@ impl HttpServiceConfigBuilder { ...@@ -110,29 +146,23 @@ impl HttpServiceConfigBuilder {
let mut routes = vec![ let mut routes = vec![
metrics::router(registry, None), metrics::router(registry, None),
super::openai::list_models_router(model_manager.state(), None), super::openai::list_models_router(state.clone(), None),
]; ];
if config.enable_chat_endpoints { if config.enable_chat_endpoints {
routes.push(super::openai::chat_completions_router( routes.push(super::openai::chat_completions_router(
model_manager.state(), state.clone(),
config.request_template, config.request_template,
None, None,
)); ));
} }
if config.enable_cmpl_endpoints { if config.enable_cmpl_endpoints {
routes.push(super::openai::completions_router( routes.push(super::openai::completions_router(state.clone(), None));
model_manager.state(),
None,
));
} }
if config.enable_embeddings_endpoints { if config.enable_embeddings_endpoints {
routes.push(super::openai::embeddings_router( routes.push(super::openai::embeddings_router(state.clone(), None));
model_manager.state(),
None,
));
} }
// for (route_docs, route) in routes.into_iter().chain(self.routes.into_iter()) { // for (route_docs, route) in routes.into_iter().chain(self.routes.into_iter()) {
...@@ -146,7 +176,7 @@ impl HttpServiceConfigBuilder { ...@@ -146,7 +176,7 @@ impl HttpServiceConfigBuilder {
} }
Ok(HttpService { Ok(HttpService {
models: model_manager, state,
router, router,
port: config.port, port: config.port,
host: config.host, host: config.host,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
pub mod backend; pub mod backend;
pub mod common; pub mod common;
pub mod disagg_router; pub mod disagg_router;
pub mod discovery;
pub mod engines; pub mod engines;
pub mod gguf; pub mod gguf;
pub mod http; pub mod http;
......
...@@ -8,11 +8,14 @@ use std::sync::Arc; ...@@ -8,11 +8,14 @@ use std::sync::Arc;
use dynamo_runtime::component::{Component, Endpoint}; use dynamo_runtime::component::{Component, Endpoint};
use dynamo_runtime::traits::DistributedRuntimeProvider; use dynamo_runtime::traits::DistributedRuntimeProvider;
use crate::http::service::discovery::{ModelEntry, ModelNetworkName}; use crate::discovery::ModelEntry;
use crate::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager}; use crate::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager};
use crate::model_card::{self, ModelDeploymentCard}; use crate::model_card::{self, ModelDeploymentCard};
use crate::model_type::ModelType; use crate::model_type::ModelType;
mod network_name;
pub use network_name::ModelNetworkName;
/// Prefix for Hugging Face model repository /// Prefix for Hugging Face model repository
const HF_SCHEME: &str = "hf://"; const HF_SCHEME: &str = "hf://";
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use anyhow::Context as _;
use crate::discovery::{ModelEntry, MODEL_ROOT_PATH};
use dynamo_runtime::component::{self, Instance};
use dynamo_runtime::slug::Slug;
use dynamo_runtime::transports::etcd;
#[derive(Debug, Clone)]
pub struct ModelNetworkName(String);
impl ModelNetworkName {
/// Key to store this model entry in networked key-value store (etcd).
///
/// It looks like this:
/// ns.cp.ep-694d967ca5efd804
fn from_parts(namespace: &str, component: &str, endpoint: &str, lease_id: i64) -> Self {
let model_root = MODEL_ROOT_PATH;
let slug = Slug::slugify(&format!("{namespace}.{component}.{endpoint}-{lease_id:x}"));
ModelNetworkName(format!("{model_root}/{slug}"))
}
// We can't do From<&component::Endpoint> here because we also need the lease_id
pub fn from_local(endpoint: &component::Endpoint, lease_id: i64) -> Self {
Self::from_parts(
&endpoint.component().namespace().to_string(),
&endpoint.component().name(),
endpoint.name(),
lease_id,
)
}
/// Fetch the ModelEntry from etcd.
pub async fn load_entry(&self, etcd_client: &etcd::Client) -> anyhow::Result<ModelEntry> {
let mut model_entries = etcd_client.kv_get(self.to_string(), None).await?;
if model_entries.is_empty() {
anyhow::bail!("No ModelEntry in etcd for key {self}");
}
let model_entry = model_entries.remove(0);
serde_json::from_slice(model_entry.value()).with_context(|| {
format!(
"Error deserializing JSON. Key={self}. JSON={}",
model_entry.value_str().unwrap_or("INVALID UTF-8")
)
})
}
}
impl From<&Instance> for ModelNetworkName {
fn from(cei: &Instance) -> Self {
Self::from_parts(
&cei.namespace,
&cei.component,
&cei.endpoint,
cei.instance_id,
)
}
}
impl std::fmt::Display for ModelNetworkName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
...@@ -116,7 +116,7 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon ...@@ -116,7 +116,7 @@ impl AsyncEngine<SingleIn<CompletionRequest>, ManyOut<Annotated<CompletionRespon
} }
fn compare_counter( fn compare_counter(
metrics: Arc<Metrics>, metrics: &Metrics,
model: &str, model: &str,
endpoint: &Endpoint, endpoint: &Endpoint,
request_type: &RequestType, request_type: &RequestType,
...@@ -154,13 +154,13 @@ fn compute_index(endpoint: &Endpoint, request_type: &RequestType, status: &Statu ...@@ -154,13 +154,13 @@ fn compute_index(endpoint: &Endpoint, request_type: &RequestType, status: &Statu
endpoint * 4 + request_type * 2 + status endpoint * 4 + request_type * 2 + status
} }
fn compare_counters(metrics: Arc<Metrics>, model: &str, expected: &[u64; 8]) { fn compare_counters(metrics: &Metrics, model: &str, expected: &[u64; 8]) {
for endpoint in &[Endpoint::Completions, Endpoint::ChatCompletions] { for endpoint in &[Endpoint::Completions, Endpoint::ChatCompletions] {
for request_type in &[RequestType::Unary, RequestType::Stream] { for request_type in &[RequestType::Unary, RequestType::Stream] {
for status in &[Status::Success, Status::Error] { for status in &[Status::Success, Status::Error] {
let index = compute_index(endpoint, request_type, status); let index = compute_index(endpoint, request_type, status);
compare_counter( compare_counter(
metrics.clone(), metrics,
model, model,
endpoint, endpoint,
request_type, request_type,
...@@ -186,7 +186,8 @@ fn inc_counter( ...@@ -186,7 +186,8 @@ fn inc_counter(
#[tokio::test] #[tokio::test]
async fn test_http_service() { async fn test_http_service() {
let service = HttpService::builder().port(8989).build().unwrap(); let service = HttpService::builder().port(8989).build().unwrap();
let manager = service.model_manager().clone(); let state = service.state_clone();
let manager = state.manager();
let token = CancellationToken::new(); let token = CancellationToken::new();
let cancel_token = token.clone(); let cancel_token = token.clone();
...@@ -205,14 +206,14 @@ async fn test_http_service() { ...@@ -205,14 +206,14 @@ async fn test_http_service() {
let result = manager.add_completions_model("bar", failure); let result = manager.add_completions_model("bar", failure);
assert!(result.is_ok()); assert!(result.is_ok());
let metrics = manager.metrics(); let metrics = state.metrics_clone();
metrics.register(&registry).unwrap(); metrics.register(&registry).unwrap();
let mut foo_counters = [0u64; 8]; let mut foo_counters = [0u64; 8];
let mut bar_counters = [0u64; 8]; let mut bar_counters = [0u64; 8];
compare_counters(metrics.clone(), "foo", &foo_counters); compare_counters(&metrics, "foo", &foo_counters);
compare_counters(metrics.clone(), "bar", &bar_counters); compare_counters(&metrics, "bar", &bar_counters);
let client = reqwest::Client::new(); let client = reqwest::Client::new();
...@@ -264,8 +265,8 @@ async fn test_http_service() { ...@@ -264,8 +265,8 @@ async fn test_http_service() {
Status::Success, Status::Success,
&mut foo_counters, &mut foo_counters,
); );
compare_counters(metrics.clone(), "foo", &foo_counters); compare_counters(&metrics, "foo", &foo_counters);
compare_counters(metrics.clone(), "bar", &bar_counters); compare_counters(&metrics, "bar", &bar_counters);
// check registry and look or the request duration histogram // check registry and look or the request duration histogram
let families = registry.gather(); let families = registry.gather();
...@@ -337,8 +338,8 @@ async fn test_http_service() { ...@@ -337,8 +338,8 @@ async fn test_http_service() {
Status::Success, Status::Success,
&mut foo_counters, &mut foo_counters,
); );
compare_counters(metrics.clone(), "foo", &foo_counters); compare_counters(&metrics, "foo", &foo_counters);
compare_counters(metrics.clone(), "bar", &bar_counters); compare_counters(&metrics, "bar", &bar_counters);
// ==== ChatCompletions / Unary / Success ==== // ==== ChatCompletions / Unary / Success ====
// ==== ChatCompletions / Stream / Error ==== // ==== ChatCompletions / Stream / Error ====
...@@ -362,8 +363,8 @@ async fn test_http_service() { ...@@ -362,8 +363,8 @@ async fn test_http_service() {
Status::Error, Status::Error,
&mut bar_counters, &mut bar_counters,
); );
compare_counters(metrics.clone(), "foo", &foo_counters); compare_counters(&metrics, "foo", &foo_counters);
compare_counters(metrics.clone(), "bar", &bar_counters); compare_counters(&metrics, "bar", &bar_counters);
// ==== ChatCompletions / Stream / Error ==== // ==== ChatCompletions / Stream / Error ====
// ==== ChatCompletions / Unary / Error ==== // ==== ChatCompletions / Unary / Error ====
...@@ -383,8 +384,8 @@ async fn test_http_service() { ...@@ -383,8 +384,8 @@ async fn test_http_service() {
Status::Error, Status::Error,
&mut bar_counters, &mut bar_counters,
); );
compare_counters(metrics.clone(), "foo", &foo_counters); compare_counters(&metrics, "foo", &foo_counters);
compare_counters(metrics.clone(), "bar", &bar_counters); compare_counters(&metrics, "bar", &bar_counters);
// ==== ChatCompletions / Unary / Error ==== // ==== ChatCompletions / Unary / Error ====
// ==== Completions / Unary / Error ==== // ==== Completions / Unary / Error ====
...@@ -408,8 +409,8 @@ async fn test_http_service() { ...@@ -408,8 +409,8 @@ async fn test_http_service() {
Status::Error, Status::Error,
&mut bar_counters, &mut bar_counters,
); );
compare_counters(metrics.clone(), "foo", &foo_counters); compare_counters(&metrics, "foo", &foo_counters);
compare_counters(metrics.clone(), "bar", &bar_counters); compare_counters(&metrics, "bar", &bar_counters);
// ==== Completions / Unary / Error ==== // ==== Completions / Unary / Error ====
// ==== Completions / Stream / Error ==== // ==== Completions / Stream / Error ====
...@@ -429,8 +430,8 @@ async fn test_http_service() { ...@@ -429,8 +430,8 @@ async fn test_http_service() {
Status::Error, Status::Error,
&mut bar_counters, &mut bar_counters,
); );
compare_counters(metrics.clone(), "foo", &foo_counters); compare_counters(&metrics, "foo", &foo_counters);
compare_counters(metrics.clone(), "bar", &bar_counters); compare_counters(&metrics, "bar", &bar_counters);
// ==== Completions / Stream / Error ==== // ==== Completions / Stream / Error ====
// =========== Test Invalid Request =========== // =========== Test Invalid Request ===========
......
...@@ -63,9 +63,6 @@ pub use client::{Client, InstanceSource}; ...@@ -63,9 +63,6 @@ pub use client::{Client, InstanceSource};
/// An instance is namespace+component+endpoint+lease_id and must be unique. /// An instance is namespace+component+endpoint+lease_id and must be unique.
pub const INSTANCE_ROOT_PATH: &str = "instances"; pub const INSTANCE_ROOT_PATH: &str = "instances";
/// The root etcd path for ModelEntry
pub const MODEL_ROOT_PATH: &str = "models";
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum TransportType { pub enum TransportType {
......
...@@ -97,14 +97,14 @@ impl Client { ...@@ -97,14 +97,14 @@ impl Client {
loop { loop {
let kv_event = tokio::select! { let kv_event = tokio::select! {
_ = watch_tx.closed() => { _ = watch_tx.closed() => {
tracing::debug!("all watchers have closed; shutting down endpoint watcher for prefix: {}", prefix); tracing::debug!("all watchers have closed; shutting down endpoint watcher for prefix: {prefix}");
break; break;
} }
kv_event = kv_event_rx.recv() => { kv_event = kv_event_rx.recv() => {
match kv_event { match kv_event {
Some(kv_event) => kv_event, Some(kv_event) => kv_event,
None => { None => {
tracing::debug!("watch stream has closed; shutting down endpoint watcher for prefix: {}", prefix); tracing::debug!("watch stream has closed; shutting down endpoint watcher for prefix: {prefix}");
break; break;
} }
} }
...@@ -118,7 +118,7 @@ impl Client { ...@@ -118,7 +118,7 @@ impl Client {
if let (Ok(key), Ok(val)) = (key, val) { if let (Ok(key), Ok(val)) = (key, val) {
map.insert(key.clone(), val); map.insert(key.clone(), val);
} else { } else {
tracing::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {}", prefix); tracing::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {prefix}");
break; break;
} }
} }
......
...@@ -355,6 +355,10 @@ impl Client { ...@@ -355,6 +355,10 @@ impl Client {
match event.event_type() { match event.event_type() {
etcd_client::EventType::Put => { etcd_client::EventType::Put => {
if let Some(kv) = event.kv() { if let Some(kv) = event.kv() {
if tx.is_closed() {
// Receiver no longer interested, expected.
break;
}
if let Err(err) = tx.send(WatchEvent::Put(kv.clone())).await { if let Err(err) = tx.send(WatchEvent::Put(kv.clone())).await {
tracing::error!( tracing::error!(
"kv watcher error forwarding WatchEvent::Put: {err}" "kv watcher error forwarding WatchEvent::Put: {err}"
...@@ -366,6 +370,9 @@ impl Client { ...@@ -366,6 +370,9 @@ impl Client {
} }
etcd_client::EventType::Delete => { etcd_client::EventType::Delete => {
if let Some(kv) = event.kv() { if let Some(kv) = event.kv() {
if tx.is_closed() {
break;
}
if tx.send(WatchEvent::Delete(kv.clone())).await.is_err() { if tx.send(WatchEvent::Delete(kv.clone())).await.is_err() {
// receiver is closed // receiver is closed
break; break;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment