Unverified Commit 537759f1 authored by Abrar Shivani's avatar Abrar Shivani Committed by GitHub
Browse files

feat: Dynamic Endpoint Exposure Based on Model Type (#1447)

parent a4e06895
......@@ -8,7 +8,7 @@ mod model_entry;
pub use model_entry::ModelEntry;
mod watcher;
pub use watcher::ModelWatcher;
pub use watcher::{ModelUpdate, ModelWatcher};
/// The root etcd path for ModelEntry
pub const MODEL_ROOT_PATH: &str = "models";
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use tokio::sync::mpsc::Sender;
use anyhow::Context as _;
use tokio::sync::{mpsc::Receiver, Notify};
......@@ -36,14 +37,24 @@ use crate::{
use super::{ModelEntry, ModelManager, MODEL_ROOT_PATH};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ModelUpdate {
Added(ModelType),
Removed(ModelType),
}
pub struct ModelWatcher {
manager: Arc<ModelManager>,
drt: DistributedRuntime,
router_mode: RouterMode,
notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>,
kv_router_config: Option<KvRouterConfig>,
}
const ALL_MODEL_TYPES: &[ModelType] =
&[ModelType::Chat, ModelType::Completion, ModelType::Embedding];
impl ModelWatcher {
pub fn new(
runtime: DistributedRuntime,
......@@ -56,10 +67,15 @@ impl ModelWatcher {
drt: runtime,
router_mode,
notify_on_model: Notify::new(),
model_update_tx: None,
kv_router_config,
}
}
pub fn set_notify_on_model_update(&mut self, tx: Sender<ModelUpdate>) {
self.model_update_tx = Some(tx);
}
/// Wait until we have at least one chat completions model and return it's name.
pub async fn wait_for_chat_model(&self) -> String {
// Loop in case it gets added and immediately deleted
......@@ -100,6 +116,12 @@ impl ModelWatcher {
};
self.manager.save_model_entry(key, model_entry.clone());
if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Added(model_entry.model_type))
.await
.ok();
}
if self.manager.has_model_any(&model_entry.name) {
tracing::trace!(name = model_entry.name, "New endpoint for existing model");
self.notify_on_model.notify_waiters();
......@@ -151,13 +173,91 @@ impl ModelWatcher {
.await
.with_context(|| model_name.clone())?;
if !active_instances.is_empty() {
let mut update_tx = true;
let mut model_type: ModelType = model_entry.model_type;
if model_entry.model_type == ModelType::Chat
&& self.manager.list_chat_completions_models().is_empty()
{
self.manager.remove_chat_completions_model(&model_name).ok();
model_type = ModelType::Chat;
} else if model_entry.model_type == ModelType::Completion
&& self.manager.list_completions_models().is_empty()
{
self.manager.remove_completions_model(&model_name).ok();
model_type = ModelType::Completion;
} else if model_entry.model_type == ModelType::Embedding
&& self.manager.list_embeddings_models().is_empty()
{
self.manager.remove_embeddings_model(&model_name).ok();
model_type = ModelType::Embedding;
} else if model_entry.model_type == ModelType::Backend {
if self.manager.list_chat_completions_models().is_empty() {
self.manager.remove_chat_completions_model(&model_name).ok();
model_type = ModelType::Chat;
}
if self.manager.list_completions_models().is_empty() {
self.manager.remove_completions_model(&model_name).ok();
if model_type == ModelType::Chat {
model_type = ModelType::Backend;
} else {
model_type = ModelType::Completion;
}
}
} else {
tracing::debug!(
"Model {} is still active in other instances, not removing",
model_name
);
update_tx = false;
}
if update_tx {
if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Removed(model_type)).await.ok();
}
}
return Ok(None);
}
// Ignore the errors because model could be either type
let _ = self.manager.remove_chat_completions_model(&model_name);
let _ = self.manager.remove_completions_model(&model_name);
let _ = self.manager.remove_embeddings_model(&model_name);
let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
let mut chat_model_removed = false;
let mut completions_model_removed = false;
let mut embeddings_model_removed = false;
if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
chat_model_removed = true;
}
if completions_model_remove_err.is_ok() && self.manager.list_completions_models().is_empty()
{
completions_model_removed = true;
}
if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
embeddings_model_removed = true;
}
if !chat_model_removed && !completions_model_removed && !embeddings_model_removed {
tracing::debug!(
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}",
model_name,
chat_model_removed,
completions_model_removed,
embeddings_model_removed
);
} else {
for model_type in ALL_MODEL_TYPES {
if (chat_model_removed && *model_type == ModelType::Chat)
|| (completions_model_removed && *model_type == ModelType::Completion)
|| (embeddings_model_removed && *model_type == ModelType::Embedding)
{
if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Removed(*model_type)).await.ok();
}
}
}
}
Ok(Some(model_name))
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use strum::Display;
#[derive(Copy, Debug, Clone, Display, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub enum EndpointType {
// Chat Completions API
Chat,
/// Older completions API
Completion,
/// Embeddings API
Embedding,
/// Responses API
Responses,
}
impl EndpointType {
pub fn as_str(&self) -> &str {
match self {
Self::Chat => "chat",
Self::Completion => "completion",
Self::Embedding => "embedding",
Self::Responses => "responses",
}
}
pub fn all() -> Vec<Self> {
vec![
Self::Chat,
Self::Completion,
Self::Embedding,
Self::Responses,
]
}
}
......@@ -4,11 +4,13 @@
use std::sync::Arc;
use crate::{
discovery::{ModelManager, ModelWatcher, MODEL_ROOT_PATH},
discovery::{ModelManager, ModelUpdate, ModelWatcher, MODEL_ROOT_PATH},
endpoint_type::EndpointType,
engines::StreamingEngineAdapter,
entrypoint::{self, input::common, EngineConfig},
http::service::service_v2,
http::service::service_v2::{self, HttpService},
kv_router::KvRouterConfig,
model_type::ModelType,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
......@@ -22,9 +24,6 @@ use dynamo_runtime::{DistributedRuntime, Runtime};
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
let mut http_service_builder = service_v2::HttpService::builder()
.port(engine_config.local_model().http_port())
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.enable_embeddings_endpoints(true)
.with_request_template(engine_config.local_model().request_template());
let http_service = match engine_config {
......@@ -45,6 +44,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
MODEL_ROOT_PATH,
router_config.router_mode,
Some(router_config.kv_router_config),
Arc::new(http_service.clone()),
)
.await?;
}
......@@ -98,6 +98,12 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
.await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?;
for endpoint_type in EndpointType::all() {
http_service
.enable_model_endpoint(endpoint_type, true)
.await;
}
http_service
}
EngineConfig::StaticFull { engine, model, .. } => {
......@@ -106,6 +112,13 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let manager = http_service.model_manager();
manager.add_completions_model(model.service_name(), engine.clone())?;
manager.add_chat_completions_model(model.service_name(), engine)?;
// Enable all endpoints
for endpoint_type in EndpointType::all() {
http_service
.enable_model_endpoint(endpoint_type, true)
.await;
}
http_service
}
EngineConfig::StaticCore {
......@@ -129,6 +142,12 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
>(model.card(), inner_engine)
.await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
// Enable all endpoints
for endpoint_type in EndpointType::all() {
http_service
.enable_model_endpoint(endpoint_type, true)
.await;
}
http_service
}
};
......@@ -154,13 +173,70 @@ async fn run_watcher(
network_prefix: &str,
router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>,
http_service: Arc<HttpService>,
) -> anyhow::Result<()> {
let watch_obj = ModelWatcher::new(runtime, model_manager, router_mode, kv_router_config);
let mut watch_obj = ModelWatcher::new(runtime, model_manager, router_mode, kv_router_config);
tracing::info!("Watching for remote model at {network_prefix}");
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
// Create a channel to receive model type updates
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
watch_obj.set_notify_on_model_update(tx);
// Spawn a task to watch for model type changes and update HTTP service endpoints
let _endpoint_enabler_task = tokio::spawn(async move {
while let Some(model_type) = rx.recv().await {
tracing::debug!("Received model type update: {:?}", model_type);
update_http_endpoints(http_service.clone(), model_type).await;
}
});
// Pass the sender to the watcher
let _watcher_task = tokio::spawn(async move {
watch_obj.watch(receiver).await;
});
Ok(())
}
/// Updates HTTP service endpoints based on available model types
async fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
tracing::debug!(
"Updating HTTP service endpoints for model type: {:?}",
model_type
);
match model_type {
ModelUpdate::Added(model_type) => match model_type {
ModelType::Backend => {
service
.enable_model_endpoint(EndpointType::Chat, true)
.await;
service
.enable_model_endpoint(EndpointType::Completion, true)
.await;
}
_ => {
service
.enable_model_endpoint(model_type.as_endpoint_type(), true)
.await;
}
},
ModelUpdate::Removed(model_type) => match model_type {
ModelType::Backend => {
service
.enable_model_endpoint(EndpointType::Chat, false)
.await;
service
.enable_model_endpoint(EndpointType::Completion, false)
.await;
}
_ => {
service
.enable_model_endpoint(model_type.as_endpoint_type(), false)
.await;
}
},
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use std::env::var;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::time::Duration;
......@@ -9,6 +12,7 @@ use super::metrics;
use super::Metrics;
use super::RouteDoc;
use crate::discovery::ModelManager;
use crate::endpoint_type::EndpointType;
use crate::request_template::RequestTemplate;
use anyhow::Result;
use derive_builder::Builder;
......@@ -19,10 +23,48 @@ use tokio_util::sync::CancellationToken;
use tower_http::trace::TraceLayer;
/// HTTP service shared state
#[derive(Default)]
pub struct State {
metrics: Arc<Metrics>,
manager: Arc<ModelManager>,
etcd_client: Option<etcd::Client>,
flags: StateFlags,
}
#[derive(Default, Debug)]
struct StateFlags {
chat_endpoints_enabled: AtomicBool,
cmpl_endpoints_enabled: AtomicBool,
embeddings_endpoints_enabled: AtomicBool,
responses_endpoints_enabled: AtomicBool,
}
impl StateFlags {
pub fn get(&self, endpoint_type: &EndpointType) -> bool {
match endpoint_type {
EndpointType::Chat => self.chat_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Completion => self.cmpl_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed),
}
}
pub fn set(&self, endpoint_type: &EndpointType, enabled: bool) {
match endpoint_type {
EndpointType::Chat => self
.chat_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Completion => self
.cmpl_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Embedding => self
.embeddings_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Responses => self
.responses_endpoints_enabled
.store(enabled, Ordering::Relaxed),
}
}
}
impl State {
......@@ -31,6 +73,12 @@ impl State {
manager,
metrics: Arc::new(Metrics::default()),
etcd_client: None,
flags: StateFlags {
chat_endpoints_enabled: AtomicBool::new(false),
cmpl_endpoints_enabled: AtomicBool::new(false),
embeddings_endpoints_enabled: AtomicBool::new(false),
responses_endpoints_enabled: AtomicBool::new(false),
},
}
}
......@@ -39,9 +87,14 @@ impl State {
manager,
metrics: Arc::new(Metrics::default()),
etcd_client,
flags: StateFlags {
chat_endpoints_enabled: AtomicBool::new(false),
cmpl_endpoints_enabled: AtomicBool::new(false),
embeddings_endpoints_enabled: AtomicBool::new(false),
responses_endpoints_enabled: AtomicBool::new(false),
},
}
}
/// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests
pub fn metrics_clone(&self) -> Arc<Metrics> {
self.metrics.clone()
......@@ -87,10 +140,10 @@ pub struct HttpServiceConfig {
// #[builder(default)]
// custom: Vec<axum::Router>
#[builder(default = "true")]
#[builder(default = "false")]
enable_chat_endpoints: bool,
#[builder(default = "true")]
#[builder(default = "false")]
enable_cmpl_endpoints: bool,
#[builder(default = "true")]
......@@ -151,6 +204,15 @@ impl HttpService {
pub fn route_docs(&self) -> &[RouteDoc] {
&self.route_docs
}
pub async fn enable_model_endpoint(&self, endpoint_type: EndpointType, enable: bool) {
self.state.flags.set(&endpoint_type, enable);
tracing::info!(
"{} endpoints {}",
endpoint_type.as_str(),
if enable { "enabled" } else { "disabled" }
);
}
}
/// Environment variable to set the metrics endpoint path (default: `/metrics`)
......@@ -177,6 +239,19 @@ impl HttpServiceConfigBuilder {
let model_manager = Arc::new(ModelManager::new());
let state = Arc::new(State::new_with_etcd(model_manager, config.etcd_client));
state
.flags
.set(&EndpointType::Chat, config.enable_chat_endpoints);
state
.flags
.set(&EndpointType::Completion, config.enable_cmpl_endpoints);
state
.flags
.set(&EndpointType::Embedding, config.enable_embeddings_endpoints);
state
.flags
.set(&EndpointType::Responses, config.enable_responses_endpoints);
// enable prometheus metrics
let registry = metrics::Registry::new();
state.metrics_clone().register(&registry)?;
......@@ -192,42 +267,10 @@ impl HttpServiceConfigBuilder {
super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()),
];
if config.enable_chat_endpoints {
routes.push(super::openai::chat_completions_router(
state.clone(),
config.request_template.clone(), // TODO clone()? reference?
var(HTTP_SVC_CHAT_PATH_ENV).ok(),
));
}
if config.enable_cmpl_endpoints {
routes.push(super::openai::completions_router(
state.clone(),
var(HTTP_SVC_CMP_PATH_ENV).ok(),
));
}
if config.enable_embeddings_endpoints {
routes.push(super::openai::embeddings_router(
state.clone(),
var(HTTP_SVC_EMB_PATH_ENV).ok(),
));
}
if config.enable_responses_endpoints {
routes.push(super::openai::responses_router(
state.clone(),
config.request_template,
var(HTTP_SVC_RESPONSES_PATH_ENV).ok(),
));
}
// for (route_docs, route) in routes.into_iter().chain(self.routes.into_iter()) {
// router = router.merge(route);
// all_docs.extend(route_docs);
// }
for (route_docs, route) in routes.into_iter() {
let endpoint_routes =
HttpServiceConfigBuilder::get_endpoints_router(state.clone(), &config.request_template);
routes.extend(endpoint_routes);
for (route_docs, route) in routes {
router = router.merge(route);
all_docs.extend(route_docs);
}
......@@ -253,4 +296,58 @@ impl HttpServiceConfigBuilder {
self.etcd_client = Some(etcd_client);
self
}
fn get_endpoints_router(
state: Arc<State>,
request_template: &Option<RequestTemplate>,
) -> Vec<(Vec<RouteDoc>, axum::Router)> {
let mut routes = Vec::new();
// Add chat completions route with conditional middleware
let (chat_docs, chat_route) = super::openai::chat_completions_router(
state.clone(),
request_template.clone(),
var(HTTP_SVC_CHAT_PATH_ENV).ok(),
);
let (cmpl_docs, cmpl_route) =
super::openai::completions_router(state.clone(), var(HTTP_SVC_CMP_PATH_ENV).ok());
let (embed_docs, embed_route) =
super::openai::embeddings_router(state.clone(), var(HTTP_SVC_EMB_PATH_ENV).ok());
let (responses_docs, responses_route) = super::openai::responses_router(
state.clone(),
request_template.clone(),
var(HTTP_SVC_RESPONSES_PATH_ENV).ok(),
);
let mut endpoint_routes = HashMap::new();
endpoint_routes.insert(EndpointType::Chat, (chat_docs, chat_route));
endpoint_routes.insert(EndpointType::Completion, (cmpl_docs, cmpl_route));
endpoint_routes.insert(EndpointType::Embedding, (embed_docs, embed_route));
endpoint_routes.insert(EndpointType::Responses, (responses_docs, responses_route));
for endpoint_type in EndpointType::all() {
let state_route = state.clone();
if !endpoint_routes.contains_key(&endpoint_type) {
tracing::debug!("{} endpoints are disabled", endpoint_type.as_str());
continue;
}
let (docs, route) = endpoint_routes.get(&endpoint_type).cloned().unwrap();
let route = route.route_layer(axum::middleware::from_fn(
move |req: axum::http::Request<axum::body::Body>, next: axum::middleware::Next| {
let state: Arc<State> = state_route.clone();
async move {
// Check if the endpoint is enabled
let enabled = state.flags.get(&endpoint_type);
if enabled {
Ok(next.run(req).await)
} else {
tracing::debug!("{} endpoints are disabled", endpoint_type.as_str());
Err(axum::http::StatusCode::SERVICE_UNAVAILABLE)
}
}
},
));
routes.push((docs, route));
}
routes
}
}
......@@ -14,6 +14,7 @@ pub mod backend;
pub mod common;
pub mod disagg_router;
pub mod discovery;
pub mod endpoint_type;
pub mod engines;
pub mod entrypoint;
pub mod gguf;
......
......@@ -41,4 +41,13 @@ impl ModelType {
pub fn all() -> Vec<Self> {
vec![Self::Chat, Self::Completion, Self::Embedding, Self::Backend]
}
pub fn as_endpoint_type(&self) -> crate::endpoint_type::EndpointType {
match self {
Self::Chat => crate::endpoint_type::EndpointType::Chat,
Self::Completion => crate::endpoint_type::EndpointType::Completion,
Self::Embedding => crate::endpoint_type::EndpointType::Embedding,
Self::Backend => panic!("Backend model type does not map to an endpoint type"),
}
}
}
......@@ -270,7 +270,12 @@ fn inc_counter(
#[allow(deprecated)]
#[tokio::test]
async fn test_http_service() {
let service = HttpService::builder().port(8989).build().unwrap();
let service = HttpService::builder()
.port(8989)
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.build()
.unwrap();
let state = service.state_clone();
let manager = state.manager();
......@@ -572,7 +577,12 @@ async fn wait_for_service_ready(port: u16) {
fn service_with_engines(
#[default(8990)] port: u16,
) -> (HttpService, Arc<CounterEngine>, Arc<AlwaysFailEngine>) {
let service = HttpService::builder().port(port).build().unwrap();
let service = HttpService::builder()
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.port(port)
.build()
.unwrap();
let manager = service.model_manager();
let counter = Arc::new(CounterEngine {});
......@@ -958,7 +968,12 @@ async fn test_generic_byot_client(
#[rstest]
#[tokio::test]
async fn test_client_disconnect_cancellation_unary() {
let service = HttpService::builder().port(8993).build().unwrap();
let service = HttpService::builder()
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.port(8993)
.build()
.unwrap();
let state = service.state_clone();
let manager = state.manager();
......@@ -1044,7 +1059,12 @@ async fn test_client_disconnect_cancellation_unary() {
async fn test_client_disconnect_cancellation_streaming() {
dynamo_runtime::logging::init();
let service = HttpService::builder().port(8994).build().unwrap();
let service = HttpService::builder()
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.port(8994)
.build()
.unwrap();
let state = service.state_clone();
let manager = state.manager();
......@@ -1137,7 +1157,12 @@ async fn test_request_id_annotation() {
// TODO(ryan): make better fixtures, this is too much to test sometime so simple
dynamo_runtime::logging::init();
let service = HttpService::builder().port(8995).build().unwrap();
let service = HttpService::builder()
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.port(8995)
.build()
.unwrap();
let state = service.state_clone();
let manager = state.manager();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment