server.rs 13 KB
Newer Older
1
use crate::config::RouterConfig;
2
use crate::logging::{self, LoggingConfig};
3
use crate::metrics::{self, PrometheusConfig};
4
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
5
use crate::routers::{RouterFactory, RouterTrait};
6
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
7
8
9
10
11
12
use axum::{
    extract::{Query, Request, State},
    http::StatusCode,
    response::{IntoResponse, Response},
    routing::{get, post},
    Json, Router,
13
};
14
use reqwest::Client;
15
use std::collections::HashMap;
16
use std::sync::atomic::{AtomicBool, Ordering};
17
use std::sync::Arc;
18
use std::time::Duration;
19
20
use tokio::net::TcpListener;
use tokio::signal;
21
22
use tokio::spawn;
use tracing::{error, info, warn, Level};
23

24
#[derive(Clone)]
25
pub struct AppState {
26
27
28
    pub router: Arc<dyn RouterTrait>,
    pub client: Client,
    pub _concurrency_limiter: Arc<tokio::sync::Semaphore>,
29
30
}

31
impl AppState {
32
33
34
35
36
    pub fn new(
        router_config: RouterConfig,
        client: Client,
        max_concurrent_requests: usize,
    ) -> Result<Self, String> {
37
38
        let router = RouterFactory::create_router(&router_config)?;
        let router = Arc::from(router);
39
40
41
42
43
44
        let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests));
        Ok(Self {
            router,
            client,
            _concurrency_limiter: concurrency_limiter,
        })
45
46
47
    }
}

48
49
50
// Fallback handler for unmatched routes
async fn sink_handler() -> Response {
    StatusCode::NOT_FOUND.into_response()
51
52
}

53
54
55
// Health check endpoints
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
    state.router.liveness()
56
57
}

58
59
async fn readiness(State(state): State<Arc<AppState>>) -> Response {
    state.router.readiness()
60
61
}

62
63
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
    state.router.health(&state.client, req).await
64
65
}

66
67
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
    state.router.health_generate(&state.client, req).await
68
69
}

70
71
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
    state.router.get_server_info(&state.client, req).await
72
73
}

74
75
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
    state.router.get_models(&state.client, req).await
76
77
}

78
79
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
    state.router.get_model_info(&state.client, req).await
80
}
81

82
83
// Generation endpoints
// The RouterTrait now accepts optional headers and typed body directly
84
async fn generate(
85
86
87
88
89
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<GenerateRequest>,
) -> Response {
    state
90
        .router
91
92
        .route_generate(&state.client, Some(&headers), &body)
        .await
93
94
95
}

async fn v1_chat_completions(
96
97
98
99
100
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<ChatCompletionRequest>,
) -> Response {
    state
101
        .router
102
103
        .route_chat(&state.client, Some(&headers), &body)
        .await
104
105
106
}

async fn v1_completions(
107
108
109
110
111
    State(state): State<Arc<AppState>>,
    headers: http::HeaderMap,
    Json(body): Json<CompletionRequest>,
) -> Response {
    state
112
        .router
113
114
        .route_completion(&state.client, Some(&headers), &body)
        .await
115
116
}

117
// Worker management endpoints
118
async fn add_worker(
119
120
121
122
    State(state): State<Arc<AppState>>,
    Query(params): Query<HashMap<String, String>>,
) -> Response {
    let worker_url = match params.get("url") {
123
124
        Some(url) => url.to_string(),
        None => {
125
126
127
128
129
            return (
                StatusCode::BAD_REQUEST,
                "Worker URL required. Provide 'url' query parameter",
            )
                .into_response();
130
131
        }
    };
132

133
134
135
    match state.router.add_worker(&worker_url).await {
        Ok(message) => (StatusCode::OK, message).into_response(),
        Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
136
    }
137
138
}

139
140
141
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
    let worker_list = state.router.get_worker_urls();
    Json(serde_json::json!({ "urls": worker_list })).into_response()
142
143
}

144
async fn remove_worker(
145
146
147
148
    State(state): State<Arc<AppState>>,
    Query(params): Query<HashMap<String, String>>,
) -> Response {
    let worker_url = match params.get("url") {
149
        Some(url) => url.to_string(),
150
        None => return StatusCode::BAD_REQUEST.into_response(),
151
    };
152

153
154
155
156
157
158
    state.router.remove_worker(&worker_url);
    (
        StatusCode::OK,
        format!("Successfully removed worker: {}", worker_url),
    )
        .into_response()
159
160
}

161
162
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
    state.router.flush_cache(&state.client).await
163
164
}

165
166
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
    state.router.get_worker_loads(&state.client).await
167
168
}

169
170
171
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
172
    pub router_config: RouterConfig,
173
    pub max_payload_size: usize,
174
    pub log_dir: Option<String>,
175
    pub log_level: Option<String>,
176
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
177
    pub prometheus_config: Option<PrometheusConfig>,
178
    pub request_timeout_secs: u64,
179
    pub request_id_headers: Option<Vec<String>>,
180
181
}

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
/// Build the Axum application with all routes and middleware
pub fn build_app(
    app_state: Arc<AppState>,
    max_payload_size: usize,
    request_id_headers: Vec<String>,
    cors_allowed_origins: Vec<String>,
) -> Router {
    // Create routes
    let protected_routes = Router::new()
        .route("/generate", post(generate))
        .route("/v1/chat/completions", post(v1_chat_completions))
        .route("/v1/completions", post(v1_completions));

    let public_routes = Router::new()
        .route("/liveness", get(liveness))
        .route("/readiness", get(readiness))
        .route("/health", get(health))
        .route("/health_generate", get(health_generate))
        .route("/v1/models", get(v1_models))
        .route("/get_model_info", get(get_model_info))
        .route("/get_server_info", get(get_server_info));

    let admin_routes = Router::new()
        .route("/add_worker", post(add_worker))
        .route("/remove_worker", post(remove_worker))
        .route("/list_workers", get(list_workers))
        .route("/flush_cache", post(flush_cache))
        .route("/get_loads", get(get_loads));

    // Build app with all routes and middleware
    Router::new()
        .merge(protected_routes)
        .merge(public_routes)
        .merge(admin_routes)
        // Request body size limiting
        .layer(tower_http::limit::RequestBodyLimitLayer::new(
            max_payload_size,
        ))
        // Request ID layer - must be added AFTER logging layer in the code
        // so it executes BEFORE logging layer at runtime (layers execute bottom-up)
        .layer(crate::middleware::RequestIdLayer::new(request_id_headers))
        // Custom logging layer that can now see request IDs from extensions
        .layer(crate::middleware::create_logging_layer())
        // CORS (should be outermost)
        .layer(create_cors_layer(cors_allowed_origins))
        // Fallback
        .fallback(sink_handler)
        // State - apply last to get Router<Arc<AppState>>
        .with_state(app_state)
}

pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
234
235
236
237
238
    // Only initialize logging if not already done (for Python bindings support)
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
239
240
241
242
243
244
245
246
247
248
249
            level: config
                .log_level
                .as_deref()
                .and_then(|s| match s.to_uppercase().parse::<Level>() {
                    Ok(l) => Some(l),
                    Err(_) => {
                        warn!("Invalid log level string: '{}'. Defaulting to INFO.", s);
                        None
                    }
                })
                .unwrap_or(Level::INFO),
250
251
252
253
254
255
256
257
258
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
259

260
261
    // Initialize prometheus metrics exporter
    if let Some(prometheus_config) = config.prometheus_config {
262
        metrics::start_prometheus(prometheus_config);
263
264
    }

265
    info!(
266
267
268
269
270
        "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
        config.host,
        config.port,
        config.router_config.mode,
        config.router_config.policy,
271
272
273
        config.max_payload_size / (1024 * 1024)
    );

274
    let client = Client::builder()
275
        .pool_idle_timeout(Some(Duration::from_secs(50)))
276
277
278
279
280
        .pool_max_idle_per_host(100) // Increase from default of 1 to allow more concurrent connections
        .timeout(Duration::from_secs(config.request_timeout_secs))
        .connect_timeout(Duration::from_secs(10)) // Separate connection timeout
        .tcp_nodelay(true)
        .tcp_keepalive(Some(Duration::from_secs(30))) // Keep connections alive
281
282
283
        .build()
        .expect("Failed to create HTTP client");

284
285
286
287
288
289
    let app_state = Arc::new(AppState::new(
        config.router_config.clone(),
        client.clone(),
        config.router_config.max_concurrent_requests,
    )?);
    let router_arc = Arc::clone(&app_state.router);
290

291
292
293
    // Start the service discovery if enabled
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
294
            match start_service_discovery(service_discovery_config, router_arc).await {
295
                Ok(handle) => {
296
                    info!("Service discovery started");
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
                    // Spawn a task to handle the service discovery thread
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
                    error!("Failed to start service discovery: {}", e);
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

312
    info!(
313
        "Router ready | workers: {:?}",
314
        app_state.router.get_worker_urls()
315
    );
316

317
318
319
320
321
322
323
324
325
326
    // Configure request ID headers
    let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

327
328
329
330
331
332
333
    // Build the application
    let app = build_app(
        app_state,
        config.max_payload_size,
        request_id_headers,
        config.router_config.cors_allowed_origins.clone(),
    );
334

335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
    // Create TCP listener - use the configured host
    let addr = format!("{}:{}", config.host, config.port);
    let listener = TcpListener::bind(&addr).await?;

    // Start server with graceful shutdown
    info!("Starting server on {}", addr);

    // Serve the application with graceful shutdown
    axum::serve(listener, app)
        .with_graceful_shutdown(shutdown_signal())
        .await
        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;

    Ok(())
}

// Graceful shutdown handler
async fn shutdown_signal() {
    let ctrl_c = async {
        signal::ctrl_c()
            .await
            .expect("failed to install Ctrl+C handler");
    };

    #[cfg(unix)]
    let terminate = async {
        signal::unix::signal(signal::unix::SignalKind::terminate())
            .expect("failed to install signal handler")
            .recv()
            .await;
    };

    #[cfg(not(unix))]
    let terminate = std::future::pending::<()>();

    tokio::select! {
        _ = ctrl_c => {
            info!("Received Ctrl+C, starting graceful shutdown");
        },
        _ = terminate => {
            info!("Received terminate signal, starting graceful shutdown");
        },
    }
}

// CORS Layer Creation
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
    use tower_http::cors::Any;

    let cors = if allowed_origins.is_empty() {
        // Allow all origins if none specified
        tower_http::cors::CorsLayer::new()
            .allow_origin(Any)
            .allow_methods(Any)
            .allow_headers(Any)
            .expose_headers(Any)
    } else {
        // Restrict to specific origins
        let origins: Vec<http::HeaderValue> = allowed_origins
            .into_iter()
            .filter_map(|origin| origin.parse().ok())
            .collect();

        tower_http::cors::CorsLayer::new()
            .allow_origin(origins)
            .allow_methods([http::Method::GET, http::Method::POST, http::Method::OPTIONS])
            .allow_headers([http::header::CONTENT_TYPE, http::header::AUTHORIZATION])
            .expose_headers([http::header::HeaderName::from_static("x-request-id")])
    };

    cors.max_age(Duration::from_secs(3600))
406
}