server.rs 13.1 KB
Newer Older
1
use crate::config::RouterConfig;
2
use crate::logging::{self, LoggingConfig};
3
use crate::metrics::{self, PrometheusConfig};
4
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
5
use crate::routers::{RouterFactory, RouterTrait};
6
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
7
8
9
10
11
12
use axum::{
    extract::{Query, Request, State},
    http::StatusCode,
    response::{IntoResponse, Response},
    routing::{get, post},
    Json, Router,
13
};
14
use reqwest::Client;
15
use std::collections::HashMap;
16
use std::sync::atomic::{AtomicBool, Ordering};
17
use std::sync::Arc;
18
use std::time::Duration;
19
20
use tokio::net::TcpListener;
use tokio::signal;
21
22
use tokio::spawn;
use tracing::{error, info, warn, Level};
23

24
#[derive(Clone)]
25
pub struct AppContext {
26
    pub client: Client,
27
28
29
    pub router_config: RouterConfig,
    pub concurrency_limiter: Arc<tokio::sync::Semaphore>,
    // Future dependencies can be added here
30
31
}

32
impl AppContext {
33
34
35
36
    pub fn new(
        router_config: RouterConfig,
        client: Client,
        max_concurrent_requests: usize,
37
    ) -> Self {
38
        let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests));
39
        Self {
40
            client,
41
42
43
            router_config,
            concurrency_limiter,
        }
44
45
46
    }
}

47
48
49
50
51
52
#[derive(Clone)]
pub struct AppState {
    pub router: Arc<dyn RouterTrait>,
    pub context: Arc<AppContext>,
}

53
54
55
// Fallback handler for unmatched routes
async fn sink_handler() -> Response {
    StatusCode::NOT_FOUND.into_response()
56
57
}

58
59
60
// Health check endpoints
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
    state.router.liveness()
61
62
}

63
64
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
    state.router.readiness()
65
66
}

67
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
68
    state.router.health(req).await
69
70
}

71
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
72
    state.router.health_generate(req).await
73
74
}

75
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
76
    state.router.get_server_info(req).await
77
78
}

79
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
80
    state.router.get_models(req).await
81
82
}

83
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
84
    state.router.get_model_info(req).await
85
}
86

87
88
// Generation endpoints
// The RouterTrait now accepts optional headers and typed body directly
89
async fn generate(
90
91
92
93
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<GenerateRequest>,
) -> Response {
94
    state.router.route_generate(Some(&headers), &body).await
95
96
97
}

async fn v1_chat_completions(
98
99
100
101
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ChatCompletionRequest>,
) -> Response {
102
    state.router.route_chat(Some(&headers), &body).await
103
104
105
}

async fn v1_completions(
106
107
108
109
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<CompletionRequest>,
) -> Response {
110
    state.router.route_completion(Some(&headers), &body).await
111
112
}

113
// Worker management endpoints
114
async fn add_worker(
115
116
117
118
    State(state): State<Arc<AppState>>,
    Query(params): Query<HashMap<String, String>>,
) -> Response {
    let worker_url = match params.get("url") {
119
120
        Some(url) => url.to_string(),
        None => {
121
122
123
124
125
            return (
                StatusCode::BAD_REQUEST,
                "Worker URL required. Provide 'url' query parameter",
            )
                .into_response();
126
127
        }
    };
128

129
130
131
    match state.router.add_worker(&worker_url).await {
        Ok(message) => (StatusCode::OK, message).into_response(),
        Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
132
    }
133
134
}

135
136
137
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
    let worker_list = state.router.get_worker_urls();
    Json(serde_json::json!({ "urls": worker_list })).into_response()
138
139
}

140
async fn remove_worker(
141
142
143
144
    State(state): State<Arc<AppState>>,
    Query(params): Query<HashMap<String, String>>,
) -> Response {
    let worker_url = match params.get("url") {
145
        Some(url) => url.to_string(),
146
        None => return StatusCode::BAD_REQUEST.into_response(),
147
    };
148

149
150
151
152
153
154
    state.router.remove_worker(&worker_url);
    (
        StatusCode::OK,
        format!("Successfully removed worker: {}", worker_url),
    )
        .into_response()
155
156
}

157
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
158
    state.router.flush_cache().await
159
160
}

161
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
162
    state.router.get_worker_loads().await
163
164
}

165
166
167
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
168
    pub router_config: RouterConfig,
169
    pub max_payload_size: usize,
170
    pub log_dir: Option<String>,
171
    pub log_level: Option<String>,
172
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
173
    pub prometheus_config: Option<PrometheusConfig>,
174
    pub request_timeout_secs: u64,
175
    pub request_id_headers: Option<Vec<String>>,
176
177
}

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
/// Build the Axum application with all routes and middleware
pub fn build_app(
    app_state: Arc<AppState>,
    max_payload_size: usize,
    request_id_headers: Vec<String>,
    cors_allowed_origins: Vec<String>,
) -> Router {
    // Create routes
    let protected_routes = Router::new()
        .route("/generate", post(generate))
        .route("/v1/chat/completions", post(v1_chat_completions))
        .route("/v1/completions", post(v1_completions));

    let public_routes = Router::new()
        .route("/liveness", get(liveness))
        .route("/readiness", get(readiness))
        .route("/health", get(health))
        .route("/health_generate", get(health_generate))
        .route("/v1/models", get(v1_models))
        .route("/get_model_info", get(get_model_info))
        .route("/get_server_info", get(get_server_info));

    let admin_routes = Router::new()
        .route("/add_worker", post(add_worker))
        .route("/remove_worker", post(remove_worker))
        .route("/list_workers", get(list_workers))
        .route("/flush_cache", post(flush_cache))
        .route("/get_loads", get(get_loads));

    // Build app with all routes and middleware
    Router::new()
        .merge(protected_routes)
        .merge(public_routes)
        .merge(admin_routes)
        // Request body size limiting
        .layer(tower_http::limit::RequestBodyLimitLayer::new(
            max_payload_size,
        ))
        // Request ID layer - must be added AFTER logging layer in the code
        // so it executes BEFORE logging layer at runtime (layers execute bottom-up)
        .layer(crate::middleware::RequestIdLayer::new(request_id_headers))
        // Custom logging layer that can now see request IDs from extensions
        .layer(crate::middleware::create_logging_layer())
        // CORS (should be outermost)
        .layer(create_cors_layer(cors_allowed_origins))
        // Fallback
        .fallback(sink_handler)
        // State - apply last to get Router<Arc<AppState>>
        .with_state(app_state)
}

pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
230
231
232
233
234
    // Only initialize logging if not already done (for Python bindings support)
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
235
236
237
238
239
240
241
242
243
244
245
            level: config
                .log_level
                .as_deref()
                .and_then(|s| match s.to_uppercase().parse::<Level>() {
                    Ok(l) => Some(l),
                    Err(_) => {
                        warn!("Invalid log level string: '{}'. Defaulting to INFO.", s);
                        None
                    }
                })
                .unwrap_or(Level::INFO),
246
247
248
249
250
251
252
253
254
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
255

256
257
    // Initialize prometheus metrics exporter
    if let Some(prometheus_config) = config.prometheus_config {
258
        metrics::start_prometheus(prometheus_config);
259
260
    }

261
    info!(
262
263
264
265
266
        "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
        config.host,
        config.port,
        config.router_config.mode,
        config.router_config.policy,
267
268
269
        config.max_payload_size / (1024 * 1024)
    );

270
    let client = Client::builder()
271
        .pool_idle_timeout(Some(Duration::from_secs(50)))
272
        .pool_max_idle_per_host(500) // Increase to 500 connections per host
273
274
275
276
        .timeout(Duration::from_secs(config.request_timeout_secs))
        .connect_timeout(Duration::from_secs(10)) // Separate connection timeout
        .tcp_nodelay(true)
        .tcp_keepalive(Some(Duration::from_secs(30))) // Keep connections alive
277
278
279
        .build()
        .expect("Failed to create HTTP client");

280
281
    // Create the application context with all dependencies
    let app_context = Arc::new(AppContext::new(
282
283
284
        config.router_config.clone(),
        client.clone(),
        config.router_config.max_concurrent_requests,
285
286
287
    ));

    // Create router with the context
288
    let router = RouterFactory::create_router(&app_context).await?;
289
290
291
292
293
294

    // Create app state with router and context
    let app_state = Arc::new(AppState {
        router: Arc::from(router),
        context: app_context.clone(),
    });
295
    let router_arc = Arc::clone(&app_state.router);
296

297
298
299
    // Start the service discovery if enabled
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
300
            match start_service_discovery(service_discovery_config, router_arc).await {
301
                Ok(handle) => {
302
                    info!("Service discovery started");
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
                    // Spawn a task to handle the service discovery thread
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
                    error!("Failed to start service discovery: {}", e);
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

318
    info!(
319
        "Router ready | workers: {:?}",
320
        app_state.router.get_worker_urls()
321
    );
322

323
324
325
326
327
328
329
330
331
332
    // Configure request ID headers
    let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

333
334
335
336
337
338
339
    // Build the application
    let app = build_app(
        app_state,
        config.max_payload_size,
        request_id_headers,
        config.router_config.cors_allowed_origins.clone(),
    );
340

341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    // Create TCP listener - use the configured host
    let addr = format!("{}:{}", config.host, config.port);
    let listener = TcpListener::bind(&addr).await?;

    // Start server with graceful shutdown
    info!("Starting server on {}", addr);

    // Serve the application with graceful shutdown
    axum::serve(listener, app)
        .with_graceful_shutdown(shutdown_signal())
        .await
        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;

    Ok(())
}

// Graceful shutdown handler
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {
            info!("Received Ctrl+C, starting graceful shutdown");
        },
        _ = terminate => {
            info!("Received terminate signal, starting graceful shutdown");
        },
    }
}

// CORS Layer Creation
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
    use tower_http::cors::Any;

    let cors = if allowed_origins.is_empty() {
        // Allow all origins if none specified
        tower_http::cors::CorsLayer::new()
            .allow_origin(Any)
            .allow_methods(Any)
            .allow_headers(Any)
            .expose_headers(Any)
    } else {
        // Restrict to specific origins
        let origins: Vec<http::HeaderValue> = allowed_origins
            .into_iter()
            .filter_map(|origin| origin.parse().ok())
            .collect();

        tower_http::cors::CorsLayer::new()
            .allow_origin(origins)
            .allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
            .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
            .expose_headers([http::header::HeaderName::from_static("x-request-id")])
    };

    cors.max_age(Duration::from_secs(3600))
412
}