test_app.rs 4.9 KB
Newer Older
1
2
use std::sync::{Arc, OnceLock};

3
4
5
use axum::Router;
use reqwest::Client;
use sglang_router_rs::{
6
    app_context::AppContext,
7
    config::RouterConfig,
8
9
10
11
12
13
    core::{LoadMonitor, WorkerRegistry},
    data_connector::{
        MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
    },
    middleware::{AuthConfig, TokenBucket},
    policies::PolicyRegistry,
14
    routers::RouterTrait,
15
    server::{build_app, AppState},
16
17
18
};

/// Create a test Axum application using the actual server's build_app function
19
#[allow(dead_code)]
20
21
22
23
24
pub fn create_test_app(
    router: Arc<dyn RouterTrait>,
    client: Client,
    router_config: &RouterConfig,
) -> Router {
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    // Initialize rate limiter
    let rate_limiter = match router_config.max_concurrent_requests {
        n if n <= 0 => None,
        n => {
            let rate_limit_tokens = router_config
                .rate_limit_tokens_per_second
                .filter(|&t| t > 0)
                .unwrap_or(n);
            Some(Arc::new(TokenBucket::new(
                n as usize,
                rate_limit_tokens as usize,
            )))
        }
    };

    // Initialize registries
    let worker_registry = Arc::new(WorkerRegistry::new());
    let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone()));

    // Initialize storage backends
    let response_storage = Arc::new(MemoryResponseStorage::new());
    let conversation_storage = Arc::new(MemoryConversationStorage::new());
    let conversation_item_storage = Arc::new(MemoryConversationItemStorage::new());

    // Initialize load monitor
    let load_monitor = Some(Arc::new(LoadMonitor::new(
        worker_registry.clone(),
        policy_registry.clone(),
        client.clone(),
        router_config.worker_startup_check_interval_secs,
    )));

57
    // Create empty OnceLock for worker job queue and workflow engine
58
    let worker_job_queue = Arc::new(OnceLock::new());
59
    let workflow_engine = Arc::new(OnceLock::new());
60

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    // Create AppContext using builder pattern
    let app_context = Arc::new(
        AppContext::builder()
            .router_config(router_config.clone())
            .client(client)
            .rate_limiter(rate_limiter)
            .tokenizer(None) // tokenizer
            .reasoning_parser_factory(None) // reasoning_parser_factory
            .tool_parser_factory(None) // tool_parser_factory
            .worker_registry(worker_registry)
            .policy_registry(policy_registry)
            .response_storage(response_storage)
            .conversation_storage(conversation_storage)
            .conversation_item_storage(conversation_item_storage)
            .load_monitor(load_monitor)
            .worker_job_queue(worker_job_queue)
            .workflow_engine(workflow_engine)
            .build()
            .unwrap(),
    );
81
82

    // Create AppState with the test router and context
83
84
    let app_state = Arc::new(AppState {
        router,
85
        context: app_context,
86
87
        concurrency_queue_tx: None,
        router_manager: None,
88
89
90
91
92
93
94
95
96
97
98
99
    });

    // Configure request ID headers (use defaults if not specified)
    let request_id_headers = router_config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

100
101
102
103
104
    // Create auth config from router config
    let auth_config = AuthConfig {
        api_key: router_config.api_key.clone(),
    };

105
106
107
    // Use the actual server's build_app function
    build_app(
        app_state,
108
        auth_config,
109
110
111
112
113
        router_config.max_payload_size,
        request_id_headers,
        router_config.cors_allowed_origins.clone(),
    )
}
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141

/// Create a test Axum application with an existing AppContext
#[allow(dead_code)]
pub fn create_test_app_with_context(
    router: Arc<dyn RouterTrait>,
    app_context: Arc<AppContext>,
) -> Router {
    // Create AppState with the test router and context
    let app_state = Arc::new(AppState {
        router,
        context: app_context.clone(),
        concurrency_queue_tx: None,
        router_manager: None,
    });

    // Get config from the context
    let router_config = &app_context.router_config;

    // Configure request ID headers (use defaults if not specified)
    let request_id_headers = router_config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

142
143
144
145
146
    // Create auth config from router config
    let auth_config = AuthConfig {
        api_key: router_config.api_key.clone(),
    };

147
148
149
    // Use the actual server's build_app function
    build_app(
        app_state,
150
        auth_config,
151
152
153
154
155
        router_config.max_payload_size,
        request_id_headers,
        router_config.cors_allowed_origins.clone(),
    )
}