Unverified Commit 4c56c8ae authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

chore: added middleware layer to catch json validation errors (#3182)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 37bc8444
...@@ -9,8 +9,11 @@ use std::{ ...@@ -9,8 +9,11 @@ use std::{
use axum::{ use axum::{
Json, Router, Json, Router,
body::Body,
extract::State, extract::State,
http::Request,
http::{HeaderMap, StatusCode}, http::{HeaderMap, StatusCode},
middleware::{self, Next},
response::{ response::{
IntoResponse, Response, IntoResponse, Response,
sse::{KeepAlive, Sse}, sse::{KeepAlive, Sse},
...@@ -65,7 +68,7 @@ fn get_body_limit() -> usize { ...@@ -65,7 +68,7 @@ fn get_body_limit() -> usize {
pub type ErrorResponse = (StatusCode, Json<ErrorMessage>); pub type ErrorResponse = (StatusCode, Json<ErrorMessage>);
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize, Debug)]
pub(crate) struct ErrorMessage { pub(crate) struct ErrorMessage {
error: String, error: String,
} }
...@@ -165,6 +168,31 @@ impl From<HttpError> for ErrorMessage { ...@@ -165,6 +168,31 @@ impl From<HttpError> for ErrorMessage {
} }
} }
// Problem: Currently we are using JSON from axum as the request validator. Whenever there is an invalid JSON, it will return a 422.
// But all the downstream apps that relies on openai based APIs, expects to get 400 for all these cases otherwise they fail badly
// Solution: Intercept the response from handlers and convert ANY 422 status codes to 400 with the actual error message.
pub async fn smart_json_error_middleware(request: Request<Body>, next: Next) -> Response {
let response = next.run(request).await;
if response.status() == StatusCode::UNPROCESSABLE_ENTITY {
let (_parts, body) = response.into_parts();
let body_bytes = axum::body::to_bytes(body, usize::MAX)
.await
.unwrap_or_default();
let error_message = String::from_utf8_lossy(&body_bytes).to_string();
(
StatusCode::BAD_REQUEST,
Json(ErrorMessage {
error: error_message,
}),
)
.into_response()
} else {
// Pass through if it is not a 422
response
}
}
/// Get the request ID from a primary source, or next from the headers, or lastly create a new one if not present /// Get the request ID from a primary source, or next from the headers, or lastly create a new one if not present
// TODO: Similar function exists in lib/llm/src/grpc/service/openai.rs but with different signature and simpler logic // TODO: Similar function exists in lib/llm/src/grpc/service/openai.rs but with different signature and simpler logic
fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> String { fn get_or_create_request_id(primary: Option<&str>, headers: &HeaderMap) -> String {
...@@ -1054,6 +1082,7 @@ pub fn completions_router( ...@@ -1054,6 +1082,7 @@ pub fn completions_router(
let doc = RouteDoc::new(axum::http::Method::POST, &path); let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new() let router = Router::new()
.route(&path, post(handler_completions)) .route(&path, post(handler_completions))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit())) .layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state(state); .with_state(state);
(vec![doc], router) (vec![doc], router)
...@@ -1070,6 +1099,7 @@ pub fn chat_completions_router( ...@@ -1070,6 +1099,7 @@ pub fn chat_completions_router(
let doc = RouteDoc::new(axum::http::Method::POST, &path); let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new() let router = Router::new()
.route(&path, post(handler_chat_completions)) .route(&path, post(handler_chat_completions))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit())) .layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state((state, template)); .with_state((state, template));
(vec![doc], router) (vec![doc], router)
...@@ -1085,6 +1115,7 @@ pub fn embeddings_router( ...@@ -1085,6 +1115,7 @@ pub fn embeddings_router(
let doc = RouteDoc::new(axum::http::Method::POST, &path); let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new() let router = Router::new()
.route(&path, post(embeddings)) .route(&path, post(embeddings))
.layer(middleware::from_fn(smart_json_error_middleware))
.layer(axum::extract::DefaultBodyLimit::max(get_body_limit())) .layer(axum::extract::DefaultBodyLimit::max(get_body_limit()))
.with_state(state); .with_state(state);
(vec![doc], router) (vec![doc], router)
...@@ -1117,6 +1148,7 @@ pub fn responses_router( ...@@ -1117,6 +1148,7 @@ pub fn responses_router(
let doc = RouteDoc::new(axum::http::Method::POST, &path); let doc = RouteDoc::new(axum::http::Method::POST, &path);
let router = Router::new() let router = Router::new()
.route(&path, post(handler_responses)) .route(&path, post(handler_responses))
.layer(middleware::from_fn(smart_json_error_middleware))
.with_state((state, template)); .with_state((state, template));
(vec![doc], router) (vec![doc], router)
} }
......
...@@ -531,12 +531,7 @@ async fn test_http_service() { ...@@ -531,12 +531,7 @@ async fn test_http_service() {
.await .await
.unwrap(); .unwrap();
assert_eq!( assert_eq!(response.status(), StatusCode::BAD_REQUEST, "{:?}", response);
response.status(),
StatusCode::UNPROCESSABLE_ENTITY,
"{:?}",
response
);
// =========== Query /metrics endpoint =========== // =========== Query /metrics endpoint ===========
let response = client let response = client
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment