Unverified Commit 9a0cc2e9 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[router] Forward all request headers from router to workers (#3070)

parent 7bad7e75
#!/bin/bash #!/bin/bash
# Check if sudo is available
if command -v sudo >/dev/null 2>&1; then
sudo apt-get update
sudo apt-get install -y lsof
else
apt-get update
apt-get install -y lsof
fi
# Show current GPU status # Show current GPU status
nvidia-smi nvidia-smi
......
...@@ -22,6 +22,7 @@ def popen_launch_router( ...@@ -22,6 +22,7 @@ def popen_launch_router(
timeout: float, timeout: float,
policy: str = "cache_aware", policy: str = "cache_aware",
max_payload_size: int = None, max_payload_size: int = None,
api_key: str = None,
): ):
""" """
Launch the router server process. Launch the router server process.
...@@ -33,6 +34,7 @@ def popen_launch_router( ...@@ -33,6 +34,7 @@ def popen_launch_router(
timeout: Server launch timeout timeout: Server launch timeout
policy: Router policy, one of "cache_aware", "round_robin", "random" policy: Router policy, one of "cache_aware", "round_robin", "random"
max_payload_size: Maximum payload size in bytes max_payload_size: Maximum payload size in bytes
api_key: API key for the router
""" """
_, host, port = base_url.split(":") _, host, port = base_url.split(":")
host = host[2:] host = host[2:]
...@@ -55,6 +57,9 @@ def popen_launch_router( ...@@ -55,6 +57,9 @@ def popen_launch_router(
policy, policy,
] ]
if api_key is not None:
command.extend(["--api-key", api_key])
if max_payload_size is not None: if max_payload_size is not None:
command.extend(["--router-max-payload-size", str(max_payload_size)]) command.extend(["--router-max-payload-size", str(max_payload_size)])
...@@ -333,6 +338,57 @@ class TestLaunchServer(unittest.TestCase): ...@@ -333,6 +338,57 @@ class TestLaunchServer(unittest.TestCase):
f"1.2MB payload should fail with 413 but got status {response.status_code}", f"1.2MB payload should fail with 413 but got status {response.status_code}",
) )
def test_5_api_key(self):
print("Running test_5_api_key...")
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin",
api_key="correct_api_key",
)
# # Test case 1: request without api key should fail
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json={"text": "Kanye west is, ", "temperature": 0},
)
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(
response.status_code,
401,
"Request without api key should fail with 401",
)
# Test case 2: request with invalid api key should fail
with requests.Session() as session:
response = requests.post(
f"{self.base_url}/generate",
json={"text": "Kanye west is, ", "temperature": 0},
headers={"Authorization": "Bearer 123"},
)
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(
response.status_code,
401,
"Request with invalid api key should fail with 401",
)
# Test case 3: request with correct api key should succeed
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json={"text": "Kanye west is ", "temperature": 0},
headers={"Authorization": "Bearer correct_api_key"},
)
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(
response.status_code, 200, "Request with correct api key should succeed"
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -12,6 +12,18 @@ use std::thread; ...@@ -12,6 +12,18 @@ use std::thread;
use std::time::Duration; use std::time::Duration;
use tokio; use tokio;
fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
req.headers()
.iter()
.filter_map(|(name, value)| {
value
.to_str()
.ok()
.map(|v| (name.to_string(), v.to_string()))
})
.collect()
}
#[derive(Debug)] #[derive(Debug)]
pub enum Router { pub enum Router {
RoundRobin { RoundRobin {
...@@ -303,8 +315,18 @@ impl Router { ...@@ -303,8 +315,18 @@ impl Router {
client: &reqwest::Client, client: &reqwest::Client,
worker_url: &str, worker_url: &str,
route: &str, route: &str,
req: &HttpRequest,
) -> HttpResponse { ) -> HttpResponse {
match client.get(format!("{}{}", worker_url, route)).send().await { let mut request_builder = client.get(format!("{}{}", worker_url, route));
// Copy all headers from original request except for /health because it does not need authorization
if route != "/health" {
for (name, value) in copy_request_headers(req) {
request_builder = request_builder.header(name, value);
}
}
match request_builder.send().await {
Ok(res) => { Ok(res) => {
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
...@@ -322,7 +344,12 @@ impl Router { ...@@ -322,7 +344,12 @@ impl Router {
} }
} }
pub async fn route_to_first(&self, client: &reqwest::Client, route: &str) -> HttpResponse { pub async fn route_to_first(
&self,
client: &reqwest::Client,
route: &str,
req: &HttpRequest,
) -> HttpResponse {
const MAX_REQUEST_RETRIES: u32 = 3; const MAX_REQUEST_RETRIES: u32 = 3;
const MAX_TOTAL_RETRIES: u32 = 6; const MAX_TOTAL_RETRIES: u32 = 6;
let mut total_retries = 0; let mut total_retries = 0;
...@@ -338,10 +365,17 @@ impl Router { ...@@ -338,10 +365,17 @@ impl Router {
info!("Retrying request after {} failed attempts", total_retries); info!("Retrying request after {} failed attempts", total_retries);
} }
let response = self.send_request(client, &worker_url, route).await; let response = self.send_request(client, &worker_url, route, req).await;
if response.status().is_success() { if response.status().is_success() {
return response; return response;
} else {
// if the worker is healthy, it means the request is bad, so return the error response
let health_response =
self.send_request(client, &worker_url, "/health", req).await;
if health_response.status().is_success() {
return response;
}
} }
warn!( warn!(
...@@ -496,19 +530,16 @@ impl Router { ...@@ -496,19 +530,16 @@ impl Router {
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
.unwrap_or(false); .unwrap_or(false);
let res = match client let mut request_builder = client
.post(format!("{}{}", worker_url, route)) .post(format!("{}{}", worker_url, route))
.header( .body(body.to_vec());
"Content-Type",
req.headers() // Copy all headers from original request
.get("Content-Type") for (name, value) in copy_request_headers(req) {
.and_then(|h| h.to_str().ok()) request_builder = request_builder.header(name, value);
.unwrap_or("application/json"), }
)
.body(body.to_vec()) let res = match request_builder.send().await {
.send()
.await
{
Ok(res) => res, Ok(res) => res,
Err(_) => return HttpResponse::InternalServerError().finish(), Err(_) => return HttpResponse::InternalServerError().finish(),
}; };
...@@ -596,6 +627,13 @@ impl Router { ...@@ -596,6 +627,13 @@ impl Router {
if response.status().is_success() { if response.status().is_success() {
return response; return response;
} else {
// if the worker is healthy, it means the request is bad, so return the error response
let health_response =
self.send_request(client, &worker_url, "/health", req).await;
if health_response.status().is_success() {
return response;
}
} }
warn!( warn!(
......
...@@ -26,33 +26,37 @@ impl AppState { ...@@ -26,33 +26,37 @@ impl AppState {
} }
#[get("/health")] #[get("/health")]
async fn health(data: web::Data<AppState>) -> impl Responder { async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.route_to_first(&data.client, "/health").await data.router
.route_to_first(&data.client, "/health", &req)
.await
} }
#[get("/health_generate")] #[get("/health_generate")]
async fn health_generate(data: web::Data<AppState>) -> impl Responder { async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router data.router
.route_to_first(&data.client, "/health_generate") .route_to_first(&data.client, "/health_generate", &req)
.await .await
} }
#[get("/get_server_info")] #[get("/get_server_info")]
async fn get_server_info(data: web::Data<AppState>) -> impl Responder { async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router data.router
.route_to_first(&data.client, "/get_server_info") .route_to_first(&data.client, "/get_server_info", &req)
.await .await
} }
#[get("/v1/models")] #[get("/v1/models")]
async fn v1_models(data: web::Data<AppState>) -> impl Responder { async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.route_to_first(&data.client, "/v1/models").await data.router
.route_to_first(&data.client, "/v1/models", &req)
.await
} }
#[get("/get_model_info")] #[get("/get_model_info")]
async fn get_model_info(data: web::Data<AppState>) -> impl Responder { async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router data.router
.route_to_first(&data.client, "/get_model_info") .route_to_first(&data.client, "/get_model_info", &req)
.await .await
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment