entrypoint.rs 17.5 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
// SPDX-License-Identifier: Apache-2.0

use std::fmt::Display;
5
use std::future::Future;
6
use std::path::PathBuf;
7
8
use std::pin::Pin;
use std::sync::Arc;
9
10

use pyo3::{exceptions::PyException, prelude::*};
11
use pyo3_async_runtimes::TaskLocals;
12

13
use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig;
14
use dynamo_llm::entrypoint::ChatEngineFactoryCallback;
15
use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
16
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
17
use dynamo_llm::entrypoint::input::Input;
18
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
Graham King's avatar
Graham King committed
19
use dynamo_llm::local_model::DEFAULT_HTTP_PORT;
20
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
21
use dynamo_llm::mocker::make_mocker_engine;
22
23
use dynamo_llm::model_card::ModelDeploymentCard as RsModelDeploymentCard;
use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
24
use dynamo_mocker::common::protocols::MockEngineArgs;
25
use dynamo_runtime::discovery::ModelCardInstanceId as RsModelCardInstanceId;
26
use dynamo_runtime::protocols::EndpointId;
27

28
use super::local_model::ModelRuntimeConfig;
29
use super::model_card::ModelDeploymentCard;
30
use crate::RouterMode;
31
use crate::engine::PythonAsyncEngine;
32

33
34
35
36
37
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
#[repr(i32)]
pub enum EngineType {
    Echo = 1,
38
39
    Dynamic = 2,
    Mocker = 3,
40
41
}

42
#[pyclass]
43
#[derive(Default, Clone, Debug)]
44
45
46
47
pub struct KvRouterConfig {
    inner: RsKvRouterConfig,
}

48
49
impl KvRouterConfig {
    pub fn inner(&self) -> RsKvRouterConfig {
50
        self.inner.clone()
51
52
53
    }
}

54
55
56
#[pymethods]
impl KvRouterConfig {
    #[new]
57
    #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(2.0), router_event_threads=4, router_enable_cache_control=false, router_queue_policy="fcfs", remote_indexer_component=None))]
58
    #[allow(clippy::too_many_arguments)]
59
60
61
62
    fn new(
        overlap_score_weight: f64,
        router_temperature: f64,
        use_kv_events: bool,
63
        durable_kv_events: bool,
64
        router_replica_sync: bool,
65
        router_track_active_blocks: bool,
66
        router_track_output_blocks: bool,
67
        router_assume_kv_reuse: bool,
68
69
        router_snapshot_threshold: Option<u32>,
        router_reset_states: bool,
70
71
72
        router_ttl_secs: f64,
        router_max_tree_size: usize,
        router_prune_target_ratio: f64,
73
        router_queue_threshold: Option<f64>,
Yan Ru Pei's avatar
Yan Ru Pei committed
74
        router_event_threads: u32,
75
        router_enable_cache_control: bool,
76
        router_queue_policy: &str,
77
        remote_indexer_component: Option<String>,
78
    ) -> Self {
79
80
81
82
83
        KvRouterConfig {
            inner: RsKvRouterConfig {
                overlap_score_weight,
                router_temperature,
                use_kv_events,
84
                durable_kv_events,
85
                router_replica_sync,
86
                router_track_active_blocks,
87
                router_track_output_blocks,
88
                router_assume_kv_reuse,
89
90
                router_snapshot_threshold,
                router_reset_states,
91
92
93
                router_ttl_secs,
                router_max_tree_size,
                router_prune_target_ratio,
94
                router_queue_threshold,
Yan Ru Pei's avatar
Yan Ru Pei committed
95
                router_event_threads,
96
                router_enable_cache_control,
97
                skip_initial_worker_wait: false,
98
99
100
                router_queue_policy: router_queue_policy.parse().unwrap_or_else(|_| {
                    panic!("invalid router_queue_policy: {router_queue_policy:?}")
                }),
101
                remote_indexer_component,
102
103
104
105
106
107
108
109
            },
        }
    }
}

#[pyclass]
#[derive(Clone, Debug)]
pub struct RouterConfig {
110
111
112
113
114
115
    #[pyo3(get, set)]
    pub router_mode: RouterMode,

    #[pyo3(get, set)]
    pub kv_router_config: KvRouterConfig,

116
117
118
119
    /// Threshold for active decode blocks utilization (0.0-1.0)
    active_decode_blocks_threshold: Option<f64>,
    /// Threshold for active prefill tokens utilization (literal token count)
    active_prefill_tokens_threshold: Option<u64>,
120
121
    /// Threshold for active prefill tokens as fraction of max_num_batched_tokens
    active_prefill_tokens_threshold_frac: Option<f64>,
122
    enforce_disagg: bool,
123
124
125
126
127
}

#[pymethods]
impl RouterConfig {
    #[new]
128
    #[pyo3(signature = (mode, config=None, active_decode_blocks_threshold=None, active_prefill_tokens_threshold=None, active_prefill_tokens_threshold_frac=None, enforce_disagg=false))]
129
130
131
    pub fn new(
        mode: RouterMode,
        config: Option<KvRouterConfig>,
132
133
        active_decode_blocks_threshold: Option<f64>,
        active_prefill_tokens_threshold: Option<u64>,
134
        active_prefill_tokens_threshold_frac: Option<f64>,
135
        enforce_disagg: bool,
136
    ) -> Self {
137
138
139
        Self {
            router_mode: mode,
            kv_router_config: config.unwrap_or_default(),
140
141
            active_decode_blocks_threshold,
            active_prefill_tokens_threshold,
142
            active_prefill_tokens_threshold_frac,
143
            enforce_disagg,
144
145
146
147
148
149
150
151
152
        }
    }
}

impl From<RouterConfig> for RsRouterConfig {
    fn from(rc: RouterConfig) -> RsRouterConfig {
        RsRouterConfig {
            router_mode: rc.router_mode.into(),
            kv_router_config: rc.kv_router_config.inner,
153
154
155
156
157
            load_threshold_config: RsLoadThresholdConfig {
                active_decode_blocks_threshold: rc.active_decode_blocks_threshold,
                active_prefill_tokens_threshold: rc.active_prefill_tokens_threshold,
                active_prefill_tokens_threshold_frac: rc.active_prefill_tokens_threshold_frac,
            },
158
            enforce_disagg: rc.enforce_disagg,
159
160
161
162
        }
    }
}

163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
/// Wrapper to hold Python callback and its TaskLocals for async execution
#[derive(Clone)]
struct PyEngineFactory {
    callback: Arc<PyObject>,
    locals: Arc<TaskLocals>,
}

impl std::fmt::Debug for PyEngineFactory {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("PyEngineFactory")
            .field("callback", &"<PyObject>")
            .finish()
    }
}

178
179
180
181
182
183
184
185
186
#[pyclass]
#[derive(Clone, Debug)]
pub(crate) struct EntrypointArgs {
    engine_type: EngineType,
    model_path: Option<PathBuf>,
    model_name: Option<String>,
    endpoint_id: Option<EndpointId>,
    context_length: Option<u32>,
    template_file: Option<PathBuf>,
187
    router_config: Option<RouterConfig>,
188
    kv_cache_block_size: Option<u32>,
189
    http_host: Option<String>,
Graham King's avatar
Graham King committed
190
    http_port: u16,
191
    http_metrics_port: Option<u16>,
Graham King's avatar
Graham King committed
192
193
    tls_cert_path: Option<PathBuf>,
    tls_key_path: Option<PathBuf>,
194
    extra_engine_args: Option<PathBuf>,
195
    runtime_config: Option<ModelRuntimeConfig>,
196
    namespace: Option<String>,
197
    namespace_prefix: Option<String>,
198
    is_prefill: bool,
199
    migration_limit: u32,
200
    chat_engine_factory: Option<PyEngineFactory>,
201
202
203
204
205
206
}

#[pymethods]
impl EntrypointArgs {
    #[allow(clippy::too_many_arguments)]
    #[new]
207
    #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
208
    pub fn new(
209
        py: Python<'_>,
210
211
212
213
214
215
        engine_type: EngineType,
        model_path: Option<PathBuf>,
        model_name: Option<String>, // e.g. "dyn://namespace.component.endpoint"
        endpoint_id: Option<String>,
        context_length: Option<u32>,
        template_file: Option<PathBuf>,
216
        router_config: Option<RouterConfig>,
217
        kv_cache_block_size: Option<u32>,
218
        http_host: Option<String>,
219
        http_port: Option<u16>,
220
        http_metrics_port: Option<u16>,
Graham King's avatar
Graham King committed
221
222
        tls_cert_path: Option<PathBuf>,
        tls_key_path: Option<PathBuf>,
223
        extra_engine_args: Option<PathBuf>,
224
        runtime_config: Option<ModelRuntimeConfig>,
225
        namespace: Option<String>,
226
        namespace_prefix: Option<String>,
227
        is_prefill: bool,
228
        migration_limit: u32,
229
        chat_engine_factory: Option<PyObject>,
230
    ) -> PyResult<Self> {
231
        let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
Graham King's avatar
Graham King committed
232
233
234
235
236
237
238
        if (tls_cert_path.is_some() && tls_key_path.is_none())
            || (tls_cert_path.is_none() && tls_key_path.is_some())
        {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "tls_cert_path and tls_key_path must be provided together",
            ));
        }
239

240
241
        // Capture TaskLocals at registration time for the chat engine factory callback
        let chat_engine_factory = chat_engine_factory
242
243
244
            .map(|callback| {
                let locals = pyo3_async_runtimes::tokio::get_current_locals(py).map_err(|e| {
                    pyo3::exceptions::PyRuntimeError::new_err(format!(
245
                        "Failed to get TaskLocals for chat_engine_factory: {}",
246
247
248
249
250
251
252
253
254
255
                        e
                    ))
                })?;
                Ok::<_, PyErr>(PyEngineFactory {
                    callback: Arc::new(callback),
                    locals: Arc::new(locals),
                })
            })
            .transpose()?;

256
257
258
259
260
261
262
        Ok(EntrypointArgs {
            engine_type,
            model_path,
            model_name,
            endpoint_id: endpoint_id_obj,
            context_length,
            template_file,
263
            router_config,
264
            kv_cache_block_size,
265
            http_host,
Graham King's avatar
Graham King committed
266
            http_port: http_port.unwrap_or(DEFAULT_HTTP_PORT),
267
            http_metrics_port,
Graham King's avatar
Graham King committed
268
269
            tls_cert_path,
            tls_key_path,
270
            extra_engine_args,
271
            runtime_config,
272
            namespace,
273
            namespace_prefix,
274
            is_prefill,
275
            migration_limit,
276
            chat_engine_factory,
277
278
279
280
281
282
283
284
285
286
        })
    }
}

#[pyclass]
#[derive(Clone)]
pub(crate) struct EngineConfig {
    inner: RsEngineConfig,
}

287
288
/// Create the backend engine wrapper to run the model.
/// Download the model if necessary.
289
290
291
292
293
294
295
296
297
#[pyfunction]
#[pyo3(signature = (distributed_runtime, args))]
pub fn make_engine<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    args: EntrypointArgs,
) -> PyResult<Bound<'p, PyAny>> {
    let mut builder = LocalModelBuilder::default();
    builder
298
299
300
301
302
        .model_name(
            args.model_name
                .clone()
                .or_else(|| args.model_path.clone().map(|p| p.display().to_string())),
        )
303
        .endpoint_id(args.endpoint_id.clone())
304
        .context_length(args.context_length)
305
        .request_template(args.template_file.clone())
306
        .kv_cache_block_size(args.kv_cache_block_size)
307
        .router_config(args.router_config.clone().map(|rc| rc.into()))
308
        .migration_limit(Some(args.migration_limit))
309
        .http_host(args.http_host.clone())
310
        .http_port(args.http_port)
311
        .http_metrics_port(args.http_metrics_port)
Graham King's avatar
Graham King committed
312
313
        .tls_cert_path(args.tls_cert_path.clone())
        .tls_key_path(args.tls_key_path.clone())
314
        .is_mocker(matches!(args.engine_type, EngineType::Mocker))
315
        .extra_engine_args(args.extra_engine_args.clone())
316
        .runtime_config(args.runtime_config.clone().unwrap_or_default().inner)
317
318
        .namespace(args.namespace.clone())
        .namespace_prefix(args.namespace_prefix.clone());
319
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
320
321
322
323
        if let Some(model_path) = args.model_path.clone() {
            let local_path = if model_path.exists() {
                model_path
            } else {
324
325
326
                // Mocker only needs tokenizer, not weights
                let ignore_weights = matches!(args.engine_type, EngineType::Mocker);
                LocalModel::fetch(&model_path.display().to_string(), ignore_weights)
327
328
329
330
331
332
                    .await
                    .map_err(to_pyerr)?
            };
            builder.model_path(local_path);
        }

333
        let local_model = builder.build().await.map_err(to_pyerr)?;
334
        let inner = select_engine(distributed_runtime, args, local_model)
335
336
337
338
339
340
            .await
            .map_err(to_pyerr)?;
        Ok(EngineConfig { inner })
    })
}

341
342
/// Convert a PyEngineFactory to a Rust ChatEngineFactoryCallback
fn py_engine_factory_to_callback(factory: PyEngineFactory) -> ChatEngineFactoryCallback {
343
344
345
346
    let callback = factory.callback;
    let locals = factory.locals;

    Arc::new(
347
348
349
        move |instance_id: RsModelCardInstanceId,
              card: RsModelDeploymentCard|
              -> Pin<
350
351
352
353
354
355
356
357
            Box<dyn Future<Output = anyhow::Result<OpenAIChatCompletionsStreamingEngine>> + Send>,
        > {
            let callback = callback.clone();
            let locals = locals.clone();

            Box::pin(async move {
                // Acquire GIL to call Python callback and convert coroutine to future
                let py_future = Python::with_gil(|py| {
358
359
360
361
                    let py_instance_id =
                        Py::new(py, crate::ModelCardInstanceId { inner: instance_id }).map_err(
                            |e| anyhow::anyhow!("Failed to create Python ModelCardInstanceId: {e}"),
                        )?;
362
363
364
                    // Create Python ModelDeploymentCard wrapper
                    let py_card = ModelDeploymentCard { inner: card };
                    let py_card_obj = Py::new(py, py_card)
365
                        .map_err(|e| anyhow::anyhow!("Failed to create Python MDC: {e}"))?;
366
367
368

                    // Call Python async function to get a coroutine
                    let coroutine = callback
369
370
                        .call1(py, (py_instance_id, py_card_obj))
                        .map_err(|e| anyhow::anyhow!("Failed to call chat_engine_factory: {e}"))?;
371
372
373

                    // Use the TaskLocals captured at registration time
                    pyo3_async_runtimes::into_future_with_locals(&locals, coroutine.into_bound(py))
374
                        .map_err(|e| anyhow::anyhow!("Failed to convert coroutine to future: {e}"))
375
376
377
378
379
                })?;

                // Await the Python coroutine (GIL is released during await)
                let py_result = py_future
                    .await
380
                    .map_err(|e| anyhow::anyhow!("chat_engine_factory callback failed: {}", e))?;
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395

                // Extract PythonAsyncEngine from the Python result and wrap in Arc
                let engine: OpenAIChatCompletionsStreamingEngine = Python::with_gil(|py| {
                    let engine: PythonAsyncEngine = py_result.extract(py).map_err(|e| {
                        anyhow::anyhow!("Failed to extract PythonAsyncEngine: {}", e)
                    })?;
                    Ok::<_, anyhow::Error>(Arc::new(engine))
                })?;

                Ok(engine)
            })
        },
    )
}

396
397
async fn select_engine(
    #[allow(unused_variables)] distributed_runtime: super::DistributedRuntime,
398
    args: EntrypointArgs,
399
400
    local_model: LocalModel,
) -> anyhow::Result<RsEngineConfig> {
401
    let inner = match args.engine_type {
402
403
        EngineType::Echo => {
            // There is no validation for the echo engine
404
            RsEngineConfig::InProcessText {
405
                model: Box::new(local_model),
406
                engine: dynamo_llm::engines::make_echo_engine(),
407
408
            }
        }
409
        EngineType::Dynamic => {
410
411
            //  Convert Python chat engine factory to Rust callback
            let chat_engine_factory = args.chat_engine_factory.map(py_engine_factory_to_callback);
412
413
            RsEngineConfig::Dynamic {
                model: Box::new(local_model),
414
                chat_engine_factory,
415
416
            }
        }
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
        EngineType::Mocker => {
            let mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
                MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
                    anyhow::anyhow!(
                        "Failed to load mocker args from {:?}: {}",
                        extra_args_path,
                        e
                    )
                })?
            } else {
                tracing::warn!(
                    "No extra_engine_args specified for mocker engine. Using default mocker args."
                );
                MockEngineArgs::default()
            };

            let endpoint = local_model.endpoint_id().clone();

435
            let engine =
436
                make_mocker_engine(distributed_runtime.inner, endpoint, mocker_args).await?;
437

438
            RsEngineConfig::InProcessTokens {
439
440
                engine,
                model: Box::new(local_model),
441
                is_prefill: args.is_prefill,
442
443
            }
        }
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
    };

    Ok(inner)
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, input, engine_config))]
pub fn run_input<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    input: &str,
    engine_config: EngineConfig,
) -> PyResult<Bound<'p, PyAny>> {
    let input_enum: Input = input.parse().map_err(to_pyerr)?;
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        dynamo_llm::entrypoint::input::run_input(
460
            distributed_runtime.inner.clone(),
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
            input_enum,
            engine_config.inner,
        )
        .await
        .map_err(to_pyerr)?;
        Ok(())
    })
}

pub fn to_pyerr<E>(err: E) -> PyErr
where
    E: Display,
{
    PyException::new_err(format!("{}", err))
}