server.rs 8.45 KB
Newer Older
1
use crate::logging::{self, LoggingConfig};
2
use crate::router::PolicyConfig;
3
use crate::router::Router;
4
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
5
6
7
use actix_web::{
    error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
};
8
use bytes::Bytes;
9
use futures_util::StreamExt;
10
use reqwest::Client;
11
use std::collections::HashMap;
12
use std::sync::atomic::{AtomicBool, Ordering};
13
use std::sync::Arc;
14
use std::time::Duration;
15
16
use tokio::spawn;
use tracing::{error, info, warn, Level};
17
18
19

#[derive(Debug)]
pub struct AppState {
20
    router: Router,
21
    client: Client,
22
23
}

24
impl AppState {
25
26
    pub fn new(
        worker_urls: Vec<String>,
27
        client: Client,
28
        policy_config: PolicyConfig,
29
    ) -> Result<Self, String> {
30
        // Create router based on policy
31
32
        let router = Router::new(worker_urls, policy_config)?;
        Ok(Self { router, client })
33
34
35
    }
}

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<HttpResponse, Error> {
    // Drain the payload
    while let Some(chunk) = payload.next().await {
        if let Err(err) = chunk {
            println!("Error while draining payload: {:?}", err);
            break;
        }
    }
    Ok(HttpResponse::NotFound().finish())
}

// Custom error handler for JSON payload errors.
fn json_error_handler(_err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
    error::ErrorPayloadTooLarge("Payload too large")
}

52
#[get("/health")]
53
54
55
56
async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
    data.router
        .route_to_first(&data.client, "/health", &req)
        .await
57
58
59
}

#[get("/health_generate")]
60
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
61
    data.router
62
        .route_to_first(&data.client, "/health_generate", &req)
63
        .await
64
65
}

66
#[get("/get_server_info")]
67
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
68
    data.router
69
        .route_to_first(&data.client, "/get_server_info", &req)
70
        .await
71
72
}

73
#[get("/v1/models")]
74
75
76
77
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
    data.router
        .route_to_first(&data.client, "/v1/models", &req)
        .await
78
79
}

80
#[get("/get_model_info")]
81
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
82
    data.router
83
        .route_to_first(&data.client, "/get_model_info", &req)
84
        .await
85
}
86

87
88
#[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
89
    data.router
90
        .route_generate_request(&data.client, &req, &body, "/generate")
91
92
93
94
95
96
97
98
99
100
        .await
}

#[post("/v1/chat/completions")]
async fn v1_chat_completions(
    req: HttpRequest,
    body: Bytes,
    data: web::Data<AppState>,
) -> impl Responder {
    data.router
101
        .route_generate_request(&data.client, &req, &body, "/v1/chat/completions")
102
103
104
105
106
107
108
109
110
111
        .await
}

#[post("/v1/completions")]
async fn v1_completions(
    req: HttpRequest,
    body: Bytes,
    data: web::Data<AppState>,
) -> impl Responder {
    data.router
112
        .route_generate_request(&data.client, &req, &body, "/v1/completions")
113
        .await
114
115
}

116
117
118
119
120
121
122
123
124
125
126
127
#[post("/add_worker")]
async fn add_worker(
    query: web::Query<HashMap<String, String>>,
    data: web::Data<AppState>,
) -> impl Responder {
    let worker_url = match query.get("url") {
        Some(url) => url.to_string(),
        None => {
            return HttpResponse::BadRequest()
                .body("Worker URL required. Provide 'url' query parameter")
        }
    };
128

129
    match data.router.add_worker(&worker_url).await {
130
131
132
        Ok(message) => HttpResponse::Ok().body(message),
        Err(error) => HttpResponse::BadRequest().body(error),
    }
133
134
}

135
136
137
138
139
140
141
#[get("/list_workers")]
async fn list_workers(data: web::Data<AppState>) -> impl Responder {
    let workers = data.router.get_worker_urls();
    let worker_list = workers.read().unwrap().clone();
    HttpResponse::Ok().json(serde_json::json!({ "urls": worker_list }))
}

142
143
144
145
146
147
148
149
150
#[post("/remove_worker")]
async fn remove_worker(
    query: web::Query<HashMap<String, String>>,
    data: web::Data<AppState>,
) -> impl Responder {
    let worker_url = match query.get("url") {
        Some(url) => url.to_string(),
        None => return HttpResponse::BadRequest().finish(),
    };
151
    data.router.remove_worker(&worker_url);
152
    HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
153
154
}

155
156
157
158
159
160
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
    pub worker_urls: Vec<String>,
    pub policy_config: PolicyConfig,
    pub verbose: bool,
161
    pub max_payload_size: usize,
162
    pub log_dir: Option<String>,
163
    pub service_discovery_config: Option<ServiceDiscoveryConfig>,
164
165
166
}

pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
167
168
169
170
171
172
173
    // Only initialize logging if not already done (for Python bindings support)
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
            level: if config.verbose {
                Level::DEBUG
174
            } else {
175
                Level::INFO
176
            },
177
178
179
180
181
182
183
184
185
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
186

187
188
189
190
191
192
193
194
    info!("🚧 Initializing router on {}:{}", config.host, config.port);
    info!("🚧 Initializing workers on {:?}", config.worker_urls);
    info!("🚧 Policy Config: {:?}", config.policy_config);
    info!(
        "🚧 Max payload size: {} MB",
        config.max_payload_size / (1024 * 1024)
    );

195
196
197
198
199
200
201
202
203
    // Log service discovery status
    if let Some(service_discovery_config) = &config.service_discovery_config {
        info!("🚧 Service discovery enabled");
        info!("🚧 Selector: {:?}", service_discovery_config.selector);
    } else {
        info!("🚧 Service discovery disabled");
    }

    let client = Client::builder()
204
        .pool_idle_timeout(Some(Duration::from_secs(50)))
205
206
207
        .build()
        .expect("Failed to create HTTP client");

208
209
210
    let app_state = web::Data::new(
        AppState::new(
            config.worker_urls.clone(),
211
            client.clone(), // Clone the client here
212
213
214
            config.policy_config.clone(),
        )
        .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
215
    );
216

217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    // Start the service discovery if enabled
    if let Some(service_discovery_config) = config.service_discovery_config {
        if service_discovery_config.enabled {
            let worker_urls = Arc::clone(&app_state.router.get_worker_urls());

            match start_service_discovery(service_discovery_config, worker_urls).await {
                Ok(handle) => {
                    info!("✅ Service discovery started successfully");

                    // Spawn a task to handle the service discovery thread
                    spawn(async move {
                        if let Err(e) = handle.await {
                            error!("Service discovery task failed: {:?}", e);
                        }
                    });
                }
                Err(e) => {
                    error!("Failed to start service discovery: {}", e);
                    warn!("Continuing without service discovery");
                }
            }
        }
    }

241
242
243
    info!("✅ Serving router on {}:{}", config.host, config.port);
    info!("✅ Serving workers on {:?}", config.worker_urls);

244
245
246
    HttpServer::new(move || {
        App::new()
            .app_data(app_state.clone())
247
248
249
250
251
            .app_data(
                web::JsonConfig::default()
                    .limit(config.max_payload_size)
                    .error_handler(json_error_handler),
            )
252
            .app_data(web::PayloadConfig::default().limit(config.max_payload_size))
253
            .service(generate)
254
255
256
            .service(v1_chat_completions)
            .service(v1_completions)
            .service(v1_models)
257
            .service(get_model_info)
258
259
            .service(health)
            .service(health_generate)
260
            .service(get_server_info)
261
            .service(add_worker)
262
            .service(remove_worker)
263
            .service(list_workers)
264
265
            // Default handler for unmatched routes.
            .default_service(web::route().to(sink_handler))
266
    })
267
    .bind_auto_h2c((config.host, config.port))?
268
269
    .run()
    .await
270
}