service_v2.rs 12.3 KB
Newer Older
1
2
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
3

4
use std::collections::HashMap;
5
use std::env::var;
6
7
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
8
9
use std::sync::Arc;
use std::time::Duration;
10
11

use super::metrics;
12
use super::Metrics;
13
use super::RouteDoc;
14
use crate::discovery::ModelManager;
15
use crate::endpoint_type::EndpointType;
16
use crate::request_template::RequestTemplate;
17
use anyhow::Result;
18
use derive_builder::Builder;
19
use dynamo_runtime::logging::make_request_span;
20
use dynamo_runtime::transports::etcd;
21
use tokio::task::JoinHandle;
22
use tokio_util::sync::CancellationToken;
23
use tower_http::trace::TraceLayer;
24

25
/// HTTP service shared state
26
#[derive(Default)]
27
28
29
pub struct State {
    metrics: Arc<Metrics>,
    manager: Arc<ModelManager>,
30
    etcd_client: Option<etcd::Client>,
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    flags: StateFlags,
}

#[derive(Default, Debug)]
struct StateFlags {
    chat_endpoints_enabled: AtomicBool,
    cmpl_endpoints_enabled: AtomicBool,
    embeddings_endpoints_enabled: AtomicBool,
    responses_endpoints_enabled: AtomicBool,
}

impl StateFlags {
    pub fn get(&self, endpoint_type: &EndpointType) -> bool {
        match endpoint_type {
            EndpointType::Chat => self.chat_endpoints_enabled.load(Ordering::Relaxed),
            EndpointType::Completion => self.cmpl_endpoints_enabled.load(Ordering::Relaxed),
            EndpointType::Embedding => self.embeddings_endpoints_enabled.load(Ordering::Relaxed),
            EndpointType::Responses => self.responses_endpoints_enabled.load(Ordering::Relaxed),
        }
    }

    pub fn set(&self, endpoint_type: &EndpointType, enabled: bool) {
        match endpoint_type {
            EndpointType::Chat => self
                .chat_endpoints_enabled
                .store(enabled, Ordering::Relaxed),
            EndpointType::Completion => self
                .cmpl_endpoints_enabled
                .store(enabled, Ordering::Relaxed),
            EndpointType::Embedding => self
                .embeddings_endpoints_enabled
                .store(enabled, Ordering::Relaxed),
            EndpointType::Responses => self
                .responses_endpoints_enabled
                .store(enabled, Ordering::Relaxed),
        }
    }
68
69
70
71
72
73
74
}

impl State {
    pub fn new(manager: Arc<ModelManager>) -> Self {
        Self {
            manager,
            metrics: Arc::new(Metrics::default()),
75
            etcd_client: None,
76
77
78
79
80
81
            flags: StateFlags {
                chat_endpoints_enabled: AtomicBool::new(false),
                cmpl_endpoints_enabled: AtomicBool::new(false),
                embeddings_endpoints_enabled: AtomicBool::new(false),
                responses_endpoints_enabled: AtomicBool::new(false),
            },
82
83
84
85
86
87
88
89
        }
    }

    pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: Option<etcd::Client>) -> Self {
        Self {
            manager,
            metrics: Arc::new(Metrics::default()),
            etcd_client,
90
91
92
93
94
95
            flags: StateFlags {
                chat_endpoints_enabled: AtomicBool::new(false),
                cmpl_endpoints_enabled: AtomicBool::new(false),
                embeddings_endpoints_enabled: AtomicBool::new(false),
                responses_endpoints_enabled: AtomicBool::new(false),
            },
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        }
    }
    /// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests
    pub fn metrics_clone(&self) -> Arc<Metrics> {
        self.metrics.clone()
    }

    pub fn manager(&self) -> &ModelManager {
        Arc::as_ref(&self.manager)
    }

    pub fn manager_clone(&self) -> Arc<ModelManager> {
        self.manager.clone()
    }

111
112
113
114
    pub fn etcd_client(&self) -> Option<&etcd::Client> {
        self.etcd_client.as_ref()
    }

115
116
117
118
119
120
    // TODO
    pub fn sse_keep_alive(&self) -> Option<Duration> {
        None
    }
}

121
122
#[derive(Clone)]
pub struct HttpService {
123
124
125
    // The state we share with every request handler
    state: Arc<State>,

126
127
    router: axum::Router,
    port: u16,
128
    host: String,
129
    route_docs: Vec<RouteDoc>,
130
131
132
}

#[derive(Clone, Builder)]
133
#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
134
135
136
137
pub struct HttpServiceConfig {
    #[builder(default = "8787")]
    port: u16,

138
139
140
    #[builder(setter(into), default = "String::from(\"0.0.0.0\")")]
    host: String,

141
142
    // #[builder(default)]
    // custom: Vec<axum::Router>
143
    #[builder(default = "false")]
144
145
    enable_chat_endpoints: bool,

146
    #[builder(default = "false")]
147
    enable_cmpl_endpoints: bool,
148

149
    #[builder(default = "true")]
150
151
    enable_embeddings_endpoints: bool,

152
153
154
    #[builder(default = "true")]
    enable_responses_endpoints: bool,

155
156
    #[builder(default = "None")]
    request_template: Option<RequestTemplate>,
157
158
159

    #[builder(default = "None")]
    etcd_client: Option<etcd::Client>,
160
161
162
163
164
165
166
}

impl HttpService {
    pub fn builder() -> HttpServiceConfigBuilder {
        HttpServiceConfigBuilder::default()
    }

167
168
169
170
171
172
173
174
    pub fn state_clone(&self) -> Arc<State> {
        self.state.clone()
    }

    pub fn state(&self) -> &State {
        Arc::as_ref(&self.state)
    }

175
    pub fn model_manager(&self) -> &ModelManager {
176
        self.state().manager()
177
178
    }

179
180
181
182
183
184
    pub async fn spawn(&self, cancel_token: CancellationToken) -> JoinHandle<Result<()>> {
        let this = self.clone();
        tokio::spawn(async move { this.run(cancel_token).await })
    }

    pub async fn run(&self, cancel_token: CancellationToken) -> Result<()> {
185
        let address = format!("{}:{}", self.host, self.port);
186
187
188
189
190
191
192
193
194
        tracing::info!(address, "Starting HTTP service on: {address}");

        let listener = tokio::net::TcpListener::bind(address.as_str())
            .await
            .unwrap_or_else(|_| panic!("could not bind to address: {address}"));

        let router = self.router.clone();
        let observer = cancel_token.child_token();

195
        axum::serve(listener, router)
196
197
            .with_graceful_shutdown(observer.cancelled_owned())
            .await
198
199
200
            .inspect_err(|_| cancel_token.cancel())?;

        Ok(())
201
    }
202
203
204
205
206

    /// Documentation of exposed HTTP endpoints
    pub fn route_docs(&self) -> &[RouteDoc] {
        &self.route_docs
    }
207
208
209
210
211
212
213
214
215

    pub async fn enable_model_endpoint(&self, endpoint_type: EndpointType, enable: bool) {
        self.state.flags.set(&endpoint_type, enable);
        tracing::info!(
            "{} endpoints {}",
            endpoint_type.as_str(),
            if enable { "enabled" } else { "disabled" }
        );
    }
216
217
}

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
/// Environment variable to set the metrics endpoint path (default: `/metrics`)
static HTTP_SVC_METRICS_PATH_ENV: &str = "DYN_HTTP_SVC_METRICS_PATH";
/// Environment variable to set the models endpoint path (default: `/v1/models`)
static HTTP_SVC_MODELS_PATH_ENV: &str = "DYN_HTTP_SVC_MODELS_PATH";
/// Environment variable to set the health endpoint path (default: `/health`)
static HTTP_SVC_HEALTH_PATH_ENV: &str = "DYN_HTTP_SVC_HEALTH_PATH";
/// Environment variable to set the live endpoint path (default: `/live`)
static HTTP_SVC_LIVE_PATH_ENV: &str = "DYN_HTTP_SVC_LIVE_PATH";
/// Environment variable to set the chat completions endpoint path (default: `/v1/chat/completions`)
static HTTP_SVC_CHAT_PATH_ENV: &str = "DYN_HTTP_SVC_CHAT_PATH";
/// Environment variable to set the completions endpoint path (default: `/v1/completions`)
static HTTP_SVC_CMP_PATH_ENV: &str = "DYN_HTTP_SVC_CMP_PATH";
/// Environment variable to set the embeddings endpoint path (default: `/v1/embeddings`)
static HTTP_SVC_EMB_PATH_ENV: &str = "DYN_HTTP_SVC_EMB_PATH";
/// Environment variable to set the responses endpoint path (default: `/v1/responses`)
static HTTP_SVC_RESPONSES_PATH_ENV: &str = "DYN_HTTP_SVC_RESPONSES_PATH";

235
236
impl HttpServiceConfigBuilder {
    pub fn build(self) -> Result<HttpService, anyhow::Error> {
237
        let config: HttpServiceConfig = self.build_internal()?;
238

239
        let model_manager = Arc::new(ModelManager::new());
240
        let state = Arc::new(State::new_with_etcd(model_manager, config.etcd_client));
241

242
243
244
245
246
247
248
249
250
251
252
253
254
        state
            .flags
            .set(&EndpointType::Chat, config.enable_chat_endpoints);
        state
            .flags
            .set(&EndpointType::Completion, config.enable_cmpl_endpoints);
        state
            .flags
            .set(&EndpointType::Embedding, config.enable_embeddings_endpoints);
        state
            .flags
            .set(&EndpointType::Responses, config.enable_responses_endpoints);

255
256
        // enable prometheus metrics
        let registry = metrics::Registry::new();
257
        state.metrics_clone().register(&registry)?;
258
259

        let mut router = axum::Router::new();
260

261
262
263
        let mut all_docs = Vec::new();

        let mut routes = vec![
264
265
266
267
            metrics::router(registry, var(HTTP_SVC_METRICS_PATH_ENV).ok()),
            super::openai::list_models_router(state.clone(), var(HTTP_SVC_MODELS_PATH_ENV).ok()),
            super::health::health_check_router(state.clone(), var(HTTP_SVC_HEALTH_PATH_ENV).ok()),
            super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()),
268
269
        ];

270
271
272
273
        let endpoint_routes =
            HttpServiceConfigBuilder::get_endpoints_router(state.clone(), &config.request_template);
        routes.extend(endpoint_routes);
        for (route_docs, route) in routes {
274
275
276
277
            router = router.merge(route);
            all_docs.extend(route_docs);
        }

278
279
280
        // Add span for tracing
        router = router.layer(TraceLayer::new_for_http().make_span_with(make_request_span));

281
        Ok(HttpService {
282
            state,
283
284
            router,
            port: config.port,
285
            host: config.host,
286
            route_docs: all_docs,
287
288
        })
    }
289
290
291
292
293

    pub fn with_request_template(mut self, request_template: Option<RequestTemplate>) -> Self {
        self.request_template = Some(request_template);
        self
    }
294
295
296
297
298

    pub fn with_etcd_client(mut self, etcd_client: Option<etcd::Client>) -> Self {
        self.etcd_client = Some(etcd_client);
        self
    }
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352

    fn get_endpoints_router(
        state: Arc<State>,
        request_template: &Option<RequestTemplate>,
    ) -> Vec<(Vec<RouteDoc>, axum::Router)> {
        let mut routes = Vec::new();
        // Add chat completions route with conditional middleware
        let (chat_docs, chat_route) = super::openai::chat_completions_router(
            state.clone(),
            request_template.clone(),
            var(HTTP_SVC_CHAT_PATH_ENV).ok(),
        );
        let (cmpl_docs, cmpl_route) =
            super::openai::completions_router(state.clone(), var(HTTP_SVC_CMP_PATH_ENV).ok());
        let (embed_docs, embed_route) =
            super::openai::embeddings_router(state.clone(), var(HTTP_SVC_EMB_PATH_ENV).ok());
        let (responses_docs, responses_route) = super::openai::responses_router(
            state.clone(),
            request_template.clone(),
            var(HTTP_SVC_RESPONSES_PATH_ENV).ok(),
        );

        let mut endpoint_routes = HashMap::new();
        endpoint_routes.insert(EndpointType::Chat, (chat_docs, chat_route));
        endpoint_routes.insert(EndpointType::Completion, (cmpl_docs, cmpl_route));
        endpoint_routes.insert(EndpointType::Embedding, (embed_docs, embed_route));
        endpoint_routes.insert(EndpointType::Responses, (responses_docs, responses_route));

        for endpoint_type in EndpointType::all() {
            let state_route = state.clone();
            if !endpoint_routes.contains_key(&endpoint_type) {
                tracing::debug!("{} endpoints are disabled", endpoint_type.as_str());
                continue;
            }
            let (docs, route) = endpoint_routes.get(&endpoint_type).cloned().unwrap();
            let route = route.route_layer(axum::middleware::from_fn(
                move |req: axum::http::Request<axum::body::Body>, next: axum::middleware::Next| {
                    let state: Arc<State> = state_route.clone();
                    async move {
                        // Check if the endpoint is enabled
                        let enabled = state.flags.get(&endpoint_type);
                        if enabled {
                            Ok(next.run(req).await)
                        } else {
                            tracing::debug!("{} endpoints are disabled", endpoint_type.as_str());
                            Err(axum::http::StatusCode::SERVICE_UNAVAILABLE)
                        }
                    }
                },
            ));
            routes.push((docs, route));
        }
        routes
    }
353
}