server.rs 16.2 KB
Newer Older
1
use crate::config::RouterConfig;
2
use crate::logging::{self, LoggingConfig};
3
use crate::metrics::{self, PrometheusConfig};
4
use crate::middleware::TokenBucket;
5
6
7
use crate::protocols::spec::{
    ChatCompletionRequest, CompletionRequest, GenerateRequest, ResponsesRequest,
};
8
use crate::reasoning_parser::ParserFactory;
9
use crate::routers::{RouterFactory, RouterTrait};
10
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
11
12
use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer};
use crate::tool_parser::ParserRegistry;
13
14
15
16
17
18
use axum::{
    extract::{Query, Request, State},
    http::StatusCode,
    response::{IntoResponse, Response},
    routing::{get, post},
    Json, Router,
19
};
20
use reqwest::Client;
21
use std::collections::HashMap;
22
use std::sync::atomic::{AtomicBool, Ordering};
23
use std::sync::Arc;
24
use std::time::Duration;
25
26
use tokio::net::TcpListener;
use tokio::signal;
27
28
use tokio::spawn;
use tracing::{error, info, warn, Level};
29

30
#[derive(Clone)]
31
pub struct AppContext {
32
    pub client: Client,
33
    pub router_config: RouterConfig,
34
    pub rate_limiter: Arc<TokenBucket>,
35
36
37
    pub tokenizer: Option<Arc<dyn Tokenizer>>,
    pub reasoning_parser_factory: Option<ParserFactory>,
    pub tool_parser_registry: Option<&'static ParserRegistry>,
38
39
}

40
impl AppContext {
41
42
43
44
    pub fn new(
        router_config: RouterConfig,
        client: Client,
        max_concurrent_requests: usize,
45
        rate_limit_tokens_per_second: Option<usize>,
46
    ) -> Result<Self, String> {
47
48
        let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
        let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

        // Initialize gRPC-specific components only when in gRPC mode
        let (tokenizer, reasoning_parser_factory, tool_parser_registry) =
            if router_config.connection_mode == crate::config::ConnectionMode::Grpc {
                // Get tokenizer path (required for gRPC mode)
                let tokenizer_path = router_config
                    .tokenizer_path
                    .clone()
                    .or_else(|| router_config.model_path.clone())
                    .ok_or_else(|| {
                        "gRPC mode requires either --tokenizer-path or --model-path to be specified"
                            .to_string()
                    })?;

                // Initialize all gRPC components
                let tokenizer = Some(
                    tokenizer_factory::create_tokenizer(&tokenizer_path)
                        .map_err(|e| format!("Failed to create tokenizer: {}", e))?,
                );
                let reasoning_parser_factory = Some(ParserFactory::new());
                let tool_parser_registry = Some(ParserRegistry::new());

                (tokenizer, reasoning_parser_factory, tool_parser_registry)
            } else {
                // HTTP mode doesn't need these components
                (None, None, None)
            };

        Ok(Self {
78
            client,
79
            router_config,
80
            rate_limiter,
81
82
83
84
            tokenizer,
            reasoning_parser_factory,
            tool_parser_registry,
        })
85
86
87
    }
}

88
89
90
91
#[derive(Clone)]
pub struct AppState {
    pub router: Arc<dyn RouterTrait>,
    pub context: Arc<AppContext>,
92
    pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<crate::middleware::QueuedRequest>>,
93
94
}

95
96
97
// Fallback handler for unmatched routes
async fn sink_handler() -> Response {
    StatusCode::NOT_FOUND.into_response()
98
99
}

100
101
102
// Health check endpoints
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
    state.router.liveness()
103
104
}

105
106
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
    state.router.readiness()
107
108
}

109
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
110
    state.router.health(req).await
111
112
}

113
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
114
    state.router.health_generate(req).await
115
116
}

117
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
118
    state.router.get_server_info(req).await
119
120
}

121
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
122
    state.router.get_models(req).await
123
124
}

125
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
126
    state.router.get_model_info(req).await
127
}
128

129
130
// Generation endpoints
// The RouterTrait now accepts optional headers and typed body directly
131
async fn generate(
132
133
134
135
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<GenerateRequest>,
) -> Response {
136
    state.router.route_generate(Some(&headers), &body).await
137
138
139
}

async fn v1_chat_completions(
140
141
142
143
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ChatCompletionRequest>,
) -> Response {
144
    state.router.route_chat(Some(&headers), &body).await
145
146
147
}

async fn v1_completions(
148
149
150
151
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<CompletionRequest>,
) -> Response {
152
    state.router.route_completion(Some(&headers), &body).await
153
154
}

155
156
157
158
159
160
161
162
async fn v1_responses(
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ResponsesRequest>,
) -> Response {
    state.router.route_responses(Some(&headers), &body).await
}

163
// Worker management endpoints
164
async fn add_worker(
165
166
167
168
    State(state): State<Arc<AppState>>,
    Query(params): Query<HashMap<String, String>>,
) -> Response {
    let worker_url = match params.get("url") {
169
170
        Some(url) => url.to_string(),
        None => {
171
172
173
174
175
            return (
                StatusCode::BAD_REQUEST,
                "Worker URL required. Provide 'url' query parameter",
            )
                .into_response();
176
177
        }
    };
178

179
180
181
    match state.router.add_worker(&worker_url).await {
        Ok(message) => (StatusCode::OK, message).into_response(),
        Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
182
    }
183
184
}

185
186
187
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
    let worker_list = state.router.get_worker_urls();
    Json(serde_json::json!({ "urls": worker_list })).into_response()
188
189
}

190
async fn remove_worker(
191
192
193
194
    State(state): State<Arc<AppState>>,
    Query(params): Query<HashMap<String, String>>,
) -> Response {
    let worker_url = match params.get("url") {
195
        Some(url) => url.to_string(),
196
        None => return StatusCode::BAD_REQUEST.into_response(),
197
    };
198

199
200
201
202
203
204
    state.router.remove_worker(&worker_url);
    (
        StatusCode::OK,
        format!("Successfully removed worker: {}", worker_url),
    )
        .into_response()
205
206
}

207
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
208
    state.router.flush_cache().await
209
210
}

211
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
212
    state.router.get_worker_loads().await
213
214
}

215
216
217
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
218
    pub router_config: RouterConfig,
219
    pub max_payload_size: usize,
220
    pub log_dir: Option<String>,
221
    pub log_level: Option<String>,
222
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
223
    pub prometheus_config: Option<PrometheusConfig>,
224
    pub request_timeout_secs: u64,
225
    pub request_id_headers: Option<Vec<String>>,
226
227
}

228
229
230
231
232
233
234
235
236
237
238
/// Build the Axum application with all routes and middleware
pub fn build_app(
    app_state: Arc<AppState>,
    max_payload_size: usize,
    request_id_headers: Vec<String>,
    cors_allowed_origins: Vec<String>,
) -> Router {
    // Create routes
    let protected_routes = Router::new()
        .route("/generate", post(generate))
        .route("/v1/chat/completions", post(v1_chat_completions))
239
        .route("/v1/completions", post(v1_completions))
240
        .route("/v1/responses", post(v1_responses))
241
242
243
244
        .route_layer(axum::middleware::from_fn_with_state(
            app_state.clone(),
            crate::middleware::concurrency_limit_middleware,
        ));
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

    let public_routes = Router::new()
        .route("/liveness", get(liveness))
        .route("/readiness", get(readiness))
        .route("/health", get(health))
        .route("/health_generate", get(health_generate))
        .route("/v1/models", get(v1_models))
        .route("/get_model_info", get(get_model_info))
        .route("/get_server_info", get(get_server_info));

    let admin_routes = Router::new()
        .route("/add_worker", post(add_worker))
        .route("/remove_worker", post(remove_worker))
        .route("/list_workers", get(list_workers))
        .route("/flush_cache", post(flush_cache))
        .route("/get_loads", get(get_loads));

    // Build app with all routes and middleware
    Router::new()
        .merge(protected_routes)
        .merge(public_routes)
        .merge(admin_routes)
        // Request body size limiting
        .layer(tower_http::limit::RequestBodyLimitLayer::new(
            max_payload_size,
        ))
        // Request ID layer - must be added AFTER logging layer in the code
        // so it executes BEFORE logging layer at runtime (layers execute bottom-up)
        .layer(crate::middleware::RequestIdLayer::new(request_id_headers))
        // Custom logging layer that can now see request IDs from extensions
        .layer(crate::middleware::create_logging_layer())
        // CORS (should be outermost)
        .layer(create_cors_layer(cors_allowed_origins))
        // Fallback
        .fallback(sink_handler)
        // State - apply last to get Router<Arc<AppState>>
        .with_state(app_state)
}

pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
285
286
287
288
289
    // Only initialize logging if not already done (for Python bindings support)
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
290
291
292
293
294
295
296
297
298
299
300
            level: config
                .log_level
                .as_deref()
                .and_then(|s| match s.to_uppercase().parse::<Level>() {
                    Ok(l) => Some(l),
                    Err(_) => {
                        warn!("Invalid log level string: '{}'. Defaulting to INFO.", s);
                        None
                    }
                })
                .unwrap_or(Level::INFO),
301
302
303
304
305
306
307
308
309
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
310

311
312
    // Initialize prometheus metrics exporter
    if let Some(prometheus_config) = config.prometheus_config {
313
        metrics::start_prometheus(prometheus_config);
314
315
    }

316
    info!(
317
318
319
320
321
        "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
        config.host,
        config.port,
        config.router_config.mode,
        config.router_config.policy,
322
323
324
        config.max_payload_size / (1024 * 1024)
    );

325
    let client = Client::builder()
326
        .pool_idle_timeout(Some(Duration::from_secs(50)))
327
        .pool_max_idle_per_host(500) // Increase to 500 connections per host
328
329
330
331
        .timeout(Duration::from_secs(config.request_timeout_secs))
        .connect_timeout(Duration::from_secs(10)) // Separate connection timeout
        .tcp_nodelay(true)
        .tcp_keepalive(Some(Duration::from_secs(30))) // Keep connections alive
332
333
334
        .build()
        .expect("Failed to create HTTP client");

335
336
    // Create the application context with all dependencies
    let app_context = Arc::new(AppContext::new(
337
338
339
        config.router_config.clone(),
        client.clone(),
        config.router_config.max_concurrent_requests,
340
        config.router_config.rate_limit_tokens_per_second,
341
    )?);
342
343

    // Create router with the context
344
    let router = RouterFactory::create_router(&app_context).await?;
345

346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
    // Set up concurrency limiter with queue if configured
    let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new(
        app_context.rate_limiter.clone(),
        config.router_config.queue_size,
        Duration::from_secs(config.router_config.queue_timeout_secs),
    );

    // Start queue processor if enabled
    if let Some(processor) = processor {
        tokio::spawn(processor.run());
        info!(
            "Started request queue with size: {}, timeout: {}s",
            config.router_config.queue_size, config.router_config.queue_timeout_secs
        );
    }

362
363
364
365
    // Create app state with router and context
    let app_state = Arc::new(AppState {
        router: Arc::from(router),
        context: app_context.clone(),
366
        concurrency_queue_tx: limiter.queue_tx.clone(),
367
    });
368
    let router_arc = Arc::clone(&app_state.router);
369

370
371
372
    // Start the service discovery if enabled
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
373
            match start_service_discovery(service_discovery_config, router_arc).await {
374
                Ok(handle) => {
375
                    info!("Service discovery started");
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
                    // Spawn a task to handle the service discovery thread
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
                    error!("Failed to start service discovery: {}", e);
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

391
    info!(
392
        "Router ready | workers: {:?}",
393
        app_state.router.get_worker_urls()
394
    );
395

396
397
398
399
400
401
402
403
404
405
    // Configure request ID headers
    let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

406
407
408
409
410
411
412
    // Build the application
    let app = build_app(
        app_state,
        config.max_payload_size,
        request_id_headers,
        config.router_config.cors_allowed_origins.clone(),
    );
413

414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
    // Create TCP listener - use the configured host
    let addr = format!("{}:{}", config.host, config.port);
    let listener = TcpListener::bind(&addr).await?;

    // Start server with graceful shutdown
    info!("Starting server on {}", addr);

    // Serve the application with graceful shutdown
    axum::serve(listener, app)
        .with_graceful_shutdown(shutdown_signal())
        .await
        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;

    Ok(())
}

// Graceful shutdown handler
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {
            info!("Received Ctrl+C, starting graceful shutdown");
        },
        _ = terminate => {
            info!("Received terminate signal, starting graceful shutdown");
        },
    }
}

// CORS Layer Creation
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
    use tower_http::cors::Any;

    let cors = if allowed_origins.is_empty() {
        // Allow all origins if none specified
        tower_http::cors::CorsLayer::new()
            .allow_origin(Any)
            .allow_methods(Any)
            .allow_headers(Any)
            .expose_headers(Any)
    } else {
        // Restrict to specific origins
        let origins: Vec<http::HeaderValue> = allowed_origins
            .into_iter()
            .filter_map(|origin| origin.parse().ok())
            .collect();

        tower_http::cors::CorsLayer::new()
            .allow_origin(origins)
            .allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
            .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
            .expose_headers([http::header::HeaderName::from_static("x-request-id")])
    };

    cors.max_age(Duration::from_secs(3600))
485
}