http.rs 6.32 KB
Newer Older
1
2
3
4
5
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::sync::Arc;

6
use anyhow::{Error, Result, anyhow as error};
7
use pyo3::prelude::*;
8

9
use crate::{CancellationToken, engine::*, to_pyerr};
10

11
pub use dynamo_llm::endpoint_type::EndpointType;
12
13
pub use dynamo_llm::http::service::{error as http_error, service_v2};
pub use dynamo_runtime::{
14
    pipeline::{AsyncEngine, Data, ManyOut, SingleIn, async_trait},
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
    protocols::annotated::Annotated,
};

#[pyclass]
pub struct HttpService {
    inner: service_v2::HttpService,
}

#[pymethods]
impl HttpService {
    #[new]
    #[pyo3(signature = (port=None))]
    pub fn new(port: Option<u16>) -> PyResult<Self> {
        let builder = service_v2::HttpService::builder().port(port.unwrap_or(8080));
        let inner = builder.build().map_err(to_pyerr)?;
        Ok(Self { inner })
    }

33
34
35
36
37
38
    pub fn add_completions_model(
        &self,
        model: String,
        checksum: String,
        engine: HttpAsyncEngine,
    ) -> PyResult<()> {
39
40
41
        let engine = Arc::new(engine);
        self.inner
            .model_manager()
42
            .add_completions_model(&model, &checksum, engine)
43
44
45
46
47
48
            .map_err(to_pyerr)
    }

    pub fn add_chat_completions_model(
        &self,
        model: String,
49
        checksum: String,
50
51
52
53
54
        engine: HttpAsyncEngine,
    ) -> PyResult<()> {
        let engine = Arc::new(engine);
        self.inner
            .model_manager()
55
            .add_chat_completions_model(&model, &checksum, engine)
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
            .map_err(to_pyerr)
    }

    pub fn remove_completions_model(&self, model: String) -> PyResult<()> {
        self.inner
            .model_manager()
            .remove_completions_model(&model)
            .map_err(to_pyerr)
    }

    pub fn remove_chat_completions_model(&self, model: String) -> PyResult<()> {
        self.inner
            .model_manager()
            .remove_chat_completions_model(&model)
            .map_err(to_pyerr)
    }

    pub fn list_chat_completions_models(&self) -> PyResult<Vec<String>> {
        Ok(self.inner.model_manager().list_chat_completions_models())
    }

    pub fn list_completions_models(&self) -> PyResult<Vec<String>> {
        Ok(self.inner.model_manager().list_completions_models())
    }

    fn run<'p>(&self, py: Python<'p>, token: CancellationToken) -> PyResult<Bound<'p, PyAny>> {
        let service = self.inner.clone();
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
            service.run(token.inner).await.map_err(to_pyerr)?;
            Ok(())
        })
    }
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

    fn enable_endpoint(&self, endpoint_type: String, enabled: bool) -> PyResult<()> {
        let endpoint_type = EndpointType::all()
            .iter()
            .find(|&&ep_type| ep_type.as_str().to_lowercase() == endpoint_type.to_lowercase())
            .copied()
            .ok_or_else(|| {
                let valid_types = EndpointType::all()
                    .iter()
                    .map(|&ep_type| ep_type.as_str().to_string())
                    .collect::<Vec<_>>()
                    .join(", ");
                to_pyerr(format!(
                    "Invalid endpoint type: '{}'. Valid types are: {}",
                    endpoint_type, valid_types
                ))
            })?;

        self.inner.enable_model_endpoint(endpoint_type, enabled);
        Ok(())
    }
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
}

#[pyclass]
#[derive(Clone)]
pub struct HttpAsyncEngine(pub PythonAsyncEngine);

impl From<PythonAsyncEngine> for HttpAsyncEngine {
    fn from(engine: PythonAsyncEngine) -> Self {
        Self(engine)
    }
}

#[pymethods]
impl HttpAsyncEngine {
    /// Create a new instance of the HttpAsyncEngine
    /// This is a simple extension of the PythonAsyncEngine that handles HttpError
    /// exceptions from Python and converts them to the Rust version of HttpError
    ///
    /// # Arguments
    /// - `generator`: a Python async generator that will be used to generate responses
    /// - `event_loop`: the Python event loop that will be used to run the generator
    ///
    /// Note: In Rust land, the request and the response are both concrete; however, in
    /// Python land, the request and response are not strongly typed, meaning the generator
    /// could accept a different type of request or return a different type of response
    /// and we would not know until runtime.
    #[new]
    pub fn new(generator: PyObject, event_loop: PyObject) -> PyResult<Self> {
        Ok(PythonAsyncEngine::new(generator, event_loop)?.into())
    }
}

#[async_trait]
impl<Req, Resp> AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error> for HttpAsyncEngine
where
    Req: Data + Serialize,
    Resp: Data + for<'de> Deserialize<'de>,
{
    async fn generate(&self, request: SingleIn<Req>) -> Result<ManyOut<Annotated<Resp>>, Error> {
        match self.0.generate(request).await {
            Ok(res) => Ok(res),

            // Inspect the error - if it was an HttpError from Python, extract the code and message
            // and return the rust version of HttpError
            Err(e) => {
                if let Some(py_err) = e.downcast_ref::<PyErr>() {
                    Python::with_gil(|py| {
156
157
158
159
160
161
162
163
164
165
166
                        let err_val = py_err.clone_ref(py).into_value(py);
                        let bound_err = err_val.bind(py);

                        // check: Py03 exceptions cannot be cross-compiled, so we duck-type by name
                        // and fields.
                        if let Ok(type_name) = bound_err.get_type().name()
                            && type_name.to_string().contains("HttpError")
                            && let (Ok(code), Ok(message)) =
                                (bound_err.getattr("code"), bound_err.getattr("message"))
                            && let (Ok(code), Ok(message)) =
                                (code.extract::<u16>(), message.extract::<String>())
167
                        {
168
169
170
                            // SSE panics if there are carriage returns or newlines
                            let message = message.replace(['\r', '\n'], "");
                            return Err(http_error::HttpError { code, message })?;
171
                        }
172
                        Err(error!("Python Error: {}", py_err.to_string()))
173
174
175
176
177
178
179
180
                    })
                } else {
                    Err(e)
                }
            }
        }
    }
}