Unverified Commit da0f2fb8 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(frontend): First part of Python request handling (#4999)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 91423c45
...@@ -30,12 +30,15 @@ from dynamo.llm import ( ...@@ -30,12 +30,15 @@ from dynamo.llm import (
EngineType, EngineType,
EntrypointArgs, EntrypointArgs,
KvRouterConfig, KvRouterConfig,
ModelDeploymentCard,
PythonAsyncEngine,
RouterConfig, RouterConfig,
RouterMode, RouterMode,
make_engine, make_engine,
run_input, run_input,
) )
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
from . import __version__ from . import __version__
...@@ -45,9 +48,25 @@ CUSTOM_BACKEND_METRICS_POLLING_INTERVAL_ENV_VAR = ( ...@@ -45,9 +48,25 @@ CUSTOM_BACKEND_METRICS_POLLING_INTERVAL_ENV_VAR = (
) )
CUSTOM_BACKEND_ENDPOINT_ENV_VAR = "CUSTOM_BACKEND_ENDPOINT" CUSTOM_BACKEND_ENDPOINT_ENV_VAR = "CUSTOM_BACKEND_ENDPOINT"
configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def _dummy_generator(request):
"""Minimal generator that yields nothing. Work in progress."""
return
yield # Makes this an async generator
async def engine_factory(mdc: ModelDeploymentCard) -> PythonAsyncEngine:
"""
Called by Rust when a model is discovered.
"""
loop = asyncio.get_running_loop()
logger.info(f"Engine_factory called with MDC: {mdc.to_json_str()[:100]}...")
return PythonAsyncEngine(_dummy_generator, loop)
def validate_model_name(value): def validate_model_name(value):
"""Validate that model-name is a non-empty string.""" """Validate that model-name is a non-empty string."""
if not value or not isinstance(value, str) or len(value.strip()) == 0: if not value or not isinstance(value, str) or len(value.strip()) == 0:
...@@ -254,6 +273,12 @@ def parse_args(): ...@@ -254,6 +273,12 @@ def parse_args():
default=os.environ.get("DYN_REQUEST_PLANE", "nats"), default=os.environ.get("DYN_REQUEST_PLANE", "nats"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]", help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
) )
parser.add_argument(
"--exp-python-factory",
action="store_true",
default=False,
help="[EXPERIMENTAL] Enable Python-based engine factory. When set, engines will be created via a Python callback instead of the default Rust pipeline.",
)
flags = parser.parse_args() flags = parser.parse_args()
...@@ -356,6 +381,9 @@ async def async_main(): ...@@ -356,6 +381,9 @@ async def async_main():
"custom_backend_metrics_polling_interval" "custom_backend_metrics_polling_interval"
] = flags.custom_backend_metrics_polling_interval ] = flags.custom_backend_metrics_polling_interval
if flags.exp_python_factory:
kwargs["engine_factory"] = engine_factory
e = EntrypointArgs(EngineType.Dynamic, **kwargs) e = EntrypointArgs(EngineType.Dynamic, **kwargs)
engine = await make_engine(runtime, e) engine = await make_engine(runtime, e)
......
...@@ -146,7 +146,10 @@ async fn engine_for( ...@@ -146,7 +146,10 @@ async fn engine_for(
match out_opt { match out_opt {
Output::Auto => { Output::Auto => {
// Auto-discover backends // Auto-discover backends
Ok(EngineConfig::Dynamic(Box::new(local_model))) Ok(EngineConfig::Dynamic {
model: Box::new(local_model),
engine_factory: None,
})
} }
Output::Echo => Ok(EngineConfig::InProcessText { Output::Echo => Ok(EngineConfig::InProcessText {
model: Box::new(local_model), model: Box::new(local_model),
......
...@@ -975,6 +975,7 @@ pub async fn create_worker_selection_pipeline_chat( ...@@ -975,6 +975,7 @@ pub async fn create_worker_selection_pipeline_chat(
component.drt().clone(), component.drt().clone(),
model_manager.clone(), model_manager.clone(),
router_config, router_config,
None,
); );
let cards = watcher let cards = watcher
.cards_for_model(model_name, Some(namespace), false) .cards_for_model(model_name, Some(namespace), false)
......
...@@ -2,20 +2,29 @@ ...@@ -2,20 +2,29 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::fmt::Display; use std::fmt::Display;
use std::future::Future;
use std::path::PathBuf; use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use pyo3::{exceptions::PyException, prelude::*}; use pyo3::{exceptions::PyException, prelude::*};
use pyo3_async_runtimes::TaskLocals;
use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig; use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
use dynamo_llm::entrypoint::EngineFactoryCallback;
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig; use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
use dynamo_llm::entrypoint::input::Input; use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig; use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
use dynamo_llm::local_model::DEFAULT_HTTP_PORT; use dynamo_llm::local_model::DEFAULT_HTTP_PORT;
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder}; use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
use dynamo_llm::mocker::protocols::MockEngineArgs; use dynamo_llm::mocker::protocols::MockEngineArgs;
use dynamo_llm::model_card::ModelDeploymentCard as RsModelDeploymentCard;
use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
use dynamo_runtime::protocols::EndpointId; use dynamo_runtime::protocols::EndpointId;
use super::model_card::ModelDeploymentCard;
use crate::RouterMode; use crate::RouterMode;
use crate::engine::PythonAsyncEngine;
#[pyclass(eq, eq_int)] #[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
...@@ -117,6 +126,21 @@ impl From<RouterConfig> for RsRouterConfig { ...@@ -117,6 +126,21 @@ impl From<RouterConfig> for RsRouterConfig {
} }
} }
/// Wrapper to hold Python callback and its TaskLocals for async execution
#[derive(Clone)]
struct PyEngineFactory {
callback: Arc<PyObject>,
locals: Arc<TaskLocals>,
}
impl std::fmt::Debug for PyEngineFactory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PyEngineFactory")
.field("callback", &"<PyObject>")
.finish()
}
}
#[pyclass] #[pyclass]
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub(crate) struct EntrypointArgs { pub(crate) struct EntrypointArgs {
...@@ -137,14 +161,16 @@ pub(crate) struct EntrypointArgs { ...@@ -137,14 +161,16 @@ pub(crate) struct EntrypointArgs {
custom_backend_metrics_endpoint: Option<String>, custom_backend_metrics_endpoint: Option<String>,
custom_backend_metrics_polling_interval: Option<f64>, custom_backend_metrics_polling_interval: Option<f64>,
is_prefill: bool, is_prefill: bool,
engine_factory: Option<PyEngineFactory>,
} }
#[pymethods] #[pymethods]
impl EntrypointArgs { impl EntrypointArgs {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[new] #[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, custom_backend_metrics_endpoint=None, custom_backend_metrics_polling_interval=None, is_prefill=false))] #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, custom_backend_metrics_endpoint=None, custom_backend_metrics_polling_interval=None, is_prefill=false, engine_factory=None))]
pub fn new( pub fn new(
py: Python<'_>,
engine_type: EngineType, engine_type: EngineType,
model_path: Option<PathBuf>, model_path: Option<PathBuf>,
model_name: Option<String>, // e.g. "dyn://namespace.component.endpoint" model_name: Option<String>, // e.g. "dyn://namespace.component.endpoint"
...@@ -162,6 +188,7 @@ impl EntrypointArgs { ...@@ -162,6 +188,7 @@ impl EntrypointArgs {
custom_backend_metrics_endpoint: Option<String>, custom_backend_metrics_endpoint: Option<String>,
custom_backend_metrics_polling_interval: Option<f64>, custom_backend_metrics_polling_interval: Option<f64>,
is_prefill: bool, is_prefill: bool,
engine_factory: Option<PyObject>,
) -> PyResult<Self> { ) -> PyResult<Self> {
let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from); let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
if (tls_cert_path.is_some() && tls_key_path.is_none()) if (tls_cert_path.is_some() && tls_key_path.is_none())
...@@ -171,6 +198,23 @@ impl EntrypointArgs { ...@@ -171,6 +198,23 @@ impl EntrypointArgs {
"tls_cert_path and tls_key_path must be provided together", "tls_cert_path and tls_key_path must be provided together",
)); ));
} }
// Capture TaskLocals at registration time for the engine factory callback
let engine_factory = engine_factory
.map(|callback| {
let locals = pyo3_async_runtimes::tokio::get_current_locals(py).map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"Failed to get TaskLocals for engine_factory: {}",
e
))
})?;
Ok::<_, PyErr>(PyEngineFactory {
callback: Arc::new(callback),
locals: Arc::new(locals),
})
})
.transpose()?;
Ok(EntrypointArgs { Ok(EntrypointArgs {
engine_type, engine_type,
model_path, model_path,
...@@ -189,6 +233,7 @@ impl EntrypointArgs { ...@@ -189,6 +233,7 @@ impl EntrypointArgs {
custom_backend_metrics_endpoint, custom_backend_metrics_endpoint,
custom_backend_metrics_polling_interval, custom_backend_metrics_polling_interval,
is_prefill, is_prefill,
engine_factory,
}) })
} }
} }
...@@ -251,6 +296,57 @@ pub fn make_engine<'p>( ...@@ -251,6 +296,57 @@ pub fn make_engine<'p>(
}) })
} }
/// Convert a PyEngineFactory to a Rust EngineFactoryCallback
fn py_engine_factory_to_callback(factory: PyEngineFactory) -> EngineFactoryCallback {
let callback = factory.callback;
let locals = factory.locals;
Arc::new(
move |card: RsModelDeploymentCard| -> Pin<
Box<dyn Future<Output = anyhow::Result<OpenAIChatCompletionsStreamingEngine>> + Send>,
> {
let callback = callback.clone();
let locals = locals.clone();
Box::pin(async move {
// Acquire GIL to call Python callback and convert coroutine to future
let py_future = Python::with_gil(|py| {
// Create Python ModelDeploymentCard wrapper
let py_card = ModelDeploymentCard { inner: card };
let py_card_obj = Py::new(py, py_card)
.map_err(|e| anyhow::anyhow!("Failed to create Python MDC: {}", e))?;
// Call Python async function to get a coroutine
let coroutine = callback
.call1(py, (py_card_obj,))
.map_err(|e| anyhow::anyhow!("Failed to call engine_factory: {}", e))?;
// Use the TaskLocals captured at registration time
pyo3_async_runtimes::into_future_with_locals(&locals, coroutine.into_bound(py))
.map_err(|e| {
anyhow::anyhow!("Failed to convert coroutine to future: {}", e)
})
})?;
// Await the Python coroutine (GIL is released during await)
let py_result = py_future
.await
.map_err(|e| anyhow::anyhow!("engine_factory callback failed: {}", e))?;
// Extract PythonAsyncEngine from the Python result and wrap in Arc
let engine: OpenAIChatCompletionsStreamingEngine = Python::with_gil(|py| {
let engine: PythonAsyncEngine = py_result.extract(py).map_err(|e| {
anyhow::anyhow!("Failed to extract PythonAsyncEngine: {}", e)
})?;
Ok::<_, anyhow::Error>(Arc::new(engine))
})?;
Ok(engine)
})
},
)
}
async fn select_engine( async fn select_engine(
#[allow(unused_variables)] distributed_runtime: super::DistributedRuntime, #[allow(unused_variables)] distributed_runtime: super::DistributedRuntime,
args: EntrypointArgs, args: EntrypointArgs,
...@@ -264,7 +360,14 @@ async fn select_engine( ...@@ -264,7 +360,14 @@ async fn select_engine(
engine: dynamo_llm::engines::make_echo_engine(), engine: dynamo_llm::engines::make_echo_engine(),
} }
} }
EngineType::Dynamic => RsEngineConfig::Dynamic(Box::new(local_model)), EngineType::Dynamic => {
// Convert Python engine factory to Rust callback
let engine_factory = args.engine_factory.map(py_engine_factory_to_callback);
RsEngineConfig::Dynamic {
model: Box::new(local_model),
engine_factory,
}
}
EngineType::Mocker => { EngineType::Mocker => {
let mocker_args = if let Some(extra_args_path) = args.extra_engine_args { let mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| { MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
......
...@@ -21,6 +21,7 @@ from dynamo._core import KvStats as KvStats ...@@ -21,6 +21,7 @@ from dynamo._core import KvStats as KvStats
from dynamo._core import LoRADownloader as LoRADownloader from dynamo._core import LoRADownloader as LoRADownloader
from dynamo._core import MediaDecoder as MediaDecoder from dynamo._core import MediaDecoder as MediaDecoder
from dynamo._core import MediaFetcher as MediaFetcher from dynamo._core import MediaFetcher as MediaFetcher
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import ModelInput as ModelInput from dynamo._core import ModelInput as ModelInput
from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig
from dynamo._core import ModelType as ModelType from dynamo._core import ModelType as ModelType
......
...@@ -20,7 +20,7 @@ use dynamo_runtime::{ ...@@ -20,7 +20,7 @@ use dynamo_runtime::{
use crate::{ use crate::{
backend::Backend, backend::Backend,
entrypoint::{self, RouterConfig}, entrypoint::{self, EngineFactoryCallback, RouterConfig},
kv_router::PrefillRouter, kv_router::PrefillRouter,
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
model_type::{ModelInput, ModelType}, model_type::{ModelInput, ModelType},
...@@ -53,6 +53,7 @@ pub struct ModelWatcher { ...@@ -53,6 +53,7 @@ pub struct ModelWatcher {
router_config: RouterConfig, router_config: RouterConfig,
notify_on_model: Notify, notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>, model_update_tx: Option<Sender<ModelUpdate>>,
engine_factory: Option<EngineFactoryCallback>,
} }
const ALL_MODEL_TYPES: &[ModelType] = &[ const ALL_MODEL_TYPES: &[ModelType] = &[
...@@ -68,6 +69,7 @@ impl ModelWatcher { ...@@ -68,6 +69,7 @@ impl ModelWatcher {
runtime: DistributedRuntime, runtime: DistributedRuntime,
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
router_config: RouterConfig, router_config: RouterConfig,
engine_factory: Option<EngineFactoryCallback>,
) -> ModelWatcher { ) -> ModelWatcher {
Self { Self {
manager: model_manager, manager: model_manager,
...@@ -75,6 +77,7 @@ impl ModelWatcher { ...@@ -75,6 +77,7 @@ impl ModelWatcher {
router_config, router_config,
notify_on_model: Notify::new(), notify_on_model: Notify::new(),
model_update_tx: None, model_update_tx: None,
engine_factory,
} }
} }
...@@ -355,6 +358,7 @@ impl ModelWatcher { ...@@ -355,6 +358,7 @@ impl ModelWatcher {
if let Some(tx) = &self.model_update_tx { if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Added(card.clone())).await.ok(); tx.send(ModelUpdate::Added(card.clone())).await.ok();
} }
let checksum = card.mdcsum(); let checksum = card.mdcsum();
if card.model_input == ModelInput::Tokens if card.model_input == ModelInput::Tokens
...@@ -429,7 +433,13 @@ impl ModelWatcher { ...@@ -429,7 +433,13 @@ impl ModelWatcher {
// Add chat engine only if the model supports chat // Add chat engine only if the model supports chat
if card.model_type.supports_chat() { if card.model_type.supports_chat() {
let chat_engine = entrypoint::build_routed_pipeline::< // Work in progress. This will allow creating a chat_engine from Python.
let chat_engine = if let Some(ref factory) = self.engine_factory {
factory(card.clone())
.await
.context("python engine_factory")?
} else {
entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
>( >(
...@@ -443,7 +453,8 @@ impl ModelWatcher { ...@@ -443,7 +453,8 @@ impl ModelWatcher {
self.router_config.enforce_disagg, self.router_config.enforce_disagg,
) )
.await .await
.context("build_routed_pipeline")?; .context("build_routed_pipeline")?
};
self.manager self.manager
.add_chat_completions_model(card.name(), checksum, chat_engine) .add_chat_completions_model(card.name(), checksum, chat_engine)
.context("add_chat_completions_model")?; .context("add_chat_completions_model")?;
......
...@@ -8,15 +8,28 @@ ...@@ -8,15 +8,28 @@
pub mod input; pub mod input;
pub use input::{build_routed_pipeline, build_routed_pipeline_with_preprocessor}; pub use input::{build_routed_pipeline, build_routed_pipeline_with_preprocessor};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use dynamo_runtime::pipeline::RouterMode; use dynamo_runtime::pipeline::RouterMode;
use crate::{ use crate::{
backend::ExecutionContext, engines::StreamingEngine, kv_router::KvRouterConfig, backend::ExecutionContext, engines::StreamingEngine, kv_router::KvRouterConfig,
local_model::LocalModel, local_model::LocalModel, model_card::ModelDeploymentCard,
types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine,
}; };
/// Callback type for engine factory (async)
pub type EngineFactoryCallback = Arc<
dyn Fn(
ModelDeploymentCard,
) -> Pin<
Box<dyn Future<Output = anyhow::Result<OpenAIChatCompletionsStreamingEngine>> + Send>,
> + Send
+ Sync,
>;
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct RouterConfig { pub struct RouterConfig {
pub router_mode: RouterMode, pub router_mode: RouterMode,
...@@ -58,7 +71,10 @@ impl RouterConfig { ...@@ -58,7 +71,10 @@ impl RouterConfig {
#[derive(Clone)] #[derive(Clone)]
pub enum EngineConfig { pub enum EngineConfig {
/// Remote networked engines that we discover via etcd /// Remote networked engines that we discover via etcd
Dynamic(Box<LocalModel>), Dynamic {
model: Box<LocalModel>,
engine_factory: Option<EngineFactoryCallback>,
},
/// A Text engine receives text, does it's own tokenization and prompt formatting. /// A Text engine receives text, does it's own tokenization and prompt formatting.
InProcessText { InProcessText {
...@@ -75,12 +91,19 @@ pub enum EngineConfig { ...@@ -75,12 +91,19 @@ pub enum EngineConfig {
} }
impl EngineConfig { impl EngineConfig {
fn local_model(&self) -> &LocalModel { pub fn local_model(&self) -> &LocalModel {
use EngineConfig::*; use EngineConfig::*;
match self { match self {
Dynamic(lm) => lm, Dynamic { model, .. } => model,
InProcessText { model, .. } => model, InProcessText { model, .. } => model,
InProcessTokens { model, .. } => model, InProcessTokens { model, .. } => model,
} }
} }
pub fn engine_factory(&self) -> Option<&EngineFactoryCallback> {
match self {
EngineConfig::Dynamic { engine_factory, .. } => engine_factory.as_ref(),
_ => None,
}
}
} }
...@@ -58,12 +58,15 @@ pub async fn prepare_engine( ...@@ -58,12 +58,15 @@ pub async fn prepare_engine(
engine_config: EngineConfig, engine_config: EngineConfig,
) -> anyhow::Result<PreparedEngine> { ) -> anyhow::Result<PreparedEngine> {
match engine_config { match engine_config {
EngineConfig::Dynamic(local_model) => { EngineConfig::Dynamic {
model: local_model, ..
} => {
let model_manager = Arc::new(ModelManager::new()); let model_manager = Arc::new(ModelManager::new());
let watch_obj = Arc::new(ModelWatcher::new( let watch_obj = Arc::new(ModelWatcher::new(
distributed_runtime.clone(), distributed_runtime.clone(),
model_manager.clone(), model_manager.clone(),
RouterConfig::default(), RouterConfig::default(),
None,
)); ));
let discovery = distributed_runtime.discovery(); let discovery = distributed_runtime.discovery();
let discovery_stream = discovery let discovery_stream = discovery
......
...@@ -82,7 +82,7 @@ pub async fn run( ...@@ -82,7 +82,7 @@ pub async fn run(
let fut = endpoint.endpoint_builder().handler(ingress).start(); let fut = endpoint.endpoint_builder().handler(ingress).start();
Box::pin(fut) Box::pin(fut)
} }
EngineConfig::Dynamic(_) => { EngineConfig::Dynamic { .. } => {
unreachable!("An endpoint input will never have a Dynamic engine"); unreachable!("An endpoint input will never have a Dynamic engine");
} }
}; };
......
...@@ -26,11 +26,11 @@ pub async fn run( ...@@ -26,11 +26,11 @@ pub async fn run(
.with_request_template(engine_config.local_model().request_template()); .with_request_template(engine_config.local_model().request_template());
let grpc_service = match engine_config { let grpc_service = match engine_config {
EngineConfig::Dynamic(_) => { EngineConfig::Dynamic { ref model, .. } => {
let grpc_service = grpc_service_builder.build()?; let grpc_service = grpc_service_builder.build()?;
let router_config = engine_config.local_model().router_config(); let router_config = model.router_config();
// Listen for models registering themselves, add them to gRPC service // Listen for models registering themselves, add them to gRPC service
let namespace = engine_config.local_model().namespace().unwrap_or(""); let namespace = model.namespace().unwrap_or("");
let target_namespace = if is_global_namespace(namespace) { let target_namespace = if is_global_namespace(namespace) {
None None
} else { } else {
...@@ -105,7 +105,7 @@ async fn run_watcher( ...@@ -105,7 +105,7 @@ async fn run_watcher(
router_config: RouterConfig, router_config: RouterConfig,
target_namespace: Option<String>, target_namespace: Option<String>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let watch_obj = ModelWatcher::new(runtime.clone(), model_manager, router_config); let watch_obj = ModelWatcher::new(runtime.clone(), model_manager, router_config, None);
tracing::debug!("Waiting for remote model"); tracing::debug!("Waiting for remote model");
let discovery = runtime.discovery(); let discovery = runtime.discovery();
let discovery_stream = discovery let discovery_stream = discovery
......
...@@ -7,7 +7,7 @@ use crate::{ ...@@ -7,7 +7,7 @@ use crate::{
discovery::{ModelManager, ModelUpdate, ModelWatcher}, discovery::{ModelManager, ModelUpdate, ModelWatcher},
endpoint_type::EndpointType, endpoint_type::EndpointType,
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{EngineConfig, RouterConfig, input::common}, entrypoint::{EngineConfig, EngineFactoryCallback, RouterConfig, input::common},
http::service::service_v2::{self, HttpService}, http::service::service_v2::{self, HttpService},
namespace::is_global_namespace, namespace::is_global_namespace,
types::openai::{ types::openai::{
...@@ -61,16 +61,19 @@ pub async fn run( ...@@ -61,16 +61,19 @@ pub async fn run(
); );
let http_service = match engine_config { let http_service = match engine_config {
EngineConfig::Dynamic(_) => { EngineConfig::Dynamic {
ref model,
ref engine_factory,
} => {
// This allows the /health endpoint to query store for active instances // This allows the /health endpoint to query store for active instances
http_service_builder = http_service_builder.store(distributed_runtime.store().clone()); http_service_builder = http_service_builder.store(distributed_runtime.store().clone());
let http_service = http_service_builder.build()?; let http_service = http_service_builder.build()?;
let router_config = engine_config.local_model().router_config(); let router_config = model.router_config();
// Listen for models registering themselves, add them to HTTP service // Listen for models registering themselves, add them to HTTP service
// Check if we should filter by namespace (based on the local model's namespace) // Check if we should filter by namespace (based on the local model's namespace)
// Get namespace from the model, fallback to endpoint_id namespace if not set // Get namespace from the model, fallback to endpoint_id namespace if not set
let namespace = engine_config.local_model().namespace().unwrap_or(""); let namespace = model.namespace().unwrap_or("");
let target_namespace = if is_global_namespace(namespace) { let target_namespace = if is_global_namespace(namespace) {
None None
} else { } else {
...@@ -83,6 +86,7 @@ pub async fn run( ...@@ -83,6 +86,7 @@ pub async fn run(
target_namespace, target_namespace,
Arc::new(http_service.clone()), Arc::new(http_service.clone()),
http_service.state().metrics_clone(), http_service.state().metrics_clone(),
engine_factory.clone(),
) )
.await?; .await?;
http_service http_service
...@@ -194,8 +198,14 @@ async fn run_watcher( ...@@ -194,8 +198,14 @@ async fn run_watcher(
target_namespace: Option<String>, target_namespace: Option<String>,
http_service: Arc<HttpService>, http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>, metrics: Arc<crate::http::service::metrics::Metrics>,
engine_factory: Option<EngineFactoryCallback>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut watch_obj = ModelWatcher::new(runtime.clone(), model_manager, router_config); let mut watch_obj = ModelWatcher::new(
runtime.clone(),
model_manager,
router_config,
engine_factory,
);
tracing::debug!("Waiting for remote model"); tracing::debug!("Waiting for remote model");
let discovery = runtime.discovery(); let discovery = runtime.discovery();
let discovery_stream = discovery let discovery_stream = discovery
......
...@@ -339,6 +339,7 @@ mod integration_tests { ...@@ -339,6 +339,7 @@ mod integration_tests {
distributed_runtime.clone(), distributed_runtime.clone(),
service.state().manager_clone(), service.state().manager_clone(),
dynamo_llm::entrypoint::RouterConfig::default(), dynamo_llm::entrypoint::RouterConfig::default(),
None,
); );
// Start watching for model registrations via discovery interface // Start watching for model registrations via discovery interface
let discovery = distributed_runtime.discovery(); let discovery = distributed_runtime.discovery();
...@@ -510,6 +511,7 @@ mod integration_tests { ...@@ -510,6 +511,7 @@ mod integration_tests {
distributed_runtime.clone(), distributed_runtime.clone(),
service.state().manager_clone(), service.state().manager_clone(),
dynamo_llm::entrypoint::RouterConfig::default(), dynamo_llm::entrypoint::RouterConfig::default(),
None,
); );
// Get all model entries for our test model // Get all model entries for our test model
......
...@@ -961,7 +961,7 @@ mod integration_tests { ...@@ -961,7 +961,7 @@ mod integration_tests {
// Now create a namespace, component, and endpoint to make the system healthy // Now create a namespace, component, and endpoint to make the system healthy
let namespace = drt.namespace("ns1234").unwrap(); let namespace = drt.namespace("ns1234").unwrap();
let mut component = namespace.component("comp1234").unwrap(); let component = namespace.component("comp1234").unwrap();
// Create a simple test handler // Create a simple test handler
use crate::pipeline::{async_trait, network::Ingress, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, SingleIn}; use crate::pipeline::{async_trait, network::Ingress, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, SingleIn};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment