http_server.rs 13.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

16
17
use crate::metrics::MetricsRegistry;
use crate::traits::DistributedRuntimeProvider;
18
19
use axum::{body, http::StatusCode, response::IntoResponse, routing::get, Router};
use std::sync::Arc;
20
21
use std::sync::OnceLock;
use std::time::Instant;
22
23
24
25
use tokio::net::TcpListener;
use tokio_util::sync::CancellationToken;
use tracing;

26
27
pub struct HttpMetricsRegistry {
    pub drt: Arc<crate::DistributedRuntime>,
28
29
}

30
31
32
33
34
impl crate::traits::DistributedRuntimeProvider for HttpMetricsRegistry {
    fn drt(&self) -> &crate::DistributedRuntime {
        &self.drt
    }
}
35

36
37
38
impl MetricsRegistry for HttpMetricsRegistry {
    fn basename(&self) -> String {
        "http_server".to_string()
39
40
    }

41
42
    fn parent_hierarchy(&self) -> Vec<String> {
        [self.drt().parent_hierarchy(), vec![self.drt().basename()]].concat()
43
44
45
    }
}

46
/// HTTP server state containing metrics and uptime tracking
47
pub struct HttpServerState {
48
49
50
51
    // global drt registry is for printing out the entire Prometheus format output
    root_drt: Arc<crate::DistributedRuntime>,
    start_time: OnceLock<Instant>,
    uptime_gauge: Arc<prometheus::Gauge>,
52
53
54
}

impl HttpServerState {
55
    /// Create new HTTP server state with the provided metrics registry
56
    pub fn new(drt: Arc<crate::DistributedRuntime>) -> anyhow::Result<Self> {
57
58
59
60
61
62
63
64
65
66
67
68
69
        let http_metrics_registry = Arc::new(HttpMetricsRegistry { drt: drt.clone() });
        let uptime_gauge = http_metrics_registry.as_ref().create_gauge(
            "uptime_seconds",
            "Total uptime of the DistributedRuntime in seconds",
            &[],
        )?;
        let state = Self {
            root_drt: drt,
            start_time: OnceLock::new(),
            uptime_gauge,
        };
        Ok(state)
    }
70

71
72
73
74
75
76
    /// Initialize the start time (can only be called once)
    pub fn initialize_start_time(&self) -> Result<(), &'static str> {
        self.start_time
            .set(Instant::now())
            .map_err(|_| "Start time already initialized")
    }
77

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    pub fn uptime(&self) -> Result<std::time::Duration, &'static str> {
        self.start_time
            .get()
            .ok_or("Start time not initialized")
            .map(|start_time| start_time.elapsed())
    }

    /// Get a reference to the distributed runtime
    pub fn drt(&self) -> &crate::DistributedRuntime {
        &self.root_drt
    }

    /// Update the uptime gauge with current value
    pub fn update_uptime_gauge(&self) {
        if let Ok(uptime) = self.uptime() {
            let uptime_seconds = uptime.as_secs_f64();
            self.uptime_gauge.set(uptime_seconds);
        } else {
            tracing::warn!("Failed to update uptime gauge: start time not initialized");
        }
98
99
100
    }
}

101
/// Start HTTP server with metrics support
102
pub async fn spawn_http_server(
103
104
105
106
    host: &str,
    port: u16,
    cancel_token: CancellationToken,
    drt: Arc<crate::DistributedRuntime>,
107
) -> anyhow::Result<(std::net::SocketAddr, tokio::task::JoinHandle<()>)> {
108
    // Create HTTP server state with the provided metrics registry
109
110
    let server_state = Arc::new(HttpServerState::new(drt)?);

111
112
113
114
115
    // Initialize the start time
    server_state
        .initialize_start_time()
        .map_err(|e| anyhow::anyhow!("Failed to initialize start time: {}", e))?;

116
    let app = Router::new()
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        .route(
            "/health",
            get({
                let state = Arc::clone(&server_state);
                move || health_handler(state.clone())
            }),
        )
        .route(
            "/live",
            get({
                let state = Arc::clone(&server_state);
                move || health_handler(state)
            }),
        )
131
132
133
134
135
136
        .route(
            "/metrics",
            get({
                let state = Arc::clone(&server_state);
                move || metrics_handler(state)
            }),
137
138
139
140
141
        )
        .fallback(|| async {
            tracing::info!("[fallback handler] called");
            (StatusCode::NOT_FOUND, "Route not found").into_response()
        });
142
143

    let address = format!("{}:{}", host, port);
144
    tracing::info!("[spawn_http_server] binding to: {}", address);
145
146
147
148
149

    let listener = match TcpListener::bind(&address).await {
        Ok(listener) => {
            // get the actual address and port, print in debug level
            let actual_address = listener.local_addr()?;
150
151
152
153
154
            tracing::info!(
                "[spawn_http_server] HTTP server bound to: {}",
                actual_address
            );
            (listener, actual_address)
155
156
157
158
159
160
        }
        Err(e) => {
            tracing::error!("Failed to bind to address {}: {}", address, e);
            return Err(anyhow::anyhow!("Failed to bind to address: {}", e));
        }
    };
161
    let (listener, actual_address) = listener;
162
163

    let observer = cancel_token.child_token();
164
165
166
167
168
169
170
171
172
173
    // Spawn the server in the background and return the handle
    let handle = tokio::spawn(async move {
        if let Err(e) = axum::serve(listener, app)
            .with_graceful_shutdown(observer.cancelled_owned())
            .await
        {
            tracing::error!("HTTP server error: {}", e);
        }
    });
    Ok((actual_address, handle))
174
175
}

176
177
/// Health handler
async fn health_handler(state: Arc<HttpServerState>) -> impl IntoResponse {
178
179
180
181
182
183
184
185
186
187
188
189
190
    match state.uptime() {
        Ok(uptime) => {
            let response = format!("OK\nUptime: {} seconds\n", uptime.as_secs());
            (StatusCode::OK, response)
        }
        Err(e) => {
            tracing::error!("Failed to get uptime: {}", e);
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                "Failed to get uptime".to_string(),
            )
        }
    }
191
}
192
193
194
195

/// Metrics handler with DistributedRuntime uptime
async fn metrics_handler(state: Arc<HttpServerState>) -> impl IntoResponse {
    // Update the uptime gauge with current value
196
    state.update_uptime_gauge();
197

198
199
200
    // Get metrics from the registry
    match state.drt().prometheus_metrics_fmt() {
        Ok(response) => (StatusCode::OK, response),
201
        Err(e) => {
202
            tracing::error!("Failed to get metrics from registry: {}", e);
203
204
            (
                StatusCode::INTERNAL_SERVER_ERROR,
205
                "Failed to get metrics".to_string(),
206
207
208
209
210
            )
        }
    }
}

211
212
213
214
215
216
217
218
219
220
221
222
223
// Regular tests: cargo test http_server --lib
// Integration tests: cargo test http_server --lib --features integration

#[cfg(test)]
/// Helper function to create a DRT instance for async testing
/// Uses the test-friendly constructor without discovery
async fn create_test_drt_async() -> crate::DistributedRuntime {
    let rt = crate::Runtime::from_current().unwrap();
    crate::DistributedRuntime::from_settings_without_discovery(rt)
        .await
        .unwrap()
}

224
225
226
#[cfg(test)]
mod tests {
    use super::*;
227
228
    use crate::metrics::MetricsRegistry;
    use std::sync::Arc;
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
    use tokio::time::{sleep, Duration};

    #[tokio::test]
    async fn test_http_server_lifecycle() {
        let cancel_token = CancellationToken::new();
        let cancel_token_for_server = cancel_token.clone();

        // Test basic HTTP server lifecycle without DistributedRuntime
        let app = Router::new().route("/test", get(|| async { (StatusCode::OK, "test") }));

        // start HTTP server
        let server_handle = tokio::spawn(async move {
            let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
            let _ = axum::serve(listener, app)
                .with_graceful_shutdown(cancel_token_for_server.cancelled_owned())
                .await;
        });

        // wait for a while to let the server start
        sleep(Duration::from_millis(100)).await;

        // cancel token
        cancel_token.cancel();

        // wait for the server to shut down
        let result = tokio::time::timeout(Duration::from_secs(5), server_handle).await;
        assert!(
            result.is_ok(),
            "HTTP server should shut down when cancel token is cancelled"
        );
    }

261
    #[cfg(feature = "integration")]
262
    #[tokio::test]
263
264
265
266
    async fn test_runtime_metrics_initialization_and_namespace() {
        // Test that metrics have correct namespace
        let drt = create_test_drt_async().await;
        let runtime_metrics = HttpServerState::new(Arc::new(drt)).unwrap();
267

268
269
        // Initialize start time
        runtime_metrics.initialize_start_time().unwrap();
270

271
        runtime_metrics.uptime_gauge.set(42.0);
272

273
274
        let response = runtime_metrics.drt().prometheus_metrics_fmt().unwrap();
        println!("Full metrics response:\n{}", response);
275

276
277
278
279
280
281
        let expected = "\
# HELP uptime_seconds Total uptime of the DistributedRuntime in seconds
# TYPE uptime_seconds gauge
uptime_seconds{namespace=\"http_server\"} 42
";
        assert_eq!(response, expected);
282
283
    }

284
    #[cfg(feature = "integration")]
285
    #[tokio::test]
286
287
288
289
    async fn test_start_time_initialization() {
        // Test that start time can only be initialized once
        let drt = create_test_drt_async().await;
        let runtime_metrics = HttpServerState::new(Arc::new(drt)).unwrap();
290

291
292
        // First initialization should succeed
        assert!(runtime_metrics.initialize_start_time().is_ok());
293

294
295
        // Second initialization should fail
        assert!(runtime_metrics.initialize_start_time().is_err());
296

297
298
299
        // Uptime should work after initialization
        let _uptime = runtime_metrics.uptime().unwrap();
        // If we get here, uptime calculation works correctly
300
    }
301

302
303
304
305
306
307
308
309
310
311
312
313
314
315
    #[cfg(feature = "integration")]
    #[tokio::test]
    async fn test_uptime_without_initialization() {
        // Test that uptime returns an error if start time is not initialized
        let drt = create_test_drt_async().await;
        let runtime_metrics = HttpServerState::new(Arc::new(drt)).unwrap();

        // This should return an error because start time is not initialized
        let result = runtime_metrics.uptime();
        assert!(result.is_err());
        assert_eq!(result.unwrap_err(), "Start time not initialized");
    }

    #[cfg(feature = "integration")]
316
317
318
319
    #[tokio::test]
    async fn test_spawn_http_server_endpoints() {
        // use reqwest for HTTP requests
        let cancel_token = CancellationToken::new();
320
321
322
323
324
        let drt = create_test_drt_async().await;
        let (addr, server_handle) =
            spawn_http_server("127.0.0.1", 0, cancel_token.clone(), Arc::new(drt))
                .await
                .unwrap();
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
        println!("[test] Waiting for server to start...");
        sleep(std::time::Duration::from_millis(1000)).await;
        println!("[test] Server should be up, starting requests...");
        let client = reqwest::Client::new();
        for (path, expect_200, expect_body) in [
            ("/health", true, "OK"),
            ("/live", true, "OK"),
            ("/someRandomPathNotFoundHere", false, "Route not found"),
        ] {
            println!("[test] Sending request to {}", path);
            let url = format!("http://{}{}", addr, path);
            let response = client.get(&url).send().await.unwrap();
            let status = response.status();
            let body = response.text().await.unwrap();
            println!(
                "[test] Response for {}: status={}, body={:?}",
                path, status, body
            );
            if expect_200 {
                assert_eq!(status, 200, "Response: status={}, body={:?}", status, body);
            } else {
                assert_eq!(status, 404, "Response: status={}, body={:?}", status, body);
            }
            assert!(
                body.contains(expect_body),
                "Response: status={}, body={:?}",
                status,
                body
            );
        }
        cancel_token.cancel();
        match server_handle.await {
            Ok(_) => println!("[test] Server shut down normally"),
            Err(e) => {
                if e.is_panic() {
                    println!("[test] Server panicked: {:?}", e);
                } else {
                    println!("[test] Server cancelled: {:?}", e);
                }
            }
        }
    }
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398

    #[cfg(feature = "integration")]
    #[tokio::test]
    async fn test_http_server_basic_functionality() {
        // Test basic HTTP server functionality without requiring etcd
        let cancel_token = CancellationToken::new();
        let cancel_token_for_server = cancel_token.clone();

        // Test basic HTTP server lifecycle
        let app = Router::new().route("/test", get(|| async { (StatusCode::OK, "test") }));

        // start HTTP server
        let server_handle = tokio::spawn(async move {
            let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
            let _ = axum::serve(listener, app)
                .with_graceful_shutdown(cancel_token_for_server.cancelled_owned())
                .await;
        });

        // wait for a while to let the server start
        sleep(Duration::from_millis(100)).await;

        // cancel token
        cancel_token.cancel();

        // wait for the server to shut down
        let result = tokio::time::timeout(Duration::from_secs(5), server_handle).await;
        assert!(
            result.is_ok(),
            "HTTP server should shut down when cancel token is cancelled"
        );
    }
399
}