Unverified Commit cb3dc244 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: updated error schema (#3210)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 818c2642
...@@ -70,16 +70,30 @@ pub type ErrorResponse = (StatusCode, Json<ErrorMessage>); ...@@ -70,16 +70,30 @@ pub type ErrorResponse = (StatusCode, Json<ErrorMessage>);
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub(crate) struct ErrorMessage { pub(crate) struct ErrorMessage {
error: String, message: String,
#[serde(rename = "type")]
error_type: String,
code: u16,
}
fn map_error_code_to_error_type(code: StatusCode) -> String {
match code.canonical_reason() {
Some(reason) => reason.to_string(),
None => "UnknownError".to_string(),
}
} }
impl ErrorMessage { impl ErrorMessage {
/// Not Found Error /// Not Found Error
pub fn model_not_found() -> ErrorResponse { pub fn model_not_found() -> ErrorResponse {
let code = StatusCode::NOT_FOUND;
let error_type = map_error_code_to_error_type(code);
( (
StatusCode::NOT_FOUND, code,
Json(ErrorMessage { Json(ErrorMessage {
error: "Model not found".to_string(), message: "Model not found".to_string(),
error_type,
code: code.as_u16(),
}), }),
) )
} }
...@@ -87,10 +101,14 @@ impl ErrorMessage { ...@@ -87,10 +101,14 @@ impl ErrorMessage {
/// Service Unavailable /// Service Unavailable
/// This is returned when the service is live, but not ready. /// This is returned when the service is live, but not ready.
pub fn _service_unavailable() -> ErrorResponse { pub fn _service_unavailable() -> ErrorResponse {
let code = StatusCode::SERVICE_UNAVAILABLE;
let error_type = map_error_code_to_error_type(code);
( (
StatusCode::SERVICE_UNAVAILABLE, code,
Json(ErrorMessage { Json(ErrorMessage {
error: "Service is not ready".to_string(), message: "Service is not ready".to_string(),
error_type,
code: code.as_u16(),
}), }),
) )
} }
...@@ -101,10 +119,14 @@ impl ErrorMessage { ...@@ -101,10 +119,14 @@ impl ErrorMessage {
/// Internal Services errors are the result of misconfiguration or bugs in the service. /// Internal Services errors are the result of misconfiguration or bugs in the service.
pub fn internal_server_error(msg: &str) -> ErrorResponse { pub fn internal_server_error(msg: &str) -> ErrorResponse {
tracing::error!("Internal server error: {msg}"); tracing::error!("Internal server error: {msg}");
let code = StatusCode::INTERNAL_SERVER_ERROR;
let error_type = map_error_code_to_error_type(code);
( (
StatusCode::INTERNAL_SERVER_ERROR, code,
Json(ErrorMessage { Json(ErrorMessage {
error: msg.to_string(), message: msg.to_string(),
error_type,
code: code.as_u16(),
}), }),
) )
} }
...@@ -114,10 +136,14 @@ impl ErrorMessage { ...@@ -114,10 +136,14 @@ impl ErrorMessage {
/// This should be used for features that are planned but not available. /// This should be used for features that are planned but not available.
pub fn not_implemented_error(msg: &str) -> ErrorResponse { pub fn not_implemented_error(msg: &str) -> ErrorResponse {
tracing::error!("Not Implemented error: {msg}"); tracing::error!("Not Implemented error: {msg}");
let code = StatusCode::NOT_IMPLEMENTED;
let error_type = map_error_code_to_error_type(code);
( (
StatusCode::NOT_IMPLEMENTED, code,
Json(ErrorMessage { Json(ErrorMessage {
error: msg.to_string(), message: msg.to_string(),
error_type,
code: code.as_u16(),
}), }),
) )
} }
...@@ -138,7 +164,9 @@ impl ErrorMessage { ...@@ -138,7 +164,9 @@ impl ErrorMessage {
return ( return (
StatusCode::SERVICE_UNAVAILABLE, StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorMessage { Json(ErrorMessage {
error: pipeline_err.to_string(), message: pipeline_err.to_string(),
error_type: map_error_code_to_error_type(StatusCode::SERVICE_UNAVAILABLE),
code: StatusCode::SERVICE_UNAVAILABLE.as_u16(),
}), }),
); );
} }
...@@ -156,7 +184,14 @@ impl ErrorMessage { ...@@ -156,7 +184,14 @@ impl ErrorMessage {
return ErrorMessage::internal_server_error(&err.message); return ErrorMessage::internal_server_error(&err.message);
} }
match StatusCode::from_u16(err.code) { match StatusCode::from_u16(err.code) {
Ok(code) => (code, Json(ErrorMessage { error: err.message })), Ok(code) => (
code,
Json(ErrorMessage {
message: err.message,
error_type: map_error_code_to_error_type(code),
code: code.as_u16(),
}),
),
Err(_) => ErrorMessage::internal_server_error(&err.message), Err(_) => ErrorMessage::internal_server_error(&err.message),
} }
} }
...@@ -164,7 +199,13 @@ impl ErrorMessage { ...@@ -164,7 +199,13 @@ impl ErrorMessage {
impl From<HttpError> for ErrorMessage { impl From<HttpError> for ErrorMessage {
fn from(err: HttpError) -> Self { fn from(err: HttpError) -> Self {
ErrorMessage { error: err.message } ErrorMessage {
message: err.message,
error_type: map_error_code_to_error_type(
StatusCode::from_u16(err.code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
),
code: err.code,
}
} }
} }
...@@ -183,7 +224,9 @@ pub async fn smart_json_error_middleware(request: Request<Body>, next: Next) -> ...@@ -183,7 +224,9 @@ pub async fn smart_json_error_middleware(request: Request<Body>, next: Next) ->
( (
StatusCode::BAD_REQUEST, StatusCode::BAD_REQUEST,
Json(ErrorMessage { Json(ErrorMessage {
error: error_message, message: error_message,
error_type: map_error_code_to_error_type(StatusCode::BAD_REQUEST),
code: StatusCode::BAD_REQUEST.as_u16(),
}), }),
) )
.into_response() .into_response()
...@@ -1221,36 +1264,36 @@ mod tests { ...@@ -1221,36 +1264,36 @@ mod tests {
#[test] #[test]
fn test_http_error_response_from_anyhow() { fn test_http_error_response_from_anyhow() {
let err = http_error_from_engine(400).unwrap_err(); let err = http_error_from_engine(400).unwrap_err();
let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE); let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::BAD_REQUEST); assert_eq!(response.0, StatusCode::BAD_REQUEST);
assert_eq!(response.error, "custom error message"); assert_eq!(response.1.message, "custom error message");
} }
#[test] #[test]
fn test_error_response_from_anyhow_out_of_range() { fn test_error_response_from_anyhow_out_of_range() {
let err = http_error_from_engine(399).unwrap_err(); let err = http_error_from_engine(399).unwrap_err();
let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE); let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(response.0, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.error, "custom error message"); assert_eq!(response.1.message, "custom error message");
let err = http_error_from_engine(500).unwrap_err(); let err = http_error_from_engine(500).unwrap_err();
let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE); let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(response.0, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.error, "custom error message"); assert_eq!(response.1.message, "custom error message");
let err = http_error_from_engine(501).unwrap_err(); let err = http_error_from_engine(501).unwrap_err();
let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE); let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(response.0, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.error, "custom error message"); assert_eq!(response.1.message, "custom error message");
} }
#[test] #[test]
fn test_other_error_response_from_anyhow() { fn test_other_error_response_from_anyhow() {
let err = other_error_from_engine().unwrap_err(); let err = other_error_from_engine().unwrap_err();
let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE); let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(response.0, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!( assert_eq!(
response.error, response.1.message,
format!( format!(
"{}: {}", "{}: {}",
BACKUP_ERROR_MESSAGE, BACKUP_ERROR_MESSAGE,
...@@ -1267,10 +1310,10 @@ mod tests { ...@@ -1267,10 +1310,10 @@ mod tests {
"All workers are busy, please retry later".to_string(), "All workers are busy, please retry later".to_string(),
) )
.into(); .into();
let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE); let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE); assert_eq!(response.0, StatusCode::SERVICE_UNAVAILABLE);
assert_eq!( assert_eq!(
response.error, response.1.message,
"Service temporarily unavailable: All workers are busy, please retry later" "Service temporarily unavailable: All workers are busy, please retry later"
); );
} }
...@@ -1386,10 +1429,10 @@ mod tests { ...@@ -1386,10 +1429,10 @@ mod tests {
}; };
let result = validate_chat_completion_required_fields(&request); let result = validate_chat_completion_required_fields(&request);
assert!(result.is_err()); assert!(result.is_err());
if let Err((status, error_response)) = result { if let Err(error_response) = result {
assert_eq!(status, StatusCode::BAD_REQUEST); assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_response.error, error_response.1.message,
"The 'messages' field cannot be empty. At least one message is required." "The 'messages' field cannot be empty. At least one message is required."
); );
} }
...@@ -1450,10 +1493,10 @@ mod tests { ...@@ -1450,10 +1493,10 @@ mod tests {
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
if let Err((status, error_response)) = result { if let Err(error_response) = result {
assert_eq!(status, StatusCode::BAD_REQUEST); assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_response.error, error_response.1.message,
"Frequency penalty must be between -2 and 2, got -3" "Frequency penalty must be between -2 and 2, got -3"
); );
} }
...@@ -1471,10 +1514,10 @@ mod tests { ...@@ -1471,10 +1514,10 @@ mod tests {
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
if let Err((status, error_response)) = result { if let Err(error_response) = result {
assert_eq!(status, StatusCode::BAD_REQUEST); assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_response.error, error_response.1.message,
"Presence penalty must be between -2 and 2, got -3" "Presence penalty must be between -2 and 2, got -3"
); );
} }
...@@ -1492,10 +1535,10 @@ mod tests { ...@@ -1492,10 +1535,10 @@ mod tests {
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
if let Err((status, error_response)) = result { if let Err(error_response) = result {
assert_eq!(status, StatusCode::BAD_REQUEST); assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_response.error, error_response.1.message,
"Temperature must be between 0 and 2, got -3" "Temperature must be between 0 and 2, got -3"
); );
} }
...@@ -1513,10 +1556,10 @@ mod tests { ...@@ -1513,10 +1556,10 @@ mod tests {
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
if let Err((status, error_response)) = result { if let Err(error_response) = result {
assert_eq!(status, StatusCode::BAD_REQUEST); assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_response.error, error_response.1.message,
"Top_p must be between 0 and 1, got -3" "Top_p must be between 0 and 1, got -3"
); );
} }
...@@ -1536,10 +1579,10 @@ mod tests { ...@@ -1536,10 +1579,10 @@ mod tests {
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
if let Err((status, error_response)) = result { if let Err(error_response) = result {
assert_eq!(status, StatusCode::BAD_REQUEST); assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_response.error, error_response.1.message,
"Repetition penalty must be between 0 and 2, got -3" "Repetition penalty must be between 0 and 2, got -3"
); );
} }
...@@ -1557,10 +1600,10 @@ mod tests { ...@@ -1557,10 +1600,10 @@ mod tests {
}; };
let result = validate_completion_fields_generic(&request); let result = validate_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
if let Err((status, error_response)) = result { if let Err(error_response) = result {
assert_eq!(status, StatusCode::BAD_REQUEST); assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_response.error, error_response.1.message,
"Logprobs must be between 0 and 5, got 6" "Logprobs must be between 0 and 5, got 6"
); );
} }
...@@ -1588,10 +1631,10 @@ mod tests { ...@@ -1588,10 +1631,10 @@ mod tests {
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
if let Err((status, error_response)) = result { if let Err(error_response) = result {
assert_eq!(status, StatusCode::BAD_REQUEST); assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_response.error, error_response.1.message,
"Frequency penalty must be between -2 and 2, got -3" "Frequency penalty must be between -2 and 2, got -3"
); );
} }
...@@ -1615,10 +1658,10 @@ mod tests { ...@@ -1615,10 +1658,10 @@ mod tests {
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
if let Err((status, error_response)) = result { if let Err(error_response) = result {
assert_eq!(status, StatusCode::BAD_REQUEST); assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_response.error, error_response.1.message,
"Presence penalty must be between -2 and 2, got -3" "Presence penalty must be between -2 and 2, got -3"
); );
} }
...@@ -1642,10 +1685,10 @@ mod tests { ...@@ -1642,10 +1685,10 @@ mod tests {
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
if let Err((status, error_response)) = result { if let Err(error_response) = result {
assert_eq!(status, StatusCode::BAD_REQUEST); assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_response.error, error_response.1.message,
"Temperature must be between 0 and 2, got -3" "Temperature must be between 0 and 2, got -3"
); );
} }
...@@ -1669,10 +1712,10 @@ mod tests { ...@@ -1669,10 +1712,10 @@ mod tests {
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
if let Err((status, error_response)) = result { if let Err(error_response) = result {
assert_eq!(status, StatusCode::BAD_REQUEST); assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_response.error, error_response.1.message,
"Top_p must be between 0 and 1, got -3" "Top_p must be between 0 and 1, got -3"
); );
} }
...@@ -1698,10 +1741,10 @@ mod tests { ...@@ -1698,10 +1741,10 @@ mod tests {
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
if let Err((status, error_response)) = result { if let Err(error_response) = result {
assert_eq!(status, StatusCode::BAD_REQUEST); assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_response.error, error_response.1.message,
"Repetition penalty must be between 0 and 2, got -3" "Repetition penalty must be between 0 and 2, got -3"
); );
} }
...@@ -1725,10 +1768,10 @@ mod tests { ...@@ -1725,10 +1768,10 @@ mod tests {
}; };
let result = validate_chat_completion_fields_generic(&request); let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err()); assert!(result.is_err());
if let Err((status, error_response)) = result { if let Err(error_response) = result {
assert_eq!(status, StatusCode::BAD_REQUEST); assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!( assert_eq!(
error_response.error, error_response.1.message,
"Top_logprobs must be between 0 and 20, got 25" "Top_logprobs must be between 0 and 20, got 25"
); );
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment