Unverified Commit 174389e6 authored by Michael Feil's avatar Michael Feil Committed by GitHub
Browse files

fix: Httpengine sync-enable-endpoint (#2591)

parent 770d63cc
......@@ -19,8 +19,8 @@ use pyo3::{exceptions::PyException, prelude::*};
use crate::{engine::*, to_pyerr, CancellationToken};
pub use dynamo_llm::endpoint_type::EndpointType;
pub use dynamo_llm::http::service::{error as http_error, service_v2};
pub use dynamo_runtime::{
error,
pipeline::{async_trait, AsyncEngine, Data, ManyOut, SingleIn},
......@@ -92,6 +92,27 @@ impl HttpService {
Ok(())
})
}
fn enable_endpoint(&self, endpoint_type: String, enabled: bool) -> PyResult<()> {
let endpoint_type = EndpointType::all()
.iter()
.find(|&&ep_type| ep_type.as_str().to_lowercase() == endpoint_type.to_lowercase())
.copied()
.ok_or_else(|| {
let valid_types = EndpointType::all()
.iter()
.map(|&ep_type| ep_type.as_str().to_string())
.collect::<Vec<_>>()
.join(", ");
to_pyerr(format!(
"Invalid endpoint type: '{}'. Valid types are: {}",
endpoint_type, valid_types
))
})?;
self.inner.enable_model_endpoint(endpoint_type, enabled);
Ok(())
}
}
/// Python Exception for HTTP errors
......
......@@ -125,9 +125,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
manager.add_completions_model(local_model.display_name(), completions_engine)?;
for endpoint_type in EndpointType::all() {
http_service
.enable_model_endpoint(endpoint_type, true)
.await;
http_service.enable_model_endpoint(endpoint_type, true);
}
http_service
......@@ -141,9 +139,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
// Enable all endpoints
for endpoint_type in EndpointType::all() {
http_service
.enable_model_endpoint(endpoint_type, true)
.await;
http_service.enable_model_endpoint(endpoint_type, true);
}
http_service
}
......@@ -170,9 +166,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
// Enable all endpoints
for endpoint_type in EndpointType::all() {
http_service
.enable_model_endpoint(endpoint_type, true)
.await;
http_service.enable_model_endpoint(endpoint_type, true);
}
http_service
}
......@@ -223,7 +217,7 @@ async fn run_watcher(
let _endpoint_enabler_task = tokio::spawn(async move {
while let Some(model_type) = rx.recv().await {
tracing::debug!("Received model type update: {:?}", model_type);
update_http_endpoints(http_service.clone(), model_type).await;
update_http_endpoints(http_service.clone(), model_type);
}
});
......@@ -236,7 +230,7 @@ async fn run_watcher(
}
/// Updates HTTP service endpoints based on available model types
async fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
tracing::debug!(
"Updating HTTP service endpoints for model type: {:?}",
model_type
......@@ -244,32 +238,20 @@ async fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdat
match model_type {
ModelUpdate::Added(model_type) => match model_type {
ModelType::Backend => {
service
.enable_model_endpoint(EndpointType::Chat, true)
.await;
service
.enable_model_endpoint(EndpointType::Completion, true)
.await;
service.enable_model_endpoint(EndpointType::Chat, true);
service.enable_model_endpoint(EndpointType::Completion, true);
}
_ => {
service
.enable_model_endpoint(model_type.as_endpoint_type(), true)
.await;
service.enable_model_endpoint(model_type.as_endpoint_type(), true);
}
},
ModelUpdate::Removed(model_type) => match model_type {
ModelType::Backend => {
service
.enable_model_endpoint(EndpointType::Chat, false)
.await;
service
.enable_model_endpoint(EndpointType::Completion, false)
.await;
service.enable_model_endpoint(EndpointType::Chat, false);
service.enable_model_endpoint(EndpointType::Completion, false);
}
_ => {
service
.enable_model_endpoint(model_type.as_endpoint_type(), false)
.await;
service.enable_model_endpoint(model_type.as_endpoint_type(), false);
}
},
}
......
......@@ -262,7 +262,7 @@ impl HttpService {
&self.route_docs
}
pub async fn enable_model_endpoint(&self, endpoint_type: EndpointType, enable: bool) {
pub fn enable_model_endpoint(&self, endpoint_type: EndpointType, enable: bool) {
self.state.flags.set(&endpoint_type, enable);
tracing::info!(
"{} endpoints {}",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment