entrypoint.rs 8.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::fmt::Display;
use std::path::PathBuf;

use pyo3::{exceptions::PyException, prelude::*};

use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
10
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
11
use dynamo_llm::entrypoint::input::Input;
12
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
Graham King's avatar
Graham King committed
13
use dynamo_llm::local_model::DEFAULT_HTTP_PORT;
14
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
15
use dynamo_llm::mocker::protocols::MockEngineArgs;
16
use dynamo_runtime::protocols::EndpointId;
17

18
19
use crate::RouterMode;

20
21
22
23
24
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
#[repr(i32)]
pub enum EngineType {
    Echo = 1,
25
26
    Dynamic = 2,
    Mocker = 3,
27
    Static = 4,
28
29
}

30
31
32
33
34
35
#[pyclass]
#[derive(Default, Clone, Debug, Copy)]
pub struct KvRouterConfig {
    inner: RsKvRouterConfig,
}

36
37
38
39
40
41
impl KvRouterConfig {
    pub fn inner(&self) -> RsKvRouterConfig {
        self.inner
    }
}

42
43
44
#[pymethods]
impl KvRouterConfig {
    #[new]
45
    #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true, router_replica_sync=false, router_track_active_blocks=true, router_snapshot_threshold=1000000, router_reset_states=false))]
46
47
48
49
50
    fn new(
        overlap_score_weight: f64,
        router_temperature: f64,
        use_kv_events: bool,
        router_replica_sync: bool,
51
        router_track_active_blocks: bool,
52
53
        router_snapshot_threshold: Option<u32>,
        router_reset_states: bool,
54
    ) -> Self {
55
56
57
58
59
        KvRouterConfig {
            inner: RsKvRouterConfig {
                overlap_score_weight,
                router_temperature,
                use_kv_events,
60
                router_replica_sync,
61
                router_track_active_blocks,
62
63
                router_snapshot_threshold,
                router_reset_states,
64
65
66
67
68
69
70
71
72
73
            },
        }
    }
}

#[pyclass]
#[derive(Clone, Debug)]
pub struct RouterConfig {
    router_mode: RouterMode,
    kv_router_config: KvRouterConfig,
74
    busy_threshold: Option<f64>,
75
76
77
78
79
}

#[pymethods]
impl RouterConfig {
    #[new]
80
81
82
83
84
85
    #[pyo3(signature = (mode, config=None, busy_threshold=None))]
    pub fn new(
        mode: RouterMode,
        config: Option<KvRouterConfig>,
        busy_threshold: Option<f64>,
    ) -> Self {
86
87
88
        Self {
            router_mode: mode,
            kv_router_config: config.unwrap_or_default(),
89
            busy_threshold,
90
91
92
93
94
95
96
97
98
        }
    }
}

impl From<RouterConfig> for RsRouterConfig {
    fn from(rc: RouterConfig) -> RsRouterConfig {
        RsRouterConfig {
            router_mode: rc.router_mode.into(),
            kv_router_config: rc.kv_router_config.inner,
99
            busy_threshold: rc.busy_threshold,
100
101
102
103
        }
    }
}

104
105
106
107
108
109
110
111
112
#[pyclass]
#[derive(Clone, Debug)]
pub(crate) struct EntrypointArgs {
    engine_type: EngineType,
    model_path: Option<PathBuf>,
    model_name: Option<String>,
    endpoint_id: Option<EndpointId>,
    context_length: Option<u32>,
    template_file: Option<PathBuf>,
113
    router_config: Option<RouterConfig>,
114
    kv_cache_block_size: Option<u32>,
115
    http_host: Option<String>,
Graham King's avatar
Graham King committed
116
117
118
    http_port: u16,
    tls_cert_path: Option<PathBuf>,
    tls_key_path: Option<PathBuf>,
119
    extra_engine_args: Option<PathBuf>,
120
    namespace: Option<String>,
121
122
123
124
125
126
}

#[pymethods]
impl EntrypointArgs {
    #[allow(clippy::too_many_arguments)]
    #[new]
127
    #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None))]
128
129
130
131
132
133
134
    pub fn new(
        engine_type: EngineType,
        model_path: Option<PathBuf>,
        model_name: Option<String>, // e.g. "dyn://namespace.component.endpoint"
        endpoint_id: Option<String>,
        context_length: Option<u32>,
        template_file: Option<PathBuf>,
135
        router_config: Option<RouterConfig>,
136
        kv_cache_block_size: Option<u32>,
137
        http_host: Option<String>,
138
        http_port: Option<u16>,
Graham King's avatar
Graham King committed
139
140
        tls_cert_path: Option<PathBuf>,
        tls_key_path: Option<PathBuf>,
141
        extra_engine_args: Option<PathBuf>,
142
        namespace: Option<String>,
143
    ) -> PyResult<Self> {
144
        let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
Graham King's avatar
Graham King committed
145
146
147
148
149
150
151
        if (tls_cert_path.is_some() && tls_key_path.is_none())
            || (tls_cert_path.is_none() && tls_key_path.is_some())
        {
            return Err(pyo3::exceptions::PyValueError::new_err(
                "tls_cert_path and tls_key_path must be provided together",
            ));
        }
152
153
154
155
156
157
158
        Ok(EntrypointArgs {
            engine_type,
            model_path,
            model_name,
            endpoint_id: endpoint_id_obj,
            context_length,
            template_file,
159
            router_config,
160
            kv_cache_block_size,
161
            http_host,
Graham King's avatar
Graham King committed
162
163
164
            http_port: http_port.unwrap_or(DEFAULT_HTTP_PORT),
            tls_cert_path,
            tls_key_path,
165
            extra_engine_args,
166
            namespace,
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        })
    }
}

#[pyclass]
#[derive(Clone)]
pub(crate) struct EngineConfig {
    inner: RsEngineConfig,
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, args))]
pub fn make_engine<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    args: EntrypointArgs,
) -> PyResult<Bound<'p, PyAny>> {
    let mut builder = LocalModelBuilder::default();
    builder
186
187
188
        .model_path(args.model_path.clone())
        .model_name(args.model_name.clone())
        .endpoint_id(args.endpoint_id.clone())
189
        .context_length(args.context_length)
190
        .request_template(args.template_file.clone())
191
        .kv_cache_block_size(args.kv_cache_block_size)
192
        .router_config(args.router_config.clone().map(|rc| rc.into()))
193
        .http_host(args.http_host.clone())
194
        .http_port(args.http_port)
Graham King's avatar
Graham King committed
195
196
        .tls_cert_path(args.tls_cert_path.clone())
        .tls_key_path(args.tls_key_path.clone())
197
        .is_mocker(matches!(args.engine_type, EngineType::Mocker))
198
199
        .extra_engine_args(args.extra_engine_args.clone())
        .namespace(args.namespace.clone());
200
201
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        let local_model = builder.build().await.map_err(to_pyerr)?;
202
        let inner = select_engine(distributed_runtime, args, local_model)
203
204
205
206
207
208
209
210
            .await
            .map_err(to_pyerr)?;
        Ok(EngineConfig { inner })
    })
}

async fn select_engine(
    #[allow(unused_variables)] distributed_runtime: super::DistributedRuntime,
211
    args: EntrypointArgs,
212
213
    local_model: LocalModel,
) -> anyhow::Result<RsEngineConfig> {
214
    let inner = match args.engine_type {
215
216
217
218
        EngineType::Echo => {
            // There is no validation for the echo engine
            RsEngineConfig::StaticFull {
                model: Box::new(local_model),
219
                engine: dynamo_llm::engines::make_echo_engine(),
220
                is_static: false,
221
222
223
            }
        }
        EngineType::Dynamic => RsEngineConfig::Dynamic(Box::new(local_model)),
224
        EngineType::Static => RsEngineConfig::StaticRemote(Box::new(local_model)),
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        EngineType::Mocker => {
            let mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
                MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
                    anyhow::anyhow!(
                        "Failed to load mocker args from {:?}: {}",
                        extra_args_path,
                        e
                    )
                })?
            } else {
                tracing::warn!(
                    "No extra_engine_args specified for mocker engine. Using default mocker args."
                );
                MockEngineArgs::default()
            };

            let endpoint = local_model.endpoint_id().clone();

            let engine = dynamo_llm::mocker::engine::make_mocker_engine(
                distributed_runtime.inner,
                endpoint,
                mocker_args,
            )
            .await?;

            RsEngineConfig::StaticCore {
                engine,
                model: Box::new(local_model),
253
                is_static: false,
254
255
            }
        }
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    };

    Ok(inner)
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, input, engine_config))]
pub fn run_input<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    input: &str,
    engine_config: EngineConfig,
) -> PyResult<Bound<'p, PyAny>> {
    let input_enum: Input = input.parse().map_err(to_pyerr)?;
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        dynamo_llm::entrypoint::input::run_input(
            either::Either::Right(distributed_runtime.inner.clone()),
            input_enum,
            engine_config.inner,
        )
        .await
        .map_err(to_pyerr)?;
        Ok(())
    })
}

pub fn to_pyerr<E>(err: E) -> PyErr
where
    E: Display,
{
    PyException::new_err(format!("{}", err))
}