server.rs 15.9 KB
Newer Older
1
use crate::config::RouterConfig;
2
use crate::logging::{self, LoggingConfig};
3
use crate::metrics::{self, PrometheusConfig};
4
use crate::middleware::TokenBucket;
5
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
6
use crate::reasoning_parser::ParserFactory;
7
use crate::routers::{RouterFactory, RouterTrait};
8
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
9
10
use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer};
use crate::tool_parser::ParserRegistry;
11
12
13
14
15
16
use axum::{
    extract::{Query, Request, State},
    http::StatusCode,
    response::{IntoResponse, Response},
    routing::{get, post},
    Json, Router,
17
};
18
use reqwest::Client;
19
use std::collections::HashMap;
20
use std::sync::atomic::{AtomicBool, Ordering};
21
use std::sync::Arc;
22
use std::time::Duration;
23
24
use tokio::net::TcpListener;
use tokio::signal;
25
26
use tokio::spawn;
use tracing::{error, info, warn, Level};
27

28
#[derive(Clone)]
29
pub struct AppContext {
30
    pub client: Client,
31
    pub router_config: RouterConfig,
32
    pub rate_limiter: Arc<TokenBucket>,
33
34
35
    pub tokenizer: Option<Arc<dyn Tokenizer>>,
    pub reasoning_parser_factory: Option<ParserFactory>,
    pub tool_parser_registry: Option<&'static ParserRegistry>,
36
37
}

38
impl AppContext {
39
40
41
42
    pub fn new(
        router_config: RouterConfig,
        client: Client,
        max_concurrent_requests: usize,
43
        rate_limit_tokens_per_second: Option<usize>,
44
    ) -> Result<Self, String> {
45
46
        let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
        let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

        // Initialize gRPC-specific components only when in gRPC mode
        let (tokenizer, reasoning_parser_factory, tool_parser_registry) =
            if router_config.connection_mode == crate::config::ConnectionMode::Grpc {
                // Get tokenizer path (required for gRPC mode)
                let tokenizer_path = router_config
                    .tokenizer_path
                    .clone()
                    .or_else(|| router_config.model_path.clone())
                    .ok_or_else(|| {
                        "gRPC mode requires either --tokenizer-path or --model-path to be specified"
                            .to_string()
                    })?;

                // Initialize all gRPC components
                let tokenizer = Some(
                    tokenizer_factory::create_tokenizer(&tokenizer_path)
                        .map_err(|e| format!("Failed to create tokenizer: {}", e))?,
                );
                let reasoning_parser_factory = Some(ParserFactory::new());
                let tool_parser_registry = Some(ParserRegistry::new());

                (tokenizer, reasoning_parser_factory, tool_parser_registry)
            } else {
                // HTTP mode doesn't need these components
                (None, None, None)
            };

        Ok(Self {
76
            client,
77
            router_config,
78
            rate_limiter,
79
80
81
82
            tokenizer,
            reasoning_parser_factory,
            tool_parser_registry,
        })
83
84
85
    }
}

86
87
88
89
#[derive(Clone)]
pub struct AppState {
    pub router: Arc<dyn RouterTrait>,
    pub context: Arc<AppContext>,
90
    pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<crate::middleware::QueuedRequest>>,
91
92
}

93
94
95
// Fallback handler for unmatched routes
async fn sink_handler() -> Response {
    StatusCode::NOT_FOUND.into_response()
96
97
}

98
99
100
// Health check endpoints
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
    state.router.liveness()
101
102
}

103
104
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
    state.router.readiness()
105
106
}

107
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
108
    state.router.health(req).await
109
110
}

111
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
112
    state.router.health_generate(req).await
113
114
}

115
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
116
    state.router.get_server_info(req).await
117
118
}

119
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
120
    state.router.get_models(req).await
121
122
}

123
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
124
    state.router.get_model_info(req).await
125
}
126

127
128
// Generation endpoints
// The RouterTrait now accepts optional headers and typed body directly
129
async fn generate(
130
131
132
133
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<GenerateRequest>,
) -> Response {
134
    state.router.route_generate(Some(&headers), &body).await
135
136
137
}

async fn v1_chat_completions(
138
139
140
141
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ChatCompletionRequest>,
) -> Response {
142
    state.router.route_chat(Some(&headers), &body).await
143
144
145
}

async fn v1_completions(
146
147
148
149
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<CompletionRequest>,
) -> Response {
150
    state.router.route_completion(Some(&headers), &body).await
151
152
}

153
// Worker management endpoints
154
async fn add_worker(
155
156
157
158
    State(state): State<Arc<AppState>>,
    Query(params): Query<HashMap<String, String>>,
) -> Response {
    let worker_url = match params.get("url") {
159
160
        Some(url) => url.to_string(),
        None => {
161
162
163
164
165
            return (
                StatusCode::BAD_REQUEST,
                "Worker URL required. Provide 'url' query parameter",
            )
                .into_response();
166
167
        }
    };
168

169
170
171
    match state.router.add_worker(&worker_url).await {
        Ok(message) => (StatusCode::OK, message).into_response(),
        Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
172
    }
173
174
}

175
176
177
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
    let worker_list = state.router.get_worker_urls();
    Json(serde_json::json!({ "urls": worker_list })).into_response()
178
179
}

180
async fn remove_worker(
181
182
183
184
    State(state): State<Arc<AppState>>,
    Query(params): Query<HashMap<String, String>>,
) -> Response {
    let worker_url = match params.get("url") {
185
        Some(url) => url.to_string(),
186
        None => return StatusCode::BAD_REQUEST.into_response(),
187
    };
188

189
190
191
192
193
194
    state.router.remove_worker(&worker_url);
    (
        StatusCode::OK,
        format!("Successfully removed worker: {}", worker_url),
    )
        .into_response()
195
196
}

197
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
198
    state.router.flush_cache().await
199
200
}

201
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
202
    state.router.get_worker_loads().await
203
204
}

205
206
207
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
208
    pub router_config: RouterConfig,
209
    pub max_payload_size: usize,
210
    pub log_dir: Option<String>,
211
    pub log_level: Option<String>,
212
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
213
    pub prometheus_config: Option<PrometheusConfig>,
214
    pub request_timeout_secs: u64,
215
    pub request_id_headers: Option<Vec<String>>,
216
217
}

218
219
220
221
222
223
224
225
226
227
228
/// Build the Axum application with all routes and middleware
pub fn build_app(
    app_state: Arc<AppState>,
    max_payload_size: usize,
    request_id_headers: Vec<String>,
    cors_allowed_origins: Vec<String>,
) -> Router {
    // Create routes
    let protected_routes = Router::new()
        .route("/generate", post(generate))
        .route("/v1/chat/completions", post(v1_chat_completions))
229
230
231
232
233
        .route("/v1/completions", post(v1_completions))
        .route_layer(axum::middleware::from_fn_with_state(
            app_state.clone(),
            crate::middleware::concurrency_limit_middleware,
        ));
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273

    let public_routes = Router::new()
        .route("/liveness", get(liveness))
        .route("/readiness", get(readiness))
        .route("/health", get(health))
        .route("/health_generate", get(health_generate))
        .route("/v1/models", get(v1_models))
        .route("/get_model_info", get(get_model_info))
        .route("/get_server_info", get(get_server_info));

    let admin_routes = Router::new()
        .route("/add_worker", post(add_worker))
        .route("/remove_worker", post(remove_worker))
        .route("/list_workers", get(list_workers))
        .route("/flush_cache", post(flush_cache))
        .route("/get_loads", get(get_loads));

    // Build app with all routes and middleware
    Router::new()
        .merge(protected_routes)
        .merge(public_routes)
        .merge(admin_routes)
        // Request body size limiting
        .layer(tower_http::limit::RequestBodyLimitLayer::new(
            max_payload_size,
        ))
        // Request ID layer - must be added AFTER logging layer in the code
        // so it executes BEFORE logging layer at runtime (layers execute bottom-up)
        .layer(crate::middleware::RequestIdLayer::new(request_id_headers))
        // Custom logging layer that can now see request IDs from extensions
        .layer(crate::middleware::create_logging_layer())
        // CORS (should be outermost)
        .layer(create_cors_layer(cors_allowed_origins))
        // Fallback
        .fallback(sink_handler)
        // State - apply last to get Router<Arc<AppState>>
        .with_state(app_state)
}

pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
274
275
276
277
278
    // Only initialize logging if not already done (for Python bindings support)
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
279
280
281
282
283
284
285
286
287
288
289
            level: config
                .log_level
                .as_deref()
                .and_then(|s| match s.to_uppercase().parse::<Level>() {
                    Ok(l) => Some(l),
                    Err(_) => {
                        warn!("Invalid log level string: '{}'. Defaulting to INFO.", s);
                        None
                    }
                })
                .unwrap_or(Level::INFO),
290
291
292
293
294
295
296
297
298
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
299

300
301
    // Initialize prometheus metrics exporter
    if let Some(prometheus_config) = config.prometheus_config {
302
        metrics::start_prometheus(prometheus_config);
303
304
    }

305
    info!(
306
307
308
309
310
        "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
        config.host,
        config.port,
        config.router_config.mode,
        config.router_config.policy,
311
312
313
        config.max_payload_size / (1024 * 1024)
    );

314
    let client = Client::builder()
315
        .pool_idle_timeout(Some(Duration::from_secs(50)))
316
        .pool_max_idle_per_host(500) // Increase to 500 connections per host
317
318
319
320
        .timeout(Duration::from_secs(config.request_timeout_secs))
        .connect_timeout(Duration::from_secs(10)) // Separate connection timeout
        .tcp_nodelay(true)
        .tcp_keepalive(Some(Duration::from_secs(30))) // Keep connections alive
321
322
323
        .build()
        .expect("Failed to create HTTP client");

324
325
    // Create the application context with all dependencies
    let app_context = Arc::new(AppContext::new(
326
327
328
        config.router_config.clone(),
        client.clone(),
        config.router_config.max_concurrent_requests,
329
        config.router_config.rate_limit_tokens_per_second,
330
    )?);
331
332

    // Create router with the context
333
    let router = RouterFactory::create_router(&app_context).await?;
334

335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    // Set up concurrency limiter with queue if configured
    let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new(
        app_context.rate_limiter.clone(),
        config.router_config.queue_size,
        Duration::from_secs(config.router_config.queue_timeout_secs),
    );

    // Start queue processor if enabled
    if let Some(processor) = processor {
        tokio::spawn(processor.run());
        info!(
            "Started request queue with size: {}, timeout: {}s",
            config.router_config.queue_size, config.router_config.queue_timeout_secs
        );
    }

351
352
353
354
    // Create app state with router and context
    let app_state = Arc::new(AppState {
        router: Arc::from(router),
        context: app_context.clone(),
355
        concurrency_queue_tx: limiter.queue_tx.clone(),
356
    });
357
    let router_arc = Arc::clone(&app_state.router);
358

359
360
361
    // Start the service discovery if enabled
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
362
            match start_service_discovery(service_discovery_config, router_arc).await {
363
                Ok(handle) => {
364
                    info!("Service discovery started");
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
                    // Spawn a task to handle the service discovery thread
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
                    error!("Failed to start service discovery: {}", e);
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

380
    info!(
381
        "Router ready | workers: {:?}",
382
        app_state.router.get_worker_urls()
383
    );
384

385
386
387
388
389
390
391
392
393
394
    // Configure request ID headers
    let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

395
396
397
398
399
400
401
    // Build the application
    let app = build_app(
        app_state,
        config.max_payload_size,
        request_id_headers,
        config.router_config.cors_allowed_origins.clone(),
    );
402

403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
    // Create TCP listener - use the configured host
    let addr = format!("{}:{}", config.host, config.port);
    let listener = TcpListener::bind(&addr).await?;

    // Start server with graceful shutdown
    info!("Starting server on {}", addr);

    // Serve the application with graceful shutdown
    axum::serve(listener, app)
        .with_graceful_shutdown(shutdown_signal())
        .await
        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;

    Ok(())
}

// Graceful shutdown handler
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {
            info!("Received Ctrl+C, starting graceful shutdown");
        },
        _ = terminate => {
            info!("Received terminate signal, starting graceful shutdown");
        },
    }
}

// CORS Layer Creation
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
    use tower_http::cors::Any;

    let cors = if allowed_origins.is_empty() {
        // Allow all origins if none specified
        tower_http::cors::CorsLayer::new()
            .allow_origin(Any)
            .allow_methods(Any)
            .allow_headers(Any)
            .expose_headers(Any)
    } else {
        // Restrict to specific origins
        let origins: Vec<http::HeaderValue> = allowed_origins
            .into_iter()
            .filter_map(|origin| origin.parse().ok())
            .collect();

        tower_http::cors::CorsLayer::new()
            .allow_origin(origins)
            .allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
            .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
            .expose_headers([http::header::HeaderName::from_static("x-request-id")])
    };

    cors.max_age(Duration::from_secs(3600))
474
}