server.rs 13.7 KB
Newer Older
1
use crate::logging::{self, LoggingConfig};
2
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
3
use crate::prometheus::{self, PrometheusConfig};
4
use crate::request_adapter::ToPdRequest;
5
use crate::router::PolicyConfig;
6
use crate::router::Router;
7
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
8
9
10
11
use actix_web::{
    error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
};
use futures_util::StreamExt;
12
use reqwest::Client;
13
use std::collections::HashMap;
14
use std::sync::atomic::{AtomicBool, Ordering};
15
use std::sync::Arc;
16
use std::time::Duration;
17
18
use tokio::spawn;
use tracing::{error, info, warn, Level};
19
20
21

#[derive(Debug)]
pub struct AppState {
22
    router: Arc<Router>,
23
    client: Client,
24
    is_pd_mode: bool, // Add flag to track PD mode
25
26
}

27
impl AppState {
28
29
    pub fn new(
        worker_urls: Vec<String>,
30
        client: Client,
31
        policy_config: PolicyConfig,
32
    ) -> Result<Self, String> {
33
34
35
        // Check if this is PD mode from policy config
        let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. });

36
        // Create router based on policy
37
        let router = Arc::new(Router::new(worker_urls, policy_config)?);
38
39
40
41
42
        Ok(Self {
            router,
            client,
            is_pd_mode,
        })
43
44
45
    }
}

46
47
48
49
50
51
52
53
54
55
56
57
async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<HttpResponse, Error> {
    // Drain the payload
    while let Some(chunk) = payload.next().await {
        if let Err(err) = chunk {
            println!("Error while draining payload: {:?}", err);
            break;
        }
    }
    Ok(HttpResponse::NotFound().finish())
}

// Custom error handler for JSON payload errors.
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
    error!("JSON payload error: {:?}", err);
    match &err {
        error::JsonPayloadError::OverflowKnownLength { length, limit } => {
            error!(
                "Payload too large: {} bytes exceeds limit of {} bytes",
                length, limit
            );
            error::ErrorPayloadTooLarge(format!(
                "Payload too large: {} bytes exceeds limit of {} bytes",
                length, limit
            ))
        }
        error::JsonPayloadError::Overflow { limit } => {
            error!("Payload overflow: exceeds limit of {} bytes", limit);
            error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
        }
        _ => error::ErrorBadRequest(format!("Invalid JSON payload: {}", err)),
    }
77
78
}

79
#[get("/health")]
80
81
82
83
async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
    data.router
        .route_to_first(&data.client, "/health", &req)
        .await
84
85
86
}

#[get("/health_generate")]
87
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
88
89
90
91
92
93
94
95
96
97
98
99
    // Check if we're in PD mode
    if data.is_pd_mode {
        // For PD mode, check health on all servers
        data.router
            .route_pd_health_generate(&data.client, &req)
            .await
    } else {
        // Regular mode
        data.router
            .route_to_first(&data.client, "/health_generate", &req)
            .await
    }
100
101
}

102
#[get("/get_server_info")]
103
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
104
105
106
107
108
109
110
111
112
    if data.is_pd_mode {
        // For PD mode, aggregate info from both prefill and decode servers
        data.router.get_pd_server_info(&data.client, &req).await
    } else {
        // Regular mode - return first server's info
        data.router
            .route_to_first(&data.client, "/get_server_info", &req)
            .await
    }
113
114
}

115
#[get("/v1/models")]
116
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
117
118
119
120
121
122
123
124
125
    if data.is_pd_mode {
        // For PD mode, return models from the first prefill server
        data.router.get_pd_models(&data.client, &req).await
    } else {
        // Regular mode
        data.router
            .route_to_first(&data.client, "/v1/models", &req)
            .await
    }
126
127
}

128
#[get("/get_model_info")]
129
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
130
131
132
133
134
135
136
137
    if data.is_pd_mode {
        // For PD mode, get model info from the first prefill server
        data.router.get_pd_model_info(&data.client, &req).await
    } else {
        data.router
            .route_to_first(&data.client, "/get_model_info", &req)
            .await
    }
138
}
139

140
#[post("/generate")]
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
async fn generate(
    req: HttpRequest,
    body: web::Json<GenerateRequest>,
    state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
    let client = &state.client;
    let router = &state.router;

    // Use typed request directly for both PD and regular routing
    if state.is_pd_mode {
        // For PD mode, convert to PD request with bootstrap
        let pd_request = body.into_inner().to_pd_request();

        Ok(router
            .route_pd_generate_typed(&client, &req, pd_request, "/generate")
            .await)
    } else {
        // For regular mode, use typed request directly
        let request = body.into_inner();
        Ok(router
            .route_typed_request(&client, &req, &request, "/generate")
            .await)
    }
164
165
166
167
168
}

#[post("/v1/chat/completions")]
async fn v1_chat_completions(
    req: HttpRequest,
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    body: web::Json<ChatCompletionRequest>,
    state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
    let client = &state.client;
    let router = &state.router;

    // Use typed request directly for both PD and regular routing
    if state.is_pd_mode {
        // For PD mode, convert to PD request with bootstrap
        let pd_request = body.into_inner().to_pd_request();

        Ok(router
            .route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions")
            .await)
    } else {
        // For regular mode, use typed request directly
        let request = body.into_inner();
        Ok(router
            .route_typed_request(&client, &req, &request, "/v1/chat/completions")
            .await)
    }
190
191
192
193
194
}

#[post("/v1/completions")]
async fn v1_completions(
    req: HttpRequest,
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
    body: web::Json<CompletionRequest>,
    state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
    let client = &state.client;
    let router = &state.router;

    // Use typed request directly for both PD and regular routing
    if state.is_pd_mode {
        // For PD mode, convert to PD request with bootstrap
        let pd_request = body.into_inner().to_pd_request();

        Ok(router
            .route_pd_generate_typed(&client, &req, pd_request, "/v1/completions")
            .await)
    } else {
        // For regular mode, use typed request directly
        let request = body.into_inner();
        Ok(router
            .route_typed_request(&client, &req, &request, "/v1/completions")
            .await)
    }
216
217
}

218
219
220
221
222
223
224
225
226
227
228
229
#[post("/add_worker")]
async fn add_worker(
    query: web::Query<HashMap<String, String>>,
    data: web::Data<AppState>,
) -> impl Responder {
    let worker_url = match query.get("url") {
        Some(url) => url.to_string(),
        None => {
            return HttpResponse::BadRequest()
                .body("Worker URL required. Provide 'url' query parameter")
        }
    };
230

231
    match data.router.add_worker(&worker_url).await {
232
233
234
        Ok(message) => HttpResponse::Ok().body(message),
        Err(error) => HttpResponse::BadRequest().body(error),
    }
235
236
}

237
238
#[get("/list_workers")]
async fn list_workers(data: web::Data<AppState>) -> impl Responder {
239
    let worker_list = data.router.get_worker_urls();
240
241
242
    HttpResponse::Ok().json(serde_json::json!({ "urls": worker_list }))
}

243
244
245
246
247
248
249
250
251
#[post("/remove_worker")]
async fn remove_worker(
    query: web::Query<HashMap<String, String>>,
    data: web::Data<AppState>,
) -> impl Responder {
    let worker_url = match query.get("url") {
        Some(url) => url.to_string(),
        None => return HttpResponse::BadRequest().finish(),
    };
252
    data.router.remove_worker(&worker_url);
253
    HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
254
255
}

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
#[post("/flush_cache")]
async fn flush_cache(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
    if data.is_pd_mode {
        // For PD mode, flush cache on both prefill and decode servers
        data.router.route_pd_flush_cache(&data.client).await
    } else {
        // Route to all workers for cache flushing
        data.router
            .route_to_all(&data.client, "/flush_cache", &req)
            .await
    }
}

#[get("/get_loads")]
async fn get_loads(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
    // Get loads from all workers
    data.router.get_all_loads(&data.client, &req).await
}

275
276
277
278
279
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
    pub worker_urls: Vec<String>,
    pub policy_config: PolicyConfig,
280
    pub max_payload_size: usize,
281
    pub log_dir: Option<String>,
282
    pub log_level: Option<String>,
283
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
284
    pub prometheus_config: Option<PrometheusConfig>,
285
    pub request_timeout_secs: u64,
286
287
288
}

pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
289
290
291
292
293
    // Only initialize logging if not already done (for Python bindings support)
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
294
295
296
297
298
299
300
301
302
303
304
            level: config
                .log_level
                .as_deref()
                .and_then(|s| match s.to_uppercase().parse::<Level>() {
                    Ok(l) => Some(l),
                    Err(_) => {
                        warn!("Invalid log level string: '{}'. Defaulting to INFO.", s);
                        None
                    }
                })
                .unwrap_or(Level::INFO),
305
306
307
308
309
310
311
312
313
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
314

315
316
317
318
319
320
321
322
323
324
325
    // Initialize prometheus metrics exporter
    if let Some(prometheus_config) = config.prometheus_config {
        info!(
            "🚧 Initializing Prometheus metrics on {}:{}",
            prometheus_config.host, prometheus_config.port
        );
        prometheus::start_prometheus(prometheus_config);
    } else {
        info!("🚧 Prometheus metrics disabled");
    }

326
327
328
329
330
331
332
333
    info!("🚧 Initializing router on {}:{}", config.host, config.port);
    info!("🚧 Initializing workers on {:?}", config.worker_urls);
    info!("🚧 Policy Config: {:?}", config.policy_config);
    info!(
        "🚧 Max payload size: {} MB",
        config.max_payload_size / (1024 * 1024)
    );

334
335
336
337
338
339
340
341
342
    // Log service discovery status
    if let Some(service_discovery_config) = &config.service_discovery_config {
        info!("🚧 Service discovery enabled");
        info!("🚧 Selector: {:?}", service_discovery_config.selector);
    } else {
        info!("🚧 Service discovery disabled");
    }

    let client = Client::builder()
343
        .pool_idle_timeout(Some(Duration::from_secs(50)))
344
        .timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
345
346
347
        .build()
        .expect("Failed to create HTTP client");

348
349
350
351
352
353
354
355
    let app_state_init = AppState::new(
        config.worker_urls.clone(),
        client.clone(),
        config.policy_config.clone(),
    )
    .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
    let router_arc = Arc::clone(&app_state_init.router);
    let app_state = web::Data::new(app_state_init);
356

357
358
359
    // Start the service discovery if enabled
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
360
361
362
            info!("🚧 Initializing Kubernetes service discovery");
            // Pass the Arc<Router> directly
            match start_service_discovery(service_discovery_config, router_arc).await {
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
                Ok(handle) => {
                    info!("✅ Service discovery started successfully");
                    // Spawn a task to handle the service discovery thread
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
                    error!("Failed to start service discovery: {}", e);
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

380
    info!("✅ Serving router on {}:{}", config.host, config.port);
381
382
    info!(
        "✅ Serving workers on {:?}",
383
        app_state.router.get_worker_urls()
384
    );
385

386
387
388
    HttpServer::new(move || {
        App::new()
            .app_data(app_state.clone())
389
390
391
392
393
            .app_data(
                web::JsonConfig::default()
                    .limit(config.max_payload_size)
                    .error_handler(json_error_handler),
            )
394
            .app_data(web::PayloadConfig::default().limit(config.max_payload_size))
395
            .service(generate)
396
397
398
            .service(v1_chat_completions)
            .service(v1_completions)
            .service(v1_models)
399
            .service(get_model_info)
400
401
            .service(health)
            .service(health_generate)
402
            .service(get_server_info)
403
            .service(add_worker)
404
            .service(remove_worker)
405
            .service(list_workers)
406
407
            .service(flush_cache)
            .service(get_loads)
408
            .default_service(web::route().to(sink_handler))
409
    })
410
    .bind_auto_h2c((config.host, config.port))?
411
412
    .run()
    .await
413
}