entrypoint.rs 7.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::fmt::Display;
use std::path::PathBuf;

use pyo3::{exceptions::PyException, prelude::*};

use dynamo_llm::entrypoint::input::Input;
use dynamo_llm::entrypoint::EngineConfig as RsEngineConfig;
11
12
use dynamo_llm::entrypoint::RouterConfig as RsRouterConfig;
use dynamo_llm::kv_router::KvRouterConfig as RsKvRouterConfig;
13
use dynamo_llm::local_model::{LocalModel, LocalModelBuilder};
14
use dynamo_llm::mocker::protocols::MockEngineArgs;
15
16
use dynamo_runtime::protocols::Endpoint as EndpointId;

17
18
use crate::RouterMode;

19
20
21
22
23
#[pyclass(eq, eq_int)]
#[derive(Clone, Debug, PartialEq)]
#[repr(i32)]
pub enum EngineType {
    Echo = 1,
24
25
    Dynamic = 2,
    Mocker = 3,
26
    Static = 4,
27
28
}

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#[pyclass]
#[derive(Default, Clone, Debug, Copy)]
pub struct KvRouterConfig {
    inner: RsKvRouterConfig,
}

#[pymethods]
impl KvRouterConfig {
    #[new]
    #[pyo3(signature = (overlap_score_weight=1.0, router_temperature=0.0, use_kv_events=true))]
    fn new(overlap_score_weight: f64, router_temperature: f64, use_kv_events: bool) -> Self {
        KvRouterConfig {
            inner: RsKvRouterConfig {
                overlap_score_weight,
                router_temperature,
                use_kv_events,
                ..Default::default()
            },
        }
    }
}

#[pyclass]
#[derive(Clone, Debug)]
pub struct RouterConfig {
    router_mode: RouterMode,
    kv_router_config: KvRouterConfig,
}

#[pymethods]
impl RouterConfig {
    #[new]
    #[pyo3(signature = (mode, config=None))]
    pub fn new(mode: RouterMode, config: Option<KvRouterConfig>) -> Self {
        Self {
            router_mode: mode,
            kv_router_config: config.unwrap_or_default(),
        }
    }
}

impl From<RouterConfig> for RsRouterConfig {
    fn from(rc: RouterConfig) -> RsRouterConfig {
        RsRouterConfig {
            router_mode: rc.router_mode.into(),
            kv_router_config: rc.kv_router_config.inner,
        }
    }
}

79
80
81
82
83
84
85
86
87
88
#[pyclass]
#[derive(Clone, Debug)]
pub(crate) struct EntrypointArgs {
    engine_type: EngineType,
    model_path: Option<PathBuf>,
    model_name: Option<String>,
    model_config: Option<PathBuf>,
    endpoint_id: Option<EndpointId>,
    context_length: Option<u32>,
    template_file: Option<PathBuf>,
89
    router_config: Option<RouterConfig>,
90
91
    kv_cache_block_size: Option<u32>,
    http_port: Option<u16>,
92
    extra_engine_args: Option<PathBuf>,
93
94
95
96
97
98
}

#[pymethods]
impl EntrypointArgs {
    #[allow(clippy::too_many_arguments)]
    #[new]
99
    #[pyo3(signature = (engine_type, model_path=None, model_name=None, model_config=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_port=None, extra_engine_args=None))]
100
101
102
103
104
105
106
107
    pub fn new(
        engine_type: EngineType,
        model_path: Option<PathBuf>,
        model_name: Option<String>, // e.g. "dyn://namespace.component.endpoint"
        model_config: Option<PathBuf>,
        endpoint_id: Option<String>,
        context_length: Option<u32>,
        template_file: Option<PathBuf>,
108
        router_config: Option<RouterConfig>,
109
110
        kv_cache_block_size: Option<u32>,
        http_port: Option<u16>,
111
        extra_engine_args: Option<PathBuf>,
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    ) -> PyResult<Self> {
        let endpoint_id_obj: Option<EndpointId> = match endpoint_id {
            Some(eid) => Some(eid.parse().map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
                    "Invalid endpoint_id format: {eid}"
                ))
            })?),
            None => None,
        };
        Ok(EntrypointArgs {
            engine_type,
            model_path,
            model_name,
            model_config,
            endpoint_id: endpoint_id_obj,
            context_length,
            template_file,
129
            router_config,
130
131
            kv_cache_block_size,
            http_port,
132
            extra_engine_args,
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        })
    }
}

#[pyclass]
#[derive(Clone)]
pub(crate) struct EngineConfig {
    inner: RsEngineConfig,
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, args))]
pub fn make_engine<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    args: EntrypointArgs,
) -> PyResult<Bound<'p, PyAny>> {
    let mut builder = LocalModelBuilder::default();
    builder
152
153
154
155
        .model_path(args.model_path.clone())
        .model_name(args.model_name.clone())
        .model_config(args.model_config.clone())
        .endpoint_id(args.endpoint_id.clone())
156
        .context_length(args.context_length)
157
        .request_template(args.template_file.clone())
158
        .kv_cache_block_size(args.kv_cache_block_size)
159
        .router_config(args.router_config.clone().map(|rc| rc.into()))
160
161
        .http_port(args.http_port)
        .is_mocker(matches!(args.engine_type, EngineType::Mocker));
162
163
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        let local_model = builder.build().await.map_err(to_pyerr)?;
164
        let inner = select_engine(distributed_runtime, args, local_model)
165
166
167
168
169
170
171
172
            .await
            .map_err(to_pyerr)?;
        Ok(EngineConfig { inner })
    })
}

async fn select_engine(
    #[allow(unused_variables)] distributed_runtime: super::DistributedRuntime,
173
    args: EntrypointArgs,
174
175
    local_model: LocalModel,
) -> anyhow::Result<RsEngineConfig> {
176
    let inner = match args.engine_type {
177
178
179
180
181
        EngineType::Echo => {
            // There is no validation for the echo engine
            RsEngineConfig::StaticFull {
                model: Box::new(local_model),
                engine: dynamo_llm::engines::make_engine_full(),
182
                is_static: false,
183
184
185
            }
        }
        EngineType::Dynamic => RsEngineConfig::Dynamic(Box::new(local_model)),
186
        EngineType::Static => RsEngineConfig::StaticRemote(Box::new(local_model)),
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
        EngineType::Mocker => {
            let mocker_args = if let Some(extra_args_path) = args.extra_engine_args {
                MockEngineArgs::from_json_file(&extra_args_path).map_err(|e| {
                    anyhow::anyhow!(
                        "Failed to load mocker args from {:?}: {}",
                        extra_args_path,
                        e
                    )
                })?
            } else {
                tracing::warn!(
                    "No extra_engine_args specified for mocker engine. Using default mocker args."
                );
                MockEngineArgs::default()
            };

            let endpoint = local_model.endpoint_id().clone();

            let engine = dynamo_llm::mocker::engine::make_mocker_engine(
                distributed_runtime.inner,
                endpoint,
                mocker_args,
            )
            .await?;

            RsEngineConfig::StaticCore {
                engine,
                model: Box::new(local_model),
215
                is_static: false,
216
217
            }
        }
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    };

    Ok(inner)
}

#[pyfunction]
#[pyo3(signature = (distributed_runtime, input, engine_config))]
pub fn run_input<'p>(
    py: Python<'p>,
    distributed_runtime: super::DistributedRuntime,
    input: &str,
    engine_config: EngineConfig,
) -> PyResult<Bound<'p, PyAny>> {
    let input_enum: Input = input.parse().map_err(to_pyerr)?;
    pyo3_async_runtimes::tokio::future_into_py(py, async move {
        dynamo_llm::entrypoint::input::run_input(
            either::Either::Right(distributed_runtime.inner.clone()),
            input_enum,
            engine_config.inner,
        )
        .await
        .map_err(to_pyerr)?;
        Ok(())
    })
}

pub fn to_pyerr<E>(err: E) -> PyErr
where
    E: Display,
{
    PyException::new_err(format!("{}", err))
}