// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::fmt::Display; use std::path::PathBuf; use pyo3::{exceptions::PyException, prelude::*}; use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig; use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig; use dynamo_llm::entrypoint::input::Input; use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig; use dynamo_llm::local_model::DEFAULT_HTTP_PORT; use dynamo_llm::local_model::{LocalModel, LocalModelBuilder}; use dynamo_llm::mocker::protocols::MockEngineArgs; use dynamo_runtime::protocols::EndpointId; use crate::RouterMode; #[pyclass(eq, eq_int)] #[derive(Clone, Debug, PartialEq)] #[repr(i32)] pub enum EngineType { Echo = 1, Dynamic = 2, Mocker = 3, } #[pyclass] #[derive(Default, Clone, Debug, Copy)] pub struct KvRouterConfig { inner: RsKvRouterConfig, } impl KvRouterConfig { pub fn inner(&self) -> RsKvRouterConfig { self.inner } } #[pymethods] impl KvRouterConfig { #[new] #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=1000000, router_reset_states=false))] fn new( overlap_score_weight: f64, router_temperature: f64, use_kv_events: bool, router_replica_sync: bool, router_track_active_blocks: bool, router_snapshot_threshold: Option, router_reset_states: bool, ) -> Self { KvRouterConfig { inner: RsKvRouterConfig { overlap_score_weight, router_temperature, use_kv_events, router_replica_sync, router_track_active_blocks, router_snapshot_threshold, router_reset_states, }, } } } #[pyclass] #[derive(Clone, Debug)] pub struct RouterConfig { router_mode: RouterMode, kv_router_config: KvRouterConfig, busy_threshold: Option, enforce_disagg: bool, } #[pymethods] impl RouterConfig { #[new] #[pyo3(signature = (mode, config=None, busy_threshold=None, enforce_disagg=false))] pub fn new( mode: RouterMode, config: Option, busy_threshold: Option, enforce_disagg: bool, ) -> Self { Self { router_mode: mode, kv_router_config: config.unwrap_or_default(), busy_threshold, enforce_disagg, } } } impl From for RsRouterConfig { fn from(rc: RouterConfig) -> RsRouterConfig { RsRouterConfig { router_mode: rc.router_mode.into(), kv_router_config: rc.kv_router_config.inner, busy_threshold: rc.busy_threshold, enforce_disagg: rc.enforce_disagg, } } } #[pyclass] #[derive(Clone, Debug)] pub(crate) struct EntrypointArgs { engine_type: EngineType, model_path: Option, model_name: Option, endpoint_id: Option, context_length: Option, template_file: Option, router_config: Option, kv_cache_block_size: Option, http_host: Option, http_port: u16, tls_cert_path: Option, tls_key_path: Option, extra_engine_args: Option, namespace: Option, custom_backend_metrics_endpoint: Option, custom_backend_metrics_polling_interval: Option, is_prefill: bool, } #[pymethods] impl EntrypointArgs { #[allow(clippy::too_many_arguments)] #[new] #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, custom_backend_metrics_endpoint=None, custom_backend_metrics_polling_interval=None, is_prefill=false))] pub fn new( engine_type: EngineType, model_path: Option, model_name: Option, // e.g. "dyn://namespace.component.endpoint" endpoint_id: Option, context_length: Option, template_file: Option, router_config: Option, kv_cache_block_size: Option, http_host: Option, http_port: Option, tls_cert_path: Option, tls_key_path: Option, extra_engine_args: Option, namespace: Option, custom_backend_metrics_endpoint: Option, custom_backend_metrics_polling_interval: Option, is_prefill: bool, ) -> PyResult { let endpoint_id_obj: Option = endpoint_id.as_deref().map(EndpointId::from); if (tls_cert_path.is_some() && tls_key_path.is_none()) || (tls_cert_path.is_none() && tls_key_path.is_some()) { return Err(pyo3::exceptions::PyValueError::new_err( "tls_cert_path and tls_key_path must be provided together", )); } Ok(EntrypointArgs { engine_type, model_path, model_name, endpoint_id: endpoint_id_obj, context_length, template_file, router_config, kv_cache_block_size, http_host, http_port: http_port.unwrap_or(DEFAULT_HTTP_PORT), tls_cert_path, tls_key_path, extra_engine_args, namespace, custom_backend_metrics_endpoint, custom_backend_metrics_polling_interval, is_prefill, }) } } #[pyclass] #[derive(Clone)] pub(crate) struct EngineConfig { inner: RsEngineConfig, } /// Create the backend engine wrapper to run the model. /// Download the model if necessary. #[pyfunction] #[pyo3(signature = (distributed_runtime, args))] pub fn make_engine<'p>( py: Python<'p>, distributed_runtime: super::DistributedRuntime, args: EntrypointArgs, ) -> PyResult> { let mut builder = LocalModelBuilder::default(); builder .model_name( args.model_name .clone() .or_else(|| args.model_path.clone().map(|p| p.display().to_string())), ) .endpoint_id(args.endpoint_id.clone()) .context_length(args.context_length) .request_template(args.template_file.clone()) .kv_cache_block_size(args.kv_cache_block_size) .router_config(args.router_config.clone().map(|rc| rc.into())) .http_host(args.http_host.clone()) .http_port(args.http_port) .tls_cert_path(args.tls_cert_path.clone()) .tls_key_path(args.tls_key_path.clone()) .is_mocker(matches!(args.engine_type, EngineType::Mocker)) .extra_engine_args(args.extra_engine_args.clone()) .namespace(args.namespace.clone()) .custom_backend_metrics_endpoint(args.custom_backend_metrics_endpoint.clone()) .custom_backend_metrics_polling_interval(args.custom_backend_metrics_polling_interval); pyo3_async_runtimes::tokio::future_into_py(py, async move { if let Some(model_path) = args.model_path.clone() { let local_path = if model_path.exists() { model_path } else { // Mocker only needs tokenizer, not weights let ignore_weights = matches!(args.engine_type, EngineType::Mocker); LocalModel::fetch(&model_path.display().to_string(), ignore_weights) .await .map_err(to_pyerr)? }; builder.model_path(local_path); } let local_model = builder.build().await.map_err(to_pyerr)?; let inner = select_engine(distributed_runtime, args, local_model) .await .map_err(to_pyerr)?; Ok(EngineConfig { inner }) }) } async fn select_engine( #[allow(unused_variables)] distributed_runtime: super::DistributedRuntime, args: EntrypointArgs, local_model: LocalModel, ) -> anyhow::Result { let inner = match args.engine_type { EngineType::Echo => { // There is no validation for the echo engine RsEngineConfig::StaticFull { model: Box::new(local_model), engine: dynamo_llm::engines::make_echo_engine(), } } EngineType::Dynamic => RsEngineConfig::Dynamic(Box::new(local_model)), EngineType::Mocker => { let mocker_args = if let Some(extra_args_path) = args.extra_engine_args { MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| { anyhow::anyhow!( "Failed to load mocker args from {:?}: {}", extra_args_path, e ) })? } else { tracing::warn!( "No extra_engine_args specified for mocker engine. Using default mocker args." ); MockEngineArgs::default() }; let endpoint = local_model.endpoint_id().clone(); let engine = dynamo_llm::mocker::engine::make_mocker_engine( distributed_runtime.inner, endpoint, mocker_args, ) .await?; RsEngineConfig::StaticCore { engine, model: Box::new(local_model), is_prefill: args.is_prefill, } } }; Ok(inner) } #[pyfunction] #[pyo3(signature = (distributed_runtime, input, engine_config))] pub fn run_input<'p>( py: Python<'p>, distributed_runtime: super::DistributedRuntime, input: &str, engine_config: EngineConfig, ) -> PyResult> { let input_enum: Input = input.parse().map_err(to_pyerr)?; pyo3_async_runtimes::tokio::future_into_py(py, async move { dynamo_llm::entrypoint::input::run_input( distributed_runtime.inner.clone(), input_enum, engine_config.inner, ) .await .map_err(to_pyerr)?; Ok(()) }) } pub fn to_pyerr(err: E) -> PyErr where E: Display, { PyException::new_err(format!("{}", err)) }