entrypoint.rs 10.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::fmt::Display;
use std::path::PathBuf;

use pyo3::{exceptions::PyException, prelude::*};

use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
10
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
11
use dynamo_llm::entrypoint::input::Input;
12
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
Graham King's avatar
Graham King committed
13
use dynamo_llm::local_model::DEFAULT_HTTP_PORT;
14
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
15
use dynamo_llm::mocker::protocols::MockEngineArgs;
16
use dynamo_runtime::protocols::EndpointId;
17

18
19
use crate::RouterMode;

20
21
22
23
24
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
#[repr(i32)]
pub enum EngineType {
    Echo = 1,
25
26
    Dynamic = 2,
    Mocker = 3,
27
    Static = 4,
28
29
}

30
31
32
33
34
35
#[pyclass]
#[derive(Default, Clone, Debug, Copy)]
pub struct KvRouterConfig {
    inner: RsKvRouterConfig,
}

36
37
38
39
40
41
impl KvRouterConfig {
    pub fn inner(&self) -> RsKvRouterConfig {
        self.inner
    }
}

42
43
44
#[pymethods]
impl KvRouterConfig {
    #[new]
45
    #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=1000000, router_reset_states=false))]
46
47
48
49
50
    fn new(
        overlap_score_weight: f64,
        router_temperature: f64,
        use_kv_events: bool,
        router_replica_sync: bool,
51
        router_track_active_blocks: bool,
52
53
        router_snapshot_threshold: Option<u32>,
        router_reset_states: bool,
54
    ) -> Self {
55
56
57
58
59
        KvRouterConfig {
            inner: RsKvRouterConfig {
                overlap_score_weight,
                router_temperature,
                use_kv_events,
60
                router_replica_sync,
61
                router_track_active_blocks,
62
63
                router_snapshot_threshold,
                router_reset_states,
64
65
66
67
68
69
70
71
72
73
            },
        }
    }
}

#[pyclass]
#[derive(Clone, Debug)]
pub struct RouterConfig {
    router_mode: RouterMode,
    kv_router_config: KvRouterConfig,
74
    busy_threshold: Option<f64>,
75
76
77
78
79
}

#[pymethods]
impl RouterConfig {
    #[new]
80
81
82
83
84
85
    #[pyo3(signature = (mode, config=None, busy_threshold=None))]
    pub fn new(
        mode: RouterMode,
        config: Option<KvRouterConfig>,
        busy_threshold: Option<f64>,
    ) -> Self {
86
87
88
        Self {
            router_mode: mode,
            kv_router_config: config.unwrap_or_default(),
89
            busy_threshold,
90
91
92
93
94
95
96
97
98
        }
    }
}

impl From<RouterConfig> for RsRouterConfig {
    fn from(rc: RouterConfig) -> RsRouterConfig {
        RsRouterConfig {
            router_mode: rc.router_mode.into(),
            kv_router_config: rc.kv_router_config.inner,
99
            busy_threshold: rc.busy_threshold,
100
101
102
103
        }
    }
}

104
105
106
107
108
109
110
111
112
#[pyclass]
#[derive(Clone, Debug)]
pub(crate) struct EntrypointArgs {
    engine_type: EngineType,
    model_path: Option<PathBuf>,
    model_name: Option<String>,
    endpoint_id: Option<EndpointId>,
    context_length: Option<u32>,
    template_file: Option<PathBuf>,
113
    router_config: Option<RouterConfig>,
114
    kv_cache_block_size: Option<u32>,
115
    http_host: Option<String>,
Graham King's avatar
Graham King committed
116
117
118
    http_port: u16,
    tls_cert_path: Option<PathBuf>,
    tls_key_path: Option<PathBuf>,
119
    extra_engine_args: Option<PathBuf>,
120
    namespace: Option<String>,
121
122
    custom_backend_metrics_endpoint: Option<String>,
    custom_backend_metrics_polling_interval: Option<f64>,
123
    is_prefill: bool,
124
125
126
127
128
129
}

#[pymethods]
impl EntrypointArgs {
    #[allow(clippy::too_many_arguments)]
    #[new]
130
    #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, custom_backend_metrics_endpoint=None, custom_backend_metrics_polling_interval=None, is_prefill=false))]
131
132
133
134
135
136
137
    pub fn new(
        engine_type: EngineType,
        model_path: Option<PathBuf>,
        model_name: Option<String>, // e.g. "dyn://namespace.component.endpoint"
        endpoint_id: Option<String>,
        context_length: Option<u32>,
        template_file: Option<PathBuf>,
138
        router_config: Option<RouterConfig>,
139
        kv_cache_block_size: Option<u32>,
140
        http_host: Option<String>,
141
        http_port: Option<u16>,
Graham King's avatar
Graham King committed
142
143
        tls_cert_path: Option<PathBuf>,
        tls_key_path: Option<PathBuf>,
144
        extra_engine_args: Option<PathBuf>,
145
        namespace: Option<String>,
146
147
        custom_backend_metrics_endpoint: Option<String>,
        custom_backend_metrics_polling_interval: Option<f64>,
148
        is_prefill: bool,
149
    ) -> PyResult<Self> {
150
        let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
Graham King's avatar
Graham King committed
151
152
153
154
155
156
157
        if (tls_cert_path.is_some() && tls_key_path.is_none())
            || (tls_cert_path.is_none() && tls_key_path.is_some())
        {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "tls_cert_path and tls_key_path must be provided together",
            ));
        }
158
159
160
161
162
163
164
        Ok(EntrypointArgs {
            engine_type,
            model_path,
            model_name,
            endpoint_id: endpoint_id_obj,
            context_length,
            template_file,
165
            router_config,
166
            kv_cache_block_size,
167
            http_host,
Graham King's avatar
Graham King committed
168
169
170
            http_port: http_port.unwrap_or(DEFAULT_HTTP_PORT),
            tls_cert_path,
            tls_key_path,
171
            extra_engine_args,
172
            namespace,
173
174
            custom_backend_metrics_endpoint,
            custom_backend_metrics_polling_interval,
175
            is_prefill,
176
177
178
179
180
181
182
183
184
185
        })
    }
}

#[pyclass]
#[derive(Clone)]
pub(crate) struct EngineConfig {
    inner: RsEngineConfig,
}

186
187
/// Create the backend engine wrapper to run the model.
/// Download the model if necessary.
188
189
190
191
192
193
194
195
196
#[pyfunction]
#[pyo3(signature = (distributed_runtime, args))]
pub fn make_engine<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    args: EntrypointArgs,
) -> PyResult<Bound<'p, PyAny>> {
    let mut builder = LocalModelBuilder::default();
    builder
197
198
199
200
201
        .model_name(
            args.model_name
                .clone()
                .or_else(|| args.model_path.clone().map(|p| p.display().to_string())),
        )
202
        .endpoint_id(args.endpoint_id.clone())
203
        .context_length(args.context_length)
204
        .request_template(args.template_file.clone())
205
        .kv_cache_block_size(args.kv_cache_block_size)
206
        .router_config(args.router_config.clone().map(|rc| rc.into()))
207
        .http_host(args.http_host.clone())
208
        .http_port(args.http_port)
Graham King's avatar
Graham King committed
209
210
        .tls_cert_path(args.tls_cert_path.clone())
        .tls_key_path(args.tls_key_path.clone())
211
        .is_mocker(matches!(args.engine_type, EngineType::Mocker))
212
        .extra_engine_args(args.extra_engine_args.clone())
213
214
215
        .namespace(args.namespace.clone())
        .custom_backend_metrics_endpoint(args.custom_backend_metrics_endpoint.clone())
        .custom_backend_metrics_polling_interval(args.custom_backend_metrics_polling_interval);
216
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
217
218
219
220
        if let Some(model_path) = args.model_path.clone() {
            let local_path = if model_path.exists() {
                model_path
            } else {
221
222
223
                // Mocker only needs tokenizer, not weights
                let ignore_weights = matches!(args.engine_type, EngineType::Mocker);
                LocalModel::fetch(&model_path.display().to_string(), ignore_weights)
224
225
226
227
228
229
                    .await
                    .map_err(to_pyerr)?
            };
            builder.model_path(local_path);
        }

230
        let local_model = builder.build().await.map_err(to_pyerr)?;
231
        let inner = select_engine(distributed_runtime, args, local_model)
232
233
234
235
236
237
238
239
            .await
            .map_err(to_pyerr)?;
        Ok(EngineConfig { inner })
    })
}

async fn select_engine(
    #[allow(unused_variables)] distributed_runtime: super::DistributedRuntime,
240
    args: EntrypointArgs,
241
242
    local_model: LocalModel,
) -> anyhow::Result<RsEngineConfig> {
243
    let inner = match args.engine_type {
244
245
246
247
        EngineType::Echo => {
            // There is no validation for the echo engine
            RsEngineConfig::StaticFull {
                model: Box::new(local_model),
248
                engine: dynamo_llm::engines::make_echo_engine(),
249
                is_static: false,
250
251
252
            }
        }
        EngineType::Dynamic => RsEngineConfig::Dynamic(Box::new(local_model)),
253
        EngineType::Static => RsEngineConfig::StaticRemote(Box::new(local_model)),
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        EngineType::Mocker => {
            let mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
                MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
                    anyhow::anyhow!(
                        "Failed to load mocker args from {:?}: {}",
                        extra_args_path,
                        e
                    )
                })?
            } else {
                tracing::warn!(
                    "No extra_engine_args specified for mocker engine. Using default mocker args."
                );
                MockEngineArgs::default()
            };

            let endpoint = local_model.endpoint_id().clone();

            let engine = dynamo_llm::mocker::engine::make_mocker_engine(
                distributed_runtime.inner,
                endpoint,
                mocker_args,
            )
            .await?;

            RsEngineConfig::StaticCore {
                engine,
                model: Box::new(local_model),
282
                is_static: false,
283
                is_prefill: args.is_prefill,
284
285
            }
        }
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    };

    Ok(inner)
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, input, engine_config))]
pub fn run_input<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    input: &str,
    engine_config: EngineConfig,
) -> PyResult<Bound<'p, PyAny>> {
    let input_enum: Input = input.parse().map_err(to_pyerr)?;
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        dynamo_llm::entrypoint::input::run_input(
            either::Either::Right(distributed_runtime.inner.clone()),
            input_enum,
            engine_config.inner,
        )
        .await
        .map_err(to_pyerr)?;
        Ok(())
    })
}

pub fn to_pyerr<E>(err: E) -> PyErr
where
    E: Display,
{
    PyException::new_err(format!("{}", err))
}