server.rs 2.93 KB
Newer Older
1
use crate::router::PolicyConfig;
2
3
use crate::router::Router;
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
4
5
6
7
use bytes::Bytes;

#[derive(Debug)]
pub struct AppState {
8
    router: Router,
9
10
11
    client: reqwest::Client,
}

12
impl AppState {
13
14
15
16
17
    pub fn new(
        worker_urls: Vec<String>,
        client: reqwest::Client,
        policy_config: PolicyConfig,
    ) -> Self {
18
        // Create router based on policy
19
        let router = Router::new(worker_urls, policy_config);
20
21

        Self { router, client }
22
23
24
    }
}

25
26
27
28
29
30
async fn forward_request(
    client: &reqwest::Client,
    worker_url: String,
    route: String,
) -> HttpResponse {
    match client.get(format!("{}{}", worker_url, route)).send().await {
31
32
        Ok(res) => {
            let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
33
34
                .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);

35
36
37
38
39
40
            // print the status
            println!("Worker URL: {}, Status: {}", worker_url, status);
            match res.bytes().await {
                Ok(body) => HttpResponse::build(status).body(body.to_vec()),
                Err(_) => HttpResponse::InternalServerError().finish(),
            }
41
        }
42
43
44
45
        Err(_) => HttpResponse::InternalServerError().finish(),
    }
}

46
47
#[get("/v1/models")]
async fn v1_model(data: web::Data<AppState>) -> impl Responder {
48
    let worker_url = match data.router.get_first() {
49
50
51
        Some(url) => url,
        None => return HttpResponse::InternalServerError().finish(),
    };
52

53
    forward_request(&data.client, worker_url, "/v1/models".to_string()).await
54
55
}

56
57
58
#[get("/get_model_info")]
async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
    let worker_url = match data.router.get_first() {
59
60
61
62
        Some(url) => url,
        None => return HttpResponse::InternalServerError().finish(),
    };

63
64
    forward_request(&data.client, worker_url, "/get_model_info".to_string()).await
}
65

66
67
68
#[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
    data.router.dispatch(&data.client, req, body).await
69
70
}

71
72
73
74
pub async fn startup(
    host: String,
    port: u16,
    worker_urls: Vec<String>,
75
    policy_config: PolicyConfig,
76
) -> std::io::Result<()> {
77
78
    println!("Starting server on {}:{}", host, port);
    println!("Worker URLs: {:?}", worker_urls);
79
    println!("Policy Config: {:?}", policy_config);
80
81
82
83
84
85
86

    // Create client once with configuration
    let client = reqwest::Client::builder()
        .build()
        .expect("Failed to create HTTP client");

    // Store both worker_urls and client in AppState
87
    let app_state = web::Data::new(AppState::new(worker_urls, client, policy_config));
88
89
90
91
92
93
94
95
96
97
98

    HttpServer::new(move || {
        App::new()
            .app_data(app_state.clone())
            .service(generate)
            .service(v1_model)
            .service(get_model_info)
    })
    .bind((host, port))?
    .run()
    .await
99
}