server.rs 13.2 KB
Newer Older
1
use crate::config::RouterConfig;
2
use crate::logging::{self, LoggingConfig};
3
use crate::metrics::{self, PrometheusConfig};
4
5
6
7
use crate::protocols::{
    generate::GenerateRequest,
    openai::{chat::ChatCompletionRequest, completions::CompletionRequest},
};
8
use crate::routers::{RouterFactory, RouterTrait};
9
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
10
11
12
13
14
15
use axum::{
    extract::{Query, Request, State},
    http::StatusCode,
    response::{IntoResponse, Response},
    routing::{get, post},
    Json, Router,
16
};
17
use reqwest::Client;
18
use std::collections::HashMap;
19
use std::sync::atomic::{AtomicBool, Ordering};
20
use std::sync::Arc;
21
use std::time::Duration;
22
23
use tokio::net::TcpListener;
use tokio::signal;
24
25
use tokio::spawn;
use tracing::{error, info, warn, Level};
26

27
#[derive(Clone)]
28
pub struct AppContext {
29
    pub client: Client,
30
31
32
    pub router_config: RouterConfig,
    pub concurrency_limiter: Arc<tokio::sync::Semaphore>,
    // Future dependencies can be added here
33
34
}

35
impl AppContext {
36
37
38
39
    pub fn new(
        router_config: RouterConfig,
        client: Client,
        max_concurrent_requests: usize,
40
    ) -> Self {
41
        let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests));
42
        Self {
43
            client,
44
45
46
            router_config,
            concurrency_limiter,
        }
47
48
49
    }
}

50
51
52
53
54
55
#[derive(Clone)]
pub struct AppState {
    pub router: Arc<dyn RouterTrait>,
    pub context: Arc<AppContext>,
}

56
57
58
// Fallback handler for unmatched routes
async fn sink_handler() -> Response {
    StatusCode::NOT_FOUND.into_response()
59
60
}

61
62
63
// Health check endpoints
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
    state.router.liveness()
64
65
}

66
67
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
    state.router.readiness()
68
69
}

70
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
71
    state.router.health(req).await
72
73
}

74
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
75
    state.router.health_generate(req).await
76
77
}

78
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
79
    state.router.get_server_info(req).await
80
81
}

82
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
83
    state.router.get_models(req).await
84
85
}

86
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
87
    state.router.get_model_info(req).await
88
}
89

90
91
// Generation endpoints
// The RouterTrait now accepts optional headers and typed body directly
92
async fn generate(
93
94
95
96
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<GenerateRequest>,
) -> Response {
97
    state.router.route_generate(Some(&headers), &body).await
98
99
100
}

async fn v1_chat_completions(
101
102
103
104
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ChatCompletionRequest>,
) -> Response {
105
    state.router.route_chat(Some(&headers), &body).await
106
107
108
}

async fn v1_completions(
109
110
111
112
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<CompletionRequest>,
) -> Response {
113
    state.router.route_completion(Some(&headers), &body).await
114
115
}

116
// Worker management endpoints
117
async fn add_worker(
118
119
120
121
    State(state): State<Arc<AppState>>,
    Query(params): Query<HashMap<String, String>>,
) -> Response {
    let worker_url = match params.get("url") {
122
123
        Some(url) => url.to_string(),
        None => {
124
125
126
127
128
            return (
                StatusCode::BAD_REQUEST,
                "Worker URL required. Provide 'url' query parameter",
            )
                .into_response();
129
130
        }
    };
131

132
133
134
    match state.router.add_worker(&worker_url).await {
        Ok(message) => (StatusCode::OK, message).into_response(),
        Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
135
    }
136
137
}

138
139
140
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
    let worker_list = state.router.get_worker_urls();
    Json(serde_json::json!({ "urls": worker_list })).into_response()
141
142
}

143
async fn remove_worker(
144
145
146
147
    State(state): State<Arc<AppState>>,
    Query(params): Query<HashMap<String, String>>,
) -> Response {
    let worker_url = match params.get("url") {
148
        Some(url) => url.to_string(),
149
        None => return StatusCode::BAD_REQUEST.into_response(),
150
    };
151

152
153
154
155
156
157
    state.router.remove_worker(&worker_url);
    (
        StatusCode::OK,
        format!("Successfully removed worker: {}", worker_url),
    )
        .into_response()
158
159
}

160
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
161
    state.router.flush_cache().await
162
163
}

164
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
165
    state.router.get_worker_loads().await
166
167
}

168
169
170
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
171
    pub router_config: RouterConfig,
172
    pub max_payload_size: usize,
173
    pub log_dir: Option<String>,
174
    pub log_level: Option<String>,
175
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
176
    pub prometheus_config: Option<PrometheusConfig>,
177
    pub request_timeout_secs: u64,
178
    pub request_id_headers: Option<Vec<String>>,
179
180
}

181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
/// Build the Axum application with all routes and middleware
pub fn build_app(
    app_state: Arc<AppState>,
    max_payload_size: usize,
    request_id_headers: Vec<String>,
    cors_allowed_origins: Vec<String>,
) -> Router {
    // Create routes
    let protected_routes = Router::new()
        .route("/generate", post(generate))
        .route("/v1/chat/completions", post(v1_chat_completions))
        .route("/v1/completions", post(v1_completions));

    let public_routes = Router::new()
        .route("/liveness", get(liveness))
        .route("/readiness", get(readiness))
        .route("/health", get(health))
        .route("/health_generate", get(health_generate))
        .route("/v1/models", get(v1_models))
        .route("/get_model_info", get(get_model_info))
        .route("/get_server_info", get(get_server_info));

    let admin_routes = Router::new()
        .route("/add_worker", post(add_worker))
        .route("/remove_worker", post(remove_worker))
        .route("/list_workers", get(list_workers))
        .route("/flush_cache", post(flush_cache))
        .route("/get_loads", get(get_loads));

    // Build app with all routes and middleware
    Router::new()
        .merge(protected_routes)
        .merge(public_routes)
        .merge(admin_routes)
        // Request body size limiting
        .layer(tower_http::limit::RequestBodyLimitLayer::new(
            max_payload_size,
        ))
        // Request ID layer - must be added AFTER logging layer in the code
        // so it executes BEFORE logging layer at runtime (layers execute bottom-up)
        .layer(crate::middleware::RequestIdLayer::new(request_id_headers))
        // Custom logging layer that can now see request IDs from extensions
        .layer(crate::middleware::create_logging_layer())
        // CORS (should be outermost)
        .layer(create_cors_layer(cors_allowed_origins))
        // Fallback
        .fallback(sink_handler)
        // State - apply last to get Router<Arc<AppState>>
        .with_state(app_state)
}

pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
233
234
235
236
237
    // Only initialize logging if not already done (for Python bindings support)
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
238
239
240
241
242
243
244
245
246
247
248
            level: config
                .log_level
                .as_deref()
                .and_then(|s| match s.to_uppercase().parse::<Level>() {
                    Ok(l) => Some(l),
                    Err(_) => {
                        warn!("Invalid log level string: '{}'. Defaulting to INFO.", s);
                        None
                    }
                })
                .unwrap_or(Level::INFO),
249
250
251
252
253
254
255
256
257
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
258

259
260
    // Initialize prometheus metrics exporter
    if let Some(prometheus_config) = config.prometheus_config {
261
        metrics::start_prometheus(prometheus_config);
262
263
    }

264
    info!(
265
266
267
268
269
        "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
        config.host,
        config.port,
        config.router_config.mode,
        config.router_config.policy,
270
271
272
        config.max_payload_size / (1024 * 1024)
    );

273
    let client = Client::builder()
274
        .pool_idle_timeout(Some(Duration::from_secs(50)))
275
        .pool_max_idle_per_host(500) // Increase to 500 connections per host
276
277
278
279
        .timeout(Duration::from_secs(config.request_timeout_secs))
        .connect_timeout(Duration::from_secs(10)) // Separate connection timeout
        .tcp_nodelay(true)
        .tcp_keepalive(Some(Duration::from_secs(30))) // Keep connections alive
280
281
282
        .build()
        .expect("Failed to create HTTP client");

283
284
    // Create the application context with all dependencies
    let app_context = Arc::new(AppContext::new(
285
286
287
        config.router_config.clone(),
        client.clone(),
        config.router_config.max_concurrent_requests,
288
289
290
    ));

    // Create router with the context
291
    let router = RouterFactory::create_router(&app_context).await?;
292
293
294
295
296
297

    // Create app state with router and context
    let app_state = Arc::new(AppState {
        router: Arc::from(router),
        context: app_context.clone(),
    });
298
    let router_arc = Arc::clone(&app_state.router);
299

300
301
302
    // Start the service discovery if enabled
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
303
            match start_service_discovery(service_discovery_config, router_arc).await {
304
                Ok(handle) => {
305
                    info!("Service discovery started");
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
                    // Spawn a task to handle the service discovery thread
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
                    error!("Failed to start service discovery: {}", e);
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

321
    info!(
322
        "Router ready | workers: {:?}",
323
        app_state.router.get_worker_urls()
324
    );
325

326
327
328
329
330
331
332
333
334
335
    // Configure request ID headers
    let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

336
337
338
339
340
341
342
    // Build the application
    let app = build_app(
        app_state,
        config.max_payload_size,
        request_id_headers,
        config.router_config.cors_allowed_origins.clone(),
    );
343

344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    // Create TCP listener - use the configured host
    let addr = format!("{}:{}", config.host, config.port);
    let listener = TcpListener::bind(&addr).await?;

    // Start server with graceful shutdown
    info!("Starting server on {}", addr);

    // Serve the application with graceful shutdown
    axum::serve(listener, app)
        .with_graceful_shutdown(shutdown_signal())
        .await
        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;

    Ok(())
}

// Graceful shutdown handler
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {
            info!("Received Ctrl+C, starting graceful shutdown");
        },
        _ = terminate => {
            info!("Received terminate signal, starting graceful shutdown");
        },
    }
}

// CORS Layer Creation
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
    use tower_http::cors::Any;

    let cors = if allowed_origins.is_empty() {
        // Allow all origins if none specified
        tower_http::cors::CorsLayer::new()
            .allow_origin(Any)
            .allow_methods(Any)
            .allow_headers(Any)
            .expose_headers(Any)
    } else {
        // Restrict to specific origins
        let origins: Vec<http::HeaderValue> = allowed_origins
            .into_iter()
            .filter_map(|origin| origin.parse().ok())
            .collect();

        tower_http::cors::CorsLayer::new()
            .allow_origin(origins)
            .allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
            .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
            .expose_headers([http::header::HeaderName::from_static("x-request-id")])
    };

    cors.max_age(Duration::from_secs(3600))
415
}