entrypoint.rs 10.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::fmt::Display;
use std::path::PathBuf;

use pyo3::{exceptions::PyException, prelude::*};

use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
10
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
11
use dynamo_llm::entrypoint::input::Input;
12
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
Graham King's avatar
Graham King committed
13
use dynamo_llm::local_model::DEFAULT_HTTP_PORT;
14
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
15
use dynamo_llm::mocker::protocols::MockEngineArgs;
16
use dynamo_runtime::protocols::EndpointId;
17

18
19
use crate::RouterMode;

20
21
22
23
24
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
#[repr(i32)]
pub enum EngineType {
    Echo = 1,
25
26
    Dynamic = 2,
    Mocker = 3,
27
28
}

29
30
31
32
33
34
#[pyclass]
#[derive(Default, Clone, Debug, Copy)]
pub struct KvRouterConfig {
    inner: RsKvRouterConfig,
}

35
36
37
38
39
40
impl KvRouterConfig {
    pub fn inner(&self) -> RsKvRouterConfig {
        self.inner
    }
}

41
42
43
#[pymethods]
impl KvRouterConfig {
    #[new]
44
45
    #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=1000000, router_reset_states=false, router_ttl_secs=120.0, router_max_tree_size=1024, router_prune_target_ratio=0.8))]
    #[allow(clippy::too_many_arguments)]
46
47
48
49
50
    fn new(
        overlap_score_weight: f64,
        router_temperature: f64,
        use_kv_events: bool,
        router_replica_sync: bool,
51
        router_track_active_blocks: bool,
52
53
        router_snapshot_threshold: Option<u32>,
        router_reset_states: bool,
54
55
56
        router_ttl_secs: f64,
        router_max_tree_size: usize,
        router_prune_target_ratio: f64,
57
    ) -> Self {
58
59
60
61
62
        KvRouterConfig {
            inner: RsKvRouterConfig {
                overlap_score_weight,
                router_temperature,
                use_kv_events,
63
                router_replica_sync,
64
                router_track_active_blocks,
65
66
                router_snapshot_threshold,
                router_reset_states,
67
68
69
                router_ttl_secs,
                router_max_tree_size,
                router_prune_target_ratio,
70
71
72
73
74
75
76
77
78
79
            },
        }
    }
}

#[pyclass]
#[derive(Clone, Debug)]
pub struct RouterConfig {
    router_mode: RouterMode,
    kv_router_config: KvRouterConfig,
80
    busy_threshold: Option<f64>,
81
    enforce_disagg: bool,
82
83
84
85
86
}

#[pymethods]
impl RouterConfig {
    #[new]
87
    #[pyo3(signature = (mode, config=None, busy_threshold=None, enforce_disagg=false))]
88
89
90
91
    pub fn new(
        mode: RouterMode,
        config: Option<KvRouterConfig>,
        busy_threshold: Option<f64>,
92
        enforce_disagg: bool,
93
    ) -> Self {
94
95
96
        Self {
            router_mode: mode,
            kv_router_config: config.unwrap_or_default(),
97
            busy_threshold,
98
            enforce_disagg,
99
100
101
102
103
104
105
106
107
        }
    }
}

impl From<RouterConfig> for RsRouterConfig {
    fn from(rc: RouterConfig) -> RsRouterConfig {
        RsRouterConfig {
            router_mode: rc.router_mode.into(),
            kv_router_config: rc.kv_router_config.inner,
108
            busy_threshold: rc.busy_threshold,
109
            enforce_disagg: rc.enforce_disagg,
110
111
112
113
        }
    }
}

114
115
116
117
118
119
120
121
122
#[pyclass]
#[derive(Clone, Debug)]
pub(crate) struct EntrypointArgs {
    engine_type: EngineType,
    model_path: Option<PathBuf>,
    model_name: Option<String>,
    endpoint_id: Option<EndpointId>,
    context_length: Option<u32>,
    template_file: Option<PathBuf>,
123
    router_config: Option<RouterConfig>,
124
    kv_cache_block_size: Option<u32>,
125
    http_host: Option<String>,
Graham King's avatar
Graham King committed
126
127
128
    http_port: u16,
    tls_cert_path: Option<PathBuf>,
    tls_key_path: Option<PathBuf>,
129
    extra_engine_args: Option<PathBuf>,
130
    namespace: Option<String>,
131
132
    custom_backend_metrics_endpoint: Option<String>,
    custom_backend_metrics_polling_interval: Option<f64>,
133
    is_prefill: bool,
134
135
136
137
138
139
}

#[pymethods]
impl EntrypointArgs {
    #[allow(clippy::too_many_arguments)]
    #[new]
140
    #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, custom_backend_metrics_endpoint=None, custom_backend_metrics_polling_interval=None, is_prefill=false))]
141
142
143
144
145
146
147
    pub fn new(
        engine_type: EngineType,
        model_path: Option<PathBuf>,
        model_name: Option<String>, // e.g. "dyn://namespace.component.endpoint"
        endpoint_id: Option<String>,
        context_length: Option<u32>,
        template_file: Option<PathBuf>,
148
        router_config: Option<RouterConfig>,
149
        kv_cache_block_size: Option<u32>,
150
        http_host: Option<String>,
151
        http_port: Option<u16>,
Graham King's avatar
Graham King committed
152
153
        tls_cert_path: Option<PathBuf>,
        tls_key_path: Option<PathBuf>,
154
        extra_engine_args: Option<PathBuf>,
155
        namespace: Option<String>,
156
157
        custom_backend_metrics_endpoint: Option<String>,
        custom_backend_metrics_polling_interval: Option<f64>,
158
        is_prefill: bool,
159
    ) -> PyResult<Self> {
160
        let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
Graham King's avatar
Graham King committed
161
162
163
164
165
166
167
        if (tls_cert_path.is_some() && tls_key_path.is_none())
            || (tls_cert_path.is_none() && tls_key_path.is_some())
        {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "tls_cert_path and tls_key_path must be provided together",
            ));
        }
168
169
170
171
172
173
174
        Ok(EntrypointArgs {
            engine_type,
            model_path,
            model_name,
            endpoint_id: endpoint_id_obj,
            context_length,
            template_file,
175
            router_config,
176
            kv_cache_block_size,
177
            http_host,
Graham King's avatar
Graham King committed
178
179
180
            http_port: http_port.unwrap_or(DEFAULT_HTTP_PORT),
            tls_cert_path,
            tls_key_path,
181
            extra_engine_args,
182
            namespace,
183
184
            custom_backend_metrics_endpoint,
            custom_backend_metrics_polling_interval,
185
            is_prefill,
186
187
188
189
190
191
192
193
194
195
        })
    }
}

#[pyclass]
#[derive(Clone)]
pub(crate) struct EngineConfig {
    inner: RsEngineConfig,
}

196
197
/// Create the backend engine wrapper to run the model.
/// Download the model if necessary.
198
199
200
201
202
203
204
205
206
#[pyfunction]
#[pyo3(signature = (distributed_runtime, args))]
pub fn make_engine<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    args: EntrypointArgs,
) -> PyResult<Bound<'p, PyAny>> {
    let mut builder = LocalModelBuilder::default();
    builder
207
208
209
210
211
        .model_name(
            args.model_name
                .clone()
                .or_else(|| args.model_path.clone().map(|p| p.display().to_string())),
        )
212
        .endpoint_id(args.endpoint_id.clone())
213
        .context_length(args.context_length)
214
        .request_template(args.template_file.clone())
215
        .kv_cache_block_size(args.kv_cache_block_size)
216
        .router_config(args.router_config.clone().map(|rc| rc.into()))
217
        .http_host(args.http_host.clone())
218
        .http_port(args.http_port)
Graham King's avatar
Graham King committed
219
220
        .tls_cert_path(args.tls_cert_path.clone())
        .tls_key_path(args.tls_key_path.clone())
221
        .is_mocker(matches!(args.engine_type, EngineType::Mocker))
222
        .extra_engine_args(args.extra_engine_args.clone())
223
224
225
        .namespace(args.namespace.clone())
        .custom_backend_metrics_endpoint(args.custom_backend_metrics_endpoint.clone())
        .custom_backend_metrics_polling_interval(args.custom_backend_metrics_polling_interval);
226
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
227
228
229
230
        if let Some(model_path) = args.model_path.clone() {
            let local_path = if model_path.exists() {
                model_path
            } else {
231
232
233
                // Mocker only needs tokenizer, not weights
                let ignore_weights = matches!(args.engine_type, EngineType::Mocker);
                LocalModel::fetch(&model_path.display().to_string(), ignore_weights)
234
235
236
237
238
239
                    .await
                    .map_err(to_pyerr)?
            };
            builder.model_path(local_path);
        }

240
        let local_model = builder.build().await.map_err(to_pyerr)?;
241
        let inner = select_engine(distributed_runtime, args, local_model)
242
243
244
245
246
247
248
249
            .await
            .map_err(to_pyerr)?;
        Ok(EngineConfig { inner })
    })
}

async fn select_engine(
    #[allow(unused_variables)] distributed_runtime: super::DistributedRuntime,
250
    args: EntrypointArgs,
251
252
    local_model: LocalModel,
) -> anyhow::Result<RsEngineConfig> {
253
    let inner = match args.engine_type {
254
255
256
257
        EngineType::Echo => {
            // There is no validation for the echo engine
            RsEngineConfig::StaticFull {
                model: Box::new(local_model),
258
                engine: dynamo_llm::engines::make_echo_engine(),
259
260
261
            }
        }
        EngineType::Dynamic => RsEngineConfig::Dynamic(Box::new(local_model)),
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        EngineType::Mocker => {
            let mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
                MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
                    anyhow::anyhow!(
                        "Failed to load mocker args from {:?}: {}",
                        extra_args_path,
                        e
                    )
                })?
            } else {
                tracing::warn!(
                    "No extra_engine_args specified for mocker engine. Using default mocker args."
                );
                MockEngineArgs::default()
            };

            let endpoint = local_model.endpoint_id().clone();

            let engine = dynamo_llm::mocker::engine::make_mocker_engine(
                distributed_runtime.inner,
                endpoint,
                mocker_args,
            )
            .await?;

            RsEngineConfig::StaticCore {
                engine,
                model: Box::new(local_model),
290
                is_prefill: args.is_prefill,
291
292
            }
        }
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    };

    Ok(inner)
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, input, engine_config))]
pub fn run_input<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    input: &str,
    engine_config: EngineConfig,
) -> PyResult<Bound<'p, PyAny>> {
    let input_enum: Input = input.parse().map_err(to_pyerr)?;
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        dynamo_llm::entrypoint::input::run_input(
309
            distributed_runtime.inner.clone(),
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
            input_enum,
            engine_config.inner,
        )
        .await
        .map_err(to_pyerr)?;
        Ok(())
    })
}

pub fn to_pyerr<E>(err: E) -> PyErr
where
    E: Display,
{
    PyException::new_err(format!("{}", err))
}