test_app.rs 4.41 KB
Newer Older
1
2
3
4
use axum::Router;
use reqwest::Client;
use sglang_router_rs::{
    config::RouterConfig,
5
6
7
8
9
10
    core::{LoadMonitor, WorkerRegistry},
    data_connector::{
        MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
    },
    middleware::{AuthConfig, TokenBucket},
    policies::PolicyRegistry,
11
    routers::RouterTrait,
12
    server::{build_app, AppContext, AppState},
13
};
14
use std::sync::{Arc, OnceLock};
15
16

/// Create a test Axum application using the actual server's build_app function
17
#[allow(dead_code)]
18
19
20
21
22
pub fn create_test_app(
    router: Arc<dyn RouterTrait>,
    client: Client,
    router_config: &RouterConfig,
) -> Router {
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    // Initialize rate limiter
    let rate_limiter = match router_config.max_concurrent_requests {
        n if n <= 0 => None,
        n => {
            let rate_limit_tokens = router_config
                .rate_limit_tokens_per_second
                .filter(|&t| t > 0)
                .unwrap_or(n);
            Some(Arc::new(TokenBucket::new(
                n as usize,
                rate_limit_tokens as usize,
            )))
        }
    };

    // Initialize registries
    let worker_registry = Arc::new(WorkerRegistry::new());
    let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone()));

    // Initialize storage backends
    let response_storage = Arc::new(MemoryResponseStorage::new());
    let conversation_storage = Arc::new(MemoryConversationStorage::new());
    let conversation_item_storage = Arc::new(MemoryConversationItemStorage::new());

    // Initialize load monitor
    let load_monitor = Some(Arc::new(LoadMonitor::new(
        worker_registry.clone(),
        policy_registry.clone(),
        client.clone(),
        router_config.worker_startup_check_interval_secs,
    )));

    // Create empty OnceLock for worker job queue
    let worker_job_queue = Arc::new(OnceLock::new());

58
    // Create AppContext
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    let app_context = Arc::new(AppContext::new(
        router_config.clone(),
        client,
        rate_limiter,
        None, // tokenizer
        None, // reasoning_parser_factory
        None, // tool_parser_factory
        worker_registry,
        policy_registry,
        response_storage,
        conversation_storage,
        conversation_item_storage,
        load_monitor,
        worker_job_queue,
    ));
74
75

    // Create AppState with the test router and context
76
77
    let app_state = Arc::new(AppState {
        router,
78
        context: app_context,
79
80
        concurrency_queue_tx: None,
        router_manager: None,
81
82
83
84
85
86
87
88
89
90
91
92
    });

    // Configure request ID headers (use defaults if not specified)
    let request_id_headers = router_config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

93
94
95
96
97
    // Create auth config from router config
    let auth_config = AuthConfig {
        api_key: router_config.api_key.clone(),
    };

98
99
100
    // Use the actual server's build_app function
    build_app(
        app_state,
101
        auth_config,
102
103
104
105
106
        router_config.max_payload_size,
        request_id_headers,
        router_config.cors_allowed_origins.clone(),
    )
}
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

/// Create a test Axum application with an existing AppContext
#[allow(dead_code)]
pub fn create_test_app_with_context(
    router: Arc<dyn RouterTrait>,
    app_context: Arc<AppContext>,
) -> Router {
    // Create AppState with the test router and context
    let app_state = Arc::new(AppState {
        router,
        context: app_context.clone(),
        concurrency_queue_tx: None,
        router_manager: None,
    });

    // Get config from the context
    let router_config = &app_context.router_config;

    // Configure request ID headers (use defaults if not specified)
    let request_id_headers = router_config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

135
136
137
138
139
    // Create auth config from router config
    let auth_config = AuthConfig {
        api_key: router_config.api_key.clone(),
    };

140
141
142
    // Use the actual server's build_app function
    build_app(
        app_state,
143
        auth_config,
144
145
146
147
148
        router_config.max_payload_size,
        request_id_headers,
        router_config.cors_allowed_origins.clone(),
    )
}