server.rs 2.88 KB
Newer Older
1
2
use crate::router::Router;
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
3
4
5
6
use bytes::Bytes;

#[derive(Debug)]
pub struct AppState {
7
    router: Router,
8
9
10
    client: reqwest::Client,
}

11
impl AppState {
12
13
    pub fn new(worker_urls: Vec<String>, policy: String, client: reqwest::Client) -> Self {
        // Create router based on policy
14
        let router = Router::new(worker_urls, policy);
15
16

        Self { router, client }
17
18
19
    }
}

20
21
22
23
24
25
async fn forward_request(
    client: &reqwest::Client,
    worker_url: String,
    route: String,
) -> HttpResponse {
    match client.get(format!("{}{}", worker_url, route)).send().await {
26
27
        Ok(res) => {
            let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
28
29
                .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);

30
31
32
33
34
35
            // print the status
            println!("Worker URL: {}, Status: {}", worker_url, status);
            match res.bytes().await {
                Ok(body) => HttpResponse::build(status).body(body.to_vec()),
                Err(_) => HttpResponse::InternalServerError().finish(),
            }
36
        }
37
38
39
40
        Err(_) => HttpResponse::InternalServerError().finish(),
    }
}

41
42
43
#[get("/v1/models")]
async fn v1_model(data: web::Data<AppState>) -> impl Responder {
    // TODO: extract forward_to_route
44
    let worker_url = match data.router.get_first() {
45
46
47
        Some(url) => url,
        None => return HttpResponse::InternalServerError().finish(),
    };
48

49
    forward_request(&data.client, worker_url, "/v1/models".to_string()).await
50
51
}

52
53
54
#[get("/get_model_info")]
async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
    let worker_url = match data.router.get_first() {
55
56
57
58
        Some(url) => url,
        None => return HttpResponse::InternalServerError().finish(),
    };

59
60
    forward_request(&data.client, worker_url, "/get_model_info".to_string()).await
}
61

62
63
64
65
// no deser and ser, just forward and return
#[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
    data.router.dispatch(&data.client, req, body).await
66
67
}

68
69
70
71
72
73
pub async fn startup(
    host: String,
    port: u16,
    worker_urls: Vec<String>,
    routing_policy: String,
) -> std::io::Result<()> {
74
75
76
77
78
79
80
81
82
    println!("Starting server on {}:{}", host, port);
    println!("Worker URLs: {:?}", worker_urls);

    // Create client once with configuration
    let client = reqwest::Client::builder()
        .build()
        .expect("Failed to create HTTP client");

    // Store both worker_urls and client in AppState
83
    let app_state = web::Data::new(AppState::new(worker_urls, routing_policy, client));
84
85
86
87
88
89
90
91
92
93
94

    HttpServer::new(move || {
        App::new()
            .app_data(app_state.clone())
            .service(generate)
            .service(v1_model)
            .service(get_model_info)
    })
    .bind((host, port))?
    .run()
    .await
95
}