server.rs 6.7 KB
Newer Older
1
use crate::logging::{self, LoggingConfig};
2
use crate::router::PolicyConfig;
3
use crate::router::Router;
4
5
6
use actix_web::{
    error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
};
7
use bytes::Bytes;
8
use futures_util::StreamExt;
9
use std::collections::HashMap;
10
use std::sync::atomic::{AtomicBool, Ordering};
11
use std::time::Duration;
12
use tracing::{info, Level};
13
14
15

#[derive(Debug)]
pub struct AppState {
16
    router: Router,
17
18
19
    client: reqwest::Client,
}

20
impl AppState {
21
22
23
24
    pub fn new(
        worker_urls: Vec<String>,
        client: reqwest::Client,
        policy_config: PolicyConfig,
25
    ) -> Result<Self, String> {
26
        // Create router based on policy
27
28
        let router = Router::new(worker_urls, policy_config)?;
        Ok(Self { router, client })
29
30
31
    }
}

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<HttpResponse, Error> {
    // Drain the payload
    while let Some(chunk) = payload.next().await {
        if let Err(err) = chunk {
            println!("Error while draining payload: {:?}", err);
            break;
        }
    }
    Ok(HttpResponse::NotFound().finish())
}

// Custom error handler for JSON payload errors.
fn json_error_handler(_err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
    error::ErrorPayloadTooLarge("Payload too large")
}

48
#[get("/health")]
49
50
51
52
async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
    data.router
        .route_to_first(&data.client, "/health", &req)
        .await
53
54
55
}

#[get("/health_generate")]
56
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
57
    data.router
58
        .route_to_first(&data.client, "/health_generate", &req)
59
        .await
60
61
}

62
#[get("/get_server_info")]
63
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
64
    data.router
65
        .route_to_first(&data.client, "/get_server_info", &req)
66
        .await
67
68
}

69
#[get("/v1/models")]
70
71
72
73
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
    data.router
        .route_to_first(&data.client, "/v1/models", &req)
        .await
74
75
}

76
#[get("/get_model_info")]
77
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
78
    data.router
79
        .route_to_first(&data.client, "/get_model_info", &req)
80
        .await
81
}
82

83
84
#[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
85
    data.router
86
        .route_generate_request(&data.client, &req, &body, "/generate")
87
88
89
90
91
92
93
94
95
96
        .await
}

#[post("/v1/chat/completions")]
async fn v1_chat_completions(
    req: HttpRequest,
    body: Bytes,
    data: web::Data<AppState>,
) -> impl Responder {
    data.router
97
        .route_generate_request(&data.client, &req, &body, "/v1/chat/completions")
98
99
100
101
102
103
104
105
106
107
        .await
}

#[post("/v1/completions")]
async fn v1_completions(
    req: HttpRequest,
    body: Bytes,
    data: web::Data<AppState>,
) -> impl Responder {
    data.router
108
        .route_generate_request(&data.client, &req, &body, "/v1/completions")
109
        .await
110
111
}

112
113
114
115
116
117
118
119
120
121
122
123
#[post("/add_worker")]
async fn add_worker(
    query: web::Query<HashMap<String, String>>,
    data: web::Data<AppState>,
) -> impl Responder {
    let worker_url = match query.get("url") {
        Some(url) => url.to_string(),
        None => {
            return HttpResponse::BadRequest()
                .body("Worker URL required. Provide 'url' query parameter")
        }
    };
124

125
    match data.router.add_worker(&worker_url).await {
126
127
128
        Ok(message) => HttpResponse::Ok().body(message),
        Err(error) => HttpResponse::BadRequest().body(error),
    }
129
130
}

131
132
133
134
135
136
137
138
139
#[post("/remove_worker")]
async fn remove_worker(
    query: web::Query<HashMap<String, String>>,
    data: web::Data<AppState>,
) -> impl Responder {
    let worker_url = match query.get("url") {
        Some(url) => url.to_string(),
        None => return HttpResponse::BadRequest().finish(),
    };
140
    data.router.remove_worker(&worker_url);
141
    HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
142
143
}

144
145
146
147
148
149
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
    pub worker_urls: Vec<String>,
    pub policy_config: PolicyConfig,
    pub verbose: bool,
150
    pub max_payload_size: usize,
151
    pub log_dir: Option<String>,
152
153
154
}

pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
155
156
157
158
159
160
161
    // Only initialize logging if not already done (for Python bindings support)
    static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);

    let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
        Some(logging::init_logging(LoggingConfig {
            level: if config.verbose {
                Level::DEBUG
162
            } else {
163
                Level::INFO
164
            },
165
166
167
168
169
170
171
172
173
            json_format: false,
            log_dir: config.log_dir.clone(),
            colorize: true,
            log_file_name: "sgl-router".to_string(),
            log_targets: None,
        }))
    } else {
        None
    };
174

175
176
177
178
179
180
181
182
    info!("🚧 Initializing router on {}:{}", config.host, config.port);
    info!("🚧 Initializing workers on {:?}", config.worker_urls);
    info!("🚧 Policy Config: {:?}", config.policy_config);
    info!(
        "🚧 Max payload size: {} MB",
        config.max_payload_size / (1024 * 1024)
    );

183
    let client = reqwest::Client::builder()
184
        .pool_idle_timeout(Some(Duration::from_secs(50)))
185
186
187
        .build()
        .expect("Failed to create HTTP client");

188
189
190
    let app_state = web::Data::new(
        AppState::new(
            config.worker_urls.clone(),
191
            client.clone(), // Clone the client here
192
193
194
            config.policy_config.clone(),
        )
        .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?,
195
    );
196

197
198
199
    info!("✅ Serving router on {}:{}", config.host, config.port);
    info!("✅ Serving workers on {:?}", config.worker_urls);

200
201
202
    HttpServer::new(move || {
        App::new()
            .app_data(app_state.clone())
203
204
205
206
207
            .app_data(
                web::JsonConfig::default()
                    .limit(config.max_payload_size)
                    .error_handler(json_error_handler),
            )
208
            .app_data(web::PayloadConfig::default().limit(config.max_payload_size))
209
            .service(generate)
210
211
212
            .service(v1_chat_completions)
            .service(v1_completions)
            .service(v1_models)
213
            .service(get_model_info)
214
215
            .service(health)
            .service(health_generate)
216
            .service(get_server_info)
217
            .service(add_worker)
218
            .service(remove_worker)
219
220
            // Default handler for unmatched routes.
            .default_service(web::route().to(sink_handler))
221
    })
222
    .bind((config.host, config.port))?
223
224
    .run()
    .await
225
}