test_app.rs 2.86 KB
Newer Older
1
2
3
4
use axum::Router;
use reqwest::Client;
use sglang_router_rs::{
    config::RouterConfig,
5
    middleware::AuthConfig,
6
    routers::RouterTrait,
7
    server::{build_app, AppContext, AppState},
8
9
10
11
};
use std::sync::Arc;

/// Create a test Axum application using the actual server's build_app function
12
#[allow(dead_code)]
13
14
15
16
17
pub fn create_test_app(
    router: Arc<dyn RouterTrait>,
    client: Client,
    router_config: &RouterConfig,
) -> Router {
18
    // Create AppContext
19
20
21
22
23
24
25
26
27
    let app_context = Arc::new(
        AppContext::new(
            router_config.clone(),
            client,
            router_config.max_concurrent_requests,
            router_config.rate_limit_tokens_per_second,
        )
        .expect("Failed to create AppContext in test"),
    );
28
29

    // Create AppState with the test router and context
30
31
    let app_state = Arc::new(AppState {
        router,
32
        context: app_context,
33
34
        concurrency_queue_tx: None,
        router_manager: None,
35
36
37
38
39
40
41
42
43
44
45
46
    });

    // Configure request ID headers (use defaults if not specified)
    let request_id_headers = router_config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

47
48
49
50
51
    // Create auth config from router config
    let auth_config = AuthConfig {
        api_key: router_config.api_key.clone(),
    };

52
53
54
    // Use the actual server's build_app function
    build_app(
        app_state,
55
        auth_config,
56
57
58
59
60
        router_config.max_payload_size,
        request_id_headers,
        router_config.cors_allowed_origins.clone(),
    )
}
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

/// Create a test Axum application with an existing AppContext
#[allow(dead_code)]
pub fn create_test_app_with_context(
    router: Arc<dyn RouterTrait>,
    app_context: Arc<AppContext>,
) -> Router {
    // Create AppState with the test router and context
    let app_state = Arc::new(AppState {
        router,
        context: app_context.clone(),
        concurrency_queue_tx: None,
        router_manager: None,
    });

    // Get config from the context
    let router_config = &app_context.router_config;

    // Configure request ID headers (use defaults if not specified)
    let request_id_headers = router_config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

89
90
91
92
93
    // Create auth config from router config
    let auth_config = AuthConfig {
        api_key: router_config.api_key.clone(),
    };

94
95
96
    // Use the actual server's build_app function
    build_app(
        app_state,
97
        auth_config,
98
99
100
101
102
        router_config.max_payload_size,
        request_id_headers,
        router_config.cors_allowed_origins.clone(),
    )
}