server.rs 4.68 KB
Newer Older
1
use crate::router::PolicyConfig;
2
3
use crate::router::Router;
use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder};
4
5
6
7
use bytes::Bytes;

#[derive(Debug)]
pub struct AppState {
8
    router: Router,
9
10
11
    client: reqwest::Client,
}

12
impl AppState {
13
14
15
16
17
    pub fn new(
        worker_urls: Vec<String>,
        client: reqwest::Client,
        policy_config: PolicyConfig,
    ) -> Self {
18
        // Create router based on policy
19
        let router = Router::new(worker_urls, policy_config);
20
21

        Self { router, client }
22
23
24
    }
}

25
26
27
28
29
30
async fn forward_request(
    client: &reqwest::Client,
    worker_url: String,
    route: String,
) -> HttpResponse {
    match client.get(format!("{}{}", worker_url, route)).send().await {
31
32
        Ok(res) => {
            let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
33
34
                .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);

35
            // print the status
36
37
38
39
            println!(
                "Forwarding Request Worker URL: {}, Route: {}, Status: {}",
                worker_url, route, status
            );
40
41
42
43
            match res.bytes().await {
                Ok(body) => HttpResponse::build(status).body(body.to_vec()),
                Err(_) => HttpResponse::InternalServerError().finish(),
            }
44
        }
45
46
47
48
        Err(_) => HttpResponse::InternalServerError().finish(),
    }
}

49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#[get("/health")]
async fn health(data: web::Data<AppState>) -> impl Responder {
    let worker_url = match data.router.get_first() {
        Some(url) => url,
        None => return HttpResponse::InternalServerError().finish(),
    };

    forward_request(&data.client, worker_url, "/health".to_string()).await
}

#[get("/health_generate")]
async fn health_generate(data: web::Data<AppState>) -> impl Responder {
    let worker_url = match data.router.get_first() {
        Some(url) => url,
        None => return HttpResponse::InternalServerError().finish(),
    };

    forward_request(&data.client, worker_url, "/health_generate".to_string()).await
}

69
70
#[get("/get_server_info")]
async fn get_server_info(data: web::Data<AppState>) -> impl Responder {
71
72
73
74
75
    let worker_url = match data.router.get_first() {
        Some(url) => url,
        None => return HttpResponse::InternalServerError().finish(),
    };

76
    forward_request(&data.client, worker_url, "/get_server_info".to_string()).await
77
78
}

79
#[get("/v1/models")]
80
async fn v1_models(data: web::Data<AppState>) -> impl Responder {
81
    let worker_url = match data.router.get_first() {
82
83
84
        Some(url) => url,
        None => return HttpResponse::InternalServerError().finish(),
    };
85

86
    forward_request(&data.client, worker_url, "/v1/models".to_string()).await
87
88
}

89
90
91
#[get("/get_model_info")]
async fn get_model_info(data: web::Data<AppState>) -> impl Responder {
    let worker_url = match data.router.get_first() {
92
93
94
95
        Some(url) => url,
        None => return HttpResponse::InternalServerError().finish(),
    };

96
97
    forward_request(&data.client, worker_url, "/get_model_info".to_string()).await
}
98

99
100
#[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    data.router
        .dispatch(&data.client, req, body, "generate")
        .await
}

#[post("/v1/chat/completions")]
async fn v1_chat_completions(
    req: HttpRequest,
    body: Bytes,
    data: web::Data<AppState>,
) -> impl Responder {
    data.router
        .dispatch(&data.client, req, body, "v1/chat/completions")
        .await
}

#[post("/v1/completions")]
async fn v1_completions(
    req: HttpRequest,
    body: Bytes,
    data: web::Data<AppState>,
) -> impl Responder {
    data.router
        .dispatch(&data.client, req, body, "v1/completions")
        .await
126
127
}

128
129
130
131
pub async fn startup(
    host: String,
    port: u16,
    worker_urls: Vec<String>,
132
    policy_config: PolicyConfig,
133
) -> std::io::Result<()> {
134
135
    println!("Starting server on {}:{}", host, port);
    println!("Worker URLs: {:?}", worker_urls);
136
    println!("Policy Config: {:?}", policy_config);
137
138
139
140
141
142
143

    // Create client once with configuration
    let client = reqwest::Client::builder()
        .build()
        .expect("Failed to create HTTP client");

    // Store both worker_urls and client in AppState
144
    let app_state = web::Data::new(AppState::new(worker_urls, client, policy_config));
145
146
147
148
149

    HttpServer::new(move || {
        App::new()
            .app_data(app_state.clone())
            .service(generate)
150
151
152
            .service(v1_chat_completions)
            .service(v1_completions)
            .service(v1_models)
153
            .service(get_model_info)
154
155
            .service(health)
            .service(health_generate)
156
            .service(get_server_info)
157
158
159
160
    })
    .bind((host, port))?
    .run()
    .await
161
}