server.rs 13.1 KB
Newer Older
1
use crate::config::RouterConfig;
2
use crate::logging::{self, LoggingConfig};
3
use crate::metrics::{self, PrometheusConfig};
4
use crate::middleware::{get_request_id, RequestIdMiddleware};
5
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
6
use crate::routers::{RouterFactory, RouterTrait};
7
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
8
9
10
11
use actix_web::{
    error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
};
use futures_util::StreamExt;
12
use reqwest::Client;
13
use std::collections::HashMap;
14
use std::sync::atomic::{AtomicBool, Ordering};
15
use std::sync::Arc;
16
use std::time::Duration;
17
18
use tokio::spawn;
use tracing::{error, info, warn, Level};
19
20
21

#[derive(Debug)]
pub struct AppState {
22
    router: Arc<dyn RouterTrait>,
23
    client: Client,
24
25
}

26
impl AppState {
27
28
29
30
31
32
33
34
    pub fn new(router_config: RouterConfig, client: Client) -> Result<Self, String> {
        // Use RouterFactory to create the appropriate router type
        let router = RouterFactory::create_router(&router_config)?;

        // Convert Box<dyn RouterTrait> to Arc<dyn RouterTrait>
        let router = Arc::from(router);

        Ok(Self { router, client })
35
36
37
    }
}

38
39
40
41
42
43
44
45
46
47
48
49
async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<HttpResponse, Error> {
    // Drain the payload
    while let Some(chunk) = payload.next().await {
        if let Err(err) = chunk {
            println!("Error while draining payload: {:?}", err);
            break;
        }
    }
    Ok(HttpResponse::NotFound().finish())
}

// Custom error handler for JSON payload errors.
50
51
fn json_error_handler(err: error::JsonPayloadError, req: &HttpRequest) -> Error {
    let request_id = get_request_id(req);
52
53
54
    match &err {
        error::JsonPayloadError::OverflowKnownLength { length, limit } => {
            error!(
55
56
                request_id = %request_id,
                "Payload too large length={} limit={}", length, limit
57
58
59
60
61
62
63
            );
            error::ErrorPayloadTooLarge(format!(
                "Payload too large: {} bytes exceeds limit of {} bytes",
                length, limit
            ))
        }
        error::JsonPayloadError::Overflow { limit } => {
64
65
66
67
            error!(
                request_id = %request_id,
                "Payload overflow limit={}", limit
            );
68
69
            error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
        }
70
71
72
73
74
75
76
        _ => {
            error!(
                request_id = %request_id,
                "Invalid JSON payload error={}", err
            );
            error::ErrorBadRequest(format!("Invalid JSON payload: {}", err))
        }
77
    }
78
79
}

80
81
82
83
84
85
86
87
88
89
#[get("/liveness")]
async fn liveness(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
    data.router.liveness()
}

#[get("/readiness")]
async fn readiness(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
    data.router.readiness()
}

90
#[get("/health")]
91
async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
92
    data.router.health(&data.client, &req).await
93
94
95
}

#[get("/health_generate")]
96
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
97
    data.router.health_generate(&data.client, &req).await
98
99
}

100
#[get("/get_server_info")]
101
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
102
    data.router.get_server_info(&data.client, &req).await
103
104
}

105
#[get("/v1/models")]
106
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
107
    data.router.get_models(&data.client, &req).await
108
109
}

110
#[get("/get_model_info")]
111
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
112
    data.router.get_model_info(&data.client, &req).await
113
}
114

115
#[post("/generate")]
116
117
118
119
120
async fn generate(
    req: HttpRequest,
    body: web::Json<GenerateRequest>,
    state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
121
122
123
124
125
126
127
128
129
130
131
132
133
134
    let request_id = get_request_id(&req);
    info!(
        request_id = %request_id,
        "Received generate request method=\"POST\" path=\"/generate\""
    );

    let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
        error!(
            request_id = %request_id,
            "Failed to parse generate request body error={}", e
        );
        error::ErrorBadRequest(format!("Invalid JSON: {}", e))
    })?;

135
136
137
138
    Ok(state
        .router
        .route_generate(&state.client, &req, json_body)
        .await)
139
140
141
142
143
}

#[post("/v1/chat/completions")]
async fn v1_chat_completions(
    req: HttpRequest,
144
145
146
    body: web::Json<ChatCompletionRequest>,
    state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
147
148
149
150
151
152
153
154
155
156
157
158
159
160
    let request_id = get_request_id(&req);
    info!(
        request_id = %request_id,
        "Received chat completion request method=\"POST\" path=\"/v1/chat/completions\""
    );

    let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
        error!(
            request_id = %request_id,
            "Failed to parse chat completion request body error={}", e
        );
        error::ErrorBadRequest(format!("Invalid JSON: {}", e))
    })?;

161
162
163
164
    Ok(state
        .router
        .route_chat(&state.client, &req, json_body)
        .await)
165
166
167
168
169
}

#[post("/v1/completions")]
async fn v1_completions(
    req: HttpRequest,
170
171
172
    body: web::Json<CompletionRequest>,
    state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    let request_id = get_request_id(&req);
    info!(
        request_id = %request_id,
        "Received completion request method=\"POST\" path=\"/v1/completions\""
    );

    let json_body = serde_json::to_value(body.into_inner()).map_err(|e| {
        error!(
            request_id = %request_id,
            "Failed to parse completion request body error={}", e
        );
        error::ErrorBadRequest(format!("Invalid JSON: {}", e))
    })?;

187
188
189
190
    Ok(state
        .router
        .route_completion(&state.client, &req, json_body)
        .await)
191
192
}

193
194
#[post("/add_worker")]
async fn add_worker(
195
    req: HttpRequest,
196
197
198
    query: web::Query<HashMap<String, String>>,
    data: web::Data<AppState>,
) -> impl Responder {
199
200
    let request_id = get_request_id(&req);

201
202
203
    let worker_url = match query.get("url") {
        Some(url) => url.to_string(),
        None => {
204
205
206
207
            warn!(
                request_id = %request_id,
                "Add worker request missing URL parameter"
            );
208
            return HttpResponse::BadRequest()
209
                .body("Worker URL required. Provide 'url' query parameter");
210
211
        }
    };
212

213
214
215
216
217
218
    info!(
        request_id = %request_id,
        worker_url = %worker_url,
        "Adding worker"
    );

219
    match data.router.add_worker(&worker_url).await {
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        Ok(message) => {
            info!(
                request_id = %request_id,
                worker_url = %worker_url,
                "Successfully added worker"
            );
            HttpResponse::Ok().body(message)
        }
        Err(error) => {
            error!(
                request_id = %request_id,
                worker_url = %worker_url,
                error = %error,
                "Failed to add worker"
            );
            HttpResponse::BadRequest().body(error)
        }
237
    }
238
239
}

240
241
#[get("/list_workers")]
async fn list_workers(data: web::Data<AppState>) -> impl Responder {
242
    let worker_list = data.router.get_worker_urls();
243
244
245
    HttpResponse::Ok().json(serde_json::json!({ "urls": worker_list }))
}

246
247
#[post("/remove_worker")]
async fn remove_worker(
248
    req: HttpRequest,
249
250
251
    query: web::Query<HashMap<String, String>>,
    data: web::Data<AppState>,
) -> impl Responder {
252
253
    let request_id = get_request_id(&req);

254
255
    let worker_url = match query.get("url") {
        Some(url) => url.to_string(),
256
257
258
259
260
261
262
        None => {
            warn!(
                request_id = %request_id,
                "Remove worker request missing URL parameter"
            );
            return HttpResponse::BadRequest().finish();
        }
263
    };
264
265
266
267
268
269
270

    info!(
        request_id = %request_id,
        worker_url = %worker_url,
        "Removing worker"
    );

271
    data.router.remove_worker(&worker_url);
272
    HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
273
274
}

275
#[post("/flush_cache")]
276
277
async fn flush_cache(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
    data.router.flush_cache(&data.client).await
278
279
280
}

#[get("/get_loads")]
281
282
async fn get_loads(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
    data.router.get_worker_loads(&data.client).await
283
284
}

285
286
287
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
288
    pub router_config: RouterConfig,
289
    pub max_payload_size: usize,
290
    pub log_dir: Option<String>,
291
    pub log_level: Option<String>,
292
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
293
    pub prometheus_config: Option<PrometheusConfig>,
294
    pub request_timeout_secs: u64,
295
    pub request_id_headers: Option<Vec<String>>,
296
297
298
}

pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
299
300
301
302
303
    // Only initialize logging if not already done (for Python bindings support)
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
304
305
306
307
308
309
310
311
312
313
314
            level: config
                .log_level
                .as_deref()
                .and_then(|s| match s.to_uppercase().parse::<Level>() {
                    Ok(l) => Some(l),
                    Err(_) => {
                        warn!("Invalid log level string: '{}'. Defaulting to INFO.", s);
                        None
                    }
                })
                .unwrap_or(Level::INFO),
315
316
317
318
319
320
321
322
323
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
324

325
326
    // Initialize prometheus metrics exporter
    if let Some(prometheus_config) = config.prometheus_config {
327
        metrics::start_prometheus(prometheus_config);
328
329
    }

330
    info!(
331
332
333
334
335
        "Starting router on {}:{} | mode: {:?} | policy: {:?} | max_payload: {}MB",
        config.host,
        config.port,
        config.router_config.mode,
        config.router_config.policy,
336
337
338
        config.max_payload_size / (1024 * 1024)
    );

339
    let client = Client::builder()
340
        .pool_idle_timeout(Some(Duration::from_secs(50)))
341
        .timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
342
343
344
        .build()
        .expect("Failed to create HTTP client");

345
346
    let app_state_init = AppState::new(config.router_config.clone(), client.clone())
        .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
347
348
    let router_arc = Arc::clone(&app_state_init.router);
    let app_state = web::Data::new(app_state_init);
349

350
351
352
    // Start the service discovery if enabled
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
353
            match start_service_discovery(service_discovery_config, router_arc).await {
354
                Ok(handle) => {
355
                    info!("Service discovery started");
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
                    // Spawn a task to handle the service discovery thread
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
                    error!("Failed to start service discovery: {}", e);
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

371
    info!(
372
        "Router ready | workers: {:?}",
373
        app_state.router.get_worker_urls()
374
    );
375

376
377
378
379
380
381
382
383
384
385
    // Configure request ID headers
    let request_id_headers = config.request_id_headers.clone().unwrap_or_else(|| {
        vec![
            "x-request-id".to_string(),
            "x-correlation-id".to_string(),
            "x-trace-id".to_string(),
            "request-id".to_string(),
        ]
    });

386
    HttpServer::new(move || {
387
388
        let request_id_middleware = RequestIdMiddleware::new(request_id_headers.clone());

389
        App::new()
390
            .wrap(request_id_middleware)
391
            .app_data(app_state.clone())
392
393
394
395
396
            .app_data(
                web::JsonConfig::default()
                    .limit(config.max_payload_size)
                    .error_handler(json_error_handler),
            )
397
            .app_data(web::PayloadConfig::default().limit(config.max_payload_size))
398
            .service(generate)
399
400
401
            .service(v1_chat_completions)
            .service(v1_completions)
            .service(v1_models)
402
            .service(get_model_info)
403
404
            .service(liveness)
            .service(readiness)
405
406
            .service(health)
            .service(health_generate)
407
            .service(get_server_info)
408
            .service(add_worker)
409
            .service(remove_worker)
410
            .service(list_workers)
411
412
            .service(flush_cache)
            .service(get_loads)
413
            .default_service(web::route().to(sink_handler))
414
    })
415
    .bind_auto_h2c((config.host, config.port))?
416
417
    .run()
    .await
418
}