server.rs 14.2 KB
Newer Older
1
use crate::config::RouterConfig;
2
use crate::logging::{self, LoggingConfig};
3
use crate::metrics::{self, PrometheusConfig};
4
use crate::middleware::TokenBucket;
5
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
6
use crate::routers::{RouterFactory, RouterTrait};
7
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
8
9
10
11
12
13
use axum::{
    extract::{Query, Request, State},
    http::StatusCode,
    response::{IntoResponse, Response},
    routing::{get, post},
    Json, Router,
14
};
15
use reqwest::Client;
16
use std::collections::HashMap;
17
use std::sync::atomic::{AtomicBool, Ordering};
18
use std::sync::Arc;
19
use std::time::Duration;
20
21
use tokio::net::TcpListener;
use tokio::signal;
22
23
use tokio::spawn;
use tracing::{error, info, warn, Level};
24

25
#[derive(Clone)]
26
pub struct AppContext {
27
    pub client: Client,
28
    pub router_config: RouterConfig,
29
    pub rate_limiter: Arc<TokenBucket>,
30
    // Future dependencies can be added here
31
32
}

33
impl AppContext {
34
35
36
37
    pub fn new(
        router_config: RouterConfig,
        client: Client,
        max_concurrent_requests: usize,
38
        rate_limit_tokens_per_second: Option<usize>,
39
    ) -> Self {
40
41
        let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
        let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
42
        Self {
43
            client,
44
            router_config,
45
            rate_limiter,
46
        }
47
48
49
    }
}

50
51
52
53
#[derive(Clone)]
pub struct AppState {
    pub router: Arc<dyn RouterTrait>,
    pub context: Arc<AppContext>,
54
    pub concurrency_queue_tx: Option<tokio::sync::mpsc::Sender<crate::middleware::QueuedRequest>>,
55
56
}

57
58
59
// Fallback handler for unmatched routes
async fn sink_handler() -> Response {
    StatusCode::NOT_FOUND.into_response()
60
61
}

62
63
64
// Health check endpoints
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
    state.router.liveness()
65
66
}

67
68
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
    state.router.readiness()
69
70
}

71
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
72
    state.router.health(req).await
73
74
}

75
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
76
    state.router.health_generate(req).await
77
78
}

79
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
80
    state.router.get_server_info(req).await
81
82
}

83
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
84
    state.router.get_models(req).await
85
86
}

87
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
88
    state.router.get_model_info(req).await
89
}
90

91
92
// Generation endpoints
// The RouterTrait now accepts optional headers and typed body directly
93
async fn generate(
94
95
96
97
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<GenerateRequest>,
) -> Response {
98
    state.router.route_generate(Some(&headers), &body).await
99
100
101
}

async fn v1_chat_completions(
102
103
104
105
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ChatCompletionRequest>,
) -> Response {
106
    state.router.route_chat(Some(&headers), &body).await
107
108
109
}

async fn v1_completions(
110
111
112
113
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<CompletionRequest>,
) -> Response {
114
    state.router.route_completion(Some(&headers), &body).await
115
116
}

117
// Worker management endpoints
118
async fn add_worker(
119
120
121
122
    State(state): State<Arc<AppState>>,
    Query(params): Query<HashMap<String, String>>,
) -> Response {
    let worker_url = match params.get("url") {
123
124
        Some(url) => url.to_string(),
        None => {
125
126
127
128
129
            return (
                StatusCode::BAD_REQUEST,
                "Worker URL required. Provide 'url' query parameter",
            )
                .into_response();
130
131
        }
    };
132

133
134
135
    match state.router.add_worker(&worker_url).await {
        Ok(message) => (StatusCode::OK, message).into_response(),
        Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
136
    }
137
138
}

139
140
141
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
    let worker_list = state.router.get_worker_urls();
    Json(serde_json::json!({ "urls": worker_list })).into_response()
142
143
}

144
async fn remove_worker(
145
146
147
148
    State(state): State<Arc<AppState>>,
    Query(params): Query<HashMap<String, String>>,
) -> Response {
    let worker_url = match params.get("url") {
149
        Some(url) => url.to_string(),
150
        None => return StatusCode::BAD_REQUEST.into_response(),
151
    };
152

153
154
155
156
157
158
    state.router.remove_worker(&worker_url);
    (
        StatusCode::OK,
        format!("Successfully removed worker: {}", worker_url),
    )
        .into_response()
159
160
}

161
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
162
    state.router.flush_cache().await
163
164
}

165
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
166
    state.router.get_worker_loads().await
167
168
}

169
170
171
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
172
    pub router_config: RouterConfig,
173
    pub max_payload_size: usize,
174
    pub log_dir: Option<String>,
175
    pub log_level: Option<String>,
176
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
177
    pub prometheus_config: Option<PrometheusConfig>,
178
    pub request_timeout_secs: u64,
179
    pub request_id_headers: Option<Vec<String>>,
180
181
}

182
183
184
185
186
187
188
189
190
191
192
/// Build the Axum application with all routes and middleware
pub fn build_app(
    app_state: Arc<AppState>,
    max_payload_size: usize,
    request_id_headers: Vec<String>,
    cors_allowed_origins: Vec<String>,
) -> Router {
    // Create routes
    let protected_routes = Router::new()
        .route("/generate", post(generate))
        .route("/v1/chat/completions", post(v1_chat_completions))
193
194
195
196
197
        .route("/v1/completions", post(v1_completions))
        .route_layer(axum::middleware::from_fn_with_state(
            app_state.clone(),
            crate::middleware::concurrency_limit_middleware,
        ));
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237

    let public_routes = Router::new()
        .route("/liveness", get(liveness))
        .route("/readiness", get(readiness))
        .route("/health", get(health))
        .route("/health_generate", get(health_generate))
        .route("/v1/models", get(v1_models))
        .route("/get_model_info", get(get_model_info))
        .route("/get_server_info", get(get_server_info));

    let admin_routes = Router::new()
        .route("/add_worker", post(add_worker))
        .route("/remove_worker", post(remove_worker))
        .route("/list_workers", get(list_workers))
        .route("/flush_cache", post(flush_cache))
        .route("/get_loads", get(get_loads));

    // Build app with all routes and middleware
    Router::new()
        .merge(protected_routes)
        .merge(public_routes)
        .merge(admin_routes)
        // Request body size limiting
        .layer(tower_http::limit::RequestBodyLimitLayer::new(
            max_payload_size,
        ))
        // Request ID layer - must be added AFTER logging layer in the code
        // so it executes BEFORE logging layer at runtime (layers execute bottom-up)
        .layer(crate::middleware::RequestIdLayer::new(request_id_headers))
        // Custom logging layer that can now see request IDs from extensions
        .layer(crate::middleware::create_logging_layer())
        // CORS (should be outermost)
        .layer(create_cors_layer(cors_allowed_origins))
        // Fallback
        .fallback(sink_handler)
        // State - apply last to get Router<Arc<AppState>>
        .with_state(app_state)
}

pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
238
239
240
241
242
    // Only initialize logging if not already done (for Python bindings support)
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
243
244
245
246
247
248
249
250
251
252
253
            level: config
                .log_level
                .as_deref()
                .and_then(|s| match s.to_uppercase().parse::<Level>() {
                    Ok(l) => Some(l),
                    Err(_) => {
                        warn!("Invalid log level string: '{}'. Defaulting to INFO.", s);
                        None
                    }
                })
                .unwrap_or(Level::INFO),
254
255
256
257
258
259
260
261
262
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
263

264
265
    // Initialize prometheus metrics exporter
    if let Some(prometheus_config) = config.prometheus_config {
266
        metrics::start_prometheus(prometheus_config);
267
268
    }

269
    info!(
270
271
272
273
274
        "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
        config.host,
        config.port,
        config.router_config.mode,
        config.router_config.policy,
275
276
277
        config.max_payload_size / (1024 * 1024)
    );

278
    let client = Client::builder()
279
        .pool_idle_timeout(Some(Duration::from_secs(50)))
280
        .pool_max_idle_per_host(500) // Increase to 500 connections per host
281
282
283
284
        .timeout(Duration::from_secs(config.request_timeout_secs))
        .connect_timeout(Duration::from_secs(10)) // Separate connection timeout
        .tcp_nodelay(true)
        .tcp_keepalive(Some(Duration::from_secs(30))) // Keep connections alive
285
286
287
        .build()
        .expect("Failed to create HTTP client");

288
289
    // Create the application context with all dependencies
    let app_context = Arc::new(AppContext::new(
290
291
292
        config.router_config.clone(),
        client.clone(),
        config.router_config.max_concurrent_requests,
293
        config.router_config.rate_limit_tokens_per_second,
294
295
296
    ));

    // Create router with the context
297
    let router = RouterFactory::create_router(&app_context).await?;
298

299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    // Set up concurrency limiter with queue if configured
    let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new(
        app_context.rate_limiter.clone(),
        config.router_config.queue_size,
        Duration::from_secs(config.router_config.queue_timeout_secs),
    );

    // Start queue processor if enabled
    if let Some(processor) = processor {
        tokio::spawn(processor.run());
        info!(
            "Started request queue with size: {}, timeout: {}s",
            config.router_config.queue_size, config.router_config.queue_timeout_secs
        );
    }

315
316
317
318
    // Create app state with router and context
    let app_state = Arc::new(AppState {
        router: Arc::from(router),
        context: app_context.clone(),
319
        concurrency_queue_tx: limiter.queue_tx.clone(),
320
    });
321
    let router_arc = Arc::clone(&app_state.router);
322

323
324
325
    // Start the service discovery if enabled
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
326
            match start_service_discovery(service_discovery_config, router_arc).await {
327
                Ok(handle) => {
328
                    info!("Service discovery started");
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
                    // Spawn a task to handle the service discovery thread
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
                    error!("Failed to start service discovery: {}", e);
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

344
    info!(
345
        "Router ready | workers: {:?}",
346
        app_state.router.get_worker_urls()
347
    );
348

349
350
351
352
353
354
355
356
357
358
    // Configure request ID headers
    let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

359
360
361
362
363
364
365
    // Build the application
    let app = build_app(
        app_state,
        config.max_payload_size,
        request_id_headers,
        config.router_config.cors_allowed_origins.clone(),
    );
366

367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    // Create TCP listener - use the configured host
    let addr = format!("{}:{}", config.host, config.port);
    let listener = TcpListener::bind(&addr).await?;

    // Start server with graceful shutdown
    info!("Starting server on {}", addr);

    // Serve the application with graceful shutdown
    axum::serve(listener, app)
        .with_graceful_shutdown(shutdown_signal())
        .await
        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;

    Ok(())
}

// Graceful shutdown handler
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {
            info!("Received Ctrl+C, starting graceful shutdown");
        },
        _ = terminate => {
            info!("Received terminate signal, starting graceful shutdown");
        },
    }
}

// CORS Layer Creation
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
    use tower_http::cors::Any;

    let cors = if allowed_origins.is_empty() {
        // Allow all origins if none specified
        tower_http::cors::CorsLayer::new()
            .allow_origin(Any)
            .allow_methods(Any)
            .allow_headers(Any)
            .expose_headers(Any)
    } else {
        // Restrict to specific origins
        let origins: Vec<http::HeaderValue> = allowed_origins
            .into_iter()
            .filter_map(|origin| origin.parse().ok())
            .collect();

        tower_http::cors::CorsLayer::new()
            .allow_origin(origins)
            .allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
            .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
            .expose_headers([http::header::HeaderName::from_static("x-request-id")])
    };

    cors.max_age(Duration::from_secs(3600))
438
}