"...controller/dynamocomponentdeployment_controller_test.go" did not exist on "33d9ae7811e70db24ea15a39f475307587265a48"
entrypoint.rs 16.5 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
// SPDX-License-Identifier: Apache-2.0

use std::fmt::Display;
5
use std::future::Future;
6
use std::path::PathBuf;
7
8
use std::pin::Pin;
use std::sync::Arc;
9
10

use pyo3::{exceptions::PyException, prelude::*};
11
use pyo3_async_runtimes::TaskLocals;
12

13
use dynamo_llm::discovery::LoadThresholdConfig as RsLoadThresholdConfig;
14
use dynamo_llm::entrypoint::ChatEngineFactoryCallback;
15
use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
16
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
17
use dynamo_llm::entrypoint::input::Input;
18
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
Graham King's avatar
Graham King committed
19
use dynamo_llm::local_model::DEFAULT_HTTP_PORT;
20
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
21
use dynamo_llm::mocker::protocols::MockEngineArgs;
22
23
use dynamo_llm::model_card::ModelDeploymentCard as RsModelDeploymentCard;
use dynamo_llm::types::openai::chat_completions::OpenAIChatCompletionsStreamingEngine;
24
use dynamo_runtime::discovery::ModelCardInstanceId as RsModelCardInstanceId;
25
use dynamo_runtime::protocols::EndpointId;
26

27
use super::model_card::ModelDeploymentCard;
28
use crate::RouterMode;
29
use crate::engine::PythonAsyncEngine;
30

31
32
33
34
35
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
#[repr(i32)]
pub enum EngineType {
    Echo = 1,
36
37
    Dynamic = 2,
    Mocker = 3,
38
39
}

40
41
42
43
44
45
#[pyclass]
#[derive(Default, Clone, Debug, Copy)]
pub struct KvRouterConfig {
    inner: RsKvRouterConfig,
}

46
47
48
49
50
51
impl KvRouterConfig {
    pub fn inner(&self) -> RsKvRouterConfig {
        self.inner
    }
}

52
53
54
#[pymethods]
impl KvRouterConfig {
    #[new]
55
    #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, durable_kv_events=false, router_replica_sync=false, router_track_active_blocks=true, router_track_output_blocks=false, router_assume_kv_reuse=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1048576, router_prune_target_ratio=0.8, router_queue_threshold=None, router_event_threads=1))]
56
    #[allow(clippy::too_many_arguments)]
57
58
59
60
    fn new(
        overlap_score_weight: f64,
        router_temperature: f64,
        use_kv_events: bool,
61
        durable_kv_events: bool,
62
        router_replica_sync: bool,
63
        router_track_active_blocks: bool,
64
        router_track_output_blocks: bool,
65
        router_assume_kv_reuse: bool,
66
67
        router_snapshot_threshold: Option<u32>,
        router_reset_states: bool,
68
69
70
        router_ttl_secs: f64,
        router_max_tree_size: usize,
        router_prune_target_ratio: f64,
71
        router_queue_threshold: Option<f64>,
Yan Ru Pei's avatar
Yan Ru Pei committed
72
        router_event_threads: u32,
73
    ) -> Self {
74
75
76
77
78
        KvRouterConfig {
            inner: RsKvRouterConfig {
                overlap_score_weight,
                router_temperature,
                use_kv_events,
79
                durable_kv_events,
80
                router_replica_sync,
81
                router_track_active_blocks,
82
                router_track_output_blocks,
83
                router_assume_kv_reuse,
84
85
                router_snapshot_threshold,
                router_reset_states,
86
87
88
                router_ttl_secs,
                router_max_tree_size,
                router_prune_target_ratio,
89
                router_queue_threshold,
Yan Ru Pei's avatar
Yan Ru Pei committed
90
                router_event_threads,
91
92
93
94
95
96
97
98
            },
        }
    }
}

#[pyclass]
#[derive(Clone, Debug)]
pub struct RouterConfig {
99
100
101
102
103
104
    #[pyo3(get, set)]
    pub router_mode: RouterMode,

    #[pyo3(get, set)]
    pub kv_router_config: KvRouterConfig,

105
106
107
108
    /// Threshold for active decode blocks utilization (0.0-1.0)
    active_decode_blocks_threshold: Option<f64>,
    /// Threshold for active prefill tokens utilization (literal token count)
    active_prefill_tokens_threshold: Option<u64>,
109
110
    /// Threshold for active prefill tokens as fraction of max_num_batched_tokens
    active_prefill_tokens_threshold_frac: Option<f64>,
111
    enforce_disagg: bool,
112
113
114
115
116
}

#[pymethods]
impl RouterConfig {
    #[new]
117
    #[pyo3(signature = (mode, config=None, active_decode_blocks_threshold=None, active_prefill_tokens_threshold=None, active_prefill_tokens_threshold_frac=None, enforce_disagg=false))]
118
119
120
    pub fn new(
        mode: RouterMode,
        config: Option<KvRouterConfig>,
121
122
        active_decode_blocks_threshold: Option<f64>,
        active_prefill_tokens_threshold: Option<u64>,
123
        active_prefill_tokens_threshold_frac: Option<f64>,
124
        enforce_disagg: bool,
125
    ) -> Self {
126
127
128
        Self {
            router_mode: mode,
            kv_router_config: config.unwrap_or_default(),
129
130
            active_decode_blocks_threshold,
            active_prefill_tokens_threshold,
131
            active_prefill_tokens_threshold_frac,
132
            enforce_disagg,
133
134
135
136
137
138
139
140
141
        }
    }
}

impl From<RouterConfig> for RsRouterConfig {
    fn from(rc: RouterConfig) -> RsRouterConfig {
        RsRouterConfig {
            router_mode: rc.router_mode.into(),
            kv_router_config: rc.kv_router_config.inner,
142
143
144
145
146
            load_threshold_config: RsLoadThresholdConfig {
                active_decode_blocks_threshold: rc.active_decode_blocks_threshold,
                active_prefill_tokens_threshold: rc.active_prefill_tokens_threshold,
                active_prefill_tokens_threshold_frac: rc.active_prefill_tokens_threshold_frac,
            },
147
            enforce_disagg: rc.enforce_disagg,
148
149
150
151
        }
    }
}

152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
/// Wrapper to hold Python callback and its TaskLocals for async execution
#[derive(Clone)]
struct PyEngineFactory {
    callback: Arc<PyObject>,
    locals: Arc<TaskLocals>,
}

impl std::fmt::Debug for PyEngineFactory {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("PyEngineFactory")
            .field("callback", &"<PyObject>")
            .finish()
    }
}

167
168
169
170
171
172
173
174
175
#[pyclass]
#[derive(Clone, Debug)]
pub(crate) struct EntrypointArgs {
    engine_type: EngineType,
    model_path: Option<PathBuf>,
    model_name: Option<String>,
    endpoint_id: Option<EndpointId>,
    context_length: Option<u32>,
    template_file: Option<PathBuf>,
176
    router_config: Option<RouterConfig>,
177
    kv_cache_block_size: Option<u32>,
178
    http_host: Option<String>,
Graham King's avatar
Graham King committed
179
    http_port: u16,
180
    http_metrics_port: Option<u16>,
Graham King's avatar
Graham King committed
181
182
    tls_cert_path: Option<PathBuf>,
    tls_key_path: Option<PathBuf>,
183
    extra_engine_args: Option<PathBuf>,
184
    namespace: Option<String>,
185
    is_prefill: bool,
186
    migration_limit: u32,
187
    chat_engine_factory: Option<PyEngineFactory>,
188
189
190
191
192
193
}

#[pymethods]
impl EntrypointArgs {
    #[allow(clippy::too_many_arguments)]
    #[new]
194
    #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
195
    pub fn new(
196
        py: Python<'_>,
197
198
199
200
201
202
        engine_type: EngineType,
        model_path: Option<PathBuf>,
        model_name: Option<String>, // e.g. "dyn://namespace.component.endpoint"
        endpoint_id: Option<String>,
        context_length: Option<u32>,
        template_file: Option<PathBuf>,
203
        router_config: Option<RouterConfig>,
204
        kv_cache_block_size: Option<u32>,
205
        http_host: Option<String>,
206
        http_port: Option<u16>,
207
        http_metrics_port: Option<u16>,
Graham King's avatar
Graham King committed
208
209
        tls_cert_path: Option<PathBuf>,
        tls_key_path: Option<PathBuf>,
210
        extra_engine_args: Option<PathBuf>,
211
        namespace: Option<String>,
212
        is_prefill: bool,
213
        migration_limit: u32,
214
        chat_engine_factory: Option<PyObject>,
215
    ) -> PyResult<Self> {
216
        let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
Graham King's avatar
Graham King committed
217
218
219
220
221
222
223
        if (tls_cert_path.is_some() && tls_key_path.is_none())
            || (tls_cert_path.is_none() && tls_key_path.is_some())
        {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "tls_cert_path and tls_key_path must be provided together",
            ));
        }
224

225
226
        // Capture TaskLocals at registration time for the chat engine factory callback
        let chat_engine_factory = chat_engine_factory
227
228
229
            .map(|callback| {
                let locals = pyo3_async_runtimes::tokio::get_current_locals(py).map_err(|e| {
                    pyo3::exceptions::PyRuntimeError::new_err(format!(
230
                        "Failed to get TaskLocals for chat_engine_factory: {}",
231
232
233
234
235
236
237
238
239
240
                        e
                    ))
                })?;
                Ok::<_, PyErr>(PyEngineFactory {
                    callback: Arc::new(callback),
                    locals: Arc::new(locals),
                })
            })
            .transpose()?;

241
242
243
244
245
246
247
        Ok(EntrypointArgs {
            engine_type,
            model_path,
            model_name,
            endpoint_id: endpoint_id_obj,
            context_length,
            template_file,
248
            router_config,
249
            kv_cache_block_size,
250
            http_host,
Graham King's avatar
Graham King committed
251
            http_port: http_port.unwrap_or(DEFAULT_HTTP_PORT),
252
            http_metrics_port,
Graham King's avatar
Graham King committed
253
254
            tls_cert_path,
            tls_key_path,
255
            extra_engine_args,
256
            namespace,
257
            is_prefill,
258
            migration_limit,
259
            chat_engine_factory,
260
261
262
263
264
265
266
267
268
269
        })
    }
}

#[pyclass]
#[derive(Clone)]
pub(crate) struct EngineConfig {
    inner: RsEngineConfig,
}

270
271
/// Create the backend engine wrapper to run the model.
/// Download the model if necessary.
272
273
274
275
276
277
278
279
280
#[pyfunction]
#[pyo3(signature = (distributed_runtime, args))]
pub fn make_engine<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    args: EntrypointArgs,
) -> PyResult<Bound<'p, PyAny>> {
    let mut builder = LocalModelBuilder::default();
    builder
281
282
283
284
285
        .model_name(
            args.model_name
                .clone()
                .or_else(|| args.model_path.clone().map(|p| p.display().to_string())),
        )
286
        .endpoint_id(args.endpoint_id.clone())
287
        .context_length(args.context_length)
288
        .request_template(args.template_file.clone())
289
        .kv_cache_block_size(args.kv_cache_block_size)
290
        .router_config(args.router_config.clone().map(|rc| rc.into()))
291
        .migration_limit(Some(args.migration_limit))
292
        .http_host(args.http_host.clone())
293
        .http_port(args.http_port)
294
        .http_metrics_port(args.http_metrics_port)
Graham King's avatar
Graham King committed
295
296
        .tls_cert_path(args.tls_cert_path.clone())
        .tls_key_path(args.tls_key_path.clone())
297
        .is_mocker(matches!(args.engine_type, EngineType::Mocker))
298
        .extra_engine_args(args.extra_engine_args.clone())
299
        .namespace(args.namespace.clone());
300
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
301
302
303
304
        if let Some(model_path) = args.model_path.clone() {
            let local_path = if model_path.exists() {
                model_path
            } else {
305
306
307
                // Mocker only needs tokenizer, not weights
                let ignore_weights = matches!(args.engine_type, EngineType::Mocker);
                LocalModel::fetch(&model_path.display().to_string(), ignore_weights)
308
309
310
311
312
313
                    .await
                    .map_err(to_pyerr)?
            };
            builder.model_path(local_path);
        }

314
        let local_model = builder.build().await.map_err(to_pyerr)?;
315
        let inner = select_engine(distributed_runtime, args, local_model)
316
317
318
319
320
321
            .await
            .map_err(to_pyerr)?;
        Ok(EngineConfig { inner })
    })
}

322
323
/// Convert a PyEngineFactory to a Rust ChatEngineFactoryCallback
fn py_engine_factory_to_callback(factory: PyEngineFactory) -> ChatEngineFactoryCallback {
324
325
326
327
    let callback = factory.callback;
    let locals = factory.locals;

    Arc::new(
328
329
330
        move |instance_id: RsModelCardInstanceId,
              card: RsModelDeploymentCard|
              -> Pin<
331
332
333
334
335
336
337
338
            Box<dyn Future<Output = anyhow::Result<OpenAIChatCompletionsStreamingEngine>> + Send>,
        > {
            let callback = callback.clone();
            let locals = locals.clone();

            Box::pin(async move {
                // Acquire GIL to call Python callback and convert coroutine to future
                let py_future = Python::with_gil(|py| {
339
340
341
342
                    let py_instance_id =
                        Py::new(py, crate::ModelCardInstanceId { inner: instance_id }).map_err(
                            |e| anyhow::anyhow!("Failed to create Python ModelCardInstanceId: {e}"),
                        )?;
343
344
345
                    // Create Python ModelDeploymentCard wrapper
                    let py_card = ModelDeploymentCard { inner: card };
                    let py_card_obj = Py::new(py, py_card)
346
                        .map_err(|e| anyhow::anyhow!("Failed to create Python MDC: {e}"))?;
347
348
349

                    // Call Python async function to get a coroutine
                    let coroutine = callback
350
351
                        .call1(py, (py_instance_id, py_card_obj))
                        .map_err(|e| anyhow::anyhow!("Failed to call chat_engine_factory: {e}"))?;
352
353
354

                    // Use the TaskLocals captured at registration time
                    pyo3_async_runtimes::into_future_with_locals(&locals, coroutine.into_bound(py))
355
                        .map_err(|e| anyhow::anyhow!("Failed to convert coroutine to future: {e}"))
356
357
358
359
360
                })?;

                // Await the Python coroutine (GIL is released during await)
                let py_result = py_future
                    .await
361
                    .map_err(|e| anyhow::anyhow!("chat_engine_factory callback failed: {}", e))?;
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376

                // Extract PythonAsyncEngine from the Python result and wrap in Arc
                let engine: OpenAIChatCompletionsStreamingEngine = Python::with_gil(|py| {
                    let engine: PythonAsyncEngine = py_result.extract(py).map_err(|e| {
                        anyhow::anyhow!("Failed to extract PythonAsyncEngine: {}", e)
                    })?;
                    Ok::<_, anyhow::Error>(Arc::new(engine))
                })?;

                Ok(engine)
            })
        },
    )
}

377
378
async fn select_engine(
    #[allow(unused_variables)] distributed_runtime: super::DistributedRuntime,
379
    args: EntrypointArgs,
380
381
    local_model: LocalModel,
) -> anyhow::Result<RsEngineConfig> {
382
    let inner = match args.engine_type {
383
384
        EngineType::Echo => {
            // There is no validation for the echo engine
385
            RsEngineConfig::InProcessText {
386
                model: Box::new(local_model),
387
                engine: dynamo_llm::engines::make_echo_engine(),
388
389
            }
        }
390
        EngineType::Dynamic => {
391
392
            //  Convert Python chat engine factory to Rust callback
            let chat_engine_factory = args.chat_engine_factory.map(py_engine_factory_to_callback);
393
394
            RsEngineConfig::Dynamic {
                model: Box::new(local_model),
395
                chat_engine_factory,
396
397
            }
        }
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
        EngineType::Mocker => {
            let mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
                MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
                    anyhow::anyhow!(
                        "Failed to load mocker args from {:?}: {}",
                        extra_args_path,
                        e
                    )
                })?
            } else {
                tracing::warn!(
                    "No extra_engine_args specified for mocker engine. Using default mocker args."
                );
                MockEngineArgs::default()
            };

            let endpoint = local_model.endpoint_id().clone();

416
            let engine = dynamo_llm::mocker::make_mocker_engine(
417
418
419
420
421
422
                distributed_runtime.inner,
                endpoint,
                mocker_args,
            )
            .await?;

423
            RsEngineConfig::InProcessTokens {
424
425
                engine,
                model: Box::new(local_model),
426
                is_prefill: args.is_prefill,
427
428
            }
        }
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
    };

    Ok(inner)
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, input, engine_config))]
pub fn run_input<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    input: &str,
    engine_config: EngineConfig,
) -> PyResult<Bound<'p, PyAny>> {
    let input_enum: Input = input.parse().map_err(to_pyerr)?;
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        dynamo_llm::entrypoint::input::run_input(
445
            distributed_runtime.inner.clone(),
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
            input_enum,
            engine_config.inner,
        )
        .await
        .map_err(to_pyerr)?;
        Ok(())
    })
}

pub fn to_pyerr<E>(err: E) -> PyErr
where
    E: Display,
{
    PyException::new_err(format!("{}", err))
}