entrypoint.rs 8.68 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::fmt::Display;
use std::path::PathBuf;

use pyo3::{exceptions::PyException, prelude::*};

use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
11
12
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
Graham King's avatar
Graham King committed
13
use dynamo_llm::local_model::DEFAULT_HTTP_PORT;
14
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
15
use dynamo_llm::mocker::protocols::MockEngineArgs;
16
use dynamo_runtime::protocols::EndpointId;
17

18
19
use crate::RouterMode;

20
21
22
23
24
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
#[repr(i32)]
pub enum EngineType {
    Echo = 1,
25
26
    Dynamic = 2,
    Mocker = 3,
27
    Static = 4,
28
29
}

30
31
32
33
34
35
#[pyclass]
#[derive(Default, Clone, Debug, Copy)]
pub struct KvRouterConfig {
    inner: RsKvRouterConfig,
}

36
37
38
39
40
41
impl KvRouterConfig {
    pub fn inner(&self) -> RsKvRouterConfig {
        self.inner
    }
}

42
43
44
#[pymethods]
impl KvRouterConfig {
    #[new]
45
46
47
48
49
50
51
    #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false))]
    fn new(
        overlap_score_weight: f64,
        router_temperature: f64,
        use_kv_events: bool,
        router_replica_sync: bool,
    ) -> Self {
52
53
54
55
56
        KvRouterConfig {
            inner: RsKvRouterConfig {
                overlap_score_weight,
                router_temperature,
                use_kv_events,
57
                router_replica_sync,
58
59
60
61
62
63
64
65
66
67
68
                ..Default::default()
            },
        }
    }
}

#[pyclass]
#[derive(Clone, Debug)]
pub struct RouterConfig {
    router_mode: RouterMode,
    kv_router_config: KvRouterConfig,
69
    busy_threshold: Option<f64>,
70
71
72
73
74
}

#[pymethods]
impl RouterConfig {
    #[new]
75
76
77
78
79
80
    #[pyo3(signature = (mode, config=None, busy_threshold=None))]
    pub fn new(
        mode: RouterMode,
        config: Option<KvRouterConfig>,
        busy_threshold: Option<f64>,
    ) -> Self {
81
82
83
        Self {
            router_mode: mode,
            kv_router_config: config.unwrap_or_default(),
84
            busy_threshold,
85
86
87
88
89
90
91
92
93
        }
    }
}

impl From<RouterConfig> for RsRouterConfig {
    fn from(rc: RouterConfig) -> RsRouterConfig {
        RsRouterConfig {
            router_mode: rc.router_mode.into(),
            kv_router_config: rc.kv_router_config.inner,
94
            busy_threshold: rc.busy_threshold,
95
96
97
98
        }
    }
}

99
100
101
102
103
104
105
106
107
108
#[pyclass]
#[derive(Clone, Debug)]
pub(crate) struct EntrypointArgs {
    engine_type: EngineType,
    model_path: Option<PathBuf>,
    model_name: Option<String>,
    model_config: Option<PathBuf>,
    endpoint_id: Option<EndpointId>,
    context_length: Option<u32>,
    template_file: Option<PathBuf>,
109
    router_config: Option<RouterConfig>,
110
    kv_cache_block_size: Option<u32>,
111
    http_host: Option<String>,
Graham King's avatar
Graham King committed
112
113
114
    http_port: u16,
    tls_cert_path: Option<PathBuf>,
    tls_key_path: Option<PathBuf>,
115
    extra_engine_args: Option<PathBuf>,
116
117
118
119
120
121
}

#[pymethods]
impl EntrypointArgs {
    #[allow(clippy::too_many_arguments)]
    #[new]
122
    #[pyo3(signature = (engine_type, model_path=None, model_name=None, model_config=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None))]
123
124
125
126
127
128
129
130
    pub fn new(
        engine_type: EngineType,
        model_path: Option<PathBuf>,
        model_name: Option<String>, // e.g. "dyn://namespace.component.endpoint"
        model_config: Option<PathBuf>,
        endpoint_id: Option<String>,
        context_length: Option<u32>,
        template_file: Option<PathBuf>,
131
        router_config: Option<RouterConfig>,
132
        kv_cache_block_size: Option<u32>,
133
        http_host: Option<String>,
134
        http_port: Option<u16>,
Graham King's avatar
Graham King committed
135
136
        tls_cert_path: Option<PathBuf>,
        tls_key_path: Option<PathBuf>,
137
        extra_engine_args: Option<PathBuf>,
138
    ) -> PyResult<Self> {
139
        let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
Graham King's avatar
Graham King committed
140
141
142
143
144
145
146
        if (tls_cert_path.is_some() && tls_key_path.is_none())
            || (tls_cert_path.is_none() && tls_key_path.is_some())
        {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "tls_cert_path and tls_key_path must be provided together",
            ));
        }
147
148
149
150
151
152
153
154
        Ok(EntrypointArgs {
            engine_type,
            model_path,
            model_name,
            model_config,
            endpoint_id: endpoint_id_obj,
            context_length,
            template_file,
155
            router_config,
156
            kv_cache_block_size,
157
            http_host,
Graham King's avatar
Graham King committed
158
159
160
            http_port: http_port.unwrap_or(DEFAULT_HTTP_PORT),
            tls_cert_path,
            tls_key_path,
161
            extra_engine_args,
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        })
    }
}

#[pyclass]
#[derive(Clone)]
pub(crate) struct EngineConfig {
    inner: RsEngineConfig,
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, args))]
pub fn make_engine<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    args: EntrypointArgs,
) -> PyResult<Bound<'p, PyAny>> {
    let mut builder = LocalModelBuilder::default();
    builder
181
182
183
184
        .model_path(args.model_path.clone())
        .model_name(args.model_name.clone())
        .model_config(args.model_config.clone())
        .endpoint_id(args.endpoint_id.clone())
185
        .context_length(args.context_length)
186
        .request_template(args.template_file.clone())
187
        .kv_cache_block_size(args.kv_cache_block_size)
188
        .router_config(args.router_config.clone().map(|rc| rc.into()))
189
        .http_host(args.http_host.clone())
190
        .http_port(args.http_port)
Graham King's avatar
Graham King committed
191
192
        .tls_cert_path(args.tls_cert_path.clone())
        .tls_key_path(args.tls_key_path.clone())
193
194
        .is_mocker(matches!(args.engine_type, EngineType::Mocker))
        .extra_engine_args(args.extra_engine_args.clone());
195
196
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        let local_model = builder.build().await.map_err(to_pyerr)?;
197
        let inner = select_engine(distributed_runtime, args, local_model)
198
199
200
201
202
203
204
205
            .await
            .map_err(to_pyerr)?;
        Ok(EngineConfig { inner })
    })
}

async fn select_engine(
    #[allow(unused_variables)] distributed_runtime: super::DistributedRuntime,
206
    args: EntrypointArgs,
207
208
    local_model: LocalModel,
) -> anyhow::Result<RsEngineConfig> {
209
    let inner = match args.engine_type {
210
211
212
213
214
        EngineType::Echo => {
            // There is no validation for the echo engine
            RsEngineConfig::StaticFull {
                model: Box::new(local_model),
                engine: dynamo_llm::engines::make_engine_full(),
215
                is_static: false,
216
217
218
            }
        }
        EngineType::Dynamic => RsEngineConfig::Dynamic(Box::new(local_model)),
219
        EngineType::Static => RsEngineConfig::StaticRemote(Box::new(local_model)),
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        EngineType::Mocker => {
            let mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
                MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
                    anyhow::anyhow!(
                        "Failed to load mocker args from {:?}: {}",
                        extra_args_path,
                        e
                    )
                })?
            } else {
                tracing::warn!(
                    "No extra_engine_args specified for mocker engine. Using default mocker args."
                );
                MockEngineArgs::default()
            };

            let endpoint = local_model.endpoint_id().clone();

            let engine = dynamo_llm::mocker::engine::make_mocker_engine(
                distributed_runtime.inner,
                endpoint,
                mocker_args,
            )
            .await?;

            RsEngineConfig::StaticCore {
                engine,
                model: Box::new(local_model),
248
                is_static: false,
249
250
            }
        }
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    };

    Ok(inner)
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, input, engine_config))]
pub fn run_input<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    input: &str,
    engine_config: EngineConfig,
) -> PyResult<Bound<'p, PyAny>> {
    let input_enum: Input = input.parse().map_err(to_pyerr)?;
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        dynamo_llm::entrypoint::input::run_input(
            either::Either::Right(distributed_runtime.inner.clone()),
            input_enum,
            engine_config.inner,
        )
        .await
        .map_err(to_pyerr)?;
        Ok(())
    })
}

pub fn to_pyerr<E>(err: E) -> PyErr
where
    E: Display,
{
    PyException::new_err(format!("{}", err))
}