service_v2.rs 5.98 KB
Newer Older
1
2
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
3
4
5

use std::sync::Arc;
use std::time::Duration;
6
7

use super::metrics;
8
use super::Metrics;
9
use super::RouteDoc;
10
use crate::discovery::ModelManager;
11
use crate::request_template::RequestTemplate;
12
use anyhow::Result;
13
use derive_builder::Builder;
14
use dynamo_runtime::DistributedRuntime;
15
use tokio::task::JoinHandle;
16
17
use tokio_util::sync::CancellationToken;

18
19
20
21
/// HTTP service shared state
pub struct State {
    metrics: Arc<Metrics>,
    manager: Arc<ModelManager>,
22
    runtime: Option<Arc<DistributedRuntime>>,
23
24
25
26
27
28
29
}

impl State {
    pub fn new(manager: Arc<ModelManager>) -> Self {
        Self {
            manager,
            metrics: Arc::new(Metrics::default()),
30
31
32
33
34
35
36
37
38
            runtime: None,
        }
    }

    pub fn with_runtime(manager: Arc<ModelManager>, runtime: Arc<DistributedRuntime>) -> Self {
        Self {
            manager,
            metrics: Arc::new(Metrics::default()),
            runtime: Some(runtime),
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
        }
    }

    /// Get the Prometheus [`Metrics`] object which tracks request counts and inflight requests
    pub fn metrics_clone(&self) -> Arc<Metrics> {
        self.metrics.clone()
    }

    pub fn manager(&self) -> &ModelManager {
        Arc::as_ref(&self.manager)
    }

    pub fn manager_clone(&self) -> Arc<ModelManager> {
        self.manager.clone()
    }

55
56
57
58
59
    /// Get the DistributedRuntime if available
    pub fn runtime(&self) -> Option<&DistributedRuntime> {
        self.runtime.as_ref().map(|r| r.as_ref())
    }

60
61
62
63
64
65
    // TODO
    pub fn sse_keep_alive(&self) -> Option<Duration> {
        None
    }
}

66
67
#[derive(Clone)]
pub struct HttpService {
68
69
70
    // The state we share with every request handler
    state: Arc<State>,

71
72
    router: axum::Router,
    port: u16,
73
    host: String,
74
    route_docs: Vec<RouteDoc>,
75
76
77
}

#[derive(Clone, Builder)]
78
#[builder(pattern = "owned", build_fn(private, name = "build_internal"))]
79
80
81
82
pub struct HttpServiceConfig {
    #[builder(default = "8787")]
    port: u16,

83
84
85
    #[builder(setter(into), default = "String::from(\"0.0.0.0\")")]
    host: String,

86
87
88
89
90
91
92
    // #[builder(default)]
    // custom: Vec<axum::Router>
    #[builder(default = "true")]
    enable_chat_endpoints: bool,

    #[builder(default = "true")]
    enable_cmpl_endpoints: bool,
93

94
    #[builder(default = "true")]
95
96
    enable_embeddings_endpoints: bool,

97
98
    #[builder(default = "None")]
    request_template: Option<RequestTemplate>,
99
100
101

    #[builder(default = "None")]
    runtime: Option<Arc<DistributedRuntime>>,
102
103
104
105
106
107
108
}

impl HttpService {
    pub fn builder() -> HttpServiceConfigBuilder {
        HttpServiceConfigBuilder::default()
    }

109
110
111
112
113
114
115
116
    pub fn state_clone(&self) -> Arc<State> {
        self.state.clone()
    }

    pub fn state(&self) -> &State {
        Arc::as_ref(&self.state)
    }

117
    pub fn model_manager(&self) -> &ModelManager {
118
        self.state().manager()
119
120
    }

121
122
123
124
125
126
    pub async fn spawn(&self, cancel_token: CancellationToken) -> JoinHandle<Result<()>> {
        let this = self.clone();
        tokio::spawn(async move { this.run(cancel_token).await })
    }

    pub async fn run(&self, cancel_token: CancellationToken) -> Result<()> {
127
        let address = format!("{}:{}", self.host, self.port);
128
129
130
131
132
133
134
135
136
        tracing::info!(address, "Starting HTTP service on: {address}");

        let listener = tokio::net::TcpListener::bind(address.as_str())
            .await
            .unwrap_or_else(|_| panic!("could not bind to address: {address}"));

        let router = self.router.clone();
        let observer = cancel_token.child_token();

137
        axum::serve(listener, router)
138
139
            .with_graceful_shutdown(observer.cancelled_owned())
            .await
140
141
142
            .inspect_err(|_| cancel_token.cancel())?;

        Ok(())
143
    }
144
145
146
147
148

    /// Documentation of exposed HTTP endpoints
    pub fn route_docs(&self) -> &[RouteDoc] {
        &self.route_docs
    }
149
150
151
152
}

impl HttpServiceConfigBuilder {
    pub fn build(self) -> Result<HttpService, anyhow::Error> {
153
        let config: HttpServiceConfig = self.build_internal()?;
154

155
        let model_manager = Arc::new(ModelManager::new());
156
157
158
159
160
        let state = if let Some(runtime) = config.runtime {
            Arc::new(State::with_runtime(model_manager, runtime))
        } else {
            Arc::new(State::new(model_manager))
        };
161
162
163

        // enable prometheus metrics
        let registry = metrics::Registry::new();
164
        state.metrics_clone().register(&registry)?;
165
166

        let mut router = axum::Router::new();
167

168
169
170
171
        let mut all_docs = Vec::new();

        let mut routes = vec![
            metrics::router(registry, None),
172
            super::openai::list_models_router(state.clone(), None),
173
            super::health::health_check_router(state.clone(), None),
174
            super::clear_kv_blocks::clear_kv_blocks_router(state.clone(), None),
175
176
177
        ];

        if config.enable_chat_endpoints {
178
            routes.push(super::openai::chat_completions_router(
179
                state.clone(),
180
                config.request_template,
181
182
183
184
185
                None,
            ));
        }

        if config.enable_cmpl_endpoints {
186
            routes.push(super::openai::completions_router(state.clone(), None));
187
188
        }

189
        if config.enable_embeddings_endpoints {
190
            routes.push(super::openai::embeddings_router(state.clone(), None));
191
192
        }

193
194
195
196
197
198
199
200
201
202
203
        // for (route_docs, route) in routes.into_iter().chain(self.routes.into_iter()) {
        //     router = router.merge(route);
        //     all_docs.extend(route_docs);
        // }

        for (route_docs, route) in routes.into_iter() {
            router = router.merge(route);
            all_docs.extend(route_docs);
        }

        Ok(HttpService {
204
            state,
205
206
            router,
            port: config.port,
207
            host: config.host,
208
            route_docs: all_docs,
209
210
        })
    }
211
212
213
214
215

    pub fn with_request_template(mut self, request_template: Option<RequestTemplate>) -> Self {
        self.request_template = Some(request_template);
        self
    }
216
}