server.rs 30.9 KB
Newer Older
1
use crate::{
2
    config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
3
    core::{LoadMonitor, WorkerManager, WorkerRegistry, WorkerType},
4
    data_connector::{
5
6
7
        MemoryConversationStorage, MemoryResponseStorage, NoOpConversationStorage,
        NoOpResponseStorage, OracleConversationStorage, OracleResponseStorage,
        SharedConversationStorage, SharedResponseStorage,
8
    },
9
10
    logging::{self, LoggingConfig},
    metrics::{self, PrometheusConfig},
11
    middleware::{self, AuthConfig, QueuedRequest, TokenBucket},
12
13
14
15
    policies::PolicyRegistry,
    protocols::{
        spec::{
            ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest,
16
            RerankRequest, ResponsesGetParams, ResponsesRequest, V1RerankReqInput,
17
18
19
        },
        worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
    },
20
    reasoning_parser::ReasoningParserFactory,
21
    routers::{router_manager::RouterManager, RouterTrait},
22
23
    service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
    tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
24
    tool_parser::ToolParserFactory,
25
};
26
use axum::{
27
    extract::{Path, Query, Request, State},
28
29
    http::StatusCode,
    response::{IntoResponse, Response},
30
    routing::{delete, get, post},
31
    serve, Json, Router,
32
};
33
use reqwest::Client;
34
use serde::Deserialize;
35
use serde_json::{json, Value};
36
37
38
39
40
41
use std::{
    sync::atomic::{AtomicBool, Ordering},
    sync::Arc,
    time::Duration,
};
use tokio::{net::TcpListener, signal, spawn};
42
use tracing::{error, info, warn, Level};
43

44
45
//

46
#[derive(Clone)]
47
pub struct AppContext {
48
    pub client: Client,
49
    pub router_config: RouterConfig,
50
    pub rate_limiter: Arc<TokenBucket>,
51
    pub tokenizer: Option<Arc<dyn Tokenizer>>,
52
    pub reasoning_parser_factory: Option<ReasoningParserFactory>,
53
    pub tool_parser_factory: Option<ToolParserFactory>,
54
55
56
    pub worker_registry: Arc<WorkerRegistry>,
    pub policy_registry: Arc<PolicyRegistry>,
    pub router_manager: Option<Arc<RouterManager>>,
57
    pub response_storage: SharedResponseStorage,
58
    pub conversation_storage: SharedConversationStorage,
59
    pub load_monitor: Option<Arc<LoadMonitor>>,
60
61
    pub configured_reasoning_parser: Option<String>,
    pub configured_tool_parser: Option<String>,
62
63
}

64
impl AppContext {
65
66
67
68
    pub fn new(
        router_config: RouterConfig,
        client: Client,
        max_concurrent_requests: usize,
69
        rate_limit_tokens_per_second: Option<usize>,
70
    ) -> Result<Self, String> {
71
72
        let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
        let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
73

74
        let (tokenizer, reasoning_parser_factory, tool_parser_factory) =
75
            if router_config.connection_mode == ConnectionMode::Grpc {
76
77
78
79
80
81
82
83
84
85
86
                let tokenizer_path = router_config
                    .tokenizer_path
                    .clone()
                    .or_else(|| router_config.model_path.clone())
                    .ok_or_else(|| {
                        "gRPC mode requires either --tokenizer-path or --model-path to be specified"
                            .to_string()
                    })?;

                let tokenizer = Some(
                    tokenizer_factory::create_tokenizer(&tokenizer_path)
87
                        .map_err(|e| format!("Failed to create tokenizer: {e}"))?,
88
                );
89
                let reasoning_parser_factory = Some(ReasoningParserFactory::new());
90
                let tool_parser_factory = Some(ToolParserFactory::new());
91

92
                (tokenizer, reasoning_parser_factory, tool_parser_factory)
93
94
95
96
            } else {
                (None, None, None)
            };

97
        let worker_registry = Arc::new(WorkerRegistry::new());
98
        let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone()));
99

100
        let router_manager = None;
101

102
103
104
105
106
107
108
109
110
111
112
113
        let (response_storage, conversation_storage): (
            SharedResponseStorage,
            SharedConversationStorage,
        ) = match router_config.history_backend {
            HistoryBackend::Memory => (
                Arc::new(MemoryResponseStorage::new()),
                Arc::new(MemoryConversationStorage::new()),
            ),
            HistoryBackend::None => (
                Arc::new(NoOpResponseStorage::new()),
                Arc::new(NoOpConversationStorage::new()),
            ),
114
115
116
117
118
            HistoryBackend::Oracle => {
                let oracle_cfg = router_config.oracle.clone().ok_or_else(|| {
                    "oracle configuration is required when history_backend=oracle".to_string()
                })?;

119
120
121
122
123
124
125
126
127
                let response_storage =
                    OracleResponseStorage::new(oracle_cfg.clone()).map_err(|err| {
                        format!("failed to initialize Oracle response storage: {err}")
                    })?;

                let conversation_storage =
                    OracleConversationStorage::new(oracle_cfg).map_err(|err| {
                        format!("failed to initialize Oracle conversation storage: {err}")
                    })?;
128

129
                (Arc::new(response_storage), Arc::new(conversation_storage))
130
            }
131
132
        };

133
134
135
136
137
138
139
        let load_monitor = Some(Arc::new(LoadMonitor::new(
            worker_registry.clone(),
            policy_registry.clone(),
            client.clone(),
            router_config.worker_startup_check_interval_secs,
        )));

140
141
142
        let configured_reasoning_parser = router_config.reasoning_parser.clone();
        let configured_tool_parser = router_config.tool_call_parser.clone();

143
        Ok(Self {
144
            client,
145
            router_config,
146
            rate_limiter,
147
148
            tokenizer,
            reasoning_parser_factory,
149
            tool_parser_factory,
150
151
152
            worker_registry,
            policy_registry,
            router_manager,
153
            response_storage,
154
            conversation_storage,
155
            load_monitor,
156
157
            configured_reasoning_parser,
            configured_tool_parser,
158
        })
159
160
161
    }
}

162
163
164
165
#[derive(Clone)]
pub struct AppState {
    pub router: Arc<dyn RouterTrait>,
    pub context: Arc<AppContext>,
166
    pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<QueuedRequest>>,
167
    pub router_manager: Option<Arc<RouterManager>>,
168
169
}

170
171
async fn sink_handler() -> Response {
    StatusCode::NOT_FOUND.into_response()
172
173
}

174
175
async fn liveness() -> Response {
    (StatusCode::OK, "OK").into_response()
176
177
}

178
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    let workers = state.context.worker_registry.get_all();
    let healthy_workers: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect();

    let is_ready = if state.context.router_config.enable_igw {
        !healthy_workers.is_empty()
    } else {
        match &state.context.router_config.mode {
            RoutingMode::PrefillDecode { .. } => {
                let has_prefill = healthy_workers
                    .iter()
                    .any(|w| matches!(w.worker_type(), WorkerType::Prefill { .. }));
                let has_decode = healthy_workers
                    .iter()
                    .any(|w| matches!(w.worker_type(), WorkerType::Decode));
                has_prefill && has_decode
            }
            RoutingMode::Regular { .. } => !healthy_workers.is_empty(),
            RoutingMode::OpenAI { .. } => !healthy_workers.is_empty(),
        }
    };

    if is_ready {
        (
            StatusCode::OK,
            Json(json!({
                "status": "ready",
                "healthy_workers": healthy_workers.len(),
                "total_workers": workers.len()
            })),
        )
            .into_response()
    } else {
        (
            StatusCode::SERVICE_UNAVAILABLE,
            Json(json!({
                "status": "not ready",
                "reason": "insufficient healthy workers"
            })),
        )
            .into_response()
    }
220
221
}

222
223
async fn health(_state: State<Arc<AppState>>) -> Response {
    liveness().await
224
225
}

226
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
227
    state.router.health_generate(req).await
228
229
}

230
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
231
    state.router.get_server_info(req).await
232
233
}

234
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
235
    state.router.get_models(req).await
236
237
}

238
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
239
    state.router.get_model_info(req).await
240
}
241

242
async fn generate(
243
244
245
246
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<GenerateRequest>,
) -> Response {
247
248
249
250
    state
        .router
        .route_generate(Some(&headers), &body, None)
        .await
251
252
253
}

async fn v1_chat_completions(
254
255
256
257
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ChatCompletionRequest>,
) -> Response {
258
    state.router.route_chat(Some(&headers), &body, None).await
259
260
261
}

async fn v1_completions(
262
263
264
265
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<CompletionRequest>,
) -> Response {
266
267
268
269
    state
        .router
        .route_completion(Some(&headers), &body, None)
        .await
270
271
}

272
273
274
275
276
async fn rerank(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<RerankRequest>,
) -> Response {
277
    state.router.route_rerank(Some(&headers), &body, None).await
278
279
280
281
282
283
284
285
286
}

async fn v1_rerank(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<V1RerankReqInput>,
) -> Response {
    state
        .router
287
        .route_rerank(Some(&headers), &body.into(), None)
288
289
290
        .await
}

291
292
293
294
295
async fn v1_responses(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ResponsesRequest>,
) -> Response {
296
297
298
299
    state
        .router
        .route_responses(Some(&headers), &body, None)
        .await
300
301
}

302
303
304
305
306
307
308
309
310
311
312
async fn v1_embeddings(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<EmbeddingRequest>,
) -> Response {
    state
        .router
        .route_embeddings(Some(&headers), &body, None)
        .await
}

313
314
315
316
async fn v1_responses_get(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
317
    Query(params): Query<ResponsesGetParams>,
318
319
320
) -> Response {
    state
        .router
321
        .get_response(Some(&headers), &response_id, &params)
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        .await
}

async fn v1_responses_cancel(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .cancel_response(Some(&headers), &response_id)
        .await
}

async fn v1_responses_delete(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .delete_response(Some(&headers), &response_id)
        .await
}

async fn v1_responses_list_input_items(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .list_response_input_items(Some(&headers), &response_id)
        .await
}

358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
async fn v1_conversations_create(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<Value>,
) -> Response {
    state
        .router
        .create_conversation(Some(&headers), &body)
        .await
}

async fn v1_conversations_get(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .get_conversation(Some(&headers), &conversation_id)
        .await
}

async fn v1_conversations_update(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
    Json(body): Json<Value>,
) -> Response {
    state
        .router
        .update_conversation(Some(&headers), &conversation_id, &body)
        .await
}

async fn v1_conversations_delete(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .delete_conversation(Some(&headers), &conversation_id)
        .await
}

403
#[derive(Deserialize)]
404
struct AddWorkerQuery {
405
    url: String,
406
    api_key: Option<String>,
407
408
}

409
async fn add_worker(
410
    State(state): State<Arc<AppState>>,
411
    Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
412
) -> Response {
413
414
415
416
417
418
419
420
421
422
    // Warn if router has API key but worker is being added without one
    if state.context.router_config.api_key.is_some() && api_key.is_none() {
        warn!(
            "Adding worker {} without API key while router has API key configured. \
            Worker will be accessible without authentication. \
            If the worker requires the same API key as the router, please specify it explicitly.",
            url
        );
    }

423
424
425
    let result = WorkerManager::add_worker(&url, &api_key, &state.context).await;

    match result {
426
427
        Ok(message) => (StatusCode::OK, message).into_response(),
        Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
428
    }
429
430
}

431
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
432
433
    let worker_list = WorkerManager::get_worker_urls(&state.context.worker_registry);
    Json(json!({ "urls": worker_list })).into_response()
434
435
}

436
async fn remove_worker(
437
    State(state): State<Arc<AppState>>,
438
    Query(AddWorkerQuery { url, .. }): Query<AddWorkerQuery>,
439
) -> Response {
440
441
442
443
444
445
    let result = WorkerManager::remove_worker(&url, &state.context);

    match result {
        Ok(message) => (StatusCode::OK, message).into_response(),
        Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
    }
446
447
}

448
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
    match WorkerManager::flush_cache_all(&state.context.worker_registry, &state.context.client)
        .await
    {
        Ok(result) => {
            if result.failed.is_empty() {
                (
                    StatusCode::OK,
                    Json(json!({
                        "status": "success",
                        "message": result.message,
                        "workers_flushed": result.successful.len(),
                        "total_http_workers": result.http_workers,
                        "total_workers": result.total_workers
                    })),
                )
                    .into_response()
            } else {
                (
                    StatusCode::PARTIAL_CONTENT,
                    Json(json!({
                        "status": "partial_success",
                        "message": result.message,
                        "successful": result.successful,
                        "failed": result.failed.into_iter().map(|(url, err)| json!({
                            "worker": url,
                            "error": err
                        })).collect::<Vec<_>>(),
                        "total_http_workers": result.http_workers,
                        "total_workers": result.total_workers
                    })),
                )
                    .into_response()
            }
        }
        Err(e) => {
            error!("Failed to flush cache: {}", e);
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(json!({
                    "status": "error",
                    "message": format!("Failed to flush cache: {}", e)
                })),
            )
                .into_response()
        }
    }
495
496
}

497
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
    let result =
        WorkerManager::get_all_worker_loads(&state.context.worker_registry, &state.context.client)
            .await;

    let loads: Vec<Value> = result
        .loads
        .iter()
        .map(|info| {
            json!({
                "worker": &info.worker,
                "load": info.load
            })
        })
        .collect();

    (
        StatusCode::OK,
        Json(json!({
            "workers": loads
        })),
    )
        .into_response()
520
521
}

522
523
524
525
async fn create_worker(
    State(state): State<Arc<AppState>>,
    Json(config): Json<WorkerConfigRequest>,
) -> Response {
526
527
528
529
530
531
532
533
534
535
    // Warn if router has API key but worker is being added without one
    if state.context.router_config.api_key.is_some() && config.api_key.is_none() {
        warn!(
            "Adding worker {} without API key while router has API key configured. \
            Worker will be accessible without authentication. \
            If the worker requires the same API key as the router, please specify it explicitly.",
            config.url
        );
    }

536
537
538
539
540
541
542
543
544
545
    let result = WorkerManager::add_worker_from_config(&config, &state.context).await;

    match result {
        Ok(message) => {
            let response = WorkerApiResponse {
                success: true,
                message,
                worker: None,
            };
            (StatusCode::OK, Json(response)).into_response()
546
        }
547
548
549
550
551
552
        Err(error) => {
            let error_response = WorkerErrorResponse {
                error,
                code: "ADD_WORKER_FAILED".to_string(),
            };
            (StatusCode::BAD_REQUEST, Json(error_response)).into_response()
553
554
555
556
557
        }
    }
}

async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
    let workers = state.context.worker_registry.get_all();
    let response = serde_json::json!({
        "workers": workers.iter().map(|worker| {
            let mut worker_info = serde_json::json!({
                "url": worker.url(),
                "model_id": worker.model_id(),
                "worker_type": match worker.worker_type() {
                    WorkerType::Regular => "regular",
                    WorkerType::Prefill { .. } => "prefill",
                    WorkerType::Decode => "decode",
                },
                "is_healthy": worker.is_healthy(),
                "load": worker.load(),
                "connection_mode": format!("{:?}", worker.connection_mode()),
                "priority": worker.priority(),
                "cost": worker.cost(),
            });

            if let WorkerType::Prefill { bootstrap_port } = worker.worker_type() {
                worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port);
578
            }
579
580
581
582
583
584
585
586
587
588
589

            worker_info
        }).collect::<Vec<_>>(),
        "total": workers.len(),
        "stats": {
            "prefill_count": state.context.worker_registry.get_prefill_workers().len(),
            "decode_count": state.context.worker_registry.get_decode_workers().len(),
            "regular_count": state.context.worker_registry.get_by_type(&WorkerType::Regular).len(),
        }
    });
    Json(response).into_response()
590
591
}

592
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
593
594
595
596
597
598
599
600
    let workers = WorkerManager::get_worker_urls(&state.context.worker_registry);
    if workers.contains(&url) {
        Json(json!({
            "url": url,
            "model_id": "unknown",
            "is_healthy": true
        }))
        .into_response()
601
    } else {
602
603
604
605
606
        let error = WorkerErrorResponse {
            error: format!("Worker {url} not found"),
            code: "WORKER_NOT_FOUND".to_string(),
        };
        (StatusCode::NOT_FOUND, Json(error)).into_response()
607
608
609
    }
}

610
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
    let result = WorkerManager::remove_worker(&url, &state.context);

    match result {
        Ok(message) => {
            let response = WorkerApiResponse {
                success: true,
                message,
                worker: None,
            };
            (StatusCode::OK, Json(response)).into_response()
        }
        Err(error) => {
            let error_response = WorkerErrorResponse {
                error,
                code: "REMOVE_WORKER_FAILED".to_string(),
            };
            (StatusCode::BAD_REQUEST, Json(error_response)).into_response()
628
629
630
631
        }
    }
}

632
633
634
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
635
    pub router_config: RouterConfig,
636
    pub max_payload_size: usize,
637
    pub log_dir: Option<String>,
638
    pub log_level: Option<String>,
639
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
640
    pub prometheus_config: Option<PrometheusConfig>,
641
    pub request_timeout_secs: u64,
642
    pub request_id_headers: Option<Vec<String>>,
643
644
}

645
646
pub fn build_app(
    app_state: Arc<AppState>,
647
    auth_config: AuthConfig,
648
649
650
651
652
653
654
    max_payload_size: usize,
    request_id_headers: Vec<String>,
    cors_allowed_origins: Vec<String>,
) -> Router {
    let protected_routes = Router::new()
        .route("/generate", post(generate))
        .route("/v1/chat/completions", post(v1_chat_completions))
655
        .route("/v1/completions", post(v1_completions))
656
657
        .route("/rerank", post(rerank))
        .route("/v1/rerank", post(v1_rerank))
658
        .route("/v1/responses", post(v1_responses))
659
        .route("/v1/embeddings", post(v1_embeddings))
660
661
662
663
664
665
666
667
668
669
        .route("/v1/responses/{response_id}", get(v1_responses_get))
        .route(
            "/v1/responses/{response_id}/cancel",
            post(v1_responses_cancel),
        )
        .route("/v1/responses/{response_id}", delete(v1_responses_delete))
        .route(
            "/v1/responses/{response_id}/input",
            get(v1_responses_list_input_items),
        )
670
671
672
673
674
675
676
        .route("/v1/conversations", post(v1_conversations_create))
        .route(
            "/v1/conversations/{conversation_id}",
            get(v1_conversations_get)
                .post(v1_conversations_update)
                .delete(v1_conversations_delete),
        )
677
678
        .route_layer(axum::middleware::from_fn_with_state(
            app_state.clone(),
679
            middleware::concurrency_limit_middleware,
680
681
682
683
        ))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
684
        ));
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699

    let public_routes = Router::new()
        .route("/liveness", get(liveness))
        .route("/readiness", get(readiness))
        .route("/health", get(health))
        .route("/health_generate", get(health_generate))
        .route("/v1/models", get(v1_models))
        .route("/get_model_info", get(get_model_info))
        .route("/get_server_info", get(get_server_info));

    let admin_routes = Router::new()
        .route("/add_worker", post(add_worker))
        .route("/remove_worker", post(remove_worker))
        .route("/list_workers", get(list_workers))
        .route("/flush_cache", post(flush_cache))
700
701
702
703
704
        .route("/get_loads", get(get_loads))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
        ));
705

706
707
708
709
    let worker_routes = Router::new()
        .route("/workers", post(create_worker))
        .route("/workers", get(list_workers_rest))
        .route("/workers/{url}", get(get_worker))
710
711
712
713
714
        .route("/workers/{url}", delete(delete_worker))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
        ));
715

716
717
718
719
    Router::new()
        .merge(protected_routes)
        .merge(public_routes)
        .merge(admin_routes)
720
        .merge(worker_routes)
721
        .layer(axum::extract::DefaultBodyLimit::max(max_payload_size))
722
723
724
        .layer(tower_http::limit::RequestBodyLimitLayer::new(
            max_payload_size,
        ))
725
726
        .layer(middleware::create_logging_layer())
        .layer(middleware::RequestIdLayer::new(request_id_headers))
727
728
729
730
731
732
        .layer(create_cors_layer(cors_allowed_origins))
        .fallback(sink_handler)
        .with_state(app_state)
}

pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
733
734
735
736
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
737
738
739
740
741
742
            level: config
                .log_level
                .as_deref()
                .and_then(|s| match s.to_uppercase().parse::<Level>() {
                    Ok(l) => Some(l),
                    Err(_) => {
743
                        warn!("Invalid log level string: '{s}'. Defaulting to INFO.");
744
745
746
747
                        None
                    }
                })
                .unwrap_or(Level::INFO),
748
749
750
751
752
753
754
755
756
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
757

758
759
    if let Some(prometheus_config) = &config.prometheus_config {
        metrics::start_prometheus(prometheus_config.clone());
760
761
    }

762
    info!(
763
764
765
766
767
        "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
        config.host,
        config.port,
        config.router_config.mode,
        config.router_config.policy,
768
769
770
        config.max_payload_size / (1024 * 1024)
    );

771
    let client = Client::builder()
772
        .pool_idle_timeout(Some(Duration::from_secs(50)))
773
        .pool_max_idle_per_host(500)
774
        .timeout(Duration::from_secs(config.request_timeout_secs))
775
        .connect_timeout(Duration::from_secs(10))
776
        .tcp_nodelay(true)
777
        .tcp_keepalive(Some(Duration::from_secs(30)))
778
779
780
        .build()
        .expect("Failed to create HTTP client");

781
    let app_context = AppContext::new(
782
783
784
        config.router_config.clone(),
        client.clone(),
        config.router_config.max_concurrent_requests,
785
        config.router_config.rate_limit_tokens_per_second,
786
787
788
789
    )?;

    let app_context = Arc::new(app_context);

790
791
792
793
    info!(
        "Initializing workers for routing mode: {:?}",
        config.router_config.mode
    );
794
    WorkerManager::initialize_workers(
795
796
797
798
799
800
        &config.router_config,
        &app_context.worker_registry,
        Some(&app_context.policy_registry),
    )
    .await
    .map_err(|e| format!("Failed to initialize workers: {}", e))?;
801
802
803
804
805
806
807

    let worker_stats = app_context.worker_registry.stats();
    info!(
        "Workers initialized: {} total, {} healthy",
        worker_stats.total_workers, worker_stats.healthy_workers
    );

808
809
    let router_manager = RouterManager::from_config(&config, &app_context).await?;
    let router: Arc<dyn RouterTrait> = router_manager.clone();
810
811
812
813
814
815
816
817

    let _health_checker = app_context
        .worker_registry
        .start_health_checker(config.router_config.health_check.check_interval_secs);
    info!(
        "Started health checker for workers with {}s interval",
        config.router_config.health_check.check_interval_secs
    );
818

819
820
821
822
823
    if let Some(ref load_monitor) = app_context.load_monitor {
        load_monitor.start().await;
        info!("Started LoadMonitor for PowerOfTwo policies");
    }

824
    let (limiter, processor) = middleware::ConcurrencyLimiter::new(
825
826
827
828
829
830
        app_context.rate_limiter.clone(),
        config.router_config.queue_size,
        Duration::from_secs(config.router_config.queue_timeout_secs),
    );

    if let Some(processor) = processor {
831
        spawn(processor.run());
832
833
834
835
836
837
        info!(
            "Started request queue with size: {}, timeout: {}s",
            config.router_config.queue_size, config.router_config.queue_timeout_secs
        );
    }

838
    let app_state = Arc::new(AppState {
839
        router,
840
        context: app_context.clone(),
841
        concurrency_queue_tx: limiter.queue_tx.clone(),
842
        router_manager: Some(router_manager),
843
    });
844
845
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
846
847
            let app_context_arc = Arc::clone(&app_state.context);
            match start_service_discovery(service_discovery_config, app_context_arc).await {
848
                Ok(handle) => {
849
                    info!("Service discovery started");
850
851
852
853
854
855
856
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
857
                    error!("Failed to start service discovery: {e}");
858
859
860
861
862
863
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

864
    info!(
865
        "Router ready | workers: {:?}",
866
        WorkerManager::get_worker_urls(&app_state.context.worker_registry)
867
    );
868

869
870
871
872
873
874
875
876
877
    let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

878
879
880
881
    let auth_config = AuthConfig {
        api_key: config.router_config.api_key.clone(),
    };

882
883
    let app = build_app(
        app_state,
884
        auth_config,
885
886
887
888
        config.max_payload_size,
        request_id_headers,
        config.router_config.cors_allowed_origins.clone(),
    );
889

890
891
892
893
894
895
    // TcpListener::bind accepts &str and handles IPv4/IPv6 via ToSocketAddrs
    let bind_addr = format!("{}:{}", config.host, config.port);
    info!("Starting server on {}", bind_addr);
    let listener = TcpListener::bind(&bind_addr)
        .await
        .map_err(|e| format!("Failed to bind to {}: {}", bind_addr, e))?;
896
    serve(listener, app)
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
        .with_graceful_shutdown(shutdown_signal())
        .await
        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;

    Ok(())
}

async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {
            info!("Received Ctrl+C, starting graceful shutdown");
        },
        _ = terminate => {
            info!("Received terminate signal, starting graceful shutdown");
        },
    }
}

fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
    use tower_http::cors::Any;

    let cors = if allowed_origins.is_empty() {
        tower_http::cors::CorsLayer::new()
            .allow_origin(Any)
            .allow_methods(Any)
            .allow_headers(Any)
            .expose_headers(Any)
    } else {
        let origins: Vec<http::HeaderValue> = allowed_origins
            .into_iter()
            .filter_map(|origin| origin.parse().ok())
            .collect();

        tower_http::cors::CorsLayer::new()
            .allow_origin(origins)
            .allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
            .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
            .expose_headers([http::header::HeaderName::from_static("x-request-id")])
    };

    cors.max_age(Duration::from_secs(3600))
955
}