Unverified Commit 4c56c8ae authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: added middleware layer to catch json validation errors (#3182)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 37bc8444
......@@ -9,8 +9,11 @@ use std::{
use axum::{
Json, Router,
body::Body,
extract::State,
http::Request,
http::{HeaderMap, StatusCode},
middleware::{self, Next},
response::{
IntoResponse, Response,
sse::{KeepAlive, Sse},
......@@ -65,7 +68,7 @@ fn get_body_limit() -> usize {
pub type ErrorResponse = (StatusCode, Json<ErrorMessage>);
#[derive(Serialize, Deserialize)]
#[derive(Serialize, Deserialize, Debug)]
pub(crate) struct ErrorMessage {
error: String,
}
......@@ -165,6 +168,31 @@ impl From<HttpError> for ErrorMessage {
}
}
// Problem: Currently we are using JSON from axum as the request validator. Whenever there is an invalid JSON, it will return a 422.
// But all the downstream apps that relies on openai based APIs, expects to get 400 for all these cases otherwise they fail badly
// Solution: Intercept the response from handlers and convert ANY 422 status codes to 400 with the actual error message.
pub async fn smart_json_error_middleware(request: Request<Body>, next: Next) -> Response {
let response = next.run(request).await;
if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
let (_parts, body) = response.into_parts();
let body_bytes = axum::body::to_bytes(body, usize::MAX)
.await
.unwrap_or_default();
let error_message = String::from_utf8_lossy(&body_bytes).to_string();
(
StatusCode::BAD_REQUEST,
Json(ErrorMessage {
error: error_message,
}),
)
.into_response()
} else {
// Pass through if it is not a 422
response
}
}
/// Get the request ID from a primary source, or next from the headers, or lastly create a new one if not present
// TODO: Similar function exists in lib/llm/src/grpc/service/openai.rs but with different signature and simpler logic
fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> String {
......@@ -1054,6 +1082,7 @@ pub fn completions_router(
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(handler_completions))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state(state);
(vec![doc], router)
......@@ -1070,6 +1099,7 @@ pub fn chat_completions_router(
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(handler_chat_completions))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state((state, template));
(vec![doc], router)
......@@ -1085,6 +1115,7 @@ pub fn embeddings_router(
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(embeddings))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state(state);
(vec![doc], router)
......@@ -1117,6 +1148,7 @@ pub fn responses_router(
let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new()
.route(&path, post(handler_responses))
.layer(middleware::from_fn(smart_json_error_middleware))
.with_state((state, template));
(vec![doc], router)
}
......
......@@ -531,12 +531,7 @@ async fn test_http_service() {
.await
.unwrap();
assert_eq!(
response.status(),
StatusCode::UNPROCESSABLE_ENTITY,
"{:?}",
response
);
assert_eq!(response.status(), StatusCode::BAD_REQUEST, "{:?}", response);
// =========== Query /metrics endpoint ===========
let response = client
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment