"vllm/vscode:/vscode.git/clone" did not exist on "48eb8eba581f0e45272f4e763bf5ec342f77091a"
http.rs 7.24 KB
Newer Older
1
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
// SPDX-License-Identifier: Apache-2.0

4
use std::sync::{Arc, OnceLock};
5

6
use anyhow::{Error, Result, anyhow as error};
7
use pyo3::prelude::*;
8

9
use crate::{CancellationToken, DistributedRuntime, engine::*, to_pyerr};
10

11
pub use dynamo_llm::endpoint_type::EndpointType;
12
13
pub use dynamo_llm::http::service::{error as http_error, service_v2};
pub use dynamo_runtime::{
14
    pipeline::{AsyncEngine, Data, ManyOut, SingleIn, async_trait},
15
16
17
18
19
20
    protocols::annotated::Annotated,
};

#[pyclass]
pub struct HttpService {
    inner: service_v2::HttpService,
21
22
    // CancellationToken is already Send + Sync + Clone, no Mutex needed
    cancel_token: Arc<OnceLock<CancellationToken>>,
23
24
25
26
27
28
29
30
31
}

#[pymethods]
impl HttpService {
    #[new]
    #[pyo3(signature = (port=None))]
    pub fn new(port: Option<u16>) -> PyResult<Self> {
        let builder = service_v2::HttpService::builder().port(port.unwrap_or(8080));
        let inner = builder.build().map_err(to_pyerr)?;
32
33
34
35
        Ok(Self {
            inner,
            cancel_token: Arc::new(OnceLock::new()),
        })
36
37
    }

38
39
40
41
42
43
    pub fn add_completions_model(
        &self,
        model: String,
        checksum: String,
        engine: HttpAsyncEngine,
    ) -> PyResult<()> {
44
45
46
        let engine = Arc::new(engine);
        self.inner
            .model_manager()
47
            .add_completions_model(&model, &checksum, engine)
48
49
50
51
52
53
            .map_err(to_pyerr)
    }

    pub fn add_chat_completions_model(
        &self,
        model: String,
54
        checksum: String,
55
56
57
58
59
        engine: HttpAsyncEngine,
    ) -> PyResult<()> {
        let engine = Arc::new(engine);
        self.inner
            .model_manager()
60
            .add_chat_completions_model(&model, &checksum, engine)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
            .map_err(to_pyerr)
    }

    pub fn remove_completions_model(&self, model: String) -> PyResult<()> {
        self.inner
            .model_manager()
            .remove_completions_model(&model)
            .map_err(to_pyerr)
    }

    pub fn remove_chat_completions_model(&self, model: String) -> PyResult<()> {
        self.inner
            .model_manager()
            .remove_chat_completions_model(&model)
            .map_err(to_pyerr)
    }

    pub fn list_chat_completions_models(&self) -> PyResult<Vec<String>> {
        Ok(self.inner.model_manager().list_chat_completions_models())
    }

    pub fn list_completions_models(&self) -> PyResult<Vec<String>> {
        Ok(self.inner.model_manager().list_completions_models())
    }

86
87
88
89
90
91
92
93
    fn run<'p>(&self, py: Python<'p>, runtime: &DistributedRuntime) -> PyResult<Bound<'p, PyAny>> {
        // Check if run() was already called to avoid creating unnecessary token
        if self.cancel_token.get().is_some() {
            return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                "HttpService.run() has already been called on this instance",
            ));
        }

94
        let service = self.inner.clone();
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        // Only create token if we passed the check above
        let token = runtime.inner().child_token();

        // Store the token for shutdown - should always succeed after the check above
        self.cancel_token
            .set(CancellationToken {
                inner: token.clone(),
            })
            .map_err(|_| {
                PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
                    "Race condition detected in HttpService.run()",
                )
            })?;

109
        pyo3_async_runtimes::tokio::future_into_py(py, async move {
110
            service.run(token).await.map_err(to_pyerr)?;
111
112
113
            Ok(())
        })
    }
114

115
116
117
118
119
120
121
    fn shutdown(&self) {
        // CancellationToken.cancel() is thread-safe, no lock needed
        if let Some(token) = self.cancel_token.get() {
            token.inner.cancel();
        }
    }

122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    fn enable_endpoint(&self, endpoint_type: String, enabled: bool) -> PyResult<()> {
        let endpoint_type = EndpointType::all()
            .iter()
            .find(|&&ep_type| ep_type.as_str().to_lowercase() == endpoint_type.to_lowercase())
            .copied()
            .ok_or_else(|| {
                let valid_types = EndpointType::all()
                    .iter()
                    .map(|&ep_type| ep_type.as_str().to_string())
                    .collect::<Vec<_>>()
                    .join(", ");
                to_pyerr(format!(
                    "Invalid endpoint type: '{}'. Valid types are: {}",
                    endpoint_type, valid_types
                ))
            })?;

        self.inner.enable_model_endpoint(endpoint_type, enabled);
        Ok(())
    }
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
}

#[pyclass]
#[derive(Clone)]
pub struct HttpAsyncEngine(pub PythonAsyncEngine);

impl From<PythonAsyncEngine> for HttpAsyncEngine {
    fn from(engine: PythonAsyncEngine) -> Self {
        Self(engine)
    }
}

#[pymethods]
impl HttpAsyncEngine {
    /// Create a new instance of the HttpAsyncEngine
    /// This is a simple extension of the PythonAsyncEngine that handles HttpError
    /// exceptions from Python and converts them to the Rust version of HttpError
    ///
    /// # Arguments
    /// - `generator`: a Python async generator that will be used to generate responses
    /// - `event_loop`: the Python event loop that will be used to run the generator
    ///
    /// Note: In Rust land, the request and the response are both concrete; however, in
    /// Python land, the request and response are not strongly typed, meaning the generator
    /// could accept a different type of request or return a different type of response
    /// and we would not know until runtime.
    #[new]
    pub fn new(generator: PyObject, event_loop: PyObject) -> PyResult<Self> {
        Ok(PythonAsyncEngine::new(generator, event_loop)?.into())
    }
}

174
175
176
177
178
179
#[derive(FromPyObject)]
struct HttpError {
    code: u16,
    message: String,
}

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#[async_trait]
impl<Req, Resp> AsyncEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>, Error> for HttpAsyncEngine
where
    Req: Data + Serialize,
    Resp: Data + for<'de> Deserialize<'de>,
{
    async fn generate(&self, request: SingleIn<Req>) -> Result<ManyOut<Annotated<Resp>>, Error> {
        match self.0.generate(request).await {
            Ok(res) => Ok(res),

            // Inspect the error - if it was an HttpError from Python, extract the code and message
            // and return the rust version of HttpError
            Err(e) => {
                if let Some(py_err) = e.downcast_ref::<PyErr>() {
                    Python::with_gil(|py| {
195
196
197
198
                        // With the Stable ABI, we can't subclass Python's built-in exceptions in PyO3, so instead we
                        // implement the exception in Python and assume that it's an HttpError if the code and message
                        // are present.
                        if let Ok(HttpError { code, message }) = py_err.value(py).extract() {
199
200
201
                            // SSE panics if there are carriage returns or newlines
                            let message = message.replace(['\r', '\n'], "");
                            return Err(http_error::HttpError { code, message })?;
202
                        }
203
                        Err(error!("Python Error: {}", py_err))
204
205
206
207
208
209
210
211
                    })
                } else {
                    Err(e)
                }
            }
        }
    }
}