"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "af770b8e7bf77539fcbc81a9ff1974f12cdf87ff"
Unverified Commit 537759f1 authored by Abrar Shivani's avatar Abrar Shivani Committed by GitHub
Browse files

feat: Dynamic Endpoint Exposure Based on Model Type (#1447)

parent a4e06895
...@@ -8,7 +8,7 @@ mod model_entry; ...@@ -8,7 +8,7 @@ mod model_entry;
pub use model_entry::ModelEntry; pub use model_entry::ModelEntry;
mod watcher; mod watcher;
pub use watcher::ModelWatcher; pub use watcher::{ModelUpdate, ModelWatcher};
/// The root etcd path for ModelEntry /// The root etcd path for ModelEntry
pub const MODEL_ROOT_PATH: &str = "models"; pub const MODEL_ROOT_PATH: &str = "models";
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::mpsc::Sender;
use anyhow::Context as _; use anyhow::Context as _;
use tokio::sync::{mpsc::Receiver, Notify}; use tokio::sync::{mpsc::Receiver, Notify};
...@@ -36,14 +37,24 @@ use crate::{ ...@@ -36,14 +37,24 @@ use crate::{
use super::{ModelEntry, ModelManager, MODEL_ROOT_PATH}; use super::{ModelEntry, ModelManager, MODEL_ROOT_PATH};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ModelUpdate {
Added(ModelType),
Removed(ModelType),
}
pub struct ModelWatcher { pub struct ModelWatcher {
manager: Arc<ModelManager>, manager: Arc<ModelManager>,
drt: DistributedRuntime, drt: DistributedRuntime,
router_mode: RouterMode, router_mode: RouterMode,
notify_on_model: Notify, notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
} }
const ALL_MODEL_TYPES: &[ModelType] =
&[ModelType::Chat, ModelType::Completion, ModelType::Embedding];
impl ModelWatcher { impl ModelWatcher {
pub fn new( pub fn new(
runtime: DistributedRuntime, runtime: DistributedRuntime,
...@@ -56,10 +67,15 @@ impl ModelWatcher { ...@@ -56,10 +67,15 @@ impl ModelWatcher {
drt: runtime, drt: runtime,
router_mode, router_mode,
notify_on_model: Notify::new(), notify_on_model: Notify::new(),
model_update_tx: None,
kv_router_config, kv_router_config,
} }
} }
pub fn set_notify_on_model_update(&mut self, tx: Sender<ModelUpdate>) {
self.model_update_tx = Some(tx);
}
/// Wait until we have at least one chat completions model and return it's name. /// Wait until we have at least one chat completions model and return it's name.
pub async fn wait_for_chat_model(&self) -> String { pub async fn wait_for_chat_model(&self) -> String {
// Loop in case it gets added and immediately deleted // Loop in case it gets added and immediately deleted
...@@ -100,6 +116,12 @@ impl ModelWatcher { ...@@ -100,6 +116,12 @@ impl ModelWatcher {
}; };
self.manager.save_model_entry(key, model_entry.clone()); self.manager.save_model_entry(key, model_entry.clone());
if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Added(model_entry.model_type))
.await
.ok();
}
if self.manager.has_model_any(&model_entry.name) { if self.manager.has_model_any(&model_entry.name) {
tracing::trace!(name = model_entry.name, "New endpoint for existing model"); tracing::trace!(name = model_entry.name, "New endpoint for existing model");
self.notify_on_model.notify_waiters(); self.notify_on_model.notify_waiters();
...@@ -151,13 +173,91 @@ impl ModelWatcher { ...@@ -151,13 +173,91 @@ impl ModelWatcher {
.await .await
.with_context(|| model_name.clone())?; .with_context(|| model_name.clone())?;
if !active_instances.is_empty() { if !active_instances.is_empty() {
let mut update_tx = true;
let mut model_type: ModelType = model_entry.model_type;
if model_entry.model_type == ModelType::Chat
&& self.manager.list_chat_completions_models().is_empty()
{
self.manager.remove_chat_completions_model(&model_name).ok();
model_type = ModelType::Chat;
} else if model_entry.model_type == ModelType::Completion
&& self.manager.list_completions_models().is_empty()
{
self.manager.remove_completions_model(&model_name).ok();
model_type = ModelType::Completion;
} else if model_entry.model_type == ModelType::Embedding
&& self.manager.list_embeddings_models().is_empty()
{
self.manager.remove_embeddings_model(&model_name).ok();
model_type = ModelType::Embedding;
} else if model_entry.model_type == ModelType::Backend {
if self.manager.list_chat_completions_models().is_empty() {
self.manager.remove_chat_completions_model(&model_name).ok();
model_type = ModelType::Chat;
}
if self.manager.list_completions_models().is_empty() {
self.manager.remove_completions_model(&model_name).ok();
if model_type == ModelType::Chat {
model_type = ModelType::Backend;
} else {
model_type = ModelType::Completion;
}
}
} else {
tracing::debug!(
"Model {} is still active in other instances, not removing",
model_name
);
update_tx = false;
}
if update_tx {
if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Removed(model_type)).await.ok();
}
}
return Ok(None); return Ok(None);
} }
// Ignore the errors because model could be either type // Ignore the errors because model could be either type
let _ = self.manager.remove_chat_completions_model(&model_name); let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
let _ = self.manager.remove_completions_model(&model_name); let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
let _ = self.manager.remove_embeddings_model(&model_name); let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
let mut chat_model_removed = false;
let mut completions_model_removed = false;
let mut embeddings_model_removed = false;
if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
chat_model_removed = true;
}
if completions_model_remove_err.is_ok() && self.manager.list_completions_models().is_empty()
{
completions_model_removed = true;
}
if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
embeddings_model_removed = true;
}
if !chat_model_removed && !completions_model_removed && !embeddings_model_removed {
tracing::debug!(
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}",
model_name,
chat_model_removed,
completions_model_removed,
embeddings_model_removed
);
} else {
for model_type in ALL_MODEL_TYPES {
if (chat_model_removed && *model_type == ModelType::Chat)
|| (completions_model_removed && *model_type == ModelType::Completion)
|| (embeddings_model_removed && *model_type == ModelType::Embedding)
{
if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Removed(*model_type)).await.ok();
}
}
}
}
Ok(Some(model_name)) Ok(Some(model_name))
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use serde::{Deserialize, Serialize};
use strum::Display;
#[derive(Copy, Debug, Clone, Display, Serialize, Deserialize, Eq, PartialEq, Hash)]
pub enum EndpointType {
// Chat Completions API
Chat,
/// Older completions API
Completion,
/// Embeddings API
Embedding,
/// Responses API
Responses,
}
impl EndpointType {
pub fn as_str(&self) -> &str {
match self {
Self::Chat => "chat",
Self::Completion => "completion",
Self::Embedding => "embedding",
Self::Responses => "responses",
}
}
pub fn all() -> Vec<Self> {
vec![
Self::Chat,
Self::Completion,
Self::Embedding,
Self::Responses,
]
}
}
...@@ -4,11 +4,13 @@ ...@@ -4,11 +4,13 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{ use crate::{
discovery::{ModelManager, ModelWatcher, MODEL_ROOT_PATH}, discovery::{ModelManager, ModelUpdate, ModelWatcher, MODEL_ROOT_PATH},
endpoint_type::EndpointType,
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{self, input::common, EngineConfig}, entrypoint::{self, input::common, EngineConfig},
http::service::service_v2, http::service::service_v2::{self, HttpService},
kv_router::KvRouterConfig, kv_router::KvRouterConfig,
model_type::ModelType,
types::openai::{ types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
...@@ -22,9 +24,6 @@ use dynamo_runtime::{DistributedRuntime, Runtime}; ...@@ -22,9 +24,6 @@ use dynamo_runtime::{DistributedRuntime, Runtime};
pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> { pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Result<()> {
let mut http_service_builder = service_v2::HttpService::builder() let mut http_service_builder = service_v2::HttpService::builder()
.port(engine_config.local_model().http_port()) .port(engine_config.local_model().http_port())
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.enable_embeddings_endpoints(true)
.with_request_template(engine_config.local_model().request_template()); .with_request_template(engine_config.local_model().request_template());
let http_service = match engine_config { let http_service = match engine_config {
...@@ -45,6 +44,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -45,6 +44,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
MODEL_ROOT_PATH, MODEL_ROOT_PATH,
router_config.router_mode, router_config.router_mode,
Some(router_config.kv_router_config), Some(router_config.kv_router_config),
Arc::new(http_service.clone()),
) )
.await?; .await?;
} }
...@@ -98,6 +98,12 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -98,6 +98,12 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
.await?; .await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?; manager.add_completions_model(local_model.display_name(), completions_engine)?;
for endpoint_type in EndpointType::all() {
http_service
.enable_model_endpoint(endpoint_type, true)
.await;
}
http_service http_service
} }
EngineConfig::StaticFull { engine, model, .. } => { EngineConfig::StaticFull { engine, model, .. } => {
...@@ -106,6 +112,13 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -106,6 +112,13 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let manager = http_service.model_manager(); let manager = http_service.model_manager();
manager.add_completions_model(model.service_name(), engine.clone())?; manager.add_completions_model(model.service_name(), engine.clone())?;
manager.add_chat_completions_model(model.service_name(), engine)?; manager.add_chat_completions_model(model.service_name(), engine)?;
// Enable all endpoints
for endpoint_type in EndpointType::all() {
http_service
.enable_model_endpoint(endpoint_type, true)
.await;
}
http_service http_service
} }
EngineConfig::StaticCore { EngineConfig::StaticCore {
...@@ -129,6 +142,12 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -129,6 +142,12 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
>(model.card(), inner_engine) >(model.card(), inner_engine)
.await?; .await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?; manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
// Enable all endpoints
for endpoint_type in EndpointType::all() {
http_service
.enable_model_endpoint(endpoint_type, true)
.await;
}
http_service http_service
} }
}; };
...@@ -154,13 +173,70 @@ async fn run_watcher( ...@@ -154,13 +173,70 @@ async fn run_watcher(
network_prefix: &str, network_prefix: &str,
router_mode: RouterMode, router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
http_service: Arc<HttpService>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let watch_obj = ModelWatcher::new(runtime, model_manager, router_mode, kv_router_config); let mut watch_obj = ModelWatcher::new(runtime, model_manager, router_mode, kv_router_config);
tracing::info!("Watching for remote model at {network_prefix}"); tracing::info!("Watching for remote model at {network_prefix}");
let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?; let models_watcher = etcd_client.kv_get_and_watch_prefix(network_prefix).await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
// Create a channel to receive model type updates
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
watch_obj.set_notify_on_model_update(tx);
// Spawn a task to watch for model type changes and update HTTP service endpoints
let _endpoint_enabler_task = tokio::spawn(async move {
while let Some(model_type) = rx.recv().await {
tracing::debug!("Received model type update: {:?}", model_type);
update_http_endpoints(http_service.clone(), model_type).await;
}
});
// Pass the sender to the watcher
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
watch_obj.watch(receiver).await; watch_obj.watch(receiver).await;
}); });
Ok(()) Ok(())
} }
/// Updates HTTP service endpoints based on available model types
async fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
tracing::debug!(
"Updating HTTP service endpoints for model type: {:?}",
model_type
);
match model_type {
ModelUpdate::Added(model_type) => match model_type {
ModelType::Backend => {
service
.enable_model_endpoint(EndpointType::Chat, true)
.await;
service
.enable_model_endpoint(EndpointType::Completion, true)
.await;
}
_ => {
service
.enable_model_endpoint(model_type.as_endpoint_type(), true)
.await;
}
},
ModelUpdate::Removed(model_type) => match model_type {
ModelType::Backend => {
service
.enable_model_endpoint(EndpointType::Chat, false)
.await;
service
.enable_model_endpoint(EndpointType::Completion, false)
.await;
}
_ => {
service
.enable_model_endpoint(model_type.as_endpoint_type(), false)
.await;
}
},
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::collections::HashMap;
use std::env::var; use std::env::var;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
...@@ -9,6 +12,7 @@ use super::metrics; ...@@ -9,6 +12,7 @@ use super::metrics;
use super::Metrics; use super::Metrics;
use super::RouteDoc; use super::RouteDoc;
use crate::discovery::ModelManager; use crate::discovery::ModelManager;
use crate::endpoint_type::EndpointType;
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
use anyhow::Result; use anyhow::Result;
use derive_builder::Builder; use derive_builder::Builder;
...@@ -19,10 +23,48 @@ use tokio_util::sync::CancellationToken; ...@@ -19,10 +23,48 @@ use tokio_util::sync::CancellationToken;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
/// HTTP service shared state /// HTTP service shared state
#[derive(Default)]
pub struct State { pub struct State {
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
manager: Arc<ModelManager>, manager: Arc<ModelManager>,
etcd_client: Option<etcd::Client>, etcd_client: Option<etcd::Client>,
flags: StateFlags,
}
#[derive(Default, Debug)]
struct StateFlags {
chat_endpoints_enabled: AtomicBool,
cmpl_endpoints_enabled: AtomicBool,
embeddings_endpoints_enabled: AtomicBool,
responses_endpoints_enabled: AtomicBool,
}
impl StateFlags {
pub fn get(&self, endpoint_type: &EndpointType) -> bool {
match endpoint_type {
EndpointType::Chat => self.chat_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Completion => self.cmpl_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed),
EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed),
}
}
pub fn set(&self, endpoint_type: &EndpointType, enabled: bool) {
match endpoint_type {
EndpointType::Chat => self
.chat_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Completion => self
.cmpl_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Embedding => self
.embeddings_endpoints_enabled
.store(enabled, Ordering::Relaxed),
EndpointType::Responses => self
.responses_endpoints_enabled
.store(enabled, Ordering::Relaxed),
}
}
} }
impl State { impl State {
...@@ -31,6 +73,12 @@ impl State { ...@@ -31,6 +73,12 @@ impl State {
manager, manager,
metrics: Arc::new(Metrics::default()), metrics: Arc::new(Metrics::default()),
etcd_client: None, etcd_client: None,
flags: StateFlags {
chat_endpoints_enabled: AtomicBool::new(false),
cmpl_endpoints_enabled: AtomicBool::new(false),
embeddings_endpoints_enabled: AtomicBool::new(false),
responses_endpoints_enabled: AtomicBool::new(false),
},
} }
} }
...@@ -39,9 +87,14 @@ impl State { ...@@ -39,9 +87,14 @@ impl State {
manager, manager,
metrics: Arc::new(Metrics::default()), metrics: Arc::new(Metrics::default()),
etcd_client, etcd_client,
flags: StateFlags {
chat_endpoints_enabled: AtomicBool::new(false),
cmpl_endpoints_enabled: AtomicBool::new(false),
embeddings_endpoints_enabled: AtomicBool::new(false),
responses_endpoints_enabled: AtomicBool::new(false),
},
} }
} }
/// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests /// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests
pub fn metrics_clone(&self) -> Arc<Metrics> { pub fn metrics_clone(&self) -> Arc<Metrics> {
self.metrics.clone() self.metrics.clone()
...@@ -87,10 +140,10 @@ pub struct HttpServiceConfig { ...@@ -87,10 +140,10 @@ pub struct HttpServiceConfig {
// #[builder(default)] // #[builder(default)]
// custom: Vec<axum::Router> // custom: Vec<axum::Router>
#[builder(default = "true")] #[builder(default = "false")]
enable_chat_endpoints: bool, enable_chat_endpoints: bool,
#[builder(default = "true")] #[builder(default = "false")]
enable_cmpl_endpoints: bool, enable_cmpl_endpoints: bool,
#[builder(default = "true")] #[builder(default = "true")]
...@@ -151,6 +204,15 @@ impl HttpService { ...@@ -151,6 +204,15 @@ impl HttpService {
pub fn route_docs(&self) -> &[RouteDoc] { pub fn route_docs(&self) -> &[RouteDoc] {
&self.route_docs &self.route_docs
} }
pub async fn enable_model_endpoint(&self, endpoint_type: EndpointType, enable: bool) {
self.state.flags.set(&endpoint_type, enable);
tracing::info!(
"{} endpoints {}",
endpoint_type.as_str(),
if enable { "enabled" } else { "disabled" }
);
}
} }
/// Environment variable to set the metrics endpoint path (default: `/metrics`) /// Environment variable to set the metrics endpoint path (default: `/metrics`)
...@@ -177,6 +239,19 @@ impl HttpServiceConfigBuilder { ...@@ -177,6 +239,19 @@ impl HttpServiceConfigBuilder {
let model_manager = Arc::new(ModelManager::new()); let model_manager = Arc::new(ModelManager::new());
let state = Arc::new(State::new_with_etcd(model_manager, config.etcd_client)); let state = Arc::new(State::new_with_etcd(model_manager, config.etcd_client));
state
.flags
.set(&EndpointType::Chat, config.enable_chat_endpoints);
state
.flags
.set(&EndpointType::Completion, config.enable_cmpl_endpoints);
state
.flags
.set(&EndpointType::Embedding, config.enable_embeddings_endpoints);
state
.flags
.set(&EndpointType::Responses, config.enable_responses_endpoints);
// enable prometheus metrics // enable prometheus metrics
let registry = metrics::Registry::new(); let registry = metrics::Registry::new();
state.metrics_clone().register(&registry)?; state.metrics_clone().register(&registry)?;
...@@ -192,42 +267,10 @@ impl HttpServiceConfigBuilder { ...@@ -192,42 +267,10 @@ impl HttpServiceConfigBuilder {
super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()), super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()),
]; ];
if config.enable_chat_endpoints { let endpoint_routes =
routes.push(super::openai::chat_completions_router( HttpServiceConfigBuilder::get_endpoints_router(state.clone(), &config.request_template);
state.clone(), routes.extend(endpoint_routes);
config.request_template.clone(), // TODO clone()? reference? for (route_docs, route) in routes {
var(HTTP_SVC_CHAT_PATH_ENV).ok(),
));
}
if config.enable_cmpl_endpoints {
routes.push(super::openai::completions_router(
state.clone(),
var(HTTP_SVC_CMP_PATH_ENV).ok(),
));
}
if config.enable_embeddings_endpoints {
routes.push(super::openai::embeddings_router(
state.clone(),
var(HTTP_SVC_EMB_PATH_ENV).ok(),
));
}
if config.enable_responses_endpoints {
routes.push(super::openai::responses_router(
state.clone(),
config.request_template,
var(HTTP_SVC_RESPONSES_PATH_ENV).ok(),
));
}
// for (route_docs, route) in routes.into_iter().chain(self.routes.into_iter()) {
// router = router.merge(route);
// all_docs.extend(route_docs);
// }
for (route_docs, route) in routes.into_iter() {
router = router.merge(route); router = router.merge(route);
all_docs.extend(route_docs); all_docs.extend(route_docs);
} }
...@@ -253,4 +296,58 @@ impl HttpServiceConfigBuilder { ...@@ -253,4 +296,58 @@ impl HttpServiceConfigBuilder {
self.etcd_client = Some(etcd_client); self.etcd_client = Some(etcd_client);
self self
} }
fn get_endpoints_router(
state: Arc<State>,
request_template: &Option<RequestTemplate>,
) -> Vec<(Vec<RouteDoc>, axum::Router)> {
let mut routes = Vec::new();
// Add chat completions route with conditional middleware
let (chat_docs, chat_route) = super::openai::chat_completions_router(
state.clone(),
request_template.clone(),
var(HTTP_SVC_CHAT_PATH_ENV).ok(),
);
let (cmpl_docs, cmpl_route) =
super::openai::completions_router(state.clone(), var(HTTP_SVC_CMP_PATH_ENV).ok());
let (embed_docs, embed_route) =
super::openai::embeddings_router(state.clone(), var(HTTP_SVC_EMB_PATH_ENV).ok());
let (responses_docs, responses_route) = super::openai::responses_router(
state.clone(),
request_template.clone(),
var(HTTP_SVC_RESPONSES_PATH_ENV).ok(),
);
let mut endpoint_routes = HashMap::new();
endpoint_routes.insert(EndpointType::Chat, (chat_docs, chat_route));
endpoint_routes.insert(EndpointType::Completion, (cmpl_docs, cmpl_route));
endpoint_routes.insert(EndpointType::Embedding, (embed_docs, embed_route));
endpoint_routes.insert(EndpointType::Responses, (responses_docs, responses_route));
for endpoint_type in EndpointType::all() {
let state_route = state.clone();
if !endpoint_routes.contains_key(&endpoint_type) {
tracing::debug!("{} endpoints are disabled", endpoint_type.as_str());
continue;
}
let (docs, route) = endpoint_routes.get(&endpoint_type).cloned().unwrap();
let route = route.route_layer(axum::middleware::from_fn(
move |req: axum::http::Request<axum::body::Body>, next: axum::middleware::Next| {
let state: Arc<State> = state_route.clone();
async move {
// Check if the endpoint is enabled
let enabled = state.flags.get(&endpoint_type);
if enabled {
Ok(next.run(req).await)
} else {
tracing::debug!("{} endpoints are disabled", endpoint_type.as_str());
Err(axum::http::StatusCode::SERVICE_UNAVAILABLE)
}
}
},
));
routes.push((docs, route));
}
routes
}
} }
...@@ -14,6 +14,7 @@ pub mod backend; ...@@ -14,6 +14,7 @@ pub mod backend;
pub mod common; pub mod common;
pub mod disagg_router; pub mod disagg_router;
pub mod discovery; pub mod discovery;
pub mod endpoint_type;
pub mod engines; pub mod engines;
pub mod entrypoint; pub mod entrypoint;
pub mod gguf; pub mod gguf;
......
...@@ -41,4 +41,13 @@ impl ModelType { ...@@ -41,4 +41,13 @@ impl ModelType {
pub fn all() -> Vec<Self> { pub fn all() -> Vec<Self> {
vec![Self::Chat, Self::Completion, Self::Embedding, Self::Backend] vec![Self::Chat, Self::Completion, Self::Embedding, Self::Backend]
} }
pub fn as_endpoint_type(&self) -> crate::endpoint_type::EndpointType {
match self {
Self::Chat => crate::endpoint_type::EndpointType::Chat,
Self::Completion => crate::endpoint_type::EndpointType::Completion,
Self::Embedding => crate::endpoint_type::EndpointType::Embedding,
Self::Backend => panic!("Backend model type does not map to an endpoint type"),
}
}
} }
...@@ -270,7 +270,12 @@ fn inc_counter( ...@@ -270,7 +270,12 @@ fn inc_counter(
#[allow(deprecated)] #[allow(deprecated)]
#[tokio::test] #[tokio::test]
async fn test_http_service() { async fn test_http_service() {
let service = HttpService::builder().port(8989).build().unwrap(); let service = HttpService::builder()
.port(8989)
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.build()
.unwrap();
let state = service.state_clone(); let state = service.state_clone();
let manager = state.manager(); let manager = state.manager();
...@@ -572,7 +577,12 @@ async fn wait_for_service_ready(port: u16) { ...@@ -572,7 +577,12 @@ async fn wait_for_service_ready(port: u16) {
fn service_with_engines( fn service_with_engines(
#[default(8990)] port: u16, #[default(8990)] port: u16,
) -> (HttpService, Arc<CounterEngine>, Arc<AlwaysFailEngine>) { ) -> (HttpService, Arc<CounterEngine>, Arc<AlwaysFailEngine>) {
let service = HttpService::builder().port(port).build().unwrap(); let service = HttpService::builder()
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.port(port)
.build()
.unwrap();
let manager = service.model_manager(); let manager = service.model_manager();
let counter = Arc::new(CounterEngine {}); let counter = Arc::new(CounterEngine {});
...@@ -958,7 +968,12 @@ async fn test_generic_byot_client( ...@@ -958,7 +968,12 @@ async fn test_generic_byot_client(
#[rstest] #[rstest]
#[tokio::test] #[tokio::test]
async fn test_client_disconnect_cancellation_unary() { async fn test_client_disconnect_cancellation_unary() {
let service = HttpService::builder().port(8993).build().unwrap(); let service = HttpService::builder()
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.port(8993)
.build()
.unwrap();
let state = service.state_clone(); let state = service.state_clone();
let manager = state.manager(); let manager = state.manager();
...@@ -1044,7 +1059,12 @@ async fn test_client_disconnect_cancellation_unary() { ...@@ -1044,7 +1059,12 @@ async fn test_client_disconnect_cancellation_unary() {
async fn test_client_disconnect_cancellation_streaming() { async fn test_client_disconnect_cancellation_streaming() {
dynamo_runtime::logging::init(); dynamo_runtime::logging::init();
let service = HttpService::builder().port(8994).build().unwrap(); let service = HttpService::builder()
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.port(8994)
.build()
.unwrap();
let state = service.state_clone(); let state = service.state_clone();
let manager = state.manager(); let manager = state.manager();
...@@ -1137,7 +1157,12 @@ async fn test_request_id_annotation() { ...@@ -1137,7 +1157,12 @@ async fn test_request_id_annotation() {
// TODO(ryan): make better fixtures, this is too much to test sometime so simple // TODO(ryan): make better fixtures, this is too much to test sometime so simple
dynamo_runtime::logging::init(); dynamo_runtime::logging::init();
let service = HttpService::builder().port(8995).build().unwrap(); let service = HttpService::builder()
.enable_chat_endpoints(true)
.enable_cmpl_endpoints(true)
.port(8995)
.build()
.unwrap();
let state = service.state_clone(); let state = service.state_clone();
let manager = state.manager(); let manager = state.manager();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment