server.rs 29.9 KB
Newer Older
1
2
3
use std::{
    sync::{
        atomic::{AtomicBool, Ordering},
4
        Arc,
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
    },
    time::Duration,
};

use axum::{
    extract::{Path, Query, Request, State},
    http::StatusCode,
    response::{IntoResponse, Response},
    routing::{delete, get, post},
    serve, Json, Router,
};
use serde::Deserialize;
use serde_json::{json, Value};
use tokio::{net::TcpListener, signal, spawn};
use tracing::{error, info, warn, Level};

21
use crate::{
22
    app_context::AppContext,
23
    config::{RouterConfig, RoutingMode},
24
    core::{
25
26
        worker_to_info,
        workflow::{
27
28
            create_mcp_registration_workflow, create_worker_registration_workflow,
            create_worker_removal_workflow, LoggingSubscriber, WorkflowEngine,
29
        },
30
        Job, JobQueue, JobQueueConfig, WorkerManager, WorkerType,
31
    },
32
33
    logging::{self, LoggingConfig},
    metrics::{self, PrometheusConfig},
34
    middleware::{self, AuthConfig, QueuedRequest},
35
    protocols::{
36
        chat::ChatCompletionRequest,
37
        classify::ClassifyRequest,
38
39
40
41
42
        completion::CompletionRequest,
        embedding::EmbeddingRequest,
        generate::GenerateRequest,
        rerank::{RerankRequest, V1RerankReqInput},
        responses::{ResponsesGetParams, ResponsesRequest},
43
        validated::ValidatedJson,
44
        worker_spec::{WorkerConfigRequest, WorkerErrorResponse, WorkerInfo},
45
    },
46
    routers::{router_manager::RouterManager, RouterTrait},
47
    service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
48
};
49

50
51
52
53
#[derive(Clone)]
pub struct AppState {
    pub router: Arc<dyn RouterTrait>,
    pub context: Arc<AppContext>,
54
    pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<QueuedRequest>>,
55
    pub router_manager: Option<Arc<RouterManager>>,
56
57
}

58
59
async fn sink_handler() -> Response {
    StatusCode::NOT_FOUND.into_response()
60
61
}

62
63
async fn liveness() -> Response {
    (StatusCode::OK, "OK").into_response()
64
65
}

66
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    let workers = state.context.worker_registry.get_all();
    let healthy_workers: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect();

    let is_ready = if state.context.router_config.enable_igw {
        !healthy_workers.is_empty()
    } else {
        match &state.context.router_config.mode {
            RoutingMode::PrefillDecode { .. } => {
                let has_prefill = healthy_workers
                    .iter()
                    .any(|w| matches!(w.worker_type(), WorkerType::Prefill { .. }));
                let has_decode = healthy_workers
                    .iter()
                    .any(|w| matches!(w.worker_type(), WorkerType::Decode));
                has_prefill && has_decode
            }
            RoutingMode::Regular { .. } => !healthy_workers.is_empty(),
            RoutingMode::OpenAI { .. } => !healthy_workers.is_empty(),
        }
    };

    if is_ready {
        (
            StatusCode::OK,
            Json(json!({
                "status": "ready",
                "healthy_workers": healthy_workers.len(),
                "total_workers": workers.len()
            })),
        )
            .into_response()
    } else {
        (
            StatusCode::SERVICE_UNAVAILABLE,
            Json(json!({
                "status": "not ready",
                "reason": "insufficient healthy workers"
            })),
        )
            .into_response()
    }
108
109
}

110
111
async fn health(_state: State<Arc<AppState>>) -> Response {
    liveness().await
112
113
}

114
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
115
    state.router.health_generate(req).await
116
117
}

118
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
119
    state.router.get_server_info(req).await
120
121
}

122
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
123
    state.router.get_models(req).await
124
125
}

126
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
127
    state.router.get_model_info(req).await
128
}
129

130
async fn generate(
131
132
133
134
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<GenerateRequest>,
) -> Response {
135
136
137
138
    state
        .router
        .route_generate(Some(&headers), &body, None)
        .await
139
140
141
}

async fn v1_chat_completions(
142
143
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
144
    ValidatedJson(body): ValidatedJson<ChatCompletionRequest>,
145
) -> Response {
146
    state.router.route_chat(Some(&headers), &body, None).await
147
148
149
}

async fn v1_completions(
150
151
152
153
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<CompletionRequest>,
) -> Response {
154
155
156
157
    state
        .router
        .route_completion(Some(&headers), &body, None)
        .await
158
159
}

160
161
162
async fn rerank(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
163
    ValidatedJson(body): ValidatedJson<RerankRequest>,
164
) -> Response {
165
    state.router.route_rerank(Some(&headers), &body, None).await
166
167
168
169
170
171
172
173
174
}

async fn v1_rerank(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<V1RerankReqInput>,
) -> Response {
    state
        .router
175
        .route_rerank(Some(&headers), &body.into(), None)
176
177
178
        .await
}

179
180
181
182
183
async fn v1_responses(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ResponsesRequest>,
) -> Response {
184
185
186
187
    state
        .router
        .route_responses(Some(&headers), &body, None)
        .await
188
189
}

190
191
192
193
194
195
196
197
198
199
200
async fn v1_embeddings(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<EmbeddingRequest>,
) -> Response {
    state
        .router
        .route_embeddings(Some(&headers), &body, None)
        .await
}

201
202
203
204
205
206
207
208
209
210
211
async fn v1_classify(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ClassifyRequest>,
) -> Response {
    state
        .router
        .route_classify(Some(&headers), &body, None)
        .await
}

212
213
214
215
async fn v1_responses_get(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
216
    Query(params): Query<ResponsesGetParams>,
217
218
219
) -> Response {
    state
        .router
220
        .get_response(Some(&headers), &response_id, &params)
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        .await
}

async fn v1_responses_cancel(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .cancel_response(Some(&headers), &response_id)
        .await
}

async fn v1_responses_delete(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .delete_response(Some(&headers), &response_id)
        .await
}

async fn v1_responses_list_input_items(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .list_response_input_items(Some(&headers), &response_id)
        .await
}

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
async fn v1_conversations_create(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<Value>,
) -> Response {
    state
        .router
        .create_conversation(Some(&headers), &body)
        .await
}

async fn v1_conversations_get(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .get_conversation(Some(&headers), &conversation_id)
        .await
}

async fn v1_conversations_update(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
    Json(body): Json<Value>,
) -> Response {
    state
        .router
        .update_conversation(Some(&headers), &conversation_id, &body)
        .await
}

async fn v1_conversations_delete(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .delete_conversation(Some(&headers), &conversation_id)
        .await
}

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
#[derive(Deserialize, Default)]
struct ListItemsQuery {
    limit: Option<usize>,
    order: Option<String>,
    after: Option<String>,
}

async fn v1_conversations_list_items(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    Query(ListItemsQuery {
        limit,
        order,
        after,
    }): Query<ListItemsQuery>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .list_conversation_items(Some(&headers), &conversation_id, limit, order, after)
        .await
}

325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
#[derive(Deserialize, Default)]
struct GetItemQuery {
    /// Additional fields to include in response (not yet implemented)
    include: Option<Vec<String>>,
}

async fn v1_conversations_create_items(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
    Json(body): Json<Value>,
) -> Response {
    state
        .router
        .create_conversation_items(Some(&headers), &conversation_id, &body)
        .await
}

async fn v1_conversations_get_item(
    State(state): State<Arc<AppState>>,
    Path((conversation_id, item_id)): Path<(String, String)>,
    Query(query): Query<GetItemQuery>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .get_conversation_item(Some(&headers), &conversation_id, &item_id, query.include)
        .await
}

async fn v1_conversations_delete_item(
    State(state): State<Arc<AppState>>,
    Path((conversation_id, item_id)): Path<(String, String)>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .delete_conversation_item(Some(&headers), &conversation_id, &item_id)
        .await
}

366
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
    match WorkerManager::flush_cache_all(&state.context.worker_registry, &state.context.client)
        .await
    {
        Ok(result) => {
            if result.failed.is_empty() {
                (
                    StatusCode::OK,
                    Json(json!({
                        "status": "success",
                        "message": result.message,
                        "workers_flushed": result.successful.len(),
                        "total_http_workers": result.http_workers,
                        "total_workers": result.total_workers
                    })),
                )
                    .into_response()
            } else {
                (
                    StatusCode::PARTIAL_CONTENT,
                    Json(json!({
                        "status": "partial_success",
                        "message": result.message,
                        "successful": result.successful,
                        "failed": result.failed.into_iter().map(|(url, err)| json!({
                            "worker": url,
                            "error": err
                        })).collect::<Vec<_>>(),
                        "total_http_workers": result.http_workers,
                        "total_workers": result.total_workers
                    })),
                )
                    .into_response()
            }
        }
        Err(e) => {
            error!("Failed to flush cache: {}", e);
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(json!({
                    "status": "error",
                    "message": format!("Failed to flush cache: {}", e)
                })),
            )
                .into_response()
        }
    }
413
414
}

415
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    let result =
        WorkerManager::get_all_worker_loads(&state.context.worker_registry, &state.context.client)
            .await;

    let loads: Vec<Value> = result
        .loads
        .iter()
        .map(|info| {
            json!({
                "worker": &info.worker,
                "load": info.load
            })
        })
        .collect();

431
    (StatusCode::OK, Json(json!({ "workers": loads }))).into_response()
432
433
}

434
435
436
437
async fn create_worker(
    State(state): State<Arc<AppState>>,
    Json(config): Json<WorkerConfigRequest>,
) -> Response {
438
439
440
441
442
443
444
445
446
447
    // Warn if router has API key but worker is being added without one
    if state.context.router_config.api_key.is_some() && config.api_key.is_none() {
        warn!(
            "Adding worker {} without API key while router has API key configured. \
            Worker will be accessible without authentication. \
            If the worker requires the same API key as the router, please specify it explicitly.",
            config.url
        );
    }

448
449
450
451
452
453
    // Populate dp_aware from router's configuration
    let config = WorkerConfigRequest {
        dp_aware: state.context.router_config.dp_aware,
        ..config
    };

454
455
456
457
458
    // Submit job for async processing
    let worker_url = config.url.clone();
    let job = Job::AddWorker {
        config: Box::new(config),
    };
459

460
461
462
463
464
465
466
467
468
469
470
471
472
    let job_queue = state
        .context
        .worker_job_queue
        .get()
        .expect("JobQueue not initialized");
    match job_queue.submit(job).await {
        Ok(_) => {
            let response = json!({
                "status": "accepted",
                "worker_id": worker_url,
                "message": "Worker addition queued for background processing"
            });
            (StatusCode::ACCEPTED, Json(response)).into_response()
473
        }
474
475
476
        Err(error) => {
            let error_response = WorkerErrorResponse {
                error,
477
                code: "INTERNAL_SERVER_ERROR".to_string(),
478
            };
479
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)).into_response()
480
481
482
483
484
        }
    }
}

async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
485
    let workers = state.context.worker_registry.get_all();
486
    let worker_infos: Vec<WorkerInfo> = workers.iter().map(worker_to_info).collect();
487

488
489
    let response = json!({
        "workers": worker_infos,
490
491
492
493
494
495
496
497
        "total": workers.len(),
        "stats": {
            "prefill_count": state.context.worker_registry.get_prefill_workers().len(),
            "decode_count": state.context.worker_registry.get_decode_workers().len(),
            "regular_count": state.context.worker_registry.get_by_type(&WorkerType::Regular).len(),
        }
    });
    Json(response).into_response()
498
499
}

500
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
    let job_queue = state
        .context
        .worker_job_queue
        .get()
        .expect("JobQueue not initialized");

    if let Some(worker) = state.context.worker_registry.get_by_url(&url) {
        // Worker exists in registry, get its full info and attach job status if any
        let mut worker_info = worker_to_info(&worker);
        if let Some(status) = job_queue.get_status(&url) {
            worker_info.job_status = Some(status);
        }
        return Json(worker_info).into_response();
    }

    // Worker not in registry, check job queue for its status
    if let Some(status) = job_queue.get_status(&url) {
        // Create a partial WorkerInfo to report the job status
        let worker_info = WorkerInfo {
            id: url.clone(),
            url: url.clone(),
            model_id: "unknown".to_string(),
            priority: 0,
            cost: 1.0,
            worker_type: "unknown".to_string(),
            is_healthy: false,
            load: 0,
            connection_mode: "unknown".to_string(),
            tokenizer_path: None,
            reasoning_parser: None,
            tool_parser: None,
            chat_template: None,
            bootstrap_port: None,
            metadata: std::collections::HashMap::new(),
            job_status: Some(status),
536
        };
537
        return Json(worker_info).into_response();
538
    }
539
540
541
542
543
544
545

    // Worker not found in registry or job queue
    let error = WorkerErrorResponse {
        error: format!("Worker {url} not found"),
        code: "WORKER_NOT_FOUND".to_string(),
    };
    (StatusCode::NOT_FOUND, Json(error)).into_response()
546
547
}

548
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    let worker_id = url.clone();
    let job = Job::RemoveWorker { url };

    let job_queue = state
        .context
        .worker_job_queue
        .get()
        .expect("JobQueue not initialized");
    match job_queue.submit(job).await {
        Ok(_) => {
            let response = json!({
                "status": "accepted",
                "worker_id": worker_id,
                "message": "Worker removal queued for background processing"
            });
            (StatusCode::ACCEPTED, Json(response)).into_response()
565
566
567
568
        }
        Err(error) => {
            let error_response = WorkerErrorResponse {
                error,
569
                code: "INTERNAL_SERVER_ERROR".to_string(),
570
            };
571
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)).into_response()
572
573
574
575
        }
    }
}

576
577
578
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
579
    pub router_config: RouterConfig,
580
    pub max_payload_size: usize,
581
    pub log_dir: Option<String>,
582
    pub log_level: Option<String>,
583
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
584
    pub prometheus_config: Option<PrometheusConfig>,
585
    pub request_timeout_secs: u64,
586
    pub request_id_headers: Option<Vec<String>>,
587
588
}

589
590
pub fn build_app(
    app_state: Arc<AppState>,
591
    auth_config: AuthConfig,
592
593
594
595
596
597
598
    max_payload_size: usize,
    request_id_headers: Vec<String>,
    cors_allowed_origins: Vec<String>,
) -> Router {
    let protected_routes = Router::new()
        .route("/generate", post(generate))
        .route("/v1/chat/completions", post(v1_chat_completions))
599
        .route("/v1/completions", post(v1_completions))
600
601
        .route("/rerank", post(rerank))
        .route("/v1/rerank", post(v1_rerank))
602
        .route("/v1/responses", post(v1_responses))
603
        .route("/v1/embeddings", post(v1_embeddings))
604
        .route("/v1/classify", post(v1_classify))
605
606
607
608
609
610
611
        .route("/v1/responses/{response_id}", get(v1_responses_get))
        .route(
            "/v1/responses/{response_id}/cancel",
            post(v1_responses_cancel),
        )
        .route("/v1/responses/{response_id}", delete(v1_responses_delete))
        .route(
612
            "/v1/responses/{response_id}/input_items",
613
614
            get(v1_responses_list_input_items),
        )
615
616
617
618
619
620
621
        .route("/v1/conversations", post(v1_conversations_create))
        .route(
            "/v1/conversations/{conversation_id}",
            get(v1_conversations_get)
                .post(v1_conversations_update)
                .delete(v1_conversations_delete),
        )
622
623
        .route(
            "/v1/conversations/{conversation_id}/items",
624
625
626
627
628
            get(v1_conversations_list_items).post(v1_conversations_create_items),
        )
        .route(
            "/v1/conversations/{conversation_id}/items/{item_id}",
            get(v1_conversations_get_item).delete(v1_conversations_delete_item),
629
        )
630
631
        .route_layer(axum::middleware::from_fn_with_state(
            app_state.clone(),
632
            middleware::concurrency_limit_middleware,
633
634
635
636
        ))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
637
        ));
638
639
640
641
642
643
644
645
646
647
648
649

    let public_routes = Router::new()
        .route("/liveness", get(liveness))
        .route("/readiness", get(readiness))
        .route("/health", get(health))
        .route("/health_generate", get(health_generate))
        .route("/v1/models", get(v1_models))
        .route("/get_model_info", get(get_model_info))
        .route("/get_server_info", get(get_server_info));

    let admin_routes = Router::new()
        .route("/flush_cache", post(flush_cache))
650
651
652
653
654
        .route("/get_loads", get(get_loads))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
        ));
655

656
657
658
659
    let worker_routes = Router::new()
        .route("/workers", post(create_worker))
        .route("/workers", get(list_workers_rest))
        .route("/workers/{url}", get(get_worker))
660
661
662
663
664
        .route("/workers/{url}", delete(delete_worker))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
        ));
665

666
667
668
669
    Router::new()
        .merge(protected_routes)
        .merge(public_routes)
        .merge(admin_routes)
670
        .merge(worker_routes)
671
        .layer(axum::extract::DefaultBodyLimit::max(max_payload_size))
672
673
674
        .layer(tower_http::limit::RequestBodyLimitLayer::new(
            max_payload_size,
        ))
675
676
        .layer(middleware::create_logging_layer())
        .layer(middleware::RequestIdLayer::new(request_id_headers))
677
678
679
680
681
682
        .layer(create_cors_layer(cors_allowed_origins))
        .fallback(sink_handler)
        .with_state(app_state)
}

pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
683
684
685
686
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
687
688
689
690
691
692
            level: config
                .log_level
                .as_deref()
                .and_then(|s| match s.to_uppercase().parse::<Level>() {
                    Ok(l) => Some(l),
                    Err(_) => {
693
                        warn!("Invalid log level string: '{s}'. Defaulting to INFO.");
694
695
696
697
                        None
                    }
                })
                .unwrap_or(Level::INFO),
698
699
700
701
702
703
704
705
706
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
707

708
709
    if let Some(prometheus_config) = &config.prometheus_config {
        metrics::start_prometheus(prometheus_config.clone());
710
711
    }

712
    info!(
713
714
715
716
717
        "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
        config.host,
        config.port,
        config.router_config.mode,
        config.router_config.policy,
718
719
720
        config.max_payload_size / (1024 * 1024)
    );

721
    let app_context = Arc::new(
722
        AppContext::from_config(config.router_config.clone(), config.request_timeout_secs).await?,
723
    );
724

725
726
727
728
729
730
731
    let weak_context = Arc::downgrade(&app_context);
    let worker_job_queue = JobQueue::new(JobQueueConfig::default(), weak_context);
    app_context
        .worker_job_queue
        .set(worker_job_queue)
        .expect("JobQueue should only be initialized once");

732
733
734
735
736
    // Initialize workflow engine and register workflows
    let engine = Arc::new(WorkflowEngine::new());

    engine
        .event_bus()
737
        .subscribe(Arc::new(LoggingSubscriber))
738
739
        .await;

740
    engine.register_workflow(create_worker_registration_workflow(&config.router_config));
741
    engine.register_workflow(create_worker_removal_workflow());
742
    engine.register_workflow(create_mcp_registration_workflow());
743
744
745
746
    app_context
        .workflow_engine
        .set(engine)
        .expect("WorkflowEngine should only be initialized once");
747
748
749
750
    info!(
        "Workflow engine initialized with worker and MCP registration workflows (health check timeout: {}s)",
        config.router_config.health_check.timeout_secs
    );
751

752
753
754
755
    info!(
        "Initializing workers for routing mode: {:?}",
        config.router_config.mode
    );
756
757
758
759
760
761
762
763
764
765
766
767
768

    // Submit worker initialization job to queue
    let job_queue = app_context
        .worker_job_queue
        .get()
        .expect("JobQueue should be initialized");
    let job = Job::InitializeWorkersFromConfig {
        router_config: Box::new(config.router_config.clone()),
    };
    job_queue
        .submit(job)
        .await
        .map_err(|e| format!("Failed to submit worker initialization job: {}", e))?;
769

770
771
    info!("Worker initialization job submitted (will complete in background)");

772
773
774
775
776
777
778
779
780
781
782
783
784
    if let Some(mcp_config) = &config.router_config.mcp_config {
        info!("Found {} MCP server(s) in config", mcp_config.servers.len());
        let mcp_job = Job::InitializeMcpServers {
            mcp_config: Box::new(mcp_config.clone()),
        };
        job_queue
            .submit(mcp_job)
            .await
            .map_err(|e| format!("Failed to submit MCP initialization job: {}", e))?;
    } else {
        info!("No MCP config provided, skipping MCP server initialization");
    }

785
    // Start background refresh for ALL MCP servers (static + dynamic in LRU cache)
786
    if let Some(mcp_manager) = app_context.mcp_manager.get() {
787
        let refresh_interval = Duration::from_secs(600); // 10 minutes
788
789
        let _refresh_handle =
            Arc::clone(mcp_manager).spawn_background_refresh_all(refresh_interval);
790
        info!("Started background refresh for all MCP servers (every 10 minutes)");
791
792
    }

793
794
795
796
797
798
    let worker_stats = app_context.worker_registry.stats();
    info!(
        "Workers initialized: {} total, {} healthy",
        worker_stats.total_workers, worker_stats.healthy_workers
    );

799
800
    let router_manager = RouterManager::from_config(&config, &app_context).await?;
    let router: Arc<dyn RouterTrait> = router_manager.clone();
801
802
803
804
805
806
807
808

    let _health_checker = app_context
        .worker_registry
        .start_health_checker(config.router_config.health_check.check_interval_secs);
    info!(
        "Started health checker for workers with {}s interval",
        config.router_config.health_check.check_interval_secs
    );
809

810
811
812
813
814
    if let Some(ref load_monitor) = app_context.load_monitor {
        load_monitor.start().await;
        info!("Started LoadMonitor for PowerOfTwo policies");
    }

815
    let (limiter, processor) = middleware::ConcurrencyLimiter::new(
816
817
818
819
820
        app_context.rate_limiter.clone(),
        config.router_config.queue_size,
        Duration::from_secs(config.router_config.queue_timeout_secs),
    );

821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
    if app_context.rate_limiter.is_none() {
        info!("Rate limiting is disabled (max_concurrent_requests = -1)");
    }

    match processor {
        Some(proc) => {
            spawn(proc.run());
            info!(
                "Started request queue (size: {}, timeout: {}s)",
                config.router_config.queue_size, config.router_config.queue_timeout_secs
            );
        }
        None => {
            info!(
                "Rate limiting enabled (max_concurrent_requests = {}, queue disabled)",
                config.router_config.max_concurrent_requests
            );
        }
839
840
    }

841
    let app_state = Arc::new(AppState {
842
        router,
843
        context: app_context.clone(),
844
        concurrency_queue_tx: limiter.queue_tx.clone(),
845
        router_manager: Some(router_manager),
846
    });
847
848
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
849
850
            let app_context_arc = Arc::clone(&app_state.context);
            match start_service_discovery(service_discovery_config, app_context_arc).await {
851
                Ok(handle) => {
852
                    info!("Service discovery started");
853
854
855
856
857
858
859
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
860
                    error!("Failed to start service discovery: {e}");
861
862
863
864
865
866
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

867
    info!(
868
        "Router ready | workers: {:?}",
869
        WorkerManager::get_worker_urls(&app_state.context.worker_registry)
870
    );
871

872
873
874
875
876
877
878
879
880
    let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

881
882
883
884
    let auth_config = AuthConfig {
        api_key: config.router_config.api_key.clone(),
    };

885
886
    let app = build_app(
        app_state,
887
        auth_config,
888
889
890
891
        config.max_payload_size,
        request_id_headers,
        config.router_config.cors_allowed_origins.clone(),
    );
892

893
894
895
896
897
898
    // TcpListener::bind accepts &str and handles IPv4/IPv6 via ToSocketAddrs
    let bind_addr = format!("{}:{}", config.host, config.port);
    info!("Starting server on {}", bind_addr);
    let listener = TcpListener::bind(&bind_addr)
        .await
        .map_err(|e| format!("Failed to bind to {}: {}", bind_addr, e))?;
899
    serve(listener, app)
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
        .with_graceful_shutdown(shutdown_signal())
        .await
        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;

    Ok(())
}

async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {
            info!("Received Ctrl+C, starting graceful shutdown");
        },
        _ = terminate => {
            info!("Received terminate signal, starting graceful shutdown");
        },
    }
}

fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
    use tower_http::cors::Any;

    let cors = if allowed_origins.is_empty() {
        tower_http::cors::CorsLayer::new()
            .allow_origin(Any)
            .allow_methods(Any)
            .allow_headers(Any)
            .expose_headers(Any)
    } else {
        let origins: Vec<http::HeaderValue> = allowed_origins
            .into_iter()
            .filter_map(|origin| origin.parse().ok())
            .collect();

        tower_http::cors::CorsLayer::new()
            .allow_origin(origins)
            .allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
            .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
            .expose_headers([http::header::HeaderName::from_static("x-request-id")])
    };

    cors.max_age(Duration::from_secs(3600))
958
}