server.rs 28.7 KB
Newer Older
1
2
3
use std::{
    sync::{
        atomic::{AtomicBool, Ordering},
4
        Arc,
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
    },
    time::Duration,
};

use axum::{
    extract::{Path, Query, Request, State},
    http::StatusCode,
    response::{IntoResponse, Response},
    routing::{delete, get, post},
    serve, Json, Router,
};
use serde::Deserialize;
use serde_json::{json, Value};
use tokio::{net::TcpListener, signal, spawn};
use tracing::{error, info, warn, Level};

21
use crate::{
22
    app_context::AppContext,
23
    config::{RouterConfig, RoutingMode},
24
    core::{
25
26
27
28
29
        worker_to_info,
        workflow::{
            create_worker_registration_workflow, create_worker_removal_workflow, LoggingSubscriber,
            WorkflowEngine,
        },
30
        Job, JobQueue, JobQueueConfig, WorkerManager, WorkerType,
31
    },
32
33
    logging::{self, LoggingConfig},
    metrics::{self, PrometheusConfig},
34
    middleware::{self, AuthConfig, QueuedRequest},
35
    protocols::{
36
        chat::ChatCompletionRequest,
37
        classify::ClassifyRequest,
38
39
40
41
42
        completion::CompletionRequest,
        embedding::EmbeddingRequest,
        generate::GenerateRequest,
        rerank::{RerankRequest, V1RerankReqInput},
        responses::{ResponsesGetParams, ResponsesRequest},
43
        validated::ValidatedJson,
44
        worker_spec::{WorkerConfigRequest, WorkerErrorResponse, WorkerInfo},
45
    },
46
    routers::{router_manager::RouterManager, RouterTrait},
47
    service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
48
};
49

50
51
52
53
#[derive(Clone)]
pub struct AppState {
    pub router: Arc<dyn RouterTrait>,
    pub context: Arc<AppContext>,
54
    pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<QueuedRequest>>,
55
    pub router_manager: Option<Arc<RouterManager>>,
56
57
}

58
59
async fn sink_handler() -> Response {
    StatusCode::NOT_FOUND.into_response()
60
61
}

62
63
async fn liveness() -> Response {
    (StatusCode::OK, "OK").into_response()
64
65
}

66
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
    let workers = state.context.worker_registry.get_all();
    let healthy_workers: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect();

    let is_ready = if state.context.router_config.enable_igw {
        !healthy_workers.is_empty()
    } else {
        match &state.context.router_config.mode {
            RoutingMode::PrefillDecode { .. } => {
                let has_prefill = healthy_workers
                    .iter()
                    .any(|w| matches!(w.worker_type(), WorkerType::Prefill { .. }));
                let has_decode = healthy_workers
                    .iter()
                    .any(|w| matches!(w.worker_type(), WorkerType::Decode));
                has_prefill && has_decode
            }
            RoutingMode::Regular { .. } => !healthy_workers.is_empty(),
            RoutingMode::OpenAI { .. } => !healthy_workers.is_empty(),
        }
    };

    if is_ready {
        (
            StatusCode::OK,
            Json(json!({
                "status": "ready",
                "healthy_workers": healthy_workers.len(),
                "total_workers": workers.len()
            })),
        )
            .into_response()
    } else {
        (
            StatusCode::SERVICE_UNAVAILABLE,
            Json(json!({
                "status": "not ready",
                "reason": "insufficient healthy workers"
            })),
        )
            .into_response()
    }
108
109
}

110
111
async fn health(_state: State<Arc<AppState>>) -> Response {
    liveness().await
112
113
}

114
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
115
    state.router.health_generate(req).await
116
117
}

118
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
119
    state.router.get_server_info(req).await
120
121
}

122
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
123
    state.router.get_models(req).await
124
125
}

126
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
127
    state.router.get_model_info(req).await
128
}
129

130
async fn generate(
131
132
133
134
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<GenerateRequest>,
) -> Response {
135
136
137
138
    state
        .router
        .route_generate(Some(&headers), &body, None)
        .await
139
140
141
}

async fn v1_chat_completions(
142
143
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
144
    ValidatedJson(body): ValidatedJson<ChatCompletionRequest>,
145
) -> Response {
146
    state.router.route_chat(Some(&headers), &body, None).await
147
148
149
}

async fn v1_completions(
150
151
152
153
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<CompletionRequest>,
) -> Response {
154
155
156
157
    state
        .router
        .route_completion(Some(&headers), &body, None)
        .await
158
159
}

160
161
162
async fn rerank(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
163
    ValidatedJson(body): ValidatedJson<RerankRequest>,
164
) -> Response {
165
    state.router.route_rerank(Some(&headers), &body, None).await
166
167
168
169
170
171
172
173
174
}

async fn v1_rerank(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<V1RerankReqInput>,
) -> Response {
    state
        .router
175
        .route_rerank(Some(&headers), &body.into(), None)
176
177
178
        .await
}

179
180
181
182
183
async fn v1_responses(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ResponsesRequest>,
) -> Response {
184
185
186
187
    state
        .router
        .route_responses(Some(&headers), &body, None)
        .await
188
189
}

190
191
192
193
194
195
196
197
198
199
200
async fn v1_embeddings(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<EmbeddingRequest>,
) -> Response {
    state
        .router
        .route_embeddings(Some(&headers), &body, None)
        .await
}

201
202
203
204
205
206
207
208
209
210
211
async fn v1_classify(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ClassifyRequest>,
) -> Response {
    state
        .router
        .route_classify(Some(&headers), &body, None)
        .await
}

212
213
214
215
async fn v1_responses_get(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
216
    Query(params): Query<ResponsesGetParams>,
217
218
219
) -> Response {
    state
        .router
220
        .get_response(Some(&headers), &response_id, &params)
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        .await
}

async fn v1_responses_cancel(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .cancel_response(Some(&headers), &response_id)
        .await
}

async fn v1_responses_delete(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .delete_response(Some(&headers), &response_id)
        .await
}

async fn v1_responses_list_input_items(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .list_response_input_items(Some(&headers), &response_id)
        .await
}

257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
async fn v1_conversations_create(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<Value>,
) -> Response {
    state
        .router
        .create_conversation(Some(&headers), &body)
        .await
}

async fn v1_conversations_get(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .get_conversation(Some(&headers), &conversation_id)
        .await
}

async fn v1_conversations_update(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
    Json(body): Json<Value>,
) -> Response {
    state
        .router
        .update_conversation(Some(&headers), &conversation_id, &body)
        .await
}

async fn v1_conversations_delete(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .delete_conversation(Some(&headers), &conversation_id)
        .await
}

302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
#[derive(Deserialize, Default)]
struct ListItemsQuery {
    limit: Option<usize>,
    order: Option<String>,
    after: Option<String>,
}

async fn v1_conversations_list_items(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    Query(ListItemsQuery {
        limit,
        order,
        after,
    }): Query<ListItemsQuery>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .list_conversation_items(Some(&headers), &conversation_id, limit, order, after)
        .await
}

325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
#[derive(Deserialize, Default)]
struct GetItemQuery {
    /// Additional fields to include in response (not yet implemented)
    include: Option<Vec<String>>,
}

async fn v1_conversations_create_items(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
    Json(body): Json<Value>,
) -> Response {
    state
        .router
        .create_conversation_items(Some(&headers), &conversation_id, &body)
        .await
}

async fn v1_conversations_get_item(
    State(state): State<Arc<AppState>>,
    Path((conversation_id, item_id)): Path<(String, String)>,
    Query(query): Query<GetItemQuery>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .get_conversation_item(Some(&headers), &conversation_id, &item_id, query.include)
        .await
}

async fn v1_conversations_delete_item(
    State(state): State<Arc<AppState>>,
    Path((conversation_id, item_id)): Path<(String, String)>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .delete_conversation_item(Some(&headers), &conversation_id, &item_id)
        .await
}

366
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
    match WorkerManager::flush_cache_all(&state.context.worker_registry, &state.context.client)
        .await
    {
        Ok(result) => {
            if result.failed.is_empty() {
                (
                    StatusCode::OK,
                    Json(json!({
                        "status": "success",
                        "message": result.message,
                        "workers_flushed": result.successful.len(),
                        "total_http_workers": result.http_workers,
                        "total_workers": result.total_workers
                    })),
                )
                    .into_response()
            } else {
                (
                    StatusCode::PARTIAL_CONTENT,
                    Json(json!({
                        "status": "partial_success",
                        "message": result.message,
                        "successful": result.successful,
                        "failed": result.failed.into_iter().map(|(url, err)| json!({
                            "worker": url,
                            "error": err
                        })).collect::<Vec<_>>(),
                        "total_http_workers": result.http_workers,
                        "total_workers": result.total_workers
                    })),
                )
                    .into_response()
            }
        }
        Err(e) => {
            error!("Failed to flush cache: {}", e);
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(json!({
                    "status": "error",
                    "message": format!("Failed to flush cache: {}", e)
                })),
            )
                .into_response()
        }
    }
413
414
}

415
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    let result =
        WorkerManager::get_all_worker_loads(&state.context.worker_registry, &state.context.client)
            .await;

    let loads: Vec<Value> = result
        .loads
        .iter()
        .map(|info| {
            json!({
                "worker": &info.worker,
                "load": info.load
            })
        })
        .collect();

431
    (StatusCode::OK, Json(json!({ "workers": loads }))).into_response()
432
433
}

434
435
436
437
async fn create_worker(
    State(state): State<Arc<AppState>>,
    Json(config): Json<WorkerConfigRequest>,
) -> Response {
438
439
440
441
442
443
444
445
446
447
    // Warn if router has API key but worker is being added without one
    if state.context.router_config.api_key.is_some() && config.api_key.is_none() {
        warn!(
            "Adding worker {} without API key while router has API key configured. \
            Worker will be accessible without authentication. \
            If the worker requires the same API key as the router, please specify it explicitly.",
            config.url
        );
    }

448
449
450
451
452
453
    // Populate dp_aware from router's configuration
    let config = WorkerConfigRequest {
        dp_aware: state.context.router_config.dp_aware,
        ..config
    };

454
455
456
457
458
    // Submit job for async processing
    let worker_url = config.url.clone();
    let job = Job::AddWorker {
        config: Box::new(config),
    };
459

460
461
462
463
464
465
466
467
468
469
470
471
472
    let job_queue = state
        .context
        .worker_job_queue
        .get()
        .expect("JobQueue not initialized");
    match job_queue.submit(job).await {
        Ok(_) => {
            let response = json!({
                "status": "accepted",
                "worker_id": worker_url,
                "message": "Worker addition queued for background processing"
            });
            (StatusCode::ACCEPTED, Json(response)).into_response()
473
        }
474
475
476
        Err(error) => {
            let error_response = WorkerErrorResponse {
                error,
477
                code: "INTERNAL_SERVER_ERROR".to_string(),
478
            };
479
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)).into_response()
480
481
482
483
484
        }
    }
}

async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
485
    let workers = state.context.worker_registry.get_all();
486
    let worker_infos: Vec<WorkerInfo> = workers.iter().map(worker_to_info).collect();
487

488
489
    let response = json!({
        "workers": worker_infos,
490
491
492
493
494
495
496
497
        "total": workers.len(),
        "stats": {
            "prefill_count": state.context.worker_registry.get_prefill_workers().len(),
            "decode_count": state.context.worker_registry.get_decode_workers().len(),
            "regular_count": state.context.worker_registry.get_by_type(&WorkerType::Regular).len(),
        }
    });
    Json(response).into_response()
498
499
}

500
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
    let job_queue = state
        .context
        .worker_job_queue
        .get()
        .expect("JobQueue not initialized");

    if let Some(worker) = state.context.worker_registry.get_by_url(&url) {
        // Worker exists in registry, get its full info and attach job status if any
        let mut worker_info = worker_to_info(&worker);
        if let Some(status) = job_queue.get_status(&url) {
            worker_info.job_status = Some(status);
        }
        return Json(worker_info).into_response();
    }

    // Worker not in registry, check job queue for its status
    if let Some(status) = job_queue.get_status(&url) {
        // Create a partial WorkerInfo to report the job status
        let worker_info = WorkerInfo {
            id: url.clone(),
            url: url.clone(),
            model_id: "unknown".to_string(),
            priority: 0,
            cost: 1.0,
            worker_type: "unknown".to_string(),
            is_healthy: false,
            load: 0,
            connection_mode: "unknown".to_string(),
            tokenizer_path: None,
            reasoning_parser: None,
            tool_parser: None,
            chat_template: None,
            bootstrap_port: None,
            metadata: std::collections::HashMap::new(),
            job_status: Some(status),
536
        };
537
        return Json(worker_info).into_response();
538
    }
539
540
541
542
543
544
545

    // Worker not found in registry or job queue
    let error = WorkerErrorResponse {
        error: format!("Worker {url} not found"),
        code: "WORKER_NOT_FOUND".to_string(),
    };
    (StatusCode::NOT_FOUND, Json(error)).into_response()
546
547
}

548
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
    let worker_id = url.clone();
    let job = Job::RemoveWorker { url };

    let job_queue = state
        .context
        .worker_job_queue
        .get()
        .expect("JobQueue not initialized");
    match job_queue.submit(job).await {
        Ok(_) => {
            let response = json!({
                "status": "accepted",
                "worker_id": worker_id,
                "message": "Worker removal queued for background processing"
            });
            (StatusCode::ACCEPTED, Json(response)).into_response()
565
566
567
568
        }
        Err(error) => {
            let error_response = WorkerErrorResponse {
                error,
569
                code: "INTERNAL_SERVER_ERROR".to_string(),
570
            };
571
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)).into_response()
572
573
574
575
        }
    }
}

576
577
578
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
579
    pub router_config: RouterConfig,
580
    pub max_payload_size: usize,
581
    pub log_dir: Option<String>,
582
    pub log_level: Option<String>,
583
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
584
    pub prometheus_config: Option<PrometheusConfig>,
585
    pub request_timeout_secs: u64,
586
    pub request_id_headers: Option<Vec<String>>,
587
588
}

589
590
pub fn build_app(
    app_state: Arc<AppState>,
591
    auth_config: AuthConfig,
592
593
594
595
596
597
598
    max_payload_size: usize,
    request_id_headers: Vec<String>,
    cors_allowed_origins: Vec<String>,
) -> Router {
    let protected_routes = Router::new()
        .route("/generate", post(generate))
        .route("/v1/chat/completions", post(v1_chat_completions))
599
        .route("/v1/completions", post(v1_completions))
600
601
        .route("/rerank", post(rerank))
        .route("/v1/rerank", post(v1_rerank))
602
        .route("/v1/responses", post(v1_responses))
603
        .route("/v1/embeddings", post(v1_embeddings))
604
        .route("/v1/classify", post(v1_classify))
605
606
607
608
609
610
611
        .route("/v1/responses/{response_id}", get(v1_responses_get))
        .route(
            "/v1/responses/{response_id}/cancel",
            post(v1_responses_cancel),
        )
        .route("/v1/responses/{response_id}", delete(v1_responses_delete))
        .route(
612
            "/v1/responses/{response_id}/input_items",
613
614
            get(v1_responses_list_input_items),
        )
615
616
617
618
619
620
621
        .route("/v1/conversations", post(v1_conversations_create))
        .route(
            "/v1/conversations/{conversation_id}",
            get(v1_conversations_get)
                .post(v1_conversations_update)
                .delete(v1_conversations_delete),
        )
622
623
        .route(
            "/v1/conversations/{conversation_id}/items",
624
625
626
627
628
            get(v1_conversations_list_items).post(v1_conversations_create_items),
        )
        .route(
            "/v1/conversations/{conversation_id}/items/{item_id}",
            get(v1_conversations_get_item).delete(v1_conversations_delete_item),
629
        )
630
631
        .route_layer(axum::middleware::from_fn_with_state(
            app_state.clone(),
632
            middleware::concurrency_limit_middleware,
633
634
635
636
        ))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
637
        ));
638
639
640
641
642
643
644
645
646
647
648
649

    let public_routes = Router::new()
        .route("/liveness", get(liveness))
        .route("/readiness", get(readiness))
        .route("/health", get(health))
        .route("/health_generate", get(health_generate))
        .route("/v1/models", get(v1_models))
        .route("/get_model_info", get(get_model_info))
        .route("/get_server_info", get(get_server_info));

    let admin_routes = Router::new()
        .route("/flush_cache", post(flush_cache))
650
651
652
653
654
        .route("/get_loads", get(get_loads))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
        ));
655

656
657
658
659
    let worker_routes = Router::new()
        .route("/workers", post(create_worker))
        .route("/workers", get(list_workers_rest))
        .route("/workers/{url}", get(get_worker))
660
661
662
663
664
        .route("/workers/{url}", delete(delete_worker))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
        ));
665

666
667
668
669
    Router::new()
        .merge(protected_routes)
        .merge(public_routes)
        .merge(admin_routes)
670
        .merge(worker_routes)
671
        .layer(axum::extract::DefaultBodyLimit::max(max_payload_size))
672
673
674
        .layer(tower_http::limit::RequestBodyLimitLayer::new(
            max_payload_size,
        ))
675
676
        .layer(middleware::create_logging_layer())
        .layer(middleware::RequestIdLayer::new(request_id_headers))
677
678
679
680
681
682
        .layer(create_cors_layer(cors_allowed_origins))
        .fallback(sink_handler)
        .with_state(app_state)
}

pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
683
684
685
686
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
687
688
689
690
691
692
            level: config
                .log_level
                .as_deref()
                .and_then(|s| match s.to_uppercase().parse::<Level>() {
                    Ok(l) => Some(l),
                    Err(_) => {
693
                        warn!("Invalid log level string: '{s}'. Defaulting to INFO.");
694
695
696
697
                        None
                    }
                })
                .unwrap_or(Level::INFO),
698
699
700
701
702
703
704
705
706
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
707

708
709
    if let Some(prometheus_config) = &config.prometheus_config {
        metrics::start_prometheus(prometheus_config.clone());
710
711
    }

712
    info!(
713
714
715
716
717
        "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
        config.host,
        config.port,
        config.router_config.mode,
        config.router_config.policy,
718
719
720
        config.max_payload_size / (1024 * 1024)
    );

721
    let app_context = Arc::new(
722
        AppContext::from_config(config.router_config.clone(), config.request_timeout_secs).await?,
723
    );
724

725
726
727
728
729
730
731
    let weak_context = Arc::downgrade(&app_context);
    let worker_job_queue = JobQueue::new(JobQueueConfig::default(), weak_context);
    app_context
        .worker_job_queue
        .set(worker_job_queue)
        .expect("JobQueue should only be initialized once");

732
733
734
735
736
    // Initialize workflow engine and register workflows
    let engine = Arc::new(WorkflowEngine::new());

    engine
        .event_bus()
737
        .subscribe(Arc::new(LoggingSubscriber))
738
739
        .await;

740
741
    engine.register_workflow(create_worker_registration_workflow());
    engine.register_workflow(create_worker_removal_workflow());
742
743
744
745
    app_context
        .workflow_engine
        .set(engine)
        .expect("WorkflowEngine should only be initialized once");
746
    info!("Workflow engine initialized with worker registration and removal workflows");
747

748
749
750
751
    info!(
        "Initializing workers for routing mode: {:?}",
        config.router_config.mode
    );
752
753
754
755
756
757
758
759
760
761
762
763
764

    // Submit worker initialization job to queue
    let job_queue = app_context
        .worker_job_queue
        .get()
        .expect("JobQueue should be initialized");
    let job = Job::InitializeWorkersFromConfig {
        router_config: Box::new(config.router_config.clone()),
    };
    job_queue
        .submit(job)
        .await
        .map_err(|e| format!("Failed to submit worker initialization job: {}", e))?;
765
766
767
768
769
770
771

    let worker_stats = app_context.worker_registry.stats();
    info!(
        "Workers initialized: {} total, {} healthy",
        worker_stats.total_workers, worker_stats.healthy_workers
    );

772
773
    let router_manager = RouterManager::from_config(&config, &app_context).await?;
    let router: Arc<dyn RouterTrait> = router_manager.clone();
774
775
776
777
778
779
780
781

    let _health_checker = app_context
        .worker_registry
        .start_health_checker(config.router_config.health_check.check_interval_secs);
    info!(
        "Started health checker for workers with {}s interval",
        config.router_config.health_check.check_interval_secs
    );
782

783
784
785
786
787
    if let Some(ref load_monitor) = app_context.load_monitor {
        load_monitor.start().await;
        info!("Started LoadMonitor for PowerOfTwo policies");
    }

788
    let (limiter, processor) = middleware::ConcurrencyLimiter::new(
789
790
791
792
793
        app_context.rate_limiter.clone(),
        config.router_config.queue_size,
        Duration::from_secs(config.router_config.queue_timeout_secs),
    );

794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
    if app_context.rate_limiter.is_none() {
        info!("Rate limiting is disabled (max_concurrent_requests = -1)");
    }

    match processor {
        Some(proc) => {
            spawn(proc.run());
            info!(
                "Started request queue (size: {}, timeout: {}s)",
                config.router_config.queue_size, config.router_config.queue_timeout_secs
            );
        }
        None => {
            info!(
                "Rate limiting enabled (max_concurrent_requests = {}, queue disabled)",
                config.router_config.max_concurrent_requests
            );
        }
812
813
    }

814
    let app_state = Arc::new(AppState {
815
        router,
816
        context: app_context.clone(),
817
        concurrency_queue_tx: limiter.queue_tx.clone(),
818
        router_manager: Some(router_manager),
819
    });
820
821
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
822
823
            let app_context_arc = Arc::clone(&app_state.context);
            match start_service_discovery(service_discovery_config, app_context_arc).await {
824
                Ok(handle) => {
825
                    info!("Service discovery started");
826
827
828
829
830
831
832
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
833
                    error!("Failed to start service discovery: {e}");
834
835
836
837
838
839
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

840
    info!(
841
        "Router ready | workers: {:?}",
842
        WorkerManager::get_worker_urls(&app_state.context.worker_registry)
843
    );
844

845
846
847
848
849
850
851
852
853
    let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

854
855
856
857
    let auth_config = AuthConfig {
        api_key: config.router_config.api_key.clone(),
    };

858
859
    let app = build_app(
        app_state,
860
        auth_config,
861
862
863
864
        config.max_payload_size,
        request_id_headers,
        config.router_config.cors_allowed_origins.clone(),
    );
865

866
867
868
869
870
871
    // TcpListener::bind accepts &str and handles IPv4/IPv6 via ToSocketAddrs
    let bind_addr = format!("{}:{}", config.host, config.port);
    info!("Starting server on {}", bind_addr);
    let listener = TcpListener::bind(&bind_addr)
        .await
        .map_err(|e| format!("Failed to bind to {}: {}", bind_addr, e))?;
872
    serve(listener, app)
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
        .with_graceful_shutdown(shutdown_signal())
        .await
        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;

    Ok(())
}

async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {
            info!("Received Ctrl+C, starting graceful shutdown");
        },
        _ = terminate => {
            info!("Received terminate signal, starting graceful shutdown");
        },
    }
}

fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
    use tower_http::cors::Any;

    let cors = if allowed_origins.is_empty() {
        tower_http::cors::CorsLayer::new()
            .allow_origin(Any)
            .allow_methods(Any)
            .allow_headers(Any)
            .expose_headers(Any)
    } else {
        let origins: Vec<http::HeaderValue> = allowed_origins
            .into_iter()
            .filter_map(|origin| origin.parse().ok())
            .collect();

        tower_http::cors::CorsLayer::new()
            .allow_origin(origins)
            .allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
            .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
            .expose_headers([http::header::HeaderName::from_static("x-request-id")])
    };

    cors.max_age(Duration::from_secs(3600))
931
}