Unverified Commit cb3dc244 authored by Ayush Agarwal's avatar Ayush Agarwal Committed by GitHub
Browse files

feat: updated error schema (#3210)


Signed-off-by: default avatarayushag <ayushag@nvidia.com>
parent 818c2642
......@@ -70,16 +70,30 @@ pub type ErrorResponse = (StatusCode, Json<ErrorMessage>);
#[derive(Serialize, Deserialize, Debug)]
pub(crate) struct ErrorMessage {
error: String,
message: String,
#[serde(rename = "type")]
error_type: String,
code: u16,
}
fn map_error_code_to_error_type(code: StatusCode) -> String {
match code.canonical_reason() {
Some(reason) => reason.to_string(),
None => "UnknownError".to_string(),
}
}
impl ErrorMessage {
/// Not Found Error
pub fn model_not_found() -> ErrorResponse {
let code = StatusCode::NOT_FOUND;
let error_type = map_error_code_to_error_type(code);
(
StatusCode::NOT_FOUND,
code,
Json(ErrorMessage {
error: "Model not found".to_string(),
message: "Model not found".to_string(),
error_type,
code: code.as_u16(),
}),
)
}
......@@ -87,10 +101,14 @@ impl ErrorMessage {
/// Service Unavailable
/// This is returned when the service is live, but not ready.
pub fn _service_unavailable() -> ErrorResponse {
let code = StatusCode::SERVICE_UNAVAILABLE;
let error_type = map_error_code_to_error_type(code);
(
StatusCode::SERVICE_UNAVAILABLE,
code,
Json(ErrorMessage {
error: "Service is not ready".to_string(),
message: "Service is not ready".to_string(),
error_type,
code: code.as_u16(),
}),
)
}
......@@ -101,10 +119,14 @@ impl ErrorMessage {
/// Internal Services errors are the result of misconfiguration or bugs in the service.
pub fn internal_server_error(msg: &str) -> ErrorResponse {
tracing::error!("Internal server error: {msg}");
let code = StatusCode::INTERNAL_SERVER_ERROR;
let error_type = map_error_code_to_error_type(code);
(
StatusCode::INTERNAL_SERVER_ERROR,
code,
Json(ErrorMessage {
error: msg.to_string(),
message: msg.to_string(),
error_type,
code: code.as_u16(),
}),
)
}
......@@ -114,10 +136,14 @@ impl ErrorMessage {
/// This should be used for features that are planned but not available.
pub fn not_implemented_error(msg: &str) -> ErrorResponse {
tracing::error!("Not Implemented error: {msg}");
let code = StatusCode::NOT_IMPLEMENTED;
let error_type = map_error_code_to_error_type(code);
(
StatusCode::NOT_IMPLEMENTED,
code,
Json(ErrorMessage {
error: msg.to_string(),
message: msg.to_string(),
error_type,
code: code.as_u16(),
}),
)
}
......@@ -138,7 +164,9 @@ impl ErrorMessage {
return (
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorMessage {
error: pipeline_err.to_string(),
message: pipeline_err.to_string(),
error_type: map_error_code_to_error_type(StatusCode::SERVICE_UNAVAILABLE),
code: StatusCode::SERVICE_UNAVAILABLE.as_u16(),
}),
);
}
......@@ -156,7 +184,14 @@ impl ErrorMessage {
return ErrorMessage::internal_server_error(&err.message);
}
match StatusCode::from_u16(err.code) {
Ok(code) => (code, Json(ErrorMessage { error: err.message })),
Ok(code) => (
code,
Json(ErrorMessage {
message: err.message,
error_type: map_error_code_to_error_type(code),
code: code.as_u16(),
}),
),
Err(_) => ErrorMessage::internal_server_error(&err.message),
}
}
......@@ -164,7 +199,13 @@ impl ErrorMessage {
impl From<HttpError> for ErrorMessage {
fn from(err: HttpError) -> Self {
ErrorMessage { error: err.message }
ErrorMessage {
message: err.message,
error_type: map_error_code_to_error_type(
StatusCode::from_u16(err.code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
),
code: err.code,
}
}
}
......@@ -183,7 +224,9 @@ pub async fn smart_json_error_middleware(request: Request<Body>, next: Next) ->
(
StatusCode::BAD_REQUEST,
Json(ErrorMessage {
error: error_message,
message: error_message,
error_type: map_error_code_to_error_type(StatusCode::BAD_REQUEST),
code: StatusCode::BAD_REQUEST.as_u16(),
}),
)
.into_response()
......@@ -1221,36 +1264,36 @@ mod tests {
#[test]
fn test_http_error_response_from_anyhow() {
let err = http_error_from_engine(400).unwrap_err();
let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(response.error, "custom error message");
let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(response.0, StatusCode::BAD_REQUEST);
assert_eq!(response.1.message, "custom error message");
}
#[test]
fn test_error_response_from_anyhow_out_of_range() {
let err = http_error_from_engine(399).unwrap_err();
let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.error, "custom error message");
let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(response.0, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.1.message, "custom error message");
let err = http_error_from_engine(500).unwrap_err();
let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.error, "custom error message");
let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(response.0, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.1.message, "custom error message");
let err = http_error_from_engine(501).unwrap_err();
let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.error, "custom error message");
let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(response.0, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(response.1.message, "custom error message");
}
#[test]
fn test_other_error_response_from_anyhow() {
let err = other_error_from_engine().unwrap_err();
let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(response.0, StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
response.error,
response.1.message,
format!(
"{}: {}",
BACKUP_ERROR_MESSAGE,
......@@ -1267,10 +1310,10 @@ mod tests {
"All workers are busy, please retry later".to_string(),
)
.into();
let (status, response) = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE);
let response = ErrorMessage::from_anyhow(err, BACKUP_ERROR_MESSAGE);
assert_eq!(response.0, StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(
response.error,
response.1.message,
"Service temporarily unavailable: All workers are busy, please retry later"
);
}
......@@ -1386,10 +1429,10 @@ mod tests {
};
let result = validate_chat_completion_required_fields(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
error_response.1.message,
"The 'messages' field cannot be empty. At least one message is required."
);
}
......@@ -1450,10 +1493,10 @@ mod tests {
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
error_response.1.message,
"Frequency penalty must be between -2 and 2, got -3"
);
}
......@@ -1471,10 +1514,10 @@ mod tests {
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
error_response.1.message,
"Presence penalty must be between -2 and 2, got -3"
);
}
......@@ -1492,10 +1535,10 @@ mod tests {
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
error_response.1.message,
"Temperature must be between 0 and 2, got -3"
);
}
......@@ -1513,10 +1556,10 @@ mod tests {
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
error_response.1.message,
"Top_p must be between 0 and 1, got -3"
);
}
......@@ -1536,10 +1579,10 @@ mod tests {
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
error_response.1.message,
"Repetition penalty must be between 0 and 2, got -3"
);
}
......@@ -1557,10 +1600,10 @@ mod tests {
};
let result = validate_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
error_response.1.message,
"Logprobs must be between 0 and 5, got 6"
);
}
......@@ -1588,10 +1631,10 @@ mod tests {
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
error_response.1.message,
"Frequency penalty must be between -2 and 2, got -3"
);
}
......@@ -1615,10 +1658,10 @@ mod tests {
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
error_response.1.message,
"Presence penalty must be between -2 and 2, got -3"
);
}
......@@ -1642,10 +1685,10 @@ mod tests {
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
error_response.1.message,
"Temperature must be between 0 and 2, got -3"
);
}
......@@ -1669,10 +1712,10 @@ mod tests {
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
error_response.1.message,
"Top_p must be between 0 and 1, got -3"
);
}
......@@ -1698,10 +1741,10 @@ mod tests {
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
error_response.1.message,
"Repetition penalty must be between 0 and 2, got -3"
);
}
......@@ -1725,10 +1768,10 @@ mod tests {
};
let result = validate_chat_completion_fields_generic(&request);
assert!(result.is_err());
if let Err((status, error_response)) = result {
assert_eq!(status, StatusCode::BAD_REQUEST);
if let Err(error_response) = result {
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
assert_eq!(
error_response.error,
error_response.1.message,
"Top_logprobs must be between 0 and 20, got 25"
);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment