server.rs 28.8 KB
Newer Older
1
use crate::{
2
    config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
3
    core::{LoadMonitor, WorkerManager, WorkerRegistry, WorkerType},
4
5
6
    data_connector::{
        MemoryResponseStorage, NoOpResponseStorage, OracleResponseStorage, SharedResponseStorage,
    },
7
8
    logging::{self, LoggingConfig},
    metrics::{self, PrometheusConfig},
9
    middleware::{self, AuthConfig, QueuedRequest, TokenBucket},
10
11
12
13
    policies::PolicyRegistry,
    protocols::{
        spec::{
            ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest,
14
            RerankRequest, ResponsesGetParams, ResponsesRequest, V1RerankReqInput,
15
16
17
        },
        worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
    },
18
    reasoning_parser::ReasoningParserFactory,
19
    routers::{router_manager::RouterManager, RouterTrait},
20
21
    service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
    tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
22
    tool_parser::ToolParserFactory,
23
};
24
use axum::{
25
    extract::{Path, Query, Request, State},
26
27
    http::StatusCode,
    response::{IntoResponse, Response},
28
    routing::{delete, get, post},
29
    serve, Json, Router,
30
};
31
use reqwest::Client;
32
use serde::Deserialize;
33
use serde_json::{json, Value};
34
35
36
37
38
39
use std::{
    sync::atomic::{AtomicBool, Ordering},
    sync::Arc,
    time::Duration,
};
use tokio::{net::TcpListener, signal, spawn};
40
use tracing::{error, info, warn, Level};
41

42
#[derive(Clone)]
43
pub struct AppContext {
44
    pub client: Client,
45
    pub router_config: RouterConfig,
46
    pub rate_limiter: Arc<TokenBucket>,
47
    pub tokenizer: Option<Arc<dyn Tokenizer>>,
48
    pub reasoning_parser_factory: Option<ReasoningParserFactory>,
49
    pub tool_parser_factory: Option<ToolParserFactory>,
50
51
52
    pub worker_registry: Arc<WorkerRegistry>,
    pub policy_registry: Arc<PolicyRegistry>,
    pub router_manager: Option<Arc<RouterManager>>,
53
    pub response_storage: SharedResponseStorage,
54
    pub load_monitor: Option<Arc<LoadMonitor>>,
55
56
    pub configured_reasoning_parser: Option<String>,
    pub configured_tool_parser: Option<String>,
57
58
}

59
impl AppContext {
60
61
62
63
    pub fn new(
        router_config: RouterConfig,
        client: Client,
        max_concurrent_requests: usize,
64
        rate_limit_tokens_per_second: Option<usize>,
65
    ) -> Result<Self, String> {
66
67
        let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
        let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
68

69
        let (tokenizer, reasoning_parser_factory, tool_parser_factory) =
70
            if router_config.connection_mode == ConnectionMode::Grpc {
71
72
73
74
75
76
77
78
79
80
81
                let tokenizer_path = router_config
                    .tokenizer_path
                    .clone()
                    .or_else(|| router_config.model_path.clone())
                    .ok_or_else(|| {
                        "gRPC mode requires either --tokenizer-path or --model-path to be specified"
                            .to_string()
                    })?;

                let tokenizer = Some(
                    tokenizer_factory::create_tokenizer(&tokenizer_path)
82
                        .map_err(|e| format!("Failed to create tokenizer: {e}"))?,
83
                );
84
                let reasoning_parser_factory = Some(ReasoningParserFactory::new());
85
                let tool_parser_factory = Some(ToolParserFactory::new());
86

87
                (tokenizer, reasoning_parser_factory, tool_parser_factory)
88
89
90
91
            } else {
                (None, None, None)
            };

92
        let worker_registry = Arc::new(WorkerRegistry::new());
93
        let policy_registry = Arc::new(PolicyRegistry::new(router_config.policy.clone()));
94

95
        let router_manager = None;
96

97
98
99
        let response_storage: SharedResponseStorage = match router_config.history_backend {
            HistoryBackend::Memory => Arc::new(MemoryResponseStorage::new()),
            HistoryBackend::None => Arc::new(NoOpResponseStorage::new()),
100
101
102
103
104
105
106
107
108
109
110
            HistoryBackend::Oracle => {
                let oracle_cfg = router_config.oracle.clone().ok_or_else(|| {
                    "oracle configuration is required when history_backend=oracle".to_string()
                })?;

                let storage = OracleResponseStorage::new(oracle_cfg).map_err(|err| {
                    format!("failed to initialize Oracle response storage: {err}")
                })?;

                Arc::new(storage)
            }
111
112
        };

113
114
115
116
117
118
119
        let load_monitor = Some(Arc::new(LoadMonitor::new(
            worker_registry.clone(),
            policy_registry.clone(),
            client.clone(),
            router_config.worker_startup_check_interval_secs,
        )));

120
121
122
        let configured_reasoning_parser = router_config.reasoning_parser.clone();
        let configured_tool_parser = router_config.tool_call_parser.clone();

123
        Ok(Self {
124
            client,
125
            router_config,
126
            rate_limiter,
127
128
            tokenizer,
            reasoning_parser_factory,
129
            tool_parser_factory,
130
131
132
            worker_registry,
            policy_registry,
            router_manager,
133
            response_storage,
134
            load_monitor,
135
136
            configured_reasoning_parser,
            configured_tool_parser,
137
        })
138
139
140
    }
}

141
142
143
144
#[derive(Clone)]
pub struct AppState {
    pub router: Arc<dyn RouterTrait>,
    pub context: Arc<AppContext>,
145
    pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<QueuedRequest>>,
146
    pub router_manager: Option<Arc<RouterManager>>,
147
148
}

149
150
async fn sink_handler() -> Response {
    StatusCode::NOT_FOUND.into_response()
151
152
}

153
154
async fn liveness() -> Response {
    (StatusCode::OK, "OK").into_response()
155
156
}

157
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
    let workers = state.context.worker_registry.get_all();
    let healthy_workers: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect();

    let is_ready = if state.context.router_config.enable_igw {
        !healthy_workers.is_empty()
    } else {
        match &state.context.router_config.mode {
            RoutingMode::PrefillDecode { .. } => {
                let has_prefill = healthy_workers
                    .iter()
                    .any(|w| matches!(w.worker_type(), WorkerType::Prefill { .. }));
                let has_decode = healthy_workers
                    .iter()
                    .any(|w| matches!(w.worker_type(), WorkerType::Decode));
                has_prefill && has_decode
            }
            RoutingMode::Regular { .. } => !healthy_workers.is_empty(),
            RoutingMode::OpenAI { .. } => !healthy_workers.is_empty(),
        }
    };

    if is_ready {
        (
            StatusCode::OK,
            Json(json!({
                "status": "ready",
                "healthy_workers": healthy_workers.len(),
                "total_workers": workers.len()
            })),
        )
            .into_response()
    } else {
        (
            StatusCode::SERVICE_UNAVAILABLE,
            Json(json!({
                "status": "not ready",
                "reason": "insufficient healthy workers"
            })),
        )
            .into_response()
    }
199
200
}

201
202
async fn health(_state: State<Arc<AppState>>) -> Response {
    liveness().await
203
204
}

205
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
206
    state.router.health_generate(req).await
207
208
}

209
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
210
    state.router.get_server_info(req).await
211
212
}

213
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
214
    state.router.get_models(req).await
215
216
}

217
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
218
    state.router.get_model_info(req).await
219
}
220

221
async fn generate(
222
223
224
225
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<GenerateRequest>,
) -> Response {
226
227
228
229
    state
        .router
        .route_generate(Some(&headers), &body, None)
        .await
230
231
232
}

async fn v1_chat_completions(
233
234
235
236
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ChatCompletionRequest>,
) -> Response {
237
    state.router.route_chat(Some(&headers), &body, None).await
238
239
240
}

async fn v1_completions(
241
242
243
244
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<CompletionRequest>,
) -> Response {
245
246
247
248
    state
        .router
        .route_completion(Some(&headers), &body, None)
        .await
249
250
}

251
252
253
254
255
async fn rerank(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<RerankRequest>,
) -> Response {
256
    state.router.route_rerank(Some(&headers), &body, None).await
257
258
259
260
261
262
263
264
265
}

async fn v1_rerank(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<V1RerankReqInput>,
) -> Response {
    state
        .router
266
        .route_rerank(Some(&headers), &body.into(), None)
267
268
269
        .await
}

270
271
272
273
274
async fn v1_responses(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ResponsesRequest>,
) -> Response {
275
276
277
278
    state
        .router
        .route_responses(Some(&headers), &body, None)
        .await
279
280
}

281
282
283
284
285
286
287
288
289
290
291
async fn v1_embeddings(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<EmbeddingRequest>,
) -> Response {
    state
        .router
        .route_embeddings(Some(&headers), &body, None)
        .await
}

292
293
294
295
async fn v1_responses_get(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
296
    Query(params): Query<ResponsesGetParams>,
297
298
299
) -> Response {
    state
        .router
300
        .get_response(Some(&headers), &response_id, &params)
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        .await
}

async fn v1_responses_cancel(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .cancel_response(Some(&headers), &response_id)
        .await
}

async fn v1_responses_delete(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .delete_response(Some(&headers), &response_id)
        .await
}

async fn v1_responses_list_input_items(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .list_response_input_items(Some(&headers), &response_id)
        .await
}

337
#[derive(Deserialize)]
338
struct AddWorkerQuery {
339
    url: String,
340
    api_key: Option<String>,
341
342
}

343
async fn add_worker(
344
    State(state): State<Arc<AppState>>,
345
    Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
346
) -> Response {
347
348
349
350
351
352
353
354
355
356
    // Warn if router has API key but worker is being added without one
    if state.context.router_config.api_key.is_some() && api_key.is_none() {
        warn!(
            "Adding worker {} without API key while router has API key configured. \
            Worker will be accessible without authentication. \
            If the worker requires the same API key as the router, please specify it explicitly.",
            url
        );
    }

357
358
359
    let result = WorkerManager::add_worker(&url, &api_key, &state.context).await;

    match result {
360
361
        Ok(message) => (StatusCode::OK, message).into_response(),
        Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
362
    }
363
364
}

365
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
366
367
    let worker_list = WorkerManager::get_worker_urls(&state.context.worker_registry);
    Json(json!({ "urls": worker_list })).into_response()
368
369
}

370
async fn remove_worker(
371
    State(state): State<Arc<AppState>>,
372
    Query(AddWorkerQuery { url, .. }): Query<AddWorkerQuery>,
373
) -> Response {
374
375
376
377
378
379
    let result = WorkerManager::remove_worker(&url, &state.context);

    match result {
        Ok(message) => (StatusCode::OK, message).into_response(),
        Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
    }
380
381
}

382
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
    match WorkerManager::flush_cache_all(&state.context.worker_registry, &state.context.client)
        .await
    {
        Ok(result) => {
            if result.failed.is_empty() {
                (
                    StatusCode::OK,
                    Json(json!({
                        "status": "success",
                        "message": result.message,
                        "workers_flushed": result.successful.len(),
                        "total_http_workers": result.http_workers,
                        "total_workers": result.total_workers
                    })),
                )
                    .into_response()
            } else {
                (
                    StatusCode::PARTIAL_CONTENT,
                    Json(json!({
                        "status": "partial_success",
                        "message": result.message,
                        "successful": result.successful,
                        "failed": result.failed.into_iter().map(|(url, err)| json!({
                            "worker": url,
                            "error": err
                        })).collect::<Vec<_>>(),
                        "total_http_workers": result.http_workers,
                        "total_workers": result.total_workers
                    })),
                )
                    .into_response()
            }
        }
        Err(e) => {
            error!("Failed to flush cache: {}", e);
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(json!({
                    "status": "error",
                    "message": format!("Failed to flush cache: {}", e)
                })),
            )
                .into_response()
        }
    }
429
430
}

431
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
    let result =
        WorkerManager::get_all_worker_loads(&state.context.worker_registry, &state.context.client)
            .await;

    let loads: Vec<Value> = result
        .loads
        .iter()
        .map(|info| {
            json!({
                "worker": &info.worker,
                "load": info.load
            })
        })
        .collect();

    (
        StatusCode::OK,
        Json(json!({
            "workers": loads
        })),
    )
        .into_response()
454
455
}

456
457
458
459
async fn create_worker(
    State(state): State<Arc<AppState>>,
    Json(config): Json<WorkerConfigRequest>,
) -> Response {
460
461
462
463
464
465
466
467
468
469
    // Warn if router has API key but worker is being added without one
    if state.context.router_config.api_key.is_some() && config.api_key.is_none() {
        warn!(
            "Adding worker {} without API key while router has API key configured. \
            Worker will be accessible without authentication. \
            If the worker requires the same API key as the router, please specify it explicitly.",
            config.url
        );
    }

470
471
472
473
474
475
476
477
478
479
    let result = WorkerManager::add_worker_from_config(&config, &state.context).await;

    match result {
        Ok(message) => {
            let response = WorkerApiResponse {
                success: true,
                message,
                worker: None,
            };
            (StatusCode::OK, Json(response)).into_response()
480
        }
481
482
483
484
485
486
        Err(error) => {
            let error_response = WorkerErrorResponse {
                error,
                code: "ADD_WORKER_FAILED".to_string(),
            };
            (StatusCode::BAD_REQUEST, Json(error_response)).into_response()
487
488
489
490
491
        }
    }
}

async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
    let workers = state.context.worker_registry.get_all();
    let response = serde_json::json!({
        "workers": workers.iter().map(|worker| {
            let mut worker_info = serde_json::json!({
                "url": worker.url(),
                "model_id": worker.model_id(),
                "worker_type": match worker.worker_type() {
                    WorkerType::Regular => "regular",
                    WorkerType::Prefill { .. } => "prefill",
                    WorkerType::Decode => "decode",
                },
                "is_healthy": worker.is_healthy(),
                "load": worker.load(),
                "connection_mode": format!("{:?}", worker.connection_mode()),
                "priority": worker.priority(),
                "cost": worker.cost(),
            });

            if let WorkerType::Prefill { bootstrap_port } = worker.worker_type() {
                worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port);
512
            }
513
514
515
516
517
518
519
520
521
522
523

            worker_info
        }).collect::<Vec<_>>(),
        "total": workers.len(),
        "stats": {
            "prefill_count": state.context.worker_registry.get_prefill_workers().len(),
            "decode_count": state.context.worker_registry.get_decode_workers().len(),
            "regular_count": state.context.worker_registry.get_by_type(&WorkerType::Regular).len(),
        }
    });
    Json(response).into_response()
524
525
}

526
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
527
528
529
530
531
532
533
534
    let workers = WorkerManager::get_worker_urls(&state.context.worker_registry);
    if workers.contains(&url) {
        Json(json!({
            "url": url,
            "model_id": "unknown",
            "is_healthy": true
        }))
        .into_response()
535
    } else {
536
537
538
539
540
        let error = WorkerErrorResponse {
            error: format!("Worker {url} not found"),
            code: "WORKER_NOT_FOUND".to_string(),
        };
        (StatusCode::NOT_FOUND, Json(error)).into_response()
541
542
543
    }
}

544
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    let result = WorkerManager::remove_worker(&url, &state.context);

    match result {
        Ok(message) => {
            let response = WorkerApiResponse {
                success: true,
                message,
                worker: None,
            };
            (StatusCode::OK, Json(response)).into_response()
        }
        Err(error) => {
            let error_response = WorkerErrorResponse {
                error,
                code: "REMOVE_WORKER_FAILED".to_string(),
            };
            (StatusCode::BAD_REQUEST, Json(error_response)).into_response()
562
563
564
565
        }
    }
}

566
567
568
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
569
    pub router_config: RouterConfig,
570
    pub max_payload_size: usize,
571
    pub log_dir: Option<String>,
572
    pub log_level: Option<String>,
573
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
574
    pub prometheus_config: Option<PrometheusConfig>,
575
    pub request_timeout_secs: u64,
576
    pub request_id_headers: Option<Vec<String>>,
577
578
}

579
580
pub fn build_app(
    app_state: Arc<AppState>,
581
    auth_config: AuthConfig,
582
583
584
585
586
587
588
    max_payload_size: usize,
    request_id_headers: Vec<String>,
    cors_allowed_origins: Vec<String>,
) -> Router {
    let protected_routes = Router::new()
        .route("/generate", post(generate))
        .route("/v1/chat/completions", post(v1_chat_completions))
589
        .route("/v1/completions", post(v1_completions))
590
591
        .route("/rerank", post(rerank))
        .route("/v1/rerank", post(v1_rerank))
592
        .route("/v1/responses", post(v1_responses))
593
        .route("/v1/embeddings", post(v1_embeddings))
594
595
596
597
598
599
600
601
602
603
        .route("/v1/responses/{response_id}", get(v1_responses_get))
        .route(
            "/v1/responses/{response_id}/cancel",
            post(v1_responses_cancel),
        )
        .route("/v1/responses/{response_id}", delete(v1_responses_delete))
        .route(
            "/v1/responses/{response_id}/input",
            get(v1_responses_list_input_items),
        )
604
605
        .route_layer(axum::middleware::from_fn_with_state(
            app_state.clone(),
606
            middleware::concurrency_limit_middleware,
607
608
609
610
        ))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
611
        ));
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626

    let public_routes = Router::new()
        .route("/liveness", get(liveness))
        .route("/readiness", get(readiness))
        .route("/health", get(health))
        .route("/health_generate", get(health_generate))
        .route("/v1/models", get(v1_models))
        .route("/get_model_info", get(get_model_info))
        .route("/get_server_info", get(get_server_info));

    let admin_routes = Router::new()
        .route("/add_worker", post(add_worker))
        .route("/remove_worker", post(remove_worker))
        .route("/list_workers", get(list_workers))
        .route("/flush_cache", post(flush_cache))
627
628
629
630
631
        .route("/get_loads", get(get_loads))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
        ));
632

633
634
635
636
    let worker_routes = Router::new()
        .route("/workers", post(create_worker))
        .route("/workers", get(list_workers_rest))
        .route("/workers/{url}", get(get_worker))
637
638
639
640
641
        .route("/workers/{url}", delete(delete_worker))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
        ));
642

643
644
645
646
    Router::new()
        .merge(protected_routes)
        .merge(public_routes)
        .merge(admin_routes)
647
        .merge(worker_routes)
648
        .layer(axum::extract::DefaultBodyLimit::max(max_payload_size))
649
650
651
        .layer(tower_http::limit::RequestBodyLimitLayer::new(
            max_payload_size,
        ))
652
653
        .layer(middleware::create_logging_layer())
        .layer(middleware::RequestIdLayer::new(request_id_headers))
654
655
656
657
658
659
        .layer(create_cors_layer(cors_allowed_origins))
        .fallback(sink_handler)
        .with_state(app_state)
}

pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
660
661
662
663
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
664
665
666
667
668
669
            level: config
                .log_level
                .as_deref()
                .and_then(|s| match s.to_uppercase().parse::<Level>() {
                    Ok(l) => Some(l),
                    Err(_) => {
670
                        warn!("Invalid log level string: '{s}'. Defaulting to INFO.");
671
672
673
674
                        None
                    }
                })
                .unwrap_or(Level::INFO),
675
676
677
678
679
680
681
682
683
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
684

685
686
    if let Some(prometheus_config) = &config.prometheus_config {
        metrics::start_prometheus(prometheus_config.clone());
687
688
    }

689
    info!(
690
691
692
693
694
        "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
        config.host,
        config.port,
        config.router_config.mode,
        config.router_config.policy,
695
696
697
        config.max_payload_size / (1024 * 1024)
    );

698
    let client = Client::builder()
699
        .pool_idle_timeout(Some(Duration::from_secs(50)))
700
        .pool_max_idle_per_host(500)
701
        .timeout(Duration::from_secs(config.request_timeout_secs))
702
        .connect_timeout(Duration::from_secs(10))
703
        .tcp_nodelay(true)
704
        .tcp_keepalive(Some(Duration::from_secs(30)))
705
706
707
        .build()
        .expect("Failed to create HTTP client");

708
    let app_context = AppContext::new(
709
710
711
        config.router_config.clone(),
        client.clone(),
        config.router_config.max_concurrent_requests,
712
        config.router_config.rate_limit_tokens_per_second,
713
714
715
716
    )?;

    let app_context = Arc::new(app_context);

717
718
719
720
    info!(
        "Initializing workers for routing mode: {:?}",
        config.router_config.mode
    );
721
    WorkerManager::initialize_workers(
722
723
724
725
726
727
        &config.router_config,
        &app_context.worker_registry,
        Some(&app_context.policy_registry),
    )
    .await
    .map_err(|e| format!("Failed to initialize workers: {}", e))?;
728
729
730
731
732
733
734

    let worker_stats = app_context.worker_registry.stats();
    info!(
        "Workers initialized: {} total, {} healthy",
        worker_stats.total_workers, worker_stats.healthy_workers
    );

735
736
    let router_manager = RouterManager::from_config(&config, &app_context).await?;
    let router: Arc<dyn RouterTrait> = router_manager.clone();
737
738
739
740
741
742
743
744

    let _health_checker = app_context
        .worker_registry
        .start_health_checker(config.router_config.health_check.check_interval_secs);
    info!(
        "Started health checker for workers with {}s interval",
        config.router_config.health_check.check_interval_secs
    );
745

746
747
748
749
750
    if let Some(ref load_monitor) = app_context.load_monitor {
        load_monitor.start().await;
        info!("Started LoadMonitor for PowerOfTwo policies");
    }

751
    let (limiter, processor) = middleware::ConcurrencyLimiter::new(
752
753
754
755
756
757
        app_context.rate_limiter.clone(),
        config.router_config.queue_size,
        Duration::from_secs(config.router_config.queue_timeout_secs),
    );

    if let Some(processor) = processor {
758
        spawn(processor.run());
759
760
761
762
763
764
        info!(
            "Started request queue with size: {}, timeout: {}s",
            config.router_config.queue_size, config.router_config.queue_timeout_secs
        );
    }

765
    let app_state = Arc::new(AppState {
766
        router,
767
        context: app_context.clone(),
768
        concurrency_queue_tx: limiter.queue_tx.clone(),
769
        router_manager: Some(router_manager),
770
    });
771
772
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
773
774
            let app_context_arc = Arc::clone(&app_state.context);
            match start_service_discovery(service_discovery_config, app_context_arc).await {
775
                Ok(handle) => {
776
                    info!("Service discovery started");
777
778
779
780
781
782
783
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
784
                    error!("Failed to start service discovery: {e}");
785
786
787
788
789
790
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

791
    info!(
792
        "Router ready | workers: {:?}",
793
        WorkerManager::get_worker_urls(&app_state.context.worker_registry)
794
    );
795

796
797
798
799
800
801
802
803
804
    let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

805
806
807
808
    let auth_config = AuthConfig {
        api_key: config.router_config.api_key.clone(),
    };

809
810
    let app = build_app(
        app_state,
811
        auth_config,
812
813
814
815
        config.max_payload_size,
        request_id_headers,
        config.router_config.cors_allowed_origins.clone(),
    );
816

817
818
819
820
821
822
    // TcpListener::bind accepts &str and handles IPv4/IPv6 via ToSocketAddrs
    let bind_addr = format!("{}:{}", config.host, config.port);
    info!("Starting server on {}", bind_addr);
    let listener = TcpListener::bind(&bind_addr)
        .await
        .map_err(|e| format!("Failed to bind to {}: {}", bind_addr, e))?;
823
    serve(listener, app)
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
        .with_graceful_shutdown(shutdown_signal())
        .await
        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;

    Ok(())
}

async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {
            info!("Received Ctrl+C, starting graceful shutdown");
        },
        _ = terminate => {
            info!("Received terminate signal, starting graceful shutdown");
        },
    }
}

fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
    use tower_http::cors::Any;

    let cors = if allowed_origins.is_empty() {
        tower_http::cors::CorsLayer::new()
            .allow_origin(Any)
            .allow_methods(Any)
            .allow_headers(Any)
            .expose_headers(Any)
    } else {
        let origins: Vec<http::HeaderValue> = allowed_origins
            .into_iter()
            .filter_map(|origin| origin.parse().ok())
            .collect();

        tower_http::cors::CorsLayer::new()
            .allow_origin(origins)
            .allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
            .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
            .expose_headers([http::header::HeaderName::from_static("x-request-id")])
    };

    cors.max_age(Duration::from_secs(3600))
882
}