service_v2.rs 7.97 KB
Newer Older
1
2
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
3

4
use std::env::var;
5
6
use std::sync::Arc;
use std::time::Duration;
7
8

use super::metrics;
9
use super::Metrics;
10
use super::RouteDoc;
11
use crate::discovery::ModelManager;
12
use crate::request_template::RequestTemplate;
13
use anyhow::Result;
14
use derive_builder::Builder;
15
use dynamo_runtime::logging::make_request_span;
16
use dynamo_runtime::transports::etcd;
17
use tokio::task::JoinHandle;
18
use tokio_util::sync::CancellationToken;
19
use tower_http::trace::TraceLayer;
20

21
22
23
24
/// HTTP service shared state
pub struct State {
    metrics: Arc<Metrics>,
    manager: Arc<ModelManager>,
25
    etcd_client: Option<etcd::Client>,
26
27
28
29
30
31
32
}

impl State {
    pub fn new(manager: Arc<ModelManager>) -> Self {
        Self {
            manager,
            metrics: Arc::new(Metrics::default()),
33
34
35
36
37
38
39
40
41
            etcd_client: None,
        }
    }

    pub fn new_with_etcd(manager: Arc<ModelManager>, etcd_client: Option<etcd::Client>) -> Self {
        Self {
            manager,
            metrics: Arc::new(Metrics::default()),
            etcd_client,
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
        }
    }

    /// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests
    pub fn metrics_clone(&self) -> Arc<Metrics> {
        self.metrics.clone()
    }

    pub fn manager(&self) -> &ModelManager {
        Arc::as_ref(&self.manager)
    }

    pub fn manager_clone(&self) -> Arc<ModelManager> {
        self.manager.clone()
    }

58
59
60
61
    pub fn etcd_client(&self) -> Option<&etcd::Client> {
        self.etcd_client.as_ref()
    }

62
63
64
65
66
67
    // TODO
    pub fn sse_keep_alive(&self) -> Option<Duration> {
        None
    }
}

68
69
#[derive(Clone)]
pub struct HttpService {
70
71
72
    // The state we share with every request handler
    state: Arc<State>,

73
74
    router: axum::Router,
    port: u16,
75
    host: String,
76
    route_docs: Vec<RouteDoc>,
77
78
79
}

#[derive(Clone, Builder)]
80
#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
81
82
83
84
pub struct HttpServiceConfig {
    #[builder(default = "8787")]
    port: u16,

85
86
87
    #[builder(setter(into), default = "String::from(\"0.0.0.0\")")]
    host: String,

88
89
90
91
92
93
94
    // #[builder(default)]
    // custom: Vec<axum::Router>
    #[builder(default = "true")]
    enable_chat_endpoints: bool,

    #[builder(default = "true")]
    enable_cmpl_endpoints: bool,
95

96
    #[builder(default = "true")]
97
98
    enable_embeddings_endpoints: bool,

99
100
101
    #[builder(default = "true")]
    enable_responses_endpoints: bool,

102
103
    #[builder(default = "None")]
    request_template: Option<RequestTemplate>,
104
105
106

    #[builder(default = "None")]
    etcd_client: Option<etcd::Client>,
107
108
109
110
111
112
113
}

impl HttpService {
    pub fn builder() -> HttpServiceConfigBuilder {
        HttpServiceConfigBuilder::default()
    }

114
115
116
117
118
119
120
121
    pub fn state_clone(&self) -> Arc<State> {
        self.state.clone()
    }

    pub fn state(&self) -> &State {
        Arc::as_ref(&self.state)
    }

122
    pub fn model_manager(&self) -> &ModelManager {
123
        self.state().manager()
124
125
    }

126
127
128
129
130
131
    pub async fn spawn(&self, cancel_token: CancellationToken) -> JoinHandle<Result<()>> {
        let this = self.clone();
        tokio::spawn(async move { this.run(cancel_token).await })
    }

    pub async fn run(&self, cancel_token: CancellationToken) -> Result<()> {
132
        let address = format!("{}:{}", self.host, self.port);
133
134
135
136
137
138
139
140
141
        tracing::info!(address, "Starting HTTP service on: {address}");

        let listener = tokio::net::TcpListener::bind(address.as_str())
            .await
            .unwrap_or_else(|_| panic!("could not bind to address: {address}"));

        let router = self.router.clone();
        let observer = cancel_token.child_token();

142
        axum::serve(listener, router)
143
144
            .with_graceful_shutdown(observer.cancelled_owned())
            .await
145
146
147
            .inspect_err(|_| cancel_token.cancel())?;

        Ok(())
148
    }
149
150
151
152
153

    /// Documentation of exposed HTTP endpoints
    pub fn route_docs(&self) -> &[RouteDoc] {
        &self.route_docs
    }
154
155
}

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
/// Environment variable to set the metrics endpoint path (default: `/metrics`)
static HTTP_SVC_METRICS_PATH_ENV: &str = "DYN_HTTP_SVC_METRICS_PATH";
/// Environment variable to set the models endpoint path (default: `/v1/models`)
static HTTP_SVC_MODELS_PATH_ENV: &str = "DYN_HTTP_SVC_MODELS_PATH";
/// Environment variable to set the health endpoint path (default: `/health`)
static HTTP_SVC_HEALTH_PATH_ENV: &str = "DYN_HTTP_SVC_HEALTH_PATH";
/// Environment variable to set the live endpoint path (default: `/live`)
static HTTP_SVC_LIVE_PATH_ENV: &str = "DYN_HTTP_SVC_LIVE_PATH";
/// Environment variable to set the chat completions endpoint path (default: `/v1/chat/completions`)
static HTTP_SVC_CHAT_PATH_ENV: &str = "DYN_HTTP_SVC_CHAT_PATH";
/// Environment variable to set the completions endpoint path (default: `/v1/completions`)
static HTTP_SVC_CMP_PATH_ENV: &str = "DYN_HTTP_SVC_CMP_PATH";
/// Environment variable to set the embeddings endpoint path (default: `/v1/embeddings`)
static HTTP_SVC_EMB_PATH_ENV: &str = "DYN_HTTP_SVC_EMB_PATH";
/// Environment variable to set the responses endpoint path (default: `/v1/responses`)
static HTTP_SVC_RESPONSES_PATH_ENV: &str = "DYN_HTTP_SVC_RESPONSES_PATH";

173
174
impl HttpServiceConfigBuilder {
    pub fn build(self) -> Result<HttpService, anyhow::Error> {
175
        let config: HttpServiceConfig = self.build_internal()?;
176

177
        let model_manager = Arc::new(ModelManager::new());
178
        let state = Arc::new(State::new_with_etcd(model_manager, config.etcd_client));
179
180
181

        // enable prometheus metrics
        let registry = metrics::Registry::new();
182
        state.metrics_clone().register(&registry)?;
183
184

        let mut router = axum::Router::new();
185

186
187
188
        let mut all_docs = Vec::new();

        let mut routes = vec![
189
190
191
192
            metrics::router(registry, var(HTTP_SVC_METRICS_PATH_ENV).ok()),
            super::openai::list_models_router(state.clone(), var(HTTP_SVC_MODELS_PATH_ENV).ok()),
            super::health::health_check_router(state.clone(), var(HTTP_SVC_HEALTH_PATH_ENV).ok()),
            super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()),
193
194
195
        ];

        if config.enable_chat_endpoints {
196
            routes.push(super::openai::chat_completions_router(
197
                state.clone(),
198
                config.request_template.clone(), // TODO clone()? reference?
199
                var(HTTP_SVC_CHAT_PATH_ENV).ok(),
200
201
202
203
            ));
        }

        if config.enable_cmpl_endpoints {
204
205
206
207
            routes.push(super::openai::completions_router(
                state.clone(),
                var(HTTP_SVC_CMP_PATH_ENV).ok(),
            ));
208
209
        }

210
        if config.enable_embeddings_endpoints {
211
212
213
214
            routes.push(super::openai::embeddings_router(
                state.clone(),
                var(HTTP_SVC_EMB_PATH_ENV).ok(),
            ));
215
216
        }

217
218
219
220
        if config.enable_responses_endpoints {
            routes.push(super::openai::responses_router(
                state.clone(),
                config.request_template,
221
                var(HTTP_SVC_RESPONSES_PATH_ENV).ok(),
222
223
224
            ));
        }

225
226
227
228
229
230
231
232
233
234
        // for (route_docs, route) in routes.into_iter().chain(self.routes.into_iter()) {
        //     router = router.merge(route);
        //     all_docs.extend(route_docs);
        // }

        for (route_docs, route) in routes.into_iter() {
            router = router.merge(route);
            all_docs.extend(route_docs);
        }

235
236
237
        // Add span for tracing
        router = router.layer(TraceLayer::new_for_http().make_span_with(make_request_span));

238
        Ok(HttpService {
239
            state,
240
241
            router,
            port: config.port,
242
            host: config.host,
243
            route_docs: all_docs,
244
245
        })
    }
246
247
248
249
250

    pub fn with_request_template(mut self, request_template: Option<RequestTemplate>) -> Self {
        self.request_template = Some(request_template);
        self
    }
251
252
253
254
255

    pub fn with_etcd_client(mut self, etcd_client: Option<etcd::Client>) -> Self {
        self.etcd_client = Some(etcd_client);
        self
    }
256
}