entrypoint.rs 10.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::fmt::Display;
use std::path::PathBuf;

use pyo3::{exceptions::PyException, prelude::*};

use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
10
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
11
use dynamo_llm::entrypoint::input::Input;
12
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
Graham King's avatar
Graham King committed
13
use dynamo_llm::local_model::DEFAULT_HTTP_PORT;
14
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
15
use dynamo_llm::mocker::protocols::MockEngineArgs;
16
use dynamo_runtime::protocols::EndpointId;
17

18
19
use crate::RouterMode;

20
21
22
23
24
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
#[repr(i32)]
pub enum EngineType {
    Echo = 1,
25
26
    Dynamic = 2,
    Mocker = 3,
27
28
}

29
30
31
32
33
34
#[pyclass]
#[derive(Default, Clone, Debug, Copy)]
pub struct KvRouterConfig {
    inner: RsKvRouterConfig,
}

35
36
37
38
39
40
impl KvRouterConfig {
    pub fn inner(&self) -> RsKvRouterConfig {
        self.inner
    }
}

41
42
43
#[pymethods]
impl KvRouterConfig {
    #[new]
44
    #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=1000000, router_reset_states=false))]
45
46
47
48
49
    fn new(
        overlap_score_weight: f64,
        router_temperature: f64,
        use_kv_events: bool,
        router_replica_sync: bool,
50
        router_track_active_blocks: bool,
51
52
        router_snapshot_threshold: Option<u32>,
        router_reset_states: bool,
53
    ) -> Self {
54
55
56
57
58
        KvRouterConfig {
            inner: RsKvRouterConfig {
                overlap_score_weight,
                router_temperature,
                use_kv_events,
59
                router_replica_sync,
60
                router_track_active_blocks,
61
62
                router_snapshot_threshold,
                router_reset_states,
63
64
65
66
67
68
69
70
71
72
            },
        }
    }
}

#[pyclass]
#[derive(Clone, Debug)]
pub struct RouterConfig {
    router_mode: RouterMode,
    kv_router_config: KvRouterConfig,
73
    busy_threshold: Option<f64>,
74
75
76
77
78
}

#[pymethods]
impl RouterConfig {
    #[new]
79
80
81
82
83
84
    #[pyo3(signature = (mode, config=None, busy_threshold=None))]
    pub fn new(
        mode: RouterMode,
        config: Option<KvRouterConfig>,
        busy_threshold: Option<f64>,
    ) -> Self {
85
86
87
        Self {
            router_mode: mode,
            kv_router_config: config.unwrap_or_default(),
88
            busy_threshold,
89
90
91
92
93
94
95
96
97
        }
    }
}

impl From<RouterConfig> for RsRouterConfig {
    fn from(rc: RouterConfig) -> RsRouterConfig {
        RsRouterConfig {
            router_mode: rc.router_mode.into(),
            kv_router_config: rc.kv_router_config.inner,
98
            busy_threshold: rc.busy_threshold,
99
100
101
102
        }
    }
}

103
104
105
106
107
108
109
110
111
#[pyclass]
#[derive(Clone, Debug)]
pub(crate) struct EntrypointArgs {
    engine_type: EngineType,
    model_path: Option<PathBuf>,
    model_name: Option<String>,
    endpoint_id: Option<EndpointId>,
    context_length: Option<u32>,
    template_file: Option<PathBuf>,
112
    router_config: Option<RouterConfig>,
113
    kv_cache_block_size: Option<u32>,
114
    http_host: Option<String>,
Graham King's avatar
Graham King committed
115
116
117
    http_port: u16,
    tls_cert_path: Option<PathBuf>,
    tls_key_path: Option<PathBuf>,
118
    extra_engine_args: Option<PathBuf>,
119
    namespace: Option<String>,
120
121
    custom_backend_metrics_endpoint: Option<String>,
    custom_backend_metrics_polling_interval: Option<f64>,
122
    is_prefill: bool,
123
124
125
126
127
128
}

#[pymethods]
impl EntrypointArgs {
    #[allow(clippy::too_many_arguments)]
    #[new]
129
    #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, custom_backend_metrics_endpoint=None, custom_backend_metrics_polling_interval=None, is_prefill=false))]
130
131
132
133
134
135
136
    pub fn new(
        engine_type: EngineType,
        model_path: Option<PathBuf>,
        model_name: Option<String>, // e.g. "dyn://namespace.component.endpoint"
        endpoint_id: Option<String>,
        context_length: Option<u32>,
        template_file: Option<PathBuf>,
137
        router_config: Option<RouterConfig>,
138
        kv_cache_block_size: Option<u32>,
139
        http_host: Option<String>,
140
        http_port: Option<u16>,
Graham King's avatar
Graham King committed
141
142
        tls_cert_path: Option<PathBuf>,
        tls_key_path: Option<PathBuf>,
143
        extra_engine_args: Option<PathBuf>,
144
        namespace: Option<String>,
145
146
        custom_backend_metrics_endpoint: Option<String>,
        custom_backend_metrics_polling_interval: Option<f64>,
147
        is_prefill: bool,
148
    ) -> PyResult<Self> {
149
        let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
Graham King's avatar
Graham King committed
150
151
152
153
154
155
156
        if (tls_cert_path.is_some() && tls_key_path.is_none())
            || (tls_cert_path.is_none() && tls_key_path.is_some())
        {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "tls_cert_path and tls_key_path must be provided together",
            ));
        }
157
158
159
160
161
162
163
        Ok(EntrypointArgs {
            engine_type,
            model_path,
            model_name,
            endpoint_id: endpoint_id_obj,
            context_length,
            template_file,
164
            router_config,
165
            kv_cache_block_size,
166
            http_host,
Graham King's avatar
Graham King committed
167
168
169
            http_port: http_port.unwrap_or(DEFAULT_HTTP_PORT),
            tls_cert_path,
            tls_key_path,
170
            extra_engine_args,
171
            namespace,
172
173
            custom_backend_metrics_endpoint,
            custom_backend_metrics_polling_interval,
174
            is_prefill,
175
176
177
178
179
180
181
182
183
184
        })
    }
}

#[pyclass]
#[derive(Clone)]
pub(crate) struct EngineConfig {
    inner: RsEngineConfig,
}

185
186
/// Create the backend engine wrapper to run the model.
/// Download the model if necessary.
187
188
189
190
191
192
193
194
195
#[pyfunction]
#[pyo3(signature = (distributed_runtime, args))]
pub fn make_engine<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    args: EntrypointArgs,
) -> PyResult<Bound<'p, PyAny>> {
    let mut builder = LocalModelBuilder::default();
    builder
196
197
198
199
200
        .model_name(
            args.model_name
                .clone()
                .or_else(|| args.model_path.clone().map(|p| p.display().to_string())),
        )
201
        .endpoint_id(args.endpoint_id.clone())
202
        .context_length(args.context_length)
203
        .request_template(args.template_file.clone())
204
        .kv_cache_block_size(args.kv_cache_block_size)
205
        .router_config(args.router_config.clone().map(|rc| rc.into()))
206
        .http_host(args.http_host.clone())
207
        .http_port(args.http_port)
Graham King's avatar
Graham King committed
208
209
        .tls_cert_path(args.tls_cert_path.clone())
        .tls_key_path(args.tls_key_path.clone())
210
        .is_mocker(matches!(args.engine_type, EngineType::Mocker))
211
        .extra_engine_args(args.extra_engine_args.clone())
212
213
214
        .namespace(args.namespace.clone())
        .custom_backend_metrics_endpoint(args.custom_backend_metrics_endpoint.clone())
        .custom_backend_metrics_polling_interval(args.custom_backend_metrics_polling_interval);
215
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
216
217
218
219
        if let Some(model_path) = args.model_path.clone() {
            let local_path = if model_path.exists() {
                model_path
            } else {
220
221
222
                // Mocker only needs tokenizer, not weights
                let ignore_weights = matches!(args.engine_type, EngineType::Mocker);
                LocalModel::fetch(&model_path.display().to_string(), ignore_weights)
223
224
225
226
227
228
                    .await
                    .map_err(to_pyerr)?
            };
            builder.model_path(local_path);
        }

229
        let local_model = builder.build().await.map_err(to_pyerr)?;
230
        let inner = select_engine(distributed_runtime, args, local_model)
231
232
233
234
235
236
237
238
            .await
            .map_err(to_pyerr)?;
        Ok(EngineConfig { inner })
    })
}

async fn select_engine(
    #[allow(unused_variables)] distributed_runtime: super::DistributedRuntime,
239
    args: EntrypointArgs,
240
241
    local_model: LocalModel,
) -> anyhow::Result<RsEngineConfig> {
242
    let inner = match args.engine_type {
243
244
245
246
        EngineType::Echo => {
            // There is no validation for the echo engine
            RsEngineConfig::StaticFull {
                model: Box::new(local_model),
247
                engine: dynamo_llm::engines::make_echo_engine(),
248
249
250
            }
        }
        EngineType::Dynamic => RsEngineConfig::Dynamic(Box::new(local_model)),
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
        EngineType::Mocker => {
            let mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
                MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
                    anyhow::anyhow!(
                        "Failed to load mocker args from {:?}: {}",
                        extra_args_path,
                        e
                    )
                })?
            } else {
                tracing::warn!(
                    "No extra_engine_args specified for mocker engine. Using default mocker args."
                );
                MockEngineArgs::default()
            };

            let endpoint = local_model.endpoint_id().clone();

            let engine = dynamo_llm::mocker::engine::make_mocker_engine(
                distributed_runtime.inner,
                endpoint,
                mocker_args,
            )
            .await?;

            RsEngineConfig::StaticCore {
                engine,
                model: Box::new(local_model),
279
                is_prefill: args.is_prefill,
280
281
            }
        }
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    };

    Ok(inner)
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, input, engine_config))]
pub fn run_input<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    input: &str,
    engine_config: EngineConfig,
) -> PyResult<Bound<'p, PyAny>> {
    let input_enum: Input = input.parse().map_err(to_pyerr)?;
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        dynamo_llm::entrypoint::input::run_input(
298
            distributed_runtime.inner.clone(),
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
            input_enum,
            engine_config.inner,
        )
        .await
        .map_err(to_pyerr)?;
        Ok(())
    })
}

pub fn to_pyerr<E>(err: E) -> PyErr
where
    E: Display,
{
    PyException::new_err(format!("{}", err))
}