server.rs 30.1 KB
Newer Older
1
2
3
use std::{
    sync::{
        atomic::{AtomicBool, Ordering},
4
        Arc,
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
    },
    time::Duration,
};

use axum::{
    extract::{Path, Query, Request, State},
    http::StatusCode,
    response::{IntoResponse, Response},
    routing::{delete, get, post},
    serve, Json, Router,
};
use serde::Deserialize;
use serde_json::{json, Value};
use tokio::{net::TcpListener, signal, spawn};
use tracing::{error, info, warn, Level};

21
use crate::{
22
    app_context::AppContext,
23
    config::{RouterConfig, RoutingMode},
24
    core::{
25
26
        worker_to_info,
        workflow::{
27
28
            create_mcp_registration_workflow, create_worker_registration_workflow,
            create_worker_removal_workflow, LoggingSubscriber, WorkflowEngine,
29
        },
30
        Job, JobQueue, JobQueueConfig, WorkerManager, WorkerType,
31
    },
32
33
    logging::{self, LoggingConfig},
    metrics::{self, PrometheusConfig},
34
    middleware::{self, AuthConfig, QueuedRequest},
35
    protocols::{
36
        chat::ChatCompletionRequest,
37
        classify::ClassifyRequest,
38
39
40
41
42
        completion::CompletionRequest,
        embedding::EmbeddingRequest,
        generate::GenerateRequest,
        rerank::{RerankRequest, V1RerankReqInput},
        responses::{ResponsesGetParams, ResponsesRequest},
43
        validated::ValidatedJson,
44
        worker_spec::{WorkerConfigRequest, WorkerErrorResponse, WorkerInfo},
45
    },
46
    routers::{router_manager::RouterManager, RouterTrait},
47
    service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
48
};
49

50
51
52
53
#[derive(Clone)]
pub struct AppState {
    pub router: Arc<dyn RouterTrait>,
    pub context: Arc<AppContext>,
54
    pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<QueuedRequest>>,
55
    pub router_manager: Option<Arc<RouterManager>>,
56
57
}

58
59
async fn sink_handler() -> Response {
    StatusCode::NOT_FOUND.into_response()
60
61
}

62
63
async fn liveness() -> Response {
    (StatusCode::OK, "OK").into_response()
64
65
}

66
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    let workers = state.context.worker_registry.get_all();
    let healthy_workers: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect();

    let is_ready = if state.context.router_config.enable_igw {
        !healthy_workers.is_empty()
    } else {
        match &state.context.router_config.mode {
            RoutingMode::PrefillDecode { .. } => {
                let has_prefill = healthy_workers
                    .iter()
                    .any(|w| matches!(w.worker_type(), WorkerType::Prefill { .. }));
                let has_decode = healthy_workers
                    .iter()
                    .any(|w| matches!(w.worker_type(), WorkerType::Decode));
                has_prefill && has_decode
            }
            RoutingMode::Regular { .. } => !healthy_workers.is_empty(),
            RoutingMode::OpenAI { .. } => !healthy_workers.is_empty(),
        }
    };

    if is_ready {
        (
            StatusCode::OK,
            Json(json!({
                "status": "ready",
                "healthy_workers": healthy_workers.len(),
                "total_workers": workers.len()
            })),
        )
            .into_response()
    } else {
        (
            StatusCode::SERVICE_UNAVAILABLE,
            Json(json!({
                "status": "not ready",
                "reason": "insufficient healthy workers"
            })),
        )
            .into_response()
    }
108
109
}

110
111
async fn health(_state: State<Arc<AppState>>) -> Response {
    liveness().await
112
113
}

114
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
115
    state.router.health_generate(req).await
116
117
}

118
119
120
121
async fn engine_metrics(State(state): State<Arc<AppState>>) -> Response {
    WorkerManager::get_engine_metrics(&state.context.worker_registry, &state.context.client).await
}

122
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
123
    state.router.get_server_info(req).await
124
125
}

126
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
127
    state.router.get_models(req).await
128
129
}

130
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
131
    state.router.get_model_info(req).await
132
}
133

134
async fn generate(
135
136
137
138
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<GenerateRequest>,
) -> Response {
139
140
141
142
    state
        .router
        .route_generate(Some(&headers), &body, None)
        .await
143
144
145
}

async fn v1_chat_completions(
146
147
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
148
    ValidatedJson(body): ValidatedJson<ChatCompletionRequest>,
149
) -> Response {
150
    state.router.route_chat(Some(&headers), &body, None).await
151
152
153
}

async fn v1_completions(
154
155
156
157
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<CompletionRequest>,
) -> Response {
158
159
160
161
    state
        .router
        .route_completion(Some(&headers), &body, None)
        .await
162
163
}

164
165
166
async fn rerank(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
167
    ValidatedJson(body): ValidatedJson<RerankRequest>,
168
) -> Response {
169
    state.router.route_rerank(Some(&headers), &body, None).await
170
171
172
173
174
175
176
177
178
}

async fn v1_rerank(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<V1RerankReqInput>,
) -> Response {
    state
        .router
179
        .route_rerank(Some(&headers), &body.into(), None)
180
181
182
        .await
}

183
184
185
186
187
async fn v1_responses(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ResponsesRequest>,
) -> Response {
188
189
190
191
    state
        .router
        .route_responses(Some(&headers), &body, None)
        .await
192
193
}

194
195
196
197
198
199
200
201
202
203
204
async fn v1_embeddings(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<EmbeddingRequest>,
) -> Response {
    state
        .router
        .route_embeddings(Some(&headers), &body, None)
        .await
}

205
206
207
208
209
210
211
212
213
214
215
async fn v1_classify(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ClassifyRequest>,
) -> Response {
    state
        .router
        .route_classify(Some(&headers), &body, None)
        .await
}

216
217
218
219
async fn v1_responses_get(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
220
    Query(params): Query<ResponsesGetParams>,
221
222
223
) -> Response {
    state
        .router
224
        .get_response(Some(&headers), &response_id, &params)
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
        .await
}

async fn v1_responses_cancel(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .cancel_response(Some(&headers), &response_id)
        .await
}

async fn v1_responses_delete(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .delete_response(Some(&headers), &response_id)
        .await
}

async fn v1_responses_list_input_items(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .list_response_input_items(Some(&headers), &response_id)
        .await
}

261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
async fn v1_conversations_create(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<Value>,
) -> Response {
    state
        .router
        .create_conversation(Some(&headers), &body)
        .await
}

async fn v1_conversations_get(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .get_conversation(Some(&headers), &conversation_id)
        .await
}

async fn v1_conversations_update(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
    Json(body): Json<Value>,
) -> Response {
    state
        .router
        .update_conversation(Some(&headers), &conversation_id, &body)
        .await
}

async fn v1_conversations_delete(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .delete_conversation(Some(&headers), &conversation_id)
        .await
}

306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
#[derive(Deserialize, Default)]
struct ListItemsQuery {
    limit: Option<usize>,
    order: Option<String>,
    after: Option<String>,
}

async fn v1_conversations_list_items(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    Query(ListItemsQuery {
        limit,
        order,
        after,
    }): Query<ListItemsQuery>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .list_conversation_items(Some(&headers), &conversation_id, limit, order, after)
        .await
}

329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
#[derive(Deserialize, Default)]
struct GetItemQuery {
    /// Additional fields to include in response (not yet implemented)
    include: Option<Vec<String>>,
}

async fn v1_conversations_create_items(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
    Json(body): Json<Value>,
) -> Response {
    state
        .router
        .create_conversation_items(Some(&headers), &conversation_id, &body)
        .await
}

async fn v1_conversations_get_item(
    State(state): State<Arc<AppState>>,
    Path((conversation_id, item_id)): Path<(String, String)>,
    Query(query): Query<GetItemQuery>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .get_conversation_item(Some(&headers), &conversation_id, &item_id, query.include)
        .await
}

async fn v1_conversations_delete_item(
    State(state): State<Arc<AppState>>,
    Path((conversation_id, item_id)): Path<(String, String)>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .delete_conversation_item(Some(&headers), &conversation_id, &item_id)
        .await
}

370
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
    match WorkerManager::flush_cache_all(&state.context.worker_registry, &state.context.client)
        .await
    {
        Ok(result) => {
            if result.failed.is_empty() {
                (
                    StatusCode::OK,
                    Json(json!({
                        "status": "success",
                        "message": result.message,
                        "workers_flushed": result.successful.len(),
                        "total_http_workers": result.http_workers,
                        "total_workers": result.total_workers
                    })),
                )
                    .into_response()
            } else {
                (
                    StatusCode::PARTIAL_CONTENT,
                    Json(json!({
                        "status": "partial_success",
                        "message": result.message,
                        "successful": result.successful,
                        "failed": result.failed.into_iter().map(|(url, err)| json!({
                            "worker": url,
                            "error": err
                        })).collect::<Vec<_>>(),
                        "total_http_workers": result.http_workers,
                        "total_workers": result.total_workers
                    })),
                )
                    .into_response()
            }
        }
        Err(e) => {
            error!("Failed to flush cache: {}", e);
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(json!({
                    "status": "error",
                    "message": format!("Failed to flush cache: {}", e)
                })),
            )
                .into_response()
        }
    }
417
418
}

419
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
    let result =
        WorkerManager::get_all_worker_loads(&state.context.worker_registry, &state.context.client)
            .await;

    let loads: Vec<Value> = result
        .loads
        .iter()
        .map(|info| {
            json!({
                "worker": &info.worker,
                "load": info.load
            })
        })
        .collect();

435
    (StatusCode::OK, Json(json!({ "workers": loads }))).into_response()
436
437
}

438
439
440
441
async fn create_worker(
    State(state): State<Arc<AppState>>,
    Json(config): Json<WorkerConfigRequest>,
) -> Response {
442
443
444
445
446
447
448
449
450
451
    // Warn if router has API key but worker is being added without one
    if state.context.router_config.api_key.is_some() && config.api_key.is_none() {
        warn!(
            "Adding worker {} without API key while router has API key configured. \
            Worker will be accessible without authentication. \
            If the worker requires the same API key as the router, please specify it explicitly.",
            config.url
        );
    }

452
453
454
455
456
457
    // Populate dp_aware from router's configuration
    let config = WorkerConfigRequest {
        dp_aware: state.context.router_config.dp_aware,
        ..config
    };

458
459
460
461
462
    // Submit job for async processing
    let worker_url = config.url.clone();
    let job = Job::AddWorker {
        config: Box::new(config),
    };
463

464
465
466
467
468
469
470
471
472
473
474
475
476
    let job_queue = state
        .context
        .worker_job_queue
        .get()
        .expect("JobQueue not initialized");
    match job_queue.submit(job).await {
        Ok(_) => {
            let response = json!({
                "status": "accepted",
                "worker_id": worker_url,
                "message": "Worker addition queued for background processing"
            });
            (StatusCode::ACCEPTED, Json(response)).into_response()
477
        }
478
479
480
        Err(error) => {
            let error_response = WorkerErrorResponse {
                error,
481
                code: "INTERNAL_SERVER_ERROR".to_string(),
482
            };
483
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)).into_response()
484
485
486
487
488
        }
    }
}

async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
489
    let workers = state.context.worker_registry.get_all();
490
    let worker_infos: Vec<WorkerInfo> = workers.iter().map(worker_to_info).collect();
491

492
493
    let response = json!({
        "workers": worker_infos,
494
495
496
497
498
499
500
501
        "total": workers.len(),
        "stats": {
            "prefill_count": state.context.worker_registry.get_prefill_workers().len(),
            "decode_count": state.context.worker_registry.get_decode_workers().len(),
            "regular_count": state.context.worker_registry.get_by_type(&WorkerType::Regular).len(),
        }
    });
    Json(response).into_response()
502
503
}

504
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
    let job_queue = state
        .context
        .worker_job_queue
        .get()
        .expect("JobQueue not initialized");

    if let Some(worker) = state.context.worker_registry.get_by_url(&url) {
        // Worker exists in registry, get its full info and attach job status if any
        let mut worker_info = worker_to_info(&worker);
        if let Some(status) = job_queue.get_status(&url) {
            worker_info.job_status = Some(status);
        }
        return Json(worker_info).into_response();
    }

    // Worker not in registry, check job queue for its status
    if let Some(status) = job_queue.get_status(&url) {
        // Create a partial WorkerInfo to report the job status
        let worker_info = WorkerInfo {
            id: url.clone(),
            url: url.clone(),
            model_id: "unknown".to_string(),
            priority: 0,
            cost: 1.0,
            worker_type: "unknown".to_string(),
            is_healthy: false,
            load: 0,
            connection_mode: "unknown".to_string(),
            tokenizer_path: None,
            reasoning_parser: None,
            tool_parser: None,
            chat_template: None,
            bootstrap_port: None,
            metadata: std::collections::HashMap::new(),
            job_status: Some(status),
540
        };
541
        return Json(worker_info).into_response();
542
    }
543
544
545
546
547
548
549

    // Worker not found in registry or job queue
    let error = WorkerErrorResponse {
        error: format!("Worker {url} not found"),
        code: "WORKER_NOT_FOUND".to_string(),
    };
    (StatusCode::NOT_FOUND, Json(error)).into_response()
550
551
}

552
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
    let worker_id = url.clone();
    let job = Job::RemoveWorker { url };

    let job_queue = state
        .context
        .worker_job_queue
        .get()
        .expect("JobQueue not initialized");
    match job_queue.submit(job).await {
        Ok(_) => {
            let response = json!({
                "status": "accepted",
                "worker_id": worker_id,
                "message": "Worker removal queued for background processing"
            });
            (StatusCode::ACCEPTED, Json(response)).into_response()
569
570
571
572
        }
        Err(error) => {
            let error_response = WorkerErrorResponse {
                error,
573
                code: "INTERNAL_SERVER_ERROR".to_string(),
574
            };
575
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)).into_response()
576
577
578
579
        }
    }
}

580
581
582
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
583
    pub router_config: RouterConfig,
584
    pub max_payload_size: usize,
585
    pub log_dir: Option<String>,
586
    pub log_level: Option<String>,
587
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
588
    pub prometheus_config: Option<PrometheusConfig>,
589
    pub request_timeout_secs: u64,
590
    pub request_id_headers: Option<Vec<String>>,
591
592
}

593
594
pub fn build_app(
    app_state: Arc<AppState>,
595
    auth_config: AuthConfig,
596
597
598
599
600
601
602
    max_payload_size: usize,
    request_id_headers: Vec<String>,
    cors_allowed_origins: Vec<String>,
) -> Router {
    let protected_routes = Router::new()
        .route("/generate", post(generate))
        .route("/v1/chat/completions", post(v1_chat_completions))
603
        .route("/v1/completions", post(v1_completions))
604
605
        .route("/rerank", post(rerank))
        .route("/v1/rerank", post(v1_rerank))
606
        .route("/v1/responses", post(v1_responses))
607
        .route("/v1/embeddings", post(v1_embeddings))
608
        .route("/v1/classify", post(v1_classify))
609
610
611
612
613
614
615
        .route("/v1/responses/{response_id}", get(v1_responses_get))
        .route(
            "/v1/responses/{response_id}/cancel",
            post(v1_responses_cancel),
        )
        .route("/v1/responses/{response_id}", delete(v1_responses_delete))
        .route(
616
            "/v1/responses/{response_id}/input_items",
617
618
            get(v1_responses_list_input_items),
        )
619
620
621
622
623
624
625
        .route("/v1/conversations", post(v1_conversations_create))
        .route(
            "/v1/conversations/{conversation_id}",
            get(v1_conversations_get)
                .post(v1_conversations_update)
                .delete(v1_conversations_delete),
        )
626
627
        .route(
            "/v1/conversations/{conversation_id}/items",
628
629
630
631
632
            get(v1_conversations_list_items).post(v1_conversations_create_items),
        )
        .route(
            "/v1/conversations/{conversation_id}/items/{item_id}",
            get(v1_conversations_get_item).delete(v1_conversations_delete_item),
633
        )
634
635
        .route_layer(axum::middleware::from_fn_with_state(
            app_state.clone(),
636
            middleware::concurrency_limit_middleware,
637
638
639
640
        ))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
641
        ));
642
643
644
645
646
647

    let public_routes = Router::new()
        .route("/liveness", get(liveness))
        .route("/readiness", get(readiness))
        .route("/health", get(health))
        .route("/health_generate", get(health_generate))
648
        .route("/engine_metrics", get(engine_metrics))
649
650
651
652
653
654
        .route("/v1/models", get(v1_models))
        .route("/get_model_info", get(get_model_info))
        .route("/get_server_info", get(get_server_info));

    let admin_routes = Router::new()
        .route("/flush_cache", post(flush_cache))
655
656
657
658
659
        .route("/get_loads", get(get_loads))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
        ));
660

661
662
663
664
    let worker_routes = Router::new()
        .route("/workers", post(create_worker))
        .route("/workers", get(list_workers_rest))
        .route("/workers/{url}", get(get_worker))
665
666
667
668
669
        .route("/workers/{url}", delete(delete_worker))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
        ));
670

671
672
673
674
    Router::new()
        .merge(protected_routes)
        .merge(public_routes)
        .merge(admin_routes)
675
        .merge(worker_routes)
676
        .layer(axum::extract::DefaultBodyLimit::max(max_payload_size))
677
678
679
        .layer(tower_http::limit::RequestBodyLimitLayer::new(
            max_payload_size,
        ))
680
681
        .layer(middleware::create_logging_layer())
        .layer(middleware::RequestIdLayer::new(request_id_headers))
682
683
684
685
686
687
        .layer(create_cors_layer(cors_allowed_origins))
        .fallback(sink_handler)
        .with_state(app_state)
}

pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
688
689
690
691
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
692
693
694
695
696
697
            level: config
                .log_level
                .as_deref()
                .and_then(|s| match s.to_uppercase().parse::<Level>() {
                    Ok(l) => Some(l),
                    Err(_) => {
698
                        warn!("Invalid log level string: '{s}'. Defaulting to INFO.");
699
700
701
702
                        None
                    }
                })
                .unwrap_or(Level::INFO),
703
704
705
706
707
708
709
710
711
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
712

713
714
    if let Some(prometheus_config) = &config.prometheus_config {
        metrics::start_prometheus(prometheus_config.clone());
715
716
    }

717
    info!(
718
719
720
721
722
        "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
        config.host,
        config.port,
        config.router_config.mode,
        config.router_config.policy,
723
724
725
        config.max_payload_size / (1024 * 1024)
    );

726
    let app_context = Arc::new(
727
        AppContext::from_config(config.router_config.clone(), config.request_timeout_secs).await?,
728
    );
729

730
731
732
733
734
735
736
    let weak_context = Arc::downgrade(&app_context);
    let worker_job_queue = JobQueue::new(JobQueueConfig::default(), weak_context);
    app_context
        .worker_job_queue
        .set(worker_job_queue)
        .expect("JobQueue should only be initialized once");

737
738
739
740
741
    // Initialize workflow engine and register workflows
    let engine = Arc::new(WorkflowEngine::new());

    engine
        .event_bus()
742
        .subscribe(Arc::new(LoggingSubscriber))
743
744
        .await;

745
    engine.register_workflow(create_worker_registration_workflow(&config.router_config));
746
    engine.register_workflow(create_worker_removal_workflow());
747
    engine.register_workflow(create_mcp_registration_workflow());
748
749
750
751
    app_context
        .workflow_engine
        .set(engine)
        .expect("WorkflowEngine should only be initialized once");
752
753
754
755
    info!(
        "Workflow engine initialized with worker and MCP registration workflows (health check timeout: {}s)",
        config.router_config.health_check.timeout_secs
    );
756

757
758
759
760
    info!(
        "Initializing workers for routing mode: {:?}",
        config.router_config.mode
    );
761
762
763
764
765
766
767
768
769
770
771
772
773

    // Submit worker initialization job to queue
    let job_queue = app_context
        .worker_job_queue
        .get()
        .expect("JobQueue should be initialized");
    let job = Job::InitializeWorkersFromConfig {
        router_config: Box::new(config.router_config.clone()),
    };
    job_queue
        .submit(job)
        .await
        .map_err(|e| format!("Failed to submit worker initialization job: {}", e))?;
774

775
776
    info!("Worker initialization job submitted (will complete in background)");

777
778
779
780
781
782
783
784
785
786
787
788
789
    if let Some(mcp_config) = &config.router_config.mcp_config {
        info!("Found {} MCP server(s) in config", mcp_config.servers.len());
        let mcp_job = Job::InitializeMcpServers {
            mcp_config: Box::new(mcp_config.clone()),
        };
        job_queue
            .submit(mcp_job)
            .await
            .map_err(|e| format!("Failed to submit MCP initialization job: {}", e))?;
    } else {
        info!("No MCP config provided, skipping MCP server initialization");
    }

790
    // Start background refresh for ALL MCP servers (static + dynamic in LRU cache)
791
    if let Some(mcp_manager) = app_context.mcp_manager.get() {
792
        let refresh_interval = Duration::from_secs(600); // 10 minutes
793
794
        let _refresh_handle =
            Arc::clone(mcp_manager).spawn_background_refresh_all(refresh_interval);
795
        info!("Started background refresh for all MCP servers (every 10 minutes)");
796
797
    }

798
799
800
801
802
803
    let worker_stats = app_context.worker_registry.stats();
    info!(
        "Workers initialized: {} total, {} healthy",
        worker_stats.total_workers, worker_stats.healthy_workers
    );

804
805
    let router_manager = RouterManager::from_config(&config, &app_context).await?;
    let router: Arc<dyn RouterTrait> = router_manager.clone();
806
807
808
809
810
811
812
813

    let _health_checker = app_context
        .worker_registry
        .start_health_checker(config.router_config.health_check.check_interval_secs);
    info!(
        "Started health checker for workers with {}s interval",
        config.router_config.health_check.check_interval_secs
    );
814

815
816
817
818
819
    if let Some(ref load_monitor) = app_context.load_monitor {
        load_monitor.start().await;
        info!("Started LoadMonitor for PowerOfTwo policies");
    }

820
    let (limiter, processor) = middleware::ConcurrencyLimiter::new(
821
822
823
824
825
        app_context.rate_limiter.clone(),
        config.router_config.queue_size,
        Duration::from_secs(config.router_config.queue_timeout_secs),
    );

826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
    if app_context.rate_limiter.is_none() {
        info!("Rate limiting is disabled (max_concurrent_requests = -1)");
    }

    match processor {
        Some(proc) => {
            spawn(proc.run());
            info!(
                "Started request queue (size: {}, timeout: {}s)",
                config.router_config.queue_size, config.router_config.queue_timeout_secs
            );
        }
        None => {
            info!(
                "Rate limiting enabled (max_concurrent_requests = {}, queue disabled)",
                config.router_config.max_concurrent_requests
            );
        }
844
845
    }

846
    let app_state = Arc::new(AppState {
847
        router,
848
        context: app_context.clone(),
849
        concurrency_queue_tx: limiter.queue_tx.clone(),
850
        router_manager: Some(router_manager),
851
    });
852
853
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
854
855
            let app_context_arc = Arc::clone(&app_state.context);
            match start_service_discovery(service_discovery_config, app_context_arc).await {
856
                Ok(handle) => {
857
                    info!("Service discovery started");
858
859
860
861
862
863
864
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
865
                    error!("Failed to start service discovery: {e}");
866
867
868
869
870
871
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

872
    info!(
873
        "Router ready | workers: {:?}",
874
        WorkerManager::get_worker_urls(&app_state.context.worker_registry)
875
    );
876

877
878
879
880
881
882
883
884
885
    let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

886
887
888
889
    let auth_config = AuthConfig {
        api_key: config.router_config.api_key.clone(),
    };

890
891
    let app = build_app(
        app_state,
892
        auth_config,
893
894
895
896
        config.max_payload_size,
        request_id_headers,
        config.router_config.cors_allowed_origins.clone(),
    );
897

898
899
900
901
902
903
    // TcpListener::bind accepts &str and handles IPv4/IPv6 via ToSocketAddrs
    let bind_addr = format!("{}:{}", config.host, config.port);
    info!("Starting server on {}", bind_addr);
    let listener = TcpListener::bind(&bind_addr)
        .await
        .map_err(|e| format!("Failed to bind to {}: {}", bind_addr, e))?;
904
    serve(listener, app)
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
        .with_graceful_shutdown(shutdown_signal())
        .await
        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;

    Ok(())
}

async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {
            info!("Received Ctrl+C, starting graceful shutdown");
        },
        _ = terminate => {
            info!("Received terminate signal, starting graceful shutdown");
        },
    }
}

fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
    use tower_http::cors::Any;

    let cors = if allowed_origins.is_empty() {
        tower_http::cors::CorsLayer::new()
            .allow_origin(Any)
            .allow_methods(Any)
            .allow_headers(Any)
            .expose_headers(Any)
    } else {
        let origins: Vec<http::HeaderValue> = allowed_origins
            .into_iter()
            .filter_map(|origin| origin.parse().ok())
            .collect();

        tower_http::cors::CorsLayer::new()
            .allow_origin(origins)
            .allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
            .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
            .expose_headers([http::header::HeaderName::from_static("x-request-id")])
    };

    cors.max_age(Duration::from_secs(3600))
963
}