server.rs 39.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
use std::{
    sync::{
        atomic::{AtomicBool, Ordering},
        Arc, OnceLock,
    },
    time::Duration,
};

use axum::{
    extract::{Path, Query, Request, State},
    http::StatusCode,
    response::{IntoResponse, Response},
    routing::{delete, get, post},
    serve, Json, Router,
};
use reqwest::Client;
use serde::Deserialize;
use serde_json::{json, Value};
use tokio::{net::TcpListener, signal, spawn};
use tracing::{error, info, warn, Level};

22
use crate::{
23
    config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
24
    core::{
25
26
        worker_to_info, workflow::WorkflowEngine, Job, JobQueue, JobQueueConfig, LoadMonitor,
        WorkerManager, WorkerRegistry, WorkerType,
27
    },
28
    data_connector::{
29
30
        MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
        NoOpConversationStorage, NoOpResponseStorage, OracleConversationItemStorage,
31
32
        OracleConversationStorage, OracleResponseStorage, SharedConversationItemStorage,
        SharedConversationStorage, SharedResponseStorage,
33
    },
34
35
    logging::{self, LoggingConfig},
    metrics::{self, PrometheusConfig},
36
    middleware::{self, AuthConfig, QueuedRequest, TokenBucket},
37
38
    policies::PolicyRegistry,
    protocols::{
39
        chat::ChatCompletionRequest,
40
        classify::ClassifyRequest,
41
42
43
44
45
        completion::CompletionRequest,
        embedding::EmbeddingRequest,
        generate::GenerateRequest,
        rerank::{RerankRequest, V1RerankReqInput},
        responses::{ResponsesGetParams, ResponsesRequest},
46
        validated::ValidatedJson,
47
        worker_spec::{WorkerConfigRequest, WorkerErrorResponse, WorkerInfo},
48
    },
49
    reasoning_parser::ParserFactory as ReasoningParserFactory,
50
    routers::{router_manager::RouterManager, RouterTrait},
51
    service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
52
53
54
55
56
    tokenizer::{
        cache::{CacheConfig, CachedTokenizer},
        factory as tokenizer_factory,
        traits::Tokenizer,
    },
57
    tool_parser::ParserFactory as ToolParserFactory,
58
};
59

60
61
//

62
#[derive(Clone)]
63
pub struct AppContext {
64
    pub client: Client,
65
    pub router_config: RouterConfig,
66
    pub rate_limiter: Option<Arc<TokenBucket>>,
67
    pub tokenizer: Option<Arc<dyn Tokenizer>>,
68
    pub reasoning_parser_factory: Option<ReasoningParserFactory>,
69
    pub tool_parser_factory: Option<ToolParserFactory>,
70
71
72
    pub worker_registry: Arc<WorkerRegistry>,
    pub policy_registry: Arc<PolicyRegistry>,
    pub router_manager: Option<Arc<RouterManager>>,
73
    pub response_storage: SharedResponseStorage,
74
    pub conversation_storage: SharedConversationStorage,
75
    pub conversation_item_storage: SharedConversationItemStorage,
76
    pub load_monitor: Option<Arc<LoadMonitor>>,
77
78
    pub configured_reasoning_parser: Option<String>,
    pub configured_tool_parser: Option<String>,
79
    pub worker_job_queue: Arc<OnceLock<Arc<JobQueue>>>,
80
    pub workflow_engine: Arc<OnceLock<Arc<WorkflowEngine>>>,
81
82
}

83
impl AppContext {
84
    #[allow(clippy::too_many_arguments)]
85
86
87
    pub fn new(
        router_config: RouterConfig,
        client: Client,
88
89
90
91
92
93
94
95
96
97
98
        rate_limiter: Option<Arc<TokenBucket>>,
        tokenizer: Option<Arc<dyn Tokenizer>>,
        reasoning_parser_factory: Option<ReasoningParserFactory>,
        tool_parser_factory: Option<ToolParserFactory>,
        worker_registry: Arc<WorkerRegistry>,
        policy_registry: Arc<PolicyRegistry>,
        response_storage: SharedResponseStorage,
        conversation_storage: SharedConversationStorage,
        conversation_item_storage: SharedConversationItemStorage,
        load_monitor: Option<Arc<LoadMonitor>>,
        worker_job_queue: Arc<OnceLock<Arc<JobQueue>>>,
99
        workflow_engine: Arc<OnceLock<Arc<WorkflowEngine>>>,
100
    ) -> Self {
101
102
103
        let configured_reasoning_parser = router_config.reasoning_parser.clone();
        let configured_tool_parser = router_config.tool_call_parser.clone();

104
        Self {
105
            client,
106
            router_config,
107
            rate_limiter,
108
109
            tokenizer,
            reasoning_parser_factory,
110
            tool_parser_factory,
111
112
            worker_registry,
            policy_registry,
113
            router_manager: None,
114
            response_storage,
115
            conversation_storage,
116
            conversation_item_storage,
117
            load_monitor,
118
119
            configured_reasoning_parser,
            configured_tool_parser,
120
            worker_job_queue,
121
            workflow_engine,
122
        }
123
124
125
    }
}

126
127
128
129
#[derive(Clone)]
pub struct AppState {
    pub router: Arc<dyn RouterTrait>,
    pub context: Arc<AppContext>,
130
    pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<QueuedRequest>>,
131
    pub router_manager: Option<Arc<RouterManager>>,
132
133
}

134
135
async fn sink_handler() -> Response {
    StatusCode::NOT_FOUND.into_response()
136
137
}

138
139
async fn liveness() -> Response {
    (StatusCode::OK, "OK").into_response()
140
141
}

142
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    let workers = state.context.worker_registry.get_all();
    let healthy_workers: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect();

    let is_ready = if state.context.router_config.enable_igw {
        !healthy_workers.is_empty()
    } else {
        match &state.context.router_config.mode {
            RoutingMode::PrefillDecode { .. } => {
                let has_prefill = healthy_workers
                    .iter()
                    .any(|w| matches!(w.worker_type(), WorkerType::Prefill { .. }));
                let has_decode = healthy_workers
                    .iter()
                    .any(|w| matches!(w.worker_type(), WorkerType::Decode));
                has_prefill && has_decode
            }
            RoutingMode::Regular { .. } => !healthy_workers.is_empty(),
            RoutingMode::OpenAI { .. } => !healthy_workers.is_empty(),
        }
    };

    if is_ready {
        (
            StatusCode::OK,
            Json(json!({
                "status": "ready",
                "healthy_workers": healthy_workers.len(),
                "total_workers": workers.len()
            })),
        )
            .into_response()
    } else {
        (
            StatusCode::SERVICE_UNAVAILABLE,
            Json(json!({
                "status": "not ready",
                "reason": "insufficient healthy workers"
            })),
        )
            .into_response()
    }
184
185
}

186
187
async fn health(_state: State<Arc<AppState>>) -> Response {
    liveness().await
188
189
}

190
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
191
    state.router.health_generate(req).await
192
193
}

194
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
195
    state.router.get_server_info(req).await
196
197
}

198
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
199
    state.router.get_models(req).await
200
201
}

202
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
203
    state.router.get_model_info(req).await
204
}
205

206
async fn generate(
207
208
209
210
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<GenerateRequest>,
) -> Response {
211
212
213
214
    state
        .router
        .route_generate(Some(&headers), &body, None)
        .await
215
216
217
}

async fn v1_chat_completions(
218
219
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
220
    ValidatedJson(body): ValidatedJson<ChatCompletionRequest>,
221
) -> Response {
222
    state.router.route_chat(Some(&headers), &body, None).await
223
224
225
}

async fn v1_completions(
226
227
228
229
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<CompletionRequest>,
) -> Response {
230
231
232
233
    state
        .router
        .route_completion(Some(&headers), &body, None)
        .await
234
235
}

236
237
238
async fn rerank(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
239
    ValidatedJson(body): ValidatedJson<RerankRequest>,
240
) -> Response {
241
    state.router.route_rerank(Some(&headers), &body, None).await
242
243
244
245
246
247
248
249
250
}

async fn v1_rerank(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<V1RerankReqInput>,
) -> Response {
    state
        .router
251
        .route_rerank(Some(&headers), &body.into(), None)
252
253
254
        .await
}

255
256
257
258
259
async fn v1_responses(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ResponsesRequest>,
) -> Response {
260
261
262
263
    state
        .router
        .route_responses(Some(&headers), &body, None)
        .await
264
265
}

266
267
268
269
270
271
272
273
274
275
276
async fn v1_embeddings(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<EmbeddingRequest>,
) -> Response {
    state
        .router
        .route_embeddings(Some(&headers), &body, None)
        .await
}

277
278
279
280
281
282
283
284
285
286
287
async fn v1_classify(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ClassifyRequest>,
) -> Response {
    state
        .router
        .route_classify(Some(&headers), &body, None)
        .await
}

288
289
290
291
async fn v1_responses_get(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
292
    Query(params): Query<ResponsesGetParams>,
293
294
295
) -> Response {
    state
        .router
296
        .get_response(Some(&headers), &response_id, &params)
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        .await
}

async fn v1_responses_cancel(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .cancel_response(Some(&headers), &response_id)
        .await
}

async fn v1_responses_delete(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .delete_response(Some(&headers), &response_id)
        .await
}

async fn v1_responses_list_input_items(
    State(state): State<Arc<AppState>>,
    Path(response_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .list_response_input_items(Some(&headers), &response_id)
        .await
}

333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
async fn v1_conversations_create(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<Value>,
) -> Response {
    state
        .router
        .create_conversation(Some(&headers), &body)
        .await
}

async fn v1_conversations_get(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .get_conversation(Some(&headers), &conversation_id)
        .await
}

async fn v1_conversations_update(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
    Json(body): Json<Value>,
) -> Response {
    state
        .router
        .update_conversation(Some(&headers), &conversation_id, &body)
        .await
}

async fn v1_conversations_delete(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .delete_conversation(Some(&headers), &conversation_id)
        .await
}

378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
#[derive(Deserialize, Default)]
struct ListItemsQuery {
    limit: Option<usize>,
    order: Option<String>,
    after: Option<String>,
}

async fn v1_conversations_list_items(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    Query(ListItemsQuery {
        limit,
        order,
        after,
    }): Query<ListItemsQuery>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .list_conversation_items(Some(&headers), &conversation_id, limit, order, after)
        .await
}

401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
#[derive(Deserialize, Default)]
struct GetItemQuery {
    /// Additional fields to include in response (not yet implemented)
    include: Option<Vec<String>>,
}

async fn v1_conversations_create_items(
    State(state): State<Arc<AppState>>,
    Path(conversation_id): Path<String>,
    headers: http::HeaderMap,
    Json(body): Json<Value>,
) -> Response {
    state
        .router
        .create_conversation_items(Some(&headers), &conversation_id, &body)
        .await
}

async fn v1_conversations_get_item(
    State(state): State<Arc<AppState>>,
    Path((conversation_id, item_id)): Path<(String, String)>,
    Query(query): Query<GetItemQuery>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .get_conversation_item(Some(&headers), &conversation_id, &item_id, query.include)
        .await
}

async fn v1_conversations_delete_item(
    State(state): State<Arc<AppState>>,
    Path((conversation_id, item_id)): Path<(String, String)>,
    headers: http::HeaderMap,
) -> Response {
    state
        .router
        .delete_conversation_item(Some(&headers), &conversation_id, &item_id)
        .await
}

442
#[derive(Deserialize)]
443
struct AddWorkerQuery {
444
    url: String,
445
    api_key: Option<String>,
446
447
}

448
async fn add_worker(
449
    State(state): State<Arc<AppState>>,
450
    Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
451
) -> Response {
452
453
454
455
456
457
458
459
460
461
    // Warn if router has API key but worker is being added without one
    if state.context.router_config.api_key.is_some() && api_key.is_none() {
        warn!(
            "Adding worker {} without API key while router has API key configured. \
            Worker will be accessible without authentication. \
            If the worker requires the same API key as the router, please specify it explicitly.",
            url
        );
    }

462
463
464
    let result = WorkerManager::add_worker(&url, &api_key, &state.context).await;

    match result {
465
466
        Ok(message) => (StatusCode::OK, message).into_response(),
        Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
467
    }
468
469
}

470
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
471
472
    let worker_list = WorkerManager::get_worker_urls(&state.context.worker_registry);
    Json(json!({ "urls": worker_list })).into_response()
473
474
}

475
async fn remove_worker(
476
    State(state): State<Arc<AppState>>,
477
    Query(AddWorkerQuery { url, .. }): Query<AddWorkerQuery>,
478
) -> Response {
479
480
481
482
483
484
    let result = WorkerManager::remove_worker(&url, &state.context);

    match result {
        Ok(message) => (StatusCode::OK, message).into_response(),
        Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
    }
485
486
}

487
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    match WorkerManager::flush_cache_all(&state.context.worker_registry, &state.context.client)
        .await
    {
        Ok(result) => {
            if result.failed.is_empty() {
                (
                    StatusCode::OK,
                    Json(json!({
                        "status": "success",
                        "message": result.message,
                        "workers_flushed": result.successful.len(),
                        "total_http_workers": result.http_workers,
                        "total_workers": result.total_workers
                    })),
                )
                    .into_response()
            } else {
                (
                    StatusCode::PARTIAL_CONTENT,
                    Json(json!({
                        "status": "partial_success",
                        "message": result.message,
                        "successful": result.successful,
                        "failed": result.failed.into_iter().map(|(url, err)| json!({
                            "worker": url,
                            "error": err
                        })).collect::<Vec<_>>(),
                        "total_http_workers": result.http_workers,
                        "total_workers": result.total_workers
                    })),
                )
                    .into_response()
            }
        }
        Err(e) => {
            error!("Failed to flush cache: {}", e);
            (
                StatusCode::INTERNAL_SERVER_ERROR,
                Json(json!({
                    "status": "error",
                    "message": format!("Failed to flush cache: {}", e)
                })),
            )
                .into_response()
        }
    }
534
535
}

536
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
    let result =
        WorkerManager::get_all_worker_loads(&state.context.worker_registry, &state.context.client)
            .await;

    let loads: Vec<Value> = result
        .loads
        .iter()
        .map(|info| {
            json!({
                "worker": &info.worker,
                "load": info.load
            })
        })
        .collect();

552
    (StatusCode::OK, Json(json!({ "workers": loads }))).into_response()
553
554
}

555
556
557
558
async fn create_worker(
    State(state): State<Arc<AppState>>,
    Json(config): Json<WorkerConfigRequest>,
) -> Response {
559
560
561
562
563
564
565
566
567
568
    // Warn if router has API key but worker is being added without one
    if state.context.router_config.api_key.is_some() && config.api_key.is_none() {
        warn!(
            "Adding worker {} without API key while router has API key configured. \
            Worker will be accessible without authentication. \
            If the worker requires the same API key as the router, please specify it explicitly.",
            config.url
        );
    }

569
570
571
572
573
    // Submit job for async processing
    let worker_url = config.url.clone();
    let job = Job::AddWorker {
        config: Box::new(config),
    };
574

575
576
577
578
579
580
581
582
583
584
585
586
587
    let job_queue = state
        .context
        .worker_job_queue
        .get()
        .expect("JobQueue not initialized");
    match job_queue.submit(job).await {
        Ok(_) => {
            let response = json!({
                "status": "accepted",
                "worker_id": worker_url,
                "message": "Worker addition queued for background processing"
            });
            (StatusCode::ACCEPTED, Json(response)).into_response()
588
        }
589
590
591
        Err(error) => {
            let error_response = WorkerErrorResponse {
                error,
592
                code: "INTERNAL_SERVER_ERROR".to_string(),
593
            };
594
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)).into_response()
595
596
597
598
599
        }
    }
}

async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
600
    let workers = state.context.worker_registry.get_all();
601
    let worker_infos: Vec<WorkerInfo> = workers.iter().map(worker_to_info).collect();
602

603
604
    let response = json!({
        "workers": worker_infos,
605
606
607
608
609
610
611
612
        "total": workers.len(),
        "stats": {
            "prefill_count": state.context.worker_registry.get_prefill_workers().len(),
            "decode_count": state.context.worker_registry.get_decode_workers().len(),
            "regular_count": state.context.worker_registry.get_by_type(&WorkerType::Regular).len(),
        }
    });
    Json(response).into_response()
613
614
}

615
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
    let job_queue = state
        .context
        .worker_job_queue
        .get()
        .expect("JobQueue not initialized");

    if let Some(worker) = state.context.worker_registry.get_by_url(&url) {
        // Worker exists in registry, get its full info and attach job status if any
        let mut worker_info = worker_to_info(&worker);
        if let Some(status) = job_queue.get_status(&url) {
            worker_info.job_status = Some(status);
        }
        return Json(worker_info).into_response();
    }

    // Worker not in registry, check job queue for its status
    if let Some(status) = job_queue.get_status(&url) {
        // Create a partial WorkerInfo to report the job status
        let worker_info = WorkerInfo {
            id: url.clone(),
            url: url.clone(),
            model_id: "unknown".to_string(),
            priority: 0,
            cost: 1.0,
            worker_type: "unknown".to_string(),
            is_healthy: false,
            load: 0,
            connection_mode: "unknown".to_string(),
            tokenizer_path: None,
            reasoning_parser: None,
            tool_parser: None,
            chat_template: None,
            bootstrap_port: None,
            metadata: std::collections::HashMap::new(),
            job_status: Some(status),
651
        };
652
        return Json(worker_info).into_response();
653
    }
654
655
656
657
658
659
660

    // Worker not found in registry or job queue
    let error = WorkerErrorResponse {
        error: format!("Worker {url} not found"),
        code: "WORKER_NOT_FOUND".to_string(),
    };
    (StatusCode::NOT_FOUND, Json(error)).into_response()
661
662
}

663
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
    let worker_id = url.clone();
    let job = Job::RemoveWorker { url };

    let job_queue = state
        .context
        .worker_job_queue
        .get()
        .expect("JobQueue not initialized");
    match job_queue.submit(job).await {
        Ok(_) => {
            let response = json!({
                "status": "accepted",
                "worker_id": worker_id,
                "message": "Worker removal queued for background processing"
            });
            (StatusCode::ACCEPTED, Json(response)).into_response()
680
681
682
683
        }
        Err(error) => {
            let error_response = WorkerErrorResponse {
                error,
684
                code: "INTERNAL_SERVER_ERROR".to_string(),
685
            };
686
            (StatusCode::INTERNAL_SERVER_ERROR, Json(error_response)).into_response()
687
688
689
690
        }
    }
}

691
692
693
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
694
    pub router_config: RouterConfig,
695
    pub max_payload_size: usize,
696
    pub log_dir: Option<String>,
697
    pub log_level: Option<String>,
698
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
699
    pub prometheus_config: Option<PrometheusConfig>,
700
    pub request_timeout_secs: u64,
701
    pub request_id_headers: Option<Vec<String>>,
702
703
}

704
705
pub fn build_app(
    app_state: Arc<AppState>,
706
    auth_config: AuthConfig,
707
708
709
710
711
712
713
    max_payload_size: usize,
    request_id_headers: Vec<String>,
    cors_allowed_origins: Vec<String>,
) -> Router {
    let protected_routes = Router::new()
        .route("/generate", post(generate))
        .route("/v1/chat/completions", post(v1_chat_completions))
714
        .route("/v1/completions", post(v1_completions))
715
716
        .route("/rerank", post(rerank))
        .route("/v1/rerank", post(v1_rerank))
717
        .route("/v1/responses", post(v1_responses))
718
        .route("/v1/embeddings", post(v1_embeddings))
719
        .route("/v1/classify", post(v1_classify))
720
721
722
723
724
725
726
727
728
729
        .route("/v1/responses/{response_id}", get(v1_responses_get))
        .route(
            "/v1/responses/{response_id}/cancel",
            post(v1_responses_cancel),
        )
        .route("/v1/responses/{response_id}", delete(v1_responses_delete))
        .route(
            "/v1/responses/{response_id}/input",
            get(v1_responses_list_input_items),
        )
730
731
732
733
734
735
736
        .route("/v1/conversations", post(v1_conversations_create))
        .route(
            "/v1/conversations/{conversation_id}",
            get(v1_conversations_get)
                .post(v1_conversations_update)
                .delete(v1_conversations_delete),
        )
737
738
        .route(
            "/v1/conversations/{conversation_id}/items",
739
740
741
742
743
            get(v1_conversations_list_items).post(v1_conversations_create_items),
        )
        .route(
            "/v1/conversations/{conversation_id}/items/{item_id}",
            get(v1_conversations_get_item).delete(v1_conversations_delete_item),
744
        )
745
746
        .route_layer(axum::middleware::from_fn_with_state(
            app_state.clone(),
747
            middleware::concurrency_limit_middleware,
748
749
750
751
        ))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
752
        ));
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767

    let public_routes = Router::new()
        .route("/liveness", get(liveness))
        .route("/readiness", get(readiness))
        .route("/health", get(health))
        .route("/health_generate", get(health_generate))
        .route("/v1/models", get(v1_models))
        .route("/get_model_info", get(get_model_info))
        .route("/get_server_info", get(get_server_info));

    let admin_routes = Router::new()
        .route("/add_worker", post(add_worker))
        .route("/remove_worker", post(remove_worker))
        .route("/list_workers", get(list_workers))
        .route("/flush_cache", post(flush_cache))
768
769
770
771
772
        .route("/get_loads", get(get_loads))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
        ));
773

774
775
776
777
    let worker_routes = Router::new()
        .route("/workers", post(create_worker))
        .route("/workers", get(list_workers_rest))
        .route("/workers/{url}", get(get_worker))
778
779
780
781
782
        .route("/workers/{url}", delete(delete_worker))
        .route_layer(axum::middleware::from_fn_with_state(
            auth_config.clone(),
            middleware::auth_middleware,
        ));
783

784
785
786
787
    Router::new()
        .merge(protected_routes)
        .merge(public_routes)
        .merge(admin_routes)
788
        .merge(worker_routes)
789
        .layer(axum::extract::DefaultBodyLimit::max(max_payload_size))
790
791
792
        .layer(tower_http::limit::RequestBodyLimitLayer::new(
            max_payload_size,
        ))
793
794
        .layer(middleware::create_logging_layer())
        .layer(middleware::RequestIdLayer::new(request_id_headers))
795
796
797
798
799
800
        .layer(create_cors_layer(cors_allowed_origins))
        .fallback(sink_handler)
        .with_state(app_state)
}

pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
801
802
803
804
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
805
806
807
808
809
810
            level: config
                .log_level
                .as_deref()
                .and_then(|s| match s.to_uppercase().parse::<Level>() {
                    Ok(l) => Some(l),
                    Err(_) => {
811
                        warn!("Invalid log level string: '{s}'. Defaulting to INFO.");
812
813
814
815
                        None
                    }
                })
                .unwrap_or(Level::INFO),
816
817
818
819
820
821
822
823
824
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
825

826
827
    if let Some(prometheus_config) = &config.prometheus_config {
        metrics::start_prometheus(prometheus_config.clone());
828
829
    }

830
    info!(
831
832
833
834
835
        "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
        config.host,
        config.port,
        config.router_config.mode,
        config.router_config.policy,
836
837
838
        config.max_payload_size / (1024 * 1024)
    );

839
    let client = Client::builder()
840
        .pool_idle_timeout(Some(Duration::from_secs(50)))
841
        .pool_max_idle_per_host(500)
842
        .timeout(Duration::from_secs(config.request_timeout_secs))
843
        .connect_timeout(Duration::from_secs(10))
844
        .tcp_nodelay(true)
845
        .tcp_keepalive(Some(Duration::from_secs(30)))
846
847
848
        .build()
        .expect("Failed to create HTTP client");

849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
    // Initialize rate limiter
    let rate_limiter = match config.router_config.max_concurrent_requests {
        n if n <= 0 => None,
        n => {
            let rate_limit_tokens = config
                .router_config
                .rate_limit_tokens_per_second
                .filter(|&t| t > 0)
                .unwrap_or(n);
            Some(Arc::new(TokenBucket::new(
                n as usize,
                rate_limit_tokens as usize,
            )))
        }
    };

    // Initialize tokenizer and parser factories for gRPC mode
    let (tokenizer, reasoning_parser_factory, tool_parser_factory) = if config
        .router_config
        .connection_mode
        == ConnectionMode::Grpc
    {
        let tokenizer_path = config
            .router_config
            .tokenizer_path
            .clone()
            .or_else(|| config.router_config.model_path.clone())
            .ok_or_else(|| {
                "gRPC mode requires either --tokenizer-path or --model-path to be specified"
                    .to_string()
            })?;

881
        let base_tokenizer =
882
883
884
885
886
887
888
889
890
891
892
                tokenizer_factory::create_tokenizer_with_chat_template_blocking(
                    &tokenizer_path,
                    config.router_config.chat_template.as_deref(),
                )
                .map_err(|e| {
                    format!(
                        "Failed to create tokenizer from '{}': {}. \
                        Ensure the path is valid and points to a tokenizer file (tokenizer.json) \
                        or a HuggingFace model ID. For directories, ensure they contain tokenizer files.",
                        tokenizer_path, e
                    )
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
                })?;

        // Conditionally wrap with caching layer if at least one cache is enabled
        let tokenizer = if config.router_config.tokenizer_cache.enable_l0
            || config.router_config.tokenizer_cache.enable_l1
        {
            let cache_config = CacheConfig {
                enable_l0: config.router_config.tokenizer_cache.enable_l0,
                l0_max_entries: config.router_config.tokenizer_cache.l0_max_entries,
                enable_l1: config.router_config.tokenizer_cache.enable_l1,
                l1_max_memory: config.router_config.tokenizer_cache.l1_max_memory,
            };
            Some(Arc::new(CachedTokenizer::new(base_tokenizer, cache_config)) as Arc<dyn Tokenizer>)
        } else {
            // Use base tokenizer directly without caching
            Some(base_tokenizer)
        };
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
        let reasoning_parser_factory = Some(ReasoningParserFactory::new());
        let tool_parser_factory = Some(ToolParserFactory::new());

        (tokenizer, reasoning_parser_factory, tool_parser_factory)
    } else {
        (None, None, None)
    };

    // Initialize worker registry and policy registry
    let worker_registry = Arc::new(WorkerRegistry::new());
    let policy_registry = Arc::new(PolicyRegistry::new(config.router_config.policy.clone()));

    // Initialize storage backends
    let (response_storage, conversation_storage): (
        SharedResponseStorage,
        SharedConversationStorage,
    ) = match config.router_config.history_backend {
        HistoryBackend::Memory => {
            info!("Initializing data connector: Memory");
            (
                Arc::new(MemoryResponseStorage::new()),
                Arc::new(MemoryConversationStorage::new()),
            )
        }
        HistoryBackend::None => {
            info!("Initializing data connector: None (no persistence)");
            (
                Arc::new(NoOpResponseStorage::new()),
                Arc::new(NoOpConversationStorage::new()),
            )
        }
        HistoryBackend::Oracle => {
            let oracle_cfg = config.router_config.oracle.clone().ok_or_else(|| {
                "oracle configuration is required when history_backend=oracle".to_string()
            })?;
            info!(
                "Initializing data connector: Oracle ATP (pool: {}-{})",
                oracle_cfg.pool_min, oracle_cfg.pool_max
            );

            let response_storage = OracleResponseStorage::new(oracle_cfg.clone())
                .map_err(|err| format!("failed to initialize Oracle response storage: {err}"))?;

            let conversation_storage =
                OracleConversationStorage::new(oracle_cfg.clone()).map_err(|err| {
                    format!("failed to initialize Oracle conversation storage: {err}")
                })?;
            info!("Data connector initialized successfully: Oracle ATP");

            (Arc::new(response_storage), Arc::new(conversation_storage))
        }
    };

    // Initialize conversation items storage
    let conversation_item_storage: SharedConversationItemStorage =
        match config.router_config.history_backend {
            HistoryBackend::Oracle => {
                let oracle_cfg = config.router_config.oracle.clone().ok_or_else(|| {
                    "oracle configuration is required when history_backend=oracle".to_string()
                })?;
                Arc::new(OracleConversationItemStorage::new(oracle_cfg).map_err(|e| {
                    format!("failed to initialize Oracle conversation item storage: {e}")
                })?)
            }
            _ => Arc::new(MemoryConversationItemStorage::new()),
        };

    // Initialize load monitor
    let load_monitor = Some(Arc::new(LoadMonitor::new(
        worker_registry.clone(),
        policy_registry.clone(),
        client.clone(),
        config.router_config.worker_startup_check_interval_secs,
    )));

985
    // Create empty OnceLock for worker job queue and workflow engine (will be initialized below)
986
    let worker_job_queue = Arc::new(OnceLock::new());
987
    let workflow_engine = Arc::new(OnceLock::new());
988
989

    // Create AppContext with all initialized components
990
    let app_context = AppContext::new(
991
992
        config.router_config.clone(),
        client.clone(),
993
994
995
996
997
998
999
1000
1001
1002
1003
        rate_limiter,
        tokenizer,
        reasoning_parser_factory,
        tool_parser_factory,
        worker_registry,
        policy_registry,
        response_storage,
        conversation_storage,
        conversation_item_storage,
        load_monitor,
        worker_job_queue,
1004
        workflow_engine,
1005
    );
1006
1007
1008

    let app_context = Arc::new(app_context);

1009
1010
1011
1012
1013
1014
1015
    let weak_context = Arc::downgrade(&app_context);
    let worker_job_queue = JobQueue::new(JobQueueConfig::default(), weak_context);
    app_context
        .worker_job_queue
        .set(worker_job_queue)
        .expect("JobQueue should only be initialized once");

1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
    // Initialize workflow engine and register workflows
    let engine = Arc::new(WorkflowEngine::new());

    engine
        .event_bus()
        .subscribe(Arc::new(crate::core::workflow::LoggingSubscriber))
        .await;

    engine.register_workflow(crate::core::workflow::create_worker_registration_workflow());
    app_context
        .workflow_engine
        .set(engine)
        .expect("WorkflowEngine should only be initialized once");
    info!("Workflow engine initialized with worker registration workflow");

1031
1032
1033
1034
    info!(
        "Initializing workers for routing mode: {:?}",
        config.router_config.mode
    );
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047

    // Submit worker initialization job to queue
    let job_queue = app_context
        .worker_job_queue
        .get()
        .expect("JobQueue should be initialized");
    let job = Job::InitializeWorkersFromConfig {
        router_config: Box::new(config.router_config.clone()),
    };
    job_queue
        .submit(job)
        .await
        .map_err(|e| format!("Failed to submit worker initialization job: {}", e))?;
1048
1049
1050
1051
1052
1053
1054

    let worker_stats = app_context.worker_registry.stats();
    info!(
        "Workers initialized: {} total, {} healthy",
        worker_stats.total_workers, worker_stats.healthy_workers
    );

1055
1056
    let router_manager = RouterManager::from_config(&config, &app_context).await?;
    let router: Arc<dyn RouterTrait> = router_manager.clone();
1057
1058
1059
1060
1061
1062
1063
1064

    let _health_checker = app_context
        .worker_registry
        .start_health_checker(config.router_config.health_check.check_interval_secs);
    info!(
        "Started health checker for workers with {}s interval",
        config.router_config.health_check.check_interval_secs
    );
1065

1066
1067
1068
1069
1070
    if let Some(ref load_monitor) = app_context.load_monitor {
        load_monitor.start().await;
        info!("Started LoadMonitor for PowerOfTwo policies");
    }

1071
    let (limiter, processor) = middleware::ConcurrencyLimiter::new(
1072
1073
1074
1075
1076
        app_context.rate_limiter.clone(),
        config.router_config.queue_size,
        Duration::from_secs(config.router_config.queue_timeout_secs),
    );

1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
    if app_context.rate_limiter.is_none() {
        info!("Rate limiting is disabled (max_concurrent_requests = -1)");
    }

    match processor {
        Some(proc) => {
            spawn(proc.run());
            info!(
                "Started request queue (size: {}, timeout: {}s)",
                config.router_config.queue_size, config.router_config.queue_timeout_secs
            );
        }
        None => {
            info!(
                "Rate limiting enabled (max_concurrent_requests = {}, queue disabled)",
                config.router_config.max_concurrent_requests
            );
        }
1095
1096
    }

1097
    let app_state = Arc::new(AppState {
1098
        router,
1099
        context: app_context.clone(),
1100
        concurrency_queue_tx: limiter.queue_tx.clone(),
1101
        router_manager: Some(router_manager),
1102
    });
1103
1104
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
1105
1106
            let app_context_arc = Arc::clone(&app_state.context);
            match start_service_discovery(service_discovery_config, app_context_arc).await {
1107
                Ok(handle) => {
1108
                    info!("Service discovery started");
1109
1110
1111
1112
1113
1114
1115
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
1116
                    error!("Failed to start service discovery: {e}");
1117
1118
1119
1120
1121
1122
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

1123
    info!(
1124
        "Router ready | workers: {:?}",
1125
        WorkerManager::get_worker_urls(&app_state.context.worker_registry)
1126
    );
1127

1128
1129
1130
1131
1132
1133
1134
1135
1136
    let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

1137
1138
1139
1140
    let auth_config = AuthConfig {
        api_key: config.router_config.api_key.clone(),
    };

1141
1142
    let app = build_app(
        app_state,
1143
        auth_config,
1144
1145
1146
1147
        config.max_payload_size,
        request_id_headers,
        config.router_config.cors_allowed_origins.clone(),
    );
1148

1149
1150
1151
1152
1153
1154
    // TcpListener::bind accepts &str and handles IPv4/IPv6 via ToSocketAddrs
    let bind_addr = format!("{}:{}", config.host, config.port);
    info!("Starting server on {}", bind_addr);
    let listener = TcpListener::bind(&bind_addr)
        .await
        .map_err(|e| format!("Failed to bind to {}: {}", bind_addr, e))?;
1155
    serve(listener, app)
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
        .with_graceful_shutdown(shutdown_signal())
        .await
        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;

    Ok(())
}

async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {
            info!("Received Ctrl+C, starting graceful shutdown");
        },
        _ = terminate => {
            info!("Received terminate signal, starting graceful shutdown");
        },
    }
}

fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
    use tower_http::cors::Any;

    let cors = if allowed_origins.is_empty() {
        tower_http::cors::CorsLayer::new()
            .allow_origin(Any)
            .allow_methods(Any)
            .allow_headers(Any)
            .expose_headers(Any)
    } else {
        let origins: Vec<http::HeaderValue> = allowed_origins
            .into_iter()
            .filter_map(|origin| origin.parse().ok())
            .collect();

        tower_http::cors::CorsLayer::new()
            .allow_origin(origins)
            .allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
            .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
            .expose_headers([http::header::HeaderName::from_static("x-request-id")])
    };

    cors.max_age(Duration::from_secs(3600))
1214
}