Unverified Commit da0f2fb8 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(frontend): First part of Python request handling (#4999)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 91423c45
......@@ -30,12 +30,15 @@ from dynamo.llm import (
EngineType,
EntrypointArgs,
KvRouterConfig,
ModelDeploymentCard,
PythonAsyncEngine,
RouterConfig,
RouterMode,
make_engine,
run_input,
)
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
from . import __version__
......@@ -45,9 +48,25 @@ CUSTOM_BACKEND_METRICS_POLLING_INTERVAL_ENV_VAR = (
)
CUSTOM_BACKEND_ENDPOINT_ENV_VAR = "CUSTOM_BACKEND_ENDPOINT"
configure_dynamo_logging()
logger = logging.getLogger(__name__)
async def _dummy_generator(request):
"""Minimal generator that yields nothing. Work in progress."""
return
yield # Makes this an async generator
async def engine_factory(mdc: ModelDeploymentCard) -> PythonAsyncEngine:
"""
Called by Rust when a model is discovered.
"""
loop = asyncio.get_running_loop()
logger.info(f"Engine_factory called with MDC: {mdc.to_json_str()[:100]}...")
return PythonAsyncEngine(_dummy_generator, loop)
def validate_model_name(value):
"""Validate that model-name is a non-empty string."""
if not value or not isinstance(value, str) or len(value.strip()) == 0:
......@@ -254,6 +273,12 @@ def parse_args():
default=os.environ.get("DYN_REQUEST_PLANE", "nats"),
help="Determines how requests are distributed from routers to workers. 'tcp' is fastest [nats|http|tcp]",
)
parser.add_argument(
"--exp-python-factory",
action="store_true",
default=False,
help="[EXPERIMENTAL] Enable Python-based engine factory. When set, engines will be created via a Python callback instead of the default Rust pipeline.",
)
flags = parser.parse_args()
......@@ -356,6 +381,9 @@ async def async_main():
"custom_backend_metrics_polling_interval"
] = flags.custom_backend_metrics_polling_interval
if flags.exp_python_factory:
kwargs["engine_factory"] = engine_factory
e = EntrypointArgs(EngineType.Dynamic, **kwargs)
engine = await make_engine(runtime, e)
......
......@@ -146,7 +146,10 @@ async fn engine_for(
match out_opt {
Output::Auto => {
// Auto-discover backends
Ok(EngineConfig::Dynamic(Box::new(local_model)))
Ok(EngineConfig::Dynamic {
model: Box::new(local_model),
engine_factory: None,
})
}
Output::Echo => Ok(EngineConfig::InProcessText {
model: Box::new(local_model),
......
......@@ -975,6 +975,7 @@ pub async fn create_worker_selection_pipeline_chat(
component.drt().clone(),
model_manager.clone(),
router_config,
None,
);
let cards = watcher
.cards_for_model(model_name, Some(namespace), false)
......
......@@ -2,20 +2,29 @@
// SPDX-License-Identifier: Apache-2.0
use std::fmt::Display;
use std::future::Future;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use pyo3::{exceptions::PyException, prelude::*};
use pyo3_async_runtimes::TaskLocals;
use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
use dynamo_llm::entrypoint::EngineFactoryCallback;
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
use dynamo_llm::local_model::DEFAULT_HTTP_PORT;
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
use dynamo_llm::mocker::protocols::MockEngineArgs;
use dynamo_llm::model_card::ModelDeploymentCard as RsModelDeploymentCard;
use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
use dynamo_runtime::protocols::EndpointId;
use super::model_card::ModelDeploymentCard;
use crate::RouterMode;
use crate::engine::PythonAsyncEngine;
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
......@@ -117,6 +126,21 @@ impl From<RouterConfig> for RsRouterConfig {
}
}
/// Wrapper to hold Python callback and its TaskLocals for async execution
#[derive(Clone)]
struct PyEngineFactory {
callback: Arc<PyObject>,
locals: Arc<TaskLocals>,
}
impl std::fmt::Debug for PyEngineFactory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PyEngineFactory")
.field("callback", &"<PyObject>")
.finish()
}
}
#[pyclass]
#[derive(Clone, Debug)]
pub(crate) struct EntrypointArgs {
......@@ -137,14 +161,16 @@ pub(crate) struct EntrypointArgs {
custom_backend_metrics_endpoint: Option<String>,
custom_backend_metrics_polling_interval: Option<f64>,
is_prefill: bool,
engine_factory: Option<PyEngineFactory>,
}
#[pymethods]
impl EntrypointArgs {
#[allow(clippy::too_many_arguments)]
#[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, custom_backend_metrics_endpoint=None, custom_backend_metrics_polling_interval=None, is_prefill=false))]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, custom_backend_metrics_endpoint=None, custom_backend_metrics_polling_interval=None, is_prefill=false, engine_factory=None))]
pub fn new(
py: Python<'_>,
engine_type: EngineType,
model_path: Option<PathBuf>,
model_name: Option<String>, // e.g. "dyn://namespace.component.endpoint"
......@@ -162,6 +188,7 @@ impl EntrypointArgs {
custom_backend_metrics_endpoint: Option<String>,
custom_backend_metrics_polling_interval: Option<f64>,
is_prefill: bool,
engine_factory: Option<PyObject>,
) -> PyResult<Self> {
let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
if (tls_cert_path.is_some() && tls_key_path.is_none())
......@@ -171,6 +198,23 @@ impl EntrypointArgs {
"tls_cert_path and tls_key_path must be provided together",
));
}
// Capture TaskLocals at registration time for the engine factory callback
let engine_factory = engine_factory
.map(|callback| {
let locals = pyo3_async_runtimes::tokio::get_current_locals(py).map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!(
"Failed to get TaskLocals for engine_factory: {}",
e
))
})?;
Ok::<_, PyErr>(PyEngineFactory {
callback: Arc::new(callback),
locals: Arc::new(locals),
})
})
.transpose()?;
Ok(EntrypointArgs {
engine_type,
model_path,
......@@ -189,6 +233,7 @@ impl EntrypointArgs {
custom_backend_metrics_endpoint,
custom_backend_metrics_polling_interval,
is_prefill,
engine_factory,
})
}
}
......@@ -251,6 +296,57 @@ pub fn make_engine<'p>(
})
}
/// Convert a PyEngineFactory to a Rust EngineFactoryCallback
fn py_engine_factory_to_callback(factory: PyEngineFactory) -> EngineFactoryCallback {
let callback = factory.callback;
let locals = factory.locals;
Arc::new(
move |card: RsModelDeploymentCard| -> Pin<
Box<dyn Future<Output = anyhow::Result<OpenAIChatCompletionsStreamingEngine>> + Send>,
> {
let callback = callback.clone();
let locals = locals.clone();
Box::pin(async move {
// Acquire GIL to call Python callback and convert coroutine to future
let py_future = Python::with_gil(|py| {
// Create Python ModelDeploymentCard wrapper
let py_card = ModelDeploymentCard { inner: card };
let py_card_obj = Py::new(py, py_card)
.map_err(|e| anyhow::anyhow!("Failed to create Python MDC: {}", e))?;
// Call Python async function to get a coroutine
let coroutine = callback
.call1(py, (py_card_obj,))
.map_err(|e| anyhow::anyhow!("Failed to call engine_factory: {}", e))?;
// Use the TaskLocals captured at registration time
pyo3_async_runtimes::into_future_with_locals(&locals, coroutine.into_bound(py))
.map_err(|e| {
anyhow::anyhow!("Failed to convert coroutine to future: {}", e)
})
})?;
// Await the Python coroutine (GIL is released during await)
let py_result = py_future
.await
.map_err(|e| anyhow::anyhow!("engine_factory callback failed: {}", e))?;
// Extract PythonAsyncEngine from the Python result and wrap in Arc
let engine: OpenAIChatCompletionsStreamingEngine = Python::with_gil(|py| {
let engine: PythonAsyncEngine = py_result.extract(py).map_err(|e| {
anyhow::anyhow!("Failed to extract PythonAsyncEngine: {}", e)
})?;
Ok::<_, anyhow::Error>(Arc::new(engine))
})?;
Ok(engine)
})
},
)
}
async fn select_engine(
#[allow(unused_variables)] distributed_runtime: super::DistributedRuntime,
args: EntrypointArgs,
......@@ -264,7 +360,14 @@ async fn select_engine(
engine: dynamo_llm::engines::make_echo_engine(),
}
}
EngineType::Dynamic => RsEngineConfig::Dynamic(Box::new(local_model)),
EngineType::Dynamic => {
// Convert Python engine factory to Rust callback
let engine_factory = args.engine_factory.map(py_engine_factory_to_callback);
RsEngineConfig::Dynamic {
model: Box::new(local_model),
engine_factory,
}
}
EngineType::Mocker => {
let mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
......
......@@ -21,6 +21,7 @@ from dynamo._core import KvStats as KvStats
from dynamo._core import LoRADownloader as LoRADownloader
from dynamo._core import MediaDecoder as MediaDecoder
from dynamo._core import MediaFetcher as MediaFetcher
from dynamo._core import ModelDeploymentCard as ModelDeploymentCard
from dynamo._core import ModelInput as ModelInput
from dynamo._core import ModelRuntimeConfig as ModelRuntimeConfig
from dynamo._core import ModelType as ModelType
......
......@@ -20,7 +20,7 @@ use dynamo_runtime::{
use crate::{
backend::Backend,
entrypoint::{self, RouterConfig},
entrypoint::{self, EngineFactoryCallback, RouterConfig},
kv_router::PrefillRouter,
model_card::ModelDeploymentCard,
model_type::{ModelInput, ModelType},
......@@ -53,6 +53,7 @@ pub struct ModelWatcher {
router_config: RouterConfig,
notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>,
engine_factory: Option<EngineFactoryCallback>,
}
const ALL_MODEL_TYPES: &[ModelType] = &[
......@@ -68,6 +69,7 @@ impl ModelWatcher {
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
router_config: RouterConfig,
engine_factory: Option<EngineFactoryCallback>,
) -> ModelWatcher {
Self {
manager: model_manager,
......@@ -75,6 +77,7 @@ impl ModelWatcher {
router_config,
notify_on_model: Notify::new(),
model_update_tx: None,
engine_factory,
}
}
......@@ -355,6 +358,7 @@ impl ModelWatcher {
if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Added(card.clone())).await.ok();
}
let checksum = card.mdcsum();
if card.model_input == ModelInput::Tokens
......@@ -429,21 +433,28 @@ impl ModelWatcher {
// Add chat engine only if the model supports chat
if card.model_type.supports_chat() {
let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(
card,
&client,
self.router_config.router_mode,
worker_monitor.clone(),
kv_chooser.clone(),
tokenizer_hf.clone(),
prefill_chooser.clone(),
self.router_config.enforce_disagg,
)
.await
.context("build_routed_pipeline")?;
// Work in progress. This will allow creating a chat_engine from Python.
let chat_engine = if let Some(ref factory) = self.engine_factory {
factory(card.clone())
.await
.context("python engine_factory")?
} else {
entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(
card,
&client,
self.router_config.router_mode,
worker_monitor.clone(),
kv_chooser.clone(),
tokenizer_hf.clone(),
prefill_chooser.clone(),
self.router_config.enforce_disagg,
)
.await
.context("build_routed_pipeline")?
};
self.manager
.add_chat_completions_model(card.name(), checksum, chat_engine)
.context("add_chat_completions_model")?;
......
......@@ -8,15 +8,28 @@
pub mod input;
pub use input::{build_routed_pipeline, build_routed_pipeline_with_preprocessor};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use dynamo_runtime::pipeline::RouterMode;
use crate::{
backend::ExecutionContext, engines::StreamingEngine, kv_router::KvRouterConfig,
local_model::LocalModel,
local_model::LocalModel, model_card::ModelDeploymentCard,
types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine,
};
/// Callback type for engine factory (async)
pub type EngineFactoryCallback = Arc<
dyn Fn(
ModelDeploymentCard,
) -> Pin<
Box<dyn Future<Output = anyhow::Result<OpenAIChatCompletionsStreamingEngine>> + Send>,
> + Send
+ Sync,
>;
#[derive(Debug, Clone, Default)]
pub struct RouterConfig {
pub router_mode: RouterMode,
......@@ -58,7 +71,10 @@ impl RouterConfig {
#[derive(Clone)]
pub enum EngineConfig {
/// Remote networked engines that we discover via etcd
Dynamic(Box<LocalModel>),
Dynamic {
model: Box<LocalModel>,
engine_factory: Option<EngineFactoryCallback>,
},
/// A Text engine receives text, does it's own tokenization and prompt formatting.
InProcessText {
......@@ -75,12 +91,19 @@ pub enum EngineConfig {
}
impl EngineConfig {
fn local_model(&self) -> &LocalModel {
pub fn local_model(&self) -> &LocalModel {
use EngineConfig::*;
match self {
Dynamic(lm) => lm,
Dynamic { model, .. } => model,
InProcessText { model, .. } => model,
InProcessTokens { model, .. } => model,
}
}
pub fn engine_factory(&self) -> Option<&EngineFactoryCallback> {
match self {
EngineConfig::Dynamic { engine_factory, .. } => engine_factory.as_ref(),
_ => None,
}
}
}
......@@ -58,12 +58,15 @@ pub async fn prepare_engine(
engine_config: EngineConfig,
) -> anyhow::Result<PreparedEngine> {
match engine_config {
EngineConfig::Dynamic(local_model) => {
EngineConfig::Dynamic {
model: local_model, ..
} => {
let model_manager = Arc::new(ModelManager::new());
let watch_obj = Arc::new(ModelWatcher::new(
distributed_runtime.clone(),
model_manager.clone(),
RouterConfig::default(),
None,
));
let discovery = distributed_runtime.discovery();
let discovery_stream = discovery
......
......@@ -82,7 +82,7 @@ pub async fn run(
let fut = endpoint.endpoint_builder().handler(ingress).start();
Box::pin(fut)
}
EngineConfig::Dynamic(_) => {
EngineConfig::Dynamic { .. } => {
unreachable!("An endpoint input will never have a Dynamic engine");
}
};
......
......@@ -26,11 +26,11 @@ pub async fn run(
.with_request_template(engine_config.local_model().request_template());
let grpc_service = match engine_config {
EngineConfig::Dynamic(_) => {
EngineConfig::Dynamic { ref model, .. } => {
let grpc_service = grpc_service_builder.build()?;
let router_config = engine_config.local_model().router_config();
let router_config = model.router_config();
// Listen for models registering themselves, add them to gRPC service
let namespace = engine_config.local_model().namespace().unwrap_or("");
let namespace = model.namespace().unwrap_or("");
let target_namespace = if is_global_namespace(namespace) {
None
} else {
......@@ -105,7 +105,7 @@ async fn run_watcher(
router_config: RouterConfig,
target_namespace: Option<String>,
) -> anyhow::Result<()> {
let watch_obj = ModelWatcher::new(runtime.clone(), model_manager, router_config);
let watch_obj = ModelWatcher::new(runtime.clone(), model_manager, router_config, None);
tracing::debug!("Waiting for remote model");
let discovery = runtime.discovery();
let discovery_stream = discovery
......
......@@ -7,7 +7,7 @@ use crate::{
discovery::{ModelManager, ModelUpdate, ModelWatcher},
endpoint_type::EndpointType,
engines::StreamingEngineAdapter,
entrypoint::{EngineConfig, RouterConfig, input::common},
entrypoint::{EngineConfig, EngineFactoryCallback, RouterConfig, input::common},
http::service::service_v2::{self, HttpService},
namespace::is_global_namespace,
types::openai::{
......@@ -61,16 +61,19 @@ pub async fn run(
);
let http_service = match engine_config {
EngineConfig::Dynamic(_) => {
EngineConfig::Dynamic {
ref model,
ref engine_factory,
} => {
// This allows the /health endpoint to query store for active instances
http_service_builder = http_service_builder.store(distributed_runtime.store().clone());
let http_service = http_service_builder.build()?;
let router_config = engine_config.local_model().router_config();
let router_config = model.router_config();
// Listen for models registering themselves, add them to HTTP service
// Check if we should filter by namespace (based on the local model's namespace)
// Get namespace from the model, fallback to endpoint_id namespace if not set
let namespace = engine_config.local_model().namespace().unwrap_or("");
let namespace = model.namespace().unwrap_or("");
let target_namespace = if is_global_namespace(namespace) {
None
} else {
......@@ -83,6 +86,7 @@ pub async fn run(
target_namespace,
Arc::new(http_service.clone()),
http_service.state().metrics_clone(),
engine_factory.clone(),
)
.await?;
http_service
......@@ -194,8 +198,14 @@ async fn run_watcher(
target_namespace: Option<String>,
http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>,
engine_factory: Option<EngineFactoryCallback>,
) -> anyhow::Result<()> {
let mut watch_obj = ModelWatcher::new(runtime.clone(), model_manager, router_config);
let mut watch_obj = ModelWatcher::new(
runtime.clone(),
model_manager,
router_config,
engine_factory,
);
tracing::debug!("Waiting for remote model");
let discovery = runtime.discovery();
let discovery_stream = discovery
......
......@@ -339,6 +339,7 @@ mod integration_tests {
distributed_runtime.clone(),
service.state().manager_clone(),
dynamo_llm::entrypoint::RouterConfig::default(),
None,
);
// Start watching for model registrations via discovery interface
let discovery = distributed_runtime.discovery();
......@@ -510,6 +511,7 @@ mod integration_tests {
distributed_runtime.clone(),
service.state().manager_clone(),
dynamo_llm::entrypoint::RouterConfig::default(),
None,
);
// Get all model entries for our test model
......
......@@ -961,7 +961,7 @@ mod integration_tests {
// Now create a namespace, component, and endpoint to make the system healthy
let namespace = drt.namespace("ns1234").unwrap();
let mut component = namespace.component("comp1234").unwrap();
let component = namespace.component("comp1234").unwrap();
// Create a simple test handler
use crate::pipeline::{async_trait, network::Ingress, AsyncEngine, AsyncEngineContextProvider, Error, ManyOut, SingleIn};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment