Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
4c56c8ae
Unverified
Commit
4c56c8ae
authored
Sep 23, 2025
by
Ayush Agarwal
Committed by
GitHub
Sep 23, 2025
Browse files
chore: added middleware layer to catch json validation errors (#3182)
Signed-off-by:
ayushag
<
ayushag@nvidia.com
>
parent
37bc8444
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
7 deletions
+34
-7
lib/llm/src/http/service/openai.rs
lib/llm/src/http/service/openai.rs
+33
-1
lib/llm/tests/http-service.rs
lib/llm/tests/http-service.rs
+1
-6
No files found.
lib/llm/src/http/service/openai.rs
View file @
4c56c8ae
...
@@ -9,8 +9,11 @@ use std::{
...
@@ -9,8 +9,11 @@ use std::{
use
axum
::{
use
axum
::{
Json
,
Router
,
Json
,
Router
,
body
::
Body
,
extract
::
State
,
extract
::
State
,
http
::
Request
,
http
::{
HeaderMap
,
StatusCode
},
http
::{
HeaderMap
,
StatusCode
},
middleware
::{
self
,
Next
},
response
::{
response
::{
IntoResponse
,
Response
,
IntoResponse
,
Response
,
sse
::{
KeepAlive
,
Sse
},
sse
::{
KeepAlive
,
Sse
},
...
@@ -65,7 +68,7 @@ fn get_body_limit() -> usize {
...
@@ -65,7 +68,7 @@ fn get_body_limit() -> usize {
pub
type
ErrorResponse
=
(
StatusCode
,
Json
<
ErrorMessage
>
);
pub
type
ErrorResponse
=
(
StatusCode
,
Json
<
ErrorMessage
>
);
#[derive(Serialize,
Deserialize)]
#[derive(Serialize,
Deserialize
,
Debug
)]
pub
(
crate
)
struct
ErrorMessage
{
pub
(
crate
)
struct
ErrorMessage
{
error
:
String
,
error
:
String
,
}
}
...
@@ -165,6 +168,31 @@ impl From<HttpError> for ErrorMessage {
...
@@ -165,6 +168,31 @@ impl From<HttpError> for ErrorMessage {
}
}
}
}
// Problem: Currently we are using JSON from axum as the request validator. Whenever there is an invalid JSON, it will return a 422.
// But all the downstream apps that relies on openai based APIs, expects to get 400 for all these cases otherwise they fail badly
// Solution: Intercept the response from handlers and convert ANY 422 status codes to 400 with the actual error message.
pub
async
fn
smart_json_error_middleware
(
request
:
Request
<
Body
>
,
next
:
Next
)
->
Response
{
let
response
=
next
.run
(
request
)
.await
;
if
response
.status
()
==
StatusCode
::
UNPROCESSABLE_ENTITY
{
let
(
_
parts
,
body
)
=
response
.into_parts
();
let
body_bytes
=
axum
::
body
::
to_bytes
(
body
,
usize
::
MAX
)
.await
.unwrap_or_default
();
let
error_message
=
String
::
from_utf8_lossy
(
&
body_bytes
)
.to_string
();
(
StatusCode
::
BAD_REQUEST
,
Json
(
ErrorMessage
{
error
:
error_message
,
}),
)
.into_response
()
}
else
{
// Pass through if it is not a 422
response
}
}
/// Get the request ID from a primary source, or next from the headers, or lastly create a new one if not present
/// Get the request ID from a primary source, or next from the headers, or lastly create a new one if not present
// TODO: Similar function exists in lib/llm/src/grpc/service/openai.rs but with different signature and simpler logic
// TODO: Similar function exists in lib/llm/src/grpc/service/openai.rs but with different signature and simpler logic
fn
get_or_create_request_id
(
primary
:
Option
<&
str
>
,
headers
:
&
HeaderMap
)
->
String
{
fn
get_or_create_request_id
(
primary
:
Option
<&
str
>
,
headers
:
&
HeaderMap
)
->
String
{
...
@@ -1054,6 +1082,7 @@ pub fn completions_router(
...
@@ -1054,6 +1082,7 @@ pub fn completions_router(
let
doc
=
RouteDoc
::
new
(
axum
::
http
::
Method
::
POST
,
&
path
);
let
doc
=
RouteDoc
::
new
(
axum
::
http
::
Method
::
POST
,
&
path
);
let
router
=
Router
::
new
()
let
router
=
Router
::
new
()
.route
(
&
path
,
post
(
handler_completions
))
.route
(
&
path
,
post
(
handler_completions
))
.layer
(
middleware
::
from_fn
(
smart_json_error_middleware
))
.layer
(
axum
::
extract
::
DefaultBodyLimit
::
max
(
get_body_limit
()))
.layer
(
axum
::
extract
::
DefaultBodyLimit
::
max
(
get_body_limit
()))
.with_state
(
state
);
.with_state
(
state
);
(
vec!
[
doc
],
router
)
(
vec!
[
doc
],
router
)
...
@@ -1070,6 +1099,7 @@ pub fn chat_completions_router(
...
@@ -1070,6 +1099,7 @@ pub fn chat_completions_router(
let
doc
=
RouteDoc
::
new
(
axum
::
http
::
Method
::
POST
,
&
path
);
let
doc
=
RouteDoc
::
new
(
axum
::
http
::
Method
::
POST
,
&
path
);
let
router
=
Router
::
new
()
let
router
=
Router
::
new
()
.route
(
&
path
,
post
(
handler_chat_completions
))
.route
(
&
path
,
post
(
handler_chat_completions
))
.layer
(
middleware
::
from_fn
(
smart_json_error_middleware
))
.layer
(
axum
::
extract
::
DefaultBodyLimit
::
max
(
get_body_limit
()))
.layer
(
axum
::
extract
::
DefaultBodyLimit
::
max
(
get_body_limit
()))
.with_state
((
state
,
template
));
.with_state
((
state
,
template
));
(
vec!
[
doc
],
router
)
(
vec!
[
doc
],
router
)
...
@@ -1085,6 +1115,7 @@ pub fn embeddings_router(
...
@@ -1085,6 +1115,7 @@ pub fn embeddings_router(
let
doc
=
RouteDoc
::
new
(
axum
::
http
::
Method
::
POST
,
&
path
);
let
doc
=
RouteDoc
::
new
(
axum
::
http
::
Method
::
POST
,
&
path
);
let
router
=
Router
::
new
()
let
router
=
Router
::
new
()
.route
(
&
path
,
post
(
embeddings
))
.route
(
&
path
,
post
(
embeddings
))
.layer
(
middleware
::
from_fn
(
smart_json_error_middleware
))
.layer
(
axum
::
extract
::
DefaultBodyLimit
::
max
(
get_body_limit
()))
.layer
(
axum
::
extract
::
DefaultBodyLimit
::
max
(
get_body_limit
()))
.with_state
(
state
);
.with_state
(
state
);
(
vec!
[
doc
],
router
)
(
vec!
[
doc
],
router
)
...
@@ -1117,6 +1148,7 @@ pub fn responses_router(
...
@@ -1117,6 +1148,7 @@ pub fn responses_router(
let
doc
=
RouteDoc
::
new
(
axum
::
http
::
Method
::
POST
,
&
path
);
let
doc
=
RouteDoc
::
new
(
axum
::
http
::
Method
::
POST
,
&
path
);
let
router
=
Router
::
new
()
let
router
=
Router
::
new
()
.route
(
&
path
,
post
(
handler_responses
))
.route
(
&
path
,
post
(
handler_responses
))
.layer
(
middleware
::
from_fn
(
smart_json_error_middleware
))
.with_state
((
state
,
template
));
.with_state
((
state
,
template
));
(
vec!
[
doc
],
router
)
(
vec!
[
doc
],
router
)
}
}
...
...
lib/llm/tests/http-service.rs
View file @
4c56c8ae
...
@@ -531,12 +531,7 @@ async fn test_http_service() {
...
@@ -531,12 +531,7 @@ async fn test_http_service() {
.await
.await
.unwrap
();
.unwrap
();
assert_eq!
(
assert_eq!
(
response
.status
(),
StatusCode
::
BAD_REQUEST
,
"{:?}"
,
response
);
response
.status
(),
StatusCode
::
UNPROCESSABLE_ENTITY
,
"{:?}"
,
response
);
// =========== Query /metrics endpoint ===========
// =========== Query /metrics endpoint ===========
let
response
=
client
let
response
=
client
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment