Unverified Commit 174389e6 authored by Michael Feil's avatar Michael Feil Committed by GitHub
Browse files

fix: Httpengine sync-enable-endpoint (#2591)

parent 770d63cc
...@@ -19,8 +19,8 @@ use pyo3::{exceptions::PyException, prelude::*}; ...@@ -19,8 +19,8 @@ use pyo3::{exceptions::PyException, prelude::*};
use crate::{engine::*, to_pyerr, CancellationToken}; use crate::{engine::*, to_pyerr, CancellationToken};
pub use dynamo_llm::endpoint_type::EndpointType;
pub use dynamo_llm::http::service::{error as http_error, service_v2}; pub use dynamo_llm::http::service::{error as http_error, service_v2};
pub use dynamo_runtime::{ pub use dynamo_runtime::{
error, error,
pipeline::{async_trait, AsyncEngine, Data, ManyOut, SingleIn}, pipeline::{async_trait, AsyncEngine, Data, ManyOut, SingleIn},
...@@ -92,6 +92,27 @@ impl HttpService { ...@@ -92,6 +92,27 @@ impl HttpService {
Ok(()) Ok(())
}) })
} }
fn enable_endpoint(&self, endpoint_type: String, enabled: bool) -> PyResult<()> {
let endpoint_type = EndpointType::all()
.iter()
.find(|&&ep_type| ep_type.as_str().to_lowercase() == endpoint_type.to_lowercase())
.copied()
.ok_or_else(|| {
let valid_types = EndpointType::all()
.iter()
.map(|&ep_type| ep_type.as_str().to_string())
.collect::<Vec<_>>()
.join(", ");
to_pyerr(format!(
"Invalid endpoint type: '{}'. Valid types are: {}",
endpoint_type, valid_types
))
})?;
self.inner.enable_model_endpoint(endpoint_type, enabled);
Ok(())
}
} }
/// Python Exception for HTTP errors /// Python Exception for HTTP errors
......
...@@ -125,9 +125,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -125,9 +125,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
manager.add_completions_model(local_model.display_name(), completions_engine)?; manager.add_completions_model(local_model.display_name(), completions_engine)?;
for endpoint_type in EndpointType::all() { for endpoint_type in EndpointType::all() {
http_service http_service.enable_model_endpoint(endpoint_type, true);
.enable_model_endpoint(endpoint_type, true)
.await;
} }
http_service http_service
...@@ -141,9 +139,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -141,9 +139,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
// Enable all endpoints // Enable all endpoints
for endpoint_type in EndpointType::all() { for endpoint_type in EndpointType::all() {
http_service http_service.enable_model_endpoint(endpoint_type, true);
.enable_model_endpoint(endpoint_type, true)
.await;
} }
http_service http_service
} }
...@@ -170,9 +166,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -170,9 +166,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
manager.add_completions_model(model.service_name(), cmpl_pipeline)?; manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
// Enable all endpoints // Enable all endpoints
for endpoint_type in EndpointType::all() { for endpoint_type in EndpointType::all() {
http_service http_service.enable_model_endpoint(endpoint_type, true);
.enable_model_endpoint(endpoint_type, true)
.await;
} }
http_service http_service
} }
...@@ -223,7 +217,7 @@ async fn run_watcher( ...@@ -223,7 +217,7 @@ async fn run_watcher(
let _endpoint_enabler_task = tokio::spawn(async move { let _endpoint_enabler_task = tokio::spawn(async move {
while let Some(model_type) = rx.recv().await { while let Some(model_type) = rx.recv().await {
tracing::debug!("Received model type update: {:?}", model_type); tracing::debug!("Received model type update: {:?}", model_type);
update_http_endpoints(http_service.clone(), model_type).await; update_http_endpoints(http_service.clone(), model_type);
} }
}); });
...@@ -236,7 +230,7 @@ async fn run_watcher( ...@@ -236,7 +230,7 @@ async fn run_watcher(
} }
/// Updates HTTP service endpoints based on available model types /// Updates HTTP service endpoints based on available model types
async fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) { fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
tracing::debug!( tracing::debug!(
"Updating HTTP service endpoints for model type: {:?}", "Updating HTTP service endpoints for model type: {:?}",
model_type model_type
...@@ -244,32 +238,20 @@ async fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdat ...@@ -244,32 +238,20 @@ async fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdat
match model_type { match model_type {
ModelUpdate::Added(model_type) => match model_type { ModelUpdate::Added(model_type) => match model_type {
ModelType::Backend => { ModelType::Backend => {
service service.enable_model_endpoint(EndpointType::Chat, true);
.enable_model_endpoint(EndpointType::Chat, true) service.enable_model_endpoint(EndpointType::Completion, true);
.await;
service
.enable_model_endpoint(EndpointType::Completion, true)
.await;
} }
_ => { _ => {
service service.enable_model_endpoint(model_type.as_endpoint_type(), true);
.enable_model_endpoint(model_type.as_endpoint_type(), true)
.await;
} }
}, },
ModelUpdate::Removed(model_type) => match model_type { ModelUpdate::Removed(model_type) => match model_type {
ModelType::Backend => { ModelType::Backend => {
service service.enable_model_endpoint(EndpointType::Chat, false);
.enable_model_endpoint(EndpointType::Chat, false) service.enable_model_endpoint(EndpointType::Completion, false);
.await;
service
.enable_model_endpoint(EndpointType::Completion, false)
.await;
} }
_ => { _ => {
service service.enable_model_endpoint(model_type.as_endpoint_type(), false);
.enable_model_endpoint(model_type.as_endpoint_type(), false)
.await;
} }
}, },
} }
......
...@@ -262,7 +262,7 @@ impl HttpService { ...@@ -262,7 +262,7 @@ impl HttpService {
&self.route_docs &self.route_docs
} }
pub async fn enable_model_endpoint(&self, endpoint_type: EndpointType, enable: bool) { pub fn enable_model_endpoint(&self, endpoint_type: EndpointType, enable: bool) {
self.state.flags.set(&endpoint_type, enable); self.state.flags.set(&endpoint_type, enable);
tracing::info!( tracing::info!(
"{} endpoints {}", "{} endpoints {}",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment