server.rs 16.7 KB
Newer Older
1
use crate::config::RouterConfig;
2
use crate::logging::{self, LoggingConfig};
3
use crate::metrics::{self, PrometheusConfig};
4
use crate::middleware::TokenBucket;
5
use crate::protocols::spec::{
6
7
    ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
    V1RerankReqInput,
8
};
9
use crate::reasoning_parser::ParserFactory;
10
use crate::routers::{RouterFactory, RouterTrait};
11
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
12
13
use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer};
use crate::tool_parser::ParserRegistry;
14
15
16
17
18
19
use axum::{
    extract::{Query, Request, State},
    http::StatusCode,
    response::{IntoResponse, Response},
    routing::{get, post},
    Json, Router,
20
};
21
use reqwest::Client;
22
use std::collections::HashMap;
23
use std::sync::atomic::{AtomicBool, Ordering};
24
use std::sync::Arc;
25
use std::time::Duration;
26
27
use tokio::net::TcpListener;
use tokio::signal;
28
29
use tokio::spawn;
use tracing::{error, info, warn, Level};
30

31
#[derive(Clone)]
32
pub struct AppContext {
33
    pub client: Client,
34
    pub router_config: RouterConfig,
35
    pub rate_limiter: Arc<TokenBucket>,
36
37
38
    pub tokenizer: Option<Arc<dyn Tokenizer>>,
    pub reasoning_parser_factory: Option<ParserFactory>,
    pub tool_parser_registry: Option<&'static ParserRegistry>,
39
40
}

41
impl AppContext {
42
43
44
45
    pub fn new(
        router_config: RouterConfig,
        client: Client,
        max_concurrent_requests: usize,
46
        rate_limit_tokens_per_second: Option<usize>,
47
    ) -> Result<Self, String> {
48
49
        let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
        let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

        // Initialize gRPC-specific components only when in gRPC mode
        let (tokenizer, reasoning_parser_factory, tool_parser_registry) =
            if router_config.connection_mode == crate::config::ConnectionMode::Grpc {
                // Get tokenizer path (required for gRPC mode)
                let tokenizer_path = router_config
                    .tokenizer_path
                    .clone()
                    .or_else(|| router_config.model_path.clone())
                    .ok_or_else(|| {
                        "gRPC mode requires either --tokenizer-path or --model-path to be specified"
                            .to_string()
                    })?;

                // Initialize all gRPC components
                let tokenizer = Some(
                    tokenizer_factory::create_tokenizer(&tokenizer_path)
                        .map_err(|e| format!("Failed to create tokenizer: {}", e))?,
                );
                let reasoning_parser_factory = Some(ParserFactory::new());
                let tool_parser_registry = Some(ParserRegistry::new());

                (tokenizer, reasoning_parser_factory, tool_parser_registry)
            } else {
                // HTTP mode doesn't need these components
                (None, None, None)
            };

        Ok(Self {
79
            client,
80
            router_config,
81
            rate_limiter,
82
83
84
85
            tokenizer,
            reasoning_parser_factory,
            tool_parser_registry,
        })
86
87
88
    }
}

89
90
91
92
#[derive(Clone)]
pub struct AppState {
    pub router: Arc<dyn RouterTrait>,
    pub context: Arc<AppContext>,
93
    pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<crate::middleware::QueuedRequest>>,
94
95
}

96
97
98
// Fallback handler for unmatched routes
async fn sink_handler() -> Response {
    StatusCode::NOT_FOUND.into_response()
99
100
}

101
102
103
// Health check endpoints
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
    state.router.liveness()
104
105
}

106
107
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
    state.router.readiness()
108
109
}

110
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
111
    state.router.health(req).await
112
113
}

114
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
115
    state.router.health_generate(req).await
116
117
}

118
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
119
    state.router.get_server_info(req).await
120
121
}

122
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
123
    state.router.get_models(req).await
124
125
}

126
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
127
    state.router.get_model_info(req).await
128
}
129

130
131
// Generation endpoints
// The RouterTrait now accepts optional headers and typed body directly
132
async fn generate(
133
134
135
136
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<GenerateRequest>,
) -> Response {
137
    state.router.route_generate(Some(&headers), &body).await
138
139
140
}

async fn v1_chat_completions(
141
142
143
144
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ChatCompletionRequest>,
) -> Response {
145
    state.router.route_chat(Some(&headers), &body).await
146
147
148
}

async fn v1_completions(
149
150
151
152
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<CompletionRequest>,
) -> Response {
153
    state.router.route_completion(Some(&headers), &body).await
154
155
}

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
async fn rerank(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<RerankRequest>,
) -> Response {
    state.router.route_rerank(Some(&headers), &body).await
}

async fn v1_rerank(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<V1RerankReqInput>,
) -> Response {
    state
        .router
        .route_rerank(Some(&headers), &body.into())
        .await
}

175
176
177
178
179
180
181
182
async fn v1_responses(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ResponsesRequest>,
) -> Response {
    state.router.route_responses(Some(&headers), &body).await
}

183
// Worker management endpoints
184
async fn add_worker(
185
186
187
188
    State(state): State<Arc<AppState>>,
    Query(params): Query<HashMap<String, String>>,
) -> Response {
    let worker_url = match params.get("url") {
189
190
        Some(url) => url.to_string(),
        None => {
191
192
193
194
195
            return (
                StatusCode::BAD_REQUEST,
                "Worker URL required. Provide 'url' query parameter",
            )
                .into_response();
196
197
        }
    };
198

199
200
201
    match state.router.add_worker(&worker_url).await {
        Ok(message) => (StatusCode::OK, message).into_response(),
        Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
202
    }
203
204
}

205
206
207
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
    let worker_list = state.router.get_worker_urls();
    Json(serde_json::json!({ "urls": worker_list })).into_response()
208
209
}

210
async fn remove_worker(
211
212
213
214
    State(state): State<Arc<AppState>>,
    Query(params): Query<HashMap<String, String>>,
) -> Response {
    let worker_url = match params.get("url") {
215
        Some(url) => url.to_string(),
216
        None => return StatusCode::BAD_REQUEST.into_response(),
217
    };
218

219
220
221
222
223
224
    state.router.remove_worker(&worker_url);
    (
        StatusCode::OK,
        format!("Successfully removed worker: {}", worker_url),
    )
        .into_response()
225
226
}

227
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
228
    state.router.flush_cache().await
229
230
}

231
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
232
    state.router.get_worker_loads().await
233
234
}

235
236
237
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
238
    pub router_config: RouterConfig,
239
    pub max_payload_size: usize,
240
    pub log_dir: Option<String>,
241
    pub log_level: Option<String>,
242
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
243
    pub prometheus_config: Option<PrometheusConfig>,
244
    pub request_timeout_secs: u64,
245
    pub request_id_headers: Option<Vec<String>>,
246
247
}

248
249
250
251
252
253
254
255
256
257
258
/// Build the Axum application with all routes and middleware
pub fn build_app(
    app_state: Arc<AppState>,
    max_payload_size: usize,
    request_id_headers: Vec<String>,
    cors_allowed_origins: Vec<String>,
) -> Router {
    // Create routes
    let protected_routes = Router::new()
        .route("/generate", post(generate))
        .route("/v1/chat/completions", post(v1_chat_completions))
259
        .route("/v1/completions", post(v1_completions))
260
261
        .route("/rerank", post(rerank))
        .route("/v1/rerank", post(v1_rerank))
262
        .route("/v1/responses", post(v1_responses))
263
264
265
266
        .route_layer(axum::middleware::from_fn_with_state(
            app_state.clone(),
            crate::middleware::concurrency_limit_middleware,
        ));
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306

    let public_routes = Router::new()
        .route("/liveness", get(liveness))
        .route("/readiness", get(readiness))
        .route("/health", get(health))
        .route("/health_generate", get(health_generate))
        .route("/v1/models", get(v1_models))
        .route("/get_model_info", get(get_model_info))
        .route("/get_server_info", get(get_server_info));

    let admin_routes = Router::new()
        .route("/add_worker", post(add_worker))
        .route("/remove_worker", post(remove_worker))
        .route("/list_workers", get(list_workers))
        .route("/flush_cache", post(flush_cache))
        .route("/get_loads", get(get_loads));

    // Build app with all routes and middleware
    Router::new()
        .merge(protected_routes)
        .merge(public_routes)
        .merge(admin_routes)
        // Request body size limiting
        .layer(tower_http::limit::RequestBodyLimitLayer::new(
            max_payload_size,
        ))
        // Request ID layer - must be added AFTER logging layer in the code
        // so it executes BEFORE logging layer at runtime (layers execute bottom-up)
        .layer(crate::middleware::RequestIdLayer::new(request_id_headers))
        // Custom logging layer that can now see request IDs from extensions
        .layer(crate::middleware::create_logging_layer())
        // CORS (should be outermost)
        .layer(create_cors_layer(cors_allowed_origins))
        // Fallback
        .fallback(sink_handler)
        // State - apply last to get Router<Arc<AppState>>
        .with_state(app_state)
}

pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
307
308
309
310
311
    // Only initialize logging if not already done (for Python bindings support)
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
312
313
314
315
316
317
318
319
320
321
322
            level: config
                .log_level
                .as_deref()
                .and_then(|s| match s.to_uppercase().parse::<Level>() {
                    Ok(l) => Some(l),
                    Err(_) => {
                        warn!("Invalid log level string: '{}'. Defaulting to INFO.", s);
                        None
                    }
                })
                .unwrap_or(Level::INFO),
323
324
325
326
327
328
329
330
331
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
332

333
334
    // Initialize prometheus metrics exporter
    if let Some(prometheus_config) = config.prometheus_config {
335
        metrics::start_prometheus(prometheus_config);
336
337
    }

338
    info!(
339
340
341
342
343
        "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
        config.host,
        config.port,
        config.router_config.mode,
        config.router_config.policy,
344
345
346
        config.max_payload_size / (1024 * 1024)
    );

347
    let client = Client::builder()
348
        .pool_idle_timeout(Some(Duration::from_secs(50)))
349
        .pool_max_idle_per_host(500) // Increase to 500 connections per host
350
351
352
353
        .timeout(Duration::from_secs(config.request_timeout_secs))
        .connect_timeout(Duration::from_secs(10)) // Separate connection timeout
        .tcp_nodelay(true)
        .tcp_keepalive(Some(Duration::from_secs(30))) // Keep connections alive
354
355
356
        .build()
        .expect("Failed to create HTTP client");

357
358
    // Create the application context with all dependencies
    let app_context = Arc::new(AppContext::new(
359
360
361
        config.router_config.clone(),
        client.clone(),
        config.router_config.max_concurrent_requests,
362
        config.router_config.rate_limit_tokens_per_second,
363
    )?);
364
365

    // Create router with the context
366
    let router = RouterFactory::create_router(&app_context).await?;
367

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    // Set up concurrency limiter with queue if configured
    let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new(
        app_context.rate_limiter.clone(),
        config.router_config.queue_size,
        Duration::from_secs(config.router_config.queue_timeout_secs),
    );

    // Start queue processor if enabled
    if let Some(processor) = processor {
        tokio::spawn(processor.run());
        info!(
            "Started request queue with size: {}, timeout: {}s",
            config.router_config.queue_size, config.router_config.queue_timeout_secs
        );
    }

384
385
386
387
    // Create app state with router and context
    let app_state = Arc::new(AppState {
        router: Arc::from(router),
        context: app_context.clone(),
388
        concurrency_queue_tx: limiter.queue_tx.clone(),
389
    });
390
    let router_arc = Arc::clone(&app_state.router);
391

392
393
394
    // Start the service discovery if enabled
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
395
            match start_service_discovery(service_discovery_config, router_arc).await {
396
                Ok(handle) => {
397
                    info!("Service discovery started");
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
                    // Spawn a task to handle the service discovery thread
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
                    error!("Failed to start service discovery: {}", e);
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

413
    info!(
414
        "Router ready | workers: {:?}",
415
        app_state.router.get_worker_urls()
416
    );
417

418
419
420
421
422
423
424
425
426
427
    // Configure request ID headers
    let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

428
429
430
431
432
433
434
    // Build the application
    let app = build_app(
        app_state,
        config.max_payload_size,
        request_id_headers,
        config.router_config.cors_allowed_origins.clone(),
    );
435

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
    // Create TCP listener - use the configured host
    let addr = format!("{}:{}", config.host, config.port);
    let listener = TcpListener::bind(&addr).await?;

    // Start server with graceful shutdown
    info!("Starting server on {}", addr);

    // Serve the application with graceful shutdown
    axum::serve(listener, app)
        .with_graceful_shutdown(shutdown_signal())
        .await
        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;

    Ok(())
}

// Graceful shutdown handler
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {
            info!("Received Ctrl+C, starting graceful shutdown");
        },
        _ = terminate => {
            info!("Received terminate signal, starting graceful shutdown");
        },
    }
}

// CORS Layer Creation
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
    use tower_http::cors::Any;

    let cors = if allowed_origins.is_empty() {
        // Allow all origins if none specified
        tower_http::cors::CorsLayer::new()
            .allow_origin(Any)
            .allow_methods(Any)
            .allow_headers(Any)
            .expose_headers(Any)
    } else {
        // Restrict to specific origins
        let origins: Vec<http::HeaderValue> = allowed_origins
            .into_iter()
            .filter_map(|origin| origin.parse().ok())
            .collect();

        tower_http::cors::CorsLayer::new()
            .allow_origin(origins)
            .allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
            .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
            .expose_headers([http::header::HeaderName::from_static("x-request-id")])
    };

    cors.max_age(Duration::from_secs(3600))
507
}