entrypoint.rs 16.3 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
// SPDX-License-Identifier: Apache-2.0

use std::fmt::Display;
5
use std::future::Future;
6
use std::path::PathBuf;
7
8
use std::pin::Pin;
use std::sync::Arc;
9
10

use pyo3::{exceptions::PyException, prelude::*};
11
use pyo3_async_runtimes::TaskLocals;
12

13
use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig;
14
use dynamo_llm::entrypoint::ChatEngineFactoryCallback;
15
use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
16
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
17
use dynamo_llm::entrypoint::input::Input;
18
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
Graham King's avatar
Graham King committed
19
use dynamo_llm::local_model::DEFAULT_HTTP_PORT;
20
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
21
use dynamo_llm::mocker::protocols::MockEngineArgs;
22
23
use dynamo_llm::model_card::ModelDeploymentCard as RsModelDeploymentCard;
use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
24
use dynamo_runtime::discovery::ModelCardInstanceId as RsModelCardInstanceId;
25
use dynamo_runtime::protocols::EndpointId;
26

27
use super::model_card::ModelDeploymentCard;
28
use crate::RouterMode;
29
use crate::engine::PythonAsyncEngine;
30

31
32
33
34
35
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
#[repr(i32)]
pub enum EngineType {
    Echo = 1,
36
37
    Dynamic = 2,
    Mocker = 3,
38
39
}

40
41
42
43
44
45
#[pyclass]
#[derive(Default, Clone, Debug, Copy)]
pub struct KvRouterConfig {
    inner: RsKvRouterConfig,
}

46
47
48
49
50
51
impl KvRouterConfig {
    pub fn inner(&self) -> RsKvRouterConfig {
        self.inner
    }
}

52
53
54
#[pymethods]
impl KvRouterConfig {
    #[new]
55
    #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8))]
56
    #[allow(clippy::too_many_arguments)]
57
58
59
60
    fn new(
        overlap_score_weight: f64,
        router_temperature: f64,
        use_kv_events: bool,
61
        durable_kv_events: bool,
62
        router_replica_sync: bool,
63
        router_track_active_blocks: bool,
64
        router_track_output_blocks: bool,
65
        router_assume_kv_reuse: bool,
66
67
        router_snapshot_threshold: Option<u32>,
        router_reset_states: bool,
68
69
70
        router_ttl_secs: f64,
        router_max_tree_size: usize,
        router_prune_target_ratio: f64,
71
    ) -> Self {
72
73
74
75
76
        KvRouterConfig {
            inner: RsKvRouterConfig {
                overlap_score_weight,
                router_temperature,
                use_kv_events,
77
                durable_kv_events,
78
                router_replica_sync,
79
                router_track_active_blocks,
80
                router_track_output_blocks,
81
                router_assume_kv_reuse,
82
83
                router_snapshot_threshold,
                router_reset_states,
84
85
86
                router_ttl_secs,
                router_max_tree_size,
                router_prune_target_ratio,
87
88
89
90
91
92
93
94
            },
        }
    }
}

#[pyclass]
#[derive(Clone, Debug)]
pub struct RouterConfig {
95
96
97
98
99
100
    #[pyo3(get, set)]
    pub router_mode: RouterMode,

    #[pyo3(get, set)]
    pub kv_router_config: KvRouterConfig,

101
102
103
104
    /// Threshold for active decode blocks utilization (0.0-1.0)
    active_decode_blocks_threshold: Option<f64>,
    /// Threshold for active prefill tokens utilization (literal token count)
    active_prefill_tokens_threshold: Option<u64>,
105
106
    /// Threshold for active prefill tokens as fraction of max_num_batched_tokens
    active_prefill_tokens_threshold_frac: Option<f64>,
107
    enforce_disagg: bool,
108
109
110
111
112
}

#[pymethods]
impl RouterConfig {
    #[new]
113
    #[pyo3(signature = (mode, config=None, active_decode_blocks_threshold=None, active_prefill_tokens_threshold=None, active_prefill_tokens_threshold_frac=None, enforce_disagg=false))]
114
115
116
    pub fn new(
        mode: RouterMode,
        config: Option<KvRouterConfig>,
117
118
        active_decode_blocks_threshold: Option<f64>,
        active_prefill_tokens_threshold: Option<u64>,
119
        active_prefill_tokens_threshold_frac: Option<f64>,
120
        enforce_disagg: bool,
121
    ) -> Self {
122
123
124
        Self {
            router_mode: mode,
            kv_router_config: config.unwrap_or_default(),
125
126
            active_decode_blocks_threshold,
            active_prefill_tokens_threshold,
127
            active_prefill_tokens_threshold_frac,
128
            enforce_disagg,
129
130
131
132
133
134
135
136
137
        }
    }
}

impl From<RouterConfig> for RsRouterConfig {
    fn from(rc: RouterConfig) -> RsRouterConfig {
        RsRouterConfig {
            router_mode: rc.router_mode.into(),
            kv_router_config: rc.kv_router_config.inner,
138
139
140
141
142
            load_threshold_config: RsLoadThresholdConfig {
                active_decode_blocks_threshold: rc.active_decode_blocks_threshold,
                active_prefill_tokens_threshold: rc.active_prefill_tokens_threshold,
                active_prefill_tokens_threshold_frac: rc.active_prefill_tokens_threshold_frac,
            },
143
            enforce_disagg: rc.enforce_disagg,
144
145
146
147
        }
    }
}

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
/// Wrapper to hold Python callback and its TaskLocals for async execution
#[derive(Clone)]
struct PyEngineFactory {
    callback: Arc<PyObject>,
    locals: Arc<TaskLocals>,
}

impl std::fmt::Debug for PyEngineFactory {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("PyEngineFactory")
            .field("callback", &"<PyObject>")
            .finish()
    }
}

163
164
165
166
167
168
169
170
171
#[pyclass]
#[derive(Clone, Debug)]
pub(crate) struct EntrypointArgs {
    engine_type: EngineType,
    model_path: Option<PathBuf>,
    model_name: Option<String>,
    endpoint_id: Option<EndpointId>,
    context_length: Option<u32>,
    template_file: Option<PathBuf>,
172
    router_config: Option<RouterConfig>,
173
    kv_cache_block_size: Option<u32>,
174
    http_host: Option<String>,
Graham King's avatar
Graham King committed
175
    http_port: u16,
176
    http_metrics_port: Option<u16>,
Graham King's avatar
Graham King committed
177
178
    tls_cert_path: Option<PathBuf>,
    tls_key_path: Option<PathBuf>,
179
    extra_engine_args: Option<PathBuf>,
180
    namespace: Option<String>,
181
    is_prefill: bool,
182
    migration_limit: u32,
183
    chat_engine_factory: Option<PyEngineFactory>,
184
185
186
187
188
189
}

#[pymethods]
impl EntrypointArgs {
    #[allow(clippy::too_many_arguments)]
    #[new]
190
    #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
191
    pub fn new(
192
        py: Python<'_>,
193
194
195
196
197
198
        engine_type: EngineType,
        model_path: Option<PathBuf>,
        model_name: Option<String>, // e.g. "dyn://namespace.component.endpoint"
        endpoint_id: Option<String>,
        context_length: Option<u32>,
        template_file: Option<PathBuf>,
199
        router_config: Option<RouterConfig>,
200
        kv_cache_block_size: Option<u32>,
201
        http_host: Option<String>,
202
        http_port: Option<u16>,
203
        http_metrics_port: Option<u16>,
Graham King's avatar
Graham King committed
204
205
        tls_cert_path: Option<PathBuf>,
        tls_key_path: Option<PathBuf>,
206
        extra_engine_args: Option<PathBuf>,
207
        namespace: Option<String>,
208
        is_prefill: bool,
209
        migration_limit: u32,
210
        chat_engine_factory: Option<PyObject>,
211
    ) -> PyResult<Self> {
212
        let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
Graham King's avatar
Graham King committed
213
214
215
216
217
218
219
        if (tls_cert_path.is_some() && tls_key_path.is_none())
            || (tls_cert_path.is_none() && tls_key_path.is_some())
        {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "tls_cert_path and tls_key_path must be provided together",
            ));
        }
220

221
222
        // Capture TaskLocals at registration time for the chat engine factory callback
        let chat_engine_factory = chat_engine_factory
223
224
225
            .map(|callback| {
                let locals = pyo3_async_runtimes::tokio::get_current_locals(py).map_err(|e| {
                    pyo3::exceptions::PyRuntimeError::new_err(format!(
226
                        "Failed to get TaskLocals for chat_engine_factory: {}",
227
228
229
230
231
232
233
234
235
236
                        e
                    ))
                })?;
                Ok::<_, PyErr>(PyEngineFactory {
                    callback: Arc::new(callback),
                    locals: Arc::new(locals),
                })
            })
            .transpose()?;

237
238
239
240
241
242
243
        Ok(EntrypointArgs {
            engine_type,
            model_path,
            model_name,
            endpoint_id: endpoint_id_obj,
            context_length,
            template_file,
244
            router_config,
245
            kv_cache_block_size,
246
            http_host,
Graham King's avatar
Graham King committed
247
            http_port: http_port.unwrap_or(DEFAULT_HTTP_PORT),
248
            http_metrics_port,
Graham King's avatar
Graham King committed
249
250
            tls_cert_path,
            tls_key_path,
251
            extra_engine_args,
252
            namespace,
253
            is_prefill,
254
            migration_limit,
255
            chat_engine_factory,
256
257
258
259
260
261
262
263
264
265
        })
    }
}

#[pyclass]
#[derive(Clone)]
pub(crate) struct EngineConfig {
    inner: RsEngineConfig,
}

266
267
/// Create the backend engine wrapper to run the model.
/// Download the model if necessary.
268
269
270
271
272
273
274
275
276
#[pyfunction]
#[pyo3(signature = (distributed_runtime, args))]
pub fn make_engine<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    args: EntrypointArgs,
) -> PyResult<Bound<'p, PyAny>> {
    let mut builder = LocalModelBuilder::default();
    builder
277
278
279
280
281
        .model_name(
            args.model_name
                .clone()
                .or_else(|| args.model_path.clone().map(|p| p.display().to_string())),
        )
282
        .endpoint_id(args.endpoint_id.clone())
283
        .context_length(args.context_length)
284
        .request_template(args.template_file.clone())
285
        .kv_cache_block_size(args.kv_cache_block_size)
286
        .router_config(args.router_config.clone().map(|rc| rc.into()))
287
        .migration_limit(Some(args.migration_limit))
288
        .http_host(args.http_host.clone())
289
        .http_port(args.http_port)
290
        .http_metrics_port(args.http_metrics_port)
Graham King's avatar
Graham King committed
291
292
        .tls_cert_path(args.tls_cert_path.clone())
        .tls_key_path(args.tls_key_path.clone())
293
        .is_mocker(matches!(args.engine_type, EngineType::Mocker))
294
        .extra_engine_args(args.extra_engine_args.clone())
295
        .namespace(args.namespace.clone());
296
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
297
298
299
300
        if let Some(model_path) = args.model_path.clone() {
            let local_path = if model_path.exists() {
                model_path
            } else {
301
302
303
                // Mocker only needs tokenizer, not weights
                let ignore_weights = matches!(args.engine_type, EngineType::Mocker);
                LocalModel::fetch(&model_path.display().to_string(), ignore_weights)
304
305
306
307
308
309
                    .await
                    .map_err(to_pyerr)?
            };
            builder.model_path(local_path);
        }

310
        let local_model = builder.build().await.map_err(to_pyerr)?;
311
        let inner = select_engine(distributed_runtime, args, local_model)
312
313
314
315
316
317
            .await
            .map_err(to_pyerr)?;
        Ok(EngineConfig { inner })
    })
}

318
319
/// Convert a PyEngineFactory to a Rust ChatEngineFactoryCallback
fn py_engine_factory_to_callback(factory: PyEngineFactory) -> ChatEngineFactoryCallback {
320
321
322
323
    let callback = factory.callback;
    let locals = factory.locals;

    Arc::new(
324
325
326
        move |instance_id: RsModelCardInstanceId,
              card: RsModelDeploymentCard|
              -> Pin<
327
328
329
330
331
332
333
334
            Box<dyn Future<Output = anyhow::Result<OpenAIChatCompletionsStreamingEngine>> + Send>,
        > {
            let callback = callback.clone();
            let locals = locals.clone();

            Box::pin(async move {
                // Acquire GIL to call Python callback and convert coroutine to future
                let py_future = Python::with_gil(|py| {
335
336
337
338
                    let py_instance_id =
                        Py::new(py, crate::ModelCardInstanceId { inner: instance_id }).map_err(
                            |e| anyhow::anyhow!("Failed to create Python ModelCardInstanceId: {e}"),
                        )?;
339
340
341
                    // Create Python ModelDeploymentCard wrapper
                    let py_card = ModelDeploymentCard { inner: card };
                    let py_card_obj = Py::new(py, py_card)
342
                        .map_err(|e| anyhow::anyhow!("Failed to create Python MDC: {e}"))?;
343
344
345

                    // Call Python async function to get a coroutine
                    let coroutine = callback
346
347
                        .call1(py, (py_instance_id, py_card_obj))
                        .map_err(|e| anyhow::anyhow!("Failed to call chat_engine_factory: {e}"))?;
348
349
350

                    // Use the TaskLocals captured at registration time
                    pyo3_async_runtimes::into_future_with_locals(&locals, coroutine.into_bound(py))
351
                        .map_err(|e| anyhow::anyhow!("Failed to convert coroutine to future: {e}"))
352
353
354
355
356
                })?;

                // Await the Python coroutine (GIL is released during await)
                let py_result = py_future
                    .await
357
                    .map_err(|e| anyhow::anyhow!("chat_engine_factory callback failed: {}", e))?;
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372

                // Extract PythonAsyncEngine from the Python result and wrap in Arc
                let engine: OpenAIChatCompletionsStreamingEngine = Python::with_gil(|py| {
                    let engine: PythonAsyncEngine = py_result.extract(py).map_err(|e| {
                        anyhow::anyhow!("Failed to extract PythonAsyncEngine: {}", e)
                    })?;
                    Ok::<_, anyhow::Error>(Arc::new(engine))
                })?;

                Ok(engine)
            })
        },
    )
}

373
374
async fn select_engine(
    #[allow(unused_variables)] distributed_runtime: super::DistributedRuntime,
375
    args: EntrypointArgs,
376
377
    local_model: LocalModel,
) -> anyhow::Result<RsEngineConfig> {
378
    let inner = match args.engine_type {
379
380
        EngineType::Echo => {
            // There is no validation for the echo engine
381
            RsEngineConfig::InProcessText {
382
                model: Box::new(local_model),
383
                engine: dynamo_llm::engines::make_echo_engine(),
384
385
            }
        }
386
        EngineType::Dynamic => {
387
388
            //  Convert Python chat engine factory to Rust callback
            let chat_engine_factory = args.chat_engine_factory.map(py_engine_factory_to_callback);
389
390
            RsEngineConfig::Dynamic {
                model: Box::new(local_model),
391
                chat_engine_factory,
392
393
            }
        }
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
        EngineType::Mocker => {
            let mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
                MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
                    anyhow::anyhow!(
                        "Failed to load mocker args from {:?}: {}",
                        extra_args_path,
                        e
                    )
                })?
            } else {
                tracing::warn!(
                    "No extra_engine_args specified for mocker engine. Using default mocker args."
                );
                MockEngineArgs::default()
            };

            let endpoint = local_model.endpoint_id().clone();

412
            let engine = dynamo_llm::mocker::make_mocker_engine(
413
414
415
416
417
418
                distributed_runtime.inner,
                endpoint,
                mocker_args,
            )
            .await?;

419
            RsEngineConfig::InProcessTokens {
420
421
                engine,
                model: Box::new(local_model),
422
                is_prefill: args.is_prefill,
423
424
            }
        }
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
    };

    Ok(inner)
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, input, engine_config))]
pub fn run_input<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    input: &str,
    engine_config: EngineConfig,
) -> PyResult<Bound<'p, PyAny>> {
    let input_enum: Input = input.parse().map_err(to_pyerr)?;
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        dynamo_llm::entrypoint::input::run_input(
441
            distributed_runtime.inner.clone(),
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
            input_enum,
            engine_config.inner,
        )
        .await
        .map_err(to_pyerr)?;
        Ok(())
    })
}

pub fn to_pyerr<E>(err: E) -> PyErr
where
    E: Display,
{
    PyException::new_err(format!("{}", err))
}