service_v2.rs 7.09 KB
Newer Older
1
2
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
3

4
use std::env::var;
5
6
use std::sync::Arc;
use std::time::Duration;
7
8

use super::metrics;
9
use super::Metrics;
10
use super::RouteDoc;
11
use crate::discovery::ModelManager;
12
use crate::request_template::RequestTemplate;
13
use anyhow::Result;
14
use derive_builder::Builder;
15
use tokio::task::JoinHandle;
16
17
use tokio_util::sync::CancellationToken;

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
/// HTTP service shared state
pub struct State {
    metrics: Arc<Metrics>,
    manager: Arc<ModelManager>,
}

impl State {
    pub fn new(manager: Arc<ModelManager>) -> Self {
        Self {
            manager,
            metrics: Arc::new(Metrics::default()),
        }
    }

    /// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests
    pub fn metrics_clone(&self) -> Arc<Metrics> {
        self.metrics.clone()
    }

    pub fn manager(&self) -> &ModelManager {
        Arc::as_ref(&self.manager)
    }

    pub fn manager_clone(&self) -> Arc<ModelManager> {
        self.manager.clone()
    }

    // TODO
    pub fn sse_keep_alive(&self) -> Option<Duration> {
        None
    }
}

51
52
#[derive(Clone)]
pub struct HttpService {
53
54
55
    // The state we share with every request handler
    state: Arc<State>,

56
57
    router: axum::Router,
    port: u16,
58
    host: String,
59
    route_docs: Vec<RouteDoc>,
60
61
62
}

#[derive(Clone, Builder)]
63
#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
64
65
66
67
pub struct HttpServiceConfig {
    #[builder(default = "8787")]
    port: u16,

68
69
70
    #[builder(setter(into), default = "String::from(\"0.0.0.0\")")]
    host: String,

71
72
73
74
75
76
77
    // #[builder(default)]
    // custom: Vec<axum::Router>
    #[builder(default = "true")]
    enable_chat_endpoints: bool,

    #[builder(default = "true")]
    enable_cmpl_endpoints: bool,
78

79
    #[builder(default = "true")]
80
81
    enable_embeddings_endpoints: bool,

82
83
84
    #[builder(default = "true")]
    enable_responses_endpoints: bool,

85
86
    #[builder(default = "None")]
    request_template: Option<RequestTemplate>,
87
88
89
90
91
92
93
}

impl HttpService {
    pub fn builder() -> HttpServiceConfigBuilder {
        HttpServiceConfigBuilder::default()
    }

94
95
96
97
98
99
100
101
    pub fn state_clone(&self) -> Arc<State> {
        self.state.clone()
    }

    pub fn state(&self) -> &State {
        Arc::as_ref(&self.state)
    }

102
    pub fn model_manager(&self) -> &ModelManager {
103
        self.state().manager()
104
105
    }

106
107
108
109
110
111
    pub async fn spawn(&self, cancel_token: CancellationToken) -> JoinHandle<Result<()>> {
        let this = self.clone();
        tokio::spawn(async move { this.run(cancel_token).await })
    }

    pub async fn run(&self, cancel_token: CancellationToken) -> Result<()> {
112
        let address = format!("{}:{}", self.host, self.port);
113
114
115
116
117
118
119
120
121
        tracing::info!(address, "Starting HTTP service on: {address}");

        let listener = tokio::net::TcpListener::bind(address.as_str())
            .await
            .unwrap_or_else(|_| panic!("could not bind to address: {address}"));

        let router = self.router.clone();
        let observer = cancel_token.child_token();

122
        axum::serve(listener, router)
123
124
            .with_graceful_shutdown(observer.cancelled_owned())
            .await
125
126
127
            .inspect_err(|_| cancel_token.cancel())?;

        Ok(())
128
    }
129
130
131
132
133

    /// Documentation of exposed HTTP endpoints
    pub fn route_docs(&self) -> &[RouteDoc] {
        &self.route_docs
    }
134
135
}

136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
/// Environment variable to set the metrics endpoint path (default: `/metrics`)
static HTTP_SVC_METRICS_PATH_ENV: &str = "DYN_HTTP_SVC_METRICS_PATH";
/// Environment variable to set the models endpoint path (default: `/v1/models`)
static HTTP_SVC_MODELS_PATH_ENV: &str = "DYN_HTTP_SVC_MODELS_PATH";
/// Environment variable to set the health endpoint path (default: `/health`)
static HTTP_SVC_HEALTH_PATH_ENV: &str = "DYN_HTTP_SVC_HEALTH_PATH";
/// Environment variable to set the live endpoint path (default: `/live`)
static HTTP_SVC_LIVE_PATH_ENV: &str = "DYN_HTTP_SVC_LIVE_PATH";
/// Environment variable to set the chat completions endpoint path (default: `/v1/chat/completions`)
static HTTP_SVC_CHAT_PATH_ENV: &str = "DYN_HTTP_SVC_CHAT_PATH";
/// Environment variable to set the completions endpoint path (default: `/v1/completions`)
static HTTP_SVC_CMP_PATH_ENV: &str = "DYN_HTTP_SVC_CMP_PATH";
/// Environment variable to set the embeddings endpoint path (default: `/v1/embeddings`)
static HTTP_SVC_EMB_PATH_ENV: &str = "DYN_HTTP_SVC_EMB_PATH";
/// Environment variable to set the responses endpoint path (default: `/v1/responses`)
static HTTP_SVC_RESPONSES_PATH_ENV: &str = "DYN_HTTP_SVC_RESPONSES_PATH";

153
154
impl HttpServiceConfigBuilder {
    pub fn build(self) -> Result<HttpService, anyhow::Error> {
155
        let config: HttpServiceConfig = self.build_internal()?;
156

157
        let model_manager = Arc::new(ModelManager::new());
158
        let state = Arc::new(State::new(model_manager));
159
160
161

        // enable prometheus metrics
        let registry = metrics::Registry::new();
162
        state.metrics_clone().register(&registry)?;
163
164

        let mut router = axum::Router::new();
165

166
167
168
        let mut all_docs = Vec::new();

        let mut routes = vec![
169
170
171
172
            metrics::router(registry, var(HTTP_SVC_METRICS_PATH_ENV).ok()),
            super::openai::list_models_router(state.clone(), var(HTTP_SVC_MODELS_PATH_ENV).ok()),
            super::health::health_check_router(state.clone(), var(HTTP_SVC_HEALTH_PATH_ENV).ok()),
            super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()),
173
174
175
        ];

        if config.enable_chat_endpoints {
176
            routes.push(super::openai::chat_completions_router(
177
                state.clone(),
178
                config.request_template.clone(), // TODO clone()? reference?
179
                var(HTTP_SVC_CHAT_PATH_ENV).ok(),
180
181
182
183
            ));
        }

        if config.enable_cmpl_endpoints {
184
185
186
187
            routes.push(super::openai::completions_router(
                state.clone(),
                var(HTTP_SVC_CMP_PATH_ENV).ok(),
            ));
188
189
        }

190
        if config.enable_embeddings_endpoints {
191
192
193
194
            routes.push(super::openai::embeddings_router(
                state.clone(),
                var(HTTP_SVC_EMB_PATH_ENV).ok(),
            ));
195
196
        }

197
198
199
200
        if config.enable_responses_endpoints {
            routes.push(super::openai::responses_router(
                state.clone(),
                config.request_template,
201
                var(HTTP_SVC_RESPONSES_PATH_ENV).ok(),
202
203
204
            ));
        }

205
206
207
208
209
210
211
212
213
214
215
        // for (route_docs, route) in routes.into_iter().chain(self.routes.into_iter()) {
        //     router = router.merge(route);
        //     all_docs.extend(route_docs);
        // }

        for (route_docs, route) in routes.into_iter() {
            router = router.merge(route);
            all_docs.extend(route_docs);
        }

        Ok(HttpService {
216
            state,
217
218
            router,
            port: config.port,
219
            host: config.host,
220
            route_docs: all_docs,
221
222
        })
    }
223
224
225
226
227

    pub fn with_request_template(mut self, request_template: Option<RequestTemplate>) -> Self {
        self.request_template = Some(request_template);
        self
    }
228
}