server.rs 5.3 KB
Newer Older
1
use crate::router::PolicyConfig;
2
use crate::router::Router;
3
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
4
use bytes::Bytes;
5
use env_logger::Builder;
Byron Hsu's avatar
Byron Hsu committed
6
use log::{info, LevelFilter};
7
use std::collections::HashMap;
8
use std::io::Write;
9
10
11

#[derive(Debug)]
pub struct AppState {
12
    router: Router,
13
14
15
    client: reqwest::Client,
}

16
impl AppState {
17
18
19
20
21
    pub fn new(
        worker_urls: Vec<String>,
        client: reqwest::Client,
        policy_config: PolicyConfig,
    ) -> Self {
22
        // Create router based on policy
23
24
25
26
        let router = match Router::new(worker_urls, policy_config) {
            Ok(router) => router,
            Err(error) => panic!("Failed to create router: {}", error),
        };
27
28

        Self { router, client }
29
30
31
    }
}

32
33
#[get("/health")]
async fn health(data: web::Data<AppState>) -> impl Responder {
34
    data.router.route_to_first(&data.client, "/health").await
35
36
37
38
}

#[get("/health_generate")]
async fn health_generate(data: web::Data<AppState>) -> impl Responder {
39
40
41
    data.router
        .route_to_first(&data.client, "/health_generate")
        .await
42
43
}

44
45
#[get("/get_server_info")]
async fn get_server_info(data: web::Data<AppState>) -> impl Responder {
46
47
48
    data.router
        .route_to_first(&data.client, "/get_server_info")
        .await
49
50
}

51
#[get("/v1/models")]
52
async fn v1_models(data: web::Data<AppState>) -> impl Responder {
53
    data.router.route_to_first(&data.client, "/v1/models").await
54
55
}

56
57
#[get("/get_model_info")]
async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
58
59
60
    data.router
        .route_to_first(&data.client, "/get_model_info")
        .await
61
}
62

63
64
#[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
65
    data.router
66
        .route_generate_request(&data.client, &req, &body, "/generate")
67
68
69
70
71
72
73
74
75
76
        .await
}

#[post("/v1/chat/completions")]
async fn v1_chat_completions(
    req: HttpRequest,
    body: Bytes,
    data: web::Data<AppState>,
) -> impl Responder {
    data.router
77
        .route_generate_request(&data.client, &req, &body, "/v1/chat/completions")
78
79
80
81
82
83
84
85
86
87
        .await
}

#[post("/v1/completions")]
async fn v1_completions(
    req: HttpRequest,
    body: Bytes,
    data: web::Data<AppState>,
) -> impl Responder {
    data.router
88
        .route_generate_request(&data.client, &req, &body, "/v1/completions")
89
        .await
90
91
}

92
93
94
95
96
97
98
99
100
101
102
103
#[post("/add_worker")]
async fn add_worker(
    query: web::Query<HashMap<String, String>>,
    data: web::Data<AppState>,
) -> impl Responder {
    let worker_url = match query.get("url") {
        Some(url) => url.to_string(),
        None => {
            return HttpResponse::BadRequest()
                .body("Worker URL required. Provide 'url' query parameter")
        }
    };
104

105
    match data.router.add_worker(&worker_url).await {
106
107
108
        Ok(message) => HttpResponse::Ok().body(message),
        Err(error) => HttpResponse::BadRequest().body(error),
    }
109
110
}

111
112
113
114
115
116
117
118
119
#[post("/remove_worker")]
async fn remove_worker(
    query: web::Query<HashMap<String, String>>,
    data: web::Data<AppState>,
) -> impl Responder {
    let worker_url = match query.get("url") {
        Some(url) => url.to_string(),
        None => return HttpResponse::BadRequest().finish(),
    };
120
    data.router.remove_worker(&worker_url);
121
    HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
122
123
}

124
125
126
127
128
129
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
    pub worker_urls: Vec<String>,
    pub policy_config: PolicyConfig,
    pub verbose: bool,
130
    pub max_payload_size: usize,
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
}

pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
    Builder::new()
        .format(|buf, record| {
            use chrono::Local;
            writeln!(
                buf,
                "[Router (Rust)] {} - {} - {}",
                Local::now().format("%Y-%m-%d %H:%M:%S"),
                record.level(),
                record.args()
            )
        })
        .filter(
            None,
            if config.verbose {
                LevelFilter::Debug
            } else {
                LevelFilter::Info
            },
        )
        .init();

155
156
157
158
    let client = reqwest::Client::builder()
        .build()
        .expect("Failed to create HTTP client");

159
    let app_state = web::Data::new(AppState::new(
160
        config.worker_urls.clone(),
161
        client,
162
        config.policy_config.clone(),
163
    ));
164

165
166
167
    info!("✅ Starting router on {}:{}", config.host, config.port);
    info!("✅ Serving Worker URLs: {:?}", config.worker_urls);
    info!("✅ Policy Config: {:?}", config.policy_config);
168
169
170
171
    info!(
        "✅ Max payload size: {} MB",
        config.max_payload_size / (1024 * 1024)
    );
172

173
174
175
    HttpServer::new(move || {
        App::new()
            .app_data(app_state.clone())
176
177
            .app_data(web::JsonConfig::default().limit(config.max_payload_size))
            .app_data(web::PayloadConfig::default().limit(config.max_payload_size))
178
            .service(generate)
179
180
181
            .service(v1_chat_completions)
            .service(v1_completions)
            .service(v1_models)
182
            .service(get_model_info)
183
184
            .service(health)
            .service(health_generate)
185
            .service(get_server_info)
186
            .service(add_worker)
187
            .service(remove_worker)
188
    })
189
    .bind((config.host, config.port))?
190
191
    .run()
    .await
192
}