entrypoint.rs 10.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::fmt::Display;
use std::path::PathBuf;

use pyo3::{exceptions::PyException, prelude::*};

use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
10
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
11
use dynamo_llm::entrypoint::input::Input;
12
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
Graham King's avatar
Graham King committed
13
use dynamo_llm::local_model::DEFAULT_HTTP_PORT;
14
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
15
use dynamo_llm::mocker::protocols::MockEngineArgs;
16
use dynamo_runtime::protocols::EndpointId;
17

18
19
use crate::RouterMode;

20
21
22
23
24
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
#[repr(i32)]
pub enum EngineType {
    Echo = 1,
25
26
    Dynamic = 2,
    Mocker = 3,
27
    Static = 4,
28
29
}

30
31
32
33
34
35
#[pyclass]
#[derive(Default, Clone, Debug, Copy)]
pub struct KvRouterConfig {
    inner: RsKvRouterConfig,
}

36
37
38
39
40
41
impl KvRouterConfig {
    pub fn inner(&self) -> RsKvRouterConfig {
        self.inner
    }
}

42
43
44
#[pymethods]
impl KvRouterConfig {
    #[new]
45
    #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=1000000, router_reset_states=false))]
46
47
48
49
50
    fn new(
        overlap_score_weight: f64,
        router_temperature: f64,
        use_kv_events: bool,
        router_replica_sync: bool,
51
        router_track_active_blocks: bool,
52
53
        router_snapshot_threshold: Option<u32>,
        router_reset_states: bool,
54
    ) -> Self {
55
56
57
58
59
        KvRouterConfig {
            inner: RsKvRouterConfig {
                overlap_score_weight,
                router_temperature,
                use_kv_events,
60
                router_replica_sync,
61
                router_track_active_blocks,
62
63
                router_snapshot_threshold,
                router_reset_states,
64
65
66
67
68
69
70
71
72
73
            },
        }
    }
}

#[pyclass]
#[derive(Clone, Debug)]
pub struct RouterConfig {
    router_mode: RouterMode,
    kv_router_config: KvRouterConfig,
74
    busy_threshold: Option<f64>,
75
76
77
78
79
}

#[pymethods]
impl RouterConfig {
    #[new]
80
81
82
83
84
85
    #[pyo3(signature = (mode, config=None, busy_threshold=None))]
    pub fn new(
        mode: RouterMode,
        config: Option<KvRouterConfig>,
        busy_threshold: Option<f64>,
    ) -> Self {
86
87
88
        Self {
            router_mode: mode,
            kv_router_config: config.unwrap_or_default(),
89
            busy_threshold,
90
91
92
93
94
95
96
97
98
        }
    }
}

impl From<RouterConfig> for RsRouterConfig {
    fn from(rc: RouterConfig) -> RsRouterConfig {
        RsRouterConfig {
            router_mode: rc.router_mode.into(),
            kv_router_config: rc.kv_router_config.inner,
99
            busy_threshold: rc.busy_threshold,
100
101
102
103
        }
    }
}

104
105
106
107
108
109
110
111
112
#[pyclass]
#[derive(Clone, Debug)]
pub(crate) struct EntrypointArgs {
    engine_type: EngineType,
    model_path: Option<PathBuf>,
    model_name: Option<String>,
    endpoint_id: Option<EndpointId>,
    context_length: Option<u32>,
    template_file: Option<PathBuf>,
113
    router_config: Option<RouterConfig>,
114
    kv_cache_block_size: Option<u32>,
115
    http_host: Option<String>,
Graham King's avatar
Graham King committed
116
117
118
    http_port: u16,
    tls_cert_path: Option<PathBuf>,
    tls_key_path: Option<PathBuf>,
119
    extra_engine_args: Option<PathBuf>,
120
    namespace: Option<String>,
121
122
    custom_backend_metrics_endpoint: Option<String>,
    custom_backend_metrics_polling_interval: Option<f64>,
123
124
125
126
127
128
}

#[pymethods]
impl EntrypointArgs {
    #[allow(clippy::too_many_arguments)]
    #[new]
129
    #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, custom_backend_metrics_endpoint=None, custom_backend_metrics_polling_interval=None))]
130
131
132
133
134
135
136
    pub fn new(
        engine_type: EngineType,
        model_path: Option<PathBuf>,
        model_name: Option<String>, // e.g. "dyn://namespace.component.endpoint"
        endpoint_id: Option<String>,
        context_length: Option<u32>,
        template_file: Option<PathBuf>,
137
        router_config: Option<RouterConfig>,
138
        kv_cache_block_size: Option<u32>,
139
        http_host: Option<String>,
140
        http_port: Option<u16>,
Graham King's avatar
Graham King committed
141
142
        tls_cert_path: Option<PathBuf>,
        tls_key_path: Option<PathBuf>,
143
        extra_engine_args: Option<PathBuf>,
144
        namespace: Option<String>,
145
146
        custom_backend_metrics_endpoint: Option<String>,
        custom_backend_metrics_polling_interval: Option<f64>,
147
    ) -> PyResult<Self> {
148
        let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
Graham King's avatar
Graham King committed
149
150
151
152
153
154
155
        if (tls_cert_path.is_some() && tls_key_path.is_none())
            || (tls_cert_path.is_none() && tls_key_path.is_some())
        {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "tls_cert_path and tls_key_path must be provided together",
            ));
        }
156
157
158
159
160
161
162
        Ok(EntrypointArgs {
            engine_type,
            model_path,
            model_name,
            endpoint_id: endpoint_id_obj,
            context_length,
            template_file,
163
            router_config,
164
            kv_cache_block_size,
165
            http_host,
Graham King's avatar
Graham King committed
166
167
168
            http_port: http_port.unwrap_or(DEFAULT_HTTP_PORT),
            tls_cert_path,
            tls_key_path,
169
            extra_engine_args,
170
            namespace,
171
172
            custom_backend_metrics_endpoint,
            custom_backend_metrics_polling_interval,
173
174
175
176
177
178
179
180
181
182
        })
    }
}

#[pyclass]
#[derive(Clone)]
pub(crate) struct EngineConfig {
    inner: RsEngineConfig,
}

183
184
/// Create the backend engine wrapper to run the model.
/// Download the model if necessary.
185
186
187
188
189
190
191
192
193
#[pyfunction]
#[pyo3(signature = (distributed_runtime, args))]
pub fn make_engine<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    args: EntrypointArgs,
) -> PyResult<Bound<'p, PyAny>> {
    let mut builder = LocalModelBuilder::default();
    builder
194
195
196
197
198
        .model_name(
            args.model_name
                .clone()
                .or_else(|| args.model_path.clone().map(|p| p.display().to_string())),
        )
199
        .endpoint_id(args.endpoint_id.clone())
200
        .context_length(args.context_length)
201
        .request_template(args.template_file.clone())
202
        .kv_cache_block_size(args.kv_cache_block_size)
203
        .router_config(args.router_config.clone().map(|rc| rc.into()))
204
        .http_host(args.http_host.clone())
205
        .http_port(args.http_port)
Graham King's avatar
Graham King committed
206
207
        .tls_cert_path(args.tls_cert_path.clone())
        .tls_key_path(args.tls_key_path.clone())
208
        .is_mocker(matches!(args.engine_type, EngineType::Mocker))
209
        .extra_engine_args(args.extra_engine_args.clone())
210
211
212
        .namespace(args.namespace.clone())
        .custom_backend_metrics_endpoint(args.custom_backend_metrics_endpoint.clone())
        .custom_backend_metrics_polling_interval(args.custom_backend_metrics_polling_interval);
213
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
214
215
216
217
218
219
220
221
222
223
224
        if let Some(model_path) = args.model_path.clone() {
            let local_path = if model_path.exists() {
                model_path
            } else {
                LocalModel::fetch(&model_path.display().to_string(), false)
                    .await
                    .map_err(to_pyerr)?
            };
            builder.model_path(local_path);
        }

225
        let local_model = builder.build().await.map_err(to_pyerr)?;
226
        let inner = select_engine(distributed_runtime, args, local_model)
227
228
229
230
231
232
233
234
            .await
            .map_err(to_pyerr)?;
        Ok(EngineConfig { inner })
    })
}

async fn select_engine(
    #[allow(unused_variables)] distributed_runtime: super::DistributedRuntime,
235
    args: EntrypointArgs,
236
237
    local_model: LocalModel,
) -> anyhow::Result<RsEngineConfig> {
238
    let inner = match args.engine_type {
239
240
241
242
        EngineType::Echo => {
            // There is no validation for the echo engine
            RsEngineConfig::StaticFull {
                model: Box::new(local_model),
243
                engine: dynamo_llm::engines::make_echo_engine(),
244
                is_static: false,
245
246
247
            }
        }
        EngineType::Dynamic => RsEngineConfig::Dynamic(Box::new(local_model)),
248
        EngineType::Static => RsEngineConfig::StaticRemote(Box::new(local_model)),
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
        EngineType::Mocker => {
            let mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
                MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
                    anyhow::anyhow!(
                        "Failed to load mocker args from {:?}: {}",
                        extra_args_path,
                        e
                    )
                })?
            } else {
                tracing::warn!(
                    "No extra_engine_args specified for mocker engine. Using default mocker args."
                );
                MockEngineArgs::default()
            };

            let endpoint = local_model.endpoint_id().clone();

            let engine = dynamo_llm::mocker::engine::make_mocker_engine(
                distributed_runtime.inner,
                endpoint,
                mocker_args,
            )
            .await?;

            RsEngineConfig::StaticCore {
                engine,
                model: Box::new(local_model),
277
                is_static: false,
278
279
            }
        }
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    };

    Ok(inner)
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, input, engine_config))]
pub fn run_input<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    input: &str,
    engine_config: EngineConfig,
) -> PyResult<Bound<'p, PyAny>> {
    let input_enum: Input = input.parse().map_err(to_pyerr)?;
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        dynamo_llm::entrypoint::input::run_input(
            either::Either::Right(distributed_runtime.inner.clone()),
            input_enum,
            engine_config.inner,
        )
        .await
        .map_err(to_pyerr)?;
        Ok(())
    })
}

pub fn to_pyerr<E>(err: E) -> PyErr
where
    E: Display,
{
    PyException::new_err(format!("{}", err))
}