server.rs 5.33 KB
Newer Older
1
use crate::router::PolicyConfig;
2
3
use crate::router::Router;
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
4
use bytes::Bytes;
5
use env_logger::Builder;
Byron Hsu's avatar
Byron Hsu committed
6
use log::{info, LevelFilter};
7
use std::io::Write;
8
9
10

#[derive(Debug)]
pub struct AppState {
11
    router: Router,
12
13
14
    client: reqwest::Client,
}

15
impl AppState {
16
17
18
19
20
    pub fn new(
        worker_urls: Vec<String>,
        client: reqwest::Client,
        policy_config: PolicyConfig,
    ) -> Self {
21
        // Create router based on policy
22
        let router = Router::new(worker_urls, policy_config);
23
24

        Self { router, client }
25
26
27
    }
}

28
29
30
31
32
33
async fn forward_request(
    client: &reqwest::Client,
    worker_url: String,
    route: String,
) -> HttpResponse {
    match client.get(format!("{}{}", worker_url, route)).send().await {
34
35
        Ok(res) => {
            let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
36
37
                .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);

38
            // print the status
39
40
41
42
            println!(
                "Forwarding Request Worker URL: {}, Route: {}, Status: {}",
                worker_url, route, status
            );
43
44
45
46
            match res.bytes().await {
                Ok(body) => HttpResponse::build(status).body(body.to_vec()),
                Err(_) => HttpResponse::InternalServerError().finish(),
            }
47
        }
48
49
50
51
        Err(_) => HttpResponse::InternalServerError().finish(),
    }
}

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#[get("/health")]
async fn health(data: web::Data<AppState>) -> impl Responder {
    let worker_url = match data.router.get_first() {
        Some(url) => url,
        None => return HttpResponse::InternalServerError().finish(),
    };

    forward_request(&data.client, worker_url, "/health".to_string()).await
}

#[get("/health_generate")]
async fn health_generate(data: web::Data<AppState>) -> impl Responder {
    let worker_url = match data.router.get_first() {
        Some(url) => url,
        None => return HttpResponse::InternalServerError().finish(),
    };

    forward_request(&data.client, worker_url, "/health_generate".to_string()).await
}

72
73
#[get("/get_server_info")]
async fn get_server_info(data: web::Data<AppState>) -> impl Responder {
74
75
76
77
78
    let worker_url = match data.router.get_first() {
        Some(url) => url,
        None => return HttpResponse::InternalServerError().finish(),
    };

79
    forward_request(&data.client, worker_url, "/get_server_info".to_string()).await
80
81
}

82
#[get("/v1/models")]
83
async fn v1_models(data: web::Data<AppState>) -> impl Responder {
84
    let worker_url = match data.router.get_first() {
85
86
87
        Some(url) => url,
        None => return HttpResponse::InternalServerError().finish(),
    };
88

89
    forward_request(&data.client, worker_url, "/v1/models".to_string()).await
90
91
}

92
93
94
#[get("/get_model_info")]
async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
    let worker_url = match data.router.get_first() {
95
96
97
98
        Some(url) => url,
        None => return HttpResponse::InternalServerError().finish(),
    };

99
100
    forward_request(&data.client, worker_url, "/get_model_info".to_string()).await
}
101

102
103
#[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    data.router
        .dispatch(&data.client, req, body, "generate")
        .await
}

#[post("/v1/chat/completions")]
async fn v1_chat_completions(
    req: HttpRequest,
    body: Bytes,
    data: web::Data<AppState>,
) -> impl Responder {
    data.router
        .dispatch(&data.client, req, body, "v1/chat/completions")
        .await
}

#[post("/v1/completions")]
async fn v1_completions(
    req: HttpRequest,
    body: Bytes,
    data: web::Data<AppState>,
) -> impl Responder {
    data.router
        .dispatch(&data.client, req, body, "v1/completions")
        .await
129
130
}

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
pub struct ServerConfig {
    pub host: String,
    pub port: u16,
    pub worker_urls: Vec<String>,
    pub policy_config: PolicyConfig,
    pub verbose: bool,
}

pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
    Builder::new()
        .format(|buf, record| {
            use chrono::Local;
            writeln!(
                buf,
                "[Router (Rust)] {} - {} - {}",
                Local::now().format("%Y-%m-%d %H:%M:%S"),
                record.level(),
                record.args()
            )
        })
        .filter(
            None,
            if config.verbose {
                LevelFilter::Debug
            } else {
                LevelFilter::Info
            },
        )
        .init();

    info!("Starting server on {}:{}", config.host, config.port);
    info!("Worker URLs: {:?}", config.worker_urls);
    info!("Policy Config: {:?}", config.policy_config);

165
166
167
168
    let client = reqwest::Client::builder()
        .build()
        .expect("Failed to create HTTP client");

169
170
171
172
173
    let app_state = web::Data::new(AppState::new(
        config.worker_urls,
        client,
        config.policy_config,
    ));
174
175
176
177
178

    HttpServer::new(move || {
        App::new()
            .app_data(app_state.clone())
            .service(generate)
179
180
181
            .service(v1_chat_completions)
            .service(v1_completions)
            .service(v1_models)
182
            .service(get_model_info)
183
184
            .service(health)
            .service(health_generate)
185
            .service(get_server_info)
186
    })
187
    .bind((config.host, config.port))?
188
189
    .run()
    .await
190
}