"vscode:/vscode.git/clone" did not exist on "b9793e6a8c30bc42f35d2a1eac919284aea27f76"
entrypoint.rs 21.7 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
// SPDX-License-Identifier: Apache-2.0

use std::fmt::Display;
5
use std::future::Future;
6
use std::path::PathBuf;
7
8
use std::pin::Pin;
use std::sync::Arc;
9

10
use pyo3::{exceptions::PyException, exceptions::PyValueError, prelude::*};
11
use pyo3_async_runtimes::TaskLocals;
12

13
use dynamo_kv_router::config::KvRouterConfig as RsKvRouterConfig;
14
use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig;
15
use dynamo_llm::entrypoint::ChatEngineFactoryCallback;
16
use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
17
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
18
use dynamo_llm::entrypoint::input::Input;
Graham King's avatar
Graham King committed
19
use dynamo_llm::local_model::DEFAULT_HTTP_PORT;
20
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
21
use dynamo_llm::mocker::make_mocker_engine;
22
23
use dynamo_llm::model_card::ModelDeploymentCard as RsModelDeploymentCard;
use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
24
25
26
use dynamo_mocker::common::perf_model::PerfModel;

use super::aic_callback::create_aic_callback;
27
28
use super::replay::MockEngineArgs as PyMockEngineArgs;
use dynamo_mocker::common::protocols::MockEngineArgs as RsMockEngineArgs;
29
use dynamo_runtime::discovery::ModelCardInstanceId as RsModelCardInstanceId;
30
use dynamo_runtime::protocols::EndpointId;
31

32
use super::local_model::ModelRuntimeConfig;
33
use super::model_card::ModelDeploymentCard;
34
use crate::RouterMode;
35
use crate::engine::PythonAsyncEngine;
36

37
38
39
40
41
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
#[repr(i32)]
pub enum EngineType {
    Echo = 1,
42
43
    Dynamic = 2,
    Mocker = 3,
44
45
}

46
#[pyclass]
47
#[derive(Default, Clone, Debug)]
48
49
50
51
pub struct KvRouterConfig {
    inner: RsKvRouterConfig,
}

52
53
impl KvRouterConfig {
    pub fn inner(&self) -> RsKvRouterConfig {
54
        self.inner.clone()
55
56
57
    }
}

58
59
60
#[pymethods]
impl KvRouterConfig {
    #[new]
61
    #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_track_prefill_tokens=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=Some(4.0), router_event_threads=4, router_queue_policy="fcfs", remote_indexer_component=None))]
62
    #[allow(clippy::too_many_arguments)]
63
64
65
66
    fn new(
        overlap_score_weight: f64,
        router_temperature: f64,
        use_kv_events: bool,
67
        durable_kv_events: bool,
68
        router_replica_sync: bool,
69
        router_track_active_blocks: bool,
70
        router_track_output_blocks: bool,
71
        router_assume_kv_reuse: bool,
72
        router_track_prefill_tokens: bool,
73
74
        router_snapshot_threshold: Option<u32>,
        router_reset_states: bool,
75
76
77
        router_ttl_secs: f64,
        router_max_tree_size: usize,
        router_prune_target_ratio: f64,
78
        router_queue_threshold: Option<f64>,
Yan Ru Pei's avatar
Yan Ru Pei committed
79
        router_event_threads: u32,
80
        router_queue_policy: &str,
81
        remote_indexer_component: Option<String>,
82
    ) -> Self {
83
84
85
86
87
        KvRouterConfig {
            inner: RsKvRouterConfig {
                overlap_score_weight,
                router_temperature,
                use_kv_events,
88
                durable_kv_events,
89
                router_replica_sync,
90
                router_track_active_blocks,
91
                router_track_output_blocks,
92
                router_assume_kv_reuse,
93
                router_track_prefill_tokens,
94
95
                router_snapshot_threshold,
                router_reset_states,
96
97
98
                router_ttl_secs,
                router_max_tree_size,
                router_prune_target_ratio,
99
                router_queue_threshold,
Yan Ru Pei's avatar
Yan Ru Pei committed
100
                router_event_threads,
101
                skip_initial_worker_wait: false,
102
103
104
                router_queue_policy: router_queue_policy.parse().unwrap_or_else(|_| {
                    panic!("invalid router_queue_policy: {router_queue_policy:?}")
                }),
105
                remote_indexer_component,
106
107
108
            },
        }
    }
109
110
111
112
113
114
115

    #[staticmethod]
    fn from_json(config_json: &str) -> PyResult<Self> {
        serde_json::from_str::<RsKvRouterConfig>(config_json)
            .map(|inner| KvRouterConfig { inner })
            .map_err(|e| PyException::new_err(format!("Failed to parse KvRouterConfig JSON: {e}")))
    }
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

    fn dump_json(&self) -> PyResult<String> {
        serde_json::to_string(&self.inner)
            .map_err(|e| PyException::new_err(format!("Failed to serialize KvRouterConfig: {e}")))
    }

    fn copy(&self) -> Self {
        self.clone()
    }

    #[getter]
    fn overlap_score_weight(&self) -> f64 {
        self.inner.overlap_score_weight
    }

    #[setter]
    fn set_overlap_score_weight(&mut self, value: f64) -> PyResult<()> {
        if value < 0.0 {
            return Err(PyValueError::new_err(
                "overlap_score_weight must be non-negative",
            ));
        }
        self.inner.overlap_score_weight = value;
        Ok(())
    }

    #[pyo3(signature = (overlap_score_weight=None))]
    fn with_overrides(&self, overlap_score_weight: Option<f64>) -> PyResult<Self> {
        let mut inner = self.inner.clone();
        if let Some(weight) = overlap_score_weight {
            if weight < 0.0 {
                return Err(PyValueError::new_err(
                    "overlap_score_weight must be non-negative",
                ));
            }
            inner.overlap_score_weight = weight;
        }
        Ok(Self { inner })
    }
155
156
157
158
159
}

#[pyclass]
#[derive(Clone, Debug)]
pub struct RouterConfig {
160
161
162
163
164
165
    #[pyo3(get, set)]
    pub router_mode: RouterMode,

    #[pyo3(get, set)]
    pub kv_router_config: KvRouterConfig,

166
167
168
169
    /// Threshold for active decode blocks utilization (0.0-1.0)
    active_decode_blocks_threshold: Option<f64>,
    /// Threshold for active prefill tokens utilization (literal token count)
    active_prefill_tokens_threshold: Option<u64>,
170
171
    /// Threshold for active prefill tokens as fraction of max_num_batched_tokens
    active_prefill_tokens_threshold_frac: Option<f64>,
172
    enforce_disagg: bool,
173
174
175
176
177
}

#[pymethods]
impl RouterConfig {
    #[new]
178
    #[pyo3(signature = (mode, config=None, active_decode_blocks_threshold=None, active_prefill_tokens_threshold=None, active_prefill_tokens_threshold_frac=None, enforce_disagg=false))]
179
180
181
    pub fn new(
        mode: RouterMode,
        config: Option<KvRouterConfig>,
182
183
        active_decode_blocks_threshold: Option<f64>,
        active_prefill_tokens_threshold: Option<u64>,
184
        active_prefill_tokens_threshold_frac: Option<f64>,
185
        enforce_disagg: bool,
186
    ) -> Self {
187
188
189
        Self {
            router_mode: mode,
            kv_router_config: config.unwrap_or_default(),
190
191
            active_decode_blocks_threshold,
            active_prefill_tokens_threshold,
192
            active_prefill_tokens_threshold_frac,
193
            enforce_disagg,
194
195
196
197
198
199
200
201
202
        }
    }
}

impl From<RouterConfig> for RsRouterConfig {
    fn from(rc: RouterConfig) -> RsRouterConfig {
        RsRouterConfig {
            router_mode: rc.router_mode.into(),
            kv_router_config: rc.kv_router_config.inner,
203
204
205
206
207
            load_threshold_config: RsLoadThresholdConfig {
                active_decode_blocks_threshold: rc.active_decode_blocks_threshold,
                active_prefill_tokens_threshold: rc.active_prefill_tokens_threshold,
                active_prefill_tokens_threshold_frac: rc.active_prefill_tokens_threshold_frac,
            },
208
            enforce_disagg: rc.enforce_disagg,
209
210
211
212
        }
    }
}

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
/// Wrapper to hold Python callback and its TaskLocals for async execution
#[derive(Clone)]
struct PyEngineFactory {
    callback: Arc<PyObject>,
    locals: Arc<TaskLocals>,
}

impl std::fmt::Debug for PyEngineFactory {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("PyEngineFactory")
            .field("callback", &"<PyObject>")
            .finish()
    }
}

228
229
230
231
232
233
234
235
236
#[pyclass]
#[derive(Clone, Debug)]
pub(crate) struct EntrypointArgs {
    engine_type: EngineType,
    model_path: Option<PathBuf>,
    model_name: Option<String>,
    endpoint_id: Option<EndpointId>,
    context_length: Option<u32>,
    template_file: Option<PathBuf>,
237
    router_config: Option<RouterConfig>,
238
    kv_cache_block_size: Option<u32>,
239
    http_host: Option<String>,
Graham King's avatar
Graham King committed
240
    http_port: u16,
241
    http_metrics_port: Option<u16>,
Graham King's avatar
Graham King committed
242
243
    tls_cert_path: Option<PathBuf>,
    tls_key_path: Option<PathBuf>,
244
    extra_engine_args: Option<PathBuf>,
245
    mocker_engine_args: Option<PyMockEngineArgs>,
246
    runtime_config: Option<ModelRuntimeConfig>,
247
    namespace: Option<String>,
248
    namespace_prefix: Option<String>,
249
    is_prefill: bool,
250
    migration_limit: u32,
251
    chat_engine_factory: Option<PyEngineFactory>,
252
253
254
255
256
257
}

#[pymethods]
impl EntrypointArgs {
    #[allow(clippy::too_many_arguments)]
    #[new]
258
    #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, mocker_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
259
    pub fn new(
260
        py: Python<'_>,
261
262
263
264
265
266
        engine_type: EngineType,
        model_path: Option<PathBuf>,
        model_name: Option<String>, // e.g. "dyn://namespace.component.endpoint"
        endpoint_id: Option<String>,
        context_length: Option<u32>,
        template_file: Option<PathBuf>,
267
        router_config: Option<RouterConfig>,
268
        kv_cache_block_size: Option<u32>,
269
        http_host: Option<String>,
270
        http_port: Option<u16>,
271
        http_metrics_port: Option<u16>,
Graham King's avatar
Graham King committed
272
273
        tls_cert_path: Option<PathBuf>,
        tls_key_path: Option<PathBuf>,
274
        extra_engine_args: Option<PathBuf>,
275
        mocker_engine_args: Option<PyMockEngineArgs>,
276
        runtime_config: Option<ModelRuntimeConfig>,
277
        namespace: Option<String>,
278
        namespace_prefix: Option<String>,
279
        is_prefill: bool,
280
        migration_limit: u32,
281
        chat_engine_factory: Option<PyObject>,
282
    ) -> PyResult<Self> {
283
        let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
Graham King's avatar
Graham King committed
284
285
286
287
288
289
290
        if (tls_cert_path.is_some() && tls_key_path.is_none())
            || (tls_cert_path.is_none() && tls_key_path.is_some())
        {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "tls_cert_path and tls_key_path must be provided together",
            ));
        }
291

292
293
        // Capture TaskLocals at registration time for the chat engine factory callback
        let chat_engine_factory = chat_engine_factory
294
295
296
            .map(|callback| {
                let locals = pyo3_async_runtimes::tokio::get_current_locals(py).map_err(|e| {
                    pyo3::exceptions::PyRuntimeError::new_err(format!(
297
                        "Failed to get TaskLocals for chat_engine_factory: {}",
298
299
300
301
302
303
304
305
306
307
                        e
                    ))
                })?;
                Ok::<_, PyErr>(PyEngineFactory {
                    callback: Arc::new(callback),
                    locals: Arc::new(locals),
                })
            })
            .transpose()?;

308
309
310
311
312
313
314
        Ok(EntrypointArgs {
            engine_type,
            model_path,
            model_name,
            endpoint_id: endpoint_id_obj,
            context_length,
            template_file,
315
            router_config,
316
            kv_cache_block_size,
317
            http_host,
Graham King's avatar
Graham King committed
318
            http_port: http_port.unwrap_or(DEFAULT_HTTP_PORT),
319
            http_metrics_port,
Graham King's avatar
Graham King committed
320
321
            tls_cert_path,
            tls_key_path,
322
            extra_engine_args,
323
            mocker_engine_args,
324
            runtime_config,
325
            namespace,
326
            namespace_prefix,
327
            is_prefill,
328
            migration_limit,
329
            chat_engine_factory,
330
331
332
333
334
335
336
337
338
339
        })
    }
}

#[pyclass]
#[derive(Clone)]
pub(crate) struct EngineConfig {
    inner: RsEngineConfig,
}

340
341
/// Create the backend engine wrapper to run the model.
/// Download the model if necessary.
342
343
344
345
346
347
348
349
350
#[pyfunction]
#[pyo3(signature = (distributed_runtime, args))]
pub fn make_engine<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    args: EntrypointArgs,
) -> PyResult<Bound<'p, PyAny>> {
    let mut builder = LocalModelBuilder::default();
    builder
351
352
353
354
355
        .model_name(
            args.model_name
                .clone()
                .or_else(|| args.model_path.clone().map(|p| p.display().to_string())),
        )
356
        .endpoint_id(args.endpoint_id.clone())
357
        .context_length(args.context_length)
358
        .request_template(args.template_file.clone())
359
        .kv_cache_block_size(args.kv_cache_block_size)
360
        .router_config(args.router_config.clone().map(|rc| rc.into()))
361
        .migration_limit(Some(args.migration_limit))
362
        .http_host(args.http_host.clone())
363
        .http_port(args.http_port)
364
        .http_metrics_port(args.http_metrics_port)
Graham King's avatar
Graham King committed
365
366
        .tls_cert_path(args.tls_cert_path.clone())
        .tls_key_path(args.tls_key_path.clone())
367
        .is_mocker(matches!(args.engine_type, EngineType::Mocker))
368
        .extra_engine_args(args.extra_engine_args.clone())
369
        .runtime_config(args.runtime_config.clone().unwrap_or_default().inner)
370
371
        .namespace(args.namespace.clone())
        .namespace_prefix(args.namespace_prefix.clone());
372
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
373
374
375
376
        if let Some(model_path) = args.model_path.clone() {
            let local_path = if model_path.exists() {
                model_path
            } else {
377
378
                // Mocker only needs tokenizer, not weights
                let ignore_weights = matches!(args.engine_type, EngineType::Mocker);
379
380
381
382
383
                // Preserve the original HF model ID as source_path so the
                // frontend can resolve model metadata even when the served
                // model name differs (e.g., --model-name model-1 --model-path
                // Qwen/Qwen3-0.6B).
                builder.source_path(model_path.clone());
384
                LocalModel::fetch(&model_path.display().to_string(), ignore_weights)
385
386
387
388
389
390
                    .await
                    .map_err(to_pyerr)?
            };
            builder.model_path(local_path);
        }

391
        let local_model = builder.build().await.map_err(to_pyerr)?;
392
        let inner = select_engine(distributed_runtime, args, local_model)
393
394
395
396
397
398
            .await
            .map_err(to_pyerr)?;
        Ok(EngineConfig { inner })
    })
}

399
400
/// Convert a PyEngineFactory to a Rust ChatEngineFactoryCallback
fn py_engine_factory_to_callback(factory: PyEngineFactory) -> ChatEngineFactoryCallback {
401
402
403
404
    let callback = factory.callback;
    let locals = factory.locals;

    Arc::new(
405
406
407
        move |instance_id: RsModelCardInstanceId,
              card: RsModelDeploymentCard|
              -> Pin<
408
409
410
411
412
413
414
415
            Box<dyn Future<Output = anyhow::Result<OpenAIChatCompletionsStreamingEngine>> + Send>,
        > {
            let callback = callback.clone();
            let locals = locals.clone();

            Box::pin(async move {
                // Acquire GIL to call Python callback and convert coroutine to future
                let py_future = Python::with_gil(|py| {
416
417
418
419
                    let py_instance_id =
                        Py::new(py, crate::ModelCardInstanceId { inner: instance_id }).map_err(
                            |e| anyhow::anyhow!("Failed to create Python ModelCardInstanceId: {e}"),
                        )?;
420
421
422
                    // Create Python ModelDeploymentCard wrapper
                    let py_card = ModelDeploymentCard { inner: card };
                    let py_card_obj = Py::new(py, py_card)
423
                        .map_err(|e| anyhow::anyhow!("Failed to create Python MDC: {e}"))?;
424
425
426

                    // Call Python async function to get a coroutine
                    let coroutine = callback
427
428
                        .call1(py, (py_instance_id, py_card_obj))
                        .map_err(|e| anyhow::anyhow!("Failed to call chat_engine_factory: {e}"))?;
429
430
431

                    // Use the TaskLocals captured at registration time
                    pyo3_async_runtimes::into_future_with_locals(&locals, coroutine.into_bound(py))
432
                        .map_err(|e| anyhow::anyhow!("Failed to convert coroutine to future: {e}"))
433
434
435
436
437
                })?;

                // Await the Python coroutine (GIL is released during await)
                let py_result = py_future
                    .await
438
                    .map_err(|e| anyhow::anyhow!("chat_engine_factory callback failed: {}", e))?;
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453

                // Extract PythonAsyncEngine from the Python result and wrap in Arc
                let engine: OpenAIChatCompletionsStreamingEngine = Python::with_gil(|py| {
                    let engine: PythonAsyncEngine = py_result.extract(py).map_err(|e| {
                        anyhow::anyhow!("Failed to extract PythonAsyncEngine: {}", e)
                    })?;
                    Ok::<_, anyhow::Error>(Arc::new(engine))
                })?;

                Ok(engine)
            })
        },
    )
}

454
455
async fn select_engine(
    #[allow(unused_variables)] distributed_runtime: super::DistributedRuntime,
456
    args: EntrypointArgs,
457
458
    local_model: LocalModel,
) -> anyhow::Result<RsEngineConfig> {
459
    let inner = match args.engine_type {
460
461
        EngineType::Echo => {
            // There is no validation for the echo engine
462
            RsEngineConfig::InProcessText {
463
                model: Box::new(local_model),
464
                engine: dynamo_llm::engines::make_echo_engine(),
465
466
            }
        }
467
        EngineType::Dynamic => {
468
469
            //  Convert Python chat engine factory to Rust callback
            let chat_engine_factory = args.chat_engine_factory.map(py_engine_factory_to_callback);
470
471
            RsEngineConfig::Dynamic {
                model: Box::new(local_model),
472
                chat_engine_factory,
473
474
            }
        }
475
        EngineType::Mocker => {
476
477
478
479
            let mut mocker_args = if let Some(mocker_engine_args) = args.mocker_engine_args {
                mocker_engine_args.inner()
            } else if let Some(extra_args_path) = args.extra_engine_args {
                RsMockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
480
481
482
483
484
485
486
487
488
489
                    anyhow::anyhow!(
                        "Failed to load mocker args from {:?}: {}",
                        extra_args_path,
                        e
                    )
                })?
            } else {
                tracing::warn!(
                    "No extra_engine_args specified for mocker engine. Using default mocker args."
                );
490
                RsMockEngineArgs::default()
491
492
            };

493
494
495
496
497
498
499
500
501
502
            // If aic_backend is set, create Python AIC callback and override perf_model
            if let Some(ref backend_name) = mocker_args.aic_backend {
                let backend = backend_name.clone();
                let system = mocker_args.aic_system.as_deref().unwrap_or("h200_sxm");
                let model_name = mocker_args
                    .aic_model_path
                    .as_deref()
                    .unwrap_or_else(|| local_model.card().source_path());
                let backend_version = mocker_args.aic_backend_version.as_deref();
                let tp_size = mocker_args.aic_tp_size.unwrap_or(1);
503
504
505
                let moe_tp_size = mocker_args.aic_moe_tp_size;
                let moe_ep_size = mocker_args.aic_moe_ep_size;
                let attention_dp_size = mocker_args.aic_attention_dp_size;
506
                match Python::with_gil(|py| {
507
508
509
510
511
512
513
514
515
516
517
                    create_aic_callback(
                        py,
                        &backend,
                        system,
                        model_name,
                        tp_size,
                        backend_version,
                        moe_tp_size,
                        moe_ep_size,
                        attention_dp_size,
                    )
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
                }) {
                    Ok(callback) => {
                        tracing::info!(
                            "AIC perf model: backend={}, gpu={}, model={}, version={:?}",
                            backend,
                            system,
                            model_name,
                            backend_version
                        );
                        mocker_args.perf_model = Arc::new(PerfModel::from_aic_callback(callback));
                    }
                    Err(e) => {
                        return Err(anyhow::anyhow!(
                            "Failed to create AIC callback (--aic-perf-model was requested): {}",
                            e
                        ));
                    }
                }
            }

538
539
            let endpoint = local_model.endpoint_id().clone();

540
            let engine =
541
                make_mocker_engine(distributed_runtime.inner, endpoint, mocker_args).await?;
542

543
            RsEngineConfig::InProcessTokens {
544
545
                engine,
                model: Box::new(local_model),
546
                is_prefill: args.is_prefill,
547
548
            }
        }
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    };

    Ok(inner)
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, input, engine_config))]
pub fn run_input<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    input: &str,
    engine_config: EngineConfig,
) -> PyResult<Bound<'p, PyAny>> {
    let input_enum: Input = input.parse().map_err(to_pyerr)?;
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        dynamo_llm::entrypoint::input::run_input(
565
            distributed_runtime.inner.clone(),
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
            input_enum,
            engine_config.inner,
        )
        .await
        .map_err(to_pyerr)?;
        Ok(())
    })
}

pub fn to_pyerr<E>(err: E) -> PyErr
where
    E: Display,
{
    PyException::new_err(format!("{}", err))
}