entrypoint.rs 9.16 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::fmt::Display;
use std::path::PathBuf;

use pyo3::{exceptions::PyException, prelude::*};

use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
11
12
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
Graham King's avatar
Graham King committed
13
use dynamo_llm::local_model::DEFAULT_HTTP_PORT;
14
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
15
use dynamo_llm::mocker::protocols::MockEngineArgs;
16
use dynamo_runtime::protocols::EndpointId;
17

18
19
use crate::RouterMode;

20
21
22
23
24
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
#[repr(i32)]
pub enum EngineType {
    Echo = 1,
25
26
    Dynamic = 2,
    Mocker = 3,
27
    Static = 4,
28
29
}

30
31
32
33
34
35
#[pyclass]
#[derive(Default, Clone, Debug, Copy)]
pub struct KvRouterConfig {
    inner: RsKvRouterConfig,
}

36
37
38
39
40
41
impl KvRouterConfig {
    pub fn inner(&self) -> RsKvRouterConfig {
        self.inner
    }
}

42
43
44
#[pymethods]
impl KvRouterConfig {
    #[new]
45
    #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=10000, router_reset_states=false))]
46
47
48
49
50
    fn new(
        overlap_score_weight: f64,
        router_temperature: f64,
        use_kv_events: bool,
        router_replica_sync: bool,
51
        router_track_active_blocks: bool,
52
53
        router_snapshot_threshold: Option<u32>,
        router_reset_states: bool,
54
    ) -> Self {
55
56
57
58
59
        KvRouterConfig {
            inner: RsKvRouterConfig {
                overlap_score_weight,
                router_temperature,
                use_kv_events,
60
                router_replica_sync,
61
                router_track_active_blocks,
62
63
                router_snapshot_threshold,
                router_reset_states,
64
65
66
67
68
69
70
71
72
73
74
                ..Default::default()
            },
        }
    }
}

#[pyclass]
#[derive(Clone, Debug)]
pub struct RouterConfig {
    router_mode: RouterMode,
    kv_router_config: KvRouterConfig,
75
    busy_threshold: Option<f64>,
76
77
78
79
80
}

#[pymethods]
impl RouterConfig {
    #[new]
81
82
83
84
85
86
    #[pyo3(signature = (mode, config=None, busy_threshold=None))]
    pub fn new(
        mode: RouterMode,
        config: Option<KvRouterConfig>,
        busy_threshold: Option<f64>,
    ) -> Self {
87
88
89
        Self {
            router_mode: mode,
            kv_router_config: config.unwrap_or_default(),
90
            busy_threshold,
91
92
93
94
95
96
97
98
99
        }
    }
}

impl From<RouterConfig> for RsRouterConfig {
    fn from(rc: RouterConfig) -> RsRouterConfig {
        RsRouterConfig {
            router_mode: rc.router_mode.into(),
            kv_router_config: rc.kv_router_config.inner,
100
            busy_threshold: rc.busy_threshold,
101
102
103
104
        }
    }
}

105
106
107
108
109
110
111
112
113
114
#[pyclass]
#[derive(Clone, Debug)]
pub(crate) struct EntrypointArgs {
    engine_type: EngineType,
    model_path: Option<PathBuf>,
    model_name: Option<String>,
    model_config: Option<PathBuf>,
    endpoint_id: Option<EndpointId>,
    context_length: Option<u32>,
    template_file: Option<PathBuf>,
115
    router_config: Option<RouterConfig>,
116
    kv_cache_block_size: Option<u32>,
117
    http_host: Option<String>,
Graham King's avatar
Graham King committed
118
119
120
    http_port: u16,
    tls_cert_path: Option<PathBuf>,
    tls_key_path: Option<PathBuf>,
121
    extra_engine_args: Option<PathBuf>,
122
    namespace: Option<String>,
123
124
125
126
127
128
}

#[pymethods]
impl EntrypointArgs {
    #[allow(clippy::too_many_arguments)]
    #[new]
129
    #[pyo3(signature = (engine_type, model_path=None, model_name=None, model_config=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None))]
130
131
132
133
134
135
136
137
    pub fn new(
        engine_type: EngineType,
        model_path: Option<PathBuf>,
        model_name: Option<String>, // e.g. "dyn://namespace.component.endpoint"
        model_config: Option<PathBuf>,
        endpoint_id: Option<String>,
        context_length: Option<u32>,
        template_file: Option<PathBuf>,
138
        router_config: Option<RouterConfig>,
139
        kv_cache_block_size: Option<u32>,
140
        http_host: Option<String>,
141
        http_port: Option<u16>,
Graham King's avatar
Graham King committed
142
143
        tls_cert_path: Option<PathBuf>,
        tls_key_path: Option<PathBuf>,
144
        extra_engine_args: Option<PathBuf>,
145
        namespace: Option<String>,
146
    ) -> PyResult<Self> {
147
        let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
Graham King's avatar
Graham King committed
148
149
150
151
152
153
154
        if (tls_cert_path.is_some() && tls_key_path.is_none())
            || (tls_cert_path.is_none() && tls_key_path.is_some())
        {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "tls_cert_path and tls_key_path must be provided together",
            ));
        }
155
156
157
158
159
160
161
162
        Ok(EntrypointArgs {
            engine_type,
            model_path,
            model_name,
            model_config,
            endpoint_id: endpoint_id_obj,
            context_length,
            template_file,
163
            router_config,
164
            kv_cache_block_size,
165
            http_host,
Graham King's avatar
Graham King committed
166
167
168
            http_port: http_port.unwrap_or(DEFAULT_HTTP_PORT),
            tls_cert_path,
            tls_key_path,
169
            extra_engine_args,
170
            namespace,
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        })
    }
}

#[pyclass]
#[derive(Clone)]
pub(crate) struct EngineConfig {
    inner: RsEngineConfig,
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, args))]
pub fn make_engine<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    args: EntrypointArgs,
) -> PyResult<Bound<'p, PyAny>> {
    let mut builder = LocalModelBuilder::default();
    builder
190
191
192
193
        .model_path(args.model_path.clone())
        .model_name(args.model_name.clone())
        .model_config(args.model_config.clone())
        .endpoint_id(args.endpoint_id.clone())
194
        .context_length(args.context_length)
195
        .request_template(args.template_file.clone())
196
        .kv_cache_block_size(args.kv_cache_block_size)
197
        .router_config(args.router_config.clone().map(|rc| rc.into()))
198
        .http_host(args.http_host.clone())
199
        .http_port(args.http_port)
Graham King's avatar
Graham King committed
200
201
        .tls_cert_path(args.tls_cert_path.clone())
        .tls_key_path(args.tls_key_path.clone())
202
        .is_mocker(matches!(args.engine_type, EngineType::Mocker))
203
204
        .extra_engine_args(args.extra_engine_args.clone())
        .namespace(args.namespace.clone());
205
206
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        let local_model = builder.build().await.map_err(to_pyerr)?;
207
        let inner = select_engine(distributed_runtime, args, local_model)
208
209
210
211
212
213
214
215
            .await
            .map_err(to_pyerr)?;
        Ok(EngineConfig { inner })
    })
}

async fn select_engine(
    #[allow(unused_variables)] distributed_runtime: super::DistributedRuntime,
216
    args: EntrypointArgs,
217
218
    local_model: LocalModel,
) -> anyhow::Result<RsEngineConfig> {
219
    let inner = match args.engine_type {
220
221
222
223
        EngineType::Echo => {
            // There is no validation for the echo engine
            RsEngineConfig::StaticFull {
                model: Box::new(local_model),
224
                engine: dynamo_llm::engines::make_echo_engine(),
225
                is_static: false,
226
227
228
            }
        }
        EngineType::Dynamic => RsEngineConfig::Dynamic(Box::new(local_model)),
229
        EngineType::Static => RsEngineConfig::StaticRemote(Box::new(local_model)),
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        EngineType::Mocker => {
            let mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
                MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
                    anyhow::anyhow!(
                        "Failed to load mocker args from {:?}: {}",
                        extra_args_path,
                        e
                    )
                })?
            } else {
                tracing::warn!(
                    "No extra_engine_args specified for mocker engine. Using default mocker args."
                );
                MockEngineArgs::default()
            };

            let endpoint = local_model.endpoint_id().clone();

            let engine = dynamo_llm::mocker::engine::make_mocker_engine(
                distributed_runtime.inner,
                endpoint,
                mocker_args,
            )
            .await?;

            RsEngineConfig::StaticCore {
                engine,
                model: Box::new(local_model),
258
                is_static: false,
259
260
            }
        }
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    };

    Ok(inner)
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, input, engine_config))]
pub fn run_input<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    input: &str,
    engine_config: EngineConfig,
) -> PyResult<Bound<'p, PyAny>> {
    let input_enum: Input = input.parse().map_err(to_pyerr)?;
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        dynamo_llm::entrypoint::input::run_input(
            either::Either::Right(distributed_runtime.inner.clone()),
            input_enum,
            engine_config.inner,
        )
        .await
        .map_err(to_pyerr)?;
        Ok(())
    })
}

pub fn to_pyerr<E>(err: E) -> PyErr
where
    E: Display,
{
    PyException::new_err(format!("{}", err))
}