server.rs 4.95 KB
Newer Older
1
use crate::router::PolicyConfig;
2
use crate::router::Router;
3
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
4
use bytes::Bytes;
5
use env_logger::Builder;
Byron Hsu's avatar
Byron Hsu committed
6
use log::{info, LevelFilter};
7
use std::collections::HashMap;
8
use std::io::Write;
9
10
11

#[derive(Debug)]
pub struct AppState {
12
    router: Router,
13
14
15
    client: reqwest::Client,
}

16
impl AppState {
17
18
19
20
21
    pub fn new(
        worker_urls: Vec<String>,
        client: reqwest::Client,
        policy_config: PolicyConfig,
    ) -> Self {
22
        // Create router based on policy
23
24
25
26
        let router = match Router::new(worker_urls, policy_config) {
            Ok(router) => router,
            Err(error) => panic!("Failed to create router: {}", error),
        };
27
28

        Self { router, client }
29
30
31
    }
}

32
33
#[get("/health")]
async fn health(data: web::Data<AppState>) -> impl Responder {
34
    data.router.route_to_first(&data.client, "/health").await
35
36
37
38
}

#[get("/health_generate")]
async fn health_generate(data: web::Data<AppState>) -> impl Responder {
39
40
41
    data.router
        .route_to_first(&data.client, "/health_generate")
        .await
42
43
}

44
45
#[get("/get_server_info")]
async fn get_server_info(data: web::Data<AppState>) -> impl Responder {
46
47
48
    data.router
        .route_to_first(&data.client, "/get_server_info")
        .await
49
50
}

51
#[get("/v1/models")]
52
async fn v1_models(data: web::Data<AppState>) -> impl Responder {
53
    data.router.route_to_first(&data.client, "/v1/models").await
54
55
}

56
57
#[get("/get_model_info")]
async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
58
59
60
    data.router
        .route_to_first(&data.client, "/get_model_info")
        .await
61
}
62

63
64
#[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
65
    data.router
66
        .route_generate_request(&data.client, &req, &body, "/generate")
67
68
69
70
71
72
73
74
75
76
        .await
}

#[post("/v1/chat/completions")]
async fn v1_chat_completions(
    req: HttpRequest,
    body: Bytes,
    data: web::Data<AppState>,
) -> impl Responder {
    data.router
77
        .route_generate_request(&data.client, &req, &body, "/v1/chat/completions")
78
79
80
81
82
83
84
85
86
87
        .await
}

#[post("/v1/completions")]
async fn v1_completions(
    req: HttpRequest,
    body: Bytes,
    data: web::Data<AppState>,
) -> impl Responder {
    data.router
88
        .route_generate_request(&data.client, &req, &body, "/v1/completions")
89
        .await
90
91
}

92
93
94
95
96
97
98
99
100
101
102
103
#[post("/add_worker")]
async fn add_worker(
    query: web::Query<HashMap<String, String>>,
    data: web::Data<AppState>,
) -> impl Responder {
    let worker_url = match query.get("url") {
        Some(url) => url.to_string(),
        None => {
            return HttpResponse::BadRequest()
                .body("Worker URL required. Provide 'url' query parameter")
        }
    };
104

105
    match data.router.add_worker(&worker_url).await {
106
107
108
        Ok(message) => HttpResponse::Ok().body(message),
        Err(error) => HttpResponse::BadRequest().body(error),
    }
109
110
}

111
112
113
114
115
116
117
118
119
#[post("/remove_worker")]
async fn remove_worker(
    query: web::Query<HashMap<String, String>>,
    data: web::Data<AppState>,
) -> impl Responder {
    let worker_url = match query.get("url") {
        Some(url) => url.to_string(),
        None => return HttpResponse::BadRequest().finish(),
    };
120
    data.router.remove_worker(&worker_url);
121
122
123
    HttpResponse::Ok().finish()
}

124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
    pub worker_urls: Vec<String>,
    pub policy_config: PolicyConfig,
    pub verbose: bool,
}

pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
    Builder::new()
        .format(|buf, record| {
            use chrono::Local;
            writeln!(
                buf,
                "[Router (Rust)] {} - {} - {}",
                Local::now().format("%Y-%m-%d %H:%M:%S"),
                record.level(),
                record.args()
            )
        })
        .filter(
            None,
            if config.verbose {
                LevelFilter::Debug
            } else {
                LevelFilter::Info
            },
        )
        .init();

154
155
156
157
    let client = reqwest::Client::builder()
        .build()
        .expect("Failed to create HTTP client");

158
    let app_state = web::Data::new(AppState::new(
159
        config.worker_urls.clone(),
160
        client,
161
        config.policy_config.clone(),
162
    ));
163

164
165
166
167
    info!("✅ Starting router on {}:{}", config.host, config.port);
    info!("✅ Serving Worker URLs: {:?}", config.worker_urls);
    info!("✅ Policy Config: {:?}", config.policy_config);

168
169
170
171
    HttpServer::new(move || {
        App::new()
            .app_data(app_state.clone())
            .service(generate)
172
173
174
            .service(v1_chat_completions)
            .service(v1_completions)
            .service(v1_models)
175
            .service(get_model_info)
176
177
            .service(health)
            .service(health_generate)
178
            .service(get_server_info)
179
            .service(add_worker)
180
            .service(remove_worker)
181
    })
182
    .bind((config.host, config.port))?
183
184
    .run()
    .await
185
}