test_app.rs 6.89 KB
Newer Older
1
2
use std::sync::{Arc, OnceLock};

3
4
5
use axum::Router;
use reqwest::Client;
use sglang_router_rs::{
6
    app_context::AppContext,
7
    config::RouterConfig,
8
9
10
11
    core::{LoadMonitor, WorkerRegistry},
    data_connector::{
        MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
    },
12
    mcp::{McpConfig, McpManager},
13
14
    middleware::{AuthConfig, TokenBucket},
    policies::PolicyRegistry,
15
    routers::RouterTrait,
16
    server::{build_app, AppState},
17
18
19
};

/// Create a test Axum application using the actual server's build_app function
20
#[allow(dead_code)]
21
22
23
24
25
pub fn create_test_app(
    router: Arc<dyn RouterTrait>,
    client: Client,
    router_config: &RouterConfig,
) -> Router {
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    // Initialize rate limiter
    let rate_limiter = match router_config.max_concurrent_requests {
        n if n <= 0 => None,
        n => {
            let rate_limit_tokens = router_config
                .rate_limit_tokens_per_second
                .filter(|&t| t > 0)
                .unwrap_or(n);
            Some(Arc::new(TokenBucket::new(
                n as usize,
                rate_limit_tokens as usize,
            )))
        }
    };

    // Initialize registries
    let worker_registry = Arc::new(WorkerRegistry::new());
    let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone()));

    // Initialize storage backends
    let response_storage = Arc::new(MemoryResponseStorage::new());
    let conversation_storage = Arc::new(MemoryConversationStorage::new());
    let conversation_item_storage = Arc::new(MemoryConversationItemStorage::new());

    // Initialize load monitor
    let load_monitor = Some(Arc::new(LoadMonitor::new(
        worker_registry.clone(),
        policy_registry.clone(),
        client.clone(),
        router_config.worker_startup_check_interval_secs,
    )));

58
    // Create empty OnceLock for worker job queue and workflow engine
59
    let worker_job_queue = Arc::new(OnceLock::new());
60
    let workflow_engine = Arc::new(OnceLock::new());
61

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    // Create AppContext using builder pattern
    let app_context = Arc::new(
        AppContext::builder()
            .router_config(router_config.clone())
            .client(client)
            .rate_limiter(rate_limiter)
            .tokenizer(None) // tokenizer
            .reasoning_parser_factory(None) // reasoning_parser_factory
            .tool_parser_factory(None) // tool_parser_factory
            .worker_registry(worker_registry)
            .policy_registry(policy_registry)
            .response_storage(response_storage)
            .conversation_storage(conversation_storage)
            .conversation_item_storage(conversation_item_storage)
            .load_monitor(load_monitor)
            .worker_job_queue(worker_job_queue)
            .workflow_engine(workflow_engine)
            .build()
            .unwrap(),
    );
82
83

    // Create AppState with the test router and context
84
85
    let app_state = Arc::new(AppState {
        router,
86
        context: app_context,
87
88
        concurrency_queue_tx: None,
        router_manager: None,
89
90
91
92
93
94
95
96
97
98
99
100
    });

    // Configure request ID headers (use defaults if not specified)
    let request_id_headers = router_config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

101
102
103
104
105
    // Create auth config from router config
    let auth_config = AuthConfig {
        api_key: router_config.api_key.clone(),
    };

106
107
108
    // Use the actual server's build_app function
    build_app(
        app_state,
109
        auth_config,
110
111
112
113
114
        router_config.max_payload_size,
        request_id_headers,
        router_config.cors_allowed_origins.clone(),
    )
}
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142

/// Create a test Axum application with an existing AppContext
#[allow(dead_code)]
pub fn create_test_app_with_context(
    router: Arc<dyn RouterTrait>,
    app_context: Arc<AppContext>,
) -> Router {
    // Create AppState with the test router and context
    let app_state = Arc::new(AppState {
        router,
        context: app_context.clone(),
        concurrency_queue_tx: None,
        router_manager: None,
    });

    // Get config from the context
    let router_config = &app_context.router_config;

    // Configure request ID headers (use defaults if not specified)
    let request_id_headers = router_config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

143
144
145
146
147
    // Create auth config from router config
    let auth_config = AuthConfig {
        api_key: router_config.api_key.clone(),
    };

148
149
150
    // Use the actual server's build_app function
    build_app(
        app_state,
151
        auth_config,
152
153
154
155
156
        router_config.max_payload_size,
        request_id_headers,
        router_config.cors_allowed_origins.clone(),
    )
}
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211

/// Create a minimal test AppContext for unit tests
#[allow(dead_code)]
pub async fn create_test_app_context() -> Arc<AppContext> {
    let router_config = RouterConfig::default();
    let client = Client::new();

    // Initialize empty OnceLocks
    let worker_job_queue = Arc::new(OnceLock::new());
    let workflow_engine = Arc::new(OnceLock::new());

    // Initialize MCP manager with empty config
    let mcp_manager_lock = Arc::new(OnceLock::new());
    let empty_config = McpConfig {
        servers: vec![],
        pool: Default::default(),
        proxy: None,
        warmup: vec![],
        inventory: Default::default(),
    };
    let mcp_manager = McpManager::with_defaults(empty_config)
        .await
        .expect("Failed to create MCP manager");
    mcp_manager_lock.set(Arc::new(mcp_manager)).ok();

    // Initialize registries
    let worker_registry = Arc::new(WorkerRegistry::new());
    let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone()));

    // Initialize storage backends
    let response_storage = Arc::new(MemoryResponseStorage::new());
    let conversation_storage = Arc::new(MemoryConversationStorage::new());
    let conversation_item_storage = Arc::new(MemoryConversationItemStorage::new());

    Arc::new(
        AppContext::builder()
            .router_config(router_config)
            .client(client)
            .rate_limiter(None)
            .tokenizer(None)
            .reasoning_parser_factory(None)
            .tool_parser_factory(None)
            .worker_registry(worker_registry)
            .policy_registry(policy_registry)
            .response_storage(response_storage)
            .conversation_storage(conversation_storage)
            .conversation_item_storage(conversation_item_storage)
            .load_monitor(None)
            .worker_job_queue(worker_job_queue)
            .workflow_engine(workflow_engine)
            .mcp_manager(mcp_manager_lock)
            .build()
            .unwrap(),
    )
}